diff --git a/R/install-dev.bat b/R/install-dev.bat
index 008a5c668bc45..f32670b67de96 100644
--- a/R/install-dev.bat
+++ b/R/install-dev.bat
@@ -25,3 +25,8 @@ set SPARK_HOME=%~dp0..
MKDIR %SPARK_HOME%\R\lib
R.exe CMD INSTALL --library="%SPARK_HOME%\R\lib" %SPARK_HOME%\R\pkg\
+
+rem Zip the SparkR package so that it can be distributed to worker nodes on YARN
+pushd %SPARK_HOME%\R\lib
+%JAVA_HOME%\bin\jar.exe cfM "%SPARK_HOME%\R\lib\sparkr.zip" SparkR
+popd
diff --git a/R/install-dev.sh b/R/install-dev.sh
index 1edd551f8d243..4972bb9217072 100755
--- a/R/install-dev.sh
+++ b/R/install-dev.sh
@@ -34,7 +34,7 @@ LIB_DIR="$FWDIR/lib"
mkdir -p $LIB_DIR
-pushd $FWDIR
+pushd $FWDIR > /dev/null
# Generate Rd files if devtools is installed
Rscript -e ' if("devtools" %in% rownames(installed.packages())) { library(devtools); devtools::document(pkg="./pkg", roclets=c("rd")) }'
@@ -42,4 +42,8 @@ Rscript -e ' if("devtools" %in% rownames(installed.packages())) { library(devtoo
# Install SparkR to $LIB_DIR
R CMD INSTALL --library=$LIB_DIR $FWDIR/pkg/
-popd
+# Zip the SparkR package so that it can be distributed to worker nodes on YARN
+cd $LIB_DIR
+jar cfM "$LIB_DIR/sparkr.zip" SparkR
+
+popd > /dev/null
diff --git a/R/pkg/DESCRIPTION b/R/pkg/DESCRIPTION
index efc85bbc4b316..4949d86d20c91 100644
--- a/R/pkg/DESCRIPTION
+++ b/R/pkg/DESCRIPTION
@@ -29,7 +29,7 @@ Collate:
'client.R'
'context.R'
'deserialize.R'
+ 'mllib.R'
'serialize.R'
'sparkR.R'
'utils.R'
- 'zzz.R'
diff --git a/R/pkg/NAMESPACE b/R/pkg/NAMESPACE
index 7f857222452d4..7f7a8a2e4de24 100644
--- a/R/pkg/NAMESPACE
+++ b/R/pkg/NAMESPACE
@@ -10,6 +10,10 @@ export("sparkR.init")
export("sparkR.stop")
export("print.jobj")
+# MLlib integration
+exportMethods("glm",
+ "predict")
+
# Job group lifecycle management methods
export("setJobGroup",
"clearJobGroup",
@@ -22,6 +26,7 @@ exportMethods("arrange",
"collect",
"columns",
"count",
+ "crosstab",
"describe",
"distinct",
"dropna",
@@ -77,6 +82,7 @@ exportMethods("abs",
"atan",
"atan2",
"avg",
+ "between",
"cast",
"cbrt",
"ceiling",
diff --git a/R/pkg/R/DataFrame.R b/R/pkg/R/DataFrame.R
index 60702824acb46..06dd6b75dff3d 100644
--- a/R/pkg/R/DataFrame.R
+++ b/R/pkg/R/DataFrame.R
@@ -1314,7 +1314,7 @@ setMethod("except",
#' write.df(df, "myfile", "parquet", "overwrite")
#' }
setMethod("write.df",
- signature(df = "DataFrame", path = 'character'),
+ signature(df = "DataFrame", path = "character"),
function(df, path, source = NULL, mode = "append", ...){
if (is.null(source)) {
sqlContext <- get(".sparkRSQLsc", envir = .sparkREnv)
@@ -1328,7 +1328,7 @@ setMethod("write.df",
jmode <- callJStatic("org.apache.spark.sql.api.r.SQLUtils", "saveMode", mode)
options <- varargsToEnv(...)
if (!is.null(path)) {
- options[['path']] = path
+ options[["path"]] <- path
}
callJMethod(df@sdf, "save", source, jmode, options)
})
@@ -1337,7 +1337,7 @@ setMethod("write.df",
#' @aliases saveDF
#' @export
setMethod("saveDF",
- signature(df = "DataFrame", path = 'character'),
+ signature(df = "DataFrame", path = "character"),
function(df, path, source = NULL, mode = "append", ...){
write.df(df, path, source, mode, ...)
})
@@ -1375,8 +1375,8 @@ setMethod("saveDF",
#' saveAsTable(df, "myfile")
#' }
setMethod("saveAsTable",
- signature(df = "DataFrame", tableName = 'character', source = 'character',
- mode = 'character'),
+ signature(df = "DataFrame", tableName = "character", source = "character",
+ mode = "character"),
function(df, tableName, source = NULL, mode="append", ...){
if (is.null(source)) {
sqlContext <- get(".sparkRSQLsc", envir = .sparkREnv)
@@ -1554,3 +1554,31 @@ setMethod("fillna",
}
dataFrame(sdf)
})
+
+#' crosstab
+#'
+#' Computes a pair-wise frequency table of the given columns. Also known as a contingency
+#' table. The number of distinct values for each column should be less than 1e4. At most 1e6
+#' non-zero pair frequencies will be returned.
+#'
+#' @param col1 name of the first column. Distinct items will make the first item of each row.
+#' @param col2 name of the second column. Distinct items will make the column names of the output.
+#' @return a local R data.frame representing the contingency table. The first column of each row
+#' will be the distinct values of `col1` and the column names will be the distinct values
+#' of `col2`. The name of the first column will be `$col1_$col2`. Pairs that have no
+#' occurrences will have `null` as their counts.
+#'
+#' @rdname statfunctions
+#' @export
+#' @examples
+#' \dontrun{
+#' df <- jsonFile(sqlCtx, "/path/to/file.json")
+#' ct = crosstab(df, "title", "gender")
+#' }
+setMethod("crosstab",
+ signature(x = "DataFrame", col1 = "character", col2 = "character"),
+ function(x, col1, col2) {
+ statFunctions <- callJMethod(x@sdf, "stat")
+ sct <- callJMethod(statFunctions, "crosstab", col1, col2)
+ collect(dataFrame(sct))
+ })
diff --git a/R/pkg/R/RDD.R b/R/pkg/R/RDD.R
index 89511141d3ef7..d2d096709245d 100644
--- a/R/pkg/R/RDD.R
+++ b/R/pkg/R/RDD.R
@@ -165,7 +165,6 @@ setMethod("getJRDD", signature(rdd = "PipelinedRDD"),
serializedFuncArr,
rdd@env$prev_serializedMode,
packageNamesArr,
- as.character(.sparkREnv[["libname"]]),
broadcastArr,
callJMethod(prev_jrdd, "classTag"))
} else {
@@ -175,7 +174,6 @@ setMethod("getJRDD", signature(rdd = "PipelinedRDD"),
rdd@env$prev_serializedMode,
serializedMode,
packageNamesArr,
- as.character(.sparkREnv[["libname"]]),
broadcastArr,
callJMethod(prev_jrdd, "classTag"))
}
diff --git a/R/pkg/R/SQLContext.R b/R/pkg/R/SQLContext.R
index 30978bb50d339..110117a18ccbc 100644
--- a/R/pkg/R/SQLContext.R
+++ b/R/pkg/R/SQLContext.R
@@ -457,7 +457,7 @@ dropTempTable <- function(sqlContext, tableName) {
read.df <- function(sqlContext, path = NULL, source = NULL, schema = NULL, ...) {
options <- varargsToEnv(...)
if (!is.null(path)) {
- options[['path']] <- path
+ options[["path"]] <- path
}
if (is.null(source)) {
sqlContext <- get(".sparkRSQLsc", envir = .sparkREnv)
@@ -506,7 +506,7 @@ loadDF <- function(sqlContext, path = NULL, source = NULL, schema = NULL, ...) {
createExternalTable <- function(sqlContext, tableName, path = NULL, source = NULL, ...) {
options <- varargsToEnv(...)
if (!is.null(path)) {
- options[['path']] <- path
+ options[["path"]] <- path
}
sdf <- callJMethod(sqlContext, "createExternalTable", tableName, source, options)
dataFrame(sdf)
diff --git a/R/pkg/R/client.R b/R/pkg/R/client.R
index 78c7a3037ffac..6f772158ddfe8 100644
--- a/R/pkg/R/client.R
+++ b/R/pkg/R/client.R
@@ -36,9 +36,9 @@ connectBackend <- function(hostname, port, timeout = 6000) {
determineSparkSubmitBin <- function() {
if (.Platform$OS.type == "unix") {
- sparkSubmitBinName = "spark-submit"
+ sparkSubmitBinName <- "spark-submit"
} else {
- sparkSubmitBinName = "spark-submit.cmd"
+ sparkSubmitBinName <- "spark-submit.cmd"
}
sparkSubmitBinName
}
diff --git a/R/pkg/R/column.R b/R/pkg/R/column.R
index 8e4b0f5bf1c4d..2892e1416cc65 100644
--- a/R/pkg/R/column.R
+++ b/R/pkg/R/column.R
@@ -187,6 +187,23 @@ setMethod("substr", signature(x = "Column"),
column(jc)
})
+#' between
+#'
+#' Test if the column is between the lower bound and upper bound, inclusive.
+#'
+#' @rdname column
+#'
+#' @param bounds lower and upper bounds
+setMethod("between", signature(x = "Column"),
+ function(x, bounds) {
+ if (is.vector(bounds) && length(bounds) == 2) {
+ jc <- callJMethod(x@jc, "between", bounds[1], bounds[2])
+ column(jc)
+ } else {
+ stop("bounds should be a vector of lower and upper bounds")
+ }
+ })
+
#' Casts the column to a different data type.
#'
#' @rdname column
diff --git a/R/pkg/R/deserialize.R b/R/pkg/R/deserialize.R
index d961bbc383688..7d1f6b0819ed0 100644
--- a/R/pkg/R/deserialize.R
+++ b/R/pkg/R/deserialize.R
@@ -23,6 +23,7 @@
# Int -> integer
# String -> character
# Boolean -> logical
+# Float -> double
# Double -> double
# Long -> double
# Array[Byte] -> raw
diff --git a/R/pkg/R/generics.R b/R/pkg/R/generics.R
index fad9d71158c51..836e0175c391f 100644
--- a/R/pkg/R/generics.R
+++ b/R/pkg/R/generics.R
@@ -59,6 +59,10 @@ setGeneric("count", function(x) { standardGeneric("count") })
# @export
setGeneric("countByValue", function(x) { standardGeneric("countByValue") })
+# @rdname statfunctions
+# @export
+setGeneric("crosstab", function(x, col1, col2) { standardGeneric("crosstab") })
+
# @rdname distinct
# @export
setGeneric("distinct", function(x, numPartitions = 1) { standardGeneric("distinct") })
@@ -567,6 +571,10 @@ setGeneric("asc", function(x) { standardGeneric("asc") })
#' @export
setGeneric("avg", function(x, ...) { standardGeneric("avg") })
+#' @rdname column
+#' @export
+setGeneric("between", function(x, bounds) { standardGeneric("between") })
+
#' @rdname column
#' @export
setGeneric("cast", function(x, dataType) { standardGeneric("cast") })
@@ -657,3 +665,7 @@ setGeneric("toRadians", function(x) { standardGeneric("toRadians") })
#' @rdname column
#' @export
setGeneric("upper", function(x) { standardGeneric("upper") })
+
+#' @rdname glm
+#' @export
+setGeneric("glm")
diff --git a/R/pkg/R/group.R b/R/pkg/R/group.R
index 8f1c68f7c4d28..576ac72f40fc0 100644
--- a/R/pkg/R/group.R
+++ b/R/pkg/R/group.R
@@ -87,7 +87,7 @@ setMethod("count",
setMethod("agg",
signature(x = "GroupedData"),
function(x, ...) {
- cols = list(...)
+ cols <- list(...)
stopifnot(length(cols) > 0)
if (is.character(cols[[1]])) {
cols <- varargsToEnv(...)
@@ -97,7 +97,7 @@ setMethod("agg",
if (!is.null(ns)) {
for (n in ns) {
if (n != "") {
- cols[[n]] = alias(cols[[n]], n)
+ cols[[n]] <- alias(cols[[n]], n)
}
}
}
diff --git a/R/pkg/R/mllib.R b/R/pkg/R/mllib.R
new file mode 100644
index 0000000000000..258e354081fc1
--- /dev/null
+++ b/R/pkg/R/mllib.R
@@ -0,0 +1,73 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+# mllib.R: Provides methods for MLlib integration
+
+#' @title S4 class that represents a PipelineModel
+#' @param model A Java object reference to the backing Scala PipelineModel
+#' @export
+setClass("PipelineModel", representation(model = "jobj"))
+
+#' Fits a generalized linear model
+#'
+#' Fits a generalized linear model, similarly to R's glm(). Also see the glmnet package.
+#'
+#' @param formula A symbolic description of the model to be fitted. Currently only a few formula
+#' operators are supported, including '~' and '+'.
+#' @param data DataFrame for training
+#' @param family Error distribution. "gaussian" -> linear regression, "binomial" -> logistic reg.
+#' @param lambda Regularization parameter
+#' @param alpha Elastic-net mixing parameter (see glmnet's documentation for details)
+#' @return a fitted MLlib model
+#' @rdname glm
+#' @export
+#' @examples
+#'\dontrun{
+#' sc <- sparkR.init()
+#' sqlContext <- sparkRSQL.init(sc)
+#' data(iris)
+#' df <- createDataFrame(sqlContext, iris)
+#' model <- glm(Sepal_Length ~ Sepal_Width, df)
+#'}
+setMethod("glm", signature(formula = "formula", family = "ANY", data = "DataFrame"),
+ function(formula, family = c("gaussian", "binomial"), data, lambda = 0, alpha = 0) {
+ family <- match.arg(family)
+ model <- callJStatic("org.apache.spark.ml.api.r.SparkRWrappers",
+ "fitRModelFormula", deparse(formula), data@sdf, family, lambda,
+ alpha)
+ return(new("PipelineModel", model = model))
+ })
+
+#' Make predictions from a model
+#'
+#' Makes predictions from a model produced by glm(), similarly to R's predict().
+#'
+#' @param model A fitted MLlib model
+#' @param newData DataFrame for testing
+#' @return DataFrame containing predicted values
+#' @rdname glm
+#' @export
+#' @examples
+#'\dontrun{
+#' model <- glm(y ~ x, trainingData)
+#' predicted <- predict(model, testData)
+#' showDF(predicted)
+#'}
+setMethod("predict", signature(object = "PipelineModel"),
+ function(object, newData) {
+ return(dataFrame(callJMethod(object@model, "transform", newData@sdf)))
+ })
diff --git a/R/pkg/R/pairRDD.R b/R/pkg/R/pairRDD.R
index 0f1179e0aa51a..ebc6ff65e9d0f 100644
--- a/R/pkg/R/pairRDD.R
+++ b/R/pkg/R/pairRDD.R
@@ -215,7 +215,6 @@ setMethod("partitionBy",
serializedHashFuncBytes,
getSerializedMode(x),
packageNamesArr,
- as.character(.sparkREnv$libname),
broadcastArr,
callJMethod(jrdd, "classTag"))
diff --git a/R/pkg/R/schema.R b/R/pkg/R/schema.R
index 15e2bdbd55d79..79c744ef29c23 100644
--- a/R/pkg/R/schema.R
+++ b/R/pkg/R/schema.R
@@ -69,11 +69,14 @@ structType.structField <- function(x, ...) {
#' @param ... further arguments passed to or from other methods
print.structType <- function(x, ...) {
cat("StructType\n",
- sapply(x$fields(), function(field) { paste("|-", "name = \"", field$name(),
- "\", type = \"", field$dataType.toString(),
- "\", nullable = ", field$nullable(), "\n",
- sep = "") })
- , sep = "")
+ sapply(x$fields(),
+ function(field) {
+ paste("|-", "name = \"", field$name(),
+ "\", type = \"", field$dataType.toString(),
+ "\", nullable = ", field$nullable(), "\n",
+ sep = "")
+ }),
+ sep = "")
}
#' structField
@@ -123,6 +126,7 @@ structField.character <- function(x, type, nullable = TRUE) {
}
options <- c("byte",
"integer",
+ "float",
"double",
"numeric",
"character",
diff --git a/R/pkg/R/serialize.R b/R/pkg/R/serialize.R
index 78535eff0d2f6..311021e5d8473 100644
--- a/R/pkg/R/serialize.R
+++ b/R/pkg/R/serialize.R
@@ -140,8 +140,8 @@ writeType <- function(con, class) {
jobj = "j",
environment = "e",
Date = "D",
- POSIXlt = 't',
- POSIXct = 't',
+ POSIXlt = "t",
+ POSIXct = "t",
stop(paste("Unsupported type for serialization", class)))
writeBin(charToRaw(type), con)
}
diff --git a/R/pkg/R/sparkR.R b/R/pkg/R/sparkR.R
index 048eb8ed541e4..79b79d70943cb 100644
--- a/R/pkg/R/sparkR.R
+++ b/R/pkg/R/sparkR.R
@@ -17,10 +17,6 @@
.sparkREnv <- new.env()
-sparkR.onLoad <- function(libname, pkgname) {
- .sparkREnv$libname <- libname
-}
-
# Utility function that returns TRUE if we have an active connection to the
# backend and FALSE otherwise
connExists <- function(env) {
@@ -80,7 +76,6 @@ sparkR.stop <- function() {
#' @param sparkEnvir Named list of environment variables to set on worker nodes.
#' @param sparkExecutorEnv Named list of environment variables to be used when launching executors.
#' @param sparkJars Character string vector of jar files to pass to the worker nodes.
-#' @param sparkRLibDir The path where R is installed on the worker nodes.
#' @param sparkPackages Character string vector of packages from spark-packages.org
#' @export
#' @examples
@@ -101,7 +96,6 @@ sparkR.init <- function(
sparkEnvir = list(),
sparkExecutorEnv = list(),
sparkJars = "",
- sparkRLibDir = "",
sparkPackages = "") {
if (exists(".sparkRjsc", envir = .sparkREnv)) {
@@ -146,7 +140,7 @@ sparkR.init <- function(
if (!file.exists(path)) {
stop("JVM is not ready after 10 seconds")
}
- f <- file(path, open='rb')
+ f <- file(path, open="rb")
backendPort <- readInt(f)
monitorPort <- readInt(f)
close(f)
@@ -170,10 +164,6 @@ sparkR.init <- function(
sparkHome <- normalizePath(sparkHome)
}
- if (nchar(sparkRLibDir) != 0) {
- .sparkREnv$libname <- sparkRLibDir
- }
-
sparkEnvirMap <- new.env()
for (varname in names(sparkEnvir)) {
sparkEnvirMap[[varname]] <- sparkEnvir[[varname]]
diff --git a/R/pkg/R/utils.R b/R/pkg/R/utils.R
index ea629a64f7158..3f45589a50443 100644
--- a/R/pkg/R/utils.R
+++ b/R/pkg/R/utils.R
@@ -41,8 +41,8 @@ convertJListToRList <- function(jList, flatten, logicalUpperBound = NULL,
if (isInstanceOf(obj, "scala.Tuple2")) {
# JavaPairRDD[Array[Byte], Array[Byte]].
- keyBytes = callJMethod(obj, "_1")
- valBytes = callJMethod(obj, "_2")
+ keyBytes <- callJMethod(obj, "_1")
+ valBytes <- callJMethod(obj, "_2")
res <- list(unserialize(keyBytes),
unserialize(valBytes))
} else {
@@ -390,14 +390,17 @@ processClosure <- function(node, oldEnv, defVars, checkedFuncs, newEnv) {
for (i in 1:nodeLen) {
processClosure(node[[i]], oldEnv, defVars, checkedFuncs, newEnv)
}
- } else { # if node[[1]] is length of 1, check for some R special functions.
+ } else {
+ # if node[[1]] is length of 1, check for some R special functions.
nodeChar <- as.character(node[[1]])
- if (nodeChar == "{" || nodeChar == "(") { # Skip start symbol.
+ if (nodeChar == "{" || nodeChar == "(") {
+ # Skip start symbol.
for (i in 2:nodeLen) {
processClosure(node[[i]], oldEnv, defVars, checkedFuncs, newEnv)
}
} else if (nodeChar == "<-" || nodeChar == "=" ||
- nodeChar == "<<-") { # Assignment Ops.
+ nodeChar == "<<-") {
+ # Assignment Ops.
defVar <- node[[2]]
if (length(defVar) == 1 && typeof(defVar) == "symbol") {
# Add the defined variable name into defVars.
@@ -408,14 +411,16 @@ processClosure <- function(node, oldEnv, defVars, checkedFuncs, newEnv) {
for (i in 3:nodeLen) {
processClosure(node[[i]], oldEnv, defVars, checkedFuncs, newEnv)
}
- } else if (nodeChar == "function") { # Function definition.
+ } else if (nodeChar == "function") {
+ # Function definition.
# Add parameter names.
newArgs <- names(node[[2]])
lapply(newArgs, function(arg) { addItemToAccumulator(defVars, arg) })
for (i in 3:nodeLen) {
processClosure(node[[i]], oldEnv, defVars, checkedFuncs, newEnv)
}
- } else if (nodeChar == "$") { # Skip the field.
+ } else if (nodeChar == "$") {
+ # Skip the field.
processClosure(node[[2]], oldEnv, defVars, checkedFuncs, newEnv)
} else if (nodeChar == "::" || nodeChar == ":::") {
processClosure(node[[3]], oldEnv, defVars, checkedFuncs, newEnv)
@@ -429,7 +434,8 @@ processClosure <- function(node, oldEnv, defVars, checkedFuncs, newEnv) {
(typeof(node) == "symbol" || typeof(node) == "language")) {
# Base case: current AST node is a leaf node and a symbol or a function call.
nodeChar <- as.character(node)
- if (!nodeChar %in% defVars$data) { # Not a function parameter or local variable.
+ if (!nodeChar %in% defVars$data) {
+ # Not a function parameter or local variable.
func.env <- oldEnv
topEnv <- parent.env(.GlobalEnv)
# Search in function environment, and function's enclosing environments
@@ -439,20 +445,24 @@ processClosure <- function(node, oldEnv, defVars, checkedFuncs, newEnv) {
while (!identical(func.env, topEnv)) {
# Namespaces other than "SparkR" will not be searched.
if (!isNamespace(func.env) ||
- (getNamespaceName(func.env) == "SparkR" &&
- !(nodeChar %in% getNamespaceExports("SparkR")))) { # Only include SparkR internals.
+ (getNamespaceName(func.env) == "SparkR" &&
+ !(nodeChar %in% getNamespaceExports("SparkR")))) {
+ # Only include SparkR internals.
+
# Set parameter 'inherits' to FALSE since we do not need to search in
# attached package environments.
if (tryCatch(exists(nodeChar, envir = func.env, inherits = FALSE),
error = function(e) { FALSE })) {
obj <- get(nodeChar, envir = func.env, inherits = FALSE)
- if (is.function(obj)) { # If the node is a function call.
+ if (is.function(obj)) {
+ # If the node is a function call.
funcList <- mget(nodeChar, envir = checkedFuncs, inherits = F,
ifnotfound = list(list(NULL)))[[1]]
found <- sapply(funcList, function(func) {
ifelse(identical(func, obj), TRUE, FALSE)
})
- if (sum(found) > 0) { # If function has been examined, ignore.
+ if (sum(found) > 0) {
+ # If function has been examined, ignore.
break
}
# Function has not been examined, record it and recursively clean its closure.
@@ -495,7 +505,8 @@ cleanClosure <- function(func, checkedFuncs = new.env()) {
# environment. First, function's arguments are added to defVars.
defVars <- initAccumulator()
argNames <- names(as.list(args(func)))
- for (i in 1:(length(argNames) - 1)) { # Remove the ending NULL in pairlist.
+ for (i in 1:(length(argNames) - 1)) {
+ # Remove the ending NULL in pairlist.
addItemToAccumulator(defVars, argNames[i])
}
# Recursively examine variables in the function body.
diff --git a/R/pkg/R/zzz.R b/R/pkg/R/zzz.R
deleted file mode 100644
index 301feade65fa3..0000000000000
--- a/R/pkg/R/zzz.R
+++ /dev/null
@@ -1,20 +0,0 @@
-#
-# Licensed to the Apache Software Foundation (ASF) under one or more
-# contributor license agreements. See the NOTICE file distributed with
-# this work for additional information regarding copyright ownership.
-# The ASF licenses this file to You under the Apache License, Version 2.0
-# (the "License"); you may not use this file except in compliance with
-# the License. You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-#
-
-.onLoad <- function(libname, pkgname) {
- sparkR.onLoad(libname, pkgname)
-}
diff --git a/R/pkg/inst/profile/general.R b/R/pkg/inst/profile/general.R
index 8fe711b622086..2a8a8213d0849 100644
--- a/R/pkg/inst/profile/general.R
+++ b/R/pkg/inst/profile/general.R
@@ -16,7 +16,7 @@
#
.First <- function() {
- home <- Sys.getenv("SPARK_HOME")
- .libPaths(c(file.path(home, "R", "lib"), .libPaths()))
+ packageDir <- Sys.getenv("SPARKR_PACKAGE_DIR")
+ .libPaths(c(packageDir, .libPaths()))
Sys.setenv(NOAWT=1)
}
diff --git a/R/pkg/inst/tests/test_binaryFile.R b/R/pkg/inst/tests/test_binaryFile.R
index ccaea18ecab2a..f2452ed97d2ea 100644
--- a/R/pkg/inst/tests/test_binaryFile.R
+++ b/R/pkg/inst/tests/test_binaryFile.R
@@ -20,7 +20,7 @@ context("functions on binary files")
# JavaSparkContext handle
sc <- sparkR.init()
-mockFile = c("Spark is pretty.", "Spark is awesome.")
+mockFile <- c("Spark is pretty.", "Spark is awesome.")
test_that("saveAsObjectFile()/objectFile() following textFile() works", {
fileName1 <- tempfile(pattern="spark-test", fileext=".tmp")
diff --git a/R/pkg/inst/tests/test_binary_function.R b/R/pkg/inst/tests/test_binary_function.R
index 3be8c65a6c1a0..dca0657c57e0d 100644
--- a/R/pkg/inst/tests/test_binary_function.R
+++ b/R/pkg/inst/tests/test_binary_function.R
@@ -76,7 +76,7 @@ test_that("zipPartitions() on RDDs", {
expect_equal(actual,
list(list(1, c(1,2), c(1,2,3)), list(2, c(3,4), c(4,5,6))))
- mockFile = c("Spark is pretty.", "Spark is awesome.")
+ mockFile <- c("Spark is pretty.", "Spark is awesome.")
fileName <- tempfile(pattern="spark-test", fileext=".tmp")
writeLines(mockFile, fileName)
diff --git a/R/pkg/inst/tests/test_mllib.R b/R/pkg/inst/tests/test_mllib.R
new file mode 100644
index 0000000000000..a492763344ae6
--- /dev/null
+++ b/R/pkg/inst/tests/test_mllib.R
@@ -0,0 +1,42 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+library(testthat)
+
+context("MLlib functions")
+
+# Tests for MLlib functions in SparkR
+
+sc <- sparkR.init()
+
+sqlContext <- sparkRSQL.init(sc)
+
+test_that("glm and predict", {
+ training <- createDataFrame(sqlContext, iris)
+ test <- select(training, "Sepal_Length")
+ model <- glm(Sepal_Width ~ Sepal_Length, training, family = "gaussian")
+ prediction <- predict(model, test)
+ expect_equal(typeof(take(select(prediction, "prediction"), 1)$prediction), "double")
+})
+
+test_that("predictions match with native glm", {
+ training <- createDataFrame(sqlContext, iris)
+ model <- glm(Sepal_Width ~ Sepal_Length, data = training)
+ vals <- collect(select(predict(model, training), "prediction"))
+ rVals <- predict(glm(Sepal.Width ~ Sepal.Length, data = iris), iris)
+ expect_true(all(abs(rVals - vals) < 1e-9), rVals - vals)
+})
diff --git a/R/pkg/inst/tests/test_rdd.R b/R/pkg/inst/tests/test_rdd.R
index b79692873cec3..6c3aaab8c711e 100644
--- a/R/pkg/inst/tests/test_rdd.R
+++ b/R/pkg/inst/tests/test_rdd.R
@@ -447,7 +447,7 @@ test_that("zipRDD() on RDDs", {
expect_equal(actual,
list(list(0, 1000), list(1, 1001), list(2, 1002), list(3, 1003), list(4, 1004)))
- mockFile = c("Spark is pretty.", "Spark is awesome.")
+ mockFile <- c("Spark is pretty.", "Spark is awesome.")
fileName <- tempfile(pattern="spark-test", fileext=".tmp")
writeLines(mockFile, fileName)
@@ -483,7 +483,7 @@ test_that("cartesian() on RDDs", {
actual <- collect(cartesian(rdd, emptyRdd))
expect_equal(actual, list())
- mockFile = c("Spark is pretty.", "Spark is awesome.")
+ mockFile <- c("Spark is pretty.", "Spark is awesome.")
fileName <- tempfile(pattern="spark-test", fileext=".tmp")
writeLines(mockFile, fileName)
diff --git a/R/pkg/inst/tests/test_sparkSQL.R b/R/pkg/inst/tests/test_sparkSQL.R
index b0ea38854304e..62fe48a5d6c7b 100644
--- a/R/pkg/inst/tests/test_sparkSQL.R
+++ b/R/pkg/inst/tests/test_sparkSQL.R
@@ -57,9 +57,9 @@ test_that("infer types", {
expect_equal(infer_type(as.Date("2015-03-11")), "date")
expect_equal(infer_type(as.POSIXlt("2015-03-11 12:13:04.043")), "timestamp")
expect_equal(infer_type(c(1L, 2L)),
- list(type = 'array', elementType = "integer", containsNull = TRUE))
+ list(type = "array", elementType = "integer", containsNull = TRUE))
expect_equal(infer_type(list(1L, 2L)),
- list(type = 'array', elementType = "integer", containsNull = TRUE))
+ list(type = "array", elementType = "integer", containsNull = TRUE))
testStruct <- infer_type(list(a = 1L, b = "2"))
expect_equal(class(testStruct), "structType")
checkStructField(testStruct$fields()[[1]], "a", "IntegerType", TRUE)
@@ -108,6 +108,32 @@ test_that("create DataFrame from RDD", {
expect_equal(count(df), 10)
expect_equal(columns(df), c("a", "b"))
expect_equal(dtypes(df), list(c("a", "int"), c("b", "string")))
+
+ df <- jsonFile(sqlContext, jsonPathNa)
+ hiveCtx <- tryCatch({
+ newJObject("org.apache.spark.sql.hive.test.TestHiveContext", ssc)
+ }, error = function(err) {
+ skip("Hive is not build with SparkSQL, skipped")
+ })
+ sql(hiveCtx, "CREATE TABLE people (name string, age double, height float)")
+ insertInto(df, "people")
+ expect_equal(sql(hiveCtx, "SELECT age from people WHERE name = 'Bob'"), c(16))
+ expect_equal(sql(hiveCtx, "SELECT height from people WHERE name ='Bob'"), c(176.5))
+
+ schema <- structType(structField("name", "string"), structField("age", "integer"),
+ structField("height", "float"))
+ df2 <- createDataFrame(sqlContext, df.toRDD, schema)
+ expect_equal(columns(df2), c("name", "age", "height"))
+ expect_equal(dtypes(df2), list(c("name", "string"), c("age", "int"), c("height", "float")))
+ expect_equal(collect(where(df2, df2$name == "Bob")), c("Bob", 16, 176.5))
+
+ localDF <- data.frame(name=c("John", "Smith", "Sarah"), age=c(19, 23, 18), height=c(164.10, 181.4, 173.7))
+ df <- createDataFrame(sqlContext, localDF, schema)
+ expect_is(df, "DataFrame")
+ expect_equal(count(df), 3)
+ expect_equal(columns(df), c("name", "age", "height"))
+ expect_equal(dtypes(df), list(c("name", "string"), c("age", "int"), c("height", "float")))
+ expect_equal(collect(where(df, df$name == "John")), c("John", 19, 164.10))
})
test_that("convert NAs to null type in DataFrames", {
@@ -612,6 +638,18 @@ test_that("column functions", {
c7 <- floor(c) + log(c) + log10(c) + log1p(c) + rint(c)
c8 <- sign(c) + sin(c) + sinh(c) + tan(c) + tanh(c)
c9 <- toDegrees(c) + toRadians(c)
+
+ df <- jsonFile(sqlContext, jsonPath)
+ df2 <- select(df, between(df$age, c(20, 30)), between(df$age, c(10, 20)))
+ expect_equal(collect(df2)[[2, 1]], TRUE)
+ expect_equal(collect(df2)[[2, 2]], FALSE)
+ expect_equal(collect(df2)[[3, 1]], FALSE)
+ expect_equal(collect(df2)[[3, 2]], TRUE)
+
+ df3 <- select(df, between(df$name, c("Apache", "Spark")))
+ expect_equal(collect(df3)[[1, 1]], TRUE)
+ expect_equal(collect(df3)[[2, 1]], FALSE)
+ expect_equal(collect(df3)[[3, 1]], TRUE)
})
test_that("column binary mathfunctions", {
@@ -949,6 +987,19 @@ test_that("fillna() on a DataFrame", {
expect_identical(expected, actual)
})
+test_that("crosstab() on a DataFrame", {
+ rdd <- lapply(parallelize(sc, 0:3), function(x) {
+ list(paste0("a", x %% 3), paste0("b", x %% 2))
+ })
+ df <- toDF(rdd, list("a", "b"))
+ ct <- crosstab(df, "a", "b")
+ ordered <- ct[order(ct$a_b),]
+ row.names(ordered) <- NULL
+ expected <- data.frame("a_b" = c("a0", "a1", "a2"), "b0" = c(1, 0, 1), "b1" = c(1, 1, 0),
+ stringsAsFactors = FALSE, row.names = NULL)
+ expect_identical(expected, ordered)
+})
+
unlink(parquetPath)
unlink(jsonPath)
unlink(jsonPathNa)
diff --git a/R/pkg/inst/tests/test_textFile.R b/R/pkg/inst/tests/test_textFile.R
index 58318dfef71ab..a9cf83dbdbdb1 100644
--- a/R/pkg/inst/tests/test_textFile.R
+++ b/R/pkg/inst/tests/test_textFile.R
@@ -20,7 +20,7 @@ context("the textFile() function")
# JavaSparkContext handle
sc <- sparkR.init()
-mockFile = c("Spark is pretty.", "Spark is awesome.")
+mockFile <- c("Spark is pretty.", "Spark is awesome.")
test_that("textFile() on a local file returns an RDD", {
fileName <- tempfile(pattern="spark-test", fileext=".tmp")
diff --git a/R/pkg/inst/tests/test_utils.R b/R/pkg/inst/tests/test_utils.R
index aa0d2a66b9082..12df4cf4f65b7 100644
--- a/R/pkg/inst/tests/test_utils.R
+++ b/R/pkg/inst/tests/test_utils.R
@@ -119,7 +119,7 @@ test_that("cleanClosure on R functions", {
# Test for overriding variables in base namespace (Issue: SparkR-196).
nums <- as.list(1:10)
rdd <- parallelize(sc, nums, 2L)
- t = 4 # Override base::t in .GlobalEnv.
+ t <- 4 # Override base::t in .GlobalEnv.
f <- function(x) { x > t }
newF <- cleanClosure(f)
env <- environment(newF)
diff --git a/bin/spark-shell b/bin/spark-shell
index a6dc863d83fc6..00ab7afd118b5 100755
--- a/bin/spark-shell
+++ b/bin/spark-shell
@@ -47,11 +47,11 @@ function main() {
# (see https://github.com/sbt/sbt/issues/562).
stty -icanon min 1 -echo > /dev/null 2>&1
export SPARK_SUBMIT_OPTS="$SPARK_SUBMIT_OPTS -Djline.terminal=unix"
- "$FWDIR"/bin/spark-submit --class org.apache.spark.repl.Main "$@"
+ "$FWDIR"/bin/spark-submit --class org.apache.spark.repl.Main --name "Spark shell" "$@"
stty icanon echo > /dev/null 2>&1
else
export SPARK_SUBMIT_OPTS
- "$FWDIR"/bin/spark-submit --class org.apache.spark.repl.Main "$@"
+ "$FWDIR"/bin/spark-submit --class org.apache.spark.repl.Main --name "Spark shell" "$@"
fi
}
diff --git a/bin/spark-shell2.cmd b/bin/spark-shell2.cmd
index 251309d67f860..b9b0f510d7f5d 100644
--- a/bin/spark-shell2.cmd
+++ b/bin/spark-shell2.cmd
@@ -32,4 +32,4 @@ if "x%SPARK_SUBMIT_OPTS%"=="x" (
set SPARK_SUBMIT_OPTS="%SPARK_SUBMIT_OPTS% -Dscala.usejavacp=true"
:run_shell
-%SPARK_HOME%\bin\spark-submit2.cmd --class org.apache.spark.repl.Main %*
+%SPARK_HOME%\bin\spark-submit2.cmd --class org.apache.spark.repl.Main --name "Spark shell" %*
diff --git a/build/mvn b/build/mvn
index e8364181e8230..f62f61ee1c416 100755
--- a/build/mvn
+++ b/build/mvn
@@ -112,10 +112,17 @@ install_scala() {
# the environment
ZINC_PORT=${ZINC_PORT:-"3030"}
+# Check for the `--force` flag dictating that `mvn` should be downloaded
+# regardless of whether the system already has a `mvn` install
+if [ "$1" == "--force" ]; then
+ FORCE_MVN=1
+ shift
+fi
+
# Install Maven if necessary
MVN_BIN="$(command -v mvn)"
-if [ ! "$MVN_BIN" ]; then
+if [ ! "$MVN_BIN" -o -n "$FORCE_MVN" ]; then
install_mvn
fi
@@ -139,5 +146,7 @@ fi
# Set any `mvn` options if not already present
export MAVEN_OPTS=${MAVEN_OPTS:-"$_COMPILE_JVM_OPTS"}
+echo "Using \`mvn\` from path: $MVN_BIN"
+
# Last, call the `mvn` command as usual
${MVN_BIN} "$@"
diff --git a/build/sbt-launch-lib.bash b/build/sbt-launch-lib.bash
index 504be48b358fa..7930a38b9674a 100755
--- a/build/sbt-launch-lib.bash
+++ b/build/sbt-launch-lib.bash
@@ -51,9 +51,13 @@ acquire_sbt_jar () {
printf "Attempting to fetch sbt\n"
JAR_DL="${JAR}.part"
if [ $(command -v curl) ]; then
- (curl --silent ${URL1} > "${JAR_DL}" || curl --silent ${URL2} > "${JAR_DL}") && mv "${JAR_DL}" "${JAR}"
+ (curl --fail --location --silent ${URL1} > "${JAR_DL}" ||\
+ (rm -f "${JAR_DL}" && curl --fail --location --silent ${URL2} > "${JAR_DL}")) &&\
+ mv "${JAR_DL}" "${JAR}"
elif [ $(command -v wget) ]; then
- (wget --quiet ${URL1} -O "${JAR_DL}" || wget --quiet ${URL2} -O "${JAR_DL}") && mv "${JAR_DL}" "${JAR}"
+ (wget --quiet ${URL1} -O "${JAR_DL}" ||\
+ (rm -f "${JAR_DL}" && wget --quiet ${URL2} -O "${JAR_DL}")) &&\
+ mv "${JAR_DL}" "${JAR}"
else
printf "You do not have curl or wget installed, please install sbt manually from http://www.scala-sbt.org/\n"
exit -1
diff --git a/core/pom.xml b/core/pom.xml
index 558cc3fb9f2f3..95f36eb348698 100644
--- a/core/pom.xml
+++ b/core/pom.xml
@@ -261,7 +261,7 @@
com.fasterxml.jackson.module
- jackson-module-scala_2.10
+ jackson-module-scala_${scala.binary.version}org.apache.derby
@@ -372,6 +372,11 @@
junit-interfacetest
+
+ org.apache.curator
+ curator-test
+ test
+ net.razorvinepyrolite
diff --git a/core/src/main/java/org/apache/spark/JavaSparkListener.java b/core/src/main/java/org/apache/spark/JavaSparkListener.java
index 646496f313507..fa9acf0a15b88 100644
--- a/core/src/main/java/org/apache/spark/JavaSparkListener.java
+++ b/core/src/main/java/org/apache/spark/JavaSparkListener.java
@@ -17,23 +17,7 @@
package org.apache.spark;
-import org.apache.spark.scheduler.SparkListener;
-import org.apache.spark.scheduler.SparkListenerApplicationEnd;
-import org.apache.spark.scheduler.SparkListenerApplicationStart;
-import org.apache.spark.scheduler.SparkListenerBlockManagerAdded;
-import org.apache.spark.scheduler.SparkListenerBlockManagerRemoved;
-import org.apache.spark.scheduler.SparkListenerEnvironmentUpdate;
-import org.apache.spark.scheduler.SparkListenerExecutorAdded;
-import org.apache.spark.scheduler.SparkListenerExecutorMetricsUpdate;
-import org.apache.spark.scheduler.SparkListenerExecutorRemoved;
-import org.apache.spark.scheduler.SparkListenerJobEnd;
-import org.apache.spark.scheduler.SparkListenerJobStart;
-import org.apache.spark.scheduler.SparkListenerStageCompleted;
-import org.apache.spark.scheduler.SparkListenerStageSubmitted;
-import org.apache.spark.scheduler.SparkListenerTaskEnd;
-import org.apache.spark.scheduler.SparkListenerTaskGettingResult;
-import org.apache.spark.scheduler.SparkListenerTaskStart;
-import org.apache.spark.scheduler.SparkListenerUnpersistRDD;
+import org.apache.spark.scheduler.*;
/**
* Java clients should extend this class instead of implementing
@@ -94,4 +78,8 @@ public void onExecutorAdded(SparkListenerExecutorAdded executorAdded) { }
@Override
public void onExecutorRemoved(SparkListenerExecutorRemoved executorRemoved) { }
+
+ @Override
+ public void onBlockUpdated(SparkListenerBlockUpdated blockUpdated) { }
+
}
diff --git a/core/src/main/java/org/apache/spark/SparkFirehoseListener.java b/core/src/main/java/org/apache/spark/SparkFirehoseListener.java
index fbc5666959055..1214d05ba6063 100644
--- a/core/src/main/java/org/apache/spark/SparkFirehoseListener.java
+++ b/core/src/main/java/org/apache/spark/SparkFirehoseListener.java
@@ -112,4 +112,10 @@ public final void onExecutorAdded(SparkListenerExecutorAdded executorAdded) {
public final void onExecutorRemoved(SparkListenerExecutorRemoved executorRemoved) {
onEvent(executorRemoved);
}
+
+ @Override
+ public void onBlockUpdated(SparkListenerBlockUpdated blockUpdated) {
+ onEvent(blockUpdated);
+ }
+
}
diff --git a/core/src/main/java/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriter.java b/core/src/main/java/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriter.java
index d3d6280284beb..0b8b604e18494 100644
--- a/core/src/main/java/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriter.java
+++ b/core/src/main/java/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriter.java
@@ -75,7 +75,7 @@ final class BypassMergeSortShuffleWriter implements SortShuffleFileWriter<
private final Serializer serializer;
/** Array of file writers, one for each partition */
- private BlockObjectWriter[] partitionWriters;
+ private DiskBlockObjectWriter[] partitionWriters;
public BypassMergeSortShuffleWriter(
SparkConf conf,
@@ -101,7 +101,7 @@ public void insertAll(Iterator> records) throws IOException {
}
final SerializerInstance serInstance = serializer.newInstance();
final long openStartTime = System.nanoTime();
- partitionWriters = new BlockObjectWriter[numPartitions];
+ partitionWriters = new DiskBlockObjectWriter[numPartitions];
for (int i = 0; i < numPartitions; i++) {
final Tuple2 tempShuffleBlockIdPlusFile =
blockManager.diskBlockManager().createTempShuffleBlock();
@@ -121,7 +121,7 @@ public void insertAll(Iterator> records) throws IOException {
partitionWriters[partitioner.getPartition(key)].write(key, record._2());
}
- for (BlockObjectWriter writer : partitionWriters) {
+ for (DiskBlockObjectWriter writer : partitionWriters) {
writer.commitAndClose();
}
}
@@ -169,7 +169,7 @@ public void stop() throws IOException {
if (partitionWriters != null) {
try {
final DiskBlockManager diskBlockManager = blockManager.diskBlockManager();
- for (BlockObjectWriter writer : partitionWriters) {
+ for (DiskBlockObjectWriter writer : partitionWriters) {
// This method explicitly does _not_ throw exceptions:
writer.revertPartialWritesAndClose();
if (!diskBlockManager.getFile(writer.blockId()).delete()) {
diff --git a/core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleExternalSorter.java b/core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleExternalSorter.java
index 56289573209fb..1d460432be9ff 100644
--- a/core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleExternalSorter.java
+++ b/core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleExternalSorter.java
@@ -157,7 +157,7 @@ private void writeSortedFile(boolean isLastFile) throws IOException {
// Currently, we need to open a new DiskBlockObjectWriter for each partition; we can avoid this
// after SPARK-5581 is fixed.
- BlockObjectWriter writer;
+ DiskBlockObjectWriter writer;
// Small writes to DiskBlockObjectWriter will be fairly inefficient. Since there doesn't seem to
// be an API to directly transfer bytes from managed memory to the disk writer, we buffer
diff --git a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/PrefixComparators.java b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/PrefixComparators.java
index 438742565c51d..bf1bc5dffba78 100644
--- a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/PrefixComparators.java
+++ b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/PrefixComparators.java
@@ -23,6 +23,7 @@
import org.apache.spark.annotation.Private;
import org.apache.spark.unsafe.types.UTF8String;
+import org.apache.spark.util.Utils;
@Private
public class PrefixComparators {
@@ -82,7 +83,7 @@ public static final class FloatPrefixComparator extends PrefixComparator {
public int compare(long aPrefix, long bPrefix) {
float a = Float.intBitsToFloat((int) aPrefix);
float b = Float.intBitsToFloat((int) bPrefix);
- return (a < b) ? -1 : (a > b) ? 1 : 0;
+ return Utils.nanSafeCompareFloats(a, b);
}
public long computePrefix(float value) {
@@ -97,7 +98,7 @@ public static final class DoublePrefixComparator extends PrefixComparator {
public int compare(long aPrefix, long bPrefix) {
double a = Double.longBitsToDouble(aPrefix);
double b = Double.longBitsToDouble(bPrefix);
- return (a < b) ? -1 : (a > b) ? 1 : 0;
+ return Utils.nanSafeCompareDoubles(a, b);
}
public long computePrefix(double value) {
diff --git a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSorterSpillWriter.java b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSorterSpillWriter.java
index b8d66659804ad..71eed29563d4a 100644
--- a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSorterSpillWriter.java
+++ b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSorterSpillWriter.java
@@ -26,7 +26,7 @@
import org.apache.spark.serializer.DummySerializerInstance;
import org.apache.spark.storage.BlockId;
import org.apache.spark.storage.BlockManager;
-import org.apache.spark.storage.BlockObjectWriter;
+import org.apache.spark.storage.DiskBlockObjectWriter;
import org.apache.spark.storage.TempLocalBlockId;
import org.apache.spark.unsafe.PlatformDependent;
@@ -47,7 +47,7 @@ final class UnsafeSorterSpillWriter {
private final File file;
private final BlockId blockId;
private final int numRecordsToWrite;
- private BlockObjectWriter writer;
+ private DiskBlockObjectWriter writer;
private int numRecordsSpilled = 0;
public UnsafeSorterSpillWriter(
diff --git a/core/src/main/resources/org/apache/spark/ui/static/additional-metrics.js b/core/src/main/resources/org/apache/spark/ui/static/additional-metrics.js
index 0b450dc76bc38..3c8ddddf07b1e 100644
--- a/core/src/main/resources/org/apache/spark/ui/static/additional-metrics.js
+++ b/core/src/main/resources/org/apache/spark/ui/static/additional-metrics.js
@@ -19,6 +19,9 @@
* to be registered after the page loads. */
$(function() {
$("span.expand-additional-metrics").click(function(){
+ var status = window.localStorage.getItem("expand-additional-metrics") == "true";
+ status = !status;
+
// Expand the list of additional metrics.
var additionalMetricsDiv = $(this).parent().find('.additional-metrics');
$(additionalMetricsDiv).toggleClass('collapsed');
@@ -26,17 +29,31 @@ $(function() {
// Switch the class of the arrow from open to closed.
$(this).find('.expand-additional-metrics-arrow').toggleClass('arrow-open');
$(this).find('.expand-additional-metrics-arrow').toggleClass('arrow-closed');
+
+ window.localStorage.setItem("expand-additional-metrics", "" + status);
});
+ if (window.localStorage.getItem("expand-additional-metrics") == "true") {
+ // Set it to false so that the click function can revert it
+ window.localStorage.setItem("expand-additional-metrics", "false");
+ $("span.expand-additional-metrics").trigger("click");
+ }
+
stripeSummaryTable();
$('input[type="checkbox"]').click(function() {
- var column = "table ." + $(this).attr("name");
+ var name = $(this).attr("name")
+ var column = "table ." + name;
+ var status = window.localStorage.getItem(name) == "true";
+ status = !status;
$(column).toggle();
stripeSummaryTable();
+ window.localStorage.setItem(name, "" + status);
});
$("#select-all-metrics").click(function() {
+ var status = window.localStorage.getItem("select-all-metrics") == "true";
+ status = !status;
if (this.checked) {
// Toggle all un-checked options.
$('input[type="checkbox"]:not(:checked)').trigger('click');
@@ -44,6 +61,21 @@ $(function() {
// Toggle all checked options.
$('input[type="checkbox"]:checked').trigger('click');
}
+ window.localStorage.setItem("select-all-metrics", "" + status);
+ });
+
+ if (window.localStorage.getItem("select-all-metrics") == "true") {
+ $("#select-all-metrics").attr('checked', status);
+ }
+
+ $("span.additional-metric-title").parent().find('input[type="checkbox"]').each(function() {
+ var name = $(this).attr("name")
+ // If name is undefined, then skip it because it's the "select-all-metrics" checkbox
+ if (name && window.localStorage.getItem(name) == "true") {
+ // Set it to false so that the click function can revert it
+ window.localStorage.setItem(name, "false");
+ $(this).trigger("click")
+ }
});
// Trigger a click on the checkbox if a user clicks the label next to it.
diff --git a/core/src/main/resources/org/apache/spark/ui/static/spark-dag-viz.js b/core/src/main/resources/org/apache/spark/ui/static/spark-dag-viz.js
index 9fa53baaf4212..4a893bc0189aa 100644
--- a/core/src/main/resources/org/apache/spark/ui/static/spark-dag-viz.js
+++ b/core/src/main/resources/org/apache/spark/ui/static/spark-dag-viz.js
@@ -72,6 +72,14 @@ var StagePageVizConstants = {
rankSep: 40
};
+/*
+ * Return "expand-dag-viz-arrow-job" if forJob is true.
+ * Otherwise, return "expand-dag-viz-arrow-stage".
+ */
+function expandDagVizArrowKey(forJob) {
+ return forJob ? "expand-dag-viz-arrow-job" : "expand-dag-viz-arrow-stage";
+}
+
/*
* Show or hide the RDD DAG visualization.
*
@@ -79,6 +87,9 @@ var StagePageVizConstants = {
* This is the narrow interface called from the Scala UI code.
*/
function toggleDagViz(forJob) {
+ var status = window.localStorage.getItem(expandDagVizArrowKey(forJob)) == "true";
+ status = !status;
+
var arrowSelector = ".expand-dag-viz-arrow";
$(arrowSelector).toggleClass('arrow-closed');
$(arrowSelector).toggleClass('arrow-open');
@@ -93,8 +104,24 @@ function toggleDagViz(forJob) {
// Save the graph for later so we don't have to render it again
graphContainer().style("display", "none");
}
+
+ window.localStorage.setItem(expandDagVizArrowKey(forJob), "" + status);
}
+$(function (){
+ if (window.localStorage.getItem(expandDagVizArrowKey(false)) == "true") {
+ // Set it to false so that the click function can revert it
+ window.localStorage.setItem(expandDagVizArrowKey(false), "false");
+ toggleDagViz(false);
+ }
+
+ if (window.localStorage.getItem(expandDagVizArrowKey(true)) == "true") {
+ // Set it to false so that the click function can revert it
+ window.localStorage.setItem(expandDagVizArrowKey(true), "false");
+ toggleDagViz(true);
+ }
+});
+
/*
* Render the RDD DAG visualization.
*
diff --git a/core/src/main/resources/org/apache/spark/ui/static/timeline-view.js b/core/src/main/resources/org/apache/spark/ui/static/timeline-view.js
index ca74ef9d7e94e..f4453c71df1ea 100644
--- a/core/src/main/resources/org/apache/spark/ui/static/timeline-view.js
+++ b/core/src/main/resources/org/apache/spark/ui/static/timeline-view.js
@@ -66,14 +66,27 @@ function drawApplicationTimeline(groupArray, eventObjArray, startTime) {
setupJobEventAction();
$("span.expand-application-timeline").click(function() {
+ var status = window.localStorage.getItem("expand-application-timeline") == "true";
+ status = !status;
+
$("#application-timeline").toggleClass('collapsed');
// Switch the class of the arrow from open to closed.
$(this).find('.expand-application-timeline-arrow').toggleClass('arrow-open');
$(this).find('.expand-application-timeline-arrow').toggleClass('arrow-closed');
+
+ window.localStorage.setItem("expand-application-timeline", "" + status);
});
}
+$(function (){
+ if (window.localStorage.getItem("expand-application-timeline") == "true") {
+ // Set it to false so that the click function can revert it
+ window.localStorage.setItem("expand-application-timeline", "false");
+ $("span.expand-application-timeline").trigger('click');
+ }
+});
+
function drawJobTimeline(groupArray, eventObjArray, startTime) {
var groups = new vis.DataSet(groupArray);
var items = new vis.DataSet(eventObjArray);
@@ -125,14 +138,27 @@ function drawJobTimeline(groupArray, eventObjArray, startTime) {
setupStageEventAction();
$("span.expand-job-timeline").click(function() {
+ var status = window.localStorage.getItem("expand-job-timeline") == "true";
+ status = !status;
+
$("#job-timeline").toggleClass('collapsed');
// Switch the class of the arrow from open to closed.
$(this).find('.expand-job-timeline-arrow').toggleClass('arrow-open');
$(this).find('.expand-job-timeline-arrow').toggleClass('arrow-closed');
+
+ window.localStorage.setItem("expand-job-timeline", "" + status);
});
}
+$(function (){
+ if (window.localStorage.getItem("expand-job-timeline") == "true") {
+ // Set it to false so that the click function can revert it
+ window.localStorage.setItem("expand-job-timeline", "false");
+ $("span.expand-job-timeline").trigger('click');
+ }
+});
+
function drawTaskAssignmentTimeline(groupArray, eventObjArray, minLaunchTime, maxFinishTime) {
var groups = new vis.DataSet(groupArray);
var items = new vis.DataSet(eventObjArray);
@@ -176,14 +202,27 @@ function drawTaskAssignmentTimeline(groupArray, eventObjArray, minLaunchTime, ma
setupZoomable("#task-assignment-timeline-zoom-lock", taskTimeline);
$("span.expand-task-assignment-timeline").click(function() {
+ var status = window.localStorage.getItem("expand-task-assignment-timeline") == "true";
+ status = !status;
+
$("#task-assignment-timeline").toggleClass("collapsed");
// Switch the class of the arrow from open to closed.
$(this).find(".expand-task-assignment-timeline-arrow").toggleClass("arrow-open");
$(this).find(".expand-task-assignment-timeline-arrow").toggleClass("arrow-closed");
+
+ window.localStorage.setItem("expand-task-assignment-timeline", "" + status);
});
}
+$(function (){
+ if (window.localStorage.getItem("expand-task-assignment-timeline") == "true") {
+ // Set it to false so that the click function can revert it
+ window.localStorage.setItem("expand-task-assignment-timeline", "false");
+ $("span.expand-task-assignment-timeline").trigger('click');
+ }
+});
+
function setupExecutorEventAction() {
$(".item.box.executor").each(function () {
$(this).hover(
diff --git a/core/src/main/scala/org/apache/spark/Accumulators.scala b/core/src/main/scala/org/apache/spark/Accumulators.scala
index 5a8d17bd99933..2f4fcac890eef 100644
--- a/core/src/main/scala/org/apache/spark/Accumulators.scala
+++ b/core/src/main/scala/org/apache/spark/Accumulators.scala
@@ -20,7 +20,8 @@ package org.apache.spark
import java.io.{ObjectInputStream, Serializable}
import scala.collection.generic.Growable
-import scala.collection.mutable.Map
+import scala.collection.Map
+import scala.collection.mutable
import scala.ref.WeakReference
import scala.reflect.ClassTag
@@ -39,25 +40,44 @@ import org.apache.spark.util.Utils
* @param initialValue initial value of accumulator
* @param param helper object defining how to add elements of type `R` and `T`
* @param name human-readable name for use in Spark's web UI
+ * @param internal if this [[Accumulable]] is internal. Internal [[Accumulable]]s will be reported
+ * to the driver via heartbeats. For internal [[Accumulable]]s, `R` must be
+ * thread safe so that they can be reported correctly.
* @tparam R the full accumulated data (result type)
* @tparam T partial data that can be added in
*/
-class Accumulable[R, T] (
+class Accumulable[R, T] private[spark] (
@transient initialValue: R,
param: AccumulableParam[R, T],
- val name: Option[String])
+ val name: Option[String],
+ internal: Boolean)
extends Serializable {
+ private[spark] def this(
+ @transient initialValue: R, param: AccumulableParam[R, T], internal: Boolean) = {
+ this(initialValue, param, None, internal)
+ }
+
+ def this(@transient initialValue: R, param: AccumulableParam[R, T], name: Option[String]) =
+ this(initialValue, param, name, false)
+
def this(@transient initialValue: R, param: AccumulableParam[R, T]) =
this(initialValue, param, None)
val id: Long = Accumulators.newId
- @transient private var value_ = initialValue // Current value on master
+ @volatile @transient private var value_ : R = initialValue // Current value on master
val zero = param.zero(initialValue) // Zero value to be passed to workers
private var deserialized = false
- Accumulators.register(this, true)
+ Accumulators.register(this)
+
+ /**
+ * If this [[Accumulable]] is internal. Internal [[Accumulable]]s will be reported to the driver
+ * via heartbeats. For internal [[Accumulable]]s, `R` must be thread safe so that they can be
+ * reported correctly.
+ */
+ private[spark] def isInternal: Boolean = internal
/**
* Add more data to this accumulator / accumulable
@@ -132,7 +152,8 @@ class Accumulable[R, T] (
in.defaultReadObject()
value_ = zero
deserialized = true
- Accumulators.register(this, false)
+ val taskContext = TaskContext.get()
+ taskContext.registerAccumulator(this)
}
override def toString: String = if (value_ == null) "null" else value_.toString
@@ -284,16 +305,7 @@ private[spark] object Accumulators extends Logging {
* It keeps weak references to these objects so that accumulators can be garbage-collected
* once the RDDs and user-code that reference them are cleaned up.
*/
- val originals = Map[Long, WeakReference[Accumulable[_, _]]]()
-
- /**
- * This thread-local map holds per-task copies of accumulators; it is used to collect the set
- * of accumulator updates to send back to the driver when tasks complete. After tasks complete,
- * this map is cleared by `Accumulators.clear()` (see Executor.scala).
- */
- private val localAccums = new ThreadLocal[Map[Long, Accumulable[_, _]]]() {
- override protected def initialValue() = Map[Long, Accumulable[_, _]]()
- }
+ val originals = mutable.Map[Long, WeakReference[Accumulable[_, _]]]()
private var lastId: Long = 0
@@ -302,19 +314,8 @@ private[spark] object Accumulators extends Logging {
lastId
}
- def register(a: Accumulable[_, _], original: Boolean): Unit = synchronized {
- if (original) {
- originals(a.id) = new WeakReference[Accumulable[_, _]](a)
- } else {
- localAccums.get()(a.id) = a
- }
- }
-
- // Clear the local (non-original) accumulators for the current thread
- def clear() {
- synchronized {
- localAccums.get.clear()
- }
+ def register(a: Accumulable[_, _]): Unit = synchronized {
+ originals(a.id) = new WeakReference[Accumulable[_, _]](a)
}
def remove(accId: Long) {
@@ -323,15 +324,6 @@ private[spark] object Accumulators extends Logging {
}
}
- // Get the values of the local accumulators for the current thread (by ID)
- def values: Map[Long, Any] = synchronized {
- val ret = Map[Long, Any]()
- for ((id, accum) <- localAccums.get) {
- ret(id) = accum.localValue
- }
- return ret
- }
-
// Add values to the original accumulators with some given IDs
def add(values: Map[Long, Any]): Unit = synchronized {
for ((id, value) <- values) {
diff --git a/core/src/main/scala/org/apache/spark/ExecutorAllocationManager.scala b/core/src/main/scala/org/apache/spark/ExecutorAllocationManager.scala
index 0c50b4002cf7b..648bcfe28cad2 100644
--- a/core/src/main/scala/org/apache/spark/ExecutorAllocationManager.scala
+++ b/core/src/main/scala/org/apache/spark/ExecutorAllocationManager.scala
@@ -20,6 +20,7 @@ package org.apache.spark
import java.util.concurrent.TimeUnit
import scala.collection.mutable
+import scala.util.control.ControlThrowable
import com.codahale.metrics.{Gauge, MetricRegistry}
@@ -211,7 +212,16 @@ private[spark] class ExecutorAllocationManager(
listenerBus.addListener(listener)
val scheduleTask = new Runnable() {
- override def run(): Unit = Utils.logUncaughtExceptions(schedule())
+ override def run(): Unit = {
+ try {
+ schedule()
+ } catch {
+ case ct: ControlThrowable =>
+ throw ct
+ case t: Throwable =>
+ logWarning(s"Uncaught exception in thread ${Thread.currentThread().getName}", t)
+ }
+ }
}
executor.scheduleAtFixedRate(scheduleTask, 0, intervalMillis, TimeUnit.MILLISECONDS)
}
diff --git a/core/src/main/scala/org/apache/spark/HeartbeatReceiver.scala b/core/src/main/scala/org/apache/spark/HeartbeatReceiver.scala
index 221b1dab43278..43dd4a170731d 100644
--- a/core/src/main/scala/org/apache/spark/HeartbeatReceiver.scala
+++ b/core/src/main/scala/org/apache/spark/HeartbeatReceiver.scala
@@ -181,7 +181,9 @@ private[spark] class HeartbeatReceiver(sc: SparkContext, clock: Clock)
// Asynchronously kill the executor to avoid blocking the current thread
killExecutorThread.submit(new Runnable {
override def run(): Unit = Utils.tryLogNonFatalError {
- sc.killExecutor(executorId)
+ // Note: we want to get an executor back after expiring this one,
+ // so do not simply call `sc.killExecutor` here (SPARK-8119)
+ sc.killAndReplaceExecutor(executorId)
}
})
}
diff --git a/core/src/main/scala/org/apache/spark/Logging.scala b/core/src/main/scala/org/apache/spark/Logging.scala
index 87ab099267b2f..f0598816d6c07 100644
--- a/core/src/main/scala/org/apache/spark/Logging.scala
+++ b/core/src/main/scala/org/apache/spark/Logging.scala
@@ -159,7 +159,7 @@ private object Logging {
try {
// We use reflection here to handle the case where users remove the
// slf4j-to-jul bridge order to route their logs to JUL.
- val bridgeClass = Class.forName("org.slf4j.bridge.SLF4JBridgeHandler")
+ val bridgeClass = Utils.classForName("org.slf4j.bridge.SLF4JBridgeHandler")
bridgeClass.getMethod("removeHandlersForRootLogger").invoke(null)
val installed = bridgeClass.getMethod("isInstalled").invoke(null).asInstanceOf[Boolean]
if (!installed) {
diff --git a/core/src/main/scala/org/apache/spark/MapOutputTracker.scala b/core/src/main/scala/org/apache/spark/MapOutputTracker.scala
index 862ffe868f58f..92218832d256f 100644
--- a/core/src/main/scala/org/apache/spark/MapOutputTracker.scala
+++ b/core/src/main/scala/org/apache/spark/MapOutputTracker.scala
@@ -21,14 +21,14 @@ import java.io._
import java.util.concurrent.ConcurrentHashMap
import java.util.zip.{GZIPInputStream, GZIPOutputStream}
-import scala.collection.mutable.{HashMap, HashSet, Map}
+import scala.collection.mutable.{ArrayBuffer, HashMap, HashSet, Map}
import scala.collection.JavaConversions._
import scala.reflect.ClassTag
import org.apache.spark.rpc.{RpcEndpointRef, RpcEnv, RpcCallContext, RpcEndpoint}
import org.apache.spark.scheduler.MapStatus
import org.apache.spark.shuffle.MetadataFetchFailedException
-import org.apache.spark.storage.BlockManagerId
+import org.apache.spark.storage.{BlockId, BlockManagerId, ShuffleBlockId}
import org.apache.spark.util._
private[spark] sealed trait MapOutputTrackerMessage
@@ -124,10 +124,18 @@ private[spark] abstract class MapOutputTracker(conf: SparkConf) extends Logging
}
/**
- * Called from executors to get the server URIs and output sizes of the map outputs of
- * a given shuffle.
+ * Called from executors to get the server URIs and output sizes for each shuffle block that
+ * needs to be read from a given reduce task.
+ *
+ * @return A sequence of 2-item tuples, where the first item in the tuple is a BlockManagerId,
+ * and the second item is a sequence of (shuffle block id, shuffle block size) tuples
+ * describing the shuffle blocks that are stored at that block manager.
*/
- def getServerStatuses(shuffleId: Int, reduceId: Int): Array[(BlockManagerId, Long)] = {
+ def getMapSizesByExecutorId(shuffleId: Int, reduceId: Int)
+ : Seq[(BlockManagerId, Seq[(BlockId, Long)])] = {
+ logDebug(s"Fetching outputs for shuffle $shuffleId, reduce $reduceId")
+ val startTime = System.currentTimeMillis
+
val statuses = mapStatuses.get(shuffleId).orNull
if (statuses == null) {
logInfo("Don't have map outputs for shuffle " + shuffleId + ", fetching them")
@@ -167,6 +175,9 @@ private[spark] abstract class MapOutputTracker(conf: SparkConf) extends Logging
}
}
}
+ logDebug(s"Fetching map output location for shuffle $shuffleId, reduce $reduceId took " +
+ s"${System.currentTimeMillis - startTime} ms")
+
if (fetchedStatuses != null) {
fetchedStatuses.synchronized {
return MapOutputTracker.convertMapStatuses(shuffleId, reduceId, fetchedStatuses)
@@ -421,23 +432,38 @@ private[spark] object MapOutputTracker extends Logging {
}
}
- // Convert an array of MapStatuses to locations and sizes for a given reduce ID. If
- // any of the statuses is null (indicating a missing location due to a failed mapper),
- // throw a FetchFailedException.
+ /**
+ * Converts an array of MapStatuses for a given reduce ID to a sequence that, for each block
+ * manager ID, lists the shuffle block ids and corresponding shuffle block sizes stored at that
+ * block manager.
+ *
+ * If any of the statuses is null (indicating a missing location due to a failed mapper),
+ * throws a FetchFailedException.
+ *
+ * @param shuffleId Identifier for the shuffle
+ * @param reduceId Identifier for the reduce task
+ * @param statuses List of map statuses, indexed by map ID.
+ * @return A sequence of 2-item tuples, where the first item in the tuple is a BlockManagerId,
+ * and the second item is a sequence of (shuffle block id, shuffle block size) tuples
+ * describing the shuffle blocks that are stored at that block manager.
+ */
private def convertMapStatuses(
shuffleId: Int,
reduceId: Int,
- statuses: Array[MapStatus]): Array[(BlockManagerId, Long)] = {
+ statuses: Array[MapStatus]): Seq[(BlockManagerId, Seq[(BlockId, Long)])] = {
assert (statuses != null)
- statuses.map {
- status =>
- if (status == null) {
- logError("Missing an output location for shuffle " + shuffleId)
- throw new MetadataFetchFailedException(
- shuffleId, reduceId, "Missing an output location for shuffle " + shuffleId)
- } else {
- (status.location, status.getSizeForBlock(reduceId))
- }
+ val splitsByAddress = new HashMap[BlockManagerId, ArrayBuffer[(BlockId, Long)]]
+ for ((status, mapId) <- statuses.zipWithIndex) {
+ if (status == null) {
+ val errorMessage = s"Missing an output location for shuffle $shuffleId"
+ logError(errorMessage)
+ throw new MetadataFetchFailedException(shuffleId, reduceId, errorMessage)
+ } else {
+ splitsByAddress.getOrElseUpdate(status.location, ArrayBuffer()) +=
+ ((ShuffleBlockId(shuffleId, mapId, reduceId), status.getSizeForBlock(reduceId)))
+ }
}
+
+ splitsByAddress.toSeq
}
}
diff --git a/core/src/main/scala/org/apache/spark/Partitioner.scala b/core/src/main/scala/org/apache/spark/Partitioner.scala
index 82889bcd30988..ad68512dccb79 100644
--- a/core/src/main/scala/org/apache/spark/Partitioner.scala
+++ b/core/src/main/scala/org/apache/spark/Partitioner.scala
@@ -76,6 +76,8 @@ object Partitioner {
* produce an unexpected or incorrect result.
*/
class HashPartitioner(partitions: Int) extends Partitioner {
+ require(partitions >= 0, s"Number of partitions ($partitions) cannot be negative.")
+
def numPartitions: Int = partitions
def getPartition(key: Any): Int = key match {
diff --git a/core/src/main/scala/org/apache/spark/SparkContext.scala b/core/src/main/scala/org/apache/spark/SparkContext.scala
index 82704b1ab2189..6a6b94a271cfc 100644
--- a/core/src/main/scala/org/apache/spark/SparkContext.scala
+++ b/core/src/main/scala/org/apache/spark/SparkContext.scala
@@ -471,7 +471,7 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli
.orElse(Option(System.getenv("SPARK_MEM"))
.map(warnSparkMem))
.map(Utils.memoryStringToMb)
- .getOrElse(512)
+ .getOrElse(1024)
// Convert java options to env vars as a work around
// since we can't set env vars directly in sbt.
@@ -1419,6 +1419,12 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli
/**
* :: DeveloperApi ::
* Request that the cluster manager kill the specified executors.
+ *
+ * Note: This is an indication to the cluster manager that the application wishes to adjust
+ * its resource usage downwards. If the application wishes to replace the executors it kills
+ * through this method with new ones, it should follow up explicitly with a call to
+ * {{SparkContext#requestExecutors}}.
+ *
* This is currently only supported in YARN mode. Return whether the request is received.
*/
@DeveloperApi
@@ -1436,12 +1442,42 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli
/**
* :: DeveloperApi ::
- * Request that cluster manager the kill the specified executor.
- * This is currently only supported in Yarn mode. Return whether the request is received.
+ * Request that the cluster manager kill the specified executor.
+ *
+ * Note: This is an indication to the cluster manager that the application wishes to adjust
+ * its resource usage downwards. If the application wishes to replace the executor it kills
+ * through this method with a new one, it should follow up explicitly with a call to
+ * {{SparkContext#requestExecutors}}.
+ *
+ * This is currently only supported in YARN mode. Return whether the request is received.
*/
@DeveloperApi
override def killExecutor(executorId: String): Boolean = super.killExecutor(executorId)
+ /**
+ * Request that the cluster manager kill the specified executor without adjusting the
+ * application resource requirements.
+ *
+ * The effect is that a new executor will be launched in place of the one killed by
+ * this request. This assumes the cluster manager will automatically and eventually
+ * fulfill all missing application resource requests.
+ *
+ * Note: The replace is by no means guaranteed; another application on the same cluster
+ * can steal the window of opportunity and acquire this application's resources in the
+ * mean time.
+ *
+ * This is currently only supported in YARN mode. Return whether the request is received.
+ */
+ private[spark] def killAndReplaceExecutor(executorId: String): Boolean = {
+ schedulerBackend match {
+ case b: CoarseGrainedSchedulerBackend =>
+ b.killExecutors(Seq(executorId), replace = true)
+ case _ =>
+ logWarning("Killing executors is only supported in coarse-grained mode")
+ false
+ }
+ }
+
/** The version of Spark on which this application is running. */
def version: String = SPARK_VERSION
@@ -1722,16 +1758,13 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli
/**
* Run a function on a given set of partitions in an RDD and pass the results to the given
- * handler function. This is the main entry point for all actions in Spark. The allowLocal
- * flag specifies whether the scheduler can run the computation on the driver rather than
- * shipping it out to the cluster, for short actions like first().
+ * handler function. This is the main entry point for all actions in Spark.
*/
def runJob[T, U: ClassTag](
rdd: RDD[T],
func: (TaskContext, Iterator[T]) => U,
partitions: Seq[Int],
- allowLocal: Boolean,
- resultHandler: (Int, U) => Unit) {
+ resultHandler: (Int, U) => Unit): Unit = {
if (stopped.get()) {
throw new IllegalStateException("SparkContext has been shutdown")
}
@@ -1741,54 +1774,104 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli
if (conf.getBoolean("spark.logLineage", false)) {
logInfo("RDD's recursive dependencies:\n" + rdd.toDebugString)
}
- dagScheduler.runJob(rdd, cleanedFunc, partitions, callSite, allowLocal,
- resultHandler, localProperties.get)
+ dagScheduler.runJob(rdd, cleanedFunc, partitions, callSite, resultHandler, localProperties.get)
progressBar.foreach(_.finishAll())
rdd.doCheckpoint()
}
/**
- * Run a function on a given set of partitions in an RDD and return the results as an array. The
- * allowLocal flag specifies whether the scheduler can run the computation on the driver rather
- * than shipping it out to the cluster, for short actions like first().
+ * Run a function on a given set of partitions in an RDD and return the results as an array.
+ */
+ def runJob[T, U: ClassTag](
+ rdd: RDD[T],
+ func: (TaskContext, Iterator[T]) => U,
+ partitions: Seq[Int]): Array[U] = {
+ val results = new Array[U](partitions.size)
+ runJob[T, U](rdd, func, partitions, (index, res) => results(index) = res)
+ results
+ }
+
+ /**
+ * Run a job on a given set of partitions of an RDD, but take a function of type
+ * `Iterator[T] => U` instead of `(TaskContext, Iterator[T]) => U`.
+ */
+ def runJob[T, U: ClassTag](
+ rdd: RDD[T],
+ func: Iterator[T] => U,
+ partitions: Seq[Int]): Array[U] = {
+ val cleanedFunc = clean(func)
+ runJob(rdd, (ctx: TaskContext, it: Iterator[T]) => cleanedFunc(it), partitions)
+ }
+
+
+ /**
+ * Run a function on a given set of partitions in an RDD and pass the results to the given
+ * handler function. This is the main entry point for all actions in Spark.
+ *
+ * The allowLocal flag is deprecated as of Spark 1.5.0+.
+ */
+ @deprecated("use the version of runJob without the allowLocal parameter", "1.5.0")
+ def runJob[T, U: ClassTag](
+ rdd: RDD[T],
+ func: (TaskContext, Iterator[T]) => U,
+ partitions: Seq[Int],
+ allowLocal: Boolean,
+ resultHandler: (Int, U) => Unit): Unit = {
+ if (allowLocal) {
+ logWarning("sc.runJob with allowLocal=true is deprecated in Spark 1.5.0+")
+ }
+ runJob(rdd, func, partitions, resultHandler)
+ }
+
+ /**
+ * Run a function on a given set of partitions in an RDD and return the results as an array.
+ *
+ * The allowLocal flag is deprecated as of Spark 1.5.0+.
*/
+ @deprecated("use the version of runJob without the allowLocal parameter", "1.5.0")
def runJob[T, U: ClassTag](
rdd: RDD[T],
func: (TaskContext, Iterator[T]) => U,
partitions: Seq[Int],
allowLocal: Boolean
): Array[U] = {
- val results = new Array[U](partitions.size)
- runJob[T, U](rdd, func, partitions, allowLocal, (index, res) => results(index) = res)
- results
+ if (allowLocal) {
+ logWarning("sc.runJob with allowLocal=true is deprecated in Spark 1.5.0+")
+ }
+ runJob(rdd, func, partitions)
}
/**
* Run a job on a given set of partitions of an RDD, but take a function of type
* `Iterator[T] => U` instead of `(TaskContext, Iterator[T]) => U`.
+ *
+ * The allowLocal argument is deprecated as of Spark 1.5.0+.
*/
+ @deprecated("use the version of runJob without the allowLocal parameter", "1.5.0")
def runJob[T, U: ClassTag](
rdd: RDD[T],
func: Iterator[T] => U,
partitions: Seq[Int],
allowLocal: Boolean
): Array[U] = {
- val cleanedFunc = clean(func)
- runJob(rdd, (ctx: TaskContext, it: Iterator[T]) => cleanedFunc(it), partitions, allowLocal)
+ if (allowLocal) {
+ logWarning("sc.runJob with allowLocal=true is deprecated in Spark 1.5.0+")
+ }
+ runJob(rdd, func, partitions)
}
/**
* Run a job on all partitions in an RDD and return the results in an array.
*/
def runJob[T, U: ClassTag](rdd: RDD[T], func: (TaskContext, Iterator[T]) => U): Array[U] = {
- runJob(rdd, func, 0 until rdd.partitions.size, false)
+ runJob(rdd, func, 0 until rdd.partitions.length)
}
/**
* Run a job on all partitions in an RDD and return the results in an array.
*/
def runJob[T, U: ClassTag](rdd: RDD[T], func: Iterator[T] => U): Array[U] = {
- runJob(rdd, func, 0 until rdd.partitions.size, false)
+ runJob(rdd, func, 0 until rdd.partitions.length)
}
/**
@@ -1799,7 +1882,7 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli
processPartition: (TaskContext, Iterator[T]) => U,
resultHandler: (Int, U) => Unit)
{
- runJob[T, U](rdd, processPartition, 0 until rdd.partitions.size, false, resultHandler)
+ runJob[T, U](rdd, processPartition, 0 until rdd.partitions.length, resultHandler)
}
/**
@@ -1811,7 +1894,7 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli
resultHandler: (Int, U) => Unit)
{
val processFunc = (context: TaskContext, iter: Iterator[T]) => processPartition(iter)
- runJob[T, U](rdd, processFunc, 0 until rdd.partitions.size, false, resultHandler)
+ runJob[T, U](rdd, processFunc, 0 until rdd.partitions.length, resultHandler)
}
/**
@@ -1856,7 +1939,6 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli
(context: TaskContext, iter: Iterator[T]) => cleanF(iter),
partitions,
callSite,
- allowLocal = false,
resultHandler,
localProperties.get)
new SimpleFutureAction(waiter, resultFunc)
@@ -1968,7 +2050,7 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli
for (className <- listenerClassNames) {
// Use reflection to find the right constructor
val constructors = {
- val listenerClass = Class.forName(className)
+ val listenerClass = Utils.classForName(className)
listenerClass.getConstructors.asInstanceOf[Array[Constructor[_ <: SparkListener]]]
}
val constructorTakingSparkConf = constructors.find { c =>
@@ -2503,7 +2585,7 @@ object SparkContext extends Logging {
"\"yarn-standalone\" is deprecated as of Spark 1.0. Use \"yarn-cluster\" instead.")
}
val scheduler = try {
- val clazz = Class.forName("org.apache.spark.scheduler.cluster.YarnClusterScheduler")
+ val clazz = Utils.classForName("org.apache.spark.scheduler.cluster.YarnClusterScheduler")
val cons = clazz.getConstructor(classOf[SparkContext])
cons.newInstance(sc).asInstanceOf[TaskSchedulerImpl]
} catch {
@@ -2515,7 +2597,7 @@ object SparkContext extends Logging {
}
val backend = try {
val clazz =
- Class.forName("org.apache.spark.scheduler.cluster.YarnClusterSchedulerBackend")
+ Utils.classForName("org.apache.spark.scheduler.cluster.YarnClusterSchedulerBackend")
val cons = clazz.getConstructor(classOf[TaskSchedulerImpl], classOf[SparkContext])
cons.newInstance(scheduler, sc).asInstanceOf[CoarseGrainedSchedulerBackend]
} catch {
@@ -2528,8 +2610,7 @@ object SparkContext extends Logging {
case "yarn-client" =>
val scheduler = try {
- val clazz =
- Class.forName("org.apache.spark.scheduler.cluster.YarnScheduler")
+ val clazz = Utils.classForName("org.apache.spark.scheduler.cluster.YarnScheduler")
val cons = clazz.getConstructor(classOf[SparkContext])
cons.newInstance(sc).asInstanceOf[TaskSchedulerImpl]
@@ -2541,7 +2622,7 @@ object SparkContext extends Logging {
val backend = try {
val clazz =
- Class.forName("org.apache.spark.scheduler.cluster.YarnClientSchedulerBackend")
+ Utils.classForName("org.apache.spark.scheduler.cluster.YarnClientSchedulerBackend")
val cons = clazz.getConstructor(classOf[TaskSchedulerImpl], classOf[SparkContext])
cons.newInstance(scheduler, sc).asInstanceOf[CoarseGrainedSchedulerBackend]
} catch {
diff --git a/core/src/main/scala/org/apache/spark/SparkEnv.scala b/core/src/main/scala/org/apache/spark/SparkEnv.scala
index d18fc599e9890..adfece4d6e7c0 100644
--- a/core/src/main/scala/org/apache/spark/SparkEnv.scala
+++ b/core/src/main/scala/org/apache/spark/SparkEnv.scala
@@ -261,7 +261,7 @@ object SparkEnv extends Logging {
// Create an instance of the class with the given name, possibly initializing it with our conf
def instantiateClass[T](className: String): T = {
- val cls = Class.forName(className, true, Utils.getContextOrSparkClassLoader)
+ val cls = Utils.classForName(className)
// Look for a constructor taking a SparkConf and a boolean isDriver, then one taking just
// SparkConf, then one taking no arguments
try {
diff --git a/core/src/main/scala/org/apache/spark/TaskContext.scala b/core/src/main/scala/org/apache/spark/TaskContext.scala
index d09e17dea0911..b48836d5c8897 100644
--- a/core/src/main/scala/org/apache/spark/TaskContext.scala
+++ b/core/src/main/scala/org/apache/spark/TaskContext.scala
@@ -21,6 +21,7 @@ import java.io.Serializable
import org.apache.spark.annotation.DeveloperApi
import org.apache.spark.executor.TaskMetrics
+import org.apache.spark.metrics.source.Source
import org.apache.spark.unsafe.memory.TaskMemoryManager
import org.apache.spark.util.TaskCompletionListener
@@ -32,7 +33,20 @@ object TaskContext {
*/
def get(): TaskContext = taskContext.get
- private val taskContext: ThreadLocal[TaskContext] = new ThreadLocal[TaskContext]
+ /**
+ * Returns the partition id of currently active TaskContext. It will return 0
+ * if there is no active TaskContext for cases like local execution.
+ */
+ def getPartitionId(): Int = {
+ val tc = taskContext.get()
+ if (tc eq null) {
+ 0
+ } else {
+ tc.partitionId()
+ }
+ }
+
+ private[this] val taskContext: ThreadLocal[TaskContext] = new ThreadLocal[TaskContext]
// Note: protected[spark] instead of private[spark] to prevent the following two from
// showing up in JavaDoc.
@@ -135,8 +149,34 @@ abstract class TaskContext extends Serializable {
@DeveloperApi
def taskMetrics(): TaskMetrics
+ /**
+ * ::DeveloperApi::
+ * Returns all metrics sources with the given name which are associated with the instance
+ * which runs the task. For more information see [[org.apache.spark.metrics.MetricsSystem!]].
+ */
+ @DeveloperApi
+ def getMetricsSources(sourceName: String): Seq[Source]
+
/**
* Returns the manager for this task's managed memory.
*/
private[spark] def taskMemoryManager(): TaskMemoryManager
+
+ /**
+ * Register an accumulator that belongs to this task. Accumulators must call this method when
+ * deserializing in executors.
+ */
+ private[spark] def registerAccumulator(a: Accumulable[_, _]): Unit
+
+ /**
+ * Return the local values of internal accumulators that belong to this task. The key of the Map
+ * is the accumulator id and the value of the Map is the latest accumulator local value.
+ */
+ private[spark] def collectInternalAccumulators(): Map[Long, Any]
+
+ /**
+ * Return the local values of accumulators that belong to this task. The key of the Map is the
+ * accumulator id and the value of the Map is the latest accumulator local value.
+ */
+ private[spark] def collectAccumulators(): Map[Long, Any]
}
diff --git a/core/src/main/scala/org/apache/spark/TaskContextImpl.scala b/core/src/main/scala/org/apache/spark/TaskContextImpl.scala
index b4d572cb52313..9ee168ae016f8 100644
--- a/core/src/main/scala/org/apache/spark/TaskContextImpl.scala
+++ b/core/src/main/scala/org/apache/spark/TaskContextImpl.scala
@@ -17,18 +17,21 @@
package org.apache.spark
+import scala.collection.mutable.{ArrayBuffer, HashMap}
+
import org.apache.spark.executor.TaskMetrics
+import org.apache.spark.metrics.MetricsSystem
+import org.apache.spark.metrics.source.Source
import org.apache.spark.unsafe.memory.TaskMemoryManager
import org.apache.spark.util.{TaskCompletionListener, TaskCompletionListenerException}
-import scala.collection.mutable.ArrayBuffer
-
private[spark] class TaskContextImpl(
val stageId: Int,
val partitionId: Int,
override val taskAttemptId: Long,
override val attemptNumber: Int,
override val taskMemoryManager: TaskMemoryManager,
+ @transient private val metricsSystem: MetricsSystem,
val runningLocally: Boolean = false,
val taskMetrics: TaskMetrics = TaskMetrics.empty)
extends TaskContext
@@ -94,5 +97,21 @@ private[spark] class TaskContextImpl(
override def isRunningLocally(): Boolean = runningLocally
override def isInterrupted(): Boolean = interrupted
-}
+ override def getMetricsSources(sourceName: String): Seq[Source] =
+ metricsSystem.getSourcesByName(sourceName)
+
+ @transient private val accumulators = new HashMap[Long, Accumulable[_, _]]
+
+ private[spark] override def registerAccumulator(a: Accumulable[_, _]): Unit = synchronized {
+ accumulators(a.id) = a
+ }
+
+ private[spark] override def collectInternalAccumulators(): Map[Long, Any] = synchronized {
+ accumulators.filter(_._2.isInternal).mapValues(_.localValue).toMap
+ }
+
+ private[spark] override def collectAccumulators(): Map[Long, Any] = synchronized {
+ accumulators.mapValues(_.localValue).toMap
+ }
+}
diff --git a/core/src/main/scala/org/apache/spark/api/java/JavaRDDLike.scala b/core/src/main/scala/org/apache/spark/api/java/JavaRDDLike.scala
index c95615a5a9307..829fae1d1d9bf 100644
--- a/core/src/main/scala/org/apache/spark/api/java/JavaRDDLike.scala
+++ b/core/src/main/scala/org/apache/spark/api/java/JavaRDDLike.scala
@@ -364,7 +364,7 @@ trait JavaRDDLike[T, This <: JavaRDDLike[T, This]] extends Serializable {
// This is useful for implementing `take` from other language frontends
// like Python where the data is serialized.
import scala.collection.JavaConversions._
- val res = context.runJob(rdd, (it: Iterator[T]) => it.toArray, partitionIds, true)
+ val res = context.runJob(rdd, (it: Iterator[T]) => it.toArray, partitionIds)
res.map(x => new java.util.ArrayList(x.toSeq)).toArray
}
diff --git a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala
index dc9f62f39e6d5..598953ac3bcc8 100644
--- a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala
+++ b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala
@@ -358,12 +358,11 @@ private[spark] object PythonRDD extends Logging {
def runJob(
sc: SparkContext,
rdd: JavaRDD[Array[Byte]],
- partitions: JArrayList[Int],
- allowLocal: Boolean): Int = {
+ partitions: JArrayList[Int]): Int = {
type ByteArray = Array[Byte]
type UnrolledPartition = Array[ByteArray]
val allPartitions: Array[UnrolledPartition] =
- sc.runJob(rdd, (x: Iterator[ByteArray]) => x.toArray, partitions, allowLocal)
+ sc.runJob(rdd, (x: Iterator[ByteArray]) => x.toArray, partitions)
val flattenedPartition: UnrolledPartition = Array.concat(allPartitions: _*)
serveIterator(flattenedPartition.iterator,
s"serve RDD ${rdd.id} with partitions ${partitions.mkString(",")}")
diff --git a/core/src/main/scala/org/apache/spark/api/r/RBackendHandler.scala b/core/src/main/scala/org/apache/spark/api/r/RBackendHandler.scala
index 4b8f7fe9242e0..a5de10fe89c42 100644
--- a/core/src/main/scala/org/apache/spark/api/r/RBackendHandler.scala
+++ b/core/src/main/scala/org/apache/spark/api/r/RBackendHandler.scala
@@ -20,12 +20,14 @@ package org.apache.spark.api.r
import java.io.{ByteArrayInputStream, ByteArrayOutputStream, DataInputStream, DataOutputStream}
import scala.collection.mutable.HashMap
+import scala.language.existentials
import io.netty.channel.ChannelHandler.Sharable
import io.netty.channel.{ChannelHandlerContext, SimpleChannelInboundHandler}
import org.apache.spark.Logging
import org.apache.spark.api.r.SerDe._
+import org.apache.spark.util.Utils
/**
* Handler for RBackend
@@ -88,21 +90,6 @@ private[r] class RBackendHandler(server: RBackend)
ctx.close()
}
- // Looks up a class given a class name. This function first checks the
- // current class loader and if a class is not found, it looks up the class
- // in the context class loader. Address [SPARK-5185]
- def getStaticClass(objId: String): Class[_] = {
- try {
- val clsCurrent = Class.forName(objId)
- clsCurrent
- } catch {
- // Use contextLoader if we can't find the JAR in the system class loader
- case e: ClassNotFoundException =>
- val clsContext = Class.forName(objId, true, Thread.currentThread().getContextClassLoader)
- clsContext
- }
- }
-
def handleMethodCall(
isStatic: Boolean,
objId: String,
@@ -113,7 +100,7 @@ private[r] class RBackendHandler(server: RBackend)
var obj: Object = null
try {
val cls = if (isStatic) {
- getStaticClass(objId)
+ Utils.classForName(objId)
} else {
JVMObjectTracker.get(objId) match {
case None => throw new IllegalArgumentException("Object not found " + objId)
diff --git a/core/src/main/scala/org/apache/spark/api/r/RRDD.scala b/core/src/main/scala/org/apache/spark/api/r/RRDD.scala
index ff1702f7dea48..23a470d6afcae 100644
--- a/core/src/main/scala/org/apache/spark/api/r/RRDD.scala
+++ b/core/src/main/scala/org/apache/spark/api/r/RRDD.scala
@@ -39,7 +39,6 @@ private abstract class BaseRRDD[T: ClassTag, U: ClassTag](
deserializer: String,
serializer: String,
packageNames: Array[Byte],
- rLibDir: String,
broadcastVars: Array[Broadcast[Object]])
extends RDD[U](parent) with Logging {
protected var dataStream: DataInputStream = _
@@ -60,7 +59,7 @@ private abstract class BaseRRDD[T: ClassTag, U: ClassTag](
// The stdout/stderr is shared by multiple tasks, because we use one daemon
// to launch child process as worker.
- val errThread = RRDD.createRWorker(rLibDir, listenPort)
+ val errThread = RRDD.createRWorker(listenPort)
// We use two sockets to separate input and output, then it's easy to manage
// the lifecycle of them to avoid deadlock.
@@ -235,11 +234,10 @@ private class PairwiseRRDD[T: ClassTag](
hashFunc: Array[Byte],
deserializer: String,
packageNames: Array[Byte],
- rLibDir: String,
broadcastVars: Array[Object])
extends BaseRRDD[T, (Int, Array[Byte])](
parent, numPartitions, hashFunc, deserializer,
- SerializationFormats.BYTE, packageNames, rLibDir,
+ SerializationFormats.BYTE, packageNames,
broadcastVars.map(x => x.asInstanceOf[Broadcast[Object]])) {
override protected def readData(length: Int): (Int, Array[Byte]) = {
@@ -266,10 +264,9 @@ private class RRDD[T: ClassTag](
deserializer: String,
serializer: String,
packageNames: Array[Byte],
- rLibDir: String,
broadcastVars: Array[Object])
extends BaseRRDD[T, Array[Byte]](
- parent, -1, func, deserializer, serializer, packageNames, rLibDir,
+ parent, -1, func, deserializer, serializer, packageNames,
broadcastVars.map(x => x.asInstanceOf[Broadcast[Object]])) {
override protected def readData(length: Int): Array[Byte] = {
@@ -293,10 +290,9 @@ private class StringRRDD[T: ClassTag](
func: Array[Byte],
deserializer: String,
packageNames: Array[Byte],
- rLibDir: String,
broadcastVars: Array[Object])
extends BaseRRDD[T, String](
- parent, -1, func, deserializer, SerializationFormats.STRING, packageNames, rLibDir,
+ parent, -1, func, deserializer, SerializationFormats.STRING, packageNames,
broadcastVars.map(x => x.asInstanceOf[Broadcast[Object]])) {
override protected def readData(length: Int): String = {
@@ -392,9 +388,10 @@ private[r] object RRDD {
thread
}
- private def createRProcess(rLibDir: String, port: Int, script: String): BufferedStreamThread = {
+ private def createRProcess(port: Int, script: String): BufferedStreamThread = {
val rCommand = SparkEnv.get.conf.get("spark.sparkr.r.command", "Rscript")
val rOptions = "--vanilla"
+ val rLibDir = RUtils.sparkRPackagePath(isDriver = false)
val rExecScript = rLibDir + "/SparkR/worker/" + script
val pb = new ProcessBuilder(List(rCommand, rOptions, rExecScript))
// Unset the R_TESTS environment variable for workers.
@@ -413,7 +410,7 @@ private[r] object RRDD {
/**
* ProcessBuilder used to launch worker R processes.
*/
- def createRWorker(rLibDir: String, port: Int): BufferedStreamThread = {
+ def createRWorker(port: Int): BufferedStreamThread = {
val useDaemon = SparkEnv.get.conf.getBoolean("spark.sparkr.use.daemon", true)
if (!Utils.isWindows && useDaemon) {
synchronized {
@@ -421,7 +418,7 @@ private[r] object RRDD {
// we expect one connections
val serverSocket = new ServerSocket(0, 1, InetAddress.getByName("localhost"))
val daemonPort = serverSocket.getLocalPort
- errThread = createRProcess(rLibDir, daemonPort, "daemon.R")
+ errThread = createRProcess(daemonPort, "daemon.R")
// the socket used to send out the input of task
serverSocket.setSoTimeout(10000)
val sock = serverSocket.accept()
@@ -443,7 +440,7 @@ private[r] object RRDD {
errThread
}
} else {
- createRProcess(rLibDir, port, "worker.R")
+ createRProcess(port, "worker.R")
}
}
diff --git a/core/src/main/scala/org/apache/spark/api/r/RUtils.scala b/core/src/main/scala/org/apache/spark/api/r/RUtils.scala
new file mode 100644
index 0000000000000..d53abd3408c55
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/api/r/RUtils.scala
@@ -0,0 +1,65 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.api.r
+
+import java.io.File
+
+import org.apache.spark.{SparkEnv, SparkException}
+
+private[spark] object RUtils {
+ /**
+ * Get the SparkR package path in the local spark distribution.
+ */
+ def localSparkRPackagePath: Option[String] = {
+ val sparkHome = sys.env.get("SPARK_HOME")
+ sparkHome.map(
+ Seq(_, "R", "lib").mkString(File.separator)
+ )
+ }
+
+ /**
+ * Get the SparkR package path in various deployment modes.
+ * This assumes that Spark properties `spark.master` and `spark.submit.deployMode`
+ * and environment variable `SPARK_HOME` are set.
+ */
+ def sparkRPackagePath(isDriver: Boolean): String = {
+ val (master, deployMode) =
+ if (isDriver) {
+ (sys.props("spark.master"), sys.props("spark.submit.deployMode"))
+ } else {
+ val sparkConf = SparkEnv.get.conf
+ (sparkConf.get("spark.master"), sparkConf.get("spark.submit.deployMode"))
+ }
+
+ val isYarnCluster = master.contains("yarn") && deployMode == "cluster"
+ val isYarnClient = master.contains("yarn") && deployMode == "client"
+
+ // In YARN mode, the SparkR package is distributed as an archive symbolically
+ // linked to the "sparkr" file in the current directory. Note that this does not apply
+ // to the driver in client mode because it is run outside of the cluster.
+ if (isYarnCluster || (isYarnClient && !isDriver)) {
+ new File("sparkr").getAbsolutePath
+ } else {
+ // Otherwise, assume the package is local
+ // TODO: support this for Mesos
+ localSparkRPackagePath.getOrElse {
+ throw new SparkException("SPARK_HOME not set. Can't locate SparkR package.")
+ }
+ }
+ }
+}
diff --git a/core/src/main/scala/org/apache/spark/api/r/SerDe.scala b/core/src/main/scala/org/apache/spark/api/r/SerDe.scala
index 56adc857d4ce0..d5b4260bf4529 100644
--- a/core/src/main/scala/org/apache/spark/api/r/SerDe.scala
+++ b/core/src/main/scala/org/apache/spark/api/r/SerDe.scala
@@ -179,6 +179,7 @@ private[spark] object SerDe {
// Int -> integer
// String -> character
// Boolean -> logical
+ // Float -> double
// Double -> double
// Long -> double
// Array[Byte] -> raw
@@ -215,6 +216,9 @@ private[spark] object SerDe {
case "long" | "java.lang.Long" =>
writeType(dos, "double")
writeDouble(dos, value.asInstanceOf[Long].toDouble)
+ case "float" | "java.lang.Float" =>
+ writeType(dos, "double")
+ writeDouble(dos, value.asInstanceOf[Float].toDouble)
case "double" | "java.lang.Double" =>
writeType(dos, "double")
writeDouble(dos, value.asInstanceOf[Double])
diff --git a/core/src/main/scala/org/apache/spark/broadcast/BroadcastManager.scala b/core/src/main/scala/org/apache/spark/broadcast/BroadcastManager.scala
index 685313ac009ba..fac6666bb3410 100644
--- a/core/src/main/scala/org/apache/spark/broadcast/BroadcastManager.scala
+++ b/core/src/main/scala/org/apache/spark/broadcast/BroadcastManager.scala
@@ -22,6 +22,7 @@ import java.util.concurrent.atomic.AtomicLong
import scala.reflect.ClassTag
import org.apache.spark._
+import org.apache.spark.util.Utils
private[spark] class BroadcastManager(
val isDriver: Boolean,
@@ -42,7 +43,7 @@ private[spark] class BroadcastManager(
conf.get("spark.broadcast.factory", "org.apache.spark.broadcast.TorrentBroadcastFactory")
broadcastFactory =
- Class.forName(broadcastFactoryClass).newInstance.asInstanceOf[BroadcastFactory]
+ Utils.classForName(broadcastFactoryClass).newInstance.asInstanceOf[BroadcastFactory]
// Initialize appropriate BroadcastFactory and BroadcastObject
broadcastFactory.initialize(isDriver, conf, securityManager)
diff --git a/core/src/main/scala/org/apache/spark/deploy/RRunner.scala b/core/src/main/scala/org/apache/spark/deploy/RRunner.scala
index 4165740312e03..c0cab22fa8252 100644
--- a/core/src/main/scala/org/apache/spark/deploy/RRunner.scala
+++ b/core/src/main/scala/org/apache/spark/deploy/RRunner.scala
@@ -24,7 +24,7 @@ import scala.collection.JavaConversions._
import org.apache.hadoop.fs.Path
-import org.apache.spark.api.r.RBackend
+import org.apache.spark.api.r.{RBackend, RUtils}
import org.apache.spark.util.RedirectThread
/**
@@ -71,9 +71,10 @@ object RRunner {
val builder = new ProcessBuilder(Seq(rCommand, rFileNormalized) ++ otherArgs)
val env = builder.environment()
env.put("EXISTING_SPARKR_BACKEND_PORT", sparkRBackendPort.toString)
- val sparkHome = System.getenv("SPARK_HOME")
+ val rPackageDir = RUtils.sparkRPackagePath(isDriver = true)
+ env.put("SPARKR_PACKAGE_DIR", rPackageDir)
env.put("R_PROFILE_USER",
- Seq(sparkHome, "R", "lib", "SparkR", "profile", "general.R").mkString(File.separator))
+ Seq(rPackageDir, "SparkR", "profile", "general.R").mkString(File.separator))
builder.redirectErrorStream(true) // Ugly but needed for stdout and stderr to synchronize
val process = builder.start()
diff --git a/core/src/main/scala/org/apache/spark/deploy/SparkHadoopUtil.scala b/core/src/main/scala/org/apache/spark/deploy/SparkHadoopUtil.scala
index 6d14590a1d192..e06b06e06fb4a 100644
--- a/core/src/main/scala/org/apache/spark/deploy/SparkHadoopUtil.scala
+++ b/core/src/main/scala/org/apache/spark/deploy/SparkHadoopUtil.scala
@@ -25,6 +25,7 @@ import java.util.{Arrays, Comparator}
import scala.collection.JavaConversions._
import scala.concurrent.duration._
import scala.language.postfixOps
+import scala.util.control.NonFatal
import com.google.common.primitives.Longs
import org.apache.hadoop.conf.Configuration
@@ -178,7 +179,7 @@ class SparkHadoopUtil extends Logging {
private def getFileSystemThreadStatisticsMethod(methodName: String): Method = {
val statisticsDataClass =
- Class.forName("org.apache.hadoop.fs.FileSystem$Statistics$StatisticsData")
+ Utils.classForName("org.apache.hadoop.fs.FileSystem$Statistics$StatisticsData")
statisticsDataClass.getDeclaredMethod(methodName)
}
@@ -238,6 +239,14 @@ class SparkHadoopUtil extends Logging {
}.getOrElse(Seq.empty[Path])
}
+ def globPathIfNecessary(pattern: Path): Seq[Path] = {
+ if (pattern.toString.exists("{}[]*?\\".toSet.contains)) {
+ globPath(pattern)
+ } else {
+ Seq(pattern)
+ }
+ }
+
/**
* Lists all the files in a directory with the specified prefix, and does not end with the
* given suffix. The returned {{FileStatus}} instances are sorted by the modification times of
@@ -248,19 +257,25 @@ class SparkHadoopUtil extends Logging {
dir: Path,
prefix: String,
exclusionSuffix: String): Array[FileStatus] = {
- val fileStatuses = remoteFs.listStatus(dir,
- new PathFilter {
- override def accept(path: Path): Boolean = {
- val name = path.getName
- name.startsWith(prefix) && !name.endsWith(exclusionSuffix)
+ try {
+ val fileStatuses = remoteFs.listStatus(dir,
+ new PathFilter {
+ override def accept(path: Path): Boolean = {
+ val name = path.getName
+ name.startsWith(prefix) && !name.endsWith(exclusionSuffix)
+ }
+ })
+ Arrays.sort(fileStatuses, new Comparator[FileStatus] {
+ override def compare(o1: FileStatus, o2: FileStatus): Int = {
+ Longs.compare(o1.getModificationTime, o2.getModificationTime)
}
})
- Arrays.sort(fileStatuses, new Comparator[FileStatus] {
- override def compare(o1: FileStatus, o2: FileStatus): Int = {
- Longs.compare(o1.getModificationTime, o2.getModificationTime)
- }
- })
- fileStatuses
+ fileStatuses
+ } catch {
+ case NonFatal(e) =>
+ logWarning("Error while attempting to list files from application staging dir", e)
+ Array.empty
+ }
}
/**
@@ -356,7 +371,7 @@ object SparkHadoopUtil {
System.getProperty("SPARK_YARN_MODE", System.getenv("SPARK_YARN_MODE")))
if (yarnMode) {
try {
- Class.forName("org.apache.spark.deploy.yarn.YarnSparkHadoopUtil")
+ Utils.classForName("org.apache.spark.deploy.yarn.YarnSparkHadoopUtil")
.newInstance()
.asInstanceOf[SparkHadoopUtil]
} catch {
diff --git a/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala b/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala
index 4cec9017b8adb..0b39ee8fe3ba0 100644
--- a/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala
+++ b/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala
@@ -37,6 +37,7 @@ import org.apache.ivy.core.settings.IvySettings
import org.apache.ivy.plugins.matcher.GlobPatternMatcher
import org.apache.ivy.plugins.repository.file.FileRepository
import org.apache.ivy.plugins.resolver.{FileSystemResolver, ChainResolver, IBiblioResolver}
+import org.apache.spark.api.r.RUtils
import org.apache.spark.SPARK_VERSION
import org.apache.spark.deploy.rest._
import org.apache.spark.util.{ChildFirstURLClassLoader, MutableURLClassLoader, Utils}
@@ -79,6 +80,7 @@ object SparkSubmit {
private val SPARK_SHELL = "spark-shell"
private val PYSPARK_SHELL = "pyspark-shell"
private val SPARKR_SHELL = "sparkr-shell"
+ private val SPARKR_PACKAGE_ARCHIVE = "sparkr.zip"
private val CLASS_NOT_FOUND_EXIT_STATUS = 101
@@ -262,6 +264,12 @@ object SparkSubmit {
}
}
+ // Update args.deployMode if it is null. It will be passed down as a Spark property later.
+ (args.deployMode, deployMode) match {
+ case (null, CLIENT) => args.deployMode = "client"
+ case (null, CLUSTER) => args.deployMode = "cluster"
+ case _ =>
+ }
val isYarnCluster = clusterManager == YARN && deployMode == CLUSTER
val isMesosCluster = clusterManager == MESOS && deployMode == CLUSTER
@@ -347,6 +355,23 @@ object SparkSubmit {
}
}
+ // In YARN mode for an R app, add the SparkR package archive to archives
+ // that can be distributed with the job
+ if (args.isR && clusterManager == YARN) {
+ val rPackagePath = RUtils.localSparkRPackagePath
+ if (rPackagePath.isEmpty) {
+ printErrorAndExit("SPARK_HOME does not exist for R application in YARN mode.")
+ }
+ val rPackageFile = new File(rPackagePath.get, SPARKR_PACKAGE_ARCHIVE)
+ if (!rPackageFile.exists()) {
+ printErrorAndExit(s"$SPARKR_PACKAGE_ARCHIVE does not exist for R application in YARN mode.")
+ }
+ val localURI = Utils.resolveURI(rPackageFile.getAbsolutePath)
+
+ // Assigns a symbol link name "sparkr" to the shipped package.
+ args.archives = mergeFileLists(args.archives, localURI.toString + "#sparkr")
+ }
+
// If we're running a R app, set the main class to our specific R runner
if (args.isR && deployMode == CLIENT) {
if (args.primaryResource == SPARKR_SHELL) {
@@ -375,6 +400,8 @@ object SparkSubmit {
// All cluster managers
OptionAssigner(args.master, ALL_CLUSTER_MGRS, ALL_DEPLOY_MODES, sysProp = "spark.master"),
+ OptionAssigner(args.deployMode, ALL_CLUSTER_MGRS, ALL_DEPLOY_MODES,
+ sysProp = "spark.submit.deployMode"),
OptionAssigner(args.name, ALL_CLUSTER_MGRS, ALL_DEPLOY_MODES, sysProp = "spark.app.name"),
OptionAssigner(args.jars, ALL_CLUSTER_MGRS, CLIENT, sysProp = "spark.jars"),
OptionAssigner(args.ivyRepoPath, ALL_CLUSTER_MGRS, CLIENT, sysProp = "spark.jars.ivy"),
@@ -481,8 +508,14 @@ object SparkSubmit {
}
// Let YARN know it's a pyspark app, so it distributes needed libraries.
- if (clusterManager == YARN && args.isPython) {
- sysProps.put("spark.yarn.isPython", "true")
+ if (clusterManager == YARN) {
+ if (args.isPython) {
+ sysProps.put("spark.yarn.isPython", "true")
+ }
+ if (args.principal != null) {
+ require(args.keytab != null, "Keytab must be specified when the keytab is specified")
+ UserGroupInformation.loginUserFromKeytab(args.principal, args.keytab)
+ }
}
// In yarn-cluster mode, use yarn.Client as a wrapper around the user class
@@ -597,7 +630,7 @@ object SparkSubmit {
var mainClass: Class[_] = null
try {
- mainClass = Class.forName(childMainClass, true, loader)
+ mainClass = Utils.classForName(childMainClass)
} catch {
case e: ClassNotFoundException =>
e.printStackTrace(printStream)
diff --git a/core/src/main/scala/org/apache/spark/deploy/SparkSubmitArguments.scala b/core/src/main/scala/org/apache/spark/deploy/SparkSubmitArguments.scala
index ebb39c354dff1..b3710073e330c 100644
--- a/core/src/main/scala/org/apache/spark/deploy/SparkSubmitArguments.scala
+++ b/core/src/main/scala/org/apache/spark/deploy/SparkSubmitArguments.scala
@@ -576,7 +576,7 @@ private[deploy] class SparkSubmitArguments(args: Seq[String], env: Map[String, S
System.setSecurityManager(sm)
try {
- Class.forName(mainClass).getMethod("main", classOf[Array[String]])
+ Utils.classForName(mainClass).getMethod("main", classOf[Array[String]])
.invoke(null, Array(HELP))
} catch {
case e: InvocationTargetException =>
diff --git a/core/src/main/scala/org/apache/spark/deploy/history/FsHistoryProvider.scala b/core/src/main/scala/org/apache/spark/deploy/history/FsHistoryProvider.scala
index 2cc465e55fceb..e3060ac3fa1a9 100644
--- a/core/src/main/scala/org/apache/spark/deploy/history/FsHistoryProvider.scala
+++ b/core/src/main/scala/org/apache/spark/deploy/history/FsHistoryProvider.scala
@@ -407,8 +407,8 @@ private[history] class FsHistoryProvider(conf: SparkConf, clock: Clock)
/**
* Comparison function that defines the sort order for application attempts within the same
- * application. Order is: running attempts before complete attempts, running attempts sorted
- * by start time, completed attempts sorted by end time.
+ * application. Order is: attempts are sorted by descending start time.
+ * Most recent attempt state matches with current state of the app.
*
* Normally applications should have a single running attempt; but failure to call sc.stop()
* may cause multiple running attempts to show up.
@@ -418,11 +418,7 @@ private[history] class FsHistoryProvider(conf: SparkConf, clock: Clock)
private def compareAttemptInfo(
a1: FsApplicationAttemptInfo,
a2: FsApplicationAttemptInfo): Boolean = {
- if (a1.completed == a2.completed) {
- if (a1.completed) a1.endTime >= a2.endTime else a1.startTime >= a2.startTime
- } else {
- !a1.completed
- }
+ a1.startTime >= a2.startTime
}
/**
diff --git a/core/src/main/scala/org/apache/spark/deploy/history/HistoryServer.scala b/core/src/main/scala/org/apache/spark/deploy/history/HistoryServer.scala
index 10638afb74900..a076a9c3f984d 100644
--- a/core/src/main/scala/org/apache/spark/deploy/history/HistoryServer.scala
+++ b/core/src/main/scala/org/apache/spark/deploy/history/HistoryServer.scala
@@ -228,7 +228,7 @@ object HistoryServer extends Logging {
val providerName = conf.getOption("spark.history.provider")
.getOrElse(classOf[FsHistoryProvider].getName())
- val provider = Class.forName(providerName)
+ val provider = Utils.classForName(providerName)
.getConstructor(classOf[SparkConf])
.newInstance(conf)
.asInstanceOf[ApplicationHistoryProvider]
diff --git a/core/src/main/scala/org/apache/spark/deploy/master/FileSystemPersistenceEngine.scala b/core/src/main/scala/org/apache/spark/deploy/master/FileSystemPersistenceEngine.scala
index f459ed5b3a1a1..aa379d4cd61e7 100644
--- a/core/src/main/scala/org/apache/spark/deploy/master/FileSystemPersistenceEngine.scala
+++ b/core/src/main/scala/org/apache/spark/deploy/master/FileSystemPersistenceEngine.scala
@@ -21,9 +21,8 @@ import java.io._
import scala.reflect.ClassTag
-import akka.serialization.Serialization
-
import org.apache.spark.Logging
+import org.apache.spark.serializer.{DeserializationStream, SerializationStream, Serializer}
import org.apache.spark.util.Utils
@@ -32,11 +31,11 @@ import org.apache.spark.util.Utils
* Files are deleted when applications and workers are removed.
*
* @param dir Directory to store files. Created if non-existent (but not recursively).
- * @param serialization Used to serialize our objects.
+ * @param serializer Used to serialize our objects.
*/
private[master] class FileSystemPersistenceEngine(
val dir: String,
- val serialization: Serialization)
+ val serializer: Serializer)
extends PersistenceEngine with Logging {
new File(dir).mkdir()
@@ -57,27 +56,31 @@ private[master] class FileSystemPersistenceEngine(
private def serializeIntoFile(file: File, value: AnyRef) {
val created = file.createNewFile()
if (!created) { throw new IllegalStateException("Could not create file: " + file) }
- val serializer = serialization.findSerializerFor(value)
- val serialized = serializer.toBinary(value)
- val out = new FileOutputStream(file)
+ val fileOut = new FileOutputStream(file)
+ var out: SerializationStream = null
Utils.tryWithSafeFinally {
- out.write(serialized)
+ out = serializer.newInstance().serializeStream(fileOut)
+ out.writeObject(value)
} {
- out.close()
+ fileOut.close()
+ if (out != null) {
+ out.close()
+ }
}
}
private def deserializeFromFile[T](file: File)(implicit m: ClassTag[T]): T = {
- val fileData = new Array[Byte](file.length().asInstanceOf[Int])
- val dis = new DataInputStream(new FileInputStream(file))
+ val fileIn = new FileInputStream(file)
+ var in: DeserializationStream = null
try {
- dis.readFully(fileData)
+ in = serializer.newInstance().deserializeStream(fileIn)
+ in.readObject[T]()
} finally {
- dis.close()
+ fileIn.close()
+ if (in != null) {
+ in.close()
+ }
}
- val clazz = m.runtimeClass.asInstanceOf[Class[T]]
- val serializer = serialization.serializerFor(clazz)
- serializer.fromBinary(fileData).asInstanceOf[T]
}
}
diff --git a/core/src/main/scala/org/apache/spark/deploy/master/Master.scala b/core/src/main/scala/org/apache/spark/deploy/master/Master.scala
index 48070768f6edb..4615febf17d24 100644
--- a/core/src/main/scala/org/apache/spark/deploy/master/Master.scala
+++ b/core/src/main/scala/org/apache/spark/deploy/master/Master.scala
@@ -27,11 +27,8 @@ import scala.collection.mutable.{ArrayBuffer, HashMap, HashSet}
import scala.language.postfixOps
import scala.util.Random
-import akka.serialization.Serialization
-import akka.serialization.SerializationExtension
import org.apache.hadoop.fs.Path
-import org.apache.spark.rpc.akka.AkkaRpcEnv
import org.apache.spark.rpc._
import org.apache.spark.{Logging, SecurityManager, SparkConf, SparkException}
import org.apache.spark.deploy.{ApplicationDescription, DriverDescription,
@@ -44,6 +41,7 @@ import org.apache.spark.deploy.master.ui.MasterWebUI
import org.apache.spark.deploy.rest.StandaloneRestServer
import org.apache.spark.metrics.MetricsSystem
import org.apache.spark.scheduler.{EventLoggingListener, ReplayListenerBus}
+import org.apache.spark.serializer.{JavaSerializer, Serializer}
import org.apache.spark.ui.SparkUI
import org.apache.spark.util.{ThreadUtils, SignalLogger, Utils}
@@ -58,9 +56,6 @@ private[master] class Master(
private val forwardMessageThread =
ThreadUtils.newDaemonSingleThreadScheduledExecutor("master-forward-message-thread")
- // TODO Remove it once we don't use akka.serialization.Serialization
- private val actorSystem = rpcEnv.asInstanceOf[AkkaRpcEnv].actorSystem
-
private val hadoopConf = SparkHadoopUtil.get.newConfiguration(conf)
private def createDateFormat = new SimpleDateFormat("yyyyMMddHHmmss") // For application IDs
@@ -161,20 +156,21 @@ private[master] class Master(
masterMetricsSystem.getServletHandlers.foreach(webUi.attachHandler)
applicationMetricsSystem.getServletHandlers.foreach(webUi.attachHandler)
+ val serializer = new JavaSerializer(conf)
val (persistenceEngine_, leaderElectionAgent_) = RECOVERY_MODE match {
case "ZOOKEEPER" =>
logInfo("Persisting recovery state to ZooKeeper")
val zkFactory =
- new ZooKeeperRecoveryModeFactory(conf, SerializationExtension(actorSystem))
+ new ZooKeeperRecoveryModeFactory(conf, serializer)
(zkFactory.createPersistenceEngine(), zkFactory.createLeaderElectionAgent(this))
case "FILESYSTEM" =>
val fsFactory =
- new FileSystemRecoveryModeFactory(conf, SerializationExtension(actorSystem))
+ new FileSystemRecoveryModeFactory(conf, serializer)
(fsFactory.createPersistenceEngine(), fsFactory.createLeaderElectionAgent(this))
case "CUSTOM" =>
- val clazz = Class.forName(conf.get("spark.deploy.recoveryMode.factory"))
- val factory = clazz.getConstructor(classOf[SparkConf], classOf[Serialization])
- .newInstance(conf, SerializationExtension(actorSystem))
+ val clazz = Utils.classForName(conf.get("spark.deploy.recoveryMode.factory"))
+ val factory = clazz.getConstructor(classOf[SparkConf], classOf[Serializer])
+ .newInstance(conf, serializer)
.asInstanceOf[StandaloneRecoveryModeFactory]
(factory.createPersistenceEngine(), factory.createLeaderElectionAgent(this))
case _ =>
@@ -213,7 +209,7 @@ private[master] class Master(
override def receive: PartialFunction[Any, Unit] = {
case ElectedLeader => {
- val (storedApps, storedDrivers, storedWorkers) = persistenceEngine.readPersistedData()
+ val (storedApps, storedDrivers, storedWorkers) = persistenceEngine.readPersistedData(rpcEnv)
state = if (storedApps.isEmpty && storedDrivers.isEmpty && storedWorkers.isEmpty) {
RecoveryState.ALIVE
} else {
diff --git a/core/src/main/scala/org/apache/spark/deploy/master/PersistenceEngine.scala b/core/src/main/scala/org/apache/spark/deploy/master/PersistenceEngine.scala
index a03d460509e03..58a00bceee6af 100644
--- a/core/src/main/scala/org/apache/spark/deploy/master/PersistenceEngine.scala
+++ b/core/src/main/scala/org/apache/spark/deploy/master/PersistenceEngine.scala
@@ -18,6 +18,7 @@
package org.apache.spark.deploy.master
import org.apache.spark.annotation.DeveloperApi
+import org.apache.spark.rpc.RpcEnv
import scala.reflect.ClassTag
@@ -80,8 +81,11 @@ abstract class PersistenceEngine {
* Returns the persisted data sorted by their respective ids (which implies that they're
* sorted by time of creation).
*/
- final def readPersistedData(): (Seq[ApplicationInfo], Seq[DriverInfo], Seq[WorkerInfo]) = {
- (read[ApplicationInfo]("app_"), read[DriverInfo]("driver_"), read[WorkerInfo]("worker_"))
+ final def readPersistedData(
+ rpcEnv: RpcEnv): (Seq[ApplicationInfo], Seq[DriverInfo], Seq[WorkerInfo]) = {
+ rpcEnv.deserialize { () =>
+ (read[ApplicationInfo]("app_"), read[DriverInfo]("driver_"), read[WorkerInfo]("worker_"))
+ }
}
def close() {}
diff --git a/core/src/main/scala/org/apache/spark/deploy/master/RecoveryModeFactory.scala b/core/src/main/scala/org/apache/spark/deploy/master/RecoveryModeFactory.scala
index 351db8fab2041..c4c3283fb73f7 100644
--- a/core/src/main/scala/org/apache/spark/deploy/master/RecoveryModeFactory.scala
+++ b/core/src/main/scala/org/apache/spark/deploy/master/RecoveryModeFactory.scala
@@ -17,10 +17,9 @@
package org.apache.spark.deploy.master
-import akka.serialization.Serialization
-
import org.apache.spark.{Logging, SparkConf}
import org.apache.spark.annotation.DeveloperApi
+import org.apache.spark.serializer.Serializer
/**
* ::DeveloperApi::
@@ -30,7 +29,7 @@ import org.apache.spark.annotation.DeveloperApi
*
*/
@DeveloperApi
-abstract class StandaloneRecoveryModeFactory(conf: SparkConf, serializer: Serialization) {
+abstract class StandaloneRecoveryModeFactory(conf: SparkConf, serializer: Serializer) {
/**
* PersistenceEngine defines how the persistent data(Information about worker, driver etc..)
@@ -49,7 +48,7 @@ abstract class StandaloneRecoveryModeFactory(conf: SparkConf, serializer: Serial
* LeaderAgent in this case is a no-op. Since leader is forever leader as the actual
* recovery is made by restoring from filesystem.
*/
-private[master] class FileSystemRecoveryModeFactory(conf: SparkConf, serializer: Serialization)
+private[master] class FileSystemRecoveryModeFactory(conf: SparkConf, serializer: Serializer)
extends StandaloneRecoveryModeFactory(conf, serializer) with Logging {
val RECOVERY_DIR = conf.get("spark.deploy.recoveryDirectory", "")
@@ -64,7 +63,7 @@ private[master] class FileSystemRecoveryModeFactory(conf: SparkConf, serializer:
}
}
-private[master] class ZooKeeperRecoveryModeFactory(conf: SparkConf, serializer: Serialization)
+private[master] class ZooKeeperRecoveryModeFactory(conf: SparkConf, serializer: Serializer)
extends StandaloneRecoveryModeFactory(conf, serializer) {
def createPersistenceEngine(): PersistenceEngine = {
diff --git a/core/src/main/scala/org/apache/spark/deploy/master/ZooKeeperPersistenceEngine.scala b/core/src/main/scala/org/apache/spark/deploy/master/ZooKeeperPersistenceEngine.scala
index 328d95a7a0c68..563831cc6b8dd 100644
--- a/core/src/main/scala/org/apache/spark/deploy/master/ZooKeeperPersistenceEngine.scala
+++ b/core/src/main/scala/org/apache/spark/deploy/master/ZooKeeperPersistenceEngine.scala
@@ -17,7 +17,7 @@
package org.apache.spark.deploy.master
-import akka.serialization.Serialization
+import java.nio.ByteBuffer
import scala.collection.JavaConversions._
import scala.reflect.ClassTag
@@ -27,9 +27,10 @@ import org.apache.zookeeper.CreateMode
import org.apache.spark.{Logging, SparkConf}
import org.apache.spark.deploy.SparkCuratorUtil
+import org.apache.spark.serializer.Serializer
-private[master] class ZooKeeperPersistenceEngine(conf: SparkConf, val serialization: Serialization)
+private[master] class ZooKeeperPersistenceEngine(conf: SparkConf, val serializer: Serializer)
extends PersistenceEngine
with Logging {
@@ -57,17 +58,16 @@ private[master] class ZooKeeperPersistenceEngine(conf: SparkConf, val serializat
}
private def serializeIntoFile(path: String, value: AnyRef) {
- val serializer = serialization.findSerializerFor(value)
- val serialized = serializer.toBinary(value)
- zk.create().withMode(CreateMode.PERSISTENT).forPath(path, serialized)
+ val serialized = serializer.newInstance().serialize(value)
+ val bytes = new Array[Byte](serialized.remaining())
+ serialized.get(bytes)
+ zk.create().withMode(CreateMode.PERSISTENT).forPath(path, bytes)
}
private def deserializeFromFile[T](filename: String)(implicit m: ClassTag[T]): Option[T] = {
val fileData = zk.getData().forPath(WORKING_DIR + "/" + filename)
- val clazz = m.runtimeClass.asInstanceOf[Class[T]]
- val serializer = serialization.serializerFor(clazz)
try {
- Some(serializer.fromBinary(fileData).asInstanceOf[T])
+ Some(serializer.newInstance().deserialize[T](ByteBuffer.wrap(fileData)))
} catch {
case e: Exception => {
logWarning("Exception while reading persisted file, deleting", e)
diff --git a/core/src/main/scala/org/apache/spark/deploy/rest/SubmitRestProtocolMessage.scala b/core/src/main/scala/org/apache/spark/deploy/rest/SubmitRestProtocolMessage.scala
index e6615a3174ce1..ef5a7e35ad562 100644
--- a/core/src/main/scala/org/apache/spark/deploy/rest/SubmitRestProtocolMessage.scala
+++ b/core/src/main/scala/org/apache/spark/deploy/rest/SubmitRestProtocolMessage.scala
@@ -128,7 +128,7 @@ private[spark] object SubmitRestProtocolMessage {
*/
def fromJson(json: String): SubmitRestProtocolMessage = {
val className = parseAction(json)
- val clazz = Class.forName(packagePrefix + "." + className)
+ val clazz = Utils.classForName(packagePrefix + "." + className)
.asSubclass[SubmitRestProtocolMessage](classOf[SubmitRestProtocolMessage])
fromJson(json, clazz)
}
diff --git a/core/src/main/scala/org/apache/spark/deploy/worker/DriverWrapper.scala b/core/src/main/scala/org/apache/spark/deploy/worker/DriverWrapper.scala
index 2d6be3042c905..6799f78ec0c19 100644
--- a/core/src/main/scala/org/apache/spark/deploy/worker/DriverWrapper.scala
+++ b/core/src/main/scala/org/apache/spark/deploy/worker/DriverWrapper.scala
@@ -53,7 +53,7 @@ object DriverWrapper {
Thread.currentThread.setContextClassLoader(loader)
// Delegate to supplied main class
- val clazz = Class.forName(mainClass, true, loader)
+ val clazz = Utils.classForName(mainClass)
val mainMethod = clazz.getMethod("main", classOf[Array[String]])
mainMethod.invoke(null, extraArgs.toArray[String])
diff --git a/core/src/main/scala/org/apache/spark/deploy/worker/WorkerArguments.scala b/core/src/main/scala/org/apache/spark/deploy/worker/WorkerArguments.scala
index e89d076802215..5181142c5f80e 100644
--- a/core/src/main/scala/org/apache/spark/deploy/worker/WorkerArguments.scala
+++ b/core/src/main/scala/org/apache/spark/deploy/worker/WorkerArguments.scala
@@ -149,6 +149,7 @@ private[worker] class WorkerArguments(args: Array[String], conf: SparkConf) {
val ibmVendor = System.getProperty("java.vendor").contains("IBM")
var totalMb = 0
try {
+ // scalastyle:off classforname
val bean = ManagementFactory.getOperatingSystemMXBean()
if (ibmVendor) {
val beanClass = Class.forName("com.ibm.lang.management.OperatingSystemMXBean")
@@ -159,6 +160,7 @@ private[worker] class WorkerArguments(args: Array[String], conf: SparkConf) {
val method = beanClass.getDeclaredMethod("getTotalPhysicalMemorySize")
totalMb = (method.invoke(bean).asInstanceOf[Long] / 1024 / 1024).toInt
}
+ // scalastyle:on classforname
} catch {
case e: Exception => {
totalMb = 2*1024
diff --git a/core/src/main/scala/org/apache/spark/executor/Executor.scala b/core/src/main/scala/org/apache/spark/executor/Executor.scala
index f7ef92bc80f91..e76664f1bd7b0 100644
--- a/core/src/main/scala/org/apache/spark/executor/Executor.scala
+++ b/core/src/main/scala/org/apache/spark/executor/Executor.scala
@@ -209,15 +209,19 @@ private[spark] class Executor(
// Run the actual task and measure its runtime.
taskStart = System.currentTimeMillis()
- val value = try {
- task.run(taskAttemptId = taskId, attemptNumber = attemptNumber)
+ var threwException = true
+ val (value, accumUpdates) = try {
+ val res = task.run(
+ taskAttemptId = taskId,
+ attemptNumber = attemptNumber,
+ metricsSystem = env.metricsSystem)
+ threwException = false
+ res
} finally {
- // Note: this memory freeing logic is duplicated in DAGScheduler.runLocallyWithinThread;
- // when changing this, make sure to update both copies.
val freedMemory = taskMemoryManager.cleanUpAllAllocatedMemory()
if (freedMemory > 0) {
val errMsg = s"Managed memory leak detected; size = $freedMemory bytes, TID = $taskId"
- if (conf.getBoolean("spark.unsafe.exceptionOnMemoryLeak", false)) {
+ if (conf.getBoolean("spark.unsafe.exceptionOnMemoryLeak", false) && !threwException) {
throw new SparkException(errMsg)
} else {
logError(errMsg)
@@ -247,7 +251,6 @@ private[spark] class Executor(
m.setResultSerializationTime(afterSerialization - beforeSerialization)
}
- val accumUpdates = Accumulators.values
val directResult = new DirectTaskResult(valueBytes, accumUpdates, task.metrics.orNull)
val serializedDirectResult = ser.serialize(directResult)
val resultSize = serializedDirectResult.limit
@@ -314,8 +317,6 @@ private[spark] class Executor(
env.shuffleMemoryManager.releaseMemoryForThisThread()
// Release memory used by this thread for unrolling blocks
env.blockManager.memoryStore.releaseUnrollMemoryForThisThread()
- // Release memory used by this thread for accumulators
- Accumulators.clear()
runningTasks.remove(taskId)
}
}
@@ -356,7 +357,7 @@ private[spark] class Executor(
logInfo("Using REPL class URI: " + classUri)
try {
val _userClassPathFirst: java.lang.Boolean = userClassPathFirst
- val klass = Class.forName("org.apache.spark.repl.ExecutorClassLoader")
+ val klass = Utils.classForName("org.apache.spark.repl.ExecutorClassLoader")
.asInstanceOf[Class[_ <: ClassLoader]]
val constructor = klass.getConstructor(classOf[SparkConf], classOf[String],
classOf[ClassLoader], classOf[Boolean])
@@ -424,6 +425,7 @@ private[spark] class Executor(
metrics.updateShuffleReadMetrics()
metrics.updateInputMetrics()
metrics.setJvmGCTime(curGCTime - taskRunner.startGCTime)
+ metrics.updateAccumulators()
if (isLocal) {
// JobProgressListener will hold an reference of it during
diff --git a/core/src/main/scala/org/apache/spark/executor/TaskMetrics.scala b/core/src/main/scala/org/apache/spark/executor/TaskMetrics.scala
index a3b4561b07e7f..42207a9553592 100644
--- a/core/src/main/scala/org/apache/spark/executor/TaskMetrics.scala
+++ b/core/src/main/scala/org/apache/spark/executor/TaskMetrics.scala
@@ -17,11 +17,15 @@
package org.apache.spark.executor
+import java.io.{IOException, ObjectInputStream}
+import java.util.concurrent.ConcurrentHashMap
+
import scala.collection.mutable.ArrayBuffer
import org.apache.spark.annotation.DeveloperApi
import org.apache.spark.executor.DataReadMethod.DataReadMethod
import org.apache.spark.storage.{BlockId, BlockStatus}
+import org.apache.spark.util.Utils
/**
* :: DeveloperApi ::
@@ -210,10 +214,42 @@ class TaskMetrics extends Serializable {
private[spark] def updateInputMetrics(): Unit = synchronized {
inputMetrics.foreach(_.updateBytesRead())
}
+
+ @throws(classOf[IOException])
+ private def readObject(in: ObjectInputStream): Unit = Utils.tryOrIOException {
+ in.defaultReadObject()
+ // Get the hostname from cached data, since hostname is the order of number of nodes in
+ // cluster, so using cached hostname will decrease the object number and alleviate the GC
+ // overhead.
+ _hostname = TaskMetrics.getCachedHostName(_hostname)
+ }
+
+ private var _accumulatorUpdates: Map[Long, Any] = Map.empty
+ @transient private var _accumulatorsUpdater: () => Map[Long, Any] = null
+
+ private[spark] def updateAccumulators(): Unit = synchronized {
+ _accumulatorUpdates = _accumulatorsUpdater()
+ }
+
+ /**
+ * Return the latest updates of accumulators in this task.
+ */
+ def accumulatorUpdates(): Map[Long, Any] = _accumulatorUpdates
+
+ private[spark] def setAccumulatorsUpdater(accumulatorsUpdater: () => Map[Long, Any]): Unit = {
+ _accumulatorsUpdater = accumulatorsUpdater
+ }
}
private[spark] object TaskMetrics {
+ private val hostNameCache = new ConcurrentHashMap[String, String]()
+
def empty: TaskMetrics = new TaskMetrics
+
+ def getCachedHostName(host: String): String = {
+ val canonicalHost = hostNameCache.putIfAbsent(host, host)
+ if (canonicalHost != null) canonicalHost else host
+ }
}
/**
diff --git a/core/src/main/scala/org/apache/spark/io/CompressionCodec.scala b/core/src/main/scala/org/apache/spark/io/CompressionCodec.scala
index 0d8ac1f80a9f4..607d5a321efca 100644
--- a/core/src/main/scala/org/apache/spark/io/CompressionCodec.scala
+++ b/core/src/main/scala/org/apache/spark/io/CompressionCodec.scala
@@ -63,8 +63,7 @@ private[spark] object CompressionCodec {
def createCodec(conf: SparkConf, codecName: String): CompressionCodec = {
val codecClass = shortCompressionCodecNames.getOrElse(codecName.toLowerCase, codecName)
val codec = try {
- val ctor = Class.forName(codecClass, true, Utils.getContextOrSparkClassLoader)
- .getConstructor(classOf[SparkConf])
+ val ctor = Utils.classForName(codecClass).getConstructor(classOf[SparkConf])
Some(ctor.newInstance(conf).asInstanceOf[CompressionCodec])
} catch {
case e: ClassNotFoundException => None
diff --git a/core/src/main/scala/org/apache/spark/mapred/SparkHadoopMapRedUtil.scala b/core/src/main/scala/org/apache/spark/mapred/SparkHadoopMapRedUtil.scala
index 818f7a4c8d422..87df42748be44 100644
--- a/core/src/main/scala/org/apache/spark/mapred/SparkHadoopMapRedUtil.scala
+++ b/core/src/main/scala/org/apache/spark/mapred/SparkHadoopMapRedUtil.scala
@@ -26,6 +26,7 @@ import org.apache.hadoop.mapreduce.{OutputCommitter => MapReduceOutputCommitter}
import org.apache.spark.executor.CommitDeniedException
import org.apache.spark.{Logging, SparkEnv, TaskContext}
+import org.apache.spark.util.{Utils => SparkUtils}
private[spark]
trait SparkHadoopMapRedUtil {
@@ -64,10 +65,10 @@ trait SparkHadoopMapRedUtil {
private def firstAvailableClass(first: String, second: String): Class[_] = {
try {
- Class.forName(first)
+ SparkUtils.classForName(first)
} catch {
case e: ClassNotFoundException =>
- Class.forName(second)
+ SparkUtils.classForName(second)
}
}
}
diff --git a/core/src/main/scala/org/apache/spark/mapreduce/SparkHadoopMapReduceUtil.scala b/core/src/main/scala/org/apache/spark/mapreduce/SparkHadoopMapReduceUtil.scala
index 390d148bc97f9..943ebcb7bd0a1 100644
--- a/core/src/main/scala/org/apache/spark/mapreduce/SparkHadoopMapReduceUtil.scala
+++ b/core/src/main/scala/org/apache/spark/mapreduce/SparkHadoopMapReduceUtil.scala
@@ -21,6 +21,7 @@ import java.lang.{Boolean => JBoolean, Integer => JInteger}
import org.apache.hadoop.conf.Configuration
import org.apache.hadoop.mapreduce.{JobContext, JobID, TaskAttemptContext, TaskAttemptID}
+import org.apache.spark.util.Utils
private[spark]
trait SparkHadoopMapReduceUtil {
@@ -46,7 +47,7 @@ trait SparkHadoopMapReduceUtil {
isMap: Boolean,
taskId: Int,
attemptId: Int): TaskAttemptID = {
- val klass = Class.forName("org.apache.hadoop.mapreduce.TaskAttemptID")
+ val klass = Utils.classForName("org.apache.hadoop.mapreduce.TaskAttemptID")
try {
// First, attempt to use the old-style constructor that takes a boolean isMap
// (not available in YARN)
@@ -57,7 +58,7 @@ trait SparkHadoopMapReduceUtil {
} catch {
case exc: NoSuchMethodException => {
// If that failed, look for the new constructor that takes a TaskType (not available in 1.x)
- val taskTypeClass = Class.forName("org.apache.hadoop.mapreduce.TaskType")
+ val taskTypeClass = Utils.classForName("org.apache.hadoop.mapreduce.TaskType")
.asInstanceOf[Class[Enum[_]]]
val taskType = taskTypeClass.getMethod("valueOf", classOf[String]).invoke(
taskTypeClass, if (isMap) "MAP" else "REDUCE")
@@ -71,10 +72,10 @@ trait SparkHadoopMapReduceUtil {
private def firstAvailableClass(first: String, second: String): Class[_] = {
try {
- Class.forName(first)
+ Utils.classForName(first)
} catch {
case e: ClassNotFoundException =>
- Class.forName(second)
+ Utils.classForName(second)
}
}
}
diff --git a/core/src/main/scala/org/apache/spark/metrics/MetricsSystem.scala b/core/src/main/scala/org/apache/spark/metrics/MetricsSystem.scala
index ed5131c79fdc5..4517f465ebd3b 100644
--- a/core/src/main/scala/org/apache/spark/metrics/MetricsSystem.scala
+++ b/core/src/main/scala/org/apache/spark/metrics/MetricsSystem.scala
@@ -20,6 +20,8 @@ package org.apache.spark.metrics
import java.util.Properties
import java.util.concurrent.TimeUnit
+import org.apache.spark.util.Utils
+
import scala.collection.mutable
import com.codahale.metrics.{Metric, MetricFilter, MetricRegistry}
@@ -140,6 +142,9 @@ private[spark] class MetricsSystem private (
} else { defaultName }
}
+ def getSourcesByName(sourceName: String): Seq[Source] =
+ sources.filter(_.sourceName == sourceName)
+
def registerSource(source: Source) {
sources += source
try {
@@ -166,7 +171,7 @@ private[spark] class MetricsSystem private (
sourceConfigs.foreach { kv =>
val classPath = kv._2.getProperty("class")
try {
- val source = Class.forName(classPath).newInstance()
+ val source = Utils.classForName(classPath).newInstance()
registerSource(source.asInstanceOf[Source])
} catch {
case e: Exception => logError("Source class " + classPath + " cannot be instantiated", e)
@@ -182,7 +187,7 @@ private[spark] class MetricsSystem private (
val classPath = kv._2.getProperty("class")
if (null != classPath) {
try {
- val sink = Class.forName(classPath)
+ val sink = Utils.classForName(classPath)
.getConstructor(classOf[Properties], classOf[MetricRegistry], classOf[SecurityManager])
.newInstance(kv._2, registry, securityMgr)
if (kv._1 == "servlet") {
diff --git a/core/src/main/scala/org/apache/spark/rdd/CoGroupedRDD.scala b/core/src/main/scala/org/apache/spark/rdd/CoGroupedRDD.scala
index 658e8c8b89318..130b58882d8ee 100644
--- a/core/src/main/scala/org/apache/spark/rdd/CoGroupedRDD.scala
+++ b/core/src/main/scala/org/apache/spark/rdd/CoGroupedRDD.scala
@@ -94,13 +94,14 @@ class CoGroupedRDD[K](@transient var rdds: Seq[RDD[_ <: Product2[K, _]]], part:
}
override def getDependencies: Seq[Dependency[_]] = {
- rdds.map { rdd: RDD[_ <: Product2[K, _]] =>
+ rdds.map { rdd: RDD[_] =>
if (rdd.partitioner == Some(part)) {
logDebug("Adding one-to-one dependency with " + rdd)
new OneToOneDependency(rdd)
} else {
logDebug("Adding shuffle dependency with " + rdd)
- new ShuffleDependency[K, Any, CoGroupCombiner](rdd, part, serializer)
+ new ShuffleDependency[K, Any, CoGroupCombiner](
+ rdd.asInstanceOf[RDD[_ <: Product2[K, _]]], part, serializer)
}
}
}
@@ -133,7 +134,7 @@ class CoGroupedRDD[K](@transient var rdds: Seq[RDD[_ <: Product2[K, _]]], part:
// A list of (rdd iterator, dependency number) pairs
val rddIterators = new ArrayBuffer[(Iterator[Product2[K, Any]], Int)]
for ((dep, depNum) <- dependencies.zipWithIndex) dep match {
- case oneToOneDependency: OneToOneDependency[Product2[K, Any]] =>
+ case oneToOneDependency: OneToOneDependency[Product2[K, Any]] @unchecked =>
val dependencyPartition = split.narrowDeps(depNum).get.split
// Read them from the parent
val it = oneToOneDependency.rdd.iterator(dependencyPartition, context)
diff --git a/core/src/main/scala/org/apache/spark/rdd/CoalescedRDD.scala b/core/src/main/scala/org/apache/spark/rdd/CoalescedRDD.scala
index 663eebb8e4191..90d9735cb3f69 100644
--- a/core/src/main/scala/org/apache/spark/rdd/CoalescedRDD.scala
+++ b/core/src/main/scala/org/apache/spark/rdd/CoalescedRDD.scala
@@ -69,7 +69,7 @@ private[spark] case class CoalescedRDDPartition(
* the preferred location of each new partition overlaps with as many preferred locations of its
* parent partitions
* @param prev RDD to be coalesced
- * @param maxPartitions number of desired partitions in the coalesced RDD
+ * @param maxPartitions number of desired partitions in the coalesced RDD (must be positive)
* @param balanceSlack used to trade-off balance and locality. 1.0 is all locality, 0 is all balance
*/
private[spark] class CoalescedRDD[T: ClassTag](
@@ -78,6 +78,9 @@ private[spark] class CoalescedRDD[T: ClassTag](
balanceSlack: Double = 0.10)
extends RDD[T](prev.context, Nil) { // Nil since we implement getDependencies
+ require(maxPartitions > 0 || maxPartitions == prev.partitions.length,
+ s"Number of partitions ($maxPartitions) must be positive.")
+
override def getPartitions: Array[Partition] = {
val pc = new PartitionCoalescer(maxPartitions, prev, balanceSlack)
diff --git a/core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala b/core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala
index bee59a437f120..f1c17369cb48c 100644
--- a/core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala
+++ b/core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala
@@ -383,11 +383,11 @@ private[spark] object HadoopRDD extends Logging {
private[spark] class SplitInfoReflections {
val inputSplitWithLocationInfo =
- Class.forName("org.apache.hadoop.mapred.InputSplitWithLocationInfo")
+ Utils.classForName("org.apache.hadoop.mapred.InputSplitWithLocationInfo")
val getLocationInfo = inputSplitWithLocationInfo.getMethod("getLocationInfo")
- val newInputSplit = Class.forName("org.apache.hadoop.mapreduce.InputSplit")
+ val newInputSplit = Utils.classForName("org.apache.hadoop.mapreduce.InputSplit")
val newGetLocationInfo = newInputSplit.getMethod("getLocationInfo")
- val splitLocationInfo = Class.forName("org.apache.hadoop.mapred.SplitLocationInfo")
+ val splitLocationInfo = Utils.classForName("org.apache.hadoop.mapred.SplitLocationInfo")
val isInMemory = splitLocationInfo.getMethod("isInMemory")
val getLocation = splitLocationInfo.getMethod("getLocation")
}
diff --git a/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala b/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala
index 91a6a2d039852..326fafb230a40 100644
--- a/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala
+++ b/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala
@@ -881,7 +881,7 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)])
}
buf
} : Seq[V]
- val res = self.context.runJob(self, process, Array(index), false)
+ val res = self.context.runJob(self, process, Array(index))
res(0)
case None =>
self.filter(_._1 == key).map(_._2).collect()
diff --git a/core/src/main/scala/org/apache/spark/rdd/RDD.scala b/core/src/main/scala/org/apache/spark/rdd/RDD.scala
index 9f7ebae3e9af3..6d61d227382d7 100644
--- a/core/src/main/scala/org/apache/spark/rdd/RDD.scala
+++ b/core/src/main/scala/org/apache/spark/rdd/RDD.scala
@@ -897,7 +897,7 @@ abstract class RDD[T: ClassTag](
*/
def toLocalIterator: Iterator[T] = withScope {
def collectPartition(p: Int): Array[T] = {
- sc.runJob(this, (iter: Iterator[T]) => iter.toArray, Seq(p), allowLocal = false).head
+ sc.runJob(this, (iter: Iterator[T]) => iter.toArray, Seq(p)).head
}
(0 until partitions.length).iterator.flatMap(i => collectPartition(i))
}
@@ -1082,7 +1082,9 @@ abstract class RDD[T: ClassTag](
val scale = math.max(math.ceil(math.pow(numPartitions, 1.0 / depth)).toInt, 2)
// If creating an extra level doesn't help reduce
// the wall-clock time, we stop tree aggregation.
- while (numPartitions > scale + numPartitions / scale) {
+
+ // Don't trigger TreeAggregation when it doesn't save wall-clock time
+ while (numPartitions > scale + math.ceil(numPartitions.toDouble / scale)) {
numPartitions /= scale
val curNumPartitions = numPartitions
partiallyAggregated = partiallyAggregated.mapPartitionsWithIndex {
@@ -1273,7 +1275,7 @@ abstract class RDD[T: ClassTag](
val left = num - buf.size
val p = partsScanned until math.min(partsScanned + numPartsToTry, totalParts)
- val res = sc.runJob(this, (it: Iterator[T]) => it.take(left).toArray, p, allowLocal = true)
+ val res = sc.runJob(this, (it: Iterator[T]) => it.take(left).toArray, p)
res.foreach(buf ++= _.take(num - buf.size))
partsScanned += numPartsToTry
diff --git a/core/src/main/scala/org/apache/spark/rdd/ZippedWithIndexRDD.scala b/core/src/main/scala/org/apache/spark/rdd/ZippedWithIndexRDD.scala
index 523aaf2b860b5..e277ae28d588f 100644
--- a/core/src/main/scala/org/apache/spark/rdd/ZippedWithIndexRDD.scala
+++ b/core/src/main/scala/org/apache/spark/rdd/ZippedWithIndexRDD.scala
@@ -50,8 +50,7 @@ class ZippedWithIndexRDD[T: ClassTag](@transient prev: RDD[T]) extends RDD[(T, L
prev.context.runJob(
prev,
Utils.getIteratorSize _,
- 0 until n - 1, // do not need to count the last partition
- allowLocal = false
+ 0 until n - 1 // do not need to count the last partition
).scanLeft(0L)(_ + _)
}
}
diff --git a/core/src/main/scala/org/apache/spark/rpc/RpcEnv.scala b/core/src/main/scala/org/apache/spark/rpc/RpcEnv.scala
index 1709bdf560b6f..29debe8081308 100644
--- a/core/src/main/scala/org/apache/spark/rpc/RpcEnv.scala
+++ b/core/src/main/scala/org/apache/spark/rpc/RpcEnv.scala
@@ -39,8 +39,7 @@ private[spark] object RpcEnv {
val rpcEnvNames = Map("akka" -> "org.apache.spark.rpc.akka.AkkaRpcEnvFactory")
val rpcEnvName = conf.get("spark.rpc", "akka")
val rpcEnvFactoryClassName = rpcEnvNames.getOrElse(rpcEnvName.toLowerCase, rpcEnvName)
- Class.forName(rpcEnvFactoryClassName, true, Utils.getContextOrSparkClassLoader).
- newInstance().asInstanceOf[RpcEnvFactory]
+ Utils.classForName(rpcEnvFactoryClassName).newInstance().asInstanceOf[RpcEnvFactory]
}
def create(
@@ -140,6 +139,12 @@ private[spark] abstract class RpcEnv(conf: SparkConf) {
* creating it manually because different [[RpcEnv]] may have different formats.
*/
def uriOf(systemName: String, address: RpcAddress, endpointName: String): String
+
+ /**
+ * [[RpcEndpointRef]] cannot be deserialized without [[RpcEnv]]. So when deserializing any object
+ * that contains [[RpcEndpointRef]]s, the deserialization codes should be wrapped by this method.
+ */
+ def deserialize[T](deserializationAction: () => T): T
}
diff --git a/core/src/main/scala/org/apache/spark/rpc/akka/AkkaRpcEnv.scala b/core/src/main/scala/org/apache/spark/rpc/akka/AkkaRpcEnv.scala
index f2d87f68341af..fc17542abf81d 100644
--- a/core/src/main/scala/org/apache/spark/rpc/akka/AkkaRpcEnv.scala
+++ b/core/src/main/scala/org/apache/spark/rpc/akka/AkkaRpcEnv.scala
@@ -28,7 +28,7 @@ import akka.actor.{ActorSystem, ExtendedActorSystem, Actor, ActorRef, Props, Add
import akka.event.Logging.Error
import akka.pattern.{ask => akkaAsk}
import akka.remote.{AssociationEvent, AssociatedEvent, DisassociatedEvent, AssociationErrorEvent}
-import com.google.common.util.concurrent.MoreExecutors
+import akka.serialization.JavaSerializer
import org.apache.spark.{SparkException, Logging, SparkConf}
import org.apache.spark.rpc._
@@ -239,6 +239,12 @@ private[spark] class AkkaRpcEnv private[akka] (
}
override def toString: String = s"${getClass.getSimpleName}($actorSystem)"
+
+ override def deserialize[T](deserializationAction: () => T): T = {
+ JavaSerializer.currentSystem.withValue(actorSystem.asInstanceOf[ExtendedActorSystem]) {
+ deserializationAction()
+ }
+ }
}
private[spark] class AkkaRpcEnvFactory extends RpcEnvFactory {
@@ -315,6 +321,12 @@ private[akka] class AkkaRpcEndpointRef(
override def toString: String = s"${getClass.getSimpleName}($actorRef)"
+ final override def equals(that: Any): Boolean = that match {
+ case other: AkkaRpcEndpointRef => actorRef == other.actorRef
+ case _ => false
+ }
+
+ final override def hashCode(): Int = if (actorRef == null) 0 else actorRef.hashCode()
}
/**
diff --git a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala
index f3d87ee5c4fd1..552dabcfa5139 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala
@@ -22,7 +22,8 @@ import java.util.Properties
import java.util.concurrent.TimeUnit
import java.util.concurrent.atomic.AtomicInteger
-import scala.collection.mutable.{ArrayBuffer, HashMap, HashSet, Map, Stack}
+import scala.collection.Map
+import scala.collection.mutable.{ArrayBuffer, HashMap, HashSet, Stack}
import scala.concurrent.duration._
import scala.language.existentials
import scala.language.postfixOps
@@ -37,7 +38,6 @@ import org.apache.spark.partial.{ApproximateActionListener, ApproximateEvaluator
import org.apache.spark.rdd.RDD
import org.apache.spark.rpc.RpcTimeout
import org.apache.spark.storage._
-import org.apache.spark.unsafe.memory.TaskMemoryManager
import org.apache.spark.util._
import org.apache.spark.storage.BlockManagerMessages.BlockManagerHeartbeat
@@ -127,10 +127,6 @@ class DAGScheduler(
// This is only safe because DAGScheduler runs in a single thread.
private val closureSerializer = SparkEnv.get.closureSerializer.newInstance()
-
- /** If enabled, we may run certain actions like take() and first() locally. */
- private val localExecutionEnabled = sc.getConf.getBoolean("spark.localExecution.enabled", false)
-
/** If enabled, FetchFailed will not cause stage retry, in order to surface the problem. */
private val disallowStageRetryForTest = sc.getConf.getBoolean("spark.test.noStageRetry", false)
@@ -514,7 +510,6 @@ class DAGScheduler(
func: (TaskContext, Iterator[T]) => U,
partitions: Seq[Int],
callSite: CallSite,
- allowLocal: Boolean,
resultHandler: (Int, U) => Unit,
properties: Properties): JobWaiter[U] = {
// Check to make sure we are not launching a task on a partition that does not exist.
@@ -534,7 +529,7 @@ class DAGScheduler(
val func2 = func.asInstanceOf[(TaskContext, Iterator[_]) => _]
val waiter = new JobWaiter(this, jobId, partitions.size, resultHandler)
eventProcessLoop.post(JobSubmitted(
- jobId, rdd, func2, partitions.toArray, allowLocal, callSite, waiter,
+ jobId, rdd, func2, partitions.toArray, callSite, waiter,
SerializationUtils.clone(properties)))
waiter
}
@@ -544,11 +539,10 @@ class DAGScheduler(
func: (TaskContext, Iterator[T]) => U,
partitions: Seq[Int],
callSite: CallSite,
- allowLocal: Boolean,
resultHandler: (Int, U) => Unit,
properties: Properties): Unit = {
val start = System.nanoTime
- val waiter = submitJob(rdd, func, partitions, callSite, allowLocal, resultHandler, properties)
+ val waiter = submitJob(rdd, func, partitions, callSite, resultHandler, properties)
waiter.awaitResult() match {
case JobSucceeded =>
logInfo("Job %d finished: %s, took %f s".format
@@ -556,6 +550,9 @@ class DAGScheduler(
case JobFailed(exception: Exception) =>
logInfo("Job %d failed: %s, took %f s".format
(waiter.jobId, callSite.shortForm, (System.nanoTime - start) / 1e9))
+ // SPARK-8644: Include user stack trace in exceptions coming from DAGScheduler.
+ val callerStackTrace = Thread.currentThread().getStackTrace.tail
+ exception.setStackTrace(exception.getStackTrace ++ callerStackTrace)
throw exception
}
}
@@ -572,8 +569,7 @@ class DAGScheduler(
val partitions = (0 until rdd.partitions.size).toArray
val jobId = nextJobId.getAndIncrement()
eventProcessLoop.post(JobSubmitted(
- jobId, rdd, func2, partitions, allowLocal = false, callSite, listener,
- SerializationUtils.clone(properties)))
+ jobId, rdd, func2, partitions, callSite, listener, SerializationUtils.clone(properties)))
listener.awaitResult() // Will throw an exception if the job fails
}
@@ -650,73 +646,6 @@ class DAGScheduler(
}
}
- /**
- * Run a job on an RDD locally, assuming it has only a single partition and no dependencies.
- * We run the operation in a separate thread just in case it takes a bunch of time, so that we
- * don't block the DAGScheduler event loop or other concurrent jobs.
- */
- protected def runLocally(job: ActiveJob) {
- logInfo("Computing the requested partition locally")
- new Thread("Local computation of job " + job.jobId) {
- override def run() {
- runLocallyWithinThread(job)
- }
- }.start()
- }
-
- // Broken out for easier testing in DAGSchedulerSuite.
- protected def runLocallyWithinThread(job: ActiveJob) {
- var jobResult: JobResult = JobSucceeded
- try {
- val rdd = job.finalStage.rdd
- val split = rdd.partitions(job.partitions(0))
- val taskMemoryManager = new TaskMemoryManager(env.executorMemoryManager)
- val taskContext =
- new TaskContextImpl(
- job.finalStage.id,
- job.partitions(0),
- taskAttemptId = 0,
- attemptNumber = 0,
- taskMemoryManager = taskMemoryManager,
- runningLocally = true)
- TaskContext.setTaskContext(taskContext)
- try {
- val result = job.func(taskContext, rdd.iterator(split, taskContext))
- job.listener.taskSucceeded(0, result)
- } finally {
- taskContext.markTaskCompleted()
- TaskContext.unset()
- // Note: this memory freeing logic is duplicated in Executor.run(); when changing this,
- // make sure to update both copies.
- val freedMemory = taskMemoryManager.cleanUpAllAllocatedMemory()
- if (freedMemory > 0) {
- if (sc.getConf.getBoolean("spark.unsafe.exceptionOnMemoryLeak", false)) {
- throw new SparkException(s"Managed memory leak detected; size = $freedMemory bytes")
- } else {
- logError(s"Managed memory leak detected; size = $freedMemory bytes")
- }
- }
- }
- } catch {
- case e: Exception =>
- val exception = new SparkDriverExecutionException(e)
- jobResult = JobFailed(exception)
- job.listener.jobFailed(exception)
- case oom: OutOfMemoryError =>
- val exception = new SparkException("Local job aborted due to out of memory error", oom)
- jobResult = JobFailed(exception)
- job.listener.jobFailed(exception)
- } finally {
- val s = job.finalStage
- // clean up data structures that were populated for a local job,
- // but that won't get cleaned up via the normal paths through
- // completion events or stage abort
- stageIdToStage -= s.id
- jobIdToStageIds -= job.jobId
- listenerBus.post(SparkListenerJobEnd(job.jobId, clock.getTimeMillis(), jobResult))
- }
- }
-
/** Finds the earliest-created active job that needs the stage */
// TODO: Probably should actually find among the active jobs that need this
// stage the one with the highest priority (highest-priority pool, earliest created).
@@ -779,7 +708,6 @@ class DAGScheduler(
finalRDD: RDD[_],
func: (TaskContext, Iterator[_]) => _,
partitions: Array[Int],
- allowLocal: Boolean,
callSite: CallSite,
listener: JobListener,
properties: Properties) {
@@ -797,29 +725,20 @@ class DAGScheduler(
if (finalStage != null) {
val job = new ActiveJob(jobId, finalStage, func, partitions, callSite, listener, properties)
clearCacheLocs()
- logInfo("Got job %s (%s) with %d output partitions (allowLocal=%s)".format(
- job.jobId, callSite.shortForm, partitions.length, allowLocal))
+ logInfo("Got job %s (%s) with %d output partitions".format(
+ job.jobId, callSite.shortForm, partitions.length))
logInfo("Final stage: " + finalStage + "(" + finalStage.name + ")")
logInfo("Parents of final stage: " + finalStage.parents)
logInfo("Missing parents: " + getMissingParentStages(finalStage))
- val shouldRunLocally =
- localExecutionEnabled && allowLocal && finalStage.parents.isEmpty && partitions.length == 1
val jobSubmissionTime = clock.getTimeMillis()
- if (shouldRunLocally) {
- // Compute very short actions like first() or take() with no parent stages locally.
- listenerBus.post(
- SparkListenerJobStart(job.jobId, jobSubmissionTime, Seq.empty, properties))
- runLocally(job)
- } else {
- jobIdToActiveJob(jobId) = job
- activeJobs += job
- finalStage.resultOfJob = Some(job)
- val stageIds = jobIdToStageIds(jobId).toArray
- val stageInfos = stageIds.flatMap(id => stageIdToStage.get(id).map(_.latestInfo))
- listenerBus.post(
- SparkListenerJobStart(job.jobId, jobSubmissionTime, stageInfos, properties))
- submitStage(finalStage)
- }
+ jobIdToActiveJob(jobId) = job
+ activeJobs += job
+ finalStage.resultOfJob = Some(job)
+ val stageIds = jobIdToStageIds(jobId).toArray
+ val stageInfos = stageIds.flatMap(id => stageIdToStage.get(id).map(_.latestInfo))
+ listenerBus.post(
+ SparkListenerJobStart(job.jobId, jobSubmissionTime, stageInfos, properties))
+ submitStage(finalStage)
}
submitWaitingStages()
}
@@ -853,7 +772,6 @@ class DAGScheduler(
// Get our pending tasks and remember them in our pendingTasks entry
stage.pendingTasks.clear()
-
// First figure out the indexes of partition ids to compute.
val partitionsToCompute: Seq[Int] = {
stage match {
@@ -914,7 +832,7 @@ class DAGScheduler(
partitionsToCompute.map { id =>
val locs = getPreferredLocs(stage.rdd, id)
val part = stage.rdd.partitions(id)
- new ShuffleMapTask(stage.id, taskBinary, part, locs)
+ new ShuffleMapTask(stage.id, stage.latestInfo.attemptId, taskBinary, part, locs)
}
case stage: ResultStage =>
@@ -923,7 +841,7 @@ class DAGScheduler(
val p: Int = job.partitions(id)
val part = stage.rdd.partitions(p)
val locs = getPreferredLocs(stage.rdd, p)
- new ResultTask(stage.id, taskBinary, part, locs, id)
+ new ResultTask(stage.id, stage.latestInfo.attemptId, taskBinary, part, locs, id)
}
}
} catch {
@@ -1065,10 +983,11 @@ class DAGScheduler(
val execId = status.location.executorId
logDebug("ShuffleMapTask finished on " + execId)
if (failedEpoch.contains(execId) && smt.epoch <= failedEpoch(execId)) {
- logInfo("Ignoring possibly bogus ShuffleMapTask completion from " + execId)
+ logInfo(s"Ignoring possibly bogus $smt completion from executor $execId")
} else {
shuffleStage.addOutputLoc(smt.partitionId, status)
}
+
if (runningStages.contains(shuffleStage) && shuffleStage.pendingTasks.isEmpty) {
markStageAsFinished(shuffleStage)
logInfo("looking for newly runnable stages")
@@ -1128,38 +1047,48 @@ class DAGScheduler(
val failedStage = stageIdToStage(task.stageId)
val mapStage = shuffleToMapStage(shuffleId)
- // It is likely that we receive multiple FetchFailed for a single stage (because we have
- // multiple tasks running concurrently on different executors). In that case, it is possible
- // the fetch failure has already been handled by the scheduler.
- if (runningStages.contains(failedStage)) {
- logInfo(s"Marking $failedStage (${failedStage.name}) as failed " +
- s"due to a fetch failure from $mapStage (${mapStage.name})")
- markStageAsFinished(failedStage, Some(failureMessage))
- }
+ if (failedStage.latestInfo.attemptId != task.stageAttemptId) {
+ logInfo(s"Ignoring fetch failure from $task as it's from $failedStage attempt" +
+ s" ${task.stageAttemptId} and there is a more recent attempt for that stage " +
+ s"(attempt ID ${failedStage.latestInfo.attemptId}) running")
+ } else {
- if (disallowStageRetryForTest) {
- abortStage(failedStage, "Fetch failure will not retry stage due to testing config")
- } else if (failedStages.isEmpty) {
- // Don't schedule an event to resubmit failed stages if failed isn't empty, because
- // in that case the event will already have been scheduled.
- // TODO: Cancel running tasks in the stage
- logInfo(s"Resubmitting $mapStage (${mapStage.name}) and " +
- s"$failedStage (${failedStage.name}) due to fetch failure")
- messageScheduler.schedule(new Runnable {
- override def run(): Unit = eventProcessLoop.post(ResubmitFailedStages)
- }, DAGScheduler.RESUBMIT_TIMEOUT, TimeUnit.MILLISECONDS)
- }
- failedStages += failedStage
- failedStages += mapStage
- // Mark the map whose fetch failed as broken in the map stage
- if (mapId != -1) {
- mapStage.removeOutputLoc(mapId, bmAddress)
- mapOutputTracker.unregisterMapOutput(shuffleId, mapId, bmAddress)
- }
+ // It is likely that we receive multiple FetchFailed for a single stage (because we have
+ // multiple tasks running concurrently on different executors). In that case, it is
+ // possible the fetch failure has already been handled by the scheduler.
+ if (runningStages.contains(failedStage)) {
+ logInfo(s"Marking $failedStage (${failedStage.name}) as failed " +
+ s"due to a fetch failure from $mapStage (${mapStage.name})")
+ markStageAsFinished(failedStage, Some(failureMessage))
+ } else {
+ logDebug(s"Received fetch failure from $task, but its from $failedStage which is no " +
+ s"longer running")
+ }
- // TODO: mark the executor as failed only if there were lots of fetch failures on it
- if (bmAddress != null) {
- handleExecutorLost(bmAddress.executorId, fetchFailed = true, Some(task.epoch))
+ if (disallowStageRetryForTest) {
+ abortStage(failedStage, "Fetch failure will not retry stage due to testing config")
+ } else if (failedStages.isEmpty) {
+ // Don't schedule an event to resubmit failed stages if failed isn't empty, because
+ // in that case the event will already have been scheduled.
+ // TODO: Cancel running tasks in the stage
+ logInfo(s"Resubmitting $mapStage (${mapStage.name}) and " +
+ s"$failedStage (${failedStage.name}) due to fetch failure")
+ messageScheduler.schedule(new Runnable {
+ override def run(): Unit = eventProcessLoop.post(ResubmitFailedStages)
+ }, DAGScheduler.RESUBMIT_TIMEOUT, TimeUnit.MILLISECONDS)
+ }
+ failedStages += failedStage
+ failedStages += mapStage
+ // Mark the map whose fetch failed as broken in the map stage
+ if (mapId != -1) {
+ mapStage.removeOutputLoc(mapId, bmAddress)
+ mapOutputTracker.unregisterMapOutput(shuffleId, mapId, bmAddress)
+ }
+
+ // TODO: mark the executor as failed only if there were lots of fetch failures on it
+ if (bmAddress != null) {
+ handleExecutorLost(bmAddress.executorId, fetchFailed = true, Some(task.epoch))
+ }
}
case commitDenied: TaskCommitDenied =>
@@ -1471,9 +1400,8 @@ private[scheduler] class DAGSchedulerEventProcessLoop(dagScheduler: DAGScheduler
}
private def doOnReceive(event: DAGSchedulerEvent): Unit = event match {
- case JobSubmitted(jobId, rdd, func, partitions, allowLocal, callSite, listener, properties) =>
- dagScheduler.handleJobSubmitted(jobId, rdd, func, partitions, allowLocal, callSite,
- listener, properties)
+ case JobSubmitted(jobId, rdd, func, partitions, callSite, listener, properties) =>
+ dagScheduler.handleJobSubmitted(jobId, rdd, func, partitions, callSite, listener, properties)
case StageCancelled(stageId) =>
dagScheduler.handleStageCancellation(stageId)
diff --git a/core/src/main/scala/org/apache/spark/scheduler/DAGSchedulerEvent.scala b/core/src/main/scala/org/apache/spark/scheduler/DAGSchedulerEvent.scala
index 2b6f7e4205c32..a213d419cf033 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/DAGSchedulerEvent.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/DAGSchedulerEvent.scala
@@ -19,7 +19,7 @@ package org.apache.spark.scheduler
import java.util.Properties
-import scala.collection.mutable.Map
+import scala.collection.Map
import scala.language.existentials
import org.apache.spark._
@@ -40,7 +40,6 @@ private[scheduler] case class JobSubmitted(
finalRDD: RDD[_],
func: (TaskContext, Iterator[_]) => _,
partitions: Array[Int],
- allowLocal: Boolean,
callSite: CallSite,
listener: JobListener,
properties: Properties = null)
diff --git a/core/src/main/scala/org/apache/spark/scheduler/EventLoggingListener.scala b/core/src/main/scala/org/apache/spark/scheduler/EventLoggingListener.scala
index 62b05033a9281..5a06ef02f5c57 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/EventLoggingListener.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/EventLoggingListener.scala
@@ -199,6 +199,9 @@ private[spark] class EventLoggingListener(
logEvent(event, flushLogger = true)
}
+ // No-op because logging every update would be overkill
+ override def onBlockUpdated(event: SparkListenerBlockUpdated): Unit = {}
+
// No-op because logging every update would be overkill
override def onExecutorMetricsUpdate(event: SparkListenerExecutorMetricsUpdate): Unit = { }
diff --git a/core/src/main/scala/org/apache/spark/scheduler/ResultTask.scala b/core/src/main/scala/org/apache/spark/scheduler/ResultTask.scala
index c9a124113961f..9c2606e278c54 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/ResultTask.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/ResultTask.scala
@@ -41,11 +41,12 @@ import org.apache.spark.rdd.RDD
*/
private[spark] class ResultTask[T, U](
stageId: Int,
+ stageAttemptId: Int,
taskBinary: Broadcast[Array[Byte]],
partition: Partition,
@transient locs: Seq[TaskLocation],
val outputId: Int)
- extends Task[U](stageId, partition.index) with Serializable {
+ extends Task[U](stageId, stageAttemptId, partition.index) with Serializable {
@transient private[this] val preferredLocs: Seq[TaskLocation] = {
if (locs == null) Nil else locs.toSet.toSeq
diff --git a/core/src/main/scala/org/apache/spark/scheduler/ShuffleMapTask.scala b/core/src/main/scala/org/apache/spark/scheduler/ShuffleMapTask.scala
index bd3dd23dfe1ac..14c8c00961487 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/ShuffleMapTask.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/ShuffleMapTask.scala
@@ -40,14 +40,15 @@ import org.apache.spark.shuffle.ShuffleWriter
*/
private[spark] class ShuffleMapTask(
stageId: Int,
+ stageAttemptId: Int,
taskBinary: Broadcast[Array[Byte]],
partition: Partition,
@transient private var locs: Seq[TaskLocation])
- extends Task[MapStatus](stageId, partition.index) with Logging {
+ extends Task[MapStatus](stageId, stageAttemptId, partition.index) with Logging {
/** A constructor used only in test suites. This does not require passing in an RDD. */
def this(partitionId: Int) {
- this(0, null, new Partition { override def index: Int = 0 }, null)
+ this(0, 0, null, new Partition { override def index: Int = 0 }, null)
}
@transient private val preferredLocs: Seq[TaskLocation] = {
diff --git a/core/src/main/scala/org/apache/spark/scheduler/SparkListener.scala b/core/src/main/scala/org/apache/spark/scheduler/SparkListener.scala
index 9620915f495ab..896f1743332f1 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/SparkListener.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/SparkListener.scala
@@ -26,7 +26,7 @@ import org.apache.spark.{Logging, TaskEndReason}
import org.apache.spark.annotation.DeveloperApi
import org.apache.spark.executor.TaskMetrics
import org.apache.spark.scheduler.cluster.ExecutorInfo
-import org.apache.spark.storage.BlockManagerId
+import org.apache.spark.storage.{BlockManagerId, BlockUpdatedInfo}
import org.apache.spark.util.{Distribution, Utils}
@DeveloperApi
@@ -98,6 +98,9 @@ case class SparkListenerExecutorAdded(time: Long, executorId: String, executorIn
case class SparkListenerExecutorRemoved(time: Long, executorId: String, reason: String)
extends SparkListenerEvent
+@DeveloperApi
+case class SparkListenerBlockUpdated(blockUpdatedInfo: BlockUpdatedInfo) extends SparkListenerEvent
+
/**
* Periodic updates from executors.
* @param execId executor id
@@ -215,6 +218,11 @@ trait SparkListener {
* Called when the driver removes an executor.
*/
def onExecutorRemoved(executorRemoved: SparkListenerExecutorRemoved) { }
+
+ /**
+ * Called when the driver receives a block update info.
+ */
+ def onBlockUpdated(blockUpdated: SparkListenerBlockUpdated) { }
}
/**
diff --git a/core/src/main/scala/org/apache/spark/scheduler/SparkListenerBus.scala b/core/src/main/scala/org/apache/spark/scheduler/SparkListenerBus.scala
index 61e69ecc08387..04afde33f5aad 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/SparkListenerBus.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/SparkListenerBus.scala
@@ -58,6 +58,8 @@ private[spark] trait SparkListenerBus extends ListenerBus[SparkListener, SparkLi
listener.onExecutorAdded(executorAdded)
case executorRemoved: SparkListenerExecutorRemoved =>
listener.onExecutorRemoved(executorRemoved)
+ case blockUpdated: SparkListenerBlockUpdated =>
+ listener.onBlockUpdated(blockUpdated)
case logStart: SparkListenerLogStart => // ignore event log metadata
}
}
diff --git a/core/src/main/scala/org/apache/spark/scheduler/Task.scala b/core/src/main/scala/org/apache/spark/scheduler/Task.scala
index 15101c64f0503..d11a00956a9a9 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/Task.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/Task.scala
@@ -22,6 +22,7 @@ import java.nio.ByteBuffer
import scala.collection.mutable.HashMap
+import org.apache.spark.metrics.MetricsSystem
import org.apache.spark.{TaskContextImpl, TaskContext}
import org.apache.spark.executor.TaskMetrics
import org.apache.spark.serializer.SerializerInstance
@@ -43,31 +44,46 @@ import org.apache.spark.util.Utils
* @param stageId id of the stage this task belongs to
* @param partitionId index of the number in the RDD
*/
-private[spark] abstract class Task[T](val stageId: Int, var partitionId: Int) extends Serializable {
+private[spark] abstract class Task[T](
+ val stageId: Int,
+ val stageAttemptId: Int,
+ var partitionId: Int) extends Serializable {
+
+ /**
+ * The key of the Map is the accumulator id and the value of the Map is the latest accumulator
+ * local value.
+ */
+ type AccumulatorUpdates = Map[Long, Any]
/**
* Called by [[Executor]] to run this task.
*
* @param taskAttemptId an identifier for this task attempt that is unique within a SparkContext.
* @param attemptNumber how many times this task has been attempted (0 for the first attempt)
- * @return the result of the task
+ * @return the result of the task along with updates of Accumulators.
*/
- final def run(taskAttemptId: Long, attemptNumber: Int): T = {
+ final def run(
+ taskAttemptId: Long,
+ attemptNumber: Int,
+ metricsSystem: MetricsSystem)
+ : (T, AccumulatorUpdates) = {
context = new TaskContextImpl(
stageId = stageId,
partitionId = partitionId,
taskAttemptId = taskAttemptId,
attemptNumber = attemptNumber,
taskMemoryManager = taskMemoryManager,
+ metricsSystem = metricsSystem,
runningLocally = false)
TaskContext.setTaskContext(context)
context.taskMetrics.setHostname(Utils.localHostName())
+ context.taskMetrics.setAccumulatorsUpdater(context.collectInternalAccumulators)
taskThread = Thread.currentThread()
if (_killed) {
kill(interruptThread = false)
}
try {
- runTask(context)
+ (runTask(context), context.collectAccumulators())
} finally {
context.markTaskCompleted()
TaskContext.unset()
diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskResult.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskResult.scala
index 8b2a742b96988..b82c7f3fa54f8 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/TaskResult.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/TaskResult.scala
@@ -20,7 +20,8 @@ package org.apache.spark.scheduler
import java.io._
import java.nio.ByteBuffer
-import scala.collection.mutable.Map
+import scala.collection.Map
+import scala.collection.mutable
import org.apache.spark.SparkEnv
import org.apache.spark.executor.TaskMetrics
@@ -69,10 +70,11 @@ class DirectTaskResult[T](var valueBytes: ByteBuffer, var accumUpdates: Map[Long
if (numUpdates == 0) {
accumUpdates = null
} else {
- accumUpdates = Map()
+ val _accumUpdates = mutable.Map[Long, Any]()
for (i <- 0 until numUpdates) {
- accumUpdates(in.readLong()) = in.readObject()
+ _accumUpdates(in.readLong()) = in.readObject()
}
+ accumUpdates = _accumUpdates
}
metrics = in.readObject().asInstanceOf[TaskMetrics]
valueObjectDeserialized = false
diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala
index ed3dde0fc3055..1705e7f962de2 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala
@@ -75,9 +75,9 @@ private[spark] class TaskSchedulerImpl(
// TaskSetManagers are not thread safe, so any access to one should be synchronized
// on this class.
- val activeTaskSets = new HashMap[String, TaskSetManager]
+ private val taskSetsByStageIdAndAttempt = new HashMap[Int, HashMap[Int, TaskSetManager]]
- val taskIdToTaskSetId = new HashMap[Long, String]
+ private[scheduler] val taskIdToTaskSetManager = new HashMap[Long, TaskSetManager]
val taskIdToExecutorId = new HashMap[Long, String]
@volatile private var hasReceivedTask = false
@@ -162,7 +162,17 @@ private[spark] class TaskSchedulerImpl(
logInfo("Adding task set " + taskSet.id + " with " + tasks.length + " tasks")
this.synchronized {
val manager = createTaskSetManager(taskSet, maxTaskFailures)
- activeTaskSets(taskSet.id) = manager
+ val stage = taskSet.stageId
+ val stageTaskSets =
+ taskSetsByStageIdAndAttempt.getOrElseUpdate(stage, new HashMap[Int, TaskSetManager])
+ stageTaskSets(taskSet.stageAttemptId) = manager
+ val conflictingTaskSet = stageTaskSets.exists { case (_, ts) =>
+ ts.taskSet != taskSet && !ts.isZombie
+ }
+ if (conflictingTaskSet) {
+ throw new IllegalStateException(s"more than one active taskSet for stage $stage:" +
+ s" ${stageTaskSets.toSeq.map{_._2.taskSet.id}.mkString(",")}")
+ }
schedulableBuilder.addTaskSetManager(manager, manager.taskSet.properties)
if (!isLocal && !hasReceivedTask) {
@@ -192,19 +202,21 @@ private[spark] class TaskSchedulerImpl(
override def cancelTasks(stageId: Int, interruptThread: Boolean): Unit = synchronized {
logInfo("Cancelling stage " + stageId)
- activeTaskSets.find(_._2.stageId == stageId).foreach { case (_, tsm) =>
- // There are two possible cases here:
- // 1. The task set manager has been created and some tasks have been scheduled.
- // In this case, send a kill signal to the executors to kill the task and then abort
- // the stage.
- // 2. The task set manager has been created but no tasks has been scheduled. In this case,
- // simply abort the stage.
- tsm.runningTasksSet.foreach { tid =>
- val execId = taskIdToExecutorId(tid)
- backend.killTask(tid, execId, interruptThread)
+ taskSetsByStageIdAndAttempt.get(stageId).foreach { attempts =>
+ attempts.foreach { case (_, tsm) =>
+ // There are two possible cases here:
+ // 1. The task set manager has been created and some tasks have been scheduled.
+ // In this case, send a kill signal to the executors to kill the task and then abort
+ // the stage.
+ // 2. The task set manager has been created but no tasks has been scheduled. In this case,
+ // simply abort the stage.
+ tsm.runningTasksSet.foreach { tid =>
+ val execId = taskIdToExecutorId(tid)
+ backend.killTask(tid, execId, interruptThread)
+ }
+ tsm.abort("Stage %s cancelled".format(stageId))
+ logInfo("Stage %d was cancelled".format(stageId))
}
- tsm.abort("Stage %s cancelled".format(stageId))
- logInfo("Stage %d was cancelled".format(stageId))
}
}
@@ -214,7 +226,12 @@ private[spark] class TaskSchedulerImpl(
* cleaned up.
*/
def taskSetFinished(manager: TaskSetManager): Unit = synchronized {
- activeTaskSets -= manager.taskSet.id
+ taskSetsByStageIdAndAttempt.get(manager.taskSet.stageId).foreach { taskSetsForStage =>
+ taskSetsForStage -= manager.taskSet.stageAttemptId
+ if (taskSetsForStage.isEmpty) {
+ taskSetsByStageIdAndAttempt -= manager.taskSet.stageId
+ }
+ }
manager.parent.removeSchedulable(manager)
logInfo("Removed TaskSet %s, whose tasks have all completed, from pool %s"
.format(manager.taskSet.id, manager.parent.name))
@@ -235,7 +252,7 @@ private[spark] class TaskSchedulerImpl(
for (task <- taskSet.resourceOffer(execId, host, maxLocality)) {
tasks(i) += task
val tid = task.taskId
- taskIdToTaskSetId(tid) = taskSet.taskSet.id
+ taskIdToTaskSetManager(tid) = taskSet
taskIdToExecutorId(tid) = execId
executorsByHost(host) += execId
availableCpus(i) -= CPUS_PER_TASK
@@ -319,26 +336,24 @@ private[spark] class TaskSchedulerImpl(
failedExecutor = Some(execId)
}
}
- taskIdToTaskSetId.get(tid) match {
- case Some(taskSetId) =>
+ taskIdToTaskSetManager.get(tid) match {
+ case Some(taskSet) =>
if (TaskState.isFinished(state)) {
- taskIdToTaskSetId.remove(tid)
+ taskIdToTaskSetManager.remove(tid)
taskIdToExecutorId.remove(tid)
}
- activeTaskSets.get(taskSetId).foreach { taskSet =>
- if (state == TaskState.FINISHED) {
- taskSet.removeRunningTask(tid)
- taskResultGetter.enqueueSuccessfulTask(taskSet, tid, serializedData)
- } else if (Set(TaskState.FAILED, TaskState.KILLED, TaskState.LOST).contains(state)) {
- taskSet.removeRunningTask(tid)
- taskResultGetter.enqueueFailedTask(taskSet, tid, state, serializedData)
- }
+ if (state == TaskState.FINISHED) {
+ taskSet.removeRunningTask(tid)
+ taskResultGetter.enqueueSuccessfulTask(taskSet, tid, serializedData)
+ } else if (Set(TaskState.FAILED, TaskState.KILLED, TaskState.LOST).contains(state)) {
+ taskSet.removeRunningTask(tid)
+ taskResultGetter.enqueueFailedTask(taskSet, tid, state, serializedData)
}
case None =>
logError(
("Ignoring update with state %s for TID %s because its task set is gone (this is " +
- "likely the result of receiving duplicate task finished status updates)")
- .format(state, tid))
+ "likely the result of receiving duplicate task finished status updates)")
+ .format(state, tid))
}
} catch {
case e: Exception => logError("Exception in statusUpdate", e)
@@ -363,9 +378,9 @@ private[spark] class TaskSchedulerImpl(
val metricsWithStageIds: Array[(Long, Int, Int, TaskMetrics)] = synchronized {
taskMetrics.flatMap { case (id, metrics) =>
- taskIdToTaskSetId.get(id)
- .flatMap(activeTaskSets.get)
- .map(taskSetMgr => (id, taskSetMgr.stageId, taskSetMgr.taskSet.attempt, metrics))
+ taskIdToTaskSetManager.get(id).map { taskSetMgr =>
+ (id, taskSetMgr.stageId, taskSetMgr.taskSet.stageAttemptId, metrics)
+ }
}
}
dagScheduler.executorHeartbeatReceived(execId, metricsWithStageIds, blockManagerId)
@@ -397,9 +412,12 @@ private[spark] class TaskSchedulerImpl(
def error(message: String) {
synchronized {
- if (activeTaskSets.nonEmpty) {
+ if (taskSetsByStageIdAndAttempt.nonEmpty) {
// Have each task set throw a SparkException with the error
- for ((taskSetId, manager) <- activeTaskSets) {
+ for {
+ attempts <- taskSetsByStageIdAndAttempt.values
+ manager <- attempts.values
+ } {
try {
manager.abort(message)
} catch {
@@ -520,6 +538,17 @@ private[spark] class TaskSchedulerImpl(
override def applicationAttemptId(): Option[String] = backend.applicationAttemptId()
+ private[scheduler] def taskSetManagerForAttempt(
+ stageId: Int,
+ stageAttemptId: Int): Option[TaskSetManager] = {
+ for {
+ attempts <- taskSetsByStageIdAndAttempt.get(stageId)
+ manager <- attempts.get(stageAttemptId)
+ } yield {
+ manager
+ }
+ }
+
}
diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskSet.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskSet.scala
index c3ad325156f53..be8526ba9b94f 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/TaskSet.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/TaskSet.scala
@@ -26,10 +26,10 @@ import java.util.Properties
private[spark] class TaskSet(
val tasks: Array[Task[_]],
val stageId: Int,
- val attempt: Int,
+ val stageAttemptId: Int,
val priority: Int,
val properties: Properties) {
- val id: String = stageId + "." + attempt
+ val id: String = stageId + "." + stageAttemptId
override def toString: String = "TaskSet " + id
}
diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala
index 7c7f70d8a193b..c65b3e517773e 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala
@@ -169,9 +169,12 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, val rpcEnv: Rp
// Make fake resource offers on all executors
private def makeOffers() {
- launchTasks(scheduler.resourceOffers(executorDataMap.map { case (id, executorData) =>
+ // Filter out executors under killing
+ val activeExecutors = executorDataMap.filterKeys(!executorsPendingToRemove.contains(_))
+ val workOffers = activeExecutors.map { case (id, executorData) =>
new WorkerOffer(id, executorData.executorHost, executorData.freeCores)
- }.toSeq))
+ }.toSeq
+ launchTasks(scheduler.resourceOffers(workOffers))
}
override def onDisconnected(remoteAddress: RpcAddress): Unit = {
@@ -181,9 +184,13 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, val rpcEnv: Rp
// Make fake resource offers on just one executor
private def makeOffers(executorId: String) {
- val executorData = executorDataMap(executorId)
- launchTasks(scheduler.resourceOffers(
- Seq(new WorkerOffer(executorId, executorData.executorHost, executorData.freeCores))))
+ // Filter out executors under killing
+ if (!executorsPendingToRemove.contains(executorId)) {
+ val executorData = executorDataMap(executorId)
+ val workOffers = Seq(
+ new WorkerOffer(executorId, executorData.executorHost, executorData.freeCores))
+ launchTasks(scheduler.resourceOffers(workOffers))
+ }
}
// Launch tasks returned by a set of resource offers
@@ -191,15 +198,14 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, val rpcEnv: Rp
for (task <- tasks.flatten) {
val serializedTask = ser.serialize(task)
if (serializedTask.limit >= akkaFrameSize - AkkaUtils.reservedSizeBytes) {
- val taskSetId = scheduler.taskIdToTaskSetId(task.taskId)
- scheduler.activeTaskSets.get(taskSetId).foreach { taskSet =>
+ scheduler.taskIdToTaskSetManager.get(task.taskId).foreach { taskSetMgr =>
try {
var msg = "Serialized task %s:%d was %d bytes, which exceeds max allowed: " +
"spark.akka.frameSize (%d bytes) - reserved (%d bytes). Consider increasing " +
"spark.akka.frameSize or using broadcast variables for large values."
msg = msg.format(task.taskId, task.index, serializedTask.limit, akkaFrameSize,
AkkaUtils.reservedSizeBytes)
- taskSet.abort(msg)
+ taskSetMgr.abort(msg)
} catch {
case e: Exception => logError("Exception in error callback", e)
}
@@ -371,26 +377,36 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, val rpcEnv: Rp
/**
* Request that the cluster manager kill the specified executors.
- * Return whether the kill request is acknowledged.
+ * @return whether the kill request is acknowledged.
*/
final override def killExecutors(executorIds: Seq[String]): Boolean = synchronized {
+ killExecutors(executorIds, replace = false)
+ }
+
+ /**
+ * Request that the cluster manager kill the specified executors.
+ *
+ * @param executorIds identifiers of executors to kill
+ * @param replace whether to replace the killed executors with new ones
+ * @return whether the kill request is acknowledged.
+ */
+ final def killExecutors(executorIds: Seq[String], replace: Boolean): Boolean = synchronized {
logInfo(s"Requesting to kill executor(s) ${executorIds.mkString(", ")}")
- val filteredExecutorIds = new ArrayBuffer[String]
- executorIds.foreach { id =>
- if (executorDataMap.contains(id)) {
- filteredExecutorIds += id
- } else {
- logWarning(s"Executor to kill $id does not exist!")
- }
+ val (knownExecutors, unknownExecutors) = executorIds.partition(executorDataMap.contains)
+ unknownExecutors.foreach { id =>
+ logWarning(s"Executor to kill $id does not exist!")
+ }
+
+ // If we do not wish to replace the executors we kill, sync the target number of executors
+ // with the cluster manager to avoid allocating new ones. When computing the new target,
+ // take into account executors that are pending to be added or removed.
+ if (!replace) {
+ doRequestTotalExecutors(numExistingExecutors + numPendingExecutors
+ - executorsPendingToRemove.size - knownExecutors.size)
}
- // Killing executors means effectively that we want less executors than before, so also update
- // the target number of executors to avoid having the backend allocate new ones.
- val newTotal = (numExistingExecutors + numPendingExecutors - executorsPendingToRemove.size
- - filteredExecutorIds.size)
- doRequestTotalExecutors(newTotal)
- executorsPendingToRemove ++= filteredExecutorIds
- doKillExecutors(filteredExecutorIds)
+ executorsPendingToRemove ++= knownExecutors
+ doKillExecutors(knownExecutors)
}
/**
diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/CoarseMesosSchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/CoarseMesosSchedulerBackend.scala
index cbade131494bc..b7fde0d9b3265 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/CoarseMesosSchedulerBackend.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/CoarseMesosSchedulerBackend.scala
@@ -18,8 +18,8 @@
package org.apache.spark.scheduler.cluster.mesos
import java.io.File
-import java.util.{List => JList, Collections}
import java.util.concurrent.locks.ReentrantLock
+import java.util.{Collections, List => JList}
import scala.collection.JavaConversions._
import scala.collection.mutable.{HashMap, HashSet}
@@ -27,12 +27,11 @@ import scala.collection.mutable.{HashMap, HashSet}
import com.google.common.collect.HashBiMap
import org.apache.mesos.Protos.{TaskInfo => MesosTaskInfo, _}
import org.apache.mesos.{Scheduler => MScheduler, _}
-import org.apache.mesos.Protos.{TaskInfo => MesosTaskInfo, _}
-import org.apache.spark.{SparkContext, SparkEnv, SparkException, TaskState}
import org.apache.spark.rpc.RpcAddress
import org.apache.spark.scheduler.TaskSchedulerImpl
import org.apache.spark.scheduler.cluster.CoarseGrainedSchedulerBackend
import org.apache.spark.util.Utils
+import org.apache.spark.{SparkContext, SparkEnv, SparkException, TaskState}
/**
* A SchedulerBackend that runs tasks on Mesos, but uses "coarse-grained" tasks, where it holds
@@ -69,7 +68,7 @@ private[spark] class CoarseMesosSchedulerBackend(
/**
* The total number of executors we aim to have. Undefined when not using dynamic allocation
- * and before the ExecutorAllocatorManager calls [[doRequesTotalExecutors]].
+ * and before the ExecutorAllocatorManager calls [[doRequestTotalExecutors]].
*/
private var executorLimitOption: Option[Int] = None
@@ -103,8 +102,9 @@ private[spark] class CoarseMesosSchedulerBackend(
override def start() {
super.start()
- val fwInfo = FrameworkInfo.newBuilder().setUser(sc.sparkUser).setName(sc.appName).build()
- startScheduler(master, CoarseMesosSchedulerBackend.this, fwInfo)
+ val driver = createSchedulerDriver(
+ master, CoarseMesosSchedulerBackend.this, sc.sparkUser, sc.appName, sc.conf)
+ startScheduler(driver)
}
def createCommand(offer: Offer, numCores: Int, taskId: Int): CommandInfo = {
@@ -224,24 +224,29 @@ private[spark] class CoarseMesosSchedulerBackend(
taskIdToSlaveId(taskId) = slaveId
slaveIdsWithExecutors += slaveId
coresByTaskId(taskId) = cpusToUse
- val task = MesosTaskInfo.newBuilder()
+ // Gather cpu resources from the available resources and use them in the task.
+ val (remainingResources, cpuResourcesToUse) =
+ partitionResources(offer.getResourcesList, "cpus", cpusToUse)
+ val (_, memResourcesToUse) =
+ partitionResources(remainingResources, "mem", calculateTotalMemory(sc))
+ val taskBuilder = MesosTaskInfo.newBuilder()
.setTaskId(TaskID.newBuilder().setValue(taskId.toString).build())
.setSlaveId(offer.getSlaveId)
.setCommand(createCommand(offer, cpusToUse + extraCoresPerSlave, taskId))
.setName("Task " + taskId)
- .addResources(createResource("cpus", cpusToUse))
- .addResources(createResource("mem", calculateTotalMemory(sc)))
+ .addAllResources(cpuResourcesToUse)
+ .addAllResources(memResourcesToUse)
sc.conf.getOption("spark.mesos.executor.docker.image").foreach { image =>
MesosSchedulerBackendUtil
- .setupContainerBuilderDockerInfo(image, sc.conf, task.getContainerBuilder)
+ .setupContainerBuilderDockerInfo(image, sc.conf, taskBuilder.getContainerBuilder())
}
// accept the offer and launch the task
logDebug(s"Accepting offer: $id with attributes: $offerAttributes mem: $mem cpu: $cpus")
d.launchTasks(
Collections.singleton(offer.getId),
- Collections.singleton(task.build()), filters)
+ Collections.singleton(taskBuilder.build()), filters)
} else {
// Decline the offer
logDebug(s"Declining offer: $id with attributes: $offerAttributes mem: $mem cpu: $cpus")
@@ -255,7 +260,7 @@ private[spark] class CoarseMesosSchedulerBackend(
override def statusUpdate(d: SchedulerDriver, status: TaskStatus) {
val taskId = status.getTaskId.getValue.toInt
val state = status.getState
- logInfo("Mesos task " + taskId + " is now " + state)
+ logInfo(s"Mesos task $taskId is now $state")
stateLock.synchronized {
if (TaskState.isFinished(TaskState.fromMesos(state))) {
val slaveId = taskIdToSlaveId(taskId)
@@ -270,7 +275,7 @@ private[spark] class CoarseMesosSchedulerBackend(
if (TaskState.isFailed(TaskState.fromMesos(state))) {
failuresBySlaveId(slaveId) = failuresBySlaveId.getOrElse(slaveId, 0) + 1
if (failuresBySlaveId(slaveId) >= MAX_SLAVE_FAILURES) {
- logInfo("Blacklisting Mesos slave " + slaveId + " due to too many failures; " +
+ logInfo(s"Blacklisting Mesos slave $slaveId due to too many failures; " +
"is Spark installed on it?")
}
}
@@ -282,7 +287,7 @@ private[spark] class CoarseMesosSchedulerBackend(
}
override def error(d: SchedulerDriver, message: String) {
- logError("Mesos error: " + message)
+ logError(s"Mesos error: $message")
scheduler.error(message)
}
@@ -323,7 +328,7 @@ private[spark] class CoarseMesosSchedulerBackend(
}
override def slaveLost(d: SchedulerDriver, slaveId: SlaveID): Unit = {
- logInfo("Mesos slave lost: " + slaveId.getValue)
+ logInfo(s"Mesos slave lost: ${slaveId.getValue}")
executorTerminated(d, slaveId.getValue, "Mesos slave lost: " + slaveId.getValue)
}
diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosClusterScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosClusterScheduler.scala
index d3a20f822176e..f078547e71352 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosClusterScheduler.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosClusterScheduler.scala
@@ -295,20 +295,24 @@ private[spark] class MesosClusterScheduler(
def start(): Unit = {
// TODO: Implement leader election to make sure only one framework running in the cluster.
val fwId = schedulerState.fetch[String]("frameworkId")
- val builder = FrameworkInfo.newBuilder()
- .setUser(Utils.getCurrentUserName())
- .setName(appName)
- .setWebuiUrl(frameworkUrl)
- .setCheckpoint(true)
- .setFailoverTimeout(Integer.MAX_VALUE) // Setting to max so tasks keep running on crash
fwId.foreach { id =>
- builder.setId(FrameworkID.newBuilder().setValue(id).build())
frameworkId = id
}
recoverState()
metricsSystem.registerSource(new MesosClusterSchedulerSource(this))
metricsSystem.start()
- startScheduler(master, MesosClusterScheduler.this, builder.build())
+ val driver = createSchedulerDriver(
+ master,
+ MesosClusterScheduler.this,
+ Utils.getCurrentUserName(),
+ appName,
+ conf,
+ Some(frameworkUrl),
+ Some(true),
+ Some(Integer.MAX_VALUE),
+ fwId)
+
+ startScheduler(driver)
ready = true
}
@@ -449,12 +453,8 @@ private[spark] class MesosClusterScheduler(
offer.cpu -= driverCpu
offer.mem -= driverMem
val taskId = TaskID.newBuilder().setValue(submission.submissionId).build()
- val cpuResource = Resource.newBuilder()
- .setName("cpus").setType(Value.Type.SCALAR)
- .setScalar(Value.Scalar.newBuilder().setValue(driverCpu)).build()
- val memResource = Resource.newBuilder()
- .setName("mem").setType(Value.Type.SCALAR)
- .setScalar(Value.Scalar.newBuilder().setValue(driverMem)).build()
+ val cpuResource = createResource("cpus", driverCpu)
+ val memResource = createResource("mem", driverMem)
val commandInfo = buildDriverCommand(submission)
val appName = submission.schedulerProperties("spark.app.name")
val taskInfo = TaskInfo.newBuilder()
diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerBackend.scala
index d72e2af456e15..3f63ec1c5832f 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerBackend.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerBackend.scala
@@ -32,6 +32,7 @@ import org.apache.spark.scheduler._
import org.apache.spark.scheduler.cluster.ExecutorInfo
import org.apache.spark.util.Utils
+
/**
* A SchedulerBackend for running fine-grained tasks on Mesos. Each Spark task is mapped to a
* separate Mesos task, allowing multiple applications to share cluster nodes both in space (tasks
@@ -45,8 +46,8 @@ private[spark] class MesosSchedulerBackend(
with MScheduler
with MesosSchedulerUtils {
- // Which slave IDs we have executors on
- val slaveIdsWithExecutors = new HashSet[String]
+ // Stores the slave ids that has launched a Mesos executor.
+ val slaveIdToExecutorInfo = new HashMap[String, MesosExecutorInfo]
val taskIdToSlaveId = new HashMap[Long, String]
// An ExecutorInfo for our tasks
@@ -66,12 +67,21 @@ private[spark] class MesosSchedulerBackend(
@volatile var appId: String = _
override def start() {
- val fwInfo = FrameworkInfo.newBuilder().setUser(sc.sparkUser).setName(sc.appName).build()
classLoader = Thread.currentThread.getContextClassLoader
- startScheduler(master, MesosSchedulerBackend.this, fwInfo)
+ val driver = createSchedulerDriver(
+ master, MesosSchedulerBackend.this, sc.sparkUser, sc.appName, sc.conf)
+ startScheduler(driver)
}
- def createExecutorInfo(execId: String): MesosExecutorInfo = {
+ /**
+ * Creates a MesosExecutorInfo that is used to launch a Mesos executor.
+ * @param availableResources Available resources that is offered by Mesos
+ * @param execId The executor id to assign to this new executor.
+ * @return A tuple of the new mesos executor info and the remaining available resources.
+ */
+ def createExecutorInfo(
+ availableResources: JList[Resource],
+ execId: String): (MesosExecutorInfo, JList[Resource]) = {
val executorSparkHome = sc.conf.getOption("spark.mesos.executor.home")
.orElse(sc.getSparkHome()) // Fall back to driver Spark home for backward compatibility
.getOrElse {
@@ -115,32 +125,25 @@ private[spark] class MesosSchedulerBackend(
command.setValue(s"cd ${basename}*; $prefixEnv ./bin/spark-class $executorBackendName")
command.addUris(CommandInfo.URI.newBuilder().setValue(uri.get))
}
- val cpus = Resource.newBuilder()
- .setName("cpus")
- .setType(Value.Type.SCALAR)
- .setScalar(Value.Scalar.newBuilder()
- .setValue(mesosExecutorCores).build())
- .build()
- val memory = Resource.newBuilder()
- .setName("mem")
- .setType(Value.Type.SCALAR)
- .setScalar(
- Value.Scalar.newBuilder()
- .setValue(calculateTotalMemory(sc)).build())
- .build()
- val executorInfo = MesosExecutorInfo.newBuilder()
+ val builder = MesosExecutorInfo.newBuilder()
+ val (resourcesAfterCpu, usedCpuResources) =
+ partitionResources(availableResources, "cpus", scheduler.CPUS_PER_TASK)
+ val (resourcesAfterMem, usedMemResources) =
+ partitionResources(resourcesAfterCpu, "mem", calculateTotalMemory(sc))
+
+ builder.addAllResources(usedCpuResources)
+ builder.addAllResources(usedMemResources)
+ val executorInfo = builder
.setExecutorId(ExecutorID.newBuilder().setValue(execId).build())
.setCommand(command)
.setData(ByteString.copyFrom(createExecArg()))
- .addResources(cpus)
- .addResources(memory)
sc.conf.getOption("spark.mesos.executor.docker.image").foreach { image =>
MesosSchedulerBackendUtil
.setupContainerBuilderDockerInfo(image, sc.conf, executorInfo.getContainerBuilder())
}
- executorInfo.build()
+ (executorInfo.build(), resourcesAfterMem)
}
/**
@@ -183,6 +186,18 @@ private[spark] class MesosSchedulerBackend(
override def reregistered(d: SchedulerDriver, masterInfo: MasterInfo) {}
+ private def getTasksSummary(tasks: JArrayList[MesosTaskInfo]): String = {
+ val builder = new StringBuilder
+ tasks.foreach { t =>
+ builder.append("Task id: ").append(t.getTaskId.getValue).append("\n")
+ .append("Slave id: ").append(t.getSlaveId.getValue).append("\n")
+ .append("Task resources: ").append(t.getResourcesList).append("\n")
+ .append("Executor resources: ").append(t.getExecutor.getResourcesList)
+ .append("---------------------------------------------\n")
+ }
+ builder.toString()
+ }
+
/**
* Method called by Mesos to offer resources on slaves. We respond by asking our active task sets
* for tasks in order of priority. We fill each node with tasks in a round-robin manner so that
@@ -207,7 +222,7 @@ private[spark] class MesosSchedulerBackend(
val meetsRequirements =
(meetsConstraints && meetsMemoryRequirements && meetsCPURequirements) ||
- (slaveIdsWithExecutors.contains(slaveId) && cpus >= scheduler.CPUS_PER_TASK)
+ (slaveIdToExecutorInfo.contains(slaveId) && cpus >= scheduler.CPUS_PER_TASK)
// add some debug messaging
val debugstr = if (meetsRequirements) "Accepting" else "Declining"
@@ -221,7 +236,7 @@ private[spark] class MesosSchedulerBackend(
unUsableOffers.foreach(o => d.declineOffer(o.getId))
val workerOffers = usableOffers.map { o =>
- val cpus = if (slaveIdsWithExecutors.contains(o.getSlaveId.getValue)) {
+ val cpus = if (slaveIdToExecutorInfo.contains(o.getSlaveId.getValue)) {
getResource(o.getResourcesList, "cpus").toInt
} else {
// If the Mesos executor has not been started on this slave yet, set aside a few
@@ -236,6 +251,10 @@ private[spark] class MesosSchedulerBackend(
val slaveIdToOffer = usableOffers.map(o => o.getSlaveId.getValue -> o).toMap
val slaveIdToWorkerOffer = workerOffers.map(o => o.executorId -> o).toMap
+ val slaveIdToResources = new HashMap[String, JList[Resource]]()
+ usableOffers.foreach { o =>
+ slaveIdToResources(o.getSlaveId.getValue) = o.getResourcesList
+ }
val mesosTasks = new HashMap[String, JArrayList[MesosTaskInfo]]
@@ -245,15 +264,19 @@ private[spark] class MesosSchedulerBackend(
val acceptedOffers = scheduler.resourceOffers(workerOffers).filter(!_.isEmpty)
acceptedOffers
.foreach { offer =>
- offer.foreach { taskDesc =>
- val slaveId = taskDesc.executorId
- slaveIdsWithExecutors += slaveId
- slavesIdsOfAcceptedOffers += slaveId
- taskIdToSlaveId(taskDesc.taskId) = slaveId
- mesosTasks.getOrElseUpdate(slaveId, new JArrayList[MesosTaskInfo])
- .add(createMesosTask(taskDesc, slaveId))
+ offer.foreach { taskDesc =>
+ val slaveId = taskDesc.executorId
+ slavesIdsOfAcceptedOffers += slaveId
+ taskIdToSlaveId(taskDesc.taskId) = slaveId
+ val (mesosTask, remainingResources) = createMesosTask(
+ taskDesc,
+ slaveIdToResources(slaveId),
+ slaveId)
+ mesosTasks.getOrElseUpdate(slaveId, new JArrayList[MesosTaskInfo])
+ .add(mesosTask)
+ slaveIdToResources(slaveId) = remainingResources
+ }
}
- }
// Reply to the offers
val filters = Filters.newBuilder().setRefuseSeconds(1).build() // TODO: lower timeout?
@@ -264,6 +287,7 @@ private[spark] class MesosSchedulerBackend(
// TODO: Add support for log urls for Mesos
new ExecutorInfo(o.host, o.cores, Map.empty)))
)
+ logTrace(s"Launching Mesos tasks on slave '$slaveId', tasks:\n${getTasksSummary(tasks)}")
d.launchTasks(Collections.singleton(slaveIdToOffer(slaveId).getId), tasks, filters)
}
@@ -272,26 +296,32 @@ private[spark] class MesosSchedulerBackend(
for (o <- usableOffers if !slavesIdsOfAcceptedOffers.contains(o.getSlaveId.getValue)) {
d.declineOffer(o.getId)
}
-
}
}
- /** Turn a Spark TaskDescription into a Mesos task */
- def createMesosTask(task: TaskDescription, slaveId: String): MesosTaskInfo = {
+ /** Turn a Spark TaskDescription into a Mesos task and also resources unused by the task */
+ def createMesosTask(
+ task: TaskDescription,
+ resources: JList[Resource],
+ slaveId: String): (MesosTaskInfo, JList[Resource]) = {
val taskId = TaskID.newBuilder().setValue(task.taskId.toString).build()
- val cpuResource = Resource.newBuilder()
- .setName("cpus")
- .setType(Value.Type.SCALAR)
- .setScalar(Value.Scalar.newBuilder().setValue(scheduler.CPUS_PER_TASK).build())
- .build()
- MesosTaskInfo.newBuilder()
+ val (executorInfo, remainingResources) = if (slaveIdToExecutorInfo.contains(slaveId)) {
+ (slaveIdToExecutorInfo(slaveId), resources)
+ } else {
+ createExecutorInfo(resources, slaveId)
+ }
+ slaveIdToExecutorInfo(slaveId) = executorInfo
+ val (finalResources, cpuResources) =
+ partitionResources(remainingResources, "cpus", scheduler.CPUS_PER_TASK)
+ val taskInfo = MesosTaskInfo.newBuilder()
.setTaskId(taskId)
.setSlaveId(SlaveID.newBuilder().setValue(slaveId).build())
- .setExecutor(createExecutorInfo(slaveId))
+ .setExecutor(executorInfo)
.setName(task.name)
- .addResources(cpuResource)
+ .addAllResources(cpuResources)
.setData(MesosTaskLaunchData(task.serializedTask, task.attemptNumber).toByteString)
.build()
+ (taskInfo, finalResources)
}
override def statusUpdate(d: SchedulerDriver, status: TaskStatus) {
@@ -337,7 +367,7 @@ private[spark] class MesosSchedulerBackend(
private def removeExecutor(slaveId: String, reason: String) = {
synchronized {
listenerBus.post(SparkListenerExecutorRemoved(System.currentTimeMillis(), slaveId, reason))
- slaveIdsWithExecutors -= slaveId
+ slaveIdToExecutorInfo -= slaveId
}
}
diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerUtils.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerUtils.scala
index 925702e63afd3..c04920e4f5873 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerUtils.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerUtils.scala
@@ -21,15 +21,17 @@ import java.util.{List => JList}
import java.util.concurrent.CountDownLatch
import scala.collection.JavaConversions._
+import scala.collection.mutable.ArrayBuffer
import scala.util.control.NonFatal
import com.google.common.base.Splitter
import org.apache.mesos.{MesosSchedulerDriver, SchedulerDriver, Scheduler, Protos}
import org.apache.mesos.Protos._
-import org.apache.mesos.protobuf.GeneratedMessage
-import org.apache.spark.{Logging, SparkContext}
+import org.apache.mesos.protobuf.{ByteString, GeneratedMessage}
+import org.apache.spark.{SparkException, SparkConf, Logging, SparkContext}
import org.apache.spark.util.Utils
+
/**
* Shared trait for implementing a Mesos Scheduler. This holds common state and helper
* methods and Mesos scheduler will use.
@@ -42,13 +44,63 @@ private[mesos] trait MesosSchedulerUtils extends Logging {
protected var mesosDriver: SchedulerDriver = null
/**
- * Starts the MesosSchedulerDriver with the provided information. This method returns
- * only after the scheduler has registered with Mesos.
- * @param masterUrl Mesos master connection URL
- * @param scheduler Scheduler object
- * @param fwInfo FrameworkInfo to pass to the Mesos master
+ * Creates a new MesosSchedulerDriver that communicates to the Mesos master.
+ * @param masterUrl The url to connect to Mesos master
+ * @param scheduler the scheduler class to receive scheduler callbacks
+ * @param sparkUser User to impersonate with when running tasks
+ * @param appName The framework name to display on the Mesos UI
+ * @param conf Spark configuration
+ * @param webuiUrl The WebUI url to link from Mesos UI
+ * @param checkpoint Option to checkpoint tasks for failover
+ * @param failoverTimeout Duration Mesos master expect scheduler to reconnect on disconnect
+ * @param frameworkId The id of the new framework
*/
- def startScheduler(masterUrl: String, scheduler: Scheduler, fwInfo: FrameworkInfo): Unit = {
+ protected def createSchedulerDriver(
+ masterUrl: String,
+ scheduler: Scheduler,
+ sparkUser: String,
+ appName: String,
+ conf: SparkConf,
+ webuiUrl: Option[String] = None,
+ checkpoint: Option[Boolean] = None,
+ failoverTimeout: Option[Double] = None,
+ frameworkId: Option[String] = None): SchedulerDriver = {
+ val fwInfoBuilder = FrameworkInfo.newBuilder().setUser(sparkUser).setName(appName)
+ val credBuilder = Credential.newBuilder()
+ webuiUrl.foreach { url => fwInfoBuilder.setWebuiUrl(url) }
+ checkpoint.foreach { checkpoint => fwInfoBuilder.setCheckpoint(checkpoint) }
+ failoverTimeout.foreach { timeout => fwInfoBuilder.setFailoverTimeout(timeout) }
+ frameworkId.foreach { id =>
+ fwInfoBuilder.setId(FrameworkID.newBuilder().setValue(id).build())
+ }
+ conf.getOption("spark.mesos.principal").foreach { principal =>
+ fwInfoBuilder.setPrincipal(principal)
+ credBuilder.setPrincipal(principal)
+ }
+ conf.getOption("spark.mesos.secret").foreach { secret =>
+ credBuilder.setSecret(ByteString.copyFromUtf8(secret))
+ }
+ if (credBuilder.hasSecret && !fwInfoBuilder.hasPrincipal) {
+ throw new SparkException(
+ "spark.mesos.principal must be configured when spark.mesos.secret is set")
+ }
+ conf.getOption("spark.mesos.role").foreach { role =>
+ fwInfoBuilder.setRole(role)
+ }
+ if (credBuilder.hasPrincipal) {
+ new MesosSchedulerDriver(
+ scheduler, fwInfoBuilder.build(), masterUrl, credBuilder.build())
+ } else {
+ new MesosSchedulerDriver(scheduler, fwInfoBuilder.build(), masterUrl)
+ }
+ }
+
+ /**
+ * Starts the MesosSchedulerDriver and stores the current running driver to this new instance.
+ * This driver is expected to not be running.
+ * This method returns only after the scheduler has registered with Mesos.
+ */
+ def startScheduler(newDriver: SchedulerDriver): Unit = {
synchronized {
if (mesosDriver != null) {
registerLatch.await()
@@ -59,11 +111,11 @@ private[mesos] trait MesosSchedulerUtils extends Logging {
setDaemon(true)
override def run() {
- mesosDriver = new MesosSchedulerDriver(scheduler, fwInfo, masterUrl)
+ mesosDriver = newDriver
try {
val ret = mesosDriver.run()
logInfo("driver.run() returned with code " + ret)
- if (ret.equals(Status.DRIVER_ABORTED)) {
+ if (ret != null && ret.equals(Status.DRIVER_ABORTED)) {
System.exit(1)
}
} catch {
@@ -82,18 +134,62 @@ private[mesos] trait MesosSchedulerUtils extends Logging {
/**
* Signal that the scheduler has registered with Mesos.
*/
+ protected def getResource(res: JList[Resource], name: String): Double = {
+ // A resource can have multiple values in the offer since it can either be from
+ // a specific role or wildcard.
+ res.filter(_.getName == name).map(_.getScalar.getValue).sum
+ }
+
protected def markRegistered(): Unit = {
registerLatch.countDown()
}
+ def createResource(name: String, amount: Double, role: Option[String] = None): Resource = {
+ val builder = Resource.newBuilder()
+ .setName(name)
+ .setType(Value.Type.SCALAR)
+ .setScalar(Value.Scalar.newBuilder().setValue(amount).build())
+
+ role.foreach { r => builder.setRole(r) }
+
+ builder.build()
+ }
+
/**
- * Get the amount of resources for the specified type from the resource list
+ * Partition the existing set of resources into two groups, those remaining to be
+ * scheduled and those requested to be used for a new task.
+ * @param resources The full list of available resources
+ * @param resourceName The name of the resource to take from the available resources
+ * @param amountToUse The amount of resources to take from the available resources
+ * @return The remaining resources list and the used resources list.
*/
- protected def getResource(res: JList[Resource], name: String): Double = {
- for (r <- res if r.getName == name) {
- return r.getScalar.getValue
+ def partitionResources(
+ resources: JList[Resource],
+ resourceName: String,
+ amountToUse: Double): (List[Resource], List[Resource]) = {
+ var remain = amountToUse
+ var requestedResources = new ArrayBuffer[Resource]
+ val remainingResources = resources.map {
+ case r => {
+ if (remain > 0 &&
+ r.getType == Value.Type.SCALAR &&
+ r.getScalar.getValue > 0.0 &&
+ r.getName == resourceName) {
+ val usage = Math.min(remain, r.getScalar.getValue)
+ requestedResources += createResource(resourceName, usage, Some(r.getRole))
+ remain -= usage
+ createResource(resourceName, r.getScalar.getValue - usage, Some(r.getRole))
+ } else {
+ r
+ }
+ }
}
- 0.0
+
+ // Filter any resource that has depleted.
+ val filteredResources =
+ remainingResources.filter(r => r.getType != Value.Type.SCALAR || r.getScalar.getValue > 0.0)
+
+ (filteredResources.toList, requestedResources.toList)
}
/** Helper method to get the key,value-set pair for a Mesos Attribute protobuf */
diff --git a/core/src/main/scala/org/apache/spark/scheduler/local/LocalBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/local/LocalBackend.scala
index 776e5d330e3c7..4d48fcfea44e7 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/local/LocalBackend.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/local/LocalBackend.scala
@@ -25,7 +25,8 @@ import org.apache.spark.{Logging, SparkConf, SparkContext, SparkEnv, TaskState}
import org.apache.spark.TaskState.TaskState
import org.apache.spark.executor.{Executor, ExecutorBackend}
import org.apache.spark.rpc.{RpcCallContext, RpcEndpointRef, RpcEnv, ThreadSafeRpcEndpoint}
-import org.apache.spark.scheduler.{SchedulerBackend, TaskSchedulerImpl, WorkerOffer}
+import org.apache.spark.scheduler._
+import org.apache.spark.scheduler.cluster.ExecutorInfo
private case class ReviveOffers()
@@ -50,8 +51,8 @@ private[spark] class LocalEndpoint(
private var freeCores = totalCores
- private val localExecutorId = SparkContext.DRIVER_IDENTIFIER
- private val localExecutorHostname = "localhost"
+ val localExecutorId = SparkContext.DRIVER_IDENTIFIER
+ val localExecutorHostname = "localhost"
private val executor = new Executor(
localExecutorId, localExecutorHostname, SparkEnv.get, userClassPath, isLocal = true)
@@ -99,8 +100,9 @@ private[spark] class LocalBackend(
extends SchedulerBackend with ExecutorBackend with Logging {
private val appId = "local-" + System.currentTimeMillis
- var localEndpoint: RpcEndpointRef = null
+ private var localEndpoint: RpcEndpointRef = null
private val userClassPath = getUserClasspath(conf)
+ private val listenerBus = scheduler.sc.listenerBus
/**
* Returns a list of URLs representing the user classpath.
@@ -113,9 +115,13 @@ private[spark] class LocalBackend(
}
override def start() {
- localEndpoint = SparkEnv.get.rpcEnv.setupEndpoint(
- "LocalBackendEndpoint",
- new LocalEndpoint(SparkEnv.get.rpcEnv, userClassPath, scheduler, this, totalCores))
+ val rpcEnv = SparkEnv.get.rpcEnv
+ val executorEndpoint = new LocalEndpoint(rpcEnv, userClassPath, scheduler, this, totalCores)
+ localEndpoint = rpcEnv.setupEndpoint("LocalBackendEndpoint", executorEndpoint)
+ listenerBus.post(SparkListenerExecutorAdded(
+ System.currentTimeMillis,
+ executorEndpoint.localExecutorId,
+ new ExecutorInfo(executorEndpoint.localExecutorHostname, totalCores, Map.empty)))
}
override def stop() {
diff --git a/core/src/main/scala/org/apache/spark/serializer/JavaSerializer.scala b/core/src/main/scala/org/apache/spark/serializer/JavaSerializer.scala
index 698d1384d580d..4a5274b46b7a0 100644
--- a/core/src/main/scala/org/apache/spark/serializer/JavaSerializer.scala
+++ b/core/src/main/scala/org/apache/spark/serializer/JavaSerializer.scala
@@ -62,8 +62,11 @@ private[spark] class JavaDeserializationStream(in: InputStream, loader: ClassLoa
extends DeserializationStream {
private val objIn = new ObjectInputStream(in) {
- override def resolveClass(desc: ObjectStreamClass): Class[_] =
+ override def resolveClass(desc: ObjectStreamClass): Class[_] = {
+ // scalastyle:off classforname
Class.forName(desc.getName, false, loader)
+ // scalastyle:on classforname
+ }
}
def readObject[T: ClassTag](): T = objIn.readObject().asInstanceOf[T]
diff --git a/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala b/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala
index ed35cffe968f8..7cb6e080533ad 100644
--- a/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala
+++ b/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala
@@ -102,6 +102,7 @@ class KryoSerializer(conf: SparkConf)
kryo.register(classOf[PythonBroadcast], new KryoJavaSerializer())
try {
+ // scalastyle:off classforname
// Use the default classloader when calling the user registrator.
Thread.currentThread.setContextClassLoader(classLoader)
// Register classes given through spark.kryo.classesToRegister.
@@ -111,6 +112,7 @@ class KryoSerializer(conf: SparkConf)
userRegistrator
.map(Class.forName(_, true, classLoader).newInstance().asInstanceOf[KryoRegistrator])
.foreach { reg => reg.registerClasses(kryo) }
+ // scalastyle:on classforname
} catch {
case e: Exception =>
throw new SparkException(s"Failed to register classes with Kryo", e)
diff --git a/core/src/main/scala/org/apache/spark/serializer/SerializationDebugger.scala b/core/src/main/scala/org/apache/spark/serializer/SerializationDebugger.scala
index cc2f0506817d3..a1b1e1631eafb 100644
--- a/core/src/main/scala/org/apache/spark/serializer/SerializationDebugger.scala
+++ b/core/src/main/scala/org/apache/spark/serializer/SerializationDebugger.scala
@@ -407,7 +407,9 @@ private[spark] object SerializationDebugger extends Logging {
/** ObjectStreamClass$ClassDataSlot.desc field */
val DescField: Field = {
+ // scalastyle:off classforname
val f = Class.forName("java.io.ObjectStreamClass$ClassDataSlot").getDeclaredField("desc")
+ // scalastyle:on classforname
f.setAccessible(true)
f
}
diff --git a/core/src/main/scala/org/apache/spark/shuffle/FileShuffleBlockResolver.scala b/core/src/main/scala/org/apache/spark/shuffle/FileShuffleBlockResolver.scala
index 6c3b3080d2605..f6a96d81e7aa9 100644
--- a/core/src/main/scala/org/apache/spark/shuffle/FileShuffleBlockResolver.scala
+++ b/core/src/main/scala/org/apache/spark/shuffle/FileShuffleBlockResolver.scala
@@ -35,7 +35,7 @@ import org.apache.spark.util.collection.{PrimitiveKeyOpenHashMap, PrimitiveVecto
/** A group of writers for a ShuffleMapTask, one writer per reducer. */
private[spark] trait ShuffleWriterGroup {
- val writers: Array[BlockObjectWriter]
+ val writers: Array[DiskBlockObjectWriter]
/** @param success Indicates all writes were successful. If false, no blocks will be recorded. */
def releaseWriters(success: Boolean)
@@ -113,15 +113,15 @@ private[spark] class FileShuffleBlockResolver(conf: SparkConf)
val openStartTime = System.nanoTime
val serializerInstance = serializer.newInstance()
- val writers: Array[BlockObjectWriter] = if (consolidateShuffleFiles) {
+ val writers: Array[DiskBlockObjectWriter] = if (consolidateShuffleFiles) {
fileGroup = getUnusedFileGroup()
- Array.tabulate[BlockObjectWriter](numBuckets) { bucketId =>
+ Array.tabulate[DiskBlockObjectWriter](numBuckets) { bucketId =>
val blockId = ShuffleBlockId(shuffleId, mapId, bucketId)
blockManager.getDiskWriter(blockId, fileGroup(bucketId), serializerInstance, bufferSize,
writeMetrics)
}
} else {
- Array.tabulate[BlockObjectWriter](numBuckets) { bucketId =>
+ Array.tabulate[DiskBlockObjectWriter](numBuckets) { bucketId =>
val blockId = ShuffleBlockId(shuffleId, mapId, bucketId)
val blockFile = blockManager.diskBlockManager.getFile(blockId)
// Because of previous failures, the shuffle file may already exist on this machine.
diff --git a/core/src/main/scala/org/apache/spark/shuffle/IndexShuffleBlockResolver.scala b/core/src/main/scala/org/apache/spark/shuffle/IndexShuffleBlockResolver.scala
index d9c63b6e7bbb9..fae69551e7330 100644
--- a/core/src/main/scala/org/apache/spark/shuffle/IndexShuffleBlockResolver.scala
+++ b/core/src/main/scala/org/apache/spark/shuffle/IndexShuffleBlockResolver.scala
@@ -114,7 +114,7 @@ private[spark] class IndexShuffleBlockResolver(conf: SparkConf) extends ShuffleB
}
private[spark] object IndexShuffleBlockResolver {
- // No-op reduce ID used in interactions with disk store and BlockObjectWriter.
+ // No-op reduce ID used in interactions with disk store and DiskBlockObjectWriter.
// The disk store currently expects puts to relate to a (map, reduce) pair, but in the sort
// shuffle outputs for several reduces are glommed into a single file.
// TODO: Avoid this entirely by having the DiskBlockObjectWriter not require a BlockId.
diff --git a/core/src/main/scala/org/apache/spark/shuffle/hash/BlockStoreShuffleFetcher.scala b/core/src/main/scala/org/apache/spark/shuffle/hash/BlockStoreShuffleFetcher.scala
deleted file mode 100644
index 9d8e7e9f03aea..0000000000000
--- a/core/src/main/scala/org/apache/spark/shuffle/hash/BlockStoreShuffleFetcher.scala
+++ /dev/null
@@ -1,85 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.shuffle.hash
-
-import java.io.InputStream
-
-import scala.collection.mutable.{ArrayBuffer, HashMap}
-import scala.util.{Failure, Success}
-
-import org.apache.spark._
-import org.apache.spark.shuffle.FetchFailedException
-import org.apache.spark.storage.{BlockId, BlockManager, BlockManagerId, ShuffleBlockFetcherIterator,
- ShuffleBlockId}
-
-private[hash] object BlockStoreShuffleFetcher extends Logging {
- def fetchBlockStreams(
- shuffleId: Int,
- reduceId: Int,
- context: TaskContext,
- blockManager: BlockManager,
- mapOutputTracker: MapOutputTracker)
- : Iterator[(BlockId, InputStream)] =
- {
- logDebug("Fetching outputs for shuffle %d, reduce %d".format(shuffleId, reduceId))
-
- val startTime = System.currentTimeMillis
- val statuses = mapOutputTracker.getServerStatuses(shuffleId, reduceId)
- logDebug("Fetching map output location for shuffle %d, reduce %d took %d ms".format(
- shuffleId, reduceId, System.currentTimeMillis - startTime))
-
- val splitsByAddress = new HashMap[BlockManagerId, ArrayBuffer[(Int, Long)]]
- for (((address, size), index) <- statuses.zipWithIndex) {
- splitsByAddress.getOrElseUpdate(address, ArrayBuffer()) += ((index, size))
- }
-
- val blocksByAddress: Seq[(BlockManagerId, Seq[(BlockId, Long)])] = splitsByAddress.toSeq.map {
- case (address, splits) =>
- (address, splits.map(s => (ShuffleBlockId(shuffleId, s._1, reduceId), s._2)))
- }
-
- val blockFetcherItr = new ShuffleBlockFetcherIterator(
- context,
- blockManager.shuffleClient,
- blockManager,
- blocksByAddress,
- // Note: we use getSizeAsMb when no suffix is provided for backwards compatibility
- SparkEnv.get.conf.getSizeAsMb("spark.reducer.maxSizeInFlight", "48m") * 1024 * 1024)
-
- // Make sure that fetch failures are wrapped inside a FetchFailedException for the scheduler
- blockFetcherItr.map { blockPair =>
- val blockId = blockPair._1
- val blockOption = blockPair._2
- blockOption match {
- case Success(inputStream) => {
- (blockId, inputStream)
- }
- case Failure(e) => {
- blockId match {
- case ShuffleBlockId(shufId, mapId, _) =>
- val address = statuses(mapId.toInt)._1
- throw new FetchFailedException(address, shufId.toInt, mapId.toInt, reduceId, e)
- case _ =>
- throw new SparkException(
- "Failed to get block " + blockId + ", which is not a shuffle block", e)
- }
- }
- }
- }
- }
-}
diff --git a/core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleReader.scala b/core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleReader.scala
index d5c9880659dd3..de79fa56f017b 100644
--- a/core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleReader.scala
+++ b/core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleReader.scala
@@ -17,10 +17,10 @@
package org.apache.spark.shuffle.hash
-import org.apache.spark.{InterruptibleIterator, MapOutputTracker, SparkEnv, TaskContext}
+import org.apache.spark.{InterruptibleIterator, Logging, MapOutputTracker, SparkEnv, TaskContext}
import org.apache.spark.serializer.Serializer
import org.apache.spark.shuffle.{BaseShuffleHandle, ShuffleReader}
-import org.apache.spark.storage.BlockManager
+import org.apache.spark.storage.{BlockManager, ShuffleBlockFetcherIterator}
import org.apache.spark.util.CompletionIterator
import org.apache.spark.util.collection.ExternalSorter
@@ -31,8 +31,8 @@ private[spark] class HashShuffleReader[K, C](
context: TaskContext,
blockManager: BlockManager = SparkEnv.get.blockManager,
mapOutputTracker: MapOutputTracker = SparkEnv.get.mapOutputTracker)
- extends ShuffleReader[K, C]
-{
+ extends ShuffleReader[K, C] with Logging {
+
require(endPartition == startPartition + 1,
"Hash shuffle currently only supports fetching one partition")
@@ -40,11 +40,16 @@ private[spark] class HashShuffleReader[K, C](
/** Read the combined key-values for this reduce task */
override def read(): Iterator[Product2[K, C]] = {
- val blockStreams = BlockStoreShuffleFetcher.fetchBlockStreams(
- handle.shuffleId, startPartition, context, blockManager, mapOutputTracker)
+ val blockFetcherItr = new ShuffleBlockFetcherIterator(
+ context,
+ blockManager.shuffleClient,
+ blockManager,
+ mapOutputTracker.getMapSizesByExecutorId(handle.shuffleId, startPartition),
+ // Note: we use getSizeAsMb when no suffix is provided for backwards compatibility
+ SparkEnv.get.conf.getSizeAsMb("spark.reducer.maxSizeInFlight", "48m") * 1024 * 1024)
// Wrap the streams for compression based on configuration
- val wrappedStreams = blockStreams.map { case (blockId, inputStream) =>
+ val wrappedStreams = blockFetcherItr.map { case (blockId, inputStream) =>
blockManager.wrapForCompression(blockId, inputStream)
}
diff --git a/core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleWriter.scala b/core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleWriter.scala
index eb87cee15903c..41df70c602c30 100644
--- a/core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleWriter.scala
+++ b/core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleWriter.scala
@@ -22,7 +22,7 @@ import org.apache.spark.executor.ShuffleWriteMetrics
import org.apache.spark.scheduler.MapStatus
import org.apache.spark.serializer.Serializer
import org.apache.spark.shuffle._
-import org.apache.spark.storage.BlockObjectWriter
+import org.apache.spark.storage.DiskBlockObjectWriter
private[spark] class HashShuffleWriter[K, V](
shuffleBlockResolver: FileShuffleBlockResolver,
@@ -102,7 +102,7 @@ private[spark] class HashShuffleWriter[K, V](
private def commitWritesAndBuildStatus(): MapStatus = {
// Commit the writes. Get the size of each bucket block (total block size).
- val sizes: Array[Long] = shuffle.writers.map { writer: BlockObjectWriter =>
+ val sizes: Array[Long] = shuffle.writers.map { writer: DiskBlockObjectWriter =>
writer.commitAndClose()
writer.fileSegment().length
}
diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala
index 1beafa1771448..86493673d958d 100644
--- a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala
+++ b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala
@@ -648,7 +648,7 @@ private[spark] class BlockManager(
file: File,
serializerInstance: SerializerInstance,
bufferSize: Int,
- writeMetrics: ShuffleWriteMetrics): BlockObjectWriter = {
+ writeMetrics: ShuffleWriteMetrics): DiskBlockObjectWriter = {
val compressStream: OutputStream => OutputStream = wrapForCompression(blockId, _)
val syncWrites = conf.getBoolean("spark.shuffle.sync", false)
new DiskBlockObjectWriter(blockId, file, serializerInstance, bufferSize, compressStream,
diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterEndpoint.scala b/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterEndpoint.scala
index 68ed9096731c5..5dc0c537cbb62 100644
--- a/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterEndpoint.scala
+++ b/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterEndpoint.scala
@@ -60,10 +60,11 @@ class BlockManagerMasterEndpoint(
register(blockManagerId, maxMemSize, slaveEndpoint)
context.reply(true)
- case UpdateBlockInfo(
+ case _updateBlockInfo @ UpdateBlockInfo(
blockManagerId, blockId, storageLevel, deserializedSize, size, externalBlockStoreSize) =>
context.reply(updateBlockInfo(
blockManagerId, blockId, storageLevel, deserializedSize, size, externalBlockStoreSize))
+ listenerBus.post(SparkListenerBlockUpdated(BlockUpdatedInfo(_updateBlockInfo)))
case GetLocations(blockId) =>
context.reply(getLocations(blockId))
diff --git a/core/src/main/scala/org/apache/spark/storage/BlockStatusListener.scala b/core/src/main/scala/org/apache/spark/storage/BlockStatusListener.scala
new file mode 100644
index 0000000000000..2789e25b8d3ab
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/storage/BlockStatusListener.scala
@@ -0,0 +1,105 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.storage
+
+import scala.collection.mutable
+
+import org.apache.spark.scheduler._
+
+private[spark] case class BlockUIData(
+ blockId: BlockId,
+ location: String,
+ storageLevel: StorageLevel,
+ memSize: Long,
+ diskSize: Long,
+ externalBlockStoreSize: Long)
+
+/**
+ * The aggregated status of stream blocks in an executor
+ */
+private[spark] case class ExecutorStreamBlockStatus(
+ executorId: String,
+ location: String,
+ blocks: Seq[BlockUIData]) {
+
+ def totalMemSize: Long = blocks.map(_.memSize).sum
+
+ def totalDiskSize: Long = blocks.map(_.diskSize).sum
+
+ def totalExternalBlockStoreSize: Long = blocks.map(_.externalBlockStoreSize).sum
+
+ def numStreamBlocks: Int = blocks.size
+
+}
+
+private[spark] class BlockStatusListener extends SparkListener {
+
+ private val blockManagers =
+ new mutable.HashMap[BlockManagerId, mutable.HashMap[BlockId, BlockUIData]]
+
+ override def onBlockUpdated(blockUpdated: SparkListenerBlockUpdated): Unit = {
+ val blockId = blockUpdated.blockUpdatedInfo.blockId
+ if (!blockId.isInstanceOf[StreamBlockId]) {
+ // Now we only monitor StreamBlocks
+ return
+ }
+ val blockManagerId = blockUpdated.blockUpdatedInfo.blockManagerId
+ val storageLevel = blockUpdated.blockUpdatedInfo.storageLevel
+ val memSize = blockUpdated.blockUpdatedInfo.memSize
+ val diskSize = blockUpdated.blockUpdatedInfo.diskSize
+ val externalBlockStoreSize = blockUpdated.blockUpdatedInfo.externalBlockStoreSize
+
+ synchronized {
+ // Drop the update info if the block manager is not registered
+ blockManagers.get(blockManagerId).foreach { blocksInBlockManager =>
+ if (storageLevel.isValid) {
+ blocksInBlockManager.put(blockId,
+ BlockUIData(
+ blockId,
+ blockManagerId.hostPort,
+ storageLevel,
+ memSize,
+ diskSize,
+ externalBlockStoreSize)
+ )
+ } else {
+ // If isValid is not true, it means we should drop the block.
+ blocksInBlockManager -= blockId
+ }
+ }
+ }
+ }
+
+ override def onBlockManagerAdded(blockManagerAdded: SparkListenerBlockManagerAdded): Unit = {
+ synchronized {
+ blockManagers.put(blockManagerAdded.blockManagerId, mutable.HashMap())
+ }
+ }
+
+ override def onBlockManagerRemoved(
+ blockManagerRemoved: SparkListenerBlockManagerRemoved): Unit = synchronized {
+ blockManagers -= blockManagerRemoved.blockManagerId
+ }
+
+ def allExecutorStreamBlockStatus: Seq[ExecutorStreamBlockStatus] = synchronized {
+ blockManagers.map { case (blockManagerId, blocks) =>
+ ExecutorStreamBlockStatus(
+ blockManagerId.executorId, blockManagerId.hostPort, blocks.values.toSeq)
+ }.toSeq
+ }
+}
diff --git a/core/src/main/scala/org/apache/spark/storage/BlockUpdatedInfo.scala b/core/src/main/scala/org/apache/spark/storage/BlockUpdatedInfo.scala
new file mode 100644
index 0000000000000..a5790e4454a89
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/storage/BlockUpdatedInfo.scala
@@ -0,0 +1,47 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.storage
+
+import org.apache.spark.annotation.DeveloperApi
+import org.apache.spark.storage.BlockManagerMessages.UpdateBlockInfo
+
+/**
+ * :: DeveloperApi ::
+ * Stores information about a block status in a block manager.
+ */
+@DeveloperApi
+case class BlockUpdatedInfo(
+ blockManagerId: BlockManagerId,
+ blockId: BlockId,
+ storageLevel: StorageLevel,
+ memSize: Long,
+ diskSize: Long,
+ externalBlockStoreSize: Long)
+
+private[spark] object BlockUpdatedInfo {
+
+ private[spark] def apply(updateBlockInfo: UpdateBlockInfo): BlockUpdatedInfo = {
+ BlockUpdatedInfo(
+ updateBlockInfo.blockManagerId,
+ updateBlockInfo.blockId,
+ updateBlockInfo.storageLevel,
+ updateBlockInfo.memSize,
+ updateBlockInfo.diskSize,
+ updateBlockInfo.externalBlockStoreSize)
+ }
+}
diff --git a/core/src/main/scala/org/apache/spark/storage/BlockObjectWriter.scala b/core/src/main/scala/org/apache/spark/storage/DiskBlockObjectWriter.scala
similarity index 83%
rename from core/src/main/scala/org/apache/spark/storage/BlockObjectWriter.scala
rename to core/src/main/scala/org/apache/spark/storage/DiskBlockObjectWriter.scala
index 7eeabd1e0489c..49d9154f95a5b 100644
--- a/core/src/main/scala/org/apache/spark/storage/BlockObjectWriter.scala
+++ b/core/src/main/scala/org/apache/spark/storage/DiskBlockObjectWriter.scala
@@ -26,66 +26,25 @@ import org.apache.spark.executor.ShuffleWriteMetrics
import org.apache.spark.util.Utils
/**
- * An interface for writing JVM objects to some underlying storage. This interface allows
- * appending data to an existing block, and can guarantee atomicity in the case of faults
- * as it allows the caller to revert partial writes.
+ * A class for writing JVM objects directly to a file on disk. This class allows data to be appended
+ * to an existing block and can guarantee atomicity in the case of faults as it allows the caller to
+ * revert partial writes.
*
- * This interface does not support concurrent writes. Also, once the writer has
- * been opened, it cannot be reopened again.
- */
-private[spark] abstract class BlockObjectWriter(val blockId: BlockId) extends OutputStream {
-
- def open(): BlockObjectWriter
-
- def close()
-
- def isOpen: Boolean
-
- /**
- * Flush the partial writes and commit them as a single atomic block.
- */
- def commitAndClose(): Unit
-
- /**
- * Reverts writes that haven't been flushed yet. Callers should invoke this function
- * when there are runtime exceptions. This method will not throw, though it may be
- * unsuccessful in truncating written data.
- */
- def revertPartialWritesAndClose()
-
- /**
- * Writes a key-value pair.
- */
- def write(key: Any, value: Any)
-
- /**
- * Notify the writer that a record worth of bytes has been written with OutputStream#write.
- */
- def recordWritten()
-
- /**
- * Returns the file segment of committed data that this Writer has written.
- * This is only valid after commitAndClose() has been called.
- */
- def fileSegment(): FileSegment
-}
-
-/**
- * BlockObjectWriter which writes directly to a file on disk. Appends to the given file.
+ * This class does not support concurrent writes. Also, once the writer has been opened it cannot be
+ * reopened again.
*/
private[spark] class DiskBlockObjectWriter(
- blockId: BlockId,
+ val blockId: BlockId,
file: File,
serializerInstance: SerializerInstance,
bufferSize: Int,
compressStream: OutputStream => OutputStream,
syncWrites: Boolean,
- // These write metrics concurrently shared with other active BlockObjectWriter's who
+ // These write metrics concurrently shared with other active DiskBlockObjectWriters who
// are themselves performing writes. All updates must be relative.
writeMetrics: ShuffleWriteMetrics)
- extends BlockObjectWriter(blockId)
- with Logging
-{
+ extends OutputStream
+ with Logging {
/** The file channel, used for repositioning / truncating the file. */
private var channel: FileChannel = null
@@ -122,7 +81,7 @@ private[spark] class DiskBlockObjectWriter(
*/
private var numRecordsWritten = 0
- override def open(): BlockObjectWriter = {
+ def open(): DiskBlockObjectWriter = {
if (hasBeenClosed) {
throw new IllegalStateException("Writer already closed. Cannot be reopened.")
}
@@ -159,9 +118,12 @@ private[spark] class DiskBlockObjectWriter(
}
}
- override def isOpen: Boolean = objOut != null
+ def isOpen: Boolean = objOut != null
- override def commitAndClose(): Unit = {
+ /**
+ * Flush the partial writes and commit them as a single atomic block.
+ */
+ def commitAndClose(): Unit = {
if (initialized) {
// NOTE: Because Kryo doesn't flush the underlying stream we explicitly flush both the
// serializer stream and the lower level stream.
@@ -177,9 +139,15 @@ private[spark] class DiskBlockObjectWriter(
commitAndCloseHasBeenCalled = true
}
- // Discard current writes. We do this by flushing the outstanding writes and then
- // truncating the file to its initial position.
- override def revertPartialWritesAndClose() {
+
+ /**
+ * Reverts writes that haven't been flushed yet. Callers should invoke this function
+ * when there are runtime exceptions. This method will not throw, though it may be
+ * unsuccessful in truncating written data.
+ */
+ def revertPartialWritesAndClose() {
+ // Discard current writes. We do this by flushing the outstanding writes and then
+ // truncating the file to its initial position.
try {
if (initialized) {
writeMetrics.decShuffleBytesWritten(reportedPosition - initialPosition)
@@ -201,7 +169,10 @@ private[spark] class DiskBlockObjectWriter(
}
}
- override def write(key: Any, value: Any) {
+ /**
+ * Writes a key-value pair.
+ */
+ def write(key: Any, value: Any) {
if (!initialized) {
open()
}
@@ -221,7 +192,10 @@ private[spark] class DiskBlockObjectWriter(
bs.write(kvBytes, offs, len)
}
- override def recordWritten(): Unit = {
+ /**
+ * Notify the writer that a record worth of bytes has been written with OutputStream#write.
+ */
+ def recordWritten(): Unit = {
numRecordsWritten += 1
writeMetrics.incShuffleRecordsWritten(1)
@@ -230,7 +204,11 @@ private[spark] class DiskBlockObjectWriter(
}
}
- override def fileSegment(): FileSegment = {
+ /**
+ * Returns the file segment of committed data that this Writer has written.
+ * This is only valid after commitAndClose() has been called.
+ */
+ def fileSegment(): FileSegment = {
if (!commitAndCloseHasBeenCalled) {
throw new IllegalStateException(
"fileSegment() is only valid after commitAndClose() has been called")
diff --git a/core/src/main/scala/org/apache/spark/storage/ExternalBlockStore.scala b/core/src/main/scala/org/apache/spark/storage/ExternalBlockStore.scala
index 291394ed34816..db965d54bafd6 100644
--- a/core/src/main/scala/org/apache/spark/storage/ExternalBlockStore.scala
+++ b/core/src/main/scala/org/apache/spark/storage/ExternalBlockStore.scala
@@ -192,7 +192,7 @@ private[spark] class ExternalBlockStore(blockManager: BlockManager, executorId:
.getOrElse(ExternalBlockStore.DEFAULT_BLOCK_MANAGER_NAME)
try {
- val instance = Class.forName(clsName)
+ val instance = Utils.classForName(clsName)
.newInstance()
.asInstanceOf[ExternalBlockManager]
instance.init(blockManager, executorId)
diff --git a/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala b/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala
index e49e39679e940..a759ceb96ec1e 100644
--- a/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala
+++ b/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala
@@ -21,18 +21,19 @@ import java.io.InputStream
import java.util.concurrent.LinkedBlockingQueue
import scala.collection.mutable.{ArrayBuffer, HashSet, Queue}
-import scala.util.{Failure, Try}
+import scala.util.control.NonFatal
-import org.apache.spark.{Logging, TaskContext}
+import org.apache.spark.{Logging, SparkException, TaskContext}
import org.apache.spark.network.buffer.ManagedBuffer
import org.apache.spark.network.shuffle.{BlockFetchingListener, ShuffleClient}
+import org.apache.spark.shuffle.FetchFailedException
import org.apache.spark.util.Utils
/**
* An iterator that fetches multiple blocks. For local blocks, it fetches from the local block
* manager. For remote blocks, it fetches them using the provided BlockTransferService.
*
- * This creates an iterator of (BlockID, Try[InputStream]) tuples so the caller can handle blocks
+ * This creates an iterator of (BlockID, InputStream) tuples so the caller can handle blocks
* in a pipelined fashion as they are received.
*
* The implementation throttles the remote fetches to they don't exceed maxBytesInFlight to avoid
@@ -53,7 +54,7 @@ final class ShuffleBlockFetcherIterator(
blockManager: BlockManager,
blocksByAddress: Seq[(BlockManagerId, Seq[(BlockId, Long)])],
maxBytesInFlight: Long)
- extends Iterator[(BlockId, Try[InputStream])] with Logging {
+ extends Iterator[(BlockId, InputStream)] with Logging {
import ShuffleBlockFetcherIterator._
@@ -115,7 +116,7 @@ final class ShuffleBlockFetcherIterator(
private[storage] def releaseCurrentResultBuffer(): Unit = {
// Release the current buffer if necessary
currentResult match {
- case SuccessFetchResult(_, _, buf) => buf.release()
+ case SuccessFetchResult(_, _, _, buf) => buf.release()
case _ =>
}
currentResult = null
@@ -132,7 +133,7 @@ final class ShuffleBlockFetcherIterator(
while (iter.hasNext) {
val result = iter.next()
result match {
- case SuccessFetchResult(_, _, buf) => buf.release()
+ case SuccessFetchResult(_, _, _, buf) => buf.release()
case _ =>
}
}
@@ -157,7 +158,7 @@ final class ShuffleBlockFetcherIterator(
// Increment the ref count because we need to pass this to a different thread.
// This needs to be released after use.
buf.retain()
- results.put(new SuccessFetchResult(BlockId(blockId), sizeMap(blockId), buf))
+ results.put(new SuccessFetchResult(BlockId(blockId), address, sizeMap(blockId), buf))
shuffleMetrics.incRemoteBytesRead(buf.size)
shuffleMetrics.incRemoteBlocksFetched(1)
}
@@ -166,7 +167,7 @@ final class ShuffleBlockFetcherIterator(
override def onBlockFetchFailure(blockId: String, e: Throwable): Unit = {
logError(s"Failed to get block(s) from ${req.address.host}:${req.address.port}", e)
- results.put(new FailureFetchResult(BlockId(blockId), e))
+ results.put(new FailureFetchResult(BlockId(blockId), address, e))
}
}
)
@@ -238,12 +239,12 @@ final class ShuffleBlockFetcherIterator(
shuffleMetrics.incLocalBlocksFetched(1)
shuffleMetrics.incLocalBytesRead(buf.size)
buf.retain()
- results.put(new SuccessFetchResult(blockId, 0, buf))
+ results.put(new SuccessFetchResult(blockId, blockManager.blockManagerId, 0, buf))
} catch {
case e: Exception =>
// If we see an exception, stop immediately.
logError(s"Error occurred while fetching local blocks", e)
- results.put(new FailureFetchResult(blockId, e))
+ results.put(new FailureFetchResult(blockId, blockManager.blockManagerId, e))
return
}
}
@@ -275,12 +276,14 @@ final class ShuffleBlockFetcherIterator(
override def hasNext: Boolean = numBlocksProcessed < numBlocksToFetch
/**
- * Fetches the next (BlockId, Try[InputStream]). If a task fails, the ManagedBuffers
+ * Fetches the next (BlockId, InputStream). If a task fails, the ManagedBuffers
* underlying each InputStream will be freed by the cleanup() method registered with the
* TaskCompletionListener. However, callers should close() these InputStreams
* as soon as they are no longer needed, in order to release memory as early as possible.
+ *
+ * Throws a FetchFailedException if the next block could not be fetched.
*/
- override def next(): (BlockId, Try[InputStream]) = {
+ override def next(): (BlockId, InputStream) = {
numBlocksProcessed += 1
val startFetchWait = System.currentTimeMillis()
currentResult = results.take()
@@ -289,7 +292,7 @@ final class ShuffleBlockFetcherIterator(
shuffleMetrics.incFetchWaitTime(stopFetchWait - startFetchWait)
result match {
- case SuccessFetchResult(_, size, _) => bytesInFlight -= size
+ case SuccessFetchResult(_, _, size, _) => bytesInFlight -= size
case _ =>
}
// Send fetch requests up to maxBytesInFlight
@@ -298,19 +301,28 @@ final class ShuffleBlockFetcherIterator(
sendRequest(fetchRequests.dequeue())
}
- val iteratorTry: Try[InputStream] = result match {
- case FailureFetchResult(_, e) =>
- Failure(e)
- case SuccessFetchResult(blockId, _, buf) =>
- // There is a chance that createInputStream can fail (e.g. fetching a local file that does
- // not exist, SPARK-4085). In that case, we should propagate the right exception so
- // the scheduler gets a FetchFailedException.
- Try(buf.createInputStream()).map { inputStream =>
- new BufferReleasingInputStream(inputStream, this)
+ result match {
+ case FailureFetchResult(blockId, address, e) =>
+ throwFetchFailedException(blockId, address, e)
+
+ case SuccessFetchResult(blockId, address, _, buf) =>
+ try {
+ (result.blockId, new BufferReleasingInputStream(buf.createInputStream(), this))
+ } catch {
+ case NonFatal(t) =>
+ throwFetchFailedException(blockId, address, t)
}
}
+ }
- (result.blockId, iteratorTry)
+ private def throwFetchFailedException(blockId: BlockId, address: BlockManagerId, e: Throwable) = {
+ blockId match {
+ case ShuffleBlockId(shufId, mapId, reduceId) =>
+ throw new FetchFailedException(address, shufId.toInt, mapId.toInt, reduceId, e)
+ case _ =>
+ throw new SparkException(
+ "Failed to get block " + blockId + ", which is not a shuffle block", e)
+ }
}
}
@@ -366,16 +378,22 @@ object ShuffleBlockFetcherIterator {
*/
private[storage] sealed trait FetchResult {
val blockId: BlockId
+ val address: BlockManagerId
}
/**
* Result of a fetch from a remote block successfully.
* @param blockId block id
+ * @param address BlockManager that the block was fetched from.
* @param size estimated size of the block, used to calculate bytesInFlight.
* Note that this is NOT the exact bytes.
* @param buf [[ManagedBuffer]] for the content.
*/
- private[storage] case class SuccessFetchResult(blockId: BlockId, size: Long, buf: ManagedBuffer)
+ private[storage] case class SuccessFetchResult(
+ blockId: BlockId,
+ address: BlockManagerId,
+ size: Long,
+ buf: ManagedBuffer)
extends FetchResult {
require(buf != null)
require(size >= 0)
@@ -384,8 +402,12 @@ object ShuffleBlockFetcherIterator {
/**
* Result of a fetch from a remote block unsuccessfully.
* @param blockId block id
+ * @param address BlockManager that the block was attempted to be fetched from
* @param e the failure exception
*/
- private[storage] case class FailureFetchResult(blockId: BlockId, e: Throwable)
+ private[storage] case class FailureFetchResult(
+ blockId: BlockId,
+ address: BlockManagerId,
+ e: Throwable)
extends FetchResult
}
diff --git a/core/src/main/scala/org/apache/spark/ui/PagedTable.scala b/core/src/main/scala/org/apache/spark/ui/PagedTable.scala
new file mode 100644
index 0000000000000..17d7b39c2d951
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/ui/PagedTable.scala
@@ -0,0 +1,246 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.ui
+
+import scala.xml.{Node, Unparsed}
+
+/**
+ * A data source that provides data for a page.
+ *
+ * @param pageSize the number of rows in a page
+ */
+private[ui] abstract class PagedDataSource[T](val pageSize: Int) {
+
+ if (pageSize <= 0) {
+ throw new IllegalArgumentException("Page size must be positive")
+ }
+
+ /**
+ * Return the size of all data.
+ */
+ protected def dataSize: Int
+
+ /**
+ * Slice a range of data.
+ */
+ protected def sliceData(from: Int, to: Int): Seq[T]
+
+ /**
+ * Slice the data for this page
+ */
+ def pageData(page: Int): PageData[T] = {
+ val totalPages = (dataSize + pageSize - 1) / pageSize
+ if (page <= 0 || page > totalPages) {
+ throw new IndexOutOfBoundsException(
+ s"Page $page is out of range. Please select a page number between 1 and $totalPages.")
+ }
+ val from = (page - 1) * pageSize
+ val to = dataSize.min(page * pageSize)
+ PageData(totalPages, sliceData(from, to))
+ }
+
+}
+
+/**
+ * The data returned by `PagedDataSource.pageData`, including the page number, the number of total
+ * pages and the data in this page.
+ */
+private[ui] case class PageData[T](totalPage: Int, data: Seq[T])
+
+/**
+ * A paged table that will generate a HTML table for a specified page and also the page navigation.
+ */
+private[ui] trait PagedTable[T] {
+
+ def tableId: String
+
+ def tableCssClass: String
+
+ def dataSource: PagedDataSource[T]
+
+ def headers: Seq[Node]
+
+ def row(t: T): Seq[Node]
+
+ def table(page: Int): Seq[Node] = {
+ val _dataSource = dataSource
+ try {
+ val PageData(totalPages, data) = _dataSource.pageData(page)
+
If the totalPages is 1, the page navigation will be empty
+ *
+ * If the totalPages is more than 1, it will create a page navigation including a group of
+ * page numbers and a form to submit the page number.
+ *
+ *
+ *
+ * Here are some examples of the page navigation:
+ * {{{
+ * << < 11 12 13* 14 15 16 17 18 19 20 > >>
+ *
+ * This is the first group, so "<<" is hidden.
+ * < 1 2* 3 4 5 6 7 8 9 10 > >>
+ *
+ * This is the first group and the first page, so "<<" and "<" are hidden.
+ * 1* 2 3 4 5 6 7 8 9 10 > >>
+ *
+ * Assume totalPages is 19. This is the last group, so ">>" is hidden.
+ * << < 11 12 13* 14 15 16 17 18 19 >
+ *
+ * Assume totalPages is 19. This is the last group and the last page, so ">>" and ">" are hidden.
+ * << < 11 12 13 14 15 16 17 18 19*
+ *
+ * * means the current page number
+ * << means jumping to the first page of the previous group.
+ * < means jumping to the previous page.
+ * >> means jumping to the first page of the next group.
+ * > means jumping to the next page.
+ * }}}
+ */
+ private[ui] def pageNavigation(page: Int, pageSize: Int, totalPages: Int): Seq[Node] = {
+ if (totalPages == 1) {
+ Nil
+ } else {
+ // A group includes all page numbers will be shown in the page navigation.
+ // The size of group is 10 means there are 10 page numbers will be shown.
+ // The first group is 1 to 10, the second is 2 to 20, and so on
+ val groupSize = 10
+ val firstGroup = 0
+ val lastGroup = (totalPages - 1) / groupSize
+ val currentGroup = (page - 1) / groupSize
+ val startPage = currentGroup * groupSize + 1
+ val endPage = totalPages.min(startPage + groupSize - 1)
+ val pageTags = (startPage to endPage).map { p =>
+ if (p == page) {
+ // The current page should be disabled so that it cannot be clicked.
+
+ }
+ }
+
+ /**
+ * Return a link to jump to a page.
+ */
+ def pageLink(page: Int): String
+
+ /**
+ * Only the implementation knows how to create the url with a page number and the page size, so we
+ * leave this one to the implementation. The implementation should create a JavaScript method that
+ * accepts a page number along with the page size and jumps to the page. The return value is this
+ * method name and its JavaScript codes.
+ */
+ def goButtonJavascriptFunction: (String, String)
+}
diff --git a/core/src/main/scala/org/apache/spark/ui/UIUtils.scala b/core/src/main/scala/org/apache/spark/ui/UIUtils.scala
index 7898039519201..718aea7e1dc22 100644
--- a/core/src/main/scala/org/apache/spark/ui/UIUtils.scala
+++ b/core/src/main/scala/org/apache/spark/ui/UIUtils.scala
@@ -27,7 +27,7 @@ import org.apache.spark.ui.scope.RDDOperationGraph
/** Utility functions for generating XML pages with spark content. */
private[spark] object UIUtils extends Logging {
- val TABLE_CLASS_NOT_STRIPED = "table table-bordered table-condensed sortable"
+ val TABLE_CLASS_NOT_STRIPED = "table table-bordered table-condensed"
val TABLE_CLASS_STRIPED = TABLE_CLASS_NOT_STRIPED + " table-striped"
// SimpleDateFormat is not thread-safe. Don't expose it to avoid improper use.
@@ -267,9 +267,17 @@ private[spark] object UIUtils extends Logging {
fixedWidth: Boolean = false,
id: Option[String] = None,
headerClasses: Seq[String] = Seq.empty,
- stripeRowsWithCss: Boolean = true): Seq[Node] = {
+ stripeRowsWithCss: Boolean = true,
+ sortable: Boolean = true): Seq[Node] = {
- val listingTableClass = if (stripeRowsWithCss) TABLE_CLASS_STRIPED else TABLE_CLASS_NOT_STRIPED
+ val listingTableClass = {
+ val _tableClass = if (stripeRowsWithCss) TABLE_CLASS_STRIPED else TABLE_CLASS_NOT_STRIPED
+ if (sortable) {
+ _tableClass + " sortable"
+ } else {
+ _tableClass
+ }
+ }
val colWidth = 100.toDouble / headers.size
val colWidthAttr = if (fixedWidth) colWidth + "%" else ""
diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala b/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala
index 60e3c6343122c..cf04b5e59239b 100644
--- a/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala
+++ b/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala
@@ -17,6 +17,7 @@
package org.apache.spark.ui.jobs
+import java.net.URLEncoder
import java.util.Date
import javax.servlet.http.HttpServletRequest
@@ -27,13 +28,14 @@ import org.apache.commons.lang3.StringEscapeUtils
import org.apache.spark.executor.TaskMetrics
import org.apache.spark.scheduler.{AccumulableInfo, TaskInfo}
-import org.apache.spark.ui.{ToolTips, WebUIPage, UIUtils}
+import org.apache.spark.ui._
import org.apache.spark.ui.jobs.UIData._
-import org.apache.spark.ui.scope.RDDOperationGraph
import org.apache.spark.util.{Utils, Distribution}
/** Page showing statistics and task list for a given stage */
private[ui] class StagePage(parent: StagesTab) extends WebUIPage("stage") {
+ import StagePage._
+
private val progressListener = parent.progressListener
private val operationGraphListener = parent.operationGraphListener
@@ -74,6 +76,16 @@ private[ui] class StagePage(parent: StagesTab) extends WebUIPage("stage") {
val parameterAttempt = request.getParameter("attempt")
require(parameterAttempt != null && parameterAttempt.nonEmpty, "Missing attempt parameter")
+ val parameterTaskPage = request.getParameter("task.page")
+ val parameterTaskSortColumn = request.getParameter("task.sort")
+ val parameterTaskSortDesc = request.getParameter("task.desc")
+ val parameterTaskPageSize = request.getParameter("task.pageSize")
+
+ val taskPage = Option(parameterTaskPage).map(_.toInt).getOrElse(1)
+ val taskSortColumn = Option(parameterTaskSortColumn).getOrElse("Index")
+ val taskSortDesc = Option(parameterTaskSortDesc).map(_.toBoolean).getOrElse(false)
+ val taskPageSize = Option(parameterTaskPageSize).map(_.toInt).getOrElse(100)
+
// If this is set, expand the dag visualization by default
val expandDagVizParam = request.getParameter("expandDagViz")
val expandDagViz = expandDagVizParam != null && expandDagVizParam.toBoolean
@@ -231,52 +243,47 @@ private[ui] class StagePage(parent: StagesTab) extends WebUIPage("stage") {
accumulableRow,
accumulables.values.toSeq)
- val taskHeadersAndCssClasses: Seq[(String, String)] =
- Seq(
- ("Index", ""), ("ID", ""), ("Attempt", ""), ("Status", ""), ("Locality Level", ""),
- ("Executor ID / Host", ""), ("Launch Time", ""), ("Duration", ""),
- ("Scheduler Delay", TaskDetailsClassNames.SCHEDULER_DELAY),
- ("Task Deserialization Time", TaskDetailsClassNames.TASK_DESERIALIZATION_TIME),
- ("GC Time", ""),
- ("Result Serialization Time", TaskDetailsClassNames.RESULT_SERIALIZATION_TIME),
- ("Getting Result Time", TaskDetailsClassNames.GETTING_RESULT_TIME)) ++
- {if (hasAccumulators) Seq(("Accumulators", "")) else Nil} ++
- {if (stageData.hasInput) Seq(("Input Size / Records", "")) else Nil} ++
- {if (stageData.hasOutput) Seq(("Output Size / Records", "")) else Nil} ++
- {if (stageData.hasShuffleRead) {
- Seq(("Shuffle Read Blocked Time", TaskDetailsClassNames.SHUFFLE_READ_BLOCKED_TIME),
- ("Shuffle Read Size / Records", ""),
- ("Shuffle Remote Reads", TaskDetailsClassNames.SHUFFLE_READ_REMOTE_SIZE))
- } else {
- Nil
- }} ++
- {if (stageData.hasShuffleWrite) {
- Seq(("Write Time", ""), ("Shuffle Write Size / Records", ""))
- } else {
- Nil
- }} ++
- {if (stageData.hasBytesSpilled) {
- Seq(("Shuffle Spill (Memory)", ""), ("Shuffle Spill (Disk)", ""))
- } else {
- Nil
- }} ++
- Seq(("Errors", ""))
-
- val unzipped = taskHeadersAndCssClasses.unzip
-
val currentTime = System.currentTimeMillis()
- val taskTable = UIUtils.listingTable(
- unzipped._1,
- taskRow(
+ val (taskTable, taskTableHTML) = try {
+ val _taskTable = new TaskPagedTable(
+ UIUtils.prependBaseUri(parent.basePath) +
+ s"/stages/stage?id=${stageId}&attempt=${stageAttemptId}",
+ tasks,
hasAccumulators,
stageData.hasInput,
stageData.hasOutput,
stageData.hasShuffleRead,
stageData.hasShuffleWrite,
stageData.hasBytesSpilled,
- currentTime),
- tasks,
- headerClasses = unzipped._2)
+ currentTime,
+ pageSize = taskPageSize,
+ sortColumn = taskSortColumn,
+ desc = taskSortDesc
+ )
+ (_taskTable, _taskTable.table(taskPage))
+ } catch {
+ case e @ (_ : IllegalArgumentException | _ : IndexOutOfBoundsException) =>
+ (null,
{e.getMessage}
)
+ }
+
+ val jsForScrollingDownToTaskTable =
+
+
+ val taskIdsInPage = if (taskTable == null) Set.empty[Long]
+ else taskTable.dataSource.slicedTaskIds
+
// Excludes tasks which failed and have incomplete metrics
val validTasks = tasks.filter(t => t.taskInfo.status == "SUCCESS" && t.taskMetrics.isDefined)
@@ -332,7 +339,7 @@ private[ui] class StagePage(parent: StagesTab) extends WebUIPage("stage") {
+: getFormattedTimeQuantiles(serializationTimes)
val gettingResultTimes = validTasks.map { case TaskUIData(info, _, _) =>
- getGettingResultTime(info).toDouble
+ getGettingResultTime(info, currentTime).toDouble
}
val gettingResultQuantiles =
@@ -346,7 +353,7 @@ private[ui] class StagePage(parent: StagesTab) extends WebUIPage("stage") {
// machine and to send back the result (but not the time to fetch the task result,
// if it needed to be fetched from the block manager on the worker).
val schedulerDelays = validTasks.map { case TaskUIData(info, metrics, _) =>
- getSchedulerDelay(info, metrics.get).toDouble
+ getSchedulerDelay(info, metrics.get, currentTime).toDouble
}
val schedulerDelayTitle =
Scheduler Delay
@@ -499,12 +506,15 @@ private[ui] class StagePage(parent: StagesTab) extends WebUIPage("stage") {
dagViz ++
maybeExpandDagViz ++
showAdditionalMetrics ++
- makeTimeline(stageData.taskData.values.toSeq, currentTime) ++
+ makeTimeline(
+ // Only show the tasks in the table
+ stageData.taskData.values.toSeq.filter(t => taskIdsInPage.contains(t.taskInfo.taskId)),
+ currentTime) ++
Summary Metrics for {numCompleted} Completed Tasks
++
{summaryTable.getOrElse("No tasks have reported metrics yet.")}
+ }
+ }
+
+ private val streamBlockTableHeader = Seq(
+ "Block ID",
+ "Replication Level",
+ "Location",
+ "Storage Level",
+ "Size")
+
+ /** Render a stream block */
+ private def streamBlockTableRow(block: (BlockId, Seq[BlockUIData])): Seq[Node] = {
+ val replications = block._2
+ assert(replications.size > 0) // This must be true because it's the result of "groupBy"
+ if (replications.size == 1) {
+ streamBlockTableSubrow(block._1, replications.head, replications.size, true)
+ } else {
+ streamBlockTableSubrow(block._1, replications.head, replications.size, true) ++
+ replications.tail.map(streamBlockTableSubrow(block._1, _, replications.size, false)).flatten
+ }
+ }
+
+ private def streamBlockTableSubrow(
+ blockId: BlockId, block: BlockUIData, replication: Int, firstSubrow: Boolean): Seq[Node] = {
+ val (storageLevel, size) = streamBlockStorageLevelDescriptionAndSize(block)
+
+
+ {
+ if (firstSubrow) {
+
+ {block.blockId.toString}
+
+
+ {replication.toString}
+
+ }
+ }
+
{block.location}
+
{storageLevel}
+
{Utils.bytesToString(size)}
+
+ }
+
+ private[storage] def streamBlockStorageLevelDescriptionAndSize(
+ block: BlockUIData): (String, Long) = {
+ if (block.storageLevel.useDisk) {
+ ("Disk", block.diskSize)
+ } else if (block.storageLevel.useMemory && block.storageLevel.deserialized) {
+ ("Memory", block.memSize)
+ } else if (block.storageLevel.useMemory && !block.storageLevel.deserialized) {
+ ("Memory Serialized", block.memSize)
+ } else if (block.storageLevel.useOffHeap) {
+ ("External", block.externalBlockStoreSize)
+ } else {
+ throw new IllegalStateException(s"Invalid Storage Level: ${block.storageLevel}")
+ }
+ }
+
}
diff --git a/core/src/main/scala/org/apache/spark/ui/storage/StorageTab.scala b/core/src/main/scala/org/apache/spark/ui/storage/StorageTab.scala
index 0351749700962..22e2993b3b5bd 100644
--- a/core/src/main/scala/org/apache/spark/ui/storage/StorageTab.scala
+++ b/core/src/main/scala/org/apache/spark/ui/storage/StorageTab.scala
@@ -39,7 +39,8 @@ private[ui] class StorageTab(parent: SparkUI) extends SparkUITab(parent, "storag
* This class is thread-safe (unlike JobProgressListener)
*/
@DeveloperApi
-class StorageListener(storageStatusListener: StorageStatusListener) extends SparkListener {
+class StorageListener(storageStatusListener: StorageStatusListener) extends BlockStatusListener {
+
private[ui] val _rddInfoMap = mutable.Map[Int, RDDInfo]() // exposed for testing
def storageStatusList: Seq[StorageStatus] = storageStatusListener.storageStatusList
diff --git a/core/src/main/scala/org/apache/spark/util/AkkaUtils.scala b/core/src/main/scala/org/apache/spark/util/AkkaUtils.scala
index c179833e5b06a..78e7ddc27d1c7 100644
--- a/core/src/main/scala/org/apache/spark/util/AkkaUtils.scala
+++ b/core/src/main/scala/org/apache/spark/util/AkkaUtils.scala
@@ -128,7 +128,7 @@ private[spark] object AkkaUtils extends Logging {
/** Returns the configured max frame size for Akka messages in bytes. */
def maxFrameSizeBytes(conf: SparkConf): Int = {
- val frameSizeInMB = conf.getInt("spark.akka.frameSize", 10)
+ val frameSizeInMB = conf.getInt("spark.akka.frameSize", 128)
if (frameSizeInMB > AKKA_MAX_FRAME_SIZE_IN_MB) {
throw new IllegalArgumentException(
s"spark.akka.frameSize should not be greater than $AKKA_MAX_FRAME_SIZE_IN_MB MB")
diff --git a/core/src/main/scala/org/apache/spark/util/ClosureCleaner.scala b/core/src/main/scala/org/apache/spark/util/ClosureCleaner.scala
index 305de4c75539d..ebead830c6466 100644
--- a/core/src/main/scala/org/apache/spark/util/ClosureCleaner.scala
+++ b/core/src/main/scala/org/apache/spark/util/ClosureCleaner.scala
@@ -49,45 +49,28 @@ private[spark] object ClosureCleaner extends Logging {
cls.getName.contains("$anonfun$")
}
- // Get a list of the classes of the outer objects of a given closure object, obj;
+ // Get a list of the outer objects and their classes of a given closure object, obj;
// the outer objects are defined as any closures that obj is nested within, plus
// possibly the class that the outermost closure is in, if any. We stop searching
// for outer objects beyond that because cloning the user's object is probably
// not a good idea (whereas we can clone closure objects just fine since we
// understand how all their fields are used).
- private def getOuterClasses(obj: AnyRef): List[Class[_]] = {
+ private def getOuterClassesAndObjects(obj: AnyRef): (List[Class[_]], List[AnyRef]) = {
for (f <- obj.getClass.getDeclaredFields if f.getName == "$outer") {
f.setAccessible(true)
val outer = f.get(obj)
// The outer pointer may be null if we have cleaned this closure before
if (outer != null) {
if (isClosure(f.getType)) {
- return f.getType :: getOuterClasses(outer)
+ val recurRet = getOuterClassesAndObjects(outer)
+ return (f.getType :: recurRet._1, outer :: recurRet._2)
} else {
- return f.getType :: Nil // Stop at the first $outer that is not a closure
+ return (f.getType :: Nil, outer :: Nil) // Stop at the first $outer that is not a closure
}
}
}
- Nil
+ (Nil, Nil)
}
-
- // Get a list of the outer objects for a given closure object.
- private def getOuterObjects(obj: AnyRef): List[AnyRef] = {
- for (f <- obj.getClass.getDeclaredFields if f.getName == "$outer") {
- f.setAccessible(true)
- val outer = f.get(obj)
- // The outer pointer may be null if we have cleaned this closure before
- if (outer != null) {
- if (isClosure(f.getType)) {
- return outer :: getOuterObjects(outer)
- } else {
- return outer :: Nil // Stop at the first $outer that is not a closure
- }
- }
- }
- Nil
- }
-
/**
* Return a list of classes that represent closures enclosed in the given closure object.
*/
@@ -205,8 +188,7 @@ private[spark] object ClosureCleaner extends Logging {
// A list of enclosing objects and their respective classes, from innermost to outermost
// An outer object at a given index is of type outer class at the same index
- val outerClasses = getOuterClasses(func)
- val outerObjects = getOuterObjects(func)
+ val (outerClasses, outerObjects) = getOuterClassesAndObjects(func)
// For logging purposes only
val declaredFields = func.getClass.getDeclaredFields
@@ -448,10 +430,12 @@ private class InnerClosureFinder(output: Set[Class[_]]) extends ClassVisitor(ASM
if (op == INVOKESPECIAL && name == "" && argTypes.length > 0
&& argTypes(0).toString.startsWith("L") // is it an object?
&& argTypes(0).getInternalName == myName) {
+ // scalastyle:off classforname
output += Class.forName(
owner.replace('/', '.'),
false,
Thread.currentThread.getContextClassLoader)
+ // scalastyle:on classforname
}
}
}
diff --git a/core/src/main/scala/org/apache/spark/util/JsonProtocol.scala b/core/src/main/scala/org/apache/spark/util/JsonProtocol.scala
index adf69a4e78e71..c600319d9ddb4 100644
--- a/core/src/main/scala/org/apache/spark/util/JsonProtocol.scala
+++ b/core/src/main/scala/org/apache/spark/util/JsonProtocol.scala
@@ -92,8 +92,10 @@ private[spark] object JsonProtocol {
executorRemovedToJson(executorRemoved)
case logStart: SparkListenerLogStart =>
logStartToJson(logStart)
- // These aren't used, but keeps compiler happy
- case SparkListenerExecutorMetricsUpdate(_, _) => JNothing
+ case metricsUpdate: SparkListenerExecutorMetricsUpdate =>
+ executorMetricsUpdateToJson(metricsUpdate)
+ case blockUpdated: SparkListenerBlockUpdated =>
+ throw new MatchError(blockUpdated) // TODO(ekl) implement this
}
}
@@ -224,6 +226,19 @@ private[spark] object JsonProtocol {
("Spark Version" -> SPARK_VERSION)
}
+ def executorMetricsUpdateToJson(metricsUpdate: SparkListenerExecutorMetricsUpdate): JValue = {
+ val execId = metricsUpdate.execId
+ val taskMetrics = metricsUpdate.taskMetrics
+ ("Event" -> Utils.getFormattedClassName(metricsUpdate)) ~
+ ("Executor ID" -> execId) ~
+ ("Metrics Updated" -> taskMetrics.map { case (taskId, stageId, stageAttemptId, metrics) =>
+ ("Task ID" -> taskId) ~
+ ("Stage ID" -> stageId) ~
+ ("Stage Attempt ID" -> stageAttemptId) ~
+ ("Task Metrics" -> taskMetricsToJson(metrics))
+ })
+ }
+
/** ------------------------------------------------------------------- *
* JSON serialization methods for classes SparkListenerEvents depend on |
* -------------------------------------------------------------------- */
@@ -463,6 +478,7 @@ private[spark] object JsonProtocol {
val executorAdded = Utils.getFormattedClassName(SparkListenerExecutorAdded)
val executorRemoved = Utils.getFormattedClassName(SparkListenerExecutorRemoved)
val logStart = Utils.getFormattedClassName(SparkListenerLogStart)
+ val metricsUpdate = Utils.getFormattedClassName(SparkListenerExecutorMetricsUpdate)
(json \ "Event").extract[String] match {
case `stageSubmitted` => stageSubmittedFromJson(json)
@@ -481,6 +497,7 @@ private[spark] object JsonProtocol {
case `executorAdded` => executorAddedFromJson(json)
case `executorRemoved` => executorRemovedFromJson(json)
case `logStart` => logStartFromJson(json)
+ case `metricsUpdate` => executorMetricsUpdateFromJson(json)
}
}
@@ -598,6 +615,18 @@ private[spark] object JsonProtocol {
SparkListenerLogStart(sparkVersion)
}
+ def executorMetricsUpdateFromJson(json: JValue): SparkListenerExecutorMetricsUpdate = {
+ val execInfo = (json \ "Executor ID").extract[String]
+ val taskMetrics = (json \ "Metrics Updated").extract[List[JValue]].map { json =>
+ val taskId = (json \ "Task ID").extract[Long]
+ val stageId = (json \ "Stage ID").extract[Int]
+ val stageAttemptId = (json \ "Stage Attempt ID").extract[Int]
+ val metrics = taskMetricsFromJson(json \ "Task Metrics")
+ (taskId, stageId, stageAttemptId, metrics)
+ }
+ SparkListenerExecutorMetricsUpdate(execInfo, taskMetrics)
+ }
+
/** --------------------------------------------------------------------- *
* JSON deserialization methods for classes SparkListenerEvents depend on |
* ---------------------------------------------------------------------- */
diff --git a/core/src/main/scala/org/apache/spark/util/SerializableConfiguration.scala b/core/src/main/scala/org/apache/spark/util/SerializableConfiguration.scala
index 30bcf1d2f24d5..3354a923273ff 100644
--- a/core/src/main/scala/org/apache/spark/util/SerializableConfiguration.scala
+++ b/core/src/main/scala/org/apache/spark/util/SerializableConfiguration.scala
@@ -20,8 +20,6 @@ import java.io.{ObjectInputStream, ObjectOutputStream}
import org.apache.hadoop.conf.Configuration
-import org.apache.spark.util.Utils
-
private[spark]
class SerializableConfiguration(@transient var value: Configuration) extends Serializable {
private def writeObject(out: ObjectOutputStream): Unit = Utils.tryOrIOException {
diff --git a/core/src/main/scala/org/apache/spark/util/SerializableJobConf.scala b/core/src/main/scala/org/apache/spark/util/SerializableJobConf.scala
index afbcc6efc850c..cadae472b3f85 100644
--- a/core/src/main/scala/org/apache/spark/util/SerializableJobConf.scala
+++ b/core/src/main/scala/org/apache/spark/util/SerializableJobConf.scala
@@ -21,8 +21,6 @@ import java.io.{ObjectInputStream, ObjectOutputStream}
import org.apache.hadoop.mapred.JobConf
-import org.apache.spark.util.Utils
-
private[spark]
class SerializableJobConf(@transient var value: JobConf) extends Serializable {
private def writeObject(out: ObjectOutputStream): Unit = Utils.tryOrIOException {
diff --git a/core/src/main/scala/org/apache/spark/util/SizeEstimator.scala b/core/src/main/scala/org/apache/spark/util/SizeEstimator.scala
index 0180399c9dad5..7d84468f62ab1 100644
--- a/core/src/main/scala/org/apache/spark/util/SizeEstimator.scala
+++ b/core/src/main/scala/org/apache/spark/util/SizeEstimator.scala
@@ -124,9 +124,11 @@ object SizeEstimator extends Logging {
val server = ManagementFactory.getPlatformMBeanServer()
// NOTE: This should throw an exception in non-Sun JVMs
+ // scalastyle:off classforname
val hotSpotMBeanClass = Class.forName("com.sun.management.HotSpotDiagnosticMXBean")
val getVMMethod = hotSpotMBeanClass.getDeclaredMethod("getVMOption",
Class.forName("java.lang.String"))
+ // scalastyle:on classforname
val bean = ManagementFactory.newPlatformMXBeanProxy(server,
hotSpotMBeanName, hotSpotMBeanClass)
diff --git a/core/src/main/scala/org/apache/spark/util/Utils.scala b/core/src/main/scala/org/apache/spark/util/Utils.scala
index b6b932104a94d..c5816949cd360 100644
--- a/core/src/main/scala/org/apache/spark/util/Utils.scala
+++ b/core/src/main/scala/org/apache/spark/util/Utils.scala
@@ -113,8 +113,11 @@ private[spark] object Utils extends Logging {
def deserialize[T](bytes: Array[Byte], loader: ClassLoader): T = {
val bis = new ByteArrayInputStream(bytes)
val ois = new ObjectInputStream(bis) {
- override def resolveClass(desc: ObjectStreamClass): Class[_] =
+ override def resolveClass(desc: ObjectStreamClass): Class[_] = {
+ // scalastyle:off classforname
Class.forName(desc.getName, false, loader)
+ // scalastyle:on classforname
+ }
}
ois.readObject.asInstanceOf[T]
}
@@ -177,12 +180,16 @@ private[spark] object Utils extends Logging {
/** Determines whether the provided class is loadable in the current thread. */
def classIsLoadable(clazz: String): Boolean = {
+ // scalastyle:off classforname
Try { Class.forName(clazz, false, getContextOrSparkClassLoader) }.isSuccess
+ // scalastyle:on classforname
}
+ // scalastyle:off classforname
/** Preferred alternative to Class.forName(className) */
def classForName(className: String): Class[_] = {
Class.forName(className, true, getContextOrSparkClassLoader)
+ // scalastyle:on classforname
}
/**
@@ -1579,6 +1586,34 @@ private[spark] object Utils extends Logging {
hashAbs
}
+ /**
+ * NaN-safe version of [[java.lang.Double.compare()]] which allows NaN values to be compared
+ * according to semantics where NaN == NaN and NaN > any non-NaN double.
+ */
+ def nanSafeCompareDoubles(x: Double, y: Double): Int = {
+ val xIsNan: Boolean = java.lang.Double.isNaN(x)
+ val yIsNan: Boolean = java.lang.Double.isNaN(y)
+ if ((xIsNan && yIsNan) || (x == y)) 0
+ else if (xIsNan) 1
+ else if (yIsNan) -1
+ else if (x > y) 1
+ else -1
+ }
+
+ /**
+ * NaN-safe version of [[java.lang.Float.compare()]] which allows NaN values to be compared
+ * according to semantics where NaN == NaN and NaN > any non-NaN float.
+ */
+ def nanSafeCompareFloats(x: Float, y: Float): Int = {
+ val xIsNan: Boolean = java.lang.Float.isNaN(x)
+ val yIsNan: Boolean = java.lang.Float.isNaN(y)
+ if ((xIsNan && yIsNan) || (x == y)) 0
+ else if (xIsNan) 1
+ else if (yIsNan) -1
+ else if (x > y) 1
+ else -1
+ }
+
/** Returns the system properties map that is thread-safe to iterator over. It gets the
* properties which have been set explicitly, as well as those for which only a default value
* has been defined. */
@@ -2266,7 +2301,7 @@ private [util] class SparkShutdownHookManager {
val hookTask = new Runnable() {
override def run(): Unit = runAll()
}
- Try(Class.forName("org.apache.hadoop.util.ShutdownHookManager")) match {
+ Try(Utils.classForName("org.apache.hadoop.util.ShutdownHookManager")) match {
case Success(shmClass) =>
val fsPriority = classOf[FileSystem].getField("SHUTDOWN_HOOK_PRIORITY").get()
.asInstanceOf[Int]
diff --git a/core/src/main/scala/org/apache/spark/util/collection/ChainedBuffer.scala b/core/src/main/scala/org/apache/spark/util/collection/ChainedBuffer.scala
index 516aaa44d03fc..ae60f3b0cb555 100644
--- a/core/src/main/scala/org/apache/spark/util/collection/ChainedBuffer.scala
+++ b/core/src/main/scala/org/apache/spark/util/collection/ChainedBuffer.scala
@@ -37,7 +37,7 @@ private[spark] class ChainedBuffer(chunkSize: Int) {
private var _size: Long = 0
/**
- * Feed bytes from this buffer into a BlockObjectWriter.
+ * Feed bytes from this buffer into a DiskBlockObjectWriter.
*
* @param pos Offset in the buffer to read from.
* @param os OutputStream to read into.
diff --git a/core/src/main/scala/org/apache/spark/util/collection/ExternalAppendOnlyMap.scala b/core/src/main/scala/org/apache/spark/util/collection/ExternalAppendOnlyMap.scala
index 1e4531ef395ae..d166037351c31 100644
--- a/core/src/main/scala/org/apache/spark/util/collection/ExternalAppendOnlyMap.scala
+++ b/core/src/main/scala/org/apache/spark/util/collection/ExternalAppendOnlyMap.scala
@@ -26,7 +26,7 @@ import scala.collection.mutable.ArrayBuffer
import com.google.common.io.ByteStreams
-import org.apache.spark.{Logging, SparkEnv}
+import org.apache.spark.{Logging, SparkEnv, TaskContext}
import org.apache.spark.annotation.DeveloperApi
import org.apache.spark.serializer.{DeserializationStream, Serializer}
import org.apache.spark.storage.{BlockId, BlockManager}
@@ -470,14 +470,27 @@ class ExternalAppendOnlyMap[K, V, C](
item
}
- // TODO: Ensure this gets called even if the iterator isn't drained.
private def cleanup() {
batchIndex = batchOffsets.length // Prevent reading any other batch
val ds = deserializeStream
- deserializeStream = null
- fileStream = null
- ds.close()
- file.delete()
+ if (ds != null) {
+ ds.close()
+ deserializeStream = null
+ }
+ if (fileStream != null) {
+ fileStream.close()
+ fileStream = null
+ }
+ if (file.exists()) {
+ file.delete()
+ }
+ }
+
+ val context = TaskContext.get()
+ // context is null in some tests of ExternalAppendOnlyMapSuite because these tests don't run in
+ // a TaskContext.
+ if (context != null) {
+ context.addTaskCompletionListener(context => cleanup())
}
}
diff --git a/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala b/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala
index 757dec66c203b..ba7ec834d622d 100644
--- a/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala
+++ b/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala
@@ -30,7 +30,7 @@ import org.apache.spark._
import org.apache.spark.serializer._
import org.apache.spark.executor.ShuffleWriteMetrics
import org.apache.spark.shuffle.sort.{SortShuffleFileWriter, SortShuffleWriter}
-import org.apache.spark.storage.{BlockId, BlockObjectWriter}
+import org.apache.spark.storage.{BlockId, DiskBlockObjectWriter}
/**
* Sorts and potentially merges a number of key-value pairs of type (K, V) to produce key-combiner
@@ -250,7 +250,7 @@ private[spark] class ExternalSorter[K, V, C](
// These variables are reset after each flush
var objectsWritten: Long = 0
var spillMetrics: ShuffleWriteMetrics = null
- var writer: BlockObjectWriter = null
+ var writer: DiskBlockObjectWriter = null
def openWriter(): Unit = {
assert (writer == null && spillMetrics == null)
spillMetrics = new ShuffleWriteMetrics
diff --git a/core/src/main/scala/org/apache/spark/util/collection/PartitionedPairBuffer.scala b/core/src/main/scala/org/apache/spark/util/collection/PartitionedPairBuffer.scala
index 04bb7fc78c13b..f5844d5353be7 100644
--- a/core/src/main/scala/org/apache/spark/util/collection/PartitionedPairBuffer.scala
+++ b/core/src/main/scala/org/apache/spark/util/collection/PartitionedPairBuffer.scala
@@ -19,7 +19,6 @@ package org.apache.spark.util.collection
import java.util.Comparator
-import org.apache.spark.storage.BlockObjectWriter
import org.apache.spark.util.collection.WritablePartitionedPairCollection._
/**
diff --git a/core/src/main/scala/org/apache/spark/util/collection/PartitionedSerializedPairBuffer.scala b/core/src/main/scala/org/apache/spark/util/collection/PartitionedSerializedPairBuffer.scala
index ae9a48729e201..87a786b02d651 100644
--- a/core/src/main/scala/org/apache/spark/util/collection/PartitionedSerializedPairBuffer.scala
+++ b/core/src/main/scala/org/apache/spark/util/collection/PartitionedSerializedPairBuffer.scala
@@ -21,9 +21,8 @@ import java.io.InputStream
import java.nio.IntBuffer
import java.util.Comparator
-import org.apache.spark.SparkEnv
import org.apache.spark.serializer.{JavaSerializerInstance, SerializerInstance}
-import org.apache.spark.storage.BlockObjectWriter
+import org.apache.spark.storage.DiskBlockObjectWriter
import org.apache.spark.util.collection.PartitionedSerializedPairBuffer._
/**
@@ -136,7 +135,7 @@ private[spark] class PartitionedSerializedPairBuffer[K, V](
// current position in the meta buffer in ints
var pos = 0
- def writeNext(writer: BlockObjectWriter): Unit = {
+ def writeNext(writer: DiskBlockObjectWriter): Unit = {
val keyStart = getKeyStartPos(metaBuffer, pos)
val keyValLen = metaBuffer.get(pos + KEY_VAL_LEN)
pos += RECORD_SIZE
diff --git a/core/src/main/scala/org/apache/spark/util/collection/WritablePartitionedPairCollection.scala b/core/src/main/scala/org/apache/spark/util/collection/WritablePartitionedPairCollection.scala
index 7bc59898658e4..38848e9018c6c 100644
--- a/core/src/main/scala/org/apache/spark/util/collection/WritablePartitionedPairCollection.scala
+++ b/core/src/main/scala/org/apache/spark/util/collection/WritablePartitionedPairCollection.scala
@@ -19,7 +19,7 @@ package org.apache.spark.util.collection
import java.util.Comparator
-import org.apache.spark.storage.BlockObjectWriter
+import org.apache.spark.storage.DiskBlockObjectWriter
/**
* A common interface for size-tracking collections of key-value pairs that
@@ -51,7 +51,7 @@ private[spark] trait WritablePartitionedPairCollection[K, V] {
new WritablePartitionedIterator {
private[this] var cur = if (it.hasNext) it.next() else null
- def writeNext(writer: BlockObjectWriter): Unit = {
+ def writeNext(writer: DiskBlockObjectWriter): Unit = {
writer.write(cur._1._2, cur._2)
cur = if (it.hasNext) it.next() else null
}
@@ -91,11 +91,11 @@ private[spark] object WritablePartitionedPairCollection {
}
/**
- * Iterator that writes elements to a BlockObjectWriter instead of returning them. Each element
+ * Iterator that writes elements to a DiskBlockObjectWriter instead of returning them. Each element
* has an associated partition.
*/
private[spark] trait WritablePartitionedIterator {
- def writeNext(writer: BlockObjectWriter): Unit
+ def writeNext(writer: DiskBlockObjectWriter): Unit
def hasNext(): Boolean
diff --git a/core/src/test/java/org/apache/spark/JavaAPISuite.java b/core/src/test/java/org/apache/spark/JavaAPISuite.java
index dfd86d3e51e7d..e948ca33471a4 100644
--- a/core/src/test/java/org/apache/spark/JavaAPISuite.java
+++ b/core/src/test/java/org/apache/spark/JavaAPISuite.java
@@ -1011,7 +1011,7 @@ public void persist() {
@Test
public void iterator() {
JavaRDD rdd = sc.parallelize(Arrays.asList(1, 2, 3, 4, 5), 2);
- TaskContext context = new TaskContextImpl(0, 0, 0L, 0, null, false, new TaskMetrics());
+ TaskContext context = new TaskContextImpl(0, 0, 0L, 0, null, null, false, new TaskMetrics());
Assert.assertEquals(1, rdd.iterator(rdd.partitions().get(0), context).next().intValue());
}
@@ -1783,7 +1783,7 @@ public void testGuavaOptional() {
// Stop the context created in setUp() and start a local-cluster one, to force usage of the
// assembly.
sc.stop();
- JavaSparkContext localCluster = new JavaSparkContext("local-cluster[1,1,512]", "JavaAPISuite");
+ JavaSparkContext localCluster = new JavaSparkContext("local-cluster[1,1,1024]", "JavaAPISuite");
try {
JavaRDD rdd1 = localCluster.parallelize(Arrays.asList(1, 2, null), 3);
JavaRDD> rdd2 = rdd1.map(
diff --git a/core/src/test/scala/org/apache/spark/CacheManagerSuite.scala b/core/src/test/scala/org/apache/spark/CacheManagerSuite.scala
index af81e46a657d3..618a5fb24710f 100644
--- a/core/src/test/scala/org/apache/spark/CacheManagerSuite.scala
+++ b/core/src/test/scala/org/apache/spark/CacheManagerSuite.scala
@@ -65,7 +65,7 @@ class CacheManagerSuite extends SparkFunSuite with LocalSparkContext with Before
// in blockManager.put is a losing battle. You have been warned.
blockManager = sc.env.blockManager
cacheManager = sc.env.cacheManager
- val context = new TaskContextImpl(0, 0, 0, 0, null)
+ val context = new TaskContextImpl(0, 0, 0, 0, null, null)
val computeValue = cacheManager.getOrCompute(rdd, split, context, StorageLevel.MEMORY_ONLY)
val getValue = blockManager.get(RDDBlockId(rdd.id, split.index))
assert(computeValue.toList === List(1, 2, 3, 4))
@@ -77,7 +77,7 @@ class CacheManagerSuite extends SparkFunSuite with LocalSparkContext with Before
val result = new BlockResult(Array(5, 6, 7).iterator, DataReadMethod.Memory, 12)
when(blockManager.get(RDDBlockId(0, 0))).thenReturn(Some(result))
- val context = new TaskContextImpl(0, 0, 0, 0, null)
+ val context = new TaskContextImpl(0, 0, 0, 0, null, null)
val value = cacheManager.getOrCompute(rdd, split, context, StorageLevel.MEMORY_ONLY)
assert(value.toList === List(5, 6, 7))
}
@@ -86,14 +86,14 @@ class CacheManagerSuite extends SparkFunSuite with LocalSparkContext with Before
// Local computation should not persist the resulting value, so don't expect a put().
when(blockManager.get(RDDBlockId(0, 0))).thenReturn(None)
- val context = new TaskContextImpl(0, 0, 0, 0, null, true)
+ val context = new TaskContextImpl(0, 0, 0, 0, null, null, true)
val value = cacheManager.getOrCompute(rdd, split, context, StorageLevel.MEMORY_ONLY)
assert(value.toList === List(1, 2, 3, 4))
}
test("verify task metrics updated correctly") {
cacheManager = sc.env.cacheManager
- val context = new TaskContextImpl(0, 0, 0, 0, null)
+ val context = new TaskContextImpl(0, 0, 0, 0, null, null)
cacheManager.getOrCompute(rdd3, split, context, StorageLevel.MEMORY_ONLY)
assert(context.taskMetrics.updatedBlocks.getOrElse(Seq()).size === 2)
}
diff --git a/core/src/test/scala/org/apache/spark/ContextCleanerSuite.scala b/core/src/test/scala/org/apache/spark/ContextCleanerSuite.scala
index 501fe186bfd7c..26858ef2774fc 100644
--- a/core/src/test/scala/org/apache/spark/ContextCleanerSuite.scala
+++ b/core/src/test/scala/org/apache/spark/ContextCleanerSuite.scala
@@ -292,7 +292,7 @@ class ContextCleanerSuite extends ContextCleanerSuiteBase {
sc.stop()
val conf2 = new SparkConf()
- .setMaster("local-cluster[2, 1, 512]")
+ .setMaster("local-cluster[2, 1, 1024]")
.setAppName("ContextCleanerSuite")
.set("spark.cleaner.referenceTracking.blocking", "true")
.set("spark.cleaner.referenceTracking.blocking.shuffle", "true")
@@ -370,7 +370,7 @@ class SortShuffleContextCleanerSuite extends ContextCleanerSuiteBase(classOf[Sor
sc.stop()
val conf2 = new SparkConf()
- .setMaster("local-cluster[2, 1, 512]")
+ .setMaster("local-cluster[2, 1, 1024]")
.setAppName("ContextCleanerSuite")
.set("spark.cleaner.referenceTracking.blocking", "true")
.set("spark.cleaner.referenceTracking.blocking.shuffle", "true")
diff --git a/core/src/test/scala/org/apache/spark/DistributedSuite.scala b/core/src/test/scala/org/apache/spark/DistributedSuite.scala
index 2300bcff4f118..600c1403b0344 100644
--- a/core/src/test/scala/org/apache/spark/DistributedSuite.scala
+++ b/core/src/test/scala/org/apache/spark/DistributedSuite.scala
@@ -29,7 +29,7 @@ class NotSerializableExn(val notSer: NotSerializableClass) extends Throwable() {
class DistributedSuite extends SparkFunSuite with Matchers with LocalSparkContext {
- val clusterUrl = "local-cluster[2,1,512]"
+ val clusterUrl = "local-cluster[2,1,1024]"
test("task throws not serializable exception") {
// Ensures that executors do not crash when an exn is not serializable. If executors crash,
@@ -40,7 +40,7 @@ class DistributedSuite extends SparkFunSuite with Matchers with LocalSparkContex
val numSlaves = 3
val numPartitions = 10
- sc = new SparkContext("local-cluster[%s,1,512]".format(numSlaves), "test")
+ sc = new SparkContext("local-cluster[%s,1,1024]".format(numSlaves), "test")
val data = sc.parallelize(1 to 100, numPartitions).
map(x => throw new NotSerializableExn(new NotSerializableClass))
intercept[SparkException] {
@@ -50,16 +50,16 @@ class DistributedSuite extends SparkFunSuite with Matchers with LocalSparkContex
}
test("local-cluster format") {
- sc = new SparkContext("local-cluster[2,1,512]", "test")
+ sc = new SparkContext("local-cluster[2,1,1024]", "test")
assert(sc.parallelize(1 to 2, 2).count() == 2)
resetSparkContext()
- sc = new SparkContext("local-cluster[2 , 1 , 512]", "test")
+ sc = new SparkContext("local-cluster[2 , 1 , 1024]", "test")
assert(sc.parallelize(1 to 2, 2).count() == 2)
resetSparkContext()
- sc = new SparkContext("local-cluster[2, 1, 512]", "test")
+ sc = new SparkContext("local-cluster[2, 1, 1024]", "test")
assert(sc.parallelize(1 to 2, 2).count() == 2)
resetSparkContext()
- sc = new SparkContext("local-cluster[ 2, 1, 512 ]", "test")
+ sc = new SparkContext("local-cluster[ 2, 1, 1024 ]", "test")
assert(sc.parallelize(1 to 2, 2).count() == 2)
resetSparkContext()
}
@@ -276,7 +276,7 @@ class DistributedSuite extends SparkFunSuite with Matchers with LocalSparkContex
DistributedSuite.amMaster = true
// Using more than two nodes so we don't have a symmetric communication pattern and might
// cache a partially correct list of peers.
- sc = new SparkContext("local-cluster[3,1,512]", "test")
+ sc = new SparkContext("local-cluster[3,1,1024]", "test")
for (i <- 1 to 3) {
val data = sc.parallelize(Seq(true, false, false, false), 4)
data.persist(StorageLevel.MEMORY_ONLY_2)
@@ -294,7 +294,7 @@ class DistributedSuite extends SparkFunSuite with Matchers with LocalSparkContex
test("unpersist RDDs") {
DistributedSuite.amMaster = true
- sc = new SparkContext("local-cluster[3,1,512]", "test")
+ sc = new SparkContext("local-cluster[3,1,1024]", "test")
val data = sc.parallelize(Seq(true, false, false, false), 4)
data.persist(StorageLevel.MEMORY_ONLY_2)
data.count
diff --git a/core/src/test/scala/org/apache/spark/DriverSuite.scala b/core/src/test/scala/org/apache/spark/DriverSuite.scala
index b2262033ca238..454b7e607a51b 100644
--- a/core/src/test/scala/org/apache/spark/DriverSuite.scala
+++ b/core/src/test/scala/org/apache/spark/DriverSuite.scala
@@ -29,7 +29,7 @@ class DriverSuite extends SparkFunSuite with Timeouts {
ignore("driver should exit after finishing without cleanup (SPARK-530)") {
val sparkHome = sys.props.getOrElse("spark.test.home", fail("spark.test.home is not set!"))
- val masters = Table("master", "local", "local-cluster[2,1,512]")
+ val masters = Table("master", "local", "local-cluster[2,1,1024]")
forAll(masters) { (master: String) =>
val process = Utils.executeCommand(
Seq(s"$sparkHome/bin/spark-class", "org.apache.spark.DriverWithoutCleanup", master),
diff --git a/core/src/test/scala/org/apache/spark/ExternalShuffleServiceSuite.scala b/core/src/test/scala/org/apache/spark/ExternalShuffleServiceSuite.scala
index 140012226fdbb..c38d70252add1 100644
--- a/core/src/test/scala/org/apache/spark/ExternalShuffleServiceSuite.scala
+++ b/core/src/test/scala/org/apache/spark/ExternalShuffleServiceSuite.scala
@@ -51,7 +51,7 @@ class ExternalShuffleServiceSuite extends ShuffleSuite with BeforeAndAfterAll {
// This test ensures that the external shuffle service is actually in use for the other tests.
test("using external shuffle service") {
- sc = new SparkContext("local-cluster[2,1,512]", "test", conf)
+ sc = new SparkContext("local-cluster[2,1,1024]", "test", conf)
sc.env.blockManager.externalShuffleServiceEnabled should equal(true)
sc.env.blockManager.shuffleClient.getClass should equal(classOf[ExternalShuffleClient])
diff --git a/core/src/test/scala/org/apache/spark/FailureSuite.scala b/core/src/test/scala/org/apache/spark/FailureSuite.scala
index b099cd3fb7965..69cb4b44cf7ef 100644
--- a/core/src/test/scala/org/apache/spark/FailureSuite.scala
+++ b/core/src/test/scala/org/apache/spark/FailureSuite.scala
@@ -141,5 +141,30 @@ class FailureSuite extends SparkFunSuite with LocalSparkContext {
FailureSuiteState.clear()
}
+ test("managed memory leak error should not mask other failures (SPARK-9266") {
+ val conf = new SparkConf().set("spark.unsafe.exceptionOnMemoryLeak", "true")
+ sc = new SparkContext("local[1,1]", "test", conf)
+
+ // If a task leaks memory but fails due to some other cause, then make sure that the original
+ // cause is preserved
+ val thrownDueToTaskFailure = intercept[SparkException] {
+ sc.parallelize(Seq(0)).mapPartitions { iter =>
+ TaskContext.get().taskMemoryManager().allocate(128)
+ throw new Exception("intentional task failure")
+ iter
+ }.count()
+ }
+ assert(thrownDueToTaskFailure.getMessage.contains("intentional task failure"))
+
+ // If the task succeeded but memory was leaked, then the task should fail due to that leak
+ val thrownDueToMemoryLeak = intercept[SparkException] {
+ sc.parallelize(Seq(0)).mapPartitions { iter =>
+ TaskContext.get().taskMemoryManager().allocate(128)
+ iter
+ }.count()
+ }
+ assert(thrownDueToMemoryLeak.getMessage.contains("memory leak"))
+ }
+
// TODO: Need to add tests with shuffle fetch failures.
}
diff --git a/core/src/test/scala/org/apache/spark/FileServerSuite.scala b/core/src/test/scala/org/apache/spark/FileServerSuite.scala
index 876418aa13029..1255e71af6c0b 100644
--- a/core/src/test/scala/org/apache/spark/FileServerSuite.scala
+++ b/core/src/test/scala/org/apache/spark/FileServerSuite.scala
@@ -139,7 +139,7 @@ class FileServerSuite extends SparkFunSuite with LocalSparkContext {
}
test("Distributing files on a standalone cluster") {
- sc = new SparkContext("local-cluster[1,1,512]", "test", newConf)
+ sc = new SparkContext("local-cluster[1,1,1024]", "test", newConf)
sc.addFile(tmpFile.toString)
val testData = Array((1, 1), (1, 1), (2, 1), (3, 5), (2, 2), (3, 0))
val result = sc.parallelize(testData).reduceByKey {
@@ -153,7 +153,7 @@ class FileServerSuite extends SparkFunSuite with LocalSparkContext {
}
test ("Dynamically adding JARS on a standalone cluster") {
- sc = new SparkContext("local-cluster[1,1,512]", "test", newConf)
+ sc = new SparkContext("local-cluster[1,1,1024]", "test", newConf)
sc.addJar(tmpJarUrl)
val testData = Array((1, 1))
sc.parallelize(testData).foreach { x =>
@@ -164,7 +164,7 @@ class FileServerSuite extends SparkFunSuite with LocalSparkContext {
}
test ("Dynamically adding JARS on a standalone cluster using local: URL") {
- sc = new SparkContext("local-cluster[1,1,512]", "test", newConf)
+ sc = new SparkContext("local-cluster[1,1,1024]", "test", newConf)
sc.addJar(tmpJarUrl.replace("file", "local"))
val testData = Array((1, 1))
sc.parallelize(testData).foreach { x =>
diff --git a/core/src/test/scala/org/apache/spark/FileSuite.scala b/core/src/test/scala/org/apache/spark/FileSuite.scala
index 1d8fade90f398..418763f4e5ffa 100644
--- a/core/src/test/scala/org/apache/spark/FileSuite.scala
+++ b/core/src/test/scala/org/apache/spark/FileSuite.scala
@@ -179,6 +179,7 @@ class FileSuite extends SparkFunSuite with LocalSparkContext {
}
test("object files of classes from a JAR") {
+ // scalastyle:off classforname
val original = Thread.currentThread().getContextClassLoader
val className = "FileSuiteObjectFileTest"
val jar = TestUtils.createJarWithClasses(Seq(className))
@@ -201,6 +202,7 @@ class FileSuite extends SparkFunSuite with LocalSparkContext {
finally {
Thread.currentThread().setContextClassLoader(original)
}
+ // scalastyle:on classforname
}
test("write SequenceFile using new Hadoop API") {
diff --git a/core/src/test/scala/org/apache/spark/HeartbeatReceiverSuite.scala b/core/src/test/scala/org/apache/spark/HeartbeatReceiverSuite.scala
index b31b09196608f..5a2670e4d1cf0 100644
--- a/core/src/test/scala/org/apache/spark/HeartbeatReceiverSuite.scala
+++ b/core/src/test/scala/org/apache/spark/HeartbeatReceiverSuite.scala
@@ -17,6 +17,9 @@
package org.apache.spark
+import java.util.concurrent.{ExecutorService, TimeUnit}
+
+import scala.collection.mutable
import scala.language.postfixOps
import org.scalatest.{BeforeAndAfterEach, PrivateMethodTester}
@@ -25,11 +28,16 @@ import org.mockito.Matchers
import org.mockito.Matchers._
import org.apache.spark.executor.TaskMetrics
-import org.apache.spark.rpc.RpcEndpointRef
+import org.apache.spark.rpc.{RpcCallContext, RpcEndpoint, RpcEnv, RpcEndpointRef}
import org.apache.spark.scheduler._
+import org.apache.spark.scheduler.cluster.CoarseGrainedClusterMessages._
+import org.apache.spark.scheduler.cluster.CoarseGrainedSchedulerBackend
import org.apache.spark.storage.BlockManagerId
import org.apache.spark.util.ManualClock
+/**
+ * A test suite for the heartbeating behavior between the driver and the executors.
+ */
class HeartbeatReceiverSuite
extends SparkFunSuite
with BeforeAndAfterEach
@@ -40,23 +48,40 @@ class HeartbeatReceiverSuite
private val executorId2 = "executor-2"
// Shared state that must be reset before and after each test
- private var scheduler: TaskScheduler = null
+ private var scheduler: TaskSchedulerImpl = null
private var heartbeatReceiver: HeartbeatReceiver = null
private var heartbeatReceiverRef: RpcEndpointRef = null
private var heartbeatReceiverClock: ManualClock = null
+ // Helper private method accessors for HeartbeatReceiver
+ private val _executorLastSeen = PrivateMethod[collection.Map[String, Long]]('executorLastSeen)
+ private val _executorTimeoutMs = PrivateMethod[Long]('executorTimeoutMs)
+ private val _killExecutorThread = PrivateMethod[ExecutorService]('killExecutorThread)
+
+ /**
+ * Before each test, set up the SparkContext and a custom [[HeartbeatReceiver]]
+ * that uses a manual clock.
+ */
override def beforeEach(): Unit = {
- sc = spy(new SparkContext("local[2]", "test"))
- scheduler = mock(classOf[TaskScheduler])
+ val conf = new SparkConf()
+ .setMaster("local[2]")
+ .setAppName("test")
+ .set("spark.dynamicAllocation.testing", "true")
+ sc = spy(new SparkContext(conf))
+ scheduler = mock(classOf[TaskSchedulerImpl])
when(sc.taskScheduler).thenReturn(scheduler)
+ when(scheduler.sc).thenReturn(sc)
heartbeatReceiverClock = new ManualClock
heartbeatReceiver = new HeartbeatReceiver(sc, heartbeatReceiverClock)
heartbeatReceiverRef = sc.env.rpcEnv.setupEndpoint("heartbeat", heartbeatReceiver)
when(scheduler.executorHeartbeatReceived(any(), any(), any())).thenReturn(true)
}
+ /**
+ * After each test, clean up all state and stop the [[SparkContext]].
+ */
override def afterEach(): Unit = {
- resetSparkContext()
+ super.afterEach()
scheduler = null
heartbeatReceiver = null
heartbeatReceiverRef = null
@@ -75,7 +100,7 @@ class HeartbeatReceiverSuite
heartbeatReceiver.onExecutorAdded(SparkListenerExecutorAdded(0, executorId2, null))
triggerHeartbeat(executorId1, executorShouldReregister = false)
triggerHeartbeat(executorId2, executorShouldReregister = false)
- val trackedExecutors = executorLastSeen(heartbeatReceiver)
+ val trackedExecutors = heartbeatReceiver.invokePrivate(_executorLastSeen())
assert(trackedExecutors.size === 2)
assert(trackedExecutors.contains(executorId1))
assert(trackedExecutors.contains(executorId2))
@@ -83,15 +108,15 @@ class HeartbeatReceiverSuite
test("reregister if scheduler is not ready yet") {
heartbeatReceiver.onExecutorAdded(SparkListenerExecutorAdded(0, executorId1, null))
- // Task scheduler not set in HeartbeatReceiver
+ // Task scheduler is not set yet in HeartbeatReceiver, so executors should reregister
triggerHeartbeat(executorId1, executorShouldReregister = true)
}
test("reregister if heartbeat from unregistered executor") {
heartbeatReceiverRef.askWithRetry[Boolean](TaskSchedulerIsSet)
- // Received heartbeat from unknown receiver, so we ask it to re-register
+ // Received heartbeat from unknown executor, so we ask it to re-register
triggerHeartbeat(executorId1, executorShouldReregister = true)
- assert(executorLastSeen(heartbeatReceiver).isEmpty)
+ assert(heartbeatReceiver.invokePrivate(_executorLastSeen()).isEmpty)
}
test("reregister if heartbeat from removed executor") {
@@ -104,14 +129,14 @@ class HeartbeatReceiverSuite
// A heartbeat from the second executor should require reregistering
triggerHeartbeat(executorId1, executorShouldReregister = false)
triggerHeartbeat(executorId2, executorShouldReregister = true)
- val trackedExecutors = executorLastSeen(heartbeatReceiver)
+ val trackedExecutors = heartbeatReceiver.invokePrivate(_executorLastSeen())
assert(trackedExecutors.size === 1)
assert(trackedExecutors.contains(executorId1))
assert(!trackedExecutors.contains(executorId2))
}
test("expire dead hosts") {
- val executorTimeout = executorTimeoutMs(heartbeatReceiver)
+ val executorTimeout = heartbeatReceiver.invokePrivate(_executorTimeoutMs())
heartbeatReceiverRef.askWithRetry[Boolean](TaskSchedulerIsSet)
heartbeatReceiver.onExecutorAdded(SparkListenerExecutorAdded(0, executorId1, null))
heartbeatReceiver.onExecutorAdded(SparkListenerExecutorAdded(0, executorId2, null))
@@ -124,12 +149,61 @@ class HeartbeatReceiverSuite
heartbeatReceiverRef.askWithRetry[Boolean](ExpireDeadHosts)
// Only the second executor should be expired as a dead host
verify(scheduler).executorLost(Matchers.eq(executorId2), any())
- val trackedExecutors = executorLastSeen(heartbeatReceiver)
+ val trackedExecutors = heartbeatReceiver.invokePrivate(_executorLastSeen())
assert(trackedExecutors.size === 1)
assert(trackedExecutors.contains(executorId1))
assert(!trackedExecutors.contains(executorId2))
}
+ test("expire dead hosts should kill executors with replacement (SPARK-8119)") {
+ // Set up a fake backend and cluster manager to simulate killing executors
+ val rpcEnv = sc.env.rpcEnv
+ val fakeClusterManager = new FakeClusterManager(rpcEnv)
+ val fakeClusterManagerRef = rpcEnv.setupEndpoint("fake-cm", fakeClusterManager)
+ val fakeSchedulerBackend = new FakeSchedulerBackend(scheduler, rpcEnv, fakeClusterManagerRef)
+ when(sc.schedulerBackend).thenReturn(fakeSchedulerBackend)
+
+ // Register fake executors with our fake scheduler backend
+ // This is necessary because the backend refuses to kill executors it does not know about
+ fakeSchedulerBackend.start()
+ val dummyExecutorEndpoint1 = new FakeExecutorEndpoint(rpcEnv)
+ val dummyExecutorEndpoint2 = new FakeExecutorEndpoint(rpcEnv)
+ val dummyExecutorEndpointRef1 = rpcEnv.setupEndpoint("fake-executor-1", dummyExecutorEndpoint1)
+ val dummyExecutorEndpointRef2 = rpcEnv.setupEndpoint("fake-executor-2", dummyExecutorEndpoint2)
+ fakeSchedulerBackend.driverEndpoint.askWithRetry[RegisteredExecutor.type](
+ RegisterExecutor(executorId1, dummyExecutorEndpointRef1, "dummy:4040", 0, Map.empty))
+ fakeSchedulerBackend.driverEndpoint.askWithRetry[RegisteredExecutor.type](
+ RegisterExecutor(executorId2, dummyExecutorEndpointRef2, "dummy:4040", 0, Map.empty))
+ heartbeatReceiverRef.askWithRetry[Boolean](TaskSchedulerIsSet)
+ heartbeatReceiver.onExecutorAdded(SparkListenerExecutorAdded(0, executorId1, null))
+ heartbeatReceiver.onExecutorAdded(SparkListenerExecutorAdded(0, executorId2, null))
+ triggerHeartbeat(executorId1, executorShouldReregister = false)
+ triggerHeartbeat(executorId2, executorShouldReregister = false)
+
+ // Adjust the target number of executors on the cluster manager side
+ assert(fakeClusterManager.getTargetNumExecutors === 0)
+ sc.requestTotalExecutors(2)
+ assert(fakeClusterManager.getTargetNumExecutors === 2)
+ assert(fakeClusterManager.getExecutorIdsToKill.isEmpty)
+
+ // Expire the executors. This should trigger our fake backend to kill the executors.
+ // Since the kill request is sent to the cluster manager asynchronously, we need to block
+ // on the kill thread to ensure that the cluster manager actually received our requests.
+ // Here we use a timeout of O(seconds), but in practice this whole test takes O(10ms).
+ val executorTimeout = heartbeatReceiver.invokePrivate(_executorTimeoutMs())
+ heartbeatReceiverClock.advance(executorTimeout * 2)
+ heartbeatReceiverRef.askWithRetry[Boolean](ExpireDeadHosts)
+ val killThread = heartbeatReceiver.invokePrivate(_killExecutorThread())
+ killThread.shutdown() // needed for awaitTermination
+ killThread.awaitTermination(10L, TimeUnit.SECONDS)
+
+ // The target number of executors should not change! Otherwise, having an expired
+ // executor means we permanently adjust the target number downwards until we
+ // explicitly request new executors. For more detail, see SPARK-8119.
+ assert(fakeClusterManager.getTargetNumExecutors === 2)
+ assert(fakeClusterManager.getExecutorIdsToKill === Set(executorId1, executorId2))
+ }
+
/** Manually send a heartbeat and return the response. */
private def triggerHeartbeat(
executorId: String,
@@ -148,14 +222,49 @@ class HeartbeatReceiverSuite
}
}
- // Helper methods to access private fields in HeartbeatReceiver
- private val _executorLastSeen = PrivateMethod[collection.Map[String, Long]]('executorLastSeen)
- private val _executorTimeoutMs = PrivateMethod[Long]('executorTimeoutMs)
- private def executorLastSeen(receiver: HeartbeatReceiver): collection.Map[String, Long] = {
- receiver invokePrivate _executorLastSeen()
+}
+
+// TODO: use these classes to add end-to-end tests for dynamic allocation!
+
+/**
+ * Dummy RPC endpoint to simulate executors.
+ */
+private class FakeExecutorEndpoint(override val rpcEnv: RpcEnv) extends RpcEndpoint
+
+/**
+ * Dummy scheduler backend to simulate executor allocation requests to the cluster manager.
+ */
+private class FakeSchedulerBackend(
+ scheduler: TaskSchedulerImpl,
+ rpcEnv: RpcEnv,
+ clusterManagerEndpoint: RpcEndpointRef)
+ extends CoarseGrainedSchedulerBackend(scheduler, rpcEnv) {
+
+ protected override def doRequestTotalExecutors(requestedTotal: Int): Boolean = {
+ clusterManagerEndpoint.askWithRetry[Boolean](RequestExecutors(requestedTotal))
}
- private def executorTimeoutMs(receiver: HeartbeatReceiver): Long = {
- receiver invokePrivate _executorTimeoutMs()
+
+ protected override def doKillExecutors(executorIds: Seq[String]): Boolean = {
+ clusterManagerEndpoint.askWithRetry[Boolean](KillExecutors(executorIds))
}
+}
+/**
+ * Dummy cluster manager to simulate responses to executor allocation requests.
+ */
+private class FakeClusterManager(override val rpcEnv: RpcEnv) extends RpcEndpoint {
+ private var targetNumExecutors = 0
+ private val executorIdsToKill = new mutable.HashSet[String]
+
+ def getTargetNumExecutors: Int = targetNumExecutors
+ def getExecutorIdsToKill: Set[String] = executorIdsToKill.toSet
+
+ override def receiveAndReply(context: RpcCallContext): PartialFunction[Any, Unit] = {
+ case RequestExecutors(requestedTotal) =>
+ targetNumExecutors = requestedTotal
+ context.reply(true)
+ case KillExecutors(executorIds) =>
+ executorIdsToKill ++= executorIds
+ context.reply(true)
+ }
}
diff --git a/core/src/test/scala/org/apache/spark/JobCancellationSuite.scala b/core/src/test/scala/org/apache/spark/JobCancellationSuite.scala
index 340a9e327107e..1168eb0b802f2 100644
--- a/core/src/test/scala/org/apache/spark/JobCancellationSuite.scala
+++ b/core/src/test/scala/org/apache/spark/JobCancellationSuite.scala
@@ -64,7 +64,7 @@ class JobCancellationSuite extends SparkFunSuite with Matchers with BeforeAndAft
test("cluster mode, FIFO scheduler") {
val conf = new SparkConf().set("spark.scheduler.mode", "FIFO")
- sc = new SparkContext("local-cluster[2,1,512]", "test", conf)
+ sc = new SparkContext("local-cluster[2,1,1024]", "test", conf)
testCount()
testTake()
// Make sure we can still launch tasks.
@@ -75,7 +75,7 @@ class JobCancellationSuite extends SparkFunSuite with Matchers with BeforeAndAft
val conf = new SparkConf().set("spark.scheduler.mode", "FAIR")
val xmlPath = getClass.getClassLoader.getResource("fairscheduler.xml").getFile()
conf.set("spark.scheduler.allocation.file", xmlPath)
- sc = new SparkContext("local-cluster[2,1,512]", "test", conf)
+ sc = new SparkContext("local-cluster[2,1,1024]", "test", conf)
testCount()
testTake()
// Make sure we can still launch tasks.
diff --git a/core/src/test/scala/org/apache/spark/MapOutputTrackerSuite.scala b/core/src/test/scala/org/apache/spark/MapOutputTrackerSuite.scala
index 7a1961137cce5..af4e68950f75a 100644
--- a/core/src/test/scala/org/apache/spark/MapOutputTrackerSuite.scala
+++ b/core/src/test/scala/org/apache/spark/MapOutputTrackerSuite.scala
@@ -17,13 +17,15 @@
package org.apache.spark
+import scala.collection.mutable.ArrayBuffer
+
import org.mockito.Mockito._
import org.mockito.Matchers.{any, isA}
import org.apache.spark.rpc.{RpcAddress, RpcEndpointRef, RpcCallContext, RpcEnv}
import org.apache.spark.scheduler.{CompressedMapStatus, MapStatus}
import org.apache.spark.shuffle.FetchFailedException
-import org.apache.spark.storage.BlockManagerId
+import org.apache.spark.storage.{BlockManagerId, ShuffleBlockId}
class MapOutputTrackerSuite extends SparkFunSuite {
private val conf = new SparkConf
@@ -55,9 +57,11 @@ class MapOutputTrackerSuite extends SparkFunSuite {
Array(1000L, 10000L)))
tracker.registerMapOutput(10, 1, MapStatus(BlockManagerId("b", "hostB", 1000),
Array(10000L, 1000L)))
- val statuses = tracker.getServerStatuses(10, 0)
- assert(statuses.toSeq === Seq((BlockManagerId("a", "hostA", 1000), size1000),
- (BlockManagerId("b", "hostB", 1000), size10000)))
+ val statuses = tracker.getMapSizesByExecutorId(10, 0)
+ assert(statuses.toSet ===
+ Seq((BlockManagerId("a", "hostA", 1000), ArrayBuffer((ShuffleBlockId(10, 0, 0), size1000))),
+ (BlockManagerId("b", "hostB", 1000), ArrayBuffer((ShuffleBlockId(10, 1, 0), size10000))))
+ .toSet)
tracker.stop()
rpcEnv.shutdown()
}
@@ -75,10 +79,10 @@ class MapOutputTrackerSuite extends SparkFunSuite {
tracker.registerMapOutput(10, 1, MapStatus(BlockManagerId("b", "hostB", 1000),
Array(compressedSize10000, compressedSize1000)))
assert(tracker.containsShuffle(10))
- assert(tracker.getServerStatuses(10, 0).nonEmpty)
+ assert(tracker.getMapSizesByExecutorId(10, 0).nonEmpty)
tracker.unregisterShuffle(10)
assert(!tracker.containsShuffle(10))
- assert(tracker.getServerStatuses(10, 0).isEmpty)
+ assert(tracker.getMapSizesByExecutorId(10, 0).isEmpty)
tracker.stop()
rpcEnv.shutdown()
@@ -104,7 +108,7 @@ class MapOutputTrackerSuite extends SparkFunSuite {
// The remaining reduce task might try to grab the output despite the shuffle failure;
// this should cause it to fail, and the scheduler will ignore the failure due to the
// stage already being aborted.
- intercept[FetchFailedException] { tracker.getServerStatuses(10, 1) }
+ intercept[FetchFailedException] { tracker.getMapSizesByExecutorId(10, 1) }
tracker.stop()
rpcEnv.shutdown()
@@ -126,23 +130,23 @@ class MapOutputTrackerSuite extends SparkFunSuite {
masterTracker.registerShuffle(10, 1)
masterTracker.incrementEpoch()
slaveTracker.updateEpoch(masterTracker.getEpoch)
- intercept[FetchFailedException] { slaveTracker.getServerStatuses(10, 0) }
+ intercept[FetchFailedException] { slaveTracker.getMapSizesByExecutorId(10, 0) }
val size1000 = MapStatus.decompressSize(MapStatus.compressSize(1000L))
masterTracker.registerMapOutput(10, 0, MapStatus(
BlockManagerId("a", "hostA", 1000), Array(1000L)))
masterTracker.incrementEpoch()
slaveTracker.updateEpoch(masterTracker.getEpoch)
- assert(slaveTracker.getServerStatuses(10, 0).toSeq ===
- Seq((BlockManagerId("a", "hostA", 1000), size1000)))
+ assert(slaveTracker.getMapSizesByExecutorId(10, 0) ===
+ Seq((BlockManagerId("a", "hostA", 1000), ArrayBuffer((ShuffleBlockId(10, 0, 0), size1000)))))
masterTracker.unregisterMapOutput(10, 0, BlockManagerId("a", "hostA", 1000))
masterTracker.incrementEpoch()
slaveTracker.updateEpoch(masterTracker.getEpoch)
- intercept[FetchFailedException] { slaveTracker.getServerStatuses(10, 0) }
+ intercept[FetchFailedException] { slaveTracker.getMapSizesByExecutorId(10, 0) }
// failure should be cached
- intercept[FetchFailedException] { slaveTracker.getServerStatuses(10, 0) }
+ intercept[FetchFailedException] { slaveTracker.getMapSizesByExecutorId(10, 0) }
masterTracker.stop()
slaveTracker.stop()
diff --git a/core/src/test/scala/org/apache/spark/ShuffleSuite.scala b/core/src/test/scala/org/apache/spark/ShuffleSuite.scala
index c3c2b1ffc1efa..d91b799ecfc08 100644
--- a/core/src/test/scala/org/apache/spark/ShuffleSuite.scala
+++ b/core/src/test/scala/org/apache/spark/ShuffleSuite.scala
@@ -47,7 +47,7 @@ abstract class ShuffleSuite extends SparkFunSuite with Matchers with LocalSparkC
}
test("shuffle non-zero block size") {
- sc = new SparkContext("local-cluster[2,1,512]", "test", conf)
+ sc = new SparkContext("local-cluster[2,1,1024]", "test", conf)
val NUM_BLOCKS = 3
val a = sc.parallelize(1 to 10, 2)
@@ -66,14 +66,14 @@ abstract class ShuffleSuite extends SparkFunSuite with Matchers with LocalSparkC
// All blocks must have non-zero size
(0 until NUM_BLOCKS).foreach { id =>
- val statuses = SparkEnv.get.mapOutputTracker.getServerStatuses(shuffleId, id)
- assert(statuses.forall(s => s._2 > 0))
+ val statuses = SparkEnv.get.mapOutputTracker.getMapSizesByExecutorId(shuffleId, id)
+ assert(statuses.forall(_._2.forall(blockIdSizePair => blockIdSizePair._2 > 0)))
}
}
test("shuffle serializer") {
// Use a local cluster with 2 processes to make sure there are both local and remote blocks
- sc = new SparkContext("local-cluster[2,1,512]", "test", conf)
+ sc = new SparkContext("local-cluster[2,1,1024]", "test", conf)
val a = sc.parallelize(1 to 10, 2)
val b = a.map { x =>
(x, new NonJavaSerializableClass(x * 2))
@@ -89,7 +89,7 @@ abstract class ShuffleSuite extends SparkFunSuite with Matchers with LocalSparkC
test("zero sized blocks") {
// Use a local cluster with 2 processes to make sure there are both local and remote blocks
- sc = new SparkContext("local-cluster[2,1,512]", "test", conf)
+ sc = new SparkContext("local-cluster[2,1,1024]", "test", conf)
// 201 partitions (greater than "spark.shuffle.sort.bypassMergeThreshold") from 4 keys
val NUM_BLOCKS = 201
@@ -105,8 +105,8 @@ abstract class ShuffleSuite extends SparkFunSuite with Matchers with LocalSparkC
assert(c.count === 4)
val blockSizes = (0 until NUM_BLOCKS).flatMap { id =>
- val statuses = SparkEnv.get.mapOutputTracker.getServerStatuses(shuffleId, id)
- statuses.map(x => x._2)
+ val statuses = SparkEnv.get.mapOutputTracker.getMapSizesByExecutorId(shuffleId, id)
+ statuses.flatMap(_._2.map(_._2))
}
val nonEmptyBlocks = blockSizes.filter(x => x > 0)
@@ -116,7 +116,7 @@ abstract class ShuffleSuite extends SparkFunSuite with Matchers with LocalSparkC
test("zero sized blocks without kryo") {
// Use a local cluster with 2 processes to make sure there are both local and remote blocks
- sc = new SparkContext("local-cluster[2,1,512]", "test", conf)
+ sc = new SparkContext("local-cluster[2,1,1024]", "test", conf)
// 201 partitions (greater than "spark.shuffle.sort.bypassMergeThreshold") from 4 keys
val NUM_BLOCKS = 201
@@ -130,8 +130,8 @@ abstract class ShuffleSuite extends SparkFunSuite with Matchers with LocalSparkC
assert(c.count === 4)
val blockSizes = (0 until NUM_BLOCKS).flatMap { id =>
- val statuses = SparkEnv.get.mapOutputTracker.getServerStatuses(shuffleId, id)
- statuses.map(x => x._2)
+ val statuses = SparkEnv.get.mapOutputTracker.getMapSizesByExecutorId(shuffleId, id)
+ statuses.flatMap(_._2.map(_._2))
}
val nonEmptyBlocks = blockSizes.filter(x => x > 0)
@@ -141,7 +141,7 @@ abstract class ShuffleSuite extends SparkFunSuite with Matchers with LocalSparkC
test("shuffle on mutable pairs") {
// Use a local cluster with 2 processes to make sure there are both local and remote blocks
- sc = new SparkContext("local-cluster[2,1,512]", "test", conf)
+ sc = new SparkContext("local-cluster[2,1,1024]", "test", conf)
def p[T1, T2](_1: T1, _2: T2): MutablePair[T1, T2] = MutablePair(_1, _2)
val data = Array(p(1, 1), p(1, 2), p(1, 3), p(2, 1))
val pairs: RDD[MutablePair[Int, Int]] = sc.parallelize(data, 2)
@@ -154,7 +154,7 @@ abstract class ShuffleSuite extends SparkFunSuite with Matchers with LocalSparkC
test("sorting on mutable pairs") {
// This is not in SortingSuite because of the local cluster setup.
// Use a local cluster with 2 processes to make sure there are both local and remote blocks
- sc = new SparkContext("local-cluster[2,1,512]", "test", conf)
+ sc = new SparkContext("local-cluster[2,1,1024]", "test", conf)
def p[T1, T2](_1: T1, _2: T2): MutablePair[T1, T2] = MutablePair(_1, _2)
val data = Array(p(1, 11), p(3, 33), p(100, 100), p(2, 22))
val pairs: RDD[MutablePair[Int, Int]] = sc.parallelize(data, 2)
@@ -168,7 +168,7 @@ abstract class ShuffleSuite extends SparkFunSuite with Matchers with LocalSparkC
test("cogroup using mutable pairs") {
// Use a local cluster with 2 processes to make sure there are both local and remote blocks
- sc = new SparkContext("local-cluster[2,1,512]", "test", conf)
+ sc = new SparkContext("local-cluster[2,1,1024]", "test", conf)
def p[T1, T2](_1: T1, _2: T2): MutablePair[T1, T2] = MutablePair(_1, _2)
val data1 = Seq(p(1, 1), p(1, 2), p(1, 3), p(2, 1))
val data2 = Seq(p(1, "11"), p(1, "12"), p(2, "22"), p(3, "3"))
@@ -195,7 +195,7 @@ abstract class ShuffleSuite extends SparkFunSuite with Matchers with LocalSparkC
test("subtract mutable pairs") {
// Use a local cluster with 2 processes to make sure there are both local and remote blocks
- sc = new SparkContext("local-cluster[2,1,512]", "test", conf)
+ sc = new SparkContext("local-cluster[2,1,1024]", "test", conf)
def p[T1, T2](_1: T1, _2: T2): MutablePair[T1, T2] = MutablePair(_1, _2)
val data1 = Seq(p(1, 1), p(1, 2), p(1, 3), p(2, 1), p(3, 33))
val data2 = Seq(p(1, "11"), p(1, "12"), p(2, "22"))
@@ -210,7 +210,7 @@ abstract class ShuffleSuite extends SparkFunSuite with Matchers with LocalSparkC
test("sort with Java non serializable class - Kryo") {
// Use a local cluster with 2 processes to make sure there are both local and remote blocks
val myConf = conf.clone().set("spark.serializer", "org.apache.spark.serializer.KryoSerializer")
- sc = new SparkContext("local-cluster[2,1,512]", "test", myConf)
+ sc = new SparkContext("local-cluster[2,1,1024]", "test", myConf)
val a = sc.parallelize(1 to 10, 2)
val b = a.map { x =>
(new NonJavaSerializableClass(x), x)
@@ -223,7 +223,7 @@ abstract class ShuffleSuite extends SparkFunSuite with Matchers with LocalSparkC
test("sort with Java non serializable class - Java") {
// Use a local cluster with 2 processes to make sure there are both local and remote blocks
- sc = new SparkContext("local-cluster[2,1,512]", "test", conf)
+ sc = new SparkContext("local-cluster[2,1,1024]", "test", conf)
val a = sc.parallelize(1 to 10, 2)
val b = a.map { x =>
(new NonJavaSerializableClass(x), x)
diff --git a/core/src/test/scala/org/apache/spark/SparkContextSchedulerCreationSuite.scala b/core/src/test/scala/org/apache/spark/SparkContextSchedulerCreationSuite.scala
index f89e3d0a49920..e5a14a69ef05f 100644
--- a/core/src/test/scala/org/apache/spark/SparkContextSchedulerCreationSuite.scala
+++ b/core/src/test/scala/org/apache/spark/SparkContextSchedulerCreationSuite.scala
@@ -19,6 +19,7 @@ package org.apache.spark
import org.scalatest.PrivateMethodTester
+import org.apache.spark.util.Utils
import org.apache.spark.scheduler.{SchedulerBackend, TaskScheduler, TaskSchedulerImpl}
import org.apache.spark.scheduler.cluster.{SimrSchedulerBackend, SparkDeploySchedulerBackend}
import org.apache.spark.scheduler.cluster.mesos.{CoarseMesosSchedulerBackend, MesosSchedulerBackend}
@@ -122,7 +123,7 @@ class SparkContextSchedulerCreationSuite
}
test("local-cluster") {
- createTaskScheduler("local-cluster[3, 14, 512]").backend match {
+ createTaskScheduler("local-cluster[3, 14, 1024]").backend match {
case s: SparkDeploySchedulerBackend => // OK
case _ => fail()
}
@@ -131,7 +132,7 @@ class SparkContextSchedulerCreationSuite
def testYarn(master: String, expectedClassName: String) {
try {
val sched = createTaskScheduler(master)
- assert(sched.getClass === Class.forName(expectedClassName))
+ assert(sched.getClass === Utils.classForName(expectedClassName))
} catch {
case e: SparkException =>
assert(e.getMessage.contains("YARN mode not available"))
diff --git a/core/src/test/scala/org/apache/spark/broadcast/BroadcastSuite.scala b/core/src/test/scala/org/apache/spark/broadcast/BroadcastSuite.scala
index c054c718075f8..48e74f06f79b1 100644
--- a/core/src/test/scala/org/apache/spark/broadcast/BroadcastSuite.scala
+++ b/core/src/test/scala/org/apache/spark/broadcast/BroadcastSuite.scala
@@ -69,7 +69,7 @@ class BroadcastSuite extends SparkFunSuite with LocalSparkContext {
val conf = httpConf.clone
conf.set("spark.serializer", "org.apache.spark.serializer.KryoSerializer")
conf.set("spark.broadcast.compress", "true")
- sc = new SparkContext("local-cluster[%d, 1, 512]".format(numSlaves), "test", conf)
+ sc = new SparkContext("local-cluster[%d, 1, 1024]".format(numSlaves), "test", conf)
val list = List[Int](1, 2, 3, 4)
val broadcast = sc.broadcast(list)
val results = sc.parallelize(1 to numSlaves).map(x => (x, broadcast.value.sum))
@@ -97,7 +97,7 @@ class BroadcastSuite extends SparkFunSuite with LocalSparkContext {
val conf = torrentConf.clone
conf.set("spark.serializer", "org.apache.spark.serializer.KryoSerializer")
conf.set("spark.broadcast.compress", "true")
- sc = new SparkContext("local-cluster[%d, 1, 512]".format(numSlaves), "test", conf)
+ sc = new SparkContext("local-cluster[%d, 1, 1024]".format(numSlaves), "test", conf)
val list = List[Int](1, 2, 3, 4)
val broadcast = sc.broadcast(list)
val results = sc.parallelize(1 to numSlaves).map(x => (x, broadcast.value.sum))
@@ -125,7 +125,7 @@ class BroadcastSuite extends SparkFunSuite with LocalSparkContext {
test("Test Lazy Broadcast variables with TorrentBroadcast") {
val numSlaves = 2
val conf = torrentConf.clone
- sc = new SparkContext("local-cluster[%d, 1, 512]".format(numSlaves), "test", conf)
+ sc = new SparkContext("local-cluster[%d, 1, 1024]".format(numSlaves), "test", conf)
val rdd = sc.parallelize(1 to numSlaves)
val results = new DummyBroadcastClass(rdd).doSomething()
@@ -308,7 +308,7 @@ class BroadcastSuite extends SparkFunSuite with LocalSparkContext {
sc = if (distributed) {
val _sc =
- new SparkContext("local-cluster[%d, 1, 512]".format(numSlaves), "test", broadcastConf)
+ new SparkContext("local-cluster[%d, 1, 1024]".format(numSlaves), "test", broadcastConf)
// Wait until all salves are up
_sc.jobProgressListener.waitUntilExecutorsUp(numSlaves, 10000)
_sc
diff --git a/core/src/test/scala/org/apache/spark/deploy/LogUrlsStandaloneSuite.scala b/core/src/test/scala/org/apache/spark/deploy/LogUrlsStandaloneSuite.scala
index ddc92814c0acf..cbd2aee10c0e2 100644
--- a/core/src/test/scala/org/apache/spark/deploy/LogUrlsStandaloneSuite.scala
+++ b/core/src/test/scala/org/apache/spark/deploy/LogUrlsStandaloneSuite.scala
@@ -33,7 +33,7 @@ class LogUrlsStandaloneSuite extends SparkFunSuite with LocalSparkContext {
private val WAIT_TIMEOUT_MILLIS = 10000
test("verify that correct log urls get propagated from workers") {
- sc = new SparkContext("local-cluster[2,1,512]", "test")
+ sc = new SparkContext("local-cluster[2,1,1024]", "test")
val listener = new SaveExecutorInfo
sc.addSparkListener(listener)
@@ -66,7 +66,7 @@ class LogUrlsStandaloneSuite extends SparkFunSuite with LocalSparkContext {
}
val conf = new MySparkConf().set(
"spark.extraListeners", classOf[SaveExecutorInfo].getName)
- sc = new SparkContext("local-cluster[2,1,512]", "test", conf)
+ sc = new SparkContext("local-cluster[2,1,1024]", "test", conf)
// Trigger a job so that executors get added
sc.parallelize(1 to 100, 4).map(_.toString).count()
diff --git a/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala b/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala
index 1b64c329b5d4b..aa78bfe30974c 100644
--- a/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala
+++ b/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala
@@ -246,7 +246,7 @@ class SparkSubmitSuite
mainClass should be ("org.apache.spark.deploy.Client")
}
classpath should have size 0
- sysProps should have size 8
+ sysProps should have size 9
sysProps.keys should contain ("SPARK_SUBMIT")
sysProps.keys should contain ("spark.master")
sysProps.keys should contain ("spark.app.name")
@@ -255,6 +255,7 @@ class SparkSubmitSuite
sysProps.keys should contain ("spark.driver.cores")
sysProps.keys should contain ("spark.driver.supervise")
sysProps.keys should contain ("spark.shuffle.spill")
+ sysProps.keys should contain ("spark.submit.deployMode")
sysProps("spark.shuffle.spill") should be ("false")
}
@@ -336,7 +337,7 @@ class SparkSubmitSuite
val args = Seq(
"--class", JarCreationTest.getClass.getName.stripSuffix("$"),
"--name", "testApp",
- "--master", "local-cluster[2,1,512]",
+ "--master", "local-cluster[2,1,1024]",
"--jars", jarsString,
unusedJar.toString, "SparkSubmitClassA", "SparkSubmitClassB")
runSparkSubmit(args)
@@ -351,7 +352,7 @@ class SparkSubmitSuite
val args = Seq(
"--class", JarCreationTest.getClass.getName.stripSuffix("$"),
"--name", "testApp",
- "--master", "local-cluster[2,1,512]",
+ "--master", "local-cluster[2,1,1024]",
"--packages", Seq(main, dep).mkString(","),
"--repositories", repo,
"--conf", "spark.ui.enabled=false",
@@ -540,8 +541,8 @@ object JarCreationTest extends Logging {
val result = sc.makeRDD(1 to 100, 10).mapPartitions { x =>
var exception: String = null
try {
- Class.forName(args(0), true, Thread.currentThread().getContextClassLoader)
- Class.forName(args(1), true, Thread.currentThread().getContextClassLoader)
+ Utils.classForName(args(0))
+ Utils.classForName(args(1))
} catch {
case t: Throwable =>
exception = t + "\n" + t.getStackTraceString
diff --git a/core/src/test/scala/org/apache/spark/deploy/history/FsHistoryProviderSuite.scala b/core/src/test/scala/org/apache/spark/deploy/history/FsHistoryProviderSuite.scala
index 2a62450bcdbad..73cff89544dc3 100644
--- a/core/src/test/scala/org/apache/spark/deploy/history/FsHistoryProviderSuite.scala
+++ b/core/src/test/scala/org/apache/spark/deploy/history/FsHistoryProviderSuite.scala
@@ -243,13 +243,12 @@ class FsHistoryProviderSuite extends SparkFunSuite with BeforeAndAfter with Matc
appListAfterRename.size should be (1)
}
- test("apps with multiple attempts") {
+ test("apps with multiple attempts with order") {
val provider = new FsHistoryProvider(createTestConf())
- val attempt1 = newLogFile("app1", Some("attempt1"), inProgress = false)
+ val attempt1 = newLogFile("app1", Some("attempt1"), inProgress = true)
writeFile(attempt1, true, None,
- SparkListenerApplicationStart("app1", Some("app1"), 1L, "test", Some("attempt1")),
- SparkListenerApplicationEnd(2L)
+ SparkListenerApplicationStart("app1", Some("app1"), 1L, "test", Some("attempt1"))
)
updateAndCheck(provider) { list =>
@@ -259,7 +258,7 @@ class FsHistoryProviderSuite extends SparkFunSuite with BeforeAndAfter with Matc
val attempt2 = newLogFile("app1", Some("attempt2"), inProgress = true)
writeFile(attempt2, true, None,
- SparkListenerApplicationStart("app1", Some("app1"), 3L, "test", Some("attempt2"))
+ SparkListenerApplicationStart("app1", Some("app1"), 2L, "test", Some("attempt2"))
)
updateAndCheck(provider) { list =>
@@ -268,22 +267,21 @@ class FsHistoryProviderSuite extends SparkFunSuite with BeforeAndAfter with Matc
list.head.attempts.head.attemptId should be (Some("attempt2"))
}
- val completedAttempt2 = newLogFile("app1", Some("attempt2"), inProgress = false)
- attempt2.delete()
- writeFile(attempt2, true, None,
- SparkListenerApplicationStart("app1", Some("app1"), 3L, "test", Some("attempt2")),
+ val attempt3 = newLogFile("app1", Some("attempt3"), inProgress = false)
+ writeFile(attempt3, true, None,
+ SparkListenerApplicationStart("app1", Some("app1"), 3L, "test", Some("attempt3")),
SparkListenerApplicationEnd(4L)
)
updateAndCheck(provider) { list =>
list should not be (null)
list.size should be (1)
- list.head.attempts.size should be (2)
- list.head.attempts.head.attemptId should be (Some("attempt2"))
+ list.head.attempts.size should be (3)
+ list.head.attempts.head.attemptId should be (Some("attempt3"))
}
val app2Attempt1 = newLogFile("app2", Some("attempt1"), inProgress = false)
- writeFile(attempt2, true, None,
+ writeFile(attempt1, true, None,
SparkListenerApplicationStart("app2", Some("app2"), 5L, "test", Some("attempt1")),
SparkListenerApplicationEnd(6L)
)
@@ -291,7 +289,7 @@ class FsHistoryProviderSuite extends SparkFunSuite with BeforeAndAfter with Matc
updateAndCheck(provider) { list =>
list.size should be (2)
list.head.attempts.size should be (1)
- list.last.attempts.size should be (2)
+ list.last.attempts.size should be (3)
list.head.attempts.head.attemptId should be (Some("attempt1"))
list.foreach { case app =>
diff --git a/core/src/test/scala/org/apache/spark/deploy/master/CustomRecoveryModeFactory.scala b/core/src/test/scala/org/apache/spark/deploy/master/CustomRecoveryModeFactory.scala
index f4e56632e426a..8c96b0e71dfdd 100644
--- a/core/src/test/scala/org/apache/spark/deploy/master/CustomRecoveryModeFactory.scala
+++ b/core/src/test/scala/org/apache/spark/deploy/master/CustomRecoveryModeFactory.scala
@@ -19,18 +19,19 @@
// when they are outside of org.apache.spark.
package other.supplier
+import java.nio.ByteBuffer
+
import scala.collection.mutable
import scala.reflect.ClassTag
-import akka.serialization.Serialization
-
import org.apache.spark.SparkConf
import org.apache.spark.deploy.master._
+import org.apache.spark.serializer.Serializer
class CustomRecoveryModeFactory(
conf: SparkConf,
- serialization: Serialization
-) extends StandaloneRecoveryModeFactory(conf, serialization) {
+ serializer: Serializer
+) extends StandaloneRecoveryModeFactory(conf, serializer) {
CustomRecoveryModeFactory.instantiationAttempts += 1
@@ -40,7 +41,7 @@ class CustomRecoveryModeFactory(
*
*/
override def createPersistenceEngine(): PersistenceEngine =
- new CustomPersistenceEngine(serialization)
+ new CustomPersistenceEngine(serializer)
/**
* Create an instance of LeaderAgent that decides who gets elected as master.
@@ -53,7 +54,7 @@ object CustomRecoveryModeFactory {
@volatile var instantiationAttempts = 0
}
-class CustomPersistenceEngine(serialization: Serialization) extends PersistenceEngine {
+class CustomPersistenceEngine(serializer: Serializer) extends PersistenceEngine {
val data = mutable.HashMap[String, Array[Byte]]()
CustomPersistenceEngine.lastInstance = Some(this)
@@ -64,10 +65,10 @@ class CustomPersistenceEngine(serialization: Serialization) extends PersistenceE
*/
override def persist(name: String, obj: Object): Unit = {
CustomPersistenceEngine.persistAttempts += 1
- serialization.serialize(obj) match {
- case util.Success(bytes) => data += name -> bytes
- case util.Failure(cause) => throw new RuntimeException(cause)
- }
+ val serialized = serializer.newInstance().serialize(obj)
+ val bytes = new Array[Byte](serialized.remaining())
+ serialized.get(bytes)
+ data += name -> bytes
}
/**
@@ -84,15 +85,9 @@ class CustomPersistenceEngine(serialization: Serialization) extends PersistenceE
*/
override def read[T: ClassTag](prefix: String): Seq[T] = {
CustomPersistenceEngine.readAttempts += 1
- val clazz = implicitly[ClassTag[T]].runtimeClass.asInstanceOf[Class[T]]
val results = for ((name, bytes) <- data; if name.startsWith(prefix))
- yield serialization.deserialize(bytes, clazz)
-
- results.find(_.isFailure).foreach {
- case util.Failure(cause) => throw new RuntimeException(cause)
- }
-
- results.flatMap(_.toOption).toSeq
+ yield serializer.newInstance().deserialize[T](ByteBuffer.wrap(bytes))
+ results.toSeq
}
}
diff --git a/core/src/test/scala/org/apache/spark/deploy/master/MasterSuite.scala b/core/src/test/scala/org/apache/spark/deploy/master/MasterSuite.scala
index 9cb6dd43bac47..a8fbaf1d9da0a 100644
--- a/core/src/test/scala/org/apache/spark/deploy/master/MasterSuite.scala
+++ b/core/src/test/scala/org/apache/spark/deploy/master/MasterSuite.scala
@@ -105,7 +105,7 @@ class MasterSuite extends SparkFunSuite with Matchers with Eventually {
persistenceEngine.addDriver(driverToPersist)
persistenceEngine.addWorker(workerToPersist)
- val (apps, drivers, workers) = persistenceEngine.readPersistedData()
+ val (apps, drivers, workers) = persistenceEngine.readPersistedData(rpcEnv)
apps.map(_.id) should contain(appToPersist.id)
drivers.map(_.id) should contain(driverToPersist.id)
diff --git a/core/src/test/scala/org/apache/spark/deploy/master/PersistenceEngineSuite.scala b/core/src/test/scala/org/apache/spark/deploy/master/PersistenceEngineSuite.scala
new file mode 100644
index 0000000000000..11e87bd1dd8eb
--- /dev/null
+++ b/core/src/test/scala/org/apache/spark/deploy/master/PersistenceEngineSuite.scala
@@ -0,0 +1,126 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+
+package org.apache.spark.deploy.master
+
+import java.net.ServerSocket
+
+import org.apache.commons.lang3.RandomUtils
+import org.apache.curator.test.TestingServer
+
+import org.apache.spark.{SecurityManager, SparkConf, SparkFunSuite}
+import org.apache.spark.rpc.{RpcEndpoint, RpcEnv}
+import org.apache.spark.serializer.{Serializer, JavaSerializer}
+import org.apache.spark.util.Utils
+
+class PersistenceEngineSuite extends SparkFunSuite {
+
+ test("FileSystemPersistenceEngine") {
+ val dir = Utils.createTempDir()
+ try {
+ val conf = new SparkConf()
+ testPersistenceEngine(conf, serializer =>
+ new FileSystemPersistenceEngine(dir.getAbsolutePath, serializer)
+ )
+ } finally {
+ Utils.deleteRecursively(dir)
+ }
+ }
+
+ test("ZooKeeperPersistenceEngine") {
+ val conf = new SparkConf()
+ // TestingServer logs the port conflict exception rather than throwing an exception.
+ // So we have to find a free port by ourselves. This approach cannot guarantee always starting
+ // zkTestServer successfully because there is a time gap between finding a free port and
+ // starting zkTestServer. But the failure possibility should be very low.
+ val zkTestServer = new TestingServer(findFreePort(conf))
+ try {
+ testPersistenceEngine(conf, serializer => {
+ conf.set("spark.deploy.zookeeper.url", zkTestServer.getConnectString)
+ new ZooKeeperPersistenceEngine(conf, serializer)
+ })
+ } finally {
+ zkTestServer.stop()
+ }
+ }
+
+ private def testPersistenceEngine(
+ conf: SparkConf, persistenceEngineCreator: Serializer => PersistenceEngine): Unit = {
+ val serializer = new JavaSerializer(conf)
+ val persistenceEngine = persistenceEngineCreator(serializer)
+ persistenceEngine.persist("test_1", "test_1_value")
+ assert(Seq("test_1_value") === persistenceEngine.read[String]("test_"))
+ persistenceEngine.persist("test_2", "test_2_value")
+ assert(Set("test_1_value", "test_2_value") === persistenceEngine.read[String]("test_").toSet)
+ persistenceEngine.unpersist("test_1")
+ assert(Seq("test_2_value") === persistenceEngine.read[String]("test_"))
+ persistenceEngine.unpersist("test_2")
+ assert(persistenceEngine.read[String]("test_").isEmpty)
+
+ // Test deserializing objects that contain RpcEndpointRef
+ val rpcEnv = RpcEnv.create("test", "localhost", 12345, conf, new SecurityManager(conf))
+ try {
+ // Create a real endpoint so that we can test RpcEndpointRef deserialization
+ val workerEndpoint = rpcEnv.setupEndpoint("worker", new RpcEndpoint {
+ override val rpcEnv: RpcEnv = rpcEnv
+ })
+
+ val workerToPersist = new WorkerInfo(
+ id = "test_worker",
+ host = "127.0.0.1",
+ port = 10000,
+ cores = 0,
+ memory = 0,
+ endpoint = workerEndpoint,
+ webUiPort = 0,
+ publicAddress = ""
+ )
+
+ persistenceEngine.addWorker(workerToPersist)
+
+ val (storedApps, storedDrivers, storedWorkers) = persistenceEngine.readPersistedData(rpcEnv)
+
+ assert(storedApps.isEmpty)
+ assert(storedDrivers.isEmpty)
+
+ // Check deserializing WorkerInfo
+ assert(storedWorkers.size == 1)
+ val recoveryWorkerInfo = storedWorkers.head
+ assert(workerToPersist.id === recoveryWorkerInfo.id)
+ assert(workerToPersist.host === recoveryWorkerInfo.host)
+ assert(workerToPersist.port === recoveryWorkerInfo.port)
+ assert(workerToPersist.cores === recoveryWorkerInfo.cores)
+ assert(workerToPersist.memory === recoveryWorkerInfo.memory)
+ assert(workerToPersist.endpoint === recoveryWorkerInfo.endpoint)
+ assert(workerToPersist.webUiPort === recoveryWorkerInfo.webUiPort)
+ assert(workerToPersist.publicAddress === recoveryWorkerInfo.publicAddress)
+ } finally {
+ rpcEnv.shutdown()
+ rpcEnv.awaitTermination()
+ }
+ }
+
+ private def findFreePort(conf: SparkConf): Int = {
+ val candidatePort = RandomUtils.nextInt(1024, 65536)
+ Utils.startServiceOnPort(candidatePort, (trialPort: Int) => {
+ val socket = new ServerSocket(trialPort)
+ socket.close()
+ (null, trialPort)
+ }, conf)._2
+ }
+}
diff --git a/core/src/test/scala/org/apache/spark/rdd/JdbcRDDSuite.scala b/core/src/test/scala/org/apache/spark/rdd/JdbcRDDSuite.scala
index 08215a2bafc09..05013fbc49b8e 100644
--- a/core/src/test/scala/org/apache/spark/rdd/JdbcRDDSuite.scala
+++ b/core/src/test/scala/org/apache/spark/rdd/JdbcRDDSuite.scala
@@ -22,11 +22,12 @@ import java.sql._
import org.scalatest.BeforeAndAfter
import org.apache.spark.{LocalSparkContext, SparkContext, SparkFunSuite}
+import org.apache.spark.util.Utils
class JdbcRDDSuite extends SparkFunSuite with BeforeAndAfter with LocalSparkContext {
before {
- Class.forName("org.apache.derby.jdbc.EmbeddedDriver")
+ Utils.classForName("org.apache.derby.jdbc.EmbeddedDriver")
val conn = DriverManager.getConnection("jdbc:derby:target/JdbcRDDSuiteDb;create=true")
try {
diff --git a/core/src/test/scala/org/apache/spark/rdd/PipedRDDSuite.scala b/core/src/test/scala/org/apache/spark/rdd/PipedRDDSuite.scala
index 32f04d54eff94..3e8816a4c65be 100644
--- a/core/src/test/scala/org/apache/spark/rdd/PipedRDDSuite.scala
+++ b/core/src/test/scala/org/apache/spark/rdd/PipedRDDSuite.scala
@@ -175,7 +175,7 @@ class PipedRDDSuite extends SparkFunSuite with SharedSparkContext {
}
val hadoopPart1 = generateFakeHadoopPartition()
val pipedRdd = new PipedRDD(nums, "printenv " + varName)
- val tContext = new TaskContextImpl(0, 0, 0, 0, null)
+ val tContext = new TaskContextImpl(0, 0, 0, 0, null, null)
val rddIter = pipedRdd.compute(hadoopPart1, tContext)
val arr = rddIter.toArray
assert(arr(0) == "/some/path")
diff --git a/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala b/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala
index f6da9f98ad253..5f718ea9f7be1 100644
--- a/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala
+++ b/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala
@@ -679,7 +679,7 @@ class RDDSuite extends SparkFunSuite with SharedSparkContext {
test("runJob on an invalid partition") {
intercept[IllegalArgumentException] {
- sc.runJob(sc.parallelize(1 to 10, 2), {iter: Iterator[Int] => iter.size}, Seq(0, 1, 2), false)
+ sc.runJob(sc.parallelize(1 to 10, 2), {iter: Iterator[Int] => iter.size}, Seq(0, 1, 2))
}
}
diff --git a/core/src/test/scala/org/apache/spark/scheduler/CoarseGrainedSchedulerBackendSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/CoarseGrainedSchedulerBackendSuite.scala
index 34145691153ce..eef6aafa624ee 100644
--- a/core/src/test/scala/org/apache/spark/scheduler/CoarseGrainedSchedulerBackendSuite.scala
+++ b/core/src/test/scala/org/apache/spark/scheduler/CoarseGrainedSchedulerBackendSuite.scala
@@ -26,7 +26,7 @@ class CoarseGrainedSchedulerBackendSuite extends SparkFunSuite with LocalSparkCo
val conf = new SparkConf
conf.set("spark.akka.frameSize", "1")
conf.set("spark.default.parallelism", "1")
- sc = new SparkContext("local-cluster[2 , 1 , 512]", "test", conf)
+ sc = new SparkContext("local-cluster[2, 1, 1024]", "test", conf)
val frameSize = AkkaUtils.maxFrameSizeBytes(sc.conf)
val buffer = new SerializableBuffer(java.nio.ByteBuffer.allocate(2 * frameSize))
val larger = sc.parallelize(Seq(buffer))
diff --git a/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala
index 6bc45f249f975..86dff8fb577d5 100644
--- a/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala
+++ b/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala
@@ -101,9 +101,15 @@ class DAGSchedulerSuite
/** Length of time to wait while draining listener events. */
val WAIT_TIMEOUT_MILLIS = 10000
val sparkListener = new SparkListener() {
+ val submittedStageInfos = new HashSet[StageInfo]
val successfulStages = new HashSet[Int]
val failedStages = new ArrayBuffer[Int]
val stageByOrderOfExecution = new ArrayBuffer[Int]
+
+ override def onStageSubmitted(stageSubmitted: SparkListenerStageSubmitted) {
+ submittedStageInfos += stageSubmitted.stageInfo
+ }
+
override def onStageCompleted(stageCompleted: SparkListenerStageCompleted) {
val stageInfo = stageCompleted.stageInfo
stageByOrderOfExecution += stageInfo.stageId
@@ -147,9 +153,8 @@ class DAGSchedulerSuite
}
before {
- // Enable local execution for this test
- val conf = new SparkConf().set("spark.localExecution.enabled", "true")
- sc = new SparkContext("local", "DAGSchedulerSuite", conf)
+ sc = new SparkContext("local", "DAGSchedulerSuite")
+ sparkListener.submittedStageInfos.clear()
sparkListener.successfulStages.clear()
sparkListener.failedStages.clear()
failure = null
@@ -165,12 +170,7 @@ class DAGSchedulerSuite
sc.listenerBus,
mapOutputTracker,
blockManagerMaster,
- sc.env) {
- override def runLocally(job: ActiveJob) {
- // don't bother with the thread while unit testing
- runLocallyWithinThread(job)
- }
- }
+ sc.env)
dagEventProcessLoopTester = new DAGSchedulerEventProcessLoopTester(scheduler)
}
@@ -234,10 +234,9 @@ class DAGSchedulerSuite
rdd: RDD[_],
partitions: Array[Int],
func: (TaskContext, Iterator[_]) => _ = jobComputeFunc,
- allowLocal: Boolean = false,
listener: JobListener = jobListener): Int = {
val jobId = scheduler.nextJobId.getAndIncrement()
- runEvent(JobSubmitted(jobId, rdd, func, partitions, allowLocal, CallSite("", ""), listener))
+ runEvent(JobSubmitted(jobId, rdd, func, partitions, CallSite("", ""), listener))
jobId
}
@@ -277,37 +276,6 @@ class DAGSchedulerSuite
assertDataStructuresEmpty()
}
- test("local job") {
- val rdd = new PairOfIntsRDD(sc, Nil) {
- override def compute(split: Partition, context: TaskContext): Iterator[(Int, Int)] =
- Array(42 -> 0).iterator
- override def getPartitions: Array[Partition] =
- Array( new Partition { override def index: Int = 0 } )
- override def getPreferredLocations(split: Partition): List[String] = Nil
- override def toString: String = "DAGSchedulerSuite Local RDD"
- }
- val jobId = scheduler.nextJobId.getAndIncrement()
- runEvent(
- JobSubmitted(jobId, rdd, jobComputeFunc, Array(0), true, CallSite("", ""), jobListener))
- assert(results === Map(0 -> 42))
- assertDataStructuresEmpty()
- }
-
- test("local job oom") {
- val rdd = new PairOfIntsRDD(sc, Nil) {
- override def compute(split: Partition, context: TaskContext): Iterator[(Int, Int)] =
- throw new java.lang.OutOfMemoryError("test local job oom")
- override def getPartitions = Array( new Partition { override def index = 0 } )
- override def getPreferredLocations(split: Partition) = Nil
- override def toString = "DAGSchedulerSuite Local RDD"
- }
- val jobId = scheduler.nextJobId.getAndIncrement()
- runEvent(
- JobSubmitted(jobId, rdd, jobComputeFunc, Array(0), true, CallSite("", ""), jobListener))
- assert(results.size == 0)
- assertDataStructuresEmpty()
- }
-
test("run trivial job w/ dependency") {
val baseRdd = new MyRDD(sc, 1, Nil)
val finalRdd = new MyRDD(sc, 1, List(new OneToOneDependency(baseRdd)))
@@ -445,12 +413,7 @@ class DAGSchedulerSuite
sc.listenerBus,
mapOutputTracker,
blockManagerMaster,
- sc.env) {
- override def runLocally(job: ActiveJob) {
- // don't bother with the thread while unit testing
- runLocallyWithinThread(job)
- }
- }
+ sc.env)
dagEventProcessLoopTester = new DAGSchedulerEventProcessLoopTester(noKillScheduler)
val jobId = submit(new MyRDD(sc, 1, Nil), Array(0))
cancel(jobId)
@@ -476,8 +439,8 @@ class DAGSchedulerSuite
complete(taskSets(0), Seq(
(Success, makeMapStatus("hostA", 1)),
(Success, makeMapStatus("hostB", 1))))
- assert(mapOutputTracker.getServerStatuses(shuffleId, 0).map(_._1) ===
- Array(makeBlockManagerId("hostA"), makeBlockManagerId("hostB")))
+ assert(mapOutputTracker.getMapSizesByExecutorId(shuffleId, 0).map(_._1).toSet ===
+ HashSet(makeBlockManagerId("hostA"), makeBlockManagerId("hostB")))
complete(taskSets(1), Seq((Success, 42)))
assert(results === Map(0 -> 42))
assertDataStructuresEmpty()
@@ -503,8 +466,8 @@ class DAGSchedulerSuite
// have the 2nd attempt pass
complete(taskSets(2), Seq((Success, makeMapStatus("hostA", reduceRdd.partitions.size))))
// we can see both result blocks now
- assert(mapOutputTracker.getServerStatuses(shuffleId, 0).map(_._1.host) ===
- Array("hostA", "hostB"))
+ assert(mapOutputTracker.getMapSizesByExecutorId(shuffleId, 0).map(_._1.host).toSet ===
+ HashSet("hostA", "hostB"))
complete(taskSets(3), Seq((Success, 43)))
assert(results === Map(0 -> 42, 1 -> 43))
assertDataStructuresEmpty()
@@ -520,8 +483,8 @@ class DAGSchedulerSuite
(Success, makeMapStatus("hostA", reduceRdd.partitions.size)),
(Success, makeMapStatus("hostB", reduceRdd.partitions.size))))
// The MapOutputTracker should know about both map output locations.
- assert(mapOutputTracker.getServerStatuses(shuffleId, 0).map(_._1.host) ===
- Array("hostA", "hostB"))
+ assert(mapOutputTracker.getMapSizesByExecutorId(shuffleId, 0).map(_._1.host).toSet ===
+ HashSet("hostA", "hostB"))
// The first result task fails, with a fetch failure for the output from the first mapper.
runEvent(CompletionEvent(
@@ -547,6 +510,140 @@ class DAGSchedulerSuite
assert(sparkListener.failedStages.size == 1)
}
+ /**
+ * This tests the case where another FetchFailed comes in while the map stage is getting
+ * re-run.
+ */
+ test("late fetch failures don't cause multiple concurrent attempts for the same map stage") {
+ val shuffleMapRdd = new MyRDD(sc, 2, Nil)
+ val shuffleDep = new ShuffleDependency(shuffleMapRdd, null)
+ val shuffleId = shuffleDep.shuffleId
+ val reduceRdd = new MyRDD(sc, 2, List(shuffleDep))
+ submit(reduceRdd, Array(0, 1))
+
+ val mapStageId = 0
+ def countSubmittedMapStageAttempts(): Int = {
+ sparkListener.submittedStageInfos.count(_.stageId == mapStageId)
+ }
+
+ // The map stage should have been submitted.
+ assert(countSubmittedMapStageAttempts() === 1)
+
+ complete(taskSets(0), Seq(
+ (Success, makeMapStatus("hostA", 2)),
+ (Success, makeMapStatus("hostB", 2))))
+ // The MapOutputTracker should know about both map output locations.
+ assert(mapOutputTracker.getMapSizesByExecutorId(shuffleId, 0).map(_._1.host).toSet ===
+ HashSet("hostA", "hostB"))
+ assert(mapOutputTracker.getMapSizesByExecutorId(shuffleId, 1).map(_._1.host).toSet ===
+ HashSet("hostA", "hostB"))
+
+ // The first result task fails, with a fetch failure for the output from the first mapper.
+ runEvent(CompletionEvent(
+ taskSets(1).tasks(0),
+ FetchFailed(makeBlockManagerId("hostA"), shuffleId, 0, 0, "ignored"),
+ null,
+ Map[Long, Any](),
+ createFakeTaskInfo(),
+ null))
+ sc.listenerBus.waitUntilEmpty(WAIT_TIMEOUT_MILLIS)
+ assert(sparkListener.failedStages.contains(1))
+
+ // Trigger resubmission of the failed map stage.
+ runEvent(ResubmitFailedStages)
+ sc.listenerBus.waitUntilEmpty(WAIT_TIMEOUT_MILLIS)
+
+ // Another attempt for the map stage should have been submitted, resulting in 2 total attempts.
+ assert(countSubmittedMapStageAttempts() === 2)
+
+ // The second ResultTask fails, with a fetch failure for the output from the second mapper.
+ runEvent(CompletionEvent(
+ taskSets(1).tasks(1),
+ FetchFailed(makeBlockManagerId("hostB"), shuffleId, 1, 1, "ignored"),
+ null,
+ Map[Long, Any](),
+ createFakeTaskInfo(),
+ null))
+
+ // Another ResubmitFailedStages event should not result in another attempt for the map
+ // stage being run concurrently.
+ // NOTE: the actual ResubmitFailedStages may get called at any time during this, but it
+ // shouldn't effect anything -- our calling it just makes *SURE* it gets called between the
+ // desired event and our check.
+ runEvent(ResubmitFailedStages)
+ sc.listenerBus.waitUntilEmpty(WAIT_TIMEOUT_MILLIS)
+ assert(countSubmittedMapStageAttempts() === 2)
+
+ }
+
+ /**
+ * This tests the case where a late FetchFailed comes in after the map stage has finished getting
+ * retried and a new reduce stage starts running.
+ */
+ test("extremely late fetch failures don't cause multiple concurrent attempts for " +
+ "the same stage") {
+ val shuffleMapRdd = new MyRDD(sc, 2, Nil)
+ val shuffleDep = new ShuffleDependency(shuffleMapRdd, null)
+ val shuffleId = shuffleDep.shuffleId
+ val reduceRdd = new MyRDD(sc, 2, List(shuffleDep))
+ submit(reduceRdd, Array(0, 1))
+
+ def countSubmittedReduceStageAttempts(): Int = {
+ sparkListener.submittedStageInfos.count(_.stageId == 1)
+ }
+ def countSubmittedMapStageAttempts(): Int = {
+ sparkListener.submittedStageInfos.count(_.stageId == 0)
+ }
+
+ // The map stage should have been submitted.
+ sc.listenerBus.waitUntilEmpty(WAIT_TIMEOUT_MILLIS)
+ assert(countSubmittedMapStageAttempts() === 1)
+
+ // Complete the map stage.
+ complete(taskSets(0), Seq(
+ (Success, makeMapStatus("hostA", 2)),
+ (Success, makeMapStatus("hostB", 2))))
+
+ // The reduce stage should have been submitted.
+ sc.listenerBus.waitUntilEmpty(WAIT_TIMEOUT_MILLIS)
+ assert(countSubmittedReduceStageAttempts() === 1)
+
+ // The first result task fails, with a fetch failure for the output from the first mapper.
+ runEvent(CompletionEvent(
+ taskSets(1).tasks(0),
+ FetchFailed(makeBlockManagerId("hostA"), shuffleId, 0, 0, "ignored"),
+ null,
+ Map[Long, Any](),
+ createFakeTaskInfo(),
+ null))
+
+ // Trigger resubmission of the failed map stage and finish the re-started map task.
+ runEvent(ResubmitFailedStages)
+ complete(taskSets(2), Seq((Success, makeMapStatus("hostA", 1))))
+
+ // Because the map stage finished, another attempt for the reduce stage should have been
+ // submitted, resulting in 2 total attempts for each the map and the reduce stage.
+ sc.listenerBus.waitUntilEmpty(WAIT_TIMEOUT_MILLIS)
+ assert(countSubmittedMapStageAttempts() === 2)
+ assert(countSubmittedReduceStageAttempts() === 2)
+
+ // A late FetchFailed arrives from the second task in the original reduce stage.
+ runEvent(CompletionEvent(
+ taskSets(1).tasks(1),
+ FetchFailed(makeBlockManagerId("hostB"), shuffleId, 1, 1, "ignored"),
+ null,
+ Map[Long, Any](),
+ createFakeTaskInfo(),
+ null))
+
+ // Running ResubmitFailedStages shouldn't result in any more attempts for the map stage, because
+ // the FetchFailed should have been ignored
+ runEvent(ResubmitFailedStages)
+
+ // The FetchFailed from the original reduce stage should be ignored.
+ assert(countSubmittedMapStageAttempts() === 2)
+ }
+
test("ignore late map task completions") {
val shuffleMapRdd = new MyRDD(sc, 2, Nil)
val shuffleDep = new ShuffleDependency(shuffleMapRdd, null)
@@ -572,8 +669,8 @@ class DAGSchedulerSuite
taskSet.tasks(1).epoch = newEpoch
runEvent(CompletionEvent(taskSet.tasks(1), Success, makeMapStatus("hostA",
reduceRdd.partitions.size), null, createFakeTaskInfo(), null))
- assert(mapOutputTracker.getServerStatuses(shuffleId, 0).map(_._1) ===
- Array(makeBlockManagerId("hostB"), makeBlockManagerId("hostA")))
+ assert(mapOutputTracker.getMapSizesByExecutorId(shuffleId, 0).map(_._1).toSet ===
+ HashSet(makeBlockManagerId("hostB"), makeBlockManagerId("hostA")))
complete(taskSets(1), Seq((Success, 42), (Success, 43)))
assert(results === Map(0 -> 42, 1 -> 43))
assertDataStructuresEmpty()
@@ -668,8 +765,8 @@ class DAGSchedulerSuite
(Success, makeMapStatus("hostB", 1))))
// have hostC complete the resubmitted task
complete(taskSets(1), Seq((Success, makeMapStatus("hostC", 1))))
- assert(mapOutputTracker.getServerStatuses(shuffleId, 0).map(_._1) ===
- Array(makeBlockManagerId("hostC"), makeBlockManagerId("hostB")))
+ assert(mapOutputTracker.getMapSizesByExecutorId(shuffleId, 0).map(_._1).toSet ===
+ HashSet(makeBlockManagerId("hostC"), makeBlockManagerId("hostB")))
complete(taskSets(2), Seq((Success, 42)))
assert(results === Map(0 -> 42))
assertDataStructuresEmpty()
@@ -748,40 +845,23 @@ class DAGSchedulerSuite
// Run this on executors
sc.parallelize(1 to 10, 2).foreach { item => acc.add(1) }
- // Run this within a local thread
- sc.parallelize(1 to 10, 2).map { item => acc.add(1) }.take(1)
-
- // Make sure we can still run local commands as well as cluster commands.
+ // Make sure we can still run commands
assert(sc.parallelize(1 to 10, 2).count() === 10)
- assert(sc.parallelize(1 to 10, 2).first() === 1)
}
test("misbehaved resultHandler should not crash DAGScheduler and SparkContext") {
- val e1 = intercept[SparkDriverExecutionException] {
- val rdd = sc.parallelize(1 to 10, 2)
- sc.runJob[Int, Int](
- rdd,
- (context: TaskContext, iter: Iterator[Int]) => iter.size,
- Seq(0),
- allowLocal = true,
- (part: Int, result: Int) => throw new DAGSchedulerSuiteDummyException)
- }
- assert(e1.getCause.isInstanceOf[DAGSchedulerSuiteDummyException])
-
- val e2 = intercept[SparkDriverExecutionException] {
+ val e = intercept[SparkDriverExecutionException] {
val rdd = sc.parallelize(1 to 10, 2)
sc.runJob[Int, Int](
rdd,
(context: TaskContext, iter: Iterator[Int]) => iter.size,
Seq(0, 1),
- allowLocal = false,
(part: Int, result: Int) => throw new DAGSchedulerSuiteDummyException)
}
- assert(e2.getCause.isInstanceOf[DAGSchedulerSuiteDummyException])
+ assert(e.getCause.isInstanceOf[DAGSchedulerSuiteDummyException])
- // Make sure we can still run local commands as well as cluster commands.
+ // Make sure we can still run commands
assert(sc.parallelize(1 to 10, 2).count() === 10)
- assert(sc.parallelize(1 to 10, 2).first() === 1)
}
test("getPartitions exceptions should not crash DAGScheduler and SparkContext (SPARK-8606)") {
@@ -794,9 +874,8 @@ class DAGSchedulerSuite
rdd.reduceByKey(_ + _, 1).count()
}
- // Make sure we can still run local commands as well as cluster commands.
+ // Make sure we can still run commands
assert(sc.parallelize(1 to 10, 2).count() === 10)
- assert(sc.parallelize(1 to 10, 2).first() === 1)
}
test("getPreferredLocations errors should not crash DAGScheduler and SparkContext (SPARK-8606)") {
@@ -810,9 +889,8 @@ class DAGSchedulerSuite
}
assert(e1.getMessage.contains(classOf[DAGSchedulerSuiteDummyException].getName))
- // Make sure we can still run local commands as well as cluster commands.
+ // Make sure we can still run commands
assert(sc.parallelize(1 to 10, 2).count() === 10)
- assert(sc.parallelize(1 to 10, 2).first() === 1)
}
test("accumulator not calculated for resubmitted result stage") {
@@ -840,8 +918,8 @@ class DAGSchedulerSuite
submit(reduceRdd, Array(0))
complete(taskSets(0), Seq(
(Success, makeMapStatus("hostA", 1))))
- assert(mapOutputTracker.getServerStatuses(shuffleId, 0).map(_._1) ===
- Array(makeBlockManagerId("hostA")))
+ assert(mapOutputTracker.getMapSizesByExecutorId(shuffleId, 0).map(_._1).toSet ===
+ HashSet(makeBlockManagerId("hostA")))
// Reducer should run on the same host that map task ran
val reduceTaskSet = taskSets(1)
@@ -875,6 +953,21 @@ class DAGSchedulerSuite
assertDataStructuresEmpty
}
+ test("Spark exceptions should include call site in stack trace") {
+ val e = intercept[SparkException] {
+ sc.parallelize(1 to 10, 2).map { _ => throw new RuntimeException("uh-oh!") }.count()
+ }
+
+ // Does not include message, ONLY stack trace.
+ val stackTraceString = e.getStackTraceString
+
+ // should actually include the RDD operation that invoked the method:
+ assert(stackTraceString.contains("org.apache.spark.rdd.RDD.count"))
+
+ // should include the FunSuite setup:
+ assert(stackTraceString.contains("org.scalatest.FunSuite"))
+ }
+
/**
* Assert that the supplied TaskSet has exactly the given hosts as its preferred locations.
* Note that this checks only the host and not the executor ID.
diff --git a/core/src/test/scala/org/apache/spark/scheduler/EventLoggingListenerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/EventLoggingListenerSuite.scala
index f681f21b6205e..5cb2d4225d281 100644
--- a/core/src/test/scala/org/apache/spark/scheduler/EventLoggingListenerSuite.scala
+++ b/core/src/test/scala/org/apache/spark/scheduler/EventLoggingListenerSuite.scala
@@ -180,7 +180,7 @@ class EventLoggingListenerSuite extends SparkFunSuite with LocalSparkContext wit
// into SPARK-6688.
val conf = getLoggingConf(testDirPath, compressionCodec)
.set("spark.hadoop.fs.defaultFS", "unsupported://example.com")
- val sc = new SparkContext("local-cluster[2,2,512]", "test", conf)
+ val sc = new SparkContext("local-cluster[2,2,1024]", "test", conf)
assert(sc.eventLogger.isDefined)
val eventLogger = sc.eventLogger.get
val eventLogPath = eventLogger.logPath
diff --git a/core/src/test/scala/org/apache/spark/scheduler/FakeTask.scala b/core/src/test/scala/org/apache/spark/scheduler/FakeTask.scala
index 0a7cb69416a08..b3ca150195a5f 100644
--- a/core/src/test/scala/org/apache/spark/scheduler/FakeTask.scala
+++ b/core/src/test/scala/org/apache/spark/scheduler/FakeTask.scala
@@ -19,7 +19,7 @@ package org.apache.spark.scheduler
import org.apache.spark.TaskContext
-class FakeTask(stageId: Int, prefLocs: Seq[TaskLocation] = Nil) extends Task[Int](stageId, 0) {
+class FakeTask(stageId: Int, prefLocs: Seq[TaskLocation] = Nil) extends Task[Int](stageId, 0, 0) {
override def runTask(context: TaskContext): Int = 0
override def preferredLocations: Seq[TaskLocation] = prefLocs
@@ -31,12 +31,16 @@ object FakeTask {
* locations for each task (given as varargs) if this sequence is not empty.
*/
def createTaskSet(numTasks: Int, prefLocs: Seq[TaskLocation]*): TaskSet = {
+ createTaskSet(numTasks, 0, prefLocs: _*)
+ }
+
+ def createTaskSet(numTasks: Int, stageAttemptId: Int, prefLocs: Seq[TaskLocation]*): TaskSet = {
if (prefLocs.size != 0 && prefLocs.size != numTasks) {
throw new IllegalArgumentException("Wrong number of task locations")
}
val tasks = Array.tabulate[Task[_]](numTasks) { i =>
new FakeTask(i, if (prefLocs.size != 0) prefLocs(i) else Nil)
}
- new TaskSet(tasks, 0, 0, 0, null)
+ new TaskSet(tasks, 0, stageAttemptId, 0, null)
}
}
diff --git a/core/src/test/scala/org/apache/spark/scheduler/NotSerializableFakeTask.scala b/core/src/test/scala/org/apache/spark/scheduler/NotSerializableFakeTask.scala
index 9b92f8de56759..383855caefa2f 100644
--- a/core/src/test/scala/org/apache/spark/scheduler/NotSerializableFakeTask.scala
+++ b/core/src/test/scala/org/apache/spark/scheduler/NotSerializableFakeTask.scala
@@ -25,7 +25,7 @@ import org.apache.spark.TaskContext
* A Task implementation that fails to serialize.
*/
private[spark] class NotSerializableFakeTask(myId: Int, stageId: Int)
- extends Task[Array[Byte]](stageId, 0) {
+ extends Task[Array[Byte]](stageId, 0, 0) {
override def runTask(context: TaskContext): Array[Byte] = Array.empty[Byte]
override def preferredLocations: Seq[TaskLocation] = Seq[TaskLocation]()
diff --git a/core/src/test/scala/org/apache/spark/scheduler/OutputCommitCoordinatorSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/OutputCommitCoordinatorSuite.scala
index a9036da9cc93d..e5ecd4b7c2610 100644
--- a/core/src/test/scala/org/apache/spark/scheduler/OutputCommitCoordinatorSuite.scala
+++ b/core/src/test/scala/org/apache/spark/scheduler/OutputCommitCoordinatorSuite.scala
@@ -134,14 +134,14 @@ class OutputCommitCoordinatorSuite extends SparkFunSuite with BeforeAndAfter {
test("Only one of two duplicate commit tasks should commit") {
val rdd = sc.parallelize(Seq(1), 1)
sc.runJob(rdd, OutputCommitFunctions(tempDir.getAbsolutePath).commitSuccessfully _,
- 0 until rdd.partitions.size, allowLocal = false)
+ 0 until rdd.partitions.size)
assert(tempDir.list().size === 1)
}
test("If commit fails, if task is retried it should not be locked, and will succeed.") {
val rdd = sc.parallelize(Seq(1), 1)
sc.runJob(rdd, OutputCommitFunctions(tempDir.getAbsolutePath).failFirstCommitAttempt _,
- 0 until rdd.partitions.size, allowLocal = false)
+ 0 until rdd.partitions.size)
assert(tempDir.list().size === 1)
}
diff --git a/core/src/test/scala/org/apache/spark/scheduler/ReplayListenerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/ReplayListenerSuite.scala
index 4e3defb43a021..103fc19369c97 100644
--- a/core/src/test/scala/org/apache/spark/scheduler/ReplayListenerSuite.scala
+++ b/core/src/test/scala/org/apache/spark/scheduler/ReplayListenerSuite.scala
@@ -102,7 +102,7 @@ class ReplayListenerSuite extends SparkFunSuite with BeforeAndAfter {
fileSystem.mkdirs(logDirPath)
val conf = EventLoggingListenerSuite.getLoggingConf(logDirPath, codecName)
- val sc = new SparkContext("local-cluster[2,1,512]", "Test replay", conf)
+ val sc = new SparkContext("local-cluster[2,1,1024]", "Test replay", conf)
// Run a few jobs
sc.parallelize(1 to 100, 1).count()
diff --git a/core/src/test/scala/org/apache/spark/scheduler/SparkListenerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/SparkListenerSuite.scala
index 651295b7344c5..730535ece7878 100644
--- a/core/src/test/scala/org/apache/spark/scheduler/SparkListenerSuite.scala
+++ b/core/src/test/scala/org/apache/spark/scheduler/SparkListenerSuite.scala
@@ -188,7 +188,7 @@ class SparkListenerSuite extends SparkFunSuite with LocalSparkContext with Match
sc.addSparkListener(listener)
val rdd1 = sc.parallelize(1 to 100, 4)
val rdd2 = rdd1.map(_.toString)
- sc.runJob(rdd2, (items: Iterator[String]) => items.size, Seq(0, 1), true)
+ sc.runJob(rdd2, (items: Iterator[String]) => items.size, Seq(0, 1))
sc.listenerBus.waitUntilEmpty(WAIT_TIMEOUT_MILLIS)
diff --git a/core/src/test/scala/org/apache/spark/scheduler/SparkListenerWithClusterSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/SparkListenerWithClusterSuite.scala
index d97fba00976d2..d1e23ed527ff1 100644
--- a/core/src/test/scala/org/apache/spark/scheduler/SparkListenerWithClusterSuite.scala
+++ b/core/src/test/scala/org/apache/spark/scheduler/SparkListenerWithClusterSuite.scala
@@ -34,7 +34,7 @@ class SparkListenerWithClusterSuite extends SparkFunSuite with LocalSparkContext
val WAIT_TIMEOUT_MILLIS = 10000
before {
- sc = new SparkContext("local-cluster[2,1,512]", "SparkListenerSuite")
+ sc = new SparkContext("local-cluster[2,1,1024]", "SparkListenerSuite")
}
test("SparkListener sends executor added message") {
diff --git a/core/src/test/scala/org/apache/spark/scheduler/TaskContextSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/TaskContextSuite.scala
index 7c1adc1aef1b6..9201d1e1f328b 100644
--- a/core/src/test/scala/org/apache/spark/scheduler/TaskContextSuite.scala
+++ b/core/src/test/scala/org/apache/spark/scheduler/TaskContextSuite.scala
@@ -24,11 +24,27 @@ import org.scalatest.BeforeAndAfter
import org.apache.spark._
import org.apache.spark.rdd.RDD
-import org.apache.spark.util.{TaskCompletionListenerException, TaskCompletionListener}
+import org.apache.spark.util.{TaskCompletionListener, TaskCompletionListenerException}
+import org.apache.spark.metrics.source.JvmSource
class TaskContextSuite extends SparkFunSuite with BeforeAndAfter with LocalSparkContext {
+ test("provide metrics sources") {
+ val filePath = getClass.getClassLoader.getResource("test_metrics_config.properties").getFile
+ val conf = new SparkConf(loadDefaults = false)
+ .set("spark.metrics.conf", filePath)
+ sc = new SparkContext("local", "test", conf)
+ val rdd = sc.makeRDD(1 to 1)
+ val result = sc.runJob(rdd, (tc: TaskContext, it: Iterator[Int]) => {
+ tc.getMetricsSources("jvm").count {
+ case source: JvmSource => true
+ case _ => false
+ }
+ }).sum
+ assert(result > 0)
+ }
+
test("calls TaskCompletionListener after failure") {
TaskContextSuite.completed = false
sc = new SparkContext("local", "test")
@@ -41,16 +57,16 @@ class TaskContextSuite extends SparkFunSuite with BeforeAndAfter with LocalSpark
}
val closureSerializer = SparkEnv.get.closureSerializer.newInstance()
val func = (c: TaskContext, i: Iterator[String]) => i.next()
- val task = new ResultTask[String, String](
- 0, sc.broadcast(closureSerializer.serialize((rdd, func)).array), rdd.partitions(0), Seq(), 0)
+ val task = new ResultTask[String, String](0, 0,
+ sc.broadcast(closureSerializer.serialize((rdd, func)).array), rdd.partitions(0), Seq(), 0)
intercept[RuntimeException] {
- task.run(0, 0)
+ task.run(0, 0, null)
}
assert(TaskContextSuite.completed === true)
}
test("all TaskCompletionListeners should be called even if some fail") {
- val context = new TaskContextImpl(0, 0, 0, 0, null)
+ val context = new TaskContextImpl(0, 0, 0, 0, null, null)
val listener = mock(classOf[TaskCompletionListener])
context.addTaskCompletionListener(_ => throw new Exception("blah"))
context.addTaskCompletionListener(listener)
diff --git a/core/src/test/scala/org/apache/spark/scheduler/TaskSchedulerImplSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/TaskSchedulerImplSuite.scala
index a6d5232feb8de..c2edd4c317d6e 100644
--- a/core/src/test/scala/org/apache/spark/scheduler/TaskSchedulerImplSuite.scala
+++ b/core/src/test/scala/org/apache/spark/scheduler/TaskSchedulerImplSuite.scala
@@ -33,7 +33,7 @@ class TaskSchedulerImplSuite extends SparkFunSuite with LocalSparkContext with L
val taskScheduler = new TaskSchedulerImpl(sc)
taskScheduler.initialize(new FakeSchedulerBackend)
// Need to initialize a DAGScheduler for the taskScheduler to use for callbacks.
- val dagScheduler = new DAGScheduler(sc, taskScheduler) {
+ new DAGScheduler(sc, taskScheduler) {
override def taskStarted(task: Task[_], taskInfo: TaskInfo) {}
override def executorAdded(execId: String, host: String) {}
}
@@ -67,7 +67,7 @@ class TaskSchedulerImplSuite extends SparkFunSuite with LocalSparkContext with L
val taskScheduler = new TaskSchedulerImpl(sc)
taskScheduler.initialize(new FakeSchedulerBackend)
// Need to initialize a DAGScheduler for the taskScheduler to use for callbacks.
- val dagScheduler = new DAGScheduler(sc, taskScheduler) {
+ new DAGScheduler(sc, taskScheduler) {
override def taskStarted(task: Task[_], taskInfo: TaskInfo) {}
override def executorAdded(execId: String, host: String) {}
}
@@ -128,4 +128,113 @@ class TaskSchedulerImplSuite extends SparkFunSuite with LocalSparkContext with L
assert(taskDescriptions.map(_.executorId) === Seq("executor0"))
}
+ test("refuse to schedule concurrent attempts for the same stage (SPARK-8103)") {
+ sc = new SparkContext("local", "TaskSchedulerImplSuite")
+ val taskScheduler = new TaskSchedulerImpl(sc)
+ taskScheduler.initialize(new FakeSchedulerBackend)
+ // Need to initialize a DAGScheduler for the taskScheduler to use for callbacks.
+ val dagScheduler = new DAGScheduler(sc, taskScheduler) {
+ override def taskStarted(task: Task[_], taskInfo: TaskInfo) {}
+ override def executorAdded(execId: String, host: String) {}
+ }
+ taskScheduler.setDAGScheduler(dagScheduler)
+ val attempt1 = FakeTask.createTaskSet(1, 0)
+ val attempt2 = FakeTask.createTaskSet(1, 1)
+ taskScheduler.submitTasks(attempt1)
+ intercept[IllegalStateException] { taskScheduler.submitTasks(attempt2) }
+
+ // OK to submit multiple if previous attempts are all zombie
+ taskScheduler.taskSetManagerForAttempt(attempt1.stageId, attempt1.stageAttemptId)
+ .get.isZombie = true
+ taskScheduler.submitTasks(attempt2)
+ val attempt3 = FakeTask.createTaskSet(1, 2)
+ intercept[IllegalStateException] { taskScheduler.submitTasks(attempt3) }
+ taskScheduler.taskSetManagerForAttempt(attempt2.stageId, attempt2.stageAttemptId)
+ .get.isZombie = true
+ taskScheduler.submitTasks(attempt3)
+ }
+
+ test("don't schedule more tasks after a taskset is zombie") {
+ sc = new SparkContext("local", "TaskSchedulerImplSuite")
+ val taskScheduler = new TaskSchedulerImpl(sc)
+ taskScheduler.initialize(new FakeSchedulerBackend)
+ // Need to initialize a DAGScheduler for the taskScheduler to use for callbacks.
+ new DAGScheduler(sc, taskScheduler) {
+ override def taskStarted(task: Task[_], taskInfo: TaskInfo) {}
+ override def executorAdded(execId: String, host: String) {}
+ }
+
+ val numFreeCores = 1
+ val workerOffers = Seq(new WorkerOffer("executor0", "host0", numFreeCores))
+ val attempt1 = FakeTask.createTaskSet(10)
+
+ // submit attempt 1, offer some resources, some tasks get scheduled
+ taskScheduler.submitTasks(attempt1)
+ val taskDescriptions = taskScheduler.resourceOffers(workerOffers).flatten
+ assert(1 === taskDescriptions.length)
+
+ // now mark attempt 1 as a zombie
+ taskScheduler.taskSetManagerForAttempt(attempt1.stageId, attempt1.stageAttemptId)
+ .get.isZombie = true
+
+ // don't schedule anything on another resource offer
+ val taskDescriptions2 = taskScheduler.resourceOffers(workerOffers).flatten
+ assert(0 === taskDescriptions2.length)
+
+ // if we schedule another attempt for the same stage, it should get scheduled
+ val attempt2 = FakeTask.createTaskSet(10, 1)
+
+ // submit attempt 2, offer some resources, some tasks get scheduled
+ taskScheduler.submitTasks(attempt2)
+ val taskDescriptions3 = taskScheduler.resourceOffers(workerOffers).flatten
+ assert(1 === taskDescriptions3.length)
+ val mgr = taskScheduler.taskIdToTaskSetManager.get(taskDescriptions3(0).taskId).get
+ assert(mgr.taskSet.stageAttemptId === 1)
+ }
+
+ test("if a zombie attempt finishes, continue scheduling tasks for non-zombie attempts") {
+ sc = new SparkContext("local", "TaskSchedulerImplSuite")
+ val taskScheduler = new TaskSchedulerImpl(sc)
+ taskScheduler.initialize(new FakeSchedulerBackend)
+ // Need to initialize a DAGScheduler for the taskScheduler to use for callbacks.
+ new DAGScheduler(sc, taskScheduler) {
+ override def taskStarted(task: Task[_], taskInfo: TaskInfo) {}
+ override def executorAdded(execId: String, host: String) {}
+ }
+
+ val numFreeCores = 10
+ val workerOffers = Seq(new WorkerOffer("executor0", "host0", numFreeCores))
+ val attempt1 = FakeTask.createTaskSet(10)
+
+ // submit attempt 1, offer some resources, some tasks get scheduled
+ taskScheduler.submitTasks(attempt1)
+ val taskDescriptions = taskScheduler.resourceOffers(workerOffers).flatten
+ assert(10 === taskDescriptions.length)
+
+ // now mark attempt 1 as a zombie
+ val mgr1 = taskScheduler.taskSetManagerForAttempt(attempt1.stageId, attempt1.stageAttemptId).get
+ mgr1.isZombie = true
+
+ // don't schedule anything on another resource offer
+ val taskDescriptions2 = taskScheduler.resourceOffers(workerOffers).flatten
+ assert(0 === taskDescriptions2.length)
+
+ // submit attempt 2
+ val attempt2 = FakeTask.createTaskSet(10, 1)
+ taskScheduler.submitTasks(attempt2)
+
+ // attempt 1 finished (this can happen even if it was marked zombie earlier -- all tasks were
+ // already submitted, and then they finish)
+ taskScheduler.taskSetFinished(mgr1)
+
+ // now with another resource offer, we should still schedule all the tasks in attempt2
+ val taskDescriptions3 = taskScheduler.resourceOffers(workerOffers).flatten
+ assert(10 === taskDescriptions3.length)
+
+ taskDescriptions3.foreach { task =>
+ val mgr = taskScheduler.taskIdToTaskSetManager.get(task.taskId).get
+ assert(mgr.taskSet.stageAttemptId === 1)
+ }
+ }
+
}
diff --git a/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala
index 0060f3396dcde..3abb99c4b2b54 100644
--- a/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala
+++ b/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala
@@ -19,12 +19,13 @@ package org.apache.spark.scheduler
import java.util.Random
-import scala.collection.mutable.ArrayBuffer
+import scala.collection.Map
import scala.collection.mutable
+import scala.collection.mutable.ArrayBuffer
import org.apache.spark._
import org.apache.spark.executor.TaskMetrics
-import org.apache.spark.util.{ManualClock, Utils}
+import org.apache.spark.util.ManualClock
class FakeDAGScheduler(sc: SparkContext, taskScheduler: FakeTaskScheduler)
extends DAGScheduler(sc) {
@@ -37,7 +38,7 @@ class FakeDAGScheduler(sc: SparkContext, taskScheduler: FakeTaskScheduler)
task: Task[_],
reason: TaskEndReason,
result: Any,
- accumUpdates: mutable.Map[Long, Any],
+ accumUpdates: Map[Long, Any],
taskInfo: TaskInfo,
taskMetrics: TaskMetrics) {
taskScheduler.endedTasks(taskInfo.index) = reason
@@ -135,7 +136,7 @@ class FakeTaskScheduler(sc: SparkContext, liveExecutors: (String, String)* /* ex
/**
* A Task implementation that results in a large serialized task.
*/
-class LargeTask(stageId: Int) extends Task[Array[Byte]](stageId, 0) {
+class LargeTask(stageId: Int) extends Task[Array[Byte]](stageId, 0, 0) {
val randomBuffer = new Array[Byte](TaskSetManager.TASK_SIZE_TO_WARN_KB * 1024)
val random = new Random(0)
random.nextBytes(randomBuffer)
diff --git a/core/src/test/scala/org/apache/spark/scheduler/cluster/mesos/CoarseMesosSchedulerBackendSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/cluster/mesos/CoarseMesosSchedulerBackendSuite.scala
index 3f1692917a357..4b504df7b8851 100644
--- a/core/src/test/scala/org/apache/spark/scheduler/cluster/mesos/CoarseMesosSchedulerBackendSuite.scala
+++ b/core/src/test/scala/org/apache/spark/scheduler/cluster/mesos/CoarseMesosSchedulerBackendSuite.scala
@@ -22,7 +22,7 @@ import java.util.Collections
import org.apache.mesos.Protos.Value.Scalar
import org.apache.mesos.Protos._
-import org.apache.mesos.SchedulerDriver
+import org.apache.mesos.{Protos, Scheduler, SchedulerDriver}
import org.mockito.Matchers._
import org.mockito.Mockito._
import org.mockito.Matchers
@@ -60,7 +60,16 @@ class CoarseMesosSchedulerBackendSuite extends SparkFunSuite
taskScheduler: TaskSchedulerImpl,
driver: SchedulerDriver): CoarseMesosSchedulerBackend = {
val backend = new CoarseMesosSchedulerBackend(taskScheduler, sc, "master") {
- mesosDriver = driver
+ override protected def createSchedulerDriver(
+ masterUrl: String,
+ scheduler: Scheduler,
+ sparkUser: String,
+ appName: String,
+ conf: SparkConf,
+ webuiUrl: Option[String] = None,
+ checkpoint: Option[Boolean] = None,
+ failoverTimeout: Option[Double] = None,
+ frameworkId: Option[String] = None): SchedulerDriver = driver
markRegistered()
}
backend.start()
@@ -80,6 +89,7 @@ class CoarseMesosSchedulerBackendSuite extends SparkFunSuite
test("mesos supports killing and limiting executors") {
val driver = mock[SchedulerDriver]
+ when(driver.start()).thenReturn(Protos.Status.DRIVER_RUNNING)
val taskScheduler = mock[TaskSchedulerImpl]
when(taskScheduler.sc).thenReturn(sc)
@@ -87,7 +97,7 @@ class CoarseMesosSchedulerBackendSuite extends SparkFunSuite
sparkConf.set("spark.driver.port", "1234")
val backend = createSchedulerBackend(taskScheduler, driver)
- val minMem = backend.calculateTotalMemory(sc).toInt
+ val minMem = backend.calculateTotalMemory(sc)
val minCpu = 4
val mesosOffers = new java.util.ArrayList[Offer]
@@ -130,11 +140,12 @@ class CoarseMesosSchedulerBackendSuite extends SparkFunSuite
test("mesos supports killing and relaunching tasks with executors") {
val driver = mock[SchedulerDriver]
+ when(driver.start()).thenReturn(Protos.Status.DRIVER_RUNNING)
val taskScheduler = mock[TaskSchedulerImpl]
when(taskScheduler.sc).thenReturn(sc)
val backend = createSchedulerBackend(taskScheduler, driver)
- val minMem = backend.calculateTotalMemory(sc).toInt + 1024
+ val minMem = backend.calculateTotalMemory(sc) + 1024
val minCpu = 4
val mesosOffers = new java.util.ArrayList[Offer]
diff --git a/core/src/test/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerBackendSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerBackendSuite.scala
index d01837fe78957..5ed30f64d705f 100644
--- a/core/src/test/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerBackendSuite.scala
+++ b/core/src/test/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerBackendSuite.scala
@@ -21,6 +21,7 @@ import java.nio.ByteBuffer
import java.util
import java.util.Collections
+import scala.collection.JavaConversions._
import scala.collection.mutable
import scala.collection.mutable.ArrayBuffer
@@ -60,14 +61,17 @@ class MesosSchedulerBackendSuite extends SparkFunSuite with LocalSparkContext wi
val mesosSchedulerBackend = new MesosSchedulerBackend(taskScheduler, sc, "master")
+ val resources = List(
+ mesosSchedulerBackend.createResource("cpus", 4),
+ mesosSchedulerBackend.createResource("mem", 1024))
// uri is null.
- val executorInfo = mesosSchedulerBackend.createExecutorInfo("test-id")
+ val (executorInfo, _) = mesosSchedulerBackend.createExecutorInfo(resources, "test-id")
assert(executorInfo.getCommand.getValue ===
s" /mesos-home/bin/spark-class ${classOf[MesosExecutorBackend].getName}")
// uri exists.
conf.set("spark.executor.uri", "hdfs:///test-app-1.0.0.tgz")
- val executorInfo1 = mesosSchedulerBackend.createExecutorInfo("test-id")
+ val (executorInfo1, _) = mesosSchedulerBackend.createExecutorInfo(resources, "test-id")
assert(executorInfo1.getCommand.getValue ===
s"cd test-app-1*; ./bin/spark-class ${classOf[MesosExecutorBackend].getName}")
}
@@ -93,7 +97,8 @@ class MesosSchedulerBackendSuite extends SparkFunSuite with LocalSparkContext wi
val backend = new MesosSchedulerBackend(taskScheduler, sc, "master")
- val execInfo = backend.createExecutorInfo("mockExecutor")
+ val (execInfo, _) = backend.createExecutorInfo(
+ List(backend.createResource("cpus", 4)), "mockExecutor")
assert(execInfo.getContainer.getDocker.getImage.equals("spark/mock"))
val portmaps = execInfo.getContainer.getDocker.getPortMappingsList
assert(portmaps.get(0).getHostPort.equals(80))
@@ -194,7 +199,7 @@ class MesosSchedulerBackendSuite extends SparkFunSuite with LocalSparkContext wi
)
verify(driver, times(1)).declineOffer(mesosOffers.get(1).getId)
verify(driver, times(1)).declineOffer(mesosOffers.get(2).getId)
- assert(capture.getValue.size() == 1)
+ assert(capture.getValue.size() === 1)
val taskInfo = capture.getValue.iterator().next()
assert(taskInfo.getName.equals("n1"))
val cpus = taskInfo.getResourcesList.get(0)
@@ -214,4 +219,97 @@ class MesosSchedulerBackendSuite extends SparkFunSuite with LocalSparkContext wi
backend.resourceOffers(driver, mesosOffers2)
verify(driver, times(1)).declineOffer(mesosOffers2.get(0).getId)
}
+
+ test("can handle multiple roles") {
+ val driver = mock[SchedulerDriver]
+ val taskScheduler = mock[TaskSchedulerImpl]
+
+ val listenerBus = mock[LiveListenerBus]
+ listenerBus.post(
+ SparkListenerExecutorAdded(anyLong, "s1", new ExecutorInfo("host1", 2, Map.empty)))
+
+ val sc = mock[SparkContext]
+ when(sc.executorMemory).thenReturn(100)
+ when(sc.getSparkHome()).thenReturn(Option("/path"))
+ when(sc.executorEnvs).thenReturn(new mutable.HashMap[String, String])
+ when(sc.conf).thenReturn(new SparkConf)
+ when(sc.listenerBus).thenReturn(listenerBus)
+
+ val id = 1
+ val builder = Offer.newBuilder()
+ builder.addResourcesBuilder()
+ .setName("mem")
+ .setType(Value.Type.SCALAR)
+ .setRole("prod")
+ .setScalar(Scalar.newBuilder().setValue(500))
+ builder.addResourcesBuilder()
+ .setName("cpus")
+ .setRole("prod")
+ .setType(Value.Type.SCALAR)
+ .setScalar(Scalar.newBuilder().setValue(1))
+ builder.addResourcesBuilder()
+ .setName("mem")
+ .setRole("dev")
+ .setType(Value.Type.SCALAR)
+ .setScalar(Scalar.newBuilder().setValue(600))
+ builder.addResourcesBuilder()
+ .setName("cpus")
+ .setRole("dev")
+ .setType(Value.Type.SCALAR)
+ .setScalar(Scalar.newBuilder().setValue(2))
+ val offer = builder.setId(OfferID.newBuilder().setValue(s"o${id.toString}").build())
+ .setFrameworkId(FrameworkID.newBuilder().setValue("f1"))
+ .setSlaveId(SlaveID.newBuilder().setValue(s"s${id.toString}"))
+ .setHostname(s"host${id.toString}").build()
+
+
+ val mesosOffers = new java.util.ArrayList[Offer]
+ mesosOffers.add(offer)
+
+ val backend = new MesosSchedulerBackend(taskScheduler, sc, "master")
+
+ val expectedWorkerOffers = new ArrayBuffer[WorkerOffer](1)
+ expectedWorkerOffers.append(new WorkerOffer(
+ mesosOffers.get(0).getSlaveId.getValue,
+ mesosOffers.get(0).getHostname,
+ 2 // Deducting 1 for executor
+ ))
+
+ val taskDesc = new TaskDescription(1L, 0, "s1", "n1", 0, ByteBuffer.wrap(new Array[Byte](0)))
+ when(taskScheduler.resourceOffers(expectedWorkerOffers)).thenReturn(Seq(Seq(taskDesc)))
+ when(taskScheduler.CPUS_PER_TASK).thenReturn(1)
+
+ val capture = ArgumentCaptor.forClass(classOf[util.Collection[TaskInfo]])
+ when(
+ driver.launchTasks(
+ Matchers.eq(Collections.singleton(mesosOffers.get(0).getId)),
+ capture.capture(),
+ any(classOf[Filters])
+ )
+ ).thenReturn(Status.valueOf(1))
+
+ backend.resourceOffers(driver, mesosOffers)
+
+ verify(driver, times(1)).launchTasks(
+ Matchers.eq(Collections.singleton(mesosOffers.get(0).getId)),
+ capture.capture(),
+ any(classOf[Filters])
+ )
+
+ assert(capture.getValue.size() === 1)
+ val taskInfo = capture.getValue.iterator().next()
+ assert(taskInfo.getName.equals("n1"))
+ assert(taskInfo.getResourcesCount === 1)
+ val cpusDev = taskInfo.getResourcesList.get(0)
+ assert(cpusDev.getName.equals("cpus"))
+ assert(cpusDev.getScalar.getValue.equals(1.0))
+ assert(cpusDev.getRole.equals("dev"))
+ val executorResources = taskInfo.getExecutor.getResourcesList
+ assert(executorResources.exists { r =>
+ r.getName.equals("mem") && r.getScalar.getValue.equals(484.0) && r.getRole.equals("prod")
+ })
+ assert(executorResources.exists { r =>
+ r.getName.equals("cpus") && r.getScalar.getValue.equals(1.0) && r.getRole.equals("prod")
+ })
+ }
}
diff --git a/core/src/test/scala/org/apache/spark/serializer/KryoSerializerDistributedSuite.scala b/core/src/test/scala/org/apache/spark/serializer/KryoSerializerDistributedSuite.scala
index 63a8480c9b57b..935a091f14f9b 100644
--- a/core/src/test/scala/org/apache/spark/serializer/KryoSerializerDistributedSuite.scala
+++ b/core/src/test/scala/org/apache/spark/serializer/KryoSerializerDistributedSuite.scala
@@ -35,7 +35,7 @@ class KryoSerializerDistributedSuite extends SparkFunSuite {
val jar = TestUtils.createJarWithClasses(List(AppJarRegistrator.customClassName))
conf.setJars(List(jar.getPath))
- val sc = new SparkContext("local-cluster[2,1,512]", "test", conf)
+ val sc = new SparkContext("local-cluster[2,1,1024]", "test", conf)
val original = Thread.currentThread.getContextClassLoader
val loader = new java.net.URLClassLoader(Array(jar), Utils.getContextOrSparkClassLoader)
SparkEnv.get.serializer.setDefaultClassLoader(loader)
@@ -59,7 +59,9 @@ object KryoDistributedTest {
class AppJarRegistrator extends KryoRegistrator {
override def registerClasses(k: Kryo) {
val classLoader = Thread.currentThread.getContextClassLoader
+ // scalastyle:off classforname
k.register(Class.forName(AppJarRegistrator.customClassName, true, classLoader))
+ // scalastyle:on classforname
}
}
diff --git a/core/src/test/scala/org/apache/spark/shuffle/hash/HashShuffleReaderSuite.scala b/core/src/test/scala/org/apache/spark/shuffle/hash/HashShuffleReaderSuite.scala
index 28ca68698e3dc..db718ecabbdb9 100644
--- a/core/src/test/scala/org/apache/spark/shuffle/hash/HashShuffleReaderSuite.scala
+++ b/core/src/test/scala/org/apache/spark/shuffle/hash/HashShuffleReaderSuite.scala
@@ -115,11 +115,15 @@ class HashShuffleReaderSuite extends SparkFunSuite with LocalSparkContext {
// Make a mocked MapOutputTracker for the shuffle reader to use to determine what
// shuffle data to read.
val mapOutputTracker = mock(classOf[MapOutputTracker])
- // Test a scenario where all data is local, just to avoid creating a bunch of additional mocks
- // for the code to read data over the network.
- val statuses: Array[(BlockManagerId, Long)] =
- Array.fill(numMaps)((localBlockManagerId, byteOutputStream.size().toLong))
- when(mapOutputTracker.getServerStatuses(shuffleId, reduceId)).thenReturn(statuses)
+ when(mapOutputTracker.getMapSizesByExecutorId(shuffleId, reduceId)).thenReturn {
+ // Test a scenario where all data is local, to avoid creating a bunch of additional mocks
+ // for the code to read data over the network.
+ val shuffleBlockIdsAndSizes = (0 until numMaps).map { mapId =>
+ val shuffleBlockId = ShuffleBlockId(shuffleId, mapId, reduceId)
+ (shuffleBlockId, byteOutputStream.size().toLong)
+ }
+ Seq((localBlockManagerId, shuffleBlockIdsAndSizes))
+ }
// Create a mocked shuffle handle to pass into HashShuffleReader.
val shuffleHandle = {
@@ -134,7 +138,7 @@ class HashShuffleReaderSuite extends SparkFunSuite with LocalSparkContext {
shuffleHandle,
reduceId,
reduceId + 1,
- new TaskContextImpl(0, 0, 0, 0, null),
+ new TaskContextImpl(0, 0, 0, 0, null, null),
blockManager,
mapOutputTracker)
diff --git a/core/src/test/scala/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriterSuite.scala b/core/src/test/scala/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriterSuite.scala
index 542f8f45125a4..cc7342f1ecd78 100644
--- a/core/src/test/scala/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriterSuite.scala
+++ b/core/src/test/scala/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriterSuite.scala
@@ -68,8 +68,8 @@ class BypassMergeSortShuffleWriterSuite extends SparkFunSuite with BeforeAndAfte
any[SerializerInstance],
anyInt(),
any[ShuffleWriteMetrics]
- )).thenAnswer(new Answer[BlockObjectWriter] {
- override def answer(invocation: InvocationOnMock): BlockObjectWriter = {
+ )).thenAnswer(new Answer[DiskBlockObjectWriter] {
+ override def answer(invocation: InvocationOnMock): DiskBlockObjectWriter = {
val args = invocation.getArguments
new DiskBlockObjectWriter(
args(0).asInstanceOf[BlockId],
diff --git a/core/src/test/scala/org/apache/spark/storage/BlockStatusListenerSuite.scala b/core/src/test/scala/org/apache/spark/storage/BlockStatusListenerSuite.scala
new file mode 100644
index 0000000000000..d7ffde1e7864e
--- /dev/null
+++ b/core/src/test/scala/org/apache/spark/storage/BlockStatusListenerSuite.scala
@@ -0,0 +1,119 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.storage
+
+import org.apache.spark.SparkFunSuite
+import org.apache.spark.scheduler._
+
+class BlockStatusListenerSuite extends SparkFunSuite {
+
+ test("basic functions") {
+ val blockManagerId = BlockManagerId("0", "localhost", 10000)
+ val listener = new BlockStatusListener()
+
+ // Add a block manager and a new block status
+ listener.onBlockManagerAdded(SparkListenerBlockManagerAdded(0, blockManagerId, 0))
+ listener.onBlockUpdated(SparkListenerBlockUpdated(
+ BlockUpdatedInfo(
+ blockManagerId,
+ StreamBlockId(0, 100),
+ StorageLevel.MEMORY_AND_DISK,
+ memSize = 100,
+ diskSize = 100,
+ externalBlockStoreSize = 0)))
+ // The new block status should be added to the listener
+ val expectedBlock = BlockUIData(
+ StreamBlockId(0, 100),
+ "localhost:10000",
+ StorageLevel.MEMORY_AND_DISK,
+ memSize = 100,
+ diskSize = 100,
+ externalBlockStoreSize = 0
+ )
+ val expectedExecutorStreamBlockStatus = Seq(
+ ExecutorStreamBlockStatus("0", "localhost:10000", Seq(expectedBlock))
+ )
+ assert(listener.allExecutorStreamBlockStatus === expectedExecutorStreamBlockStatus)
+
+ // Add the second block manager
+ val blockManagerId2 = BlockManagerId("1", "localhost", 10001)
+ listener.onBlockManagerAdded(SparkListenerBlockManagerAdded(0, blockManagerId2, 0))
+ // Add a new replication of the same block id from the second manager
+ listener.onBlockUpdated(SparkListenerBlockUpdated(
+ BlockUpdatedInfo(
+ blockManagerId2,
+ StreamBlockId(0, 100),
+ StorageLevel.MEMORY_AND_DISK,
+ memSize = 100,
+ diskSize = 100,
+ externalBlockStoreSize = 0)))
+ val expectedBlock2 = BlockUIData(
+ StreamBlockId(0, 100),
+ "localhost:10001",
+ StorageLevel.MEMORY_AND_DISK,
+ memSize = 100,
+ diskSize = 100,
+ externalBlockStoreSize = 0
+ )
+ // Each block manager should contain one block
+ val expectedExecutorStreamBlockStatus2 = Set(
+ ExecutorStreamBlockStatus("0", "localhost:10000", Seq(expectedBlock)),
+ ExecutorStreamBlockStatus("1", "localhost:10001", Seq(expectedBlock2))
+ )
+ assert(listener.allExecutorStreamBlockStatus.toSet === expectedExecutorStreamBlockStatus2)
+
+ // Remove a replication of the same block
+ listener.onBlockUpdated(SparkListenerBlockUpdated(
+ BlockUpdatedInfo(
+ blockManagerId2,
+ StreamBlockId(0, 100),
+ StorageLevel.NONE, // StorageLevel.NONE means removing it
+ memSize = 0,
+ diskSize = 0,
+ externalBlockStoreSize = 0)))
+ // Only the first block manager contains a block
+ val expectedExecutorStreamBlockStatus3 = Set(
+ ExecutorStreamBlockStatus("0", "localhost:10000", Seq(expectedBlock)),
+ ExecutorStreamBlockStatus("1", "localhost:10001", Seq.empty)
+ )
+ assert(listener.allExecutorStreamBlockStatus.toSet === expectedExecutorStreamBlockStatus3)
+
+ // Remove the second block manager at first but add a new block status
+ // from this removed block manager
+ listener.onBlockManagerRemoved(SparkListenerBlockManagerRemoved(0, blockManagerId2))
+ listener.onBlockUpdated(SparkListenerBlockUpdated(
+ BlockUpdatedInfo(
+ blockManagerId2,
+ StreamBlockId(0, 100),
+ StorageLevel.MEMORY_AND_DISK,
+ memSize = 100,
+ diskSize = 100,
+ externalBlockStoreSize = 0)))
+ // The second block manager is removed so we should not see the new block
+ val expectedExecutorStreamBlockStatus4 = Seq(
+ ExecutorStreamBlockStatus("0", "localhost:10000", Seq(expectedBlock))
+ )
+ assert(listener.allExecutorStreamBlockStatus === expectedExecutorStreamBlockStatus4)
+
+ // Remove the last block manager
+ listener.onBlockManagerRemoved(SparkListenerBlockManagerRemoved(0, blockManagerId))
+ // No block manager now so we should dop all block managers
+ assert(listener.allExecutorStreamBlockStatus.isEmpty)
+ }
+
+}
diff --git a/core/src/test/scala/org/apache/spark/storage/BlockObjectWriterSuite.scala b/core/src/test/scala/org/apache/spark/storage/DiskBlockObjectWriterSuite.scala
similarity index 98%
rename from core/src/test/scala/org/apache/spark/storage/BlockObjectWriterSuite.scala
rename to core/src/test/scala/org/apache/spark/storage/DiskBlockObjectWriterSuite.scala
index 7bdea724fea58..66af6e1a79740 100644
--- a/core/src/test/scala/org/apache/spark/storage/BlockObjectWriterSuite.scala
+++ b/core/src/test/scala/org/apache/spark/storage/DiskBlockObjectWriterSuite.scala
@@ -26,7 +26,7 @@ import org.apache.spark.executor.ShuffleWriteMetrics
import org.apache.spark.serializer.JavaSerializer
import org.apache.spark.util.Utils
-class BlockObjectWriterSuite extends SparkFunSuite with BeforeAndAfterEach {
+class DiskBlockObjectWriterSuite extends SparkFunSuite with BeforeAndAfterEach {
var tempDir: File = _
diff --git a/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala b/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala
index 9ced4148d7206..cf8bd8ae69625 100644
--- a/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala
+++ b/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala
@@ -33,6 +33,7 @@ import org.apache.spark.{SparkFunSuite, TaskContextImpl}
import org.apache.spark.network._
import org.apache.spark.network.buffer.ManagedBuffer
import org.apache.spark.network.shuffle.BlockFetchingListener
+import org.apache.spark.shuffle.FetchFailedException
class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodTester {
@@ -94,7 +95,7 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT
)
val iterator = new ShuffleBlockFetcherIterator(
- new TaskContextImpl(0, 0, 0, 0, null),
+ new TaskContextImpl(0, 0, 0, 0, null, null),
transfer,
blockManager,
blocksByAddress,
@@ -106,13 +107,11 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT
for (i <- 0 until 5) {
assert(iterator.hasNext, s"iterator should have 5 elements but actually has $i elements")
val (blockId, inputStream) = iterator.next()
- assert(inputStream.isSuccess,
- s"iterator should have 5 elements defined but actually has $i elements")
// Make sure we release buffers when a wrapped input stream is closed.
val mockBuf = localBlocks.getOrElse(blockId, remoteBlocks(blockId))
// Note: ShuffleBlockFetcherIterator wraps input streams in a BufferReleasingInputStream
- val wrappedInputStream = inputStream.get.asInstanceOf[BufferReleasingInputStream]
+ val wrappedInputStream = inputStream.asInstanceOf[BufferReleasingInputStream]
verify(mockBuf, times(0)).release()
val delegateAccess = PrivateMethod[InputStream]('delegate)
@@ -166,7 +165,7 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT
val blocksByAddress = Seq[(BlockManagerId, Seq[(BlockId, Long)])](
(remoteBmId, blocks.keys.map(blockId => (blockId, 1.asInstanceOf[Long])).toSeq))
- val taskContext = new TaskContextImpl(0, 0, 0, 0, null)
+ val taskContext = new TaskContextImpl(0, 0, 0, 0, null, null)
val iterator = new ShuffleBlockFetcherIterator(
taskContext,
transfer,
@@ -175,11 +174,11 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT
48 * 1024 * 1024)
verify(blocks(ShuffleBlockId(0, 0, 0)), times(0)).release()
- iterator.next()._2.get.close() // close() first block's input stream
+ iterator.next()._2.close() // close() first block's input stream
verify(blocks(ShuffleBlockId(0, 0, 0)), times(1)).release()
// Get the 2nd block but do not exhaust the iterator
- val subIter = iterator.next()._2.get
+ val subIter = iterator.next()._2
// Complete the task; then the 2nd block buffer should be exhausted
verify(blocks(ShuffleBlockId(0, 1, 0)), times(0)).release()
@@ -228,7 +227,7 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT
val blocksByAddress = Seq[(BlockManagerId, Seq[(BlockId, Long)])](
(remoteBmId, blocks.keys.map(blockId => (blockId, 1.asInstanceOf[Long])).toSeq))
- val taskContext = new TaskContextImpl(0, 0, 0, 0, null)
+ val taskContext = new TaskContextImpl(0, 0, 0, 0, null, null)
val iterator = new ShuffleBlockFetcherIterator(
taskContext,
transfer,
@@ -239,9 +238,10 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT
// Continue only after the mock calls onBlockFetchFailure
sem.acquire()
- // The first block should be defined, and the last two are not defined (due to failure)
- assert(iterator.next()._2.isSuccess)
- assert(iterator.next()._2.isFailure)
- assert(iterator.next()._2.isFailure)
+ // The first block should be returned without an exception, and the last two should throw
+ // FetchFailedExceptions (due to failure)
+ iterator.next()
+ intercept[FetchFailedException] { iterator.next() }
+ intercept[FetchFailedException] { iterator.next() }
}
}
diff --git a/core/src/test/scala/org/apache/spark/ui/PagedTableSuite.scala b/core/src/test/scala/org/apache/spark/ui/PagedTableSuite.scala
new file mode 100644
index 0000000000000..cc76c141c53cc
--- /dev/null
+++ b/core/src/test/scala/org/apache/spark/ui/PagedTableSuite.scala
@@ -0,0 +1,99 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.ui
+
+import scala.xml.Node
+
+import org.apache.spark.SparkFunSuite
+
+class PagedDataSourceSuite extends SparkFunSuite {
+
+ test("basic") {
+ val dataSource1 = new SeqPagedDataSource[Int](1 to 5, pageSize = 2)
+ assert(dataSource1.pageData(1) === PageData(3, (1 to 2)))
+
+ val dataSource2 = new SeqPagedDataSource[Int](1 to 5, pageSize = 2)
+ assert(dataSource2.pageData(2) === PageData(3, (3 to 4)))
+
+ val dataSource3 = new SeqPagedDataSource[Int](1 to 5, pageSize = 2)
+ assert(dataSource3.pageData(3) === PageData(3, Seq(5)))
+
+ val dataSource4 = new SeqPagedDataSource[Int](1 to 5, pageSize = 2)
+ val e1 = intercept[IndexOutOfBoundsException] {
+ dataSource4.pageData(4)
+ }
+ assert(e1.getMessage === "Page 4 is out of range. Please select a page number between 1 and 3.")
+
+ val dataSource5 = new SeqPagedDataSource[Int](1 to 5, pageSize = 2)
+ val e2 = intercept[IndexOutOfBoundsException] {
+ dataSource5.pageData(0)
+ }
+ assert(e2.getMessage === "Page 0 is out of range. Please select a page number between 1 and 3.")
+
+ }
+}
+
+class PagedTableSuite extends SparkFunSuite {
+ test("pageNavigation") {
+ // Create a fake PagedTable to test pageNavigation
+ val pagedTable = new PagedTable[Int] {
+ override def tableId: String = ""
+
+ override def tableCssClass: String = ""
+
+ override def dataSource: PagedDataSource[Int] = null
+
+ override def pageLink(page: Int): String = page.toString
+
+ override def headers: Seq[Node] = Nil
+
+ override def row(t: Int): Seq[Node] = Nil
+
+ override def goButtonJavascriptFunction: (String, String) = ("", "")
+ }
+
+ assert(pagedTable.pageNavigation(1, 10, 1) === Nil)
+ assert(
+ (pagedTable.pageNavigation(1, 10, 2).head \\ "li").map(_.text.trim) === Seq("1", "2", ">"))
+ assert(
+ (pagedTable.pageNavigation(2, 10, 2).head \\ "li").map(_.text.trim) === Seq("<", "1", "2"))
+
+ assert((pagedTable.pageNavigation(1, 10, 100).head \\ "li").map(_.text.trim) ===
+ (1 to 10).map(_.toString) ++ Seq(">", ">>"))
+ assert((pagedTable.pageNavigation(2, 10, 100).head \\ "li").map(_.text.trim) ===
+ Seq("<") ++ (1 to 10).map(_.toString) ++ Seq(">", ">>"))
+
+ assert((pagedTable.pageNavigation(100, 10, 100).head \\ "li").map(_.text.trim) ===
+ Seq("<<", "<") ++ (91 to 100).map(_.toString))
+ assert((pagedTable.pageNavigation(99, 10, 100).head \\ "li").map(_.text.trim) ===
+ Seq("<<", "<") ++ (91 to 100).map(_.toString) ++ Seq(">"))
+
+ assert((pagedTable.pageNavigation(11, 10, 100).head \\ "li").map(_.text.trim) ===
+ Seq("<<", "<") ++ (11 to 20).map(_.toString) ++ Seq(">", ">>"))
+ assert((pagedTable.pageNavigation(93, 10, 97).head \\ "li").map(_.text.trim) ===
+ Seq("<<", "<") ++ (91 to 97).map(_.toString) ++ Seq(">"))
+ }
+}
+
+private[spark] class SeqPagedDataSource[T](seq: Seq[T], pageSize: Int)
+ extends PagedDataSource[T](pageSize) {
+
+ override protected def dataSize: Int = seq.size
+
+ override protected def sliceData(from: Int, to: Int): Seq[T] = seq.slice(from, to)
+}
diff --git a/core/src/test/scala/org/apache/spark/ui/storage/StoragePageSuite.scala b/core/src/test/scala/org/apache/spark/ui/storage/StoragePageSuite.scala
new file mode 100644
index 0000000000000..3dab15a9d4691
--- /dev/null
+++ b/core/src/test/scala/org/apache/spark/ui/storage/StoragePageSuite.scala
@@ -0,0 +1,230 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.ui.storage
+
+import scala.xml.Utility
+
+import org.mockito.Mockito._
+
+import org.apache.spark.SparkFunSuite
+import org.apache.spark.storage._
+
+class StoragePageSuite extends SparkFunSuite {
+
+ val storageTab = mock(classOf[StorageTab])
+ when(storageTab.basePath).thenReturn("http://localhost:4040")
+ val storagePage = new StoragePage(storageTab)
+
+ test("rddTable") {
+ val rdd1 = new RDDInfo(1,
+ "rdd1",
+ 10,
+ StorageLevel.MEMORY_ONLY,
+ Seq.empty)
+ rdd1.memSize = 100
+ rdd1.numCachedPartitions = 10
+
+ val rdd2 = new RDDInfo(2,
+ "rdd2",
+ 10,
+ StorageLevel.DISK_ONLY,
+ Seq.empty)
+ rdd2.diskSize = 200
+ rdd2.numCachedPartitions = 5
+
+ val rdd3 = new RDDInfo(3,
+ "rdd3",
+ 10,
+ StorageLevel.MEMORY_AND_DISK_SER,
+ Seq.empty)
+ rdd3.memSize = 400
+ rdd3.diskSize = 500
+ rdd3.numCachedPartitions = 10
+
+ val xmlNodes = storagePage.rddTable(Seq(rdd1, rdd2, rdd3))
+
+ val headers = Seq(
+ "RDD Name",
+ "Storage Level",
+ "Cached Partitions",
+ "Fraction Cached",
+ "Size in Memory",
+ "Size in ExternalBlockStore",
+ "Size on Disk")
+ assert((xmlNodes \\ "th").map(_.text) === headers)
+
+ assert((xmlNodes \\ "tr").size === 3)
+ assert(((xmlNodes \\ "tr")(0) \\ "td").map(_.text.trim) ===
+ Seq("rdd1", "Memory Deserialized 1x Replicated", "10", "100%", "100.0 B", "0.0 B", "0.0 B"))
+ // Check the url
+ assert(((xmlNodes \\ "tr")(0) \\ "td" \ "a")(0).attribute("href").map(_.text) ===
+ Some("http://localhost:4040/storage/rdd?id=1"))
+
+ assert(((xmlNodes \\ "tr")(1) \\ "td").map(_.text.trim) ===
+ Seq("rdd2", "Disk Serialized 1x Replicated", "5", "50%", "0.0 B", "0.0 B", "200.0 B"))
+ // Check the url
+ assert(((xmlNodes \\ "tr")(1) \\ "td" \ "a")(0).attribute("href").map(_.text) ===
+ Some("http://localhost:4040/storage/rdd?id=2"))
+
+ assert(((xmlNodes \\ "tr")(2) \\ "td").map(_.text.trim) ===
+ Seq("rdd3", "Disk Memory Serialized 1x Replicated", "10", "100%", "400.0 B", "0.0 B",
+ "500.0 B"))
+ // Check the url
+ assert(((xmlNodes \\ "tr")(2) \\ "td" \ "a")(0).attribute("href").map(_.text) ===
+ Some("http://localhost:4040/storage/rdd?id=3"))
+ }
+
+ test("empty rddTable") {
+ assert(storagePage.rddTable(Seq.empty).isEmpty)
+ }
+
+ test("streamBlockStorageLevelDescriptionAndSize") {
+ val memoryBlock = BlockUIData(StreamBlockId(0, 0),
+ "localhost:1111",
+ StorageLevel.MEMORY_ONLY,
+ memSize = 100,
+ diskSize = 0,
+ externalBlockStoreSize = 0)
+ assert(("Memory", 100) === storagePage.streamBlockStorageLevelDescriptionAndSize(memoryBlock))
+
+ val memorySerializedBlock = BlockUIData(StreamBlockId(0, 0),
+ "localhost:1111",
+ StorageLevel.MEMORY_ONLY_SER,
+ memSize = 100,
+ diskSize = 0,
+ externalBlockStoreSize = 0)
+ assert(("Memory Serialized", 100) ===
+ storagePage.streamBlockStorageLevelDescriptionAndSize(memorySerializedBlock))
+
+ val diskBlock = BlockUIData(StreamBlockId(0, 0),
+ "localhost:1111",
+ StorageLevel.DISK_ONLY,
+ memSize = 0,
+ diskSize = 100,
+ externalBlockStoreSize = 0)
+ assert(("Disk", 100) === storagePage.streamBlockStorageLevelDescriptionAndSize(diskBlock))
+
+ val externalBlock = BlockUIData(StreamBlockId(0, 0),
+ "localhost:1111",
+ StorageLevel.OFF_HEAP,
+ memSize = 0,
+ diskSize = 0,
+ externalBlockStoreSize = 100)
+ assert(("External", 100) ===
+ storagePage.streamBlockStorageLevelDescriptionAndSize(externalBlock))
+ }
+
+ test("receiverBlockTables") {
+ val blocksForExecutor0 = Seq(
+ BlockUIData(StreamBlockId(0, 0),
+ "localhost:10000",
+ StorageLevel.MEMORY_ONLY,
+ memSize = 100,
+ diskSize = 0,
+ externalBlockStoreSize = 0),
+ BlockUIData(StreamBlockId(1, 1),
+ "localhost:10000",
+ StorageLevel.DISK_ONLY,
+ memSize = 0,
+ diskSize = 100,
+ externalBlockStoreSize = 0)
+ )
+ val executor0 = ExecutorStreamBlockStatus("0", "localhost:10000", blocksForExecutor0)
+
+ val blocksForExecutor1 = Seq(
+ BlockUIData(StreamBlockId(0, 0),
+ "localhost:10001",
+ StorageLevel.MEMORY_ONLY,
+ memSize = 100,
+ diskSize = 0,
+ externalBlockStoreSize = 0),
+ BlockUIData(StreamBlockId(2, 2),
+ "localhost:10001",
+ StorageLevel.OFF_HEAP,
+ memSize = 0,
+ diskSize = 0,
+ externalBlockStoreSize = 200),
+ BlockUIData(StreamBlockId(1, 1),
+ "localhost:10001",
+ StorageLevel.MEMORY_ONLY_SER,
+ memSize = 100,
+ diskSize = 0,
+ externalBlockStoreSize = 0)
+ )
+ val executor1 = ExecutorStreamBlockStatus("1", "localhost:10001", blocksForExecutor1)
+ val xmlNodes = storagePage.receiverBlockTables(Seq(executor0, executor1))
+
+ val executorTable = (xmlNodes \\ "table")(0)
+ val executorHeaders = Seq(
+ "Executor ID",
+ "Address",
+ "Total Size in Memory",
+ "Total Size in ExternalBlockStore",
+ "Total Size on Disk",
+ "Stream Blocks")
+ assert((executorTable \\ "th").map(_.text) === executorHeaders)
+
+ assert((executorTable \\ "tr").size === 2)
+ assert(((executorTable \\ "tr")(0) \\ "td").map(_.text.trim) ===
+ Seq("0", "localhost:10000", "100.0 B", "0.0 B", "100.0 B", "2"))
+ assert(((executorTable \\ "tr")(1) \\ "td").map(_.text.trim) ===
+ Seq("1", "localhost:10001", "200.0 B", "200.0 B", "0.0 B", "3"))
+
+ val blockTable = (xmlNodes \\ "table")(1)
+ val blockHeaders = Seq(
+ "Block ID",
+ "Replication Level",
+ "Location",
+ "Storage Level",
+ "Size")
+ assert((blockTable \\ "th").map(_.text) === blockHeaders)
+
+ assert((blockTable \\ "tr").size === 5)
+ assert(((blockTable \\ "tr")(0) \\ "td").map(_.text.trim) ===
+ Seq("input-0-0", "2", "localhost:10000", "Memory", "100.0 B"))
+ // Check "rowspan=2" for the first 2 columns
+ assert(((blockTable \\ "tr")(0) \\ "td")(0).attribute("rowspan").map(_.text) === Some("2"))
+ assert(((blockTable \\ "tr")(0) \\ "td")(1).attribute("rowspan").map(_.text) === Some("2"))
+
+ assert(((blockTable \\ "tr")(1) \\ "td").map(_.text.trim) ===
+ Seq("localhost:10001", "Memory", "100.0 B"))
+
+ assert(((blockTable \\ "tr")(2) \\ "td").map(_.text.trim) ===
+ Seq("input-1-1", "2", "localhost:10000", "Disk", "100.0 B"))
+ // Check "rowspan=2" for the first 2 columns
+ assert(((blockTable \\ "tr")(2) \\ "td")(0).attribute("rowspan").map(_.text) === Some("2"))
+ assert(((blockTable \\ "tr")(2) \\ "td")(1).attribute("rowspan").map(_.text) === Some("2"))
+
+ assert(((blockTable \\ "tr")(3) \\ "td").map(_.text.trim) ===
+ Seq("localhost:10001", "Memory Serialized", "100.0 B"))
+
+ assert(((blockTable \\ "tr")(4) \\ "td").map(_.text.trim) ===
+ Seq("input-2-2", "1", "localhost:10001", "External", "200.0 B"))
+ // Check "rowspan=1" for the first 2 columns
+ assert(((blockTable \\ "tr")(4) \\ "td")(0).attribute("rowspan").map(_.text) === Some("1"))
+ assert(((blockTable \\ "tr")(4) \\ "td")(1).attribute("rowspan").map(_.text) === Some("1"))
+ }
+
+ test("empty receiverBlockTables") {
+ assert(storagePage.receiverBlockTables(Seq.empty).isEmpty)
+
+ val executor0 = ExecutorStreamBlockStatus("0", "localhost:10000", Seq.empty)
+ val executor1 = ExecutorStreamBlockStatus("1", "localhost:10001", Seq.empty)
+ assert(storagePage.receiverBlockTables(Seq(executor0, executor1)).isEmpty)
+ }
+}
diff --git a/core/src/test/scala/org/apache/spark/util/AkkaUtilsSuite.scala b/core/src/test/scala/org/apache/spark/util/AkkaUtilsSuite.scala
index 6c40685484ed4..61601016e005e 100644
--- a/core/src/test/scala/org/apache/spark/util/AkkaUtilsSuite.scala
+++ b/core/src/test/scala/org/apache/spark/util/AkkaUtilsSuite.scala
@@ -17,6 +17,8 @@
package org.apache.spark.util
+import scala.collection.mutable.ArrayBuffer
+
import java.util.concurrent.TimeoutException
import akka.actor.ActorNotFound
@@ -24,7 +26,7 @@ import akka.actor.ActorNotFound
import org.apache.spark._
import org.apache.spark.rpc.RpcEnv
import org.apache.spark.scheduler.MapStatus
-import org.apache.spark.storage.BlockManagerId
+import org.apache.spark.storage.{BlockManagerId, ShuffleBlockId}
import org.apache.spark.SSLSampleConfigs._
@@ -107,8 +109,9 @@ class AkkaUtilsSuite extends SparkFunSuite with LocalSparkContext with ResetSyst
slaveTracker.updateEpoch(masterTracker.getEpoch)
// this should succeed since security off
- assert(slaveTracker.getServerStatuses(10, 0).toSeq ===
- Seq((BlockManagerId("a", "hostA", 1000), size1000)))
+ assert(slaveTracker.getMapSizesByExecutorId(10, 0).toSeq ===
+ Seq((BlockManagerId("a", "hostA", 1000),
+ ArrayBuffer((ShuffleBlockId(10, 0, 0), size1000)))))
rpcEnv.shutdown()
slaveRpcEnv.shutdown()
@@ -153,8 +156,9 @@ class AkkaUtilsSuite extends SparkFunSuite with LocalSparkContext with ResetSyst
slaveTracker.updateEpoch(masterTracker.getEpoch)
// this should succeed since security on and passwords match
- assert(slaveTracker.getServerStatuses(10, 0).toSeq ===
- Seq((BlockManagerId("a", "hostA", 1000), size1000)))
+ assert(slaveTracker.getMapSizesByExecutorId(10, 0) ===
+ Seq((BlockManagerId("a", "hostA", 1000),
+ ArrayBuffer((ShuffleBlockId(10, 0, 0), size1000)))))
rpcEnv.shutdown()
slaveRpcEnv.shutdown()
@@ -232,8 +236,8 @@ class AkkaUtilsSuite extends SparkFunSuite with LocalSparkContext with ResetSyst
slaveTracker.updateEpoch(masterTracker.getEpoch)
// this should succeed since security off
- assert(slaveTracker.getServerStatuses(10, 0).toSeq ===
- Seq((BlockManagerId("a", "hostA", 1000), size1000)))
+ assert(slaveTracker.getMapSizesByExecutorId(10, 0) ===
+ Seq((BlockManagerId("a", "hostA", 1000), ArrayBuffer((ShuffleBlockId(10, 0, 0), size1000)))))
rpcEnv.shutdown()
slaveRpcEnv.shutdown()
@@ -278,8 +282,8 @@ class AkkaUtilsSuite extends SparkFunSuite with LocalSparkContext with ResetSyst
masterTracker.incrementEpoch()
slaveTracker.updateEpoch(masterTracker.getEpoch)
- assert(slaveTracker.getServerStatuses(10, 0).toSeq ===
- Seq((BlockManagerId("a", "hostA", 1000), size1000)))
+ assert(slaveTracker.getMapSizesByExecutorId(10, 0) ===
+ Seq((BlockManagerId("a", "hostA", 1000), ArrayBuffer((ShuffleBlockId(10, 0, 0), size1000)))))
rpcEnv.shutdown()
slaveRpcEnv.shutdown()
diff --git a/core/src/test/scala/org/apache/spark/util/ClosureCleanerSuite2.scala b/core/src/test/scala/org/apache/spark/util/ClosureCleanerSuite2.scala
index 3147c937769d2..a829b099025e9 100644
--- a/core/src/test/scala/org/apache/spark/util/ClosureCleanerSuite2.scala
+++ b/core/src/test/scala/org/apache/spark/util/ClosureCleanerSuite2.scala
@@ -120,8 +120,8 @@ class ClosureCleanerSuite2 extends SparkFunSuite with BeforeAndAfterAll with Pri
// Accessors for private methods
private val _isClosure = PrivateMethod[Boolean]('isClosure)
private val _getInnerClosureClasses = PrivateMethod[List[Class[_]]]('getInnerClosureClasses)
- private val _getOuterClasses = PrivateMethod[List[Class[_]]]('getOuterClasses)
- private val _getOuterObjects = PrivateMethod[List[AnyRef]]('getOuterObjects)
+ private val _getOuterClassesAndObjects =
+ PrivateMethod[(List[Class[_]], List[AnyRef])]('getOuterClassesAndObjects)
private def isClosure(obj: AnyRef): Boolean = {
ClosureCleaner invokePrivate _isClosure(obj)
@@ -131,12 +131,8 @@ class ClosureCleanerSuite2 extends SparkFunSuite with BeforeAndAfterAll with Pri
ClosureCleaner invokePrivate _getInnerClosureClasses(closure)
}
- private def getOuterClasses(closure: AnyRef): List[Class[_]] = {
- ClosureCleaner invokePrivate _getOuterClasses(closure)
- }
-
- private def getOuterObjects(closure: AnyRef): List[AnyRef] = {
- ClosureCleaner invokePrivate _getOuterObjects(closure)
+ private def getOuterClassesAndObjects(closure: AnyRef): (List[Class[_]], List[AnyRef]) = {
+ ClosureCleaner invokePrivate _getOuterClassesAndObjects(closure)
}
test("get inner closure classes") {
@@ -171,14 +167,11 @@ class ClosureCleanerSuite2 extends SparkFunSuite with BeforeAndAfterAll with Pri
val closure2 = () => localValue
val closure3 = () => someSerializableValue
val closure4 = () => someSerializableMethod()
- val outerClasses1 = getOuterClasses(closure1)
- val outerClasses2 = getOuterClasses(closure2)
- val outerClasses3 = getOuterClasses(closure3)
- val outerClasses4 = getOuterClasses(closure4)
- val outerObjects1 = getOuterObjects(closure1)
- val outerObjects2 = getOuterObjects(closure2)
- val outerObjects3 = getOuterObjects(closure3)
- val outerObjects4 = getOuterObjects(closure4)
+
+ val (outerClasses1, outerObjects1) = getOuterClassesAndObjects(closure1)
+ val (outerClasses2, outerObjects2) = getOuterClassesAndObjects(closure2)
+ val (outerClasses3, outerObjects3) = getOuterClassesAndObjects(closure3)
+ val (outerClasses4, outerObjects4) = getOuterClassesAndObjects(closure4)
// The classes and objects should have the same size
assert(outerClasses1.size === outerObjects1.size)
@@ -211,10 +204,8 @@ class ClosureCleanerSuite2 extends SparkFunSuite with BeforeAndAfterAll with Pri
val x = 1
val closure1 = () => 1
val closure2 = () => x
- val outerClasses1 = getOuterClasses(closure1)
- val outerClasses2 = getOuterClasses(closure2)
- val outerObjects1 = getOuterObjects(closure1)
- val outerObjects2 = getOuterObjects(closure2)
+ val (outerClasses1, outerObjects1) = getOuterClassesAndObjects(closure1)
+ val (outerClasses2, outerObjects2) = getOuterClassesAndObjects(closure2)
assert(outerClasses1.size === outerObjects1.size)
assert(outerClasses2.size === outerObjects2.size)
// These inner closures only reference local variables, and so do not have $outer pointers
@@ -227,12 +218,9 @@ class ClosureCleanerSuite2 extends SparkFunSuite with BeforeAndAfterAll with Pri
val closure1 = () => 1
val closure2 = () => y
val closure3 = () => localValue
- val outerClasses1 = getOuterClasses(closure1)
- val outerClasses2 = getOuterClasses(closure2)
- val outerClasses3 = getOuterClasses(closure3)
- val outerObjects1 = getOuterObjects(closure1)
- val outerObjects2 = getOuterObjects(closure2)
- val outerObjects3 = getOuterObjects(closure3)
+ val (outerClasses1, outerObjects1) = getOuterClassesAndObjects(closure1)
+ val (outerClasses2, outerObjects2) = getOuterClassesAndObjects(closure2)
+ val (outerClasses3, outerObjects3) = getOuterClassesAndObjects(closure3)
assert(outerClasses1.size === outerObjects1.size)
assert(outerClasses2.size === outerObjects2.size)
assert(outerClasses3.size === outerObjects3.size)
@@ -265,9 +253,9 @@ class ClosureCleanerSuite2 extends SparkFunSuite with BeforeAndAfterAll with Pri
val closure1 = () => 1
val closure2 = () => localValue
val closure3 = () => someSerializableValue
- val outerClasses1 = getOuterClasses(closure1)
- val outerClasses2 = getOuterClasses(closure2)
- val outerClasses3 = getOuterClasses(closure3)
+ val (outerClasses1, _) = getOuterClassesAndObjects(closure1)
+ val (outerClasses2, _) = getOuterClassesAndObjects(closure2)
+ val (outerClasses3, _) = getOuterClassesAndObjects(closure3)
val fields1 = findAccessedFields(closure1, outerClasses1, findTransitively = false)
val fields2 = findAccessedFields(closure2, outerClasses2, findTransitively = false)
@@ -307,10 +295,10 @@ class ClosureCleanerSuite2 extends SparkFunSuite with BeforeAndAfterAll with Pri
val closure2 = () => a
val closure3 = () => localValue
val closure4 = () => someSerializableValue
- val outerClasses1 = getOuterClasses(closure1)
- val outerClasses2 = getOuterClasses(closure2)
- val outerClasses3 = getOuterClasses(closure3)
- val outerClasses4 = getOuterClasses(closure4)
+ val (outerClasses1, _) = getOuterClassesAndObjects(closure1)
+ val (outerClasses2, _) = getOuterClassesAndObjects(closure2)
+ val (outerClasses3, _) = getOuterClassesAndObjects(closure3)
+ val (outerClasses4, _) = getOuterClassesAndObjects(closure4)
// First, find only fields accessed directly, not transitively, by these closures
val fields1 = findAccessedFields(closure1, outerClasses1, findTransitively = false)
diff --git a/core/src/test/scala/org/apache/spark/util/JsonProtocolSuite.scala b/core/src/test/scala/org/apache/spark/util/JsonProtocolSuite.scala
index e0ef9c70a5fc3..dde95f3778434 100644
--- a/core/src/test/scala/org/apache/spark/util/JsonProtocolSuite.scala
+++ b/core/src/test/scala/org/apache/spark/util/JsonProtocolSuite.scala
@@ -83,6 +83,9 @@ class JsonProtocolSuite extends SparkFunSuite {
val executorAdded = SparkListenerExecutorAdded(executorAddedTime, "exec1",
new ExecutorInfo("Hostee.awesome.com", 11, logUrlMap))
val executorRemoved = SparkListenerExecutorRemoved(executorRemovedTime, "exec2", "test reason")
+ val executorMetricsUpdate = SparkListenerExecutorMetricsUpdate("exec3", Seq(
+ (1L, 2, 3, makeTaskMetrics(300L, 400L, 500L, 600L, 700, 800,
+ hasHadoopInput = true, hasOutput = true))))
testEvent(stageSubmitted, stageSubmittedJsonString)
testEvent(stageCompleted, stageCompletedJsonString)
@@ -102,6 +105,7 @@ class JsonProtocolSuite extends SparkFunSuite {
testEvent(applicationEnd, applicationEndJsonString)
testEvent(executorAdded, executorAddedJsonString)
testEvent(executorRemoved, executorRemovedJsonString)
+ testEvent(executorMetricsUpdate, executorMetricsUpdateJsonString)
}
test("Dependent Classes") {
@@ -440,10 +444,20 @@ class JsonProtocolSuite extends SparkFunSuite {
case (e1: SparkListenerEnvironmentUpdate, e2: SparkListenerEnvironmentUpdate) =>
assertEquals(e1.environmentDetails, e2.environmentDetails)
case (e1: SparkListenerExecutorAdded, e2: SparkListenerExecutorAdded) =>
- assert(e1.executorId == e1.executorId)
+ assert(e1.executorId === e1.executorId)
assertEquals(e1.executorInfo, e2.executorInfo)
case (e1: SparkListenerExecutorRemoved, e2: SparkListenerExecutorRemoved) =>
- assert(e1.executorId == e1.executorId)
+ assert(e1.executorId === e1.executorId)
+ case (e1: SparkListenerExecutorMetricsUpdate, e2: SparkListenerExecutorMetricsUpdate) =>
+ assert(e1.execId === e2.execId)
+ assertSeqEquals[(Long, Int, Int, TaskMetrics)](e1.taskMetrics, e2.taskMetrics, (a, b) => {
+ val (taskId1, stageId1, stageAttemptId1, metrics1) = a
+ val (taskId2, stageId2, stageAttemptId2, metrics2) = b
+ assert(taskId1 === taskId2)
+ assert(stageId1 === stageId2)
+ assert(stageAttemptId1 === stageAttemptId2)
+ assertEquals(metrics1, metrics2)
+ })
case (e1, e2) =>
assert(e1 === e2)
case _ => fail("Events don't match in types!")
@@ -1598,4 +1612,55 @@ class JsonProtocolSuite extends SparkFunSuite {
| "Removed Reason": "test reason"
|}
"""
+
+ private val executorMetricsUpdateJsonString =
+ s"""
+ |{
+ | "Event": "SparkListenerExecutorMetricsUpdate",
+ | "Executor ID": "exec3",
+ | "Metrics Updated": [
+ | {
+ | "Task ID": 1,
+ | "Stage ID": 2,
+ | "Stage Attempt ID": 3,
+ | "Task Metrics": {
+ | "Host Name": "localhost",
+ | "Executor Deserialize Time": 300,
+ | "Executor Run Time": 400,
+ | "Result Size": 500,
+ | "JVM GC Time": 600,
+ | "Result Serialization Time": 700,
+ | "Memory Bytes Spilled": 800,
+ | "Disk Bytes Spilled": 0,
+ | "Input Metrics": {
+ | "Data Read Method": "Hadoop",
+ | "Bytes Read": 2100,
+ | "Records Read": 21
+ | },
+ | "Output Metrics": {
+ | "Data Write Method": "Hadoop",
+ | "Bytes Written": 1200,
+ | "Records Written": 12
+ | },
+ | "Updated Blocks": [
+ | {
+ | "Block ID": "rdd_0_0",
+ | "Status": {
+ | "Storage Level": {
+ | "Use Disk": true,
+ | "Use Memory": true,
+ | "Use ExternalBlockStore": false,
+ | "Deserialized": false,
+ | "Replication": 2
+ | },
+ | "Memory Size": 0,
+ | "ExternalBlockStore Size": 0,
+ | "Disk Size": 0
+ | }
+ | }
+ | ]
+ | }
+ | }]
+ |}
+ """.stripMargin
}
diff --git a/core/src/test/scala/org/apache/spark/util/MutableURLClassLoaderSuite.scala b/core/src/test/scala/org/apache/spark/util/MutableURLClassLoaderSuite.scala
index 42125547436cb..d3d464e84ffd7 100644
--- a/core/src/test/scala/org/apache/spark/util/MutableURLClassLoaderSuite.scala
+++ b/core/src/test/scala/org/apache/spark/util/MutableURLClassLoaderSuite.scala
@@ -84,7 +84,9 @@ class MutableURLClassLoaderSuite extends SparkFunSuite {
try {
sc.makeRDD(1 to 5, 2).mapPartitions { x =>
val loader = Thread.currentThread().getContextClassLoader
+ // scalastyle:off classforname
Class.forName(className, true, loader).newInstance()
+ // scalastyle:on classforname
Seq().iterator
}.count()
}
diff --git a/core/src/test/scala/org/apache/spark/util/UtilsSuite.scala b/core/src/test/scala/org/apache/spark/util/UtilsSuite.scala
index c7638507c88c6..8f7e402d5f2a6 100644
--- a/core/src/test/scala/org/apache/spark/util/UtilsSuite.scala
+++ b/core/src/test/scala/org/apache/spark/util/UtilsSuite.scala
@@ -18,6 +18,7 @@
package org.apache.spark.util
import java.io.{File, ByteArrayOutputStream, ByteArrayInputStream, FileOutputStream}
+import java.lang.{Double => JDouble, Float => JFloat}
import java.net.{BindException, ServerSocket, URI}
import java.nio.{ByteBuffer, ByteOrder}
import java.text.DecimalFormatSymbols
@@ -689,4 +690,34 @@ class UtilsSuite extends SparkFunSuite with ResetSystemProperties with Logging {
// scalastyle:on println
assert(buffer.toString === "t circular test circular\n")
}
+
+ test("nanSafeCompareDoubles") {
+ def shouldMatchDefaultOrder(a: Double, b: Double): Unit = {
+ assert(Utils.nanSafeCompareDoubles(a, b) === JDouble.compare(a, b))
+ assert(Utils.nanSafeCompareDoubles(b, a) === JDouble.compare(b, a))
+ }
+ shouldMatchDefaultOrder(0d, 0d)
+ shouldMatchDefaultOrder(0d, 1d)
+ shouldMatchDefaultOrder(Double.MinValue, Double.MaxValue)
+ assert(Utils.nanSafeCompareDoubles(Double.NaN, Double.NaN) === 0)
+ assert(Utils.nanSafeCompareDoubles(Double.NaN, Double.PositiveInfinity) === 1)
+ assert(Utils.nanSafeCompareDoubles(Double.NaN, Double.NegativeInfinity) === 1)
+ assert(Utils.nanSafeCompareDoubles(Double.PositiveInfinity, Double.NaN) === -1)
+ assert(Utils.nanSafeCompareDoubles(Double.NegativeInfinity, Double.NaN) === -1)
+ }
+
+ test("nanSafeCompareFloats") {
+ def shouldMatchDefaultOrder(a: Float, b: Float): Unit = {
+ assert(Utils.nanSafeCompareFloats(a, b) === JFloat.compare(a, b))
+ assert(Utils.nanSafeCompareFloats(b, a) === JFloat.compare(b, a))
+ }
+ shouldMatchDefaultOrder(0f, 0f)
+ shouldMatchDefaultOrder(1f, 1f)
+ shouldMatchDefaultOrder(Float.MinValue, Float.MaxValue)
+ assert(Utils.nanSafeCompareFloats(Float.NaN, Float.NaN) === 0)
+ assert(Utils.nanSafeCompareFloats(Float.NaN, Float.PositiveInfinity) === 1)
+ assert(Utils.nanSafeCompareFloats(Float.NaN, Float.NegativeInfinity) === 1)
+ assert(Utils.nanSafeCompareFloats(Float.PositiveInfinity, Float.NaN) === -1)
+ assert(Utils.nanSafeCompareFloats(Float.NegativeInfinity, Float.NaN) === -1)
+ }
}
diff --git a/core/src/test/scala/org/apache/spark/util/collection/ExternalAppendOnlyMapSuite.scala b/core/src/test/scala/org/apache/spark/util/collection/ExternalAppendOnlyMapSuite.scala
index 79eba61a87251..9c362f0de7076 100644
--- a/core/src/test/scala/org/apache/spark/util/collection/ExternalAppendOnlyMapSuite.scala
+++ b/core/src/test/scala/org/apache/spark/util/collection/ExternalAppendOnlyMapSuite.scala
@@ -244,7 +244,7 @@ class ExternalAppendOnlyMapSuite extends SparkFunSuite with LocalSparkContext {
private def testSimpleSpilling(codec: Option[String] = None): Unit = {
val conf = createSparkConf(loadDefaults = true, codec) // Load defaults for Spark home
conf.set("spark.shuffle.memoryFraction", "0.001")
- sc = new SparkContext("local-cluster[1,1,512]", "test", conf)
+ sc = new SparkContext("local-cluster[1,1,1024]", "test", conf)
// reduceByKey - should spill ~8 times
val rddA = sc.parallelize(0 until 100000).map(i => (i/2, i))
@@ -292,7 +292,7 @@ class ExternalAppendOnlyMapSuite extends SparkFunSuite with LocalSparkContext {
test("spilling with hash collisions") {
val conf = createSparkConf(loadDefaults = true)
conf.set("spark.shuffle.memoryFraction", "0.001")
- sc = new SparkContext("local-cluster[1,1,512]", "test", conf)
+ sc = new SparkContext("local-cluster[1,1,1024]", "test", conf)
val map = createExternalMap[String]
val collisionPairs = Seq(
@@ -341,7 +341,7 @@ class ExternalAppendOnlyMapSuite extends SparkFunSuite with LocalSparkContext {
test("spilling with many hash collisions") {
val conf = createSparkConf(loadDefaults = true)
conf.set("spark.shuffle.memoryFraction", "0.0001")
- sc = new SparkContext("local-cluster[1,1,512]", "test", conf)
+ sc = new SparkContext("local-cluster[1,1,1024]", "test", conf)
val map = new ExternalAppendOnlyMap[FixedHashObject, Int, Int](_ => 1, _ + _, _ + _)
// Insert 10 copies each of lots of objects whose hash codes are either 0 or 1. This causes
@@ -366,7 +366,7 @@ class ExternalAppendOnlyMapSuite extends SparkFunSuite with LocalSparkContext {
test("spilling with hash collisions using the Int.MaxValue key") {
val conf = createSparkConf(loadDefaults = true)
conf.set("spark.shuffle.memoryFraction", "0.001")
- sc = new SparkContext("local-cluster[1,1,512]", "test", conf)
+ sc = new SparkContext("local-cluster[1,1,1024]", "test", conf)
val map = createExternalMap[Int]
(1 to 100000).foreach { i => map.insert(i, i) }
@@ -383,7 +383,7 @@ class ExternalAppendOnlyMapSuite extends SparkFunSuite with LocalSparkContext {
test("spilling with null keys and values") {
val conf = createSparkConf(loadDefaults = true)
conf.set("spark.shuffle.memoryFraction", "0.001")
- sc = new SparkContext("local-cluster[1,1,512]", "test", conf)
+ sc = new SparkContext("local-cluster[1,1,1024]", "test", conf)
val map = createExternalMap[Int]
map.insertAll((1 to 100000).iterator.map(i => (i, i)))
diff --git a/core/src/test/scala/org/apache/spark/util/collection/ExternalSorterSuite.scala b/core/src/test/scala/org/apache/spark/util/collection/ExternalSorterSuite.scala
index 9cefa612f5491..986cd8623d145 100644
--- a/core/src/test/scala/org/apache/spark/util/collection/ExternalSorterSuite.scala
+++ b/core/src/test/scala/org/apache/spark/util/collection/ExternalSorterSuite.scala
@@ -176,7 +176,7 @@ class ExternalSorterSuite extends SparkFunSuite with LocalSparkContext {
def testSpillingInLocalCluster(conf: SparkConf) {
conf.set("spark.shuffle.memoryFraction", "0.001")
conf.set("spark.shuffle.manager", "org.apache.spark.shuffle.sort.SortShuffleManager")
- sc = new SparkContext("local-cluster[1,1,512]", "test", conf)
+ sc = new SparkContext("local-cluster[1,1,1024]", "test", conf)
// reduceByKey - should spill ~8 times
val rddA = sc.parallelize(0 until 100000).map(i => (i/2, i))
@@ -254,7 +254,7 @@ class ExternalSorterSuite extends SparkFunSuite with LocalSparkContext {
def spillingInLocalClusterWithManyReduceTasks(conf: SparkConf) {
conf.set("spark.shuffle.memoryFraction", "0.001")
conf.set("spark.shuffle.manager", "org.apache.spark.shuffle.sort.SortShuffleManager")
- sc = new SparkContext("local-cluster[2,1,512]", "test", conf)
+ sc = new SparkContext("local-cluster[2,1,1024]", "test", conf)
// reduceByKey - should spill ~4 times per executor
val rddA = sc.parallelize(0 until 100000).map(i => (i/2, i))
@@ -554,7 +554,7 @@ class ExternalSorterSuite extends SparkFunSuite with LocalSparkContext {
test("spilling with hash collisions") {
val conf = createSparkConf(true, false)
conf.set("spark.shuffle.memoryFraction", "0.001")
- sc = new SparkContext("local-cluster[1,1,512]", "test", conf)
+ sc = new SparkContext("local-cluster[1,1,1024]", "test", conf)
def createCombiner(i: String): ArrayBuffer[String] = ArrayBuffer[String](i)
def mergeValue(buffer: ArrayBuffer[String], i: String): ArrayBuffer[String] = buffer += i
@@ -611,7 +611,7 @@ class ExternalSorterSuite extends SparkFunSuite with LocalSparkContext {
test("spilling with many hash collisions") {
val conf = createSparkConf(true, false)
conf.set("spark.shuffle.memoryFraction", "0.0001")
- sc = new SparkContext("local-cluster[1,1,512]", "test", conf)
+ sc = new SparkContext("local-cluster[1,1,1024]", "test", conf)
val agg = new Aggregator[FixedHashObject, Int, Int](_ => 1, _ + _, _ + _)
val sorter = new ExternalSorter[FixedHashObject, Int, Int](Some(agg), None, None, None)
@@ -634,7 +634,7 @@ class ExternalSorterSuite extends SparkFunSuite with LocalSparkContext {
test("spilling with hash collisions using the Int.MaxValue key") {
val conf = createSparkConf(true, false)
conf.set("spark.shuffle.memoryFraction", "0.001")
- sc = new SparkContext("local-cluster[1,1,512]", "test", conf)
+ sc = new SparkContext("local-cluster[1,1,1024]", "test", conf)
def createCombiner(i: Int): ArrayBuffer[Int] = ArrayBuffer[Int](i)
def mergeValue(buffer: ArrayBuffer[Int], i: Int): ArrayBuffer[Int] = buffer += i
@@ -658,7 +658,7 @@ class ExternalSorterSuite extends SparkFunSuite with LocalSparkContext {
test("spilling with null keys and values") {
val conf = createSparkConf(true, false)
conf.set("spark.shuffle.memoryFraction", "0.001")
- sc = new SparkContext("local-cluster[1,1,512]", "test", conf)
+ sc = new SparkContext("local-cluster[1,1,1024]", "test", conf)
def createCombiner(i: String): ArrayBuffer[String] = ArrayBuffer[String](i)
def mergeValue(buffer: ArrayBuffer[String], i: String): ArrayBuffer[String] = buffer += i
@@ -695,7 +695,7 @@ class ExternalSorterSuite extends SparkFunSuite with LocalSparkContext {
def sortWithoutBreakingSortingContracts(conf: SparkConf) {
conf.set("spark.shuffle.memoryFraction", "0.01")
conf.set("spark.shuffle.manager", "sort")
- sc = new SparkContext("local-cluster[1,1,512]", "test", conf)
+ sc = new SparkContext("local-cluster[1,1,1024]", "test", conf)
// Using wrongOrdering to show integer overflow introduced exception.
val rand = new Random(100L)
diff --git a/core/src/test/scala/org/apache/spark/util/collection/PartitionedSerializedPairBufferSuite.scala b/core/src/test/scala/org/apache/spark/util/collection/PartitionedSerializedPairBufferSuite.scala
index 6d2459d48d326..3b67f6206495a 100644
--- a/core/src/test/scala/org/apache/spark/util/collection/PartitionedSerializedPairBufferSuite.scala
+++ b/core/src/test/scala/org/apache/spark/util/collection/PartitionedSerializedPairBufferSuite.scala
@@ -17,15 +17,20 @@
package org.apache.spark.util.collection
-import java.io.{ByteArrayInputStream, ByteArrayOutputStream, InputStream}
+import java.io.{ByteArrayInputStream, ByteArrayOutputStream}
import com.google.common.io.ByteStreams
+import org.mockito.Matchers.any
+import org.mockito.Mockito._
+import org.mockito.Mockito.RETURNS_SMART_NULLS
+import org.mockito.invocation.InvocationOnMock
+import org.mockito.stubbing.Answer
import org.scalatest.Matchers._
import org.apache.spark.{SparkConf, SparkFunSuite}
import org.apache.spark.serializer.KryoSerializer
-import org.apache.spark.storage.{FileSegment, BlockObjectWriter}
+import org.apache.spark.storage.DiskBlockObjectWriter
class PartitionedSerializedPairBufferSuite extends SparkFunSuite {
test("OrderedInputStream single record") {
@@ -79,13 +84,13 @@ class PartitionedSerializedPairBufferSuite extends SparkFunSuite {
val struct = SomeStruct("something", 5)
buffer.insert(4, 10, struct)
val it = buffer.destructiveSortedWritablePartitionedIterator(None)
- val writer = new SimpleBlockObjectWriter
+ val (writer, baos) = createMockWriter()
assert(it.hasNext)
it.nextPartition should be (4)
it.writeNext(writer)
assert(!it.hasNext)
- val stream = serializerInstance.deserializeStream(writer.getInputStream)
+ val stream = serializerInstance.deserializeStream(new ByteArrayInputStream(baos.toByteArray))
stream.readObject[AnyRef]() should be (10)
stream.readObject[AnyRef]() should be (struct)
}
@@ -101,7 +106,7 @@ class PartitionedSerializedPairBufferSuite extends SparkFunSuite {
buffer.insert(5, 3, struct3)
val it = buffer.destructiveSortedWritablePartitionedIterator(None)
- val writer = new SimpleBlockObjectWriter
+ val (writer, baos) = createMockWriter()
assert(it.hasNext)
it.nextPartition should be (4)
it.writeNext(writer)
@@ -113,7 +118,7 @@ class PartitionedSerializedPairBufferSuite extends SparkFunSuite {
it.writeNext(writer)
assert(!it.hasNext)
- val stream = serializerInstance.deserializeStream(writer.getInputStream)
+ val stream = serializerInstance.deserializeStream(new ByteArrayInputStream(baos.toByteArray))
val iter = stream.asIterator
iter.next() should be (2)
iter.next() should be (struct2)
@@ -123,26 +128,21 @@ class PartitionedSerializedPairBufferSuite extends SparkFunSuite {
iter.next() should be (struct1)
assert(!iter.hasNext)
}
-}
-
-case class SomeStruct(val str: String, val num: Int)
-
-class SimpleBlockObjectWriter extends BlockObjectWriter(null) {
- val baos = new ByteArrayOutputStream()
- override def write(bytes: Array[Byte], offs: Int, len: Int): Unit = {
- baos.write(bytes, offs, len)
+ def createMockWriter(): (DiskBlockObjectWriter, ByteArrayOutputStream) = {
+ val writer = mock(classOf[DiskBlockObjectWriter], RETURNS_SMART_NULLS)
+ val baos = new ByteArrayOutputStream()
+ when(writer.write(any(), any(), any())).thenAnswer(new Answer[Unit] {
+ override def answer(invocationOnMock: InvocationOnMock): Unit = {
+ val args = invocationOnMock.getArguments
+ val bytes = args(0).asInstanceOf[Array[Byte]]
+ val offset = args(1).asInstanceOf[Int]
+ val length = args(2).asInstanceOf[Int]
+ baos.write(bytes, offset, length)
+ }
+ })
+ (writer, baos)
}
-
- def getInputStream(): InputStream = new ByteArrayInputStream(baos.toByteArray)
-
- override def open(): BlockObjectWriter = this
- override def close(): Unit = { }
- override def isOpen: Boolean = true
- override def commitAndClose(): Unit = { }
- override def revertPartialWritesAndClose(): Unit = { }
- override def fileSegment(): FileSegment = null
- override def write(key: Any, value: Any): Unit = { }
- override def recordWritten(): Unit = { }
- override def write(b: Int): Unit = { }
}
+
+case class SomeStruct(str: String, num: Int)
diff --git a/core/src/test/scala/org/apache/spark/util/collection/unsafe/sort/PrefixComparatorsSuite.scala b/core/src/test/scala/org/apache/spark/util/collection/unsafe/sort/PrefixComparatorsSuite.scala
index dd505dfa7d758..dc03e374b51db 100644
--- a/core/src/test/scala/org/apache/spark/util/collection/unsafe/sort/PrefixComparatorsSuite.scala
+++ b/core/src/test/scala/org/apache/spark/util/collection/unsafe/sort/PrefixComparatorsSuite.scala
@@ -47,4 +47,29 @@ class PrefixComparatorsSuite extends SparkFunSuite with PropertyChecks {
forAll (regressionTests) { (s1: String, s2: String) => testPrefixComparison(s1, s2) }
forAll { (s1: String, s2: String) => testPrefixComparison(s1, s2) }
}
+
+ test("float prefix comparator handles NaN properly") {
+ val nan1: Float = java.lang.Float.intBitsToFloat(0x7f800001)
+ val nan2: Float = java.lang.Float.intBitsToFloat(0x7fffffff)
+ assert(nan1.isNaN)
+ assert(nan2.isNaN)
+ val nan1Prefix = PrefixComparators.FLOAT.computePrefix(nan1)
+ val nan2Prefix = PrefixComparators.FLOAT.computePrefix(nan2)
+ assert(nan1Prefix === nan2Prefix)
+ val floatMaxPrefix = PrefixComparators.FLOAT.computePrefix(Float.MaxValue)
+ assert(PrefixComparators.FLOAT.compare(nan1Prefix, floatMaxPrefix) === 1)
+ }
+
+ test("double prefix comparator handles NaNs properly") {
+ val nan1: Double = java.lang.Double.longBitsToDouble(0x7ff0000000000001L)
+ val nan2: Double = java.lang.Double.longBitsToDouble(0x7fffffffffffffffL)
+ assert(nan1.isNaN)
+ assert(nan2.isNaN)
+ val nan1Prefix = PrefixComparators.DOUBLE.computePrefix(nan1)
+ val nan2Prefix = PrefixComparators.DOUBLE.computePrefix(nan2)
+ assert(nan1Prefix === nan2Prefix)
+ val doubleMaxPrefix = PrefixComparators.DOUBLE.computePrefix(Double.MaxValue)
+ assert(PrefixComparators.DOUBLE.compare(nan1Prefix, doubleMaxPrefix) === 1)
+ }
+
}
diff --git a/data/mllib/sample_naive_bayes_data.txt b/data/mllib/sample_naive_bayes_data.txt
index 981da382d6ac8..bd22bea3a59d6 100644
--- a/data/mllib/sample_naive_bayes_data.txt
+++ b/data/mllib/sample_naive_bayes_data.txt
@@ -1,6 +1,12 @@
0,1 0 0
0,2 0 0
+0,3 0 0
+0,4 0 0
1,0 1 0
1,0 2 0
+1,0 3 0
+1,0 4 0
2,0 0 1
2,0 0 2
+2,0 0 3
+2,0 0 4
\ No newline at end of file
diff --git a/dev/change-scala-version.sh b/dev/change-scala-version.sh
new file mode 100755
index 0000000000000..b81c00c9d6d9d
--- /dev/null
+++ b/dev/change-scala-version.sh
@@ -0,0 +1,66 @@
+#!/usr/bin/env bash
+
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+set -e
+
+usage() {
+ echo "Usage: $(basename $0) " 1>&2
+ exit 1
+}
+
+if [ $# -ne 1 ]; then
+ usage
+fi
+
+TO_VERSION=$1
+
+VALID_VERSIONS=( 2.10 2.11 )
+
+check_scala_version() {
+ for i in ${VALID_VERSIONS[*]}; do [ $i = "$1" ] && return 0; done
+ echo "Invalid Scala version: $1. Valid versions: ${VALID_VERSIONS[*]}" 1>&2
+ exit 1
+}
+
+check_scala_version "$TO_VERSION"
+
+if [ $TO_VERSION = "2.11" ]; then
+ FROM_VERSION="2.10"
+else
+ FROM_VERSION="2.11"
+fi
+
+sed_i() {
+ sed -e "$1" "$2" > "$2.tmp" && mv "$2.tmp" "$2"
+}
+
+export -f sed_i
+
+BASEDIR=$(dirname $0)/..
+find "$BASEDIR" -name 'pom.xml' -not -path '*target*' -print \
+ -exec bash -c "sed_i 's/\(artifactId.*\)_'$FROM_VERSION'/\1_'$TO_VERSION'/g' {}" \;
+
+# Also update in parent POM
+# Match any scala binary version to ensure idempotency
+sed_i '1,/[0-9]*\.[0-9]*[0-9]*\.[0-9]*'$TO_VERSION'' \
+ "$BASEDIR/pom.xml"
+
+# Update source of scaladocs
+echo "$BASEDIR/docs/_plugins/copy_api_dirs.rb"
+sed_i 's/scala\-'$FROM_VERSION'/scala\-'$TO_VERSION'/' "$BASEDIR/docs/_plugins/copy_api_dirs.rb"
diff --git a/dev/change-version-to-2.10.sh b/dev/change-version-to-2.10.sh
deleted file mode 100755
index c4adb1f96b7d3..0000000000000
--- a/dev/change-version-to-2.10.sh
+++ /dev/null
@@ -1,26 +0,0 @@
-#!/usr/bin/env bash
-
-#
-# Licensed to the Apache Software Foundation (ASF) under one or more
-# contributor license agreements. See the NOTICE file distributed with
-# this work for additional information regarding copyright ownership.
-# The ASF licenses this file to You under the Apache License, Version 2.0
-# (the "License"); you may not use this file except in compliance with
-# the License. You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-#
-
-# Note that this will not necessarily work as intended with non-GNU sed (e.g. OS X)
-BASEDIR=$(dirname $0)/..
-find $BASEDIR -name 'pom.xml' | grep -v target \
- | xargs -I {} sed -i -e 's/\(artifactId.*\)_2.11/\1_2.10/g' {}
-
-# Also update in parent POM
-sed -i -e '0,/2.112.10' $BASEDIR/pom.xml
diff --git a/dev/change-version-to-2.11.sh b/dev/change-version-to-2.11.sh
deleted file mode 100755
index d370019dec34d..0000000000000
--- a/dev/change-version-to-2.11.sh
+++ /dev/null
@@ -1,26 +0,0 @@
-#!/usr/bin/env bash
-
-#
-# Licensed to the Apache Software Foundation (ASF) under one or more
-# contributor license agreements. See the NOTICE file distributed with
-# this work for additional information regarding copyright ownership.
-# The ASF licenses this file to You under the Apache License, Version 2.0
-# (the "License"); you may not use this file except in compliance with
-# the License. You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-#
-
-# Note that this will not necessarily work as intended with non-GNU sed (e.g. OS X)
-BASEDIR=$(dirname $0)/..
-find $BASEDIR -name 'pom.xml' | grep -v target \
- | xargs -I {} sed -i -e 's/\(artifactId.*\)_2.10/\1_2.11/g' {}
-
-# Also update in parent POM
-sed -i -e '0,/2.102.11' $BASEDIR/pom.xml
diff --git a/dev/create-release/create-release.sh b/dev/create-release/create-release.sh
index 30190dcd41ec5..86a7a4068c40e 100755
--- a/dev/create-release/create-release.sh
+++ b/dev/create-release/create-release.sh
@@ -122,13 +122,13 @@ if [[ ! "$@" =~ --skip-publish ]]; then
-Phive-thriftserver -Phadoop-2.2 -Pspark-ganglia-lgpl -Pkinesis-asl \
clean install
- ./dev/change-version-to-2.11.sh
+ ./dev/change-scala-version.sh 2.11
build/mvn -DskipTests -Pyarn -Phive -Prelease\
-Dscala-2.11 -Phadoop-2.2 -Pspark-ganglia-lgpl -Pkinesis-asl \
clean install
- ./dev/change-version-to-2.10.sh
+ ./dev/change-scala-version.sh 2.10
pushd $SPARK_REPO
@@ -205,7 +205,7 @@ if [[ ! "$@" =~ --skip-package ]]; then
# TODO There should probably be a flag to make-distribution to allow 2.11 support
if [[ $FLAGS == *scala-2.11* ]]; then
- ./dev/change-version-to-2.11.sh
+ ./dev/change-scala-version.sh 2.11
fi
export ZINC_PORT=$ZINC_PORT
diff --git a/dev/create-release/known_translations b/dev/create-release/known_translations
index 5f2671a6e5053..e462302f28423 100644
--- a/dev/create-release/known_translations
+++ b/dev/create-release/known_translations
@@ -129,3 +129,12 @@ yongtang - Yong Tang
ypcat - Pei-Lun Lee
zhichao-li - Zhichao Li
zzcclp - Zhichao Zhang
+979969786 - Yuming Wang
+Rosstin - Rosstin Murphy
+ameyc - Amey Chaugule
+animeshbaranawal - Animesh Baranawal
+cafreeman - Chris Freeman
+lee19 - Lee
+lockwobr - Brian Lockwood
+navis - Navis Ryu
+pparkkin - Paavo Parkkinen
diff --git a/dev/lint-python b/dev/lint-python
index 0c3586462cb37..e02dff220eb87 100755
--- a/dev/lint-python
+++ b/dev/lint-python
@@ -21,12 +21,14 @@ SCRIPT_DIR="$( cd "$( dirname "$0" )" && pwd )"
SPARK_ROOT_DIR="$(dirname "$SCRIPT_DIR")"
PATHS_TO_CHECK="./python/pyspark/ ./ec2/spark_ec2.py ./examples/src/main/python/ ./dev/sparktestsupport"
PATHS_TO_CHECK="$PATHS_TO_CHECK ./dev/run-tests.py ./python/run-tests.py"
-PYTHON_LINT_REPORT_PATH="$SPARK_ROOT_DIR/dev/python-lint-report.txt"
+PEP8_REPORT_PATH="$SPARK_ROOT_DIR/dev/pep8-report.txt"
+PYLINT_REPORT_PATH="$SPARK_ROOT_DIR/dev/pylint-report.txt"
+PYLINT_INSTALL_INFO="$SPARK_ROOT_DIR/dev/pylint-info.txt"
cd "$SPARK_ROOT_DIR"
# compileall: https://docs.python.org/2/library/compileall.html
-python -B -m compileall -q -l $PATHS_TO_CHECK > "$PYTHON_LINT_REPORT_PATH"
+python -B -m compileall -q -l $PATHS_TO_CHECK > "$PEP8_REPORT_PATH"
compile_status="${PIPESTATUS[0]}"
# Get pep8 at runtime so that we don't rely on it being installed on the build server.
@@ -47,11 +49,36 @@ if [ ! -e "$PEP8_SCRIPT_PATH" ]; then
fi
fi
+# Easy install pylint in /dev/pylint. To easy_install into a directory, the PYTHONPATH should
+# be set to the directory.
+# dev/pylint should be appended to the PATH variable as well.
+# Jenkins by default installs the pylint3 version, so for now this just checks the code quality
+# of python3.
+export "PYTHONPATH=$SPARK_ROOT_DIR/dev/pylint"
+export "PYLINT_HOME=$PYTHONPATH"
+export "PATH=$PYTHONPATH:$PATH"
+
+if [ ! -d "$PYLINT_HOME" ]; then
+ mkdir "$PYLINT_HOME"
+ # Redirect the annoying pylint installation output.
+ easy_install -d "$PYLINT_HOME" pylint==1.4.4 &>> "$PYLINT_INSTALL_INFO"
+ easy_install_status="$?"
+
+ if [ "$easy_install_status" -ne 0 ]; then
+ echo "Unable to install pylint locally in \"$PYTHONPATH\"."
+ cat "$PYLINT_INSTALL_INFO"
+ exit "$easy_install_status"
+ fi
+
+ rm "$PYLINT_INSTALL_INFO"
+
+fi
+
# There is no need to write this output to a file
#+ first, but we do so so that the check status can
#+ be output before the report, like with the
#+ scalastyle and RAT checks.
-python "$PEP8_SCRIPT_PATH" --ignore=E402,E731,E241,W503,E226 $PATHS_TO_CHECK >> "$PYTHON_LINT_REPORT_PATH"
+python "$PEP8_SCRIPT_PATH" --ignore=E402,E731,E241,W503,E226 $PATHS_TO_CHECK >> "$PEP8_REPORT_PATH"
pep8_status="${PIPESTATUS[0]}"
if [ "$compile_status" -eq 0 -a "$pep8_status" -eq 0 ]; then
@@ -61,13 +88,27 @@ else
fi
if [ "$lint_status" -ne 0 ]; then
- echo "Python lint checks failed."
- cat "$PYTHON_LINT_REPORT_PATH"
+ echo "PEP8 checks failed."
+ cat "$PEP8_REPORT_PATH"
+else
+ echo "PEP8 checks passed."
+fi
+
+rm "$PEP8_REPORT_PATH"
+
+for to_be_checked in "$PATHS_TO_CHECK"
+do
+ pylint --rcfile="$SPARK_ROOT_DIR/pylintrc" $to_be_checked >> "$PYLINT_REPORT_PATH"
+done
+
+if [ "${PIPESTATUS[0]}" -ne 0 ]; then
+ lint_status=1
+ echo "Pylint checks failed."
+ cat "$PYLINT_REPORT_PATH"
else
- echo "Python lint checks passed."
+ echo "Pylint checks passed."
fi
-# rm "$PEP8_SCRIPT_PATH"
-rm "$PYTHON_LINT_REPORT_PATH"
+rm "$PYLINT_REPORT_PATH"
exit "$lint_status"
diff --git a/dev/lint-r.R b/dev/lint-r.R
index dcb1a184291e1..48bd6246096ae 100644
--- a/dev/lint-r.R
+++ b/dev/lint-r.R
@@ -15,15 +15,21 @@
# limitations under the License.
#
+argv <- commandArgs(TRUE)
+SPARK_ROOT_DIR <- as.character(argv[1])
+
# Installs lintr from Github.
# NOTE: The CRAN's version is too old to adapt to our rules.
if ("lintr" %in% row.names(installed.packages()) == FALSE) {
devtools::install_github("jimhester/lintr")
}
-library(lintr)
-argv <- commandArgs(TRUE)
-SPARK_ROOT_DIR <- as.character(argv[1])
+library(lintr)
+library(methods)
+library(testthat)
+if (! library(SparkR, lib.loc = file.path(SPARK_ROOT_DIR, "R", "lib"), logical.return = TRUE)) {
+ stop("You should install SparkR in a local directory with `R/install-dev.sh`.")
+}
path.to.package <- file.path(SPARK_ROOT_DIR, "R", "pkg")
lint_package(path.to.package, cache = FALSE)
diff --git a/dev/merge_spark_pr.py b/dev/merge_spark_pr.py
index 4a17d48d8171d..ad4b76695c9ff 100755
--- a/dev/merge_spark_pr.py
+++ b/dev/merge_spark_pr.py
@@ -130,7 +130,12 @@ def merge_pr(pr_num, target_ref, title, body, pr_repo_desc):
'--pretty=format:%an <%ae>']).split("\n")
distinct_authors = sorted(set(commit_authors),
key=lambda x: commit_authors.count(x), reverse=True)
- primary_author = distinct_authors[0]
+ primary_author = raw_input(
+ "Enter primary author in the format of \"name \" [%s]: " %
+ distinct_authors[0])
+ if primary_author == "":
+ primary_author = distinct_authors[0]
+
commits = run_cmd(['git', 'log', 'HEAD..%s' % pr_branch_name,
'--pretty=format:%h [%an] %s']).split("\n\n")
@@ -281,7 +286,7 @@ def get_version_json(version_str):
resolve = filter(lambda a: a['name'] == "Resolve Issue", asf_jira.transitions(jira_id))[0]
resolution = filter(lambda r: r.raw['name'] == "Fixed", asf_jira.resolutions())[0]
asf_jira.transition_issue(
- jira_id, resolve["id"], fixVersions = jira_fix_versions,
+ jira_id, resolve["id"], fixVersions = jira_fix_versions,
comment = comment, resolution = {'id': resolution.raw['id']})
print "Successfully resolved %s with fixVersions=%s!" % (jira_id, fix_versions)
@@ -300,7 +305,7 @@ def standardize_jira_ref(text):
"""
Standardize the [SPARK-XXXXX] [MODULE] prefix
Converts "[SPARK-XXX][mllib] Issue", "[MLLib] SPARK-XXX. Issue" or "SPARK XXX [MLLIB]: Issue" to "[SPARK-XXX] [MLLIB] Issue"
-
+
>>> standardize_jira_ref("[SPARK-5821] [SQL] ParquetRelation2 CTAS should check if delete is successful")
'[SPARK-5821] [SQL] ParquetRelation2 CTAS should check if delete is successful'
>>> standardize_jira_ref("[SPARK-4123][Project Infra][WIP]: Show new dependencies added in pull requests")
@@ -322,11 +327,11 @@ def standardize_jira_ref(text):
"""
jira_refs = []
components = []
-
+
# If the string is compliant, no need to process any further
if (re.search(r'^\[SPARK-[0-9]{3,6}\] (\[[A-Z0-9_\s,]+\] )+\S+', text)):
return text
-
+
# Extract JIRA ref(s):
pattern = re.compile(r'(SPARK[-\s]*[0-9]{3,6})+', re.IGNORECASE)
for ref in pattern.findall(text):
@@ -348,18 +353,18 @@ def standardize_jira_ref(text):
# Assemble full text (JIRA ref(s), module(s), remaining text)
clean_text = ' '.join(jira_refs).strip() + " " + ' '.join(components).strip() + " " + text.strip()
-
+
# Replace multiple spaces with a single space, e.g. if no jira refs and/or components were included
clean_text = re.sub(r'\s+', ' ', clean_text.strip())
-
+
return clean_text
def main():
global original_head
-
+
os.chdir(SPARK_HOME)
original_head = run_cmd("git rev-parse HEAD")[:8]
-
+
branches = get_json("%s/branches" % GITHUB_API_BASE)
branch_names = filter(lambda x: x.startswith("branch-"), [x['name'] for x in branches])
# Assumes branch names can be sorted lexicographically
@@ -448,5 +453,5 @@ def main():
(failure_count, test_count) = doctest.testmod()
if failure_count:
exit(-1)
-
+
main()
diff --git a/dev/sparktestsupport/modules.py b/dev/sparktestsupport/modules.py
index 993583e2f4119..3073d489bad4a 100644
--- a/dev/sparktestsupport/modules.py
+++ b/dev/sparktestsupport/modules.py
@@ -338,6 +338,7 @@ def contains_file(self, filename):
python_test_goals=[
"pyspark.ml.feature",
"pyspark.ml.classification",
+ "pyspark.ml.clustering",
"pyspark.ml.recommendation",
"pyspark.ml.regression",
"pyspark.ml.tuning",
diff --git a/docker/spark-test/base/Dockerfile b/docker/spark-test/base/Dockerfile
index 5956d59130fbf..5dbdb8b22a44f 100644
--- a/docker/spark-test/base/Dockerfile
+++ b/docker/spark-test/base/Dockerfile
@@ -17,13 +17,13 @@
FROM ubuntu:precise
-RUN echo "deb http://archive.ubuntu.com/ubuntu precise main universe" > /etc/apt/sources.list
-
# Upgrade package index
-RUN apt-get update
-
# install a few other useful packages plus Open Jdk 7
-RUN apt-get install -y less openjdk-7-jre-headless net-tools vim-tiny sudo openssh-server
+# Remove unneeded /var/lib/apt/lists/* after install to reduce the
+# docker image size (by ~30MB)
+RUN apt-get update && \
+ apt-get install -y less openjdk-7-jre-headless net-tools vim-tiny sudo openssh-server && \
+ rm -rf /var/lib/apt/lists/*
ENV SCALA_VERSION 2.10.4
ENV CDH_VERSION cdh4
diff --git a/docs/building-spark.md b/docs/building-spark.md
index 2128fdffecc05..a5da3b39502e2 100644
--- a/docs/building-spark.md
+++ b/docs/building-spark.md
@@ -124,7 +124,7 @@ mvn -Pyarn -Phadoop-2.4 -Dhadoop.version=2.4.0 -Phive -Phive-thriftserver -Dskip
# Building for Scala 2.11
To produce a Spark package compiled with Scala 2.11, use the `-Dscala-2.11` property:
- dev/change-version-to-2.11.sh
+ dev/change-scala-version.sh 2.11
mvn -Pyarn -Phadoop-2.4 -Dscala-2.11 -DskipTests clean package
Spark does not yet support its JDBC component for Scala 2.11.
diff --git a/docs/configuration.md b/docs/configuration.md
index 443322e1eadf4..200f3cd212e46 100644
--- a/docs/configuration.md
+++ b/docs/configuration.md
@@ -31,7 +31,6 @@ which can help detect bugs that only exist when we run in a distributed context.
val conf = new SparkConf()
.setMaster("local[2]")
.setAppName("CountingSheep")
- .set("spark.executor.memory", "1g")
val sc = new SparkContext(conf)
{% endhighlight %}
@@ -84,7 +83,7 @@ Running `./bin/spark-submit --help` will show the entire list of these options.
each line consists of a key and a value separated by whitespace. For example:
spark.master spark://5.6.7.8:7077
- spark.executor.memory 512m
+ spark.executor.memory 4g
spark.eventLog.enabled true
spark.serializer org.apache.spark.serializer.KryoSerializer
@@ -150,10 +149,9 @@ of the most common options to set are:
spark.executor.memory
-
512m
+
1g
- Amount of memory to use per executor process, in the same format as JVM memory strings
- (e.g. 512m, 2g).
+ Amount of memory to use per executor process (e.g. 2g, 8g).
@@ -665,7 +663,7 @@ Apart from these, the following properties are also available, and may be useful
Initial size of Kryo's serialization buffer. Note that there will be one buffer
per core on each worker. This buffer will grow up to
- spark.kryoserializer.buffer.max.mb if needed.
+ spark.kryoserializer.buffer.max if needed.
@@ -886,11 +884,11 @@ Apart from these, the following properties are also available, and may be useful
spark.akka.frameSize
-
10
+
128
- Maximum message size to allow in "control plane" communication (for serialized tasks and task
- results), in MB. Increase this if your tasks need to send back large results to the driver
- (e.g. using collect() on a large dataset).
+ Maximum message size to allow in "control plane" communication; generally only applies to map
+ output size information sent between executors and the driver. Increase this if you are running
+ jobs with many thousands of map and reduce tasks and see messages about the frame size.
@@ -1050,15 +1048,6 @@ Apart from these, the following properties are also available, and may be useful
infinite (all available cores) on Mesos.
-
-
spark.localExecution.enabled
-
false
-
- Enables Spark to run certain jobs, such as first() or take() on the driver, without sending
- tasks to the cluster. This can make certain jobs execute very quickly, but may require
- shipping a whole partition of data to the driver.
-
-
spark.locality.wait
3s
diff --git a/docs/ml-guide.md b/docs/ml-guide.md
index c74cb1f1ef8ea..8c46adf256a9a 100644
--- a/docs/ml-guide.md
+++ b/docs/ml-guide.md
@@ -3,6 +3,24 @@ layout: global
title: Spark ML Programming Guide
---
+`\[
+\newcommand{\R}{\mathbb{R}}
+\newcommand{\E}{\mathbb{E}}
+\newcommand{\x}{\mathbf{x}}
+\newcommand{\y}{\mathbf{y}}
+\newcommand{\wv}{\mathbf{w}}
+\newcommand{\av}{\mathbf{\alpha}}
+\newcommand{\bv}{\mathbf{b}}
+\newcommand{\N}{\mathbb{N}}
+\newcommand{\id}{\mathbf{I}}
+\newcommand{\ind}{\mathbf{1}}
+\newcommand{\0}{\mathbf{0}}
+\newcommand{\unit}{\mathbf{e}}
+\newcommand{\one}{\mathbf{1}}
+\newcommand{\zero}{\mathbf{0}}
+\]`
+
+
Spark 1.2 introduced a new package called `spark.ml`, which aims to provide a uniform set of
high-level APIs that help users create and tune practical machine learning pipelines.
@@ -154,6 +172,19 @@ Parameters belong to specific instances of `Estimator`s and `Transformer`s.
For example, if we have two `LogisticRegression` instances `lr1` and `lr2`, then we can build a `ParamMap` with both `maxIter` parameters specified: `ParamMap(lr1.maxIter -> 10, lr2.maxIter -> 20)`.
This is useful if there are two algorithms with the `maxIter` parameter in a `Pipeline`.
+# Algorithm Guides
+
+There are now several algorithms in the Pipelines API which are not in the lower-level MLlib API, so we link to documentation for them here. These algorithms are mostly feature transformers, which fit naturally into the `Transformer` abstraction in Pipelines, and ensembles, which fit naturally into the `Estimator` abstraction in the Pipelines.
+
+**Pipelines API Algorithm Guides**
+
+* [Feature Extraction, Transformation, and Selection](ml-features.html)
+* [Ensembles](ml-ensembles.html)
+
+**Algorithms in `spark.ml`**
+
+* [Linear methods with elastic net regularization](ml-linear-methods.html)
+
# Code Examples
This section gives code examples illustrating the functionality discussed above.
diff --git a/docs/ml-linear-methods.md b/docs/ml-linear-methods.md
new file mode 100644
index 0000000000000..1ac83d94c9e81
--- /dev/null
+++ b/docs/ml-linear-methods.md
@@ -0,0 +1,129 @@
+---
+layout: global
+title: Linear Methods - ML
+displayTitle: ML - Linear Methods
+---
+
+
+`\[
+\newcommand{\R}{\mathbb{R}}
+\newcommand{\E}{\mathbb{E}}
+\newcommand{\x}{\mathbf{x}}
+\newcommand{\y}{\mathbf{y}}
+\newcommand{\wv}{\mathbf{w}}
+\newcommand{\av}{\mathbf{\alpha}}
+\newcommand{\bv}{\mathbf{b}}
+\newcommand{\N}{\mathbb{N}}
+\newcommand{\id}{\mathbf{I}}
+\newcommand{\ind}{\mathbf{1}}
+\newcommand{\0}{\mathbf{0}}
+\newcommand{\unit}{\mathbf{e}}
+\newcommand{\one}{\mathbf{1}}
+\newcommand{\zero}{\mathbf{0}}
+\]`
+
+
+In MLlib, we implement popular linear methods such as logistic regression and linear least squares with L1 or L2 regularization. Refer to [the linear methods in mllib](mllib-linear-methods.html) for details. In `spark.ml`, we also include Pipelines API for [Elastic net](http://en.wikipedia.org/wiki/Elastic_net_regularization), a hybrid of L1 and L2 regularization proposed in [this paper](http://users.stat.umn.edu/~zouxx019/Papers/elasticnet.pdf). Mathematically it is defined as a linear combination of the L1-norm and the L2-norm:
+`\[
+\alpha \|\wv\|_1 + (1-\alpha) \frac{1}{2}\|\wv\|_2^2, \alpha \in [0, 1].
+\]`
+By setting $\alpha$ properly, it contains both L1 and L2 regularization as special cases. For example, if a [linear regression](https://en.wikipedia.org/wiki/Linear_regression) model is trained with the elastic net parameter $\alpha$ set to $1$, it is equivalent to a [Lasso](http://en.wikipedia.org/wiki/Least_squares#Lasso_method) model. On the other hand, if $\alpha$ is set to $0$, the trained model reduces to a [ridge regression](http://en.wikipedia.org/wiki/Tikhonov_regularization) model. We implement Pipelines API for both linear regression and logistic regression with elastic net regularization.
+
+**Examples**
+
+
+
+
+
+{% highlight scala %}
+
+import org.apache.spark.ml.classification.LogisticRegression
+import org.apache.spark.mllib.util.MLUtils
+
+// Load training data
+val training = MLUtils.loadLibSVMFile(sc, "data/mllib/sample_libsvm_data.txt").toDF()
+
+val lr = new LogisticRegression()
+ .setMaxIter(10)
+ .setRegParam(0.3)
+ .setElasticNetParam(0.8)
+
+// Fit the model
+val lrModel = lr.fit(training)
+
+// Print the weights and intercept for logistic regression
+println(s"Weights: ${lrModel.weights} Intercept: ${lrModel.intercept}")
+
+{% endhighlight %}
+
+
+
+
+
+{% highlight java %}
+
+import org.apache.spark.ml.classification.LogisticRegression;
+import org.apache.spark.ml.classification.LogisticRegressionModel;
+import org.apache.spark.mllib.regression.LabeledPoint;
+import org.apache.spark.mllib.util.MLUtils;
+import org.apache.spark.SparkConf;
+import org.apache.spark.SparkContext;
+import org.apache.spark.sql.DataFrame;
+import org.apache.spark.sql.SQLContext;
+
+public class LogisticRegressionWithElasticNetExample {
+ public static void main(String[] args) {
+ SparkConf conf = new SparkConf()
+ .setAppName("Logistic Regression with Elastic Net Example");
+
+ SparkContext sc = new SparkContext(conf);
+ SQLContext sql = new SQLContext(sc);
+ String path = "sample_libsvm_data.txt";
+
+ // Load training data
+ DataFrame training = sql.createDataFrame(MLUtils.loadLibSVMFile(sc, path).toJavaRDD(), LabeledPoint.class);
+
+ LogisticRegression lr = new LogisticRegression()
+ .setMaxIter(10)
+ .setRegParam(0.3)
+ .setElasticNetParam(0.8)
+
+ // Fit the model
+ LogisticRegressionModel lrModel = lr.fit(training);
+
+ // Print the weights and intercept for logistic regression
+ System.out.println("Weights: " + lrModel.weights() + " Intercept: " + lrModel.intercept());
+ }
+}
+{% endhighlight %}
+
+
+
+
+{% highlight python %}
+
+from pyspark.ml.classification import LogisticRegression
+from pyspark.mllib.regression import LabeledPoint
+from pyspark.mllib.util import MLUtils
+
+# Load training data
+training = MLUtils.loadLibSVMFile(sc, "data/mllib/sample_libsvm_data.txt").toDF()
+
+lr = LogisticRegression(maxIter=10, regParam=0.3, elasticNetParam=0.8)
+
+# Fit the model
+lrModel = lr.fit(training)
+
+# Print the weights and intercept for logistic regression
+print("Weights: " + str(lrModel.weights))
+print("Intercept: " + str(lrModel.intercept))
+{% endhighlight %}
+
+
+
+
+
+### Optimization
+
+The optimization algorithm underlies the implementation is called [Orthant-Wise Limited-memory QuasiNewton](http://research-srv.microsoft.com/en-us/um/people/jfgao/paper/icml07scalable.pdf)
+(OWL-QN). It is an extension of L-BFGS that can effectively handle L1 regularization and elastic net.
diff --git a/docs/mllib-clustering.md b/docs/mllib-clustering.md
index d72dc20a5ad6e..bb875ae2ae6cb 100644
--- a/docs/mllib-clustering.md
+++ b/docs/mllib-clustering.md
@@ -33,6 +33,7 @@ guaranteed to find a globally optimal solution, and when run multiple times on
a given dataset, the algorithm returns the best clustering result).
* *initializationSteps* determines the number of steps in the k-means\|\| algorithm.
* *epsilon* determines the distance threshold within which we consider k-means to have converged.
+* *initialModel* is an optional set of cluster centers used for initialization. If this parameter is supplied, only one run is performed.
**Examples**
@@ -471,7 +472,7 @@ to the algorithm. We then output the topics, represented as probability distribu
{% highlight scala %}
-import org.apache.spark.mllib.clustering.LDA
+import org.apache.spark.mllib.clustering.{LDA, DistributedLDAModel}
import org.apache.spark.mllib.linalg.Vectors
// Load and parse the data
@@ -491,6 +492,11 @@ for (topic <- Range(0, 3)) {
for (word <- Range(0, ldaModel.vocabSize)) { print(" " + topics(word, topic)); }
println()
}
+
+// Save and load model.
+ldaModel.save(sc, "myLDAModel")
+val sameModel = DistributedLDAModel.load(sc, "myLDAModel")
+
{% endhighlight %}
@@ -550,6 +556,9 @@ public class JavaLDAExample {
}
System.out.println();
}
+
+ ldaModel.save(sc.sc(), "myLDAModel");
+ DistributedLDAModel sameModel = DistributedLDAModel.load(sc.sc(), "myLDAModel");
}
}
{% endhighlight %}
diff --git a/docs/mllib-linear-methods.md b/docs/mllib-linear-methods.md
index 3927d65fbf8fb..07655baa414b5 100644
--- a/docs/mllib-linear-methods.md
+++ b/docs/mllib-linear-methods.md
@@ -10,7 +10,7 @@ displayTitle: MLlib - Linear Methods
`\[
\newcommand{\R}{\mathbb{R}}
-\newcommand{\E}{\mathbb{E}}
+\newcommand{\E}{\mathbb{E}}
\newcommand{\x}{\mathbf{x}}
\newcommand{\y}{\mathbf{y}}
\newcommand{\wv}{\mathbf{w}}
@@ -18,10 +18,10 @@ displayTitle: MLlib - Linear Methods
\newcommand{\bv}{\mathbf{b}}
\newcommand{\N}{\mathbb{N}}
\newcommand{\id}{\mathbf{I}}
-\newcommand{\ind}{\mathbf{1}}
-\newcommand{\0}{\mathbf{0}}
-\newcommand{\unit}{\mathbf{e}}
-\newcommand{\one}{\mathbf{1}}
+\newcommand{\ind}{\mathbf{1}}
+\newcommand{\0}{\mathbf{0}}
+\newcommand{\unit}{\mathbf{e}}
+\newcommand{\one}{\mathbf{1}}
\newcommand{\zero}{\mathbf{0}}
\]`
@@ -29,7 +29,7 @@ displayTitle: MLlib - Linear Methods
Many standard *machine learning* methods can be formulated as a convex optimization problem, i.e.
the task of finding a minimizer of a convex function `$f$` that depends on a variable vector
-`$\wv$` (called `weights` in the code), which has `$d$` entries.
+`$\wv$` (called `weights` in the code), which has `$d$` entries.
Formally, we can write this as the optimization problem `$\min_{\wv \in\R^d} \; f(\wv)$`, where
the objective function is of the form
`\begin{equation}
@@ -39,7 +39,7 @@ the objective function is of the form
\ .
\end{equation}`
Here the vectors `$\x_i\in\R^d$` are the training data examples, for `$1\le i\le n$`, and
-`$y_i\in\R$` are their corresponding labels, which we want to predict.
+`$y_i\in\R$` are their corresponding labels, which we want to predict.
We call the method *linear* if $L(\wv; \x, y)$ can be expressed as a function of $\wv^T x$ and $y$.
Several of MLlib's classification and regression algorithms fall into this category,
and are discussed here.
@@ -99,6 +99,9 @@ regularizers in MLlib:
@@ -107,7 +110,7 @@ of `$\wv$`.
L2-regularized problems are generally easier to solve than L1-regularized due to smoothness.
However, L1 regularization can help promote sparsity in weights leading to smaller and more interpretable models, the latter of which can be useful for feature selection.
-It is not recommended to train models without any regularization,
+[Elastic net](http://en.wikipedia.org/wiki/Elastic_net_regularization) is a combination of L1 and L2 regularization. It is not recommended to train models without any regularization,
especially when the number of training examples is small.
### Optimization
@@ -531,7 +534,7 @@ sameModel = LogisticRegressionModel.load(sc, "myModelPath")
### Linear least squares, Lasso, and ridge regression
-Linear least squares is the most common formulation for regression problems.
+Linear least squares is the most common formulation for regression problems.
It is a linear method as described above in equation `$\eqref{eq:regPrimal}$`, with the loss
function in the formulation given by the squared loss:
`\[
@@ -539,8 +542,8 @@ L(\wv;\x,y) := \frac{1}{2} (\wv^T \x - y)^2.
\]`
Various related regression methods are derived by using different types of regularization:
-[*ordinary least squares*](http://en.wikipedia.org/wiki/Ordinary_least_squares) or
-[*linear least squares*](http://en.wikipedia.org/wiki/Linear_least_squares_(mathematics)) uses
+[*ordinary least squares*](http://en.wikipedia.org/wiki/Ordinary_least_squares) or
+[*linear least squares*](http://en.wikipedia.org/wiki/Linear_least_squares_(mathematics)) uses
no regularization; [*ridge regression*](http://en.wikipedia.org/wiki/Ridge_regression) uses L2
regularization; and [*Lasso*](http://en.wikipedia.org/wiki/Lasso_(statistics)) uses L1
regularization. For all of these models, the average loss or training error, $\frac{1}{n} \sum_{i=1}^n (\wv^T x_i - y_i)^2$, is
@@ -552,7 +555,7 @@ known as the [mean squared error](http://en.wikipedia.org/wiki/Mean_squared_erro
The following example demonstrate how to load training data, parse it as an RDD of LabeledPoint.
-The example then uses LinearRegressionWithSGD to build a simple linear model to predict label
+The example then uses LinearRegressionWithSGD to build a simple linear model to predict label
values. We compute the mean squared error at the end to evaluate
[goodness of fit](http://en.wikipedia.org/wiki/Goodness_of_fit).
@@ -614,7 +617,7 @@ public class LinearRegression {
public static void main(String[] args) {
SparkConf conf = new SparkConf().setAppName("Linear Regression Example");
JavaSparkContext sc = new JavaSparkContext(conf);
-
+
// Load and parse the data
String path = "data/mllib/ridge-data/lpsa.data";
JavaRDD data = sc.textFile(path);
@@ -634,7 +637,7 @@ public class LinearRegression {
// Building the model
int numIterations = 100;
- final LinearRegressionModel model =
+ final LinearRegressionModel model =
LinearRegressionWithSGD.train(JavaRDD.toRDD(parsedData), numIterations);
// Evaluate model on training examples and compute training error
@@ -665,7 +668,7 @@ public class LinearRegression {
The following example demonstrate how to load training data, parse it as an RDD of LabeledPoint.
-The example then uses LinearRegressionWithSGD to build a simple linear model to predict label
+The example then uses LinearRegressionWithSGD to build a simple linear model to predict label
values. We compute the mean squared error at the end to evaluate
[goodness of fit](http://en.wikipedia.org/wiki/Goodness_of_fit).
@@ -706,8 +709,8 @@ a dependency.
###Streaming linear regression
-When data arrive in a streaming fashion, it is useful to fit regression models online,
-updating the parameters of the model as new data arrives. MLlib currently supports
+When data arrive in a streaming fashion, it is useful to fit regression models online,
+updating the parameters of the model as new data arrives. MLlib currently supports
streaming linear regression using ordinary least squares. The fitting is similar
to that performed offline, except fitting occurs on each batch of data, so that
the model continually updates to reflect the data from the stream.
@@ -722,7 +725,7 @@ online to the first stream, and make predictions on the second stream.
-First, we import the necessary classes for parsing our input data and creating the model.
+First, we import the necessary classes for parsing our input data and creating the model.
{% highlight scala %}
@@ -734,7 +737,7 @@ import org.apache.spark.mllib.regression.StreamingLinearRegressionWithSGD
Then we make input streams for training and testing data. We assume a StreamingContext `ssc`
has already been created, see [Spark Streaming Programming Guide](streaming-programming-guide.html#initializing)
-for more info. For this example, we use labeled points in training and testing streams,
+for more info. For this example, we use labeled points in training and testing streams,
but in practice you will likely want to use unlabeled vectors for test data.
{% highlight scala %}
@@ -754,7 +757,7 @@ val model = new StreamingLinearRegressionWithSGD()
{% endhighlight %}
-Now we register the streams for training and testing and start the job.
+Now we register the streams for training and testing and start the job.
Printing predictions alongside true labels lets us easily see the result.
{% highlight scala %}
@@ -764,14 +767,14 @@ model.predictOnValues(testData.map(lp => (lp.label, lp.features))).print()
ssc.start()
ssc.awaitTermination()
-
+
{% endhighlight %}
We can now save text files with data to the training or testing folders.
-Each line should be a data point formatted as `(y,[x1,x2,x3])` where `y` is the label
-and `x1,x2,x3` are the features. Anytime a text file is placed in `/training/data/dir`
-the model will update. Anytime a text file is placed in `/testing/data/dir` you will see predictions.
-As you feed more data to the training directory, the predictions
+Each line should be a data point formatted as `(y,[x1,x2,x3])` where `y` is the label
+and `x1,x2,x3` are the features. Anytime a text file is placed in `/training/data/dir`
+the model will update. Anytime a text file is placed in `/testing/data/dir` you will see predictions.
+As you feed more data to the training directory, the predictions
will get better!
diff --git a/docs/running-on-mesos.md b/docs/running-on-mesos.md
index 1f915d8ea1d73..debdd2adf22d6 100644
--- a/docs/running-on-mesos.md
+++ b/docs/running-on-mesos.md
@@ -306,6 +306,28 @@ See the [configuration page](configuration.html) for information on Spark config
the final overhead will be this value.
+
+
spark.mesos.principal
+
Framework principal to authenticate to Mesos
+
+ Set the principal with which Spark framework will use to authenticate with Mesos.
+
+
+
+
spark.mesos.secret
+
Framework secret to authenticate to Mesos
+
+ Set the secret with which Spark framework will use to authenticate with Mesos.
+
+
+
+
spark.mesos.role
+
Role for the Spark framework
+
+ Set the role of this Spark framework for Mesos. Roles are used in Mesos for reservations
+ and resource weight sharing.
+
+
spark.mesos.constraints
Attribute based constraints to be matched against when accepting resource offers.
diff --git a/ec2/spark_ec2.py b/ec2/spark_ec2.py
index ae4f2ecc5bde7..7c83d68e7993e 100755
--- a/ec2/spark_ec2.py
+++ b/ec2/spark_ec2.py
@@ -793,7 +793,7 @@ def setup_cluster(conn, master_nodes, slave_nodes, opts, deploy_ssh_key):
ssh_write(slave_address, opts, ['tar', 'x'], dot_ssh_tar)
modules = ['spark', 'ephemeral-hdfs', 'persistent-hdfs',
- 'mapreduce', 'spark-standalone', 'tachyon']
+ 'mapreduce', 'spark-standalone', 'tachyon', 'rstudio']
if opts.hadoop_major_version == "1":
modules = list(filter(lambda x: x != "mapreduce", modules))
diff --git a/external/flume-sink/src/main/scala/org/apache/spark/streaming/flume/sink/Logging.scala b/external/flume-sink/src/main/scala/org/apache/spark/streaming/flume/sink/Logging.scala
index 17cbc6707b5ea..d87b86932dd41 100644
--- a/external/flume-sink/src/main/scala/org/apache/spark/streaming/flume/sink/Logging.scala
+++ b/external/flume-sink/src/main/scala/org/apache/spark/streaming/flume/sink/Logging.scala
@@ -113,7 +113,9 @@ private[sink] object Logging {
try {
// We use reflection here to handle the case where users remove the
// slf4j-to-jul bridge order to route their logs to JUL.
+ // scalastyle:off classforname
val bridgeClass = Class.forName("org.slf4j.bridge.SLF4JBridgeHandler")
+ // scalastyle:on classforname
bridgeClass.getMethod("removeHandlersForRootLogger").invoke(null)
val installed = bridgeClass.getMethod("isInstalled").invoke(null).asInstanceOf[Boolean]
if (!installed) {
diff --git a/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/KafkaRDD.scala b/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/KafkaRDD.scala
index c5cd2154772ac..1a9d78c0d4f59 100644
--- a/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/KafkaRDD.scala
+++ b/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/KafkaRDD.scala
@@ -98,8 +98,7 @@ class KafkaRDD[
val res = context.runJob(
this,
(tc: TaskContext, it: Iterator[R]) => it.take(parts(tc.partitionId)).toArray,
- parts.keys.toArray,
- allowLocal = true)
+ parts.keys.toArray)
res.foreach(buf ++= _)
buf.toArray
}
diff --git a/extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisTestUtils.scala b/extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisTestUtils.scala
new file mode 100644
index 0000000000000..f6bf552e6bb8e
--- /dev/null
+++ b/extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisTestUtils.scala
@@ -0,0 +1,197 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.streaming.kinesis
+
+import java.nio.ByteBuffer
+import java.util.concurrent.TimeUnit
+
+import scala.collection.mutable
+import scala.collection.mutable.ArrayBuffer
+import scala.util.{Failure, Random, Success, Try}
+
+import com.amazonaws.auth.{AWSCredentials, DefaultAWSCredentialsProviderChain}
+import com.amazonaws.regions.RegionUtils
+import com.amazonaws.services.dynamodbv2.AmazonDynamoDBClient
+import com.amazonaws.services.dynamodbv2.document.DynamoDB
+import com.amazonaws.services.kinesis.AmazonKinesisClient
+import com.amazonaws.services.kinesis.model._
+
+import org.apache.spark.Logging
+
+/**
+ * Shared utility methods for performing Kinesis tests that actually transfer data
+ */
+private class KinesisTestUtils(
+ val endpointUrl: String = "https://kinesis.us-west-2.amazonaws.com",
+ _regionName: String = "") extends Logging {
+
+ val regionName = if (_regionName.length == 0) {
+ RegionUtils.getRegionByEndpoint(endpointUrl).getName()
+ } else {
+ RegionUtils.getRegion(_regionName).getName()
+ }
+
+ val streamShardCount = 2
+
+ private val createStreamTimeoutSeconds = 300
+ private val describeStreamPollTimeSeconds = 1
+
+ @volatile
+ private var streamCreated = false
+ private var _streamName: String = _
+
+ private lazy val kinesisClient = {
+ val client = new AmazonKinesisClient(KinesisTestUtils.getAWSCredentials())
+ client.setEndpoint(endpointUrl)
+ client
+ }
+
+ private lazy val dynamoDB = {
+ val dynamoDBClient = new AmazonDynamoDBClient(new DefaultAWSCredentialsProviderChain())
+ dynamoDBClient.setRegion(RegionUtils.getRegion(regionName))
+ new DynamoDB(dynamoDBClient)
+ }
+
+ def streamName: String = {
+ require(streamCreated, "Stream not yet created, call createStream() to create one")
+ _streamName
+ }
+
+ def createStream(): Unit = {
+ logInfo("Creating stream")
+ require(!streamCreated, "Stream already created")
+ _streamName = findNonExistentStreamName()
+
+ // Create a stream. The number of shards determines the provisioned throughput.
+ val createStreamRequest = new CreateStreamRequest()
+ createStreamRequest.setStreamName(_streamName)
+ createStreamRequest.setShardCount(2)
+ kinesisClient.createStream(createStreamRequest)
+
+ // The stream is now being created. Wait for it to become active.
+ waitForStreamToBeActive(_streamName)
+ streamCreated = true
+ logInfo("Created stream")
+ }
+
+ /**
+ * Push data to Kinesis stream and return a map of
+ * shardId -> seq of (data, seq number) pushed to corresponding shard
+ */
+ def pushData(testData: Seq[Int]): Map[String, Seq[(Int, String)]] = {
+ require(streamCreated, "Stream not yet created, call createStream() to create one")
+ val shardIdToSeqNumbers = new mutable.HashMap[String, ArrayBuffer[(Int, String)]]()
+
+ testData.foreach { num =>
+ val str = num.toString
+ val putRecordRequest = new PutRecordRequest().withStreamName(streamName)
+ .withData(ByteBuffer.wrap(str.getBytes()))
+ .withPartitionKey(str)
+
+ val putRecordResult = kinesisClient.putRecord(putRecordRequest)
+ val shardId = putRecordResult.getShardId
+ val seqNumber = putRecordResult.getSequenceNumber()
+ val sentSeqNumbers = shardIdToSeqNumbers.getOrElseUpdate(shardId,
+ new ArrayBuffer[(Int, String)]())
+ sentSeqNumbers += ((num, seqNumber))
+ }
+
+ logInfo(s"Pushed $testData:\n\t ${shardIdToSeqNumbers.mkString("\n\t")}")
+ shardIdToSeqNumbers.toMap
+ }
+
+ def describeStream(streamNameToDescribe: String = streamName): Option[StreamDescription] = {
+ try {
+ val describeStreamRequest = new DescribeStreamRequest().withStreamName(streamNameToDescribe)
+ val desc = kinesisClient.describeStream(describeStreamRequest).getStreamDescription()
+ Some(desc)
+ } catch {
+ case rnfe: ResourceNotFoundException =>
+ None
+ }
+ }
+
+ def deleteStream(): Unit = {
+ try {
+ if (describeStream().nonEmpty) {
+ val deleteStreamRequest = new DeleteStreamRequest()
+ kinesisClient.deleteStream(streamName)
+ }
+ } catch {
+ case e: Exception =>
+ logWarning(s"Could not delete stream $streamName")
+ }
+ }
+
+ def deleteDynamoDBTable(tableName: String): Unit = {
+ try {
+ val table = dynamoDB.getTable(tableName)
+ table.delete()
+ table.waitForDelete()
+ } catch {
+ case e: Exception =>
+ logWarning(s"Could not delete DynamoDB table $tableName")
+ }
+ }
+
+ private def findNonExistentStreamName(): String = {
+ var testStreamName: String = null
+ do {
+ Thread.sleep(TimeUnit.SECONDS.toMillis(describeStreamPollTimeSeconds))
+ testStreamName = s"KinesisTestUtils-${math.abs(Random.nextLong())}"
+ } while (describeStream(testStreamName).nonEmpty)
+ testStreamName
+ }
+
+ private def waitForStreamToBeActive(streamNameToWaitFor: String): Unit = {
+ val startTime = System.currentTimeMillis()
+ val endTime = startTime + TimeUnit.SECONDS.toMillis(createStreamTimeoutSeconds)
+ while (System.currentTimeMillis() < endTime) {
+ Thread.sleep(TimeUnit.SECONDS.toMillis(describeStreamPollTimeSeconds))
+ describeStream(streamNameToWaitFor).foreach { description =>
+ val streamStatus = description.getStreamStatus()
+ logDebug(s"\t- current state: $streamStatus\n")
+ if ("ACTIVE".equals(streamStatus)) {
+ return
+ }
+ }
+ }
+ require(false, s"Stream $streamName never became active")
+ }
+}
+
+private[kinesis] object KinesisTestUtils {
+
+ val envVarName = "RUN_KINESIS_TESTS"
+
+ val shouldRunTests = sys.env.get(envVarName) == Some("1")
+
+ def isAWSCredentialsPresent: Boolean = {
+ Try { new DefaultAWSCredentialsProviderChain().getCredentials() }.isSuccess
+ }
+
+ def getAWSCredentials(): AWSCredentials = {
+ assert(shouldRunTests,
+ "Kinesis test not enabled, should not attempt to get AWS credentials")
+ Try { new DefaultAWSCredentialsProviderChain().getCredentials() } match {
+ case Success(cred) => cred
+ case Failure(e) =>
+ throw new Exception("Kinesis tests enabled, but could get not AWS credentials")
+ }
+ }
+}
diff --git a/extras/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KinesisFunSuite.scala b/extras/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KinesisFunSuite.scala
new file mode 100644
index 0000000000000..6d011f295e7f7
--- /dev/null
+++ b/extras/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KinesisFunSuite.scala
@@ -0,0 +1,37 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.streaming.kinesis
+
+import org.apache.spark.SparkFunSuite
+
+/**
+ * Helper class that runs Kinesis real data transfer tests or
+ * ignores them based on env variable is set or not.
+ */
+trait KinesisSuiteHelper { self: SparkFunSuite =>
+ import KinesisTestUtils._
+
+ /** Run the test if environment variable is set or ignore the test */
+ def testOrIgnore(testName: String)(testBody: => Unit) {
+ if (shouldRunTests) {
+ test(testName)(testBody)
+ } else {
+ ignore(s"$testName [enable by setting env var $envVarName=1]")(testBody)
+ }
+ }
+}
diff --git a/extras/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KinesisReceiverSuite.scala b/extras/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KinesisReceiverSuite.scala
index 2103dca6b766f..98f2c7c4f1bfb 100644
--- a/extras/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KinesisReceiverSuite.scala
+++ b/extras/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KinesisReceiverSuite.scala
@@ -73,23 +73,6 @@ class KinesisReceiverSuite extends TestSuiteBase with Matchers with BeforeAndAft
checkpointStateMock, currentClockMock)
}
- test("KinesisUtils API") {
- val ssc = new StreamingContext(master, framework, batchDuration)
- // Tests the API, does not actually test data receiving
- val kinesisStream1 = KinesisUtils.createStream(ssc, "mySparkStream",
- "https://kinesis.us-west-2.amazonaws.com", Seconds(2),
- InitialPositionInStream.LATEST, StorageLevel.MEMORY_AND_DISK_2)
- val kinesisStream2 = KinesisUtils.createStream(ssc, "myAppNam", "mySparkStream",
- "https://kinesis.us-west-2.amazonaws.com", "us-west-2",
- InitialPositionInStream.LATEST, Seconds(2), StorageLevel.MEMORY_AND_DISK_2)
- val kinesisStream3 = KinesisUtils.createStream(ssc, "myAppNam", "mySparkStream",
- "https://kinesis.us-west-2.amazonaws.com", "us-west-2",
- InitialPositionInStream.LATEST, Seconds(2), StorageLevel.MEMORY_AND_DISK_2,
- "awsAccessKey", "awsSecretKey")
-
- ssc.stop()
- }
-
test("check serializability of SerializableAWSCredentials") {
Utils.deserialize[SerializableAWSCredentials](
Utils.serialize(new SerializableAWSCredentials("x", "y")))
diff --git a/extras/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KinesisStreamSuite.scala b/extras/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KinesisStreamSuite.scala
new file mode 100644
index 0000000000000..50f71413abf37
--- /dev/null
+++ b/extras/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KinesisStreamSuite.scala
@@ -0,0 +1,115 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.streaming.kinesis
+
+import scala.collection.mutable
+import scala.concurrent.duration._
+import scala.language.postfixOps
+import scala.util.Random
+
+import com.amazonaws.services.kinesis.clientlibrary.lib.worker.InitialPositionInStream
+import org.scalatest.concurrent.Eventually
+import org.scalatest.{BeforeAndAfter, BeforeAndAfterAll}
+
+import org.apache.spark.storage.StorageLevel
+import org.apache.spark.streaming._
+import org.apache.spark.{SparkConf, SparkContext, SparkFunSuite}
+
+class KinesisStreamSuite extends SparkFunSuite with KinesisSuiteHelper
+ with Eventually with BeforeAndAfter with BeforeAndAfterAll {
+
+ // This is the name that KCL uses to save metadata to DynamoDB
+ private val kinesisAppName = s"KinesisStreamSuite-${math.abs(Random.nextLong())}"
+
+ private var ssc: StreamingContext = _
+ private var sc: SparkContext = _
+
+ override def beforeAll(): Unit = {
+ val conf = new SparkConf()
+ .setMaster("local[4]")
+ .setAppName("KinesisStreamSuite") // Setting Spark app name to Kinesis app name
+ sc = new SparkContext(conf)
+ }
+
+ override def afterAll(): Unit = {
+ sc.stop()
+ // Delete the Kinesis stream as well as the DynamoDB table generated by
+ // Kinesis Client Library when consuming the stream
+ }
+
+ after {
+ if (ssc != null) {
+ ssc.stop(stopSparkContext = false)
+ ssc = null
+ }
+ }
+
+ test("KinesisUtils API") {
+ ssc = new StreamingContext(sc, Seconds(1))
+ // Tests the API, does not actually test data receiving
+ val kinesisStream1 = KinesisUtils.createStream(ssc, "mySparkStream",
+ "https://kinesis.us-west-2.amazonaws.com", Seconds(2),
+ InitialPositionInStream.LATEST, StorageLevel.MEMORY_AND_DISK_2)
+ val kinesisStream2 = KinesisUtils.createStream(ssc, "myAppNam", "mySparkStream",
+ "https://kinesis.us-west-2.amazonaws.com", "us-west-2",
+ InitialPositionInStream.LATEST, Seconds(2), StorageLevel.MEMORY_AND_DISK_2)
+ val kinesisStream3 = KinesisUtils.createStream(ssc, "myAppNam", "mySparkStream",
+ "https://kinesis.us-west-2.amazonaws.com", "us-west-2",
+ InitialPositionInStream.LATEST, Seconds(2), StorageLevel.MEMORY_AND_DISK_2,
+ "awsAccessKey", "awsSecretKey")
+ }
+
+
+ /**
+ * Test the stream by sending data to a Kinesis stream and receiving from it.
+ * This test is not run by default as it requires AWS credentials that the test
+ * environment may not have. Even if there is AWS credentials available, the user
+ * may not want to run these tests to avoid the Kinesis costs. To enable this test,
+ * you must have AWS credentials available through the default AWS provider chain,
+ * and you have to set the system environment variable RUN_KINESIS_TESTS=1 .
+ */
+ testOrIgnore("basic operation") {
+ val kinesisTestUtils = new KinesisTestUtils()
+ try {
+ kinesisTestUtils.createStream()
+ ssc = new StreamingContext(sc, Seconds(1))
+ val aWSCredentials = KinesisTestUtils.getAWSCredentials()
+ val stream = KinesisUtils.createStream(ssc, kinesisAppName, kinesisTestUtils.streamName,
+ kinesisTestUtils.endpointUrl, kinesisTestUtils.regionName, InitialPositionInStream.LATEST,
+ Seconds(10), StorageLevel.MEMORY_ONLY,
+ aWSCredentials.getAWSAccessKeyId, aWSCredentials.getAWSSecretKey)
+
+ val collected = new mutable.HashSet[Int] with mutable.SynchronizedSet[Int]
+ stream.map { bytes => new String(bytes).toInt }.foreachRDD { rdd =>
+ collected ++= rdd.collect()
+ logInfo("Collected = " + rdd.collect().toSeq.mkString(", "))
+ }
+ ssc.start()
+
+ val testData = 1 to 10
+ eventually(timeout(120 seconds), interval(10 second)) {
+ kinesisTestUtils.pushData(testData)
+ assert(collected === testData.toSet, "\nData received does not match data sent")
+ }
+ ssc.stop()
+ } finally {
+ kinesisTestUtils.deleteStream()
+ kinesisTestUtils.deleteDynamoDBTable(kinesisAppName)
+ }
+ }
+}
diff --git a/graphx/src/main/scala/org/apache/spark/graphx/PartitionStrategy.scala b/graphx/src/main/scala/org/apache/spark/graphx/PartitionStrategy.scala
index 7372dfbd9fe98..70a7592da8ae3 100644
--- a/graphx/src/main/scala/org/apache/spark/graphx/PartitionStrategy.scala
+++ b/graphx/src/main/scala/org/apache/spark/graphx/PartitionStrategy.scala
@@ -32,7 +32,7 @@ trait PartitionStrategy extends Serializable {
object PartitionStrategy {
/**
* Assigns edges to partitions using a 2D partitioning of the sparse edge adjacency matrix,
- * guaranteeing a `2 * sqrt(numParts) - 1` bound on vertex replication.
+ * guaranteeing a `2 * sqrt(numParts)` bound on vertex replication.
*
* Suppose we have a graph with 12 vertices that we want to partition
* over 9 machines. We can use the following sparse matrix representation:
@@ -61,26 +61,36 @@ object PartitionStrategy {
* that edges adjacent to `v11` can only be in the first column of blocks `(P0, P3,
* P6)` or the last
* row of blocks `(P6, P7, P8)`. As a consequence we can guarantee that `v11` will need to be
- * replicated to at most `2 * sqrt(numParts) - 1` machines.
+ * replicated to at most `2 * sqrt(numParts)` machines.
*
* Notice that `P0` has many edges and as a consequence this partitioning would lead to poor work
* balance. To improve balance we first multiply each vertex id by a large prime to shuffle the
* vertex locations.
*
- * One of the limitations of this approach is that the number of machines must either be a
- * perfect square. We partially address this limitation by computing the machine assignment to
- * the next
- * largest perfect square and then mapping back down to the actual number of machines.
- * Unfortunately, this can also lead to work imbalance and so it is suggested that a perfect
- * square is used.
+ * When the number of partitions requested is not a perfect square we use a slightly different
+ * method where the last column can have a different number of rows than the others while still
+ * maintaining the same size per block.
*/
case object EdgePartition2D extends PartitionStrategy {
override def getPartition(src: VertexId, dst: VertexId, numParts: PartitionID): PartitionID = {
val ceilSqrtNumParts: PartitionID = math.ceil(math.sqrt(numParts)).toInt
val mixingPrime: VertexId = 1125899906842597L
- val col: PartitionID = (math.abs(src * mixingPrime) % ceilSqrtNumParts).toInt
- val row: PartitionID = (math.abs(dst * mixingPrime) % ceilSqrtNumParts).toInt
- (col * ceilSqrtNumParts + row) % numParts
+ if (numParts == ceilSqrtNumParts * ceilSqrtNumParts) {
+ // Use old method for perfect squared to ensure we get same results
+ val col: PartitionID = (math.abs(src * mixingPrime) % ceilSqrtNumParts).toInt
+ val row: PartitionID = (math.abs(dst * mixingPrime) % ceilSqrtNumParts).toInt
+ (col * ceilSqrtNumParts + row) % numParts
+
+ } else {
+ // Otherwise use new method
+ val cols = ceilSqrtNumParts
+ val rows = (numParts + cols - 1) / cols
+ val lastColRows = numParts - rows * (cols - 1)
+ val col = (math.abs(src * mixingPrime) % numParts / rows).toInt
+ val row = (math.abs(dst * mixingPrime) % (if (col < cols - 1) rows else lastColRows)).toInt
+ col * rows + row
+
+ }
}
}
diff --git a/graphx/src/main/scala/org/apache/spark/graphx/impl/GraphImpl.scala b/graphx/src/main/scala/org/apache/spark/graphx/impl/GraphImpl.scala
index 90a74d23a26cc..da95314440d86 100644
--- a/graphx/src/main/scala/org/apache/spark/graphx/impl/GraphImpl.scala
+++ b/graphx/src/main/scala/org/apache/spark/graphx/impl/GraphImpl.scala
@@ -332,9 +332,9 @@ object GraphImpl {
edgeStorageLevel: StorageLevel,
vertexStorageLevel: StorageLevel): GraphImpl[VD, ED] = {
val edgeRDD = EdgeRDD.fromEdges(edges)(classTag[ED], classTag[VD])
- .withTargetStorageLevel(edgeStorageLevel).cache()
+ .withTargetStorageLevel(edgeStorageLevel)
val vertexRDD = VertexRDD(vertices, edgeRDD, defaultVertexAttr)
- .withTargetStorageLevel(vertexStorageLevel).cache()
+ .withTargetStorageLevel(vertexStorageLevel)
GraphImpl(vertexRDD, edgeRDD)
}
@@ -346,9 +346,14 @@ object GraphImpl {
def apply[VD: ClassTag, ED: ClassTag](
vertices: VertexRDD[VD],
edges: EdgeRDD[ED]): GraphImpl[VD, ED] = {
+
+ vertices.cache()
+
// Convert the vertex partitions in edges to the correct type
val newEdges = edges.asInstanceOf[EdgeRDDImpl[ED, _]]
.mapEdgePartitions((pid, part) => part.withoutVertexAttributes[VD])
+ .cache()
+
GraphImpl.fromExistingRDDs(vertices, newEdges)
}
diff --git a/graphx/src/main/scala/org/apache/spark/graphx/util/BytecodeUtils.scala b/graphx/src/main/scala/org/apache/spark/graphx/util/BytecodeUtils.scala
index 5c07b415cd796..74a7de18d4161 100644
--- a/graphx/src/main/scala/org/apache/spark/graphx/util/BytecodeUtils.scala
+++ b/graphx/src/main/scala/org/apache/spark/graphx/util/BytecodeUtils.scala
@@ -121,7 +121,7 @@ private[graphx] object BytecodeUtils {
override def visitMethodInsn(op: Int, owner: String, name: String, desc: String) {
if (op == INVOKEVIRTUAL || op == INVOKESPECIAL || op == INVOKESTATIC) {
if (!skipClass(owner)) {
- methodsInvoked.add((Class.forName(owner.replace("/", ".")), name))
+ methodsInvoked.add((Utils.classForName(owner.replace("/", ".")), name))
}
}
}
diff --git a/launcher/src/main/java/org/apache/spark/launcher/SparkLauncher.java b/launcher/src/main/java/org/apache/spark/launcher/SparkLauncher.java
index d4cfeacb6ef18..c0f89c9230692 100644
--- a/launcher/src/main/java/org/apache/spark/launcher/SparkLauncher.java
+++ b/launcher/src/main/java/org/apache/spark/launcher/SparkLauncher.java
@@ -25,11 +25,12 @@
import static org.apache.spark.launcher.CommandBuilderUtils.*;
-/**
+/**
* Launcher for Spark applications.
- *
+ *
* Use this class to start Spark applications programmatically. The class uses a builder pattern
* to allow clients to configure the Spark application and launch it as a child process.
+ *
*/
public class SparkLauncher {
diff --git a/launcher/src/main/java/org/apache/spark/launcher/package-info.java b/launcher/src/main/java/org/apache/spark/launcher/package-info.java
index 7ed756f4b8591..7c97dba511b28 100644
--- a/launcher/src/main/java/org/apache/spark/launcher/package-info.java
+++ b/launcher/src/main/java/org/apache/spark/launcher/package-info.java
@@ -17,13 +17,17 @@
/**
* Library for launching Spark applications.
- *
+ *
+ *
* This library allows applications to launch Spark programmatically. There's only one entry
* point to the library - the {@link org.apache.spark.launcher.SparkLauncher} class.
- *
+ *
+ *
+ *
* To launch a Spark application, just instantiate a {@link org.apache.spark.launcher.SparkLauncher}
* and configure the application to run. For example:
- *
+ *
+ *
*
* {@code
* import org.apache.spark.launcher.SparkLauncher;
diff --git a/make-distribution.sh b/make-distribution.sh
index 9f063da3a16c0..cac7032bb2e87 100755
--- a/make-distribution.sh
+++ b/make-distribution.sh
@@ -219,6 +219,7 @@ cp -r "$SPARK_HOME/ec2" "$DISTDIR"
if [ -d "$SPARK_HOME"/R/lib/SparkR ]; then
mkdir -p "$DISTDIR"/R/lib
cp -r "$SPARK_HOME/R/lib/SparkR" "$DISTDIR"/R/lib
+ cp "$SPARK_HOME/R/lib/sparkr.zip" "$DISTDIR"/R/lib
fi
# Download and copy in tachyon, if requested
diff --git a/mllib/src/main/scala/org/apache/spark/ml/Predictor.scala b/mllib/src/main/scala/org/apache/spark/ml/Predictor.scala
index 333b42711ec52..19fe039b8fd03 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/Predictor.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/Predictor.scala
@@ -169,10 +169,7 @@ abstract class PredictionModel[FeaturesType, M <: PredictionModel[FeaturesType,
override def transform(dataset: DataFrame): DataFrame = {
transformSchema(dataset.schema, logging = true)
if ($(predictionCol).nonEmpty) {
- val predictUDF = udf { (features: Any) =>
- predict(features.asInstanceOf[FeaturesType])
- }
- dataset.withColumn($(predictionCol), predictUDF(col($(featuresCol))))
+ transformImpl(dataset)
} else {
this.logWarning(s"$uid: Predictor.transform() was called as NOOP" +
" since no output columns were set.")
@@ -180,6 +177,13 @@ abstract class PredictionModel[FeaturesType, M <: PredictionModel[FeaturesType,
}
}
+ protected def transformImpl(dataset: DataFrame): DataFrame = {
+ val predictUDF = udf { (features: Any) =>
+ predict(features.asInstanceOf[FeaturesType])
+ }
+ dataset.withColumn($(predictionCol), predictUDF(col($(featuresCol))))
+ }
+
/**
* Predict label for the given features.
* This internal method is used to implement [[transform()]] and output [[predictionCol]].
diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala
index 2dc1824964a42..36fe1bd40469c 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala
@@ -21,10 +21,10 @@ import org.apache.spark.annotation.Experimental
import org.apache.spark.ml.{PredictionModel, Predictor}
import org.apache.spark.ml.param.ParamMap
import org.apache.spark.ml.tree.{DecisionTreeModel, DecisionTreeParams, Node, TreeClassifierParams}
+import org.apache.spark.ml.tree.impl.RandomForest
import org.apache.spark.ml.util.{Identifiable, MetadataUtils}
import org.apache.spark.mllib.linalg.Vector
import org.apache.spark.mllib.regression.LabeledPoint
-import org.apache.spark.mllib.tree.{DecisionTree => OldDecisionTree}
import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo, Strategy => OldStrategy}
import org.apache.spark.mllib.tree.model.{DecisionTreeModel => OldDecisionTreeModel}
import org.apache.spark.rdd.RDD
@@ -75,8 +75,9 @@ final class DecisionTreeClassifier(override val uid: String)
}
val oldDataset: RDD[LabeledPoint] = extractLabeledPoints(dataset)
val strategy = getOldStrategy(categoricalFeatures, numClasses)
- val oldModel = OldDecisionTree.train(oldDataset, strategy)
- DecisionTreeClassificationModel.fromOld(oldModel, this, categoricalFeatures)
+ val trees = RandomForest.run(oldDataset, strategy, numTrees = 1, featureSubsetStrategy = "all",
+ seed = 0L, parentUID = Some(uid))
+ trees.head.asInstanceOf[DecisionTreeClassificationModel]
}
/** (private[ml]) Create a Strategy instance to use with the old API. */
@@ -112,6 +113,12 @@ final class DecisionTreeClassificationModel private[ml] (
require(rootNode != null,
"DecisionTreeClassificationModel given null rootNode, but it requires a non-null rootNode.")
+ /**
+ * Construct a decision tree classification model.
+ * @param rootNode Root node of tree, with other nodes attached.
+ */
+ def this(rootNode: Node) = this(Identifiable.randomUID("dtc"), rootNode)
+
override protected def predict(features: Vector): Double = {
rootNode.predict(features)
}
diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala
index 554e3b8e052b2..eb0b1a0a405fc 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala
@@ -34,6 +34,8 @@ import org.apache.spark.mllib.tree.loss.{LogLoss => OldLogLoss, Loss => OldLoss}
import org.apache.spark.mllib.tree.model.{GradientBoostedTreesModel => OldGBTModel}
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.DataFrame
+import org.apache.spark.sql.functions._
+import org.apache.spark.sql.types.DoubleType
/**
* :: Experimental ::
@@ -177,8 +179,15 @@ final class GBTClassificationModel(
override def treeWeights: Array[Double] = _treeWeights
+ override protected def transformImpl(dataset: DataFrame): DataFrame = {
+ val bcastModel = dataset.sqlContext.sparkContext.broadcast(this)
+ val predictUDF = udf { (features: Any) =>
+ bcastModel.value.predict(features.asInstanceOf[Vector])
+ }
+ dataset.withColumn($(predictionCol), predictUDF(col($(featuresCol))))
+ }
+
override protected def predict(features: Vector): Double = {
- // TODO: Override transform() to broadcast model: SPARK-7127
// TODO: When we add a generic Boosting class, handle transform there? SPARK-7129
// Classifies by thresholding sum of weighted tree predictions
val treePredictions = _trees.map(_.rootNode.predict(features))
diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/NaiveBayes.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/NaiveBayes.scala
new file mode 100644
index 0000000000000..1f547e4a98af7
--- /dev/null
+++ b/mllib/src/main/scala/org/apache/spark/ml/classification/NaiveBayes.scala
@@ -0,0 +1,178 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.ml.classification
+
+import org.apache.spark.SparkException
+import org.apache.spark.ml.{PredictorParams, PredictionModel, Predictor}
+import org.apache.spark.ml.param.{ParamMap, ParamValidators, Param, DoubleParam}
+import org.apache.spark.ml.util.Identifiable
+import org.apache.spark.mllib.classification.{NaiveBayes => OldNaiveBayes}
+import org.apache.spark.mllib.classification.{NaiveBayesModel => OldNaiveBayesModel}
+import org.apache.spark.mllib.linalg._
+import org.apache.spark.mllib.regression.LabeledPoint
+import org.apache.spark.rdd.RDD
+import org.apache.spark.sql.DataFrame
+
+/**
+ * Params for Naive Bayes Classifiers.
+ */
+private[ml] trait NaiveBayesParams extends PredictorParams {
+
+ /**
+ * The smoothing parameter.
+ * (default = 1.0).
+ * @group param
+ */
+ final val lambda: DoubleParam = new DoubleParam(this, "lambda", "The smoothing parameter.",
+ ParamValidators.gtEq(0))
+
+ /** @group getParam */
+ final def getLambda: Double = $(lambda)
+
+ /**
+ * The model type which is a string (case-sensitive).
+ * Supported options: "multinomial" and "bernoulli".
+ * (default = multinomial)
+ * @group param
+ */
+ final val modelType: Param[String] = new Param[String](this, "modelType", "The model type " +
+ "which is a string (case-sensitive). Supported options: multinomial (default) and bernoulli.",
+ ParamValidators.inArray[String](OldNaiveBayes.supportedModelTypes.toArray))
+
+ /** @group getParam */
+ final def getModelType: String = $(modelType)
+}
+
+/**
+ * Naive Bayes Classifiers.
+ * It supports both Multinomial NB
+ * ([[http://nlp.stanford.edu/IR-book/html/htmledition/naive-bayes-text-classification-1.html]])
+ * which can handle finitely supported discrete data. For example, by converting documents into
+ * TF-IDF vectors, it can be used for document classification. By making every vector a
+ * binary (0/1) data, it can also be used as Bernoulli NB
+ * ([[http://nlp.stanford.edu/IR-book/html/htmledition/the-bernoulli-model-1.html]]).
+ * The input feature values must be nonnegative.
+ */
+class NaiveBayes(override val uid: String)
+ extends Predictor[Vector, NaiveBayes, NaiveBayesModel]
+ with NaiveBayesParams {
+
+ def this() = this(Identifiable.randomUID("nb"))
+
+ /**
+ * Set the smoothing parameter.
+ * Default is 1.0.
+ * @group setParam
+ */
+ def setLambda(value: Double): this.type = set(lambda, value)
+ setDefault(lambda -> 1.0)
+
+ /**
+ * Set the model type using a string (case-sensitive).
+ * Supported options: "multinomial" and "bernoulli".
+ * Default is "multinomial"
+ */
+ def setModelType(value: String): this.type = set(modelType, value)
+ setDefault(modelType -> OldNaiveBayes.Multinomial)
+
+ override protected def train(dataset: DataFrame): NaiveBayesModel = {
+ val oldDataset: RDD[LabeledPoint] = extractLabeledPoints(dataset)
+ val oldModel = OldNaiveBayes.train(oldDataset, $(lambda), $(modelType))
+ NaiveBayesModel.fromOld(oldModel, this)
+ }
+
+ override def copy(extra: ParamMap): NaiveBayes = defaultCopy(extra)
+}
+
+/**
+ * Model produced by [[NaiveBayes]]
+ */
+class NaiveBayesModel private[ml] (
+ override val uid: String,
+ val pi: Vector,
+ val theta: Matrix)
+ extends PredictionModel[Vector, NaiveBayesModel] with NaiveBayesParams {
+
+ import OldNaiveBayes.{Bernoulli, Multinomial}
+
+ /**
+ * Bernoulli scoring requires log(condprob) if 1, log(1-condprob) if 0.
+ * This precomputes log(1.0 - exp(theta)) and its sum which are used for the linear algebra
+ * application of this condition (in predict function).
+ */
+ private lazy val (thetaMinusNegTheta, negThetaSum) = $(modelType) match {
+ case Multinomial => (None, None)
+ case Bernoulli =>
+ val negTheta = theta.map(value => math.log(1.0 - math.exp(value)))
+ val ones = new DenseVector(Array.fill(theta.numCols){1.0})
+ val thetaMinusNegTheta = theta.map { value =>
+ value - math.log(1.0 - math.exp(value))
+ }
+ (Option(thetaMinusNegTheta), Option(negTheta.multiply(ones)))
+ case _ =>
+ // This should never happen.
+ throw new UnknownError(s"Invalid modelType: ${$(modelType)}.")
+ }
+
+ override protected def predict(features: Vector): Double = {
+ $(modelType) match {
+ case Multinomial =>
+ val prob = theta.multiply(features)
+ BLAS.axpy(1.0, pi, prob)
+ prob.argmax
+ case Bernoulli =>
+ features.foreachActive{ (index, value) =>
+ if (value != 0.0 && value != 1.0) {
+ throw new SparkException(
+ s"Bernoulli naive Bayes requires 0 or 1 feature values but found $features")
+ }
+ }
+ val prob = thetaMinusNegTheta.get.multiply(features)
+ BLAS.axpy(1.0, pi, prob)
+ BLAS.axpy(1.0, negThetaSum.get, prob)
+ prob.argmax
+ case _ =>
+ // This should never happen.
+ throw new UnknownError(s"Invalid modelType: ${$(modelType)}.")
+ }
+ }
+
+ override def copy(extra: ParamMap): NaiveBayesModel = {
+ copyValues(new NaiveBayesModel(uid, pi, theta).setParent(this.parent), extra)
+ }
+
+ override def toString: String = {
+ s"NaiveBayesModel with ${pi.size} classes"
+ }
+
+}
+
+private[ml] object NaiveBayesModel {
+
+ /** Convert a model from the old API */
+ def fromOld(
+ oldModel: OldNaiveBayesModel,
+ parent: NaiveBayes): NaiveBayesModel = {
+ val uid = if (parent != null) parent.uid else Identifiable.randomUID("nb")
+ val labels = Vectors.dense(oldModel.labels)
+ val pi = Vectors.dense(oldModel.pi)
+ val theta = new DenseMatrix(oldModel.labels.length, oldModel.theta(0).length,
+ oldModel.theta.flatten, true)
+ new NaiveBayesModel(uid, pi, theta)
+ }
+}
diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala
index d3c67494a31e4..fc0693f67cc2e 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala
@@ -20,17 +20,19 @@ package org.apache.spark.ml.classification
import scala.collection.mutable
import org.apache.spark.annotation.Experimental
+import org.apache.spark.ml.tree.impl.RandomForest
import org.apache.spark.ml.{PredictionModel, Predictor}
import org.apache.spark.ml.param.ParamMap
import org.apache.spark.ml.tree.{DecisionTreeModel, RandomForestParams, TreeClassifierParams, TreeEnsembleModel}
import org.apache.spark.ml.util.{Identifiable, MetadataUtils}
import org.apache.spark.mllib.linalg.Vector
import org.apache.spark.mllib.regression.LabeledPoint
-import org.apache.spark.mllib.tree.{RandomForest => OldRandomForest}
import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo}
import org.apache.spark.mllib.tree.model.{RandomForestModel => OldRandomForestModel}
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.DataFrame
+import org.apache.spark.sql.functions._
+import org.apache.spark.sql.types.DoubleType
/**
* :: Experimental ::
@@ -93,9 +95,10 @@ final class RandomForestClassifier(override val uid: String)
val oldDataset: RDD[LabeledPoint] = extractLabeledPoints(dataset)
val strategy =
super.getOldStrategy(categoricalFeatures, numClasses, OldAlgo.Classification, getOldImpurity)
- val oldModel = OldRandomForest.trainClassifier(
- oldDataset, strategy, getNumTrees, getFeatureSubsetStrategy, getSeed.toInt)
- RandomForestClassificationModel.fromOld(oldModel, this, categoricalFeatures)
+ val trees =
+ RandomForest.run(oldDataset, strategy, getNumTrees, getFeatureSubsetStrategy, getSeed)
+ .map(_.asInstanceOf[DecisionTreeClassificationModel])
+ new RandomForestClassificationModel(trees)
}
override def copy(extra: ParamMap): RandomForestClassifier = defaultCopy(extra)
@@ -128,6 +131,13 @@ final class RandomForestClassificationModel private[ml] (
require(numTrees > 0, "RandomForestClassificationModel requires at least 1 tree.")
+ /**
+ * Construct a random forest classification model, with all trees weighted equally.
+ * @param trees Component trees
+ */
+ def this(trees: Array[DecisionTreeClassificationModel]) =
+ this(Identifiable.randomUID("rfc"), trees)
+
override def trees: Array[DecisionTreeModel] = _trees.asInstanceOf[Array[DecisionTreeModel]]
// Note: We may add support for weights (based on tree performance) later on.
@@ -135,8 +145,15 @@ final class RandomForestClassificationModel private[ml] (
override def treeWeights: Array[Double] = _treeWeights
+ override protected def transformImpl(dataset: DataFrame): DataFrame = {
+ val bcastModel = dataset.sqlContext.sparkContext.broadcast(this)
+ val predictUDF = udf { (features: Any) =>
+ bcastModel.value.predict(features.asInstanceOf[Vector])
+ }
+ dataset.withColumn($(predictionCol), predictUDF(col($(featuresCol))))
+ }
+
override protected def predict(features: Vector): Double = {
- // TODO: Override transform() to broadcast model. SPARK-7127
// TODO: When we add a generic Bagging class, handle transform there: SPARK-7128
// Classifies using majority votes.
// Ignore the weights since all are 1.0 for now.
diff --git a/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala b/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala
new file mode 100644
index 0000000000000..dc192add6ca13
--- /dev/null
+++ b/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala
@@ -0,0 +1,205 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.ml.clustering
+
+import org.apache.spark.annotation.Experimental
+import org.apache.spark.ml.param.{Param, Params, IntParam, DoubleParam, ParamMap}
+import org.apache.spark.ml.param.shared.{HasFeaturesCol, HasMaxIter, HasPredictionCol, HasSeed}
+import org.apache.spark.ml.util.{Identifiable, SchemaUtils}
+import org.apache.spark.ml.{Estimator, Model}
+import org.apache.spark.mllib.clustering.{KMeans => MLlibKMeans, KMeansModel => MLlibKMeansModel}
+import org.apache.spark.mllib.linalg.{Vector, VectorUDT}
+import org.apache.spark.sql.functions.{col, udf}
+import org.apache.spark.sql.types.{IntegerType, StructType}
+import org.apache.spark.sql.{DataFrame, Row}
+import org.apache.spark.util.Utils
+
+
+/**
+ * Common params for KMeans and KMeansModel
+ */
+private[clustering] trait KMeansParams
+ extends Params with HasMaxIter with HasFeaturesCol with HasSeed with HasPredictionCol {
+
+ /**
+ * Set the number of clusters to create (k). Must be > 1. Default: 2.
+ * @group param
+ */
+ final val k = new IntParam(this, "k", "number of clusters to create", (x: Int) => x > 1)
+
+ /** @group getParam */
+ def getK: Int = $(k)
+
+ /**
+ * Param the number of runs of the algorithm to execute in parallel. We initialize the algorithm
+ * this many times with random starting conditions (configured by the initialization mode), then
+ * return the best clustering found over any run. Must be >= 1. Default: 1.
+ * @group param
+ */
+ final val runs = new IntParam(this, "runs",
+ "number of runs of the algorithm to execute in parallel", (value: Int) => value >= 1)
+
+ /** @group getParam */
+ def getRuns: Int = $(runs)
+
+ /**
+ * Param the distance threshold within which we've consider centers to have converged.
+ * If all centers move less than this Euclidean distance, we stop iterating one run.
+ * Must be >= 0.0. Default: 1e-4
+ * @group param
+ */
+ final val epsilon = new DoubleParam(this, "epsilon",
+ "distance threshold within which we've consider centers to have converge",
+ (value: Double) => value >= 0.0)
+
+ /** @group getParam */
+ def getEpsilon: Double = $(epsilon)
+
+ /**
+ * Param for the initialization algorithm. This can be either "random" to choose random points as
+ * initial cluster centers, or "k-means||" to use a parallel variant of k-means++
+ * (Bahmani et al., Scalable K-Means++, VLDB 2012). Default: k-means||.
+ * @group expertParam
+ */
+ final val initMode = new Param[String](this, "initMode", "initialization algorithm",
+ (value: String) => MLlibKMeans.validateInitMode(value))
+
+ /** @group expertGetParam */
+ def getInitMode: String = $(initMode)
+
+ /**
+ * Param for the number of steps for the k-means|| initialization mode. This is an advanced
+ * setting -- the default of 5 is almost always enough. Must be > 0. Default: 5.
+ * @group expertParam
+ */
+ final val initSteps = new IntParam(this, "initSteps", "number of steps for k-means||",
+ (value: Int) => value > 0)
+
+ /** @group expertGetParam */
+ def getInitSteps: Int = $(initSteps)
+
+ /**
+ * Validates and transforms the input schema.
+ * @param schema input schema
+ * @return output schema
+ */
+ protected def validateAndTransformSchema(schema: StructType): StructType = {
+ SchemaUtils.checkColumnType(schema, $(featuresCol), new VectorUDT)
+ SchemaUtils.appendColumn(schema, $(predictionCol), IntegerType)
+ }
+}
+
+/**
+ * :: Experimental ::
+ * Model fitted by KMeans.
+ *
+ * @param parentModel a model trained by spark.mllib.clustering.KMeans.
+ */
+@Experimental
+class KMeansModel private[ml] (
+ override val uid: String,
+ private val parentModel: MLlibKMeansModel) extends Model[KMeansModel] with KMeansParams {
+
+ override def copy(extra: ParamMap): KMeansModel = {
+ val copied = new KMeansModel(uid, parentModel)
+ copyValues(copied, extra)
+ }
+
+ override def transform(dataset: DataFrame): DataFrame = {
+ val predictUDF = udf((vector: Vector) => predict(vector))
+ dataset.withColumn($(predictionCol), predictUDF(col($(featuresCol))))
+ }
+
+ override def transformSchema(schema: StructType): StructType = {
+ validateAndTransformSchema(schema)
+ }
+
+ private[clustering] def predict(features: Vector): Int = parentModel.predict(features)
+
+ def clusterCenters: Array[Vector] = parentModel.clusterCenters
+}
+
+/**
+ * :: Experimental ::
+ * K-means clustering with support for multiple parallel runs and a k-means++ like initialization
+ * mode (the k-means|| algorithm by Bahmani et al). When multiple concurrent runs are requested,
+ * they are executed together with joint passes over the data for efficiency.
+ */
+@Experimental
+class KMeans(override val uid: String) extends Estimator[KMeansModel] with KMeansParams {
+
+ setDefault(
+ k -> 2,
+ maxIter -> 20,
+ runs -> 1,
+ initMode -> MLlibKMeans.K_MEANS_PARALLEL,
+ initSteps -> 5,
+ epsilon -> 1e-4)
+
+ override def copy(extra: ParamMap): KMeans = defaultCopy(extra)
+
+ def this() = this(Identifiable.randomUID("kmeans"))
+
+ /** @group setParam */
+ def setFeaturesCol(value: String): this.type = set(featuresCol, value)
+
+ /** @group setParam */
+ def setPredictionCol(value: String): this.type = set(predictionCol, value)
+
+ /** @group setParam */
+ def setK(value: Int): this.type = set(k, value)
+
+ /** @group expertSetParam */
+ def setInitMode(value: String): this.type = set(initMode, value)
+
+ /** @group expertSetParam */
+ def setInitSteps(value: Int): this.type = set(initSteps, value)
+
+ /** @group setParam */
+ def setMaxIter(value: Int): this.type = set(maxIter, value)
+
+ /** @group setParam */
+ def setRuns(value: Int): this.type = set(runs, value)
+
+ /** @group setParam */
+ def setEpsilon(value: Double): this.type = set(epsilon, value)
+
+ /** @group setParam */
+ def setSeed(value: Long): this.type = set(seed, value)
+
+ override def fit(dataset: DataFrame): KMeansModel = {
+ val rdd = dataset.select(col($(featuresCol))).map { case Row(point: Vector) => point }
+
+ val algo = new MLlibKMeans()
+ .setK($(k))
+ .setInitializationMode($(initMode))
+ .setInitializationSteps($(initSteps))
+ .setMaxIterations($(maxIter))
+ .setSeed($(seed))
+ .setEpsilon($(epsilon))
+ .setRuns($(runs))
+ val parentModel = algo.run(rdd)
+ val model = new KMeansModel(uid, parentModel)
+ copyValues(model)
+ }
+
+ override def transformSchema(schema: StructType): StructType = {
+ validateAndTransformSchema(schema)
+ }
+}
+
diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/RFormula.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/RFormula.scala
new file mode 100644
index 0000000000000..f7b46efa10e90
--- /dev/null
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/RFormula.scala
@@ -0,0 +1,159 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.ml.feature
+
+import scala.util.parsing.combinator.RegexParsers
+
+import org.apache.spark.annotation.Experimental
+import org.apache.spark.ml.Transformer
+import org.apache.spark.ml.param.{Param, ParamMap}
+import org.apache.spark.ml.param.shared.{HasFeaturesCol, HasLabelCol}
+import org.apache.spark.ml.util.Identifiable
+import org.apache.spark.sql.DataFrame
+import org.apache.spark.sql.functions._
+import org.apache.spark.sql.types._
+
+/**
+ * :: Experimental ::
+ * Implements the transforms required for fitting a dataset against an R model formula. Currently
+ * we support a limited subset of the R operators, including '~' and '+'. Also see the R formula
+ * docs here: http://stat.ethz.ch/R-manual/R-patched/library/stats/html/formula.html
+ */
+@Experimental
+class RFormula(override val uid: String)
+ extends Transformer with HasFeaturesCol with HasLabelCol {
+
+ def this() = this(Identifiable.randomUID("rFormula"))
+
+ /**
+ * R formula parameter. The formula is provided in string form.
+ * @group param
+ */
+ val formula: Param[String] = new Param(this, "formula", "R model formula")
+
+ private var parsedFormula: Option[ParsedRFormula] = None
+
+ /**
+ * Sets the formula to use for this transformer. Must be called before use.
+ * @group setParam
+ * @param value an R formula in string form (e.g. "y ~ x + z")
+ */
+ def setFormula(value: String): this.type = {
+ parsedFormula = Some(RFormulaParser.parse(value))
+ set(formula, value)
+ this
+ }
+
+ /** @group getParam */
+ def getFormula: String = $(formula)
+
+ /** @group getParam */
+ def setFeaturesCol(value: String): this.type = set(featuresCol, value)
+
+ /** @group getParam */
+ def setLabelCol(value: String): this.type = set(labelCol, value)
+
+ override def transformSchema(schema: StructType): StructType = {
+ checkCanTransform(schema)
+ val withFeatures = transformFeatures.transformSchema(schema)
+ if (hasLabelCol(schema)) {
+ withFeatures
+ } else if (schema.exists(_.name == parsedFormula.get.label)) {
+ val nullable = schema(parsedFormula.get.label).dataType match {
+ case _: NumericType | BooleanType => false
+ case _ => true
+ }
+ StructType(withFeatures.fields :+ StructField($(labelCol), DoubleType, nullable))
+ } else {
+ // Ignore the label field. This is a hack so that this transformer can also work on test
+ // datasets in a Pipeline.
+ withFeatures
+ }
+ }
+
+ override def transform(dataset: DataFrame): DataFrame = {
+ checkCanTransform(dataset.schema)
+ transformLabel(transformFeatures.transform(dataset))
+ }
+
+ override def copy(extra: ParamMap): RFormula = defaultCopy(extra)
+
+ override def toString: String = s"RFormula(${get(formula)})"
+
+ private def transformLabel(dataset: DataFrame): DataFrame = {
+ val labelName = parsedFormula.get.label
+ if (hasLabelCol(dataset.schema)) {
+ dataset
+ } else if (dataset.schema.exists(_.name == labelName)) {
+ dataset.schema(labelName).dataType match {
+ case _: NumericType | BooleanType =>
+ dataset.withColumn($(labelCol), dataset(labelName).cast(DoubleType))
+ // TODO(ekl) add support for string-type labels
+ case other =>
+ throw new IllegalArgumentException("Unsupported type for label: " + other)
+ }
+ } else {
+ // Ignore the label field. This is a hack so that this transformer can also work on test
+ // datasets in a Pipeline.
+ dataset
+ }
+ }
+
+ private def transformFeatures: Transformer = {
+ // TODO(ekl) add support for non-numeric features and feature interactions
+ new VectorAssembler(uid)
+ .setInputCols(parsedFormula.get.terms.toArray)
+ .setOutputCol($(featuresCol))
+ }
+
+ private def checkCanTransform(schema: StructType) {
+ require(parsedFormula.isDefined, "Must call setFormula() first.")
+ val columnNames = schema.map(_.name)
+ require(!columnNames.contains($(featuresCol)), "Features column already exists.")
+ require(
+ !columnNames.contains($(labelCol)) || schema($(labelCol)).dataType == DoubleType,
+ "Label column already exists and is not of type DoubleType.")
+ }
+
+ private def hasLabelCol(schema: StructType): Boolean = {
+ schema.map(_.name).contains($(labelCol))
+ }
+}
+
+/**
+ * Represents a parsed R formula.
+ */
+private[ml] case class ParsedRFormula(label: String, terms: Seq[String])
+
+/**
+ * Limited implementation of R formula parsing. Currently supports: '~', '+'.
+ */
+private[ml] object RFormulaParser extends RegexParsers {
+ def term: Parser[String] = "([a-zA-Z]|\\.[a-zA-Z_])[a-zA-Z0-9._]*".r
+
+ def expr: Parser[List[String]] = term ~ rep("+" ~> term) ^^ { case a ~ list => a :: list }
+
+ def formula: Parser[ParsedRFormula] =
+ (term ~ "~" ~ expr) ^^ { case r ~ "~" ~ t => ParsedRFormula(r, t) }
+
+ def parse(value: String): ParsedRFormula = parseAll(formula, value) match {
+ case Success(result, _) => result
+ case failure: NoSuccess => throw new IllegalArgumentException(
+ "Could not parse formula: " + value)
+ }
+}
diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/Tokenizer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/Tokenizer.scala
index 5f9f57a2ebcfa..0b3af4747e693 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/Tokenizer.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/Tokenizer.scala
@@ -42,7 +42,7 @@ class Tokenizer(override val uid: String) extends UnaryTransformer[String, Seq[S
require(inputType == StringType, s"Input type must be string type but got $inputType.")
}
- override protected def outputDataType: DataType = new ArrayType(StringType, false)
+ override protected def outputDataType: DataType = new ArrayType(StringType, true)
override def copy(extra: ParamMap): Tokenizer = defaultCopy(extra)
}
@@ -113,7 +113,7 @@ class RegexTokenizer(override val uid: String)
require(inputType == StringType, s"Input type must be string type but got $inputType.")
}
- override protected def outputDataType: DataType = new ArrayType(StringType, false)
+ override protected def outputDataType: DataType = new ArrayType(StringType, true)
override def copy(extra: ParamMap): RegexTokenizer = defaultCopy(extra)
}
diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/VectorAssembler.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/VectorAssembler.scala
index 9f83c2ee16178..086917fa680f8 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/VectorAssembler.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/VectorAssembler.scala
@@ -116,7 +116,7 @@ class VectorAssembler(override val uid: String)
if (schema.fieldNames.contains(outputColName)) {
throw new IllegalArgumentException(s"Output column $outputColName already exists.")
}
- StructType(schema.fields :+ new StructField(outputColName, new VectorUDT, false))
+ StructType(schema.fields :+ new StructField(outputColName, new VectorUDT, true))
}
override def copy(extra: ParamMap): VectorAssembler = defaultCopy(extra)
diff --git a/mllib/src/main/scala/org/apache/spark/ml/param/params.scala b/mllib/src/main/scala/org/apache/spark/ml/param/params.scala
index d034d7ec6b60e..954aa17e26a02 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/param/params.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/param/params.scala
@@ -295,6 +295,22 @@ class DoubleArrayParam(parent: Params, name: String, doc: String, isValid: Array
w(value.asScala.map(_.asInstanceOf[Double]).toArray)
}
+/**
+ * :: DeveloperApi ::
+ * Specialized version of [[Param[Array[Int]]]] for Java.
+ */
+@DeveloperApi
+class IntArrayParam(parent: Params, name: String, doc: String, isValid: Array[Int] => Boolean)
+ extends Param[Array[Int]](parent, name, doc, isValid) {
+
+ def this(parent: Params, name: String, doc: String) =
+ this(parent, name, doc, ParamValidators.alwaysTrue)
+
+ /** Creates a param pair with a [[java.util.List]] of values (for Java and Python). */
+ def w(value: java.util.List[java.lang.Integer]): ParamPair[Array[Int]] =
+ w(value.asScala.map(_.asInstanceOf[Int]).toArray)
+}
+
/**
* :: Experimental ::
* A param and its value.
@@ -460,11 +476,14 @@ trait Params extends Identifiable with Serializable {
/**
* Sets default values for a list of params.
*
+ * Note: Java developers should use the single-parameter [[setDefault()]].
+ * Annotating this with varargs can cause compilation failures due to a Scala compiler bug.
+ * See SPARK-9268.
+ *
* @param paramPairs a list of param pairs that specify params and their default values to set
* respectively. Make sure that the params are initialized before this method
* gets called.
*/
- @varargs
protected final def setDefault(paramPairs: ParamPair[_]*): this.type = {
paramPairs.foreach { p =>
setDefault(p.param.asInstanceOf[Param[Any]], p.value)
diff --git a/mllib/src/main/scala/org/apache/spark/ml/param/shared/SharedParamsCodeGen.scala b/mllib/src/main/scala/org/apache/spark/ml/param/shared/SharedParamsCodeGen.scala
index 66b751a1b02ee..f7ae1de522e01 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/param/shared/SharedParamsCodeGen.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/param/shared/SharedParamsCodeGen.scala
@@ -134,7 +134,7 @@ private[shared] object SharedParamsCodeGen {
s"""
|/**
- | * (private[ml]) Trait for shared param $name$defaultValueDoc.
+ | * Trait for shared param $name$defaultValueDoc.
| */
|private[ml] trait Has$Name extends Params {
|
@@ -173,7 +173,6 @@ private[shared] object SharedParamsCodeGen {
|package org.apache.spark.ml.param.shared
|
|import org.apache.spark.ml.param._
- |import org.apache.spark.util.Utils
|
|// DO NOT MODIFY THIS FILE! It was generated by SharedParamsCodeGen.
|
diff --git a/mllib/src/main/scala/org/apache/spark/ml/param/shared/sharedParams.scala b/mllib/src/main/scala/org/apache/spark/ml/param/shared/sharedParams.scala
index f81bd76c22376..65e48e4ee5083 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/param/shared/sharedParams.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/param/shared/sharedParams.scala
@@ -18,14 +18,13 @@
package org.apache.spark.ml.param.shared
import org.apache.spark.ml.param._
-import org.apache.spark.util.Utils
// DO NOT MODIFY THIS FILE! It was generated by SharedParamsCodeGen.
// scalastyle:off
/**
- * (private[ml]) Trait for shared param regParam.
+ * Trait for shared param regParam.
*/
private[ml] trait HasRegParam extends Params {
@@ -40,7 +39,7 @@ private[ml] trait HasRegParam extends Params {
}
/**
- * (private[ml]) Trait for shared param maxIter.
+ * Trait for shared param maxIter.
*/
private[ml] trait HasMaxIter extends Params {
@@ -55,7 +54,7 @@ private[ml] trait HasMaxIter extends Params {
}
/**
- * (private[ml]) Trait for shared param featuresCol (default: "features").
+ * Trait for shared param featuresCol (default: "features").
*/
private[ml] trait HasFeaturesCol extends Params {
@@ -72,7 +71,7 @@ private[ml] trait HasFeaturesCol extends Params {
}
/**
- * (private[ml]) Trait for shared param labelCol (default: "label").
+ * Trait for shared param labelCol (default: "label").
*/
private[ml] trait HasLabelCol extends Params {
@@ -89,7 +88,7 @@ private[ml] trait HasLabelCol extends Params {
}
/**
- * (private[ml]) Trait for shared param predictionCol (default: "prediction").
+ * Trait for shared param predictionCol (default: "prediction").
*/
private[ml] trait HasPredictionCol extends Params {
@@ -106,7 +105,7 @@ private[ml] trait HasPredictionCol extends Params {
}
/**
- * (private[ml]) Trait for shared param rawPredictionCol (default: "rawPrediction").
+ * Trait for shared param rawPredictionCol (default: "rawPrediction").
*/
private[ml] trait HasRawPredictionCol extends Params {
@@ -123,7 +122,7 @@ private[ml] trait HasRawPredictionCol extends Params {
}
/**
- * (private[ml]) Trait for shared param probabilityCol (default: "probability").
+ * Trait for shared param probabilityCol (default: "probability").
*/
private[ml] trait HasProbabilityCol extends Params {
@@ -140,7 +139,7 @@ private[ml] trait HasProbabilityCol extends Params {
}
/**
- * (private[ml]) Trait for shared param threshold.
+ * Trait for shared param threshold.
*/
private[ml] trait HasThreshold extends Params {
@@ -155,7 +154,7 @@ private[ml] trait HasThreshold extends Params {
}
/**
- * (private[ml]) Trait for shared param inputCol.
+ * Trait for shared param inputCol.
*/
private[ml] trait HasInputCol extends Params {
@@ -170,7 +169,7 @@ private[ml] trait HasInputCol extends Params {
}
/**
- * (private[ml]) Trait for shared param inputCols.
+ * Trait for shared param inputCols.
*/
private[ml] trait HasInputCols extends Params {
@@ -185,7 +184,7 @@ private[ml] trait HasInputCols extends Params {
}
/**
- * (private[ml]) Trait for shared param outputCol (default: uid + "__output").
+ * Trait for shared param outputCol (default: uid + "__output").
*/
private[ml] trait HasOutputCol extends Params {
@@ -202,7 +201,7 @@ private[ml] trait HasOutputCol extends Params {
}
/**
- * (private[ml]) Trait for shared param checkpointInterval.
+ * Trait for shared param checkpointInterval.
*/
private[ml] trait HasCheckpointInterval extends Params {
@@ -217,7 +216,7 @@ private[ml] trait HasCheckpointInterval extends Params {
}
/**
- * (private[ml]) Trait for shared param fitIntercept (default: true).
+ * Trait for shared param fitIntercept (default: true).
*/
private[ml] trait HasFitIntercept extends Params {
@@ -234,7 +233,7 @@ private[ml] trait HasFitIntercept extends Params {
}
/**
- * (private[ml]) Trait for shared param standardization (default: true).
+ * Trait for shared param standardization (default: true).
*/
private[ml] trait HasStandardization extends Params {
@@ -251,7 +250,7 @@ private[ml] trait HasStandardization extends Params {
}
/**
- * (private[ml]) Trait for shared param seed (default: this.getClass.getName.hashCode.toLong).
+ * Trait for shared param seed (default: this.getClass.getName.hashCode.toLong).
*/
private[ml] trait HasSeed extends Params {
@@ -268,7 +267,7 @@ private[ml] trait HasSeed extends Params {
}
/**
- * (private[ml]) Trait for shared param elasticNetParam.
+ * Trait for shared param elasticNetParam.
*/
private[ml] trait HasElasticNetParam extends Params {
@@ -283,7 +282,7 @@ private[ml] trait HasElasticNetParam extends Params {
}
/**
- * (private[ml]) Trait for shared param tol.
+ * Trait for shared param tol.
*/
private[ml] trait HasTol extends Params {
@@ -298,7 +297,7 @@ private[ml] trait HasTol extends Params {
}
/**
- * (private[ml]) Trait for shared param stepSize.
+ * Trait for shared param stepSize.
*/
private[ml] trait HasStepSize extends Params {
diff --git a/mllib/src/main/scala/org/apache/spark/ml/r/SparkRWrappers.scala b/mllib/src/main/scala/org/apache/spark/ml/r/SparkRWrappers.scala
new file mode 100644
index 0000000000000..1ee080641e3e3
--- /dev/null
+++ b/mllib/src/main/scala/org/apache/spark/ml/r/SparkRWrappers.scala
@@ -0,0 +1,41 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.ml.api.r
+
+import org.apache.spark.ml.feature.RFormula
+import org.apache.spark.ml.classification.LogisticRegression
+import org.apache.spark.ml.regression.LinearRegression
+import org.apache.spark.ml.{Pipeline, PipelineModel}
+import org.apache.spark.sql.DataFrame
+
+private[r] object SparkRWrappers {
+ def fitRModelFormula(
+ value: String,
+ df: DataFrame,
+ family: String,
+ lambda: Double,
+ alpha: Double): PipelineModel = {
+ val formula = new RFormula().setFormula(value)
+ val estimator = family match {
+ case "gaussian" => new LinearRegression().setRegParam(lambda).setElasticNetParam(alpha)
+ case "binomial" => new LogisticRegression().setRegParam(lambda).setElasticNetParam(alpha)
+ }
+ val pipeline = new Pipeline().setStages(Array(formula, estimator))
+ pipeline.fit(df)
+ }
+}
diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/DecisionTreeRegressor.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/DecisionTreeRegressor.scala
index be1f8063d41d8..6f3340c2f02be 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/regression/DecisionTreeRegressor.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/regression/DecisionTreeRegressor.scala
@@ -21,10 +21,10 @@ import org.apache.spark.annotation.Experimental
import org.apache.spark.ml.{PredictionModel, Predictor}
import org.apache.spark.ml.param.ParamMap
import org.apache.spark.ml.tree.{DecisionTreeModel, DecisionTreeParams, Node, TreeRegressorParams}
+import org.apache.spark.ml.tree.impl.RandomForest
import org.apache.spark.ml.util.{Identifiable, MetadataUtils}
import org.apache.spark.mllib.linalg.Vector
import org.apache.spark.mllib.regression.LabeledPoint
-import org.apache.spark.mllib.tree.{DecisionTree => OldDecisionTree}
import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo, Strategy => OldStrategy}
import org.apache.spark.mllib.tree.model.{DecisionTreeModel => OldDecisionTreeModel}
import org.apache.spark.rdd.RDD
@@ -67,8 +67,9 @@ final class DecisionTreeRegressor(override val uid: String)
MetadataUtils.getCategoricalFeatures(dataset.schema($(featuresCol)))
val oldDataset: RDD[LabeledPoint] = extractLabeledPoints(dataset)
val strategy = getOldStrategy(categoricalFeatures)
- val oldModel = OldDecisionTree.train(oldDataset, strategy)
- DecisionTreeRegressionModel.fromOld(oldModel, this, categoricalFeatures)
+ val trees = RandomForest.run(oldDataset, strategy, numTrees = 1, featureSubsetStrategy = "all",
+ seed = 0L, parentUID = Some(uid))
+ trees.head.asInstanceOf[DecisionTreeRegressionModel]
}
/** (private[ml]) Create a Strategy instance to use with the old API. */
@@ -102,6 +103,12 @@ final class DecisionTreeRegressionModel private[ml] (
require(rootNode != null,
"DecisionTreeClassificationModel given null rootNode, but it requires a non-null rootNode.")
+ /**
+ * Construct a decision tree regression model.
+ * @param rootNode Root node of tree, with other nodes attached.
+ */
+ def this(rootNode: Node) = this(Identifiable.randomUID("dtr"), rootNode)
+
override protected def predict(features: Vector): Double = {
rootNode.predict(features)
}
diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala
index 47c110d027d67..e38dc73ee0ba7 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala
@@ -33,6 +33,8 @@ import org.apache.spark.mllib.tree.loss.{AbsoluteError => OldAbsoluteError, Loss
import org.apache.spark.mllib.tree.model.{GradientBoostedTreesModel => OldGBTModel}
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.DataFrame
+import org.apache.spark.sql.functions._
+import org.apache.spark.sql.types.DoubleType
/**
* :: Experimental ::
@@ -167,8 +169,15 @@ final class GBTRegressionModel(
override def treeWeights: Array[Double] = _treeWeights
+ override protected def transformImpl(dataset: DataFrame): DataFrame = {
+ val bcastModel = dataset.sqlContext.sparkContext.broadcast(this)
+ val predictUDF = udf { (features: Any) =>
+ bcastModel.value.predict(features.asInstanceOf[Vector])
+ }
+ dataset.withColumn($(predictionCol), predictUDF(col($(featuresCol))))
+ }
+
override protected def predict(features: Vector): Double = {
- // TODO: Override transform() to broadcast model. SPARK-7127
// TODO: When we add a generic Boosting class, handle transform there? SPARK-7129
// Classifies by thresholding sum of weighted tree predictions
val treePredictions = _trees.map(_.rootNode.predict(features))
diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala
index 8fc986056657d..89718e0f3e15a 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala
@@ -355,9 +355,9 @@ class LinearRegressionSummary private[regression] (
*/
val r2: Double = metrics.r2
- /** Residuals (predicted value - label value) */
+ /** Residuals (label - predicted value) */
@transient lazy val residuals: DataFrame = {
- val t = udf { (pred: Double, label: Double) => pred - label}
+ val t = udf { (pred: Double, label: Double) => label - pred }
predictions.select(t(col(predictionCol), col(labelCol)).as("residuals"))
}
diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/RandomForestRegressor.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/RandomForestRegressor.scala
index 21c59061a02fa..506a878c2553b 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/regression/RandomForestRegressor.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/regression/RandomForestRegressor.scala
@@ -21,14 +21,16 @@ import org.apache.spark.annotation.Experimental
import org.apache.spark.ml.{PredictionModel, Predictor}
import org.apache.spark.ml.param.ParamMap
import org.apache.spark.ml.tree.{DecisionTreeModel, RandomForestParams, TreeEnsembleModel, TreeRegressorParams}
+import org.apache.spark.ml.tree.impl.RandomForest
import org.apache.spark.ml.util.{Identifiable, MetadataUtils}
import org.apache.spark.mllib.linalg.Vector
import org.apache.spark.mllib.regression.LabeledPoint
-import org.apache.spark.mllib.tree.{RandomForest => OldRandomForest}
import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo}
import org.apache.spark.mllib.tree.model.{RandomForestModel => OldRandomForestModel}
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.DataFrame
+import org.apache.spark.sql.functions._
+import org.apache.spark.sql.types.DoubleType
/**
* :: Experimental ::
@@ -82,9 +84,10 @@ final class RandomForestRegressor(override val uid: String)
val oldDataset: RDD[LabeledPoint] = extractLabeledPoints(dataset)
val strategy =
super.getOldStrategy(categoricalFeatures, numClasses = 0, OldAlgo.Regression, getOldImpurity)
- val oldModel = OldRandomForest.trainRegressor(
- oldDataset, strategy, getNumTrees, getFeatureSubsetStrategy, getSeed.toInt)
- RandomForestRegressionModel.fromOld(oldModel, this, categoricalFeatures)
+ val trees =
+ RandomForest.run(oldDataset, strategy, getNumTrees, getFeatureSubsetStrategy, getSeed)
+ .map(_.asInstanceOf[DecisionTreeRegressionModel])
+ new RandomForestRegressionModel(trees)
}
override def copy(extra: ParamMap): RandomForestRegressor = defaultCopy(extra)
@@ -115,6 +118,12 @@ final class RandomForestRegressionModel private[ml] (
require(numTrees > 0, "RandomForestRegressionModel requires at least 1 tree.")
+ /**
+ * Construct a random forest regression model, with all trees weighted equally.
+ * @param trees Component trees
+ */
+ def this(trees: Array[DecisionTreeRegressionModel]) = this(Identifiable.randomUID("rfr"), trees)
+
override def trees: Array[DecisionTreeModel] = _trees.asInstanceOf[Array[DecisionTreeModel]]
// Note: We may add support for weights (based on tree performance) later on.
@@ -122,8 +131,15 @@ final class RandomForestRegressionModel private[ml] (
override def treeWeights: Array[Double] = _treeWeights
+ override protected def transformImpl(dataset: DataFrame): DataFrame = {
+ val bcastModel = dataset.sqlContext.sparkContext.broadcast(this)
+ val predictUDF = udf { (features: Any) =>
+ bcastModel.value.predict(features.asInstanceOf[Vector])
+ }
+ dataset.withColumn($(predictionCol), predictUDF(col($(featuresCol))))
+ }
+
override protected def predict(features: Vector): Double = {
- // TODO: Override transform() to broadcast model. SPARK-7127
// TODO: When we add a generic Bagging class, handle transform there. SPARK-7128
// Predict average of tree predictions.
// Ignore the weights since all are 1.0 for now.
diff --git a/mllib/src/main/scala/org/apache/spark/ml/tree/Node.scala b/mllib/src/main/scala/org/apache/spark/ml/tree/Node.scala
index 4242154be14ce..bbc2427ca7d3d 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/tree/Node.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/tree/Node.scala
@@ -209,3 +209,132 @@ private object InternalNode {
}
}
}
+
+/**
+ * Version of a node used in learning. This uses vars so that we can modify nodes as we split the
+ * tree by adding children, etc.
+ *
+ * For now, we use node IDs. These will be kept internal since we hope to remove node IDs
+ * in the future, or at least change the indexing (so that we can support much deeper trees).
+ *
+ * This node can either be:
+ * - a leaf node, with leftChild, rightChild, split set to null, or
+ * - an internal node, with all values set
+ *
+ * @param id We currently use the same indexing as the old implementation in
+ * [[org.apache.spark.mllib.tree.model.Node]], but this will change later.
+ * @param predictionStats Predicted label + class probability (for classification).
+ * We will later modify this to store aggregate statistics for labels
+ * to provide all class probabilities (for classification) and maybe a
+ * distribution (for regression).
+ * @param isLeaf Indicates whether this node will definitely be a leaf in the learned tree,
+ * so that we do not need to consider splitting it further.
+ * @param stats Old structure for storing stats about information gain, prediction, etc.
+ * This is legacy and will be modified in the future.
+ */
+private[tree] class LearningNode(
+ var id: Int,
+ var predictionStats: OldPredict,
+ var impurity: Double,
+ var leftChild: Option[LearningNode],
+ var rightChild: Option[LearningNode],
+ var split: Option[Split],
+ var isLeaf: Boolean,
+ var stats: Option[OldInformationGainStats]) extends Serializable {
+
+ /**
+ * Convert this [[LearningNode]] to a regular [[Node]], and recurse on any children.
+ */
+ def toNode: Node = {
+ if (leftChild.nonEmpty) {
+ assert(rightChild.nonEmpty && split.nonEmpty && stats.nonEmpty,
+ "Unknown error during Decision Tree learning. Could not convert LearningNode to Node.")
+ new InternalNode(predictionStats.predict, impurity, stats.get.gain,
+ leftChild.get.toNode, rightChild.get.toNode, split.get)
+ } else {
+ new LeafNode(predictionStats.predict, impurity)
+ }
+ }
+
+}
+
+private[tree] object LearningNode {
+
+ /** Create a node with some of its fields set. */
+ def apply(
+ id: Int,
+ predictionStats: OldPredict,
+ impurity: Double,
+ isLeaf: Boolean): LearningNode = {
+ new LearningNode(id, predictionStats, impurity, None, None, None, false, None)
+ }
+
+ /** Create an empty node with the given node index. Values must be set later on. */
+ def emptyNode(nodeIndex: Int): LearningNode = {
+ new LearningNode(nodeIndex, new OldPredict(Double.NaN, Double.NaN), Double.NaN,
+ None, None, None, false, None)
+ }
+
+ // The below indexing methods were copied from spark.mllib.tree.model.Node
+
+ /**
+ * Return the index of the left child of this node.
+ */
+ def leftChildIndex(nodeIndex: Int): Int = nodeIndex << 1
+
+ /**
+ * Return the index of the right child of this node.
+ */
+ def rightChildIndex(nodeIndex: Int): Int = (nodeIndex << 1) + 1
+
+ /**
+ * Get the parent index of the given node, or 0 if it is the root.
+ */
+ def parentIndex(nodeIndex: Int): Int = nodeIndex >> 1
+
+ /**
+ * Return the level of a tree which the given node is in.
+ */
+ def indexToLevel(nodeIndex: Int): Int = if (nodeIndex == 0) {
+ throw new IllegalArgumentException(s"0 is not a valid node index.")
+ } else {
+ java.lang.Integer.numberOfTrailingZeros(java.lang.Integer.highestOneBit(nodeIndex))
+ }
+
+ /**
+ * Returns true if this is a left child.
+ * Note: Returns false for the root.
+ */
+ def isLeftChild(nodeIndex: Int): Boolean = nodeIndex > 1 && nodeIndex % 2 == 0
+
+ /**
+ * Return the maximum number of nodes which can be in the given level of the tree.
+ * @param level Level of tree (0 = root).
+ */
+ def maxNodesInLevel(level: Int): Int = 1 << level
+
+ /**
+ * Return the index of the first node in the given level.
+ * @param level Level of tree (0 = root).
+ */
+ def startIndexInLevel(level: Int): Int = 1 << level
+
+ /**
+ * Traces down from a root node to get the node with the given node index.
+ * This assumes the node exists.
+ */
+ def getNode(nodeIndex: Int, rootNode: LearningNode): LearningNode = {
+ var tmpNode: LearningNode = rootNode
+ var levelsToGo = indexToLevel(nodeIndex)
+ while (levelsToGo > 0) {
+ if ((nodeIndex & (1 << levelsToGo - 1)) == 0) {
+ tmpNode = tmpNode.leftChild.asInstanceOf[LearningNode]
+ } else {
+ tmpNode = tmpNode.rightChild.asInstanceOf[LearningNode]
+ }
+ levelsToGo -= 1
+ }
+ tmpNode
+ }
+
+}
diff --git a/mllib/src/main/scala/org/apache/spark/ml/tree/Split.scala b/mllib/src/main/scala/org/apache/spark/ml/tree/Split.scala
index 7acdeeee72d23..78199cc2df582 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/tree/Split.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/tree/Split.scala
@@ -34,9 +34,19 @@ sealed trait Split extends Serializable {
/** Index of feature which this split tests */
def featureIndex: Int
- /** Return true (split to left) or false (split to right) */
+ /**
+ * Return true (split to left) or false (split to right).
+ * @param features Vector of features (original values, not binned).
+ */
private[ml] def shouldGoLeft(features: Vector): Boolean
+ /**
+ * Return true (split to left) or false (split to right).
+ * @param binnedFeature Binned feature value.
+ * @param splits All splits for the given feature.
+ */
+ private[tree] def shouldGoLeft(binnedFeature: Int, splits: Array[Split]): Boolean
+
/** Convert to old Split format */
private[tree] def toOld: OldSplit
}
@@ -94,6 +104,14 @@ final class CategoricalSplit private[ml] (
}
}
+ override private[tree] def shouldGoLeft(binnedFeature: Int, splits: Array[Split]): Boolean = {
+ if (isLeft) {
+ categories.contains(binnedFeature.toDouble)
+ } else {
+ !categories.contains(binnedFeature.toDouble)
+ }
+ }
+
override def equals(o: Any): Boolean = {
o match {
case other: CategoricalSplit => featureIndex == other.featureIndex &&
@@ -144,6 +162,16 @@ final class ContinuousSplit private[ml] (override val featureIndex: Int, val thr
features(featureIndex) <= threshold
}
+ override private[tree] def shouldGoLeft(binnedFeature: Int, splits: Array[Split]): Boolean = {
+ if (binnedFeature == splits.length) {
+ // > last split, so split right
+ false
+ } else {
+ val featureValueUpperBound = splits(binnedFeature).asInstanceOf[ContinuousSplit].threshold
+ featureValueUpperBound <= threshold
+ }
+ }
+
override def equals(o: Any): Boolean = {
o match {
case other: ContinuousSplit =>
diff --git a/mllib/src/main/scala/org/apache/spark/ml/tree/impl/NodeIdCache.scala b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/NodeIdCache.scala
new file mode 100644
index 0000000000000..488e8e4fb5dcd
--- /dev/null
+++ b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/NodeIdCache.scala
@@ -0,0 +1,194 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.ml.tree.impl
+
+import java.io.IOException
+
+import scala.collection.mutable
+
+import org.apache.hadoop.fs.{Path, FileSystem}
+
+import org.apache.spark.Logging
+import org.apache.spark.annotation.DeveloperApi
+import org.apache.spark.ml.tree.{LearningNode, Split}
+import org.apache.spark.mllib.tree.impl.BaggedPoint
+import org.apache.spark.rdd.RDD
+import org.apache.spark.storage.StorageLevel
+
+
+/**
+ * This is used by the node id cache to find the child id that a data point would belong to.
+ * @param split Split information.
+ * @param nodeIndex The current node index of a data point that this will update.
+ */
+private[tree] case class NodeIndexUpdater(split: Split, nodeIndex: Int) {
+
+ /**
+ * Determine a child node index based on the feature value and the split.
+ * @param binnedFeature Binned feature value.
+ * @param splits Split information to convert the bin indices to approximate feature values.
+ * @return Child node index to update to.
+ */
+ def updateNodeIndex(binnedFeature: Int, splits: Array[Split]): Int = {
+ if (split.shouldGoLeft(binnedFeature, splits)) {
+ LearningNode.leftChildIndex(nodeIndex)
+ } else {
+ LearningNode.rightChildIndex(nodeIndex)
+ }
+ }
+}
+
+/**
+ * Each TreePoint belongs to a particular node per tree.
+ * Each row in the nodeIdsForInstances RDD is an array over trees of the node index
+ * in each tree. Initially, values should all be 1 for root node.
+ * The nodeIdsForInstances RDD needs to be updated at each iteration.
+ * @param nodeIdsForInstances The initial values in the cache
+ * (should be an Array of all 1's (meaning the root nodes)).
+ * @param checkpointInterval The checkpointing interval
+ * (how often should the cache be checkpointed.).
+ */
+private[spark] class NodeIdCache(
+ var nodeIdsForInstances: RDD[Array[Int]],
+ val checkpointInterval: Int) extends Logging {
+
+ // Keep a reference to a previous node Ids for instances.
+ // Because we will keep on re-persisting updated node Ids,
+ // we want to unpersist the previous RDD.
+ private var prevNodeIdsForInstances: RDD[Array[Int]] = null
+
+ // To keep track of the past checkpointed RDDs.
+ private val checkpointQueue = mutable.Queue[RDD[Array[Int]]]()
+ private var rddUpdateCount = 0
+
+ // Indicates whether we can checkpoint
+ private val canCheckpoint = nodeIdsForInstances.sparkContext.getCheckpointDir.nonEmpty
+
+ // FileSystem instance for deleting checkpoints as needed
+ private val fs = FileSystem.get(nodeIdsForInstances.sparkContext.hadoopConfiguration)
+
+ /**
+ * Update the node index values in the cache.
+ * This updates the RDD and its lineage.
+ * TODO: Passing bin information to executors seems unnecessary and costly.
+ * @param data The RDD of training rows.
+ * @param nodeIdUpdaters A map of node index updaters.
+ * The key is the indices of nodes that we want to update.
+ * @param splits Split information needed to find child node indices.
+ */
+ def updateNodeIndices(
+ data: RDD[BaggedPoint[TreePoint]],
+ nodeIdUpdaters: Array[mutable.Map[Int, NodeIndexUpdater]],
+ splits: Array[Array[Split]]): Unit = {
+ if (prevNodeIdsForInstances != null) {
+ // Unpersist the previous one if one exists.
+ prevNodeIdsForInstances.unpersist()
+ }
+
+ prevNodeIdsForInstances = nodeIdsForInstances
+ nodeIdsForInstances = data.zip(nodeIdsForInstances).map { case (point, ids) =>
+ var treeId = 0
+ while (treeId < nodeIdUpdaters.length) {
+ val nodeIdUpdater = nodeIdUpdaters(treeId).getOrElse(ids(treeId), null)
+ if (nodeIdUpdater != null) {
+ val featureIndex = nodeIdUpdater.split.featureIndex
+ val newNodeIndex = nodeIdUpdater.updateNodeIndex(
+ binnedFeature = point.datum.binnedFeatures(featureIndex),
+ splits = splits(featureIndex))
+ ids(treeId) = newNodeIndex
+ }
+ treeId += 1
+ }
+ ids
+ }
+
+ // Keep on persisting new ones.
+ nodeIdsForInstances.persist(StorageLevel.MEMORY_AND_DISK)
+ rddUpdateCount += 1
+
+ // Handle checkpointing if the directory is not None.
+ if (canCheckpoint && (rddUpdateCount % checkpointInterval) == 0) {
+ // Let's see if we can delete previous checkpoints.
+ var canDelete = true
+ while (checkpointQueue.size > 1 && canDelete) {
+ // We can delete the oldest checkpoint iff
+ // the next checkpoint actually exists in the file system.
+ if (checkpointQueue(1).getCheckpointFile.isDefined) {
+ val old = checkpointQueue.dequeue()
+ // Since the old checkpoint is not deleted by Spark, we'll manually delete it here.
+ try {
+ fs.delete(new Path(old.getCheckpointFile.get), true)
+ } catch {
+ case e: IOException =>
+ logError("Decision Tree learning using cacheNodeIds failed to remove checkpoint" +
+ s" file: ${old.getCheckpointFile.get}")
+ }
+ } else {
+ canDelete = false
+ }
+ }
+
+ nodeIdsForInstances.checkpoint()
+ checkpointQueue.enqueue(nodeIdsForInstances)
+ }
+ }
+
+ /**
+ * Call this after training is finished to delete any remaining checkpoints.
+ */
+ def deleteAllCheckpoints(): Unit = {
+ while (checkpointQueue.nonEmpty) {
+ val old = checkpointQueue.dequeue()
+ if (old.getCheckpointFile.isDefined) {
+ try {
+ fs.delete(new Path(old.getCheckpointFile.get), true)
+ } catch {
+ case e: IOException =>
+ logError("Decision Tree learning using cacheNodeIds failed to remove checkpoint" +
+ s" file: ${old.getCheckpointFile.get}")
+ }
+ }
+ }
+ }
+ if (prevNodeIdsForInstances != null) {
+ // Unpersist the previous one if one exists.
+ prevNodeIdsForInstances.unpersist()
+ }
+}
+
+@DeveloperApi
+private[spark] object NodeIdCache {
+ /**
+ * Initialize the node Id cache with initial node Id values.
+ * @param data The RDD of training rows.
+ * @param numTrees The number of trees that we want to create cache for.
+ * @param checkpointInterval The checkpointing interval
+ * (how often should the cache be checkpointed.).
+ * @param initVal The initial values in the cache.
+ * @return A node Id cache containing an RDD of initial root node Indices.
+ */
+ def init(
+ data: RDD[BaggedPoint[TreePoint]],
+ numTrees: Int,
+ checkpointInterval: Int,
+ initVal: Int = 1): NodeIdCache = {
+ new NodeIdCache(
+ data.map(_ => Array.fill[Int](numTrees)(initVal)),
+ checkpointInterval)
+ }
+}
diff --git a/mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala
new file mode 100644
index 0000000000000..15b56bd844bad
--- /dev/null
+++ b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala
@@ -0,0 +1,1132 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.ml.tree.impl
+
+import java.io.IOException
+
+import scala.collection.mutable
+import scala.util.Random
+
+import org.apache.spark.Logging
+import org.apache.spark.ml.classification.DecisionTreeClassificationModel
+import org.apache.spark.ml.regression.DecisionTreeRegressionModel
+import org.apache.spark.ml.tree._
+import org.apache.spark.mllib.regression.LabeledPoint
+import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo, Strategy => OldStrategy}
+import org.apache.spark.mllib.tree.impl.{BaggedPoint, DTStatsAggregator, DecisionTreeMetadata,
+ TimeTracker}
+import org.apache.spark.mllib.tree.impurity.ImpurityCalculator
+import org.apache.spark.mllib.tree.model.{InformationGainStats, Predict}
+import org.apache.spark.rdd.RDD
+import org.apache.spark.storage.StorageLevel
+import org.apache.spark.util.random.{SamplingUtils, XORShiftRandom}
+
+
+private[ml] object RandomForest extends Logging {
+
+ /**
+ * Train a random forest.
+ * @param input Training data: RDD of [[org.apache.spark.mllib.regression.LabeledPoint]]
+ * @return an unweighted set of trees
+ */
+ def run(
+ input: RDD[LabeledPoint],
+ strategy: OldStrategy,
+ numTrees: Int,
+ featureSubsetStrategy: String,
+ seed: Long,
+ parentUID: Option[String] = None): Array[DecisionTreeModel] = {
+
+ val timer = new TimeTracker()
+
+ timer.start("total")
+
+ timer.start("init")
+
+ val retaggedInput = input.retag(classOf[LabeledPoint])
+ val metadata =
+ DecisionTreeMetadata.buildMetadata(retaggedInput, strategy, numTrees, featureSubsetStrategy)
+ logDebug("algo = " + strategy.algo)
+ logDebug("numTrees = " + numTrees)
+ logDebug("seed = " + seed)
+ logDebug("maxBins = " + metadata.maxBins)
+ logDebug("featureSubsetStrategy = " + featureSubsetStrategy)
+ logDebug("numFeaturesPerNode = " + metadata.numFeaturesPerNode)
+ logDebug("subsamplingRate = " + strategy.subsamplingRate)
+
+ // Find the splits and the corresponding bins (interval between the splits) using a sample
+ // of the input data.
+ timer.start("findSplitsBins")
+ val splits = findSplits(retaggedInput, metadata)
+ timer.stop("findSplitsBins")
+ logDebug("numBins: feature: number of bins")
+ logDebug(Range(0, metadata.numFeatures).map { featureIndex =>
+ s"\t$featureIndex\t${metadata.numBins(featureIndex)}"
+ }.mkString("\n"))
+
+ // Bin feature values (TreePoint representation).
+ // Cache input RDD for speedup during multiple passes.
+ val treeInput = TreePoint.convertToTreeRDD(retaggedInput, splits, metadata)
+
+ val withReplacement = numTrees > 1
+
+ val baggedInput = BaggedPoint
+ .convertToBaggedRDD(treeInput, strategy.subsamplingRate, numTrees, withReplacement, seed)
+ .persist(StorageLevel.MEMORY_AND_DISK)
+
+ // depth of the decision tree
+ val maxDepth = strategy.maxDepth
+ require(maxDepth <= 30,
+ s"DecisionTree currently only supports maxDepth <= 30, but was given maxDepth = $maxDepth.")
+
+ // Max memory usage for aggregates
+ // TODO: Calculate memory usage more precisely.
+ val maxMemoryUsage: Long = strategy.maxMemoryInMB * 1024L * 1024L
+ logDebug("max memory usage for aggregates = " + maxMemoryUsage + " bytes.")
+ val maxMemoryPerNode = {
+ val featureSubset: Option[Array[Int]] = if (metadata.subsamplingFeatures) {
+ // Find numFeaturesPerNode largest bins to get an upper bound on memory usage.
+ Some(metadata.numBins.zipWithIndex.sortBy(- _._1)
+ .take(metadata.numFeaturesPerNode).map(_._2))
+ } else {
+ None
+ }
+ RandomForest.aggregateSizeForNode(metadata, featureSubset) * 8L
+ }
+ require(maxMemoryPerNode <= maxMemoryUsage,
+ s"RandomForest/DecisionTree given maxMemoryInMB = ${strategy.maxMemoryInMB}," +
+ " which is too small for the given features." +
+ s" Minimum value = ${maxMemoryPerNode / (1024L * 1024L)}")
+
+ timer.stop("init")
+
+ /*
+ * The main idea here is to perform group-wise training of the decision tree nodes thus
+ * reducing the passes over the data from (# nodes) to (# nodes / maxNumberOfNodesPerGroup).
+ * Each data sample is handled by a particular node (or it reaches a leaf and is not used
+ * in lower levels).
+ */
+
+ // Create an RDD of node Id cache.
+ // At first, all the rows belong to the root nodes (node Id == 1).
+ val nodeIdCache = if (strategy.useNodeIdCache) {
+ Some(NodeIdCache.init(
+ data = baggedInput,
+ numTrees = numTrees,
+ checkpointInterval = strategy.checkpointInterval,
+ initVal = 1))
+ } else {
+ None
+ }
+
+ // FIFO queue of nodes to train: (treeIndex, node)
+ val nodeQueue = new mutable.Queue[(Int, LearningNode)]()
+
+ val rng = new Random()
+ rng.setSeed(seed)
+
+ // Allocate and queue root nodes.
+ val topNodes = Array.fill[LearningNode](numTrees)(LearningNode.emptyNode(nodeIndex = 1))
+ Range(0, numTrees).foreach(treeIndex => nodeQueue.enqueue((treeIndex, topNodes(treeIndex))))
+
+ while (nodeQueue.nonEmpty) {
+ // Collect some nodes to split, and choose features for each node (if subsampling).
+ // Each group of nodes may come from one or multiple trees, and at multiple levels.
+ val (nodesForGroup, treeToNodeToIndexInfo) =
+ RandomForest.selectNodesToSplit(nodeQueue, maxMemoryUsage, metadata, rng)
+ // Sanity check (should never occur):
+ assert(nodesForGroup.nonEmpty,
+ s"RandomForest selected empty nodesForGroup. Error for unknown reason.")
+
+ // Choose node splits, and enqueue new nodes as needed.
+ timer.start("findBestSplits")
+ RandomForest.findBestSplits(baggedInput, metadata, topNodes, nodesForGroup,
+ treeToNodeToIndexInfo, splits, nodeQueue, timer, nodeIdCache)
+ timer.stop("findBestSplits")
+ }
+
+ baggedInput.unpersist()
+
+ timer.stop("total")
+
+ logInfo("Internal timing for DecisionTree:")
+ logInfo(s"$timer")
+
+ // Delete any remaining checkpoints used for node Id cache.
+ if (nodeIdCache.nonEmpty) {
+ try {
+ nodeIdCache.get.deleteAllCheckpoints()
+ } catch {
+ case e: IOException =>
+ logWarning(s"delete all checkpoints failed. Error reason: ${e.getMessage}")
+ }
+ }
+
+ parentUID match {
+ case Some(uid) =>
+ if (strategy.algo == OldAlgo.Classification) {
+ topNodes.map(rootNode => new DecisionTreeClassificationModel(uid, rootNode.toNode))
+ } else {
+ topNodes.map(rootNode => new DecisionTreeRegressionModel(uid, rootNode.toNode))
+ }
+ case None =>
+ if (strategy.algo == OldAlgo.Classification) {
+ topNodes.map(rootNode => new DecisionTreeClassificationModel(rootNode.toNode))
+ } else {
+ topNodes.map(rootNode => new DecisionTreeRegressionModel(rootNode.toNode))
+ }
+ }
+ }
+
+ /**
+ * Get the node index corresponding to this data point.
+ * This function mimics prediction, passing an example from the root node down to a leaf
+ * or unsplit node; that node's index is returned.
+ *
+ * @param node Node in tree from which to classify the given data point.
+ * @param binnedFeatures Binned feature vector for data point.
+ * @param splits possible splits for all features, indexed (numFeatures)(numSplits)
+ * @return Leaf index if the data point reaches a leaf.
+ * Otherwise, last node reachable in tree matching this example.
+ * Note: This is the global node index, i.e., the index used in the tree.
+ * This index is different from the index used during training a particular
+ * group of nodes on one call to [[findBestSplits()]].
+ */
+ private def predictNodeIndex(
+ node: LearningNode,
+ binnedFeatures: Array[Int],
+ splits: Array[Array[Split]]): Int = {
+ if (node.isLeaf || node.split.isEmpty) {
+ node.id
+ } else {
+ val split = node.split.get
+ val featureIndex = split.featureIndex
+ val splitLeft = split.shouldGoLeft(binnedFeatures(featureIndex), splits(featureIndex))
+ if (node.leftChild.isEmpty) {
+ // Not yet split. Return index from next layer of nodes to train
+ if (splitLeft) {
+ LearningNode.leftChildIndex(node.id)
+ } else {
+ LearningNode.rightChildIndex(node.id)
+ }
+ } else {
+ if (splitLeft) {
+ predictNodeIndex(node.leftChild.get, binnedFeatures, splits)
+ } else {
+ predictNodeIndex(node.rightChild.get, binnedFeatures, splits)
+ }
+ }
+ }
+ }
+
+ /**
+ * Helper for binSeqOp, for data which can contain a mix of ordered and unordered features.
+ *
+ * For ordered features, a single bin is updated.
+ * For unordered features, bins correspond to subsets of categories; either the left or right bin
+ * for each subset is updated.
+ *
+ * @param agg Array storing aggregate calculation, with a set of sufficient statistics for
+ * each (feature, bin).
+ * @param treePoint Data point being aggregated.
+ * @param splits possible splits indexed (numFeatures)(numSplits)
+ * @param unorderedFeatures Set of indices of unordered features.
+ * @param instanceWeight Weight (importance) of instance in dataset.
+ */
+ private def mixedBinSeqOp(
+ agg: DTStatsAggregator,
+ treePoint: TreePoint,
+ splits: Array[Array[Split]],
+ unorderedFeatures: Set[Int],
+ instanceWeight: Double,
+ featuresForNode: Option[Array[Int]]): Unit = {
+ val numFeaturesPerNode = if (featuresForNode.nonEmpty) {
+ // Use subsampled features
+ featuresForNode.get.length
+ } else {
+ // Use all features
+ agg.metadata.numFeatures
+ }
+ // Iterate over features.
+ var featureIndexIdx = 0
+ while (featureIndexIdx < numFeaturesPerNode) {
+ val featureIndex = if (featuresForNode.nonEmpty) {
+ featuresForNode.get.apply(featureIndexIdx)
+ } else {
+ featureIndexIdx
+ }
+ if (unorderedFeatures.contains(featureIndex)) {
+ // Unordered feature
+ val featureValue = treePoint.binnedFeatures(featureIndex)
+ val (leftNodeFeatureOffset, rightNodeFeatureOffset) =
+ agg.getLeftRightFeatureOffsets(featureIndexIdx)
+ // Update the left or right bin for each split.
+ val numSplits = agg.metadata.numSplits(featureIndex)
+ val featureSplits = splits(featureIndex)
+ var splitIndex = 0
+ while (splitIndex < numSplits) {
+ if (featureSplits(splitIndex).shouldGoLeft(featureValue, featureSplits)) {
+ agg.featureUpdate(leftNodeFeatureOffset, splitIndex, treePoint.label, instanceWeight)
+ } else {
+ agg.featureUpdate(rightNodeFeatureOffset, splitIndex, treePoint.label, instanceWeight)
+ }
+ splitIndex += 1
+ }
+ } else {
+ // Ordered feature
+ val binIndex = treePoint.binnedFeatures(featureIndex)
+ agg.update(featureIndexIdx, binIndex, treePoint.label, instanceWeight)
+ }
+ featureIndexIdx += 1
+ }
+ }
+
+ /**
+ * Helper for binSeqOp, for regression and for classification with only ordered features.
+ *
+ * For each feature, the sufficient statistics of one bin are updated.
+ *
+ * @param agg Array storing aggregate calculation, with a set of sufficient statistics for
+ * each (feature, bin).
+ * @param treePoint Data point being aggregated.
+ * @param instanceWeight Weight (importance) of instance in dataset.
+ */
+ private def orderedBinSeqOp(
+ agg: DTStatsAggregator,
+ treePoint: TreePoint,
+ instanceWeight: Double,
+ featuresForNode: Option[Array[Int]]): Unit = {
+ val label = treePoint.label
+
+ // Iterate over features.
+ if (featuresForNode.nonEmpty) {
+ // Use subsampled features
+ var featureIndexIdx = 0
+ while (featureIndexIdx < featuresForNode.get.length) {
+ val binIndex = treePoint.binnedFeatures(featuresForNode.get.apply(featureIndexIdx))
+ agg.update(featureIndexIdx, binIndex, label, instanceWeight)
+ featureIndexIdx += 1
+ }
+ } else {
+ // Use all features
+ val numFeatures = agg.metadata.numFeatures
+ var featureIndex = 0
+ while (featureIndex < numFeatures) {
+ val binIndex = treePoint.binnedFeatures(featureIndex)
+ agg.update(featureIndex, binIndex, label, instanceWeight)
+ featureIndex += 1
+ }
+ }
+ }
+
+ /**
+ * Given a group of nodes, this finds the best split for each node.
+ *
+ * @param input Training data: RDD of [[org.apache.spark.mllib.tree.impl.TreePoint]]
+ * @param metadata Learning and dataset metadata
+ * @param topNodes Root node for each tree. Used for matching instances with nodes.
+ * @param nodesForGroup Mapping: treeIndex --> nodes to be split in tree
+ * @param treeToNodeToIndexInfo Mapping: treeIndex --> nodeIndex --> nodeIndexInfo,
+ * where nodeIndexInfo stores the index in the group and the
+ * feature subsets (if using feature subsets).
+ * @param splits possible splits for all features, indexed (numFeatures)(numSplits)
+ * @param nodeQueue Queue of nodes to split, with values (treeIndex, node).
+ * Updated with new non-leaf nodes which are created.
+ * @param nodeIdCache Node Id cache containing an RDD of Array[Int] where
+ * each value in the array is the data point's node Id
+ * for a corresponding tree. This is used to prevent the need
+ * to pass the entire tree to the executors during
+ * the node stat aggregation phase.
+ */
+ private[tree] def findBestSplits(
+ input: RDD[BaggedPoint[TreePoint]],
+ metadata: DecisionTreeMetadata,
+ topNodes: Array[LearningNode],
+ nodesForGroup: Map[Int, Array[LearningNode]],
+ treeToNodeToIndexInfo: Map[Int, Map[Int, NodeIndexInfo]],
+ splits: Array[Array[Split]],
+ nodeQueue: mutable.Queue[(Int, LearningNode)],
+ timer: TimeTracker = new TimeTracker,
+ nodeIdCache: Option[NodeIdCache] = None): Unit = {
+
+ /*
+ * The high-level descriptions of the best split optimizations are noted here.
+ *
+ * *Group-wise training*
+ * We perform bin calculations for groups of nodes to reduce the number of
+ * passes over the data. Each iteration requires more computation and storage,
+ * but saves several iterations over the data.
+ *
+ * *Bin-wise computation*
+ * We use a bin-wise best split computation strategy instead of a straightforward best split
+ * computation strategy. Instead of analyzing each sample for contribution to the left/right
+ * child node impurity of every split, we first categorize each feature of a sample into a
+ * bin. We exploit this structure to calculate aggregates for bins and then use these aggregates
+ * to calculate information gain for each split.
+ *
+ * *Aggregation over partitions*
+ * Instead of performing a flatMap/reduceByKey operation, we exploit the fact that we know
+ * the number of splits in advance. Thus, we store the aggregates (at the appropriate
+ * indices) in a single array for all bins and rely upon the RDD aggregate method to
+ * drastically reduce the communication overhead.
+ */
+
+ // numNodes: Number of nodes in this group
+ val numNodes = nodesForGroup.values.map(_.length).sum
+ logDebug("numNodes = " + numNodes)
+ logDebug("numFeatures = " + metadata.numFeatures)
+ logDebug("numClasses = " + metadata.numClasses)
+ logDebug("isMulticlass = " + metadata.isMulticlass)
+ logDebug("isMulticlassWithCategoricalFeatures = " +
+ metadata.isMulticlassWithCategoricalFeatures)
+ logDebug("using nodeIdCache = " + nodeIdCache.nonEmpty.toString)
+
+ /**
+ * Performs a sequential aggregation over a partition for a particular tree and node.
+ *
+ * For each feature, the aggregate sufficient statistics are updated for the relevant
+ * bins.
+ *
+ * @param treeIndex Index of the tree that we want to perform aggregation for.
+ * @param nodeInfo The node info for the tree node.
+ * @param agg Array storing aggregate calculation, with a set of sufficient statistics
+ * for each (node, feature, bin).
+ * @param baggedPoint Data point being aggregated.
+ */
+ def nodeBinSeqOp(
+ treeIndex: Int,
+ nodeInfo: NodeIndexInfo,
+ agg: Array[DTStatsAggregator],
+ baggedPoint: BaggedPoint[TreePoint]): Unit = {
+ if (nodeInfo != null) {
+ val aggNodeIndex = nodeInfo.nodeIndexInGroup
+ val featuresForNode = nodeInfo.featureSubset
+ val instanceWeight = baggedPoint.subsampleWeights(treeIndex)
+ if (metadata.unorderedFeatures.isEmpty) {
+ orderedBinSeqOp(agg(aggNodeIndex), baggedPoint.datum, instanceWeight, featuresForNode)
+ } else {
+ mixedBinSeqOp(agg(aggNodeIndex), baggedPoint.datum, splits,
+ metadata.unorderedFeatures, instanceWeight, featuresForNode)
+ }
+ }
+ }
+
+ /**
+ * Performs a sequential aggregation over a partition.
+ *
+ * Each data point contributes to one node. For each feature,
+ * the aggregate sufficient statistics are updated for the relevant bins.
+ *
+ * @param agg Array storing aggregate calculation, with a set of sufficient statistics for
+ * each (node, feature, bin).
+ * @param baggedPoint Data point being aggregated.
+ * @return agg
+ */
+ def binSeqOp(
+ agg: Array[DTStatsAggregator],
+ baggedPoint: BaggedPoint[TreePoint]): Array[DTStatsAggregator] = {
+ treeToNodeToIndexInfo.foreach { case (treeIndex, nodeIndexToInfo) =>
+ val nodeIndex =
+ predictNodeIndex(topNodes(treeIndex), baggedPoint.datum.binnedFeatures, splits)
+ nodeBinSeqOp(treeIndex, nodeIndexToInfo.getOrElse(nodeIndex, null), agg, baggedPoint)
+ }
+ agg
+ }
+
+ /**
+ * Do the same thing as binSeqOp, but with nodeIdCache.
+ */
+ def binSeqOpWithNodeIdCache(
+ agg: Array[DTStatsAggregator],
+ dataPoint: (BaggedPoint[TreePoint], Array[Int])): Array[DTStatsAggregator] = {
+ treeToNodeToIndexInfo.foreach { case (treeIndex, nodeIndexToInfo) =>
+ val baggedPoint = dataPoint._1
+ val nodeIdCache = dataPoint._2
+ val nodeIndex = nodeIdCache(treeIndex)
+ nodeBinSeqOp(treeIndex, nodeIndexToInfo.getOrElse(nodeIndex, null), agg, baggedPoint)
+ }
+
+ agg
+ }
+
+ /**
+ * Get node index in group --> features indices map,
+ * which is a short cut to find feature indices for a node given node index in group.
+ */
+ def getNodeToFeatures(
+ treeToNodeToIndexInfo: Map[Int, Map[Int, NodeIndexInfo]]): Option[Map[Int, Array[Int]]] = {
+ if (!metadata.subsamplingFeatures) {
+ None
+ } else {
+ val mutableNodeToFeatures = new mutable.HashMap[Int, Array[Int]]()
+ treeToNodeToIndexInfo.values.foreach { nodeIdToNodeInfo =>
+ nodeIdToNodeInfo.values.foreach { nodeIndexInfo =>
+ assert(nodeIndexInfo.featureSubset.isDefined)
+ mutableNodeToFeatures(nodeIndexInfo.nodeIndexInGroup) = nodeIndexInfo.featureSubset.get
+ }
+ }
+ Some(mutableNodeToFeatures.toMap)
+ }
+ }
+
+ // array of nodes to train indexed by node index in group
+ val nodes = new Array[LearningNode](numNodes)
+ nodesForGroup.foreach { case (treeIndex, nodesForTree) =>
+ nodesForTree.foreach { node =>
+ nodes(treeToNodeToIndexInfo(treeIndex)(node.id).nodeIndexInGroup) = node
+ }
+ }
+
+ // Calculate best splits for all nodes in the group
+ timer.start("chooseSplits")
+
+ // In each partition, iterate all instances and compute aggregate stats for each node,
+ // yield an (nodeIndex, nodeAggregateStats) pair for each node.
+ // After a `reduceByKey` operation,
+ // stats of a node will be shuffled to a particular partition and be combined together,
+ // then best splits for nodes are found there.
+ // Finally, only best Splits for nodes are collected to driver to construct decision tree.
+ val nodeToFeatures = getNodeToFeatures(treeToNodeToIndexInfo)
+ val nodeToFeaturesBc = input.sparkContext.broadcast(nodeToFeatures)
+
+ val partitionAggregates : RDD[(Int, DTStatsAggregator)] = if (nodeIdCache.nonEmpty) {
+ input.zip(nodeIdCache.get.nodeIdsForInstances).mapPartitions { points =>
+ // Construct a nodeStatsAggregators array to hold node aggregate stats,
+ // each node will have a nodeStatsAggregator
+ val nodeStatsAggregators = Array.tabulate(numNodes) { nodeIndex =>
+ val featuresForNode = nodeToFeaturesBc.value.flatMap { nodeToFeatures =>
+ Some(nodeToFeatures(nodeIndex))
+ }
+ new DTStatsAggregator(metadata, featuresForNode)
+ }
+
+ // iterator all instances in current partition and update aggregate stats
+ points.foreach(binSeqOpWithNodeIdCache(nodeStatsAggregators, _))
+
+ // transform nodeStatsAggregators array to (nodeIndex, nodeAggregateStats) pairs,
+ // which can be combined with other partition using `reduceByKey`
+ nodeStatsAggregators.view.zipWithIndex.map(_.swap).iterator
+ }
+ } else {
+ input.mapPartitions { points =>
+ // Construct a nodeStatsAggregators array to hold node aggregate stats,
+ // each node will have a nodeStatsAggregator
+ val nodeStatsAggregators = Array.tabulate(numNodes) { nodeIndex =>
+ val featuresForNode = nodeToFeaturesBc.value.flatMap { nodeToFeatures =>
+ Some(nodeToFeatures(nodeIndex))
+ }
+ new DTStatsAggregator(metadata, featuresForNode)
+ }
+
+ // iterator all instances in current partition and update aggregate stats
+ points.foreach(binSeqOp(nodeStatsAggregators, _))
+
+ // transform nodeStatsAggregators array to (nodeIndex, nodeAggregateStats) pairs,
+ // which can be combined with other partition using `reduceByKey`
+ nodeStatsAggregators.view.zipWithIndex.map(_.swap).iterator
+ }
+ }
+
+ val nodeToBestSplits = partitionAggregates.reduceByKey((a, b) => a.merge(b)).map {
+ case (nodeIndex, aggStats) =>
+ val featuresForNode = nodeToFeaturesBc.value.flatMap { nodeToFeatures =>
+ Some(nodeToFeatures(nodeIndex))
+ }
+
+ // find best split for each node
+ val (split: Split, stats: InformationGainStats, predict: Predict) =
+ binsToBestSplit(aggStats, splits, featuresForNode, nodes(nodeIndex))
+ (nodeIndex, (split, stats, predict))
+ }.collectAsMap()
+
+ timer.stop("chooseSplits")
+
+ val nodeIdUpdaters = if (nodeIdCache.nonEmpty) {
+ Array.fill[mutable.Map[Int, NodeIndexUpdater]](
+ metadata.numTrees)(mutable.Map[Int, NodeIndexUpdater]())
+ } else {
+ null
+ }
+ // Iterate over all nodes in this group.
+ nodesForGroup.foreach { case (treeIndex, nodesForTree) =>
+ nodesForTree.foreach { node =>
+ val nodeIndex = node.id
+ val nodeInfo = treeToNodeToIndexInfo(treeIndex)(nodeIndex)
+ val aggNodeIndex = nodeInfo.nodeIndexInGroup
+ val (split: Split, stats: InformationGainStats, predict: Predict) =
+ nodeToBestSplits(aggNodeIndex)
+ logDebug("best split = " + split)
+
+ // Extract info for this node. Create children if not leaf.
+ val isLeaf =
+ (stats.gain <= 0) || (LearningNode.indexToLevel(nodeIndex) == metadata.maxDepth)
+ node.predictionStats = predict
+ node.isLeaf = isLeaf
+ node.stats = Some(stats)
+ node.impurity = stats.impurity
+ logDebug("Node = " + node)
+
+ if (!isLeaf) {
+ node.split = Some(split)
+ val childIsLeaf = (LearningNode.indexToLevel(nodeIndex) + 1) == metadata.maxDepth
+ val leftChildIsLeaf = childIsLeaf || (stats.leftImpurity == 0.0)
+ val rightChildIsLeaf = childIsLeaf || (stats.rightImpurity == 0.0)
+ node.leftChild = Some(LearningNode(LearningNode.leftChildIndex(nodeIndex),
+ stats.leftPredict, stats.leftImpurity, leftChildIsLeaf))
+ node.rightChild = Some(LearningNode(LearningNode.rightChildIndex(nodeIndex),
+ stats.rightPredict, stats.rightImpurity, rightChildIsLeaf))
+
+ if (nodeIdCache.nonEmpty) {
+ val nodeIndexUpdater = NodeIndexUpdater(
+ split = split,
+ nodeIndex = nodeIndex)
+ nodeIdUpdaters(treeIndex).put(nodeIndex, nodeIndexUpdater)
+ }
+
+ // enqueue left child and right child if they are not leaves
+ if (!leftChildIsLeaf) {
+ nodeQueue.enqueue((treeIndex, node.leftChild.get))
+ }
+ if (!rightChildIsLeaf) {
+ nodeQueue.enqueue((treeIndex, node.rightChild.get))
+ }
+
+ logDebug("leftChildIndex = " + node.leftChild.get.id +
+ ", impurity = " + stats.leftImpurity)
+ logDebug("rightChildIndex = " + node.rightChild.get.id +
+ ", impurity = " + stats.rightImpurity)
+ }
+ }
+ }
+
+ if (nodeIdCache.nonEmpty) {
+ // Update the cache if needed.
+ nodeIdCache.get.updateNodeIndices(input, nodeIdUpdaters, splits)
+ }
+ }
+
+ /**
+ * Calculate the information gain for a given (feature, split) based upon left/right aggregates.
+ * @param leftImpurityCalculator left node aggregates for this (feature, split)
+ * @param rightImpurityCalculator right node aggregate for this (feature, split)
+ * @return information gain and statistics for split
+ */
+ private def calculateGainForSplit(
+ leftImpurityCalculator: ImpurityCalculator,
+ rightImpurityCalculator: ImpurityCalculator,
+ metadata: DecisionTreeMetadata,
+ impurity: Double): InformationGainStats = {
+ val leftCount = leftImpurityCalculator.count
+ val rightCount = rightImpurityCalculator.count
+
+ // If left child or right child doesn't satisfy minimum instances per node,
+ // then this split is invalid, return invalid information gain stats.
+ if ((leftCount < metadata.minInstancesPerNode) ||
+ (rightCount < metadata.minInstancesPerNode)) {
+ return InformationGainStats.invalidInformationGainStats
+ }
+
+ val totalCount = leftCount + rightCount
+
+ val leftImpurity = leftImpurityCalculator.calculate() // Note: This equals 0 if count = 0
+ val rightImpurity = rightImpurityCalculator.calculate()
+
+ val leftWeight = leftCount / totalCount.toDouble
+ val rightWeight = rightCount / totalCount.toDouble
+
+ val gain = impurity - leftWeight * leftImpurity - rightWeight * rightImpurity
+
+ // if information gain doesn't satisfy minimum information gain,
+ // then this split is invalid, return invalid information gain stats.
+ if (gain < metadata.minInfoGain) {
+ return InformationGainStats.invalidInformationGainStats
+ }
+
+ // calculate left and right predict
+ val leftPredict = calculatePredict(leftImpurityCalculator)
+ val rightPredict = calculatePredict(rightImpurityCalculator)
+
+ new InformationGainStats(gain, impurity, leftImpurity, rightImpurity,
+ leftPredict, rightPredict)
+ }
+
+ private def calculatePredict(impurityCalculator: ImpurityCalculator): Predict = {
+ val predict = impurityCalculator.predict
+ val prob = impurityCalculator.prob(predict)
+ new Predict(predict, prob)
+ }
+
+ /**
+ * Calculate predict value for current node, given stats of any split.
+ * Note that this function is called only once for each node.
+ * @param leftImpurityCalculator left node aggregates for a split
+ * @param rightImpurityCalculator right node aggregates for a split
+ * @return predict value and impurity for current node
+ */
+ private def calculatePredictImpurity(
+ leftImpurityCalculator: ImpurityCalculator,
+ rightImpurityCalculator: ImpurityCalculator): (Predict, Double) = {
+ val parentNodeAgg = leftImpurityCalculator.copy
+ parentNodeAgg.add(rightImpurityCalculator)
+ val predict = calculatePredict(parentNodeAgg)
+ val impurity = parentNodeAgg.calculate()
+
+ (predict, impurity)
+ }
+
+ /**
+ * Find the best split for a node.
+ * @param binAggregates Bin statistics.
+ * @return tuple for best split: (Split, information gain, prediction at node)
+ */
+ private def binsToBestSplit(
+ binAggregates: DTStatsAggregator,
+ splits: Array[Array[Split]],
+ featuresForNode: Option[Array[Int]],
+ node: LearningNode): (Split, InformationGainStats, Predict) = {
+
+ // Calculate prediction and impurity if current node is top node
+ val level = LearningNode.indexToLevel(node.id)
+ var predictionAndImpurity: Option[(Predict, Double)] = if (level == 0) {
+ None
+ } else {
+ Some((node.predictionStats, node.impurity))
+ }
+
+ // For each (feature, split), calculate the gain, and select the best (feature, split).
+ val (bestSplit, bestSplitStats) =
+ Range(0, binAggregates.metadata.numFeaturesPerNode).map { featureIndexIdx =>
+ val featureIndex = if (featuresForNode.nonEmpty) {
+ featuresForNode.get.apply(featureIndexIdx)
+ } else {
+ featureIndexIdx
+ }
+ val numSplits = binAggregates.metadata.numSplits(featureIndex)
+ if (binAggregates.metadata.isContinuous(featureIndex)) {
+ // Cumulative sum (scanLeft) of bin statistics.
+ // Afterwards, binAggregates for a bin is the sum of aggregates for
+ // that bin + all preceding bins.
+ val nodeFeatureOffset = binAggregates.getFeatureOffset(featureIndexIdx)
+ var splitIndex = 0
+ while (splitIndex < numSplits) {
+ binAggregates.mergeForFeature(nodeFeatureOffset, splitIndex + 1, splitIndex)
+ splitIndex += 1
+ }
+ // Find best split.
+ val (bestFeatureSplitIndex, bestFeatureGainStats) =
+ Range(0, numSplits).map { case splitIdx =>
+ val leftChildStats = binAggregates.getImpurityCalculator(nodeFeatureOffset, splitIdx)
+ val rightChildStats =
+ binAggregates.getImpurityCalculator(nodeFeatureOffset, numSplits)
+ rightChildStats.subtract(leftChildStats)
+ predictionAndImpurity = Some(predictionAndImpurity.getOrElse(
+ calculatePredictImpurity(leftChildStats, rightChildStats)))
+ val gainStats = calculateGainForSplit(leftChildStats,
+ rightChildStats, binAggregates.metadata, predictionAndImpurity.get._2)
+ (splitIdx, gainStats)
+ }.maxBy(_._2.gain)
+ (splits(featureIndex)(bestFeatureSplitIndex), bestFeatureGainStats)
+ } else if (binAggregates.metadata.isUnordered(featureIndex)) {
+ // Unordered categorical feature
+ val (leftChildOffset, rightChildOffset) =
+ binAggregates.getLeftRightFeatureOffsets(featureIndexIdx)
+ val (bestFeatureSplitIndex, bestFeatureGainStats) =
+ Range(0, numSplits).map { splitIndex =>
+ val leftChildStats = binAggregates.getImpurityCalculator(leftChildOffset, splitIndex)
+ val rightChildStats =
+ binAggregates.getImpurityCalculator(rightChildOffset, splitIndex)
+ predictionAndImpurity = Some(predictionAndImpurity.getOrElse(
+ calculatePredictImpurity(leftChildStats, rightChildStats)))
+ val gainStats = calculateGainForSplit(leftChildStats,
+ rightChildStats, binAggregates.metadata, predictionAndImpurity.get._2)
+ (splitIndex, gainStats)
+ }.maxBy(_._2.gain)
+ (splits(featureIndex)(bestFeatureSplitIndex), bestFeatureGainStats)
+ } else {
+ // Ordered categorical feature
+ val nodeFeatureOffset = binAggregates.getFeatureOffset(featureIndexIdx)
+ val numCategories = binAggregates.metadata.numBins(featureIndex)
+
+ /* Each bin is one category (feature value).
+ * The bins are ordered based on centroidForCategories, and this ordering determines which
+ * splits are considered. (With K categories, we consider K - 1 possible splits.)
+ *
+ * centroidForCategories is a list: (category, centroid)
+ */
+ val centroidForCategories = if (binAggregates.metadata.isMulticlass) {
+ // For categorical variables in multiclass classification,
+ // the bins are ordered by the impurity of their corresponding labels.
+ Range(0, numCategories).map { case featureValue =>
+ val categoryStats =
+ binAggregates.getImpurityCalculator(nodeFeatureOffset, featureValue)
+ val centroid = if (categoryStats.count != 0) {
+ categoryStats.calculate()
+ } else {
+ Double.MaxValue
+ }
+ (featureValue, centroid)
+ }
+ } else { // regression or binary classification
+ // For categorical variables in regression and binary classification,
+ // the bins are ordered by the centroid of their corresponding labels.
+ Range(0, numCategories).map { case featureValue =>
+ val categoryStats =
+ binAggregates.getImpurityCalculator(nodeFeatureOffset, featureValue)
+ val centroid = if (categoryStats.count != 0) {
+ categoryStats.predict
+ } else {
+ Double.MaxValue
+ }
+ (featureValue, centroid)
+ }
+ }
+
+ logDebug("Centroids for categorical variable: " + centroidForCategories.mkString(","))
+
+ // bins sorted by centroids
+ val categoriesSortedByCentroid = centroidForCategories.toList.sortBy(_._2)
+
+ logDebug("Sorted centroids for categorical variable = " +
+ categoriesSortedByCentroid.mkString(","))
+
+ // Cumulative sum (scanLeft) of bin statistics.
+ // Afterwards, binAggregates for a bin is the sum of aggregates for
+ // that bin + all preceding bins.
+ var splitIndex = 0
+ while (splitIndex < numSplits) {
+ val currentCategory = categoriesSortedByCentroid(splitIndex)._1
+ val nextCategory = categoriesSortedByCentroid(splitIndex + 1)._1
+ binAggregates.mergeForFeature(nodeFeatureOffset, nextCategory, currentCategory)
+ splitIndex += 1
+ }
+ // lastCategory = index of bin with total aggregates for this (node, feature)
+ val lastCategory = categoriesSortedByCentroid.last._1
+ // Find best split.
+ val (bestFeatureSplitIndex, bestFeatureGainStats) =
+ Range(0, numSplits).map { splitIndex =>
+ val featureValue = categoriesSortedByCentroid(splitIndex)._1
+ val leftChildStats =
+ binAggregates.getImpurityCalculator(nodeFeatureOffset, featureValue)
+ val rightChildStats =
+ binAggregates.getImpurityCalculator(nodeFeatureOffset, lastCategory)
+ rightChildStats.subtract(leftChildStats)
+ predictionAndImpurity = Some(predictionAndImpurity.getOrElse(
+ calculatePredictImpurity(leftChildStats, rightChildStats)))
+ val gainStats = calculateGainForSplit(leftChildStats,
+ rightChildStats, binAggregates.metadata, predictionAndImpurity.get._2)
+ (splitIndex, gainStats)
+ }.maxBy(_._2.gain)
+ val categoriesForSplit =
+ categoriesSortedByCentroid.map(_._1.toDouble).slice(0, bestFeatureSplitIndex + 1)
+ val bestFeatureSplit =
+ new CategoricalSplit(featureIndex, categoriesForSplit.toArray, numCategories)
+ (bestFeatureSplit, bestFeatureGainStats)
+ }
+ }.maxBy(_._2.gain)
+
+ (bestSplit, bestSplitStats, predictionAndImpurity.get._1)
+ }
+
+ /**
+ * Returns splits and bins for decision tree calculation.
+ * Continuous and categorical features are handled differently.
+ *
+ * Continuous features:
+ * For each feature, there are numBins - 1 possible splits representing the possible binary
+ * decisions at each node in the tree.
+ * This finds locations (feature values) for splits using a subsample of the data.
+ *
+ * Categorical features:
+ * For each feature, there is 1 bin per split.
+ * Splits and bins are handled in 2 ways:
+ * (a) "unordered features"
+ * For multiclass classification with a low-arity feature
+ * (i.e., if isMulticlass && isSpaceSufficientForAllCategoricalSplits),
+ * the feature is split based on subsets of categories.
+ * (b) "ordered features"
+ * For regression and binary classification,
+ * and for multiclass classification with a high-arity feature,
+ * there is one bin per category.
+ *
+ * @param input Training data: RDD of [[org.apache.spark.mllib.regression.LabeledPoint]]
+ * @param metadata Learning and dataset metadata
+ * @return A tuple of (splits, bins).
+ * Splits is an Array of [[org.apache.spark.mllib.tree.model.Split]]
+ * of size (numFeatures, numSplits).
+ * Bins is an Array of [[org.apache.spark.mllib.tree.model.Bin]]
+ * of size (numFeatures, numBins).
+ */
+ protected[tree] def findSplits(
+ input: RDD[LabeledPoint],
+ metadata: DecisionTreeMetadata): Array[Array[Split]] = {
+
+ logDebug("isMulticlass = " + metadata.isMulticlass)
+
+ val numFeatures = metadata.numFeatures
+
+ // Sample the input only if there are continuous features.
+ val hasContinuousFeatures = Range(0, numFeatures).exists(metadata.isContinuous)
+ val sampledInput = if (hasContinuousFeatures) {
+ // Calculate the number of samples for approximate quantile calculation.
+ val requiredSamples = math.max(metadata.maxBins * metadata.maxBins, 10000)
+ val fraction = if (requiredSamples < metadata.numExamples) {
+ requiredSamples.toDouble / metadata.numExamples
+ } else {
+ 1.0
+ }
+ logDebug("fraction of data used for calculating quantiles = " + fraction)
+ input.sample(withReplacement = false, fraction, new XORShiftRandom(1).nextInt()).collect()
+ } else {
+ new Array[LabeledPoint](0)
+ }
+
+ val splits = new Array[Array[Split]](numFeatures)
+
+ // Find all splits.
+ // Iterate over all features.
+ var featureIndex = 0
+ while (featureIndex < numFeatures) {
+ if (metadata.isContinuous(featureIndex)) {
+ val featureSamples = sampledInput.map(_.features(featureIndex))
+ val featureSplits = findSplitsForContinuousFeature(featureSamples, metadata, featureIndex)
+
+ val numSplits = featureSplits.length
+ logDebug(s"featureIndex = $featureIndex, numSplits = $numSplits")
+ splits(featureIndex) = new Array[Split](numSplits)
+
+ var splitIndex = 0
+ while (splitIndex < numSplits) {
+ val threshold = featureSplits(splitIndex)
+ splits(featureIndex)(splitIndex) = new ContinuousSplit(featureIndex, threshold)
+ splitIndex += 1
+ }
+ } else {
+ // Categorical feature
+ if (metadata.isUnordered(featureIndex)) {
+ val numSplits = metadata.numSplits(featureIndex)
+ val featureArity = metadata.featureArity(featureIndex)
+ // TODO: Use an implicit representation mapping each category to a subset of indices.
+ // I.e., track indices such that we can calculate the set of bins for which
+ // feature value x splits to the left.
+ // Unordered features
+ // 2^(maxFeatureValue - 1) - 1 combinations
+ splits(featureIndex) = new Array[Split](numSplits)
+ var splitIndex = 0
+ while (splitIndex < numSplits) {
+ val categories: List[Double] =
+ extractMultiClassCategories(splitIndex + 1, featureArity)
+ splits(featureIndex)(splitIndex) =
+ new CategoricalSplit(featureIndex, categories.toArray, featureArity)
+ splitIndex += 1
+ }
+ } else {
+ // Ordered features
+ // Bins correspond to feature values, so we do not need to compute splits or bins
+ // beforehand. Splits are constructed as needed during training.
+ splits(featureIndex) = new Array[Split](0)
+ }
+ }
+ featureIndex += 1
+ }
+ splits
+ }
+
+ /**
+ * Nested method to extract list of eligible categories given an index. It extracts the
+ * position of ones in a binary representation of the input. If binary
+ * representation of an number is 01101 (13), the output list should (3.0, 2.0,
+ * 0.0). The maxFeatureValue depict the number of rightmost digits that will be tested for ones.
+ */
+ private[tree] def extractMultiClassCategories(
+ input: Int,
+ maxFeatureValue: Int): List[Double] = {
+ var categories = List[Double]()
+ var j = 0
+ var bitShiftedInput = input
+ while (j < maxFeatureValue) {
+ if (bitShiftedInput % 2 != 0) {
+ // updating the list of categories.
+ categories = j.toDouble :: categories
+ }
+ // Right shift by one
+ bitShiftedInput = bitShiftedInput >> 1
+ j += 1
+ }
+ categories
+ }
+
+ /**
+ * Find splits for a continuous feature
+ * NOTE: Returned number of splits is set based on `featureSamples` and
+ * could be different from the specified `numSplits`.
+ * The `numSplits` attribute in the `DecisionTreeMetadata` class will be set accordingly.
+ * @param featureSamples feature values of each sample
+ * @param metadata decision tree metadata
+ * NOTE: `metadata.numbins` will be changed accordingly
+ * if there are not enough splits to be found
+ * @param featureIndex feature index to find splits
+ * @return array of splits
+ */
+ private[tree] def findSplitsForContinuousFeature(
+ featureSamples: Array[Double],
+ metadata: DecisionTreeMetadata,
+ featureIndex: Int): Array[Double] = {
+ require(metadata.isContinuous(featureIndex),
+ "findSplitsForContinuousFeature can only be used to find splits for a continuous feature.")
+
+ val splits = {
+ val numSplits = metadata.numSplits(featureIndex)
+
+ // get count for each distinct value
+ val valueCountMap = featureSamples.foldLeft(Map.empty[Double, Int]) { (m, x) =>
+ m + ((x, m.getOrElse(x, 0) + 1))
+ }
+ // sort distinct values
+ val valueCounts = valueCountMap.toSeq.sortBy(_._1).toArray
+
+ // if possible splits is not enough or just enough, just return all possible splits
+ val possibleSplits = valueCounts.length
+ if (possibleSplits <= numSplits) {
+ valueCounts.map(_._1)
+ } else {
+ // stride between splits
+ val stride: Double = featureSamples.length.toDouble / (numSplits + 1)
+ logDebug("stride = " + stride)
+
+ // iterate `valueCount` to find splits
+ val splitsBuilder = mutable.ArrayBuilder.make[Double]
+ var index = 1
+ // currentCount: sum of counts of values that have been visited
+ var currentCount = valueCounts(0)._2
+ // targetCount: target value for `currentCount`.
+ // If `currentCount` is closest value to `targetCount`,
+ // then current value is a split threshold.
+ // After finding a split threshold, `targetCount` is added by stride.
+ var targetCount = stride
+ while (index < valueCounts.length) {
+ val previousCount = currentCount
+ currentCount += valueCounts(index)._2
+ val previousGap = math.abs(previousCount - targetCount)
+ val currentGap = math.abs(currentCount - targetCount)
+ // If adding count of current value to currentCount
+ // makes the gap between currentCount and targetCount smaller,
+ // previous value is a split threshold.
+ if (previousGap < currentGap) {
+ splitsBuilder += valueCounts(index - 1)._1
+ targetCount += stride
+ }
+ index += 1
+ }
+
+ splitsBuilder.result()
+ }
+ }
+
+ // TODO: Do not fail; just ignore the useless feature.
+ assert(splits.length > 0,
+ s"DecisionTree could not handle feature $featureIndex since it had only 1 unique value." +
+ " Please remove this feature and then try again.")
+ // set number of splits accordingly
+ metadata.setNumSplits(featureIndex, splits.length)
+
+ splits
+ }
+
+ private[tree] class NodeIndexInfo(
+ val nodeIndexInGroup: Int,
+ val featureSubset: Option[Array[Int]]) extends Serializable
+
+ /**
+ * Pull nodes off of the queue, and collect a group of nodes to be split on this iteration.
+ * This tracks the memory usage for aggregates and stops adding nodes when too much memory
+ * will be needed; this allows an adaptive number of nodes since different nodes may require
+ * different amounts of memory (if featureSubsetStrategy is not "all").
+ *
+ * @param nodeQueue Queue of nodes to split.
+ * @param maxMemoryUsage Bound on size of aggregate statistics.
+ * @return (nodesForGroup, treeToNodeToIndexInfo).
+ * nodesForGroup holds the nodes to split: treeIndex --> nodes in tree.
+ *
+ * treeToNodeToIndexInfo holds indices selected features for each node:
+ * treeIndex --> (global) node index --> (node index in group, feature indices).
+ * The (global) node index is the index in the tree; the node index in group is the
+ * index in [0, numNodesInGroup) of the node in this group.
+ * The feature indices are None if not subsampling features.
+ */
+ private[tree] def selectNodesToSplit(
+ nodeQueue: mutable.Queue[(Int, LearningNode)],
+ maxMemoryUsage: Long,
+ metadata: DecisionTreeMetadata,
+ rng: Random): (Map[Int, Array[LearningNode]], Map[Int, Map[Int, NodeIndexInfo]]) = {
+ // Collect some nodes to split:
+ // nodesForGroup(treeIndex) = nodes to split
+ val mutableNodesForGroup = new mutable.HashMap[Int, mutable.ArrayBuffer[LearningNode]]()
+ val mutableTreeToNodeToIndexInfo =
+ new mutable.HashMap[Int, mutable.HashMap[Int, NodeIndexInfo]]()
+ var memUsage: Long = 0L
+ var numNodesInGroup = 0
+ while (nodeQueue.nonEmpty && memUsage < maxMemoryUsage) {
+ val (treeIndex, node) = nodeQueue.head
+ // Choose subset of features for node (if subsampling).
+ val featureSubset: Option[Array[Int]] = if (metadata.subsamplingFeatures) {
+ Some(SamplingUtils.reservoirSampleAndCount(Range(0,
+ metadata.numFeatures).iterator, metadata.numFeaturesPerNode, rng.nextLong())._1)
+ } else {
+ None
+ }
+ // Check if enough memory remains to add this node to the group.
+ val nodeMemUsage = RandomForest.aggregateSizeForNode(metadata, featureSubset) * 8L
+ if (memUsage + nodeMemUsage <= maxMemoryUsage) {
+ nodeQueue.dequeue()
+ mutableNodesForGroup.getOrElseUpdate(treeIndex, new mutable.ArrayBuffer[LearningNode]()) +=
+ node
+ mutableTreeToNodeToIndexInfo
+ .getOrElseUpdate(treeIndex, new mutable.HashMap[Int, NodeIndexInfo]())(node.id)
+ = new NodeIndexInfo(numNodesInGroup, featureSubset)
+ }
+ numNodesInGroup += 1
+ memUsage += nodeMemUsage
+ }
+ // Convert mutable maps to immutable ones.
+ val nodesForGroup: Map[Int, Array[LearningNode]] =
+ mutableNodesForGroup.mapValues(_.toArray).toMap
+ val treeToNodeToIndexInfo = mutableTreeToNodeToIndexInfo.mapValues(_.toMap).toMap
+ (nodesForGroup, treeToNodeToIndexInfo)
+ }
+
+ /**
+ * Get the number of values to be stored for this node in the bin aggregates.
+ * @param featureSubset Indices of features which may be split at this node.
+ * If None, then use all features.
+ */
+ private def aggregateSizeForNode(
+ metadata: DecisionTreeMetadata,
+ featureSubset: Option[Array[Int]]): Long = {
+ val totalBins = if (featureSubset.nonEmpty) {
+ featureSubset.get.map(featureIndex => metadata.numBins(featureIndex).toLong).sum
+ } else {
+ metadata.numBins.map(_.toLong).sum
+ }
+ if (metadata.isClassification) {
+ metadata.numClasses * totalBins
+ } else {
+ 3 * totalBins
+ }
+ }
+
+}
diff --git a/mllib/src/main/scala/org/apache/spark/ml/tree/impl/TreePoint.scala b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/TreePoint.scala
new file mode 100644
index 0000000000000..9fa27e5e1f721
--- /dev/null
+++ b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/TreePoint.scala
@@ -0,0 +1,134 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.ml.tree.impl
+
+import org.apache.spark.ml.tree.{ContinuousSplit, Split}
+import org.apache.spark.mllib.regression.LabeledPoint
+import org.apache.spark.mllib.tree.impl.DecisionTreeMetadata
+import org.apache.spark.rdd.RDD
+
+
+/**
+ * Internal representation of LabeledPoint for DecisionTree.
+ * This bins feature values based on a subsampled of data as follows:
+ * (a) Continuous features are binned into ranges.
+ * (b) Unordered categorical features are binned based on subsets of feature values.
+ * "Unordered categorical features" are categorical features with low arity used in
+ * multiclass classification.
+ * (c) Ordered categorical features are binned based on feature values.
+ * "Ordered categorical features" are categorical features with high arity,
+ * or any categorical feature used in regression or binary classification.
+ *
+ * @param label Label from LabeledPoint
+ * @param binnedFeatures Binned feature values.
+ * Same length as LabeledPoint.features, but values are bin indices.
+ */
+private[spark] class TreePoint(val label: Double, val binnedFeatures: Array[Int])
+ extends Serializable {
+}
+
+private[spark] object TreePoint {
+
+ /**
+ * Convert an input dataset into its TreePoint representation,
+ * binning feature values in preparation for DecisionTree training.
+ * @param input Input dataset.
+ * @param splits Splits for features, of size (numFeatures, numSplits).
+ * @param metadata Learning and dataset metadata
+ * @return TreePoint dataset representation
+ */
+ def convertToTreeRDD(
+ input: RDD[LabeledPoint],
+ splits: Array[Array[Split]],
+ metadata: DecisionTreeMetadata): RDD[TreePoint] = {
+ // Construct arrays for featureArity for efficiency in the inner loop.
+ val featureArity: Array[Int] = new Array[Int](metadata.numFeatures)
+ var featureIndex = 0
+ while (featureIndex < metadata.numFeatures) {
+ featureArity(featureIndex) = metadata.featureArity.getOrElse(featureIndex, 0)
+ featureIndex += 1
+ }
+ val thresholds: Array[Array[Double]] = featureArity.zipWithIndex.map { case (arity, idx) =>
+ if (arity == 0) {
+ splits(idx).map(_.asInstanceOf[ContinuousSplit].threshold)
+ } else {
+ Array.empty[Double]
+ }
+ }
+ input.map { x =>
+ TreePoint.labeledPointToTreePoint(x, thresholds, featureArity)
+ }
+ }
+
+ /**
+ * Convert one LabeledPoint into its TreePoint representation.
+ * @param thresholds For each feature, split thresholds for continuous features,
+ * empty for categorical features.
+ * @param featureArity Array indexed by feature, with value 0 for continuous and numCategories
+ * for categorical features.
+ */
+ private def labeledPointToTreePoint(
+ labeledPoint: LabeledPoint,
+ thresholds: Array[Array[Double]],
+ featureArity: Array[Int]): TreePoint = {
+ val numFeatures = labeledPoint.features.size
+ val arr = new Array[Int](numFeatures)
+ var featureIndex = 0
+ while (featureIndex < numFeatures) {
+ arr(featureIndex) =
+ findBin(featureIndex, labeledPoint, featureArity(featureIndex), thresholds(featureIndex))
+ featureIndex += 1
+ }
+ new TreePoint(labeledPoint.label, arr)
+ }
+
+ /**
+ * Find discretized value for one (labeledPoint, feature).
+ *
+ * NOTE: We cannot use Bucketizer since it handles split thresholds differently than the old
+ * (mllib) tree API. We want to maintain the same behavior as the old tree API.
+ *
+ * @param featureArity 0 for continuous features; number of categories for categorical features.
+ */
+ private def findBin(
+ featureIndex: Int,
+ labeledPoint: LabeledPoint,
+ featureArity: Int,
+ thresholds: Array[Double]): Int = {
+ val featureValue = labeledPoint.features(featureIndex)
+
+ if (featureArity == 0) {
+ val idx = java.util.Arrays.binarySearch(thresholds, featureValue)
+ if (idx >= 0) {
+ idx
+ } else {
+ -idx - 1
+ }
+ } else {
+ // Categorical feature bins are indexed by feature values.
+ if (featureValue < 0 || featureValue >= featureArity) {
+ throw new IllegalArgumentException(
+ s"DecisionTree given invalid data:" +
+ s" Feature $featureIndex is categorical with values in {0,...,${featureArity - 1}," +
+ s" but a data point gives it value $featureValue.\n" +
+ " Bad data point: " + labeledPoint.toString)
+ }
+ featureValue.toInt
+ }
+ }
+}
diff --git a/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala b/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala
index e2444ab65b43b..f979319cc4b58 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala
@@ -32,38 +32,7 @@ import org.apache.spark.sql.types.StructType
/**
* Params for [[CrossValidator]] and [[CrossValidatorModel]].
*/
-private[ml] trait CrossValidatorParams extends Params {
-
- /**
- * param for the estimator to be cross-validated
- * @group param
- */
- val estimator: Param[Estimator[_]] = new Param(this, "estimator", "estimator for selection")
-
- /** @group getParam */
- def getEstimator: Estimator[_] = $(estimator)
-
- /**
- * param for estimator param maps
- * @group param
- */
- val estimatorParamMaps: Param[Array[ParamMap]] =
- new Param(this, "estimatorParamMaps", "param maps for the estimator")
-
- /** @group getParam */
- def getEstimatorParamMaps: Array[ParamMap] = $(estimatorParamMaps)
-
- /**
- * param for the evaluator used to select hyper-parameters that maximize the cross-validated
- * metric
- * @group param
- */
- val evaluator: Param[Evaluator] = new Param(this, "evaluator",
- "evaluator used to select hyper-parameters that maximize the cross-validated metric")
-
- /** @group getParam */
- def getEvaluator: Evaluator = $(evaluator)
-
+private[ml] trait CrossValidatorParams extends ValidatorParams {
/**
* Param for number of folds for cross validation. Must be >= 2.
* Default: 3
diff --git a/mllib/src/main/scala/org/apache/spark/ml/tuning/TrainValidationSplit.scala b/mllib/src/main/scala/org/apache/spark/ml/tuning/TrainValidationSplit.scala
new file mode 100644
index 0000000000000..c0edc730b6fd6
--- /dev/null
+++ b/mllib/src/main/scala/org/apache/spark/ml/tuning/TrainValidationSplit.scala
@@ -0,0 +1,168 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.ml.tuning
+
+import org.apache.spark.Logging
+import org.apache.spark.annotation.Experimental
+import org.apache.spark.ml.evaluation.Evaluator
+import org.apache.spark.ml.{Estimator, Model}
+import org.apache.spark.ml.param.{DoubleParam, ParamMap, ParamValidators}
+import org.apache.spark.ml.util.Identifiable
+import org.apache.spark.sql.DataFrame
+import org.apache.spark.sql.types.StructType
+
+/**
+ * Params for [[TrainValidationSplit]] and [[TrainValidationSplitModel]].
+ */
+private[ml] trait TrainValidationSplitParams extends ValidatorParams {
+ /**
+ * Param for ratio between train and validation data. Must be between 0 and 1.
+ * Default: 0.75
+ * @group param
+ */
+ val trainRatio: DoubleParam = new DoubleParam(this, "trainRatio",
+ "ratio between training set and validation set (>= 0 && <= 1)", ParamValidators.inRange(0, 1))
+
+ /** @group getParam */
+ def getTrainRatio: Double = $(trainRatio)
+
+ setDefault(trainRatio -> 0.75)
+}
+
+/**
+ * :: Experimental ::
+ * Validation for hyper-parameter tuning.
+ * Randomly splits the input dataset into train and validation sets,
+ * and uses evaluation metric on the validation set to select the best model.
+ * Similar to [[CrossValidator]], but only splits the set once.
+ */
+@Experimental
+class TrainValidationSplit(override val uid: String) extends Estimator[TrainValidationSplitModel]
+ with TrainValidationSplitParams with Logging {
+
+ def this() = this(Identifiable.randomUID("tvs"))
+
+ /** @group setParam */
+ def setEstimator(value: Estimator[_]): this.type = set(estimator, value)
+
+ /** @group setParam */
+ def setEstimatorParamMaps(value: Array[ParamMap]): this.type = set(estimatorParamMaps, value)
+
+ /** @group setParam */
+ def setEvaluator(value: Evaluator): this.type = set(evaluator, value)
+
+ /** @group setParam */
+ def setTrainRatio(value: Double): this.type = set(trainRatio, value)
+
+ override def fit(dataset: DataFrame): TrainValidationSplitModel = {
+ val schema = dataset.schema
+ transformSchema(schema, logging = true)
+ val sqlCtx = dataset.sqlContext
+ val est = $(estimator)
+ val eval = $(evaluator)
+ val epm = $(estimatorParamMaps)
+ val numModels = epm.length
+ val metrics = new Array[Double](epm.length)
+
+ val Array(training, validation) =
+ dataset.rdd.randomSplit(Array($(trainRatio), 1 - $(trainRatio)))
+ val trainingDataset = sqlCtx.createDataFrame(training, schema).cache()
+ val validationDataset = sqlCtx.createDataFrame(validation, schema).cache()
+
+ // multi-model training
+ logDebug(s"Train split with multiple sets of parameters.")
+ val models = est.fit(trainingDataset, epm).asInstanceOf[Seq[Model[_]]]
+ trainingDataset.unpersist()
+ var i = 0
+ while (i < numModels) {
+ // TODO: duplicate evaluator to take extra params from input
+ val metric = eval.evaluate(models(i).transform(validationDataset, epm(i)))
+ logDebug(s"Got metric $metric for model trained with ${epm(i)}.")
+ metrics(i) += metric
+ i += 1
+ }
+ validationDataset.unpersist()
+
+ logInfo(s"Train validation split metrics: ${metrics.toSeq}")
+ val (bestMetric, bestIndex) = metrics.zipWithIndex.maxBy(_._1)
+ logInfo(s"Best set of parameters:\n${epm(bestIndex)}")
+ logInfo(s"Best train validation split metric: $bestMetric.")
+ val bestModel = est.fit(dataset, epm(bestIndex)).asInstanceOf[Model[_]]
+ copyValues(new TrainValidationSplitModel(uid, bestModel, metrics).setParent(this))
+ }
+
+ override def transformSchema(schema: StructType): StructType = {
+ $(estimator).transformSchema(schema)
+ }
+
+ override def validateParams(): Unit = {
+ super.validateParams()
+ val est = $(estimator)
+ for (paramMap <- $(estimatorParamMaps)) {
+ est.copy(paramMap).validateParams()
+ }
+ }
+
+ override def copy(extra: ParamMap): TrainValidationSplit = {
+ val copied = defaultCopy(extra).asInstanceOf[TrainValidationSplit]
+ if (copied.isDefined(estimator)) {
+ copied.setEstimator(copied.getEstimator.copy(extra))
+ }
+ if (copied.isDefined(evaluator)) {
+ copied.setEvaluator(copied.getEvaluator.copy(extra))
+ }
+ copied
+ }
+}
+
+/**
+ * :: Experimental ::
+ * Model from train validation split.
+ *
+ * @param uid Id.
+ * @param bestModel Estimator determined best model.
+ * @param validationMetrics Evaluated validation metrics.
+ */
+@Experimental
+class TrainValidationSplitModel private[ml] (
+ override val uid: String,
+ val bestModel: Model[_],
+ val validationMetrics: Array[Double])
+ extends Model[TrainValidationSplitModel] with TrainValidationSplitParams {
+
+ override def validateParams(): Unit = {
+ bestModel.validateParams()
+ }
+
+ override def transform(dataset: DataFrame): DataFrame = {
+ transformSchema(dataset.schema, logging = true)
+ bestModel.transform(dataset)
+ }
+
+ override def transformSchema(schema: StructType): StructType = {
+ bestModel.transformSchema(schema)
+ }
+
+ override def copy(extra: ParamMap): TrainValidationSplitModel = {
+ val copied = new TrainValidationSplitModel (
+ uid,
+ bestModel.copy(extra).asInstanceOf[Model[_]],
+ validationMetrics.clone())
+ copyValues(copied, extra)
+ }
+}
diff --git a/mllib/src/main/scala/org/apache/spark/ml/tuning/ValidatorParams.scala b/mllib/src/main/scala/org/apache/spark/ml/tuning/ValidatorParams.scala
new file mode 100644
index 0000000000000..8897ab0825acd
--- /dev/null
+++ b/mllib/src/main/scala/org/apache/spark/ml/tuning/ValidatorParams.scala
@@ -0,0 +1,60 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.ml.tuning
+
+import org.apache.spark.annotation.DeveloperApi
+import org.apache.spark.ml.Estimator
+import org.apache.spark.ml.evaluation.Evaluator
+import org.apache.spark.ml.param.{ParamMap, Param, Params}
+
+/**
+ * :: DeveloperApi ::
+ * Common params for [[TrainValidationSplitParams]] and [[CrossValidatorParams]].
+ */
+@DeveloperApi
+private[ml] trait ValidatorParams extends Params {
+
+ /**
+ * param for the estimator to be validated
+ * @group param
+ */
+ val estimator: Param[Estimator[_]] = new Param(this, "estimator", "estimator for selection")
+
+ /** @group getParam */
+ def getEstimator: Estimator[_] = $(estimator)
+
+ /**
+ * param for estimator param maps
+ * @group param
+ */
+ val estimatorParamMaps: Param[Array[ParamMap]] =
+ new Param(this, "estimatorParamMaps", "param maps for the estimator")
+
+ /** @group getParam */
+ def getEstimatorParamMaps: Array[ParamMap] = $(estimatorParamMaps)
+
+ /**
+ * param for the evaluator used to select hyper-parameters that maximize the validated metric
+ * @group param
+ */
+ val evaluator: Param[Evaluator] = new Param(this, "evaluator",
+ "evaluator used to select hyper-parameters that maximize the validated metric")
+
+ /** @group getParam */
+ def getEvaluator: Evaluator = $(evaluator)
+}
diff --git a/mllib/src/main/scala/org/apache/spark/ml/util/stopwatches.scala b/mllib/src/main/scala/org/apache/spark/ml/util/stopwatches.scala
new file mode 100644
index 0000000000000..8d4174124b5c4
--- /dev/null
+++ b/mllib/src/main/scala/org/apache/spark/ml/util/stopwatches.scala
@@ -0,0 +1,153 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.ml.util
+
+import scala.collection.mutable
+
+import org.apache.spark.{Accumulator, SparkContext}
+
+/**
+ * Abstract class for stopwatches.
+ */
+private[spark] abstract class Stopwatch extends Serializable {
+
+ @transient private var running: Boolean = false
+ private var startTime: Long = _
+
+ /**
+ * Name of the stopwatch.
+ */
+ val name: String
+
+ /**
+ * Starts the stopwatch.
+ * Throws an exception if the stopwatch is already running.
+ */
+ def start(): Unit = {
+ assume(!running, "start() called but the stopwatch is already running.")
+ running = true
+ startTime = now
+ }
+
+ /**
+ * Stops the stopwatch and returns the duration of the last session in milliseconds.
+ * Throws an exception if the stopwatch is not running.
+ */
+ def stop(): Long = {
+ assume(running, "stop() called but the stopwatch is not running.")
+ val duration = now - startTime
+ add(duration)
+ running = false
+ duration
+ }
+
+ /**
+ * Checks whether the stopwatch is running.
+ */
+ def isRunning: Boolean = running
+
+ /**
+ * Returns total elapsed time in milliseconds, not counting the current session if the stopwatch
+ * is running.
+ */
+ def elapsed(): Long
+
+ override def toString: String = s"$name: ${elapsed()}ms"
+
+ /**
+ * Gets the current time in milliseconds.
+ */
+ protected def now: Long = System.currentTimeMillis()
+
+ /**
+ * Adds input duration to total elapsed time.
+ */
+ protected def add(duration: Long): Unit
+}
+
+/**
+ * A local [[Stopwatch]].
+ */
+private[spark] class LocalStopwatch(override val name: String) extends Stopwatch {
+
+ private var elapsedTime: Long = 0L
+
+ override def elapsed(): Long = elapsedTime
+
+ override protected def add(duration: Long): Unit = {
+ elapsedTime += duration
+ }
+}
+
+/**
+ * A distributed [[Stopwatch]] using Spark accumulator.
+ * @param sc SparkContext
+ */
+private[spark] class DistributedStopwatch(
+ sc: SparkContext,
+ override val name: String) extends Stopwatch {
+
+ private val elapsedTime: Accumulator[Long] = sc.accumulator(0L, s"DistributedStopwatch($name)")
+
+ override def elapsed(): Long = elapsedTime.value
+
+ override protected def add(duration: Long): Unit = {
+ elapsedTime += duration
+ }
+}
+
+/**
+ * A multiple stopwatch that contains local and distributed stopwatches.
+ * @param sc SparkContext
+ */
+private[spark] class MultiStopwatch(@transient private val sc: SparkContext) extends Serializable {
+
+ private val stopwatches: mutable.Map[String, Stopwatch] = mutable.Map.empty
+
+ /**
+ * Adds a local stopwatch.
+ * @param name stopwatch name
+ */
+ def addLocal(name: String): this.type = {
+ require(!stopwatches.contains(name), s"Stopwatch with name $name already exists.")
+ stopwatches(name) = new LocalStopwatch(name)
+ this
+ }
+
+ /**
+ * Adds a distributed stopwatch.
+ * @param name stopwatch name
+ */
+ def addDistributed(name: String): this.type = {
+ require(!stopwatches.contains(name), s"Stopwatch with name $name already exists.")
+ stopwatches(name) = new DistributedStopwatch(sc, name)
+ this
+ }
+
+ /**
+ * Gets a stopwatch.
+ * @param name stopwatch name
+ */
+ def apply(name: String): Stopwatch = stopwatches(name)
+
+ override def toString: String = {
+ stopwatches.values.toArray.sortBy(_.name)
+ .map(c => s" $c")
+ .mkString("{\n", ",\n", "\n}")
+ }
+}
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala b/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala
index e628059c4af8e..fda8d5a0b048f 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala
@@ -43,7 +43,7 @@ import org.apache.spark.mllib.recommendation._
import org.apache.spark.mllib.regression._
import org.apache.spark.mllib.stat.correlation.CorrelationNames
import org.apache.spark.mllib.stat.distribution.MultivariateGaussian
-import org.apache.spark.mllib.stat.test.ChiSqTestResult
+import org.apache.spark.mllib.stat.test.{ChiSqTestResult, KolmogorovSmirnovTestResult}
import org.apache.spark.mllib.stat.{
KernelDensity, MultivariateStatisticalSummary, Statistics}
import org.apache.spark.mllib.tree.configuration.{Algo, BoostingStrategy, Strategy}
@@ -502,6 +502,39 @@ private[python] class PythonMLLibAPI extends Serializable {
new MatrixFactorizationModelWrapper(model)
}
+ /**
+ * Java stub for Python mllib LDA.run()
+ */
+ def trainLDAModel(
+ data: JavaRDD[java.util.List[Any]],
+ k: Int,
+ maxIterations: Int,
+ docConcentration: Double,
+ topicConcentration: Double,
+ seed: java.lang.Long,
+ checkpointInterval: Int,
+ optimizer: String): LDAModel = {
+ val algo = new LDA()
+ .setK(k)
+ .setMaxIterations(maxIterations)
+ .setDocConcentration(docConcentration)
+ .setTopicConcentration(topicConcentration)
+ .setCheckpointInterval(checkpointInterval)
+ .setOptimizer(optimizer)
+
+ if (seed != null) algo.setSeed(seed)
+
+ val documents = data.rdd.map(_.asScala.toArray).map { r =>
+ r(0) match {
+ case i: java.lang.Integer => (i.toLong, r(1).asInstanceOf[Vector])
+ case i: java.lang.Long => (i.toLong, r(1).asInstanceOf[Vector])
+ case _ => throw new IllegalArgumentException("input values contains invalid type value.")
+ }
+ }
+ algo.run(documents)
+ }
+
+
/**
* Java stub for Python mllib FPGrowth.train(). This stub returns a handle
* to the Java object instead of the content of the Java object. Extra care
@@ -1060,6 +1093,18 @@ private[python] class PythonMLLibAPI extends Serializable {
LinearDataGenerator.generateLinearRDD(
sc, nexamples, nfeatures, eps, nparts, intercept)
}
+
+ /**
+ * Java stub for Statistics.kolmogorovSmirnovTest()
+ */
+ def kolmogorovSmirnovTest(
+ data: JavaRDD[Double],
+ distName: String,
+ params: JList[Double]): KolmogorovSmirnovTestResult = {
+ val paramsSeq = params.asScala.toSeq
+ Statistics.kolmogorovSmirnovTest(data, distName, paramsSeq: _*)
+ }
+
}
/**
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/classification/ClassificationModel.scala b/mllib/src/main/scala/org/apache/spark/mllib/classification/ClassificationModel.scala
index 35a0db76f3a8c..ba73024e3c04d 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/classification/ClassificationModel.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/classification/ClassificationModel.scala
@@ -36,6 +36,7 @@ trait ClassificationModel extends Serializable {
*
* @param testData RDD representing data points to be predicted
* @return an RDD[Double] where each entry contains the corresponding prediction
+ * @since 0.8.0
*/
def predict(testData: RDD[Vector]): RDD[Double]
@@ -44,6 +45,7 @@ trait ClassificationModel extends Serializable {
*
* @param testData array representing a single data point
* @return predicted category from the trained model
+ * @since 0.8.0
*/
def predict(testData: Vector): Double
@@ -51,6 +53,7 @@ trait ClassificationModel extends Serializable {
* Predict values for examples stored in a JavaRDD.
* @param testData JavaRDD representing data points to be predicted
* @return a JavaRDD[java.lang.Double] where each entry contains the corresponding prediction
+ * @since 0.8.0
*/
def predict(testData: JavaRDD[Vector]): JavaRDD[java.lang.Double] =
predict(testData.rdd).toJavaRDD().asInstanceOf[JavaRDD[java.lang.Double]]
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/classification/LogisticRegression.scala b/mllib/src/main/scala/org/apache/spark/mllib/classification/LogisticRegression.scala
index 2df4d21e8cd55..268642ac6a2f6 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/classification/LogisticRegression.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/classification/LogisticRegression.scala
@@ -85,6 +85,7 @@ class LogisticRegressionModel (
* in Binary Logistic Regression. An example with prediction score greater than or equal to
* this threshold is identified as an positive, and negative otherwise. The default value is 0.5.
* It is only used for binary classification.
+ * @since 1.0.0
*/
@Experimental
def setThreshold(threshold: Double): this.type = {
@@ -96,6 +97,7 @@ class LogisticRegressionModel (
* :: Experimental ::
* Returns the threshold (if any) used for converting raw prediction scores into 0/1 predictions.
* It is only used for binary classification.
+ * @since 1.3.0
*/
@Experimental
def getThreshold: Option[Double] = threshold
@@ -104,6 +106,7 @@ class LogisticRegressionModel (
* :: Experimental ::
* Clears the threshold so that `predict` will output raw prediction scores.
* It is only used for binary classification.
+ * @since 1.0.0
*/
@Experimental
def clearThreshold(): this.type = {
@@ -155,6 +158,9 @@ class LogisticRegressionModel (
}
}
+ /**
+ * @since 1.3.0
+ */
override def save(sc: SparkContext, path: String): Unit = {
GLMClassificationModel.SaveLoadV1_0.save(sc, path, this.getClass.getName,
numFeatures, numClasses, weights, intercept, threshold)
@@ -162,6 +168,9 @@ class LogisticRegressionModel (
override protected def formatVersion: String = "1.0"
+ /**
+ * @since 1.4.0
+ */
override def toString: String = {
s"${super.toString}, numClasses = ${numClasses}, threshold = ${threshold.getOrElse("None")}"
}
@@ -169,6 +178,9 @@ class LogisticRegressionModel (
object LogisticRegressionModel extends Loader[LogisticRegressionModel] {
+ /**
+ * @since 1.3.0
+ */
override def load(sc: SparkContext, path: String): LogisticRegressionModel = {
val (loadedClassName, version, metadata) = Loader.loadMetadata(sc, path)
// Hard-code class name string in case it changes in the future
@@ -249,6 +261,7 @@ object LogisticRegressionWithSGD {
* @param miniBatchFraction Fraction of data to be used per iteration.
* @param initialWeights Initial set of weights to be used. Array should be equal in size to
* the number of features in the data.
+ * @since 1.0.0
*/
def train(
input: RDD[LabeledPoint],
@@ -271,6 +284,7 @@ object LogisticRegressionWithSGD {
* @param stepSize Step size to be used for each iteration of gradient descent.
* @param miniBatchFraction Fraction of data to be used per iteration.
+ * @since 1.0.0
*/
def train(
input: RDD[LabeledPoint],
@@ -292,6 +306,7 @@ object LogisticRegressionWithSGD {
* @param numIterations Number of iterations of gradient descent to run.
* @return a LogisticRegressionModel which has the weights and offset from training.
+ * @since 1.0.0
*/
def train(
input: RDD[LabeledPoint],
@@ -309,6 +324,7 @@ object LogisticRegressionWithSGD {
* @param input RDD of (label, array of features) pairs.
* @param numIterations Number of iterations of gradient descent to run.
* @return a LogisticRegressionModel which has the weights and offset from training.
+ * @since 1.0.0
*/
def train(
input: RDD[LabeledPoint],
@@ -345,6 +361,7 @@ class LogisticRegressionWithLBFGS
* Set the number of possible outcomes for k classes classification problem in
* Multinomial Logistic Regression.
* By default, it is binary logistic regression so k will be set to 2.
+ * @since 1.3.0
*/
@Experimental
def setNumClasses(numClasses: Int): this.type = {
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/classification/NaiveBayes.scala b/mllib/src/main/scala/org/apache/spark/mllib/classification/NaiveBayes.scala
index f51ee36d0dfcb..2df91c09421e9 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/classification/NaiveBayes.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/classification/NaiveBayes.scala
@@ -40,7 +40,7 @@ import org.apache.spark.sql.{DataFrame, SQLContext}
* where D is number of features
* @param modelType The type of NB model to fit can be "multinomial" or "bernoulli"
*/
-class NaiveBayesModel private[mllib] (
+class NaiveBayesModel private[spark] (
val labels: Array[Double],
val pi: Array[Double],
val theta: Array[Array[Double]],
@@ -93,26 +93,70 @@ class NaiveBayesModel private[mllib] (
override def predict(testData: Vector): Double = {
modelType match {
case Multinomial =>
- val prob = thetaMatrix.multiply(testData)
- BLAS.axpy(1.0, piVector, prob)
- labels(prob.argmax)
+ labels(multinomialCalculation(testData).argmax)
case Bernoulli =>
- testData.foreachActive { (index, value) =>
- if (value != 0.0 && value != 1.0) {
- throw new SparkException(
- s"Bernoulli naive Bayes requires 0 or 1 feature values but found $testData.")
- }
- }
- val prob = thetaMinusNegTheta.get.multiply(testData)
- BLAS.axpy(1.0, piVector, prob)
- BLAS.axpy(1.0, negThetaSum.get, prob)
- labels(prob.argmax)
- case _ =>
- // This should never happen.
- throw new UnknownError(s"Invalid modelType: $modelType.")
+ labels(bernoulliCalculation(testData).argmax)
+ }
+ }
+
+ /**
+ * Predict values for the given data set using the model trained.
+ *
+ * @param testData RDD representing data points to be predicted
+ * @return an RDD[Vector] where each entry contains the predicted posterior class probabilities,
+ * in the same order as class labels
+ */
+ def predictProbabilities(testData: RDD[Vector]): RDD[Vector] = {
+ val bcModel = testData.context.broadcast(this)
+ testData.mapPartitions { iter =>
+ val model = bcModel.value
+ iter.map(model.predictProbabilities)
}
}
+ /**
+ * Predict posterior class probabilities for a single data point using the model trained.
+ *
+ * @param testData array representing a single data point
+ * @return predicted posterior class probabilities from the trained model,
+ * in the same order as class labels
+ */
+ def predictProbabilities(testData: Vector): Vector = {
+ modelType match {
+ case Multinomial =>
+ posteriorProbabilities(multinomialCalculation(testData))
+ case Bernoulli =>
+ posteriorProbabilities(bernoulliCalculation(testData))
+ }
+ }
+
+ private def multinomialCalculation(testData: Vector) = {
+ val prob = thetaMatrix.multiply(testData)
+ BLAS.axpy(1.0, piVector, prob)
+ prob
+ }
+
+ private def bernoulliCalculation(testData: Vector) = {
+ testData.foreachActive((_, value) =>
+ if (value != 0.0 && value != 1.0) {
+ throw new SparkException(
+ s"Bernoulli naive Bayes requires 0 or 1 feature values but found $testData.")
+ }
+ )
+ val prob = thetaMinusNegTheta.get.multiply(testData)
+ BLAS.axpy(1.0, piVector, prob)
+ BLAS.axpy(1.0, negThetaSum.get, prob)
+ prob
+ }
+
+ private def posteriorProbabilities(logProb: DenseVector) = {
+ val logProbArray = logProb.toArray
+ val maxLog = logProbArray.max
+ val scaledProbs = logProbArray.map(lp => math.exp(lp - maxLog))
+ val probSum = scaledProbs.sum
+ new DenseVector(scaledProbs.map(_ / probSum))
+ }
+
override def save(sc: SparkContext, path: String): Unit = {
val data = NaiveBayesModel.SaveLoadV2_0.Data(labels, pi, theta, modelType)
NaiveBayesModel.SaveLoadV2_0.save(sc, path, data)
@@ -338,7 +382,7 @@ class NaiveBayes private (
BLAS.axpy(1.0, c2._2, c1._2)
(c1._1 + c2._1, c1._2)
}
- ).collect()
+ ).collect().sortBy(_._1)
val numLabels = aggregated.length
var numDocuments = 0L
@@ -381,13 +425,13 @@ class NaiveBayes private (
object NaiveBayes {
/** String name for multinomial model type. */
- private[classification] val Multinomial: String = "multinomial"
+ private[spark] val Multinomial: String = "multinomial"
/** String name for Bernoulli model type. */
- private[classification] val Bernoulli: String = "bernoulli"
+ private[spark] val Bernoulli: String = "bernoulli"
/* Set of modelTypes that NaiveBayes supports */
- private[classification] val supportedModelTypes = Set(Multinomial, Bernoulli)
+ private[spark] val supportedModelTypes = Set(Multinomial, Bernoulli)
/**
* Trains a Naive Bayes model given an RDD of `(label, features)` pairs.
@@ -400,6 +444,7 @@ object NaiveBayes {
*
* @param input RDD of `(label, array of features)` pairs. Every vector should be a frequency
* vector or a count vector.
+ * @since 0.9.0
*/
def train(input: RDD[LabeledPoint]): NaiveBayesModel = {
new NaiveBayes().run(input)
@@ -415,6 +460,7 @@ object NaiveBayes {
* @param input RDD of `(label, array of features)` pairs. Every vector should be a frequency
* vector or a count vector.
* @param lambda The smoothing parameter
+ * @since 0.9.0
*/
def train(input: RDD[LabeledPoint], lambda: Double): NaiveBayesModel = {
new NaiveBayes(lambda, Multinomial).run(input)
@@ -437,6 +483,7 @@ object NaiveBayes {
*
* @param modelType The type of NB model to fit from the enumeration NaiveBayesModels, can be
* multinomial or bernoulli
+ * @since 0.9.0
*/
def train(input: RDD[LabeledPoint], lambda: Double, modelType: String): NaiveBayesModel = {
require(supportedModelTypes.contains(modelType),
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/classification/SVM.scala b/mllib/src/main/scala/org/apache/spark/mllib/classification/SVM.scala
index 348485560713e..5b54feeb10467 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/classification/SVM.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/classification/SVM.scala
@@ -46,6 +46,7 @@ class SVMModel (
* Sets the threshold that separates positive predictions from negative predictions. An example
* with prediction score greater than or equal to this threshold is identified as an positive,
* and negative otherwise. The default value is 0.0.
+ * @since 1.3.0
*/
@Experimental
def setThreshold(threshold: Double): this.type = {
@@ -56,6 +57,7 @@ class SVMModel (
/**
* :: Experimental ::
* Returns the threshold (if any) used for converting raw prediction scores into 0/1 predictions.
+ * @since 1.3.0
*/
@Experimental
def getThreshold: Option[Double] = threshold
@@ -63,6 +65,7 @@ class SVMModel (
/**
* :: Experimental ::
* Clears the threshold so that `predict` will output raw prediction scores.
+ * @since 1.0.0
*/
@Experimental
def clearThreshold(): this.type = {
@@ -81,6 +84,9 @@ class SVMModel (
}
}
+ /**
+ * @since 1.3.0
+ */
override def save(sc: SparkContext, path: String): Unit = {
GLMClassificationModel.SaveLoadV1_0.save(sc, path, this.getClass.getName,
numFeatures = weights.size, numClasses = 2, weights, intercept, threshold)
@@ -88,6 +94,9 @@ class SVMModel (
override protected def formatVersion: String = "1.0"
+ /**
+ * @since 1.4.0
+ */
override def toString: String = {
s"${super.toString}, numClasses = 2, threshold = ${threshold.getOrElse("None")}"
}
@@ -95,6 +104,9 @@ class SVMModel (
object SVMModel extends Loader[SVMModel] {
+ /**
+ * @since 1.3.0
+ */
override def load(sc: SparkContext, path: String): SVMModel = {
val (loadedClassName, version, metadata) = Loader.loadMetadata(sc, path)
// Hard-code class name string in case it changes in the future
@@ -173,6 +185,7 @@ object SVMWithSGD {
* @param miniBatchFraction Fraction of data to be used per iteration.
* @param initialWeights Initial set of weights to be used. Array should be equal in size to
* the number of features in the data.
+ * @since 0.8.0
*/
def train(
input: RDD[LabeledPoint],
@@ -196,6 +209,7 @@ object SVMWithSGD {
* @param stepSize Step size to be used for each iteration of gradient descent.
* @param regParam Regularization parameter.
* @param miniBatchFraction Fraction of data to be used per iteration.
+ * @since 0.8.0
*/
def train(
input: RDD[LabeledPoint],
@@ -217,6 +231,7 @@ object SVMWithSGD {
* @param regParam Regularization parameter.
* @param numIterations Number of iterations of gradient descent to run.
* @return a SVMModel which has the weights and offset from training.
+ * @since 0.8.0
*/
def train(
input: RDD[LabeledPoint],
@@ -235,6 +250,7 @@ object SVMWithSGD {
* @param input RDD of (label, array of features) pairs.
* @param numIterations Number of iterations of gradient descent to run.
* @return a SVMModel which has the weights and offset from training.
+ * @since 0.8.0
*/
def train(input: RDD[LabeledPoint], numIterations: Int): SVMModel = {
train(input, numIterations, 1.0, 0.01, 1.0)
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeans.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeans.scala
index 0f8d6a399682d..0a65403f4ec95 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeans.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeans.scala
@@ -85,9 +85,7 @@ class KMeans private (
* (Bahmani et al., Scalable K-Means++, VLDB 2012). Default: k-means||.
*/
def setInitializationMode(initializationMode: String): this.type = {
- if (initializationMode != KMeans.RANDOM && initializationMode != KMeans.K_MEANS_PARALLEL) {
- throw new IllegalArgumentException("Invalid initialization mode: " + initializationMode)
- }
+ KMeans.validateInitMode(initializationMode)
this.initializationMode = initializationMode
this
}
@@ -156,6 +154,21 @@ class KMeans private (
this
}
+ // Initial cluster centers can be provided as a KMeansModel object rather than using the
+ // random or k-means|| initializationMode
+ private var initialModel: Option[KMeansModel] = None
+
+ /**
+ * Set the initial starting point, bypassing the random initialization or k-means||
+ * The condition model.k == this.k must be met, failure results
+ * in an IllegalArgumentException.
+ */
+ def setInitialModel(model: KMeansModel): this.type = {
+ require(model.k == k, "mismatched cluster count")
+ initialModel = Some(model)
+ this
+ }
+
/**
* Train a K-means model on the given set of points; `data` should be cached for high
* performance, because this is an iterative algorithm.
@@ -193,20 +206,34 @@ class KMeans private (
val initStartTime = System.nanoTime()
- val centers = if (initializationMode == KMeans.RANDOM) {
- initRandom(data)
+ // Only one run is allowed when initialModel is given
+ val numRuns = if (initialModel.nonEmpty) {
+ if (runs > 1) logWarning("Ignoring runs; one run is allowed when initialModel is given.")
+ 1
} else {
- initKMeansParallel(data)
+ runs
}
+ val centers = initialModel match {
+ case Some(kMeansCenters) => {
+ Array(kMeansCenters.clusterCenters.map(s => new VectorWithNorm(s)))
+ }
+ case None => {
+ if (initializationMode == KMeans.RANDOM) {
+ initRandom(data)
+ } else {
+ initKMeansParallel(data)
+ }
+ }
+ }
val initTimeInSeconds = (System.nanoTime() - initStartTime) / 1e9
logInfo(s"Initialization with $initializationMode took " + "%.3f".format(initTimeInSeconds) +
" seconds.")
- val active = Array.fill(runs)(true)
- val costs = Array.fill(runs)(0.0)
+ val active = Array.fill(numRuns)(true)
+ val costs = Array.fill(numRuns)(0.0)
- var activeRuns = new ArrayBuffer[Int] ++ (0 until runs)
+ var activeRuns = new ArrayBuffer[Int] ++ (0 until numRuns)
var iteration = 0
val iterationStartTime = System.nanoTime()
@@ -521,6 +548,14 @@ object KMeans {
v2: VectorWithNorm): Double = {
MLUtils.fastSquaredDistance(v1.vector, v1.norm, v2.vector, v2.norm)
}
+
+ private[spark] def validateInitMode(initMode: String): Boolean = {
+ initMode match {
+ case KMeans.RANDOM => true
+ case KMeans.K_MEANS_PARALLEL => true
+ case _ => false
+ }
+ }
}
/**
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDA.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDA.scala
index a410547a72fda..ab124e6d77c5e 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDA.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDA.scala
@@ -23,11 +23,10 @@ import org.apache.spark.Logging
import org.apache.spark.annotation.{DeveloperApi, Experimental}
import org.apache.spark.api.java.JavaPairRDD
import org.apache.spark.graphx._
-import org.apache.spark.mllib.linalg.Vector
+import org.apache.spark.mllib.linalg.{Vector, Vectors}
import org.apache.spark.rdd.RDD
import org.apache.spark.util.Utils
-
/**
* :: Experimental ::
*
@@ -49,14 +48,15 @@ import org.apache.spark.util.Utils
class LDA private (
private var k: Int,
private var maxIterations: Int,
- private var docConcentration: Double,
+ private var docConcentration: Vector,
private var topicConcentration: Double,
private var seed: Long,
private var checkpointInterval: Int,
private var ldaOptimizer: LDAOptimizer) extends Logging {
- def this() = this(k = 10, maxIterations = 20, docConcentration = -1, topicConcentration = -1,
- seed = Utils.random.nextLong(), checkpointInterval = 10, ldaOptimizer = new EMLDAOptimizer)
+ def this() = this(k = 10, maxIterations = 20, docConcentration = Vectors.dense(-1),
+ topicConcentration = -1, seed = Utils.random.nextLong(), checkpointInterval = 10,
+ ldaOptimizer = new EMLDAOptimizer)
/**
* Number of topics to infer. I.e., the number of soft cluster centers.
@@ -77,37 +77,50 @@ class LDA private (
* Concentration parameter (commonly named "alpha") for the prior placed on documents'
* distributions over topics ("theta").
*
- * This is the parameter to a symmetric Dirichlet distribution.
+ * This is the parameter to a Dirichlet distribution.
*/
- def getDocConcentration: Double = this.docConcentration
+ def getDocConcentration: Vector = this.docConcentration
/**
* Concentration parameter (commonly named "alpha") for the prior placed on documents'
* distributions over topics ("theta").
*
- * This is the parameter to a symmetric Dirichlet distribution, where larger values
- * mean more smoothing (more regularization).
+ * This is the parameter to a Dirichlet distribution, where larger values mean more smoothing
+ * (more regularization).
*
- * If set to -1, then docConcentration is set automatically.
- * (default = -1 = automatic)
+ * If set to a singleton vector Vector(-1), then docConcentration is set automatically. If set to
+ * singleton vector Vector(t) where t != -1, then t is replicated to a vector of length k during
+ * [[LDAOptimizer.initialize()]]. Otherwise, the [[docConcentration]] vector must be length k.
+ * (default = Vector(-1) = automatic)
*
* Optimizer-specific parameter settings:
* - EM
- * - Value should be > 1.0
- * - default = (50 / k) + 1, where 50/k is common in LDA libraries and +1 follows
- * Asuncion et al. (2009), who recommend a +1 adjustment for EM.
+ * - Currently only supports symmetric distributions, so all values in the vector should be
+ * the same.
+ * - Values should be > 1.0
+ * - default = uniformly (50 / k) + 1, where 50/k is common in LDA libraries and +1 follows
+ * from Asuncion et al. (2009), who recommend a +1 adjustment for EM.
* - Online
- * - Value should be >= 0
- * - default = (1.0 / k), following the implementation from
+ * - Values should be >= 0
+ * - default = uniformly (1.0 / k), following the implementation from
* [[https://github.com/Blei-Lab/onlineldavb]].
*/
- def setDocConcentration(docConcentration: Double): this.type = {
+ def setDocConcentration(docConcentration: Vector): this.type = {
this.docConcentration = docConcentration
this
}
+ /** Replicates Double to create a symmetric prior */
+ def setDocConcentration(docConcentration: Double): this.type = {
+ this.docConcentration = Vectors.dense(docConcentration)
+ this
+ }
+
/** Alias for [[getDocConcentration]] */
- def getAlpha: Double = getDocConcentration
+ def getAlpha: Vector = getDocConcentration
+
+ /** Alias for [[setDocConcentration()]] */
+ def setAlpha(alpha: Vector): this.type = setDocConcentration(alpha)
/** Alias for [[setDocConcentration()]] */
def setAlpha(alpha: Double): this.type = setDocConcentration(alpha)
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAModel.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAModel.scala
index 974b26924dfb8..920b57756b625 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAModel.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAModel.scala
@@ -17,15 +17,25 @@
package org.apache.spark.mllib.clustering
-import breeze.linalg.{DenseMatrix => BDM, normalize, sum => brzSum}
+import breeze.linalg.{DenseMatrix => BDM, normalize, sum => brzSum, DenseVector => BDV}
+import org.apache.hadoop.fs.Path
+
+import org.json4s.DefaultFormats
+import org.json4s.JsonDSL._
+import org.json4s.jackson.JsonMethods._
+
+import org.apache.spark.SparkContext
import org.apache.spark.annotation.Experimental
import org.apache.spark.api.java.JavaPairRDD
-import org.apache.spark.graphx.{VertexId, EdgeContext, Graph}
-import org.apache.spark.mllib.linalg.{Vectors, Vector, Matrices, Matrix}
+import org.apache.spark.graphx.{VertexId, Edge, EdgeContext, Graph}
+import org.apache.spark.mllib.linalg.{Vectors, Vector, Matrices, Matrix, DenseVector}
+import org.apache.spark.mllib.util.{Saveable, Loader}
import org.apache.spark.rdd.RDD
+import org.apache.spark.sql.{SQLContext, Row}
import org.apache.spark.util.BoundedPriorityQueue
+
/**
* :: Experimental ::
*
@@ -35,7 +45,7 @@ import org.apache.spark.util.BoundedPriorityQueue
* including local and distributed data structures.
*/
@Experimental
-abstract class LDAModel private[clustering] {
+abstract class LDAModel private[clustering] extends Saveable {
/** Number of topics */
def k: Int
@@ -176,6 +186,11 @@ class LocalLDAModel private[clustering] (
}.toArray
}
+ override protected def formatVersion = "1.0"
+
+ override def save(sc: SparkContext, path: String): Unit = {
+ LocalLDAModel.SaveLoadV1_0.save(sc, path, topicsMatrix)
+ }
// TODO
// override def logLikelihood(documents: RDD[(Long, Vector)]): Double = ???
@@ -184,6 +199,80 @@ class LocalLDAModel private[clustering] (
}
+@Experimental
+object LocalLDAModel extends Loader[LocalLDAModel] {
+
+ private object SaveLoadV1_0 {
+
+ val thisFormatVersion = "1.0"
+
+ val thisClassName = "org.apache.spark.mllib.clustering.LocalLDAModel"
+
+ // Store the distribution of terms of each topic and the column index in topicsMatrix
+ // as a Row in data.
+ case class Data(topic: Vector, index: Int)
+
+ def save(sc: SparkContext, path: String, topicsMatrix: Matrix): Unit = {
+ val sqlContext = SQLContext.getOrCreate(sc)
+ import sqlContext.implicits._
+
+ val k = topicsMatrix.numCols
+ val metadata = compact(render
+ (("class" -> thisClassName) ~ ("version" -> thisFormatVersion) ~
+ ("k" -> k) ~ ("vocabSize" -> topicsMatrix.numRows)))
+ sc.parallelize(Seq(metadata), 1).saveAsTextFile(Loader.metadataPath(path))
+
+ val topicsDenseMatrix = topicsMatrix.toBreeze.toDenseMatrix
+ val topics = Range(0, k).map { topicInd =>
+ Data(Vectors.dense((topicsDenseMatrix(::, topicInd).toArray)), topicInd)
+ }.toSeq
+ sc.parallelize(topics, 1).toDF().write.parquet(Loader.dataPath(path))
+ }
+
+ def load(sc: SparkContext, path: String): LocalLDAModel = {
+ val dataPath = Loader.dataPath(path)
+ val sqlContext = SQLContext.getOrCreate(sc)
+ val dataFrame = sqlContext.read.parquet(dataPath)
+
+ Loader.checkSchema[Data](dataFrame.schema)
+ val topics = dataFrame.collect()
+ val vocabSize = topics(0).getAs[Vector](0).size
+ val k = topics.size
+
+ val brzTopics = BDM.zeros[Double](vocabSize, k)
+ topics.foreach { case Row(vec: Vector, ind: Int) =>
+ brzTopics(::, ind) := vec.toBreeze
+ }
+ new LocalLDAModel(Matrices.fromBreeze(brzTopics))
+ }
+ }
+
+ override def load(sc: SparkContext, path: String): LocalLDAModel = {
+ val (loadedClassName, loadedVersion, metadata) = Loader.loadMetadata(sc, path)
+ implicit val formats = DefaultFormats
+ val expectedK = (metadata \ "k").extract[Int]
+ val expectedVocabSize = (metadata \ "vocabSize").extract[Int]
+ val classNameV1_0 = SaveLoadV1_0.thisClassName
+
+ val model = (loadedClassName, loadedVersion) match {
+ case (className, "1.0") if className == classNameV1_0 =>
+ SaveLoadV1_0.load(sc, path)
+ case _ => throw new Exception(
+ s"LocalLDAModel.load did not recognize model with (className, format version):" +
+ s"($loadedClassName, $loadedVersion). Supported:\n" +
+ s" ($classNameV1_0, 1.0)")
+ }
+
+ val topicsMatrix = model.topicsMatrix
+ require(expectedK == topicsMatrix.numCols,
+ s"LocalLDAModel requires $expectedK topics, got ${topicsMatrix.numCols} topics")
+ require(expectedVocabSize == topicsMatrix.numRows,
+ s"LocalLDAModel requires $expectedVocabSize terms for each topic, " +
+ s"but got ${topicsMatrix.numRows}")
+ model
+ }
+}
+
/**
* :: Experimental ::
*
@@ -354,4 +443,135 @@ class DistributedLDAModel private (
// TODO:
// override def topicDistributions(documents: RDD[(Long, Vector)]): RDD[(Long, Vector)] = ???
+ override protected def formatVersion = "1.0"
+
+ override def save(sc: SparkContext, path: String): Unit = {
+ DistributedLDAModel.SaveLoadV1_0.save(
+ sc, path, graph, globalTopicTotals, k, vocabSize, docConcentration, topicConcentration,
+ iterationTimes)
+ }
+}
+
+
+@Experimental
+object DistributedLDAModel extends Loader[DistributedLDAModel] {
+
+ private object SaveLoadV1_0 {
+
+ val thisFormatVersion = "1.0"
+
+ val classNameV1_0 = "org.apache.spark.mllib.clustering.DistributedLDAModel"
+
+ // Store globalTopicTotals as a Vector.
+ case class Data(globalTopicTotals: Vector)
+
+ // Store each term and document vertex with an id and the topicWeights.
+ case class VertexData(id: Long, topicWeights: Vector)
+
+ // Store each edge with the source id, destination id and tokenCounts.
+ case class EdgeData(srcId: Long, dstId: Long, tokenCounts: Double)
+
+ def save(
+ sc: SparkContext,
+ path: String,
+ graph: Graph[LDA.TopicCounts, LDA.TokenCount],
+ globalTopicTotals: LDA.TopicCounts,
+ k: Int,
+ vocabSize: Int,
+ docConcentration: Double,
+ topicConcentration: Double,
+ iterationTimes: Array[Double]): Unit = {
+ val sqlContext = SQLContext.getOrCreate(sc)
+ import sqlContext.implicits._
+
+ val metadata = compact(render
+ (("class" -> classNameV1_0) ~ ("version" -> thisFormatVersion) ~
+ ("k" -> k) ~ ("vocabSize" -> vocabSize) ~ ("docConcentration" -> docConcentration) ~
+ ("topicConcentration" -> topicConcentration) ~
+ ("iterationTimes" -> iterationTimes.toSeq)))
+ sc.parallelize(Seq(metadata), 1).saveAsTextFile(Loader.metadataPath(path))
+
+ val newPath = new Path(Loader.dataPath(path), "globalTopicTotals").toUri.toString
+ sc.parallelize(Seq(Data(Vectors.fromBreeze(globalTopicTotals)))).toDF()
+ .write.parquet(newPath)
+
+ val verticesPath = new Path(Loader.dataPath(path), "topicCounts").toUri.toString
+ graph.vertices.map { case (ind, vertex) =>
+ VertexData(ind, Vectors.fromBreeze(vertex))
+ }.toDF().write.parquet(verticesPath)
+
+ val edgesPath = new Path(Loader.dataPath(path), "tokenCounts").toUri.toString
+ graph.edges.map { case Edge(srcId, dstId, prop) =>
+ EdgeData(srcId, dstId, prop)
+ }.toDF().write.parquet(edgesPath)
+ }
+
+ def load(
+ sc: SparkContext,
+ path: String,
+ vocabSize: Int,
+ docConcentration: Double,
+ topicConcentration: Double,
+ iterationTimes: Array[Double]): DistributedLDAModel = {
+ val dataPath = new Path(Loader.dataPath(path), "globalTopicTotals").toUri.toString
+ val vertexDataPath = new Path(Loader.dataPath(path), "topicCounts").toUri.toString
+ val edgeDataPath = new Path(Loader.dataPath(path), "tokenCounts").toUri.toString
+ val sqlContext = SQLContext.getOrCreate(sc)
+ val dataFrame = sqlContext.read.parquet(dataPath)
+ val vertexDataFrame = sqlContext.read.parquet(vertexDataPath)
+ val edgeDataFrame = sqlContext.read.parquet(edgeDataPath)
+
+ Loader.checkSchema[Data](dataFrame.schema)
+ Loader.checkSchema[VertexData](vertexDataFrame.schema)
+ Loader.checkSchema[EdgeData](edgeDataFrame.schema)
+ val globalTopicTotals: LDA.TopicCounts =
+ dataFrame.first().getAs[Vector](0).toBreeze.toDenseVector
+ val vertices: RDD[(VertexId, LDA.TopicCounts)] = vertexDataFrame.map {
+ case Row(ind: Long, vec: Vector) => (ind, vec.toBreeze.toDenseVector)
+ }
+
+ val edges: RDD[Edge[LDA.TokenCount]] = edgeDataFrame.map {
+ case Row(srcId: Long, dstId: Long, prop: Double) => Edge(srcId, dstId, prop)
+ }
+ val graph: Graph[LDA.TopicCounts, LDA.TokenCount] = Graph(vertices, edges)
+
+ new DistributedLDAModel(graph, globalTopicTotals, globalTopicTotals.length, vocabSize,
+ docConcentration, topicConcentration, iterationTimes)
+ }
+
+ }
+
+ override def load(sc: SparkContext, path: String): DistributedLDAModel = {
+ val (loadedClassName, loadedVersion, metadata) = Loader.loadMetadata(sc, path)
+ implicit val formats = DefaultFormats
+ val expectedK = (metadata \ "k").extract[Int]
+ val vocabSize = (metadata \ "vocabSize").extract[Int]
+ val docConcentration = (metadata \ "docConcentration").extract[Double]
+ val topicConcentration = (metadata \ "topicConcentration").extract[Double]
+ val iterationTimes = (metadata \ "iterationTimes").extract[Seq[Double]]
+ val classNameV1_0 = SaveLoadV1_0.classNameV1_0
+
+ val model = (loadedClassName, loadedVersion) match {
+ case (className, "1.0") if className == classNameV1_0 => {
+ DistributedLDAModel.SaveLoadV1_0.load(
+ sc, path, vocabSize, docConcentration, topicConcentration, iterationTimes.toArray)
+ }
+ case _ => throw new Exception(
+ s"DistributedLDAModel.load did not recognize model with (className, format version):" +
+ s"($loadedClassName, $loadedVersion). Supported: ($classNameV1_0, 1.0)")
+ }
+
+ require(model.vocabSize == vocabSize,
+ s"DistributedLDAModel requires $vocabSize vocabSize, got ${model.vocabSize} vocabSize")
+ require(model.docConcentration == docConcentration,
+ s"DistributedLDAModel requires $docConcentration docConcentration, " +
+ s"got ${model.docConcentration} docConcentration")
+ require(model.topicConcentration == topicConcentration,
+ s"DistributedLDAModel requires $topicConcentration docConcentration, " +
+ s"got ${model.topicConcentration} docConcentration")
+ require(expectedK == model.k,
+ s"DistributedLDAModel requires $expectedK topics, got ${model.k} topics")
+ model
+ }
+
}
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAOptimizer.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAOptimizer.scala
index 8e5154b902d1d..f4170a3d98dd8 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAOptimizer.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAOptimizer.scala
@@ -19,15 +19,15 @@ package org.apache.spark.mllib.clustering
import java.util.Random
-import breeze.linalg.{DenseVector => BDV, DenseMatrix => BDM, sum, normalize, kron}
-import breeze.numerics.{digamma, exp, abs}
+import breeze.linalg.{DenseMatrix => BDM, DenseVector => BDV, normalize, sum}
+import breeze.numerics.{abs, digamma, exp}
import breeze.stats.distributions.{Gamma, RandBasis}
import org.apache.spark.annotation.DeveloperApi
import org.apache.spark.graphx._
import org.apache.spark.graphx.impl.GraphImpl
import org.apache.spark.mllib.impl.PeriodicGraphCheckpointer
-import org.apache.spark.mllib.linalg.{Matrices, SparseVector, DenseVector, Vector}
+import org.apache.spark.mllib.linalg.{DenseVector, Matrices, SparseVector, Vector, Vectors}
import org.apache.spark.rdd.RDD
/**
@@ -95,8 +95,11 @@ final class EMLDAOptimizer extends LDAOptimizer {
* Compute bipartite term/doc graph.
*/
override private[clustering] def initialize(docs: RDD[(Long, Vector)], lda: LDA): LDAOptimizer = {
+ val docConcentration = lda.getDocConcentration(0)
+ require({
+ lda.getDocConcentration.toArray.forall(_ == docConcentration)
+ }, "EMLDAOptimizer currently only supports symmetric document-topic priors")
- val docConcentration = lda.getDocConcentration
val topicConcentration = lda.getTopicConcentration
val k = lda.getK
@@ -229,10 +232,10 @@ final class OnlineLDAOptimizer extends LDAOptimizer {
private var vocabSize: Int = 0
/** alias for docConcentration */
- private var alpha: Double = 0
+ private var alpha: Vector = Vectors.dense(0)
/** (private[clustering] for debugging) Get docConcentration */
- private[clustering] def getAlpha: Double = alpha
+ private[clustering] def getAlpha: Vector = alpha
/** alias for topicConcentration */
private var eta: Double = 0
@@ -343,7 +346,19 @@ final class OnlineLDAOptimizer extends LDAOptimizer {
this.k = lda.getK
this.corpusSize = docs.count()
this.vocabSize = docs.first()._2.size
- this.alpha = if (lda.getDocConcentration == -1) 1.0 / k else lda.getDocConcentration
+ this.alpha = if (lda.getDocConcentration.size == 1) {
+ if (lda.getDocConcentration(0) == -1) Vectors.dense(Array.fill(k)(1.0 / k))
+ else {
+ require(lda.getDocConcentration(0) >= 0, s"all entries in alpha must be >=0, got: $alpha")
+ Vectors.dense(Array.fill(k)(lda.getDocConcentration(0)))
+ }
+ } else {
+ require(lda.getDocConcentration.size == k, s"alpha must have length k, got: $alpha")
+ lda.getDocConcentration.foreachActive { case (_, x) =>
+ require(x >= 0, s"all entries in alpha must be >= 0, got: $alpha")
+ }
+ lda.getDocConcentration
+ }
this.eta = if (lda.getTopicConcentration == -1) 1.0 / k else lda.getTopicConcentration
this.randomGenerator = new Random(lda.getSeed)
@@ -370,9 +385,9 @@ final class OnlineLDAOptimizer extends LDAOptimizer {
iteration += 1
val k = this.k
val vocabSize = this.vocabSize
- val Elogbeta = dirichletExpectation(lambda)
+ val Elogbeta = dirichletExpectation(lambda).t
val expElogbeta = exp(Elogbeta)
- val alpha = this.alpha
+ val alpha = this.alpha.toBreeze
val gammaShape = this.gammaShape
val stats: RDD[BDM[Double]] = batch.mapPartitions { docs =>
@@ -385,41 +400,36 @@ final class OnlineLDAOptimizer extends LDAOptimizer {
case v => throw new IllegalArgumentException("Online LDA does not support vector type "
+ v.getClass)
}
+ if (!ids.isEmpty) {
+
+ // Initialize the variational distribution q(theta|gamma) for the mini-batch
+ val gammad: BDV[Double] =
+ new Gamma(gammaShape, 1.0 / gammaShape).samplesVector(k) // K
+ val expElogthetad: BDV[Double] = exp(digamma(gammad) - digamma(sum(gammad))) // K
+ val expElogbetad: BDM[Double] = expElogbeta(ids, ::).toDenseMatrix // ids * K
+
+ val phinorm: BDV[Double] = expElogbetad * expElogthetad :+ 1e-100 // ids
+ var meanchange = 1D
+ val ctsVector = new BDV[Double](cts) // ids
+
+ // Iterate between gamma and phi until convergence
+ while (meanchange > 1e-3) {
+ val lastgamma = gammad.copy
+ // K K * ids ids
+ gammad := (expElogthetad :* (expElogbetad.t * (ctsVector :/ phinorm))) :+ alpha
+ expElogthetad := exp(digamma(gammad) - digamma(sum(gammad)))
+ phinorm := expElogbetad * expElogthetad :+ 1e-100
+ meanchange = sum(abs(gammad - lastgamma)) / k
+ }
- // Initialize the variational distribution q(theta|gamma) for the mini-batch
- var gammad = new Gamma(gammaShape, 1.0 / gammaShape).samplesVector(k).t // 1 * K
- var Elogthetad = digamma(gammad) - digamma(sum(gammad)) // 1 * K
- var expElogthetad = exp(Elogthetad) // 1 * K
- val expElogbetad = expElogbeta(::, ids).toDenseMatrix // K * ids
-
- var phinorm = expElogthetad * expElogbetad + 1e-100 // 1 * ids
- var meanchange = 1D
- val ctsVector = new BDV[Double](cts).t // 1 * ids
-
- // Iterate between gamma and phi until convergence
- while (meanchange > 1e-3) {
- val lastgamma = gammad
- // 1*K 1 * ids ids * k
- gammad = (expElogthetad :* ((ctsVector / phinorm) * expElogbetad.t)) + alpha
- Elogthetad = digamma(gammad) - digamma(sum(gammad))
- expElogthetad = exp(Elogthetad)
- phinorm = expElogthetad * expElogbetad + 1e-100
- meanchange = sum(abs(gammad - lastgamma)) / k
- }
-
- val m1 = expElogthetad.t
- val m2 = (ctsVector / phinorm).t.toDenseVector
- var i = 0
- while (i < ids.size) {
- stat(::, ids(i)) := stat(::, ids(i)) + m1 * m2(i)
- i += 1
+ stat(::, ids) := expElogthetad.asDenseMatrix.t * (ctsVector :/ phinorm).asDenseMatrix
}
}
Iterator(stat)
}
val statsSum: BDM[Double] = stats.reduce(_ += _)
- val batchResult = statsSum :* expElogbeta
+ val batchResult = statsSum :* expElogbeta.t
// Note that this is an optimization to avoid batch.count
update(batchResult, iteration, (miniBatchFraction * corpusSize).ceil.toInt)
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/PowerIterationClustering.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/PowerIterationClustering.scala
index e7a243f854e33..407e43a024a2e 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/PowerIterationClustering.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/PowerIterationClustering.scala
@@ -153,6 +153,27 @@ class PowerIterationClustering private[clustering] (
this
}
+ /**
+ * Run the PIC algorithm on Graph.
+ *
+ * @param graph an affinity matrix represented as graph, which is the matrix A in the PIC paper.
+ * The similarity s,,ij,, represented as the edge between vertices (i, j) must
+ * be nonnegative. This is a symmetric matrix and hence s,,ij,, = s,,ji,,. For
+ * any (i, j) with nonzero similarity, there should be either (i, j, s,,ij,,)
+ * or (j, i, s,,ji,,) in the input. Tuples with i = j are ignored, because we
+ * assume s,,ij,, = 0.0.
+ *
+ * @return a [[PowerIterationClusteringModel]] that contains the clustering result
+ */
+ def run(graph: Graph[Double, Double]): PowerIterationClusteringModel = {
+ val w = normalize(graph)
+ val w0 = initMode match {
+ case "random" => randomInit(w)
+ case "degree" => initDegreeVector(w)
+ }
+ pic(w0)
+ }
+
/**
* Run the PIC algorithm.
*
@@ -212,6 +233,31 @@ object PowerIterationClustering extends Logging {
@Experimental
case class Assignment(id: Long, cluster: Int)
+ /**
+ * Normalizes the affinity graph (A) and returns the normalized affinity matrix (W).
+ */
+ private[clustering]
+ def normalize(graph: Graph[Double, Double]): Graph[Double, Double] = {
+ val vD = graph.aggregateMessages[Double](
+ sendMsg = ctx => {
+ val i = ctx.srcId
+ val j = ctx.dstId
+ val s = ctx.attr
+ if (s < 0.0) {
+ throw new SparkException("Similarity must be nonnegative but found s($i, $j) = $s.")
+ }
+ if (s > 0.0) {
+ ctx.sendToSrc(s)
+ }
+ },
+ mergeMsg = _ + _,
+ TripletFields.EdgeOnly)
+ GraphImpl.fromExistingRDDs(vD, graph.edges)
+ .mapTriplets(
+ e => e.attr / math.max(e.srcAttr, MLUtils.EPSILON),
+ TripletFields.Src)
+ }
+
/**
* Normalizes the affinity matrix (A) by row sums and returns the normalized affinity matrix (W).
*/
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/evaluation/RegressionMetrics.scala b/mllib/src/main/scala/org/apache/spark/mllib/evaluation/RegressionMetrics.scala
index e577bf87f885e..408847afa800d 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/evaluation/RegressionMetrics.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/evaluation/RegressionMetrics.scala
@@ -53,14 +53,22 @@ class RegressionMetrics(predictionAndObservations: RDD[(Double, Double)]) extend
)
summary
}
+ private lazy val SSerr = math.pow(summary.normL2(1), 2)
+ private lazy val SStot = summary.variance(0) * (summary.count - 1)
+ private lazy val SSreg = {
+ val yMean = summary.mean(0)
+ predictionAndObservations.map {
+ case (prediction, _) => math.pow(prediction - yMean, 2)
+ }.sum()
+ }
/**
- * Returns the explained variance regression score.
- * explainedVariance = 1 - variance(y - \hat{y}) / variance(y)
- * Reference: [[http://en.wikipedia.org/wiki/Explained_variation]]
+ * Returns the variance explained by regression.
+ * explainedVariance = \sum_i (\hat{y_i} - \bar{y})^2 / n
+ * @see [[https://en.wikipedia.org/wiki/Fraction_of_variance_unexplained]]
*/
def explainedVariance: Double = {
- 1 - summary.variance(1) / summary.variance(0)
+ SSreg / summary.count
}
/**
@@ -76,8 +84,7 @@ class RegressionMetrics(predictionAndObservations: RDD[(Double, Double)]) extend
* expected value of the squared error loss or quadratic loss.
*/
def meanSquaredError: Double = {
- val rmse = summary.normL2(1) / math.sqrt(summary.count)
- rmse * rmse
+ SSerr / summary.count
}
/**
@@ -85,14 +92,14 @@ class RegressionMetrics(predictionAndObservations: RDD[(Double, Double)]) extend
* the mean squared error.
*/
def rootMeanSquaredError: Double = {
- summary.normL2(1) / math.sqrt(summary.count)
+ math.sqrt(this.meanSquaredError)
}
/**
- * Returns R^2^, the coefficient of determination.
- * Reference: [[http://en.wikipedia.org/wiki/Coefficient_of_determination]]
+ * Returns R^2^, the unadjusted coefficient of determination.
+ * @see [[http://en.wikipedia.org/wiki/Coefficient_of_determination]]
*/
def r2: Double = {
- 1 - math.pow(summary.normL2(1), 2) / (summary.variance(0) * (summary.count - 1))
+ 1 - SSerr / SStot
}
}
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/fpm/LocalPrefixSpan.scala b/mllib/src/main/scala/org/apache/spark/mllib/fpm/LocalPrefixSpan.scala
index 39c48b084e550..7ead6327486cc 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/fpm/LocalPrefixSpan.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/fpm/LocalPrefixSpan.scala
@@ -17,58 +17,49 @@
package org.apache.spark.mllib.fpm
+import scala.collection.mutable
+
import org.apache.spark.Logging
-import org.apache.spark.annotation.Experimental
/**
- *
- * :: Experimental ::
- *
* Calculate all patterns of a projected database in local.
*/
-@Experimental
private[fpm] object LocalPrefixSpan extends Logging with Serializable {
/**
* Calculate all patterns of a projected database.
* @param minCount minimum count
* @param maxPatternLength maximum pattern length
- * @param prefix prefix
- * @param projectedDatabase the projected dabase
+ * @param prefixes prefixes in reversed order
+ * @param database the projected database
* @return a set of sequential pattern pairs,
- * the key of pair is sequential pattern (a list of items),
+ * the key of pair is sequential pattern (a list of items in reversed order),
* the value of pair is the pattern's count.
*/
def run(
minCount: Long,
maxPatternLength: Int,
- prefix: Array[Int],
- projectedDatabase: Array[Array[Int]]): Array[(Array[Int], Long)] = {
- val frequentPrefixAndCounts = getFreqItemAndCounts(minCount, projectedDatabase)
- val frequentPatternAndCounts = frequentPrefixAndCounts
- .map(x => (prefix ++ Array(x._1), x._2))
- val prefixProjectedDatabases = getPatternAndProjectedDatabase(
- prefix, frequentPrefixAndCounts.map(_._1), projectedDatabase)
-
- val continueProcess = prefixProjectedDatabases.nonEmpty && prefix.length + 1 < maxPatternLength
- if (continueProcess) {
- val nextPatterns = prefixProjectedDatabases
- .map(x => run(minCount, maxPatternLength, x._1, x._2))
- .reduce(_ ++ _)
- frequentPatternAndCounts ++ nextPatterns
- } else {
- frequentPatternAndCounts
+ prefixes: List[Int],
+ database: Array[Array[Int]]): Iterator[(List[Int], Long)] = {
+ if (prefixes.length == maxPatternLength || database.isEmpty) return Iterator.empty
+ val frequentItemAndCounts = getFreqItemAndCounts(minCount, database)
+ val filteredDatabase = database.map(x => x.filter(frequentItemAndCounts.contains))
+ frequentItemAndCounts.iterator.flatMap { case (item, count) =>
+ val newPrefixes = item :: prefixes
+ val newProjected = project(filteredDatabase, item)
+ Iterator.single((newPrefixes, count)) ++
+ run(minCount, maxPatternLength, newPrefixes, newProjected)
}
}
/**
- * calculate suffix sequence following a prefix in a sequence
- * @param prefix prefix
- * @param sequence sequence
+ * Calculate suffix sequence immediately after the first occurrence of an item.
+ * @param item item to get suffix after
+ * @param sequence sequence to extract suffix from
* @return suffix sequence
*/
- def getSuffix(prefix: Int, sequence: Array[Int]): Array[Int] = {
- val index = sequence.indexOf(prefix)
+ def getSuffix(item: Int, sequence: Array[Int]): Array[Int] = {
+ val index = sequence.indexOf(item)
if (index == -1) {
Array()
} else {
@@ -76,38 +67,28 @@ private[fpm] object LocalPrefixSpan extends Logging with Serializable {
}
}
+ def project(database: Array[Array[Int]], prefix: Int): Array[Array[Int]] = {
+ database
+ .map(getSuffix(prefix, _))
+ .filter(_.nonEmpty)
+ }
+
/**
* Generates frequent items by filtering the input data using minimal count level.
- * @param minCount the absolute minimum count
- * @param sequences sequences data
- * @return array of item and count pair
+ * @param minCount the minimum count for an item to be frequent
+ * @param database database of sequences
+ * @return freq item to count map
*/
private def getFreqItemAndCounts(
minCount: Long,
- sequences: Array[Array[Int]]): Array[(Int, Long)] = {
- sequences.flatMap(_.distinct)
- .groupBy(x => x)
- .mapValues(_.length.toLong)
- .filter(_._2 >= minCount)
- .toArray
- }
-
- /**
- * Get the frequent prefixes' projected database.
- * @param prePrefix the frequent prefixes' prefix
- * @param frequentPrefixes frequent prefixes
- * @param sequences sequences data
- * @return prefixes and projected database
- */
- private def getPatternAndProjectedDatabase(
- prePrefix: Array[Int],
- frequentPrefixes: Array[Int],
- sequences: Array[Array[Int]]): Array[(Array[Int], Array[Array[Int]])] = {
- val filteredProjectedDatabase = sequences
- .map(x => x.filter(frequentPrefixes.contains(_)))
- frequentPrefixes.map { x =>
- val sub = filteredProjectedDatabase.map(y => getSuffix(x, y)).filter(_.nonEmpty)
- (prePrefix ++ Array(x), sub)
- }.filter(x => x._2.nonEmpty)
+ database: Array[Array[Int]]): mutable.Map[Int, Long] = {
+ // TODO: use PrimitiveKeyOpenHashMap
+ val counts = mutable.Map[Int, Long]().withDefaultValue(0L)
+ database.foreach { sequence =>
+ sequence.distinct.foreach { item =>
+ counts(item) += 1L
+ }
+ }
+ counts.filter(_._2 >= minCount)
}
}
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/fpm/PrefixSpan.scala b/mllib/src/main/scala/org/apache/spark/mllib/fpm/PrefixSpan.scala
index 9d8c60ef0fc45..6f52db7b073ae 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/fpm/PrefixSpan.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/fpm/PrefixSpan.scala
@@ -150,8 +150,9 @@ class PrefixSpan private (
private def getPatternsInLocal(
minCount: Long,
data: RDD[(Array[Int], Array[Array[Int]])]): RDD[(Array[Int], Long)] = {
- data.flatMap { x =>
- LocalPrefixSpan.run(minCount, maxPatternLength, x._1, x._2)
+ data.flatMap { case (prefix, projDB) =>
+ LocalPrefixSpan.run(minCount, maxPatternLength, prefix.toList, projDB)
+ .map { case (pattern: List[Int], count: Long) => (pattern.toArray.reverse, count) }
}
}
}
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/linalg/BLAS.scala b/mllib/src/main/scala/org/apache/spark/mllib/linalg/BLAS.scala
index 3523f1804325d..9029093e0fa08 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/linalg/BLAS.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/linalg/BLAS.scala
@@ -303,8 +303,8 @@ private[spark] object BLAS extends Serializable with Logging {
C: DenseMatrix): Unit = {
require(!C.isTransposed,
"The matrix C cannot be the product of a transpose() call. C.isTransposed must be false.")
- if (alpha == 0.0) {
- logDebug("gemm: alpha is equal to 0. Returning C.")
+ if (alpha == 0.0 && beta == 1.0) {
+ logDebug("gemm: alpha is equal to 0 and beta is equal to 1. Returning C.")
} else {
A match {
case sparse: SparseMatrix => gemm(alpha, sparse, B, beta, C)
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/linalg/Matrices.scala b/mllib/src/main/scala/org/apache/spark/mllib/linalg/Matrices.scala
index 0df07663405a3..55da0e094d132 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/linalg/Matrices.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/linalg/Matrices.scala
@@ -98,7 +98,7 @@ sealed trait Matrix extends Serializable {
/** Map the values of this matrix using a function. Generates a new matrix. Performs the
* function on only the backing array. For example, an operation such as addition or
* subtraction will only be performed on the non-zero values in a `SparseMatrix`. */
- private[mllib] def map(f: Double => Double): Matrix
+ private[spark] def map(f: Double => Double): Matrix
/** Update all the values of this matrix using the function f. Performed in-place on the
* backing array. For example, an operation such as addition or subtraction will only be
@@ -289,7 +289,7 @@ class DenseMatrix(
override def copy: DenseMatrix = new DenseMatrix(numRows, numCols, values.clone())
- private[mllib] def map(f: Double => Double) = new DenseMatrix(numRows, numCols, values.map(f),
+ private[spark] def map(f: Double => Double) = new DenseMatrix(numRows, numCols, values.map(f),
isTransposed)
private[mllib] def update(f: Double => Double): DenseMatrix = {
@@ -555,7 +555,7 @@ class SparseMatrix(
new SparseMatrix(numRows, numCols, colPtrs, rowIndices, values.clone())
}
- private[mllib] def map(f: Double => Double) =
+ private[spark] def map(f: Double => Double) =
new SparseMatrix(numRows, numCols, colPtrs, rowIndices, values.map(f), isTransposed)
private[mllib] def update(f: Double => Double): SparseMatrix = {
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala b/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala
index e048b01d92462..9067b3ba9a7bb 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala
@@ -150,6 +150,12 @@ sealed trait Vector extends Serializable {
toDense
}
}
+
+ /**
+ * Find the index of a maximal element. Returns the first maximal element in case of a tie.
+ * Returns -1 if vector has length 0.
+ */
+ def argmax: Int
}
/**
@@ -588,11 +594,7 @@ class DenseVector(val values: Array[Double]) extends Vector {
new SparseVector(size, ii, vv)
}
- /**
- * Find the index of a maximal element. Returns the first maximal element in case of a tie.
- * Returns -1 if vector has length 0.
- */
- private[spark] def argmax: Int = {
+ override def argmax: Int = {
if (size == 0) {
-1
} else {
@@ -717,6 +719,51 @@ class SparseVector(
new SparseVector(size, ii, vv)
}
}
+
+ override def argmax: Int = {
+ if (size == 0) {
+ -1
+ } else {
+ // Find the max active entry.
+ var maxIdx = indices(0)
+ var maxValue = values(0)
+ var maxJ = 0
+ var j = 1
+ val na = numActives
+ while (j < na) {
+ val v = values(j)
+ if (v > maxValue) {
+ maxValue = v
+ maxIdx = indices(j)
+ maxJ = j
+ }
+ j += 1
+ }
+
+ // If the max active entry is nonpositive and there exists inactive ones, find the first zero.
+ if (maxValue <= 0.0 && na < size) {
+ if (maxValue == 0.0) {
+ // If there exists an inactive entry before maxIdx, find it and return its index.
+ if (maxJ < maxIdx) {
+ var k = 0
+ while (k < maxJ && indices(k) == k) {
+ k += 1
+ }
+ maxIdx = k
+ }
+ } else {
+ // If the max active value is negative, find and return the first inactive index.
+ var k = 0
+ while (k < na && indices(k) == k) {
+ k += 1
+ }
+ maxIdx = k
+ }
+ }
+
+ maxIdx
+ }
+ }
}
object SparseVector {
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/rdd/SlidingRDD.scala b/mllib/src/main/scala/org/apache/spark/mllib/rdd/SlidingRDD.scala
index 35e81fcb3de0d..1facf83d806d0 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/rdd/SlidingRDD.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/rdd/SlidingRDD.scala
@@ -72,7 +72,7 @@ class SlidingRDD[T: ClassTag](@transient val parent: RDD[T], val windowSize: Int
val w1 = windowSize - 1
// Get the first w1 items of each partition, starting from the second partition.
val nextHeads =
- parent.context.runJob(parent, (iter: Iterator[T]) => iter.take(w1).toArray, 1 until n, true)
+ parent.context.runJob(parent, (iter: Iterator[T]) => iter.take(w1).toArray, 1 until n)
val partitions = mutable.ArrayBuffer[SlidingRDDPartition[T]]()
var i = 0
var partitionIndex = 0
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/stat/test/KolmogorovSmirnovTest.scala b/mllib/src/main/scala/org/apache/spark/mllib/stat/test/KolmogorovSmirnovTest.scala
index d89b0059d83f3..2b3ed6df486c9 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/stat/test/KolmogorovSmirnovTest.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/stat/test/KolmogorovSmirnovTest.scala
@@ -20,7 +20,7 @@ package org.apache.spark.mllib.stat.test
import scala.annotation.varargs
import org.apache.commons.math3.distribution.{NormalDistribution, RealDistribution}
-import org.apache.commons.math3.stat.inference.KolmogorovSmirnovTest
+import org.apache.commons.math3.stat.inference.{KolmogorovSmirnovTest => CommonMathKolmogorovSmirnovTest}
import org.apache.spark.Logging
import org.apache.spark.rdd.RDD
@@ -187,7 +187,7 @@ private[stat] object KolmogorovSmirnovTest extends Logging {
}
private def evalOneSampleP(ksStat: Double, n: Long): KolmogorovSmirnovTestResult = {
- val pval = 1 - new KolmogorovSmirnovTest().cdf(ksStat, n.toInt)
+ val pval = 1 - new CommonMathKolmogorovSmirnovTest().cdf(ksStat, n.toInt)
new KolmogorovSmirnovTestResult(pval, ksStat, NullHypothesis.OneSampleTwoSided.toString)
}
}
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/BaggedPoint.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/BaggedPoint.scala
index 089010c81ffb6..572815df0bc4a 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/BaggedPoint.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/BaggedPoint.scala
@@ -38,10 +38,10 @@ import org.apache.spark.util.random.XORShiftRandom
* TODO: This does not currently support (Double) weighted instances. Once MLlib has weighted
* dataset support, update. (We store subsampleWeights as Double for this future extension.)
*/
-private[tree] class BaggedPoint[Datum](val datum: Datum, val subsampleWeights: Array[Double])
+private[spark] class BaggedPoint[Datum](val datum: Datum, val subsampleWeights: Array[Double])
extends Serializable
-private[tree] object BaggedPoint {
+private[spark] object BaggedPoint {
/**
* Convert an input dataset into its BaggedPoint representation,
@@ -60,7 +60,7 @@ private[tree] object BaggedPoint {
subsamplingRate: Double,
numSubsamples: Int,
withReplacement: Boolean,
- seed: Int = Utils.random.nextInt()): RDD[BaggedPoint[Datum]] = {
+ seed: Long = Utils.random.nextLong()): RDD[BaggedPoint[Datum]] = {
if (withReplacement) {
convertToBaggedRDDSamplingWithReplacement(input, subsamplingRate, numSubsamples, seed)
} else {
@@ -76,7 +76,7 @@ private[tree] object BaggedPoint {
input: RDD[Datum],
subsamplingRate: Double,
numSubsamples: Int,
- seed: Int): RDD[BaggedPoint[Datum]] = {
+ seed: Long): RDD[BaggedPoint[Datum]] = {
input.mapPartitionsWithIndex { (partitionIndex, instances) =>
// Use random seed = seed + partitionIndex + 1 to make generation reproducible.
val rng = new XORShiftRandom
@@ -100,7 +100,7 @@ private[tree] object BaggedPoint {
input: RDD[Datum],
subsample: Double,
numSubsamples: Int,
- seed: Int): RDD[BaggedPoint[Datum]] = {
+ seed: Long): RDD[BaggedPoint[Datum]] = {
input.mapPartitionsWithIndex { (partitionIndex, instances) =>
// Use random seed = seed + partitionIndex + 1 to make generation reproducible.
val poisson = new PoissonDistribution(subsample)
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/DTStatsAggregator.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/DTStatsAggregator.scala
index ce8825cc03229..7985ed4b4c0fa 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/DTStatsAggregator.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/DTStatsAggregator.scala
@@ -27,7 +27,7 @@ import org.apache.spark.mllib.tree.impurity._
* and helps with indexing.
* This class is abstract to support learning with and without feature subsampling.
*/
-private[tree] class DTStatsAggregator(
+private[spark] class DTStatsAggregator(
val metadata: DecisionTreeMetadata,
featureSubset: Option[Array[Int]]) extends Serializable {
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/DecisionTreeMetadata.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/DecisionTreeMetadata.scala
index f73896e37c05e..380291ac22bd3 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/DecisionTreeMetadata.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/DecisionTreeMetadata.scala
@@ -37,7 +37,7 @@ import org.apache.spark.rdd.RDD
* I.e., the feature takes values in {0, ..., arity - 1}.
* @param numBins Number of bins for each feature.
*/
-private[tree] class DecisionTreeMetadata(
+private[spark] class DecisionTreeMetadata(
val numFeatures: Int,
val numExamples: Long,
val numClasses: Int,
@@ -94,7 +94,7 @@ private[tree] class DecisionTreeMetadata(
}
-private[tree] object DecisionTreeMetadata extends Logging {
+private[spark] object DecisionTreeMetadata extends Logging {
/**
* Construct a [[DecisionTreeMetadata]] instance for this dataset and parameters.
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/NodeIdCache.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/NodeIdCache.scala
index bdd0f576b048d..8f9eb24b57b55 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/NodeIdCache.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/NodeIdCache.scala
@@ -75,7 +75,7 @@ private[tree] case class NodeIndexUpdater(
* (how often should the cache be checkpointed.).
*/
@DeveloperApi
-private[tree] class NodeIdCache(
+private[spark] class NodeIdCache(
var nodeIdsForInstances: RDD[Array[Int]],
val checkpointInterval: Int) {
@@ -170,7 +170,7 @@ private[tree] class NodeIdCache(
}
@DeveloperApi
-private[tree] object NodeIdCache {
+private[spark] object NodeIdCache {
/**
* Initialize the node Id cache with initial node Id values.
* @param data The RDD of training rows.
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/TimeTracker.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/TimeTracker.scala
index d215d68c4279e..aac84243d5ce1 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/TimeTracker.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/TimeTracker.scala
@@ -25,7 +25,7 @@ import org.apache.spark.annotation.Experimental
* Time tracker implementation which holds labeled timers.
*/
@Experimental
-private[tree] class TimeTracker extends Serializable {
+private[spark] class TimeTracker extends Serializable {
private val starts: MutableHashMap[String, Long] = new MutableHashMap[String, Long]()
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/TreePoint.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/TreePoint.scala
index 50b292e71b067..21919d69a38a3 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/TreePoint.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/TreePoint.scala
@@ -37,11 +37,11 @@ import org.apache.spark.rdd.RDD
* @param binnedFeatures Binned feature values.
* Same length as LabeledPoint.features, but values are bin indices.
*/
-private[tree] class TreePoint(val label: Double, val binnedFeatures: Array[Int])
+private[spark] class TreePoint(val label: Double, val binnedFeatures: Array[Int])
extends Serializable {
}
-private[tree] object TreePoint {
+private[spark] object TreePoint {
/**
* Convert an input dataset into its TreePoint representation,
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Impurity.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Impurity.scala
index 72eb24c49264a..578749d85a4e6 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Impurity.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Impurity.scala
@@ -57,7 +57,7 @@ trait Impurity extends Serializable {
* Note: Instances of this class do not hold the data; they operate on views of the data.
* @param statsSize Length of the vector of sufficient statistics for one bin.
*/
-private[tree] abstract class ImpurityAggregator(val statsSize: Int) extends Serializable {
+private[spark] abstract class ImpurityAggregator(val statsSize: Int) extends Serializable {
/**
* Merge the stats from one bin into another.
@@ -95,7 +95,7 @@ private[tree] abstract class ImpurityAggregator(val statsSize: Int) extends Seri
* (node, feature, bin).
* @param stats Array of sufficient statistics for a (node, feature, bin).
*/
-private[tree] abstract class ImpurityCalculator(val stats: Array[Double]) {
+private[spark] abstract class ImpurityCalculator(val stats: Array[Double]) {
/**
* Make a deep copy of this [[ImpurityCalculator]].
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/loss/SquaredError.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/loss/SquaredError.scala
index a5582d3ef3324..011a5d57422f7 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/tree/loss/SquaredError.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/loss/SquaredError.scala
@@ -42,11 +42,11 @@ object SquaredError extends Loss {
* @return Loss gradient
*/
override def gradient(prediction: Double, label: Double): Double = {
- 2.0 * (prediction - label)
+ - 2.0 * (label - prediction)
}
override private[mllib] def computeError(prediction: Double, label: Double): Double = {
- val err = prediction - label
+ val err = label - prediction
err * err
}
}
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/InformationGainStats.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/InformationGainStats.scala
index 2d087c967f679..dc9e0f9f51ffb 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/InformationGainStats.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/InformationGainStats.scala
@@ -67,7 +67,7 @@ class InformationGainStats(
}
-private[tree] object InformationGainStats {
+private[spark] object InformationGainStats {
/**
* An [[org.apache.spark.mllib.tree.model.InformationGainStats]] object to
* denote that current split doesn't satisfies minimum info gain or
diff --git a/mllib/src/test/java/org/apache/spark/ml/classification/JavaNaiveBayesSuite.java b/mllib/src/test/java/org/apache/spark/ml/classification/JavaNaiveBayesSuite.java
new file mode 100644
index 0000000000000..09a9fba0c19cf
--- /dev/null
+++ b/mllib/src/test/java/org/apache/spark/ml/classification/JavaNaiveBayesSuite.java
@@ -0,0 +1,98 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.ml.classification;
+
+import java.io.Serializable;
+
+import com.google.common.collect.Lists;
+import org.junit.After;
+import org.junit.Before;
+import org.junit.Test;
+
+import org.apache.spark.api.java.JavaRDD;
+import org.apache.spark.api.java.JavaSparkContext;
+import org.apache.spark.mllib.linalg.VectorUDT;
+import org.apache.spark.mllib.linalg.Vectors;
+import org.apache.spark.sql.DataFrame;
+import org.apache.spark.sql.Row;
+import org.apache.spark.sql.RowFactory;
+import org.apache.spark.sql.SQLContext;
+import org.apache.spark.sql.types.DataTypes;
+import org.apache.spark.sql.types.Metadata;
+import org.apache.spark.sql.types.StructField;
+import org.apache.spark.sql.types.StructType;
+
+public class JavaNaiveBayesSuite implements Serializable {
+
+ private transient JavaSparkContext jsc;
+ private transient SQLContext jsql;
+
+ @Before
+ public void setUp() {
+ jsc = new JavaSparkContext("local", "JavaLogisticRegressionSuite");
+ jsql = new SQLContext(jsc);
+ }
+
+ @After
+ public void tearDown() {
+ jsc.stop();
+ jsc = null;
+ }
+
+ public void validatePrediction(DataFrame predictionAndLabels) {
+ for (Row r : predictionAndLabels.collect()) {
+ double prediction = r.getAs(0);
+ double label = r.getAs(1);
+ assert(prediction == label);
+ }
+ }
+
+ @Test
+ public void naiveBayesDefaultParams() {
+ NaiveBayes nb = new NaiveBayes();
+ assert(nb.getLabelCol() == "label");
+ assert(nb.getFeaturesCol() == "features");
+ assert(nb.getPredictionCol() == "prediction");
+ assert(nb.getLambda() == 1.0);
+ assert(nb.getModelType() == "multinomial");
+ }
+
+ @Test
+ public void testNaiveBayes() {
+ JavaRDD jrdd = jsc.parallelize(Lists.newArrayList(
+ RowFactory.create(0.0, Vectors.dense(1.0, 0.0, 0.0)),
+ RowFactory.create(0.0, Vectors.dense(2.0, 0.0, 0.0)),
+ RowFactory.create(1.0, Vectors.dense(0.0, 1.0, 0.0)),
+ RowFactory.create(1.0, Vectors.dense(0.0, 2.0, 0.0)),
+ RowFactory.create(2.0, Vectors.dense(0.0, 0.0, 1.0)),
+ RowFactory.create(2.0, Vectors.dense(0.0, 0.0, 2.0))
+ ));
+
+ StructType schema = new StructType(new StructField[]{
+ new StructField("label", DataTypes.DoubleType, false, Metadata.empty()),
+ new StructField("features", new VectorUDT(), false, Metadata.empty())
+ });
+
+ DataFrame dataset = jsql.createDataFrame(jrdd, schema);
+ NaiveBayes nb = new NaiveBayes().setLambda(0.5).setModelType("multinomial");
+ NaiveBayesModel model = nb.fit(dataset);
+
+ DataFrame predictionAndLabels = model.transform(dataset).select("prediction", "label");
+ validatePrediction(predictionAndLabels);
+ }
+}
diff --git a/mllib/src/test/java/org/apache/spark/ml/clustering/JavaKMeansSuite.java b/mllib/src/test/java/org/apache/spark/ml/clustering/JavaKMeansSuite.java
new file mode 100644
index 0000000000000..d09fa7fd5637c
--- /dev/null
+++ b/mllib/src/test/java/org/apache/spark/ml/clustering/JavaKMeansSuite.java
@@ -0,0 +1,72 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.ml.clustering;
+
+import java.io.Serializable;
+import java.util.Arrays;
+import java.util.List;
+
+import org.junit.After;
+import org.junit.Before;
+import org.junit.Test;
+import static org.junit.Assert.assertArrayEquals;
+import static org.junit.Assert.assertEquals;
+import static org.junit.Assert.assertTrue;
+
+import org.apache.spark.api.java.JavaSparkContext;
+import org.apache.spark.mllib.linalg.Vector;
+import org.apache.spark.sql.DataFrame;
+import org.apache.spark.sql.SQLContext;
+
+public class JavaKMeansSuite implements Serializable {
+
+ private transient int k = 5;
+ private transient JavaSparkContext sc;
+ private transient DataFrame dataset;
+ private transient SQLContext sql;
+
+ @Before
+ public void setUp() {
+ sc = new JavaSparkContext("local", "JavaKMeansSuite");
+ sql = new SQLContext(sc);
+
+ dataset = KMeansSuite.generateKMeansData(sql, 50, 3, k);
+ }
+
+ @After
+ public void tearDown() {
+ sc.stop();
+ sc = null;
+ }
+
+ @Test
+ public void fitAndTransform() {
+ KMeans kmeans = new KMeans().setK(k).setSeed(1);
+ KMeansModel model = kmeans.fit(dataset);
+
+ Vector[] centers = model.clusterCenters();
+ assertEquals(k, centers.length);
+
+ DataFrame transformed = model.transform(dataset);
+ List columns = Arrays.asList(transformed.columns());
+ List expectedColumns = Arrays.asList("features", "prediction");
+ for (String column: expectedColumns) {
+ assertTrue(columns.contains(column));
+ }
+ }
+}
diff --git a/mllib/src/test/java/org/apache/spark/ml/param/JavaTestParams.java b/mllib/src/test/java/org/apache/spark/ml/param/JavaTestParams.java
index 3ae09d39ef500..dc6ce8061f62b 100644
--- a/mllib/src/test/java/org/apache/spark/ml/param/JavaTestParams.java
+++ b/mllib/src/test/java/org/apache/spark/ml/param/JavaTestParams.java
@@ -96,11 +96,8 @@ private void init() {
new DoubleArrayParam(this, "myDoubleArrayParam", "this is a double param");
setDefault(myIntParam(), 1);
- setDefault(myIntParam().w(1));
setDefault(myDoubleParam(), 0.5);
- setDefault(myIntParam().w(1), myDoubleParam().w(0.5));
setDefault(myDoubleArrayParam(), new double[] {1.0, 2.0});
- setDefault(myDoubleArrayParam().w(new double[] {1.0, 2.0}));
}
@Override
diff --git a/mllib/src/test/java/org/apache/spark/ml/regression/JavaDecisionTreeRegressorSuite.java b/mllib/src/test/java/org/apache/spark/ml/regression/JavaDecisionTreeRegressorSuite.java
index 71b041818d7ee..ebe800e749e05 100644
--- a/mllib/src/test/java/org/apache/spark/ml/regression/JavaDecisionTreeRegressorSuite.java
+++ b/mllib/src/test/java/org/apache/spark/ml/regression/JavaDecisionTreeRegressorSuite.java
@@ -57,7 +57,7 @@ public void runDT() {
JavaRDD data = sc.parallelize(
LogisticRegressionSuite.generateLogisticInputAsList(A, B, nPoints, 42), 2).cache();
Map categoricalFeatures = new HashMap();
- DataFrame dataFrame = TreeTests.setMetadata(data, categoricalFeatures, 2);
+ DataFrame dataFrame = TreeTests.setMetadata(data, categoricalFeatures, 0);
// This tests setters. Training with various options is tested in Scala.
DecisionTreeRegressor dt = new DecisionTreeRegressor()
diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/NaiveBayesSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/NaiveBayesSuite.scala
new file mode 100644
index 0000000000000..76381a2741296
--- /dev/null
+++ b/mllib/src/test/scala/org/apache/spark/ml/classification/NaiveBayesSuite.scala
@@ -0,0 +1,116 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.ml.classification
+
+import org.apache.spark.SparkFunSuite
+import org.apache.spark.ml.param.ParamsSuite
+import org.apache.spark.mllib.linalg._
+import org.apache.spark.mllib.util.MLlibTestSparkContext
+import org.apache.spark.mllib.util.TestingUtils._
+import org.apache.spark.mllib.classification.NaiveBayesSuite._
+import org.apache.spark.sql.DataFrame
+import org.apache.spark.sql.Row
+
+class NaiveBayesSuite extends SparkFunSuite with MLlibTestSparkContext {
+
+ def validatePrediction(predictionAndLabels: DataFrame): Unit = {
+ val numOfErrorPredictions = predictionAndLabels.collect().count {
+ case Row(prediction: Double, label: Double) =>
+ prediction != label
+ }
+ // At least 80% of the predictions should be on.
+ assert(numOfErrorPredictions < predictionAndLabels.count() / 5)
+ }
+
+ def validateModelFit(
+ piData: Vector,
+ thetaData: Matrix,
+ model: NaiveBayesModel): Unit = {
+ assert(Vectors.dense(model.pi.toArray.map(math.exp)) ~==
+ Vectors.dense(piData.toArray.map(math.exp)) absTol 0.05, "pi mismatch")
+ assert(model.theta.map(math.exp) ~== thetaData.map(math.exp) absTol 0.05, "theta mismatch")
+ }
+
+ test("params") {
+ ParamsSuite.checkParams(new NaiveBayes)
+ val model = new NaiveBayesModel("nb", pi = Vectors.dense(Array(0.2, 0.8)),
+ theta = new DenseMatrix(2, 3, Array(0.1, 0.2, 0.3, 0.4, 0.6, 0.4)))
+ ParamsSuite.checkParams(model)
+ }
+
+ test("naive bayes: default params") {
+ val nb = new NaiveBayes
+ assert(nb.getLabelCol === "label")
+ assert(nb.getFeaturesCol === "features")
+ assert(nb.getPredictionCol === "prediction")
+ assert(nb.getLambda === 1.0)
+ assert(nb.getModelType === "multinomial")
+ }
+
+ test("Naive Bayes Multinomial") {
+ val nPoints = 1000
+ val piArray = Array(0.5, 0.1, 0.4).map(math.log)
+ val thetaArray = Array(
+ Array(0.70, 0.10, 0.10, 0.10), // label 0
+ Array(0.10, 0.70, 0.10, 0.10), // label 1
+ Array(0.10, 0.10, 0.70, 0.10) // label 2
+ ).map(_.map(math.log))
+ val pi = Vectors.dense(piArray)
+ val theta = new DenseMatrix(3, 4, thetaArray.flatten, true)
+
+ val testDataset = sqlContext.createDataFrame(generateNaiveBayesInput(
+ piArray, thetaArray, nPoints, 42, "multinomial"))
+ val nb = new NaiveBayes().setLambda(1.0).setModelType("multinomial")
+ val model = nb.fit(testDataset)
+
+ validateModelFit(pi, theta, model)
+ assert(model.hasParent)
+
+ val validationDataset = sqlContext.createDataFrame(generateNaiveBayesInput(
+ piArray, thetaArray, nPoints, 17, "multinomial"))
+ val predictionAndLabels = model.transform(validationDataset).select("prediction", "label")
+
+ validatePrediction(predictionAndLabels)
+ }
+
+ test("Naive Bayes Bernoulli") {
+ val nPoints = 10000
+ val piArray = Array(0.5, 0.3, 0.2).map(math.log)
+ val thetaArray = Array(
+ Array(0.50, 0.02, 0.02, 0.02, 0.02, 0.02, 0.02, 0.02, 0.02, 0.02, 0.02, 0.40), // label 0
+ Array(0.02, 0.70, 0.10, 0.02, 0.02, 0.02, 0.02, 0.02, 0.02, 0.02, 0.02, 0.02), // label 1
+ Array(0.02, 0.02, 0.60, 0.02, 0.02, 0.02, 0.02, 0.02, 0.02, 0.02, 0.02, 0.30) // label 2
+ ).map(_.map(math.log))
+ val pi = Vectors.dense(piArray)
+ val theta = new DenseMatrix(3, 12, thetaArray.flatten, true)
+
+ val testDataset = sqlContext.createDataFrame(generateNaiveBayesInput(
+ piArray, thetaArray, nPoints, 45, "bernoulli"))
+ val nb = new NaiveBayes().setLambda(1.0).setModelType("bernoulli")
+ val model = nb.fit(testDataset)
+
+ validateModelFit(pi, theta, model)
+ assert(model.hasParent)
+
+ val validationDataset = sqlContext.createDataFrame(generateNaiveBayesInput(
+ piArray, thetaArray, nPoints, 20, "bernoulli"))
+ val predictionAndLabels = model.transform(validationDataset).select("prediction", "label")
+
+ validatePrediction(predictionAndLabels)
+ }
+}
diff --git a/mllib/src/test/scala/org/apache/spark/ml/clustering/KMeansSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/clustering/KMeansSuite.scala
new file mode 100644
index 0000000000000..1f15ac02f4008
--- /dev/null
+++ b/mllib/src/test/scala/org/apache/spark/ml/clustering/KMeansSuite.scala
@@ -0,0 +1,114 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.ml.clustering
+
+import org.apache.spark.SparkFunSuite
+import org.apache.spark.mllib.clustering.{KMeans => MLlibKMeans}
+import org.apache.spark.mllib.linalg.{Vector, Vectors}
+import org.apache.spark.mllib.util.MLlibTestSparkContext
+import org.apache.spark.sql.{DataFrame, SQLContext}
+
+private[clustering] case class TestRow(features: Vector)
+
+object KMeansSuite {
+ def generateKMeansData(sql: SQLContext, rows: Int, dim: Int, k: Int): DataFrame = {
+ val sc = sql.sparkContext
+ val rdd = sc.parallelize(1 to rows).map(i => Vectors.dense(Array.fill(dim)((i % k).toDouble)))
+ .map(v => new TestRow(v))
+ sql.createDataFrame(rdd)
+ }
+}
+
+class KMeansSuite extends SparkFunSuite with MLlibTestSparkContext {
+
+ final val k = 5
+ @transient var dataset: DataFrame = _
+
+ override def beforeAll(): Unit = {
+ super.beforeAll()
+
+ dataset = KMeansSuite.generateKMeansData(sqlContext, 50, 3, k)
+ }
+
+ test("default parameters") {
+ val kmeans = new KMeans()
+
+ assert(kmeans.getK === 2)
+ assert(kmeans.getFeaturesCol === "features")
+ assert(kmeans.getPredictionCol === "prediction")
+ assert(kmeans.getMaxIter === 20)
+ assert(kmeans.getRuns === 1)
+ assert(kmeans.getInitMode === MLlibKMeans.K_MEANS_PARALLEL)
+ assert(kmeans.getInitSteps === 5)
+ assert(kmeans.getEpsilon === 1e-4)
+ }
+
+ test("set parameters") {
+ val kmeans = new KMeans()
+ .setK(9)
+ .setFeaturesCol("test_feature")
+ .setPredictionCol("test_prediction")
+ .setMaxIter(33)
+ .setRuns(7)
+ .setInitMode(MLlibKMeans.RANDOM)
+ .setInitSteps(3)
+ .setSeed(123)
+ .setEpsilon(1e-3)
+
+ assert(kmeans.getK === 9)
+ assert(kmeans.getFeaturesCol === "test_feature")
+ assert(kmeans.getPredictionCol === "test_prediction")
+ assert(kmeans.getMaxIter === 33)
+ assert(kmeans.getRuns === 7)
+ assert(kmeans.getInitMode === MLlibKMeans.RANDOM)
+ assert(kmeans.getInitSteps === 3)
+ assert(kmeans.getSeed === 123)
+ assert(kmeans.getEpsilon === 1e-3)
+ }
+
+ test("parameters validation") {
+ intercept[IllegalArgumentException] {
+ new KMeans().setK(1)
+ }
+ intercept[IllegalArgumentException] {
+ new KMeans().setInitMode("no_such_a_mode")
+ }
+ intercept[IllegalArgumentException] {
+ new KMeans().setInitSteps(0)
+ }
+ intercept[IllegalArgumentException] {
+ new KMeans().setRuns(0)
+ }
+ }
+
+ test("fit & transform") {
+ val predictionColName = "kmeans_prediction"
+ val kmeans = new KMeans().setK(k).setPredictionCol(predictionColName).setSeed(1)
+ val model = kmeans.fit(dataset)
+ assert(model.clusterCenters.length === k)
+
+ val transformed = model.transform(dataset)
+ val expectedColumns = Array("features", predictionColName)
+ expectedColumns.foreach { column =>
+ assert(transformed.columns.contains(column))
+ }
+ val clusters = transformed.select(predictionColName).map(_.getInt(0)).distinct().collect().toSet
+ assert(clusters.size === k)
+ assert(clusters === Set(0, 1, 2, 3, 4))
+ }
+}
diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/RFormulaParserSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/RFormulaParserSuite.scala
new file mode 100644
index 0000000000000..c8d065f37a605
--- /dev/null
+++ b/mllib/src/test/scala/org/apache/spark/ml/feature/RFormulaParserSuite.scala
@@ -0,0 +1,34 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.ml.feature
+
+import org.apache.spark.SparkFunSuite
+
+class RFormulaParserSuite extends SparkFunSuite {
+ private def checkParse(formula: String, label: String, terms: Seq[String]) {
+ val parsed = RFormulaParser.parse(formula)
+ assert(parsed.label == label)
+ assert(parsed.terms == terms)
+ }
+
+ test("parse simple formulas") {
+ checkParse("y ~ x", "y", Seq("x"))
+ checkParse("y ~ ._foo ", "y", Seq("._foo"))
+ checkParse("resp ~ A_VAR + B + c123", "resp", Seq("A_VAR", "B", "c123"))
+ }
+}
diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/RFormulaSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/RFormulaSuite.scala
new file mode 100644
index 0000000000000..79c4ccf02d4e0
--- /dev/null
+++ b/mllib/src/test/scala/org/apache/spark/ml/feature/RFormulaSuite.scala
@@ -0,0 +1,102 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.ml.feature
+
+import org.apache.spark.SparkFunSuite
+import org.apache.spark.ml.param.ParamsSuite
+import org.apache.spark.mllib.linalg.Vectors
+import org.apache.spark.mllib.util.MLlibTestSparkContext
+
+class RFormulaSuite extends SparkFunSuite with MLlibTestSparkContext {
+ test("params") {
+ ParamsSuite.checkParams(new RFormula())
+ }
+
+ test("transform numeric data") {
+ val formula = new RFormula().setFormula("id ~ v1 + v2")
+ val original = sqlContext.createDataFrame(
+ Seq((0, 1.0, 3.0), (2, 2.0, 5.0))).toDF("id", "v1", "v2")
+ val result = formula.transform(original)
+ val resultSchema = formula.transformSchema(original.schema)
+ val expected = sqlContext.createDataFrame(
+ Seq(
+ (0, 1.0, 3.0, Vectors.dense(Array(1.0, 3.0)), 0.0),
+ (2, 2.0, 5.0, Vectors.dense(Array(2.0, 5.0)), 2.0))
+ ).toDF("id", "v1", "v2", "features", "label")
+ // TODO(ekl) make schema comparisons ignore metadata, to avoid .toString
+ assert(result.schema.toString == resultSchema.toString)
+ assert(resultSchema == expected.schema)
+ assert(result.collect().toSeq == expected.collect().toSeq)
+ }
+
+ test("features column already exists") {
+ val formula = new RFormula().setFormula("y ~ x").setFeaturesCol("x")
+ val original = sqlContext.createDataFrame(Seq((0, 1.0), (2, 2.0))).toDF("x", "y")
+ intercept[IllegalArgumentException] {
+ formula.transformSchema(original.schema)
+ }
+ intercept[IllegalArgumentException] {
+ formula.transform(original)
+ }
+ }
+
+ test("label column already exists") {
+ val formula = new RFormula().setFormula("y ~ x").setLabelCol("y")
+ val original = sqlContext.createDataFrame(Seq((0, 1.0), (2, 2.0))).toDF("x", "y")
+ val resultSchema = formula.transformSchema(original.schema)
+ assert(resultSchema.length == 3)
+ assert(resultSchema.toString == formula.transform(original).schema.toString)
+ }
+
+ test("label column already exists but is not double type") {
+ val formula = new RFormula().setFormula("y ~ x").setLabelCol("y")
+ val original = sqlContext.createDataFrame(Seq((0, 1), (2, 2))).toDF("x", "y")
+ intercept[IllegalArgumentException] {
+ formula.transformSchema(original.schema)
+ }
+ intercept[IllegalArgumentException] {
+ formula.transform(original)
+ }
+ }
+
+ test("allow missing label column for test datasets") {
+ val formula = new RFormula().setFormula("y ~ x").setLabelCol("label")
+ val original = sqlContext.createDataFrame(Seq((0, 1.0), (2, 2.0))).toDF("x", "_not_y")
+ val resultSchema = formula.transformSchema(original.schema)
+ assert(resultSchema.length == 3)
+ assert(!resultSchema.exists(_.name == "label"))
+ assert(resultSchema.toString == formula.transform(original).schema.toString)
+ }
+
+// TODO(ekl) enable after we implement string label support
+// test("transform string label") {
+// val formula = new RFormula().setFormula("name ~ id")
+// val original = sqlContext.createDataFrame(
+// Seq((1, "foo"), (2, "bar"), (3, "bar"))).toDF("id", "name")
+// val result = formula.transform(original)
+// val resultSchema = formula.transformSchema(original.schema)
+// val expected = sqlContext.createDataFrame(
+// Seq(
+// (1, "foo", Vectors.dense(Array(1.0)), 1.0),
+// (2, "bar", Vectors.dense(Array(2.0)), 0.0),
+// (3, "bar", Vectors.dense(Array(3.0)), 0.0))
+// ).toDF("id", "name", "features", "label")
+// assert(result.schema.toString == resultSchema.toString)
+// assert(result.collect().toSeq == expected.collect().toSeq)
+// }
+}
diff --git a/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala
index cf120cf2a4b47..7cdda3db88ad1 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala
@@ -18,6 +18,7 @@
package org.apache.spark.ml.regression
import org.apache.spark.SparkFunSuite
+import org.apache.spark.ml.param.ParamsSuite
import org.apache.spark.mllib.linalg.{DenseVector, Vectors}
import org.apache.spark.mllib.util.{LinearDataGenerator, MLlibTestSparkContext}
import org.apache.spark.mllib.util.TestingUtils._
@@ -55,6 +56,30 @@ class LinearRegressionSuite extends SparkFunSuite with MLlibTestSparkContext {
}
+ test("params") {
+ ParamsSuite.checkParams(new LinearRegression)
+ val model = new LinearRegressionModel("linearReg", Vectors.dense(0.0), 0.0)
+ ParamsSuite.checkParams(model)
+ }
+
+ test("linear regression: default params") {
+ val lir = new LinearRegression
+ assert(lir.getLabelCol === "label")
+ assert(lir.getFeaturesCol === "features")
+ assert(lir.getPredictionCol === "prediction")
+ assert(lir.getRegParam === 0.0)
+ assert(lir.getElasticNetParam === 0.0)
+ assert(lir.getFitIntercept)
+ val model = lir.fit(dataset)
+ model.transform(dataset)
+ .select("label", "prediction")
+ .collect()
+ assert(model.getFeaturesCol === "features")
+ assert(model.getPredictionCol === "prediction")
+ assert(model.intercept !== 0.0)
+ assert(model.hasParent)
+ }
+
test("linear regression with intercept without regularization") {
val trainer = new LinearRegression
val model = trainer.fit(dataset)
@@ -302,7 +327,7 @@ class LinearRegressionSuite extends SparkFunSuite with MLlibTestSparkContext {
.map { case Row(features: DenseVector, label: Double) =>
val prediction =
features(0) * model.weights(0) + features(1) * model.weights(1) + model.intercept
- prediction - label
+ label - prediction
}
.zip(model.summary.residuals.map(_.getDouble(0)))
.collect()
@@ -314,7 +339,7 @@ class LinearRegressionSuite extends SparkFunSuite with MLlibTestSparkContext {
Use the following R code to generate model training results.
predictions <- predict(fit, newx=features)
- residuals <- predictions - label
+ residuals <- label - predictions
> mean(residuals^2) # MSE
[1] 0.009720325
> mean(abs(residuals)) # MAD
diff --git a/mllib/src/test/scala/org/apache/spark/ml/tuning/TrainValidationSplitSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/tuning/TrainValidationSplitSuite.scala
new file mode 100644
index 0000000000000..c8e58f216cceb
--- /dev/null
+++ b/mllib/src/test/scala/org/apache/spark/ml/tuning/TrainValidationSplitSuite.scala
@@ -0,0 +1,139 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.ml.tuning
+
+import org.apache.spark.SparkFunSuite
+import org.apache.spark.ml.{Estimator, Model}
+import org.apache.spark.ml.classification.LogisticRegression
+import org.apache.spark.ml.evaluation.{BinaryClassificationEvaluator, Evaluator, RegressionEvaluator}
+import org.apache.spark.ml.param.ParamMap
+import org.apache.spark.ml.param.shared.HasInputCol
+import org.apache.spark.ml.regression.LinearRegression
+import org.apache.spark.mllib.classification.LogisticRegressionSuite.generateLogisticInput
+import org.apache.spark.mllib.util.{LinearDataGenerator, MLlibTestSparkContext}
+import org.apache.spark.sql.DataFrame
+import org.apache.spark.sql.types.StructType
+
+class TrainValidationSplitSuite extends SparkFunSuite with MLlibTestSparkContext {
+ test("train validation with logistic regression") {
+ val dataset = sqlContext.createDataFrame(
+ sc.parallelize(generateLogisticInput(1.0, 1.0, 100, 42), 2))
+
+ val lr = new LogisticRegression
+ val lrParamMaps = new ParamGridBuilder()
+ .addGrid(lr.regParam, Array(0.001, 1000.0))
+ .addGrid(lr.maxIter, Array(0, 10))
+ .build()
+ val eval = new BinaryClassificationEvaluator
+ val cv = new TrainValidationSplit()
+ .setEstimator(lr)
+ .setEstimatorParamMaps(lrParamMaps)
+ .setEvaluator(eval)
+ .setTrainRatio(0.5)
+ val cvModel = cv.fit(dataset)
+ val parent = cvModel.bestModel.parent.asInstanceOf[LogisticRegression]
+ assert(cv.getTrainRatio === 0.5)
+ assert(parent.getRegParam === 0.001)
+ assert(parent.getMaxIter === 10)
+ assert(cvModel.validationMetrics.length === lrParamMaps.length)
+ }
+
+ test("train validation with linear regression") {
+ val dataset = sqlContext.createDataFrame(
+ sc.parallelize(LinearDataGenerator.generateLinearInput(
+ 6.3, Array(4.7, 7.2), Array(0.9, -1.3), Array(0.7, 1.2), 100, 42, 0.1), 2))
+
+ val trainer = new LinearRegression
+ val lrParamMaps = new ParamGridBuilder()
+ .addGrid(trainer.regParam, Array(1000.0, 0.001))
+ .addGrid(trainer.maxIter, Array(0, 10))
+ .build()
+ val eval = new RegressionEvaluator()
+ val cv = new TrainValidationSplit()
+ .setEstimator(trainer)
+ .setEstimatorParamMaps(lrParamMaps)
+ .setEvaluator(eval)
+ .setTrainRatio(0.5)
+ val cvModel = cv.fit(dataset)
+ val parent = cvModel.bestModel.parent.asInstanceOf[LinearRegression]
+ assert(parent.getRegParam === 0.001)
+ assert(parent.getMaxIter === 10)
+ assert(cvModel.validationMetrics.length === lrParamMaps.length)
+
+ eval.setMetricName("r2")
+ val cvModel2 = cv.fit(dataset)
+ val parent2 = cvModel2.bestModel.parent.asInstanceOf[LinearRegression]
+ assert(parent2.getRegParam === 0.001)
+ assert(parent2.getMaxIter === 10)
+ assert(cvModel2.validationMetrics.length === lrParamMaps.length)
+ }
+
+ test("validateParams should check estimatorParamMaps") {
+ import TrainValidationSplitSuite._
+
+ val est = new MyEstimator("est")
+ val eval = new MyEvaluator
+ val paramMaps = new ParamGridBuilder()
+ .addGrid(est.inputCol, Array("input1", "input2"))
+ .build()
+
+ val cv = new TrainValidationSplit()
+ .setEstimator(est)
+ .setEstimatorParamMaps(paramMaps)
+ .setEvaluator(eval)
+ .setTrainRatio(0.5)
+ cv.validateParams() // This should pass.
+
+ val invalidParamMaps = paramMaps :+ ParamMap(est.inputCol -> "")
+ cv.setEstimatorParamMaps(invalidParamMaps)
+ intercept[IllegalArgumentException] {
+ cv.validateParams()
+ }
+ }
+}
+
+object TrainValidationSplitSuite {
+
+ abstract class MyModel extends Model[MyModel]
+
+ class MyEstimator(override val uid: String) extends Estimator[MyModel] with HasInputCol {
+
+ override def validateParams(): Unit = require($(inputCol).nonEmpty)
+
+ override def fit(dataset: DataFrame): MyModel = {
+ throw new UnsupportedOperationException
+ }
+
+ override def transformSchema(schema: StructType): StructType = {
+ throw new UnsupportedOperationException
+ }
+
+ override def copy(extra: ParamMap): MyEstimator = defaultCopy(extra)
+ }
+
+ class MyEvaluator extends Evaluator {
+
+ override def evaluate(dataset: DataFrame): Double = {
+ throw new UnsupportedOperationException
+ }
+
+ override val uid: String = "eval"
+
+ override def copy(extra: ParamMap): MyEvaluator = defaultCopy(extra)
+ }
+}
diff --git a/mllib/src/test/scala/org/apache/spark/ml/util/StopwatchSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/util/StopwatchSuite.scala
new file mode 100644
index 0000000000000..9e6bc7193c13b
--- /dev/null
+++ b/mllib/src/test/scala/org/apache/spark/ml/util/StopwatchSuite.scala
@@ -0,0 +1,125 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.ml.util
+
+import java.util.Random
+
+import org.apache.spark.SparkFunSuite
+import org.apache.spark.mllib.util.MLlibTestSparkContext
+
+class StopwatchSuite extends SparkFunSuite with MLlibTestSparkContext {
+
+ import StopwatchSuite._
+
+ private def testStopwatchOnDriver(sw: Stopwatch): Unit = {
+ assert(sw.name === "sw")
+ assert(sw.elapsed() === 0L)
+ assert(!sw.isRunning)
+ intercept[AssertionError] {
+ sw.stop()
+ }
+ val duration = checkStopwatch(sw)
+ val elapsed = sw.elapsed()
+ assert(elapsed === duration)
+ val duration2 = checkStopwatch(sw)
+ val elapsed2 = sw.elapsed()
+ assert(elapsed2 === duration + duration2)
+ assert(sw.toString === s"sw: ${elapsed2}ms")
+ sw.start()
+ assert(sw.isRunning)
+ intercept[AssertionError] {
+ sw.start()
+ }
+ }
+
+ test("LocalStopwatch") {
+ val sw = new LocalStopwatch("sw")
+ testStopwatchOnDriver(sw)
+ }
+
+ test("DistributedStopwatch on driver") {
+ val sw = new DistributedStopwatch(sc, "sw")
+ testStopwatchOnDriver(sw)
+ }
+
+ test("DistributedStopwatch on executors") {
+ val sw = new DistributedStopwatch(sc, "sw")
+ val rdd = sc.parallelize(0 until 4, 4)
+ val acc = sc.accumulator(0L)
+ rdd.foreach { i =>
+ acc += checkStopwatch(sw)
+ }
+ assert(!sw.isRunning)
+ val elapsed = sw.elapsed()
+ assert(elapsed === acc.value)
+ }
+
+ test("MultiStopwatch") {
+ val sw = new MultiStopwatch(sc)
+ .addLocal("local")
+ .addDistributed("spark")
+ assert(sw("local").name === "local")
+ assert(sw("spark").name === "spark")
+ intercept[NoSuchElementException] {
+ sw("some")
+ }
+ assert(sw.toString === "{\n local: 0ms,\n spark: 0ms\n}")
+ val localDuration = checkStopwatch(sw("local"))
+ val sparkDuration = checkStopwatch(sw("spark"))
+ val localElapsed = sw("local").elapsed()
+ val sparkElapsed = sw("spark").elapsed()
+ assert(localElapsed === localDuration)
+ assert(sparkElapsed === sparkDuration)
+ assert(sw.toString ===
+ s"{\n local: ${localElapsed}ms,\n spark: ${sparkElapsed}ms\n}")
+ val rdd = sc.parallelize(0 until 4, 4)
+ val acc = sc.accumulator(0L)
+ rdd.foreach { i =>
+ sw("local").start()
+ val duration = checkStopwatch(sw("spark"))
+ sw("local").stop()
+ acc += duration
+ }
+ val localElapsed2 = sw("local").elapsed()
+ assert(localElapsed2 === localElapsed)
+ val sparkElapsed2 = sw("spark").elapsed()
+ assert(sparkElapsed2 === sparkElapsed + acc.value)
+ }
+}
+
+private object StopwatchSuite extends SparkFunSuite {
+
+ /**
+ * Checks the input stopwatch on a task that takes a random time (<10ms) to finish. Validates and
+ * returns the duration reported by the stopwatch.
+ */
+ def checkStopwatch(sw: Stopwatch): Long = {
+ val ubStart = now
+ sw.start()
+ val lbStart = now
+ Thread.sleep(new Random().nextInt(10))
+ val lb = now - lbStart
+ val duration = sw.stop()
+ val ub = now - ubStart
+ assert(duration >= lb && duration <= ub)
+ duration
+ }
+
+ /** The current time in milliseconds. */
+ private def now: Long = System.currentTimeMillis()
+}
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/classification/NaiveBayesSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/classification/NaiveBayesSuite.scala
index f7fc8730606af..cffa1ab700f80 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/classification/NaiveBayesSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/classification/NaiveBayesSuite.scala
@@ -19,13 +19,14 @@ package org.apache.spark.mllib.classification
import scala.util.Random
-import breeze.linalg.{DenseMatrix => BDM, DenseVector => BDV, argmax => brzArgmax, sum => brzSum}
+import breeze.linalg.{DenseMatrix => BDM, DenseVector => BDV, Vector => BV}
import breeze.stats.distributions.{Multinomial => BrzMultinomial}
import org.apache.spark.{SparkException, SparkFunSuite}
-import org.apache.spark.mllib.linalg.Vectors
+import org.apache.spark.mllib.linalg.{Vector, Vectors}
import org.apache.spark.mllib.regression.LabeledPoint
import org.apache.spark.mllib.util.{LocalClusterSparkContext, MLlibTestSparkContext}
+import org.apache.spark.mllib.util.TestingUtils._
import org.apache.spark.util.Utils
object NaiveBayesSuite {
@@ -154,6 +155,29 @@ class NaiveBayesSuite extends SparkFunSuite with MLlibTestSparkContext {
// Test prediction on Array.
validatePrediction(validationData.map(row => model.predict(row.features)), validationData)
+
+ // Test posteriors
+ validationData.map(_.features).foreach { features =>
+ val predicted = model.predictProbabilities(features).toArray
+ assert(predicted.sum ~== 1.0 relTol 1.0e-10)
+ val expected = expectedMultinomialProbabilities(model, features)
+ expected.zip(predicted).foreach { case (e, p) => assert(e ~== p relTol 1.0e-10) }
+ }
+ }
+
+ /**
+ * @param model Multinomial Naive Bayes model
+ * @param testData input to compute posterior probabilities for
+ * @return posterior class probabilities (in order of labels) for input
+ */
+ private def expectedMultinomialProbabilities(model: NaiveBayesModel, testData: Vector) = {
+ val piVector = new BDV(model.pi)
+ // model.theta is row-major; treat it as col-major representation of transpose, and transpose:
+ val thetaMatrix = new BDM(model.theta(0).length, model.theta.length, model.theta.flatten).t
+ val logClassProbs: BV[Double] = piVector + (thetaMatrix * testData.toBreeze)
+ val classProbs = logClassProbs.toArray.map(math.exp)
+ val classProbsSum = classProbs.sum
+ classProbs.map(_ / classProbsSum)
}
test("Naive Bayes Bernoulli") {
@@ -182,6 +206,33 @@ class NaiveBayesSuite extends SparkFunSuite with MLlibTestSparkContext {
// Test prediction on Array.
validatePrediction(validationData.map(row => model.predict(row.features)), validationData)
+
+ // Test posteriors
+ validationData.map(_.features).foreach { features =>
+ val predicted = model.predictProbabilities(features).toArray
+ assert(predicted.sum ~== 1.0 relTol 1.0e-10)
+ val expected = expectedBernoulliProbabilities(model, features)
+ expected.zip(predicted).foreach { case (e, p) => assert(e ~== p relTol 1.0e-10) }
+ }
+ }
+
+ /**
+ * @param model Bernoulli Naive Bayes model
+ * @param testData input to compute posterior probabilities for
+ * @return posterior class probabilities (in order of labels) for input
+ */
+ private def expectedBernoulliProbabilities(model: NaiveBayesModel, testData: Vector) = {
+ val piVector = new BDV(model.pi)
+ val thetaMatrix = new BDM(model.theta(0).length, model.theta.length, model.theta.flatten).t
+ val negThetaMatrix = new BDM(model.theta(0).length, model.theta.length,
+ model.theta.flatten.map(v => math.log(1.0 - math.exp(v)))).t
+ val testBreeze = testData.toBreeze
+ val negTestBreeze = new BDV(Array.fill(testBreeze.size)(1.0)) - testBreeze
+ val piTheta: BV[Double] = piVector + (thetaMatrix * testBreeze)
+ val logClassProbs: BV[Double] = piTheta + (negThetaMatrix * negTestBreeze)
+ val classProbs = logClassProbs.toArray.map(math.exp)
+ val classProbsSum = classProbs.sum
+ classProbs.map(_ / classProbsSum)
}
test("detect negative values") {
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/clustering/KMeansSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/clustering/KMeansSuite.scala
index 0dbbd7127444f..3003c62d9876c 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/clustering/KMeansSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/clustering/KMeansSuite.scala
@@ -278,6 +278,28 @@ class KMeansSuite extends SparkFunSuite with MLlibTestSparkContext {
}
}
}
+
+ test("Initialize using given cluster centers") {
+ val points = Seq(
+ Vectors.dense(0.0, 0.0),
+ Vectors.dense(1.0, 0.0),
+ Vectors.dense(0.0, 1.0),
+ Vectors.dense(1.0, 1.0)
+ )
+ val rdd = sc.parallelize(points, 3)
+ // creating an initial model
+ val initialModel = new KMeansModel(Array(points(0), points(2)))
+
+ val returnModel = new KMeans()
+ .setK(2)
+ .setMaxIterations(0)
+ .setInitialModel(initialModel)
+ .run(rdd)
+ // comparing the returned model and the initial model
+ assert(returnModel.clusterCenters(0) === initialModel.clusterCenters(0))
+ assert(returnModel.clusterCenters(1) === initialModel.clusterCenters(1))
+ }
+
}
object KMeansSuite extends SparkFunSuite {
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/clustering/LDASuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/clustering/LDASuite.scala
index 03a8a2538b464..da70d9bd7c790 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/clustering/LDASuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/clustering/LDASuite.scala
@@ -20,9 +20,10 @@ package org.apache.spark.mllib.clustering
import breeze.linalg.{DenseMatrix => BDM}
import org.apache.spark.SparkFunSuite
-import org.apache.spark.mllib.linalg.{Vector, DenseMatrix, Matrix, Vectors}
+import org.apache.spark.mllib.linalg.{DenseMatrix, Matrix, Vector, Vectors}
import org.apache.spark.mllib.util.MLlibTestSparkContext
import org.apache.spark.mllib.util.TestingUtils._
+import org.apache.spark.util.Utils
class LDASuite extends SparkFunSuite with MLlibTestSparkContext {
@@ -131,22 +132,38 @@ class LDASuite extends SparkFunSuite with MLlibTestSparkContext {
test("setter alias") {
val lda = new LDA().setAlpha(2.0).setBeta(3.0)
- assert(lda.getAlpha === 2.0)
- assert(lda.getDocConcentration === 2.0)
+ assert(lda.getAlpha.toArray.forall(_ === 2.0))
+ assert(lda.getDocConcentration.toArray.forall(_ === 2.0))
assert(lda.getBeta === 3.0)
assert(lda.getTopicConcentration === 3.0)
}
+ test("initializing with alpha length != k or 1 fails") {
+ intercept[IllegalArgumentException] {
+ val lda = new LDA().setK(2).setAlpha(Vectors.dense(1, 2, 3, 4))
+ val corpus = sc.parallelize(tinyCorpus, 2)
+ lda.run(corpus)
+ }
+ }
+
+ test("initializing with elements in alpha < 0 fails") {
+ intercept[IllegalArgumentException] {
+ val lda = new LDA().setK(4).setAlpha(Vectors.dense(-1, 2, 3, 4))
+ val corpus = sc.parallelize(tinyCorpus, 2)
+ lda.run(corpus)
+ }
+ }
+
test("OnlineLDAOptimizer initialization") {
val lda = new LDA().setK(2)
val corpus = sc.parallelize(tinyCorpus, 2)
val op = new OnlineLDAOptimizer().initialize(corpus, lda)
op.setKappa(0.9876).setMiniBatchFraction(0.123).setTau0(567)
- assert(op.getAlpha == 0.5) // default 1.0 / k
- assert(op.getEta == 0.5) // default 1.0 / k
- assert(op.getKappa == 0.9876)
- assert(op.getMiniBatchFraction == 0.123)
- assert(op.getTau0 == 567)
+ assert(op.getAlpha.toArray.forall(_ === 0.5)) // default 1.0 / k
+ assert(op.getEta === 0.5) // default 1.0 / k
+ assert(op.getKappa === 0.9876)
+ assert(op.getMiniBatchFraction === 0.123)
+ assert(op.getTau0 === 567)
}
test("OnlineLDAOptimizer one iteration") {
@@ -217,6 +234,96 @@ class LDASuite extends SparkFunSuite with MLlibTestSparkContext {
}
}
+ test("OnlineLDAOptimizer with asymmetric prior") {
+ def toydata: Array[(Long, Vector)] = Array(
+ Vectors.sparse(6, Array(0, 1), Array(1, 1)),
+ Vectors.sparse(6, Array(1, 2), Array(1, 1)),
+ Vectors.sparse(6, Array(0, 2), Array(1, 1)),
+ Vectors.sparse(6, Array(3, 4), Array(1, 1)),
+ Vectors.sparse(6, Array(3, 5), Array(1, 1)),
+ Vectors.sparse(6, Array(4, 5), Array(1, 1))
+ ).zipWithIndex.map { case (wordCounts, docId) => (docId.toLong, wordCounts) }
+
+ val docs = sc.parallelize(toydata)
+ val op = new OnlineLDAOptimizer().setMiniBatchFraction(1).setTau0(1024).setKappa(0.51)
+ .setGammaShape(1e10)
+ val lda = new LDA().setK(2)
+ .setDocConcentration(Vectors.dense(0.00001, 0.1))
+ .setTopicConcentration(0.01)
+ .setMaxIterations(100)
+ .setOptimizer(op)
+ .setSeed(12345)
+
+ val ldaModel = lda.run(docs)
+ val topicIndices = ldaModel.describeTopics(maxTermsPerTopic = 10)
+ val topics = topicIndices.map { case (terms, termWeights) =>
+ terms.zip(termWeights)
+ }
+
+ /* Verify results with Python:
+
+ import numpy as np
+ from gensim import models
+ corpus = [
+ [(0, 1.0), (1, 1.0)],
+ [(1, 1.0), (2, 1.0)],
+ [(0, 1.0), (2, 1.0)],
+ [(3, 1.0), (4, 1.0)],
+ [(3, 1.0), (5, 1.0)],
+ [(4, 1.0), (5, 1.0)]]
+ np.random.seed(10)
+ lda = models.ldamodel.LdaModel(
+ corpus=corpus, alpha=np.array([0.00001, 0.1]), num_topics=2, update_every=0, passes=100)
+ lda.print_topics()
+
+ > ['0.167*0 + 0.167*1 + 0.167*2 + 0.167*3 + 0.167*4 + 0.167*5',
+ '0.167*0 + 0.167*1 + 0.167*2 + 0.167*4 + 0.167*3 + 0.167*5']
+ */
+ topics.foreach { topic =>
+ assert(topic.forall { case (_, p) => p ~= 0.167 absTol 0.05 })
+ }
+ }
+
+ test("model save/load") {
+ // Test for LocalLDAModel.
+ val localModel = new LocalLDAModel(tinyTopics)
+ val tempDir1 = Utils.createTempDir()
+ val path1 = tempDir1.toURI.toString
+
+ // Test for DistributedLDAModel.
+ val k = 3
+ val docConcentration = 1.2
+ val topicConcentration = 1.5
+ val lda = new LDA()
+ lda.setK(k)
+ .setDocConcentration(docConcentration)
+ .setTopicConcentration(topicConcentration)
+ .setMaxIterations(5)
+ .setSeed(12345)
+ val corpus = sc.parallelize(tinyCorpus, 2)
+ val distributedModel: DistributedLDAModel = lda.run(corpus).asInstanceOf[DistributedLDAModel]
+ val tempDir2 = Utils.createTempDir()
+ val path2 = tempDir2.toURI.toString
+
+ try {
+ localModel.save(sc, path1)
+ distributedModel.save(sc, path2)
+ val samelocalModel = LocalLDAModel.load(sc, path1)
+ assert(samelocalModel.topicsMatrix === localModel.topicsMatrix)
+ assert(samelocalModel.k === localModel.k)
+ assert(samelocalModel.vocabSize === localModel.vocabSize)
+
+ val sameDistributedModel = DistributedLDAModel.load(sc, path2)
+ assert(distributedModel.topicsMatrix === sameDistributedModel.topicsMatrix)
+ assert(distributedModel.k === sameDistributedModel.k)
+ assert(distributedModel.vocabSize === sameDistributedModel.vocabSize)
+ assert(distributedModel.iterationTimes === sameDistributedModel.iterationTimes)
+ } finally {
+ Utils.deleteRecursively(tempDir1)
+ Utils.deleteRecursively(tempDir2)
+ }
+ }
+
}
private[clustering] object LDASuite {
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/clustering/PowerIterationClusteringSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/clustering/PowerIterationClusteringSuite.scala
index 19e65f1b53ab5..189000512155f 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/clustering/PowerIterationClusteringSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/clustering/PowerIterationClusteringSuite.scala
@@ -68,6 +68,54 @@ class PowerIterationClusteringSuite extends SparkFunSuite with MLlibTestSparkCon
assert(predictions2.toSet == Set((0 to 3).toSet, (4 to 15).toSet))
}
+ test("power iteration clustering on graph") {
+ /*
+ We use the following graph to test PIC. All edges are assigned similarity 1.0 except 0.1 for
+ edge (3, 4).
+
+ 15-14 -13 -12
+ | |
+ 4 . 3 - 2 11
+ | | x | |
+ 5 0 - 1 10
+ | |
+ 6 - 7 - 8 - 9
+ */
+
+ val similarities = Seq[(Long, Long, Double)]((0, 1, 1.0), (0, 2, 1.0), (0, 3, 1.0), (1, 2, 1.0),
+ (1, 3, 1.0), (2, 3, 1.0), (3, 4, 0.1), // (3, 4) is a weak edge
+ (4, 5, 1.0), (4, 15, 1.0), (5, 6, 1.0), (6, 7, 1.0), (7, 8, 1.0), (8, 9, 1.0), (9, 10, 1.0),
+ (10, 11, 1.0), (11, 12, 1.0), (12, 13, 1.0), (13, 14, 1.0), (14, 15, 1.0))
+
+ val edges = similarities.flatMap { case (i, j, s) =>
+ if (i != j) {
+ Seq(Edge(i, j, s), Edge(j, i, s))
+ } else {
+ None
+ }
+ }
+ val graph = Graph.fromEdges(sc.parallelize(edges, 2), 0.0)
+
+ val model = new PowerIterationClustering()
+ .setK(2)
+ .run(graph)
+ val predictions = Array.fill(2)(mutable.Set.empty[Long])
+ model.assignments.collect().foreach { a =>
+ predictions(a.cluster) += a.id
+ }
+ assert(predictions.toSet == Set((0 to 3).toSet, (4 to 15).toSet))
+
+ val model2 = new PowerIterationClustering()
+ .setK(2)
+ .setInitializationMode("degree")
+ .run(sc.parallelize(similarities, 2))
+ val predictions2 = Array.fill(2)(mutable.Set.empty[Long])
+ model2.assignments.collect().foreach { a =>
+ predictions2(a.cluster) += a.id
+ }
+ assert(predictions2.toSet == Set((0 to 3).toSet, (4 to 15).toSet))
+ }
+
test("normalize and powerIter") {
/*
Test normalize() with the following graph:
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/evaluation/RegressionMetricsSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/evaluation/RegressionMetricsSuite.scala
index 9de2bdb6d7246..4b7f1be58f99b 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/evaluation/RegressionMetricsSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/evaluation/RegressionMetricsSuite.scala
@@ -23,24 +23,85 @@ import org.apache.spark.mllib.util.TestingUtils._
class RegressionMetricsSuite extends SparkFunSuite with MLlibTestSparkContext {
- test("regression metrics") {
+ test("regression metrics for unbiased (includes intercept term) predictor") {
+ /* Verify results in R:
+ preds = c(2.25, -0.25, 1.75, 7.75)
+ obs = c(3.0, -0.5, 2.0, 7.0)
+
+ SStot = sum((obs - mean(obs))^2)
+ SSreg = sum((preds - mean(obs))^2)
+ SSerr = sum((obs - preds)^2)
+
+ explainedVariance = SSreg / length(obs)
+ explainedVariance
+ > [1] 8.796875
+ meanAbsoluteError = mean(abs(preds - obs))
+ meanAbsoluteError
+ > [1] 0.5
+ meanSquaredError = mean((preds - obs)^2)
+ meanSquaredError
+ > [1] 0.3125
+ rmse = sqrt(meanSquaredError)
+ rmse
+ > [1] 0.559017
+ r2 = 1 - SSerr / SStot
+ r2
+ > [1] 0.9571734
+ */
+ val predictionAndObservations = sc.parallelize(
+ Seq((2.25, 3.0), (-0.25, -0.5), (1.75, 2.0), (7.75, 7.0)), 2)
+ val metrics = new RegressionMetrics(predictionAndObservations)
+ assert(metrics.explainedVariance ~== 8.79687 absTol 1E-5,
+ "explained variance regression score mismatch")
+ assert(metrics.meanAbsoluteError ~== 0.5 absTol 1E-5, "mean absolute error mismatch")
+ assert(metrics.meanSquaredError ~== 0.3125 absTol 1E-5, "mean squared error mismatch")
+ assert(metrics.rootMeanSquaredError ~== 0.55901 absTol 1E-5,
+ "root mean squared error mismatch")
+ assert(metrics.r2 ~== 0.95717 absTol 1E-5, "r2 score mismatch")
+ }
+
+ test("regression metrics for biased (no intercept term) predictor") {
+ /* Verify results in R:
+ preds = c(2.5, 0.0, 2.0, 8.0)
+ obs = c(3.0, -0.5, 2.0, 7.0)
+
+ SStot = sum((obs - mean(obs))^2)
+ SSreg = sum((preds - mean(obs))^2)
+ SSerr = sum((obs - preds)^2)
+
+ explainedVariance = SSreg / length(obs)
+ explainedVariance
+ > [1] 8.859375
+ meanAbsoluteError = mean(abs(preds - obs))
+ meanAbsoluteError
+ > [1] 0.5
+ meanSquaredError = mean((preds - obs)^2)
+ meanSquaredError
+ > [1] 0.375
+ rmse = sqrt(meanSquaredError)
+ rmse
+ > [1] 0.6123724
+ r2 = 1 - SSerr / SStot
+ r2
+ > [1] 0.9486081
+ */
val predictionAndObservations = sc.parallelize(
Seq((2.5, 3.0), (0.0, -0.5), (2.0, 2.0), (8.0, 7.0)), 2)
val metrics = new RegressionMetrics(predictionAndObservations)
- assert(metrics.explainedVariance ~== 0.95717 absTol 1E-5,
+ assert(metrics.explainedVariance ~== 8.85937 absTol 1E-5,
"explained variance regression score mismatch")
assert(metrics.meanAbsoluteError ~== 0.5 absTol 1E-5, "mean absolute error mismatch")
assert(metrics.meanSquaredError ~== 0.375 absTol 1E-5, "mean squared error mismatch")
assert(metrics.rootMeanSquaredError ~== 0.61237 absTol 1E-5,
"root mean squared error mismatch")
- assert(metrics.r2 ~== 0.94861 absTol 1E-5, "r2 score mismatch")
+ assert(metrics.r2 ~== 0.94860 absTol 1E-5, "r2 score mismatch")
}
test("regression metrics with complete fitting") {
val predictionAndObservations = sc.parallelize(
Seq((3.0, 3.0), (0.0, 0.0), (2.0, 2.0), (8.0, 8.0)), 2)
val metrics = new RegressionMetrics(predictionAndObservations)
- assert(metrics.explainedVariance ~== 1.0 absTol 1E-5,
+ assert(metrics.explainedVariance ~== 8.6875 absTol 1E-5,
"explained variance regression score mismatch")
assert(metrics.meanAbsoluteError ~== 0.0 absTol 1E-5, "mean absolute error mismatch")
assert(metrics.meanSquaredError ~== 0.0 absTol 1E-5, "mean squared error mismatch")
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/fpm/PrefixSpanSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/fpm/PrefixSpanSuite.scala
index 413436d3db85f..9f107c89f6d80 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/fpm/PrefixSpanSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/fpm/PrefixSpanSuite.scala
@@ -18,9 +18,8 @@ package org.apache.spark.mllib.fpm
import org.apache.spark.SparkFunSuite
import org.apache.spark.mllib.util.MLlibTestSparkContext
-import org.apache.spark.rdd.RDD
-class PrefixspanSuite extends SparkFunSuite with MLlibTestSparkContext {
+class PrefixSpanSuite extends SparkFunSuite with MLlibTestSparkContext {
test("PrefixSpan using Integer type") {
@@ -48,15 +47,8 @@ class PrefixspanSuite extends SparkFunSuite with MLlibTestSparkContext {
def compareResult(
expectedValue: Array[(Array[Int], Long)],
actualValue: Array[(Array[Int], Long)]): Boolean = {
- val sortedExpectedValue = expectedValue.sortWith{ (x, y) =>
- x._1.mkString(",") + ":" + x._2 < y._1.mkString(",") + ":" + y._2
- }
- val sortedActualValue = actualValue.sortWith{ (x, y) =>
- x._1.mkString(",") + ":" + x._2 < y._1.mkString(",") + ":" + y._2
- }
- sortedExpectedValue.zip(sortedActualValue)
- .map(x => x._1._1.mkString(",") == x._2._1.mkString(",") && x._1._2 == x._2._2)
- .reduce(_&&_)
+ expectedValue.map(x => (x._1.toSeq, x._2)).toSet ==
+ actualValue.map(x => (x._1.toSeq, x._2)).toSet
}
val prefixspan = new PrefixSpan()
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/linalg/BLASSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/linalg/BLASSuite.scala
index b0f3f71113c57..d119e0b50a393 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/linalg/BLASSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/linalg/BLASSuite.scala
@@ -200,8 +200,14 @@ class BLASSuite extends SparkFunSuite {
val C10 = C1.copy
val C11 = C1.copy
val C12 = C1.copy
+ val C13 = C1.copy
+ val C14 = C1.copy
+ val C15 = C1.copy
+ val C16 = C1.copy
val expected2 = new DenseMatrix(4, 2, Array(2.0, 1.0, 4.0, 2.0, 4.0, 0.0, 4.0, 3.0))
val expected3 = new DenseMatrix(4, 2, Array(2.0, 2.0, 4.0, 2.0, 8.0, 0.0, 6.0, 6.0))
+ val expected4 = new DenseMatrix(4, 2, Array(5.0, 0.0, 10.0, 5.0, 0.0, 0.0, 5.0, 0.0))
+ val expected5 = C1.copy
gemm(1.0, dA, B, 2.0, C1)
gemm(1.0, sA, B, 2.0, C2)
@@ -248,6 +254,16 @@ class BLASSuite extends SparkFunSuite {
assert(C10 ~== expected2 absTol 1e-15)
assert(C11 ~== expected3 absTol 1e-15)
assert(C12 ~== expected3 absTol 1e-15)
+
+ gemm(0, dA, B, 5, C13)
+ gemm(0, sA, B, 5, C14)
+ gemm(0, dA, B, 1, C15)
+ gemm(0, sA, B, 1, C16)
+ assert(C13 ~== expected4 absTol 1e-15)
+ assert(C14 ~== expected4 absTol 1e-15)
+ assert(C15 ~== expected5 absTol 1e-15)
+ assert(C16 ~== expected5 absTol 1e-15)
+
}
test("gemv") {
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/linalg/VectorsSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/linalg/VectorsSuite.scala
index 178d95a7b94ec..03be4119bdaca 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/linalg/VectorsSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/linalg/VectorsSuite.scala
@@ -62,11 +62,50 @@ class VectorsSuite extends SparkFunSuite with Logging {
assert(vec.toArray.eq(arr))
}
+ test("dense argmax") {
+ val vec = Vectors.dense(Array.empty[Double]).asInstanceOf[DenseVector]
+ assert(vec.argmax === -1)
+
+ val vec2 = Vectors.dense(arr).asInstanceOf[DenseVector]
+ assert(vec2.argmax === 3)
+
+ val vec3 = Vectors.dense(Array(-1.0, 0.0, -2.0, 1.0)).asInstanceOf[DenseVector]
+ assert(vec3.argmax === 3)
+ }
+
test("sparse to array") {
val vec = Vectors.sparse(n, indices, values).asInstanceOf[SparseVector]
assert(vec.toArray === arr)
}
+ test("sparse argmax") {
+ val vec = Vectors.sparse(0, Array.empty[Int], Array.empty[Double]).asInstanceOf[SparseVector]
+ assert(vec.argmax === -1)
+
+ val vec2 = Vectors.sparse(n, indices, values).asInstanceOf[SparseVector]
+ assert(vec2.argmax === 3)
+
+ val vec3 = Vectors.sparse(5, Array(2, 3, 4), Array(1.0, 0.0, -.7))
+ assert(vec3.argmax === 2)
+
+ // check for case that sparse vector is created with
+ // only negative values {0.0, 0.0,-1.0, -0.7, 0.0}
+ val vec4 = Vectors.sparse(5, Array(2, 3), Array(-1.0, -.7))
+ assert(vec4.argmax === 0)
+
+ val vec5 = Vectors.sparse(11, Array(0, 3, 10), Array(-1.0, -.7, 0.0))
+ assert(vec5.argmax === 1)
+
+ val vec6 = Vectors.sparse(11, Array(0, 1, 2), Array(-1.0, -.7, 0.0))
+ assert(vec6.argmax === 2)
+
+ val vec7 = Vectors.sparse(5, Array(0, 1, 3), Array(-1.0, 0.0, -.7))
+ assert(vec7.argmax === 1)
+
+ val vec8 = Vectors.sparse(5, Array(1, 2), Array(0.0, -1.0))
+ assert(vec8.argmax === 0)
+ }
+
test("vector equals") {
val dv1 = Vectors.dense(arr.clone())
val dv2 = Vectors.dense(arr.clone())
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/tree/EnsembleTestHelper.scala b/mllib/src/test/scala/org/apache/spark/mllib/tree/EnsembleTestHelper.scala
index 8972c229b7ecb..334bf3790fc7a 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/tree/EnsembleTestHelper.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/tree/EnsembleTestHelper.scala
@@ -70,7 +70,7 @@ object EnsembleTestHelper {
metricName: String = "mse") {
val predictions = input.map(x => model.predict(x.features))
val errors = predictions.zip(input.map(_.label)).map { case (prediction, label) =>
- prediction - label
+ label - prediction
}
val metric = metricName match {
case "mse" =>
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/util/LocalClusterSparkContext.scala b/mllib/src/test/scala/org/apache/spark/mllib/util/LocalClusterSparkContext.scala
index 5e9101cdd3804..525ab68c7921a 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/util/LocalClusterSparkContext.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/util/LocalClusterSparkContext.scala
@@ -26,7 +26,7 @@ trait LocalClusterSparkContext extends BeforeAndAfterAll { self: Suite =>
override def beforeAll() {
val conf = new SparkConf()
- .setMaster("local-cluster[2, 1, 512]")
+ .setMaster("local-cluster[2, 1, 1024]")
.setAppName("test-cluster")
.set("spark.akka.frameSize", "1") // set to 1MB to detect direct serialization of data
sc = new SparkContext(conf)
diff --git a/pom.xml b/pom.xml
index c2ebc1a11e770..1f44dc8abe1d4 100644
--- a/pom.xml
+++ b/pom.xml
@@ -128,7 +128,7 @@
${hadoop.version}0.98.7-hadoop2hbase
- 1.4.0
+ 1.6.03.4.52.4.0org.spark-project.hive
@@ -144,7 +144,7 @@
0.5.02.4.02.0.8
- 3.1.0
+ 3.1.21.7.7hadoop20.7.1
@@ -152,7 +152,6 @@
1.2.14.3.23.4.1
- ${project.build.directory}/spark-test-classpath.txt2.10.42.10${scala.version}
@@ -283,18 +282,6 @@
false
-
-
- spark-1.4-staging
- Spark 1.4 RC4 Staging Repository
- https://repository.apache.org/content/repositories/orgapachespark-1112
-
- true
-
-
- false
-
-
@@ -318,17 +305,6 @@
unused1.0.0
-
-
- org.codehaus.groovy
- groovy-all
- 2.3.7
- provided
- com.fasterxml.jackson.module
- jackson-module-scala_2.10
+ jackson-module-scala_${scala.binary.version}${fasterxml.jackson.version}
@@ -748,6 +724,12 @@
curator-framework${curator.version}
+
+ org.apache.curator
+ curator-test
+ ${curator.version}
+ test
+ org.apache.hadoophadoop-client
@@ -1406,6 +1388,58 @@
maven-deploy-plugin2.8.2
+
+
+
+ org.eclipse.m2e
+ lifecycle-mapping
+ 1.0.0
+
+
+
+
+
+ org.apache.maven.plugins
+ maven-dependency-plugin
+ [2.8,)
+
+ build-classpath
+
+
+
+
+
+
+
+
+ org.apache.maven.plugins
+ maven-jar-plugin
+ [2.6,)
+
+ test-jar
+
+
+
+
+
+
+
+
+ org.apache.maven.plugins
+ maven-antrun-plugin
+ [1.8,)
+
+ run
+
+
+
+
+
+
+
+
+
+
@@ -1423,34 +1457,12 @@
test
- ${test_classpath_file}
+ test_classpath
-
-
- org.codehaus.gmavenplus
- gmavenplus-plugin
- 1.5
-
-
- process-test-classes
-
- execute
-
-
-
-
-
-
-
-
-
diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeFixedWidthAggregationMap.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeFixedWidthAggregationMap.java
index 79d55b36dab01..2f7e84a7f59e2 100644
--- a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeFixedWidthAggregationMap.java
+++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeFixedWidthAggregationMap.java
@@ -19,11 +19,9 @@
import java.util.Iterator;
-import scala.Function1;
-
import org.apache.spark.sql.catalyst.InternalRow;
-import org.apache.spark.sql.catalyst.util.ObjectPool;
-import org.apache.spark.sql.catalyst.util.UniqueObjectPool;
+import org.apache.spark.sql.types.StructField;
+import org.apache.spark.sql.types.StructType;
import org.apache.spark.unsafe.PlatformDependent;
import org.apache.spark.unsafe.map.BytesToBytesMap;
import org.apache.spark.unsafe.memory.MemoryLocation;
@@ -40,48 +38,26 @@ public final class UnsafeFixedWidthAggregationMap {
* An empty aggregation buffer, encoded in UnsafeRow format. When inserting a new key into the
* map, we copy this buffer and use it as the value.
*/
- private final byte[] emptyBuffer;
+ private final byte[] emptyAggregationBuffer;
- /**
- * An empty row used by `initProjection`
- */
- private static final InternalRow emptyRow = new GenericInternalRow();
+ private final StructType aggregationBufferSchema;
- /**
- * Whether can the empty aggregation buffer be reuse without calling `initProjection` or not.
- */
- private final boolean reuseEmptyBuffer;
+ private final StructType groupingKeySchema;
/**
- * The projection used to initialize the emptyBuffer
+ * Encodes grouping keys as UnsafeRows.
*/
- private final Function1 initProjection;
-
- /**
- * Encodes grouping keys or buffers as UnsafeRows.
- */
- private final UnsafeRowConverter keyConverter;
- private final UnsafeRowConverter bufferConverter;
+ private final UnsafeRowConverter groupingKeyToUnsafeRowConverter;
/**
* A hashmap which maps from opaque bytearray keys to bytearray values.
*/
private final BytesToBytesMap map;
- /**
- * An object pool for objects that are used in grouping keys.
- */
- private final UniqueObjectPool keyPool;
-
- /**
- * An object pool for objects that are used in aggregation buffers.
- */
- private final ObjectPool bufferPool;
-
/**
* Re-used pointer to the current aggregation buffer
*/
- private final UnsafeRow currentBuffer = new UnsafeRow();
+ private final UnsafeRow currentAggregationBuffer = new UnsafeRow();
/**
* Scratch space that is used when encoding grouping keys into UnsafeRow format.
@@ -93,41 +69,69 @@ public final class UnsafeFixedWidthAggregationMap {
private final boolean enablePerfMetrics;
+ /**
+ * @return true if UnsafeFixedWidthAggregationMap supports grouping keys with the given schema,
+ * false otherwise.
+ */
+ public static boolean supportsGroupKeySchema(StructType schema) {
+ for (StructField field: schema.fields()) {
+ if (!UnsafeRow.readableFieldTypes.contains(field.dataType())) {
+ return false;
+ }
+ }
+ return true;
+ }
+
+ /**
+ * @return true if UnsafeFixedWidthAggregationMap supports aggregation buffers with the given
+ * schema, false otherwise.
+ */
+ public static boolean supportsAggregationBufferSchema(StructType schema) {
+ for (StructField field: schema.fields()) {
+ if (!UnsafeRow.settableFieldTypes.contains(field.dataType())) {
+ return false;
+ }
+ }
+ return true;
+ }
+
/**
* Create a new UnsafeFixedWidthAggregationMap.
*
- * @param initProjection the default value for new keys (a "zero" of the agg. function)
- * @param keyConverter the converter of the grouping key, used for row conversion.
- * @param bufferConverter the converter of the aggregation buffer, used for row conversion.
+ * @param emptyAggregationBuffer the default value for new keys (a "zero" of the agg. function)
+ * @param aggregationBufferSchema the schema of the aggregation buffer, used for row conversion.
+ * @param groupingKeySchema the schema of the grouping key, used for row conversion.
* @param memoryManager the memory manager used to allocate our Unsafe memory structures.
* @param initialCapacity the initial capacity of the map (a sizing hint to avoid re-hashing).
* @param enablePerfMetrics if true, performance metrics will be recorded (has minor perf impact)
*/
public UnsafeFixedWidthAggregationMap(
- Function1 initProjection,
- UnsafeRowConverter keyConverter,
- UnsafeRowConverter bufferConverter,
+ InternalRow emptyAggregationBuffer,
+ StructType aggregationBufferSchema,
+ StructType groupingKeySchema,
TaskMemoryManager memoryManager,
int initialCapacity,
boolean enablePerfMetrics) {
- this.initProjection = initProjection;
- this.keyConverter = keyConverter;
- this.bufferConverter = bufferConverter;
- this.enablePerfMetrics = enablePerfMetrics;
-
+ this.emptyAggregationBuffer =
+ convertToUnsafeRow(emptyAggregationBuffer, aggregationBufferSchema);
+ this.aggregationBufferSchema = aggregationBufferSchema;
+ this.groupingKeyToUnsafeRowConverter = new UnsafeRowConverter(groupingKeySchema);
+ this.groupingKeySchema = groupingKeySchema;
this.map = new BytesToBytesMap(memoryManager, initialCapacity, enablePerfMetrics);
- this.keyPool = new UniqueObjectPool(100);
- this.bufferPool = new ObjectPool(initialCapacity);
+ this.enablePerfMetrics = enablePerfMetrics;
+ }
- InternalRow initRow = initProjection.apply(emptyRow);
- int emptyBufferSize = bufferConverter.getSizeRequirement(initRow);
- this.emptyBuffer = new byte[emptyBufferSize];
- int writtenLength = bufferConverter.writeRow(
- initRow, emptyBuffer, PlatformDependent.BYTE_ARRAY_OFFSET, emptyBufferSize,
- bufferPool);
- assert (writtenLength == emptyBuffer.length): "Size requirement calculation was wrong!";
- // re-use the empty buffer only when there is no object saved in pool.
- reuseEmptyBuffer = bufferPool.size() == 0;
+ /**
+ * Convert a Java object row into an UnsafeRow, allocating it into a new byte array.
+ */
+ private static byte[] convertToUnsafeRow(InternalRow javaRow, StructType schema) {
+ final UnsafeRowConverter converter = new UnsafeRowConverter(schema);
+ final int size = converter.getSizeRequirement(javaRow);
+ final byte[] unsafeRow = new byte[size];
+ final int writtenLength =
+ converter.writeRow(javaRow, unsafeRow, PlatformDependent.BYTE_ARRAY_OFFSET, size);
+ assert (writtenLength == unsafeRow.length): "Size requirement calculation was wrong!";
+ return unsafeRow;
}
/**
@@ -135,17 +139,16 @@ public UnsafeFixedWidthAggregationMap(
* return the same object.
*/
public UnsafeRow getAggregationBuffer(InternalRow groupingKey) {
- final int groupingKeySize = keyConverter.getSizeRequirement(groupingKey);
+ final int groupingKeySize = groupingKeyToUnsafeRowConverter.getSizeRequirement(groupingKey);
// Make sure that the buffer is large enough to hold the key. If it's not, grow it:
if (groupingKeySize > groupingKeyConversionScratchSpace.length) {
groupingKeyConversionScratchSpace = new byte[groupingKeySize];
}
- final int actualGroupingKeySize = keyConverter.writeRow(
+ final int actualGroupingKeySize = groupingKeyToUnsafeRowConverter.writeRow(
groupingKey,
groupingKeyConversionScratchSpace,
PlatformDependent.BYTE_ARRAY_OFFSET,
- groupingKeySize,
- keyPool);
+ groupingKeySize);
assert (groupingKeySize == actualGroupingKeySize) : "Size requirement calculation was wrong!";
// Probe our map using the serialized key
@@ -156,32 +159,25 @@ public UnsafeRow getAggregationBuffer(InternalRow groupingKey) {
if (!loc.isDefined()) {
// This is the first time that we've seen this grouping key, so we'll insert a copy of the
// empty aggregation buffer into the map:
- if (!reuseEmptyBuffer) {
- // There is some objects referenced by emptyBuffer, so generate a new one
- InternalRow initRow = initProjection.apply(emptyRow);
- bufferConverter.writeRow(initRow, emptyBuffer, PlatformDependent.BYTE_ARRAY_OFFSET,
- groupingKeySize, bufferPool);
- }
loc.putNewKey(
groupingKeyConversionScratchSpace,
PlatformDependent.BYTE_ARRAY_OFFSET,
groupingKeySize,
- emptyBuffer,
+ emptyAggregationBuffer,
PlatformDependent.BYTE_ARRAY_OFFSET,
- emptyBuffer.length
+ emptyAggregationBuffer.length
);
}
// Reset the pointer to point to the value that we just stored or looked up:
final MemoryLocation address = loc.getValueAddress();
- currentBuffer.pointTo(
+ currentAggregationBuffer.pointTo(
address.getBaseObject(),
address.getBaseOffset(),
- bufferConverter.numFields(),
- loc.getValueLength(),
- bufferPool
+ aggregationBufferSchema.length(),
+ loc.getValueLength()
);
- return currentBuffer;
+ return currentAggregationBuffer;
}
/**
@@ -217,16 +213,14 @@ public MapEntry next() {
entry.key.pointTo(
keyAddress.getBaseObject(),
keyAddress.getBaseOffset(),
- keyConverter.numFields(),
- loc.getKeyLength(),
- keyPool
+ groupingKeySchema.length(),
+ loc.getKeyLength()
);
entry.value.pointTo(
valueAddress.getBaseObject(),
valueAddress.getBaseOffset(),
- bufferConverter.numFields(),
- loc.getValueLength(),
- bufferPool
+ aggregationBufferSchema.length(),
+ loc.getValueLength()
);
return entry;
}
@@ -254,8 +248,6 @@ public void printPerfMetrics() {
System.out.println("Number of hash collisions: " + map.getNumHashCollisions());
System.out.println("Time spent resizing (ns): " + map.getTimeSpentResizingNs());
System.out.println("Total memory consumption (bytes): " + map.getTotalMemoryConsumption());
- System.out.println("Number of unique objects in keys: " + keyPool.size());
- System.out.println("Number of objects in buffers: " + bufferPool.size());
}
}
diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java
index 4b99030d1046f..fa1216b455a9e 100644
--- a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java
+++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java
@@ -17,12 +17,21 @@
package org.apache.spark.sql.catalyst.expressions;
-import org.apache.spark.sql.catalyst.InternalRow;
-import org.apache.spark.sql.catalyst.util.ObjectPool;
+import java.io.IOException;
+import java.io.OutputStream;
+import java.util.Arrays;
+import java.util.Collections;
+import java.util.HashSet;
+import java.util.Set;
+
+import org.apache.spark.sql.types.DataType;
import org.apache.spark.unsafe.PlatformDependent;
+import org.apache.spark.unsafe.array.ByteArrayMethods;
import org.apache.spark.unsafe.bitset.BitSetMethods;
+import org.apache.spark.unsafe.hash.Murmur3_x86_32;
import org.apache.spark.unsafe.types.UTF8String;
+import static org.apache.spark.sql.types.DataTypes.*;
/**
* An Unsafe implementation of Row which is backed by raw memory instead of Java objects.
@@ -36,20 +45,7 @@
* primitive types, such as long, double, or int, we store the value directly in the word. For
* fields with non-primitive or variable-length values, we store a relative offset (w.r.t. the
* base address of the row) that points to the beginning of the variable-length field, and length
- * (they are combined into a long). For other objects, they are stored in a pool, the indexes of
- * them are hold in the the word.
- *
- * In order to support fast hashing and equality checks for UnsafeRows that contain objects
- * when used as grouping key in BytesToBytesMap, we put the objects in an UniqueObjectPool to make
- * sure all the key have the same index for same object, then we can hash/compare the objects by
- * hash/compare the index.
- *
- * For non-primitive types, the word of a field could be:
- * UNION {
- * [1] [offset: 31bits] [length: 31bits] // StringType
- * [0] [offset: 31bits] [length: 31bits] // BinaryType
- * - [index: 63bits] // StringType, Binary, index to object in pool
- * }
+ * (they are combined into a long).
*
* Instances of `UnsafeRow` act as pointers to row data stored in this format.
*/
@@ -58,13 +54,9 @@ public final class UnsafeRow extends MutableRow {
private Object baseObject;
private long baseOffset;
- /** A pool to hold non-primitive objects */
- private ObjectPool pool;
-
public Object getBaseObject() { return baseObject; }
public long getBaseOffset() { return baseOffset; }
public int getSizeInBytes() { return sizeInBytes; }
- public ObjectPool getPool() { return pool; }
/** The number of fields in this row, used for calculating the bitset width (and in assertions) */
private int numFields;
@@ -85,7 +77,42 @@ public static int calculateBitSetWidthInBytes(int numFields) {
return ((numFields / 64) + (numFields % 64 == 0 ? 0 : 1)) * 8;
}
- public static final long OFFSET_BITS = 31L;
+ /**
+ * Field types that can be updated in place in UnsafeRows (e.g. we support set() for these types)
+ */
+ public static final Set settableFieldTypes;
+
+ /**
+ * Fields types can be read(but not set (e.g. set() will throw UnsupportedOperationException).
+ */
+ public static final Set readableFieldTypes;
+
+ // TODO: support DecimalType
+ static {
+ settableFieldTypes = Collections.unmodifiableSet(
+ new HashSet<>(
+ Arrays.asList(new DataType[] {
+ NullType,
+ BooleanType,
+ ByteType,
+ ShortType,
+ IntegerType,
+ LongType,
+ FloatType,
+ DoubleType,
+ DateType,
+ TimestampType
+ })));
+
+ // We support get() on a superset of the types for which we support set():
+ final Set _readableFieldTypes = new HashSet<>(
+ Arrays.asList(new DataType[]{
+ StringType,
+ BinaryType
+ }));
+ _readableFieldTypes.addAll(settableFieldTypes);
+ readableFieldTypes = Collections.unmodifiableSet(_readableFieldTypes);
+ }
/**
* Construct a new UnsafeRow. The resulting row won't be usable until `pointTo()` has been called,
@@ -100,17 +127,14 @@ public UnsafeRow() { }
* @param baseOffset the offset within the base object
* @param numFields the number of fields in this row
* @param sizeInBytes the size of this row's backing data, in bytes
- * @param pool the object pool to hold arbitrary objects
*/
- public void pointTo(
- Object baseObject, long baseOffset, int numFields, int sizeInBytes, ObjectPool pool) {
+ public void pointTo(Object baseObject, long baseOffset, int numFields, int sizeInBytes) {
assert numFields >= 0 : "numFields should >= 0";
this.bitSetWidthInBytes = calculateBitSetWidthInBytes(numFields);
this.baseObject = baseObject;
this.baseOffset = baseOffset;
this.numFields = numFields;
this.sizeInBytes = sizeInBytes;
- this.pool = pool;
}
private void assertIndexIsValid(int index) {
@@ -133,68 +157,9 @@ private void setNotNullAt(int i) {
BitSetMethods.unset(baseObject, baseOffset, i);
}
- /**
- * Updates the column `i` as Object `value`, which cannot be primitive types.
- */
@Override
- public void update(int i, Object value) {
- if (value == null) {
- if (!isNullAt(i)) {
- // remove the old value from pool
- long idx = getLong(i);
- if (idx <= 0) {
- // this is the index of old value in pool, remove it
- pool.replace((int)-idx, null);
- } else {
- // there will be some garbage left (UTF8String or byte[])
- }
- setNullAt(i);
- }
- return;
- }
-
- if (isNullAt(i)) {
- // there is not an old value, put the new value into pool
- int idx = pool.put(value);
- setLong(i, (long)-idx);
- } else {
- // there is an old value, check the type, then replace it or update it
- long v = getLong(i);
- if (v <= 0) {
- // it's the index in the pool, replace old value with new one
- int idx = (int)-v;
- pool.replace(idx, value);
- } else {
- // old value is UTF8String or byte[], try to reuse the space
- boolean isString;
- byte[] newBytes;
- if (value instanceof UTF8String) {
- newBytes = ((UTF8String) value).getBytes();
- isString = true;
- } else {
- newBytes = (byte[]) value;
- isString = false;
- }
- int offset = (int) ((v >> OFFSET_BITS) & Integer.MAX_VALUE);
- int oldLength = (int) (v & Integer.MAX_VALUE);
- if (newBytes.length <= oldLength) {
- // the new value can fit in the old buffer, re-use it
- PlatformDependent.copyMemory(
- newBytes,
- PlatformDependent.BYTE_ARRAY_OFFSET,
- baseObject,
- baseOffset + offset,
- newBytes.length);
- long flag = isString ? 1L << (OFFSET_BITS * 2) : 0L;
- setLong(i, flag | (((long) offset) << OFFSET_BITS) | (long) newBytes.length);
- } else {
- // Cannot fit in the buffer
- int idx = pool.put(value);
- setLong(i, (long) -idx);
- }
- }
- }
- setNotNullAt(i);
+ public void update(int ordinal, Object value) {
+ throw new UnsupportedOperationException();
}
@Override
@@ -215,6 +180,9 @@ public void setLong(int ordinal, long value) {
public void setDouble(int ordinal, double value) {
assertIndexIsValid(ordinal);
setNotNullAt(ordinal);
+ if (Double.isNaN(value)) {
+ value = Double.NaN;
+ }
PlatformDependent.UNSAFE.putDouble(baseObject, getFieldOffset(ordinal), value);
}
@@ -243,6 +211,9 @@ public void setByte(int ordinal, byte value) {
public void setFloat(int ordinal, float value) {
assertIndexIsValid(ordinal);
setNotNullAt(ordinal);
+ if (Float.isNaN(value)) {
+ value = Float.NaN;
+ }
PlatformDependent.UNSAFE.putFloat(baseObject, getFieldOffset(ordinal), value);
}
@@ -251,40 +222,9 @@ public int size() {
return numFields;
}
- /**
- * Returns the object for column `i`, which should not be primitive type.
- */
@Override
public Object get(int i) {
- assertIndexIsValid(i);
- if (isNullAt(i)) {
- return null;
- }
- long v = PlatformDependent.UNSAFE.getLong(baseObject, getFieldOffset(i));
- if (v <= 0) {
- // It's an index to object in the pool.
- int idx = (int)-v;
- return pool.get(idx);
- } else {
- // The column could be StingType or BinaryType
- boolean isString = (v >> (OFFSET_BITS * 2)) > 0;
- int offset = (int) ((v >> OFFSET_BITS) & Integer.MAX_VALUE);
- int size = (int) (v & Integer.MAX_VALUE);
- final byte[] bytes = new byte[size];
- // TODO(davies): Avoid the copy once we can manage the life cycle of Row well.
- PlatformDependent.copyMemory(
- baseObject,
- baseOffset + offset,
- bytes,
- PlatformDependent.BYTE_ARRAY_OFFSET,
- size
- );
- if (isString) {
- return UTF8String.fromBytes(bytes);
- } else {
- return bytes;
- }
- }
+ throw new UnsupportedOperationException();
}
@Override
@@ -343,6 +283,38 @@ public double getDouble(int i) {
}
}
+ @Override
+ public UTF8String getUTF8String(int i) {
+ assertIndexIsValid(i);
+ return isNullAt(i) ? null : UTF8String.fromBytes(getBinary(i));
+ }
+
+ @Override
+ public byte[] getBinary(int i) {
+ if (isNullAt(i)) {
+ return null;
+ } else {
+ assertIndexIsValid(i);
+ final long offsetAndSize = getLong(i);
+ final int offset = (int) (offsetAndSize >> 32);
+ final int size = (int) (offsetAndSize & ((1L << 32) - 1));
+ final byte[] bytes = new byte[size];
+ PlatformDependent.copyMemory(
+ baseObject,
+ baseOffset + offset,
+ bytes,
+ PlatformDependent.BYTE_ARRAY_OFFSET,
+ size
+ );
+ return bytes;
+ }
+ }
+
+ @Override
+ public String getString(int i) {
+ return getUTF8String(i).toString();
+ }
+
/**
* Copies this row, returning a self-contained UnsafeRow that stores its data in an internal
* byte array rather than referencing data stored in a data page.
@@ -350,28 +322,95 @@ public double getDouble(int i) {
* This method is only supported on UnsafeRows that do not use ObjectPools.
*/
@Override
- public InternalRow copy() {
- if (pool != null) {
- throw new UnsupportedOperationException(
- "Copy is not supported for UnsafeRows that use object pools");
+ public UnsafeRow copy() {
+ UnsafeRow rowCopy = new UnsafeRow();
+ final byte[] rowDataCopy = new byte[sizeInBytes];
+ PlatformDependent.copyMemory(
+ baseObject,
+ baseOffset,
+ rowDataCopy,
+ PlatformDependent.BYTE_ARRAY_OFFSET,
+ sizeInBytes
+ );
+ rowCopy.pointTo(rowDataCopy, PlatformDependent.BYTE_ARRAY_OFFSET, numFields, sizeInBytes);
+ return rowCopy;
+ }
+
+ /**
+ * Write this UnsafeRow's underlying bytes to the given OutputStream.
+ *
+ * @param out the stream to write to.
+ * @param writeBuffer a byte array for buffering chunks of off-heap data while writing to the
+ * output stream. If this row is backed by an on-heap byte array, then this
+ * buffer will not be used and may be null.
+ */
+ public void writeToStream(OutputStream out, byte[] writeBuffer) throws IOException {
+ if (baseObject instanceof byte[]) {
+ int offsetInByteArray = (int) (PlatformDependent.BYTE_ARRAY_OFFSET - baseOffset);
+ out.write((byte[]) baseObject, offsetInByteArray, sizeInBytes);
} else {
- UnsafeRow rowCopy = new UnsafeRow();
- final byte[] rowDataCopy = new byte[sizeInBytes];
- PlatformDependent.copyMemory(
- baseObject,
- baseOffset,
- rowDataCopy,
- PlatformDependent.BYTE_ARRAY_OFFSET,
- sizeInBytes
- );
- rowCopy.pointTo(
- rowDataCopy, PlatformDependent.BYTE_ARRAY_OFFSET, numFields, sizeInBytes, null);
- return rowCopy;
+ int dataRemaining = sizeInBytes;
+ long rowReadPosition = baseOffset;
+ while (dataRemaining > 0) {
+ int toTransfer = Math.min(writeBuffer.length, dataRemaining);
+ PlatformDependent.copyMemory(
+ baseObject,
+ rowReadPosition,
+ writeBuffer,
+ PlatformDependent.BYTE_ARRAY_OFFSET,
+ toTransfer);
+ out.write(writeBuffer, 0, toTransfer);
+ rowReadPosition += toTransfer;
+ dataRemaining -= toTransfer;
+ }
+ }
+ }
+
+ @Override
+ public int hashCode() {
+ return Murmur3_x86_32.hashUnsafeWords(baseObject, baseOffset, sizeInBytes, 42);
+ }
+
+ @Override
+ public boolean equals(Object other) {
+ if (other instanceof UnsafeRow) {
+ UnsafeRow o = (UnsafeRow) other;
+ return (sizeInBytes == o.sizeInBytes) &&
+ ByteArrayMethods.arrayEquals(baseObject, baseOffset, o.baseObject, o.baseOffset,
+ sizeInBytes);
+ }
+ return false;
+ }
+
+ /**
+ * Returns the underlying bytes for this UnsafeRow.
+ */
+ public byte[] getBytes() {
+ if (baseObject instanceof byte[] && baseOffset == PlatformDependent.BYTE_ARRAY_OFFSET
+ && (((byte[]) baseObject).length == sizeInBytes)) {
+ return (byte[]) baseObject;
+ } else {
+ byte[] bytes = new byte[sizeInBytes];
+ PlatformDependent.copyMemory(baseObject, baseOffset, bytes,
+ PlatformDependent.BYTE_ARRAY_OFFSET, sizeInBytes);
+ return bytes;
+ }
+ }
+
+ // This is for debugging
+ @Override
+ public String toString() {
+ StringBuilder build = new StringBuilder("[");
+ for (int i = 0; i < sizeInBytes; i += 8) {
+ build.append(PlatformDependent.UNSAFE.getLong(baseObject, baseOffset + i));
+ build.append(',');
}
+ build.append(']');
+ return build.toString();
}
@Override
public boolean anyNull() {
- return BitSetMethods.anySet(baseObject, baseOffset, bitSetWidthInBytes);
+ return BitSetMethods.anySet(baseObject, baseOffset, bitSetWidthInBytes / 8);
}
}
diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/util/ObjectPool.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/util/ObjectPool.java
deleted file mode 100644
index 97f89a7d0b758..0000000000000
--- a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/util/ObjectPool.java
+++ /dev/null
@@ -1,78 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.sql.catalyst.util;
-
-/**
- * A object pool stores a collection of objects in array, then they can be referenced by the
- * pool plus an index.
- */
-public class ObjectPool {
-
- /**
- * An array to hold objects, which will grow as needed.
- */
- private Object[] objects;
-
- /**
- * How many objects in the pool.
- */
- private int numObj;
-
- public ObjectPool(int capacity) {
- objects = new Object[capacity];
- numObj = 0;
- }
-
- /**
- * Returns how many objects in the pool.
- */
- public int size() {
- return numObj;
- }
-
- /**
- * Returns the object at position `idx` in the array.
- */
- public Object get(int idx) {
- assert (idx < numObj);
- return objects[idx];
- }
-
- /**
- * Puts an object `obj` at the end of array, returns the index of it.
- *
- * The array will grow as needed.
- */
- public int put(Object obj) {
- if (numObj >= objects.length) {
- Object[] tmp = new Object[objects.length * 2];
- System.arraycopy(objects, 0, tmp, 0, objects.length);
- objects = tmp;
- }
- objects[numObj++] = obj;
- return numObj - 1;
- }
-
- /**
- * Replaces the object at `idx` with new one `obj`.
- */
- public void replace(int idx, Object obj) {
- assert (idx < numObj);
- objects[idx] = obj;
- }
-}
diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/util/UniqueObjectPool.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/util/UniqueObjectPool.java
deleted file mode 100644
index d512392dcaacc..0000000000000
--- a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/util/UniqueObjectPool.java
+++ /dev/null
@@ -1,59 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.sql.catalyst.util;
-
-import java.util.HashMap;
-
-/**
- * An unique object pool stores a collection of unique objects in it.
- */
-public class UniqueObjectPool extends ObjectPool {
-
- /**
- * A hash map from objects to their indexes in the array.
- */
- private HashMap