diff --git a/R/install-dev.bat b/R/install-dev.bat index 008a5c668bc45..f32670b67de96 100644 --- a/R/install-dev.bat +++ b/R/install-dev.bat @@ -25,3 +25,8 @@ set SPARK_HOME=%~dp0.. MKDIR %SPARK_HOME%\R\lib R.exe CMD INSTALL --library="%SPARK_HOME%\R\lib" %SPARK_HOME%\R\pkg\ + +rem Zip the SparkR package so that it can be distributed to worker nodes on YARN +pushd %SPARK_HOME%\R\lib +%JAVA_HOME%\bin\jar.exe cfM "%SPARK_HOME%\R\lib\sparkr.zip" SparkR +popd diff --git a/R/install-dev.sh b/R/install-dev.sh index 1edd551f8d243..4972bb9217072 100755 --- a/R/install-dev.sh +++ b/R/install-dev.sh @@ -34,7 +34,7 @@ LIB_DIR="$FWDIR/lib" mkdir -p $LIB_DIR -pushd $FWDIR +pushd $FWDIR > /dev/null # Generate Rd files if devtools is installed Rscript -e ' if("devtools" %in% rownames(installed.packages())) { library(devtools); devtools::document(pkg="./pkg", roclets=c("rd")) }' @@ -42,4 +42,8 @@ Rscript -e ' if("devtools" %in% rownames(installed.packages())) { library(devtoo # Install SparkR to $LIB_DIR R CMD INSTALL --library=$LIB_DIR $FWDIR/pkg/ -popd +# Zip the SparkR package so that it can be distributed to worker nodes on YARN +cd $LIB_DIR +jar cfM "$LIB_DIR/sparkr.zip" SparkR + +popd > /dev/null diff --git a/R/pkg/DESCRIPTION b/R/pkg/DESCRIPTION index efc85bbc4b316..4949d86d20c91 100644 --- a/R/pkg/DESCRIPTION +++ b/R/pkg/DESCRIPTION @@ -29,7 +29,7 @@ Collate: 'client.R' 'context.R' 'deserialize.R' + 'mllib.R' 'serialize.R' 'sparkR.R' 'utils.R' - 'zzz.R' diff --git a/R/pkg/NAMESPACE b/R/pkg/NAMESPACE index 7f857222452d4..7f7a8a2e4de24 100644 --- a/R/pkg/NAMESPACE +++ b/R/pkg/NAMESPACE @@ -10,6 +10,10 @@ export("sparkR.init") export("sparkR.stop") export("print.jobj") +# MLlib integration +exportMethods("glm", + "predict") + # Job group lifecycle management methods export("setJobGroup", "clearJobGroup", @@ -22,6 +26,7 @@ exportMethods("arrange", "collect", "columns", "count", + "crosstab", "describe", "distinct", "dropna", @@ -77,6 +82,7 @@ exportMethods("abs", "atan", "atan2", "avg", + "between", "cast", "cbrt", "ceiling", diff --git a/R/pkg/R/DataFrame.R b/R/pkg/R/DataFrame.R index 60702824acb46..06dd6b75dff3d 100644 --- a/R/pkg/R/DataFrame.R +++ b/R/pkg/R/DataFrame.R @@ -1314,7 +1314,7 @@ setMethod("except", #' write.df(df, "myfile", "parquet", "overwrite") #' } setMethod("write.df", - signature(df = "DataFrame", path = 'character'), + signature(df = "DataFrame", path = "character"), function(df, path, source = NULL, mode = "append", ...){ if (is.null(source)) { sqlContext <- get(".sparkRSQLsc", envir = .sparkREnv) @@ -1328,7 +1328,7 @@ setMethod("write.df", jmode <- callJStatic("org.apache.spark.sql.api.r.SQLUtils", "saveMode", mode) options <- varargsToEnv(...) if (!is.null(path)) { - options[['path']] = path + options[["path"]] <- path } callJMethod(df@sdf, "save", source, jmode, options) }) @@ -1337,7 +1337,7 @@ setMethod("write.df", #' @aliases saveDF #' @export setMethod("saveDF", - signature(df = "DataFrame", path = 'character'), + signature(df = "DataFrame", path = "character"), function(df, path, source = NULL, mode = "append", ...){ write.df(df, path, source, mode, ...) }) @@ -1375,8 +1375,8 @@ setMethod("saveDF", #' saveAsTable(df, "myfile") #' } setMethod("saveAsTable", - signature(df = "DataFrame", tableName = 'character', source = 'character', - mode = 'character'), + signature(df = "DataFrame", tableName = "character", source = "character", + mode = "character"), function(df, tableName, source = NULL, mode="append", ...){ if (is.null(source)) { sqlContext <- get(".sparkRSQLsc", envir = .sparkREnv) @@ -1554,3 +1554,31 @@ setMethod("fillna", } dataFrame(sdf) }) + +#' crosstab +#' +#' Computes a pair-wise frequency table of the given columns. Also known as a contingency +#' table. The number of distinct values for each column should be less than 1e4. At most 1e6 +#' non-zero pair frequencies will be returned. +#' +#' @param col1 name of the first column. Distinct items will make the first item of each row. +#' @param col2 name of the second column. Distinct items will make the column names of the output. +#' @return a local R data.frame representing the contingency table. The first column of each row +#' will be the distinct values of `col1` and the column names will be the distinct values +#' of `col2`. The name of the first column will be `$col1_$col2`. Pairs that have no +#' occurrences will have `null` as their counts. +#' +#' @rdname statfunctions +#' @export +#' @examples +#' \dontrun{ +#' df <- jsonFile(sqlCtx, "/path/to/file.json") +#' ct = crosstab(df, "title", "gender") +#' } +setMethod("crosstab", + signature(x = "DataFrame", col1 = "character", col2 = "character"), + function(x, col1, col2) { + statFunctions <- callJMethod(x@sdf, "stat") + sct <- callJMethod(statFunctions, "crosstab", col1, col2) + collect(dataFrame(sct)) + }) diff --git a/R/pkg/R/RDD.R b/R/pkg/R/RDD.R index 89511141d3ef7..d2d096709245d 100644 --- a/R/pkg/R/RDD.R +++ b/R/pkg/R/RDD.R @@ -165,7 +165,6 @@ setMethod("getJRDD", signature(rdd = "PipelinedRDD"), serializedFuncArr, rdd@env$prev_serializedMode, packageNamesArr, - as.character(.sparkREnv[["libname"]]), broadcastArr, callJMethod(prev_jrdd, "classTag")) } else { @@ -175,7 +174,6 @@ setMethod("getJRDD", signature(rdd = "PipelinedRDD"), rdd@env$prev_serializedMode, serializedMode, packageNamesArr, - as.character(.sparkREnv[["libname"]]), broadcastArr, callJMethod(prev_jrdd, "classTag")) } diff --git a/R/pkg/R/SQLContext.R b/R/pkg/R/SQLContext.R index 30978bb50d339..110117a18ccbc 100644 --- a/R/pkg/R/SQLContext.R +++ b/R/pkg/R/SQLContext.R @@ -457,7 +457,7 @@ dropTempTable <- function(sqlContext, tableName) { read.df <- function(sqlContext, path = NULL, source = NULL, schema = NULL, ...) { options <- varargsToEnv(...) if (!is.null(path)) { - options[['path']] <- path + options[["path"]] <- path } if (is.null(source)) { sqlContext <- get(".sparkRSQLsc", envir = .sparkREnv) @@ -506,7 +506,7 @@ loadDF <- function(sqlContext, path = NULL, source = NULL, schema = NULL, ...) { createExternalTable <- function(sqlContext, tableName, path = NULL, source = NULL, ...) { options <- varargsToEnv(...) if (!is.null(path)) { - options[['path']] <- path + options[["path"]] <- path } sdf <- callJMethod(sqlContext, "createExternalTable", tableName, source, options) dataFrame(sdf) diff --git a/R/pkg/R/client.R b/R/pkg/R/client.R index 78c7a3037ffac..6f772158ddfe8 100644 --- a/R/pkg/R/client.R +++ b/R/pkg/R/client.R @@ -36,9 +36,9 @@ connectBackend <- function(hostname, port, timeout = 6000) { determineSparkSubmitBin <- function() { if (.Platform$OS.type == "unix") { - sparkSubmitBinName = "spark-submit" + sparkSubmitBinName <- "spark-submit" } else { - sparkSubmitBinName = "spark-submit.cmd" + sparkSubmitBinName <- "spark-submit.cmd" } sparkSubmitBinName } diff --git a/R/pkg/R/column.R b/R/pkg/R/column.R index 8e4b0f5bf1c4d..2892e1416cc65 100644 --- a/R/pkg/R/column.R +++ b/R/pkg/R/column.R @@ -187,6 +187,23 @@ setMethod("substr", signature(x = "Column"), column(jc) }) +#' between +#' +#' Test if the column is between the lower bound and upper bound, inclusive. +#' +#' @rdname column +#' +#' @param bounds lower and upper bounds +setMethod("between", signature(x = "Column"), + function(x, bounds) { + if (is.vector(bounds) && length(bounds) == 2) { + jc <- callJMethod(x@jc, "between", bounds[1], bounds[2]) + column(jc) + } else { + stop("bounds should be a vector of lower and upper bounds") + } + }) + #' Casts the column to a different data type. #' #' @rdname column diff --git a/R/pkg/R/deserialize.R b/R/pkg/R/deserialize.R index d961bbc383688..7d1f6b0819ed0 100644 --- a/R/pkg/R/deserialize.R +++ b/R/pkg/R/deserialize.R @@ -23,6 +23,7 @@ # Int -> integer # String -> character # Boolean -> logical +# Float -> double # Double -> double # Long -> double # Array[Byte] -> raw diff --git a/R/pkg/R/generics.R b/R/pkg/R/generics.R index fad9d71158c51..836e0175c391f 100644 --- a/R/pkg/R/generics.R +++ b/R/pkg/R/generics.R @@ -59,6 +59,10 @@ setGeneric("count", function(x) { standardGeneric("count") }) # @export setGeneric("countByValue", function(x) { standardGeneric("countByValue") }) +# @rdname statfunctions +# @export +setGeneric("crosstab", function(x, col1, col2) { standardGeneric("crosstab") }) + # @rdname distinct # @export setGeneric("distinct", function(x, numPartitions = 1) { standardGeneric("distinct") }) @@ -567,6 +571,10 @@ setGeneric("asc", function(x) { standardGeneric("asc") }) #' @export setGeneric("avg", function(x, ...) { standardGeneric("avg") }) +#' @rdname column +#' @export +setGeneric("between", function(x, bounds) { standardGeneric("between") }) + #' @rdname column #' @export setGeneric("cast", function(x, dataType) { standardGeneric("cast") }) @@ -657,3 +665,7 @@ setGeneric("toRadians", function(x) { standardGeneric("toRadians") }) #' @rdname column #' @export setGeneric("upper", function(x) { standardGeneric("upper") }) + +#' @rdname glm +#' @export +setGeneric("glm") diff --git a/R/pkg/R/group.R b/R/pkg/R/group.R index 8f1c68f7c4d28..576ac72f40fc0 100644 --- a/R/pkg/R/group.R +++ b/R/pkg/R/group.R @@ -87,7 +87,7 @@ setMethod("count", setMethod("agg", signature(x = "GroupedData"), function(x, ...) { - cols = list(...) + cols <- list(...) stopifnot(length(cols) > 0) if (is.character(cols[[1]])) { cols <- varargsToEnv(...) @@ -97,7 +97,7 @@ setMethod("agg", if (!is.null(ns)) { for (n in ns) { if (n != "") { - cols[[n]] = alias(cols[[n]], n) + cols[[n]] <- alias(cols[[n]], n) } } } diff --git a/R/pkg/R/mllib.R b/R/pkg/R/mllib.R new file mode 100644 index 0000000000000..258e354081fc1 --- /dev/null +++ b/R/pkg/R/mllib.R @@ -0,0 +1,73 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +# mllib.R: Provides methods for MLlib integration + +#' @title S4 class that represents a PipelineModel +#' @param model A Java object reference to the backing Scala PipelineModel +#' @export +setClass("PipelineModel", representation(model = "jobj")) + +#' Fits a generalized linear model +#' +#' Fits a generalized linear model, similarly to R's glm(). Also see the glmnet package. +#' +#' @param formula A symbolic description of the model to be fitted. Currently only a few formula +#' operators are supported, including '~' and '+'. +#' @param data DataFrame for training +#' @param family Error distribution. "gaussian" -> linear regression, "binomial" -> logistic reg. +#' @param lambda Regularization parameter +#' @param alpha Elastic-net mixing parameter (see glmnet's documentation for details) +#' @return a fitted MLlib model +#' @rdname glm +#' @export +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' sqlContext <- sparkRSQL.init(sc) +#' data(iris) +#' df <- createDataFrame(sqlContext, iris) +#' model <- glm(Sepal_Length ~ Sepal_Width, df) +#'} +setMethod("glm", signature(formula = "formula", family = "ANY", data = "DataFrame"), + function(formula, family = c("gaussian", "binomial"), data, lambda = 0, alpha = 0) { + family <- match.arg(family) + model <- callJStatic("org.apache.spark.ml.api.r.SparkRWrappers", + "fitRModelFormula", deparse(formula), data@sdf, family, lambda, + alpha) + return(new("PipelineModel", model = model)) + }) + +#' Make predictions from a model +#' +#' Makes predictions from a model produced by glm(), similarly to R's predict(). +#' +#' @param model A fitted MLlib model +#' @param newData DataFrame for testing +#' @return DataFrame containing predicted values +#' @rdname glm +#' @export +#' @examples +#'\dontrun{ +#' model <- glm(y ~ x, trainingData) +#' predicted <- predict(model, testData) +#' showDF(predicted) +#'} +setMethod("predict", signature(object = "PipelineModel"), + function(object, newData) { + return(dataFrame(callJMethod(object@model, "transform", newData@sdf))) + }) diff --git a/R/pkg/R/pairRDD.R b/R/pkg/R/pairRDD.R index 0f1179e0aa51a..ebc6ff65e9d0f 100644 --- a/R/pkg/R/pairRDD.R +++ b/R/pkg/R/pairRDD.R @@ -215,7 +215,6 @@ setMethod("partitionBy", serializedHashFuncBytes, getSerializedMode(x), packageNamesArr, - as.character(.sparkREnv$libname), broadcastArr, callJMethod(jrdd, "classTag")) diff --git a/R/pkg/R/schema.R b/R/pkg/R/schema.R index 15e2bdbd55d79..79c744ef29c23 100644 --- a/R/pkg/R/schema.R +++ b/R/pkg/R/schema.R @@ -69,11 +69,14 @@ structType.structField <- function(x, ...) { #' @param ... further arguments passed to or from other methods print.structType <- function(x, ...) { cat("StructType\n", - sapply(x$fields(), function(field) { paste("|-", "name = \"", field$name(), - "\", type = \"", field$dataType.toString(), - "\", nullable = ", field$nullable(), "\n", - sep = "") }) - , sep = "") + sapply(x$fields(), + function(field) { + paste("|-", "name = \"", field$name(), + "\", type = \"", field$dataType.toString(), + "\", nullable = ", field$nullable(), "\n", + sep = "") + }), + sep = "") } #' structField @@ -123,6 +126,7 @@ structField.character <- function(x, type, nullable = TRUE) { } options <- c("byte", "integer", + "float", "double", "numeric", "character", diff --git a/R/pkg/R/serialize.R b/R/pkg/R/serialize.R index 78535eff0d2f6..311021e5d8473 100644 --- a/R/pkg/R/serialize.R +++ b/R/pkg/R/serialize.R @@ -140,8 +140,8 @@ writeType <- function(con, class) { jobj = "j", environment = "e", Date = "D", - POSIXlt = 't', - POSIXct = 't', + POSIXlt = "t", + POSIXct = "t", stop(paste("Unsupported type for serialization", class))) writeBin(charToRaw(type), con) } diff --git a/R/pkg/R/sparkR.R b/R/pkg/R/sparkR.R index 048eb8ed541e4..79b79d70943cb 100644 --- a/R/pkg/R/sparkR.R +++ b/R/pkg/R/sparkR.R @@ -17,10 +17,6 @@ .sparkREnv <- new.env() -sparkR.onLoad <- function(libname, pkgname) { - .sparkREnv$libname <- libname -} - # Utility function that returns TRUE if we have an active connection to the # backend and FALSE otherwise connExists <- function(env) { @@ -80,7 +76,6 @@ sparkR.stop <- function() { #' @param sparkEnvir Named list of environment variables to set on worker nodes. #' @param sparkExecutorEnv Named list of environment variables to be used when launching executors. #' @param sparkJars Character string vector of jar files to pass to the worker nodes. -#' @param sparkRLibDir The path where R is installed on the worker nodes. #' @param sparkPackages Character string vector of packages from spark-packages.org #' @export #' @examples @@ -101,7 +96,6 @@ sparkR.init <- function( sparkEnvir = list(), sparkExecutorEnv = list(), sparkJars = "", - sparkRLibDir = "", sparkPackages = "") { if (exists(".sparkRjsc", envir = .sparkREnv)) { @@ -146,7 +140,7 @@ sparkR.init <- function( if (!file.exists(path)) { stop("JVM is not ready after 10 seconds") } - f <- file(path, open='rb') + f <- file(path, open="rb") backendPort <- readInt(f) monitorPort <- readInt(f) close(f) @@ -170,10 +164,6 @@ sparkR.init <- function( sparkHome <- normalizePath(sparkHome) } - if (nchar(sparkRLibDir) != 0) { - .sparkREnv$libname <- sparkRLibDir - } - sparkEnvirMap <- new.env() for (varname in names(sparkEnvir)) { sparkEnvirMap[[varname]] <- sparkEnvir[[varname]] diff --git a/R/pkg/R/utils.R b/R/pkg/R/utils.R index ea629a64f7158..3f45589a50443 100644 --- a/R/pkg/R/utils.R +++ b/R/pkg/R/utils.R @@ -41,8 +41,8 @@ convertJListToRList <- function(jList, flatten, logicalUpperBound = NULL, if (isInstanceOf(obj, "scala.Tuple2")) { # JavaPairRDD[Array[Byte], Array[Byte]]. - keyBytes = callJMethod(obj, "_1") - valBytes = callJMethod(obj, "_2") + keyBytes <- callJMethod(obj, "_1") + valBytes <- callJMethod(obj, "_2") res <- list(unserialize(keyBytes), unserialize(valBytes)) } else { @@ -390,14 +390,17 @@ processClosure <- function(node, oldEnv, defVars, checkedFuncs, newEnv) { for (i in 1:nodeLen) { processClosure(node[[i]], oldEnv, defVars, checkedFuncs, newEnv) } - } else { # if node[[1]] is length of 1, check for some R special functions. + } else { + # if node[[1]] is length of 1, check for some R special functions. nodeChar <- as.character(node[[1]]) - if (nodeChar == "{" || nodeChar == "(") { # Skip start symbol. + if (nodeChar == "{" || nodeChar == "(") { + # Skip start symbol. for (i in 2:nodeLen) { processClosure(node[[i]], oldEnv, defVars, checkedFuncs, newEnv) } } else if (nodeChar == "<-" || nodeChar == "=" || - nodeChar == "<<-") { # Assignment Ops. + nodeChar == "<<-") { + # Assignment Ops. defVar <- node[[2]] if (length(defVar) == 1 && typeof(defVar) == "symbol") { # Add the defined variable name into defVars. @@ -408,14 +411,16 @@ processClosure <- function(node, oldEnv, defVars, checkedFuncs, newEnv) { for (i in 3:nodeLen) { processClosure(node[[i]], oldEnv, defVars, checkedFuncs, newEnv) } - } else if (nodeChar == "function") { # Function definition. + } else if (nodeChar == "function") { + # Function definition. # Add parameter names. newArgs <- names(node[[2]]) lapply(newArgs, function(arg) { addItemToAccumulator(defVars, arg) }) for (i in 3:nodeLen) { processClosure(node[[i]], oldEnv, defVars, checkedFuncs, newEnv) } - } else if (nodeChar == "$") { # Skip the field. + } else if (nodeChar == "$") { + # Skip the field. processClosure(node[[2]], oldEnv, defVars, checkedFuncs, newEnv) } else if (nodeChar == "::" || nodeChar == ":::") { processClosure(node[[3]], oldEnv, defVars, checkedFuncs, newEnv) @@ -429,7 +434,8 @@ processClosure <- function(node, oldEnv, defVars, checkedFuncs, newEnv) { (typeof(node) == "symbol" || typeof(node) == "language")) { # Base case: current AST node is a leaf node and a symbol or a function call. nodeChar <- as.character(node) - if (!nodeChar %in% defVars$data) { # Not a function parameter or local variable. + if (!nodeChar %in% defVars$data) { + # Not a function parameter or local variable. func.env <- oldEnv topEnv <- parent.env(.GlobalEnv) # Search in function environment, and function's enclosing environments @@ -439,20 +445,24 @@ processClosure <- function(node, oldEnv, defVars, checkedFuncs, newEnv) { while (!identical(func.env, topEnv)) { # Namespaces other than "SparkR" will not be searched. if (!isNamespace(func.env) || - (getNamespaceName(func.env) == "SparkR" && - !(nodeChar %in% getNamespaceExports("SparkR")))) { # Only include SparkR internals. + (getNamespaceName(func.env) == "SparkR" && + !(nodeChar %in% getNamespaceExports("SparkR")))) { + # Only include SparkR internals. + # Set parameter 'inherits' to FALSE since we do not need to search in # attached package environments. if (tryCatch(exists(nodeChar, envir = func.env, inherits = FALSE), error = function(e) { FALSE })) { obj <- get(nodeChar, envir = func.env, inherits = FALSE) - if (is.function(obj)) { # If the node is a function call. + if (is.function(obj)) { + # If the node is a function call. funcList <- mget(nodeChar, envir = checkedFuncs, inherits = F, ifnotfound = list(list(NULL)))[[1]] found <- sapply(funcList, function(func) { ifelse(identical(func, obj), TRUE, FALSE) }) - if (sum(found) > 0) { # If function has been examined, ignore. + if (sum(found) > 0) { + # If function has been examined, ignore. break } # Function has not been examined, record it and recursively clean its closure. @@ -495,7 +505,8 @@ cleanClosure <- function(func, checkedFuncs = new.env()) { # environment. First, function's arguments are added to defVars. defVars <- initAccumulator() argNames <- names(as.list(args(func))) - for (i in 1:(length(argNames) - 1)) { # Remove the ending NULL in pairlist. + for (i in 1:(length(argNames) - 1)) { + # Remove the ending NULL in pairlist. addItemToAccumulator(defVars, argNames[i]) } # Recursively examine variables in the function body. diff --git a/R/pkg/R/zzz.R b/R/pkg/R/zzz.R deleted file mode 100644 index 301feade65fa3..0000000000000 --- a/R/pkg/R/zzz.R +++ /dev/null @@ -1,20 +0,0 @@ -# -# Licensed to the Apache Software Foundation (ASF) under one or more -# contributor license agreements. See the NOTICE file distributed with -# this work for additional information regarding copyright ownership. -# The ASF licenses this file to You under the Apache License, Version 2.0 -# (the "License"); you may not use this file except in compliance with -# the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# - -.onLoad <- function(libname, pkgname) { - sparkR.onLoad(libname, pkgname) -} diff --git a/R/pkg/inst/profile/general.R b/R/pkg/inst/profile/general.R index 8fe711b622086..2a8a8213d0849 100644 --- a/R/pkg/inst/profile/general.R +++ b/R/pkg/inst/profile/general.R @@ -16,7 +16,7 @@ # .First <- function() { - home <- Sys.getenv("SPARK_HOME") - .libPaths(c(file.path(home, "R", "lib"), .libPaths())) + packageDir <- Sys.getenv("SPARKR_PACKAGE_DIR") + .libPaths(c(packageDir, .libPaths())) Sys.setenv(NOAWT=1) } diff --git a/R/pkg/inst/tests/test_binaryFile.R b/R/pkg/inst/tests/test_binaryFile.R index ccaea18ecab2a..f2452ed97d2ea 100644 --- a/R/pkg/inst/tests/test_binaryFile.R +++ b/R/pkg/inst/tests/test_binaryFile.R @@ -20,7 +20,7 @@ context("functions on binary files") # JavaSparkContext handle sc <- sparkR.init() -mockFile = c("Spark is pretty.", "Spark is awesome.") +mockFile <- c("Spark is pretty.", "Spark is awesome.") test_that("saveAsObjectFile()/objectFile() following textFile() works", { fileName1 <- tempfile(pattern="spark-test", fileext=".tmp") diff --git a/R/pkg/inst/tests/test_binary_function.R b/R/pkg/inst/tests/test_binary_function.R index 3be8c65a6c1a0..dca0657c57e0d 100644 --- a/R/pkg/inst/tests/test_binary_function.R +++ b/R/pkg/inst/tests/test_binary_function.R @@ -76,7 +76,7 @@ test_that("zipPartitions() on RDDs", { expect_equal(actual, list(list(1, c(1,2), c(1,2,3)), list(2, c(3,4), c(4,5,6)))) - mockFile = c("Spark is pretty.", "Spark is awesome.") + mockFile <- c("Spark is pretty.", "Spark is awesome.") fileName <- tempfile(pattern="spark-test", fileext=".tmp") writeLines(mockFile, fileName) diff --git a/R/pkg/inst/tests/test_mllib.R b/R/pkg/inst/tests/test_mllib.R new file mode 100644 index 0000000000000..a492763344ae6 --- /dev/null +++ b/R/pkg/inst/tests/test_mllib.R @@ -0,0 +1,42 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +library(testthat) + +context("MLlib functions") + +# Tests for MLlib functions in SparkR + +sc <- sparkR.init() + +sqlContext <- sparkRSQL.init(sc) + +test_that("glm and predict", { + training <- createDataFrame(sqlContext, iris) + test <- select(training, "Sepal_Length") + model <- glm(Sepal_Width ~ Sepal_Length, training, family = "gaussian") + prediction <- predict(model, test) + expect_equal(typeof(take(select(prediction, "prediction"), 1)$prediction), "double") +}) + +test_that("predictions match with native glm", { + training <- createDataFrame(sqlContext, iris) + model <- glm(Sepal_Width ~ Sepal_Length, data = training) + vals <- collect(select(predict(model, training), "prediction")) + rVals <- predict(glm(Sepal.Width ~ Sepal.Length, data = iris), iris) + expect_true(all(abs(rVals - vals) < 1e-9), rVals - vals) +}) diff --git a/R/pkg/inst/tests/test_rdd.R b/R/pkg/inst/tests/test_rdd.R index b79692873cec3..6c3aaab8c711e 100644 --- a/R/pkg/inst/tests/test_rdd.R +++ b/R/pkg/inst/tests/test_rdd.R @@ -447,7 +447,7 @@ test_that("zipRDD() on RDDs", { expect_equal(actual, list(list(0, 1000), list(1, 1001), list(2, 1002), list(3, 1003), list(4, 1004))) - mockFile = c("Spark is pretty.", "Spark is awesome.") + mockFile <- c("Spark is pretty.", "Spark is awesome.") fileName <- tempfile(pattern="spark-test", fileext=".tmp") writeLines(mockFile, fileName) @@ -483,7 +483,7 @@ test_that("cartesian() on RDDs", { actual <- collect(cartesian(rdd, emptyRdd)) expect_equal(actual, list()) - mockFile = c("Spark is pretty.", "Spark is awesome.") + mockFile <- c("Spark is pretty.", "Spark is awesome.") fileName <- tempfile(pattern="spark-test", fileext=".tmp") writeLines(mockFile, fileName) diff --git a/R/pkg/inst/tests/test_sparkSQL.R b/R/pkg/inst/tests/test_sparkSQL.R index b0ea38854304e..62fe48a5d6c7b 100644 --- a/R/pkg/inst/tests/test_sparkSQL.R +++ b/R/pkg/inst/tests/test_sparkSQL.R @@ -57,9 +57,9 @@ test_that("infer types", { expect_equal(infer_type(as.Date("2015-03-11")), "date") expect_equal(infer_type(as.POSIXlt("2015-03-11 12:13:04.043")), "timestamp") expect_equal(infer_type(c(1L, 2L)), - list(type = 'array', elementType = "integer", containsNull = TRUE)) + list(type = "array", elementType = "integer", containsNull = TRUE)) expect_equal(infer_type(list(1L, 2L)), - list(type = 'array', elementType = "integer", containsNull = TRUE)) + list(type = "array", elementType = "integer", containsNull = TRUE)) testStruct <- infer_type(list(a = 1L, b = "2")) expect_equal(class(testStruct), "structType") checkStructField(testStruct$fields()[[1]], "a", "IntegerType", TRUE) @@ -108,6 +108,32 @@ test_that("create DataFrame from RDD", { expect_equal(count(df), 10) expect_equal(columns(df), c("a", "b")) expect_equal(dtypes(df), list(c("a", "int"), c("b", "string"))) + + df <- jsonFile(sqlContext, jsonPathNa) + hiveCtx <- tryCatch({ + newJObject("org.apache.spark.sql.hive.test.TestHiveContext", ssc) + }, error = function(err) { + skip("Hive is not build with SparkSQL, skipped") + }) + sql(hiveCtx, "CREATE TABLE people (name string, age double, height float)") + insertInto(df, "people") + expect_equal(sql(hiveCtx, "SELECT age from people WHERE name = 'Bob'"), c(16)) + expect_equal(sql(hiveCtx, "SELECT height from people WHERE name ='Bob'"), c(176.5)) + + schema <- structType(structField("name", "string"), structField("age", "integer"), + structField("height", "float")) + df2 <- createDataFrame(sqlContext, df.toRDD, schema) + expect_equal(columns(df2), c("name", "age", "height")) + expect_equal(dtypes(df2), list(c("name", "string"), c("age", "int"), c("height", "float"))) + expect_equal(collect(where(df2, df2$name == "Bob")), c("Bob", 16, 176.5)) + + localDF <- data.frame(name=c("John", "Smith", "Sarah"), age=c(19, 23, 18), height=c(164.10, 181.4, 173.7)) + df <- createDataFrame(sqlContext, localDF, schema) + expect_is(df, "DataFrame") + expect_equal(count(df), 3) + expect_equal(columns(df), c("name", "age", "height")) + expect_equal(dtypes(df), list(c("name", "string"), c("age", "int"), c("height", "float"))) + expect_equal(collect(where(df, df$name == "John")), c("John", 19, 164.10)) }) test_that("convert NAs to null type in DataFrames", { @@ -612,6 +638,18 @@ test_that("column functions", { c7 <- floor(c) + log(c) + log10(c) + log1p(c) + rint(c) c8 <- sign(c) + sin(c) + sinh(c) + tan(c) + tanh(c) c9 <- toDegrees(c) + toRadians(c) + + df <- jsonFile(sqlContext, jsonPath) + df2 <- select(df, between(df$age, c(20, 30)), between(df$age, c(10, 20))) + expect_equal(collect(df2)[[2, 1]], TRUE) + expect_equal(collect(df2)[[2, 2]], FALSE) + expect_equal(collect(df2)[[3, 1]], FALSE) + expect_equal(collect(df2)[[3, 2]], TRUE) + + df3 <- select(df, between(df$name, c("Apache", "Spark"))) + expect_equal(collect(df3)[[1, 1]], TRUE) + expect_equal(collect(df3)[[2, 1]], FALSE) + expect_equal(collect(df3)[[3, 1]], TRUE) }) test_that("column binary mathfunctions", { @@ -949,6 +987,19 @@ test_that("fillna() on a DataFrame", { expect_identical(expected, actual) }) +test_that("crosstab() on a DataFrame", { + rdd <- lapply(parallelize(sc, 0:3), function(x) { + list(paste0("a", x %% 3), paste0("b", x %% 2)) + }) + df <- toDF(rdd, list("a", "b")) + ct <- crosstab(df, "a", "b") + ordered <- ct[order(ct$a_b),] + row.names(ordered) <- NULL + expected <- data.frame("a_b" = c("a0", "a1", "a2"), "b0" = c(1, 0, 1), "b1" = c(1, 1, 0), + stringsAsFactors = FALSE, row.names = NULL) + expect_identical(expected, ordered) +}) + unlink(parquetPath) unlink(jsonPath) unlink(jsonPathNa) diff --git a/R/pkg/inst/tests/test_textFile.R b/R/pkg/inst/tests/test_textFile.R index 58318dfef71ab..a9cf83dbdbdb1 100644 --- a/R/pkg/inst/tests/test_textFile.R +++ b/R/pkg/inst/tests/test_textFile.R @@ -20,7 +20,7 @@ context("the textFile() function") # JavaSparkContext handle sc <- sparkR.init() -mockFile = c("Spark is pretty.", "Spark is awesome.") +mockFile <- c("Spark is pretty.", "Spark is awesome.") test_that("textFile() on a local file returns an RDD", { fileName <- tempfile(pattern="spark-test", fileext=".tmp") diff --git a/R/pkg/inst/tests/test_utils.R b/R/pkg/inst/tests/test_utils.R index aa0d2a66b9082..12df4cf4f65b7 100644 --- a/R/pkg/inst/tests/test_utils.R +++ b/R/pkg/inst/tests/test_utils.R @@ -119,7 +119,7 @@ test_that("cleanClosure on R functions", { # Test for overriding variables in base namespace (Issue: SparkR-196). nums <- as.list(1:10) rdd <- parallelize(sc, nums, 2L) - t = 4 # Override base::t in .GlobalEnv. + t <- 4 # Override base::t in .GlobalEnv. f <- function(x) { x > t } newF <- cleanClosure(f) env <- environment(newF) diff --git a/bin/spark-shell b/bin/spark-shell index a6dc863d83fc6..00ab7afd118b5 100755 --- a/bin/spark-shell +++ b/bin/spark-shell @@ -47,11 +47,11 @@ function main() { # (see https://github.com/sbt/sbt/issues/562). stty -icanon min 1 -echo > /dev/null 2>&1 export SPARK_SUBMIT_OPTS="$SPARK_SUBMIT_OPTS -Djline.terminal=unix" - "$FWDIR"/bin/spark-submit --class org.apache.spark.repl.Main "$@" + "$FWDIR"/bin/spark-submit --class org.apache.spark.repl.Main --name "Spark shell" "$@" stty icanon echo > /dev/null 2>&1 else export SPARK_SUBMIT_OPTS - "$FWDIR"/bin/spark-submit --class org.apache.spark.repl.Main "$@" + "$FWDIR"/bin/spark-submit --class org.apache.spark.repl.Main --name "Spark shell" "$@" fi } diff --git a/bin/spark-shell2.cmd b/bin/spark-shell2.cmd index 251309d67f860..b9b0f510d7f5d 100644 --- a/bin/spark-shell2.cmd +++ b/bin/spark-shell2.cmd @@ -32,4 +32,4 @@ if "x%SPARK_SUBMIT_OPTS%"=="x" ( set SPARK_SUBMIT_OPTS="%SPARK_SUBMIT_OPTS% -Dscala.usejavacp=true" :run_shell -%SPARK_HOME%\bin\spark-submit2.cmd --class org.apache.spark.repl.Main %* +%SPARK_HOME%\bin\spark-submit2.cmd --class org.apache.spark.repl.Main --name "Spark shell" %* diff --git a/build/mvn b/build/mvn index e8364181e8230..f62f61ee1c416 100755 --- a/build/mvn +++ b/build/mvn @@ -112,10 +112,17 @@ install_scala() { # the environment ZINC_PORT=${ZINC_PORT:-"3030"} +# Check for the `--force` flag dictating that `mvn` should be downloaded +# regardless of whether the system already has a `mvn` install +if [ "$1" == "--force" ]; then + FORCE_MVN=1 + shift +fi + # Install Maven if necessary MVN_BIN="$(command -v mvn)" -if [ ! "$MVN_BIN" ]; then +if [ ! "$MVN_BIN" -o -n "$FORCE_MVN" ]; then install_mvn fi @@ -139,5 +146,7 @@ fi # Set any `mvn` options if not already present export MAVEN_OPTS=${MAVEN_OPTS:-"$_COMPILE_JVM_OPTS"} +echo "Using \`mvn\` from path: $MVN_BIN" + # Last, call the `mvn` command as usual ${MVN_BIN} "$@" diff --git a/build/sbt-launch-lib.bash b/build/sbt-launch-lib.bash index 504be48b358fa..7930a38b9674a 100755 --- a/build/sbt-launch-lib.bash +++ b/build/sbt-launch-lib.bash @@ -51,9 +51,13 @@ acquire_sbt_jar () { printf "Attempting to fetch sbt\n" JAR_DL="${JAR}.part" if [ $(command -v curl) ]; then - (curl --silent ${URL1} > "${JAR_DL}" || curl --silent ${URL2} > "${JAR_DL}") && mv "${JAR_DL}" "${JAR}" + (curl --fail --location --silent ${URL1} > "${JAR_DL}" ||\ + (rm -f "${JAR_DL}" && curl --fail --location --silent ${URL2} > "${JAR_DL}")) &&\ + mv "${JAR_DL}" "${JAR}" elif [ $(command -v wget) ]; then - (wget --quiet ${URL1} -O "${JAR_DL}" || wget --quiet ${URL2} -O "${JAR_DL}") && mv "${JAR_DL}" "${JAR}" + (wget --quiet ${URL1} -O "${JAR_DL}" ||\ + (rm -f "${JAR_DL}" && wget --quiet ${URL2} -O "${JAR_DL}")) &&\ + mv "${JAR_DL}" "${JAR}" else printf "You do not have curl or wget installed, please install sbt manually from http://www.scala-sbt.org/\n" exit -1 diff --git a/core/pom.xml b/core/pom.xml index 558cc3fb9f2f3..95f36eb348698 100644 --- a/core/pom.xml +++ b/core/pom.xml @@ -261,7 +261,7 @@ com.fasterxml.jackson.module - jackson-module-scala_2.10 + jackson-module-scala_${scala.binary.version} org.apache.derby @@ -372,6 +372,11 @@ junit-interface test + + org.apache.curator + curator-test + test + net.razorvine pyrolite diff --git a/core/src/main/java/org/apache/spark/JavaSparkListener.java b/core/src/main/java/org/apache/spark/JavaSparkListener.java index 646496f313507..fa9acf0a15b88 100644 --- a/core/src/main/java/org/apache/spark/JavaSparkListener.java +++ b/core/src/main/java/org/apache/spark/JavaSparkListener.java @@ -17,23 +17,7 @@ package org.apache.spark; -import org.apache.spark.scheduler.SparkListener; -import org.apache.spark.scheduler.SparkListenerApplicationEnd; -import org.apache.spark.scheduler.SparkListenerApplicationStart; -import org.apache.spark.scheduler.SparkListenerBlockManagerAdded; -import org.apache.spark.scheduler.SparkListenerBlockManagerRemoved; -import org.apache.spark.scheduler.SparkListenerEnvironmentUpdate; -import org.apache.spark.scheduler.SparkListenerExecutorAdded; -import org.apache.spark.scheduler.SparkListenerExecutorMetricsUpdate; -import org.apache.spark.scheduler.SparkListenerExecutorRemoved; -import org.apache.spark.scheduler.SparkListenerJobEnd; -import org.apache.spark.scheduler.SparkListenerJobStart; -import org.apache.spark.scheduler.SparkListenerStageCompleted; -import org.apache.spark.scheduler.SparkListenerStageSubmitted; -import org.apache.spark.scheduler.SparkListenerTaskEnd; -import org.apache.spark.scheduler.SparkListenerTaskGettingResult; -import org.apache.spark.scheduler.SparkListenerTaskStart; -import org.apache.spark.scheduler.SparkListenerUnpersistRDD; +import org.apache.spark.scheduler.*; /** * Java clients should extend this class instead of implementing @@ -94,4 +78,8 @@ public void onExecutorAdded(SparkListenerExecutorAdded executorAdded) { } @Override public void onExecutorRemoved(SparkListenerExecutorRemoved executorRemoved) { } + + @Override + public void onBlockUpdated(SparkListenerBlockUpdated blockUpdated) { } + } diff --git a/core/src/main/java/org/apache/spark/SparkFirehoseListener.java b/core/src/main/java/org/apache/spark/SparkFirehoseListener.java index fbc5666959055..1214d05ba6063 100644 --- a/core/src/main/java/org/apache/spark/SparkFirehoseListener.java +++ b/core/src/main/java/org/apache/spark/SparkFirehoseListener.java @@ -112,4 +112,10 @@ public final void onExecutorAdded(SparkListenerExecutorAdded executorAdded) { public final void onExecutorRemoved(SparkListenerExecutorRemoved executorRemoved) { onEvent(executorRemoved); } + + @Override + public void onBlockUpdated(SparkListenerBlockUpdated blockUpdated) { + onEvent(blockUpdated); + } + } diff --git a/core/src/main/java/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriter.java b/core/src/main/java/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriter.java index d3d6280284beb..0b8b604e18494 100644 --- a/core/src/main/java/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriter.java +++ b/core/src/main/java/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriter.java @@ -75,7 +75,7 @@ final class BypassMergeSortShuffleWriter implements SortShuffleFileWriter< private final Serializer serializer; /** Array of file writers, one for each partition */ - private BlockObjectWriter[] partitionWriters; + private DiskBlockObjectWriter[] partitionWriters; public BypassMergeSortShuffleWriter( SparkConf conf, @@ -101,7 +101,7 @@ public void insertAll(Iterator> records) throws IOException { } final SerializerInstance serInstance = serializer.newInstance(); final long openStartTime = System.nanoTime(); - partitionWriters = new BlockObjectWriter[numPartitions]; + partitionWriters = new DiskBlockObjectWriter[numPartitions]; for (int i = 0; i < numPartitions; i++) { final Tuple2 tempShuffleBlockIdPlusFile = blockManager.diskBlockManager().createTempShuffleBlock(); @@ -121,7 +121,7 @@ public void insertAll(Iterator> records) throws IOException { partitionWriters[partitioner.getPartition(key)].write(key, record._2()); } - for (BlockObjectWriter writer : partitionWriters) { + for (DiskBlockObjectWriter writer : partitionWriters) { writer.commitAndClose(); } } @@ -169,7 +169,7 @@ public void stop() throws IOException { if (partitionWriters != null) { try { final DiskBlockManager diskBlockManager = blockManager.diskBlockManager(); - for (BlockObjectWriter writer : partitionWriters) { + for (DiskBlockObjectWriter writer : partitionWriters) { // This method explicitly does _not_ throw exceptions: writer.revertPartialWritesAndClose(); if (!diskBlockManager.getFile(writer.blockId()).delete()) { diff --git a/core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleExternalSorter.java b/core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleExternalSorter.java index 56289573209fb..1d460432be9ff 100644 --- a/core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleExternalSorter.java +++ b/core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleExternalSorter.java @@ -157,7 +157,7 @@ private void writeSortedFile(boolean isLastFile) throws IOException { // Currently, we need to open a new DiskBlockObjectWriter for each partition; we can avoid this // after SPARK-5581 is fixed. - BlockObjectWriter writer; + DiskBlockObjectWriter writer; // Small writes to DiskBlockObjectWriter will be fairly inefficient. Since there doesn't seem to // be an API to directly transfer bytes from managed memory to the disk writer, we buffer diff --git a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/PrefixComparators.java b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/PrefixComparators.java index 438742565c51d..bf1bc5dffba78 100644 --- a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/PrefixComparators.java +++ b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/PrefixComparators.java @@ -23,6 +23,7 @@ import org.apache.spark.annotation.Private; import org.apache.spark.unsafe.types.UTF8String; +import org.apache.spark.util.Utils; @Private public class PrefixComparators { @@ -82,7 +83,7 @@ public static final class FloatPrefixComparator extends PrefixComparator { public int compare(long aPrefix, long bPrefix) { float a = Float.intBitsToFloat((int) aPrefix); float b = Float.intBitsToFloat((int) bPrefix); - return (a < b) ? -1 : (a > b) ? 1 : 0; + return Utils.nanSafeCompareFloats(a, b); } public long computePrefix(float value) { @@ -97,7 +98,7 @@ public static final class DoublePrefixComparator extends PrefixComparator { public int compare(long aPrefix, long bPrefix) { double a = Double.longBitsToDouble(aPrefix); double b = Double.longBitsToDouble(bPrefix); - return (a < b) ? -1 : (a > b) ? 1 : 0; + return Utils.nanSafeCompareDoubles(a, b); } public long computePrefix(double value) { diff --git a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSorterSpillWriter.java b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSorterSpillWriter.java index b8d66659804ad..71eed29563d4a 100644 --- a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSorterSpillWriter.java +++ b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSorterSpillWriter.java @@ -26,7 +26,7 @@ import org.apache.spark.serializer.DummySerializerInstance; import org.apache.spark.storage.BlockId; import org.apache.spark.storage.BlockManager; -import org.apache.spark.storage.BlockObjectWriter; +import org.apache.spark.storage.DiskBlockObjectWriter; import org.apache.spark.storage.TempLocalBlockId; import org.apache.spark.unsafe.PlatformDependent; @@ -47,7 +47,7 @@ final class UnsafeSorterSpillWriter { private final File file; private final BlockId blockId; private final int numRecordsToWrite; - private BlockObjectWriter writer; + private DiskBlockObjectWriter writer; private int numRecordsSpilled = 0; public UnsafeSorterSpillWriter( diff --git a/core/src/main/resources/org/apache/spark/ui/static/additional-metrics.js b/core/src/main/resources/org/apache/spark/ui/static/additional-metrics.js index 0b450dc76bc38..3c8ddddf07b1e 100644 --- a/core/src/main/resources/org/apache/spark/ui/static/additional-metrics.js +++ b/core/src/main/resources/org/apache/spark/ui/static/additional-metrics.js @@ -19,6 +19,9 @@ * to be registered after the page loads. */ $(function() { $("span.expand-additional-metrics").click(function(){ + var status = window.localStorage.getItem("expand-additional-metrics") == "true"; + status = !status; + // Expand the list of additional metrics. var additionalMetricsDiv = $(this).parent().find('.additional-metrics'); $(additionalMetricsDiv).toggleClass('collapsed'); @@ -26,17 +29,31 @@ $(function() { // Switch the class of the arrow from open to closed. $(this).find('.expand-additional-metrics-arrow').toggleClass('arrow-open'); $(this).find('.expand-additional-metrics-arrow').toggleClass('arrow-closed'); + + window.localStorage.setItem("expand-additional-metrics", "" + status); }); + if (window.localStorage.getItem("expand-additional-metrics") == "true") { + // Set it to false so that the click function can revert it + window.localStorage.setItem("expand-additional-metrics", "false"); + $("span.expand-additional-metrics").trigger("click"); + } + stripeSummaryTable(); $('input[type="checkbox"]').click(function() { - var column = "table ." + $(this).attr("name"); + var name = $(this).attr("name") + var column = "table ." + name; + var status = window.localStorage.getItem(name) == "true"; + status = !status; $(column).toggle(); stripeSummaryTable(); + window.localStorage.setItem(name, "" + status); }); $("#select-all-metrics").click(function() { + var status = window.localStorage.getItem("select-all-metrics") == "true"; + status = !status; if (this.checked) { // Toggle all un-checked options. $('input[type="checkbox"]:not(:checked)').trigger('click'); @@ -44,6 +61,21 @@ $(function() { // Toggle all checked options. $('input[type="checkbox"]:checked').trigger('click'); } + window.localStorage.setItem("select-all-metrics", "" + status); + }); + + if (window.localStorage.getItem("select-all-metrics") == "true") { + $("#select-all-metrics").attr('checked', status); + } + + $("span.additional-metric-title").parent().find('input[type="checkbox"]').each(function() { + var name = $(this).attr("name") + // If name is undefined, then skip it because it's the "select-all-metrics" checkbox + if (name && window.localStorage.getItem(name) == "true") { + // Set it to false so that the click function can revert it + window.localStorage.setItem(name, "false"); + $(this).trigger("click") + } }); // Trigger a click on the checkbox if a user clicks the label next to it. diff --git a/core/src/main/resources/org/apache/spark/ui/static/spark-dag-viz.js b/core/src/main/resources/org/apache/spark/ui/static/spark-dag-viz.js index 9fa53baaf4212..4a893bc0189aa 100644 --- a/core/src/main/resources/org/apache/spark/ui/static/spark-dag-viz.js +++ b/core/src/main/resources/org/apache/spark/ui/static/spark-dag-viz.js @@ -72,6 +72,14 @@ var StagePageVizConstants = { rankSep: 40 }; +/* + * Return "expand-dag-viz-arrow-job" if forJob is true. + * Otherwise, return "expand-dag-viz-arrow-stage". + */ +function expandDagVizArrowKey(forJob) { + return forJob ? "expand-dag-viz-arrow-job" : "expand-dag-viz-arrow-stage"; +} + /* * Show or hide the RDD DAG visualization. * @@ -79,6 +87,9 @@ var StagePageVizConstants = { * This is the narrow interface called from the Scala UI code. */ function toggleDagViz(forJob) { + var status = window.localStorage.getItem(expandDagVizArrowKey(forJob)) == "true"; + status = !status; + var arrowSelector = ".expand-dag-viz-arrow"; $(arrowSelector).toggleClass('arrow-closed'); $(arrowSelector).toggleClass('arrow-open'); @@ -93,8 +104,24 @@ function toggleDagViz(forJob) { // Save the graph for later so we don't have to render it again graphContainer().style("display", "none"); } + + window.localStorage.setItem(expandDagVizArrowKey(forJob), "" + status); } +$(function (){ + if (window.localStorage.getItem(expandDagVizArrowKey(false)) == "true") { + // Set it to false so that the click function can revert it + window.localStorage.setItem(expandDagVizArrowKey(false), "false"); + toggleDagViz(false); + } + + if (window.localStorage.getItem(expandDagVizArrowKey(true)) == "true") { + // Set it to false so that the click function can revert it + window.localStorage.setItem(expandDagVizArrowKey(true), "false"); + toggleDagViz(true); + } +}); + /* * Render the RDD DAG visualization. * diff --git a/core/src/main/resources/org/apache/spark/ui/static/timeline-view.js b/core/src/main/resources/org/apache/spark/ui/static/timeline-view.js index ca74ef9d7e94e..f4453c71df1ea 100644 --- a/core/src/main/resources/org/apache/spark/ui/static/timeline-view.js +++ b/core/src/main/resources/org/apache/spark/ui/static/timeline-view.js @@ -66,14 +66,27 @@ function drawApplicationTimeline(groupArray, eventObjArray, startTime) { setupJobEventAction(); $("span.expand-application-timeline").click(function() { + var status = window.localStorage.getItem("expand-application-timeline") == "true"; + status = !status; + $("#application-timeline").toggleClass('collapsed'); // Switch the class of the arrow from open to closed. $(this).find('.expand-application-timeline-arrow').toggleClass('arrow-open'); $(this).find('.expand-application-timeline-arrow').toggleClass('arrow-closed'); + + window.localStorage.setItem("expand-application-timeline", "" + status); }); } +$(function (){ + if (window.localStorage.getItem("expand-application-timeline") == "true") { + // Set it to false so that the click function can revert it + window.localStorage.setItem("expand-application-timeline", "false"); + $("span.expand-application-timeline").trigger('click'); + } +}); + function drawJobTimeline(groupArray, eventObjArray, startTime) { var groups = new vis.DataSet(groupArray); var items = new vis.DataSet(eventObjArray); @@ -125,14 +138,27 @@ function drawJobTimeline(groupArray, eventObjArray, startTime) { setupStageEventAction(); $("span.expand-job-timeline").click(function() { + var status = window.localStorage.getItem("expand-job-timeline") == "true"; + status = !status; + $("#job-timeline").toggleClass('collapsed'); // Switch the class of the arrow from open to closed. $(this).find('.expand-job-timeline-arrow').toggleClass('arrow-open'); $(this).find('.expand-job-timeline-arrow').toggleClass('arrow-closed'); + + window.localStorage.setItem("expand-job-timeline", "" + status); }); } +$(function (){ + if (window.localStorage.getItem("expand-job-timeline") == "true") { + // Set it to false so that the click function can revert it + window.localStorage.setItem("expand-job-timeline", "false"); + $("span.expand-job-timeline").trigger('click'); + } +}); + function drawTaskAssignmentTimeline(groupArray, eventObjArray, minLaunchTime, maxFinishTime) { var groups = new vis.DataSet(groupArray); var items = new vis.DataSet(eventObjArray); @@ -176,14 +202,27 @@ function drawTaskAssignmentTimeline(groupArray, eventObjArray, minLaunchTime, ma setupZoomable("#task-assignment-timeline-zoom-lock", taskTimeline); $("span.expand-task-assignment-timeline").click(function() { + var status = window.localStorage.getItem("expand-task-assignment-timeline") == "true"; + status = !status; + $("#task-assignment-timeline").toggleClass("collapsed"); // Switch the class of the arrow from open to closed. $(this).find(".expand-task-assignment-timeline-arrow").toggleClass("arrow-open"); $(this).find(".expand-task-assignment-timeline-arrow").toggleClass("arrow-closed"); + + window.localStorage.setItem("expand-task-assignment-timeline", "" + status); }); } +$(function (){ + if (window.localStorage.getItem("expand-task-assignment-timeline") == "true") { + // Set it to false so that the click function can revert it + window.localStorage.setItem("expand-task-assignment-timeline", "false"); + $("span.expand-task-assignment-timeline").trigger('click'); + } +}); + function setupExecutorEventAction() { $(".item.box.executor").each(function () { $(this).hover( diff --git a/core/src/main/scala/org/apache/spark/Accumulators.scala b/core/src/main/scala/org/apache/spark/Accumulators.scala index 5a8d17bd99933..2f4fcac890eef 100644 --- a/core/src/main/scala/org/apache/spark/Accumulators.scala +++ b/core/src/main/scala/org/apache/spark/Accumulators.scala @@ -20,7 +20,8 @@ package org.apache.spark import java.io.{ObjectInputStream, Serializable} import scala.collection.generic.Growable -import scala.collection.mutable.Map +import scala.collection.Map +import scala.collection.mutable import scala.ref.WeakReference import scala.reflect.ClassTag @@ -39,25 +40,44 @@ import org.apache.spark.util.Utils * @param initialValue initial value of accumulator * @param param helper object defining how to add elements of type `R` and `T` * @param name human-readable name for use in Spark's web UI + * @param internal if this [[Accumulable]] is internal. Internal [[Accumulable]]s will be reported + * to the driver via heartbeats. For internal [[Accumulable]]s, `R` must be + * thread safe so that they can be reported correctly. * @tparam R the full accumulated data (result type) * @tparam T partial data that can be added in */ -class Accumulable[R, T] ( +class Accumulable[R, T] private[spark] ( @transient initialValue: R, param: AccumulableParam[R, T], - val name: Option[String]) + val name: Option[String], + internal: Boolean) extends Serializable { + private[spark] def this( + @transient initialValue: R, param: AccumulableParam[R, T], internal: Boolean) = { + this(initialValue, param, None, internal) + } + + def this(@transient initialValue: R, param: AccumulableParam[R, T], name: Option[String]) = + this(initialValue, param, name, false) + def this(@transient initialValue: R, param: AccumulableParam[R, T]) = this(initialValue, param, None) val id: Long = Accumulators.newId - @transient private var value_ = initialValue // Current value on master + @volatile @transient private var value_ : R = initialValue // Current value on master val zero = param.zero(initialValue) // Zero value to be passed to workers private var deserialized = false - Accumulators.register(this, true) + Accumulators.register(this) + + /** + * If this [[Accumulable]] is internal. Internal [[Accumulable]]s will be reported to the driver + * via heartbeats. For internal [[Accumulable]]s, `R` must be thread safe so that they can be + * reported correctly. + */ + private[spark] def isInternal: Boolean = internal /** * Add more data to this accumulator / accumulable @@ -132,7 +152,8 @@ class Accumulable[R, T] ( in.defaultReadObject() value_ = zero deserialized = true - Accumulators.register(this, false) + val taskContext = TaskContext.get() + taskContext.registerAccumulator(this) } override def toString: String = if (value_ == null) "null" else value_.toString @@ -284,16 +305,7 @@ private[spark] object Accumulators extends Logging { * It keeps weak references to these objects so that accumulators can be garbage-collected * once the RDDs and user-code that reference them are cleaned up. */ - val originals = Map[Long, WeakReference[Accumulable[_, _]]]() - - /** - * This thread-local map holds per-task copies of accumulators; it is used to collect the set - * of accumulator updates to send back to the driver when tasks complete. After tasks complete, - * this map is cleared by `Accumulators.clear()` (see Executor.scala). - */ - private val localAccums = new ThreadLocal[Map[Long, Accumulable[_, _]]]() { - override protected def initialValue() = Map[Long, Accumulable[_, _]]() - } + val originals = mutable.Map[Long, WeakReference[Accumulable[_, _]]]() private var lastId: Long = 0 @@ -302,19 +314,8 @@ private[spark] object Accumulators extends Logging { lastId } - def register(a: Accumulable[_, _], original: Boolean): Unit = synchronized { - if (original) { - originals(a.id) = new WeakReference[Accumulable[_, _]](a) - } else { - localAccums.get()(a.id) = a - } - } - - // Clear the local (non-original) accumulators for the current thread - def clear() { - synchronized { - localAccums.get.clear() - } + def register(a: Accumulable[_, _]): Unit = synchronized { + originals(a.id) = new WeakReference[Accumulable[_, _]](a) } def remove(accId: Long) { @@ -323,15 +324,6 @@ private[spark] object Accumulators extends Logging { } } - // Get the values of the local accumulators for the current thread (by ID) - def values: Map[Long, Any] = synchronized { - val ret = Map[Long, Any]() - for ((id, accum) <- localAccums.get) { - ret(id) = accum.localValue - } - return ret - } - // Add values to the original accumulators with some given IDs def add(values: Map[Long, Any]): Unit = synchronized { for ((id, value) <- values) { diff --git a/core/src/main/scala/org/apache/spark/ExecutorAllocationManager.scala b/core/src/main/scala/org/apache/spark/ExecutorAllocationManager.scala index 0c50b4002cf7b..648bcfe28cad2 100644 --- a/core/src/main/scala/org/apache/spark/ExecutorAllocationManager.scala +++ b/core/src/main/scala/org/apache/spark/ExecutorAllocationManager.scala @@ -20,6 +20,7 @@ package org.apache.spark import java.util.concurrent.TimeUnit import scala.collection.mutable +import scala.util.control.ControlThrowable import com.codahale.metrics.{Gauge, MetricRegistry} @@ -211,7 +212,16 @@ private[spark] class ExecutorAllocationManager( listenerBus.addListener(listener) val scheduleTask = new Runnable() { - override def run(): Unit = Utils.logUncaughtExceptions(schedule()) + override def run(): Unit = { + try { + schedule() + } catch { + case ct: ControlThrowable => + throw ct + case t: Throwable => + logWarning(s"Uncaught exception in thread ${Thread.currentThread().getName}", t) + } + } } executor.scheduleAtFixedRate(scheduleTask, 0, intervalMillis, TimeUnit.MILLISECONDS) } diff --git a/core/src/main/scala/org/apache/spark/HeartbeatReceiver.scala b/core/src/main/scala/org/apache/spark/HeartbeatReceiver.scala index 221b1dab43278..43dd4a170731d 100644 --- a/core/src/main/scala/org/apache/spark/HeartbeatReceiver.scala +++ b/core/src/main/scala/org/apache/spark/HeartbeatReceiver.scala @@ -181,7 +181,9 @@ private[spark] class HeartbeatReceiver(sc: SparkContext, clock: Clock) // Asynchronously kill the executor to avoid blocking the current thread killExecutorThread.submit(new Runnable { override def run(): Unit = Utils.tryLogNonFatalError { - sc.killExecutor(executorId) + // Note: we want to get an executor back after expiring this one, + // so do not simply call `sc.killExecutor` here (SPARK-8119) + sc.killAndReplaceExecutor(executorId) } }) } diff --git a/core/src/main/scala/org/apache/spark/Logging.scala b/core/src/main/scala/org/apache/spark/Logging.scala index 87ab099267b2f..f0598816d6c07 100644 --- a/core/src/main/scala/org/apache/spark/Logging.scala +++ b/core/src/main/scala/org/apache/spark/Logging.scala @@ -159,7 +159,7 @@ private object Logging { try { // We use reflection here to handle the case where users remove the // slf4j-to-jul bridge order to route their logs to JUL. - val bridgeClass = Class.forName("org.slf4j.bridge.SLF4JBridgeHandler") + val bridgeClass = Utils.classForName("org.slf4j.bridge.SLF4JBridgeHandler") bridgeClass.getMethod("removeHandlersForRootLogger").invoke(null) val installed = bridgeClass.getMethod("isInstalled").invoke(null).asInstanceOf[Boolean] if (!installed) { diff --git a/core/src/main/scala/org/apache/spark/MapOutputTracker.scala b/core/src/main/scala/org/apache/spark/MapOutputTracker.scala index 862ffe868f58f..92218832d256f 100644 --- a/core/src/main/scala/org/apache/spark/MapOutputTracker.scala +++ b/core/src/main/scala/org/apache/spark/MapOutputTracker.scala @@ -21,14 +21,14 @@ import java.io._ import java.util.concurrent.ConcurrentHashMap import java.util.zip.{GZIPInputStream, GZIPOutputStream} -import scala.collection.mutable.{HashMap, HashSet, Map} +import scala.collection.mutable.{ArrayBuffer, HashMap, HashSet, Map} import scala.collection.JavaConversions._ import scala.reflect.ClassTag import org.apache.spark.rpc.{RpcEndpointRef, RpcEnv, RpcCallContext, RpcEndpoint} import org.apache.spark.scheduler.MapStatus import org.apache.spark.shuffle.MetadataFetchFailedException -import org.apache.spark.storage.BlockManagerId +import org.apache.spark.storage.{BlockId, BlockManagerId, ShuffleBlockId} import org.apache.spark.util._ private[spark] sealed trait MapOutputTrackerMessage @@ -124,10 +124,18 @@ private[spark] abstract class MapOutputTracker(conf: SparkConf) extends Logging } /** - * Called from executors to get the server URIs and output sizes of the map outputs of - * a given shuffle. + * Called from executors to get the server URIs and output sizes for each shuffle block that + * needs to be read from a given reduce task. + * + * @return A sequence of 2-item tuples, where the first item in the tuple is a BlockManagerId, + * and the second item is a sequence of (shuffle block id, shuffle block size) tuples + * describing the shuffle blocks that are stored at that block manager. */ - def getServerStatuses(shuffleId: Int, reduceId: Int): Array[(BlockManagerId, Long)] = { + def getMapSizesByExecutorId(shuffleId: Int, reduceId: Int) + : Seq[(BlockManagerId, Seq[(BlockId, Long)])] = { + logDebug(s"Fetching outputs for shuffle $shuffleId, reduce $reduceId") + val startTime = System.currentTimeMillis + val statuses = mapStatuses.get(shuffleId).orNull if (statuses == null) { logInfo("Don't have map outputs for shuffle " + shuffleId + ", fetching them") @@ -167,6 +175,9 @@ private[spark] abstract class MapOutputTracker(conf: SparkConf) extends Logging } } } + logDebug(s"Fetching map output location for shuffle $shuffleId, reduce $reduceId took " + + s"${System.currentTimeMillis - startTime} ms") + if (fetchedStatuses != null) { fetchedStatuses.synchronized { return MapOutputTracker.convertMapStatuses(shuffleId, reduceId, fetchedStatuses) @@ -421,23 +432,38 @@ private[spark] object MapOutputTracker extends Logging { } } - // Convert an array of MapStatuses to locations and sizes for a given reduce ID. If - // any of the statuses is null (indicating a missing location due to a failed mapper), - // throw a FetchFailedException. + /** + * Converts an array of MapStatuses for a given reduce ID to a sequence that, for each block + * manager ID, lists the shuffle block ids and corresponding shuffle block sizes stored at that + * block manager. + * + * If any of the statuses is null (indicating a missing location due to a failed mapper), + * throws a FetchFailedException. + * + * @param shuffleId Identifier for the shuffle + * @param reduceId Identifier for the reduce task + * @param statuses List of map statuses, indexed by map ID. + * @return A sequence of 2-item tuples, where the first item in the tuple is a BlockManagerId, + * and the second item is a sequence of (shuffle block id, shuffle block size) tuples + * describing the shuffle blocks that are stored at that block manager. + */ private def convertMapStatuses( shuffleId: Int, reduceId: Int, - statuses: Array[MapStatus]): Array[(BlockManagerId, Long)] = { + statuses: Array[MapStatus]): Seq[(BlockManagerId, Seq[(BlockId, Long)])] = { assert (statuses != null) - statuses.map { - status => - if (status == null) { - logError("Missing an output location for shuffle " + shuffleId) - throw new MetadataFetchFailedException( - shuffleId, reduceId, "Missing an output location for shuffle " + shuffleId) - } else { - (status.location, status.getSizeForBlock(reduceId)) - } + val splitsByAddress = new HashMap[BlockManagerId, ArrayBuffer[(BlockId, Long)]] + for ((status, mapId) <- statuses.zipWithIndex) { + if (status == null) { + val errorMessage = s"Missing an output location for shuffle $shuffleId" + logError(errorMessage) + throw new MetadataFetchFailedException(shuffleId, reduceId, errorMessage) + } else { + splitsByAddress.getOrElseUpdate(status.location, ArrayBuffer()) += + ((ShuffleBlockId(shuffleId, mapId, reduceId), status.getSizeForBlock(reduceId))) + } } + + splitsByAddress.toSeq } } diff --git a/core/src/main/scala/org/apache/spark/Partitioner.scala b/core/src/main/scala/org/apache/spark/Partitioner.scala index 82889bcd30988..ad68512dccb79 100644 --- a/core/src/main/scala/org/apache/spark/Partitioner.scala +++ b/core/src/main/scala/org/apache/spark/Partitioner.scala @@ -76,6 +76,8 @@ object Partitioner { * produce an unexpected or incorrect result. */ class HashPartitioner(partitions: Int) extends Partitioner { + require(partitions >= 0, s"Number of partitions ($partitions) cannot be negative.") + def numPartitions: Int = partitions def getPartition(key: Any): Int = key match { diff --git a/core/src/main/scala/org/apache/spark/SparkContext.scala b/core/src/main/scala/org/apache/spark/SparkContext.scala index 82704b1ab2189..6a6b94a271cfc 100644 --- a/core/src/main/scala/org/apache/spark/SparkContext.scala +++ b/core/src/main/scala/org/apache/spark/SparkContext.scala @@ -471,7 +471,7 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli .orElse(Option(System.getenv("SPARK_MEM")) .map(warnSparkMem)) .map(Utils.memoryStringToMb) - .getOrElse(512) + .getOrElse(1024) // Convert java options to env vars as a work around // since we can't set env vars directly in sbt. @@ -1419,6 +1419,12 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli /** * :: DeveloperApi :: * Request that the cluster manager kill the specified executors. + * + * Note: This is an indication to the cluster manager that the application wishes to adjust + * its resource usage downwards. If the application wishes to replace the executors it kills + * through this method with new ones, it should follow up explicitly with a call to + * {{SparkContext#requestExecutors}}. + * * This is currently only supported in YARN mode. Return whether the request is received. */ @DeveloperApi @@ -1436,12 +1442,42 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli /** * :: DeveloperApi :: - * Request that cluster manager the kill the specified executor. - * This is currently only supported in Yarn mode. Return whether the request is received. + * Request that the cluster manager kill the specified executor. + * + * Note: This is an indication to the cluster manager that the application wishes to adjust + * its resource usage downwards. If the application wishes to replace the executor it kills + * through this method with a new one, it should follow up explicitly with a call to + * {{SparkContext#requestExecutors}}. + * + * This is currently only supported in YARN mode. Return whether the request is received. */ @DeveloperApi override def killExecutor(executorId: String): Boolean = super.killExecutor(executorId) + /** + * Request that the cluster manager kill the specified executor without adjusting the + * application resource requirements. + * + * The effect is that a new executor will be launched in place of the one killed by + * this request. This assumes the cluster manager will automatically and eventually + * fulfill all missing application resource requests. + * + * Note: The replace is by no means guaranteed; another application on the same cluster + * can steal the window of opportunity and acquire this application's resources in the + * mean time. + * + * This is currently only supported in YARN mode. Return whether the request is received. + */ + private[spark] def killAndReplaceExecutor(executorId: String): Boolean = { + schedulerBackend match { + case b: CoarseGrainedSchedulerBackend => + b.killExecutors(Seq(executorId), replace = true) + case _ => + logWarning("Killing executors is only supported in coarse-grained mode") + false + } + } + /** The version of Spark on which this application is running. */ def version: String = SPARK_VERSION @@ -1722,16 +1758,13 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli /** * Run a function on a given set of partitions in an RDD and pass the results to the given - * handler function. This is the main entry point for all actions in Spark. The allowLocal - * flag specifies whether the scheduler can run the computation on the driver rather than - * shipping it out to the cluster, for short actions like first(). + * handler function. This is the main entry point for all actions in Spark. */ def runJob[T, U: ClassTag]( rdd: RDD[T], func: (TaskContext, Iterator[T]) => U, partitions: Seq[Int], - allowLocal: Boolean, - resultHandler: (Int, U) => Unit) { + resultHandler: (Int, U) => Unit): Unit = { if (stopped.get()) { throw new IllegalStateException("SparkContext has been shutdown") } @@ -1741,54 +1774,104 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli if (conf.getBoolean("spark.logLineage", false)) { logInfo("RDD's recursive dependencies:\n" + rdd.toDebugString) } - dagScheduler.runJob(rdd, cleanedFunc, partitions, callSite, allowLocal, - resultHandler, localProperties.get) + dagScheduler.runJob(rdd, cleanedFunc, partitions, callSite, resultHandler, localProperties.get) progressBar.foreach(_.finishAll()) rdd.doCheckpoint() } /** - * Run a function on a given set of partitions in an RDD and return the results as an array. The - * allowLocal flag specifies whether the scheduler can run the computation on the driver rather - * than shipping it out to the cluster, for short actions like first(). + * Run a function on a given set of partitions in an RDD and return the results as an array. + */ + def runJob[T, U: ClassTag]( + rdd: RDD[T], + func: (TaskContext, Iterator[T]) => U, + partitions: Seq[Int]): Array[U] = { + val results = new Array[U](partitions.size) + runJob[T, U](rdd, func, partitions, (index, res) => results(index) = res) + results + } + + /** + * Run a job on a given set of partitions of an RDD, but take a function of type + * `Iterator[T] => U` instead of `(TaskContext, Iterator[T]) => U`. + */ + def runJob[T, U: ClassTag]( + rdd: RDD[T], + func: Iterator[T] => U, + partitions: Seq[Int]): Array[U] = { + val cleanedFunc = clean(func) + runJob(rdd, (ctx: TaskContext, it: Iterator[T]) => cleanedFunc(it), partitions) + } + + + /** + * Run a function on a given set of partitions in an RDD and pass the results to the given + * handler function. This is the main entry point for all actions in Spark. + * + * The allowLocal flag is deprecated as of Spark 1.5.0+. + */ + @deprecated("use the version of runJob without the allowLocal parameter", "1.5.0") + def runJob[T, U: ClassTag]( + rdd: RDD[T], + func: (TaskContext, Iterator[T]) => U, + partitions: Seq[Int], + allowLocal: Boolean, + resultHandler: (Int, U) => Unit): Unit = { + if (allowLocal) { + logWarning("sc.runJob with allowLocal=true is deprecated in Spark 1.5.0+") + } + runJob(rdd, func, partitions, resultHandler) + } + + /** + * Run a function on a given set of partitions in an RDD and return the results as an array. + * + * The allowLocal flag is deprecated as of Spark 1.5.0+. */ + @deprecated("use the version of runJob without the allowLocal parameter", "1.5.0") def runJob[T, U: ClassTag]( rdd: RDD[T], func: (TaskContext, Iterator[T]) => U, partitions: Seq[Int], allowLocal: Boolean ): Array[U] = { - val results = new Array[U](partitions.size) - runJob[T, U](rdd, func, partitions, allowLocal, (index, res) => results(index) = res) - results + if (allowLocal) { + logWarning("sc.runJob with allowLocal=true is deprecated in Spark 1.5.0+") + } + runJob(rdd, func, partitions) } /** * Run a job on a given set of partitions of an RDD, but take a function of type * `Iterator[T] => U` instead of `(TaskContext, Iterator[T]) => U`. + * + * The allowLocal argument is deprecated as of Spark 1.5.0+. */ + @deprecated("use the version of runJob without the allowLocal parameter", "1.5.0") def runJob[T, U: ClassTag]( rdd: RDD[T], func: Iterator[T] => U, partitions: Seq[Int], allowLocal: Boolean ): Array[U] = { - val cleanedFunc = clean(func) - runJob(rdd, (ctx: TaskContext, it: Iterator[T]) => cleanedFunc(it), partitions, allowLocal) + if (allowLocal) { + logWarning("sc.runJob with allowLocal=true is deprecated in Spark 1.5.0+") + } + runJob(rdd, func, partitions) } /** * Run a job on all partitions in an RDD and return the results in an array. */ def runJob[T, U: ClassTag](rdd: RDD[T], func: (TaskContext, Iterator[T]) => U): Array[U] = { - runJob(rdd, func, 0 until rdd.partitions.size, false) + runJob(rdd, func, 0 until rdd.partitions.length) } /** * Run a job on all partitions in an RDD and return the results in an array. */ def runJob[T, U: ClassTag](rdd: RDD[T], func: Iterator[T] => U): Array[U] = { - runJob(rdd, func, 0 until rdd.partitions.size, false) + runJob(rdd, func, 0 until rdd.partitions.length) } /** @@ -1799,7 +1882,7 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli processPartition: (TaskContext, Iterator[T]) => U, resultHandler: (Int, U) => Unit) { - runJob[T, U](rdd, processPartition, 0 until rdd.partitions.size, false, resultHandler) + runJob[T, U](rdd, processPartition, 0 until rdd.partitions.length, resultHandler) } /** @@ -1811,7 +1894,7 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli resultHandler: (Int, U) => Unit) { val processFunc = (context: TaskContext, iter: Iterator[T]) => processPartition(iter) - runJob[T, U](rdd, processFunc, 0 until rdd.partitions.size, false, resultHandler) + runJob[T, U](rdd, processFunc, 0 until rdd.partitions.length, resultHandler) } /** @@ -1856,7 +1939,6 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli (context: TaskContext, iter: Iterator[T]) => cleanF(iter), partitions, callSite, - allowLocal = false, resultHandler, localProperties.get) new SimpleFutureAction(waiter, resultFunc) @@ -1968,7 +2050,7 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli for (className <- listenerClassNames) { // Use reflection to find the right constructor val constructors = { - val listenerClass = Class.forName(className) + val listenerClass = Utils.classForName(className) listenerClass.getConstructors.asInstanceOf[Array[Constructor[_ <: SparkListener]]] } val constructorTakingSparkConf = constructors.find { c => @@ -2503,7 +2585,7 @@ object SparkContext extends Logging { "\"yarn-standalone\" is deprecated as of Spark 1.0. Use \"yarn-cluster\" instead.") } val scheduler = try { - val clazz = Class.forName("org.apache.spark.scheduler.cluster.YarnClusterScheduler") + val clazz = Utils.classForName("org.apache.spark.scheduler.cluster.YarnClusterScheduler") val cons = clazz.getConstructor(classOf[SparkContext]) cons.newInstance(sc).asInstanceOf[TaskSchedulerImpl] } catch { @@ -2515,7 +2597,7 @@ object SparkContext extends Logging { } val backend = try { val clazz = - Class.forName("org.apache.spark.scheduler.cluster.YarnClusterSchedulerBackend") + Utils.classForName("org.apache.spark.scheduler.cluster.YarnClusterSchedulerBackend") val cons = clazz.getConstructor(classOf[TaskSchedulerImpl], classOf[SparkContext]) cons.newInstance(scheduler, sc).asInstanceOf[CoarseGrainedSchedulerBackend] } catch { @@ -2528,8 +2610,7 @@ object SparkContext extends Logging { case "yarn-client" => val scheduler = try { - val clazz = - Class.forName("org.apache.spark.scheduler.cluster.YarnScheduler") + val clazz = Utils.classForName("org.apache.spark.scheduler.cluster.YarnScheduler") val cons = clazz.getConstructor(classOf[SparkContext]) cons.newInstance(sc).asInstanceOf[TaskSchedulerImpl] @@ -2541,7 +2622,7 @@ object SparkContext extends Logging { val backend = try { val clazz = - Class.forName("org.apache.spark.scheduler.cluster.YarnClientSchedulerBackend") + Utils.classForName("org.apache.spark.scheduler.cluster.YarnClientSchedulerBackend") val cons = clazz.getConstructor(classOf[TaskSchedulerImpl], classOf[SparkContext]) cons.newInstance(scheduler, sc).asInstanceOf[CoarseGrainedSchedulerBackend] } catch { diff --git a/core/src/main/scala/org/apache/spark/SparkEnv.scala b/core/src/main/scala/org/apache/spark/SparkEnv.scala index d18fc599e9890..adfece4d6e7c0 100644 --- a/core/src/main/scala/org/apache/spark/SparkEnv.scala +++ b/core/src/main/scala/org/apache/spark/SparkEnv.scala @@ -261,7 +261,7 @@ object SparkEnv extends Logging { // Create an instance of the class with the given name, possibly initializing it with our conf def instantiateClass[T](className: String): T = { - val cls = Class.forName(className, true, Utils.getContextOrSparkClassLoader) + val cls = Utils.classForName(className) // Look for a constructor taking a SparkConf and a boolean isDriver, then one taking just // SparkConf, then one taking no arguments try { diff --git a/core/src/main/scala/org/apache/spark/TaskContext.scala b/core/src/main/scala/org/apache/spark/TaskContext.scala index d09e17dea0911..b48836d5c8897 100644 --- a/core/src/main/scala/org/apache/spark/TaskContext.scala +++ b/core/src/main/scala/org/apache/spark/TaskContext.scala @@ -21,6 +21,7 @@ import java.io.Serializable import org.apache.spark.annotation.DeveloperApi import org.apache.spark.executor.TaskMetrics +import org.apache.spark.metrics.source.Source import org.apache.spark.unsafe.memory.TaskMemoryManager import org.apache.spark.util.TaskCompletionListener @@ -32,7 +33,20 @@ object TaskContext { */ def get(): TaskContext = taskContext.get - private val taskContext: ThreadLocal[TaskContext] = new ThreadLocal[TaskContext] + /** + * Returns the partition id of currently active TaskContext. It will return 0 + * if there is no active TaskContext for cases like local execution. + */ + def getPartitionId(): Int = { + val tc = taskContext.get() + if (tc eq null) { + 0 + } else { + tc.partitionId() + } + } + + private[this] val taskContext: ThreadLocal[TaskContext] = new ThreadLocal[TaskContext] // Note: protected[spark] instead of private[spark] to prevent the following two from // showing up in JavaDoc. @@ -135,8 +149,34 @@ abstract class TaskContext extends Serializable { @DeveloperApi def taskMetrics(): TaskMetrics + /** + * ::DeveloperApi:: + * Returns all metrics sources with the given name which are associated with the instance + * which runs the task. For more information see [[org.apache.spark.metrics.MetricsSystem!]]. + */ + @DeveloperApi + def getMetricsSources(sourceName: String): Seq[Source] + /** * Returns the manager for this task's managed memory. */ private[spark] def taskMemoryManager(): TaskMemoryManager + + /** + * Register an accumulator that belongs to this task. Accumulators must call this method when + * deserializing in executors. + */ + private[spark] def registerAccumulator(a: Accumulable[_, _]): Unit + + /** + * Return the local values of internal accumulators that belong to this task. The key of the Map + * is the accumulator id and the value of the Map is the latest accumulator local value. + */ + private[spark] def collectInternalAccumulators(): Map[Long, Any] + + /** + * Return the local values of accumulators that belong to this task. The key of the Map is the + * accumulator id and the value of the Map is the latest accumulator local value. + */ + private[spark] def collectAccumulators(): Map[Long, Any] } diff --git a/core/src/main/scala/org/apache/spark/TaskContextImpl.scala b/core/src/main/scala/org/apache/spark/TaskContextImpl.scala index b4d572cb52313..9ee168ae016f8 100644 --- a/core/src/main/scala/org/apache/spark/TaskContextImpl.scala +++ b/core/src/main/scala/org/apache/spark/TaskContextImpl.scala @@ -17,18 +17,21 @@ package org.apache.spark +import scala.collection.mutable.{ArrayBuffer, HashMap} + import org.apache.spark.executor.TaskMetrics +import org.apache.spark.metrics.MetricsSystem +import org.apache.spark.metrics.source.Source import org.apache.spark.unsafe.memory.TaskMemoryManager import org.apache.spark.util.{TaskCompletionListener, TaskCompletionListenerException} -import scala.collection.mutable.ArrayBuffer - private[spark] class TaskContextImpl( val stageId: Int, val partitionId: Int, override val taskAttemptId: Long, override val attemptNumber: Int, override val taskMemoryManager: TaskMemoryManager, + @transient private val metricsSystem: MetricsSystem, val runningLocally: Boolean = false, val taskMetrics: TaskMetrics = TaskMetrics.empty) extends TaskContext @@ -94,5 +97,21 @@ private[spark] class TaskContextImpl( override def isRunningLocally(): Boolean = runningLocally override def isInterrupted(): Boolean = interrupted -} + override def getMetricsSources(sourceName: String): Seq[Source] = + metricsSystem.getSourcesByName(sourceName) + + @transient private val accumulators = new HashMap[Long, Accumulable[_, _]] + + private[spark] override def registerAccumulator(a: Accumulable[_, _]): Unit = synchronized { + accumulators(a.id) = a + } + + private[spark] override def collectInternalAccumulators(): Map[Long, Any] = synchronized { + accumulators.filter(_._2.isInternal).mapValues(_.localValue).toMap + } + + private[spark] override def collectAccumulators(): Map[Long, Any] = synchronized { + accumulators.mapValues(_.localValue).toMap + } +} diff --git a/core/src/main/scala/org/apache/spark/api/java/JavaRDDLike.scala b/core/src/main/scala/org/apache/spark/api/java/JavaRDDLike.scala index c95615a5a9307..829fae1d1d9bf 100644 --- a/core/src/main/scala/org/apache/spark/api/java/JavaRDDLike.scala +++ b/core/src/main/scala/org/apache/spark/api/java/JavaRDDLike.scala @@ -364,7 +364,7 @@ trait JavaRDDLike[T, This <: JavaRDDLike[T, This]] extends Serializable { // This is useful for implementing `take` from other language frontends // like Python where the data is serialized. import scala.collection.JavaConversions._ - val res = context.runJob(rdd, (it: Iterator[T]) => it.toArray, partitionIds, true) + val res = context.runJob(rdd, (it: Iterator[T]) => it.toArray, partitionIds) res.map(x => new java.util.ArrayList(x.toSeq)).toArray } diff --git a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala index dc9f62f39e6d5..598953ac3bcc8 100644 --- a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala +++ b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala @@ -358,12 +358,11 @@ private[spark] object PythonRDD extends Logging { def runJob( sc: SparkContext, rdd: JavaRDD[Array[Byte]], - partitions: JArrayList[Int], - allowLocal: Boolean): Int = { + partitions: JArrayList[Int]): Int = { type ByteArray = Array[Byte] type UnrolledPartition = Array[ByteArray] val allPartitions: Array[UnrolledPartition] = - sc.runJob(rdd, (x: Iterator[ByteArray]) => x.toArray, partitions, allowLocal) + sc.runJob(rdd, (x: Iterator[ByteArray]) => x.toArray, partitions) val flattenedPartition: UnrolledPartition = Array.concat(allPartitions: _*) serveIterator(flattenedPartition.iterator, s"serve RDD ${rdd.id} with partitions ${partitions.mkString(",")}") diff --git a/core/src/main/scala/org/apache/spark/api/r/RBackendHandler.scala b/core/src/main/scala/org/apache/spark/api/r/RBackendHandler.scala index 4b8f7fe9242e0..a5de10fe89c42 100644 --- a/core/src/main/scala/org/apache/spark/api/r/RBackendHandler.scala +++ b/core/src/main/scala/org/apache/spark/api/r/RBackendHandler.scala @@ -20,12 +20,14 @@ package org.apache.spark.api.r import java.io.{ByteArrayInputStream, ByteArrayOutputStream, DataInputStream, DataOutputStream} import scala.collection.mutable.HashMap +import scala.language.existentials import io.netty.channel.ChannelHandler.Sharable import io.netty.channel.{ChannelHandlerContext, SimpleChannelInboundHandler} import org.apache.spark.Logging import org.apache.spark.api.r.SerDe._ +import org.apache.spark.util.Utils /** * Handler for RBackend @@ -88,21 +90,6 @@ private[r] class RBackendHandler(server: RBackend) ctx.close() } - // Looks up a class given a class name. This function first checks the - // current class loader and if a class is not found, it looks up the class - // in the context class loader. Address [SPARK-5185] - def getStaticClass(objId: String): Class[_] = { - try { - val clsCurrent = Class.forName(objId) - clsCurrent - } catch { - // Use contextLoader if we can't find the JAR in the system class loader - case e: ClassNotFoundException => - val clsContext = Class.forName(objId, true, Thread.currentThread().getContextClassLoader) - clsContext - } - } - def handleMethodCall( isStatic: Boolean, objId: String, @@ -113,7 +100,7 @@ private[r] class RBackendHandler(server: RBackend) var obj: Object = null try { val cls = if (isStatic) { - getStaticClass(objId) + Utils.classForName(objId) } else { JVMObjectTracker.get(objId) match { case None => throw new IllegalArgumentException("Object not found " + objId) diff --git a/core/src/main/scala/org/apache/spark/api/r/RRDD.scala b/core/src/main/scala/org/apache/spark/api/r/RRDD.scala index ff1702f7dea48..23a470d6afcae 100644 --- a/core/src/main/scala/org/apache/spark/api/r/RRDD.scala +++ b/core/src/main/scala/org/apache/spark/api/r/RRDD.scala @@ -39,7 +39,6 @@ private abstract class BaseRRDD[T: ClassTag, U: ClassTag]( deserializer: String, serializer: String, packageNames: Array[Byte], - rLibDir: String, broadcastVars: Array[Broadcast[Object]]) extends RDD[U](parent) with Logging { protected var dataStream: DataInputStream = _ @@ -60,7 +59,7 @@ private abstract class BaseRRDD[T: ClassTag, U: ClassTag]( // The stdout/stderr is shared by multiple tasks, because we use one daemon // to launch child process as worker. - val errThread = RRDD.createRWorker(rLibDir, listenPort) + val errThread = RRDD.createRWorker(listenPort) // We use two sockets to separate input and output, then it's easy to manage // the lifecycle of them to avoid deadlock. @@ -235,11 +234,10 @@ private class PairwiseRRDD[T: ClassTag]( hashFunc: Array[Byte], deserializer: String, packageNames: Array[Byte], - rLibDir: String, broadcastVars: Array[Object]) extends BaseRRDD[T, (Int, Array[Byte])]( parent, numPartitions, hashFunc, deserializer, - SerializationFormats.BYTE, packageNames, rLibDir, + SerializationFormats.BYTE, packageNames, broadcastVars.map(x => x.asInstanceOf[Broadcast[Object]])) { override protected def readData(length: Int): (Int, Array[Byte]) = { @@ -266,10 +264,9 @@ private class RRDD[T: ClassTag]( deserializer: String, serializer: String, packageNames: Array[Byte], - rLibDir: String, broadcastVars: Array[Object]) extends BaseRRDD[T, Array[Byte]]( - parent, -1, func, deserializer, serializer, packageNames, rLibDir, + parent, -1, func, deserializer, serializer, packageNames, broadcastVars.map(x => x.asInstanceOf[Broadcast[Object]])) { override protected def readData(length: Int): Array[Byte] = { @@ -293,10 +290,9 @@ private class StringRRDD[T: ClassTag]( func: Array[Byte], deserializer: String, packageNames: Array[Byte], - rLibDir: String, broadcastVars: Array[Object]) extends BaseRRDD[T, String]( - parent, -1, func, deserializer, SerializationFormats.STRING, packageNames, rLibDir, + parent, -1, func, deserializer, SerializationFormats.STRING, packageNames, broadcastVars.map(x => x.asInstanceOf[Broadcast[Object]])) { override protected def readData(length: Int): String = { @@ -392,9 +388,10 @@ private[r] object RRDD { thread } - private def createRProcess(rLibDir: String, port: Int, script: String): BufferedStreamThread = { + private def createRProcess(port: Int, script: String): BufferedStreamThread = { val rCommand = SparkEnv.get.conf.get("spark.sparkr.r.command", "Rscript") val rOptions = "--vanilla" + val rLibDir = RUtils.sparkRPackagePath(isDriver = false) val rExecScript = rLibDir + "/SparkR/worker/" + script val pb = new ProcessBuilder(List(rCommand, rOptions, rExecScript)) // Unset the R_TESTS environment variable for workers. @@ -413,7 +410,7 @@ private[r] object RRDD { /** * ProcessBuilder used to launch worker R processes. */ - def createRWorker(rLibDir: String, port: Int): BufferedStreamThread = { + def createRWorker(port: Int): BufferedStreamThread = { val useDaemon = SparkEnv.get.conf.getBoolean("spark.sparkr.use.daemon", true) if (!Utils.isWindows && useDaemon) { synchronized { @@ -421,7 +418,7 @@ private[r] object RRDD { // we expect one connections val serverSocket = new ServerSocket(0, 1, InetAddress.getByName("localhost")) val daemonPort = serverSocket.getLocalPort - errThread = createRProcess(rLibDir, daemonPort, "daemon.R") + errThread = createRProcess(daemonPort, "daemon.R") // the socket used to send out the input of task serverSocket.setSoTimeout(10000) val sock = serverSocket.accept() @@ -443,7 +440,7 @@ private[r] object RRDD { errThread } } else { - createRProcess(rLibDir, port, "worker.R") + createRProcess(port, "worker.R") } } diff --git a/core/src/main/scala/org/apache/spark/api/r/RUtils.scala b/core/src/main/scala/org/apache/spark/api/r/RUtils.scala new file mode 100644 index 0000000000000..d53abd3408c55 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/api/r/RUtils.scala @@ -0,0 +1,65 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.api.r + +import java.io.File + +import org.apache.spark.{SparkEnv, SparkException} + +private[spark] object RUtils { + /** + * Get the SparkR package path in the local spark distribution. + */ + def localSparkRPackagePath: Option[String] = { + val sparkHome = sys.env.get("SPARK_HOME") + sparkHome.map( + Seq(_, "R", "lib").mkString(File.separator) + ) + } + + /** + * Get the SparkR package path in various deployment modes. + * This assumes that Spark properties `spark.master` and `spark.submit.deployMode` + * and environment variable `SPARK_HOME` are set. + */ + def sparkRPackagePath(isDriver: Boolean): String = { + val (master, deployMode) = + if (isDriver) { + (sys.props("spark.master"), sys.props("spark.submit.deployMode")) + } else { + val sparkConf = SparkEnv.get.conf + (sparkConf.get("spark.master"), sparkConf.get("spark.submit.deployMode")) + } + + val isYarnCluster = master.contains("yarn") && deployMode == "cluster" + val isYarnClient = master.contains("yarn") && deployMode == "client" + + // In YARN mode, the SparkR package is distributed as an archive symbolically + // linked to the "sparkr" file in the current directory. Note that this does not apply + // to the driver in client mode because it is run outside of the cluster. + if (isYarnCluster || (isYarnClient && !isDriver)) { + new File("sparkr").getAbsolutePath + } else { + // Otherwise, assume the package is local + // TODO: support this for Mesos + localSparkRPackagePath.getOrElse { + throw new SparkException("SPARK_HOME not set. Can't locate SparkR package.") + } + } + } +} diff --git a/core/src/main/scala/org/apache/spark/api/r/SerDe.scala b/core/src/main/scala/org/apache/spark/api/r/SerDe.scala index 56adc857d4ce0..d5b4260bf4529 100644 --- a/core/src/main/scala/org/apache/spark/api/r/SerDe.scala +++ b/core/src/main/scala/org/apache/spark/api/r/SerDe.scala @@ -179,6 +179,7 @@ private[spark] object SerDe { // Int -> integer // String -> character // Boolean -> logical + // Float -> double // Double -> double // Long -> double // Array[Byte] -> raw @@ -215,6 +216,9 @@ private[spark] object SerDe { case "long" | "java.lang.Long" => writeType(dos, "double") writeDouble(dos, value.asInstanceOf[Long].toDouble) + case "float" | "java.lang.Float" => + writeType(dos, "double") + writeDouble(dos, value.asInstanceOf[Float].toDouble) case "double" | "java.lang.Double" => writeType(dos, "double") writeDouble(dos, value.asInstanceOf[Double]) diff --git a/core/src/main/scala/org/apache/spark/broadcast/BroadcastManager.scala b/core/src/main/scala/org/apache/spark/broadcast/BroadcastManager.scala index 685313ac009ba..fac6666bb3410 100644 --- a/core/src/main/scala/org/apache/spark/broadcast/BroadcastManager.scala +++ b/core/src/main/scala/org/apache/spark/broadcast/BroadcastManager.scala @@ -22,6 +22,7 @@ import java.util.concurrent.atomic.AtomicLong import scala.reflect.ClassTag import org.apache.spark._ +import org.apache.spark.util.Utils private[spark] class BroadcastManager( val isDriver: Boolean, @@ -42,7 +43,7 @@ private[spark] class BroadcastManager( conf.get("spark.broadcast.factory", "org.apache.spark.broadcast.TorrentBroadcastFactory") broadcastFactory = - Class.forName(broadcastFactoryClass).newInstance.asInstanceOf[BroadcastFactory] + Utils.classForName(broadcastFactoryClass).newInstance.asInstanceOf[BroadcastFactory] // Initialize appropriate BroadcastFactory and BroadcastObject broadcastFactory.initialize(isDriver, conf, securityManager) diff --git a/core/src/main/scala/org/apache/spark/deploy/RRunner.scala b/core/src/main/scala/org/apache/spark/deploy/RRunner.scala index 4165740312e03..c0cab22fa8252 100644 --- a/core/src/main/scala/org/apache/spark/deploy/RRunner.scala +++ b/core/src/main/scala/org/apache/spark/deploy/RRunner.scala @@ -24,7 +24,7 @@ import scala.collection.JavaConversions._ import org.apache.hadoop.fs.Path -import org.apache.spark.api.r.RBackend +import org.apache.spark.api.r.{RBackend, RUtils} import org.apache.spark.util.RedirectThread /** @@ -71,9 +71,10 @@ object RRunner { val builder = new ProcessBuilder(Seq(rCommand, rFileNormalized) ++ otherArgs) val env = builder.environment() env.put("EXISTING_SPARKR_BACKEND_PORT", sparkRBackendPort.toString) - val sparkHome = System.getenv("SPARK_HOME") + val rPackageDir = RUtils.sparkRPackagePath(isDriver = true) + env.put("SPARKR_PACKAGE_DIR", rPackageDir) env.put("R_PROFILE_USER", - Seq(sparkHome, "R", "lib", "SparkR", "profile", "general.R").mkString(File.separator)) + Seq(rPackageDir, "SparkR", "profile", "general.R").mkString(File.separator)) builder.redirectErrorStream(true) // Ugly but needed for stdout and stderr to synchronize val process = builder.start() diff --git a/core/src/main/scala/org/apache/spark/deploy/SparkHadoopUtil.scala b/core/src/main/scala/org/apache/spark/deploy/SparkHadoopUtil.scala index 6d14590a1d192..e06b06e06fb4a 100644 --- a/core/src/main/scala/org/apache/spark/deploy/SparkHadoopUtil.scala +++ b/core/src/main/scala/org/apache/spark/deploy/SparkHadoopUtil.scala @@ -25,6 +25,7 @@ import java.util.{Arrays, Comparator} import scala.collection.JavaConversions._ import scala.concurrent.duration._ import scala.language.postfixOps +import scala.util.control.NonFatal import com.google.common.primitives.Longs import org.apache.hadoop.conf.Configuration @@ -178,7 +179,7 @@ class SparkHadoopUtil extends Logging { private def getFileSystemThreadStatisticsMethod(methodName: String): Method = { val statisticsDataClass = - Class.forName("org.apache.hadoop.fs.FileSystem$Statistics$StatisticsData") + Utils.classForName("org.apache.hadoop.fs.FileSystem$Statistics$StatisticsData") statisticsDataClass.getDeclaredMethod(methodName) } @@ -238,6 +239,14 @@ class SparkHadoopUtil extends Logging { }.getOrElse(Seq.empty[Path]) } + def globPathIfNecessary(pattern: Path): Seq[Path] = { + if (pattern.toString.exists("{}[]*?\\".toSet.contains)) { + globPath(pattern) + } else { + Seq(pattern) + } + } + /** * Lists all the files in a directory with the specified prefix, and does not end with the * given suffix. The returned {{FileStatus}} instances are sorted by the modification times of @@ -248,19 +257,25 @@ class SparkHadoopUtil extends Logging { dir: Path, prefix: String, exclusionSuffix: String): Array[FileStatus] = { - val fileStatuses = remoteFs.listStatus(dir, - new PathFilter { - override def accept(path: Path): Boolean = { - val name = path.getName - name.startsWith(prefix) && !name.endsWith(exclusionSuffix) + try { + val fileStatuses = remoteFs.listStatus(dir, + new PathFilter { + override def accept(path: Path): Boolean = { + val name = path.getName + name.startsWith(prefix) && !name.endsWith(exclusionSuffix) + } + }) + Arrays.sort(fileStatuses, new Comparator[FileStatus] { + override def compare(o1: FileStatus, o2: FileStatus): Int = { + Longs.compare(o1.getModificationTime, o2.getModificationTime) } }) - Arrays.sort(fileStatuses, new Comparator[FileStatus] { - override def compare(o1: FileStatus, o2: FileStatus): Int = { - Longs.compare(o1.getModificationTime, o2.getModificationTime) - } - }) - fileStatuses + fileStatuses + } catch { + case NonFatal(e) => + logWarning("Error while attempting to list files from application staging dir", e) + Array.empty + } } /** @@ -356,7 +371,7 @@ object SparkHadoopUtil { System.getProperty("SPARK_YARN_MODE", System.getenv("SPARK_YARN_MODE"))) if (yarnMode) { try { - Class.forName("org.apache.spark.deploy.yarn.YarnSparkHadoopUtil") + Utils.classForName("org.apache.spark.deploy.yarn.YarnSparkHadoopUtil") .newInstance() .asInstanceOf[SparkHadoopUtil] } catch { diff --git a/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala b/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala index 4cec9017b8adb..0b39ee8fe3ba0 100644 --- a/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala +++ b/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala @@ -37,6 +37,7 @@ import org.apache.ivy.core.settings.IvySettings import org.apache.ivy.plugins.matcher.GlobPatternMatcher import org.apache.ivy.plugins.repository.file.FileRepository import org.apache.ivy.plugins.resolver.{FileSystemResolver, ChainResolver, IBiblioResolver} +import org.apache.spark.api.r.RUtils import org.apache.spark.SPARK_VERSION import org.apache.spark.deploy.rest._ import org.apache.spark.util.{ChildFirstURLClassLoader, MutableURLClassLoader, Utils} @@ -79,6 +80,7 @@ object SparkSubmit { private val SPARK_SHELL = "spark-shell" private val PYSPARK_SHELL = "pyspark-shell" private val SPARKR_SHELL = "sparkr-shell" + private val SPARKR_PACKAGE_ARCHIVE = "sparkr.zip" private val CLASS_NOT_FOUND_EXIT_STATUS = 101 @@ -262,6 +264,12 @@ object SparkSubmit { } } + // Update args.deployMode if it is null. It will be passed down as a Spark property later. + (args.deployMode, deployMode) match { + case (null, CLIENT) => args.deployMode = "client" + case (null, CLUSTER) => args.deployMode = "cluster" + case _ => + } val isYarnCluster = clusterManager == YARN && deployMode == CLUSTER val isMesosCluster = clusterManager == MESOS && deployMode == CLUSTER @@ -347,6 +355,23 @@ object SparkSubmit { } } + // In YARN mode for an R app, add the SparkR package archive to archives + // that can be distributed with the job + if (args.isR && clusterManager == YARN) { + val rPackagePath = RUtils.localSparkRPackagePath + if (rPackagePath.isEmpty) { + printErrorAndExit("SPARK_HOME does not exist for R application in YARN mode.") + } + val rPackageFile = new File(rPackagePath.get, SPARKR_PACKAGE_ARCHIVE) + if (!rPackageFile.exists()) { + printErrorAndExit(s"$SPARKR_PACKAGE_ARCHIVE does not exist for R application in YARN mode.") + } + val localURI = Utils.resolveURI(rPackageFile.getAbsolutePath) + + // Assigns a symbol link name "sparkr" to the shipped package. + args.archives = mergeFileLists(args.archives, localURI.toString + "#sparkr") + } + // If we're running a R app, set the main class to our specific R runner if (args.isR && deployMode == CLIENT) { if (args.primaryResource == SPARKR_SHELL) { @@ -375,6 +400,8 @@ object SparkSubmit { // All cluster managers OptionAssigner(args.master, ALL_CLUSTER_MGRS, ALL_DEPLOY_MODES, sysProp = "spark.master"), + OptionAssigner(args.deployMode, ALL_CLUSTER_MGRS, ALL_DEPLOY_MODES, + sysProp = "spark.submit.deployMode"), OptionAssigner(args.name, ALL_CLUSTER_MGRS, ALL_DEPLOY_MODES, sysProp = "spark.app.name"), OptionAssigner(args.jars, ALL_CLUSTER_MGRS, CLIENT, sysProp = "spark.jars"), OptionAssigner(args.ivyRepoPath, ALL_CLUSTER_MGRS, CLIENT, sysProp = "spark.jars.ivy"), @@ -481,8 +508,14 @@ object SparkSubmit { } // Let YARN know it's a pyspark app, so it distributes needed libraries. - if (clusterManager == YARN && args.isPython) { - sysProps.put("spark.yarn.isPython", "true") + if (clusterManager == YARN) { + if (args.isPython) { + sysProps.put("spark.yarn.isPython", "true") + } + if (args.principal != null) { + require(args.keytab != null, "Keytab must be specified when the keytab is specified") + UserGroupInformation.loginUserFromKeytab(args.principal, args.keytab) + } } // In yarn-cluster mode, use yarn.Client as a wrapper around the user class @@ -597,7 +630,7 @@ object SparkSubmit { var mainClass: Class[_] = null try { - mainClass = Class.forName(childMainClass, true, loader) + mainClass = Utils.classForName(childMainClass) } catch { case e: ClassNotFoundException => e.printStackTrace(printStream) diff --git a/core/src/main/scala/org/apache/spark/deploy/SparkSubmitArguments.scala b/core/src/main/scala/org/apache/spark/deploy/SparkSubmitArguments.scala index ebb39c354dff1..b3710073e330c 100644 --- a/core/src/main/scala/org/apache/spark/deploy/SparkSubmitArguments.scala +++ b/core/src/main/scala/org/apache/spark/deploy/SparkSubmitArguments.scala @@ -576,7 +576,7 @@ private[deploy] class SparkSubmitArguments(args: Seq[String], env: Map[String, S System.setSecurityManager(sm) try { - Class.forName(mainClass).getMethod("main", classOf[Array[String]]) + Utils.classForName(mainClass).getMethod("main", classOf[Array[String]]) .invoke(null, Array(HELP)) } catch { case e: InvocationTargetException => diff --git a/core/src/main/scala/org/apache/spark/deploy/history/FsHistoryProvider.scala b/core/src/main/scala/org/apache/spark/deploy/history/FsHistoryProvider.scala index 2cc465e55fceb..e3060ac3fa1a9 100644 --- a/core/src/main/scala/org/apache/spark/deploy/history/FsHistoryProvider.scala +++ b/core/src/main/scala/org/apache/spark/deploy/history/FsHistoryProvider.scala @@ -407,8 +407,8 @@ private[history] class FsHistoryProvider(conf: SparkConf, clock: Clock) /** * Comparison function that defines the sort order for application attempts within the same - * application. Order is: running attempts before complete attempts, running attempts sorted - * by start time, completed attempts sorted by end time. + * application. Order is: attempts are sorted by descending start time. + * Most recent attempt state matches with current state of the app. * * Normally applications should have a single running attempt; but failure to call sc.stop() * may cause multiple running attempts to show up. @@ -418,11 +418,7 @@ private[history] class FsHistoryProvider(conf: SparkConf, clock: Clock) private def compareAttemptInfo( a1: FsApplicationAttemptInfo, a2: FsApplicationAttemptInfo): Boolean = { - if (a1.completed == a2.completed) { - if (a1.completed) a1.endTime >= a2.endTime else a1.startTime >= a2.startTime - } else { - !a1.completed - } + a1.startTime >= a2.startTime } /** diff --git a/core/src/main/scala/org/apache/spark/deploy/history/HistoryServer.scala b/core/src/main/scala/org/apache/spark/deploy/history/HistoryServer.scala index 10638afb74900..a076a9c3f984d 100644 --- a/core/src/main/scala/org/apache/spark/deploy/history/HistoryServer.scala +++ b/core/src/main/scala/org/apache/spark/deploy/history/HistoryServer.scala @@ -228,7 +228,7 @@ object HistoryServer extends Logging { val providerName = conf.getOption("spark.history.provider") .getOrElse(classOf[FsHistoryProvider].getName()) - val provider = Class.forName(providerName) + val provider = Utils.classForName(providerName) .getConstructor(classOf[SparkConf]) .newInstance(conf) .asInstanceOf[ApplicationHistoryProvider] diff --git a/core/src/main/scala/org/apache/spark/deploy/master/FileSystemPersistenceEngine.scala b/core/src/main/scala/org/apache/spark/deploy/master/FileSystemPersistenceEngine.scala index f459ed5b3a1a1..aa379d4cd61e7 100644 --- a/core/src/main/scala/org/apache/spark/deploy/master/FileSystemPersistenceEngine.scala +++ b/core/src/main/scala/org/apache/spark/deploy/master/FileSystemPersistenceEngine.scala @@ -21,9 +21,8 @@ import java.io._ import scala.reflect.ClassTag -import akka.serialization.Serialization - import org.apache.spark.Logging +import org.apache.spark.serializer.{DeserializationStream, SerializationStream, Serializer} import org.apache.spark.util.Utils @@ -32,11 +31,11 @@ import org.apache.spark.util.Utils * Files are deleted when applications and workers are removed. * * @param dir Directory to store files. Created if non-existent (but not recursively). - * @param serialization Used to serialize our objects. + * @param serializer Used to serialize our objects. */ private[master] class FileSystemPersistenceEngine( val dir: String, - val serialization: Serialization) + val serializer: Serializer) extends PersistenceEngine with Logging { new File(dir).mkdir() @@ -57,27 +56,31 @@ private[master] class FileSystemPersistenceEngine( private def serializeIntoFile(file: File, value: AnyRef) { val created = file.createNewFile() if (!created) { throw new IllegalStateException("Could not create file: " + file) } - val serializer = serialization.findSerializerFor(value) - val serialized = serializer.toBinary(value) - val out = new FileOutputStream(file) + val fileOut = new FileOutputStream(file) + var out: SerializationStream = null Utils.tryWithSafeFinally { - out.write(serialized) + out = serializer.newInstance().serializeStream(fileOut) + out.writeObject(value) } { - out.close() + fileOut.close() + if (out != null) { + out.close() + } } } private def deserializeFromFile[T](file: File)(implicit m: ClassTag[T]): T = { - val fileData = new Array[Byte](file.length().asInstanceOf[Int]) - val dis = new DataInputStream(new FileInputStream(file)) + val fileIn = new FileInputStream(file) + var in: DeserializationStream = null try { - dis.readFully(fileData) + in = serializer.newInstance().deserializeStream(fileIn) + in.readObject[T]() } finally { - dis.close() + fileIn.close() + if (in != null) { + in.close() + } } - val clazz = m.runtimeClass.asInstanceOf[Class[T]] - val serializer = serialization.serializerFor(clazz) - serializer.fromBinary(fileData).asInstanceOf[T] } } diff --git a/core/src/main/scala/org/apache/spark/deploy/master/Master.scala b/core/src/main/scala/org/apache/spark/deploy/master/Master.scala index 48070768f6edb..4615febf17d24 100644 --- a/core/src/main/scala/org/apache/spark/deploy/master/Master.scala +++ b/core/src/main/scala/org/apache/spark/deploy/master/Master.scala @@ -27,11 +27,8 @@ import scala.collection.mutable.{ArrayBuffer, HashMap, HashSet} import scala.language.postfixOps import scala.util.Random -import akka.serialization.Serialization -import akka.serialization.SerializationExtension import org.apache.hadoop.fs.Path -import org.apache.spark.rpc.akka.AkkaRpcEnv import org.apache.spark.rpc._ import org.apache.spark.{Logging, SecurityManager, SparkConf, SparkException} import org.apache.spark.deploy.{ApplicationDescription, DriverDescription, @@ -44,6 +41,7 @@ import org.apache.spark.deploy.master.ui.MasterWebUI import org.apache.spark.deploy.rest.StandaloneRestServer import org.apache.spark.metrics.MetricsSystem import org.apache.spark.scheduler.{EventLoggingListener, ReplayListenerBus} +import org.apache.spark.serializer.{JavaSerializer, Serializer} import org.apache.spark.ui.SparkUI import org.apache.spark.util.{ThreadUtils, SignalLogger, Utils} @@ -58,9 +56,6 @@ private[master] class Master( private val forwardMessageThread = ThreadUtils.newDaemonSingleThreadScheduledExecutor("master-forward-message-thread") - // TODO Remove it once we don't use akka.serialization.Serialization - private val actorSystem = rpcEnv.asInstanceOf[AkkaRpcEnv].actorSystem - private val hadoopConf = SparkHadoopUtil.get.newConfiguration(conf) private def createDateFormat = new SimpleDateFormat("yyyyMMddHHmmss") // For application IDs @@ -161,20 +156,21 @@ private[master] class Master( masterMetricsSystem.getServletHandlers.foreach(webUi.attachHandler) applicationMetricsSystem.getServletHandlers.foreach(webUi.attachHandler) + val serializer = new JavaSerializer(conf) val (persistenceEngine_, leaderElectionAgent_) = RECOVERY_MODE match { case "ZOOKEEPER" => logInfo("Persisting recovery state to ZooKeeper") val zkFactory = - new ZooKeeperRecoveryModeFactory(conf, SerializationExtension(actorSystem)) + new ZooKeeperRecoveryModeFactory(conf, serializer) (zkFactory.createPersistenceEngine(), zkFactory.createLeaderElectionAgent(this)) case "FILESYSTEM" => val fsFactory = - new FileSystemRecoveryModeFactory(conf, SerializationExtension(actorSystem)) + new FileSystemRecoveryModeFactory(conf, serializer) (fsFactory.createPersistenceEngine(), fsFactory.createLeaderElectionAgent(this)) case "CUSTOM" => - val clazz = Class.forName(conf.get("spark.deploy.recoveryMode.factory")) - val factory = clazz.getConstructor(classOf[SparkConf], classOf[Serialization]) - .newInstance(conf, SerializationExtension(actorSystem)) + val clazz = Utils.classForName(conf.get("spark.deploy.recoveryMode.factory")) + val factory = clazz.getConstructor(classOf[SparkConf], classOf[Serializer]) + .newInstance(conf, serializer) .asInstanceOf[StandaloneRecoveryModeFactory] (factory.createPersistenceEngine(), factory.createLeaderElectionAgent(this)) case _ => @@ -213,7 +209,7 @@ private[master] class Master( override def receive: PartialFunction[Any, Unit] = { case ElectedLeader => { - val (storedApps, storedDrivers, storedWorkers) = persistenceEngine.readPersistedData() + val (storedApps, storedDrivers, storedWorkers) = persistenceEngine.readPersistedData(rpcEnv) state = if (storedApps.isEmpty && storedDrivers.isEmpty && storedWorkers.isEmpty) { RecoveryState.ALIVE } else { diff --git a/core/src/main/scala/org/apache/spark/deploy/master/PersistenceEngine.scala b/core/src/main/scala/org/apache/spark/deploy/master/PersistenceEngine.scala index a03d460509e03..58a00bceee6af 100644 --- a/core/src/main/scala/org/apache/spark/deploy/master/PersistenceEngine.scala +++ b/core/src/main/scala/org/apache/spark/deploy/master/PersistenceEngine.scala @@ -18,6 +18,7 @@ package org.apache.spark.deploy.master import org.apache.spark.annotation.DeveloperApi +import org.apache.spark.rpc.RpcEnv import scala.reflect.ClassTag @@ -80,8 +81,11 @@ abstract class PersistenceEngine { * Returns the persisted data sorted by their respective ids (which implies that they're * sorted by time of creation). */ - final def readPersistedData(): (Seq[ApplicationInfo], Seq[DriverInfo], Seq[WorkerInfo]) = { - (read[ApplicationInfo]("app_"), read[DriverInfo]("driver_"), read[WorkerInfo]("worker_")) + final def readPersistedData( + rpcEnv: RpcEnv): (Seq[ApplicationInfo], Seq[DriverInfo], Seq[WorkerInfo]) = { + rpcEnv.deserialize { () => + (read[ApplicationInfo]("app_"), read[DriverInfo]("driver_"), read[WorkerInfo]("worker_")) + } } def close() {} diff --git a/core/src/main/scala/org/apache/spark/deploy/master/RecoveryModeFactory.scala b/core/src/main/scala/org/apache/spark/deploy/master/RecoveryModeFactory.scala index 351db8fab2041..c4c3283fb73f7 100644 --- a/core/src/main/scala/org/apache/spark/deploy/master/RecoveryModeFactory.scala +++ b/core/src/main/scala/org/apache/spark/deploy/master/RecoveryModeFactory.scala @@ -17,10 +17,9 @@ package org.apache.spark.deploy.master -import akka.serialization.Serialization - import org.apache.spark.{Logging, SparkConf} import org.apache.spark.annotation.DeveloperApi +import org.apache.spark.serializer.Serializer /** * ::DeveloperApi:: @@ -30,7 +29,7 @@ import org.apache.spark.annotation.DeveloperApi * */ @DeveloperApi -abstract class StandaloneRecoveryModeFactory(conf: SparkConf, serializer: Serialization) { +abstract class StandaloneRecoveryModeFactory(conf: SparkConf, serializer: Serializer) { /** * PersistenceEngine defines how the persistent data(Information about worker, driver etc..) @@ -49,7 +48,7 @@ abstract class StandaloneRecoveryModeFactory(conf: SparkConf, serializer: Serial * LeaderAgent in this case is a no-op. Since leader is forever leader as the actual * recovery is made by restoring from filesystem. */ -private[master] class FileSystemRecoveryModeFactory(conf: SparkConf, serializer: Serialization) +private[master] class FileSystemRecoveryModeFactory(conf: SparkConf, serializer: Serializer) extends StandaloneRecoveryModeFactory(conf, serializer) with Logging { val RECOVERY_DIR = conf.get("spark.deploy.recoveryDirectory", "") @@ -64,7 +63,7 @@ private[master] class FileSystemRecoveryModeFactory(conf: SparkConf, serializer: } } -private[master] class ZooKeeperRecoveryModeFactory(conf: SparkConf, serializer: Serialization) +private[master] class ZooKeeperRecoveryModeFactory(conf: SparkConf, serializer: Serializer) extends StandaloneRecoveryModeFactory(conf, serializer) { def createPersistenceEngine(): PersistenceEngine = { diff --git a/core/src/main/scala/org/apache/spark/deploy/master/ZooKeeperPersistenceEngine.scala b/core/src/main/scala/org/apache/spark/deploy/master/ZooKeeperPersistenceEngine.scala index 328d95a7a0c68..563831cc6b8dd 100644 --- a/core/src/main/scala/org/apache/spark/deploy/master/ZooKeeperPersistenceEngine.scala +++ b/core/src/main/scala/org/apache/spark/deploy/master/ZooKeeperPersistenceEngine.scala @@ -17,7 +17,7 @@ package org.apache.spark.deploy.master -import akka.serialization.Serialization +import java.nio.ByteBuffer import scala.collection.JavaConversions._ import scala.reflect.ClassTag @@ -27,9 +27,10 @@ import org.apache.zookeeper.CreateMode import org.apache.spark.{Logging, SparkConf} import org.apache.spark.deploy.SparkCuratorUtil +import org.apache.spark.serializer.Serializer -private[master] class ZooKeeperPersistenceEngine(conf: SparkConf, val serialization: Serialization) +private[master] class ZooKeeperPersistenceEngine(conf: SparkConf, val serializer: Serializer) extends PersistenceEngine with Logging { @@ -57,17 +58,16 @@ private[master] class ZooKeeperPersistenceEngine(conf: SparkConf, val serializat } private def serializeIntoFile(path: String, value: AnyRef) { - val serializer = serialization.findSerializerFor(value) - val serialized = serializer.toBinary(value) - zk.create().withMode(CreateMode.PERSISTENT).forPath(path, serialized) + val serialized = serializer.newInstance().serialize(value) + val bytes = new Array[Byte](serialized.remaining()) + serialized.get(bytes) + zk.create().withMode(CreateMode.PERSISTENT).forPath(path, bytes) } private def deserializeFromFile[T](filename: String)(implicit m: ClassTag[T]): Option[T] = { val fileData = zk.getData().forPath(WORKING_DIR + "/" + filename) - val clazz = m.runtimeClass.asInstanceOf[Class[T]] - val serializer = serialization.serializerFor(clazz) try { - Some(serializer.fromBinary(fileData).asInstanceOf[T]) + Some(serializer.newInstance().deserialize[T](ByteBuffer.wrap(fileData))) } catch { case e: Exception => { logWarning("Exception while reading persisted file, deleting", e) diff --git a/core/src/main/scala/org/apache/spark/deploy/rest/SubmitRestProtocolMessage.scala b/core/src/main/scala/org/apache/spark/deploy/rest/SubmitRestProtocolMessage.scala index e6615a3174ce1..ef5a7e35ad562 100644 --- a/core/src/main/scala/org/apache/spark/deploy/rest/SubmitRestProtocolMessage.scala +++ b/core/src/main/scala/org/apache/spark/deploy/rest/SubmitRestProtocolMessage.scala @@ -128,7 +128,7 @@ private[spark] object SubmitRestProtocolMessage { */ def fromJson(json: String): SubmitRestProtocolMessage = { val className = parseAction(json) - val clazz = Class.forName(packagePrefix + "." + className) + val clazz = Utils.classForName(packagePrefix + "." + className) .asSubclass[SubmitRestProtocolMessage](classOf[SubmitRestProtocolMessage]) fromJson(json, clazz) } diff --git a/core/src/main/scala/org/apache/spark/deploy/worker/DriverWrapper.scala b/core/src/main/scala/org/apache/spark/deploy/worker/DriverWrapper.scala index 2d6be3042c905..6799f78ec0c19 100644 --- a/core/src/main/scala/org/apache/spark/deploy/worker/DriverWrapper.scala +++ b/core/src/main/scala/org/apache/spark/deploy/worker/DriverWrapper.scala @@ -53,7 +53,7 @@ object DriverWrapper { Thread.currentThread.setContextClassLoader(loader) // Delegate to supplied main class - val clazz = Class.forName(mainClass, true, loader) + val clazz = Utils.classForName(mainClass) val mainMethod = clazz.getMethod("main", classOf[Array[String]]) mainMethod.invoke(null, extraArgs.toArray[String]) diff --git a/core/src/main/scala/org/apache/spark/deploy/worker/WorkerArguments.scala b/core/src/main/scala/org/apache/spark/deploy/worker/WorkerArguments.scala index e89d076802215..5181142c5f80e 100644 --- a/core/src/main/scala/org/apache/spark/deploy/worker/WorkerArguments.scala +++ b/core/src/main/scala/org/apache/spark/deploy/worker/WorkerArguments.scala @@ -149,6 +149,7 @@ private[worker] class WorkerArguments(args: Array[String], conf: SparkConf) { val ibmVendor = System.getProperty("java.vendor").contains("IBM") var totalMb = 0 try { + // scalastyle:off classforname val bean = ManagementFactory.getOperatingSystemMXBean() if (ibmVendor) { val beanClass = Class.forName("com.ibm.lang.management.OperatingSystemMXBean") @@ -159,6 +160,7 @@ private[worker] class WorkerArguments(args: Array[String], conf: SparkConf) { val method = beanClass.getDeclaredMethod("getTotalPhysicalMemorySize") totalMb = (method.invoke(bean).asInstanceOf[Long] / 1024 / 1024).toInt } + // scalastyle:on classforname } catch { case e: Exception => { totalMb = 2*1024 diff --git a/core/src/main/scala/org/apache/spark/executor/Executor.scala b/core/src/main/scala/org/apache/spark/executor/Executor.scala index f7ef92bc80f91..e76664f1bd7b0 100644 --- a/core/src/main/scala/org/apache/spark/executor/Executor.scala +++ b/core/src/main/scala/org/apache/spark/executor/Executor.scala @@ -209,15 +209,19 @@ private[spark] class Executor( // Run the actual task and measure its runtime. taskStart = System.currentTimeMillis() - val value = try { - task.run(taskAttemptId = taskId, attemptNumber = attemptNumber) + var threwException = true + val (value, accumUpdates) = try { + val res = task.run( + taskAttemptId = taskId, + attemptNumber = attemptNumber, + metricsSystem = env.metricsSystem) + threwException = false + res } finally { - // Note: this memory freeing logic is duplicated in DAGScheduler.runLocallyWithinThread; - // when changing this, make sure to update both copies. val freedMemory = taskMemoryManager.cleanUpAllAllocatedMemory() if (freedMemory > 0) { val errMsg = s"Managed memory leak detected; size = $freedMemory bytes, TID = $taskId" - if (conf.getBoolean("spark.unsafe.exceptionOnMemoryLeak", false)) { + if (conf.getBoolean("spark.unsafe.exceptionOnMemoryLeak", false) && !threwException) { throw new SparkException(errMsg) } else { logError(errMsg) @@ -247,7 +251,6 @@ private[spark] class Executor( m.setResultSerializationTime(afterSerialization - beforeSerialization) } - val accumUpdates = Accumulators.values val directResult = new DirectTaskResult(valueBytes, accumUpdates, task.metrics.orNull) val serializedDirectResult = ser.serialize(directResult) val resultSize = serializedDirectResult.limit @@ -314,8 +317,6 @@ private[spark] class Executor( env.shuffleMemoryManager.releaseMemoryForThisThread() // Release memory used by this thread for unrolling blocks env.blockManager.memoryStore.releaseUnrollMemoryForThisThread() - // Release memory used by this thread for accumulators - Accumulators.clear() runningTasks.remove(taskId) } } @@ -356,7 +357,7 @@ private[spark] class Executor( logInfo("Using REPL class URI: " + classUri) try { val _userClassPathFirst: java.lang.Boolean = userClassPathFirst - val klass = Class.forName("org.apache.spark.repl.ExecutorClassLoader") + val klass = Utils.classForName("org.apache.spark.repl.ExecutorClassLoader") .asInstanceOf[Class[_ <: ClassLoader]] val constructor = klass.getConstructor(classOf[SparkConf], classOf[String], classOf[ClassLoader], classOf[Boolean]) @@ -424,6 +425,7 @@ private[spark] class Executor( metrics.updateShuffleReadMetrics() metrics.updateInputMetrics() metrics.setJvmGCTime(curGCTime - taskRunner.startGCTime) + metrics.updateAccumulators() if (isLocal) { // JobProgressListener will hold an reference of it during diff --git a/core/src/main/scala/org/apache/spark/executor/TaskMetrics.scala b/core/src/main/scala/org/apache/spark/executor/TaskMetrics.scala index a3b4561b07e7f..42207a9553592 100644 --- a/core/src/main/scala/org/apache/spark/executor/TaskMetrics.scala +++ b/core/src/main/scala/org/apache/spark/executor/TaskMetrics.scala @@ -17,11 +17,15 @@ package org.apache.spark.executor +import java.io.{IOException, ObjectInputStream} +import java.util.concurrent.ConcurrentHashMap + import scala.collection.mutable.ArrayBuffer import org.apache.spark.annotation.DeveloperApi import org.apache.spark.executor.DataReadMethod.DataReadMethod import org.apache.spark.storage.{BlockId, BlockStatus} +import org.apache.spark.util.Utils /** * :: DeveloperApi :: @@ -210,10 +214,42 @@ class TaskMetrics extends Serializable { private[spark] def updateInputMetrics(): Unit = synchronized { inputMetrics.foreach(_.updateBytesRead()) } + + @throws(classOf[IOException]) + private def readObject(in: ObjectInputStream): Unit = Utils.tryOrIOException { + in.defaultReadObject() + // Get the hostname from cached data, since hostname is the order of number of nodes in + // cluster, so using cached hostname will decrease the object number and alleviate the GC + // overhead. + _hostname = TaskMetrics.getCachedHostName(_hostname) + } + + private var _accumulatorUpdates: Map[Long, Any] = Map.empty + @transient private var _accumulatorsUpdater: () => Map[Long, Any] = null + + private[spark] def updateAccumulators(): Unit = synchronized { + _accumulatorUpdates = _accumulatorsUpdater() + } + + /** + * Return the latest updates of accumulators in this task. + */ + def accumulatorUpdates(): Map[Long, Any] = _accumulatorUpdates + + private[spark] def setAccumulatorsUpdater(accumulatorsUpdater: () => Map[Long, Any]): Unit = { + _accumulatorsUpdater = accumulatorsUpdater + } } private[spark] object TaskMetrics { + private val hostNameCache = new ConcurrentHashMap[String, String]() + def empty: TaskMetrics = new TaskMetrics + + def getCachedHostName(host: String): String = { + val canonicalHost = hostNameCache.putIfAbsent(host, host) + if (canonicalHost != null) canonicalHost else host + } } /** diff --git a/core/src/main/scala/org/apache/spark/io/CompressionCodec.scala b/core/src/main/scala/org/apache/spark/io/CompressionCodec.scala index 0d8ac1f80a9f4..607d5a321efca 100644 --- a/core/src/main/scala/org/apache/spark/io/CompressionCodec.scala +++ b/core/src/main/scala/org/apache/spark/io/CompressionCodec.scala @@ -63,8 +63,7 @@ private[spark] object CompressionCodec { def createCodec(conf: SparkConf, codecName: String): CompressionCodec = { val codecClass = shortCompressionCodecNames.getOrElse(codecName.toLowerCase, codecName) val codec = try { - val ctor = Class.forName(codecClass, true, Utils.getContextOrSparkClassLoader) - .getConstructor(classOf[SparkConf]) + val ctor = Utils.classForName(codecClass).getConstructor(classOf[SparkConf]) Some(ctor.newInstance(conf).asInstanceOf[CompressionCodec]) } catch { case e: ClassNotFoundException => None diff --git a/core/src/main/scala/org/apache/spark/mapred/SparkHadoopMapRedUtil.scala b/core/src/main/scala/org/apache/spark/mapred/SparkHadoopMapRedUtil.scala index 818f7a4c8d422..87df42748be44 100644 --- a/core/src/main/scala/org/apache/spark/mapred/SparkHadoopMapRedUtil.scala +++ b/core/src/main/scala/org/apache/spark/mapred/SparkHadoopMapRedUtil.scala @@ -26,6 +26,7 @@ import org.apache.hadoop.mapreduce.{OutputCommitter => MapReduceOutputCommitter} import org.apache.spark.executor.CommitDeniedException import org.apache.spark.{Logging, SparkEnv, TaskContext} +import org.apache.spark.util.{Utils => SparkUtils} private[spark] trait SparkHadoopMapRedUtil { @@ -64,10 +65,10 @@ trait SparkHadoopMapRedUtil { private def firstAvailableClass(first: String, second: String): Class[_] = { try { - Class.forName(first) + SparkUtils.classForName(first) } catch { case e: ClassNotFoundException => - Class.forName(second) + SparkUtils.classForName(second) } } } diff --git a/core/src/main/scala/org/apache/spark/mapreduce/SparkHadoopMapReduceUtil.scala b/core/src/main/scala/org/apache/spark/mapreduce/SparkHadoopMapReduceUtil.scala index 390d148bc97f9..943ebcb7bd0a1 100644 --- a/core/src/main/scala/org/apache/spark/mapreduce/SparkHadoopMapReduceUtil.scala +++ b/core/src/main/scala/org/apache/spark/mapreduce/SparkHadoopMapReduceUtil.scala @@ -21,6 +21,7 @@ import java.lang.{Boolean => JBoolean, Integer => JInteger} import org.apache.hadoop.conf.Configuration import org.apache.hadoop.mapreduce.{JobContext, JobID, TaskAttemptContext, TaskAttemptID} +import org.apache.spark.util.Utils private[spark] trait SparkHadoopMapReduceUtil { @@ -46,7 +47,7 @@ trait SparkHadoopMapReduceUtil { isMap: Boolean, taskId: Int, attemptId: Int): TaskAttemptID = { - val klass = Class.forName("org.apache.hadoop.mapreduce.TaskAttemptID") + val klass = Utils.classForName("org.apache.hadoop.mapreduce.TaskAttemptID") try { // First, attempt to use the old-style constructor that takes a boolean isMap // (not available in YARN) @@ -57,7 +58,7 @@ trait SparkHadoopMapReduceUtil { } catch { case exc: NoSuchMethodException => { // If that failed, look for the new constructor that takes a TaskType (not available in 1.x) - val taskTypeClass = Class.forName("org.apache.hadoop.mapreduce.TaskType") + val taskTypeClass = Utils.classForName("org.apache.hadoop.mapreduce.TaskType") .asInstanceOf[Class[Enum[_]]] val taskType = taskTypeClass.getMethod("valueOf", classOf[String]).invoke( taskTypeClass, if (isMap) "MAP" else "REDUCE") @@ -71,10 +72,10 @@ trait SparkHadoopMapReduceUtil { private def firstAvailableClass(first: String, second: String): Class[_] = { try { - Class.forName(first) + Utils.classForName(first) } catch { case e: ClassNotFoundException => - Class.forName(second) + Utils.classForName(second) } } } diff --git a/core/src/main/scala/org/apache/spark/metrics/MetricsSystem.scala b/core/src/main/scala/org/apache/spark/metrics/MetricsSystem.scala index ed5131c79fdc5..4517f465ebd3b 100644 --- a/core/src/main/scala/org/apache/spark/metrics/MetricsSystem.scala +++ b/core/src/main/scala/org/apache/spark/metrics/MetricsSystem.scala @@ -20,6 +20,8 @@ package org.apache.spark.metrics import java.util.Properties import java.util.concurrent.TimeUnit +import org.apache.spark.util.Utils + import scala.collection.mutable import com.codahale.metrics.{Metric, MetricFilter, MetricRegistry} @@ -140,6 +142,9 @@ private[spark] class MetricsSystem private ( } else { defaultName } } + def getSourcesByName(sourceName: String): Seq[Source] = + sources.filter(_.sourceName == sourceName) + def registerSource(source: Source) { sources += source try { @@ -166,7 +171,7 @@ private[spark] class MetricsSystem private ( sourceConfigs.foreach { kv => val classPath = kv._2.getProperty("class") try { - val source = Class.forName(classPath).newInstance() + val source = Utils.classForName(classPath).newInstance() registerSource(source.asInstanceOf[Source]) } catch { case e: Exception => logError("Source class " + classPath + " cannot be instantiated", e) @@ -182,7 +187,7 @@ private[spark] class MetricsSystem private ( val classPath = kv._2.getProperty("class") if (null != classPath) { try { - val sink = Class.forName(classPath) + val sink = Utils.classForName(classPath) .getConstructor(classOf[Properties], classOf[MetricRegistry], classOf[SecurityManager]) .newInstance(kv._2, registry, securityMgr) if (kv._1 == "servlet") { diff --git a/core/src/main/scala/org/apache/spark/rdd/CoGroupedRDD.scala b/core/src/main/scala/org/apache/spark/rdd/CoGroupedRDD.scala index 658e8c8b89318..130b58882d8ee 100644 --- a/core/src/main/scala/org/apache/spark/rdd/CoGroupedRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/CoGroupedRDD.scala @@ -94,13 +94,14 @@ class CoGroupedRDD[K](@transient var rdds: Seq[RDD[_ <: Product2[K, _]]], part: } override def getDependencies: Seq[Dependency[_]] = { - rdds.map { rdd: RDD[_ <: Product2[K, _]] => + rdds.map { rdd: RDD[_] => if (rdd.partitioner == Some(part)) { logDebug("Adding one-to-one dependency with " + rdd) new OneToOneDependency(rdd) } else { logDebug("Adding shuffle dependency with " + rdd) - new ShuffleDependency[K, Any, CoGroupCombiner](rdd, part, serializer) + new ShuffleDependency[K, Any, CoGroupCombiner]( + rdd.asInstanceOf[RDD[_ <: Product2[K, _]]], part, serializer) } } } @@ -133,7 +134,7 @@ class CoGroupedRDD[K](@transient var rdds: Seq[RDD[_ <: Product2[K, _]]], part: // A list of (rdd iterator, dependency number) pairs val rddIterators = new ArrayBuffer[(Iterator[Product2[K, Any]], Int)] for ((dep, depNum) <- dependencies.zipWithIndex) dep match { - case oneToOneDependency: OneToOneDependency[Product2[K, Any]] => + case oneToOneDependency: OneToOneDependency[Product2[K, Any]] @unchecked => val dependencyPartition = split.narrowDeps(depNum).get.split // Read them from the parent val it = oneToOneDependency.rdd.iterator(dependencyPartition, context) diff --git a/core/src/main/scala/org/apache/spark/rdd/CoalescedRDD.scala b/core/src/main/scala/org/apache/spark/rdd/CoalescedRDD.scala index 663eebb8e4191..90d9735cb3f69 100644 --- a/core/src/main/scala/org/apache/spark/rdd/CoalescedRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/CoalescedRDD.scala @@ -69,7 +69,7 @@ private[spark] case class CoalescedRDDPartition( * the preferred location of each new partition overlaps with as many preferred locations of its * parent partitions * @param prev RDD to be coalesced - * @param maxPartitions number of desired partitions in the coalesced RDD + * @param maxPartitions number of desired partitions in the coalesced RDD (must be positive) * @param balanceSlack used to trade-off balance and locality. 1.0 is all locality, 0 is all balance */ private[spark] class CoalescedRDD[T: ClassTag]( @@ -78,6 +78,9 @@ private[spark] class CoalescedRDD[T: ClassTag]( balanceSlack: Double = 0.10) extends RDD[T](prev.context, Nil) { // Nil since we implement getDependencies + require(maxPartitions > 0 || maxPartitions == prev.partitions.length, + s"Number of partitions ($maxPartitions) must be positive.") + override def getPartitions: Array[Partition] = { val pc = new PartitionCoalescer(maxPartitions, prev, balanceSlack) diff --git a/core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala b/core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala index bee59a437f120..f1c17369cb48c 100644 --- a/core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala @@ -383,11 +383,11 @@ private[spark] object HadoopRDD extends Logging { private[spark] class SplitInfoReflections { val inputSplitWithLocationInfo = - Class.forName("org.apache.hadoop.mapred.InputSplitWithLocationInfo") + Utils.classForName("org.apache.hadoop.mapred.InputSplitWithLocationInfo") val getLocationInfo = inputSplitWithLocationInfo.getMethod("getLocationInfo") - val newInputSplit = Class.forName("org.apache.hadoop.mapreduce.InputSplit") + val newInputSplit = Utils.classForName("org.apache.hadoop.mapreduce.InputSplit") val newGetLocationInfo = newInputSplit.getMethod("getLocationInfo") - val splitLocationInfo = Class.forName("org.apache.hadoop.mapred.SplitLocationInfo") + val splitLocationInfo = Utils.classForName("org.apache.hadoop.mapred.SplitLocationInfo") val isInMemory = splitLocationInfo.getMethod("isInMemory") val getLocation = splitLocationInfo.getMethod("getLocation") } diff --git a/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala b/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala index 91a6a2d039852..326fafb230a40 100644 --- a/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala +++ b/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala @@ -881,7 +881,7 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)]) } buf } : Seq[V] - val res = self.context.runJob(self, process, Array(index), false) + val res = self.context.runJob(self, process, Array(index)) res(0) case None => self.filter(_._1 == key).map(_._2).collect() diff --git a/core/src/main/scala/org/apache/spark/rdd/RDD.scala b/core/src/main/scala/org/apache/spark/rdd/RDD.scala index 9f7ebae3e9af3..6d61d227382d7 100644 --- a/core/src/main/scala/org/apache/spark/rdd/RDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/RDD.scala @@ -897,7 +897,7 @@ abstract class RDD[T: ClassTag]( */ def toLocalIterator: Iterator[T] = withScope { def collectPartition(p: Int): Array[T] = { - sc.runJob(this, (iter: Iterator[T]) => iter.toArray, Seq(p), allowLocal = false).head + sc.runJob(this, (iter: Iterator[T]) => iter.toArray, Seq(p)).head } (0 until partitions.length).iterator.flatMap(i => collectPartition(i)) } @@ -1082,7 +1082,9 @@ abstract class RDD[T: ClassTag]( val scale = math.max(math.ceil(math.pow(numPartitions, 1.0 / depth)).toInt, 2) // If creating an extra level doesn't help reduce // the wall-clock time, we stop tree aggregation. - while (numPartitions > scale + numPartitions / scale) { + + // Don't trigger TreeAggregation when it doesn't save wall-clock time + while (numPartitions > scale + math.ceil(numPartitions.toDouble / scale)) { numPartitions /= scale val curNumPartitions = numPartitions partiallyAggregated = partiallyAggregated.mapPartitionsWithIndex { @@ -1273,7 +1275,7 @@ abstract class RDD[T: ClassTag]( val left = num - buf.size val p = partsScanned until math.min(partsScanned + numPartsToTry, totalParts) - val res = sc.runJob(this, (it: Iterator[T]) => it.take(left).toArray, p, allowLocal = true) + val res = sc.runJob(this, (it: Iterator[T]) => it.take(left).toArray, p) res.foreach(buf ++= _.take(num - buf.size)) partsScanned += numPartsToTry diff --git a/core/src/main/scala/org/apache/spark/rdd/ZippedWithIndexRDD.scala b/core/src/main/scala/org/apache/spark/rdd/ZippedWithIndexRDD.scala index 523aaf2b860b5..e277ae28d588f 100644 --- a/core/src/main/scala/org/apache/spark/rdd/ZippedWithIndexRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/ZippedWithIndexRDD.scala @@ -50,8 +50,7 @@ class ZippedWithIndexRDD[T: ClassTag](@transient prev: RDD[T]) extends RDD[(T, L prev.context.runJob( prev, Utils.getIteratorSize _, - 0 until n - 1, // do not need to count the last partition - allowLocal = false + 0 until n - 1 // do not need to count the last partition ).scanLeft(0L)(_ + _) } } diff --git a/core/src/main/scala/org/apache/spark/rpc/RpcEnv.scala b/core/src/main/scala/org/apache/spark/rpc/RpcEnv.scala index 1709bdf560b6f..29debe8081308 100644 --- a/core/src/main/scala/org/apache/spark/rpc/RpcEnv.scala +++ b/core/src/main/scala/org/apache/spark/rpc/RpcEnv.scala @@ -39,8 +39,7 @@ private[spark] object RpcEnv { val rpcEnvNames = Map("akka" -> "org.apache.spark.rpc.akka.AkkaRpcEnvFactory") val rpcEnvName = conf.get("spark.rpc", "akka") val rpcEnvFactoryClassName = rpcEnvNames.getOrElse(rpcEnvName.toLowerCase, rpcEnvName) - Class.forName(rpcEnvFactoryClassName, true, Utils.getContextOrSparkClassLoader). - newInstance().asInstanceOf[RpcEnvFactory] + Utils.classForName(rpcEnvFactoryClassName).newInstance().asInstanceOf[RpcEnvFactory] } def create( @@ -140,6 +139,12 @@ private[spark] abstract class RpcEnv(conf: SparkConf) { * creating it manually because different [[RpcEnv]] may have different formats. */ def uriOf(systemName: String, address: RpcAddress, endpointName: String): String + + /** + * [[RpcEndpointRef]] cannot be deserialized without [[RpcEnv]]. So when deserializing any object + * that contains [[RpcEndpointRef]]s, the deserialization codes should be wrapped by this method. + */ + def deserialize[T](deserializationAction: () => T): T } diff --git a/core/src/main/scala/org/apache/spark/rpc/akka/AkkaRpcEnv.scala b/core/src/main/scala/org/apache/spark/rpc/akka/AkkaRpcEnv.scala index f2d87f68341af..fc17542abf81d 100644 --- a/core/src/main/scala/org/apache/spark/rpc/akka/AkkaRpcEnv.scala +++ b/core/src/main/scala/org/apache/spark/rpc/akka/AkkaRpcEnv.scala @@ -28,7 +28,7 @@ import akka.actor.{ActorSystem, ExtendedActorSystem, Actor, ActorRef, Props, Add import akka.event.Logging.Error import akka.pattern.{ask => akkaAsk} import akka.remote.{AssociationEvent, AssociatedEvent, DisassociatedEvent, AssociationErrorEvent} -import com.google.common.util.concurrent.MoreExecutors +import akka.serialization.JavaSerializer import org.apache.spark.{SparkException, Logging, SparkConf} import org.apache.spark.rpc._ @@ -239,6 +239,12 @@ private[spark] class AkkaRpcEnv private[akka] ( } override def toString: String = s"${getClass.getSimpleName}($actorSystem)" + + override def deserialize[T](deserializationAction: () => T): T = { + JavaSerializer.currentSystem.withValue(actorSystem.asInstanceOf[ExtendedActorSystem]) { + deserializationAction() + } + } } private[spark] class AkkaRpcEnvFactory extends RpcEnvFactory { @@ -315,6 +321,12 @@ private[akka] class AkkaRpcEndpointRef( override def toString: String = s"${getClass.getSimpleName}($actorRef)" + final override def equals(that: Any): Boolean = that match { + case other: AkkaRpcEndpointRef => actorRef == other.actorRef + case _ => false + } + + final override def hashCode(): Int = if (actorRef == null) 0 else actorRef.hashCode() } /** diff --git a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala index f3d87ee5c4fd1..552dabcfa5139 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala @@ -22,7 +22,8 @@ import java.util.Properties import java.util.concurrent.TimeUnit import java.util.concurrent.atomic.AtomicInteger -import scala.collection.mutable.{ArrayBuffer, HashMap, HashSet, Map, Stack} +import scala.collection.Map +import scala.collection.mutable.{ArrayBuffer, HashMap, HashSet, Stack} import scala.concurrent.duration._ import scala.language.existentials import scala.language.postfixOps @@ -37,7 +38,6 @@ import org.apache.spark.partial.{ApproximateActionListener, ApproximateEvaluator import org.apache.spark.rdd.RDD import org.apache.spark.rpc.RpcTimeout import org.apache.spark.storage._ -import org.apache.spark.unsafe.memory.TaskMemoryManager import org.apache.spark.util._ import org.apache.spark.storage.BlockManagerMessages.BlockManagerHeartbeat @@ -127,10 +127,6 @@ class DAGScheduler( // This is only safe because DAGScheduler runs in a single thread. private val closureSerializer = SparkEnv.get.closureSerializer.newInstance() - - /** If enabled, we may run certain actions like take() and first() locally. */ - private val localExecutionEnabled = sc.getConf.getBoolean("spark.localExecution.enabled", false) - /** If enabled, FetchFailed will not cause stage retry, in order to surface the problem. */ private val disallowStageRetryForTest = sc.getConf.getBoolean("spark.test.noStageRetry", false) @@ -514,7 +510,6 @@ class DAGScheduler( func: (TaskContext, Iterator[T]) => U, partitions: Seq[Int], callSite: CallSite, - allowLocal: Boolean, resultHandler: (Int, U) => Unit, properties: Properties): JobWaiter[U] = { // Check to make sure we are not launching a task on a partition that does not exist. @@ -534,7 +529,7 @@ class DAGScheduler( val func2 = func.asInstanceOf[(TaskContext, Iterator[_]) => _] val waiter = new JobWaiter(this, jobId, partitions.size, resultHandler) eventProcessLoop.post(JobSubmitted( - jobId, rdd, func2, partitions.toArray, allowLocal, callSite, waiter, + jobId, rdd, func2, partitions.toArray, callSite, waiter, SerializationUtils.clone(properties))) waiter } @@ -544,11 +539,10 @@ class DAGScheduler( func: (TaskContext, Iterator[T]) => U, partitions: Seq[Int], callSite: CallSite, - allowLocal: Boolean, resultHandler: (Int, U) => Unit, properties: Properties): Unit = { val start = System.nanoTime - val waiter = submitJob(rdd, func, partitions, callSite, allowLocal, resultHandler, properties) + val waiter = submitJob(rdd, func, partitions, callSite, resultHandler, properties) waiter.awaitResult() match { case JobSucceeded => logInfo("Job %d finished: %s, took %f s".format @@ -556,6 +550,9 @@ class DAGScheduler( case JobFailed(exception: Exception) => logInfo("Job %d failed: %s, took %f s".format (waiter.jobId, callSite.shortForm, (System.nanoTime - start) / 1e9)) + // SPARK-8644: Include user stack trace in exceptions coming from DAGScheduler. + val callerStackTrace = Thread.currentThread().getStackTrace.tail + exception.setStackTrace(exception.getStackTrace ++ callerStackTrace) throw exception } } @@ -572,8 +569,7 @@ class DAGScheduler( val partitions = (0 until rdd.partitions.size).toArray val jobId = nextJobId.getAndIncrement() eventProcessLoop.post(JobSubmitted( - jobId, rdd, func2, partitions, allowLocal = false, callSite, listener, - SerializationUtils.clone(properties))) + jobId, rdd, func2, partitions, callSite, listener, SerializationUtils.clone(properties))) listener.awaitResult() // Will throw an exception if the job fails } @@ -650,73 +646,6 @@ class DAGScheduler( } } - /** - * Run a job on an RDD locally, assuming it has only a single partition and no dependencies. - * We run the operation in a separate thread just in case it takes a bunch of time, so that we - * don't block the DAGScheduler event loop or other concurrent jobs. - */ - protected def runLocally(job: ActiveJob) { - logInfo("Computing the requested partition locally") - new Thread("Local computation of job " + job.jobId) { - override def run() { - runLocallyWithinThread(job) - } - }.start() - } - - // Broken out for easier testing in DAGSchedulerSuite. - protected def runLocallyWithinThread(job: ActiveJob) { - var jobResult: JobResult = JobSucceeded - try { - val rdd = job.finalStage.rdd - val split = rdd.partitions(job.partitions(0)) - val taskMemoryManager = new TaskMemoryManager(env.executorMemoryManager) - val taskContext = - new TaskContextImpl( - job.finalStage.id, - job.partitions(0), - taskAttemptId = 0, - attemptNumber = 0, - taskMemoryManager = taskMemoryManager, - runningLocally = true) - TaskContext.setTaskContext(taskContext) - try { - val result = job.func(taskContext, rdd.iterator(split, taskContext)) - job.listener.taskSucceeded(0, result) - } finally { - taskContext.markTaskCompleted() - TaskContext.unset() - // Note: this memory freeing logic is duplicated in Executor.run(); when changing this, - // make sure to update both copies. - val freedMemory = taskMemoryManager.cleanUpAllAllocatedMemory() - if (freedMemory > 0) { - if (sc.getConf.getBoolean("spark.unsafe.exceptionOnMemoryLeak", false)) { - throw new SparkException(s"Managed memory leak detected; size = $freedMemory bytes") - } else { - logError(s"Managed memory leak detected; size = $freedMemory bytes") - } - } - } - } catch { - case e: Exception => - val exception = new SparkDriverExecutionException(e) - jobResult = JobFailed(exception) - job.listener.jobFailed(exception) - case oom: OutOfMemoryError => - val exception = new SparkException("Local job aborted due to out of memory error", oom) - jobResult = JobFailed(exception) - job.listener.jobFailed(exception) - } finally { - val s = job.finalStage - // clean up data structures that were populated for a local job, - // but that won't get cleaned up via the normal paths through - // completion events or stage abort - stageIdToStage -= s.id - jobIdToStageIds -= job.jobId - listenerBus.post(SparkListenerJobEnd(job.jobId, clock.getTimeMillis(), jobResult)) - } - } - /** Finds the earliest-created active job that needs the stage */ // TODO: Probably should actually find among the active jobs that need this // stage the one with the highest priority (highest-priority pool, earliest created). @@ -779,7 +708,6 @@ class DAGScheduler( finalRDD: RDD[_], func: (TaskContext, Iterator[_]) => _, partitions: Array[Int], - allowLocal: Boolean, callSite: CallSite, listener: JobListener, properties: Properties) { @@ -797,29 +725,20 @@ class DAGScheduler( if (finalStage != null) { val job = new ActiveJob(jobId, finalStage, func, partitions, callSite, listener, properties) clearCacheLocs() - logInfo("Got job %s (%s) with %d output partitions (allowLocal=%s)".format( - job.jobId, callSite.shortForm, partitions.length, allowLocal)) + logInfo("Got job %s (%s) with %d output partitions".format( + job.jobId, callSite.shortForm, partitions.length)) logInfo("Final stage: " + finalStage + "(" + finalStage.name + ")") logInfo("Parents of final stage: " + finalStage.parents) logInfo("Missing parents: " + getMissingParentStages(finalStage)) - val shouldRunLocally = - localExecutionEnabled && allowLocal && finalStage.parents.isEmpty && partitions.length == 1 val jobSubmissionTime = clock.getTimeMillis() - if (shouldRunLocally) { - // Compute very short actions like first() or take() with no parent stages locally. - listenerBus.post( - SparkListenerJobStart(job.jobId, jobSubmissionTime, Seq.empty, properties)) - runLocally(job) - } else { - jobIdToActiveJob(jobId) = job - activeJobs += job - finalStage.resultOfJob = Some(job) - val stageIds = jobIdToStageIds(jobId).toArray - val stageInfos = stageIds.flatMap(id => stageIdToStage.get(id).map(_.latestInfo)) - listenerBus.post( - SparkListenerJobStart(job.jobId, jobSubmissionTime, stageInfos, properties)) - submitStage(finalStage) - } + jobIdToActiveJob(jobId) = job + activeJobs += job + finalStage.resultOfJob = Some(job) + val stageIds = jobIdToStageIds(jobId).toArray + val stageInfos = stageIds.flatMap(id => stageIdToStage.get(id).map(_.latestInfo)) + listenerBus.post( + SparkListenerJobStart(job.jobId, jobSubmissionTime, stageInfos, properties)) + submitStage(finalStage) } submitWaitingStages() } @@ -853,7 +772,6 @@ class DAGScheduler( // Get our pending tasks and remember them in our pendingTasks entry stage.pendingTasks.clear() - // First figure out the indexes of partition ids to compute. val partitionsToCompute: Seq[Int] = { stage match { @@ -914,7 +832,7 @@ class DAGScheduler( partitionsToCompute.map { id => val locs = getPreferredLocs(stage.rdd, id) val part = stage.rdd.partitions(id) - new ShuffleMapTask(stage.id, taskBinary, part, locs) + new ShuffleMapTask(stage.id, stage.latestInfo.attemptId, taskBinary, part, locs) } case stage: ResultStage => @@ -923,7 +841,7 @@ class DAGScheduler( val p: Int = job.partitions(id) val part = stage.rdd.partitions(p) val locs = getPreferredLocs(stage.rdd, p) - new ResultTask(stage.id, taskBinary, part, locs, id) + new ResultTask(stage.id, stage.latestInfo.attemptId, taskBinary, part, locs, id) } } } catch { @@ -1065,10 +983,11 @@ class DAGScheduler( val execId = status.location.executorId logDebug("ShuffleMapTask finished on " + execId) if (failedEpoch.contains(execId) && smt.epoch <= failedEpoch(execId)) { - logInfo("Ignoring possibly bogus ShuffleMapTask completion from " + execId) + logInfo(s"Ignoring possibly bogus $smt completion from executor $execId") } else { shuffleStage.addOutputLoc(smt.partitionId, status) } + if (runningStages.contains(shuffleStage) && shuffleStage.pendingTasks.isEmpty) { markStageAsFinished(shuffleStage) logInfo("looking for newly runnable stages") @@ -1128,38 +1047,48 @@ class DAGScheduler( val failedStage = stageIdToStage(task.stageId) val mapStage = shuffleToMapStage(shuffleId) - // It is likely that we receive multiple FetchFailed for a single stage (because we have - // multiple tasks running concurrently on different executors). In that case, it is possible - // the fetch failure has already been handled by the scheduler. - if (runningStages.contains(failedStage)) { - logInfo(s"Marking $failedStage (${failedStage.name}) as failed " + - s"due to a fetch failure from $mapStage (${mapStage.name})") - markStageAsFinished(failedStage, Some(failureMessage)) - } + if (failedStage.latestInfo.attemptId != task.stageAttemptId) { + logInfo(s"Ignoring fetch failure from $task as it's from $failedStage attempt" + + s" ${task.stageAttemptId} and there is a more recent attempt for that stage " + + s"(attempt ID ${failedStage.latestInfo.attemptId}) running") + } else { - if (disallowStageRetryForTest) { - abortStage(failedStage, "Fetch failure will not retry stage due to testing config") - } else if (failedStages.isEmpty) { - // Don't schedule an event to resubmit failed stages if failed isn't empty, because - // in that case the event will already have been scheduled. - // TODO: Cancel running tasks in the stage - logInfo(s"Resubmitting $mapStage (${mapStage.name}) and " + - s"$failedStage (${failedStage.name}) due to fetch failure") - messageScheduler.schedule(new Runnable { - override def run(): Unit = eventProcessLoop.post(ResubmitFailedStages) - }, DAGScheduler.RESUBMIT_TIMEOUT, TimeUnit.MILLISECONDS) - } - failedStages += failedStage - failedStages += mapStage - // Mark the map whose fetch failed as broken in the map stage - if (mapId != -1) { - mapStage.removeOutputLoc(mapId, bmAddress) - mapOutputTracker.unregisterMapOutput(shuffleId, mapId, bmAddress) - } + // It is likely that we receive multiple FetchFailed for a single stage (because we have + // multiple tasks running concurrently on different executors). In that case, it is + // possible the fetch failure has already been handled by the scheduler. + if (runningStages.contains(failedStage)) { + logInfo(s"Marking $failedStage (${failedStage.name}) as failed " + + s"due to a fetch failure from $mapStage (${mapStage.name})") + markStageAsFinished(failedStage, Some(failureMessage)) + } else { + logDebug(s"Received fetch failure from $task, but its from $failedStage which is no " + + s"longer running") + } - // TODO: mark the executor as failed only if there were lots of fetch failures on it - if (bmAddress != null) { - handleExecutorLost(bmAddress.executorId, fetchFailed = true, Some(task.epoch)) + if (disallowStageRetryForTest) { + abortStage(failedStage, "Fetch failure will not retry stage due to testing config") + } else if (failedStages.isEmpty) { + // Don't schedule an event to resubmit failed stages if failed isn't empty, because + // in that case the event will already have been scheduled. + // TODO: Cancel running tasks in the stage + logInfo(s"Resubmitting $mapStage (${mapStage.name}) and " + + s"$failedStage (${failedStage.name}) due to fetch failure") + messageScheduler.schedule(new Runnable { + override def run(): Unit = eventProcessLoop.post(ResubmitFailedStages) + }, DAGScheduler.RESUBMIT_TIMEOUT, TimeUnit.MILLISECONDS) + } + failedStages += failedStage + failedStages += mapStage + // Mark the map whose fetch failed as broken in the map stage + if (mapId != -1) { + mapStage.removeOutputLoc(mapId, bmAddress) + mapOutputTracker.unregisterMapOutput(shuffleId, mapId, bmAddress) + } + + // TODO: mark the executor as failed only if there were lots of fetch failures on it + if (bmAddress != null) { + handleExecutorLost(bmAddress.executorId, fetchFailed = true, Some(task.epoch)) + } } case commitDenied: TaskCommitDenied => @@ -1471,9 +1400,8 @@ private[scheduler] class DAGSchedulerEventProcessLoop(dagScheduler: DAGScheduler } private def doOnReceive(event: DAGSchedulerEvent): Unit = event match { - case JobSubmitted(jobId, rdd, func, partitions, allowLocal, callSite, listener, properties) => - dagScheduler.handleJobSubmitted(jobId, rdd, func, partitions, allowLocal, callSite, - listener, properties) + case JobSubmitted(jobId, rdd, func, partitions, callSite, listener, properties) => + dagScheduler.handleJobSubmitted(jobId, rdd, func, partitions, callSite, listener, properties) case StageCancelled(stageId) => dagScheduler.handleStageCancellation(stageId) diff --git a/core/src/main/scala/org/apache/spark/scheduler/DAGSchedulerEvent.scala b/core/src/main/scala/org/apache/spark/scheduler/DAGSchedulerEvent.scala index 2b6f7e4205c32..a213d419cf033 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/DAGSchedulerEvent.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/DAGSchedulerEvent.scala @@ -19,7 +19,7 @@ package org.apache.spark.scheduler import java.util.Properties -import scala.collection.mutable.Map +import scala.collection.Map import scala.language.existentials import org.apache.spark._ @@ -40,7 +40,6 @@ private[scheduler] case class JobSubmitted( finalRDD: RDD[_], func: (TaskContext, Iterator[_]) => _, partitions: Array[Int], - allowLocal: Boolean, callSite: CallSite, listener: JobListener, properties: Properties = null) diff --git a/core/src/main/scala/org/apache/spark/scheduler/EventLoggingListener.scala b/core/src/main/scala/org/apache/spark/scheduler/EventLoggingListener.scala index 62b05033a9281..5a06ef02f5c57 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/EventLoggingListener.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/EventLoggingListener.scala @@ -199,6 +199,9 @@ private[spark] class EventLoggingListener( logEvent(event, flushLogger = true) } + // No-op because logging every update would be overkill + override def onBlockUpdated(event: SparkListenerBlockUpdated): Unit = {} + // No-op because logging every update would be overkill override def onExecutorMetricsUpdate(event: SparkListenerExecutorMetricsUpdate): Unit = { } diff --git a/core/src/main/scala/org/apache/spark/scheduler/ResultTask.scala b/core/src/main/scala/org/apache/spark/scheduler/ResultTask.scala index c9a124113961f..9c2606e278c54 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/ResultTask.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/ResultTask.scala @@ -41,11 +41,12 @@ import org.apache.spark.rdd.RDD */ private[spark] class ResultTask[T, U]( stageId: Int, + stageAttemptId: Int, taskBinary: Broadcast[Array[Byte]], partition: Partition, @transient locs: Seq[TaskLocation], val outputId: Int) - extends Task[U](stageId, partition.index) with Serializable { + extends Task[U](stageId, stageAttemptId, partition.index) with Serializable { @transient private[this] val preferredLocs: Seq[TaskLocation] = { if (locs == null) Nil else locs.toSet.toSeq diff --git a/core/src/main/scala/org/apache/spark/scheduler/ShuffleMapTask.scala b/core/src/main/scala/org/apache/spark/scheduler/ShuffleMapTask.scala index bd3dd23dfe1ac..14c8c00961487 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/ShuffleMapTask.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/ShuffleMapTask.scala @@ -40,14 +40,15 @@ import org.apache.spark.shuffle.ShuffleWriter */ private[spark] class ShuffleMapTask( stageId: Int, + stageAttemptId: Int, taskBinary: Broadcast[Array[Byte]], partition: Partition, @transient private var locs: Seq[TaskLocation]) - extends Task[MapStatus](stageId, partition.index) with Logging { + extends Task[MapStatus](stageId, stageAttemptId, partition.index) with Logging { /** A constructor used only in test suites. This does not require passing in an RDD. */ def this(partitionId: Int) { - this(0, null, new Partition { override def index: Int = 0 }, null) + this(0, 0, null, new Partition { override def index: Int = 0 }, null) } @transient private val preferredLocs: Seq[TaskLocation] = { diff --git a/core/src/main/scala/org/apache/spark/scheduler/SparkListener.scala b/core/src/main/scala/org/apache/spark/scheduler/SparkListener.scala index 9620915f495ab..896f1743332f1 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/SparkListener.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/SparkListener.scala @@ -26,7 +26,7 @@ import org.apache.spark.{Logging, TaskEndReason} import org.apache.spark.annotation.DeveloperApi import org.apache.spark.executor.TaskMetrics import org.apache.spark.scheduler.cluster.ExecutorInfo -import org.apache.spark.storage.BlockManagerId +import org.apache.spark.storage.{BlockManagerId, BlockUpdatedInfo} import org.apache.spark.util.{Distribution, Utils} @DeveloperApi @@ -98,6 +98,9 @@ case class SparkListenerExecutorAdded(time: Long, executorId: String, executorIn case class SparkListenerExecutorRemoved(time: Long, executorId: String, reason: String) extends SparkListenerEvent +@DeveloperApi +case class SparkListenerBlockUpdated(blockUpdatedInfo: BlockUpdatedInfo) extends SparkListenerEvent + /** * Periodic updates from executors. * @param execId executor id @@ -215,6 +218,11 @@ trait SparkListener { * Called when the driver removes an executor. */ def onExecutorRemoved(executorRemoved: SparkListenerExecutorRemoved) { } + + /** + * Called when the driver receives a block update info. + */ + def onBlockUpdated(blockUpdated: SparkListenerBlockUpdated) { } } /** diff --git a/core/src/main/scala/org/apache/spark/scheduler/SparkListenerBus.scala b/core/src/main/scala/org/apache/spark/scheduler/SparkListenerBus.scala index 61e69ecc08387..04afde33f5aad 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/SparkListenerBus.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/SparkListenerBus.scala @@ -58,6 +58,8 @@ private[spark] trait SparkListenerBus extends ListenerBus[SparkListener, SparkLi listener.onExecutorAdded(executorAdded) case executorRemoved: SparkListenerExecutorRemoved => listener.onExecutorRemoved(executorRemoved) + case blockUpdated: SparkListenerBlockUpdated => + listener.onBlockUpdated(blockUpdated) case logStart: SparkListenerLogStart => // ignore event log metadata } } diff --git a/core/src/main/scala/org/apache/spark/scheduler/Task.scala b/core/src/main/scala/org/apache/spark/scheduler/Task.scala index 15101c64f0503..d11a00956a9a9 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/Task.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/Task.scala @@ -22,6 +22,7 @@ import java.nio.ByteBuffer import scala.collection.mutable.HashMap +import org.apache.spark.metrics.MetricsSystem import org.apache.spark.{TaskContextImpl, TaskContext} import org.apache.spark.executor.TaskMetrics import org.apache.spark.serializer.SerializerInstance @@ -43,31 +44,46 @@ import org.apache.spark.util.Utils * @param stageId id of the stage this task belongs to * @param partitionId index of the number in the RDD */ -private[spark] abstract class Task[T](val stageId: Int, var partitionId: Int) extends Serializable { +private[spark] abstract class Task[T]( + val stageId: Int, + val stageAttemptId: Int, + var partitionId: Int) extends Serializable { + + /** + * The key of the Map is the accumulator id and the value of the Map is the latest accumulator + * local value. + */ + type AccumulatorUpdates = Map[Long, Any] /** * Called by [[Executor]] to run this task. * * @param taskAttemptId an identifier for this task attempt that is unique within a SparkContext. * @param attemptNumber how many times this task has been attempted (0 for the first attempt) - * @return the result of the task + * @return the result of the task along with updates of Accumulators. */ - final def run(taskAttemptId: Long, attemptNumber: Int): T = { + final def run( + taskAttemptId: Long, + attemptNumber: Int, + metricsSystem: MetricsSystem) + : (T, AccumulatorUpdates) = { context = new TaskContextImpl( stageId = stageId, partitionId = partitionId, taskAttemptId = taskAttemptId, attemptNumber = attemptNumber, taskMemoryManager = taskMemoryManager, + metricsSystem = metricsSystem, runningLocally = false) TaskContext.setTaskContext(context) context.taskMetrics.setHostname(Utils.localHostName()) + context.taskMetrics.setAccumulatorsUpdater(context.collectInternalAccumulators) taskThread = Thread.currentThread() if (_killed) { kill(interruptThread = false) } try { - runTask(context) + (runTask(context), context.collectAccumulators()) } finally { context.markTaskCompleted() TaskContext.unset() diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskResult.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskResult.scala index 8b2a742b96988..b82c7f3fa54f8 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/TaskResult.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/TaskResult.scala @@ -20,7 +20,8 @@ package org.apache.spark.scheduler import java.io._ import java.nio.ByteBuffer -import scala.collection.mutable.Map +import scala.collection.Map +import scala.collection.mutable import org.apache.spark.SparkEnv import org.apache.spark.executor.TaskMetrics @@ -69,10 +70,11 @@ class DirectTaskResult[T](var valueBytes: ByteBuffer, var accumUpdates: Map[Long if (numUpdates == 0) { accumUpdates = null } else { - accumUpdates = Map() + val _accumUpdates = mutable.Map[Long, Any]() for (i <- 0 until numUpdates) { - accumUpdates(in.readLong()) = in.readObject() + _accumUpdates(in.readLong()) = in.readObject() } + accumUpdates = _accumUpdates } metrics = in.readObject().asInstanceOf[TaskMetrics] valueObjectDeserialized = false diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala index ed3dde0fc3055..1705e7f962de2 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala @@ -75,9 +75,9 @@ private[spark] class TaskSchedulerImpl( // TaskSetManagers are not thread safe, so any access to one should be synchronized // on this class. - val activeTaskSets = new HashMap[String, TaskSetManager] + private val taskSetsByStageIdAndAttempt = new HashMap[Int, HashMap[Int, TaskSetManager]] - val taskIdToTaskSetId = new HashMap[Long, String] + private[scheduler] val taskIdToTaskSetManager = new HashMap[Long, TaskSetManager] val taskIdToExecutorId = new HashMap[Long, String] @volatile private var hasReceivedTask = false @@ -162,7 +162,17 @@ private[spark] class TaskSchedulerImpl( logInfo("Adding task set " + taskSet.id + " with " + tasks.length + " tasks") this.synchronized { val manager = createTaskSetManager(taskSet, maxTaskFailures) - activeTaskSets(taskSet.id) = manager + val stage = taskSet.stageId + val stageTaskSets = + taskSetsByStageIdAndAttempt.getOrElseUpdate(stage, new HashMap[Int, TaskSetManager]) + stageTaskSets(taskSet.stageAttemptId) = manager + val conflictingTaskSet = stageTaskSets.exists { case (_, ts) => + ts.taskSet != taskSet && !ts.isZombie + } + if (conflictingTaskSet) { + throw new IllegalStateException(s"more than one active taskSet for stage $stage:" + + s" ${stageTaskSets.toSeq.map{_._2.taskSet.id}.mkString(",")}") + } schedulableBuilder.addTaskSetManager(manager, manager.taskSet.properties) if (!isLocal && !hasReceivedTask) { @@ -192,19 +202,21 @@ private[spark] class TaskSchedulerImpl( override def cancelTasks(stageId: Int, interruptThread: Boolean): Unit = synchronized { logInfo("Cancelling stage " + stageId) - activeTaskSets.find(_._2.stageId == stageId).foreach { case (_, tsm) => - // There are two possible cases here: - // 1. The task set manager has been created and some tasks have been scheduled. - // In this case, send a kill signal to the executors to kill the task and then abort - // the stage. - // 2. The task set manager has been created but no tasks has been scheduled. In this case, - // simply abort the stage. - tsm.runningTasksSet.foreach { tid => - val execId = taskIdToExecutorId(tid) - backend.killTask(tid, execId, interruptThread) + taskSetsByStageIdAndAttempt.get(stageId).foreach { attempts => + attempts.foreach { case (_, tsm) => + // There are two possible cases here: + // 1. The task set manager has been created and some tasks have been scheduled. + // In this case, send a kill signal to the executors to kill the task and then abort + // the stage. + // 2. The task set manager has been created but no tasks has been scheduled. In this case, + // simply abort the stage. + tsm.runningTasksSet.foreach { tid => + val execId = taskIdToExecutorId(tid) + backend.killTask(tid, execId, interruptThread) + } + tsm.abort("Stage %s cancelled".format(stageId)) + logInfo("Stage %d was cancelled".format(stageId)) } - tsm.abort("Stage %s cancelled".format(stageId)) - logInfo("Stage %d was cancelled".format(stageId)) } } @@ -214,7 +226,12 @@ private[spark] class TaskSchedulerImpl( * cleaned up. */ def taskSetFinished(manager: TaskSetManager): Unit = synchronized { - activeTaskSets -= manager.taskSet.id + taskSetsByStageIdAndAttempt.get(manager.taskSet.stageId).foreach { taskSetsForStage => + taskSetsForStage -= manager.taskSet.stageAttemptId + if (taskSetsForStage.isEmpty) { + taskSetsByStageIdAndAttempt -= manager.taskSet.stageId + } + } manager.parent.removeSchedulable(manager) logInfo("Removed TaskSet %s, whose tasks have all completed, from pool %s" .format(manager.taskSet.id, manager.parent.name)) @@ -235,7 +252,7 @@ private[spark] class TaskSchedulerImpl( for (task <- taskSet.resourceOffer(execId, host, maxLocality)) { tasks(i) += task val tid = task.taskId - taskIdToTaskSetId(tid) = taskSet.taskSet.id + taskIdToTaskSetManager(tid) = taskSet taskIdToExecutorId(tid) = execId executorsByHost(host) += execId availableCpus(i) -= CPUS_PER_TASK @@ -319,26 +336,24 @@ private[spark] class TaskSchedulerImpl( failedExecutor = Some(execId) } } - taskIdToTaskSetId.get(tid) match { - case Some(taskSetId) => + taskIdToTaskSetManager.get(tid) match { + case Some(taskSet) => if (TaskState.isFinished(state)) { - taskIdToTaskSetId.remove(tid) + taskIdToTaskSetManager.remove(tid) taskIdToExecutorId.remove(tid) } - activeTaskSets.get(taskSetId).foreach { taskSet => - if (state == TaskState.FINISHED) { - taskSet.removeRunningTask(tid) - taskResultGetter.enqueueSuccessfulTask(taskSet, tid, serializedData) - } else if (Set(TaskState.FAILED, TaskState.KILLED, TaskState.LOST).contains(state)) { - taskSet.removeRunningTask(tid) - taskResultGetter.enqueueFailedTask(taskSet, tid, state, serializedData) - } + if (state == TaskState.FINISHED) { + taskSet.removeRunningTask(tid) + taskResultGetter.enqueueSuccessfulTask(taskSet, tid, serializedData) + } else if (Set(TaskState.FAILED, TaskState.KILLED, TaskState.LOST).contains(state)) { + taskSet.removeRunningTask(tid) + taskResultGetter.enqueueFailedTask(taskSet, tid, state, serializedData) } case None => logError( ("Ignoring update with state %s for TID %s because its task set is gone (this is " + - "likely the result of receiving duplicate task finished status updates)") - .format(state, tid)) + "likely the result of receiving duplicate task finished status updates)") + .format(state, tid)) } } catch { case e: Exception => logError("Exception in statusUpdate", e) @@ -363,9 +378,9 @@ private[spark] class TaskSchedulerImpl( val metricsWithStageIds: Array[(Long, Int, Int, TaskMetrics)] = synchronized { taskMetrics.flatMap { case (id, metrics) => - taskIdToTaskSetId.get(id) - .flatMap(activeTaskSets.get) - .map(taskSetMgr => (id, taskSetMgr.stageId, taskSetMgr.taskSet.attempt, metrics)) + taskIdToTaskSetManager.get(id).map { taskSetMgr => + (id, taskSetMgr.stageId, taskSetMgr.taskSet.stageAttemptId, metrics) + } } } dagScheduler.executorHeartbeatReceived(execId, metricsWithStageIds, blockManagerId) @@ -397,9 +412,12 @@ private[spark] class TaskSchedulerImpl( def error(message: String) { synchronized { - if (activeTaskSets.nonEmpty) { + if (taskSetsByStageIdAndAttempt.nonEmpty) { // Have each task set throw a SparkException with the error - for ((taskSetId, manager) <- activeTaskSets) { + for { + attempts <- taskSetsByStageIdAndAttempt.values + manager <- attempts.values + } { try { manager.abort(message) } catch { @@ -520,6 +538,17 @@ private[spark] class TaskSchedulerImpl( override def applicationAttemptId(): Option[String] = backend.applicationAttemptId() + private[scheduler] def taskSetManagerForAttempt( + stageId: Int, + stageAttemptId: Int): Option[TaskSetManager] = { + for { + attempts <- taskSetsByStageIdAndAttempt.get(stageId) + manager <- attempts.get(stageAttemptId) + } yield { + manager + } + } + } diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskSet.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskSet.scala index c3ad325156f53..be8526ba9b94f 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/TaskSet.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/TaskSet.scala @@ -26,10 +26,10 @@ import java.util.Properties private[spark] class TaskSet( val tasks: Array[Task[_]], val stageId: Int, - val attempt: Int, + val stageAttemptId: Int, val priority: Int, val properties: Properties) { - val id: String = stageId + "." + attempt + val id: String = stageId + "." + stageAttemptId override def toString: String = "TaskSet " + id } diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala index 7c7f70d8a193b..c65b3e517773e 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala @@ -169,9 +169,12 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, val rpcEnv: Rp // Make fake resource offers on all executors private def makeOffers() { - launchTasks(scheduler.resourceOffers(executorDataMap.map { case (id, executorData) => + // Filter out executors under killing + val activeExecutors = executorDataMap.filterKeys(!executorsPendingToRemove.contains(_)) + val workOffers = activeExecutors.map { case (id, executorData) => new WorkerOffer(id, executorData.executorHost, executorData.freeCores) - }.toSeq)) + }.toSeq + launchTasks(scheduler.resourceOffers(workOffers)) } override def onDisconnected(remoteAddress: RpcAddress): Unit = { @@ -181,9 +184,13 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, val rpcEnv: Rp // Make fake resource offers on just one executor private def makeOffers(executorId: String) { - val executorData = executorDataMap(executorId) - launchTasks(scheduler.resourceOffers( - Seq(new WorkerOffer(executorId, executorData.executorHost, executorData.freeCores)))) + // Filter out executors under killing + if (!executorsPendingToRemove.contains(executorId)) { + val executorData = executorDataMap(executorId) + val workOffers = Seq( + new WorkerOffer(executorId, executorData.executorHost, executorData.freeCores)) + launchTasks(scheduler.resourceOffers(workOffers)) + } } // Launch tasks returned by a set of resource offers @@ -191,15 +198,14 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, val rpcEnv: Rp for (task <- tasks.flatten) { val serializedTask = ser.serialize(task) if (serializedTask.limit >= akkaFrameSize - AkkaUtils.reservedSizeBytes) { - val taskSetId = scheduler.taskIdToTaskSetId(task.taskId) - scheduler.activeTaskSets.get(taskSetId).foreach { taskSet => + scheduler.taskIdToTaskSetManager.get(task.taskId).foreach { taskSetMgr => try { var msg = "Serialized task %s:%d was %d bytes, which exceeds max allowed: " + "spark.akka.frameSize (%d bytes) - reserved (%d bytes). Consider increasing " + "spark.akka.frameSize or using broadcast variables for large values." msg = msg.format(task.taskId, task.index, serializedTask.limit, akkaFrameSize, AkkaUtils.reservedSizeBytes) - taskSet.abort(msg) + taskSetMgr.abort(msg) } catch { case e: Exception => logError("Exception in error callback", e) } @@ -371,26 +377,36 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, val rpcEnv: Rp /** * Request that the cluster manager kill the specified executors. - * Return whether the kill request is acknowledged. + * @return whether the kill request is acknowledged. */ final override def killExecutors(executorIds: Seq[String]): Boolean = synchronized { + killExecutors(executorIds, replace = false) + } + + /** + * Request that the cluster manager kill the specified executors. + * + * @param executorIds identifiers of executors to kill + * @param replace whether to replace the killed executors with new ones + * @return whether the kill request is acknowledged. + */ + final def killExecutors(executorIds: Seq[String], replace: Boolean): Boolean = synchronized { logInfo(s"Requesting to kill executor(s) ${executorIds.mkString(", ")}") - val filteredExecutorIds = new ArrayBuffer[String] - executorIds.foreach { id => - if (executorDataMap.contains(id)) { - filteredExecutorIds += id - } else { - logWarning(s"Executor to kill $id does not exist!") - } + val (knownExecutors, unknownExecutors) = executorIds.partition(executorDataMap.contains) + unknownExecutors.foreach { id => + logWarning(s"Executor to kill $id does not exist!") + } + + // If we do not wish to replace the executors we kill, sync the target number of executors + // with the cluster manager to avoid allocating new ones. When computing the new target, + // take into account executors that are pending to be added or removed. + if (!replace) { + doRequestTotalExecutors(numExistingExecutors + numPendingExecutors + - executorsPendingToRemove.size - knownExecutors.size) } - // Killing executors means effectively that we want less executors than before, so also update - // the target number of executors to avoid having the backend allocate new ones. - val newTotal = (numExistingExecutors + numPendingExecutors - executorsPendingToRemove.size - - filteredExecutorIds.size) - doRequestTotalExecutors(newTotal) - executorsPendingToRemove ++= filteredExecutorIds - doKillExecutors(filteredExecutorIds) + executorsPendingToRemove ++= knownExecutors + doKillExecutors(knownExecutors) } /** diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/CoarseMesosSchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/CoarseMesosSchedulerBackend.scala index cbade131494bc..b7fde0d9b3265 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/CoarseMesosSchedulerBackend.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/CoarseMesosSchedulerBackend.scala @@ -18,8 +18,8 @@ package org.apache.spark.scheduler.cluster.mesos import java.io.File -import java.util.{List => JList, Collections} import java.util.concurrent.locks.ReentrantLock +import java.util.{Collections, List => JList} import scala.collection.JavaConversions._ import scala.collection.mutable.{HashMap, HashSet} @@ -27,12 +27,11 @@ import scala.collection.mutable.{HashMap, HashSet} import com.google.common.collect.HashBiMap import org.apache.mesos.Protos.{TaskInfo => MesosTaskInfo, _} import org.apache.mesos.{Scheduler => MScheduler, _} -import org.apache.mesos.Protos.{TaskInfo => MesosTaskInfo, _} -import org.apache.spark.{SparkContext, SparkEnv, SparkException, TaskState} import org.apache.spark.rpc.RpcAddress import org.apache.spark.scheduler.TaskSchedulerImpl import org.apache.spark.scheduler.cluster.CoarseGrainedSchedulerBackend import org.apache.spark.util.Utils +import org.apache.spark.{SparkContext, SparkEnv, SparkException, TaskState} /** * A SchedulerBackend that runs tasks on Mesos, but uses "coarse-grained" tasks, where it holds @@ -69,7 +68,7 @@ private[spark] class CoarseMesosSchedulerBackend( /** * The total number of executors we aim to have. Undefined when not using dynamic allocation - * and before the ExecutorAllocatorManager calls [[doRequesTotalExecutors]]. + * and before the ExecutorAllocatorManager calls [[doRequestTotalExecutors]]. */ private var executorLimitOption: Option[Int] = None @@ -103,8 +102,9 @@ private[spark] class CoarseMesosSchedulerBackend( override def start() { super.start() - val fwInfo = FrameworkInfo.newBuilder().setUser(sc.sparkUser).setName(sc.appName).build() - startScheduler(master, CoarseMesosSchedulerBackend.this, fwInfo) + val driver = createSchedulerDriver( + master, CoarseMesosSchedulerBackend.this, sc.sparkUser, sc.appName, sc.conf) + startScheduler(driver) } def createCommand(offer: Offer, numCores: Int, taskId: Int): CommandInfo = { @@ -224,24 +224,29 @@ private[spark] class CoarseMesosSchedulerBackend( taskIdToSlaveId(taskId) = slaveId slaveIdsWithExecutors += slaveId coresByTaskId(taskId) = cpusToUse - val task = MesosTaskInfo.newBuilder() + // Gather cpu resources from the available resources and use them in the task. + val (remainingResources, cpuResourcesToUse) = + partitionResources(offer.getResourcesList, "cpus", cpusToUse) + val (_, memResourcesToUse) = + partitionResources(remainingResources, "mem", calculateTotalMemory(sc)) + val taskBuilder = MesosTaskInfo.newBuilder() .setTaskId(TaskID.newBuilder().setValue(taskId.toString).build()) .setSlaveId(offer.getSlaveId) .setCommand(createCommand(offer, cpusToUse + extraCoresPerSlave, taskId)) .setName("Task " + taskId) - .addResources(createResource("cpus", cpusToUse)) - .addResources(createResource("mem", calculateTotalMemory(sc))) + .addAllResources(cpuResourcesToUse) + .addAllResources(memResourcesToUse) sc.conf.getOption("spark.mesos.executor.docker.image").foreach { image => MesosSchedulerBackendUtil - .setupContainerBuilderDockerInfo(image, sc.conf, task.getContainerBuilder) + .setupContainerBuilderDockerInfo(image, sc.conf, taskBuilder.getContainerBuilder()) } // accept the offer and launch the task logDebug(s"Accepting offer: $id with attributes: $offerAttributes mem: $mem cpu: $cpus") d.launchTasks( Collections.singleton(offer.getId), - Collections.singleton(task.build()), filters) + Collections.singleton(taskBuilder.build()), filters) } else { // Decline the offer logDebug(s"Declining offer: $id with attributes: $offerAttributes mem: $mem cpu: $cpus") @@ -255,7 +260,7 @@ private[spark] class CoarseMesosSchedulerBackend( override def statusUpdate(d: SchedulerDriver, status: TaskStatus) { val taskId = status.getTaskId.getValue.toInt val state = status.getState - logInfo("Mesos task " + taskId + " is now " + state) + logInfo(s"Mesos task $taskId is now $state") stateLock.synchronized { if (TaskState.isFinished(TaskState.fromMesos(state))) { val slaveId = taskIdToSlaveId(taskId) @@ -270,7 +275,7 @@ private[spark] class CoarseMesosSchedulerBackend( if (TaskState.isFailed(TaskState.fromMesos(state))) { failuresBySlaveId(slaveId) = failuresBySlaveId.getOrElse(slaveId, 0) + 1 if (failuresBySlaveId(slaveId) >= MAX_SLAVE_FAILURES) { - logInfo("Blacklisting Mesos slave " + slaveId + " due to too many failures; " + + logInfo(s"Blacklisting Mesos slave $slaveId due to too many failures; " + "is Spark installed on it?") } } @@ -282,7 +287,7 @@ private[spark] class CoarseMesosSchedulerBackend( } override def error(d: SchedulerDriver, message: String) { - logError("Mesos error: " + message) + logError(s"Mesos error: $message") scheduler.error(message) } @@ -323,7 +328,7 @@ private[spark] class CoarseMesosSchedulerBackend( } override def slaveLost(d: SchedulerDriver, slaveId: SlaveID): Unit = { - logInfo("Mesos slave lost: " + slaveId.getValue) + logInfo(s"Mesos slave lost: ${slaveId.getValue}") executorTerminated(d, slaveId.getValue, "Mesos slave lost: " + slaveId.getValue) } diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosClusterScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosClusterScheduler.scala index d3a20f822176e..f078547e71352 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosClusterScheduler.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosClusterScheduler.scala @@ -295,20 +295,24 @@ private[spark] class MesosClusterScheduler( def start(): Unit = { // TODO: Implement leader election to make sure only one framework running in the cluster. val fwId = schedulerState.fetch[String]("frameworkId") - val builder = FrameworkInfo.newBuilder() - .setUser(Utils.getCurrentUserName()) - .setName(appName) - .setWebuiUrl(frameworkUrl) - .setCheckpoint(true) - .setFailoverTimeout(Integer.MAX_VALUE) // Setting to max so tasks keep running on crash fwId.foreach { id => - builder.setId(FrameworkID.newBuilder().setValue(id).build()) frameworkId = id } recoverState() metricsSystem.registerSource(new MesosClusterSchedulerSource(this)) metricsSystem.start() - startScheduler(master, MesosClusterScheduler.this, builder.build()) + val driver = createSchedulerDriver( + master, + MesosClusterScheduler.this, + Utils.getCurrentUserName(), + appName, + conf, + Some(frameworkUrl), + Some(true), + Some(Integer.MAX_VALUE), + fwId) + + startScheduler(driver) ready = true } @@ -449,12 +453,8 @@ private[spark] class MesosClusterScheduler( offer.cpu -= driverCpu offer.mem -= driverMem val taskId = TaskID.newBuilder().setValue(submission.submissionId).build() - val cpuResource = Resource.newBuilder() - .setName("cpus").setType(Value.Type.SCALAR) - .setScalar(Value.Scalar.newBuilder().setValue(driverCpu)).build() - val memResource = Resource.newBuilder() - .setName("mem").setType(Value.Type.SCALAR) - .setScalar(Value.Scalar.newBuilder().setValue(driverMem)).build() + val cpuResource = createResource("cpus", driverCpu) + val memResource = createResource("mem", driverMem) val commandInfo = buildDriverCommand(submission) val appName = submission.schedulerProperties("spark.app.name") val taskInfo = TaskInfo.newBuilder() diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerBackend.scala index d72e2af456e15..3f63ec1c5832f 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerBackend.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerBackend.scala @@ -32,6 +32,7 @@ import org.apache.spark.scheduler._ import org.apache.spark.scheduler.cluster.ExecutorInfo import org.apache.spark.util.Utils + /** * A SchedulerBackend for running fine-grained tasks on Mesos. Each Spark task is mapped to a * separate Mesos task, allowing multiple applications to share cluster nodes both in space (tasks @@ -45,8 +46,8 @@ private[spark] class MesosSchedulerBackend( with MScheduler with MesosSchedulerUtils { - // Which slave IDs we have executors on - val slaveIdsWithExecutors = new HashSet[String] + // Stores the slave ids that has launched a Mesos executor. + val slaveIdToExecutorInfo = new HashMap[String, MesosExecutorInfo] val taskIdToSlaveId = new HashMap[Long, String] // An ExecutorInfo for our tasks @@ -66,12 +67,21 @@ private[spark] class MesosSchedulerBackend( @volatile var appId: String = _ override def start() { - val fwInfo = FrameworkInfo.newBuilder().setUser(sc.sparkUser).setName(sc.appName).build() classLoader = Thread.currentThread.getContextClassLoader - startScheduler(master, MesosSchedulerBackend.this, fwInfo) + val driver = createSchedulerDriver( + master, MesosSchedulerBackend.this, sc.sparkUser, sc.appName, sc.conf) + startScheduler(driver) } - def createExecutorInfo(execId: String): MesosExecutorInfo = { + /** + * Creates a MesosExecutorInfo that is used to launch a Mesos executor. + * @param availableResources Available resources that is offered by Mesos + * @param execId The executor id to assign to this new executor. + * @return A tuple of the new mesos executor info and the remaining available resources. + */ + def createExecutorInfo( + availableResources: JList[Resource], + execId: String): (MesosExecutorInfo, JList[Resource]) = { val executorSparkHome = sc.conf.getOption("spark.mesos.executor.home") .orElse(sc.getSparkHome()) // Fall back to driver Spark home for backward compatibility .getOrElse { @@ -115,32 +125,25 @@ private[spark] class MesosSchedulerBackend( command.setValue(s"cd ${basename}*; $prefixEnv ./bin/spark-class $executorBackendName") command.addUris(CommandInfo.URI.newBuilder().setValue(uri.get)) } - val cpus = Resource.newBuilder() - .setName("cpus") - .setType(Value.Type.SCALAR) - .setScalar(Value.Scalar.newBuilder() - .setValue(mesosExecutorCores).build()) - .build() - val memory = Resource.newBuilder() - .setName("mem") - .setType(Value.Type.SCALAR) - .setScalar( - Value.Scalar.newBuilder() - .setValue(calculateTotalMemory(sc)).build()) - .build() - val executorInfo = MesosExecutorInfo.newBuilder() + val builder = MesosExecutorInfo.newBuilder() + val (resourcesAfterCpu, usedCpuResources) = + partitionResources(availableResources, "cpus", scheduler.CPUS_PER_TASK) + val (resourcesAfterMem, usedMemResources) = + partitionResources(resourcesAfterCpu, "mem", calculateTotalMemory(sc)) + + builder.addAllResources(usedCpuResources) + builder.addAllResources(usedMemResources) + val executorInfo = builder .setExecutorId(ExecutorID.newBuilder().setValue(execId).build()) .setCommand(command) .setData(ByteString.copyFrom(createExecArg())) - .addResources(cpus) - .addResources(memory) sc.conf.getOption("spark.mesos.executor.docker.image").foreach { image => MesosSchedulerBackendUtil .setupContainerBuilderDockerInfo(image, sc.conf, executorInfo.getContainerBuilder()) } - executorInfo.build() + (executorInfo.build(), resourcesAfterMem) } /** @@ -183,6 +186,18 @@ private[spark] class MesosSchedulerBackend( override def reregistered(d: SchedulerDriver, masterInfo: MasterInfo) {} + private def getTasksSummary(tasks: JArrayList[MesosTaskInfo]): String = { + val builder = new StringBuilder + tasks.foreach { t => + builder.append("Task id: ").append(t.getTaskId.getValue).append("\n") + .append("Slave id: ").append(t.getSlaveId.getValue).append("\n") + .append("Task resources: ").append(t.getResourcesList).append("\n") + .append("Executor resources: ").append(t.getExecutor.getResourcesList) + .append("---------------------------------------------\n") + } + builder.toString() + } + /** * Method called by Mesos to offer resources on slaves. We respond by asking our active task sets * for tasks in order of priority. We fill each node with tasks in a round-robin manner so that @@ -207,7 +222,7 @@ private[spark] class MesosSchedulerBackend( val meetsRequirements = (meetsConstraints && meetsMemoryRequirements && meetsCPURequirements) || - (slaveIdsWithExecutors.contains(slaveId) && cpus >= scheduler.CPUS_PER_TASK) + (slaveIdToExecutorInfo.contains(slaveId) && cpus >= scheduler.CPUS_PER_TASK) // add some debug messaging val debugstr = if (meetsRequirements) "Accepting" else "Declining" @@ -221,7 +236,7 @@ private[spark] class MesosSchedulerBackend( unUsableOffers.foreach(o => d.declineOffer(o.getId)) val workerOffers = usableOffers.map { o => - val cpus = if (slaveIdsWithExecutors.contains(o.getSlaveId.getValue)) { + val cpus = if (slaveIdToExecutorInfo.contains(o.getSlaveId.getValue)) { getResource(o.getResourcesList, "cpus").toInt } else { // If the Mesos executor has not been started on this slave yet, set aside a few @@ -236,6 +251,10 @@ private[spark] class MesosSchedulerBackend( val slaveIdToOffer = usableOffers.map(o => o.getSlaveId.getValue -> o).toMap val slaveIdToWorkerOffer = workerOffers.map(o => o.executorId -> o).toMap + val slaveIdToResources = new HashMap[String, JList[Resource]]() + usableOffers.foreach { o => + slaveIdToResources(o.getSlaveId.getValue) = o.getResourcesList + } val mesosTasks = new HashMap[String, JArrayList[MesosTaskInfo]] @@ -245,15 +264,19 @@ private[spark] class MesosSchedulerBackend( val acceptedOffers = scheduler.resourceOffers(workerOffers).filter(!_.isEmpty) acceptedOffers .foreach { offer => - offer.foreach { taskDesc => - val slaveId = taskDesc.executorId - slaveIdsWithExecutors += slaveId - slavesIdsOfAcceptedOffers += slaveId - taskIdToSlaveId(taskDesc.taskId) = slaveId - mesosTasks.getOrElseUpdate(slaveId, new JArrayList[MesosTaskInfo]) - .add(createMesosTask(taskDesc, slaveId)) + offer.foreach { taskDesc => + val slaveId = taskDesc.executorId + slavesIdsOfAcceptedOffers += slaveId + taskIdToSlaveId(taskDesc.taskId) = slaveId + val (mesosTask, remainingResources) = createMesosTask( + taskDesc, + slaveIdToResources(slaveId), + slaveId) + mesosTasks.getOrElseUpdate(slaveId, new JArrayList[MesosTaskInfo]) + .add(mesosTask) + slaveIdToResources(slaveId) = remainingResources + } } - } // Reply to the offers val filters = Filters.newBuilder().setRefuseSeconds(1).build() // TODO: lower timeout? @@ -264,6 +287,7 @@ private[spark] class MesosSchedulerBackend( // TODO: Add support for log urls for Mesos new ExecutorInfo(o.host, o.cores, Map.empty))) ) + logTrace(s"Launching Mesos tasks on slave '$slaveId', tasks:\n${getTasksSummary(tasks)}") d.launchTasks(Collections.singleton(slaveIdToOffer(slaveId).getId), tasks, filters) } @@ -272,26 +296,32 @@ private[spark] class MesosSchedulerBackend( for (o <- usableOffers if !slavesIdsOfAcceptedOffers.contains(o.getSlaveId.getValue)) { d.declineOffer(o.getId) } - } } - /** Turn a Spark TaskDescription into a Mesos task */ - def createMesosTask(task: TaskDescription, slaveId: String): MesosTaskInfo = { + /** Turn a Spark TaskDescription into a Mesos task and also resources unused by the task */ + def createMesosTask( + task: TaskDescription, + resources: JList[Resource], + slaveId: String): (MesosTaskInfo, JList[Resource]) = { val taskId = TaskID.newBuilder().setValue(task.taskId.toString).build() - val cpuResource = Resource.newBuilder() - .setName("cpus") - .setType(Value.Type.SCALAR) - .setScalar(Value.Scalar.newBuilder().setValue(scheduler.CPUS_PER_TASK).build()) - .build() - MesosTaskInfo.newBuilder() + val (executorInfo, remainingResources) = if (slaveIdToExecutorInfo.contains(slaveId)) { + (slaveIdToExecutorInfo(slaveId), resources) + } else { + createExecutorInfo(resources, slaveId) + } + slaveIdToExecutorInfo(slaveId) = executorInfo + val (finalResources, cpuResources) = + partitionResources(remainingResources, "cpus", scheduler.CPUS_PER_TASK) + val taskInfo = MesosTaskInfo.newBuilder() .setTaskId(taskId) .setSlaveId(SlaveID.newBuilder().setValue(slaveId).build()) - .setExecutor(createExecutorInfo(slaveId)) + .setExecutor(executorInfo) .setName(task.name) - .addResources(cpuResource) + .addAllResources(cpuResources) .setData(MesosTaskLaunchData(task.serializedTask, task.attemptNumber).toByteString) .build() + (taskInfo, finalResources) } override def statusUpdate(d: SchedulerDriver, status: TaskStatus) { @@ -337,7 +367,7 @@ private[spark] class MesosSchedulerBackend( private def removeExecutor(slaveId: String, reason: String) = { synchronized { listenerBus.post(SparkListenerExecutorRemoved(System.currentTimeMillis(), slaveId, reason)) - slaveIdsWithExecutors -= slaveId + slaveIdToExecutorInfo -= slaveId } } diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerUtils.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerUtils.scala index 925702e63afd3..c04920e4f5873 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerUtils.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerUtils.scala @@ -21,15 +21,17 @@ import java.util.{List => JList} import java.util.concurrent.CountDownLatch import scala.collection.JavaConversions._ +import scala.collection.mutable.ArrayBuffer import scala.util.control.NonFatal import com.google.common.base.Splitter import org.apache.mesos.{MesosSchedulerDriver, SchedulerDriver, Scheduler, Protos} import org.apache.mesos.Protos._ -import org.apache.mesos.protobuf.GeneratedMessage -import org.apache.spark.{Logging, SparkContext} +import org.apache.mesos.protobuf.{ByteString, GeneratedMessage} +import org.apache.spark.{SparkException, SparkConf, Logging, SparkContext} import org.apache.spark.util.Utils + /** * Shared trait for implementing a Mesos Scheduler. This holds common state and helper * methods and Mesos scheduler will use. @@ -42,13 +44,63 @@ private[mesos] trait MesosSchedulerUtils extends Logging { protected var mesosDriver: SchedulerDriver = null /** - * Starts the MesosSchedulerDriver with the provided information. This method returns - * only after the scheduler has registered with Mesos. - * @param masterUrl Mesos master connection URL - * @param scheduler Scheduler object - * @param fwInfo FrameworkInfo to pass to the Mesos master + * Creates a new MesosSchedulerDriver that communicates to the Mesos master. + * @param masterUrl The url to connect to Mesos master + * @param scheduler the scheduler class to receive scheduler callbacks + * @param sparkUser User to impersonate with when running tasks + * @param appName The framework name to display on the Mesos UI + * @param conf Spark configuration + * @param webuiUrl The WebUI url to link from Mesos UI + * @param checkpoint Option to checkpoint tasks for failover + * @param failoverTimeout Duration Mesos master expect scheduler to reconnect on disconnect + * @param frameworkId The id of the new framework */ - def startScheduler(masterUrl: String, scheduler: Scheduler, fwInfo: FrameworkInfo): Unit = { + protected def createSchedulerDriver( + masterUrl: String, + scheduler: Scheduler, + sparkUser: String, + appName: String, + conf: SparkConf, + webuiUrl: Option[String] = None, + checkpoint: Option[Boolean] = None, + failoverTimeout: Option[Double] = None, + frameworkId: Option[String] = None): SchedulerDriver = { + val fwInfoBuilder = FrameworkInfo.newBuilder().setUser(sparkUser).setName(appName) + val credBuilder = Credential.newBuilder() + webuiUrl.foreach { url => fwInfoBuilder.setWebuiUrl(url) } + checkpoint.foreach { checkpoint => fwInfoBuilder.setCheckpoint(checkpoint) } + failoverTimeout.foreach { timeout => fwInfoBuilder.setFailoverTimeout(timeout) } + frameworkId.foreach { id => + fwInfoBuilder.setId(FrameworkID.newBuilder().setValue(id).build()) + } + conf.getOption("spark.mesos.principal").foreach { principal => + fwInfoBuilder.setPrincipal(principal) + credBuilder.setPrincipal(principal) + } + conf.getOption("spark.mesos.secret").foreach { secret => + credBuilder.setSecret(ByteString.copyFromUtf8(secret)) + } + if (credBuilder.hasSecret && !fwInfoBuilder.hasPrincipal) { + throw new SparkException( + "spark.mesos.principal must be configured when spark.mesos.secret is set") + } + conf.getOption("spark.mesos.role").foreach { role => + fwInfoBuilder.setRole(role) + } + if (credBuilder.hasPrincipal) { + new MesosSchedulerDriver( + scheduler, fwInfoBuilder.build(), masterUrl, credBuilder.build()) + } else { + new MesosSchedulerDriver(scheduler, fwInfoBuilder.build(), masterUrl) + } + } + + /** + * Starts the MesosSchedulerDriver and stores the current running driver to this new instance. + * This driver is expected to not be running. + * This method returns only after the scheduler has registered with Mesos. + */ + def startScheduler(newDriver: SchedulerDriver): Unit = { synchronized { if (mesosDriver != null) { registerLatch.await() @@ -59,11 +111,11 @@ private[mesos] trait MesosSchedulerUtils extends Logging { setDaemon(true) override def run() { - mesosDriver = new MesosSchedulerDriver(scheduler, fwInfo, masterUrl) + mesosDriver = newDriver try { val ret = mesosDriver.run() logInfo("driver.run() returned with code " + ret) - if (ret.equals(Status.DRIVER_ABORTED)) { + if (ret != null && ret.equals(Status.DRIVER_ABORTED)) { System.exit(1) } } catch { @@ -82,18 +134,62 @@ private[mesos] trait MesosSchedulerUtils extends Logging { /** * Signal that the scheduler has registered with Mesos. */ + protected def getResource(res: JList[Resource], name: String): Double = { + // A resource can have multiple values in the offer since it can either be from + // a specific role or wildcard. + res.filter(_.getName == name).map(_.getScalar.getValue).sum + } + protected def markRegistered(): Unit = { registerLatch.countDown() } + def createResource(name: String, amount: Double, role: Option[String] = None): Resource = { + val builder = Resource.newBuilder() + .setName(name) + .setType(Value.Type.SCALAR) + .setScalar(Value.Scalar.newBuilder().setValue(amount).build()) + + role.foreach { r => builder.setRole(r) } + + builder.build() + } + /** - * Get the amount of resources for the specified type from the resource list + * Partition the existing set of resources into two groups, those remaining to be + * scheduled and those requested to be used for a new task. + * @param resources The full list of available resources + * @param resourceName The name of the resource to take from the available resources + * @param amountToUse The amount of resources to take from the available resources + * @return The remaining resources list and the used resources list. */ - protected def getResource(res: JList[Resource], name: String): Double = { - for (r <- res if r.getName == name) { - return r.getScalar.getValue + def partitionResources( + resources: JList[Resource], + resourceName: String, + amountToUse: Double): (List[Resource], List[Resource]) = { + var remain = amountToUse + var requestedResources = new ArrayBuffer[Resource] + val remainingResources = resources.map { + case r => { + if (remain > 0 && + r.getType == Value.Type.SCALAR && + r.getScalar.getValue > 0.0 && + r.getName == resourceName) { + val usage = Math.min(remain, r.getScalar.getValue) + requestedResources += createResource(resourceName, usage, Some(r.getRole)) + remain -= usage + createResource(resourceName, r.getScalar.getValue - usage, Some(r.getRole)) + } else { + r + } + } } - 0.0 + + // Filter any resource that has depleted. + val filteredResources = + remainingResources.filter(r => r.getType != Value.Type.SCALAR || r.getScalar.getValue > 0.0) + + (filteredResources.toList, requestedResources.toList) } /** Helper method to get the key,value-set pair for a Mesos Attribute protobuf */ diff --git a/core/src/main/scala/org/apache/spark/scheduler/local/LocalBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/local/LocalBackend.scala index 776e5d330e3c7..4d48fcfea44e7 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/local/LocalBackend.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/local/LocalBackend.scala @@ -25,7 +25,8 @@ import org.apache.spark.{Logging, SparkConf, SparkContext, SparkEnv, TaskState} import org.apache.spark.TaskState.TaskState import org.apache.spark.executor.{Executor, ExecutorBackend} import org.apache.spark.rpc.{RpcCallContext, RpcEndpointRef, RpcEnv, ThreadSafeRpcEndpoint} -import org.apache.spark.scheduler.{SchedulerBackend, TaskSchedulerImpl, WorkerOffer} +import org.apache.spark.scheduler._ +import org.apache.spark.scheduler.cluster.ExecutorInfo private case class ReviveOffers() @@ -50,8 +51,8 @@ private[spark] class LocalEndpoint( private var freeCores = totalCores - private val localExecutorId = SparkContext.DRIVER_IDENTIFIER - private val localExecutorHostname = "localhost" + val localExecutorId = SparkContext.DRIVER_IDENTIFIER + val localExecutorHostname = "localhost" private val executor = new Executor( localExecutorId, localExecutorHostname, SparkEnv.get, userClassPath, isLocal = true) @@ -99,8 +100,9 @@ private[spark] class LocalBackend( extends SchedulerBackend with ExecutorBackend with Logging { private val appId = "local-" + System.currentTimeMillis - var localEndpoint: RpcEndpointRef = null + private var localEndpoint: RpcEndpointRef = null private val userClassPath = getUserClasspath(conf) + private val listenerBus = scheduler.sc.listenerBus /** * Returns a list of URLs representing the user classpath. @@ -113,9 +115,13 @@ private[spark] class LocalBackend( } override def start() { - localEndpoint = SparkEnv.get.rpcEnv.setupEndpoint( - "LocalBackendEndpoint", - new LocalEndpoint(SparkEnv.get.rpcEnv, userClassPath, scheduler, this, totalCores)) + val rpcEnv = SparkEnv.get.rpcEnv + val executorEndpoint = new LocalEndpoint(rpcEnv, userClassPath, scheduler, this, totalCores) + localEndpoint = rpcEnv.setupEndpoint("LocalBackendEndpoint", executorEndpoint) + listenerBus.post(SparkListenerExecutorAdded( + System.currentTimeMillis, + executorEndpoint.localExecutorId, + new ExecutorInfo(executorEndpoint.localExecutorHostname, totalCores, Map.empty))) } override def stop() { diff --git a/core/src/main/scala/org/apache/spark/serializer/JavaSerializer.scala b/core/src/main/scala/org/apache/spark/serializer/JavaSerializer.scala index 698d1384d580d..4a5274b46b7a0 100644 --- a/core/src/main/scala/org/apache/spark/serializer/JavaSerializer.scala +++ b/core/src/main/scala/org/apache/spark/serializer/JavaSerializer.scala @@ -62,8 +62,11 @@ private[spark] class JavaDeserializationStream(in: InputStream, loader: ClassLoa extends DeserializationStream { private val objIn = new ObjectInputStream(in) { - override def resolveClass(desc: ObjectStreamClass): Class[_] = + override def resolveClass(desc: ObjectStreamClass): Class[_] = { + // scalastyle:off classforname Class.forName(desc.getName, false, loader) + // scalastyle:on classforname + } } def readObject[T: ClassTag](): T = objIn.readObject().asInstanceOf[T] diff --git a/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala b/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala index ed35cffe968f8..7cb6e080533ad 100644 --- a/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala +++ b/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala @@ -102,6 +102,7 @@ class KryoSerializer(conf: SparkConf) kryo.register(classOf[PythonBroadcast], new KryoJavaSerializer()) try { + // scalastyle:off classforname // Use the default classloader when calling the user registrator. Thread.currentThread.setContextClassLoader(classLoader) // Register classes given through spark.kryo.classesToRegister. @@ -111,6 +112,7 @@ class KryoSerializer(conf: SparkConf) userRegistrator .map(Class.forName(_, true, classLoader).newInstance().asInstanceOf[KryoRegistrator]) .foreach { reg => reg.registerClasses(kryo) } + // scalastyle:on classforname } catch { case e: Exception => throw new SparkException(s"Failed to register classes with Kryo", e) diff --git a/core/src/main/scala/org/apache/spark/serializer/SerializationDebugger.scala b/core/src/main/scala/org/apache/spark/serializer/SerializationDebugger.scala index cc2f0506817d3..a1b1e1631eafb 100644 --- a/core/src/main/scala/org/apache/spark/serializer/SerializationDebugger.scala +++ b/core/src/main/scala/org/apache/spark/serializer/SerializationDebugger.scala @@ -407,7 +407,9 @@ private[spark] object SerializationDebugger extends Logging { /** ObjectStreamClass$ClassDataSlot.desc field */ val DescField: Field = { + // scalastyle:off classforname val f = Class.forName("java.io.ObjectStreamClass$ClassDataSlot").getDeclaredField("desc") + // scalastyle:on classforname f.setAccessible(true) f } diff --git a/core/src/main/scala/org/apache/spark/shuffle/FileShuffleBlockResolver.scala b/core/src/main/scala/org/apache/spark/shuffle/FileShuffleBlockResolver.scala index 6c3b3080d2605..f6a96d81e7aa9 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/FileShuffleBlockResolver.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/FileShuffleBlockResolver.scala @@ -35,7 +35,7 @@ import org.apache.spark.util.collection.{PrimitiveKeyOpenHashMap, PrimitiveVecto /** A group of writers for a ShuffleMapTask, one writer per reducer. */ private[spark] trait ShuffleWriterGroup { - val writers: Array[BlockObjectWriter] + val writers: Array[DiskBlockObjectWriter] /** @param success Indicates all writes were successful. If false, no blocks will be recorded. */ def releaseWriters(success: Boolean) @@ -113,15 +113,15 @@ private[spark] class FileShuffleBlockResolver(conf: SparkConf) val openStartTime = System.nanoTime val serializerInstance = serializer.newInstance() - val writers: Array[BlockObjectWriter] = if (consolidateShuffleFiles) { + val writers: Array[DiskBlockObjectWriter] = if (consolidateShuffleFiles) { fileGroup = getUnusedFileGroup() - Array.tabulate[BlockObjectWriter](numBuckets) { bucketId => + Array.tabulate[DiskBlockObjectWriter](numBuckets) { bucketId => val blockId = ShuffleBlockId(shuffleId, mapId, bucketId) blockManager.getDiskWriter(blockId, fileGroup(bucketId), serializerInstance, bufferSize, writeMetrics) } } else { - Array.tabulate[BlockObjectWriter](numBuckets) { bucketId => + Array.tabulate[DiskBlockObjectWriter](numBuckets) { bucketId => val blockId = ShuffleBlockId(shuffleId, mapId, bucketId) val blockFile = blockManager.diskBlockManager.getFile(blockId) // Because of previous failures, the shuffle file may already exist on this machine. diff --git a/core/src/main/scala/org/apache/spark/shuffle/IndexShuffleBlockResolver.scala b/core/src/main/scala/org/apache/spark/shuffle/IndexShuffleBlockResolver.scala index d9c63b6e7bbb9..fae69551e7330 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/IndexShuffleBlockResolver.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/IndexShuffleBlockResolver.scala @@ -114,7 +114,7 @@ private[spark] class IndexShuffleBlockResolver(conf: SparkConf) extends ShuffleB } private[spark] object IndexShuffleBlockResolver { - // No-op reduce ID used in interactions with disk store and BlockObjectWriter. + // No-op reduce ID used in interactions with disk store and DiskBlockObjectWriter. // The disk store currently expects puts to relate to a (map, reduce) pair, but in the sort // shuffle outputs for several reduces are glommed into a single file. // TODO: Avoid this entirely by having the DiskBlockObjectWriter not require a BlockId. diff --git a/core/src/main/scala/org/apache/spark/shuffle/hash/BlockStoreShuffleFetcher.scala b/core/src/main/scala/org/apache/spark/shuffle/hash/BlockStoreShuffleFetcher.scala deleted file mode 100644 index 9d8e7e9f03aea..0000000000000 --- a/core/src/main/scala/org/apache/spark/shuffle/hash/BlockStoreShuffleFetcher.scala +++ /dev/null @@ -1,85 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.shuffle.hash - -import java.io.InputStream - -import scala.collection.mutable.{ArrayBuffer, HashMap} -import scala.util.{Failure, Success} - -import org.apache.spark._ -import org.apache.spark.shuffle.FetchFailedException -import org.apache.spark.storage.{BlockId, BlockManager, BlockManagerId, ShuffleBlockFetcherIterator, - ShuffleBlockId} - -private[hash] object BlockStoreShuffleFetcher extends Logging { - def fetchBlockStreams( - shuffleId: Int, - reduceId: Int, - context: TaskContext, - blockManager: BlockManager, - mapOutputTracker: MapOutputTracker) - : Iterator[(BlockId, InputStream)] = - { - logDebug("Fetching outputs for shuffle %d, reduce %d".format(shuffleId, reduceId)) - - val startTime = System.currentTimeMillis - val statuses = mapOutputTracker.getServerStatuses(shuffleId, reduceId) - logDebug("Fetching map output location for shuffle %d, reduce %d took %d ms".format( - shuffleId, reduceId, System.currentTimeMillis - startTime)) - - val splitsByAddress = new HashMap[BlockManagerId, ArrayBuffer[(Int, Long)]] - for (((address, size), index) <- statuses.zipWithIndex) { - splitsByAddress.getOrElseUpdate(address, ArrayBuffer()) += ((index, size)) - } - - val blocksByAddress: Seq[(BlockManagerId, Seq[(BlockId, Long)])] = splitsByAddress.toSeq.map { - case (address, splits) => - (address, splits.map(s => (ShuffleBlockId(shuffleId, s._1, reduceId), s._2))) - } - - val blockFetcherItr = new ShuffleBlockFetcherIterator( - context, - blockManager.shuffleClient, - blockManager, - blocksByAddress, - // Note: we use getSizeAsMb when no suffix is provided for backwards compatibility - SparkEnv.get.conf.getSizeAsMb("spark.reducer.maxSizeInFlight", "48m") * 1024 * 1024) - - // Make sure that fetch failures are wrapped inside a FetchFailedException for the scheduler - blockFetcherItr.map { blockPair => - val blockId = blockPair._1 - val blockOption = blockPair._2 - blockOption match { - case Success(inputStream) => { - (blockId, inputStream) - } - case Failure(e) => { - blockId match { - case ShuffleBlockId(shufId, mapId, _) => - val address = statuses(mapId.toInt)._1 - throw new FetchFailedException(address, shufId.toInt, mapId.toInt, reduceId, e) - case _ => - throw new SparkException( - "Failed to get block " + blockId + ", which is not a shuffle block", e) - } - } - } - } - } -} diff --git a/core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleReader.scala b/core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleReader.scala index d5c9880659dd3..de79fa56f017b 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleReader.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleReader.scala @@ -17,10 +17,10 @@ package org.apache.spark.shuffle.hash -import org.apache.spark.{InterruptibleIterator, MapOutputTracker, SparkEnv, TaskContext} +import org.apache.spark.{InterruptibleIterator, Logging, MapOutputTracker, SparkEnv, TaskContext} import org.apache.spark.serializer.Serializer import org.apache.spark.shuffle.{BaseShuffleHandle, ShuffleReader} -import org.apache.spark.storage.BlockManager +import org.apache.spark.storage.{BlockManager, ShuffleBlockFetcherIterator} import org.apache.spark.util.CompletionIterator import org.apache.spark.util.collection.ExternalSorter @@ -31,8 +31,8 @@ private[spark] class HashShuffleReader[K, C]( context: TaskContext, blockManager: BlockManager = SparkEnv.get.blockManager, mapOutputTracker: MapOutputTracker = SparkEnv.get.mapOutputTracker) - extends ShuffleReader[K, C] -{ + extends ShuffleReader[K, C] with Logging { + require(endPartition == startPartition + 1, "Hash shuffle currently only supports fetching one partition") @@ -40,11 +40,16 @@ private[spark] class HashShuffleReader[K, C]( /** Read the combined key-values for this reduce task */ override def read(): Iterator[Product2[K, C]] = { - val blockStreams = BlockStoreShuffleFetcher.fetchBlockStreams( - handle.shuffleId, startPartition, context, blockManager, mapOutputTracker) + val blockFetcherItr = new ShuffleBlockFetcherIterator( + context, + blockManager.shuffleClient, + blockManager, + mapOutputTracker.getMapSizesByExecutorId(handle.shuffleId, startPartition), + // Note: we use getSizeAsMb when no suffix is provided for backwards compatibility + SparkEnv.get.conf.getSizeAsMb("spark.reducer.maxSizeInFlight", "48m") * 1024 * 1024) // Wrap the streams for compression based on configuration - val wrappedStreams = blockStreams.map { case (blockId, inputStream) => + val wrappedStreams = blockFetcherItr.map { case (blockId, inputStream) => blockManager.wrapForCompression(blockId, inputStream) } diff --git a/core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleWriter.scala b/core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleWriter.scala index eb87cee15903c..41df70c602c30 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleWriter.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleWriter.scala @@ -22,7 +22,7 @@ import org.apache.spark.executor.ShuffleWriteMetrics import org.apache.spark.scheduler.MapStatus import org.apache.spark.serializer.Serializer import org.apache.spark.shuffle._ -import org.apache.spark.storage.BlockObjectWriter +import org.apache.spark.storage.DiskBlockObjectWriter private[spark] class HashShuffleWriter[K, V]( shuffleBlockResolver: FileShuffleBlockResolver, @@ -102,7 +102,7 @@ private[spark] class HashShuffleWriter[K, V]( private def commitWritesAndBuildStatus(): MapStatus = { // Commit the writes. Get the size of each bucket block (total block size). - val sizes: Array[Long] = shuffle.writers.map { writer: BlockObjectWriter => + val sizes: Array[Long] = shuffle.writers.map { writer: DiskBlockObjectWriter => writer.commitAndClose() writer.fileSegment().length } diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala index 1beafa1771448..86493673d958d 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala @@ -648,7 +648,7 @@ private[spark] class BlockManager( file: File, serializerInstance: SerializerInstance, bufferSize: Int, - writeMetrics: ShuffleWriteMetrics): BlockObjectWriter = { + writeMetrics: ShuffleWriteMetrics): DiskBlockObjectWriter = { val compressStream: OutputStream => OutputStream = wrapForCompression(blockId, _) val syncWrites = conf.getBoolean("spark.shuffle.sync", false) new DiskBlockObjectWriter(blockId, file, serializerInstance, bufferSize, compressStream, diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterEndpoint.scala b/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterEndpoint.scala index 68ed9096731c5..5dc0c537cbb62 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterEndpoint.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterEndpoint.scala @@ -60,10 +60,11 @@ class BlockManagerMasterEndpoint( register(blockManagerId, maxMemSize, slaveEndpoint) context.reply(true) - case UpdateBlockInfo( + case _updateBlockInfo @ UpdateBlockInfo( blockManagerId, blockId, storageLevel, deserializedSize, size, externalBlockStoreSize) => context.reply(updateBlockInfo( blockManagerId, blockId, storageLevel, deserializedSize, size, externalBlockStoreSize)) + listenerBus.post(SparkListenerBlockUpdated(BlockUpdatedInfo(_updateBlockInfo))) case GetLocations(blockId) => context.reply(getLocations(blockId)) diff --git a/core/src/main/scala/org/apache/spark/storage/BlockStatusListener.scala b/core/src/main/scala/org/apache/spark/storage/BlockStatusListener.scala new file mode 100644 index 0000000000000..2789e25b8d3ab --- /dev/null +++ b/core/src/main/scala/org/apache/spark/storage/BlockStatusListener.scala @@ -0,0 +1,105 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.storage + +import scala.collection.mutable + +import org.apache.spark.scheduler._ + +private[spark] case class BlockUIData( + blockId: BlockId, + location: String, + storageLevel: StorageLevel, + memSize: Long, + diskSize: Long, + externalBlockStoreSize: Long) + +/** + * The aggregated status of stream blocks in an executor + */ +private[spark] case class ExecutorStreamBlockStatus( + executorId: String, + location: String, + blocks: Seq[BlockUIData]) { + + def totalMemSize: Long = blocks.map(_.memSize).sum + + def totalDiskSize: Long = blocks.map(_.diskSize).sum + + def totalExternalBlockStoreSize: Long = blocks.map(_.externalBlockStoreSize).sum + + def numStreamBlocks: Int = blocks.size + +} + +private[spark] class BlockStatusListener extends SparkListener { + + private val blockManagers = + new mutable.HashMap[BlockManagerId, mutable.HashMap[BlockId, BlockUIData]] + + override def onBlockUpdated(blockUpdated: SparkListenerBlockUpdated): Unit = { + val blockId = blockUpdated.blockUpdatedInfo.blockId + if (!blockId.isInstanceOf[StreamBlockId]) { + // Now we only monitor StreamBlocks + return + } + val blockManagerId = blockUpdated.blockUpdatedInfo.blockManagerId + val storageLevel = blockUpdated.blockUpdatedInfo.storageLevel + val memSize = blockUpdated.blockUpdatedInfo.memSize + val diskSize = blockUpdated.blockUpdatedInfo.diskSize + val externalBlockStoreSize = blockUpdated.blockUpdatedInfo.externalBlockStoreSize + + synchronized { + // Drop the update info if the block manager is not registered + blockManagers.get(blockManagerId).foreach { blocksInBlockManager => + if (storageLevel.isValid) { + blocksInBlockManager.put(blockId, + BlockUIData( + blockId, + blockManagerId.hostPort, + storageLevel, + memSize, + diskSize, + externalBlockStoreSize) + ) + } else { + // If isValid is not true, it means we should drop the block. + blocksInBlockManager -= blockId + } + } + } + } + + override def onBlockManagerAdded(blockManagerAdded: SparkListenerBlockManagerAdded): Unit = { + synchronized { + blockManagers.put(blockManagerAdded.blockManagerId, mutable.HashMap()) + } + } + + override def onBlockManagerRemoved( + blockManagerRemoved: SparkListenerBlockManagerRemoved): Unit = synchronized { + blockManagers -= blockManagerRemoved.blockManagerId + } + + def allExecutorStreamBlockStatus: Seq[ExecutorStreamBlockStatus] = synchronized { + blockManagers.map { case (blockManagerId, blocks) => + ExecutorStreamBlockStatus( + blockManagerId.executorId, blockManagerId.hostPort, blocks.values.toSeq) + }.toSeq + } +} diff --git a/core/src/main/scala/org/apache/spark/storage/BlockUpdatedInfo.scala b/core/src/main/scala/org/apache/spark/storage/BlockUpdatedInfo.scala new file mode 100644 index 0000000000000..a5790e4454a89 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/storage/BlockUpdatedInfo.scala @@ -0,0 +1,47 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.storage + +import org.apache.spark.annotation.DeveloperApi +import org.apache.spark.storage.BlockManagerMessages.UpdateBlockInfo + +/** + * :: DeveloperApi :: + * Stores information about a block status in a block manager. + */ +@DeveloperApi +case class BlockUpdatedInfo( + blockManagerId: BlockManagerId, + blockId: BlockId, + storageLevel: StorageLevel, + memSize: Long, + diskSize: Long, + externalBlockStoreSize: Long) + +private[spark] object BlockUpdatedInfo { + + private[spark] def apply(updateBlockInfo: UpdateBlockInfo): BlockUpdatedInfo = { + BlockUpdatedInfo( + updateBlockInfo.blockManagerId, + updateBlockInfo.blockId, + updateBlockInfo.storageLevel, + updateBlockInfo.memSize, + updateBlockInfo.diskSize, + updateBlockInfo.externalBlockStoreSize) + } +} diff --git a/core/src/main/scala/org/apache/spark/storage/BlockObjectWriter.scala b/core/src/main/scala/org/apache/spark/storage/DiskBlockObjectWriter.scala similarity index 83% rename from core/src/main/scala/org/apache/spark/storage/BlockObjectWriter.scala rename to core/src/main/scala/org/apache/spark/storage/DiskBlockObjectWriter.scala index 7eeabd1e0489c..49d9154f95a5b 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockObjectWriter.scala +++ b/core/src/main/scala/org/apache/spark/storage/DiskBlockObjectWriter.scala @@ -26,66 +26,25 @@ import org.apache.spark.executor.ShuffleWriteMetrics import org.apache.spark.util.Utils /** - * An interface for writing JVM objects to some underlying storage. This interface allows - * appending data to an existing block, and can guarantee atomicity in the case of faults - * as it allows the caller to revert partial writes. + * A class for writing JVM objects directly to a file on disk. This class allows data to be appended + * to an existing block and can guarantee atomicity in the case of faults as it allows the caller to + * revert partial writes. * - * This interface does not support concurrent writes. Also, once the writer has - * been opened, it cannot be reopened again. - */ -private[spark] abstract class BlockObjectWriter(val blockId: BlockId) extends OutputStream { - - def open(): BlockObjectWriter - - def close() - - def isOpen: Boolean - - /** - * Flush the partial writes and commit them as a single atomic block. - */ - def commitAndClose(): Unit - - /** - * Reverts writes that haven't been flushed yet. Callers should invoke this function - * when there are runtime exceptions. This method will not throw, though it may be - * unsuccessful in truncating written data. - */ - def revertPartialWritesAndClose() - - /** - * Writes a key-value pair. - */ - def write(key: Any, value: Any) - - /** - * Notify the writer that a record worth of bytes has been written with OutputStream#write. - */ - def recordWritten() - - /** - * Returns the file segment of committed data that this Writer has written. - * This is only valid after commitAndClose() has been called. - */ - def fileSegment(): FileSegment -} - -/** - * BlockObjectWriter which writes directly to a file on disk. Appends to the given file. + * This class does not support concurrent writes. Also, once the writer has been opened it cannot be + * reopened again. */ private[spark] class DiskBlockObjectWriter( - blockId: BlockId, + val blockId: BlockId, file: File, serializerInstance: SerializerInstance, bufferSize: Int, compressStream: OutputStream => OutputStream, syncWrites: Boolean, - // These write metrics concurrently shared with other active BlockObjectWriter's who + // These write metrics concurrently shared with other active DiskBlockObjectWriters who // are themselves performing writes. All updates must be relative. writeMetrics: ShuffleWriteMetrics) - extends BlockObjectWriter(blockId) - with Logging -{ + extends OutputStream + with Logging { /** The file channel, used for repositioning / truncating the file. */ private var channel: FileChannel = null @@ -122,7 +81,7 @@ private[spark] class DiskBlockObjectWriter( */ private var numRecordsWritten = 0 - override def open(): BlockObjectWriter = { + def open(): DiskBlockObjectWriter = { if (hasBeenClosed) { throw new IllegalStateException("Writer already closed. Cannot be reopened.") } @@ -159,9 +118,12 @@ private[spark] class DiskBlockObjectWriter( } } - override def isOpen: Boolean = objOut != null + def isOpen: Boolean = objOut != null - override def commitAndClose(): Unit = { + /** + * Flush the partial writes and commit them as a single atomic block. + */ + def commitAndClose(): Unit = { if (initialized) { // NOTE: Because Kryo doesn't flush the underlying stream we explicitly flush both the // serializer stream and the lower level stream. @@ -177,9 +139,15 @@ private[spark] class DiskBlockObjectWriter( commitAndCloseHasBeenCalled = true } - // Discard current writes. We do this by flushing the outstanding writes and then - // truncating the file to its initial position. - override def revertPartialWritesAndClose() { + + /** + * Reverts writes that haven't been flushed yet. Callers should invoke this function + * when there are runtime exceptions. This method will not throw, though it may be + * unsuccessful in truncating written data. + */ + def revertPartialWritesAndClose() { + // Discard current writes. We do this by flushing the outstanding writes and then + // truncating the file to its initial position. try { if (initialized) { writeMetrics.decShuffleBytesWritten(reportedPosition - initialPosition) @@ -201,7 +169,10 @@ private[spark] class DiskBlockObjectWriter( } } - override def write(key: Any, value: Any) { + /** + * Writes a key-value pair. + */ + def write(key: Any, value: Any) { if (!initialized) { open() } @@ -221,7 +192,10 @@ private[spark] class DiskBlockObjectWriter( bs.write(kvBytes, offs, len) } - override def recordWritten(): Unit = { + /** + * Notify the writer that a record worth of bytes has been written with OutputStream#write. + */ + def recordWritten(): Unit = { numRecordsWritten += 1 writeMetrics.incShuffleRecordsWritten(1) @@ -230,7 +204,11 @@ private[spark] class DiskBlockObjectWriter( } } - override def fileSegment(): FileSegment = { + /** + * Returns the file segment of committed data that this Writer has written. + * This is only valid after commitAndClose() has been called. + */ + def fileSegment(): FileSegment = { if (!commitAndCloseHasBeenCalled) { throw new IllegalStateException( "fileSegment() is only valid after commitAndClose() has been called") diff --git a/core/src/main/scala/org/apache/spark/storage/ExternalBlockStore.scala b/core/src/main/scala/org/apache/spark/storage/ExternalBlockStore.scala index 291394ed34816..db965d54bafd6 100644 --- a/core/src/main/scala/org/apache/spark/storage/ExternalBlockStore.scala +++ b/core/src/main/scala/org/apache/spark/storage/ExternalBlockStore.scala @@ -192,7 +192,7 @@ private[spark] class ExternalBlockStore(blockManager: BlockManager, executorId: .getOrElse(ExternalBlockStore.DEFAULT_BLOCK_MANAGER_NAME) try { - val instance = Class.forName(clsName) + val instance = Utils.classForName(clsName) .newInstance() .asInstanceOf[ExternalBlockManager] instance.init(blockManager, executorId) diff --git a/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala b/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala index e49e39679e940..a759ceb96ec1e 100644 --- a/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala +++ b/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala @@ -21,18 +21,19 @@ import java.io.InputStream import java.util.concurrent.LinkedBlockingQueue import scala.collection.mutable.{ArrayBuffer, HashSet, Queue} -import scala.util.{Failure, Try} +import scala.util.control.NonFatal -import org.apache.spark.{Logging, TaskContext} +import org.apache.spark.{Logging, SparkException, TaskContext} import org.apache.spark.network.buffer.ManagedBuffer import org.apache.spark.network.shuffle.{BlockFetchingListener, ShuffleClient} +import org.apache.spark.shuffle.FetchFailedException import org.apache.spark.util.Utils /** * An iterator that fetches multiple blocks. For local blocks, it fetches from the local block * manager. For remote blocks, it fetches them using the provided BlockTransferService. * - * This creates an iterator of (BlockID, Try[InputStream]) tuples so the caller can handle blocks + * This creates an iterator of (BlockID, InputStream) tuples so the caller can handle blocks * in a pipelined fashion as they are received. * * The implementation throttles the remote fetches to they don't exceed maxBytesInFlight to avoid @@ -53,7 +54,7 @@ final class ShuffleBlockFetcherIterator( blockManager: BlockManager, blocksByAddress: Seq[(BlockManagerId, Seq[(BlockId, Long)])], maxBytesInFlight: Long) - extends Iterator[(BlockId, Try[InputStream])] with Logging { + extends Iterator[(BlockId, InputStream)] with Logging { import ShuffleBlockFetcherIterator._ @@ -115,7 +116,7 @@ final class ShuffleBlockFetcherIterator( private[storage] def releaseCurrentResultBuffer(): Unit = { // Release the current buffer if necessary currentResult match { - case SuccessFetchResult(_, _, buf) => buf.release() + case SuccessFetchResult(_, _, _, buf) => buf.release() case _ => } currentResult = null @@ -132,7 +133,7 @@ final class ShuffleBlockFetcherIterator( while (iter.hasNext) { val result = iter.next() result match { - case SuccessFetchResult(_, _, buf) => buf.release() + case SuccessFetchResult(_, _, _, buf) => buf.release() case _ => } } @@ -157,7 +158,7 @@ final class ShuffleBlockFetcherIterator( // Increment the ref count because we need to pass this to a different thread. // This needs to be released after use. buf.retain() - results.put(new SuccessFetchResult(BlockId(blockId), sizeMap(blockId), buf)) + results.put(new SuccessFetchResult(BlockId(blockId), address, sizeMap(blockId), buf)) shuffleMetrics.incRemoteBytesRead(buf.size) shuffleMetrics.incRemoteBlocksFetched(1) } @@ -166,7 +167,7 @@ final class ShuffleBlockFetcherIterator( override def onBlockFetchFailure(blockId: String, e: Throwable): Unit = { logError(s"Failed to get block(s) from ${req.address.host}:${req.address.port}", e) - results.put(new FailureFetchResult(BlockId(blockId), e)) + results.put(new FailureFetchResult(BlockId(blockId), address, e)) } } ) @@ -238,12 +239,12 @@ final class ShuffleBlockFetcherIterator( shuffleMetrics.incLocalBlocksFetched(1) shuffleMetrics.incLocalBytesRead(buf.size) buf.retain() - results.put(new SuccessFetchResult(blockId, 0, buf)) + results.put(new SuccessFetchResult(blockId, blockManager.blockManagerId, 0, buf)) } catch { case e: Exception => // If we see an exception, stop immediately. logError(s"Error occurred while fetching local blocks", e) - results.put(new FailureFetchResult(blockId, e)) + results.put(new FailureFetchResult(blockId, blockManager.blockManagerId, e)) return } } @@ -275,12 +276,14 @@ final class ShuffleBlockFetcherIterator( override def hasNext: Boolean = numBlocksProcessed < numBlocksToFetch /** - * Fetches the next (BlockId, Try[InputStream]). If a task fails, the ManagedBuffers + * Fetches the next (BlockId, InputStream). If a task fails, the ManagedBuffers * underlying each InputStream will be freed by the cleanup() method registered with the * TaskCompletionListener. However, callers should close() these InputStreams * as soon as they are no longer needed, in order to release memory as early as possible. + * + * Throws a FetchFailedException if the next block could not be fetched. */ - override def next(): (BlockId, Try[InputStream]) = { + override def next(): (BlockId, InputStream) = { numBlocksProcessed += 1 val startFetchWait = System.currentTimeMillis() currentResult = results.take() @@ -289,7 +292,7 @@ final class ShuffleBlockFetcherIterator( shuffleMetrics.incFetchWaitTime(stopFetchWait - startFetchWait) result match { - case SuccessFetchResult(_, size, _) => bytesInFlight -= size + case SuccessFetchResult(_, _, size, _) => bytesInFlight -= size case _ => } // Send fetch requests up to maxBytesInFlight @@ -298,19 +301,28 @@ final class ShuffleBlockFetcherIterator( sendRequest(fetchRequests.dequeue()) } - val iteratorTry: Try[InputStream] = result match { - case FailureFetchResult(_, e) => - Failure(e) - case SuccessFetchResult(blockId, _, buf) => - // There is a chance that createInputStream can fail (e.g. fetching a local file that does - // not exist, SPARK-4085). In that case, we should propagate the right exception so - // the scheduler gets a FetchFailedException. - Try(buf.createInputStream()).map { inputStream => - new BufferReleasingInputStream(inputStream, this) + result match { + case FailureFetchResult(blockId, address, e) => + throwFetchFailedException(blockId, address, e) + + case SuccessFetchResult(blockId, address, _, buf) => + try { + (result.blockId, new BufferReleasingInputStream(buf.createInputStream(), this)) + } catch { + case NonFatal(t) => + throwFetchFailedException(blockId, address, t) } } + } - (result.blockId, iteratorTry) + private def throwFetchFailedException(blockId: BlockId, address: BlockManagerId, e: Throwable) = { + blockId match { + case ShuffleBlockId(shufId, mapId, reduceId) => + throw new FetchFailedException(address, shufId.toInt, mapId.toInt, reduceId, e) + case _ => + throw new SparkException( + "Failed to get block " + blockId + ", which is not a shuffle block", e) + } } } @@ -366,16 +378,22 @@ object ShuffleBlockFetcherIterator { */ private[storage] sealed trait FetchResult { val blockId: BlockId + val address: BlockManagerId } /** * Result of a fetch from a remote block successfully. * @param blockId block id + * @param address BlockManager that the block was fetched from. * @param size estimated size of the block, used to calculate bytesInFlight. * Note that this is NOT the exact bytes. * @param buf [[ManagedBuffer]] for the content. */ - private[storage] case class SuccessFetchResult(blockId: BlockId, size: Long, buf: ManagedBuffer) + private[storage] case class SuccessFetchResult( + blockId: BlockId, + address: BlockManagerId, + size: Long, + buf: ManagedBuffer) extends FetchResult { require(buf != null) require(size >= 0) @@ -384,8 +402,12 @@ object ShuffleBlockFetcherIterator { /** * Result of a fetch from a remote block unsuccessfully. * @param blockId block id + * @param address BlockManager that the block was attempted to be fetched from * @param e the failure exception */ - private[storage] case class FailureFetchResult(blockId: BlockId, e: Throwable) + private[storage] case class FailureFetchResult( + blockId: BlockId, + address: BlockManagerId, + e: Throwable) extends FetchResult } diff --git a/core/src/main/scala/org/apache/spark/ui/PagedTable.scala b/core/src/main/scala/org/apache/spark/ui/PagedTable.scala new file mode 100644 index 0000000000000..17d7b39c2d951 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/ui/PagedTable.scala @@ -0,0 +1,246 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ui + +import scala.xml.{Node, Unparsed} + +/** + * A data source that provides data for a page. + * + * @param pageSize the number of rows in a page + */ +private[ui] abstract class PagedDataSource[T](val pageSize: Int) { + + if (pageSize <= 0) { + throw new IllegalArgumentException("Page size must be positive") + } + + /** + * Return the size of all data. + */ + protected def dataSize: Int + + /** + * Slice a range of data. + */ + protected def sliceData(from: Int, to: Int): Seq[T] + + /** + * Slice the data for this page + */ + def pageData(page: Int): PageData[T] = { + val totalPages = (dataSize + pageSize - 1) / pageSize + if (page <= 0 || page > totalPages) { + throw new IndexOutOfBoundsException( + s"Page $page is out of range. Please select a page number between 1 and $totalPages.") + } + val from = (page - 1) * pageSize + val to = dataSize.min(page * pageSize) + PageData(totalPages, sliceData(from, to)) + } + +} + +/** + * The data returned by `PagedDataSource.pageData`, including the page number, the number of total + * pages and the data in this page. + */ +private[ui] case class PageData[T](totalPage: Int, data: Seq[T]) + +/** + * A paged table that will generate a HTML table for a specified page and also the page navigation. + */ +private[ui] trait PagedTable[T] { + + def tableId: String + + def tableCssClass: String + + def dataSource: PagedDataSource[T] + + def headers: Seq[Node] + + def row(t: T): Seq[Node] + + def table(page: Int): Seq[Node] = { + val _dataSource = dataSource + try { + val PageData(totalPages, data) = _dataSource.pageData(page) +
+ {pageNavigation(page, _dataSource.pageSize, totalPages)} + + {headers} + + {data.map(row)} + +
+
+ } catch { + case e: IndexOutOfBoundsException => + val PageData(totalPages, _) = _dataSource.pageData(1) +
+ {pageNavigation(1, _dataSource.pageSize, totalPages)} +
{e.getMessage}
+
+ } + } + + /** + * Return a page navigation. + *
    + *
  • If the totalPages is 1, the page navigation will be empty
  • + *
  • + * If the totalPages is more than 1, it will create a page navigation including a group of + * page numbers and a form to submit the page number. + *
  • + *
+ * + * Here are some examples of the page navigation: + * {{{ + * << < 11 12 13* 14 15 16 17 18 19 20 > >> + * + * This is the first group, so "<<" is hidden. + * < 1 2* 3 4 5 6 7 8 9 10 > >> + * + * This is the first group and the first page, so "<<" and "<" are hidden. + * 1* 2 3 4 5 6 7 8 9 10 > >> + * + * Assume totalPages is 19. This is the last group, so ">>" is hidden. + * << < 11 12 13* 14 15 16 17 18 19 > + * + * Assume totalPages is 19. This is the last group and the last page, so ">>" and ">" are hidden. + * << < 11 12 13 14 15 16 17 18 19* + * + * * means the current page number + * << means jumping to the first page of the previous group. + * < means jumping to the previous page. + * >> means jumping to the first page of the next group. + * > means jumping to the next page. + * }}} + */ + private[ui] def pageNavigation(page: Int, pageSize: Int, totalPages: Int): Seq[Node] = { + if (totalPages == 1) { + Nil + } else { + // A group includes all page numbers will be shown in the page navigation. + // The size of group is 10 means there are 10 page numbers will be shown. + // The first group is 1 to 10, the second is 2 to 20, and so on + val groupSize = 10 + val firstGroup = 0 + val lastGroup = (totalPages - 1) / groupSize + val currentGroup = (page - 1) / groupSize + val startPage = currentGroup * groupSize + 1 + val endPage = totalPages.min(startPage + groupSize - 1) + val pageTags = (startPage to endPage).map { p => + if (p == page) { + // The current page should be disabled so that it cannot be clicked. +
  • {p}
  • + } else { +
  • {p}
  • + } + } + val (goButtonJsFuncName, goButtonJsFunc) = goButtonJavascriptFunction + // When clicking the "Go" button, it will call this javascript method and then call + // "goButtonJsFuncName" + val formJs = + s"""$$(function(){ + | $$( "#form-task-page" ).submit(function(event) { + | var page = $$("#form-task-page-no").val() + | var pageSize = $$("#form-task-page-size").val() + | pageSize = pageSize ? pageSize: 100; + | if (page != "") { + | ${goButtonJsFuncName}(page, pageSize); + | } + | event.preventDefault(); + | }); + |}); + """.stripMargin + +
    +
    +
    + + + + + + +
    +
    + + +
    + } + } + + /** + * Return a link to jump to a page. + */ + def pageLink(page: Int): String + + /** + * Only the implementation knows how to create the url with a page number and the page size, so we + * leave this one to the implementation. The implementation should create a JavaScript method that + * accepts a page number along with the page size and jumps to the page. The return value is this + * method name and its JavaScript codes. + */ + def goButtonJavascriptFunction: (String, String) +} diff --git a/core/src/main/scala/org/apache/spark/ui/UIUtils.scala b/core/src/main/scala/org/apache/spark/ui/UIUtils.scala index 7898039519201..718aea7e1dc22 100644 --- a/core/src/main/scala/org/apache/spark/ui/UIUtils.scala +++ b/core/src/main/scala/org/apache/spark/ui/UIUtils.scala @@ -27,7 +27,7 @@ import org.apache.spark.ui.scope.RDDOperationGraph /** Utility functions for generating XML pages with spark content. */ private[spark] object UIUtils extends Logging { - val TABLE_CLASS_NOT_STRIPED = "table table-bordered table-condensed sortable" + val TABLE_CLASS_NOT_STRIPED = "table table-bordered table-condensed" val TABLE_CLASS_STRIPED = TABLE_CLASS_NOT_STRIPED + " table-striped" // SimpleDateFormat is not thread-safe. Don't expose it to avoid improper use. @@ -267,9 +267,17 @@ private[spark] object UIUtils extends Logging { fixedWidth: Boolean = false, id: Option[String] = None, headerClasses: Seq[String] = Seq.empty, - stripeRowsWithCss: Boolean = true): Seq[Node] = { + stripeRowsWithCss: Boolean = true, + sortable: Boolean = true): Seq[Node] = { - val listingTableClass = if (stripeRowsWithCss) TABLE_CLASS_STRIPED else TABLE_CLASS_NOT_STRIPED + val listingTableClass = { + val _tableClass = if (stripeRowsWithCss) TABLE_CLASS_STRIPED else TABLE_CLASS_NOT_STRIPED + if (sortable) { + _tableClass + " sortable" + } else { + _tableClass + } + } val colWidth = 100.toDouble / headers.size val colWidthAttr = if (fixedWidth) colWidth + "%" else "" diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala b/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala index 60e3c6343122c..cf04b5e59239b 100644 --- a/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala +++ b/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala @@ -17,6 +17,7 @@ package org.apache.spark.ui.jobs +import java.net.URLEncoder import java.util.Date import javax.servlet.http.HttpServletRequest @@ -27,13 +28,14 @@ import org.apache.commons.lang3.StringEscapeUtils import org.apache.spark.executor.TaskMetrics import org.apache.spark.scheduler.{AccumulableInfo, TaskInfo} -import org.apache.spark.ui.{ToolTips, WebUIPage, UIUtils} +import org.apache.spark.ui._ import org.apache.spark.ui.jobs.UIData._ -import org.apache.spark.ui.scope.RDDOperationGraph import org.apache.spark.util.{Utils, Distribution} /** Page showing statistics and task list for a given stage */ private[ui] class StagePage(parent: StagesTab) extends WebUIPage("stage") { + import StagePage._ + private val progressListener = parent.progressListener private val operationGraphListener = parent.operationGraphListener @@ -74,6 +76,16 @@ private[ui] class StagePage(parent: StagesTab) extends WebUIPage("stage") { val parameterAttempt = request.getParameter("attempt") require(parameterAttempt != null && parameterAttempt.nonEmpty, "Missing attempt parameter") + val parameterTaskPage = request.getParameter("task.page") + val parameterTaskSortColumn = request.getParameter("task.sort") + val parameterTaskSortDesc = request.getParameter("task.desc") + val parameterTaskPageSize = request.getParameter("task.pageSize") + + val taskPage = Option(parameterTaskPage).map(_.toInt).getOrElse(1) + val taskSortColumn = Option(parameterTaskSortColumn).getOrElse("Index") + val taskSortDesc = Option(parameterTaskSortDesc).map(_.toBoolean).getOrElse(false) + val taskPageSize = Option(parameterTaskPageSize).map(_.toInt).getOrElse(100) + // If this is set, expand the dag visualization by default val expandDagVizParam = request.getParameter("expandDagViz") val expandDagViz = expandDagVizParam != null && expandDagVizParam.toBoolean @@ -231,52 +243,47 @@ private[ui] class StagePage(parent: StagesTab) extends WebUIPage("stage") { accumulableRow, accumulables.values.toSeq) - val taskHeadersAndCssClasses: Seq[(String, String)] = - Seq( - ("Index", ""), ("ID", ""), ("Attempt", ""), ("Status", ""), ("Locality Level", ""), - ("Executor ID / Host", ""), ("Launch Time", ""), ("Duration", ""), - ("Scheduler Delay", TaskDetailsClassNames.SCHEDULER_DELAY), - ("Task Deserialization Time", TaskDetailsClassNames.TASK_DESERIALIZATION_TIME), - ("GC Time", ""), - ("Result Serialization Time", TaskDetailsClassNames.RESULT_SERIALIZATION_TIME), - ("Getting Result Time", TaskDetailsClassNames.GETTING_RESULT_TIME)) ++ - {if (hasAccumulators) Seq(("Accumulators", "")) else Nil} ++ - {if (stageData.hasInput) Seq(("Input Size / Records", "")) else Nil} ++ - {if (stageData.hasOutput) Seq(("Output Size / Records", "")) else Nil} ++ - {if (stageData.hasShuffleRead) { - Seq(("Shuffle Read Blocked Time", TaskDetailsClassNames.SHUFFLE_READ_BLOCKED_TIME), - ("Shuffle Read Size / Records", ""), - ("Shuffle Remote Reads", TaskDetailsClassNames.SHUFFLE_READ_REMOTE_SIZE)) - } else { - Nil - }} ++ - {if (stageData.hasShuffleWrite) { - Seq(("Write Time", ""), ("Shuffle Write Size / Records", "")) - } else { - Nil - }} ++ - {if (stageData.hasBytesSpilled) { - Seq(("Shuffle Spill (Memory)", ""), ("Shuffle Spill (Disk)", "")) - } else { - Nil - }} ++ - Seq(("Errors", "")) - - val unzipped = taskHeadersAndCssClasses.unzip - val currentTime = System.currentTimeMillis() - val taskTable = UIUtils.listingTable( - unzipped._1, - taskRow( + val (taskTable, taskTableHTML) = try { + val _taskTable = new TaskPagedTable( + UIUtils.prependBaseUri(parent.basePath) + + s"/stages/stage?id=${stageId}&attempt=${stageAttemptId}", + tasks, hasAccumulators, stageData.hasInput, stageData.hasOutput, stageData.hasShuffleRead, stageData.hasShuffleWrite, stageData.hasBytesSpilled, - currentTime), - tasks, - headerClasses = unzipped._2) + currentTime, + pageSize = taskPageSize, + sortColumn = taskSortColumn, + desc = taskSortDesc + ) + (_taskTable, _taskTable.table(taskPage)) + } catch { + case e @ (_ : IllegalArgumentException | _ : IndexOutOfBoundsException) => + (null,
    {e.getMessage}
    ) + } + + val jsForScrollingDownToTaskTable = + + + val taskIdsInPage = if (taskTable == null) Set.empty[Long] + else taskTable.dataSource.slicedTaskIds + // Excludes tasks which failed and have incomplete metrics val validTasks = tasks.filter(t => t.taskInfo.status == "SUCCESS" && t.taskMetrics.isDefined) @@ -332,7 +339,7 @@ private[ui] class StagePage(parent: StagesTab) extends WebUIPage("stage") { +: getFormattedTimeQuantiles(serializationTimes) val gettingResultTimes = validTasks.map { case TaskUIData(info, _, _) => - getGettingResultTime(info).toDouble + getGettingResultTime(info, currentTime).toDouble } val gettingResultQuantiles = @@ -346,7 +353,7 @@ private[ui] class StagePage(parent: StagesTab) extends WebUIPage("stage") { // machine and to send back the result (but not the time to fetch the task result, // if it needed to be fetched from the block manager on the worker). val schedulerDelays = validTasks.map { case TaskUIData(info, metrics, _) => - getSchedulerDelay(info, metrics.get).toDouble + getSchedulerDelay(info, metrics.get, currentTime).toDouble } val schedulerDelayTitle = Scheduler Delay @@ -499,12 +506,15 @@ private[ui] class StagePage(parent: StagesTab) extends WebUIPage("stage") { dagViz ++ maybeExpandDagViz ++ showAdditionalMetrics ++ - makeTimeline(stageData.taskData.values.toSeq, currentTime) ++ + makeTimeline( + // Only show the tasks in the table + stageData.taskData.values.toSeq.filter(t => taskIdsInPage.contains(t.taskInfo.taskId)), + currentTime) ++

    Summary Metrics for {numCompleted} Completed Tasks

    ++
    {summaryTable.getOrElse("No tasks have reported metrics yet.")}
    ++

    Aggregated Metrics by Executor

    ++ executorTable.toNodeSeq ++ maybeAccumulableTable ++ -

    Tasks

    ++ taskTable +

    Tasks

    ++ taskTableHTML ++ jsForScrollingDownToTaskTable UIUtils.headerSparkPage(stageHeader, content, parent, showVisualization = true) } } @@ -537,20 +547,27 @@ private[ui] class StagePage(parent: StagesTab) extends WebUIPage("stage") { (metricsOpt.flatMap(_.shuffleWriteMetrics .map(_.shuffleWriteTime)).getOrElse(0L) / 1e6).toLong val shuffleWriteTimeProportion = toProportion(shuffleWriteTime) - val executorComputingTime = metricsOpt.map(_.executorRunTime).getOrElse(0L) - - shuffleReadTime - shuffleWriteTime - val executorComputingTimeProportion = toProportion(executorComputingTime) + val serializationTime = metricsOpt.map(_.resultSerializationTime).getOrElse(0L) val serializationTimeProportion = toProportion(serializationTime) val deserializationTime = metricsOpt.map(_.executorDeserializeTime).getOrElse(0L) val deserializationTimeProportion = toProportion(deserializationTime) - val gettingResultTime = getGettingResultTime(taskUIData.taskInfo) + val gettingResultTime = getGettingResultTime(taskUIData.taskInfo, currentTime) val gettingResultTimeProportion = toProportion(gettingResultTime) - val schedulerDelay = totalExecutionTime - - (executorComputingTime + shuffleReadTime + shuffleWriteTime + - serializationTime + deserializationTime + gettingResultTime) - val schedulerDelayProportion = - (100 - executorComputingTimeProportion - shuffleReadTimeProportion - + val schedulerDelay = + metricsOpt.map(getSchedulerDelay(taskInfo, _, currentTime)).getOrElse(0L) + val schedulerDelayProportion = toProportion(schedulerDelay) + + val executorOverhead = serializationTime + deserializationTime + val executorRunTime = if (taskInfo.running) { + totalExecutionTime - executorOverhead - gettingResultTime + } else { + metricsOpt.map(_.executorRunTime).getOrElse( + totalExecutionTime - executorOverhead - gettingResultTime) + } + val executorComputingTime = executorRunTime - shuffleReadTime - shuffleWriteTime + val executorComputingTimeProportion = + (100 - schedulerDelayProportion - shuffleReadTimeProportion - shuffleWriteTimeProportion - serializationTimeProportion - deserializationTimeProportion - gettingResultTimeProportion) @@ -672,162 +689,619 @@ private[ui] class StagePage(parent: StagesTab) extends WebUIPage("stage") { } - def taskRow( - hasAccumulators: Boolean, - hasInput: Boolean, - hasOutput: Boolean, - hasShuffleRead: Boolean, - hasShuffleWrite: Boolean, - hasBytesSpilled: Boolean, - currentTime: Long)(taskData: TaskUIData): Seq[Node] = { - taskData match { case TaskUIData(info, metrics, errorMessage) => - val duration = if (info.status == "RUNNING") info.timeRunning(currentTime) - else metrics.map(_.executorRunTime).getOrElse(1L) - val formatDuration = if (info.status == "RUNNING") UIUtils.formatDuration(duration) - else metrics.map(m => UIUtils.formatDuration(m.executorRunTime)).getOrElse("") - val schedulerDelay = metrics.map(getSchedulerDelay(info, _)).getOrElse(0L) - val gcTime = metrics.map(_.jvmGCTime).getOrElse(0L) - val taskDeserializationTime = metrics.map(_.executorDeserializeTime).getOrElse(0L) - val serializationTime = metrics.map(_.resultSerializationTime).getOrElse(0L) - val gettingResultTime = getGettingResultTime(info) - - val maybeAccumulators = info.accumulables - val accumulatorsReadable = maybeAccumulators.map{acc => s"${acc.name}: ${acc.update.get}"} - - val maybeInput = metrics.flatMap(_.inputMetrics) - val inputSortable = maybeInput.map(_.bytesRead.toString).getOrElse("") - val inputReadable = maybeInput - .map(m => s"${Utils.bytesToString(m.bytesRead)} (${m.readMethod.toString.toLowerCase()})") - .getOrElse("") - val inputRecords = maybeInput.map(_.recordsRead.toString).getOrElse("") - - val maybeOutput = metrics.flatMap(_.outputMetrics) - val outputSortable = maybeOutput.map(_.bytesWritten.toString).getOrElse("") - val outputReadable = maybeOutput - .map(m => s"${Utils.bytesToString(m.bytesWritten)}") - .getOrElse("") - val outputRecords = maybeOutput.map(_.recordsWritten.toString).getOrElse("") - - val maybeShuffleRead = metrics.flatMap(_.shuffleReadMetrics) - val shuffleReadBlockedTimeSortable = maybeShuffleRead - .map(_.fetchWaitTime.toString).getOrElse("") - val shuffleReadBlockedTimeReadable = - maybeShuffleRead.map(ms => UIUtils.formatDuration(ms.fetchWaitTime)).getOrElse("") - - val totalShuffleBytes = maybeShuffleRead.map(_.totalBytesRead) - val shuffleReadSortable = totalShuffleBytes.map(_.toString).getOrElse("") - val shuffleReadReadable = totalShuffleBytes.map(Utils.bytesToString).getOrElse("") - val shuffleReadRecords = maybeShuffleRead.map(_.recordsRead.toString).getOrElse("") - - val remoteShuffleBytes = maybeShuffleRead.map(_.remoteBytesRead) - val shuffleReadRemoteSortable = remoteShuffleBytes.map(_.toString).getOrElse("") - val shuffleReadRemoteReadable = remoteShuffleBytes.map(Utils.bytesToString).getOrElse("") - - val maybeShuffleWrite = metrics.flatMap(_.shuffleWriteMetrics) - val shuffleWriteSortable = maybeShuffleWrite.map(_.shuffleBytesWritten.toString).getOrElse("") - val shuffleWriteReadable = maybeShuffleWrite - .map(m => s"${Utils.bytesToString(m.shuffleBytesWritten)}").getOrElse("") - val shuffleWriteRecords = maybeShuffleWrite - .map(_.shuffleRecordsWritten.toString).getOrElse("") - - val maybeWriteTime = metrics.flatMap(_.shuffleWriteMetrics).map(_.shuffleWriteTime) - val writeTimeSortable = maybeWriteTime.map(_.toString).getOrElse("") - val writeTimeReadable = maybeWriteTime.map(t => t / (1000 * 1000)).map { ms => - if (ms == 0) "" else UIUtils.formatDuration(ms) - }.getOrElse("") - - val maybeMemoryBytesSpilled = metrics.map(_.memoryBytesSpilled) - val memoryBytesSpilledSortable = maybeMemoryBytesSpilled.map(_.toString).getOrElse("") - val memoryBytesSpilledReadable = - maybeMemoryBytesSpilled.map(Utils.bytesToString).getOrElse("") - - val maybeDiskBytesSpilled = metrics.map(_.diskBytesSpilled) - val diskBytesSpilledSortable = maybeDiskBytesSpilled.map(_.toString).getOrElse("") - val diskBytesSpilledReadable = maybeDiskBytesSpilled.map(Utils.bytesToString).getOrElse("") - - - {info.index} - {info.taskId} - { - if (info.speculative) s"${info.attempt} (speculative)" else info.attempt.toString - } - {info.status} - {info.taskLocality} - {info.executorId} / {info.host} - {UIUtils.formatDate(new Date(info.launchTime))} - - {formatDuration} - - - {UIUtils.formatDuration(schedulerDelay.toLong)} - - - {UIUtils.formatDuration(taskDeserializationTime.toLong)} - - - {if (gcTime > 0) UIUtils.formatDuration(gcTime) else ""} - - - {UIUtils.formatDuration(serializationTime)} - - - {UIUtils.formatDuration(gettingResultTime)} - - {if (hasAccumulators) { - - {Unparsed(accumulatorsReadable.mkString("
    "))} - - }} - {if (hasInput) { - - {s"$inputReadable / $inputRecords"} - - }} - {if (hasOutput) { - - {s"$outputReadable / $outputRecords"} - - }} +} + +private[ui] object StagePage { + private[ui] def getGettingResultTime(info: TaskInfo, currentTime: Long): Long = { + if (info.gettingResult) { + if (info.finished) { + info.finishTime - info.gettingResultTime + } else { + // The task is still fetching the result. + currentTime - info.gettingResultTime + } + } else { + 0L + } + } + + private[ui] def getSchedulerDelay( + info: TaskInfo, metrics: TaskMetrics, currentTime: Long): Long = { + if (info.finished) { + val totalExecutionTime = info.finishTime - info.launchTime + val executorOverhead = (metrics.executorDeserializeTime + + metrics.resultSerializationTime) + math.max( + 0, + totalExecutionTime - metrics.executorRunTime - executorOverhead - + getGettingResultTime(info, currentTime)) + } else { + // The task is still running and the metrics like executorRunTime are not available. + 0L + } + } +} + +private[ui] case class TaskTableRowInputData(inputSortable: Long, inputReadable: String) + +private[ui] case class TaskTableRowOutputData(outputSortable: Long, outputReadable: String) + +private[ui] case class TaskTableRowShuffleReadData( + shuffleReadBlockedTimeSortable: Long, + shuffleReadBlockedTimeReadable: String, + shuffleReadSortable: Long, + shuffleReadReadable: String, + shuffleReadRemoteSortable: Long, + shuffleReadRemoteReadable: String) + +private[ui] case class TaskTableRowShuffleWriteData( + writeTimeSortable: Long, + writeTimeReadable: String, + shuffleWriteSortable: Long, + shuffleWriteReadable: String) + +private[ui] case class TaskTableRowBytesSpilledData( + memoryBytesSpilledSortable: Long, + memoryBytesSpilledReadable: String, + diskBytesSpilledSortable: Long, + diskBytesSpilledReadable: String) + +/** + * Contains all data that needs for sorting and generating HTML. Using this one rather than + * TaskUIData to avoid creating duplicate contents during sorting the data. + */ +private[ui] case class TaskTableRowData( + index: Int, + taskId: Long, + attempt: Int, + speculative: Boolean, + status: String, + taskLocality: String, + executorIdAndHost: String, + launchTime: Long, + duration: Long, + formatDuration: String, + schedulerDelay: Long, + taskDeserializationTime: Long, + gcTime: Long, + serializationTime: Long, + gettingResultTime: Long, + accumulators: Option[String], // HTML + input: Option[TaskTableRowInputData], + output: Option[TaskTableRowOutputData], + shuffleRead: Option[TaskTableRowShuffleReadData], + shuffleWrite: Option[TaskTableRowShuffleWriteData], + bytesSpilled: Option[TaskTableRowBytesSpilledData], + error: String) + +private[ui] class TaskDataSource( + tasks: Seq[TaskUIData], + hasAccumulators: Boolean, + hasInput: Boolean, + hasOutput: Boolean, + hasShuffleRead: Boolean, + hasShuffleWrite: Boolean, + hasBytesSpilled: Boolean, + currentTime: Long, + pageSize: Int, + sortColumn: String, + desc: Boolean) extends PagedDataSource[TaskTableRowData](pageSize) { + import StagePage._ + + // Convert TaskUIData to TaskTableRowData which contains the final contents to show in the table + // so that we can avoid creating duplicate contents during sorting the data + private val data = tasks.map(taskRow).sorted(ordering(sortColumn, desc)) + + private var _slicedTaskIds: Set[Long] = null + + override def dataSize: Int = data.size + + override def sliceData(from: Int, to: Int): Seq[TaskTableRowData] = { + val r = data.slice(from, to) + _slicedTaskIds = r.map(_.taskId).toSet + r + } + + def slicedTaskIds: Set[Long] = _slicedTaskIds + + private def taskRow(taskData: TaskUIData): TaskTableRowData = { + val TaskUIData(info, metrics, errorMessage) = taskData + val duration = if (info.status == "RUNNING") info.timeRunning(currentTime) + else metrics.map(_.executorRunTime).getOrElse(1L) + val formatDuration = if (info.status == "RUNNING") UIUtils.formatDuration(duration) + else metrics.map(m => UIUtils.formatDuration(m.executorRunTime)).getOrElse("") + val schedulerDelay = metrics.map(getSchedulerDelay(info, _, currentTime)).getOrElse(0L) + val gcTime = metrics.map(_.jvmGCTime).getOrElse(0L) + val taskDeserializationTime = metrics.map(_.executorDeserializeTime).getOrElse(0L) + val serializationTime = metrics.map(_.resultSerializationTime).getOrElse(0L) + val gettingResultTime = getGettingResultTime(info, currentTime) + + val maybeAccumulators = info.accumulables + val accumulatorsReadable = maybeAccumulators.map { acc => + StringEscapeUtils.escapeHtml4(s"${acc.name}: ${acc.update.get}") + } + + val maybeInput = metrics.flatMap(_.inputMetrics) + val inputSortable = maybeInput.map(_.bytesRead).getOrElse(0L) + val inputReadable = maybeInput + .map(m => s"${Utils.bytesToString(m.bytesRead)} (${m.readMethod.toString.toLowerCase()})") + .getOrElse("") + val inputRecords = maybeInput.map(_.recordsRead.toString).getOrElse("") + + val maybeOutput = metrics.flatMap(_.outputMetrics) + val outputSortable = maybeOutput.map(_.bytesWritten).getOrElse(0L) + val outputReadable = maybeOutput + .map(m => s"${Utils.bytesToString(m.bytesWritten)}") + .getOrElse("") + val outputRecords = maybeOutput.map(_.recordsWritten.toString).getOrElse("") + + val maybeShuffleRead = metrics.flatMap(_.shuffleReadMetrics) + val shuffleReadBlockedTimeSortable = maybeShuffleRead.map(_.fetchWaitTime).getOrElse(0L) + val shuffleReadBlockedTimeReadable = + maybeShuffleRead.map(ms => UIUtils.formatDuration(ms.fetchWaitTime)).getOrElse("") + + val totalShuffleBytes = maybeShuffleRead.map(_.totalBytesRead) + val shuffleReadSortable = totalShuffleBytes.getOrElse(0L) + val shuffleReadReadable = totalShuffleBytes.map(Utils.bytesToString).getOrElse("") + val shuffleReadRecords = maybeShuffleRead.map(_.recordsRead.toString).getOrElse("") + + val remoteShuffleBytes = maybeShuffleRead.map(_.remoteBytesRead) + val shuffleReadRemoteSortable = remoteShuffleBytes.getOrElse(0L) + val shuffleReadRemoteReadable = remoteShuffleBytes.map(Utils.bytesToString).getOrElse("") + + val maybeShuffleWrite = metrics.flatMap(_.shuffleWriteMetrics) + val shuffleWriteSortable = maybeShuffleWrite.map(_.shuffleBytesWritten).getOrElse(0L) + val shuffleWriteReadable = maybeShuffleWrite + .map(m => s"${Utils.bytesToString(m.shuffleBytesWritten)}").getOrElse("") + val shuffleWriteRecords = maybeShuffleWrite + .map(_.shuffleRecordsWritten.toString).getOrElse("") + + val maybeWriteTime = metrics.flatMap(_.shuffleWriteMetrics).map(_.shuffleWriteTime) + val writeTimeSortable = maybeWriteTime.getOrElse(0L) + val writeTimeReadable = maybeWriteTime.map(t => t / (1000 * 1000)).map { ms => + if (ms == 0) "" else UIUtils.formatDuration(ms) + }.getOrElse("") + + val maybeMemoryBytesSpilled = metrics.map(_.memoryBytesSpilled) + val memoryBytesSpilledSortable = maybeMemoryBytesSpilled.getOrElse(0L) + val memoryBytesSpilledReadable = + maybeMemoryBytesSpilled.map(Utils.bytesToString).getOrElse("") + + val maybeDiskBytesSpilled = metrics.map(_.diskBytesSpilled) + val diskBytesSpilledSortable = maybeDiskBytesSpilled.getOrElse(0L) + val diskBytesSpilledReadable = maybeDiskBytesSpilled.map(Utils.bytesToString).getOrElse("") + + val input = + if (hasInput) { + Some(TaskTableRowInputData(inputSortable, s"$inputReadable / $inputRecords")) + } else { + None + } + + val output = + if (hasOutput) { + Some(TaskTableRowOutputData(outputSortable, s"$outputReadable / $outputRecords")) + } else { + None + } + + val shuffleRead = + if (hasShuffleRead) { + Some(TaskTableRowShuffleReadData( + shuffleReadBlockedTimeSortable, + shuffleReadBlockedTimeReadable, + shuffleReadSortable, + s"$shuffleReadReadable / $shuffleReadRecords", + shuffleReadRemoteSortable, + shuffleReadRemoteReadable + )) + } else { + None + } + + val shuffleWrite = + if (hasShuffleWrite) { + Some(TaskTableRowShuffleWriteData( + writeTimeSortable, + writeTimeReadable, + shuffleWriteSortable, + s"$shuffleWriteReadable / $shuffleWriteRecords" + )) + } else { + None + } + + val bytesSpilled = + if (hasBytesSpilled) { + Some(TaskTableRowBytesSpilledData( + memoryBytesSpilledSortable, + memoryBytesSpilledReadable, + diskBytesSpilledSortable, + diskBytesSpilledReadable + )) + } else { + None + } + + TaskTableRowData( + info.index, + info.taskId, + info.attempt, + info.speculative, + info.status, + info.taskLocality.toString, + s"${info.executorId} / ${info.host}", + info.launchTime, + duration, + formatDuration, + schedulerDelay, + taskDeserializationTime, + gcTime, + serializationTime, + gettingResultTime, + if (hasAccumulators) Some(accumulatorsReadable.mkString("
    ")) else None, + input, + output, + shuffleRead, + shuffleWrite, + bytesSpilled, + errorMessage.getOrElse("") + ) + } + + /** + * Return Ordering according to sortColumn and desc + */ + private def ordering(sortColumn: String, desc: Boolean): Ordering[TaskTableRowData] = { + val ordering = sortColumn match { + case "Index" => new Ordering[TaskTableRowData] { + override def compare(x: TaskTableRowData, y: TaskTableRowData): Int = + Ordering.Int.compare(x.index, y.index) + } + case "ID" => new Ordering[TaskTableRowData] { + override def compare(x: TaskTableRowData, y: TaskTableRowData): Int = + Ordering.Long.compare(x.taskId, y.taskId) + } + case "Attempt" => new Ordering[TaskTableRowData] { + override def compare(x: TaskTableRowData, y: TaskTableRowData): Int = + Ordering.Int.compare(x.attempt, y.attempt) + } + case "Status" => new Ordering[TaskTableRowData] { + override def compare(x: TaskTableRowData, y: TaskTableRowData): Int = + Ordering.String.compare(x.status, y.status) + } + case "Locality Level" => new Ordering[TaskTableRowData] { + override def compare(x: TaskTableRowData, y: TaskTableRowData): Int = + Ordering.String.compare(x.taskLocality, y.taskLocality) + } + case "Executor ID / Host" => new Ordering[TaskTableRowData] { + override def compare(x: TaskTableRowData, y: TaskTableRowData): Int = + Ordering.String.compare(x.executorIdAndHost, y.executorIdAndHost) + } + case "Launch Time" => new Ordering[TaskTableRowData] { + override def compare(x: TaskTableRowData, y: TaskTableRowData): Int = + Ordering.Long.compare(x.launchTime, y.launchTime) + } + case "Duration" => new Ordering[TaskTableRowData] { + override def compare(x: TaskTableRowData, y: TaskTableRowData): Int = + Ordering.Long.compare(x.duration, y.duration) + } + case "Scheduler Delay" => new Ordering[TaskTableRowData] { + override def compare(x: TaskTableRowData, y: TaskTableRowData): Int = + Ordering.Long.compare(x.schedulerDelay, y.schedulerDelay) + } + case "Task Deserialization Time" => new Ordering[TaskTableRowData] { + override def compare(x: TaskTableRowData, y: TaskTableRowData): Int = + Ordering.Long.compare(x.taskDeserializationTime, y.taskDeserializationTime) + } + case "GC Time" => new Ordering[TaskTableRowData] { + override def compare(x: TaskTableRowData, y: TaskTableRowData): Int = + Ordering.Long.compare(x.gcTime, y.gcTime) + } + case "Result Serialization Time" => new Ordering[TaskTableRowData] { + override def compare(x: TaskTableRowData, y: TaskTableRowData): Int = + Ordering.Long.compare(x.serializationTime, y.serializationTime) + } + case "Getting Result Time" => new Ordering[TaskTableRowData] { + override def compare(x: TaskTableRowData, y: TaskTableRowData): Int = + Ordering.Long.compare(x.gettingResultTime, y.gettingResultTime) + } + case "Accumulators" => + if (hasAccumulators) { + new Ordering[TaskTableRowData] { + override def compare(x: TaskTableRowData, y: TaskTableRowData): Int = + Ordering.String.compare(x.accumulators.get, y.accumulators.get) + } + } else { + throw new IllegalArgumentException( + "Cannot sort by Accumulators because of no accumulators") + } + case "Input Size / Records" => + if (hasInput) { + new Ordering[TaskTableRowData] { + override def compare(x: TaskTableRowData, y: TaskTableRowData): Int = + Ordering.Long.compare(x.input.get.inputSortable, y.input.get.inputSortable) + } + } else { + throw new IllegalArgumentException( + "Cannot sort by Input Size / Records because of no inputs") + } + case "Output Size / Records" => + if (hasOutput) { + new Ordering[TaskTableRowData] { + override def compare(x: TaskTableRowData, y: TaskTableRowData): Int = + Ordering.Long.compare(x.output.get.outputSortable, y.output.get.outputSortable) + } + } else { + throw new IllegalArgumentException( + "Cannot sort by Output Size / Records because of no outputs") + } + // ShuffleRead + case "Shuffle Read Blocked Time" => + if (hasShuffleRead) { + new Ordering[TaskTableRowData] { + override def compare(x: TaskTableRowData, y: TaskTableRowData): Int = + Ordering.Long.compare(x.shuffleRead.get.shuffleReadBlockedTimeSortable, + y.shuffleRead.get.shuffleReadBlockedTimeSortable) + } + } else { + throw new IllegalArgumentException( + "Cannot sort by Shuffle Read Blocked Time because of no shuffle reads") + } + case "Shuffle Read Size / Records" => + if (hasShuffleRead) { + new Ordering[TaskTableRowData] { + override def compare(x: TaskTableRowData, y: TaskTableRowData): Int = + Ordering.Long.compare(x.shuffleRead.get.shuffleReadSortable, + y.shuffleRead.get.shuffleReadSortable) + } + } else { + throw new IllegalArgumentException( + "Cannot sort by Shuffle Read Size / Records because of no shuffle reads") + } + case "Shuffle Remote Reads" => + if (hasShuffleRead) { + new Ordering[TaskTableRowData] { + override def compare(x: TaskTableRowData, y: TaskTableRowData): Int = + Ordering.Long.compare(x.shuffleRead.get.shuffleReadRemoteSortable, + y.shuffleRead.get.shuffleReadRemoteSortable) + } + } else { + throw new IllegalArgumentException( + "Cannot sort by Shuffle Remote Reads because of no shuffle reads") + } + // ShuffleWrite + case "Write Time" => + if (hasShuffleWrite) { + new Ordering[TaskTableRowData] { + override def compare(x: TaskTableRowData, y: TaskTableRowData): Int = + Ordering.Long.compare(x.shuffleWrite.get.writeTimeSortable, + y.shuffleWrite.get.writeTimeSortable) + } + } else { + throw new IllegalArgumentException( + "Cannot sort by Write Time because of no shuffle writes") + } + case "Shuffle Write Size / Records" => + if (hasShuffleWrite) { + new Ordering[TaskTableRowData] { + override def compare(x: TaskTableRowData, y: TaskTableRowData): Int = + Ordering.Long.compare(x.shuffleWrite.get.shuffleWriteSortable, + y.shuffleWrite.get.shuffleWriteSortable) + } + } else { + throw new IllegalArgumentException( + "Cannot sort by Shuffle Write Size / Records because of no shuffle writes") + } + // BytesSpilled + case "Shuffle Spill (Memory)" => + if (hasBytesSpilled) { + new Ordering[TaskTableRowData] { + override def compare(x: TaskTableRowData, y: TaskTableRowData): Int = + Ordering.Long.compare(x.bytesSpilled.get.memoryBytesSpilledSortable, + y.bytesSpilled.get.memoryBytesSpilledSortable) + } + } else { + throw new IllegalArgumentException( + "Cannot sort by Shuffle Spill (Memory) because of no spills") + } + case "Shuffle Spill (Disk)" => + if (hasBytesSpilled) { + new Ordering[TaskTableRowData] { + override def compare(x: TaskTableRowData, y: TaskTableRowData): Int = + Ordering.Long.compare(x.bytesSpilled.get.diskBytesSpilledSortable, + y.bytesSpilled.get.diskBytesSpilledSortable) + } + } else { + throw new IllegalArgumentException( + "Cannot sort by Shuffle Spill (Disk) because of no spills") + } + case "Errors" => new Ordering[TaskTableRowData] { + override def compare(x: TaskTableRowData, y: TaskTableRowData): Int = + Ordering.String.compare(x.error, y.error) + } + case unknownColumn => throw new IllegalArgumentException(s"Unknown column: $unknownColumn") + } + if (desc) { + ordering.reverse + } else { + ordering + } + } + +} + +private[ui] class TaskPagedTable( + basePath: String, + data: Seq[TaskUIData], + hasAccumulators: Boolean, + hasInput: Boolean, + hasOutput: Boolean, + hasShuffleRead: Boolean, + hasShuffleWrite: Boolean, + hasBytesSpilled: Boolean, + currentTime: Long, + pageSize: Int, + sortColumn: String, + desc: Boolean) extends PagedTable[TaskTableRowData]{ + + override def tableId: String = "" + + override def tableCssClass: String = "table table-bordered table-condensed table-striped" + + override val dataSource: TaskDataSource = new TaskDataSource( + data, + hasAccumulators, + hasInput, + hasOutput, + hasShuffleRead, + hasShuffleWrite, + hasBytesSpilled, + currentTime, + pageSize, + sortColumn, + desc + ) + + override def pageLink(page: Int): String = { + val encodedSortColumn = URLEncoder.encode(sortColumn, "UTF-8") + s"${basePath}&task.page=$page&task.sort=${encodedSortColumn}&task.desc=${desc}" + + s"&task.pageSize=${pageSize}" + } + + override def goButtonJavascriptFunction: (String, String) = { + val jsFuncName = "goToTaskPage" + val encodedSortColumn = URLEncoder.encode(sortColumn, "UTF-8") + val jsFunc = s""" + |currentTaskPageSize = ${pageSize} + |function goToTaskPage(page, pageSize) { + | // Set page to 1 if the page size changes + | page = pageSize == currentTaskPageSize ? page : 1; + | var url = "${basePath}&task.sort=${encodedSortColumn}&task.desc=${desc}" + + | "&task.page=" + page + "&task.pageSize=" + pageSize; + | window.location.href = url; + |} + """.stripMargin + (jsFuncName, jsFunc) + } + + def headers: Seq[Node] = { + val taskHeadersAndCssClasses: Seq[(String, String)] = + Seq( + ("Index", ""), ("ID", ""), ("Attempt", ""), ("Status", ""), ("Locality Level", ""), + ("Executor ID / Host", ""), ("Launch Time", ""), ("Duration", ""), + ("Scheduler Delay", TaskDetailsClassNames.SCHEDULER_DELAY), + ("Task Deserialization Time", TaskDetailsClassNames.TASK_DESERIALIZATION_TIME), + ("GC Time", ""), + ("Result Serialization Time", TaskDetailsClassNames.RESULT_SERIALIZATION_TIME), + ("Getting Result Time", TaskDetailsClassNames.GETTING_RESULT_TIME)) ++ + {if (hasAccumulators) Seq(("Accumulators", "")) else Nil} ++ + {if (hasInput) Seq(("Input Size / Records", "")) else Nil} ++ + {if (hasOutput) Seq(("Output Size / Records", "")) else Nil} ++ {if (hasShuffleRead) { - - {shuffleReadBlockedTimeReadable} - - - {s"$shuffleReadReadable / $shuffleReadRecords"} - - - {shuffleReadRemoteReadable} - - }} + Seq(("Shuffle Read Blocked Time", TaskDetailsClassNames.SHUFFLE_READ_BLOCKED_TIME), + ("Shuffle Read Size / Records", ""), + ("Shuffle Remote Reads", TaskDetailsClassNames.SHUFFLE_READ_REMOTE_SIZE)) + } else { + Nil + }} ++ {if (hasShuffleWrite) { - - {writeTimeReadable} - - - {s"$shuffleWriteReadable / $shuffleWriteRecords"} - - }} + Seq(("Write Time", ""), ("Shuffle Write Size / Records", "")) + } else { + Nil + }} ++ {if (hasBytesSpilled) { - - {memoryBytesSpilledReadable} - - - {diskBytesSpilledReadable} - - }} - {errorMessageCell(errorMessage)} - + Seq(("Shuffle Spill (Memory)", ""), ("Shuffle Spill (Disk)", "")) + } else { + Nil + }} ++ + Seq(("Errors", "")) + + if (!taskHeadersAndCssClasses.map(_._1).contains(sortColumn)) { + new IllegalArgumentException(s"Unknown column: $sortColumn") + } + + val headerRow: Seq[Node] = { + taskHeadersAndCssClasses.map { case (header, cssClass) => + if (header == sortColumn) { + val headerLink = + s"$basePath&task.sort=${URLEncoder.encode(header, "UTF-8")}&task.desc=${!desc}" + + s"&task.pageSize=${pageSize}" + val js = Unparsed(s"window.location.href='${headerLink}'") + val arrow = if (desc) "▾" else "▴" // UP or DOWN + + {header} +  {Unparsed(arrow)} + + } else { + val headerLink = + s"$basePath&task.sort=${URLEncoder.encode(header, "UTF-8")}&task.pageSize=${pageSize}" + val js = Unparsed(s"window.location.href='${headerLink}'") + + {header} + + } + } } + {headerRow} } - private def errorMessageCell(errorMessage: Option[String]): Seq[Node] = { - val error = errorMessage.getOrElse("") + def row(task: TaskTableRowData): Seq[Node] = { + + {task.index} + {task.taskId} + {if (task.speculative) s"${task.attempt} (speculative)" else task.attempt.toString} + {task.status} + {task.taskLocality} + {task.executorIdAndHost} + {UIUtils.formatDate(new Date(task.launchTime))} + {task.formatDuration} + + {UIUtils.formatDuration(task.schedulerDelay)} + + + {UIUtils.formatDuration(task.taskDeserializationTime)} + + + {if (task.gcTime > 0) UIUtils.formatDuration(task.gcTime) else ""} + + + {UIUtils.formatDuration(task.serializationTime)} + + + {UIUtils.formatDuration(task.gettingResultTime)} + + {if (task.accumulators.nonEmpty) { + {Unparsed(task.accumulators.get)} + }} + {if (task.input.nonEmpty) { + {task.input.get.inputReadable} + }} + {if (task.output.nonEmpty) { + {task.output.get.outputReadable} + }} + {if (task.shuffleRead.nonEmpty) { + + {task.shuffleRead.get.shuffleReadBlockedTimeReadable} + + {task.shuffleRead.get.shuffleReadReadable} + + {task.shuffleRead.get.shuffleReadRemoteReadable} + + }} + {if (task.shuffleWrite.nonEmpty) { + {task.shuffleWrite.get.writeTimeReadable} + {task.shuffleWrite.get.shuffleWriteReadable} + }} + {if (task.bytesSpilled.nonEmpty) { + {task.bytesSpilled.get.memoryBytesSpilledReadable} + {task.bytesSpilled.get.diskBytesSpilledReadable} + }} + {errorMessageCell(task.error)} + + } + + private def errorMessageCell(error: String): Seq[Node] = { val isMultiline = error.indexOf('\n') >= 0 // Display the first line by default val errorSummary = StringEscapeUtils.escapeHtml4( @@ -851,33 +1325,4 @@ private[ui] class StagePage(parent: StagesTab) extends WebUIPage("stage") { } {errorSummary}{details} } - - private def getGettingResultTime(info: TaskInfo): Long = { - if (info.gettingResultTime > 0) { - if (info.finishTime > 0) { - info.finishTime - info.gettingResultTime - } else { - // The task is still fetching the result. - System.currentTimeMillis - info.gettingResultTime - } - } else { - 0L - } - } - - private def getSchedulerDelay(info: TaskInfo, metrics: TaskMetrics): Long = { - val totalExecutionTime = - if (info.gettingResult) { - info.gettingResultTime - info.launchTime - } else if (info.finished) { - info.finishTime - info.launchTime - } else { - 0 - } - val executorOverhead = (metrics.executorDeserializeTime + - metrics.resultSerializationTime) - math.max( - 0, - totalExecutionTime - metrics.executorRunTime - executorOverhead - getGettingResultTime(info)) - } } diff --git a/core/src/main/scala/org/apache/spark/ui/storage/StoragePage.scala b/core/src/main/scala/org/apache/spark/ui/storage/StoragePage.scala index 07db783c572cf..04f584621e71e 100644 --- a/core/src/main/scala/org/apache/spark/ui/storage/StoragePage.scala +++ b/core/src/main/scala/org/apache/spark/ui/storage/StoragePage.scala @@ -21,7 +21,7 @@ import javax.servlet.http.HttpServletRequest import scala.xml.Node -import org.apache.spark.storage.RDDInfo +import org.apache.spark.storage._ import org.apache.spark.ui.{UIUtils, WebUIPage} import org.apache.spark.util.Utils @@ -30,13 +30,25 @@ private[ui] class StoragePage(parent: StorageTab) extends WebUIPage("") { private val listener = parent.listener def render(request: HttpServletRequest): Seq[Node] = { - val rdds = listener.rddInfoList - val content = UIUtils.listingTable(rddHeader, rddRow, rdds, id = Some("storage-by-rdd-table")) + val content = rddTable(listener.rddInfoList) ++ + receiverBlockTables(listener.allExecutorStreamBlockStatus.sortBy(_.executorId)) UIUtils.headerSparkPage("Storage", content, parent) } + private[storage] def rddTable(rdds: Seq[RDDInfo]): Seq[Node] = { + if (rdds.isEmpty) { + // Don't show the rdd table if there is no RDD persisted. + Nil + } else { +
    +

    RDDs

    + {UIUtils.listingTable(rddHeader, rddRow, rdds, id = Some("storage-by-rdd-table"))} +
    + } + } + /** Header fields for the RDD table */ - private def rddHeader = Seq( + private val rddHeader = Seq( "RDD Name", "Storage Level", "Cached Partitions", @@ -56,7 +68,7 @@ private[ui] class StoragePage(parent: StorageTab) extends WebUIPage("") { {rdd.storageLevel.description} - {rdd.numCachedPartitions} + {rdd.numCachedPartitions.toString} {"%.0f%%".format(rdd.numCachedPartitions * 100.0 / rdd.numPartitions)} {Utils.bytesToString(rdd.memSize)} {Utils.bytesToString(rdd.externalBlockStoreSize)} @@ -64,4 +76,130 @@ private[ui] class StoragePage(parent: StorageTab) extends WebUIPage("") { // scalastyle:on } + + private[storage] def receiverBlockTables(statuses: Seq[ExecutorStreamBlockStatus]): Seq[Node] = { + if (statuses.map(_.numStreamBlocks).sum == 0) { + // Don't show the tables if there is no stream block + Nil + } else { + val blocks = statuses.flatMap(_.blocks).groupBy(_.blockId).toSeq.sortBy(_._1.toString) + +
    +

    Receiver Blocks

    + {executorMetricsTable(statuses)} + {streamBlockTable(blocks)} +
    + } + } + + private def executorMetricsTable(statuses: Seq[ExecutorStreamBlockStatus]): Seq[Node] = { +
    +
    Aggregated Block Metrics by Executor
    + {UIUtils.listingTable(executorMetricsTableHeader, executorMetricsTableRow, statuses, + id = Some("storage-by-executor-stream-blocks"))} +
    + } + + private val executorMetricsTableHeader = Seq( + "Executor ID", + "Address", + "Total Size in Memory", + "Total Size in ExternalBlockStore", + "Total Size on Disk", + "Stream Blocks") + + private def executorMetricsTableRow(status: ExecutorStreamBlockStatus): Seq[Node] = { + + + {status.executorId} + + + {status.location} + + + {Utils.bytesToString(status.totalMemSize)} + + + {Utils.bytesToString(status.totalExternalBlockStoreSize)} + + + {Utils.bytesToString(status.totalDiskSize)} + + + {status.numStreamBlocks.toString} + + + } + + private def streamBlockTable(blocks: Seq[(BlockId, Seq[BlockUIData])]): Seq[Node] = { + if (blocks.isEmpty) { + Nil + } else { +
    +
    Blocks
    + {UIUtils.listingTable( + streamBlockTableHeader, + streamBlockTableRow, + blocks, + id = Some("storage-by-block-table"), + sortable = false)} +
    + } + } + + private val streamBlockTableHeader = Seq( + "Block ID", + "Replication Level", + "Location", + "Storage Level", + "Size") + + /** Render a stream block */ + private def streamBlockTableRow(block: (BlockId, Seq[BlockUIData])): Seq[Node] = { + val replications = block._2 + assert(replications.size > 0) // This must be true because it's the result of "groupBy" + if (replications.size == 1) { + streamBlockTableSubrow(block._1, replications.head, replications.size, true) + } else { + streamBlockTableSubrow(block._1, replications.head, replications.size, true) ++ + replications.tail.map(streamBlockTableSubrow(block._1, _, replications.size, false)).flatten + } + } + + private def streamBlockTableSubrow( + blockId: BlockId, block: BlockUIData, replication: Int, firstSubrow: Boolean): Seq[Node] = { + val (storageLevel, size) = streamBlockStorageLevelDescriptionAndSize(block) + + + { + if (firstSubrow) { + + {block.blockId.toString} + + + {replication.toString} + + } + } + {block.location} + {storageLevel} + {Utils.bytesToString(size)} + + } + + private[storage] def streamBlockStorageLevelDescriptionAndSize( + block: BlockUIData): (String, Long) = { + if (block.storageLevel.useDisk) { + ("Disk", block.diskSize) + } else if (block.storageLevel.useMemory && block.storageLevel.deserialized) { + ("Memory", block.memSize) + } else if (block.storageLevel.useMemory && !block.storageLevel.deserialized) { + ("Memory Serialized", block.memSize) + } else if (block.storageLevel.useOffHeap) { + ("External", block.externalBlockStoreSize) + } else { + throw new IllegalStateException(s"Invalid Storage Level: ${block.storageLevel}") + } + } + } diff --git a/core/src/main/scala/org/apache/spark/ui/storage/StorageTab.scala b/core/src/main/scala/org/apache/spark/ui/storage/StorageTab.scala index 0351749700962..22e2993b3b5bd 100644 --- a/core/src/main/scala/org/apache/spark/ui/storage/StorageTab.scala +++ b/core/src/main/scala/org/apache/spark/ui/storage/StorageTab.scala @@ -39,7 +39,8 @@ private[ui] class StorageTab(parent: SparkUI) extends SparkUITab(parent, "storag * This class is thread-safe (unlike JobProgressListener) */ @DeveloperApi -class StorageListener(storageStatusListener: StorageStatusListener) extends SparkListener { +class StorageListener(storageStatusListener: StorageStatusListener) extends BlockStatusListener { + private[ui] val _rddInfoMap = mutable.Map[Int, RDDInfo]() // exposed for testing def storageStatusList: Seq[StorageStatus] = storageStatusListener.storageStatusList diff --git a/core/src/main/scala/org/apache/spark/util/AkkaUtils.scala b/core/src/main/scala/org/apache/spark/util/AkkaUtils.scala index c179833e5b06a..78e7ddc27d1c7 100644 --- a/core/src/main/scala/org/apache/spark/util/AkkaUtils.scala +++ b/core/src/main/scala/org/apache/spark/util/AkkaUtils.scala @@ -128,7 +128,7 @@ private[spark] object AkkaUtils extends Logging { /** Returns the configured max frame size for Akka messages in bytes. */ def maxFrameSizeBytes(conf: SparkConf): Int = { - val frameSizeInMB = conf.getInt("spark.akka.frameSize", 10) + val frameSizeInMB = conf.getInt("spark.akka.frameSize", 128) if (frameSizeInMB > AKKA_MAX_FRAME_SIZE_IN_MB) { throw new IllegalArgumentException( s"spark.akka.frameSize should not be greater than $AKKA_MAX_FRAME_SIZE_IN_MB MB") diff --git a/core/src/main/scala/org/apache/spark/util/ClosureCleaner.scala b/core/src/main/scala/org/apache/spark/util/ClosureCleaner.scala index 305de4c75539d..ebead830c6466 100644 --- a/core/src/main/scala/org/apache/spark/util/ClosureCleaner.scala +++ b/core/src/main/scala/org/apache/spark/util/ClosureCleaner.scala @@ -49,45 +49,28 @@ private[spark] object ClosureCleaner extends Logging { cls.getName.contains("$anonfun$") } - // Get a list of the classes of the outer objects of a given closure object, obj; + // Get a list of the outer objects and their classes of a given closure object, obj; // the outer objects are defined as any closures that obj is nested within, plus // possibly the class that the outermost closure is in, if any. We stop searching // for outer objects beyond that because cloning the user's object is probably // not a good idea (whereas we can clone closure objects just fine since we // understand how all their fields are used). - private def getOuterClasses(obj: AnyRef): List[Class[_]] = { + private def getOuterClassesAndObjects(obj: AnyRef): (List[Class[_]], List[AnyRef]) = { for (f <- obj.getClass.getDeclaredFields if f.getName == "$outer") { f.setAccessible(true) val outer = f.get(obj) // The outer pointer may be null if we have cleaned this closure before if (outer != null) { if (isClosure(f.getType)) { - return f.getType :: getOuterClasses(outer) + val recurRet = getOuterClassesAndObjects(outer) + return (f.getType :: recurRet._1, outer :: recurRet._2) } else { - return f.getType :: Nil // Stop at the first $outer that is not a closure + return (f.getType :: Nil, outer :: Nil) // Stop at the first $outer that is not a closure } } } - Nil + (Nil, Nil) } - - // Get a list of the outer objects for a given closure object. - private def getOuterObjects(obj: AnyRef): List[AnyRef] = { - for (f <- obj.getClass.getDeclaredFields if f.getName == "$outer") { - f.setAccessible(true) - val outer = f.get(obj) - // The outer pointer may be null if we have cleaned this closure before - if (outer != null) { - if (isClosure(f.getType)) { - return outer :: getOuterObjects(outer) - } else { - return outer :: Nil // Stop at the first $outer that is not a closure - } - } - } - Nil - } - /** * Return a list of classes that represent closures enclosed in the given closure object. */ @@ -205,8 +188,7 @@ private[spark] object ClosureCleaner extends Logging { // A list of enclosing objects and their respective classes, from innermost to outermost // An outer object at a given index is of type outer class at the same index - val outerClasses = getOuterClasses(func) - val outerObjects = getOuterObjects(func) + val (outerClasses, outerObjects) = getOuterClassesAndObjects(func) // For logging purposes only val declaredFields = func.getClass.getDeclaredFields @@ -448,10 +430,12 @@ private class InnerClosureFinder(output: Set[Class[_]]) extends ClassVisitor(ASM if (op == INVOKESPECIAL && name == "" && argTypes.length > 0 && argTypes(0).toString.startsWith("L") // is it an object? && argTypes(0).getInternalName == myName) { + // scalastyle:off classforname output += Class.forName( owner.replace('/', '.'), false, Thread.currentThread.getContextClassLoader) + // scalastyle:on classforname } } } diff --git a/core/src/main/scala/org/apache/spark/util/JsonProtocol.scala b/core/src/main/scala/org/apache/spark/util/JsonProtocol.scala index adf69a4e78e71..c600319d9ddb4 100644 --- a/core/src/main/scala/org/apache/spark/util/JsonProtocol.scala +++ b/core/src/main/scala/org/apache/spark/util/JsonProtocol.scala @@ -92,8 +92,10 @@ private[spark] object JsonProtocol { executorRemovedToJson(executorRemoved) case logStart: SparkListenerLogStart => logStartToJson(logStart) - // These aren't used, but keeps compiler happy - case SparkListenerExecutorMetricsUpdate(_, _) => JNothing + case metricsUpdate: SparkListenerExecutorMetricsUpdate => + executorMetricsUpdateToJson(metricsUpdate) + case blockUpdated: SparkListenerBlockUpdated => + throw new MatchError(blockUpdated) // TODO(ekl) implement this } } @@ -224,6 +226,19 @@ private[spark] object JsonProtocol { ("Spark Version" -> SPARK_VERSION) } + def executorMetricsUpdateToJson(metricsUpdate: SparkListenerExecutorMetricsUpdate): JValue = { + val execId = metricsUpdate.execId + val taskMetrics = metricsUpdate.taskMetrics + ("Event" -> Utils.getFormattedClassName(metricsUpdate)) ~ + ("Executor ID" -> execId) ~ + ("Metrics Updated" -> taskMetrics.map { case (taskId, stageId, stageAttemptId, metrics) => + ("Task ID" -> taskId) ~ + ("Stage ID" -> stageId) ~ + ("Stage Attempt ID" -> stageAttemptId) ~ + ("Task Metrics" -> taskMetricsToJson(metrics)) + }) + } + /** ------------------------------------------------------------------- * * JSON serialization methods for classes SparkListenerEvents depend on | * -------------------------------------------------------------------- */ @@ -463,6 +478,7 @@ private[spark] object JsonProtocol { val executorAdded = Utils.getFormattedClassName(SparkListenerExecutorAdded) val executorRemoved = Utils.getFormattedClassName(SparkListenerExecutorRemoved) val logStart = Utils.getFormattedClassName(SparkListenerLogStart) + val metricsUpdate = Utils.getFormattedClassName(SparkListenerExecutorMetricsUpdate) (json \ "Event").extract[String] match { case `stageSubmitted` => stageSubmittedFromJson(json) @@ -481,6 +497,7 @@ private[spark] object JsonProtocol { case `executorAdded` => executorAddedFromJson(json) case `executorRemoved` => executorRemovedFromJson(json) case `logStart` => logStartFromJson(json) + case `metricsUpdate` => executorMetricsUpdateFromJson(json) } } @@ -598,6 +615,18 @@ private[spark] object JsonProtocol { SparkListenerLogStart(sparkVersion) } + def executorMetricsUpdateFromJson(json: JValue): SparkListenerExecutorMetricsUpdate = { + val execInfo = (json \ "Executor ID").extract[String] + val taskMetrics = (json \ "Metrics Updated").extract[List[JValue]].map { json => + val taskId = (json \ "Task ID").extract[Long] + val stageId = (json \ "Stage ID").extract[Int] + val stageAttemptId = (json \ "Stage Attempt ID").extract[Int] + val metrics = taskMetricsFromJson(json \ "Task Metrics") + (taskId, stageId, stageAttemptId, metrics) + } + SparkListenerExecutorMetricsUpdate(execInfo, taskMetrics) + } + /** --------------------------------------------------------------------- * * JSON deserialization methods for classes SparkListenerEvents depend on | * ---------------------------------------------------------------------- */ diff --git a/core/src/main/scala/org/apache/spark/util/SerializableConfiguration.scala b/core/src/main/scala/org/apache/spark/util/SerializableConfiguration.scala index 30bcf1d2f24d5..3354a923273ff 100644 --- a/core/src/main/scala/org/apache/spark/util/SerializableConfiguration.scala +++ b/core/src/main/scala/org/apache/spark/util/SerializableConfiguration.scala @@ -20,8 +20,6 @@ import java.io.{ObjectInputStream, ObjectOutputStream} import org.apache.hadoop.conf.Configuration -import org.apache.spark.util.Utils - private[spark] class SerializableConfiguration(@transient var value: Configuration) extends Serializable { private def writeObject(out: ObjectOutputStream): Unit = Utils.tryOrIOException { diff --git a/core/src/main/scala/org/apache/spark/util/SerializableJobConf.scala b/core/src/main/scala/org/apache/spark/util/SerializableJobConf.scala index afbcc6efc850c..cadae472b3f85 100644 --- a/core/src/main/scala/org/apache/spark/util/SerializableJobConf.scala +++ b/core/src/main/scala/org/apache/spark/util/SerializableJobConf.scala @@ -21,8 +21,6 @@ import java.io.{ObjectInputStream, ObjectOutputStream} import org.apache.hadoop.mapred.JobConf -import org.apache.spark.util.Utils - private[spark] class SerializableJobConf(@transient var value: JobConf) extends Serializable { private def writeObject(out: ObjectOutputStream): Unit = Utils.tryOrIOException { diff --git a/core/src/main/scala/org/apache/spark/util/SizeEstimator.scala b/core/src/main/scala/org/apache/spark/util/SizeEstimator.scala index 0180399c9dad5..7d84468f62ab1 100644 --- a/core/src/main/scala/org/apache/spark/util/SizeEstimator.scala +++ b/core/src/main/scala/org/apache/spark/util/SizeEstimator.scala @@ -124,9 +124,11 @@ object SizeEstimator extends Logging { val server = ManagementFactory.getPlatformMBeanServer() // NOTE: This should throw an exception in non-Sun JVMs + // scalastyle:off classforname val hotSpotMBeanClass = Class.forName("com.sun.management.HotSpotDiagnosticMXBean") val getVMMethod = hotSpotMBeanClass.getDeclaredMethod("getVMOption", Class.forName("java.lang.String")) + // scalastyle:on classforname val bean = ManagementFactory.newPlatformMXBeanProxy(server, hotSpotMBeanName, hotSpotMBeanClass) diff --git a/core/src/main/scala/org/apache/spark/util/Utils.scala b/core/src/main/scala/org/apache/spark/util/Utils.scala index b6b932104a94d..c5816949cd360 100644 --- a/core/src/main/scala/org/apache/spark/util/Utils.scala +++ b/core/src/main/scala/org/apache/spark/util/Utils.scala @@ -113,8 +113,11 @@ private[spark] object Utils extends Logging { def deserialize[T](bytes: Array[Byte], loader: ClassLoader): T = { val bis = new ByteArrayInputStream(bytes) val ois = new ObjectInputStream(bis) { - override def resolveClass(desc: ObjectStreamClass): Class[_] = + override def resolveClass(desc: ObjectStreamClass): Class[_] = { + // scalastyle:off classforname Class.forName(desc.getName, false, loader) + // scalastyle:on classforname + } } ois.readObject.asInstanceOf[T] } @@ -177,12 +180,16 @@ private[spark] object Utils extends Logging { /** Determines whether the provided class is loadable in the current thread. */ def classIsLoadable(clazz: String): Boolean = { + // scalastyle:off classforname Try { Class.forName(clazz, false, getContextOrSparkClassLoader) }.isSuccess + // scalastyle:on classforname } + // scalastyle:off classforname /** Preferred alternative to Class.forName(className) */ def classForName(className: String): Class[_] = { Class.forName(className, true, getContextOrSparkClassLoader) + // scalastyle:on classforname } /** @@ -1579,6 +1586,34 @@ private[spark] object Utils extends Logging { hashAbs } + /** + * NaN-safe version of [[java.lang.Double.compare()]] which allows NaN values to be compared + * according to semantics where NaN == NaN and NaN > any non-NaN double. + */ + def nanSafeCompareDoubles(x: Double, y: Double): Int = { + val xIsNan: Boolean = java.lang.Double.isNaN(x) + val yIsNan: Boolean = java.lang.Double.isNaN(y) + if ((xIsNan && yIsNan) || (x == y)) 0 + else if (xIsNan) 1 + else if (yIsNan) -1 + else if (x > y) 1 + else -1 + } + + /** + * NaN-safe version of [[java.lang.Float.compare()]] which allows NaN values to be compared + * according to semantics where NaN == NaN and NaN > any non-NaN float. + */ + def nanSafeCompareFloats(x: Float, y: Float): Int = { + val xIsNan: Boolean = java.lang.Float.isNaN(x) + val yIsNan: Boolean = java.lang.Float.isNaN(y) + if ((xIsNan && yIsNan) || (x == y)) 0 + else if (xIsNan) 1 + else if (yIsNan) -1 + else if (x > y) 1 + else -1 + } + /** Returns the system properties map that is thread-safe to iterator over. It gets the * properties which have been set explicitly, as well as those for which only a default value * has been defined. */ @@ -2266,7 +2301,7 @@ private [util] class SparkShutdownHookManager { val hookTask = new Runnable() { override def run(): Unit = runAll() } - Try(Class.forName("org.apache.hadoop.util.ShutdownHookManager")) match { + Try(Utils.classForName("org.apache.hadoop.util.ShutdownHookManager")) match { case Success(shmClass) => val fsPriority = classOf[FileSystem].getField("SHUTDOWN_HOOK_PRIORITY").get() .asInstanceOf[Int] diff --git a/core/src/main/scala/org/apache/spark/util/collection/ChainedBuffer.scala b/core/src/main/scala/org/apache/spark/util/collection/ChainedBuffer.scala index 516aaa44d03fc..ae60f3b0cb555 100644 --- a/core/src/main/scala/org/apache/spark/util/collection/ChainedBuffer.scala +++ b/core/src/main/scala/org/apache/spark/util/collection/ChainedBuffer.scala @@ -37,7 +37,7 @@ private[spark] class ChainedBuffer(chunkSize: Int) { private var _size: Long = 0 /** - * Feed bytes from this buffer into a BlockObjectWriter. + * Feed bytes from this buffer into a DiskBlockObjectWriter. * * @param pos Offset in the buffer to read from. * @param os OutputStream to read into. diff --git a/core/src/main/scala/org/apache/spark/util/collection/ExternalAppendOnlyMap.scala b/core/src/main/scala/org/apache/spark/util/collection/ExternalAppendOnlyMap.scala index 1e4531ef395ae..d166037351c31 100644 --- a/core/src/main/scala/org/apache/spark/util/collection/ExternalAppendOnlyMap.scala +++ b/core/src/main/scala/org/apache/spark/util/collection/ExternalAppendOnlyMap.scala @@ -26,7 +26,7 @@ import scala.collection.mutable.ArrayBuffer import com.google.common.io.ByteStreams -import org.apache.spark.{Logging, SparkEnv} +import org.apache.spark.{Logging, SparkEnv, TaskContext} import org.apache.spark.annotation.DeveloperApi import org.apache.spark.serializer.{DeserializationStream, Serializer} import org.apache.spark.storage.{BlockId, BlockManager} @@ -470,14 +470,27 @@ class ExternalAppendOnlyMap[K, V, C]( item } - // TODO: Ensure this gets called even if the iterator isn't drained. private def cleanup() { batchIndex = batchOffsets.length // Prevent reading any other batch val ds = deserializeStream - deserializeStream = null - fileStream = null - ds.close() - file.delete() + if (ds != null) { + ds.close() + deserializeStream = null + } + if (fileStream != null) { + fileStream.close() + fileStream = null + } + if (file.exists()) { + file.delete() + } + } + + val context = TaskContext.get() + // context is null in some tests of ExternalAppendOnlyMapSuite because these tests don't run in + // a TaskContext. + if (context != null) { + context.addTaskCompletionListener(context => cleanup()) } } diff --git a/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala b/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala index 757dec66c203b..ba7ec834d622d 100644 --- a/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala +++ b/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala @@ -30,7 +30,7 @@ import org.apache.spark._ import org.apache.spark.serializer._ import org.apache.spark.executor.ShuffleWriteMetrics import org.apache.spark.shuffle.sort.{SortShuffleFileWriter, SortShuffleWriter} -import org.apache.spark.storage.{BlockId, BlockObjectWriter} +import org.apache.spark.storage.{BlockId, DiskBlockObjectWriter} /** * Sorts and potentially merges a number of key-value pairs of type (K, V) to produce key-combiner @@ -250,7 +250,7 @@ private[spark] class ExternalSorter[K, V, C]( // These variables are reset after each flush var objectsWritten: Long = 0 var spillMetrics: ShuffleWriteMetrics = null - var writer: BlockObjectWriter = null + var writer: DiskBlockObjectWriter = null def openWriter(): Unit = { assert (writer == null && spillMetrics == null) spillMetrics = new ShuffleWriteMetrics diff --git a/core/src/main/scala/org/apache/spark/util/collection/PartitionedPairBuffer.scala b/core/src/main/scala/org/apache/spark/util/collection/PartitionedPairBuffer.scala index 04bb7fc78c13b..f5844d5353be7 100644 --- a/core/src/main/scala/org/apache/spark/util/collection/PartitionedPairBuffer.scala +++ b/core/src/main/scala/org/apache/spark/util/collection/PartitionedPairBuffer.scala @@ -19,7 +19,6 @@ package org.apache.spark.util.collection import java.util.Comparator -import org.apache.spark.storage.BlockObjectWriter import org.apache.spark.util.collection.WritablePartitionedPairCollection._ /** diff --git a/core/src/main/scala/org/apache/spark/util/collection/PartitionedSerializedPairBuffer.scala b/core/src/main/scala/org/apache/spark/util/collection/PartitionedSerializedPairBuffer.scala index ae9a48729e201..87a786b02d651 100644 --- a/core/src/main/scala/org/apache/spark/util/collection/PartitionedSerializedPairBuffer.scala +++ b/core/src/main/scala/org/apache/spark/util/collection/PartitionedSerializedPairBuffer.scala @@ -21,9 +21,8 @@ import java.io.InputStream import java.nio.IntBuffer import java.util.Comparator -import org.apache.spark.SparkEnv import org.apache.spark.serializer.{JavaSerializerInstance, SerializerInstance} -import org.apache.spark.storage.BlockObjectWriter +import org.apache.spark.storage.DiskBlockObjectWriter import org.apache.spark.util.collection.PartitionedSerializedPairBuffer._ /** @@ -136,7 +135,7 @@ private[spark] class PartitionedSerializedPairBuffer[K, V]( // current position in the meta buffer in ints var pos = 0 - def writeNext(writer: BlockObjectWriter): Unit = { + def writeNext(writer: DiskBlockObjectWriter): Unit = { val keyStart = getKeyStartPos(metaBuffer, pos) val keyValLen = metaBuffer.get(pos + KEY_VAL_LEN) pos += RECORD_SIZE diff --git a/core/src/main/scala/org/apache/spark/util/collection/WritablePartitionedPairCollection.scala b/core/src/main/scala/org/apache/spark/util/collection/WritablePartitionedPairCollection.scala index 7bc59898658e4..38848e9018c6c 100644 --- a/core/src/main/scala/org/apache/spark/util/collection/WritablePartitionedPairCollection.scala +++ b/core/src/main/scala/org/apache/spark/util/collection/WritablePartitionedPairCollection.scala @@ -19,7 +19,7 @@ package org.apache.spark.util.collection import java.util.Comparator -import org.apache.spark.storage.BlockObjectWriter +import org.apache.spark.storage.DiskBlockObjectWriter /** * A common interface for size-tracking collections of key-value pairs that @@ -51,7 +51,7 @@ private[spark] trait WritablePartitionedPairCollection[K, V] { new WritablePartitionedIterator { private[this] var cur = if (it.hasNext) it.next() else null - def writeNext(writer: BlockObjectWriter): Unit = { + def writeNext(writer: DiskBlockObjectWriter): Unit = { writer.write(cur._1._2, cur._2) cur = if (it.hasNext) it.next() else null } @@ -91,11 +91,11 @@ private[spark] object WritablePartitionedPairCollection { } /** - * Iterator that writes elements to a BlockObjectWriter instead of returning them. Each element + * Iterator that writes elements to a DiskBlockObjectWriter instead of returning them. Each element * has an associated partition. */ private[spark] trait WritablePartitionedIterator { - def writeNext(writer: BlockObjectWriter): Unit + def writeNext(writer: DiskBlockObjectWriter): Unit def hasNext(): Boolean diff --git a/core/src/test/java/org/apache/spark/JavaAPISuite.java b/core/src/test/java/org/apache/spark/JavaAPISuite.java index dfd86d3e51e7d..e948ca33471a4 100644 --- a/core/src/test/java/org/apache/spark/JavaAPISuite.java +++ b/core/src/test/java/org/apache/spark/JavaAPISuite.java @@ -1011,7 +1011,7 @@ public void persist() { @Test public void iterator() { JavaRDD rdd = sc.parallelize(Arrays.asList(1, 2, 3, 4, 5), 2); - TaskContext context = new TaskContextImpl(0, 0, 0L, 0, null, false, new TaskMetrics()); + TaskContext context = new TaskContextImpl(0, 0, 0L, 0, null, null, false, new TaskMetrics()); Assert.assertEquals(1, rdd.iterator(rdd.partitions().get(0), context).next().intValue()); } @@ -1783,7 +1783,7 @@ public void testGuavaOptional() { // Stop the context created in setUp() and start a local-cluster one, to force usage of the // assembly. sc.stop(); - JavaSparkContext localCluster = new JavaSparkContext("local-cluster[1,1,512]", "JavaAPISuite"); + JavaSparkContext localCluster = new JavaSparkContext("local-cluster[1,1,1024]", "JavaAPISuite"); try { JavaRDD rdd1 = localCluster.parallelize(Arrays.asList(1, 2, null), 3); JavaRDD> rdd2 = rdd1.map( diff --git a/core/src/test/scala/org/apache/spark/CacheManagerSuite.scala b/core/src/test/scala/org/apache/spark/CacheManagerSuite.scala index af81e46a657d3..618a5fb24710f 100644 --- a/core/src/test/scala/org/apache/spark/CacheManagerSuite.scala +++ b/core/src/test/scala/org/apache/spark/CacheManagerSuite.scala @@ -65,7 +65,7 @@ class CacheManagerSuite extends SparkFunSuite with LocalSparkContext with Before // in blockManager.put is a losing battle. You have been warned. blockManager = sc.env.blockManager cacheManager = sc.env.cacheManager - val context = new TaskContextImpl(0, 0, 0, 0, null) + val context = new TaskContextImpl(0, 0, 0, 0, null, null) val computeValue = cacheManager.getOrCompute(rdd, split, context, StorageLevel.MEMORY_ONLY) val getValue = blockManager.get(RDDBlockId(rdd.id, split.index)) assert(computeValue.toList === List(1, 2, 3, 4)) @@ -77,7 +77,7 @@ class CacheManagerSuite extends SparkFunSuite with LocalSparkContext with Before val result = new BlockResult(Array(5, 6, 7).iterator, DataReadMethod.Memory, 12) when(blockManager.get(RDDBlockId(0, 0))).thenReturn(Some(result)) - val context = new TaskContextImpl(0, 0, 0, 0, null) + val context = new TaskContextImpl(0, 0, 0, 0, null, null) val value = cacheManager.getOrCompute(rdd, split, context, StorageLevel.MEMORY_ONLY) assert(value.toList === List(5, 6, 7)) } @@ -86,14 +86,14 @@ class CacheManagerSuite extends SparkFunSuite with LocalSparkContext with Before // Local computation should not persist the resulting value, so don't expect a put(). when(blockManager.get(RDDBlockId(0, 0))).thenReturn(None) - val context = new TaskContextImpl(0, 0, 0, 0, null, true) + val context = new TaskContextImpl(0, 0, 0, 0, null, null, true) val value = cacheManager.getOrCompute(rdd, split, context, StorageLevel.MEMORY_ONLY) assert(value.toList === List(1, 2, 3, 4)) } test("verify task metrics updated correctly") { cacheManager = sc.env.cacheManager - val context = new TaskContextImpl(0, 0, 0, 0, null) + val context = new TaskContextImpl(0, 0, 0, 0, null, null) cacheManager.getOrCompute(rdd3, split, context, StorageLevel.MEMORY_ONLY) assert(context.taskMetrics.updatedBlocks.getOrElse(Seq()).size === 2) } diff --git a/core/src/test/scala/org/apache/spark/ContextCleanerSuite.scala b/core/src/test/scala/org/apache/spark/ContextCleanerSuite.scala index 501fe186bfd7c..26858ef2774fc 100644 --- a/core/src/test/scala/org/apache/spark/ContextCleanerSuite.scala +++ b/core/src/test/scala/org/apache/spark/ContextCleanerSuite.scala @@ -292,7 +292,7 @@ class ContextCleanerSuite extends ContextCleanerSuiteBase { sc.stop() val conf2 = new SparkConf() - .setMaster("local-cluster[2, 1, 512]") + .setMaster("local-cluster[2, 1, 1024]") .setAppName("ContextCleanerSuite") .set("spark.cleaner.referenceTracking.blocking", "true") .set("spark.cleaner.referenceTracking.blocking.shuffle", "true") @@ -370,7 +370,7 @@ class SortShuffleContextCleanerSuite extends ContextCleanerSuiteBase(classOf[Sor sc.stop() val conf2 = new SparkConf() - .setMaster("local-cluster[2, 1, 512]") + .setMaster("local-cluster[2, 1, 1024]") .setAppName("ContextCleanerSuite") .set("spark.cleaner.referenceTracking.blocking", "true") .set("spark.cleaner.referenceTracking.blocking.shuffle", "true") diff --git a/core/src/test/scala/org/apache/spark/DistributedSuite.scala b/core/src/test/scala/org/apache/spark/DistributedSuite.scala index 2300bcff4f118..600c1403b0344 100644 --- a/core/src/test/scala/org/apache/spark/DistributedSuite.scala +++ b/core/src/test/scala/org/apache/spark/DistributedSuite.scala @@ -29,7 +29,7 @@ class NotSerializableExn(val notSer: NotSerializableClass) extends Throwable() { class DistributedSuite extends SparkFunSuite with Matchers with LocalSparkContext { - val clusterUrl = "local-cluster[2,1,512]" + val clusterUrl = "local-cluster[2,1,1024]" test("task throws not serializable exception") { // Ensures that executors do not crash when an exn is not serializable. If executors crash, @@ -40,7 +40,7 @@ class DistributedSuite extends SparkFunSuite with Matchers with LocalSparkContex val numSlaves = 3 val numPartitions = 10 - sc = new SparkContext("local-cluster[%s,1,512]".format(numSlaves), "test") + sc = new SparkContext("local-cluster[%s,1,1024]".format(numSlaves), "test") val data = sc.parallelize(1 to 100, numPartitions). map(x => throw new NotSerializableExn(new NotSerializableClass)) intercept[SparkException] { @@ -50,16 +50,16 @@ class DistributedSuite extends SparkFunSuite with Matchers with LocalSparkContex } test("local-cluster format") { - sc = new SparkContext("local-cluster[2,1,512]", "test") + sc = new SparkContext("local-cluster[2,1,1024]", "test") assert(sc.parallelize(1 to 2, 2).count() == 2) resetSparkContext() - sc = new SparkContext("local-cluster[2 , 1 , 512]", "test") + sc = new SparkContext("local-cluster[2 , 1 , 1024]", "test") assert(sc.parallelize(1 to 2, 2).count() == 2) resetSparkContext() - sc = new SparkContext("local-cluster[2, 1, 512]", "test") + sc = new SparkContext("local-cluster[2, 1, 1024]", "test") assert(sc.parallelize(1 to 2, 2).count() == 2) resetSparkContext() - sc = new SparkContext("local-cluster[ 2, 1, 512 ]", "test") + sc = new SparkContext("local-cluster[ 2, 1, 1024 ]", "test") assert(sc.parallelize(1 to 2, 2).count() == 2) resetSparkContext() } @@ -276,7 +276,7 @@ class DistributedSuite extends SparkFunSuite with Matchers with LocalSparkContex DistributedSuite.amMaster = true // Using more than two nodes so we don't have a symmetric communication pattern and might // cache a partially correct list of peers. - sc = new SparkContext("local-cluster[3,1,512]", "test") + sc = new SparkContext("local-cluster[3,1,1024]", "test") for (i <- 1 to 3) { val data = sc.parallelize(Seq(true, false, false, false), 4) data.persist(StorageLevel.MEMORY_ONLY_2) @@ -294,7 +294,7 @@ class DistributedSuite extends SparkFunSuite with Matchers with LocalSparkContex test("unpersist RDDs") { DistributedSuite.amMaster = true - sc = new SparkContext("local-cluster[3,1,512]", "test") + sc = new SparkContext("local-cluster[3,1,1024]", "test") val data = sc.parallelize(Seq(true, false, false, false), 4) data.persist(StorageLevel.MEMORY_ONLY_2) data.count diff --git a/core/src/test/scala/org/apache/spark/DriverSuite.scala b/core/src/test/scala/org/apache/spark/DriverSuite.scala index b2262033ca238..454b7e607a51b 100644 --- a/core/src/test/scala/org/apache/spark/DriverSuite.scala +++ b/core/src/test/scala/org/apache/spark/DriverSuite.scala @@ -29,7 +29,7 @@ class DriverSuite extends SparkFunSuite with Timeouts { ignore("driver should exit after finishing without cleanup (SPARK-530)") { val sparkHome = sys.props.getOrElse("spark.test.home", fail("spark.test.home is not set!")) - val masters = Table("master", "local", "local-cluster[2,1,512]") + val masters = Table("master", "local", "local-cluster[2,1,1024]") forAll(masters) { (master: String) => val process = Utils.executeCommand( Seq(s"$sparkHome/bin/spark-class", "org.apache.spark.DriverWithoutCleanup", master), diff --git a/core/src/test/scala/org/apache/spark/ExternalShuffleServiceSuite.scala b/core/src/test/scala/org/apache/spark/ExternalShuffleServiceSuite.scala index 140012226fdbb..c38d70252add1 100644 --- a/core/src/test/scala/org/apache/spark/ExternalShuffleServiceSuite.scala +++ b/core/src/test/scala/org/apache/spark/ExternalShuffleServiceSuite.scala @@ -51,7 +51,7 @@ class ExternalShuffleServiceSuite extends ShuffleSuite with BeforeAndAfterAll { // This test ensures that the external shuffle service is actually in use for the other tests. test("using external shuffle service") { - sc = new SparkContext("local-cluster[2,1,512]", "test", conf) + sc = new SparkContext("local-cluster[2,1,1024]", "test", conf) sc.env.blockManager.externalShuffleServiceEnabled should equal(true) sc.env.blockManager.shuffleClient.getClass should equal(classOf[ExternalShuffleClient]) diff --git a/core/src/test/scala/org/apache/spark/FailureSuite.scala b/core/src/test/scala/org/apache/spark/FailureSuite.scala index b099cd3fb7965..69cb4b44cf7ef 100644 --- a/core/src/test/scala/org/apache/spark/FailureSuite.scala +++ b/core/src/test/scala/org/apache/spark/FailureSuite.scala @@ -141,5 +141,30 @@ class FailureSuite extends SparkFunSuite with LocalSparkContext { FailureSuiteState.clear() } + test("managed memory leak error should not mask other failures (SPARK-9266") { + val conf = new SparkConf().set("spark.unsafe.exceptionOnMemoryLeak", "true") + sc = new SparkContext("local[1,1]", "test", conf) + + // If a task leaks memory but fails due to some other cause, then make sure that the original + // cause is preserved + val thrownDueToTaskFailure = intercept[SparkException] { + sc.parallelize(Seq(0)).mapPartitions { iter => + TaskContext.get().taskMemoryManager().allocate(128) + throw new Exception("intentional task failure") + iter + }.count() + } + assert(thrownDueToTaskFailure.getMessage.contains("intentional task failure")) + + // If the task succeeded but memory was leaked, then the task should fail due to that leak + val thrownDueToMemoryLeak = intercept[SparkException] { + sc.parallelize(Seq(0)).mapPartitions { iter => + TaskContext.get().taskMemoryManager().allocate(128) + iter + }.count() + } + assert(thrownDueToMemoryLeak.getMessage.contains("memory leak")) + } + // TODO: Need to add tests with shuffle fetch failures. } diff --git a/core/src/test/scala/org/apache/spark/FileServerSuite.scala b/core/src/test/scala/org/apache/spark/FileServerSuite.scala index 876418aa13029..1255e71af6c0b 100644 --- a/core/src/test/scala/org/apache/spark/FileServerSuite.scala +++ b/core/src/test/scala/org/apache/spark/FileServerSuite.scala @@ -139,7 +139,7 @@ class FileServerSuite extends SparkFunSuite with LocalSparkContext { } test("Distributing files on a standalone cluster") { - sc = new SparkContext("local-cluster[1,1,512]", "test", newConf) + sc = new SparkContext("local-cluster[1,1,1024]", "test", newConf) sc.addFile(tmpFile.toString) val testData = Array((1, 1), (1, 1), (2, 1), (3, 5), (2, 2), (3, 0)) val result = sc.parallelize(testData).reduceByKey { @@ -153,7 +153,7 @@ class FileServerSuite extends SparkFunSuite with LocalSparkContext { } test ("Dynamically adding JARS on a standalone cluster") { - sc = new SparkContext("local-cluster[1,1,512]", "test", newConf) + sc = new SparkContext("local-cluster[1,1,1024]", "test", newConf) sc.addJar(tmpJarUrl) val testData = Array((1, 1)) sc.parallelize(testData).foreach { x => @@ -164,7 +164,7 @@ class FileServerSuite extends SparkFunSuite with LocalSparkContext { } test ("Dynamically adding JARS on a standalone cluster using local: URL") { - sc = new SparkContext("local-cluster[1,1,512]", "test", newConf) + sc = new SparkContext("local-cluster[1,1,1024]", "test", newConf) sc.addJar(tmpJarUrl.replace("file", "local")) val testData = Array((1, 1)) sc.parallelize(testData).foreach { x => diff --git a/core/src/test/scala/org/apache/spark/FileSuite.scala b/core/src/test/scala/org/apache/spark/FileSuite.scala index 1d8fade90f398..418763f4e5ffa 100644 --- a/core/src/test/scala/org/apache/spark/FileSuite.scala +++ b/core/src/test/scala/org/apache/spark/FileSuite.scala @@ -179,6 +179,7 @@ class FileSuite extends SparkFunSuite with LocalSparkContext { } test("object files of classes from a JAR") { + // scalastyle:off classforname val original = Thread.currentThread().getContextClassLoader val className = "FileSuiteObjectFileTest" val jar = TestUtils.createJarWithClasses(Seq(className)) @@ -201,6 +202,7 @@ class FileSuite extends SparkFunSuite with LocalSparkContext { finally { Thread.currentThread().setContextClassLoader(original) } + // scalastyle:on classforname } test("write SequenceFile using new Hadoop API") { diff --git a/core/src/test/scala/org/apache/spark/HeartbeatReceiverSuite.scala b/core/src/test/scala/org/apache/spark/HeartbeatReceiverSuite.scala index b31b09196608f..5a2670e4d1cf0 100644 --- a/core/src/test/scala/org/apache/spark/HeartbeatReceiverSuite.scala +++ b/core/src/test/scala/org/apache/spark/HeartbeatReceiverSuite.scala @@ -17,6 +17,9 @@ package org.apache.spark +import java.util.concurrent.{ExecutorService, TimeUnit} + +import scala.collection.mutable import scala.language.postfixOps import org.scalatest.{BeforeAndAfterEach, PrivateMethodTester} @@ -25,11 +28,16 @@ import org.mockito.Matchers import org.mockito.Matchers._ import org.apache.spark.executor.TaskMetrics -import org.apache.spark.rpc.RpcEndpointRef +import org.apache.spark.rpc.{RpcCallContext, RpcEndpoint, RpcEnv, RpcEndpointRef} import org.apache.spark.scheduler._ +import org.apache.spark.scheduler.cluster.CoarseGrainedClusterMessages._ +import org.apache.spark.scheduler.cluster.CoarseGrainedSchedulerBackend import org.apache.spark.storage.BlockManagerId import org.apache.spark.util.ManualClock +/** + * A test suite for the heartbeating behavior between the driver and the executors. + */ class HeartbeatReceiverSuite extends SparkFunSuite with BeforeAndAfterEach @@ -40,23 +48,40 @@ class HeartbeatReceiverSuite private val executorId2 = "executor-2" // Shared state that must be reset before and after each test - private var scheduler: TaskScheduler = null + private var scheduler: TaskSchedulerImpl = null private var heartbeatReceiver: HeartbeatReceiver = null private var heartbeatReceiverRef: RpcEndpointRef = null private var heartbeatReceiverClock: ManualClock = null + // Helper private method accessors for HeartbeatReceiver + private val _executorLastSeen = PrivateMethod[collection.Map[String, Long]]('executorLastSeen) + private val _executorTimeoutMs = PrivateMethod[Long]('executorTimeoutMs) + private val _killExecutorThread = PrivateMethod[ExecutorService]('killExecutorThread) + + /** + * Before each test, set up the SparkContext and a custom [[HeartbeatReceiver]] + * that uses a manual clock. + */ override def beforeEach(): Unit = { - sc = spy(new SparkContext("local[2]", "test")) - scheduler = mock(classOf[TaskScheduler]) + val conf = new SparkConf() + .setMaster("local[2]") + .setAppName("test") + .set("spark.dynamicAllocation.testing", "true") + sc = spy(new SparkContext(conf)) + scheduler = mock(classOf[TaskSchedulerImpl]) when(sc.taskScheduler).thenReturn(scheduler) + when(scheduler.sc).thenReturn(sc) heartbeatReceiverClock = new ManualClock heartbeatReceiver = new HeartbeatReceiver(sc, heartbeatReceiverClock) heartbeatReceiverRef = sc.env.rpcEnv.setupEndpoint("heartbeat", heartbeatReceiver) when(scheduler.executorHeartbeatReceived(any(), any(), any())).thenReturn(true) } + /** + * After each test, clean up all state and stop the [[SparkContext]]. + */ override def afterEach(): Unit = { - resetSparkContext() + super.afterEach() scheduler = null heartbeatReceiver = null heartbeatReceiverRef = null @@ -75,7 +100,7 @@ class HeartbeatReceiverSuite heartbeatReceiver.onExecutorAdded(SparkListenerExecutorAdded(0, executorId2, null)) triggerHeartbeat(executorId1, executorShouldReregister = false) triggerHeartbeat(executorId2, executorShouldReregister = false) - val trackedExecutors = executorLastSeen(heartbeatReceiver) + val trackedExecutors = heartbeatReceiver.invokePrivate(_executorLastSeen()) assert(trackedExecutors.size === 2) assert(trackedExecutors.contains(executorId1)) assert(trackedExecutors.contains(executorId2)) @@ -83,15 +108,15 @@ class HeartbeatReceiverSuite test("reregister if scheduler is not ready yet") { heartbeatReceiver.onExecutorAdded(SparkListenerExecutorAdded(0, executorId1, null)) - // Task scheduler not set in HeartbeatReceiver + // Task scheduler is not set yet in HeartbeatReceiver, so executors should reregister triggerHeartbeat(executorId1, executorShouldReregister = true) } test("reregister if heartbeat from unregistered executor") { heartbeatReceiverRef.askWithRetry[Boolean](TaskSchedulerIsSet) - // Received heartbeat from unknown receiver, so we ask it to re-register + // Received heartbeat from unknown executor, so we ask it to re-register triggerHeartbeat(executorId1, executorShouldReregister = true) - assert(executorLastSeen(heartbeatReceiver).isEmpty) + assert(heartbeatReceiver.invokePrivate(_executorLastSeen()).isEmpty) } test("reregister if heartbeat from removed executor") { @@ -104,14 +129,14 @@ class HeartbeatReceiverSuite // A heartbeat from the second executor should require reregistering triggerHeartbeat(executorId1, executorShouldReregister = false) triggerHeartbeat(executorId2, executorShouldReregister = true) - val trackedExecutors = executorLastSeen(heartbeatReceiver) + val trackedExecutors = heartbeatReceiver.invokePrivate(_executorLastSeen()) assert(trackedExecutors.size === 1) assert(trackedExecutors.contains(executorId1)) assert(!trackedExecutors.contains(executorId2)) } test("expire dead hosts") { - val executorTimeout = executorTimeoutMs(heartbeatReceiver) + val executorTimeout = heartbeatReceiver.invokePrivate(_executorTimeoutMs()) heartbeatReceiverRef.askWithRetry[Boolean](TaskSchedulerIsSet) heartbeatReceiver.onExecutorAdded(SparkListenerExecutorAdded(0, executorId1, null)) heartbeatReceiver.onExecutorAdded(SparkListenerExecutorAdded(0, executorId2, null)) @@ -124,12 +149,61 @@ class HeartbeatReceiverSuite heartbeatReceiverRef.askWithRetry[Boolean](ExpireDeadHosts) // Only the second executor should be expired as a dead host verify(scheduler).executorLost(Matchers.eq(executorId2), any()) - val trackedExecutors = executorLastSeen(heartbeatReceiver) + val trackedExecutors = heartbeatReceiver.invokePrivate(_executorLastSeen()) assert(trackedExecutors.size === 1) assert(trackedExecutors.contains(executorId1)) assert(!trackedExecutors.contains(executorId2)) } + test("expire dead hosts should kill executors with replacement (SPARK-8119)") { + // Set up a fake backend and cluster manager to simulate killing executors + val rpcEnv = sc.env.rpcEnv + val fakeClusterManager = new FakeClusterManager(rpcEnv) + val fakeClusterManagerRef = rpcEnv.setupEndpoint("fake-cm", fakeClusterManager) + val fakeSchedulerBackend = new FakeSchedulerBackend(scheduler, rpcEnv, fakeClusterManagerRef) + when(sc.schedulerBackend).thenReturn(fakeSchedulerBackend) + + // Register fake executors with our fake scheduler backend + // This is necessary because the backend refuses to kill executors it does not know about + fakeSchedulerBackend.start() + val dummyExecutorEndpoint1 = new FakeExecutorEndpoint(rpcEnv) + val dummyExecutorEndpoint2 = new FakeExecutorEndpoint(rpcEnv) + val dummyExecutorEndpointRef1 = rpcEnv.setupEndpoint("fake-executor-1", dummyExecutorEndpoint1) + val dummyExecutorEndpointRef2 = rpcEnv.setupEndpoint("fake-executor-2", dummyExecutorEndpoint2) + fakeSchedulerBackend.driverEndpoint.askWithRetry[RegisteredExecutor.type]( + RegisterExecutor(executorId1, dummyExecutorEndpointRef1, "dummy:4040", 0, Map.empty)) + fakeSchedulerBackend.driverEndpoint.askWithRetry[RegisteredExecutor.type]( + RegisterExecutor(executorId2, dummyExecutorEndpointRef2, "dummy:4040", 0, Map.empty)) + heartbeatReceiverRef.askWithRetry[Boolean](TaskSchedulerIsSet) + heartbeatReceiver.onExecutorAdded(SparkListenerExecutorAdded(0, executorId1, null)) + heartbeatReceiver.onExecutorAdded(SparkListenerExecutorAdded(0, executorId2, null)) + triggerHeartbeat(executorId1, executorShouldReregister = false) + triggerHeartbeat(executorId2, executorShouldReregister = false) + + // Adjust the target number of executors on the cluster manager side + assert(fakeClusterManager.getTargetNumExecutors === 0) + sc.requestTotalExecutors(2) + assert(fakeClusterManager.getTargetNumExecutors === 2) + assert(fakeClusterManager.getExecutorIdsToKill.isEmpty) + + // Expire the executors. This should trigger our fake backend to kill the executors. + // Since the kill request is sent to the cluster manager asynchronously, we need to block + // on the kill thread to ensure that the cluster manager actually received our requests. + // Here we use a timeout of O(seconds), but in practice this whole test takes O(10ms). + val executorTimeout = heartbeatReceiver.invokePrivate(_executorTimeoutMs()) + heartbeatReceiverClock.advance(executorTimeout * 2) + heartbeatReceiverRef.askWithRetry[Boolean](ExpireDeadHosts) + val killThread = heartbeatReceiver.invokePrivate(_killExecutorThread()) + killThread.shutdown() // needed for awaitTermination + killThread.awaitTermination(10L, TimeUnit.SECONDS) + + // The target number of executors should not change! Otherwise, having an expired + // executor means we permanently adjust the target number downwards until we + // explicitly request new executors. For more detail, see SPARK-8119. + assert(fakeClusterManager.getTargetNumExecutors === 2) + assert(fakeClusterManager.getExecutorIdsToKill === Set(executorId1, executorId2)) + } + /** Manually send a heartbeat and return the response. */ private def triggerHeartbeat( executorId: String, @@ -148,14 +222,49 @@ class HeartbeatReceiverSuite } } - // Helper methods to access private fields in HeartbeatReceiver - private val _executorLastSeen = PrivateMethod[collection.Map[String, Long]]('executorLastSeen) - private val _executorTimeoutMs = PrivateMethod[Long]('executorTimeoutMs) - private def executorLastSeen(receiver: HeartbeatReceiver): collection.Map[String, Long] = { - receiver invokePrivate _executorLastSeen() +} + +// TODO: use these classes to add end-to-end tests for dynamic allocation! + +/** + * Dummy RPC endpoint to simulate executors. + */ +private class FakeExecutorEndpoint(override val rpcEnv: RpcEnv) extends RpcEndpoint + +/** + * Dummy scheduler backend to simulate executor allocation requests to the cluster manager. + */ +private class FakeSchedulerBackend( + scheduler: TaskSchedulerImpl, + rpcEnv: RpcEnv, + clusterManagerEndpoint: RpcEndpointRef) + extends CoarseGrainedSchedulerBackend(scheduler, rpcEnv) { + + protected override def doRequestTotalExecutors(requestedTotal: Int): Boolean = { + clusterManagerEndpoint.askWithRetry[Boolean](RequestExecutors(requestedTotal)) } - private def executorTimeoutMs(receiver: HeartbeatReceiver): Long = { - receiver invokePrivate _executorTimeoutMs() + + protected override def doKillExecutors(executorIds: Seq[String]): Boolean = { + clusterManagerEndpoint.askWithRetry[Boolean](KillExecutors(executorIds)) } +} +/** + * Dummy cluster manager to simulate responses to executor allocation requests. + */ +private class FakeClusterManager(override val rpcEnv: RpcEnv) extends RpcEndpoint { + private var targetNumExecutors = 0 + private val executorIdsToKill = new mutable.HashSet[String] + + def getTargetNumExecutors: Int = targetNumExecutors + def getExecutorIdsToKill: Set[String] = executorIdsToKill.toSet + + override def receiveAndReply(context: RpcCallContext): PartialFunction[Any, Unit] = { + case RequestExecutors(requestedTotal) => + targetNumExecutors = requestedTotal + context.reply(true) + case KillExecutors(executorIds) => + executorIdsToKill ++= executorIds + context.reply(true) + } } diff --git a/core/src/test/scala/org/apache/spark/JobCancellationSuite.scala b/core/src/test/scala/org/apache/spark/JobCancellationSuite.scala index 340a9e327107e..1168eb0b802f2 100644 --- a/core/src/test/scala/org/apache/spark/JobCancellationSuite.scala +++ b/core/src/test/scala/org/apache/spark/JobCancellationSuite.scala @@ -64,7 +64,7 @@ class JobCancellationSuite extends SparkFunSuite with Matchers with BeforeAndAft test("cluster mode, FIFO scheduler") { val conf = new SparkConf().set("spark.scheduler.mode", "FIFO") - sc = new SparkContext("local-cluster[2,1,512]", "test", conf) + sc = new SparkContext("local-cluster[2,1,1024]", "test", conf) testCount() testTake() // Make sure we can still launch tasks. @@ -75,7 +75,7 @@ class JobCancellationSuite extends SparkFunSuite with Matchers with BeforeAndAft val conf = new SparkConf().set("spark.scheduler.mode", "FAIR") val xmlPath = getClass.getClassLoader.getResource("fairscheduler.xml").getFile() conf.set("spark.scheduler.allocation.file", xmlPath) - sc = new SparkContext("local-cluster[2,1,512]", "test", conf) + sc = new SparkContext("local-cluster[2,1,1024]", "test", conf) testCount() testTake() // Make sure we can still launch tasks. diff --git a/core/src/test/scala/org/apache/spark/MapOutputTrackerSuite.scala b/core/src/test/scala/org/apache/spark/MapOutputTrackerSuite.scala index 7a1961137cce5..af4e68950f75a 100644 --- a/core/src/test/scala/org/apache/spark/MapOutputTrackerSuite.scala +++ b/core/src/test/scala/org/apache/spark/MapOutputTrackerSuite.scala @@ -17,13 +17,15 @@ package org.apache.spark +import scala.collection.mutable.ArrayBuffer + import org.mockito.Mockito._ import org.mockito.Matchers.{any, isA} import org.apache.spark.rpc.{RpcAddress, RpcEndpointRef, RpcCallContext, RpcEnv} import org.apache.spark.scheduler.{CompressedMapStatus, MapStatus} import org.apache.spark.shuffle.FetchFailedException -import org.apache.spark.storage.BlockManagerId +import org.apache.spark.storage.{BlockManagerId, ShuffleBlockId} class MapOutputTrackerSuite extends SparkFunSuite { private val conf = new SparkConf @@ -55,9 +57,11 @@ class MapOutputTrackerSuite extends SparkFunSuite { Array(1000L, 10000L))) tracker.registerMapOutput(10, 1, MapStatus(BlockManagerId("b", "hostB", 1000), Array(10000L, 1000L))) - val statuses = tracker.getServerStatuses(10, 0) - assert(statuses.toSeq === Seq((BlockManagerId("a", "hostA", 1000), size1000), - (BlockManagerId("b", "hostB", 1000), size10000))) + val statuses = tracker.getMapSizesByExecutorId(10, 0) + assert(statuses.toSet === + Seq((BlockManagerId("a", "hostA", 1000), ArrayBuffer((ShuffleBlockId(10, 0, 0), size1000))), + (BlockManagerId("b", "hostB", 1000), ArrayBuffer((ShuffleBlockId(10, 1, 0), size10000)))) + .toSet) tracker.stop() rpcEnv.shutdown() } @@ -75,10 +79,10 @@ class MapOutputTrackerSuite extends SparkFunSuite { tracker.registerMapOutput(10, 1, MapStatus(BlockManagerId("b", "hostB", 1000), Array(compressedSize10000, compressedSize1000))) assert(tracker.containsShuffle(10)) - assert(tracker.getServerStatuses(10, 0).nonEmpty) + assert(tracker.getMapSizesByExecutorId(10, 0).nonEmpty) tracker.unregisterShuffle(10) assert(!tracker.containsShuffle(10)) - assert(tracker.getServerStatuses(10, 0).isEmpty) + assert(tracker.getMapSizesByExecutorId(10, 0).isEmpty) tracker.stop() rpcEnv.shutdown() @@ -104,7 +108,7 @@ class MapOutputTrackerSuite extends SparkFunSuite { // The remaining reduce task might try to grab the output despite the shuffle failure; // this should cause it to fail, and the scheduler will ignore the failure due to the // stage already being aborted. - intercept[FetchFailedException] { tracker.getServerStatuses(10, 1) } + intercept[FetchFailedException] { tracker.getMapSizesByExecutorId(10, 1) } tracker.stop() rpcEnv.shutdown() @@ -126,23 +130,23 @@ class MapOutputTrackerSuite extends SparkFunSuite { masterTracker.registerShuffle(10, 1) masterTracker.incrementEpoch() slaveTracker.updateEpoch(masterTracker.getEpoch) - intercept[FetchFailedException] { slaveTracker.getServerStatuses(10, 0) } + intercept[FetchFailedException] { slaveTracker.getMapSizesByExecutorId(10, 0) } val size1000 = MapStatus.decompressSize(MapStatus.compressSize(1000L)) masterTracker.registerMapOutput(10, 0, MapStatus( BlockManagerId("a", "hostA", 1000), Array(1000L))) masterTracker.incrementEpoch() slaveTracker.updateEpoch(masterTracker.getEpoch) - assert(slaveTracker.getServerStatuses(10, 0).toSeq === - Seq((BlockManagerId("a", "hostA", 1000), size1000))) + assert(slaveTracker.getMapSizesByExecutorId(10, 0) === + Seq((BlockManagerId("a", "hostA", 1000), ArrayBuffer((ShuffleBlockId(10, 0, 0), size1000))))) masterTracker.unregisterMapOutput(10, 0, BlockManagerId("a", "hostA", 1000)) masterTracker.incrementEpoch() slaveTracker.updateEpoch(masterTracker.getEpoch) - intercept[FetchFailedException] { slaveTracker.getServerStatuses(10, 0) } + intercept[FetchFailedException] { slaveTracker.getMapSizesByExecutorId(10, 0) } // failure should be cached - intercept[FetchFailedException] { slaveTracker.getServerStatuses(10, 0) } + intercept[FetchFailedException] { slaveTracker.getMapSizesByExecutorId(10, 0) } masterTracker.stop() slaveTracker.stop() diff --git a/core/src/test/scala/org/apache/spark/ShuffleSuite.scala b/core/src/test/scala/org/apache/spark/ShuffleSuite.scala index c3c2b1ffc1efa..d91b799ecfc08 100644 --- a/core/src/test/scala/org/apache/spark/ShuffleSuite.scala +++ b/core/src/test/scala/org/apache/spark/ShuffleSuite.scala @@ -47,7 +47,7 @@ abstract class ShuffleSuite extends SparkFunSuite with Matchers with LocalSparkC } test("shuffle non-zero block size") { - sc = new SparkContext("local-cluster[2,1,512]", "test", conf) + sc = new SparkContext("local-cluster[2,1,1024]", "test", conf) val NUM_BLOCKS = 3 val a = sc.parallelize(1 to 10, 2) @@ -66,14 +66,14 @@ abstract class ShuffleSuite extends SparkFunSuite with Matchers with LocalSparkC // All blocks must have non-zero size (0 until NUM_BLOCKS).foreach { id => - val statuses = SparkEnv.get.mapOutputTracker.getServerStatuses(shuffleId, id) - assert(statuses.forall(s => s._2 > 0)) + val statuses = SparkEnv.get.mapOutputTracker.getMapSizesByExecutorId(shuffleId, id) + assert(statuses.forall(_._2.forall(blockIdSizePair => blockIdSizePair._2 > 0))) } } test("shuffle serializer") { // Use a local cluster with 2 processes to make sure there are both local and remote blocks - sc = new SparkContext("local-cluster[2,1,512]", "test", conf) + sc = new SparkContext("local-cluster[2,1,1024]", "test", conf) val a = sc.parallelize(1 to 10, 2) val b = a.map { x => (x, new NonJavaSerializableClass(x * 2)) @@ -89,7 +89,7 @@ abstract class ShuffleSuite extends SparkFunSuite with Matchers with LocalSparkC test("zero sized blocks") { // Use a local cluster with 2 processes to make sure there are both local and remote blocks - sc = new SparkContext("local-cluster[2,1,512]", "test", conf) + sc = new SparkContext("local-cluster[2,1,1024]", "test", conf) // 201 partitions (greater than "spark.shuffle.sort.bypassMergeThreshold") from 4 keys val NUM_BLOCKS = 201 @@ -105,8 +105,8 @@ abstract class ShuffleSuite extends SparkFunSuite with Matchers with LocalSparkC assert(c.count === 4) val blockSizes = (0 until NUM_BLOCKS).flatMap { id => - val statuses = SparkEnv.get.mapOutputTracker.getServerStatuses(shuffleId, id) - statuses.map(x => x._2) + val statuses = SparkEnv.get.mapOutputTracker.getMapSizesByExecutorId(shuffleId, id) + statuses.flatMap(_._2.map(_._2)) } val nonEmptyBlocks = blockSizes.filter(x => x > 0) @@ -116,7 +116,7 @@ abstract class ShuffleSuite extends SparkFunSuite with Matchers with LocalSparkC test("zero sized blocks without kryo") { // Use a local cluster with 2 processes to make sure there are both local and remote blocks - sc = new SparkContext("local-cluster[2,1,512]", "test", conf) + sc = new SparkContext("local-cluster[2,1,1024]", "test", conf) // 201 partitions (greater than "spark.shuffle.sort.bypassMergeThreshold") from 4 keys val NUM_BLOCKS = 201 @@ -130,8 +130,8 @@ abstract class ShuffleSuite extends SparkFunSuite with Matchers with LocalSparkC assert(c.count === 4) val blockSizes = (0 until NUM_BLOCKS).flatMap { id => - val statuses = SparkEnv.get.mapOutputTracker.getServerStatuses(shuffleId, id) - statuses.map(x => x._2) + val statuses = SparkEnv.get.mapOutputTracker.getMapSizesByExecutorId(shuffleId, id) + statuses.flatMap(_._2.map(_._2)) } val nonEmptyBlocks = blockSizes.filter(x => x > 0) @@ -141,7 +141,7 @@ abstract class ShuffleSuite extends SparkFunSuite with Matchers with LocalSparkC test("shuffle on mutable pairs") { // Use a local cluster with 2 processes to make sure there are both local and remote blocks - sc = new SparkContext("local-cluster[2,1,512]", "test", conf) + sc = new SparkContext("local-cluster[2,1,1024]", "test", conf) def p[T1, T2](_1: T1, _2: T2): MutablePair[T1, T2] = MutablePair(_1, _2) val data = Array(p(1, 1), p(1, 2), p(1, 3), p(2, 1)) val pairs: RDD[MutablePair[Int, Int]] = sc.parallelize(data, 2) @@ -154,7 +154,7 @@ abstract class ShuffleSuite extends SparkFunSuite with Matchers with LocalSparkC test("sorting on mutable pairs") { // This is not in SortingSuite because of the local cluster setup. // Use a local cluster with 2 processes to make sure there are both local and remote blocks - sc = new SparkContext("local-cluster[2,1,512]", "test", conf) + sc = new SparkContext("local-cluster[2,1,1024]", "test", conf) def p[T1, T2](_1: T1, _2: T2): MutablePair[T1, T2] = MutablePair(_1, _2) val data = Array(p(1, 11), p(3, 33), p(100, 100), p(2, 22)) val pairs: RDD[MutablePair[Int, Int]] = sc.parallelize(data, 2) @@ -168,7 +168,7 @@ abstract class ShuffleSuite extends SparkFunSuite with Matchers with LocalSparkC test("cogroup using mutable pairs") { // Use a local cluster with 2 processes to make sure there are both local and remote blocks - sc = new SparkContext("local-cluster[2,1,512]", "test", conf) + sc = new SparkContext("local-cluster[2,1,1024]", "test", conf) def p[T1, T2](_1: T1, _2: T2): MutablePair[T1, T2] = MutablePair(_1, _2) val data1 = Seq(p(1, 1), p(1, 2), p(1, 3), p(2, 1)) val data2 = Seq(p(1, "11"), p(1, "12"), p(2, "22"), p(3, "3")) @@ -195,7 +195,7 @@ abstract class ShuffleSuite extends SparkFunSuite with Matchers with LocalSparkC test("subtract mutable pairs") { // Use a local cluster with 2 processes to make sure there are both local and remote blocks - sc = new SparkContext("local-cluster[2,1,512]", "test", conf) + sc = new SparkContext("local-cluster[2,1,1024]", "test", conf) def p[T1, T2](_1: T1, _2: T2): MutablePair[T1, T2] = MutablePair(_1, _2) val data1 = Seq(p(1, 1), p(1, 2), p(1, 3), p(2, 1), p(3, 33)) val data2 = Seq(p(1, "11"), p(1, "12"), p(2, "22")) @@ -210,7 +210,7 @@ abstract class ShuffleSuite extends SparkFunSuite with Matchers with LocalSparkC test("sort with Java non serializable class - Kryo") { // Use a local cluster with 2 processes to make sure there are both local and remote blocks val myConf = conf.clone().set("spark.serializer", "org.apache.spark.serializer.KryoSerializer") - sc = new SparkContext("local-cluster[2,1,512]", "test", myConf) + sc = new SparkContext("local-cluster[2,1,1024]", "test", myConf) val a = sc.parallelize(1 to 10, 2) val b = a.map { x => (new NonJavaSerializableClass(x), x) @@ -223,7 +223,7 @@ abstract class ShuffleSuite extends SparkFunSuite with Matchers with LocalSparkC test("sort with Java non serializable class - Java") { // Use a local cluster with 2 processes to make sure there are both local and remote blocks - sc = new SparkContext("local-cluster[2,1,512]", "test", conf) + sc = new SparkContext("local-cluster[2,1,1024]", "test", conf) val a = sc.parallelize(1 to 10, 2) val b = a.map { x => (new NonJavaSerializableClass(x), x) diff --git a/core/src/test/scala/org/apache/spark/SparkContextSchedulerCreationSuite.scala b/core/src/test/scala/org/apache/spark/SparkContextSchedulerCreationSuite.scala index f89e3d0a49920..e5a14a69ef05f 100644 --- a/core/src/test/scala/org/apache/spark/SparkContextSchedulerCreationSuite.scala +++ b/core/src/test/scala/org/apache/spark/SparkContextSchedulerCreationSuite.scala @@ -19,6 +19,7 @@ package org.apache.spark import org.scalatest.PrivateMethodTester +import org.apache.spark.util.Utils import org.apache.spark.scheduler.{SchedulerBackend, TaskScheduler, TaskSchedulerImpl} import org.apache.spark.scheduler.cluster.{SimrSchedulerBackend, SparkDeploySchedulerBackend} import org.apache.spark.scheduler.cluster.mesos.{CoarseMesosSchedulerBackend, MesosSchedulerBackend} @@ -122,7 +123,7 @@ class SparkContextSchedulerCreationSuite } test("local-cluster") { - createTaskScheduler("local-cluster[3, 14, 512]").backend match { + createTaskScheduler("local-cluster[3, 14, 1024]").backend match { case s: SparkDeploySchedulerBackend => // OK case _ => fail() } @@ -131,7 +132,7 @@ class SparkContextSchedulerCreationSuite def testYarn(master: String, expectedClassName: String) { try { val sched = createTaskScheduler(master) - assert(sched.getClass === Class.forName(expectedClassName)) + assert(sched.getClass === Utils.classForName(expectedClassName)) } catch { case e: SparkException => assert(e.getMessage.contains("YARN mode not available")) diff --git a/core/src/test/scala/org/apache/spark/broadcast/BroadcastSuite.scala b/core/src/test/scala/org/apache/spark/broadcast/BroadcastSuite.scala index c054c718075f8..48e74f06f79b1 100644 --- a/core/src/test/scala/org/apache/spark/broadcast/BroadcastSuite.scala +++ b/core/src/test/scala/org/apache/spark/broadcast/BroadcastSuite.scala @@ -69,7 +69,7 @@ class BroadcastSuite extends SparkFunSuite with LocalSparkContext { val conf = httpConf.clone conf.set("spark.serializer", "org.apache.spark.serializer.KryoSerializer") conf.set("spark.broadcast.compress", "true") - sc = new SparkContext("local-cluster[%d, 1, 512]".format(numSlaves), "test", conf) + sc = new SparkContext("local-cluster[%d, 1, 1024]".format(numSlaves), "test", conf) val list = List[Int](1, 2, 3, 4) val broadcast = sc.broadcast(list) val results = sc.parallelize(1 to numSlaves).map(x => (x, broadcast.value.sum)) @@ -97,7 +97,7 @@ class BroadcastSuite extends SparkFunSuite with LocalSparkContext { val conf = torrentConf.clone conf.set("spark.serializer", "org.apache.spark.serializer.KryoSerializer") conf.set("spark.broadcast.compress", "true") - sc = new SparkContext("local-cluster[%d, 1, 512]".format(numSlaves), "test", conf) + sc = new SparkContext("local-cluster[%d, 1, 1024]".format(numSlaves), "test", conf) val list = List[Int](1, 2, 3, 4) val broadcast = sc.broadcast(list) val results = sc.parallelize(1 to numSlaves).map(x => (x, broadcast.value.sum)) @@ -125,7 +125,7 @@ class BroadcastSuite extends SparkFunSuite with LocalSparkContext { test("Test Lazy Broadcast variables with TorrentBroadcast") { val numSlaves = 2 val conf = torrentConf.clone - sc = new SparkContext("local-cluster[%d, 1, 512]".format(numSlaves), "test", conf) + sc = new SparkContext("local-cluster[%d, 1, 1024]".format(numSlaves), "test", conf) val rdd = sc.parallelize(1 to numSlaves) val results = new DummyBroadcastClass(rdd).doSomething() @@ -308,7 +308,7 @@ class BroadcastSuite extends SparkFunSuite with LocalSparkContext { sc = if (distributed) { val _sc = - new SparkContext("local-cluster[%d, 1, 512]".format(numSlaves), "test", broadcastConf) + new SparkContext("local-cluster[%d, 1, 1024]".format(numSlaves), "test", broadcastConf) // Wait until all salves are up _sc.jobProgressListener.waitUntilExecutorsUp(numSlaves, 10000) _sc diff --git a/core/src/test/scala/org/apache/spark/deploy/LogUrlsStandaloneSuite.scala b/core/src/test/scala/org/apache/spark/deploy/LogUrlsStandaloneSuite.scala index ddc92814c0acf..cbd2aee10c0e2 100644 --- a/core/src/test/scala/org/apache/spark/deploy/LogUrlsStandaloneSuite.scala +++ b/core/src/test/scala/org/apache/spark/deploy/LogUrlsStandaloneSuite.scala @@ -33,7 +33,7 @@ class LogUrlsStandaloneSuite extends SparkFunSuite with LocalSparkContext { private val WAIT_TIMEOUT_MILLIS = 10000 test("verify that correct log urls get propagated from workers") { - sc = new SparkContext("local-cluster[2,1,512]", "test") + sc = new SparkContext("local-cluster[2,1,1024]", "test") val listener = new SaveExecutorInfo sc.addSparkListener(listener) @@ -66,7 +66,7 @@ class LogUrlsStandaloneSuite extends SparkFunSuite with LocalSparkContext { } val conf = new MySparkConf().set( "spark.extraListeners", classOf[SaveExecutorInfo].getName) - sc = new SparkContext("local-cluster[2,1,512]", "test", conf) + sc = new SparkContext("local-cluster[2,1,1024]", "test", conf) // Trigger a job so that executors get added sc.parallelize(1 to 100, 4).map(_.toString).count() diff --git a/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala b/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala index 1b64c329b5d4b..aa78bfe30974c 100644 --- a/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala +++ b/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala @@ -246,7 +246,7 @@ class SparkSubmitSuite mainClass should be ("org.apache.spark.deploy.Client") } classpath should have size 0 - sysProps should have size 8 + sysProps should have size 9 sysProps.keys should contain ("SPARK_SUBMIT") sysProps.keys should contain ("spark.master") sysProps.keys should contain ("spark.app.name") @@ -255,6 +255,7 @@ class SparkSubmitSuite sysProps.keys should contain ("spark.driver.cores") sysProps.keys should contain ("spark.driver.supervise") sysProps.keys should contain ("spark.shuffle.spill") + sysProps.keys should contain ("spark.submit.deployMode") sysProps("spark.shuffle.spill") should be ("false") } @@ -336,7 +337,7 @@ class SparkSubmitSuite val args = Seq( "--class", JarCreationTest.getClass.getName.stripSuffix("$"), "--name", "testApp", - "--master", "local-cluster[2,1,512]", + "--master", "local-cluster[2,1,1024]", "--jars", jarsString, unusedJar.toString, "SparkSubmitClassA", "SparkSubmitClassB") runSparkSubmit(args) @@ -351,7 +352,7 @@ class SparkSubmitSuite val args = Seq( "--class", JarCreationTest.getClass.getName.stripSuffix("$"), "--name", "testApp", - "--master", "local-cluster[2,1,512]", + "--master", "local-cluster[2,1,1024]", "--packages", Seq(main, dep).mkString(","), "--repositories", repo, "--conf", "spark.ui.enabled=false", @@ -540,8 +541,8 @@ object JarCreationTest extends Logging { val result = sc.makeRDD(1 to 100, 10).mapPartitions { x => var exception: String = null try { - Class.forName(args(0), true, Thread.currentThread().getContextClassLoader) - Class.forName(args(1), true, Thread.currentThread().getContextClassLoader) + Utils.classForName(args(0)) + Utils.classForName(args(1)) } catch { case t: Throwable => exception = t + "\n" + t.getStackTraceString diff --git a/core/src/test/scala/org/apache/spark/deploy/history/FsHistoryProviderSuite.scala b/core/src/test/scala/org/apache/spark/deploy/history/FsHistoryProviderSuite.scala index 2a62450bcdbad..73cff89544dc3 100644 --- a/core/src/test/scala/org/apache/spark/deploy/history/FsHistoryProviderSuite.scala +++ b/core/src/test/scala/org/apache/spark/deploy/history/FsHistoryProviderSuite.scala @@ -243,13 +243,12 @@ class FsHistoryProviderSuite extends SparkFunSuite with BeforeAndAfter with Matc appListAfterRename.size should be (1) } - test("apps with multiple attempts") { + test("apps with multiple attempts with order") { val provider = new FsHistoryProvider(createTestConf()) - val attempt1 = newLogFile("app1", Some("attempt1"), inProgress = false) + val attempt1 = newLogFile("app1", Some("attempt1"), inProgress = true) writeFile(attempt1, true, None, - SparkListenerApplicationStart("app1", Some("app1"), 1L, "test", Some("attempt1")), - SparkListenerApplicationEnd(2L) + SparkListenerApplicationStart("app1", Some("app1"), 1L, "test", Some("attempt1")) ) updateAndCheck(provider) { list => @@ -259,7 +258,7 @@ class FsHistoryProviderSuite extends SparkFunSuite with BeforeAndAfter with Matc val attempt2 = newLogFile("app1", Some("attempt2"), inProgress = true) writeFile(attempt2, true, None, - SparkListenerApplicationStart("app1", Some("app1"), 3L, "test", Some("attempt2")) + SparkListenerApplicationStart("app1", Some("app1"), 2L, "test", Some("attempt2")) ) updateAndCheck(provider) { list => @@ -268,22 +267,21 @@ class FsHistoryProviderSuite extends SparkFunSuite with BeforeAndAfter with Matc list.head.attempts.head.attemptId should be (Some("attempt2")) } - val completedAttempt2 = newLogFile("app1", Some("attempt2"), inProgress = false) - attempt2.delete() - writeFile(attempt2, true, None, - SparkListenerApplicationStart("app1", Some("app1"), 3L, "test", Some("attempt2")), + val attempt3 = newLogFile("app1", Some("attempt3"), inProgress = false) + writeFile(attempt3, true, None, + SparkListenerApplicationStart("app1", Some("app1"), 3L, "test", Some("attempt3")), SparkListenerApplicationEnd(4L) ) updateAndCheck(provider) { list => list should not be (null) list.size should be (1) - list.head.attempts.size should be (2) - list.head.attempts.head.attemptId should be (Some("attempt2")) + list.head.attempts.size should be (3) + list.head.attempts.head.attemptId should be (Some("attempt3")) } val app2Attempt1 = newLogFile("app2", Some("attempt1"), inProgress = false) - writeFile(attempt2, true, None, + writeFile(attempt1, true, None, SparkListenerApplicationStart("app2", Some("app2"), 5L, "test", Some("attempt1")), SparkListenerApplicationEnd(6L) ) @@ -291,7 +289,7 @@ class FsHistoryProviderSuite extends SparkFunSuite with BeforeAndAfter with Matc updateAndCheck(provider) { list => list.size should be (2) list.head.attempts.size should be (1) - list.last.attempts.size should be (2) + list.last.attempts.size should be (3) list.head.attempts.head.attemptId should be (Some("attempt1")) list.foreach { case app => diff --git a/core/src/test/scala/org/apache/spark/deploy/master/CustomRecoveryModeFactory.scala b/core/src/test/scala/org/apache/spark/deploy/master/CustomRecoveryModeFactory.scala index f4e56632e426a..8c96b0e71dfdd 100644 --- a/core/src/test/scala/org/apache/spark/deploy/master/CustomRecoveryModeFactory.scala +++ b/core/src/test/scala/org/apache/spark/deploy/master/CustomRecoveryModeFactory.scala @@ -19,18 +19,19 @@ // when they are outside of org.apache.spark. package other.supplier +import java.nio.ByteBuffer + import scala.collection.mutable import scala.reflect.ClassTag -import akka.serialization.Serialization - import org.apache.spark.SparkConf import org.apache.spark.deploy.master._ +import org.apache.spark.serializer.Serializer class CustomRecoveryModeFactory( conf: SparkConf, - serialization: Serialization -) extends StandaloneRecoveryModeFactory(conf, serialization) { + serializer: Serializer +) extends StandaloneRecoveryModeFactory(conf, serializer) { CustomRecoveryModeFactory.instantiationAttempts += 1 @@ -40,7 +41,7 @@ class CustomRecoveryModeFactory( * */ override def createPersistenceEngine(): PersistenceEngine = - new CustomPersistenceEngine(serialization) + new CustomPersistenceEngine(serializer) /** * Create an instance of LeaderAgent that decides who gets elected as master. @@ -53,7 +54,7 @@ object CustomRecoveryModeFactory { @volatile var instantiationAttempts = 0 } -class CustomPersistenceEngine(serialization: Serialization) extends PersistenceEngine { +class CustomPersistenceEngine(serializer: Serializer) extends PersistenceEngine { val data = mutable.HashMap[String, Array[Byte]]() CustomPersistenceEngine.lastInstance = Some(this) @@ -64,10 +65,10 @@ class CustomPersistenceEngine(serialization: Serialization) extends PersistenceE */ override def persist(name: String, obj: Object): Unit = { CustomPersistenceEngine.persistAttempts += 1 - serialization.serialize(obj) match { - case util.Success(bytes) => data += name -> bytes - case util.Failure(cause) => throw new RuntimeException(cause) - } + val serialized = serializer.newInstance().serialize(obj) + val bytes = new Array[Byte](serialized.remaining()) + serialized.get(bytes) + data += name -> bytes } /** @@ -84,15 +85,9 @@ class CustomPersistenceEngine(serialization: Serialization) extends PersistenceE */ override def read[T: ClassTag](prefix: String): Seq[T] = { CustomPersistenceEngine.readAttempts += 1 - val clazz = implicitly[ClassTag[T]].runtimeClass.asInstanceOf[Class[T]] val results = for ((name, bytes) <- data; if name.startsWith(prefix)) - yield serialization.deserialize(bytes, clazz) - - results.find(_.isFailure).foreach { - case util.Failure(cause) => throw new RuntimeException(cause) - } - - results.flatMap(_.toOption).toSeq + yield serializer.newInstance().deserialize[T](ByteBuffer.wrap(bytes)) + results.toSeq } } diff --git a/core/src/test/scala/org/apache/spark/deploy/master/MasterSuite.scala b/core/src/test/scala/org/apache/spark/deploy/master/MasterSuite.scala index 9cb6dd43bac47..a8fbaf1d9da0a 100644 --- a/core/src/test/scala/org/apache/spark/deploy/master/MasterSuite.scala +++ b/core/src/test/scala/org/apache/spark/deploy/master/MasterSuite.scala @@ -105,7 +105,7 @@ class MasterSuite extends SparkFunSuite with Matchers with Eventually { persistenceEngine.addDriver(driverToPersist) persistenceEngine.addWorker(workerToPersist) - val (apps, drivers, workers) = persistenceEngine.readPersistedData() + val (apps, drivers, workers) = persistenceEngine.readPersistedData(rpcEnv) apps.map(_.id) should contain(appToPersist.id) drivers.map(_.id) should contain(driverToPersist.id) diff --git a/core/src/test/scala/org/apache/spark/deploy/master/PersistenceEngineSuite.scala b/core/src/test/scala/org/apache/spark/deploy/master/PersistenceEngineSuite.scala new file mode 100644 index 0000000000000..11e87bd1dd8eb --- /dev/null +++ b/core/src/test/scala/org/apache/spark/deploy/master/PersistenceEngineSuite.scala @@ -0,0 +1,126 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + + +package org.apache.spark.deploy.master + +import java.net.ServerSocket + +import org.apache.commons.lang3.RandomUtils +import org.apache.curator.test.TestingServer + +import org.apache.spark.{SecurityManager, SparkConf, SparkFunSuite} +import org.apache.spark.rpc.{RpcEndpoint, RpcEnv} +import org.apache.spark.serializer.{Serializer, JavaSerializer} +import org.apache.spark.util.Utils + +class PersistenceEngineSuite extends SparkFunSuite { + + test("FileSystemPersistenceEngine") { + val dir = Utils.createTempDir() + try { + val conf = new SparkConf() + testPersistenceEngine(conf, serializer => + new FileSystemPersistenceEngine(dir.getAbsolutePath, serializer) + ) + } finally { + Utils.deleteRecursively(dir) + } + } + + test("ZooKeeperPersistenceEngine") { + val conf = new SparkConf() + // TestingServer logs the port conflict exception rather than throwing an exception. + // So we have to find a free port by ourselves. This approach cannot guarantee always starting + // zkTestServer successfully because there is a time gap between finding a free port and + // starting zkTestServer. But the failure possibility should be very low. + val zkTestServer = new TestingServer(findFreePort(conf)) + try { + testPersistenceEngine(conf, serializer => { + conf.set("spark.deploy.zookeeper.url", zkTestServer.getConnectString) + new ZooKeeperPersistenceEngine(conf, serializer) + }) + } finally { + zkTestServer.stop() + } + } + + private def testPersistenceEngine( + conf: SparkConf, persistenceEngineCreator: Serializer => PersistenceEngine): Unit = { + val serializer = new JavaSerializer(conf) + val persistenceEngine = persistenceEngineCreator(serializer) + persistenceEngine.persist("test_1", "test_1_value") + assert(Seq("test_1_value") === persistenceEngine.read[String]("test_")) + persistenceEngine.persist("test_2", "test_2_value") + assert(Set("test_1_value", "test_2_value") === persistenceEngine.read[String]("test_").toSet) + persistenceEngine.unpersist("test_1") + assert(Seq("test_2_value") === persistenceEngine.read[String]("test_")) + persistenceEngine.unpersist("test_2") + assert(persistenceEngine.read[String]("test_").isEmpty) + + // Test deserializing objects that contain RpcEndpointRef + val rpcEnv = RpcEnv.create("test", "localhost", 12345, conf, new SecurityManager(conf)) + try { + // Create a real endpoint so that we can test RpcEndpointRef deserialization + val workerEndpoint = rpcEnv.setupEndpoint("worker", new RpcEndpoint { + override val rpcEnv: RpcEnv = rpcEnv + }) + + val workerToPersist = new WorkerInfo( + id = "test_worker", + host = "127.0.0.1", + port = 10000, + cores = 0, + memory = 0, + endpoint = workerEndpoint, + webUiPort = 0, + publicAddress = "" + ) + + persistenceEngine.addWorker(workerToPersist) + + val (storedApps, storedDrivers, storedWorkers) = persistenceEngine.readPersistedData(rpcEnv) + + assert(storedApps.isEmpty) + assert(storedDrivers.isEmpty) + + // Check deserializing WorkerInfo + assert(storedWorkers.size == 1) + val recoveryWorkerInfo = storedWorkers.head + assert(workerToPersist.id === recoveryWorkerInfo.id) + assert(workerToPersist.host === recoveryWorkerInfo.host) + assert(workerToPersist.port === recoveryWorkerInfo.port) + assert(workerToPersist.cores === recoveryWorkerInfo.cores) + assert(workerToPersist.memory === recoveryWorkerInfo.memory) + assert(workerToPersist.endpoint === recoveryWorkerInfo.endpoint) + assert(workerToPersist.webUiPort === recoveryWorkerInfo.webUiPort) + assert(workerToPersist.publicAddress === recoveryWorkerInfo.publicAddress) + } finally { + rpcEnv.shutdown() + rpcEnv.awaitTermination() + } + } + + private def findFreePort(conf: SparkConf): Int = { + val candidatePort = RandomUtils.nextInt(1024, 65536) + Utils.startServiceOnPort(candidatePort, (trialPort: Int) => { + val socket = new ServerSocket(trialPort) + socket.close() + (null, trialPort) + }, conf)._2 + } +} diff --git a/core/src/test/scala/org/apache/spark/rdd/JdbcRDDSuite.scala b/core/src/test/scala/org/apache/spark/rdd/JdbcRDDSuite.scala index 08215a2bafc09..05013fbc49b8e 100644 --- a/core/src/test/scala/org/apache/spark/rdd/JdbcRDDSuite.scala +++ b/core/src/test/scala/org/apache/spark/rdd/JdbcRDDSuite.scala @@ -22,11 +22,12 @@ import java.sql._ import org.scalatest.BeforeAndAfter import org.apache.spark.{LocalSparkContext, SparkContext, SparkFunSuite} +import org.apache.spark.util.Utils class JdbcRDDSuite extends SparkFunSuite with BeforeAndAfter with LocalSparkContext { before { - Class.forName("org.apache.derby.jdbc.EmbeddedDriver") + Utils.classForName("org.apache.derby.jdbc.EmbeddedDriver") val conn = DriverManager.getConnection("jdbc:derby:target/JdbcRDDSuiteDb;create=true") try { diff --git a/core/src/test/scala/org/apache/spark/rdd/PipedRDDSuite.scala b/core/src/test/scala/org/apache/spark/rdd/PipedRDDSuite.scala index 32f04d54eff94..3e8816a4c65be 100644 --- a/core/src/test/scala/org/apache/spark/rdd/PipedRDDSuite.scala +++ b/core/src/test/scala/org/apache/spark/rdd/PipedRDDSuite.scala @@ -175,7 +175,7 @@ class PipedRDDSuite extends SparkFunSuite with SharedSparkContext { } val hadoopPart1 = generateFakeHadoopPartition() val pipedRdd = new PipedRDD(nums, "printenv " + varName) - val tContext = new TaskContextImpl(0, 0, 0, 0, null) + val tContext = new TaskContextImpl(0, 0, 0, 0, null, null) val rddIter = pipedRdd.compute(hadoopPart1, tContext) val arr = rddIter.toArray assert(arr(0) == "/some/path") diff --git a/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala b/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala index f6da9f98ad253..5f718ea9f7be1 100644 --- a/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala +++ b/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala @@ -679,7 +679,7 @@ class RDDSuite extends SparkFunSuite with SharedSparkContext { test("runJob on an invalid partition") { intercept[IllegalArgumentException] { - sc.runJob(sc.parallelize(1 to 10, 2), {iter: Iterator[Int] => iter.size}, Seq(0, 1, 2), false) + sc.runJob(sc.parallelize(1 to 10, 2), {iter: Iterator[Int] => iter.size}, Seq(0, 1, 2)) } } diff --git a/core/src/test/scala/org/apache/spark/scheduler/CoarseGrainedSchedulerBackendSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/CoarseGrainedSchedulerBackendSuite.scala index 34145691153ce..eef6aafa624ee 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/CoarseGrainedSchedulerBackendSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/CoarseGrainedSchedulerBackendSuite.scala @@ -26,7 +26,7 @@ class CoarseGrainedSchedulerBackendSuite extends SparkFunSuite with LocalSparkCo val conf = new SparkConf conf.set("spark.akka.frameSize", "1") conf.set("spark.default.parallelism", "1") - sc = new SparkContext("local-cluster[2 , 1 , 512]", "test", conf) + sc = new SparkContext("local-cluster[2, 1, 1024]", "test", conf) val frameSize = AkkaUtils.maxFrameSizeBytes(sc.conf) val buffer = new SerializableBuffer(java.nio.ByteBuffer.allocate(2 * frameSize)) val larger = sc.parallelize(Seq(buffer)) diff --git a/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala index 6bc45f249f975..86dff8fb577d5 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala @@ -101,9 +101,15 @@ class DAGSchedulerSuite /** Length of time to wait while draining listener events. */ val WAIT_TIMEOUT_MILLIS = 10000 val sparkListener = new SparkListener() { + val submittedStageInfos = new HashSet[StageInfo] val successfulStages = new HashSet[Int] val failedStages = new ArrayBuffer[Int] val stageByOrderOfExecution = new ArrayBuffer[Int] + + override def onStageSubmitted(stageSubmitted: SparkListenerStageSubmitted) { + submittedStageInfos += stageSubmitted.stageInfo + } + override def onStageCompleted(stageCompleted: SparkListenerStageCompleted) { val stageInfo = stageCompleted.stageInfo stageByOrderOfExecution += stageInfo.stageId @@ -147,9 +153,8 @@ class DAGSchedulerSuite } before { - // Enable local execution for this test - val conf = new SparkConf().set("spark.localExecution.enabled", "true") - sc = new SparkContext("local", "DAGSchedulerSuite", conf) + sc = new SparkContext("local", "DAGSchedulerSuite") + sparkListener.submittedStageInfos.clear() sparkListener.successfulStages.clear() sparkListener.failedStages.clear() failure = null @@ -165,12 +170,7 @@ class DAGSchedulerSuite sc.listenerBus, mapOutputTracker, blockManagerMaster, - sc.env) { - override def runLocally(job: ActiveJob) { - // don't bother with the thread while unit testing - runLocallyWithinThread(job) - } - } + sc.env) dagEventProcessLoopTester = new DAGSchedulerEventProcessLoopTester(scheduler) } @@ -234,10 +234,9 @@ class DAGSchedulerSuite rdd: RDD[_], partitions: Array[Int], func: (TaskContext, Iterator[_]) => _ = jobComputeFunc, - allowLocal: Boolean = false, listener: JobListener = jobListener): Int = { val jobId = scheduler.nextJobId.getAndIncrement() - runEvent(JobSubmitted(jobId, rdd, func, partitions, allowLocal, CallSite("", ""), listener)) + runEvent(JobSubmitted(jobId, rdd, func, partitions, CallSite("", ""), listener)) jobId } @@ -277,37 +276,6 @@ class DAGSchedulerSuite assertDataStructuresEmpty() } - test("local job") { - val rdd = new PairOfIntsRDD(sc, Nil) { - override def compute(split: Partition, context: TaskContext): Iterator[(Int, Int)] = - Array(42 -> 0).iterator - override def getPartitions: Array[Partition] = - Array( new Partition { override def index: Int = 0 } ) - override def getPreferredLocations(split: Partition): List[String] = Nil - override def toString: String = "DAGSchedulerSuite Local RDD" - } - val jobId = scheduler.nextJobId.getAndIncrement() - runEvent( - JobSubmitted(jobId, rdd, jobComputeFunc, Array(0), true, CallSite("", ""), jobListener)) - assert(results === Map(0 -> 42)) - assertDataStructuresEmpty() - } - - test("local job oom") { - val rdd = new PairOfIntsRDD(sc, Nil) { - override def compute(split: Partition, context: TaskContext): Iterator[(Int, Int)] = - throw new java.lang.OutOfMemoryError("test local job oom") - override def getPartitions = Array( new Partition { override def index = 0 } ) - override def getPreferredLocations(split: Partition) = Nil - override def toString = "DAGSchedulerSuite Local RDD" - } - val jobId = scheduler.nextJobId.getAndIncrement() - runEvent( - JobSubmitted(jobId, rdd, jobComputeFunc, Array(0), true, CallSite("", ""), jobListener)) - assert(results.size == 0) - assertDataStructuresEmpty() - } - test("run trivial job w/ dependency") { val baseRdd = new MyRDD(sc, 1, Nil) val finalRdd = new MyRDD(sc, 1, List(new OneToOneDependency(baseRdd))) @@ -445,12 +413,7 @@ class DAGSchedulerSuite sc.listenerBus, mapOutputTracker, blockManagerMaster, - sc.env) { - override def runLocally(job: ActiveJob) { - // don't bother with the thread while unit testing - runLocallyWithinThread(job) - } - } + sc.env) dagEventProcessLoopTester = new DAGSchedulerEventProcessLoopTester(noKillScheduler) val jobId = submit(new MyRDD(sc, 1, Nil), Array(0)) cancel(jobId) @@ -476,8 +439,8 @@ class DAGSchedulerSuite complete(taskSets(0), Seq( (Success, makeMapStatus("hostA", 1)), (Success, makeMapStatus("hostB", 1)))) - assert(mapOutputTracker.getServerStatuses(shuffleId, 0).map(_._1) === - Array(makeBlockManagerId("hostA"), makeBlockManagerId("hostB"))) + assert(mapOutputTracker.getMapSizesByExecutorId(shuffleId, 0).map(_._1).toSet === + HashSet(makeBlockManagerId("hostA"), makeBlockManagerId("hostB"))) complete(taskSets(1), Seq((Success, 42))) assert(results === Map(0 -> 42)) assertDataStructuresEmpty() @@ -503,8 +466,8 @@ class DAGSchedulerSuite // have the 2nd attempt pass complete(taskSets(2), Seq((Success, makeMapStatus("hostA", reduceRdd.partitions.size)))) // we can see both result blocks now - assert(mapOutputTracker.getServerStatuses(shuffleId, 0).map(_._1.host) === - Array("hostA", "hostB")) + assert(mapOutputTracker.getMapSizesByExecutorId(shuffleId, 0).map(_._1.host).toSet === + HashSet("hostA", "hostB")) complete(taskSets(3), Seq((Success, 43))) assert(results === Map(0 -> 42, 1 -> 43)) assertDataStructuresEmpty() @@ -520,8 +483,8 @@ class DAGSchedulerSuite (Success, makeMapStatus("hostA", reduceRdd.partitions.size)), (Success, makeMapStatus("hostB", reduceRdd.partitions.size)))) // The MapOutputTracker should know about both map output locations. - assert(mapOutputTracker.getServerStatuses(shuffleId, 0).map(_._1.host) === - Array("hostA", "hostB")) + assert(mapOutputTracker.getMapSizesByExecutorId(shuffleId, 0).map(_._1.host).toSet === + HashSet("hostA", "hostB")) // The first result task fails, with a fetch failure for the output from the first mapper. runEvent(CompletionEvent( @@ -547,6 +510,140 @@ class DAGSchedulerSuite assert(sparkListener.failedStages.size == 1) } + /** + * This tests the case where another FetchFailed comes in while the map stage is getting + * re-run. + */ + test("late fetch failures don't cause multiple concurrent attempts for the same map stage") { + val shuffleMapRdd = new MyRDD(sc, 2, Nil) + val shuffleDep = new ShuffleDependency(shuffleMapRdd, null) + val shuffleId = shuffleDep.shuffleId + val reduceRdd = new MyRDD(sc, 2, List(shuffleDep)) + submit(reduceRdd, Array(0, 1)) + + val mapStageId = 0 + def countSubmittedMapStageAttempts(): Int = { + sparkListener.submittedStageInfos.count(_.stageId == mapStageId) + } + + // The map stage should have been submitted. + assert(countSubmittedMapStageAttempts() === 1) + + complete(taskSets(0), Seq( + (Success, makeMapStatus("hostA", 2)), + (Success, makeMapStatus("hostB", 2)))) + // The MapOutputTracker should know about both map output locations. + assert(mapOutputTracker.getMapSizesByExecutorId(shuffleId, 0).map(_._1.host).toSet === + HashSet("hostA", "hostB")) + assert(mapOutputTracker.getMapSizesByExecutorId(shuffleId, 1).map(_._1.host).toSet === + HashSet("hostA", "hostB")) + + // The first result task fails, with a fetch failure for the output from the first mapper. + runEvent(CompletionEvent( + taskSets(1).tasks(0), + FetchFailed(makeBlockManagerId("hostA"), shuffleId, 0, 0, "ignored"), + null, + Map[Long, Any](), + createFakeTaskInfo(), + null)) + sc.listenerBus.waitUntilEmpty(WAIT_TIMEOUT_MILLIS) + assert(sparkListener.failedStages.contains(1)) + + // Trigger resubmission of the failed map stage. + runEvent(ResubmitFailedStages) + sc.listenerBus.waitUntilEmpty(WAIT_TIMEOUT_MILLIS) + + // Another attempt for the map stage should have been submitted, resulting in 2 total attempts. + assert(countSubmittedMapStageAttempts() === 2) + + // The second ResultTask fails, with a fetch failure for the output from the second mapper. + runEvent(CompletionEvent( + taskSets(1).tasks(1), + FetchFailed(makeBlockManagerId("hostB"), shuffleId, 1, 1, "ignored"), + null, + Map[Long, Any](), + createFakeTaskInfo(), + null)) + + // Another ResubmitFailedStages event should not result in another attempt for the map + // stage being run concurrently. + // NOTE: the actual ResubmitFailedStages may get called at any time during this, but it + // shouldn't effect anything -- our calling it just makes *SURE* it gets called between the + // desired event and our check. + runEvent(ResubmitFailedStages) + sc.listenerBus.waitUntilEmpty(WAIT_TIMEOUT_MILLIS) + assert(countSubmittedMapStageAttempts() === 2) + + } + + /** + * This tests the case where a late FetchFailed comes in after the map stage has finished getting + * retried and a new reduce stage starts running. + */ + test("extremely late fetch failures don't cause multiple concurrent attempts for " + + "the same stage") { + val shuffleMapRdd = new MyRDD(sc, 2, Nil) + val shuffleDep = new ShuffleDependency(shuffleMapRdd, null) + val shuffleId = shuffleDep.shuffleId + val reduceRdd = new MyRDD(sc, 2, List(shuffleDep)) + submit(reduceRdd, Array(0, 1)) + + def countSubmittedReduceStageAttempts(): Int = { + sparkListener.submittedStageInfos.count(_.stageId == 1) + } + def countSubmittedMapStageAttempts(): Int = { + sparkListener.submittedStageInfos.count(_.stageId == 0) + } + + // The map stage should have been submitted. + sc.listenerBus.waitUntilEmpty(WAIT_TIMEOUT_MILLIS) + assert(countSubmittedMapStageAttempts() === 1) + + // Complete the map stage. + complete(taskSets(0), Seq( + (Success, makeMapStatus("hostA", 2)), + (Success, makeMapStatus("hostB", 2)))) + + // The reduce stage should have been submitted. + sc.listenerBus.waitUntilEmpty(WAIT_TIMEOUT_MILLIS) + assert(countSubmittedReduceStageAttempts() === 1) + + // The first result task fails, with a fetch failure for the output from the first mapper. + runEvent(CompletionEvent( + taskSets(1).tasks(0), + FetchFailed(makeBlockManagerId("hostA"), shuffleId, 0, 0, "ignored"), + null, + Map[Long, Any](), + createFakeTaskInfo(), + null)) + + // Trigger resubmission of the failed map stage and finish the re-started map task. + runEvent(ResubmitFailedStages) + complete(taskSets(2), Seq((Success, makeMapStatus("hostA", 1)))) + + // Because the map stage finished, another attempt for the reduce stage should have been + // submitted, resulting in 2 total attempts for each the map and the reduce stage. + sc.listenerBus.waitUntilEmpty(WAIT_TIMEOUT_MILLIS) + assert(countSubmittedMapStageAttempts() === 2) + assert(countSubmittedReduceStageAttempts() === 2) + + // A late FetchFailed arrives from the second task in the original reduce stage. + runEvent(CompletionEvent( + taskSets(1).tasks(1), + FetchFailed(makeBlockManagerId("hostB"), shuffleId, 1, 1, "ignored"), + null, + Map[Long, Any](), + createFakeTaskInfo(), + null)) + + // Running ResubmitFailedStages shouldn't result in any more attempts for the map stage, because + // the FetchFailed should have been ignored + runEvent(ResubmitFailedStages) + + // The FetchFailed from the original reduce stage should be ignored. + assert(countSubmittedMapStageAttempts() === 2) + } + test("ignore late map task completions") { val shuffleMapRdd = new MyRDD(sc, 2, Nil) val shuffleDep = new ShuffleDependency(shuffleMapRdd, null) @@ -572,8 +669,8 @@ class DAGSchedulerSuite taskSet.tasks(1).epoch = newEpoch runEvent(CompletionEvent(taskSet.tasks(1), Success, makeMapStatus("hostA", reduceRdd.partitions.size), null, createFakeTaskInfo(), null)) - assert(mapOutputTracker.getServerStatuses(shuffleId, 0).map(_._1) === - Array(makeBlockManagerId("hostB"), makeBlockManagerId("hostA"))) + assert(mapOutputTracker.getMapSizesByExecutorId(shuffleId, 0).map(_._1).toSet === + HashSet(makeBlockManagerId("hostB"), makeBlockManagerId("hostA"))) complete(taskSets(1), Seq((Success, 42), (Success, 43))) assert(results === Map(0 -> 42, 1 -> 43)) assertDataStructuresEmpty() @@ -668,8 +765,8 @@ class DAGSchedulerSuite (Success, makeMapStatus("hostB", 1)))) // have hostC complete the resubmitted task complete(taskSets(1), Seq((Success, makeMapStatus("hostC", 1)))) - assert(mapOutputTracker.getServerStatuses(shuffleId, 0).map(_._1) === - Array(makeBlockManagerId("hostC"), makeBlockManagerId("hostB"))) + assert(mapOutputTracker.getMapSizesByExecutorId(shuffleId, 0).map(_._1).toSet === + HashSet(makeBlockManagerId("hostC"), makeBlockManagerId("hostB"))) complete(taskSets(2), Seq((Success, 42))) assert(results === Map(0 -> 42)) assertDataStructuresEmpty() @@ -748,40 +845,23 @@ class DAGSchedulerSuite // Run this on executors sc.parallelize(1 to 10, 2).foreach { item => acc.add(1) } - // Run this within a local thread - sc.parallelize(1 to 10, 2).map { item => acc.add(1) }.take(1) - - // Make sure we can still run local commands as well as cluster commands. + // Make sure we can still run commands assert(sc.parallelize(1 to 10, 2).count() === 10) - assert(sc.parallelize(1 to 10, 2).first() === 1) } test("misbehaved resultHandler should not crash DAGScheduler and SparkContext") { - val e1 = intercept[SparkDriverExecutionException] { - val rdd = sc.parallelize(1 to 10, 2) - sc.runJob[Int, Int]( - rdd, - (context: TaskContext, iter: Iterator[Int]) => iter.size, - Seq(0), - allowLocal = true, - (part: Int, result: Int) => throw new DAGSchedulerSuiteDummyException) - } - assert(e1.getCause.isInstanceOf[DAGSchedulerSuiteDummyException]) - - val e2 = intercept[SparkDriverExecutionException] { + val e = intercept[SparkDriverExecutionException] { val rdd = sc.parallelize(1 to 10, 2) sc.runJob[Int, Int]( rdd, (context: TaskContext, iter: Iterator[Int]) => iter.size, Seq(0, 1), - allowLocal = false, (part: Int, result: Int) => throw new DAGSchedulerSuiteDummyException) } - assert(e2.getCause.isInstanceOf[DAGSchedulerSuiteDummyException]) + assert(e.getCause.isInstanceOf[DAGSchedulerSuiteDummyException]) - // Make sure we can still run local commands as well as cluster commands. + // Make sure we can still run commands assert(sc.parallelize(1 to 10, 2).count() === 10) - assert(sc.parallelize(1 to 10, 2).first() === 1) } test("getPartitions exceptions should not crash DAGScheduler and SparkContext (SPARK-8606)") { @@ -794,9 +874,8 @@ class DAGSchedulerSuite rdd.reduceByKey(_ + _, 1).count() } - // Make sure we can still run local commands as well as cluster commands. + // Make sure we can still run commands assert(sc.parallelize(1 to 10, 2).count() === 10) - assert(sc.parallelize(1 to 10, 2).first() === 1) } test("getPreferredLocations errors should not crash DAGScheduler and SparkContext (SPARK-8606)") { @@ -810,9 +889,8 @@ class DAGSchedulerSuite } assert(e1.getMessage.contains(classOf[DAGSchedulerSuiteDummyException].getName)) - // Make sure we can still run local commands as well as cluster commands. + // Make sure we can still run commands assert(sc.parallelize(1 to 10, 2).count() === 10) - assert(sc.parallelize(1 to 10, 2).first() === 1) } test("accumulator not calculated for resubmitted result stage") { @@ -840,8 +918,8 @@ class DAGSchedulerSuite submit(reduceRdd, Array(0)) complete(taskSets(0), Seq( (Success, makeMapStatus("hostA", 1)))) - assert(mapOutputTracker.getServerStatuses(shuffleId, 0).map(_._1) === - Array(makeBlockManagerId("hostA"))) + assert(mapOutputTracker.getMapSizesByExecutorId(shuffleId, 0).map(_._1).toSet === + HashSet(makeBlockManagerId("hostA"))) // Reducer should run on the same host that map task ran val reduceTaskSet = taskSets(1) @@ -875,6 +953,21 @@ class DAGSchedulerSuite assertDataStructuresEmpty } + test("Spark exceptions should include call site in stack trace") { + val e = intercept[SparkException] { + sc.parallelize(1 to 10, 2).map { _ => throw new RuntimeException("uh-oh!") }.count() + } + + // Does not include message, ONLY stack trace. + val stackTraceString = e.getStackTraceString + + // should actually include the RDD operation that invoked the method: + assert(stackTraceString.contains("org.apache.spark.rdd.RDD.count")) + + // should include the FunSuite setup: + assert(stackTraceString.contains("org.scalatest.FunSuite")) + } + /** * Assert that the supplied TaskSet has exactly the given hosts as its preferred locations. * Note that this checks only the host and not the executor ID. diff --git a/core/src/test/scala/org/apache/spark/scheduler/EventLoggingListenerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/EventLoggingListenerSuite.scala index f681f21b6205e..5cb2d4225d281 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/EventLoggingListenerSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/EventLoggingListenerSuite.scala @@ -180,7 +180,7 @@ class EventLoggingListenerSuite extends SparkFunSuite with LocalSparkContext wit // into SPARK-6688. val conf = getLoggingConf(testDirPath, compressionCodec) .set("spark.hadoop.fs.defaultFS", "unsupported://example.com") - val sc = new SparkContext("local-cluster[2,2,512]", "test", conf) + val sc = new SparkContext("local-cluster[2,2,1024]", "test", conf) assert(sc.eventLogger.isDefined) val eventLogger = sc.eventLogger.get val eventLogPath = eventLogger.logPath diff --git a/core/src/test/scala/org/apache/spark/scheduler/FakeTask.scala b/core/src/test/scala/org/apache/spark/scheduler/FakeTask.scala index 0a7cb69416a08..b3ca150195a5f 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/FakeTask.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/FakeTask.scala @@ -19,7 +19,7 @@ package org.apache.spark.scheduler import org.apache.spark.TaskContext -class FakeTask(stageId: Int, prefLocs: Seq[TaskLocation] = Nil) extends Task[Int](stageId, 0) { +class FakeTask(stageId: Int, prefLocs: Seq[TaskLocation] = Nil) extends Task[Int](stageId, 0, 0) { override def runTask(context: TaskContext): Int = 0 override def preferredLocations: Seq[TaskLocation] = prefLocs @@ -31,12 +31,16 @@ object FakeTask { * locations for each task (given as varargs) if this sequence is not empty. */ def createTaskSet(numTasks: Int, prefLocs: Seq[TaskLocation]*): TaskSet = { + createTaskSet(numTasks, 0, prefLocs: _*) + } + + def createTaskSet(numTasks: Int, stageAttemptId: Int, prefLocs: Seq[TaskLocation]*): TaskSet = { if (prefLocs.size != 0 && prefLocs.size != numTasks) { throw new IllegalArgumentException("Wrong number of task locations") } val tasks = Array.tabulate[Task[_]](numTasks) { i => new FakeTask(i, if (prefLocs.size != 0) prefLocs(i) else Nil) } - new TaskSet(tasks, 0, 0, 0, null) + new TaskSet(tasks, 0, stageAttemptId, 0, null) } } diff --git a/core/src/test/scala/org/apache/spark/scheduler/NotSerializableFakeTask.scala b/core/src/test/scala/org/apache/spark/scheduler/NotSerializableFakeTask.scala index 9b92f8de56759..383855caefa2f 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/NotSerializableFakeTask.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/NotSerializableFakeTask.scala @@ -25,7 +25,7 @@ import org.apache.spark.TaskContext * A Task implementation that fails to serialize. */ private[spark] class NotSerializableFakeTask(myId: Int, stageId: Int) - extends Task[Array[Byte]](stageId, 0) { + extends Task[Array[Byte]](stageId, 0, 0) { override def runTask(context: TaskContext): Array[Byte] = Array.empty[Byte] override def preferredLocations: Seq[TaskLocation] = Seq[TaskLocation]() diff --git a/core/src/test/scala/org/apache/spark/scheduler/OutputCommitCoordinatorSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/OutputCommitCoordinatorSuite.scala index a9036da9cc93d..e5ecd4b7c2610 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/OutputCommitCoordinatorSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/OutputCommitCoordinatorSuite.scala @@ -134,14 +134,14 @@ class OutputCommitCoordinatorSuite extends SparkFunSuite with BeforeAndAfter { test("Only one of two duplicate commit tasks should commit") { val rdd = sc.parallelize(Seq(1), 1) sc.runJob(rdd, OutputCommitFunctions(tempDir.getAbsolutePath).commitSuccessfully _, - 0 until rdd.partitions.size, allowLocal = false) + 0 until rdd.partitions.size) assert(tempDir.list().size === 1) } test("If commit fails, if task is retried it should not be locked, and will succeed.") { val rdd = sc.parallelize(Seq(1), 1) sc.runJob(rdd, OutputCommitFunctions(tempDir.getAbsolutePath).failFirstCommitAttempt _, - 0 until rdd.partitions.size, allowLocal = false) + 0 until rdd.partitions.size) assert(tempDir.list().size === 1) } diff --git a/core/src/test/scala/org/apache/spark/scheduler/ReplayListenerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/ReplayListenerSuite.scala index 4e3defb43a021..103fc19369c97 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/ReplayListenerSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/ReplayListenerSuite.scala @@ -102,7 +102,7 @@ class ReplayListenerSuite extends SparkFunSuite with BeforeAndAfter { fileSystem.mkdirs(logDirPath) val conf = EventLoggingListenerSuite.getLoggingConf(logDirPath, codecName) - val sc = new SparkContext("local-cluster[2,1,512]", "Test replay", conf) + val sc = new SparkContext("local-cluster[2,1,1024]", "Test replay", conf) // Run a few jobs sc.parallelize(1 to 100, 1).count() diff --git a/core/src/test/scala/org/apache/spark/scheduler/SparkListenerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/SparkListenerSuite.scala index 651295b7344c5..730535ece7878 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/SparkListenerSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/SparkListenerSuite.scala @@ -188,7 +188,7 @@ class SparkListenerSuite extends SparkFunSuite with LocalSparkContext with Match sc.addSparkListener(listener) val rdd1 = sc.parallelize(1 to 100, 4) val rdd2 = rdd1.map(_.toString) - sc.runJob(rdd2, (items: Iterator[String]) => items.size, Seq(0, 1), true) + sc.runJob(rdd2, (items: Iterator[String]) => items.size, Seq(0, 1)) sc.listenerBus.waitUntilEmpty(WAIT_TIMEOUT_MILLIS) diff --git a/core/src/test/scala/org/apache/spark/scheduler/SparkListenerWithClusterSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/SparkListenerWithClusterSuite.scala index d97fba00976d2..d1e23ed527ff1 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/SparkListenerWithClusterSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/SparkListenerWithClusterSuite.scala @@ -34,7 +34,7 @@ class SparkListenerWithClusterSuite extends SparkFunSuite with LocalSparkContext val WAIT_TIMEOUT_MILLIS = 10000 before { - sc = new SparkContext("local-cluster[2,1,512]", "SparkListenerSuite") + sc = new SparkContext("local-cluster[2,1,1024]", "SparkListenerSuite") } test("SparkListener sends executor added message") { diff --git a/core/src/test/scala/org/apache/spark/scheduler/TaskContextSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/TaskContextSuite.scala index 7c1adc1aef1b6..9201d1e1f328b 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/TaskContextSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/TaskContextSuite.scala @@ -24,11 +24,27 @@ import org.scalatest.BeforeAndAfter import org.apache.spark._ import org.apache.spark.rdd.RDD -import org.apache.spark.util.{TaskCompletionListenerException, TaskCompletionListener} +import org.apache.spark.util.{TaskCompletionListener, TaskCompletionListenerException} +import org.apache.spark.metrics.source.JvmSource class TaskContextSuite extends SparkFunSuite with BeforeAndAfter with LocalSparkContext { + test("provide metrics sources") { + val filePath = getClass.getClassLoader.getResource("test_metrics_config.properties").getFile + val conf = new SparkConf(loadDefaults = false) + .set("spark.metrics.conf", filePath) + sc = new SparkContext("local", "test", conf) + val rdd = sc.makeRDD(1 to 1) + val result = sc.runJob(rdd, (tc: TaskContext, it: Iterator[Int]) => { + tc.getMetricsSources("jvm").count { + case source: JvmSource => true + case _ => false + } + }).sum + assert(result > 0) + } + test("calls TaskCompletionListener after failure") { TaskContextSuite.completed = false sc = new SparkContext("local", "test") @@ -41,16 +57,16 @@ class TaskContextSuite extends SparkFunSuite with BeforeAndAfter with LocalSpark } val closureSerializer = SparkEnv.get.closureSerializer.newInstance() val func = (c: TaskContext, i: Iterator[String]) => i.next() - val task = new ResultTask[String, String]( - 0, sc.broadcast(closureSerializer.serialize((rdd, func)).array), rdd.partitions(0), Seq(), 0) + val task = new ResultTask[String, String](0, 0, + sc.broadcast(closureSerializer.serialize((rdd, func)).array), rdd.partitions(0), Seq(), 0) intercept[RuntimeException] { - task.run(0, 0) + task.run(0, 0, null) } assert(TaskContextSuite.completed === true) } test("all TaskCompletionListeners should be called even if some fail") { - val context = new TaskContextImpl(0, 0, 0, 0, null) + val context = new TaskContextImpl(0, 0, 0, 0, null, null) val listener = mock(classOf[TaskCompletionListener]) context.addTaskCompletionListener(_ => throw new Exception("blah")) context.addTaskCompletionListener(listener) diff --git a/core/src/test/scala/org/apache/spark/scheduler/TaskSchedulerImplSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/TaskSchedulerImplSuite.scala index a6d5232feb8de..c2edd4c317d6e 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/TaskSchedulerImplSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/TaskSchedulerImplSuite.scala @@ -33,7 +33,7 @@ class TaskSchedulerImplSuite extends SparkFunSuite with LocalSparkContext with L val taskScheduler = new TaskSchedulerImpl(sc) taskScheduler.initialize(new FakeSchedulerBackend) // Need to initialize a DAGScheduler for the taskScheduler to use for callbacks. - val dagScheduler = new DAGScheduler(sc, taskScheduler) { + new DAGScheduler(sc, taskScheduler) { override def taskStarted(task: Task[_], taskInfo: TaskInfo) {} override def executorAdded(execId: String, host: String) {} } @@ -67,7 +67,7 @@ class TaskSchedulerImplSuite extends SparkFunSuite with LocalSparkContext with L val taskScheduler = new TaskSchedulerImpl(sc) taskScheduler.initialize(new FakeSchedulerBackend) // Need to initialize a DAGScheduler for the taskScheduler to use for callbacks. - val dagScheduler = new DAGScheduler(sc, taskScheduler) { + new DAGScheduler(sc, taskScheduler) { override def taskStarted(task: Task[_], taskInfo: TaskInfo) {} override def executorAdded(execId: String, host: String) {} } @@ -128,4 +128,113 @@ class TaskSchedulerImplSuite extends SparkFunSuite with LocalSparkContext with L assert(taskDescriptions.map(_.executorId) === Seq("executor0")) } + test("refuse to schedule concurrent attempts for the same stage (SPARK-8103)") { + sc = new SparkContext("local", "TaskSchedulerImplSuite") + val taskScheduler = new TaskSchedulerImpl(sc) + taskScheduler.initialize(new FakeSchedulerBackend) + // Need to initialize a DAGScheduler for the taskScheduler to use for callbacks. + val dagScheduler = new DAGScheduler(sc, taskScheduler) { + override def taskStarted(task: Task[_], taskInfo: TaskInfo) {} + override def executorAdded(execId: String, host: String) {} + } + taskScheduler.setDAGScheduler(dagScheduler) + val attempt1 = FakeTask.createTaskSet(1, 0) + val attempt2 = FakeTask.createTaskSet(1, 1) + taskScheduler.submitTasks(attempt1) + intercept[IllegalStateException] { taskScheduler.submitTasks(attempt2) } + + // OK to submit multiple if previous attempts are all zombie + taskScheduler.taskSetManagerForAttempt(attempt1.stageId, attempt1.stageAttemptId) + .get.isZombie = true + taskScheduler.submitTasks(attempt2) + val attempt3 = FakeTask.createTaskSet(1, 2) + intercept[IllegalStateException] { taskScheduler.submitTasks(attempt3) } + taskScheduler.taskSetManagerForAttempt(attempt2.stageId, attempt2.stageAttemptId) + .get.isZombie = true + taskScheduler.submitTasks(attempt3) + } + + test("don't schedule more tasks after a taskset is zombie") { + sc = new SparkContext("local", "TaskSchedulerImplSuite") + val taskScheduler = new TaskSchedulerImpl(sc) + taskScheduler.initialize(new FakeSchedulerBackend) + // Need to initialize a DAGScheduler for the taskScheduler to use for callbacks. + new DAGScheduler(sc, taskScheduler) { + override def taskStarted(task: Task[_], taskInfo: TaskInfo) {} + override def executorAdded(execId: String, host: String) {} + } + + val numFreeCores = 1 + val workerOffers = Seq(new WorkerOffer("executor0", "host0", numFreeCores)) + val attempt1 = FakeTask.createTaskSet(10) + + // submit attempt 1, offer some resources, some tasks get scheduled + taskScheduler.submitTasks(attempt1) + val taskDescriptions = taskScheduler.resourceOffers(workerOffers).flatten + assert(1 === taskDescriptions.length) + + // now mark attempt 1 as a zombie + taskScheduler.taskSetManagerForAttempt(attempt1.stageId, attempt1.stageAttemptId) + .get.isZombie = true + + // don't schedule anything on another resource offer + val taskDescriptions2 = taskScheduler.resourceOffers(workerOffers).flatten + assert(0 === taskDescriptions2.length) + + // if we schedule another attempt for the same stage, it should get scheduled + val attempt2 = FakeTask.createTaskSet(10, 1) + + // submit attempt 2, offer some resources, some tasks get scheduled + taskScheduler.submitTasks(attempt2) + val taskDescriptions3 = taskScheduler.resourceOffers(workerOffers).flatten + assert(1 === taskDescriptions3.length) + val mgr = taskScheduler.taskIdToTaskSetManager.get(taskDescriptions3(0).taskId).get + assert(mgr.taskSet.stageAttemptId === 1) + } + + test("if a zombie attempt finishes, continue scheduling tasks for non-zombie attempts") { + sc = new SparkContext("local", "TaskSchedulerImplSuite") + val taskScheduler = new TaskSchedulerImpl(sc) + taskScheduler.initialize(new FakeSchedulerBackend) + // Need to initialize a DAGScheduler for the taskScheduler to use for callbacks. + new DAGScheduler(sc, taskScheduler) { + override def taskStarted(task: Task[_], taskInfo: TaskInfo) {} + override def executorAdded(execId: String, host: String) {} + } + + val numFreeCores = 10 + val workerOffers = Seq(new WorkerOffer("executor0", "host0", numFreeCores)) + val attempt1 = FakeTask.createTaskSet(10) + + // submit attempt 1, offer some resources, some tasks get scheduled + taskScheduler.submitTasks(attempt1) + val taskDescriptions = taskScheduler.resourceOffers(workerOffers).flatten + assert(10 === taskDescriptions.length) + + // now mark attempt 1 as a zombie + val mgr1 = taskScheduler.taskSetManagerForAttempt(attempt1.stageId, attempt1.stageAttemptId).get + mgr1.isZombie = true + + // don't schedule anything on another resource offer + val taskDescriptions2 = taskScheduler.resourceOffers(workerOffers).flatten + assert(0 === taskDescriptions2.length) + + // submit attempt 2 + val attempt2 = FakeTask.createTaskSet(10, 1) + taskScheduler.submitTasks(attempt2) + + // attempt 1 finished (this can happen even if it was marked zombie earlier -- all tasks were + // already submitted, and then they finish) + taskScheduler.taskSetFinished(mgr1) + + // now with another resource offer, we should still schedule all the tasks in attempt2 + val taskDescriptions3 = taskScheduler.resourceOffers(workerOffers).flatten + assert(10 === taskDescriptions3.length) + + taskDescriptions3.foreach { task => + val mgr = taskScheduler.taskIdToTaskSetManager.get(task.taskId).get + assert(mgr.taskSet.stageAttemptId === 1) + } + } + } diff --git a/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala index 0060f3396dcde..3abb99c4b2b54 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala @@ -19,12 +19,13 @@ package org.apache.spark.scheduler import java.util.Random -import scala.collection.mutable.ArrayBuffer +import scala.collection.Map import scala.collection.mutable +import scala.collection.mutable.ArrayBuffer import org.apache.spark._ import org.apache.spark.executor.TaskMetrics -import org.apache.spark.util.{ManualClock, Utils} +import org.apache.spark.util.ManualClock class FakeDAGScheduler(sc: SparkContext, taskScheduler: FakeTaskScheduler) extends DAGScheduler(sc) { @@ -37,7 +38,7 @@ class FakeDAGScheduler(sc: SparkContext, taskScheduler: FakeTaskScheduler) task: Task[_], reason: TaskEndReason, result: Any, - accumUpdates: mutable.Map[Long, Any], + accumUpdates: Map[Long, Any], taskInfo: TaskInfo, taskMetrics: TaskMetrics) { taskScheduler.endedTasks(taskInfo.index) = reason @@ -135,7 +136,7 @@ class FakeTaskScheduler(sc: SparkContext, liveExecutors: (String, String)* /* ex /** * A Task implementation that results in a large serialized task. */ -class LargeTask(stageId: Int) extends Task[Array[Byte]](stageId, 0) { +class LargeTask(stageId: Int) extends Task[Array[Byte]](stageId, 0, 0) { val randomBuffer = new Array[Byte](TaskSetManager.TASK_SIZE_TO_WARN_KB * 1024) val random = new Random(0) random.nextBytes(randomBuffer) diff --git a/core/src/test/scala/org/apache/spark/scheduler/cluster/mesos/CoarseMesosSchedulerBackendSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/cluster/mesos/CoarseMesosSchedulerBackendSuite.scala index 3f1692917a357..4b504df7b8851 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/cluster/mesos/CoarseMesosSchedulerBackendSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/cluster/mesos/CoarseMesosSchedulerBackendSuite.scala @@ -22,7 +22,7 @@ import java.util.Collections import org.apache.mesos.Protos.Value.Scalar import org.apache.mesos.Protos._ -import org.apache.mesos.SchedulerDriver +import org.apache.mesos.{Protos, Scheduler, SchedulerDriver} import org.mockito.Matchers._ import org.mockito.Mockito._ import org.mockito.Matchers @@ -60,7 +60,16 @@ class CoarseMesosSchedulerBackendSuite extends SparkFunSuite taskScheduler: TaskSchedulerImpl, driver: SchedulerDriver): CoarseMesosSchedulerBackend = { val backend = new CoarseMesosSchedulerBackend(taskScheduler, sc, "master") { - mesosDriver = driver + override protected def createSchedulerDriver( + masterUrl: String, + scheduler: Scheduler, + sparkUser: String, + appName: String, + conf: SparkConf, + webuiUrl: Option[String] = None, + checkpoint: Option[Boolean] = None, + failoverTimeout: Option[Double] = None, + frameworkId: Option[String] = None): SchedulerDriver = driver markRegistered() } backend.start() @@ -80,6 +89,7 @@ class CoarseMesosSchedulerBackendSuite extends SparkFunSuite test("mesos supports killing and limiting executors") { val driver = mock[SchedulerDriver] + when(driver.start()).thenReturn(Protos.Status.DRIVER_RUNNING) val taskScheduler = mock[TaskSchedulerImpl] when(taskScheduler.sc).thenReturn(sc) @@ -87,7 +97,7 @@ class CoarseMesosSchedulerBackendSuite extends SparkFunSuite sparkConf.set("spark.driver.port", "1234") val backend = createSchedulerBackend(taskScheduler, driver) - val minMem = backend.calculateTotalMemory(sc).toInt + val minMem = backend.calculateTotalMemory(sc) val minCpu = 4 val mesosOffers = new java.util.ArrayList[Offer] @@ -130,11 +140,12 @@ class CoarseMesosSchedulerBackendSuite extends SparkFunSuite test("mesos supports killing and relaunching tasks with executors") { val driver = mock[SchedulerDriver] + when(driver.start()).thenReturn(Protos.Status.DRIVER_RUNNING) val taskScheduler = mock[TaskSchedulerImpl] when(taskScheduler.sc).thenReturn(sc) val backend = createSchedulerBackend(taskScheduler, driver) - val minMem = backend.calculateTotalMemory(sc).toInt + 1024 + val minMem = backend.calculateTotalMemory(sc) + 1024 val minCpu = 4 val mesosOffers = new java.util.ArrayList[Offer] diff --git a/core/src/test/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerBackendSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerBackendSuite.scala index d01837fe78957..5ed30f64d705f 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerBackendSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerBackendSuite.scala @@ -21,6 +21,7 @@ import java.nio.ByteBuffer import java.util import java.util.Collections +import scala.collection.JavaConversions._ import scala.collection.mutable import scala.collection.mutable.ArrayBuffer @@ -60,14 +61,17 @@ class MesosSchedulerBackendSuite extends SparkFunSuite with LocalSparkContext wi val mesosSchedulerBackend = new MesosSchedulerBackend(taskScheduler, sc, "master") + val resources = List( + mesosSchedulerBackend.createResource("cpus", 4), + mesosSchedulerBackend.createResource("mem", 1024)) // uri is null. - val executorInfo = mesosSchedulerBackend.createExecutorInfo("test-id") + val (executorInfo, _) = mesosSchedulerBackend.createExecutorInfo(resources, "test-id") assert(executorInfo.getCommand.getValue === s" /mesos-home/bin/spark-class ${classOf[MesosExecutorBackend].getName}") // uri exists. conf.set("spark.executor.uri", "hdfs:///test-app-1.0.0.tgz") - val executorInfo1 = mesosSchedulerBackend.createExecutorInfo("test-id") + val (executorInfo1, _) = mesosSchedulerBackend.createExecutorInfo(resources, "test-id") assert(executorInfo1.getCommand.getValue === s"cd test-app-1*; ./bin/spark-class ${classOf[MesosExecutorBackend].getName}") } @@ -93,7 +97,8 @@ class MesosSchedulerBackendSuite extends SparkFunSuite with LocalSparkContext wi val backend = new MesosSchedulerBackend(taskScheduler, sc, "master") - val execInfo = backend.createExecutorInfo("mockExecutor") + val (execInfo, _) = backend.createExecutorInfo( + List(backend.createResource("cpus", 4)), "mockExecutor") assert(execInfo.getContainer.getDocker.getImage.equals("spark/mock")) val portmaps = execInfo.getContainer.getDocker.getPortMappingsList assert(portmaps.get(0).getHostPort.equals(80)) @@ -194,7 +199,7 @@ class MesosSchedulerBackendSuite extends SparkFunSuite with LocalSparkContext wi ) verify(driver, times(1)).declineOffer(mesosOffers.get(1).getId) verify(driver, times(1)).declineOffer(mesosOffers.get(2).getId) - assert(capture.getValue.size() == 1) + assert(capture.getValue.size() === 1) val taskInfo = capture.getValue.iterator().next() assert(taskInfo.getName.equals("n1")) val cpus = taskInfo.getResourcesList.get(0) @@ -214,4 +219,97 @@ class MesosSchedulerBackendSuite extends SparkFunSuite with LocalSparkContext wi backend.resourceOffers(driver, mesosOffers2) verify(driver, times(1)).declineOffer(mesosOffers2.get(0).getId) } + + test("can handle multiple roles") { + val driver = mock[SchedulerDriver] + val taskScheduler = mock[TaskSchedulerImpl] + + val listenerBus = mock[LiveListenerBus] + listenerBus.post( + SparkListenerExecutorAdded(anyLong, "s1", new ExecutorInfo("host1", 2, Map.empty))) + + val sc = mock[SparkContext] + when(sc.executorMemory).thenReturn(100) + when(sc.getSparkHome()).thenReturn(Option("/path")) + when(sc.executorEnvs).thenReturn(new mutable.HashMap[String, String]) + when(sc.conf).thenReturn(new SparkConf) + when(sc.listenerBus).thenReturn(listenerBus) + + val id = 1 + val builder = Offer.newBuilder() + builder.addResourcesBuilder() + .setName("mem") + .setType(Value.Type.SCALAR) + .setRole("prod") + .setScalar(Scalar.newBuilder().setValue(500)) + builder.addResourcesBuilder() + .setName("cpus") + .setRole("prod") + .setType(Value.Type.SCALAR) + .setScalar(Scalar.newBuilder().setValue(1)) + builder.addResourcesBuilder() + .setName("mem") + .setRole("dev") + .setType(Value.Type.SCALAR) + .setScalar(Scalar.newBuilder().setValue(600)) + builder.addResourcesBuilder() + .setName("cpus") + .setRole("dev") + .setType(Value.Type.SCALAR) + .setScalar(Scalar.newBuilder().setValue(2)) + val offer = builder.setId(OfferID.newBuilder().setValue(s"o${id.toString}").build()) + .setFrameworkId(FrameworkID.newBuilder().setValue("f1")) + .setSlaveId(SlaveID.newBuilder().setValue(s"s${id.toString}")) + .setHostname(s"host${id.toString}").build() + + + val mesosOffers = new java.util.ArrayList[Offer] + mesosOffers.add(offer) + + val backend = new MesosSchedulerBackend(taskScheduler, sc, "master") + + val expectedWorkerOffers = new ArrayBuffer[WorkerOffer](1) + expectedWorkerOffers.append(new WorkerOffer( + mesosOffers.get(0).getSlaveId.getValue, + mesosOffers.get(0).getHostname, + 2 // Deducting 1 for executor + )) + + val taskDesc = new TaskDescription(1L, 0, "s1", "n1", 0, ByteBuffer.wrap(new Array[Byte](0))) + when(taskScheduler.resourceOffers(expectedWorkerOffers)).thenReturn(Seq(Seq(taskDesc))) + when(taskScheduler.CPUS_PER_TASK).thenReturn(1) + + val capture = ArgumentCaptor.forClass(classOf[util.Collection[TaskInfo]]) + when( + driver.launchTasks( + Matchers.eq(Collections.singleton(mesosOffers.get(0).getId)), + capture.capture(), + any(classOf[Filters]) + ) + ).thenReturn(Status.valueOf(1)) + + backend.resourceOffers(driver, mesosOffers) + + verify(driver, times(1)).launchTasks( + Matchers.eq(Collections.singleton(mesosOffers.get(0).getId)), + capture.capture(), + any(classOf[Filters]) + ) + + assert(capture.getValue.size() === 1) + val taskInfo = capture.getValue.iterator().next() + assert(taskInfo.getName.equals("n1")) + assert(taskInfo.getResourcesCount === 1) + val cpusDev = taskInfo.getResourcesList.get(0) + assert(cpusDev.getName.equals("cpus")) + assert(cpusDev.getScalar.getValue.equals(1.0)) + assert(cpusDev.getRole.equals("dev")) + val executorResources = taskInfo.getExecutor.getResourcesList + assert(executorResources.exists { r => + r.getName.equals("mem") && r.getScalar.getValue.equals(484.0) && r.getRole.equals("prod") + }) + assert(executorResources.exists { r => + r.getName.equals("cpus") && r.getScalar.getValue.equals(1.0) && r.getRole.equals("prod") + }) + } } diff --git a/core/src/test/scala/org/apache/spark/serializer/KryoSerializerDistributedSuite.scala b/core/src/test/scala/org/apache/spark/serializer/KryoSerializerDistributedSuite.scala index 63a8480c9b57b..935a091f14f9b 100644 --- a/core/src/test/scala/org/apache/spark/serializer/KryoSerializerDistributedSuite.scala +++ b/core/src/test/scala/org/apache/spark/serializer/KryoSerializerDistributedSuite.scala @@ -35,7 +35,7 @@ class KryoSerializerDistributedSuite extends SparkFunSuite { val jar = TestUtils.createJarWithClasses(List(AppJarRegistrator.customClassName)) conf.setJars(List(jar.getPath)) - val sc = new SparkContext("local-cluster[2,1,512]", "test", conf) + val sc = new SparkContext("local-cluster[2,1,1024]", "test", conf) val original = Thread.currentThread.getContextClassLoader val loader = new java.net.URLClassLoader(Array(jar), Utils.getContextOrSparkClassLoader) SparkEnv.get.serializer.setDefaultClassLoader(loader) @@ -59,7 +59,9 @@ object KryoDistributedTest { class AppJarRegistrator extends KryoRegistrator { override def registerClasses(k: Kryo) { val classLoader = Thread.currentThread.getContextClassLoader + // scalastyle:off classforname k.register(Class.forName(AppJarRegistrator.customClassName, true, classLoader)) + // scalastyle:on classforname } } diff --git a/core/src/test/scala/org/apache/spark/shuffle/hash/HashShuffleReaderSuite.scala b/core/src/test/scala/org/apache/spark/shuffle/hash/HashShuffleReaderSuite.scala index 28ca68698e3dc..db718ecabbdb9 100644 --- a/core/src/test/scala/org/apache/spark/shuffle/hash/HashShuffleReaderSuite.scala +++ b/core/src/test/scala/org/apache/spark/shuffle/hash/HashShuffleReaderSuite.scala @@ -115,11 +115,15 @@ class HashShuffleReaderSuite extends SparkFunSuite with LocalSparkContext { // Make a mocked MapOutputTracker for the shuffle reader to use to determine what // shuffle data to read. val mapOutputTracker = mock(classOf[MapOutputTracker]) - // Test a scenario where all data is local, just to avoid creating a bunch of additional mocks - // for the code to read data over the network. - val statuses: Array[(BlockManagerId, Long)] = - Array.fill(numMaps)((localBlockManagerId, byteOutputStream.size().toLong)) - when(mapOutputTracker.getServerStatuses(shuffleId, reduceId)).thenReturn(statuses) + when(mapOutputTracker.getMapSizesByExecutorId(shuffleId, reduceId)).thenReturn { + // Test a scenario where all data is local, to avoid creating a bunch of additional mocks + // for the code to read data over the network. + val shuffleBlockIdsAndSizes = (0 until numMaps).map { mapId => + val shuffleBlockId = ShuffleBlockId(shuffleId, mapId, reduceId) + (shuffleBlockId, byteOutputStream.size().toLong) + } + Seq((localBlockManagerId, shuffleBlockIdsAndSizes)) + } // Create a mocked shuffle handle to pass into HashShuffleReader. val shuffleHandle = { @@ -134,7 +138,7 @@ class HashShuffleReaderSuite extends SparkFunSuite with LocalSparkContext { shuffleHandle, reduceId, reduceId + 1, - new TaskContextImpl(0, 0, 0, 0, null), + new TaskContextImpl(0, 0, 0, 0, null, null), blockManager, mapOutputTracker) diff --git a/core/src/test/scala/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriterSuite.scala b/core/src/test/scala/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriterSuite.scala index 542f8f45125a4..cc7342f1ecd78 100644 --- a/core/src/test/scala/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriterSuite.scala +++ b/core/src/test/scala/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriterSuite.scala @@ -68,8 +68,8 @@ class BypassMergeSortShuffleWriterSuite extends SparkFunSuite with BeforeAndAfte any[SerializerInstance], anyInt(), any[ShuffleWriteMetrics] - )).thenAnswer(new Answer[BlockObjectWriter] { - override def answer(invocation: InvocationOnMock): BlockObjectWriter = { + )).thenAnswer(new Answer[DiskBlockObjectWriter] { + override def answer(invocation: InvocationOnMock): DiskBlockObjectWriter = { val args = invocation.getArguments new DiskBlockObjectWriter( args(0).asInstanceOf[BlockId], diff --git a/core/src/test/scala/org/apache/spark/storage/BlockStatusListenerSuite.scala b/core/src/test/scala/org/apache/spark/storage/BlockStatusListenerSuite.scala new file mode 100644 index 0000000000000..d7ffde1e7864e --- /dev/null +++ b/core/src/test/scala/org/apache/spark/storage/BlockStatusListenerSuite.scala @@ -0,0 +1,119 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.storage + +import org.apache.spark.SparkFunSuite +import org.apache.spark.scheduler._ + +class BlockStatusListenerSuite extends SparkFunSuite { + + test("basic functions") { + val blockManagerId = BlockManagerId("0", "localhost", 10000) + val listener = new BlockStatusListener() + + // Add a block manager and a new block status + listener.onBlockManagerAdded(SparkListenerBlockManagerAdded(0, blockManagerId, 0)) + listener.onBlockUpdated(SparkListenerBlockUpdated( + BlockUpdatedInfo( + blockManagerId, + StreamBlockId(0, 100), + StorageLevel.MEMORY_AND_DISK, + memSize = 100, + diskSize = 100, + externalBlockStoreSize = 0))) + // The new block status should be added to the listener + val expectedBlock = BlockUIData( + StreamBlockId(0, 100), + "localhost:10000", + StorageLevel.MEMORY_AND_DISK, + memSize = 100, + diskSize = 100, + externalBlockStoreSize = 0 + ) + val expectedExecutorStreamBlockStatus = Seq( + ExecutorStreamBlockStatus("0", "localhost:10000", Seq(expectedBlock)) + ) + assert(listener.allExecutorStreamBlockStatus === expectedExecutorStreamBlockStatus) + + // Add the second block manager + val blockManagerId2 = BlockManagerId("1", "localhost", 10001) + listener.onBlockManagerAdded(SparkListenerBlockManagerAdded(0, blockManagerId2, 0)) + // Add a new replication of the same block id from the second manager + listener.onBlockUpdated(SparkListenerBlockUpdated( + BlockUpdatedInfo( + blockManagerId2, + StreamBlockId(0, 100), + StorageLevel.MEMORY_AND_DISK, + memSize = 100, + diskSize = 100, + externalBlockStoreSize = 0))) + val expectedBlock2 = BlockUIData( + StreamBlockId(0, 100), + "localhost:10001", + StorageLevel.MEMORY_AND_DISK, + memSize = 100, + diskSize = 100, + externalBlockStoreSize = 0 + ) + // Each block manager should contain one block + val expectedExecutorStreamBlockStatus2 = Set( + ExecutorStreamBlockStatus("0", "localhost:10000", Seq(expectedBlock)), + ExecutorStreamBlockStatus("1", "localhost:10001", Seq(expectedBlock2)) + ) + assert(listener.allExecutorStreamBlockStatus.toSet === expectedExecutorStreamBlockStatus2) + + // Remove a replication of the same block + listener.onBlockUpdated(SparkListenerBlockUpdated( + BlockUpdatedInfo( + blockManagerId2, + StreamBlockId(0, 100), + StorageLevel.NONE, // StorageLevel.NONE means removing it + memSize = 0, + diskSize = 0, + externalBlockStoreSize = 0))) + // Only the first block manager contains a block + val expectedExecutorStreamBlockStatus3 = Set( + ExecutorStreamBlockStatus("0", "localhost:10000", Seq(expectedBlock)), + ExecutorStreamBlockStatus("1", "localhost:10001", Seq.empty) + ) + assert(listener.allExecutorStreamBlockStatus.toSet === expectedExecutorStreamBlockStatus3) + + // Remove the second block manager at first but add a new block status + // from this removed block manager + listener.onBlockManagerRemoved(SparkListenerBlockManagerRemoved(0, blockManagerId2)) + listener.onBlockUpdated(SparkListenerBlockUpdated( + BlockUpdatedInfo( + blockManagerId2, + StreamBlockId(0, 100), + StorageLevel.MEMORY_AND_DISK, + memSize = 100, + diskSize = 100, + externalBlockStoreSize = 0))) + // The second block manager is removed so we should not see the new block + val expectedExecutorStreamBlockStatus4 = Seq( + ExecutorStreamBlockStatus("0", "localhost:10000", Seq(expectedBlock)) + ) + assert(listener.allExecutorStreamBlockStatus === expectedExecutorStreamBlockStatus4) + + // Remove the last block manager + listener.onBlockManagerRemoved(SparkListenerBlockManagerRemoved(0, blockManagerId)) + // No block manager now so we should dop all block managers + assert(listener.allExecutorStreamBlockStatus.isEmpty) + } + +} diff --git a/core/src/test/scala/org/apache/spark/storage/BlockObjectWriterSuite.scala b/core/src/test/scala/org/apache/spark/storage/DiskBlockObjectWriterSuite.scala similarity index 98% rename from core/src/test/scala/org/apache/spark/storage/BlockObjectWriterSuite.scala rename to core/src/test/scala/org/apache/spark/storage/DiskBlockObjectWriterSuite.scala index 7bdea724fea58..66af6e1a79740 100644 --- a/core/src/test/scala/org/apache/spark/storage/BlockObjectWriterSuite.scala +++ b/core/src/test/scala/org/apache/spark/storage/DiskBlockObjectWriterSuite.scala @@ -26,7 +26,7 @@ import org.apache.spark.executor.ShuffleWriteMetrics import org.apache.spark.serializer.JavaSerializer import org.apache.spark.util.Utils -class BlockObjectWriterSuite extends SparkFunSuite with BeforeAndAfterEach { +class DiskBlockObjectWriterSuite extends SparkFunSuite with BeforeAndAfterEach { var tempDir: File = _ diff --git a/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala b/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala index 9ced4148d7206..cf8bd8ae69625 100644 --- a/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala +++ b/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala @@ -33,6 +33,7 @@ import org.apache.spark.{SparkFunSuite, TaskContextImpl} import org.apache.spark.network._ import org.apache.spark.network.buffer.ManagedBuffer import org.apache.spark.network.shuffle.BlockFetchingListener +import org.apache.spark.shuffle.FetchFailedException class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodTester { @@ -94,7 +95,7 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT ) val iterator = new ShuffleBlockFetcherIterator( - new TaskContextImpl(0, 0, 0, 0, null), + new TaskContextImpl(0, 0, 0, 0, null, null), transfer, blockManager, blocksByAddress, @@ -106,13 +107,11 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT for (i <- 0 until 5) { assert(iterator.hasNext, s"iterator should have 5 elements but actually has $i elements") val (blockId, inputStream) = iterator.next() - assert(inputStream.isSuccess, - s"iterator should have 5 elements defined but actually has $i elements") // Make sure we release buffers when a wrapped input stream is closed. val mockBuf = localBlocks.getOrElse(blockId, remoteBlocks(blockId)) // Note: ShuffleBlockFetcherIterator wraps input streams in a BufferReleasingInputStream - val wrappedInputStream = inputStream.get.asInstanceOf[BufferReleasingInputStream] + val wrappedInputStream = inputStream.asInstanceOf[BufferReleasingInputStream] verify(mockBuf, times(0)).release() val delegateAccess = PrivateMethod[InputStream]('delegate) @@ -166,7 +165,7 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT val blocksByAddress = Seq[(BlockManagerId, Seq[(BlockId, Long)])]( (remoteBmId, blocks.keys.map(blockId => (blockId, 1.asInstanceOf[Long])).toSeq)) - val taskContext = new TaskContextImpl(0, 0, 0, 0, null) + val taskContext = new TaskContextImpl(0, 0, 0, 0, null, null) val iterator = new ShuffleBlockFetcherIterator( taskContext, transfer, @@ -175,11 +174,11 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT 48 * 1024 * 1024) verify(blocks(ShuffleBlockId(0, 0, 0)), times(0)).release() - iterator.next()._2.get.close() // close() first block's input stream + iterator.next()._2.close() // close() first block's input stream verify(blocks(ShuffleBlockId(0, 0, 0)), times(1)).release() // Get the 2nd block but do not exhaust the iterator - val subIter = iterator.next()._2.get + val subIter = iterator.next()._2 // Complete the task; then the 2nd block buffer should be exhausted verify(blocks(ShuffleBlockId(0, 1, 0)), times(0)).release() @@ -228,7 +227,7 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT val blocksByAddress = Seq[(BlockManagerId, Seq[(BlockId, Long)])]( (remoteBmId, blocks.keys.map(blockId => (blockId, 1.asInstanceOf[Long])).toSeq)) - val taskContext = new TaskContextImpl(0, 0, 0, 0, null) + val taskContext = new TaskContextImpl(0, 0, 0, 0, null, null) val iterator = new ShuffleBlockFetcherIterator( taskContext, transfer, @@ -239,9 +238,10 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT // Continue only after the mock calls onBlockFetchFailure sem.acquire() - // The first block should be defined, and the last two are not defined (due to failure) - assert(iterator.next()._2.isSuccess) - assert(iterator.next()._2.isFailure) - assert(iterator.next()._2.isFailure) + // The first block should be returned without an exception, and the last two should throw + // FetchFailedExceptions (due to failure) + iterator.next() + intercept[FetchFailedException] { iterator.next() } + intercept[FetchFailedException] { iterator.next() } } } diff --git a/core/src/test/scala/org/apache/spark/ui/PagedTableSuite.scala b/core/src/test/scala/org/apache/spark/ui/PagedTableSuite.scala new file mode 100644 index 0000000000000..cc76c141c53cc --- /dev/null +++ b/core/src/test/scala/org/apache/spark/ui/PagedTableSuite.scala @@ -0,0 +1,99 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ui + +import scala.xml.Node + +import org.apache.spark.SparkFunSuite + +class PagedDataSourceSuite extends SparkFunSuite { + + test("basic") { + val dataSource1 = new SeqPagedDataSource[Int](1 to 5, pageSize = 2) + assert(dataSource1.pageData(1) === PageData(3, (1 to 2))) + + val dataSource2 = new SeqPagedDataSource[Int](1 to 5, pageSize = 2) + assert(dataSource2.pageData(2) === PageData(3, (3 to 4))) + + val dataSource3 = new SeqPagedDataSource[Int](1 to 5, pageSize = 2) + assert(dataSource3.pageData(3) === PageData(3, Seq(5))) + + val dataSource4 = new SeqPagedDataSource[Int](1 to 5, pageSize = 2) + val e1 = intercept[IndexOutOfBoundsException] { + dataSource4.pageData(4) + } + assert(e1.getMessage === "Page 4 is out of range. Please select a page number between 1 and 3.") + + val dataSource5 = new SeqPagedDataSource[Int](1 to 5, pageSize = 2) + val e2 = intercept[IndexOutOfBoundsException] { + dataSource5.pageData(0) + } + assert(e2.getMessage === "Page 0 is out of range. Please select a page number between 1 and 3.") + + } +} + +class PagedTableSuite extends SparkFunSuite { + test("pageNavigation") { + // Create a fake PagedTable to test pageNavigation + val pagedTable = new PagedTable[Int] { + override def tableId: String = "" + + override def tableCssClass: String = "" + + override def dataSource: PagedDataSource[Int] = null + + override def pageLink(page: Int): String = page.toString + + override def headers: Seq[Node] = Nil + + override def row(t: Int): Seq[Node] = Nil + + override def goButtonJavascriptFunction: (String, String) = ("", "") + } + + assert(pagedTable.pageNavigation(1, 10, 1) === Nil) + assert( + (pagedTable.pageNavigation(1, 10, 2).head \\ "li").map(_.text.trim) === Seq("1", "2", ">")) + assert( + (pagedTable.pageNavigation(2, 10, 2).head \\ "li").map(_.text.trim) === Seq("<", "1", "2")) + + assert((pagedTable.pageNavigation(1, 10, 100).head \\ "li").map(_.text.trim) === + (1 to 10).map(_.toString) ++ Seq(">", ">>")) + assert((pagedTable.pageNavigation(2, 10, 100).head \\ "li").map(_.text.trim) === + Seq("<") ++ (1 to 10).map(_.toString) ++ Seq(">", ">>")) + + assert((pagedTable.pageNavigation(100, 10, 100).head \\ "li").map(_.text.trim) === + Seq("<<", "<") ++ (91 to 100).map(_.toString)) + assert((pagedTable.pageNavigation(99, 10, 100).head \\ "li").map(_.text.trim) === + Seq("<<", "<") ++ (91 to 100).map(_.toString) ++ Seq(">")) + + assert((pagedTable.pageNavigation(11, 10, 100).head \\ "li").map(_.text.trim) === + Seq("<<", "<") ++ (11 to 20).map(_.toString) ++ Seq(">", ">>")) + assert((pagedTable.pageNavigation(93, 10, 97).head \\ "li").map(_.text.trim) === + Seq("<<", "<") ++ (91 to 97).map(_.toString) ++ Seq(">")) + } +} + +private[spark] class SeqPagedDataSource[T](seq: Seq[T], pageSize: Int) + extends PagedDataSource[T](pageSize) { + + override protected def dataSize: Int = seq.size + + override protected def sliceData(from: Int, to: Int): Seq[T] = seq.slice(from, to) +} diff --git a/core/src/test/scala/org/apache/spark/ui/storage/StoragePageSuite.scala b/core/src/test/scala/org/apache/spark/ui/storage/StoragePageSuite.scala new file mode 100644 index 0000000000000..3dab15a9d4691 --- /dev/null +++ b/core/src/test/scala/org/apache/spark/ui/storage/StoragePageSuite.scala @@ -0,0 +1,230 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ui.storage + +import scala.xml.Utility + +import org.mockito.Mockito._ + +import org.apache.spark.SparkFunSuite +import org.apache.spark.storage._ + +class StoragePageSuite extends SparkFunSuite { + + val storageTab = mock(classOf[StorageTab]) + when(storageTab.basePath).thenReturn("http://localhost:4040") + val storagePage = new StoragePage(storageTab) + + test("rddTable") { + val rdd1 = new RDDInfo(1, + "rdd1", + 10, + StorageLevel.MEMORY_ONLY, + Seq.empty) + rdd1.memSize = 100 + rdd1.numCachedPartitions = 10 + + val rdd2 = new RDDInfo(2, + "rdd2", + 10, + StorageLevel.DISK_ONLY, + Seq.empty) + rdd2.diskSize = 200 + rdd2.numCachedPartitions = 5 + + val rdd3 = new RDDInfo(3, + "rdd3", + 10, + StorageLevel.MEMORY_AND_DISK_SER, + Seq.empty) + rdd3.memSize = 400 + rdd3.diskSize = 500 + rdd3.numCachedPartitions = 10 + + val xmlNodes = storagePage.rddTable(Seq(rdd1, rdd2, rdd3)) + + val headers = Seq( + "RDD Name", + "Storage Level", + "Cached Partitions", + "Fraction Cached", + "Size in Memory", + "Size in ExternalBlockStore", + "Size on Disk") + assert((xmlNodes \\ "th").map(_.text) === headers) + + assert((xmlNodes \\ "tr").size === 3) + assert(((xmlNodes \\ "tr")(0) \\ "td").map(_.text.trim) === + Seq("rdd1", "Memory Deserialized 1x Replicated", "10", "100%", "100.0 B", "0.0 B", "0.0 B")) + // Check the url + assert(((xmlNodes \\ "tr")(0) \\ "td" \ "a")(0).attribute("href").map(_.text) === + Some("http://localhost:4040/storage/rdd?id=1")) + + assert(((xmlNodes \\ "tr")(1) \\ "td").map(_.text.trim) === + Seq("rdd2", "Disk Serialized 1x Replicated", "5", "50%", "0.0 B", "0.0 B", "200.0 B")) + // Check the url + assert(((xmlNodes \\ "tr")(1) \\ "td" \ "a")(0).attribute("href").map(_.text) === + Some("http://localhost:4040/storage/rdd?id=2")) + + assert(((xmlNodes \\ "tr")(2) \\ "td").map(_.text.trim) === + Seq("rdd3", "Disk Memory Serialized 1x Replicated", "10", "100%", "400.0 B", "0.0 B", + "500.0 B")) + // Check the url + assert(((xmlNodes \\ "tr")(2) \\ "td" \ "a")(0).attribute("href").map(_.text) === + Some("http://localhost:4040/storage/rdd?id=3")) + } + + test("empty rddTable") { + assert(storagePage.rddTable(Seq.empty).isEmpty) + } + + test("streamBlockStorageLevelDescriptionAndSize") { + val memoryBlock = BlockUIData(StreamBlockId(0, 0), + "localhost:1111", + StorageLevel.MEMORY_ONLY, + memSize = 100, + diskSize = 0, + externalBlockStoreSize = 0) + assert(("Memory", 100) === storagePage.streamBlockStorageLevelDescriptionAndSize(memoryBlock)) + + val memorySerializedBlock = BlockUIData(StreamBlockId(0, 0), + "localhost:1111", + StorageLevel.MEMORY_ONLY_SER, + memSize = 100, + diskSize = 0, + externalBlockStoreSize = 0) + assert(("Memory Serialized", 100) === + storagePage.streamBlockStorageLevelDescriptionAndSize(memorySerializedBlock)) + + val diskBlock = BlockUIData(StreamBlockId(0, 0), + "localhost:1111", + StorageLevel.DISK_ONLY, + memSize = 0, + diskSize = 100, + externalBlockStoreSize = 0) + assert(("Disk", 100) === storagePage.streamBlockStorageLevelDescriptionAndSize(diskBlock)) + + val externalBlock = BlockUIData(StreamBlockId(0, 0), + "localhost:1111", + StorageLevel.OFF_HEAP, + memSize = 0, + diskSize = 0, + externalBlockStoreSize = 100) + assert(("External", 100) === + storagePage.streamBlockStorageLevelDescriptionAndSize(externalBlock)) + } + + test("receiverBlockTables") { + val blocksForExecutor0 = Seq( + BlockUIData(StreamBlockId(0, 0), + "localhost:10000", + StorageLevel.MEMORY_ONLY, + memSize = 100, + diskSize = 0, + externalBlockStoreSize = 0), + BlockUIData(StreamBlockId(1, 1), + "localhost:10000", + StorageLevel.DISK_ONLY, + memSize = 0, + diskSize = 100, + externalBlockStoreSize = 0) + ) + val executor0 = ExecutorStreamBlockStatus("0", "localhost:10000", blocksForExecutor0) + + val blocksForExecutor1 = Seq( + BlockUIData(StreamBlockId(0, 0), + "localhost:10001", + StorageLevel.MEMORY_ONLY, + memSize = 100, + diskSize = 0, + externalBlockStoreSize = 0), + BlockUIData(StreamBlockId(2, 2), + "localhost:10001", + StorageLevel.OFF_HEAP, + memSize = 0, + diskSize = 0, + externalBlockStoreSize = 200), + BlockUIData(StreamBlockId(1, 1), + "localhost:10001", + StorageLevel.MEMORY_ONLY_SER, + memSize = 100, + diskSize = 0, + externalBlockStoreSize = 0) + ) + val executor1 = ExecutorStreamBlockStatus("1", "localhost:10001", blocksForExecutor1) + val xmlNodes = storagePage.receiverBlockTables(Seq(executor0, executor1)) + + val executorTable = (xmlNodes \\ "table")(0) + val executorHeaders = Seq( + "Executor ID", + "Address", + "Total Size in Memory", + "Total Size in ExternalBlockStore", + "Total Size on Disk", + "Stream Blocks") + assert((executorTable \\ "th").map(_.text) === executorHeaders) + + assert((executorTable \\ "tr").size === 2) + assert(((executorTable \\ "tr")(0) \\ "td").map(_.text.trim) === + Seq("0", "localhost:10000", "100.0 B", "0.0 B", "100.0 B", "2")) + assert(((executorTable \\ "tr")(1) \\ "td").map(_.text.trim) === + Seq("1", "localhost:10001", "200.0 B", "200.0 B", "0.0 B", "3")) + + val blockTable = (xmlNodes \\ "table")(1) + val blockHeaders = Seq( + "Block ID", + "Replication Level", + "Location", + "Storage Level", + "Size") + assert((blockTable \\ "th").map(_.text) === blockHeaders) + + assert((blockTable \\ "tr").size === 5) + assert(((blockTable \\ "tr")(0) \\ "td").map(_.text.trim) === + Seq("input-0-0", "2", "localhost:10000", "Memory", "100.0 B")) + // Check "rowspan=2" for the first 2 columns + assert(((blockTable \\ "tr")(0) \\ "td")(0).attribute("rowspan").map(_.text) === Some("2")) + assert(((blockTable \\ "tr")(0) \\ "td")(1).attribute("rowspan").map(_.text) === Some("2")) + + assert(((blockTable \\ "tr")(1) \\ "td").map(_.text.trim) === + Seq("localhost:10001", "Memory", "100.0 B")) + + assert(((blockTable \\ "tr")(2) \\ "td").map(_.text.trim) === + Seq("input-1-1", "2", "localhost:10000", "Disk", "100.0 B")) + // Check "rowspan=2" for the first 2 columns + assert(((blockTable \\ "tr")(2) \\ "td")(0).attribute("rowspan").map(_.text) === Some("2")) + assert(((blockTable \\ "tr")(2) \\ "td")(1).attribute("rowspan").map(_.text) === Some("2")) + + assert(((blockTable \\ "tr")(3) \\ "td").map(_.text.trim) === + Seq("localhost:10001", "Memory Serialized", "100.0 B")) + + assert(((blockTable \\ "tr")(4) \\ "td").map(_.text.trim) === + Seq("input-2-2", "1", "localhost:10001", "External", "200.0 B")) + // Check "rowspan=1" for the first 2 columns + assert(((blockTable \\ "tr")(4) \\ "td")(0).attribute("rowspan").map(_.text) === Some("1")) + assert(((blockTable \\ "tr")(4) \\ "td")(1).attribute("rowspan").map(_.text) === Some("1")) + } + + test("empty receiverBlockTables") { + assert(storagePage.receiverBlockTables(Seq.empty).isEmpty) + + val executor0 = ExecutorStreamBlockStatus("0", "localhost:10000", Seq.empty) + val executor1 = ExecutorStreamBlockStatus("1", "localhost:10001", Seq.empty) + assert(storagePage.receiverBlockTables(Seq(executor0, executor1)).isEmpty) + } +} diff --git a/core/src/test/scala/org/apache/spark/util/AkkaUtilsSuite.scala b/core/src/test/scala/org/apache/spark/util/AkkaUtilsSuite.scala index 6c40685484ed4..61601016e005e 100644 --- a/core/src/test/scala/org/apache/spark/util/AkkaUtilsSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/AkkaUtilsSuite.scala @@ -17,6 +17,8 @@ package org.apache.spark.util +import scala.collection.mutable.ArrayBuffer + import java.util.concurrent.TimeoutException import akka.actor.ActorNotFound @@ -24,7 +26,7 @@ import akka.actor.ActorNotFound import org.apache.spark._ import org.apache.spark.rpc.RpcEnv import org.apache.spark.scheduler.MapStatus -import org.apache.spark.storage.BlockManagerId +import org.apache.spark.storage.{BlockManagerId, ShuffleBlockId} import org.apache.spark.SSLSampleConfigs._ @@ -107,8 +109,9 @@ class AkkaUtilsSuite extends SparkFunSuite with LocalSparkContext with ResetSyst slaveTracker.updateEpoch(masterTracker.getEpoch) // this should succeed since security off - assert(slaveTracker.getServerStatuses(10, 0).toSeq === - Seq((BlockManagerId("a", "hostA", 1000), size1000))) + assert(slaveTracker.getMapSizesByExecutorId(10, 0).toSeq === + Seq((BlockManagerId("a", "hostA", 1000), + ArrayBuffer((ShuffleBlockId(10, 0, 0), size1000))))) rpcEnv.shutdown() slaveRpcEnv.shutdown() @@ -153,8 +156,9 @@ class AkkaUtilsSuite extends SparkFunSuite with LocalSparkContext with ResetSyst slaveTracker.updateEpoch(masterTracker.getEpoch) // this should succeed since security on and passwords match - assert(slaveTracker.getServerStatuses(10, 0).toSeq === - Seq((BlockManagerId("a", "hostA", 1000), size1000))) + assert(slaveTracker.getMapSizesByExecutorId(10, 0) === + Seq((BlockManagerId("a", "hostA", 1000), + ArrayBuffer((ShuffleBlockId(10, 0, 0), size1000))))) rpcEnv.shutdown() slaveRpcEnv.shutdown() @@ -232,8 +236,8 @@ class AkkaUtilsSuite extends SparkFunSuite with LocalSparkContext with ResetSyst slaveTracker.updateEpoch(masterTracker.getEpoch) // this should succeed since security off - assert(slaveTracker.getServerStatuses(10, 0).toSeq === - Seq((BlockManagerId("a", "hostA", 1000), size1000))) + assert(slaveTracker.getMapSizesByExecutorId(10, 0) === + Seq((BlockManagerId("a", "hostA", 1000), ArrayBuffer((ShuffleBlockId(10, 0, 0), size1000))))) rpcEnv.shutdown() slaveRpcEnv.shutdown() @@ -278,8 +282,8 @@ class AkkaUtilsSuite extends SparkFunSuite with LocalSparkContext with ResetSyst masterTracker.incrementEpoch() slaveTracker.updateEpoch(masterTracker.getEpoch) - assert(slaveTracker.getServerStatuses(10, 0).toSeq === - Seq((BlockManagerId("a", "hostA", 1000), size1000))) + assert(slaveTracker.getMapSizesByExecutorId(10, 0) === + Seq((BlockManagerId("a", "hostA", 1000), ArrayBuffer((ShuffleBlockId(10, 0, 0), size1000))))) rpcEnv.shutdown() slaveRpcEnv.shutdown() diff --git a/core/src/test/scala/org/apache/spark/util/ClosureCleanerSuite2.scala b/core/src/test/scala/org/apache/spark/util/ClosureCleanerSuite2.scala index 3147c937769d2..a829b099025e9 100644 --- a/core/src/test/scala/org/apache/spark/util/ClosureCleanerSuite2.scala +++ b/core/src/test/scala/org/apache/spark/util/ClosureCleanerSuite2.scala @@ -120,8 +120,8 @@ class ClosureCleanerSuite2 extends SparkFunSuite with BeforeAndAfterAll with Pri // Accessors for private methods private val _isClosure = PrivateMethod[Boolean]('isClosure) private val _getInnerClosureClasses = PrivateMethod[List[Class[_]]]('getInnerClosureClasses) - private val _getOuterClasses = PrivateMethod[List[Class[_]]]('getOuterClasses) - private val _getOuterObjects = PrivateMethod[List[AnyRef]]('getOuterObjects) + private val _getOuterClassesAndObjects = + PrivateMethod[(List[Class[_]], List[AnyRef])]('getOuterClassesAndObjects) private def isClosure(obj: AnyRef): Boolean = { ClosureCleaner invokePrivate _isClosure(obj) @@ -131,12 +131,8 @@ class ClosureCleanerSuite2 extends SparkFunSuite with BeforeAndAfterAll with Pri ClosureCleaner invokePrivate _getInnerClosureClasses(closure) } - private def getOuterClasses(closure: AnyRef): List[Class[_]] = { - ClosureCleaner invokePrivate _getOuterClasses(closure) - } - - private def getOuterObjects(closure: AnyRef): List[AnyRef] = { - ClosureCleaner invokePrivate _getOuterObjects(closure) + private def getOuterClassesAndObjects(closure: AnyRef): (List[Class[_]], List[AnyRef]) = { + ClosureCleaner invokePrivate _getOuterClassesAndObjects(closure) } test("get inner closure classes") { @@ -171,14 +167,11 @@ class ClosureCleanerSuite2 extends SparkFunSuite with BeforeAndAfterAll with Pri val closure2 = () => localValue val closure3 = () => someSerializableValue val closure4 = () => someSerializableMethod() - val outerClasses1 = getOuterClasses(closure1) - val outerClasses2 = getOuterClasses(closure2) - val outerClasses3 = getOuterClasses(closure3) - val outerClasses4 = getOuterClasses(closure4) - val outerObjects1 = getOuterObjects(closure1) - val outerObjects2 = getOuterObjects(closure2) - val outerObjects3 = getOuterObjects(closure3) - val outerObjects4 = getOuterObjects(closure4) + + val (outerClasses1, outerObjects1) = getOuterClassesAndObjects(closure1) + val (outerClasses2, outerObjects2) = getOuterClassesAndObjects(closure2) + val (outerClasses3, outerObjects3) = getOuterClassesAndObjects(closure3) + val (outerClasses4, outerObjects4) = getOuterClassesAndObjects(closure4) // The classes and objects should have the same size assert(outerClasses1.size === outerObjects1.size) @@ -211,10 +204,8 @@ class ClosureCleanerSuite2 extends SparkFunSuite with BeforeAndAfterAll with Pri val x = 1 val closure1 = () => 1 val closure2 = () => x - val outerClasses1 = getOuterClasses(closure1) - val outerClasses2 = getOuterClasses(closure2) - val outerObjects1 = getOuterObjects(closure1) - val outerObjects2 = getOuterObjects(closure2) + val (outerClasses1, outerObjects1) = getOuterClassesAndObjects(closure1) + val (outerClasses2, outerObjects2) = getOuterClassesAndObjects(closure2) assert(outerClasses1.size === outerObjects1.size) assert(outerClasses2.size === outerObjects2.size) // These inner closures only reference local variables, and so do not have $outer pointers @@ -227,12 +218,9 @@ class ClosureCleanerSuite2 extends SparkFunSuite with BeforeAndAfterAll with Pri val closure1 = () => 1 val closure2 = () => y val closure3 = () => localValue - val outerClasses1 = getOuterClasses(closure1) - val outerClasses2 = getOuterClasses(closure2) - val outerClasses3 = getOuterClasses(closure3) - val outerObjects1 = getOuterObjects(closure1) - val outerObjects2 = getOuterObjects(closure2) - val outerObjects3 = getOuterObjects(closure3) + val (outerClasses1, outerObjects1) = getOuterClassesAndObjects(closure1) + val (outerClasses2, outerObjects2) = getOuterClassesAndObjects(closure2) + val (outerClasses3, outerObjects3) = getOuterClassesAndObjects(closure3) assert(outerClasses1.size === outerObjects1.size) assert(outerClasses2.size === outerObjects2.size) assert(outerClasses3.size === outerObjects3.size) @@ -265,9 +253,9 @@ class ClosureCleanerSuite2 extends SparkFunSuite with BeforeAndAfterAll with Pri val closure1 = () => 1 val closure2 = () => localValue val closure3 = () => someSerializableValue - val outerClasses1 = getOuterClasses(closure1) - val outerClasses2 = getOuterClasses(closure2) - val outerClasses3 = getOuterClasses(closure3) + val (outerClasses1, _) = getOuterClassesAndObjects(closure1) + val (outerClasses2, _) = getOuterClassesAndObjects(closure2) + val (outerClasses3, _) = getOuterClassesAndObjects(closure3) val fields1 = findAccessedFields(closure1, outerClasses1, findTransitively = false) val fields2 = findAccessedFields(closure2, outerClasses2, findTransitively = false) @@ -307,10 +295,10 @@ class ClosureCleanerSuite2 extends SparkFunSuite with BeforeAndAfterAll with Pri val closure2 = () => a val closure3 = () => localValue val closure4 = () => someSerializableValue - val outerClasses1 = getOuterClasses(closure1) - val outerClasses2 = getOuterClasses(closure2) - val outerClasses3 = getOuterClasses(closure3) - val outerClasses4 = getOuterClasses(closure4) + val (outerClasses1, _) = getOuterClassesAndObjects(closure1) + val (outerClasses2, _) = getOuterClassesAndObjects(closure2) + val (outerClasses3, _) = getOuterClassesAndObjects(closure3) + val (outerClasses4, _) = getOuterClassesAndObjects(closure4) // First, find only fields accessed directly, not transitively, by these closures val fields1 = findAccessedFields(closure1, outerClasses1, findTransitively = false) diff --git a/core/src/test/scala/org/apache/spark/util/JsonProtocolSuite.scala b/core/src/test/scala/org/apache/spark/util/JsonProtocolSuite.scala index e0ef9c70a5fc3..dde95f3778434 100644 --- a/core/src/test/scala/org/apache/spark/util/JsonProtocolSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/JsonProtocolSuite.scala @@ -83,6 +83,9 @@ class JsonProtocolSuite extends SparkFunSuite { val executorAdded = SparkListenerExecutorAdded(executorAddedTime, "exec1", new ExecutorInfo("Hostee.awesome.com", 11, logUrlMap)) val executorRemoved = SparkListenerExecutorRemoved(executorRemovedTime, "exec2", "test reason") + val executorMetricsUpdate = SparkListenerExecutorMetricsUpdate("exec3", Seq( + (1L, 2, 3, makeTaskMetrics(300L, 400L, 500L, 600L, 700, 800, + hasHadoopInput = true, hasOutput = true)))) testEvent(stageSubmitted, stageSubmittedJsonString) testEvent(stageCompleted, stageCompletedJsonString) @@ -102,6 +105,7 @@ class JsonProtocolSuite extends SparkFunSuite { testEvent(applicationEnd, applicationEndJsonString) testEvent(executorAdded, executorAddedJsonString) testEvent(executorRemoved, executorRemovedJsonString) + testEvent(executorMetricsUpdate, executorMetricsUpdateJsonString) } test("Dependent Classes") { @@ -440,10 +444,20 @@ class JsonProtocolSuite extends SparkFunSuite { case (e1: SparkListenerEnvironmentUpdate, e2: SparkListenerEnvironmentUpdate) => assertEquals(e1.environmentDetails, e2.environmentDetails) case (e1: SparkListenerExecutorAdded, e2: SparkListenerExecutorAdded) => - assert(e1.executorId == e1.executorId) + assert(e1.executorId === e1.executorId) assertEquals(e1.executorInfo, e2.executorInfo) case (e1: SparkListenerExecutorRemoved, e2: SparkListenerExecutorRemoved) => - assert(e1.executorId == e1.executorId) + assert(e1.executorId === e1.executorId) + case (e1: SparkListenerExecutorMetricsUpdate, e2: SparkListenerExecutorMetricsUpdate) => + assert(e1.execId === e2.execId) + assertSeqEquals[(Long, Int, Int, TaskMetrics)](e1.taskMetrics, e2.taskMetrics, (a, b) => { + val (taskId1, stageId1, stageAttemptId1, metrics1) = a + val (taskId2, stageId2, stageAttemptId2, metrics2) = b + assert(taskId1 === taskId2) + assert(stageId1 === stageId2) + assert(stageAttemptId1 === stageAttemptId2) + assertEquals(metrics1, metrics2) + }) case (e1, e2) => assert(e1 === e2) case _ => fail("Events don't match in types!") @@ -1598,4 +1612,55 @@ class JsonProtocolSuite extends SparkFunSuite { | "Removed Reason": "test reason" |} """ + + private val executorMetricsUpdateJsonString = + s""" + |{ + | "Event": "SparkListenerExecutorMetricsUpdate", + | "Executor ID": "exec3", + | "Metrics Updated": [ + | { + | "Task ID": 1, + | "Stage ID": 2, + | "Stage Attempt ID": 3, + | "Task Metrics": { + | "Host Name": "localhost", + | "Executor Deserialize Time": 300, + | "Executor Run Time": 400, + | "Result Size": 500, + | "JVM GC Time": 600, + | "Result Serialization Time": 700, + | "Memory Bytes Spilled": 800, + | "Disk Bytes Spilled": 0, + | "Input Metrics": { + | "Data Read Method": "Hadoop", + | "Bytes Read": 2100, + | "Records Read": 21 + | }, + | "Output Metrics": { + | "Data Write Method": "Hadoop", + | "Bytes Written": 1200, + | "Records Written": 12 + | }, + | "Updated Blocks": [ + | { + | "Block ID": "rdd_0_0", + | "Status": { + | "Storage Level": { + | "Use Disk": true, + | "Use Memory": true, + | "Use ExternalBlockStore": false, + | "Deserialized": false, + | "Replication": 2 + | }, + | "Memory Size": 0, + | "ExternalBlockStore Size": 0, + | "Disk Size": 0 + | } + | } + | ] + | } + | }] + |} + """.stripMargin } diff --git a/core/src/test/scala/org/apache/spark/util/MutableURLClassLoaderSuite.scala b/core/src/test/scala/org/apache/spark/util/MutableURLClassLoaderSuite.scala index 42125547436cb..d3d464e84ffd7 100644 --- a/core/src/test/scala/org/apache/spark/util/MutableURLClassLoaderSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/MutableURLClassLoaderSuite.scala @@ -84,7 +84,9 @@ class MutableURLClassLoaderSuite extends SparkFunSuite { try { sc.makeRDD(1 to 5, 2).mapPartitions { x => val loader = Thread.currentThread().getContextClassLoader + // scalastyle:off classforname Class.forName(className, true, loader).newInstance() + // scalastyle:on classforname Seq().iterator }.count() } diff --git a/core/src/test/scala/org/apache/spark/util/UtilsSuite.scala b/core/src/test/scala/org/apache/spark/util/UtilsSuite.scala index c7638507c88c6..8f7e402d5f2a6 100644 --- a/core/src/test/scala/org/apache/spark/util/UtilsSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/UtilsSuite.scala @@ -18,6 +18,7 @@ package org.apache.spark.util import java.io.{File, ByteArrayOutputStream, ByteArrayInputStream, FileOutputStream} +import java.lang.{Double => JDouble, Float => JFloat} import java.net.{BindException, ServerSocket, URI} import java.nio.{ByteBuffer, ByteOrder} import java.text.DecimalFormatSymbols @@ -689,4 +690,34 @@ class UtilsSuite extends SparkFunSuite with ResetSystemProperties with Logging { // scalastyle:on println assert(buffer.toString === "t circular test circular\n") } + + test("nanSafeCompareDoubles") { + def shouldMatchDefaultOrder(a: Double, b: Double): Unit = { + assert(Utils.nanSafeCompareDoubles(a, b) === JDouble.compare(a, b)) + assert(Utils.nanSafeCompareDoubles(b, a) === JDouble.compare(b, a)) + } + shouldMatchDefaultOrder(0d, 0d) + shouldMatchDefaultOrder(0d, 1d) + shouldMatchDefaultOrder(Double.MinValue, Double.MaxValue) + assert(Utils.nanSafeCompareDoubles(Double.NaN, Double.NaN) === 0) + assert(Utils.nanSafeCompareDoubles(Double.NaN, Double.PositiveInfinity) === 1) + assert(Utils.nanSafeCompareDoubles(Double.NaN, Double.NegativeInfinity) === 1) + assert(Utils.nanSafeCompareDoubles(Double.PositiveInfinity, Double.NaN) === -1) + assert(Utils.nanSafeCompareDoubles(Double.NegativeInfinity, Double.NaN) === -1) + } + + test("nanSafeCompareFloats") { + def shouldMatchDefaultOrder(a: Float, b: Float): Unit = { + assert(Utils.nanSafeCompareFloats(a, b) === JFloat.compare(a, b)) + assert(Utils.nanSafeCompareFloats(b, a) === JFloat.compare(b, a)) + } + shouldMatchDefaultOrder(0f, 0f) + shouldMatchDefaultOrder(1f, 1f) + shouldMatchDefaultOrder(Float.MinValue, Float.MaxValue) + assert(Utils.nanSafeCompareFloats(Float.NaN, Float.NaN) === 0) + assert(Utils.nanSafeCompareFloats(Float.NaN, Float.PositiveInfinity) === 1) + assert(Utils.nanSafeCompareFloats(Float.NaN, Float.NegativeInfinity) === 1) + assert(Utils.nanSafeCompareFloats(Float.PositiveInfinity, Float.NaN) === -1) + assert(Utils.nanSafeCompareFloats(Float.NegativeInfinity, Float.NaN) === -1) + } } diff --git a/core/src/test/scala/org/apache/spark/util/collection/ExternalAppendOnlyMapSuite.scala b/core/src/test/scala/org/apache/spark/util/collection/ExternalAppendOnlyMapSuite.scala index 79eba61a87251..9c362f0de7076 100644 --- a/core/src/test/scala/org/apache/spark/util/collection/ExternalAppendOnlyMapSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/collection/ExternalAppendOnlyMapSuite.scala @@ -244,7 +244,7 @@ class ExternalAppendOnlyMapSuite extends SparkFunSuite with LocalSparkContext { private def testSimpleSpilling(codec: Option[String] = None): Unit = { val conf = createSparkConf(loadDefaults = true, codec) // Load defaults for Spark home conf.set("spark.shuffle.memoryFraction", "0.001") - sc = new SparkContext("local-cluster[1,1,512]", "test", conf) + sc = new SparkContext("local-cluster[1,1,1024]", "test", conf) // reduceByKey - should spill ~8 times val rddA = sc.parallelize(0 until 100000).map(i => (i/2, i)) @@ -292,7 +292,7 @@ class ExternalAppendOnlyMapSuite extends SparkFunSuite with LocalSparkContext { test("spilling with hash collisions") { val conf = createSparkConf(loadDefaults = true) conf.set("spark.shuffle.memoryFraction", "0.001") - sc = new SparkContext("local-cluster[1,1,512]", "test", conf) + sc = new SparkContext("local-cluster[1,1,1024]", "test", conf) val map = createExternalMap[String] val collisionPairs = Seq( @@ -341,7 +341,7 @@ class ExternalAppendOnlyMapSuite extends SparkFunSuite with LocalSparkContext { test("spilling with many hash collisions") { val conf = createSparkConf(loadDefaults = true) conf.set("spark.shuffle.memoryFraction", "0.0001") - sc = new SparkContext("local-cluster[1,1,512]", "test", conf) + sc = new SparkContext("local-cluster[1,1,1024]", "test", conf) val map = new ExternalAppendOnlyMap[FixedHashObject, Int, Int](_ => 1, _ + _, _ + _) // Insert 10 copies each of lots of objects whose hash codes are either 0 or 1. This causes @@ -366,7 +366,7 @@ class ExternalAppendOnlyMapSuite extends SparkFunSuite with LocalSparkContext { test("spilling with hash collisions using the Int.MaxValue key") { val conf = createSparkConf(loadDefaults = true) conf.set("spark.shuffle.memoryFraction", "0.001") - sc = new SparkContext("local-cluster[1,1,512]", "test", conf) + sc = new SparkContext("local-cluster[1,1,1024]", "test", conf) val map = createExternalMap[Int] (1 to 100000).foreach { i => map.insert(i, i) } @@ -383,7 +383,7 @@ class ExternalAppendOnlyMapSuite extends SparkFunSuite with LocalSparkContext { test("spilling with null keys and values") { val conf = createSparkConf(loadDefaults = true) conf.set("spark.shuffle.memoryFraction", "0.001") - sc = new SparkContext("local-cluster[1,1,512]", "test", conf) + sc = new SparkContext("local-cluster[1,1,1024]", "test", conf) val map = createExternalMap[Int] map.insertAll((1 to 100000).iterator.map(i => (i, i))) diff --git a/core/src/test/scala/org/apache/spark/util/collection/ExternalSorterSuite.scala b/core/src/test/scala/org/apache/spark/util/collection/ExternalSorterSuite.scala index 9cefa612f5491..986cd8623d145 100644 --- a/core/src/test/scala/org/apache/spark/util/collection/ExternalSorterSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/collection/ExternalSorterSuite.scala @@ -176,7 +176,7 @@ class ExternalSorterSuite extends SparkFunSuite with LocalSparkContext { def testSpillingInLocalCluster(conf: SparkConf) { conf.set("spark.shuffle.memoryFraction", "0.001") conf.set("spark.shuffle.manager", "org.apache.spark.shuffle.sort.SortShuffleManager") - sc = new SparkContext("local-cluster[1,1,512]", "test", conf) + sc = new SparkContext("local-cluster[1,1,1024]", "test", conf) // reduceByKey - should spill ~8 times val rddA = sc.parallelize(0 until 100000).map(i => (i/2, i)) @@ -254,7 +254,7 @@ class ExternalSorterSuite extends SparkFunSuite with LocalSparkContext { def spillingInLocalClusterWithManyReduceTasks(conf: SparkConf) { conf.set("spark.shuffle.memoryFraction", "0.001") conf.set("spark.shuffle.manager", "org.apache.spark.shuffle.sort.SortShuffleManager") - sc = new SparkContext("local-cluster[2,1,512]", "test", conf) + sc = new SparkContext("local-cluster[2,1,1024]", "test", conf) // reduceByKey - should spill ~4 times per executor val rddA = sc.parallelize(0 until 100000).map(i => (i/2, i)) @@ -554,7 +554,7 @@ class ExternalSorterSuite extends SparkFunSuite with LocalSparkContext { test("spilling with hash collisions") { val conf = createSparkConf(true, false) conf.set("spark.shuffle.memoryFraction", "0.001") - sc = new SparkContext("local-cluster[1,1,512]", "test", conf) + sc = new SparkContext("local-cluster[1,1,1024]", "test", conf) def createCombiner(i: String): ArrayBuffer[String] = ArrayBuffer[String](i) def mergeValue(buffer: ArrayBuffer[String], i: String): ArrayBuffer[String] = buffer += i @@ -611,7 +611,7 @@ class ExternalSorterSuite extends SparkFunSuite with LocalSparkContext { test("spilling with many hash collisions") { val conf = createSparkConf(true, false) conf.set("spark.shuffle.memoryFraction", "0.0001") - sc = new SparkContext("local-cluster[1,1,512]", "test", conf) + sc = new SparkContext("local-cluster[1,1,1024]", "test", conf) val agg = new Aggregator[FixedHashObject, Int, Int](_ => 1, _ + _, _ + _) val sorter = new ExternalSorter[FixedHashObject, Int, Int](Some(agg), None, None, None) @@ -634,7 +634,7 @@ class ExternalSorterSuite extends SparkFunSuite with LocalSparkContext { test("spilling with hash collisions using the Int.MaxValue key") { val conf = createSparkConf(true, false) conf.set("spark.shuffle.memoryFraction", "0.001") - sc = new SparkContext("local-cluster[1,1,512]", "test", conf) + sc = new SparkContext("local-cluster[1,1,1024]", "test", conf) def createCombiner(i: Int): ArrayBuffer[Int] = ArrayBuffer[Int](i) def mergeValue(buffer: ArrayBuffer[Int], i: Int): ArrayBuffer[Int] = buffer += i @@ -658,7 +658,7 @@ class ExternalSorterSuite extends SparkFunSuite with LocalSparkContext { test("spilling with null keys and values") { val conf = createSparkConf(true, false) conf.set("spark.shuffle.memoryFraction", "0.001") - sc = new SparkContext("local-cluster[1,1,512]", "test", conf) + sc = new SparkContext("local-cluster[1,1,1024]", "test", conf) def createCombiner(i: String): ArrayBuffer[String] = ArrayBuffer[String](i) def mergeValue(buffer: ArrayBuffer[String], i: String): ArrayBuffer[String] = buffer += i @@ -695,7 +695,7 @@ class ExternalSorterSuite extends SparkFunSuite with LocalSparkContext { def sortWithoutBreakingSortingContracts(conf: SparkConf) { conf.set("spark.shuffle.memoryFraction", "0.01") conf.set("spark.shuffle.manager", "sort") - sc = new SparkContext("local-cluster[1,1,512]", "test", conf) + sc = new SparkContext("local-cluster[1,1,1024]", "test", conf) // Using wrongOrdering to show integer overflow introduced exception. val rand = new Random(100L) diff --git a/core/src/test/scala/org/apache/spark/util/collection/PartitionedSerializedPairBufferSuite.scala b/core/src/test/scala/org/apache/spark/util/collection/PartitionedSerializedPairBufferSuite.scala index 6d2459d48d326..3b67f6206495a 100644 --- a/core/src/test/scala/org/apache/spark/util/collection/PartitionedSerializedPairBufferSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/collection/PartitionedSerializedPairBufferSuite.scala @@ -17,15 +17,20 @@ package org.apache.spark.util.collection -import java.io.{ByteArrayInputStream, ByteArrayOutputStream, InputStream} +import java.io.{ByteArrayInputStream, ByteArrayOutputStream} import com.google.common.io.ByteStreams +import org.mockito.Matchers.any +import org.mockito.Mockito._ +import org.mockito.Mockito.RETURNS_SMART_NULLS +import org.mockito.invocation.InvocationOnMock +import org.mockito.stubbing.Answer import org.scalatest.Matchers._ import org.apache.spark.{SparkConf, SparkFunSuite} import org.apache.spark.serializer.KryoSerializer -import org.apache.spark.storage.{FileSegment, BlockObjectWriter} +import org.apache.spark.storage.DiskBlockObjectWriter class PartitionedSerializedPairBufferSuite extends SparkFunSuite { test("OrderedInputStream single record") { @@ -79,13 +84,13 @@ class PartitionedSerializedPairBufferSuite extends SparkFunSuite { val struct = SomeStruct("something", 5) buffer.insert(4, 10, struct) val it = buffer.destructiveSortedWritablePartitionedIterator(None) - val writer = new SimpleBlockObjectWriter + val (writer, baos) = createMockWriter() assert(it.hasNext) it.nextPartition should be (4) it.writeNext(writer) assert(!it.hasNext) - val stream = serializerInstance.deserializeStream(writer.getInputStream) + val stream = serializerInstance.deserializeStream(new ByteArrayInputStream(baos.toByteArray)) stream.readObject[AnyRef]() should be (10) stream.readObject[AnyRef]() should be (struct) } @@ -101,7 +106,7 @@ class PartitionedSerializedPairBufferSuite extends SparkFunSuite { buffer.insert(5, 3, struct3) val it = buffer.destructiveSortedWritablePartitionedIterator(None) - val writer = new SimpleBlockObjectWriter + val (writer, baos) = createMockWriter() assert(it.hasNext) it.nextPartition should be (4) it.writeNext(writer) @@ -113,7 +118,7 @@ class PartitionedSerializedPairBufferSuite extends SparkFunSuite { it.writeNext(writer) assert(!it.hasNext) - val stream = serializerInstance.deserializeStream(writer.getInputStream) + val stream = serializerInstance.deserializeStream(new ByteArrayInputStream(baos.toByteArray)) val iter = stream.asIterator iter.next() should be (2) iter.next() should be (struct2) @@ -123,26 +128,21 @@ class PartitionedSerializedPairBufferSuite extends SparkFunSuite { iter.next() should be (struct1) assert(!iter.hasNext) } -} - -case class SomeStruct(val str: String, val num: Int) - -class SimpleBlockObjectWriter extends BlockObjectWriter(null) { - val baos = new ByteArrayOutputStream() - override def write(bytes: Array[Byte], offs: Int, len: Int): Unit = { - baos.write(bytes, offs, len) + def createMockWriter(): (DiskBlockObjectWriter, ByteArrayOutputStream) = { + val writer = mock(classOf[DiskBlockObjectWriter], RETURNS_SMART_NULLS) + val baos = new ByteArrayOutputStream() + when(writer.write(any(), any(), any())).thenAnswer(new Answer[Unit] { + override def answer(invocationOnMock: InvocationOnMock): Unit = { + val args = invocationOnMock.getArguments + val bytes = args(0).asInstanceOf[Array[Byte]] + val offset = args(1).asInstanceOf[Int] + val length = args(2).asInstanceOf[Int] + baos.write(bytes, offset, length) + } + }) + (writer, baos) } - - def getInputStream(): InputStream = new ByteArrayInputStream(baos.toByteArray) - - override def open(): BlockObjectWriter = this - override def close(): Unit = { } - override def isOpen: Boolean = true - override def commitAndClose(): Unit = { } - override def revertPartialWritesAndClose(): Unit = { } - override def fileSegment(): FileSegment = null - override def write(key: Any, value: Any): Unit = { } - override def recordWritten(): Unit = { } - override def write(b: Int): Unit = { } } + +case class SomeStruct(str: String, num: Int) diff --git a/core/src/test/scala/org/apache/spark/util/collection/unsafe/sort/PrefixComparatorsSuite.scala b/core/src/test/scala/org/apache/spark/util/collection/unsafe/sort/PrefixComparatorsSuite.scala index dd505dfa7d758..dc03e374b51db 100644 --- a/core/src/test/scala/org/apache/spark/util/collection/unsafe/sort/PrefixComparatorsSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/collection/unsafe/sort/PrefixComparatorsSuite.scala @@ -47,4 +47,29 @@ class PrefixComparatorsSuite extends SparkFunSuite with PropertyChecks { forAll (regressionTests) { (s1: String, s2: String) => testPrefixComparison(s1, s2) } forAll { (s1: String, s2: String) => testPrefixComparison(s1, s2) } } + + test("float prefix comparator handles NaN properly") { + val nan1: Float = java.lang.Float.intBitsToFloat(0x7f800001) + val nan2: Float = java.lang.Float.intBitsToFloat(0x7fffffff) + assert(nan1.isNaN) + assert(nan2.isNaN) + val nan1Prefix = PrefixComparators.FLOAT.computePrefix(nan1) + val nan2Prefix = PrefixComparators.FLOAT.computePrefix(nan2) + assert(nan1Prefix === nan2Prefix) + val floatMaxPrefix = PrefixComparators.FLOAT.computePrefix(Float.MaxValue) + assert(PrefixComparators.FLOAT.compare(nan1Prefix, floatMaxPrefix) === 1) + } + + test("double prefix comparator handles NaNs properly") { + val nan1: Double = java.lang.Double.longBitsToDouble(0x7ff0000000000001L) + val nan2: Double = java.lang.Double.longBitsToDouble(0x7fffffffffffffffL) + assert(nan1.isNaN) + assert(nan2.isNaN) + val nan1Prefix = PrefixComparators.DOUBLE.computePrefix(nan1) + val nan2Prefix = PrefixComparators.DOUBLE.computePrefix(nan2) + assert(nan1Prefix === nan2Prefix) + val doubleMaxPrefix = PrefixComparators.DOUBLE.computePrefix(Double.MaxValue) + assert(PrefixComparators.DOUBLE.compare(nan1Prefix, doubleMaxPrefix) === 1) + } + } diff --git a/data/mllib/sample_naive_bayes_data.txt b/data/mllib/sample_naive_bayes_data.txt index 981da382d6ac8..bd22bea3a59d6 100644 --- a/data/mllib/sample_naive_bayes_data.txt +++ b/data/mllib/sample_naive_bayes_data.txt @@ -1,6 +1,12 @@ 0,1 0 0 0,2 0 0 +0,3 0 0 +0,4 0 0 1,0 1 0 1,0 2 0 +1,0 3 0 +1,0 4 0 2,0 0 1 2,0 0 2 +2,0 0 3 +2,0 0 4 \ No newline at end of file diff --git a/dev/change-scala-version.sh b/dev/change-scala-version.sh new file mode 100755 index 0000000000000..b81c00c9d6d9d --- /dev/null +++ b/dev/change-scala-version.sh @@ -0,0 +1,66 @@ +#!/usr/bin/env bash + +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +set -e + +usage() { + echo "Usage: $(basename $0) " 1>&2 + exit 1 +} + +if [ $# -ne 1 ]; then + usage +fi + +TO_VERSION=$1 + +VALID_VERSIONS=( 2.10 2.11 ) + +check_scala_version() { + for i in ${VALID_VERSIONS[*]}; do [ $i = "$1" ] && return 0; done + echo "Invalid Scala version: $1. Valid versions: ${VALID_VERSIONS[*]}" 1>&2 + exit 1 +} + +check_scala_version "$TO_VERSION" + +if [ $TO_VERSION = "2.11" ]; then + FROM_VERSION="2.10" +else + FROM_VERSION="2.11" +fi + +sed_i() { + sed -e "$1" "$2" > "$2.tmp" && mv "$2.tmp" "$2" +} + +export -f sed_i + +BASEDIR=$(dirname $0)/.. +find "$BASEDIR" -name 'pom.xml' -not -path '*target*' -print \ + -exec bash -c "sed_i 's/\(artifactId.*\)_'$FROM_VERSION'/\1_'$TO_VERSION'/g' {}" \; + +# Also update in parent POM +# Match any scala binary version to ensure idempotency +sed_i '1,/[0-9]*\.[0-9]*[0-9]*\.[0-9]*'$TO_VERSION' in parent POM -sed -i -e '0,/2.112.10 in parent POM -sed -i -e '0,/2.102.11 "$PYTHON_LINT_REPORT_PATH" +python -B -m compileall -q -l $PATHS_TO_CHECK > "$PEP8_REPORT_PATH" compile_status="${PIPESTATUS[0]}" # Get pep8 at runtime so that we don't rely on it being installed on the build server. @@ -47,11 +49,36 @@ if [ ! -e "$PEP8_SCRIPT_PATH" ]; then fi fi +# Easy install pylint in /dev/pylint. To easy_install into a directory, the PYTHONPATH should +# be set to the directory. +# dev/pylint should be appended to the PATH variable as well. +# Jenkins by default installs the pylint3 version, so for now this just checks the code quality +# of python3. +export "PYTHONPATH=$SPARK_ROOT_DIR/dev/pylint" +export "PYLINT_HOME=$PYTHONPATH" +export "PATH=$PYTHONPATH:$PATH" + +if [ ! -d "$PYLINT_HOME" ]; then + mkdir "$PYLINT_HOME" + # Redirect the annoying pylint installation output. + easy_install -d "$PYLINT_HOME" pylint==1.4.4 &>> "$PYLINT_INSTALL_INFO" + easy_install_status="$?" + + if [ "$easy_install_status" -ne 0 ]; then + echo "Unable to install pylint locally in \"$PYTHONPATH\"." + cat "$PYLINT_INSTALL_INFO" + exit "$easy_install_status" + fi + + rm "$PYLINT_INSTALL_INFO" + +fi + # There is no need to write this output to a file #+ first, but we do so so that the check status can #+ be output before the report, like with the #+ scalastyle and RAT checks. -python "$PEP8_SCRIPT_PATH" --ignore=E402,E731,E241,W503,E226 $PATHS_TO_CHECK >> "$PYTHON_LINT_REPORT_PATH" +python "$PEP8_SCRIPT_PATH" --ignore=E402,E731,E241,W503,E226 $PATHS_TO_CHECK >> "$PEP8_REPORT_PATH" pep8_status="${PIPESTATUS[0]}" if [ "$compile_status" -eq 0 -a "$pep8_status" -eq 0 ]; then @@ -61,13 +88,27 @@ else fi if [ "$lint_status" -ne 0 ]; then - echo "Python lint checks failed." - cat "$PYTHON_LINT_REPORT_PATH" + echo "PEP8 checks failed." + cat "$PEP8_REPORT_PATH" +else + echo "PEP8 checks passed." +fi + +rm "$PEP8_REPORT_PATH" + +for to_be_checked in "$PATHS_TO_CHECK" +do + pylint --rcfile="$SPARK_ROOT_DIR/pylintrc" $to_be_checked >> "$PYLINT_REPORT_PATH" +done + +if [ "${PIPESTATUS[0]}" -ne 0 ]; then + lint_status=1 + echo "Pylint checks failed." + cat "$PYLINT_REPORT_PATH" else - echo "Python lint checks passed." + echo "Pylint checks passed." fi -# rm "$PEP8_SCRIPT_PATH" -rm "$PYTHON_LINT_REPORT_PATH" +rm "$PYLINT_REPORT_PATH" exit "$lint_status" diff --git a/dev/lint-r.R b/dev/lint-r.R index dcb1a184291e1..48bd6246096ae 100644 --- a/dev/lint-r.R +++ b/dev/lint-r.R @@ -15,15 +15,21 @@ # limitations under the License. # +argv <- commandArgs(TRUE) +SPARK_ROOT_DIR <- as.character(argv[1]) + # Installs lintr from Github. # NOTE: The CRAN's version is too old to adapt to our rules. if ("lintr" %in% row.names(installed.packages()) == FALSE) { devtools::install_github("jimhester/lintr") } -library(lintr) -argv <- commandArgs(TRUE) -SPARK_ROOT_DIR <- as.character(argv[1]) +library(lintr) +library(methods) +library(testthat) +if (! library(SparkR, lib.loc = file.path(SPARK_ROOT_DIR, "R", "lib"), logical.return = TRUE)) { + stop("You should install SparkR in a local directory with `R/install-dev.sh`.") +} path.to.package <- file.path(SPARK_ROOT_DIR, "R", "pkg") lint_package(path.to.package, cache = FALSE) diff --git a/dev/merge_spark_pr.py b/dev/merge_spark_pr.py index 4a17d48d8171d..ad4b76695c9ff 100755 --- a/dev/merge_spark_pr.py +++ b/dev/merge_spark_pr.py @@ -130,7 +130,12 @@ def merge_pr(pr_num, target_ref, title, body, pr_repo_desc): '--pretty=format:%an <%ae>']).split("\n") distinct_authors = sorted(set(commit_authors), key=lambda x: commit_authors.count(x), reverse=True) - primary_author = distinct_authors[0] + primary_author = raw_input( + "Enter primary author in the format of \"name \" [%s]: " % + distinct_authors[0]) + if primary_author == "": + primary_author = distinct_authors[0] + commits = run_cmd(['git', 'log', 'HEAD..%s' % pr_branch_name, '--pretty=format:%h [%an] %s']).split("\n\n") @@ -281,7 +286,7 @@ def get_version_json(version_str): resolve = filter(lambda a: a['name'] == "Resolve Issue", asf_jira.transitions(jira_id))[0] resolution = filter(lambda r: r.raw['name'] == "Fixed", asf_jira.resolutions())[0] asf_jira.transition_issue( - jira_id, resolve["id"], fixVersions = jira_fix_versions, + jira_id, resolve["id"], fixVersions = jira_fix_versions, comment = comment, resolution = {'id': resolution.raw['id']}) print "Successfully resolved %s with fixVersions=%s!" % (jira_id, fix_versions) @@ -300,7 +305,7 @@ def standardize_jira_ref(text): """ Standardize the [SPARK-XXXXX] [MODULE] prefix Converts "[SPARK-XXX][mllib] Issue", "[MLLib] SPARK-XXX. Issue" or "SPARK XXX [MLLIB]: Issue" to "[SPARK-XXX] [MLLIB] Issue" - + >>> standardize_jira_ref("[SPARK-5821] [SQL] ParquetRelation2 CTAS should check if delete is successful") '[SPARK-5821] [SQL] ParquetRelation2 CTAS should check if delete is successful' >>> standardize_jira_ref("[SPARK-4123][Project Infra][WIP]: Show new dependencies added in pull requests") @@ -322,11 +327,11 @@ def standardize_jira_ref(text): """ jira_refs = [] components = [] - + # If the string is compliant, no need to process any further if (re.search(r'^\[SPARK-[0-9]{3,6}\] (\[[A-Z0-9_\s,]+\] )+\S+', text)): return text - + # Extract JIRA ref(s): pattern = re.compile(r'(SPARK[-\s]*[0-9]{3,6})+', re.IGNORECASE) for ref in pattern.findall(text): @@ -348,18 +353,18 @@ def standardize_jira_ref(text): # Assemble full text (JIRA ref(s), module(s), remaining text) clean_text = ' '.join(jira_refs).strip() + " " + ' '.join(components).strip() + " " + text.strip() - + # Replace multiple spaces with a single space, e.g. if no jira refs and/or components were included clean_text = re.sub(r'\s+', ' ', clean_text.strip()) - + return clean_text def main(): global original_head - + os.chdir(SPARK_HOME) original_head = run_cmd("git rev-parse HEAD")[:8] - + branches = get_json("%s/branches" % GITHUB_API_BASE) branch_names = filter(lambda x: x.startswith("branch-"), [x['name'] for x in branches]) # Assumes branch names can be sorted lexicographically @@ -448,5 +453,5 @@ def main(): (failure_count, test_count) = doctest.testmod() if failure_count: exit(-1) - + main() diff --git a/dev/sparktestsupport/modules.py b/dev/sparktestsupport/modules.py index 993583e2f4119..3073d489bad4a 100644 --- a/dev/sparktestsupport/modules.py +++ b/dev/sparktestsupport/modules.py @@ -338,6 +338,7 @@ def contains_file(self, filename): python_test_goals=[ "pyspark.ml.feature", "pyspark.ml.classification", + "pyspark.ml.clustering", "pyspark.ml.recommendation", "pyspark.ml.regression", "pyspark.ml.tuning", diff --git a/docker/spark-test/base/Dockerfile b/docker/spark-test/base/Dockerfile index 5956d59130fbf..5dbdb8b22a44f 100644 --- a/docker/spark-test/base/Dockerfile +++ b/docker/spark-test/base/Dockerfile @@ -17,13 +17,13 @@ FROM ubuntu:precise -RUN echo "deb http://archive.ubuntu.com/ubuntu precise main universe" > /etc/apt/sources.list - # Upgrade package index -RUN apt-get update - # install a few other useful packages plus Open Jdk 7 -RUN apt-get install -y less openjdk-7-jre-headless net-tools vim-tiny sudo openssh-server +# Remove unneeded /var/lib/apt/lists/* after install to reduce the +# docker image size (by ~30MB) +RUN apt-get update && \ + apt-get install -y less openjdk-7-jre-headless net-tools vim-tiny sudo openssh-server && \ + rm -rf /var/lib/apt/lists/* ENV SCALA_VERSION 2.10.4 ENV CDH_VERSION cdh4 diff --git a/docs/building-spark.md b/docs/building-spark.md index 2128fdffecc05..a5da3b39502e2 100644 --- a/docs/building-spark.md +++ b/docs/building-spark.md @@ -124,7 +124,7 @@ mvn -Pyarn -Phadoop-2.4 -Dhadoop.version=2.4.0 -Phive -Phive-thriftserver -Dskip # Building for Scala 2.11 To produce a Spark package compiled with Scala 2.11, use the `-Dscala-2.11` property: - dev/change-version-to-2.11.sh + dev/change-scala-version.sh 2.11 mvn -Pyarn -Phadoop-2.4 -Dscala-2.11 -DskipTests clean package Spark does not yet support its JDBC component for Scala 2.11. diff --git a/docs/configuration.md b/docs/configuration.md index 443322e1eadf4..200f3cd212e46 100644 --- a/docs/configuration.md +++ b/docs/configuration.md @@ -31,7 +31,6 @@ which can help detect bugs that only exist when we run in a distributed context. val conf = new SparkConf() .setMaster("local[2]") .setAppName("CountingSheep") - .set("spark.executor.memory", "1g") val sc = new SparkContext(conf) {% endhighlight %} @@ -84,7 +83,7 @@ Running `./bin/spark-submit --help` will show the entire list of these options. each line consists of a key and a value separated by whitespace. For example: spark.master spark://5.6.7.8:7077 - spark.executor.memory 512m + spark.executor.memory 4g spark.eventLog.enabled true spark.serializer org.apache.spark.serializer.KryoSerializer @@ -150,10 +149,9 @@ of the most common options to set are: spark.executor.memory - 512m + 1g - Amount of memory to use per executor process, in the same format as JVM memory strings - (e.g. 512m, 2g). + Amount of memory to use per executor process (e.g. 2g, 8g). @@ -665,7 +663,7 @@ Apart from these, the following properties are also available, and may be useful Initial size of Kryo's serialization buffer. Note that there will be one buffer per core on each worker. This buffer will grow up to - spark.kryoserializer.buffer.max.mb if needed. + spark.kryoserializer.buffer.max if needed. @@ -886,11 +884,11 @@ Apart from these, the following properties are also available, and may be useful spark.akka.frameSize - 10 + 128 - Maximum message size to allow in "control plane" communication (for serialized tasks and task - results), in MB. Increase this if your tasks need to send back large results to the driver - (e.g. using collect() on a large dataset). + Maximum message size to allow in "control plane" communication; generally only applies to map + output size information sent between executors and the driver. Increase this if you are running + jobs with many thousands of map and reduce tasks and see messages about the frame size. @@ -1050,15 +1048,6 @@ Apart from these, the following properties are also available, and may be useful infinite (all available cores) on Mesos. - - spark.localExecution.enabled - false - - Enables Spark to run certain jobs, such as first() or take() on the driver, without sending - tasks to the cluster. This can make certain jobs execute very quickly, but may require - shipping a whole partition of data to the driver. - - spark.locality.wait 3s diff --git a/docs/ml-guide.md b/docs/ml-guide.md index c74cb1f1ef8ea..8c46adf256a9a 100644 --- a/docs/ml-guide.md +++ b/docs/ml-guide.md @@ -3,6 +3,24 @@ layout: global title: Spark ML Programming Guide --- +`\[ +\newcommand{\R}{\mathbb{R}} +\newcommand{\E}{\mathbb{E}} +\newcommand{\x}{\mathbf{x}} +\newcommand{\y}{\mathbf{y}} +\newcommand{\wv}{\mathbf{w}} +\newcommand{\av}{\mathbf{\alpha}} +\newcommand{\bv}{\mathbf{b}} +\newcommand{\N}{\mathbb{N}} +\newcommand{\id}{\mathbf{I}} +\newcommand{\ind}{\mathbf{1}} +\newcommand{\0}{\mathbf{0}} +\newcommand{\unit}{\mathbf{e}} +\newcommand{\one}{\mathbf{1}} +\newcommand{\zero}{\mathbf{0}} +\]` + + Spark 1.2 introduced a new package called `spark.ml`, which aims to provide a uniform set of high-level APIs that help users create and tune practical machine learning pipelines. @@ -154,6 +172,19 @@ Parameters belong to specific instances of `Estimator`s and `Transformer`s. For example, if we have two `LogisticRegression` instances `lr1` and `lr2`, then we can build a `ParamMap` with both `maxIter` parameters specified: `ParamMap(lr1.maxIter -> 10, lr2.maxIter -> 20)`. This is useful if there are two algorithms with the `maxIter` parameter in a `Pipeline`. +# Algorithm Guides + +There are now several algorithms in the Pipelines API which are not in the lower-level MLlib API, so we link to documentation for them here. These algorithms are mostly feature transformers, which fit naturally into the `Transformer` abstraction in Pipelines, and ensembles, which fit naturally into the `Estimator` abstraction in the Pipelines. + +**Pipelines API Algorithm Guides** + +* [Feature Extraction, Transformation, and Selection](ml-features.html) +* [Ensembles](ml-ensembles.html) + +**Algorithms in `spark.ml`** + +* [Linear methods with elastic net regularization](ml-linear-methods.html) + # Code Examples This section gives code examples illustrating the functionality discussed above. diff --git a/docs/ml-linear-methods.md b/docs/ml-linear-methods.md new file mode 100644 index 0000000000000..1ac83d94c9e81 --- /dev/null +++ b/docs/ml-linear-methods.md @@ -0,0 +1,129 @@ +--- +layout: global +title: Linear Methods - ML +displayTitle: ML - Linear Methods +--- + + +`\[ +\newcommand{\R}{\mathbb{R}} +\newcommand{\E}{\mathbb{E}} +\newcommand{\x}{\mathbf{x}} +\newcommand{\y}{\mathbf{y}} +\newcommand{\wv}{\mathbf{w}} +\newcommand{\av}{\mathbf{\alpha}} +\newcommand{\bv}{\mathbf{b}} +\newcommand{\N}{\mathbb{N}} +\newcommand{\id}{\mathbf{I}} +\newcommand{\ind}{\mathbf{1}} +\newcommand{\0}{\mathbf{0}} +\newcommand{\unit}{\mathbf{e}} +\newcommand{\one}{\mathbf{1}} +\newcommand{\zero}{\mathbf{0}} +\]` + + +In MLlib, we implement popular linear methods such as logistic regression and linear least squares with L1 or L2 regularization. Refer to [the linear methods in mllib](mllib-linear-methods.html) for details. In `spark.ml`, we also include Pipelines API for [Elastic net](http://en.wikipedia.org/wiki/Elastic_net_regularization), a hybrid of L1 and L2 regularization proposed in [this paper](http://users.stat.umn.edu/~zouxx019/Papers/elasticnet.pdf). Mathematically it is defined as a linear combination of the L1-norm and the L2-norm: +`\[ +\alpha \|\wv\|_1 + (1-\alpha) \frac{1}{2}\|\wv\|_2^2, \alpha \in [0, 1]. +\]` +By setting $\alpha$ properly, it contains both L1 and L2 regularization as special cases. For example, if a [linear regression](https://en.wikipedia.org/wiki/Linear_regression) model is trained with the elastic net parameter $\alpha$ set to $1$, it is equivalent to a [Lasso](http://en.wikipedia.org/wiki/Least_squares#Lasso_method) model. On the other hand, if $\alpha$ is set to $0$, the trained model reduces to a [ridge regression](http://en.wikipedia.org/wiki/Tikhonov_regularization) model. We implement Pipelines API for both linear regression and logistic regression with elastic net regularization. + +**Examples** + +
    + +
    + +{% highlight scala %} + +import org.apache.spark.ml.classification.LogisticRegression +import org.apache.spark.mllib.util.MLUtils + +// Load training data +val training = MLUtils.loadLibSVMFile(sc, "data/mllib/sample_libsvm_data.txt").toDF() + +val lr = new LogisticRegression() + .setMaxIter(10) + .setRegParam(0.3) + .setElasticNetParam(0.8) + +// Fit the model +val lrModel = lr.fit(training) + +// Print the weights and intercept for logistic regression +println(s"Weights: ${lrModel.weights} Intercept: ${lrModel.intercept}") + +{% endhighlight %} + +
    + +
    + +{% highlight java %} + +import org.apache.spark.ml.classification.LogisticRegression; +import org.apache.spark.ml.classification.LogisticRegressionModel; +import org.apache.spark.mllib.regression.LabeledPoint; +import org.apache.spark.mllib.util.MLUtils; +import org.apache.spark.SparkConf; +import org.apache.spark.SparkContext; +import org.apache.spark.sql.DataFrame; +import org.apache.spark.sql.SQLContext; + +public class LogisticRegressionWithElasticNetExample { + public static void main(String[] args) { + SparkConf conf = new SparkConf() + .setAppName("Logistic Regression with Elastic Net Example"); + + SparkContext sc = new SparkContext(conf); + SQLContext sql = new SQLContext(sc); + String path = "sample_libsvm_data.txt"; + + // Load training data + DataFrame training = sql.createDataFrame(MLUtils.loadLibSVMFile(sc, path).toJavaRDD(), LabeledPoint.class); + + LogisticRegression lr = new LogisticRegression() + .setMaxIter(10) + .setRegParam(0.3) + .setElasticNetParam(0.8) + + // Fit the model + LogisticRegressionModel lrModel = lr.fit(training); + + // Print the weights and intercept for logistic regression + System.out.println("Weights: " + lrModel.weights() + " Intercept: " + lrModel.intercept()); + } +} +{% endhighlight %} +
    + +
    + +{% highlight python %} + +from pyspark.ml.classification import LogisticRegression +from pyspark.mllib.regression import LabeledPoint +from pyspark.mllib.util import MLUtils + +# Load training data +training = MLUtils.loadLibSVMFile(sc, "data/mllib/sample_libsvm_data.txt").toDF() + +lr = LogisticRegression(maxIter=10, regParam=0.3, elasticNetParam=0.8) + +# Fit the model +lrModel = lr.fit(training) + +# Print the weights and intercept for logistic regression +print("Weights: " + str(lrModel.weights)) +print("Intercept: " + str(lrModel.intercept)) +{% endhighlight %} + +
    + +
    + +### Optimization + +The optimization algorithm underlies the implementation is called [Orthant-Wise Limited-memory QuasiNewton](http://research-srv.microsoft.com/en-us/um/people/jfgao/paper/icml07scalable.pdf) +(OWL-QN). It is an extension of L-BFGS that can effectively handle L1 regularization and elastic net. diff --git a/docs/mllib-clustering.md b/docs/mllib-clustering.md index d72dc20a5ad6e..bb875ae2ae6cb 100644 --- a/docs/mllib-clustering.md +++ b/docs/mllib-clustering.md @@ -33,6 +33,7 @@ guaranteed to find a globally optimal solution, and when run multiple times on a given dataset, the algorithm returns the best clustering result). * *initializationSteps* determines the number of steps in the k-means\|\| algorithm. * *epsilon* determines the distance threshold within which we consider k-means to have converged. +* *initialModel* is an optional set of cluster centers used for initialization. If this parameter is supplied, only one run is performed. **Examples** @@ -471,7 +472,7 @@ to the algorithm. We then output the topics, represented as probability distribu
    {% highlight scala %} -import org.apache.spark.mllib.clustering.LDA +import org.apache.spark.mllib.clustering.{LDA, DistributedLDAModel} import org.apache.spark.mllib.linalg.Vectors // Load and parse the data @@ -491,6 +492,11 @@ for (topic <- Range(0, 3)) { for (word <- Range(0, ldaModel.vocabSize)) { print(" " + topics(word, topic)); } println() } + +// Save and load model. +ldaModel.save(sc, "myLDAModel") +val sameModel = DistributedLDAModel.load(sc, "myLDAModel") + {% endhighlight %}
    @@ -550,6 +556,9 @@ public class JavaLDAExample { } System.out.println(); } + + ldaModel.save(sc.sc(), "myLDAModel"); + DistributedLDAModel sameModel = DistributedLDAModel.load(sc.sc(), "myLDAModel"); } } {% endhighlight %} diff --git a/docs/mllib-linear-methods.md b/docs/mllib-linear-methods.md index 3927d65fbf8fb..07655baa414b5 100644 --- a/docs/mllib-linear-methods.md +++ b/docs/mllib-linear-methods.md @@ -10,7 +10,7 @@ displayTitle: MLlib - Linear Methods `\[ \newcommand{\R}{\mathbb{R}} -\newcommand{\E}{\mathbb{E}} +\newcommand{\E}{\mathbb{E}} \newcommand{\x}{\mathbf{x}} \newcommand{\y}{\mathbf{y}} \newcommand{\wv}{\mathbf{w}} @@ -18,10 +18,10 @@ displayTitle: MLlib - Linear Methods \newcommand{\bv}{\mathbf{b}} \newcommand{\N}{\mathbb{N}} \newcommand{\id}{\mathbf{I}} -\newcommand{\ind}{\mathbf{1}} -\newcommand{\0}{\mathbf{0}} -\newcommand{\unit}{\mathbf{e}} -\newcommand{\one}{\mathbf{1}} +\newcommand{\ind}{\mathbf{1}} +\newcommand{\0}{\mathbf{0}} +\newcommand{\unit}{\mathbf{e}} +\newcommand{\one}{\mathbf{1}} \newcommand{\zero}{\mathbf{0}} \]` @@ -29,7 +29,7 @@ displayTitle: MLlib - Linear Methods Many standard *machine learning* methods can be formulated as a convex optimization problem, i.e. the task of finding a minimizer of a convex function `$f$` that depends on a variable vector -`$\wv$` (called `weights` in the code), which has `$d$` entries. +`$\wv$` (called `weights` in the code), which has `$d$` entries. Formally, we can write this as the optimization problem `$\min_{\wv \in\R^d} \; f(\wv)$`, where the objective function is of the form `\begin{equation} @@ -39,7 +39,7 @@ the objective function is of the form \ . \end{equation}` Here the vectors `$\x_i\in\R^d$` are the training data examples, for `$1\le i\le n$`, and -`$y_i\in\R$` are their corresponding labels, which we want to predict. +`$y_i\in\R$` are their corresponding labels, which we want to predict. We call the method *linear* if $L(\wv; \x, y)$ can be expressed as a function of $\wv^T x$ and $y$. Several of MLlib's classification and regression algorithms fall into this category, and are discussed here. @@ -99,6 +99,9 @@ regularizers in MLlib: L1$\|\wv\|_1$$\mathrm{sign}(\wv)$ + + elastic net$\alpha \|\wv\|_1 + (1-\alpha)\frac{1}{2}\|\wv\|_2^2$$\alpha \mathrm{sign}(\wv) + (1-\alpha) \wv$ + @@ -107,7 +110,7 @@ of `$\wv$`. L2-regularized problems are generally easier to solve than L1-regularized due to smoothness. However, L1 regularization can help promote sparsity in weights leading to smaller and more interpretable models, the latter of which can be useful for feature selection. -It is not recommended to train models without any regularization, +[Elastic net](http://en.wikipedia.org/wiki/Elastic_net_regularization) is a combination of L1 and L2 regularization. It is not recommended to train models without any regularization, especially when the number of training examples is small. ### Optimization @@ -531,7 +534,7 @@ sameModel = LogisticRegressionModel.load(sc, "myModelPath") ### Linear least squares, Lasso, and ridge regression -Linear least squares is the most common formulation for regression problems. +Linear least squares is the most common formulation for regression problems. It is a linear method as described above in equation `$\eqref{eq:regPrimal}$`, with the loss function in the formulation given by the squared loss: `\[ @@ -539,8 +542,8 @@ L(\wv;\x,y) := \frac{1}{2} (\wv^T \x - y)^2. \]` Various related regression methods are derived by using different types of regularization: -[*ordinary least squares*](http://en.wikipedia.org/wiki/Ordinary_least_squares) or -[*linear least squares*](http://en.wikipedia.org/wiki/Linear_least_squares_(mathematics)) uses +[*ordinary least squares*](http://en.wikipedia.org/wiki/Ordinary_least_squares) or +[*linear least squares*](http://en.wikipedia.org/wiki/Linear_least_squares_(mathematics)) uses no regularization; [*ridge regression*](http://en.wikipedia.org/wiki/Ridge_regression) uses L2 regularization; and [*Lasso*](http://en.wikipedia.org/wiki/Lasso_(statistics)) uses L1 regularization. For all of these models, the average loss or training error, $\frac{1}{n} \sum_{i=1}^n (\wv^T x_i - y_i)^2$, is @@ -552,7 +555,7 @@ known as the [mean squared error](http://en.wikipedia.org/wiki/Mean_squared_erro
    The following example demonstrate how to load training data, parse it as an RDD of LabeledPoint. -The example then uses LinearRegressionWithSGD to build a simple linear model to predict label +The example then uses LinearRegressionWithSGD to build a simple linear model to predict label values. We compute the mean squared error at the end to evaluate [goodness of fit](http://en.wikipedia.org/wiki/Goodness_of_fit). @@ -614,7 +617,7 @@ public class LinearRegression { public static void main(String[] args) { SparkConf conf = new SparkConf().setAppName("Linear Regression Example"); JavaSparkContext sc = new JavaSparkContext(conf); - + // Load and parse the data String path = "data/mllib/ridge-data/lpsa.data"; JavaRDD data = sc.textFile(path); @@ -634,7 +637,7 @@ public class LinearRegression { // Building the model int numIterations = 100; - final LinearRegressionModel model = + final LinearRegressionModel model = LinearRegressionWithSGD.train(JavaRDD.toRDD(parsedData), numIterations); // Evaluate model on training examples and compute training error @@ -665,7 +668,7 @@ public class LinearRegression {
    The following example demonstrate how to load training data, parse it as an RDD of LabeledPoint. -The example then uses LinearRegressionWithSGD to build a simple linear model to predict label +The example then uses LinearRegressionWithSGD to build a simple linear model to predict label values. We compute the mean squared error at the end to evaluate [goodness of fit](http://en.wikipedia.org/wiki/Goodness_of_fit). @@ -706,8 +709,8 @@ a dependency. ###Streaming linear regression -When data arrive in a streaming fashion, it is useful to fit regression models online, -updating the parameters of the model as new data arrives. MLlib currently supports +When data arrive in a streaming fashion, it is useful to fit regression models online, +updating the parameters of the model as new data arrives. MLlib currently supports streaming linear regression using ordinary least squares. The fitting is similar to that performed offline, except fitting occurs on each batch of data, so that the model continually updates to reflect the data from the stream. @@ -722,7 +725,7 @@ online to the first stream, and make predictions on the second stream.
    -First, we import the necessary classes for parsing our input data and creating the model. +First, we import the necessary classes for parsing our input data and creating the model. {% highlight scala %} @@ -734,7 +737,7 @@ import org.apache.spark.mllib.regression.StreamingLinearRegressionWithSGD Then we make input streams for training and testing data. We assume a StreamingContext `ssc` has already been created, see [Spark Streaming Programming Guide](streaming-programming-guide.html#initializing) -for more info. For this example, we use labeled points in training and testing streams, +for more info. For this example, we use labeled points in training and testing streams, but in practice you will likely want to use unlabeled vectors for test data. {% highlight scala %} @@ -754,7 +757,7 @@ val model = new StreamingLinearRegressionWithSGD() {% endhighlight %} -Now we register the streams for training and testing and start the job. +Now we register the streams for training and testing and start the job. Printing predictions alongside true labels lets us easily see the result. {% highlight scala %} @@ -764,14 +767,14 @@ model.predictOnValues(testData.map(lp => (lp.label, lp.features))).print() ssc.start() ssc.awaitTermination() - + {% endhighlight %} We can now save text files with data to the training or testing folders. -Each line should be a data point formatted as `(y,[x1,x2,x3])` where `y` is the label -and `x1,x2,x3` are the features. Anytime a text file is placed in `/training/data/dir` -the model will update. Anytime a text file is placed in `/testing/data/dir` you will see predictions. -As you feed more data to the training directory, the predictions +Each line should be a data point formatted as `(y,[x1,x2,x3])` where `y` is the label +and `x1,x2,x3` are the features. Anytime a text file is placed in `/training/data/dir` +the model will update. Anytime a text file is placed in `/testing/data/dir` you will see predictions. +As you feed more data to the training directory, the predictions will get better!
    diff --git a/docs/running-on-mesos.md b/docs/running-on-mesos.md index 1f915d8ea1d73..debdd2adf22d6 100644 --- a/docs/running-on-mesos.md +++ b/docs/running-on-mesos.md @@ -306,6 +306,28 @@ See the [configuration page](configuration.html) for information on Spark config the final overhead will be this value. + + spark.mesos.principal + Framework principal to authenticate to Mesos + + Set the principal with which Spark framework will use to authenticate with Mesos. + + + + spark.mesos.secret + Framework secret to authenticate to Mesos + + Set the secret with which Spark framework will use to authenticate with Mesos. + + + + spark.mesos.role + Role for the Spark framework + + Set the role of this Spark framework for Mesos. Roles are used in Mesos for reservations + and resource weight sharing. + + spark.mesos.constraints Attribute based constraints to be matched against when accepting resource offers. diff --git a/ec2/spark_ec2.py b/ec2/spark_ec2.py index ae4f2ecc5bde7..7c83d68e7993e 100755 --- a/ec2/spark_ec2.py +++ b/ec2/spark_ec2.py @@ -793,7 +793,7 @@ def setup_cluster(conn, master_nodes, slave_nodes, opts, deploy_ssh_key): ssh_write(slave_address, opts, ['tar', 'x'], dot_ssh_tar) modules = ['spark', 'ephemeral-hdfs', 'persistent-hdfs', - 'mapreduce', 'spark-standalone', 'tachyon'] + 'mapreduce', 'spark-standalone', 'tachyon', 'rstudio'] if opts.hadoop_major_version == "1": modules = list(filter(lambda x: x != "mapreduce", modules)) diff --git a/external/flume-sink/src/main/scala/org/apache/spark/streaming/flume/sink/Logging.scala b/external/flume-sink/src/main/scala/org/apache/spark/streaming/flume/sink/Logging.scala index 17cbc6707b5ea..d87b86932dd41 100644 --- a/external/flume-sink/src/main/scala/org/apache/spark/streaming/flume/sink/Logging.scala +++ b/external/flume-sink/src/main/scala/org/apache/spark/streaming/flume/sink/Logging.scala @@ -113,7 +113,9 @@ private[sink] object Logging { try { // We use reflection here to handle the case where users remove the // slf4j-to-jul bridge order to route their logs to JUL. + // scalastyle:off classforname val bridgeClass = Class.forName("org.slf4j.bridge.SLF4JBridgeHandler") + // scalastyle:on classforname bridgeClass.getMethod("removeHandlersForRootLogger").invoke(null) val installed = bridgeClass.getMethod("isInstalled").invoke(null).asInstanceOf[Boolean] if (!installed) { diff --git a/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/KafkaRDD.scala b/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/KafkaRDD.scala index c5cd2154772ac..1a9d78c0d4f59 100644 --- a/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/KafkaRDD.scala +++ b/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/KafkaRDD.scala @@ -98,8 +98,7 @@ class KafkaRDD[ val res = context.runJob( this, (tc: TaskContext, it: Iterator[R]) => it.take(parts(tc.partitionId)).toArray, - parts.keys.toArray, - allowLocal = true) + parts.keys.toArray) res.foreach(buf ++= _) buf.toArray } diff --git a/extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisTestUtils.scala b/extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisTestUtils.scala new file mode 100644 index 0000000000000..f6bf552e6bb8e --- /dev/null +++ b/extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisTestUtils.scala @@ -0,0 +1,197 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.streaming.kinesis + +import java.nio.ByteBuffer +import java.util.concurrent.TimeUnit + +import scala.collection.mutable +import scala.collection.mutable.ArrayBuffer +import scala.util.{Failure, Random, Success, Try} + +import com.amazonaws.auth.{AWSCredentials, DefaultAWSCredentialsProviderChain} +import com.amazonaws.regions.RegionUtils +import com.amazonaws.services.dynamodbv2.AmazonDynamoDBClient +import com.amazonaws.services.dynamodbv2.document.DynamoDB +import com.amazonaws.services.kinesis.AmazonKinesisClient +import com.amazonaws.services.kinesis.model._ + +import org.apache.spark.Logging + +/** + * Shared utility methods for performing Kinesis tests that actually transfer data + */ +private class KinesisTestUtils( + val endpointUrl: String = "https://kinesis.us-west-2.amazonaws.com", + _regionName: String = "") extends Logging { + + val regionName = if (_regionName.length == 0) { + RegionUtils.getRegionByEndpoint(endpointUrl).getName() + } else { + RegionUtils.getRegion(_regionName).getName() + } + + val streamShardCount = 2 + + private val createStreamTimeoutSeconds = 300 + private val describeStreamPollTimeSeconds = 1 + + @volatile + private var streamCreated = false + private var _streamName: String = _ + + private lazy val kinesisClient = { + val client = new AmazonKinesisClient(KinesisTestUtils.getAWSCredentials()) + client.setEndpoint(endpointUrl) + client + } + + private lazy val dynamoDB = { + val dynamoDBClient = new AmazonDynamoDBClient(new DefaultAWSCredentialsProviderChain()) + dynamoDBClient.setRegion(RegionUtils.getRegion(regionName)) + new DynamoDB(dynamoDBClient) + } + + def streamName: String = { + require(streamCreated, "Stream not yet created, call createStream() to create one") + _streamName + } + + def createStream(): Unit = { + logInfo("Creating stream") + require(!streamCreated, "Stream already created") + _streamName = findNonExistentStreamName() + + // Create a stream. The number of shards determines the provisioned throughput. + val createStreamRequest = new CreateStreamRequest() + createStreamRequest.setStreamName(_streamName) + createStreamRequest.setShardCount(2) + kinesisClient.createStream(createStreamRequest) + + // The stream is now being created. Wait for it to become active. + waitForStreamToBeActive(_streamName) + streamCreated = true + logInfo("Created stream") + } + + /** + * Push data to Kinesis stream and return a map of + * shardId -> seq of (data, seq number) pushed to corresponding shard + */ + def pushData(testData: Seq[Int]): Map[String, Seq[(Int, String)]] = { + require(streamCreated, "Stream not yet created, call createStream() to create one") + val shardIdToSeqNumbers = new mutable.HashMap[String, ArrayBuffer[(Int, String)]]() + + testData.foreach { num => + val str = num.toString + val putRecordRequest = new PutRecordRequest().withStreamName(streamName) + .withData(ByteBuffer.wrap(str.getBytes())) + .withPartitionKey(str) + + val putRecordResult = kinesisClient.putRecord(putRecordRequest) + val shardId = putRecordResult.getShardId + val seqNumber = putRecordResult.getSequenceNumber() + val sentSeqNumbers = shardIdToSeqNumbers.getOrElseUpdate(shardId, + new ArrayBuffer[(Int, String)]()) + sentSeqNumbers += ((num, seqNumber)) + } + + logInfo(s"Pushed $testData:\n\t ${shardIdToSeqNumbers.mkString("\n\t")}") + shardIdToSeqNumbers.toMap + } + + def describeStream(streamNameToDescribe: String = streamName): Option[StreamDescription] = { + try { + val describeStreamRequest = new DescribeStreamRequest().withStreamName(streamNameToDescribe) + val desc = kinesisClient.describeStream(describeStreamRequest).getStreamDescription() + Some(desc) + } catch { + case rnfe: ResourceNotFoundException => + None + } + } + + def deleteStream(): Unit = { + try { + if (describeStream().nonEmpty) { + val deleteStreamRequest = new DeleteStreamRequest() + kinesisClient.deleteStream(streamName) + } + } catch { + case e: Exception => + logWarning(s"Could not delete stream $streamName") + } + } + + def deleteDynamoDBTable(tableName: String): Unit = { + try { + val table = dynamoDB.getTable(tableName) + table.delete() + table.waitForDelete() + } catch { + case e: Exception => + logWarning(s"Could not delete DynamoDB table $tableName") + } + } + + private def findNonExistentStreamName(): String = { + var testStreamName: String = null + do { + Thread.sleep(TimeUnit.SECONDS.toMillis(describeStreamPollTimeSeconds)) + testStreamName = s"KinesisTestUtils-${math.abs(Random.nextLong())}" + } while (describeStream(testStreamName).nonEmpty) + testStreamName + } + + private def waitForStreamToBeActive(streamNameToWaitFor: String): Unit = { + val startTime = System.currentTimeMillis() + val endTime = startTime + TimeUnit.SECONDS.toMillis(createStreamTimeoutSeconds) + while (System.currentTimeMillis() < endTime) { + Thread.sleep(TimeUnit.SECONDS.toMillis(describeStreamPollTimeSeconds)) + describeStream(streamNameToWaitFor).foreach { description => + val streamStatus = description.getStreamStatus() + logDebug(s"\t- current state: $streamStatus\n") + if ("ACTIVE".equals(streamStatus)) { + return + } + } + } + require(false, s"Stream $streamName never became active") + } +} + +private[kinesis] object KinesisTestUtils { + + val envVarName = "RUN_KINESIS_TESTS" + + val shouldRunTests = sys.env.get(envVarName) == Some("1") + + def isAWSCredentialsPresent: Boolean = { + Try { new DefaultAWSCredentialsProviderChain().getCredentials() }.isSuccess + } + + def getAWSCredentials(): AWSCredentials = { + assert(shouldRunTests, + "Kinesis test not enabled, should not attempt to get AWS credentials") + Try { new DefaultAWSCredentialsProviderChain().getCredentials() } match { + case Success(cred) => cred + case Failure(e) => + throw new Exception("Kinesis tests enabled, but could get not AWS credentials") + } + } +} diff --git a/extras/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KinesisFunSuite.scala b/extras/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KinesisFunSuite.scala new file mode 100644 index 0000000000000..6d011f295e7f7 --- /dev/null +++ b/extras/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KinesisFunSuite.scala @@ -0,0 +1,37 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.streaming.kinesis + +import org.apache.spark.SparkFunSuite + +/** + * Helper class that runs Kinesis real data transfer tests or + * ignores them based on env variable is set or not. + */ +trait KinesisSuiteHelper { self: SparkFunSuite => + import KinesisTestUtils._ + + /** Run the test if environment variable is set or ignore the test */ + def testOrIgnore(testName: String)(testBody: => Unit) { + if (shouldRunTests) { + test(testName)(testBody) + } else { + ignore(s"$testName [enable by setting env var $envVarName=1]")(testBody) + } + } +} diff --git a/extras/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KinesisReceiverSuite.scala b/extras/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KinesisReceiverSuite.scala index 2103dca6b766f..98f2c7c4f1bfb 100644 --- a/extras/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KinesisReceiverSuite.scala +++ b/extras/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KinesisReceiverSuite.scala @@ -73,23 +73,6 @@ class KinesisReceiverSuite extends TestSuiteBase with Matchers with BeforeAndAft checkpointStateMock, currentClockMock) } - test("KinesisUtils API") { - val ssc = new StreamingContext(master, framework, batchDuration) - // Tests the API, does not actually test data receiving - val kinesisStream1 = KinesisUtils.createStream(ssc, "mySparkStream", - "https://kinesis.us-west-2.amazonaws.com", Seconds(2), - InitialPositionInStream.LATEST, StorageLevel.MEMORY_AND_DISK_2) - val kinesisStream2 = KinesisUtils.createStream(ssc, "myAppNam", "mySparkStream", - "https://kinesis.us-west-2.amazonaws.com", "us-west-2", - InitialPositionInStream.LATEST, Seconds(2), StorageLevel.MEMORY_AND_DISK_2) - val kinesisStream3 = KinesisUtils.createStream(ssc, "myAppNam", "mySparkStream", - "https://kinesis.us-west-2.amazonaws.com", "us-west-2", - InitialPositionInStream.LATEST, Seconds(2), StorageLevel.MEMORY_AND_DISK_2, - "awsAccessKey", "awsSecretKey") - - ssc.stop() - } - test("check serializability of SerializableAWSCredentials") { Utils.deserialize[SerializableAWSCredentials]( Utils.serialize(new SerializableAWSCredentials("x", "y"))) diff --git a/extras/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KinesisStreamSuite.scala b/extras/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KinesisStreamSuite.scala new file mode 100644 index 0000000000000..50f71413abf37 --- /dev/null +++ b/extras/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KinesisStreamSuite.scala @@ -0,0 +1,115 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.streaming.kinesis + +import scala.collection.mutable +import scala.concurrent.duration._ +import scala.language.postfixOps +import scala.util.Random + +import com.amazonaws.services.kinesis.clientlibrary.lib.worker.InitialPositionInStream +import org.scalatest.concurrent.Eventually +import org.scalatest.{BeforeAndAfter, BeforeAndAfterAll} + +import org.apache.spark.storage.StorageLevel +import org.apache.spark.streaming._ +import org.apache.spark.{SparkConf, SparkContext, SparkFunSuite} + +class KinesisStreamSuite extends SparkFunSuite with KinesisSuiteHelper + with Eventually with BeforeAndAfter with BeforeAndAfterAll { + + // This is the name that KCL uses to save metadata to DynamoDB + private val kinesisAppName = s"KinesisStreamSuite-${math.abs(Random.nextLong())}" + + private var ssc: StreamingContext = _ + private var sc: SparkContext = _ + + override def beforeAll(): Unit = { + val conf = new SparkConf() + .setMaster("local[4]") + .setAppName("KinesisStreamSuite") // Setting Spark app name to Kinesis app name + sc = new SparkContext(conf) + } + + override def afterAll(): Unit = { + sc.stop() + // Delete the Kinesis stream as well as the DynamoDB table generated by + // Kinesis Client Library when consuming the stream + } + + after { + if (ssc != null) { + ssc.stop(stopSparkContext = false) + ssc = null + } + } + + test("KinesisUtils API") { + ssc = new StreamingContext(sc, Seconds(1)) + // Tests the API, does not actually test data receiving + val kinesisStream1 = KinesisUtils.createStream(ssc, "mySparkStream", + "https://kinesis.us-west-2.amazonaws.com", Seconds(2), + InitialPositionInStream.LATEST, StorageLevel.MEMORY_AND_DISK_2) + val kinesisStream2 = KinesisUtils.createStream(ssc, "myAppNam", "mySparkStream", + "https://kinesis.us-west-2.amazonaws.com", "us-west-2", + InitialPositionInStream.LATEST, Seconds(2), StorageLevel.MEMORY_AND_DISK_2) + val kinesisStream3 = KinesisUtils.createStream(ssc, "myAppNam", "mySparkStream", + "https://kinesis.us-west-2.amazonaws.com", "us-west-2", + InitialPositionInStream.LATEST, Seconds(2), StorageLevel.MEMORY_AND_DISK_2, + "awsAccessKey", "awsSecretKey") + } + + + /** + * Test the stream by sending data to a Kinesis stream and receiving from it. + * This test is not run by default as it requires AWS credentials that the test + * environment may not have. Even if there is AWS credentials available, the user + * may not want to run these tests to avoid the Kinesis costs. To enable this test, + * you must have AWS credentials available through the default AWS provider chain, + * and you have to set the system environment variable RUN_KINESIS_TESTS=1 . + */ + testOrIgnore("basic operation") { + val kinesisTestUtils = new KinesisTestUtils() + try { + kinesisTestUtils.createStream() + ssc = new StreamingContext(sc, Seconds(1)) + val aWSCredentials = KinesisTestUtils.getAWSCredentials() + val stream = KinesisUtils.createStream(ssc, kinesisAppName, kinesisTestUtils.streamName, + kinesisTestUtils.endpointUrl, kinesisTestUtils.regionName, InitialPositionInStream.LATEST, + Seconds(10), StorageLevel.MEMORY_ONLY, + aWSCredentials.getAWSAccessKeyId, aWSCredentials.getAWSSecretKey) + + val collected = new mutable.HashSet[Int] with mutable.SynchronizedSet[Int] + stream.map { bytes => new String(bytes).toInt }.foreachRDD { rdd => + collected ++= rdd.collect() + logInfo("Collected = " + rdd.collect().toSeq.mkString(", ")) + } + ssc.start() + + val testData = 1 to 10 + eventually(timeout(120 seconds), interval(10 second)) { + kinesisTestUtils.pushData(testData) + assert(collected === testData.toSet, "\nData received does not match data sent") + } + ssc.stop() + } finally { + kinesisTestUtils.deleteStream() + kinesisTestUtils.deleteDynamoDBTable(kinesisAppName) + } + } +} diff --git a/graphx/src/main/scala/org/apache/spark/graphx/PartitionStrategy.scala b/graphx/src/main/scala/org/apache/spark/graphx/PartitionStrategy.scala index 7372dfbd9fe98..70a7592da8ae3 100644 --- a/graphx/src/main/scala/org/apache/spark/graphx/PartitionStrategy.scala +++ b/graphx/src/main/scala/org/apache/spark/graphx/PartitionStrategy.scala @@ -32,7 +32,7 @@ trait PartitionStrategy extends Serializable { object PartitionStrategy { /** * Assigns edges to partitions using a 2D partitioning of the sparse edge adjacency matrix, - * guaranteeing a `2 * sqrt(numParts) - 1` bound on vertex replication. + * guaranteeing a `2 * sqrt(numParts)` bound on vertex replication. * * Suppose we have a graph with 12 vertices that we want to partition * over 9 machines. We can use the following sparse matrix representation: @@ -61,26 +61,36 @@ object PartitionStrategy { * that edges adjacent to `v11` can only be in the first column of blocks `(P0, P3, * P6)` or the last * row of blocks `(P6, P7, P8)`. As a consequence we can guarantee that `v11` will need to be - * replicated to at most `2 * sqrt(numParts) - 1` machines. + * replicated to at most `2 * sqrt(numParts)` machines. * * Notice that `P0` has many edges and as a consequence this partitioning would lead to poor work * balance. To improve balance we first multiply each vertex id by a large prime to shuffle the * vertex locations. * - * One of the limitations of this approach is that the number of machines must either be a - * perfect square. We partially address this limitation by computing the machine assignment to - * the next - * largest perfect square and then mapping back down to the actual number of machines. - * Unfortunately, this can also lead to work imbalance and so it is suggested that a perfect - * square is used. + * When the number of partitions requested is not a perfect square we use a slightly different + * method where the last column can have a different number of rows than the others while still + * maintaining the same size per block. */ case object EdgePartition2D extends PartitionStrategy { override def getPartition(src: VertexId, dst: VertexId, numParts: PartitionID): PartitionID = { val ceilSqrtNumParts: PartitionID = math.ceil(math.sqrt(numParts)).toInt val mixingPrime: VertexId = 1125899906842597L - val col: PartitionID = (math.abs(src * mixingPrime) % ceilSqrtNumParts).toInt - val row: PartitionID = (math.abs(dst * mixingPrime) % ceilSqrtNumParts).toInt - (col * ceilSqrtNumParts + row) % numParts + if (numParts == ceilSqrtNumParts * ceilSqrtNumParts) { + // Use old method for perfect squared to ensure we get same results + val col: PartitionID = (math.abs(src * mixingPrime) % ceilSqrtNumParts).toInt + val row: PartitionID = (math.abs(dst * mixingPrime) % ceilSqrtNumParts).toInt + (col * ceilSqrtNumParts + row) % numParts + + } else { + // Otherwise use new method + val cols = ceilSqrtNumParts + val rows = (numParts + cols - 1) / cols + val lastColRows = numParts - rows * (cols - 1) + val col = (math.abs(src * mixingPrime) % numParts / rows).toInt + val row = (math.abs(dst * mixingPrime) % (if (col < cols - 1) rows else lastColRows)).toInt + col * rows + row + + } } } diff --git a/graphx/src/main/scala/org/apache/spark/graphx/impl/GraphImpl.scala b/graphx/src/main/scala/org/apache/spark/graphx/impl/GraphImpl.scala index 90a74d23a26cc..da95314440d86 100644 --- a/graphx/src/main/scala/org/apache/spark/graphx/impl/GraphImpl.scala +++ b/graphx/src/main/scala/org/apache/spark/graphx/impl/GraphImpl.scala @@ -332,9 +332,9 @@ object GraphImpl { edgeStorageLevel: StorageLevel, vertexStorageLevel: StorageLevel): GraphImpl[VD, ED] = { val edgeRDD = EdgeRDD.fromEdges(edges)(classTag[ED], classTag[VD]) - .withTargetStorageLevel(edgeStorageLevel).cache() + .withTargetStorageLevel(edgeStorageLevel) val vertexRDD = VertexRDD(vertices, edgeRDD, defaultVertexAttr) - .withTargetStorageLevel(vertexStorageLevel).cache() + .withTargetStorageLevel(vertexStorageLevel) GraphImpl(vertexRDD, edgeRDD) } @@ -346,9 +346,14 @@ object GraphImpl { def apply[VD: ClassTag, ED: ClassTag]( vertices: VertexRDD[VD], edges: EdgeRDD[ED]): GraphImpl[VD, ED] = { + + vertices.cache() + // Convert the vertex partitions in edges to the correct type val newEdges = edges.asInstanceOf[EdgeRDDImpl[ED, _]] .mapEdgePartitions((pid, part) => part.withoutVertexAttributes[VD]) + .cache() + GraphImpl.fromExistingRDDs(vertices, newEdges) } diff --git a/graphx/src/main/scala/org/apache/spark/graphx/util/BytecodeUtils.scala b/graphx/src/main/scala/org/apache/spark/graphx/util/BytecodeUtils.scala index 5c07b415cd796..74a7de18d4161 100644 --- a/graphx/src/main/scala/org/apache/spark/graphx/util/BytecodeUtils.scala +++ b/graphx/src/main/scala/org/apache/spark/graphx/util/BytecodeUtils.scala @@ -121,7 +121,7 @@ private[graphx] object BytecodeUtils { override def visitMethodInsn(op: Int, owner: String, name: String, desc: String) { if (op == INVOKEVIRTUAL || op == INVOKESPECIAL || op == INVOKESTATIC) { if (!skipClass(owner)) { - methodsInvoked.add((Class.forName(owner.replace("/", ".")), name)) + methodsInvoked.add((Utils.classForName(owner.replace("/", ".")), name)) } } } diff --git a/launcher/src/main/java/org/apache/spark/launcher/SparkLauncher.java b/launcher/src/main/java/org/apache/spark/launcher/SparkLauncher.java index d4cfeacb6ef18..c0f89c9230692 100644 --- a/launcher/src/main/java/org/apache/spark/launcher/SparkLauncher.java +++ b/launcher/src/main/java/org/apache/spark/launcher/SparkLauncher.java @@ -25,11 +25,12 @@ import static org.apache.spark.launcher.CommandBuilderUtils.*; -/** +/** * Launcher for Spark applications. - *

    + *

    * Use this class to start Spark applications programmatically. The class uses a builder pattern * to allow clients to configure the Spark application and launch it as a child process. + *

    */ public class SparkLauncher { diff --git a/launcher/src/main/java/org/apache/spark/launcher/package-info.java b/launcher/src/main/java/org/apache/spark/launcher/package-info.java index 7ed756f4b8591..7c97dba511b28 100644 --- a/launcher/src/main/java/org/apache/spark/launcher/package-info.java +++ b/launcher/src/main/java/org/apache/spark/launcher/package-info.java @@ -17,13 +17,17 @@ /** * Library for launching Spark applications. - *

    + * + *

    * This library allows applications to launch Spark programmatically. There's only one entry * point to the library - the {@link org.apache.spark.launcher.SparkLauncher} class. - *

    + *

    + * + *

    * To launch a Spark application, just instantiate a {@link org.apache.spark.launcher.SparkLauncher} * and configure the application to run. For example: - * + *

    + * *
      * {@code
      *   import org.apache.spark.launcher.SparkLauncher;
    diff --git a/make-distribution.sh b/make-distribution.sh
    index 9f063da3a16c0..cac7032bb2e87 100755
    --- a/make-distribution.sh
    +++ b/make-distribution.sh
    @@ -219,6 +219,7 @@ cp -r "$SPARK_HOME/ec2" "$DISTDIR"
     if [ -d "$SPARK_HOME"/R/lib/SparkR ]; then
       mkdir -p "$DISTDIR"/R/lib
       cp -r "$SPARK_HOME/R/lib/SparkR" "$DISTDIR"/R/lib
    +  cp "$SPARK_HOME/R/lib/sparkr.zip" "$DISTDIR"/R/lib
     fi
     
     # Download and copy in tachyon, if requested
    diff --git a/mllib/src/main/scala/org/apache/spark/ml/Predictor.scala b/mllib/src/main/scala/org/apache/spark/ml/Predictor.scala
    index 333b42711ec52..19fe039b8fd03 100644
    --- a/mllib/src/main/scala/org/apache/spark/ml/Predictor.scala
    +++ b/mllib/src/main/scala/org/apache/spark/ml/Predictor.scala
    @@ -169,10 +169,7 @@ abstract class PredictionModel[FeaturesType, M <: PredictionModel[FeaturesType,
       override def transform(dataset: DataFrame): DataFrame = {
         transformSchema(dataset.schema, logging = true)
         if ($(predictionCol).nonEmpty) {
    -      val predictUDF = udf { (features: Any) =>
    -        predict(features.asInstanceOf[FeaturesType])
    -      }
    -      dataset.withColumn($(predictionCol), predictUDF(col($(featuresCol))))
    +      transformImpl(dataset)
         } else {
           this.logWarning(s"$uid: Predictor.transform() was called as NOOP" +
             " since no output columns were set.")
    @@ -180,6 +177,13 @@ abstract class PredictionModel[FeaturesType, M <: PredictionModel[FeaturesType,
         }
       }
     
    +  protected def transformImpl(dataset: DataFrame): DataFrame = {
    +    val predictUDF = udf { (features: Any) =>
    +      predict(features.asInstanceOf[FeaturesType])
    +    }
    +    dataset.withColumn($(predictionCol), predictUDF(col($(featuresCol))))
    +  }
    +
       /**
        * Predict label for the given features.
        * This internal method is used to implement [[transform()]] and output [[predictionCol]].
    diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala
    index 2dc1824964a42..36fe1bd40469c 100644
    --- a/mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala
    +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala
    @@ -21,10 +21,10 @@ import org.apache.spark.annotation.Experimental
     import org.apache.spark.ml.{PredictionModel, Predictor}
     import org.apache.spark.ml.param.ParamMap
     import org.apache.spark.ml.tree.{DecisionTreeModel, DecisionTreeParams, Node, TreeClassifierParams}
    +import org.apache.spark.ml.tree.impl.RandomForest
     import org.apache.spark.ml.util.{Identifiable, MetadataUtils}
     import org.apache.spark.mllib.linalg.Vector
     import org.apache.spark.mllib.regression.LabeledPoint
    -import org.apache.spark.mllib.tree.{DecisionTree => OldDecisionTree}
     import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo, Strategy => OldStrategy}
     import org.apache.spark.mllib.tree.model.{DecisionTreeModel => OldDecisionTreeModel}
     import org.apache.spark.rdd.RDD
    @@ -75,8 +75,9 @@ final class DecisionTreeClassifier(override val uid: String)
         }
         val oldDataset: RDD[LabeledPoint] = extractLabeledPoints(dataset)
         val strategy = getOldStrategy(categoricalFeatures, numClasses)
    -    val oldModel = OldDecisionTree.train(oldDataset, strategy)
    -    DecisionTreeClassificationModel.fromOld(oldModel, this, categoricalFeatures)
    +    val trees = RandomForest.run(oldDataset, strategy, numTrees = 1, featureSubsetStrategy = "all",
    +      seed = 0L, parentUID = Some(uid))
    +    trees.head.asInstanceOf[DecisionTreeClassificationModel]
       }
     
       /** (private[ml]) Create a Strategy instance to use with the old API. */
    @@ -112,6 +113,12 @@ final class DecisionTreeClassificationModel private[ml] (
       require(rootNode != null,
         "DecisionTreeClassificationModel given null rootNode, but it requires a non-null rootNode.")
     
    +  /**
    +   * Construct a decision tree classification model.
    +   * @param rootNode  Root node of tree, with other nodes attached.
    +   */
    +  def this(rootNode: Node) = this(Identifiable.randomUID("dtc"), rootNode)
    +
       override protected def predict(features: Vector): Double = {
         rootNode.predict(features)
       }
    diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala
    index 554e3b8e052b2..eb0b1a0a405fc 100644
    --- a/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala
    +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala
    @@ -34,6 +34,8 @@ import org.apache.spark.mllib.tree.loss.{LogLoss => OldLogLoss, Loss => OldLoss}
     import org.apache.spark.mllib.tree.model.{GradientBoostedTreesModel => OldGBTModel}
     import org.apache.spark.rdd.RDD
     import org.apache.spark.sql.DataFrame
    +import org.apache.spark.sql.functions._
    +import org.apache.spark.sql.types.DoubleType
     
     /**
      * :: Experimental ::
    @@ -177,8 +179,15 @@ final class GBTClassificationModel(
     
       override def treeWeights: Array[Double] = _treeWeights
     
    +  override protected def transformImpl(dataset: DataFrame): DataFrame = {
    +    val bcastModel = dataset.sqlContext.sparkContext.broadcast(this)
    +    val predictUDF = udf { (features: Any) =>
    +      bcastModel.value.predict(features.asInstanceOf[Vector])
    +    }
    +    dataset.withColumn($(predictionCol), predictUDF(col($(featuresCol))))
    +  }
    +
       override protected def predict(features: Vector): Double = {
    -    // TODO: Override transform() to broadcast model: SPARK-7127
         // TODO: When we add a generic Boosting class, handle transform there?  SPARK-7129
         // Classifies by thresholding sum of weighted tree predictions
         val treePredictions = _trees.map(_.rootNode.predict(features))
    diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/NaiveBayes.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/NaiveBayes.scala
    new file mode 100644
    index 0000000000000..1f547e4a98af7
    --- /dev/null
    +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/NaiveBayes.scala
    @@ -0,0 +1,178 @@
    +/*
    + * Licensed to the Apache Software Foundation (ASF) under one or more
    + * contributor license agreements.  See the NOTICE file distributed with
    + * this work for additional information regarding copyright ownership.
    + * The ASF licenses this file to You under the Apache License, Version 2.0
    + * (the "License"); you may not use this file except in compliance with
    + * the License.  You may obtain a copy of the License at
    + *
    + *    http://www.apache.org/licenses/LICENSE-2.0
    + *
    + * Unless required by applicable law or agreed to in writing, software
    + * distributed under the License is distributed on an "AS IS" BASIS,
    + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    + * See the License for the specific language governing permissions and
    + * limitations under the License.
    + */
    +
    +package org.apache.spark.ml.classification
    +
    +import org.apache.spark.SparkException
    +import org.apache.spark.ml.{PredictorParams, PredictionModel, Predictor}
    +import org.apache.spark.ml.param.{ParamMap, ParamValidators, Param, DoubleParam}
    +import org.apache.spark.ml.util.Identifiable
    +import org.apache.spark.mllib.classification.{NaiveBayes => OldNaiveBayes}
    +import org.apache.spark.mllib.classification.{NaiveBayesModel => OldNaiveBayesModel}
    +import org.apache.spark.mllib.linalg._
    +import org.apache.spark.mllib.regression.LabeledPoint
    +import org.apache.spark.rdd.RDD
    +import org.apache.spark.sql.DataFrame
    +
    +/**
    + * Params for Naive Bayes Classifiers.
    + */
    +private[ml] trait NaiveBayesParams extends PredictorParams {
    +
    +  /**
    +   * The smoothing parameter.
    +   * (default = 1.0).
    +   * @group param
    +   */
    +  final val lambda: DoubleParam = new DoubleParam(this, "lambda", "The smoothing parameter.",
    +    ParamValidators.gtEq(0))
    +
    +  /** @group getParam */
    +  final def getLambda: Double = $(lambda)
    +
    +  /**
    +   * The model type which is a string (case-sensitive).
    +   * Supported options: "multinomial" and "bernoulli".
    +   * (default = multinomial)
    +   * @group param
    +   */
    +  final val modelType: Param[String] = new Param[String](this, "modelType", "The model type " +
    +    "which is a string (case-sensitive). Supported options: multinomial (default) and bernoulli.",
    +    ParamValidators.inArray[String](OldNaiveBayes.supportedModelTypes.toArray))
    +
    +  /** @group getParam */
    +  final def getModelType: String = $(modelType)
    +}
    +
    +/**
    + * Naive Bayes Classifiers.
    + * It supports both Multinomial NB
    + * ([[http://nlp.stanford.edu/IR-book/html/htmledition/naive-bayes-text-classification-1.html]])
    + * which can handle finitely supported discrete data. For example, by converting documents into
    + * TF-IDF vectors, it can be used for document classification. By making every vector a
    + * binary (0/1) data, it can also be used as Bernoulli NB
    + * ([[http://nlp.stanford.edu/IR-book/html/htmledition/the-bernoulli-model-1.html]]).
    + * The input feature values must be nonnegative.
    + */
    +class NaiveBayes(override val uid: String)
    +  extends Predictor[Vector, NaiveBayes, NaiveBayesModel]
    +  with NaiveBayesParams {
    +
    +  def this() = this(Identifiable.randomUID("nb"))
    +
    +  /**
    +   * Set the smoothing parameter.
    +   * Default is 1.0.
    +   * @group setParam
    +   */
    +  def setLambda(value: Double): this.type = set(lambda, value)
    +  setDefault(lambda -> 1.0)
    +
    +  /**
    +   * Set the model type using a string (case-sensitive).
    +   * Supported options: "multinomial" and "bernoulli".
    +   * Default is "multinomial"
    +   */
    +  def setModelType(value: String): this.type = set(modelType, value)
    +  setDefault(modelType -> OldNaiveBayes.Multinomial)
    +
    +  override protected def train(dataset: DataFrame): NaiveBayesModel = {
    +    val oldDataset: RDD[LabeledPoint] = extractLabeledPoints(dataset)
    +    val oldModel = OldNaiveBayes.train(oldDataset, $(lambda), $(modelType))
    +    NaiveBayesModel.fromOld(oldModel, this)
    +  }
    +
    +  override def copy(extra: ParamMap): NaiveBayes = defaultCopy(extra)
    +}
    +
    +/**
    + * Model produced by [[NaiveBayes]]
    + */
    +class NaiveBayesModel private[ml] (
    +    override val uid: String,
    +    val pi: Vector,
    +    val theta: Matrix)
    +  extends PredictionModel[Vector, NaiveBayesModel] with NaiveBayesParams {
    +
    +  import OldNaiveBayes.{Bernoulli, Multinomial}
    +
    +  /**
    +   * Bernoulli scoring requires log(condprob) if 1, log(1-condprob) if 0.
    +   * This precomputes log(1.0 - exp(theta)) and its sum which are used for the linear algebra
    +   * application of this condition (in predict function).
    +   */
    +  private lazy val (thetaMinusNegTheta, negThetaSum) = $(modelType) match {
    +    case Multinomial => (None, None)
    +    case Bernoulli =>
    +      val negTheta = theta.map(value => math.log(1.0 - math.exp(value)))
    +      val ones = new DenseVector(Array.fill(theta.numCols){1.0})
    +      val thetaMinusNegTheta = theta.map { value =>
    +        value - math.log(1.0 - math.exp(value))
    +      }
    +      (Option(thetaMinusNegTheta), Option(negTheta.multiply(ones)))
    +    case _ =>
    +      // This should never happen.
    +      throw new UnknownError(s"Invalid modelType: ${$(modelType)}.")
    +  }
    +
    +  override protected def predict(features: Vector): Double = {
    +    $(modelType) match {
    +      case Multinomial =>
    +        val prob = theta.multiply(features)
    +        BLAS.axpy(1.0, pi, prob)
    +        prob.argmax
    +      case Bernoulli =>
    +        features.foreachActive{ (index, value) =>
    +          if (value != 0.0 && value != 1.0) {
    +            throw new SparkException(
    +              s"Bernoulli naive Bayes requires 0 or 1 feature values but found $features")
    +          }
    +        }
    +        val prob = thetaMinusNegTheta.get.multiply(features)
    +        BLAS.axpy(1.0, pi, prob)
    +        BLAS.axpy(1.0, negThetaSum.get, prob)
    +        prob.argmax
    +      case _ =>
    +        // This should never happen.
    +        throw new UnknownError(s"Invalid modelType: ${$(modelType)}.")
    +    }
    +  }
    +
    +  override def copy(extra: ParamMap): NaiveBayesModel = {
    +    copyValues(new NaiveBayesModel(uid, pi, theta).setParent(this.parent), extra)
    +  }
    +
    +  override def toString: String = {
    +    s"NaiveBayesModel with ${pi.size} classes"
    +  }
    +
    +}
    +
    +private[ml] object NaiveBayesModel {
    +
    +  /** Convert a model from the old API */
    +  def fromOld(
    +      oldModel: OldNaiveBayesModel,
    +      parent: NaiveBayes): NaiveBayesModel = {
    +    val uid = if (parent != null) parent.uid else Identifiable.randomUID("nb")
    +    val labels = Vectors.dense(oldModel.labels)
    +    val pi = Vectors.dense(oldModel.pi)
    +    val theta = new DenseMatrix(oldModel.labels.length, oldModel.theta(0).length,
    +      oldModel.theta.flatten, true)
    +    new NaiveBayesModel(uid, pi, theta)
    +  }
    +}
    diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala
    index d3c67494a31e4..fc0693f67cc2e 100644
    --- a/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala
    +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala
    @@ -20,17 +20,19 @@ package org.apache.spark.ml.classification
     import scala.collection.mutable
     
     import org.apache.spark.annotation.Experimental
    +import org.apache.spark.ml.tree.impl.RandomForest
     import org.apache.spark.ml.{PredictionModel, Predictor}
     import org.apache.spark.ml.param.ParamMap
     import org.apache.spark.ml.tree.{DecisionTreeModel, RandomForestParams, TreeClassifierParams, TreeEnsembleModel}
     import org.apache.spark.ml.util.{Identifiable, MetadataUtils}
     import org.apache.spark.mllib.linalg.Vector
     import org.apache.spark.mllib.regression.LabeledPoint
    -import org.apache.spark.mllib.tree.{RandomForest => OldRandomForest}
     import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo}
     import org.apache.spark.mllib.tree.model.{RandomForestModel => OldRandomForestModel}
     import org.apache.spark.rdd.RDD
     import org.apache.spark.sql.DataFrame
    +import org.apache.spark.sql.functions._
    +import org.apache.spark.sql.types.DoubleType
     
     /**
      * :: Experimental ::
    @@ -93,9 +95,10 @@ final class RandomForestClassifier(override val uid: String)
         val oldDataset: RDD[LabeledPoint] = extractLabeledPoints(dataset)
         val strategy =
           super.getOldStrategy(categoricalFeatures, numClasses, OldAlgo.Classification, getOldImpurity)
    -    val oldModel = OldRandomForest.trainClassifier(
    -      oldDataset, strategy, getNumTrees, getFeatureSubsetStrategy, getSeed.toInt)
    -    RandomForestClassificationModel.fromOld(oldModel, this, categoricalFeatures)
    +    val trees =
    +      RandomForest.run(oldDataset, strategy, getNumTrees, getFeatureSubsetStrategy, getSeed)
    +        .map(_.asInstanceOf[DecisionTreeClassificationModel])
    +    new RandomForestClassificationModel(trees)
       }
     
       override def copy(extra: ParamMap): RandomForestClassifier = defaultCopy(extra)
    @@ -128,6 +131,13 @@ final class RandomForestClassificationModel private[ml] (
     
       require(numTrees > 0, "RandomForestClassificationModel requires at least 1 tree.")
     
    +  /**
    +   * Construct a random forest classification model, with all trees weighted equally.
    +   * @param trees  Component trees
    +   */
    +  def this(trees: Array[DecisionTreeClassificationModel]) =
    +    this(Identifiable.randomUID("rfc"), trees)
    +
       override def trees: Array[DecisionTreeModel] = _trees.asInstanceOf[Array[DecisionTreeModel]]
     
       // Note: We may add support for weights (based on tree performance) later on.
    @@ -135,8 +145,15 @@ final class RandomForestClassificationModel private[ml] (
     
       override def treeWeights: Array[Double] = _treeWeights
     
    +  override protected def transformImpl(dataset: DataFrame): DataFrame = {
    +    val bcastModel = dataset.sqlContext.sparkContext.broadcast(this)
    +    val predictUDF = udf { (features: Any) =>
    +      bcastModel.value.predict(features.asInstanceOf[Vector])
    +    }
    +    dataset.withColumn($(predictionCol), predictUDF(col($(featuresCol))))
    +  }
    +
       override protected def predict(features: Vector): Double = {
    -    // TODO: Override transform() to broadcast model.  SPARK-7127
         // TODO: When we add a generic Bagging class, handle transform there: SPARK-7128
         // Classifies using majority votes.
         // Ignore the weights since all are 1.0 for now.
    diff --git a/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala b/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala
    new file mode 100644
    index 0000000000000..dc192add6ca13
    --- /dev/null
    +++ b/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala
    @@ -0,0 +1,205 @@
    +/*
    + * Licensed to the Apache Software Foundation (ASF) under one or more
    + * contributor license agreements.  See the NOTICE file distributed with
    + * this work for additional information regarding copyright ownership.
    + * The ASF licenses this file to You under the Apache License, Version 2.0
    + * (the "License"); you may not use this file except in compliance with
    + * the License.  You may obtain a copy of the License at
    + *
    + *    http://www.apache.org/licenses/LICENSE-2.0
    + *
    + * Unless required by applicable law or agreed to in writing, software
    + * distributed under the License is distributed on an "AS IS" BASIS,
    + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    + * See the License for the specific language governing permissions and
    + * limitations under the License.
    + */
    +
    +package org.apache.spark.ml.clustering
    +
    +import org.apache.spark.annotation.Experimental
    +import org.apache.spark.ml.param.{Param, Params, IntParam, DoubleParam, ParamMap}
    +import org.apache.spark.ml.param.shared.{HasFeaturesCol, HasMaxIter, HasPredictionCol, HasSeed}
    +import org.apache.spark.ml.util.{Identifiable, SchemaUtils}
    +import org.apache.spark.ml.{Estimator, Model}
    +import org.apache.spark.mllib.clustering.{KMeans => MLlibKMeans, KMeansModel => MLlibKMeansModel}
    +import org.apache.spark.mllib.linalg.{Vector, VectorUDT}
    +import org.apache.spark.sql.functions.{col, udf}
    +import org.apache.spark.sql.types.{IntegerType, StructType}
    +import org.apache.spark.sql.{DataFrame, Row}
    +import org.apache.spark.util.Utils
    +
    +
    +/**
    + * Common params for KMeans and KMeansModel
    + */
    +private[clustering] trait KMeansParams
    +    extends Params with HasMaxIter with HasFeaturesCol with HasSeed with HasPredictionCol {
    +
    +  /**
    +   * Set the number of clusters to create (k). Must be > 1. Default: 2.
    +   * @group param
    +   */
    +  final val k = new IntParam(this, "k", "number of clusters to create", (x: Int) => x > 1)
    +
    +  /** @group getParam */
    +  def getK: Int = $(k)
    +
    +  /**
    +   * Param the number of runs of the algorithm to execute in parallel. We initialize the algorithm
    +   * this many times with random starting conditions (configured by the initialization mode), then
    +   * return the best clustering found over any run. Must be >= 1. Default: 1.
    +   * @group param
    +   */
    +  final val runs = new IntParam(this, "runs",
    +    "number of runs of the algorithm to execute in parallel", (value: Int) => value >= 1)
    +
    +  /** @group getParam */
    +  def getRuns: Int = $(runs)
    +
    +  /**
    +   * Param the distance threshold within which we've consider centers to have converged.
    +   * If all centers move less than this Euclidean distance, we stop iterating one run.
    +   * Must be >= 0.0. Default: 1e-4
    +   * @group param
    +   */
    +  final val epsilon = new DoubleParam(this, "epsilon",
    +    "distance threshold within which we've consider centers to have converge",
    +    (value: Double) => value >= 0.0)
    +
    +  /** @group getParam */
    +  def getEpsilon: Double = $(epsilon)
    +
    +  /**
    +   * Param for the initialization algorithm. This can be either "random" to choose random points as
    +   * initial cluster centers, or "k-means||" to use a parallel variant of k-means++
    +   * (Bahmani et al., Scalable K-Means++, VLDB 2012). Default: k-means||.
    +   * @group expertParam
    +   */
    +  final val initMode = new Param[String](this, "initMode", "initialization algorithm",
    +    (value: String) => MLlibKMeans.validateInitMode(value))
    +
    +  /** @group expertGetParam */
    +  def getInitMode: String = $(initMode)
    +
    +  /**
    +   * Param for the number of steps for the k-means|| initialization mode. This is an advanced
    +   * setting -- the default of 5 is almost always enough. Must be > 0. Default: 5.
    +   * @group expertParam
    +   */
    +  final val initSteps = new IntParam(this, "initSteps", "number of steps for k-means||",
    +    (value: Int) => value > 0)
    +
    +  /** @group expertGetParam */
    +  def getInitSteps: Int = $(initSteps)
    +
    +  /**
    +   * Validates and transforms the input schema.
    +   * @param schema input schema
    +   * @return output schema
    +   */
    +  protected def validateAndTransformSchema(schema: StructType): StructType = {
    +    SchemaUtils.checkColumnType(schema, $(featuresCol), new VectorUDT)
    +    SchemaUtils.appendColumn(schema, $(predictionCol), IntegerType)
    +  }
    +}
    +
    +/**
    + * :: Experimental ::
    + * Model fitted by KMeans.
    + *
    + * @param parentModel a model trained by spark.mllib.clustering.KMeans.
    + */
    +@Experimental
    +class KMeansModel private[ml] (
    +    override val uid: String,
    +    private val parentModel: MLlibKMeansModel) extends Model[KMeansModel] with KMeansParams {
    +
    +  override def copy(extra: ParamMap): KMeansModel = {
    +    val copied = new KMeansModel(uid, parentModel)
    +    copyValues(copied, extra)
    +  }
    +
    +  override def transform(dataset: DataFrame): DataFrame = {
    +    val predictUDF = udf((vector: Vector) => predict(vector))
    +    dataset.withColumn($(predictionCol), predictUDF(col($(featuresCol))))
    +  }
    +
    +  override def transformSchema(schema: StructType): StructType = {
    +    validateAndTransformSchema(schema)
    +  }
    +
    +  private[clustering] def predict(features: Vector): Int = parentModel.predict(features)
    +
    +  def clusterCenters: Array[Vector] = parentModel.clusterCenters
    +}
    +
    +/**
    + * :: Experimental ::
    + * K-means clustering with support for multiple parallel runs and a k-means++ like initialization
    + * mode (the k-means|| algorithm by Bahmani et al). When multiple concurrent runs are requested,
    + * they are executed together with joint passes over the data for efficiency.
    + */
    +@Experimental
    +class KMeans(override val uid: String) extends Estimator[KMeansModel] with KMeansParams {
    +
    +  setDefault(
    +    k -> 2,
    +    maxIter -> 20,
    +    runs -> 1,
    +    initMode -> MLlibKMeans.K_MEANS_PARALLEL,
    +    initSteps -> 5,
    +    epsilon -> 1e-4)
    +
    +  override def copy(extra: ParamMap): KMeans = defaultCopy(extra)
    +
    +  def this() = this(Identifiable.randomUID("kmeans"))
    +
    +  /** @group setParam */
    +  def setFeaturesCol(value: String): this.type = set(featuresCol, value)
    +
    +  /** @group setParam */
    +  def setPredictionCol(value: String): this.type = set(predictionCol, value)
    +
    +  /** @group setParam */
    +  def setK(value: Int): this.type = set(k, value)
    +
    +  /** @group expertSetParam */
    +  def setInitMode(value: String): this.type = set(initMode, value)
    +
    +  /** @group expertSetParam */
    +  def setInitSteps(value: Int): this.type = set(initSteps, value)
    +
    +  /** @group setParam */
    +  def setMaxIter(value: Int): this.type = set(maxIter, value)
    +
    +  /** @group setParam */
    +  def setRuns(value: Int): this.type = set(runs, value)
    +
    +  /** @group setParam */
    +  def setEpsilon(value: Double): this.type = set(epsilon, value)
    +
    +  /** @group setParam */
    +  def setSeed(value: Long): this.type = set(seed, value)
    +
    +  override def fit(dataset: DataFrame): KMeansModel = {
    +    val rdd = dataset.select(col($(featuresCol))).map { case Row(point: Vector) => point }
    +
    +    val algo = new MLlibKMeans()
    +      .setK($(k))
    +      .setInitializationMode($(initMode))
    +      .setInitializationSteps($(initSteps))
    +      .setMaxIterations($(maxIter))
    +      .setSeed($(seed))
    +      .setEpsilon($(epsilon))
    +      .setRuns($(runs))
    +    val parentModel = algo.run(rdd)
    +    val model = new KMeansModel(uid, parentModel)
    +    copyValues(model)
    +  }
    +
    +  override def transformSchema(schema: StructType): StructType = {
    +    validateAndTransformSchema(schema)
    +  }
    +}
    +
    diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/RFormula.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/RFormula.scala
    new file mode 100644
    index 0000000000000..f7b46efa10e90
    --- /dev/null
    +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/RFormula.scala
    @@ -0,0 +1,159 @@
    +/*
    + * Licensed to the Apache Software Foundation (ASF) under one or more
    + * contributor license agreements.  See the NOTICE file distributed with
    + * this work for additional information regarding copyright ownership.
    + * The ASF licenses this file to You under the Apache License, Version 2.0
    + * (the "License"); you may not use this file except in compliance with
    + * the License.  You may obtain a copy of the License at
    + *
    + *    http://www.apache.org/licenses/LICENSE-2.0
    + *
    + * Unless required by applicable law or agreed to in writing, software
    + * distributed under the License is distributed on an "AS IS" BASIS,
    + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    + * See the License for the specific language governing permissions and
    + * limitations under the License.
    + */
    +
    +package org.apache.spark.ml.feature
    +
    +import scala.util.parsing.combinator.RegexParsers
    +
    +import org.apache.spark.annotation.Experimental
    +import org.apache.spark.ml.Transformer
    +import org.apache.spark.ml.param.{Param, ParamMap}
    +import org.apache.spark.ml.param.shared.{HasFeaturesCol, HasLabelCol}
    +import org.apache.spark.ml.util.Identifiable
    +import org.apache.spark.sql.DataFrame
    +import org.apache.spark.sql.functions._
    +import org.apache.spark.sql.types._
    +
    +/**
    + * :: Experimental ::
    + * Implements the transforms required for fitting a dataset against an R model formula. Currently
    + * we support a limited subset of the R operators, including '~' and '+'. Also see the R formula
    + * docs here: http://stat.ethz.ch/R-manual/R-patched/library/stats/html/formula.html
    + */
    +@Experimental
    +class RFormula(override val uid: String)
    +  extends Transformer with HasFeaturesCol with HasLabelCol {
    +
    +  def this() = this(Identifiable.randomUID("rFormula"))
    +
    +  /**
    +   * R formula parameter. The formula is provided in string form.
    +   * @group param
    +   */
    +  val formula: Param[String] = new Param(this, "formula", "R model formula")
    +
    +  private var parsedFormula: Option[ParsedRFormula] = None
    +
    +  /**
    +   * Sets the formula to use for this transformer. Must be called before use.
    +   * @group setParam
    +   * @param value an R formula in string form (e.g. "y ~ x + z")
    +   */
    +  def setFormula(value: String): this.type = {
    +    parsedFormula = Some(RFormulaParser.parse(value))
    +    set(formula, value)
    +    this
    +  }
    +
    +  /** @group getParam */
    +  def getFormula: String = $(formula)
    +
    +  /** @group getParam */
    +  def setFeaturesCol(value: String): this.type = set(featuresCol, value)
    +
    +  /** @group getParam */
    +  def setLabelCol(value: String): this.type = set(labelCol, value)
    +
    +  override def transformSchema(schema: StructType): StructType = {
    +    checkCanTransform(schema)
    +    val withFeatures = transformFeatures.transformSchema(schema)
    +    if (hasLabelCol(schema)) {
    +      withFeatures
    +    } else if (schema.exists(_.name == parsedFormula.get.label)) {
    +      val nullable = schema(parsedFormula.get.label).dataType match {
    +        case _: NumericType | BooleanType => false
    +        case _ => true
    +      }
    +      StructType(withFeatures.fields :+ StructField($(labelCol), DoubleType, nullable))
    +    } else {
    +      // Ignore the label field. This is a hack so that this transformer can also work on test
    +      // datasets in a Pipeline.
    +      withFeatures
    +    }
    +  }
    +
    +  override def transform(dataset: DataFrame): DataFrame = {
    +    checkCanTransform(dataset.schema)
    +    transformLabel(transformFeatures.transform(dataset))
    +  }
    +
    +  override def copy(extra: ParamMap): RFormula = defaultCopy(extra)
    +
    +  override def toString: String = s"RFormula(${get(formula)})"
    +
    +  private def transformLabel(dataset: DataFrame): DataFrame = {
    +    val labelName = parsedFormula.get.label
    +    if (hasLabelCol(dataset.schema)) {
    +      dataset
    +    } else if (dataset.schema.exists(_.name == labelName)) {
    +      dataset.schema(labelName).dataType match {
    +        case _: NumericType | BooleanType =>
    +          dataset.withColumn($(labelCol), dataset(labelName).cast(DoubleType))
    +        // TODO(ekl) add support for string-type labels
    +        case other =>
    +          throw new IllegalArgumentException("Unsupported type for label: " + other)
    +      }
    +    } else {
    +      // Ignore the label field. This is a hack so that this transformer can also work on test
    +      // datasets in a Pipeline.
    +      dataset
    +    }
    +  }
    +
    +  private def transformFeatures: Transformer = {
    +    // TODO(ekl) add support for non-numeric features and feature interactions
    +    new VectorAssembler(uid)
    +      .setInputCols(parsedFormula.get.terms.toArray)
    +      .setOutputCol($(featuresCol))
    +  }
    +
    +  private def checkCanTransform(schema: StructType) {
    +    require(parsedFormula.isDefined, "Must call setFormula() first.")
    +    val columnNames = schema.map(_.name)
    +    require(!columnNames.contains($(featuresCol)), "Features column already exists.")
    +    require(
    +      !columnNames.contains($(labelCol)) || schema($(labelCol)).dataType == DoubleType,
    +      "Label column already exists and is not of type DoubleType.")
    +  }
    +
    +  private def hasLabelCol(schema: StructType): Boolean = {
    +    schema.map(_.name).contains($(labelCol))
    +  }
    +}
    +
    +/**
    + * Represents a parsed R formula.
    + */
    +private[ml] case class ParsedRFormula(label: String, terms: Seq[String])
    +
    +/**
    + * Limited implementation of R formula parsing. Currently supports: '~', '+'.
    + */
    +private[ml] object RFormulaParser extends RegexParsers {
    +  def term: Parser[String] = "([a-zA-Z]|\\.[a-zA-Z_])[a-zA-Z0-9._]*".r
    +
    +  def expr: Parser[List[String]] = term ~ rep("+" ~> term) ^^ { case a ~ list => a :: list }
    +
    +  def formula: Parser[ParsedRFormula] =
    +    (term ~ "~" ~ expr) ^^ { case r ~ "~" ~ t => ParsedRFormula(r, t) }
    +
    +  def parse(value: String): ParsedRFormula = parseAll(formula, value) match {
    +    case Success(result, _) => result
    +    case failure: NoSuccess => throw new IllegalArgumentException(
    +      "Could not parse formula: " + value)
    +  }
    +}
    diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/Tokenizer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/Tokenizer.scala
    index 5f9f57a2ebcfa..0b3af4747e693 100644
    --- a/mllib/src/main/scala/org/apache/spark/ml/feature/Tokenizer.scala
    +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/Tokenizer.scala
    @@ -42,7 +42,7 @@ class Tokenizer(override val uid: String) extends UnaryTransformer[String, Seq[S
         require(inputType == StringType, s"Input type must be string type but got $inputType.")
       }
     
    -  override protected def outputDataType: DataType = new ArrayType(StringType, false)
    +  override protected def outputDataType: DataType = new ArrayType(StringType, true)
     
       override def copy(extra: ParamMap): Tokenizer = defaultCopy(extra)
     }
    @@ -113,7 +113,7 @@ class RegexTokenizer(override val uid: String)
         require(inputType == StringType, s"Input type must be string type but got $inputType.")
       }
     
    -  override protected def outputDataType: DataType = new ArrayType(StringType, false)
    +  override protected def outputDataType: DataType = new ArrayType(StringType, true)
     
       override def copy(extra: ParamMap): RegexTokenizer = defaultCopy(extra)
     }
    diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/VectorAssembler.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/VectorAssembler.scala
    index 9f83c2ee16178..086917fa680f8 100644
    --- a/mllib/src/main/scala/org/apache/spark/ml/feature/VectorAssembler.scala
    +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/VectorAssembler.scala
    @@ -116,7 +116,7 @@ class VectorAssembler(override val uid: String)
         if (schema.fieldNames.contains(outputColName)) {
           throw new IllegalArgumentException(s"Output column $outputColName already exists.")
         }
    -    StructType(schema.fields :+ new StructField(outputColName, new VectorUDT, false))
    +    StructType(schema.fields :+ new StructField(outputColName, new VectorUDT, true))
       }
     
       override def copy(extra: ParamMap): VectorAssembler = defaultCopy(extra)
    diff --git a/mllib/src/main/scala/org/apache/spark/ml/param/params.scala b/mllib/src/main/scala/org/apache/spark/ml/param/params.scala
    index d034d7ec6b60e..954aa17e26a02 100644
    --- a/mllib/src/main/scala/org/apache/spark/ml/param/params.scala
    +++ b/mllib/src/main/scala/org/apache/spark/ml/param/params.scala
    @@ -295,6 +295,22 @@ class DoubleArrayParam(parent: Params, name: String, doc: String, isValid: Array
         w(value.asScala.map(_.asInstanceOf[Double]).toArray)
     }
     
    +/**
    + * :: DeveloperApi ::
    + * Specialized version of [[Param[Array[Int]]]] for Java.
    + */
    +@DeveloperApi
    +class IntArrayParam(parent: Params, name: String, doc: String, isValid: Array[Int] => Boolean)
    +  extends Param[Array[Int]](parent, name, doc, isValid) {
    +
    +  def this(parent: Params, name: String, doc: String) =
    +    this(parent, name, doc, ParamValidators.alwaysTrue)
    +
    +  /** Creates a param pair with a [[java.util.List]] of values (for Java and Python). */
    +  def w(value: java.util.List[java.lang.Integer]): ParamPair[Array[Int]] =
    +    w(value.asScala.map(_.asInstanceOf[Int]).toArray)
    +}
    +
     /**
      * :: Experimental ::
      * A param and its value.
    @@ -460,11 +476,14 @@ trait Params extends Identifiable with Serializable {
       /**
        * Sets default values for a list of params.
        *
    +   * Note: Java developers should use the single-parameter [[setDefault()]].
    +   *       Annotating this with varargs can cause compilation failures due to a Scala compiler bug.
    +   *       See SPARK-9268.
    +   *
        * @param paramPairs  a list of param pairs that specify params and their default values to set
        *                    respectively. Make sure that the params are initialized before this method
        *                    gets called.
        */
    -  @varargs
       protected final def setDefault(paramPairs: ParamPair[_]*): this.type = {
         paramPairs.foreach { p =>
           setDefault(p.param.asInstanceOf[Param[Any]], p.value)
    diff --git a/mllib/src/main/scala/org/apache/spark/ml/param/shared/SharedParamsCodeGen.scala b/mllib/src/main/scala/org/apache/spark/ml/param/shared/SharedParamsCodeGen.scala
    index 66b751a1b02ee..f7ae1de522e01 100644
    --- a/mllib/src/main/scala/org/apache/spark/ml/param/shared/SharedParamsCodeGen.scala
    +++ b/mllib/src/main/scala/org/apache/spark/ml/param/shared/SharedParamsCodeGen.scala
    @@ -134,7 +134,7 @@ private[shared] object SharedParamsCodeGen {
     
         s"""
           |/**
    -      | * (private[ml]) Trait for shared param $name$defaultValueDoc.
    +      | * Trait for shared param $name$defaultValueDoc.
           | */
           |private[ml] trait Has$Name extends Params {
           |
    @@ -173,7 +173,6 @@ private[shared] object SharedParamsCodeGen {
             |package org.apache.spark.ml.param.shared
             |
             |import org.apache.spark.ml.param._
    -        |import org.apache.spark.util.Utils
             |
             |// DO NOT MODIFY THIS FILE! It was generated by SharedParamsCodeGen.
             |
    diff --git a/mllib/src/main/scala/org/apache/spark/ml/param/shared/sharedParams.scala b/mllib/src/main/scala/org/apache/spark/ml/param/shared/sharedParams.scala
    index f81bd76c22376..65e48e4ee5083 100644
    --- a/mllib/src/main/scala/org/apache/spark/ml/param/shared/sharedParams.scala
    +++ b/mllib/src/main/scala/org/apache/spark/ml/param/shared/sharedParams.scala
    @@ -18,14 +18,13 @@
     package org.apache.spark.ml.param.shared
     
     import org.apache.spark.ml.param._
    -import org.apache.spark.util.Utils
     
     // DO NOT MODIFY THIS FILE! It was generated by SharedParamsCodeGen.
     
     // scalastyle:off
     
     /**
    - * (private[ml]) Trait for shared param regParam.
    + * Trait for shared param regParam.
      */
     private[ml] trait HasRegParam extends Params {
     
    @@ -40,7 +39,7 @@ private[ml] trait HasRegParam extends Params {
     }
     
     /**
    - * (private[ml]) Trait for shared param maxIter.
    + * Trait for shared param maxIter.
      */
     private[ml] trait HasMaxIter extends Params {
     
    @@ -55,7 +54,7 @@ private[ml] trait HasMaxIter extends Params {
     }
     
     /**
    - * (private[ml]) Trait for shared param featuresCol (default: "features").
    + * Trait for shared param featuresCol (default: "features").
      */
     private[ml] trait HasFeaturesCol extends Params {
     
    @@ -72,7 +71,7 @@ private[ml] trait HasFeaturesCol extends Params {
     }
     
     /**
    - * (private[ml]) Trait for shared param labelCol (default: "label").
    + * Trait for shared param labelCol (default: "label").
      */
     private[ml] trait HasLabelCol extends Params {
     
    @@ -89,7 +88,7 @@ private[ml] trait HasLabelCol extends Params {
     }
     
     /**
    - * (private[ml]) Trait for shared param predictionCol (default: "prediction").
    + * Trait for shared param predictionCol (default: "prediction").
      */
     private[ml] trait HasPredictionCol extends Params {
     
    @@ -106,7 +105,7 @@ private[ml] trait HasPredictionCol extends Params {
     }
     
     /**
    - * (private[ml]) Trait for shared param rawPredictionCol (default: "rawPrediction").
    + * Trait for shared param rawPredictionCol (default: "rawPrediction").
      */
     private[ml] trait HasRawPredictionCol extends Params {
     
    @@ -123,7 +122,7 @@ private[ml] trait HasRawPredictionCol extends Params {
     }
     
     /**
    - * (private[ml]) Trait for shared param probabilityCol (default: "probability").
    + * Trait for shared param probabilityCol (default: "probability").
      */
     private[ml] trait HasProbabilityCol extends Params {
     
    @@ -140,7 +139,7 @@ private[ml] trait HasProbabilityCol extends Params {
     }
     
     /**
    - * (private[ml]) Trait for shared param threshold.
    + * Trait for shared param threshold.
      */
     private[ml] trait HasThreshold extends Params {
     
    @@ -155,7 +154,7 @@ private[ml] trait HasThreshold extends Params {
     }
     
     /**
    - * (private[ml]) Trait for shared param inputCol.
    + * Trait for shared param inputCol.
      */
     private[ml] trait HasInputCol extends Params {
     
    @@ -170,7 +169,7 @@ private[ml] trait HasInputCol extends Params {
     }
     
     /**
    - * (private[ml]) Trait for shared param inputCols.
    + * Trait for shared param inputCols.
      */
     private[ml] trait HasInputCols extends Params {
     
    @@ -185,7 +184,7 @@ private[ml] trait HasInputCols extends Params {
     }
     
     /**
    - * (private[ml]) Trait for shared param outputCol (default: uid + "__output").
    + * Trait for shared param outputCol (default: uid + "__output").
      */
     private[ml] trait HasOutputCol extends Params {
     
    @@ -202,7 +201,7 @@ private[ml] trait HasOutputCol extends Params {
     }
     
     /**
    - * (private[ml]) Trait for shared param checkpointInterval.
    + * Trait for shared param checkpointInterval.
      */
     private[ml] trait HasCheckpointInterval extends Params {
     
    @@ -217,7 +216,7 @@ private[ml] trait HasCheckpointInterval extends Params {
     }
     
     /**
    - * (private[ml]) Trait for shared param fitIntercept (default: true).
    + * Trait for shared param fitIntercept (default: true).
      */
     private[ml] trait HasFitIntercept extends Params {
     
    @@ -234,7 +233,7 @@ private[ml] trait HasFitIntercept extends Params {
     }
     
     /**
    - * (private[ml]) Trait for shared param standardization (default: true).
    + * Trait for shared param standardization (default: true).
      */
     private[ml] trait HasStandardization extends Params {
     
    @@ -251,7 +250,7 @@ private[ml] trait HasStandardization extends Params {
     }
     
     /**
    - * (private[ml]) Trait for shared param seed (default: this.getClass.getName.hashCode.toLong).
    + * Trait for shared param seed (default: this.getClass.getName.hashCode.toLong).
      */
     private[ml] trait HasSeed extends Params {
     
    @@ -268,7 +267,7 @@ private[ml] trait HasSeed extends Params {
     }
     
     /**
    - * (private[ml]) Trait for shared param elasticNetParam.
    + * Trait for shared param elasticNetParam.
      */
     private[ml] trait HasElasticNetParam extends Params {
     
    @@ -283,7 +282,7 @@ private[ml] trait HasElasticNetParam extends Params {
     }
     
     /**
    - * (private[ml]) Trait for shared param tol.
    + * Trait for shared param tol.
      */
     private[ml] trait HasTol extends Params {
     
    @@ -298,7 +297,7 @@ private[ml] trait HasTol extends Params {
     }
     
     /**
    - * (private[ml]) Trait for shared param stepSize.
    + * Trait for shared param stepSize.
      */
     private[ml] trait HasStepSize extends Params {
     
    diff --git a/mllib/src/main/scala/org/apache/spark/ml/r/SparkRWrappers.scala b/mllib/src/main/scala/org/apache/spark/ml/r/SparkRWrappers.scala
    new file mode 100644
    index 0000000000000..1ee080641e3e3
    --- /dev/null
    +++ b/mllib/src/main/scala/org/apache/spark/ml/r/SparkRWrappers.scala
    @@ -0,0 +1,41 @@
    +/*
    + * Licensed to the Apache Software Foundation (ASF) under one or more
    + * contributor license agreements.  See the NOTICE file distributed with
    + * this work for additional information regarding copyright ownership.
    + * The ASF licenses this file to You under the Apache License, Version 2.0
    + * (the "License"); you may not use this file except in compliance with
    + * the License.  You may obtain a copy of the License at
    + *
    + *    http://www.apache.org/licenses/LICENSE-2.0
    + *
    + * Unless required by applicable law or agreed to in writing, software
    + * distributed under the License is distributed on an "AS IS" BASIS,
    + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    + * See the License for the specific language governing permissions and
    + * limitations under the License.
    + */
    +
    +package org.apache.spark.ml.api.r
    +
    +import org.apache.spark.ml.feature.RFormula
    +import org.apache.spark.ml.classification.LogisticRegression
    +import org.apache.spark.ml.regression.LinearRegression
    +import org.apache.spark.ml.{Pipeline, PipelineModel}
    +import org.apache.spark.sql.DataFrame
    +
    +private[r] object SparkRWrappers {
    +  def fitRModelFormula(
    +      value: String,
    +      df: DataFrame,
    +      family: String,
    +      lambda: Double,
    +      alpha: Double): PipelineModel = {
    +    val formula = new RFormula().setFormula(value)
    +    val estimator = family match {
    +      case "gaussian" => new LinearRegression().setRegParam(lambda).setElasticNetParam(alpha)
    +      case "binomial" => new LogisticRegression().setRegParam(lambda).setElasticNetParam(alpha)
    +    }
    +    val pipeline = new Pipeline().setStages(Array(formula, estimator))
    +    pipeline.fit(df)
    +  }
    +}
    diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/DecisionTreeRegressor.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/DecisionTreeRegressor.scala
    index be1f8063d41d8..6f3340c2f02be 100644
    --- a/mllib/src/main/scala/org/apache/spark/ml/regression/DecisionTreeRegressor.scala
    +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/DecisionTreeRegressor.scala
    @@ -21,10 +21,10 @@ import org.apache.spark.annotation.Experimental
     import org.apache.spark.ml.{PredictionModel, Predictor}
     import org.apache.spark.ml.param.ParamMap
     import org.apache.spark.ml.tree.{DecisionTreeModel, DecisionTreeParams, Node, TreeRegressorParams}
    +import org.apache.spark.ml.tree.impl.RandomForest
     import org.apache.spark.ml.util.{Identifiable, MetadataUtils}
     import org.apache.spark.mllib.linalg.Vector
     import org.apache.spark.mllib.regression.LabeledPoint
    -import org.apache.spark.mllib.tree.{DecisionTree => OldDecisionTree}
     import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo, Strategy => OldStrategy}
     import org.apache.spark.mllib.tree.model.{DecisionTreeModel => OldDecisionTreeModel}
     import org.apache.spark.rdd.RDD
    @@ -67,8 +67,9 @@ final class DecisionTreeRegressor(override val uid: String)
           MetadataUtils.getCategoricalFeatures(dataset.schema($(featuresCol)))
         val oldDataset: RDD[LabeledPoint] = extractLabeledPoints(dataset)
         val strategy = getOldStrategy(categoricalFeatures)
    -    val oldModel = OldDecisionTree.train(oldDataset, strategy)
    -    DecisionTreeRegressionModel.fromOld(oldModel, this, categoricalFeatures)
    +    val trees = RandomForest.run(oldDataset, strategy, numTrees = 1, featureSubsetStrategy = "all",
    +      seed = 0L, parentUID = Some(uid))
    +    trees.head.asInstanceOf[DecisionTreeRegressionModel]
       }
     
       /** (private[ml]) Create a Strategy instance to use with the old API. */
    @@ -102,6 +103,12 @@ final class DecisionTreeRegressionModel private[ml] (
       require(rootNode != null,
         "DecisionTreeClassificationModel given null rootNode, but it requires a non-null rootNode.")
     
    +  /**
    +   * Construct a decision tree regression model.
    +   * @param rootNode  Root node of tree, with other nodes attached.
    +   */
    +  def this(rootNode: Node) = this(Identifiable.randomUID("dtr"), rootNode)
    +
       override protected def predict(features: Vector): Double = {
         rootNode.predict(features)
       }
    diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala
    index 47c110d027d67..e38dc73ee0ba7 100644
    --- a/mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala
    +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala
    @@ -33,6 +33,8 @@ import org.apache.spark.mllib.tree.loss.{AbsoluteError => OldAbsoluteError, Loss
     import org.apache.spark.mllib.tree.model.{GradientBoostedTreesModel => OldGBTModel}
     import org.apache.spark.rdd.RDD
     import org.apache.spark.sql.DataFrame
    +import org.apache.spark.sql.functions._
    +import org.apache.spark.sql.types.DoubleType
     
     /**
      * :: Experimental ::
    @@ -167,8 +169,15 @@ final class GBTRegressionModel(
     
       override def treeWeights: Array[Double] = _treeWeights
     
    +  override protected def transformImpl(dataset: DataFrame): DataFrame = {
    +    val bcastModel = dataset.sqlContext.sparkContext.broadcast(this)
    +    val predictUDF = udf { (features: Any) =>
    +      bcastModel.value.predict(features.asInstanceOf[Vector])
    +    }
    +    dataset.withColumn($(predictionCol), predictUDF(col($(featuresCol))))
    +  }
    +
       override protected def predict(features: Vector): Double = {
    -    // TODO: Override transform() to broadcast model. SPARK-7127
         // TODO: When we add a generic Boosting class, handle transform there?  SPARK-7129
         // Classifies by thresholding sum of weighted tree predictions
         val treePredictions = _trees.map(_.rootNode.predict(features))
    diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala
    index 8fc986056657d..89718e0f3e15a 100644
    --- a/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala
    +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala
    @@ -355,9 +355,9 @@ class LinearRegressionSummary private[regression] (
        */
       val r2: Double = metrics.r2
     
    -  /** Residuals (predicted value - label value) */
    +  /** Residuals (label - predicted value) */
       @transient lazy val residuals: DataFrame = {
    -    val t = udf { (pred: Double, label: Double) => pred - label}
    +    val t = udf { (pred: Double, label: Double) => label - pred }
         predictions.select(t(col(predictionCol), col(labelCol)).as("residuals"))
       }
     
    diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/RandomForestRegressor.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/RandomForestRegressor.scala
    index 21c59061a02fa..506a878c2553b 100644
    --- a/mllib/src/main/scala/org/apache/spark/ml/regression/RandomForestRegressor.scala
    +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/RandomForestRegressor.scala
    @@ -21,14 +21,16 @@ import org.apache.spark.annotation.Experimental
     import org.apache.spark.ml.{PredictionModel, Predictor}
     import org.apache.spark.ml.param.ParamMap
     import org.apache.spark.ml.tree.{DecisionTreeModel, RandomForestParams, TreeEnsembleModel, TreeRegressorParams}
    +import org.apache.spark.ml.tree.impl.RandomForest
     import org.apache.spark.ml.util.{Identifiable, MetadataUtils}
     import org.apache.spark.mllib.linalg.Vector
     import org.apache.spark.mllib.regression.LabeledPoint
    -import org.apache.spark.mllib.tree.{RandomForest => OldRandomForest}
     import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo}
     import org.apache.spark.mllib.tree.model.{RandomForestModel => OldRandomForestModel}
     import org.apache.spark.rdd.RDD
     import org.apache.spark.sql.DataFrame
    +import org.apache.spark.sql.functions._
    +import org.apache.spark.sql.types.DoubleType
     
     /**
      * :: Experimental ::
    @@ -82,9 +84,10 @@ final class RandomForestRegressor(override val uid: String)
         val oldDataset: RDD[LabeledPoint] = extractLabeledPoints(dataset)
         val strategy =
           super.getOldStrategy(categoricalFeatures, numClasses = 0, OldAlgo.Regression, getOldImpurity)
    -    val oldModel = OldRandomForest.trainRegressor(
    -      oldDataset, strategy, getNumTrees, getFeatureSubsetStrategy, getSeed.toInt)
    -    RandomForestRegressionModel.fromOld(oldModel, this, categoricalFeatures)
    +    val trees =
    +      RandomForest.run(oldDataset, strategy, getNumTrees, getFeatureSubsetStrategy, getSeed)
    +        .map(_.asInstanceOf[DecisionTreeRegressionModel])
    +    new RandomForestRegressionModel(trees)
       }
     
       override def copy(extra: ParamMap): RandomForestRegressor = defaultCopy(extra)
    @@ -115,6 +118,12 @@ final class RandomForestRegressionModel private[ml] (
     
       require(numTrees > 0, "RandomForestRegressionModel requires at least 1 tree.")
     
    +  /**
    +   * Construct a random forest regression model, with all trees weighted equally.
    +   * @param trees  Component trees
    +   */
    +  def this(trees: Array[DecisionTreeRegressionModel]) = this(Identifiable.randomUID("rfr"), trees)
    +
       override def trees: Array[DecisionTreeModel] = _trees.asInstanceOf[Array[DecisionTreeModel]]
     
       // Note: We may add support for weights (based on tree performance) later on.
    @@ -122,8 +131,15 @@ final class RandomForestRegressionModel private[ml] (
     
       override def treeWeights: Array[Double] = _treeWeights
     
    +  override protected def transformImpl(dataset: DataFrame): DataFrame = {
    +    val bcastModel = dataset.sqlContext.sparkContext.broadcast(this)
    +    val predictUDF = udf { (features: Any) =>
    +      bcastModel.value.predict(features.asInstanceOf[Vector])
    +    }
    +    dataset.withColumn($(predictionCol), predictUDF(col($(featuresCol))))
    +  }
    +
       override protected def predict(features: Vector): Double = {
    -    // TODO: Override transform() to broadcast model.  SPARK-7127
         // TODO: When we add a generic Bagging class, handle transform there.  SPARK-7128
         // Predict average of tree predictions.
         // Ignore the weights since all are 1.0 for now.
    diff --git a/mllib/src/main/scala/org/apache/spark/ml/tree/Node.scala b/mllib/src/main/scala/org/apache/spark/ml/tree/Node.scala
    index 4242154be14ce..bbc2427ca7d3d 100644
    --- a/mllib/src/main/scala/org/apache/spark/ml/tree/Node.scala
    +++ b/mllib/src/main/scala/org/apache/spark/ml/tree/Node.scala
    @@ -209,3 +209,132 @@ private object InternalNode {
         }
       }
     }
    +
    +/**
    + * Version of a node used in learning.  This uses vars so that we can modify nodes as we split the
    + * tree by adding children, etc.
    + *
    + * For now, we use node IDs.  These will be kept internal since we hope to remove node IDs
    + * in the future, or at least change the indexing (so that we can support much deeper trees).
    + *
    + * This node can either be:
    + *  - a leaf node, with leftChild, rightChild, split set to null, or
    + *  - an internal node, with all values set
    + *
    + * @param id  We currently use the same indexing as the old implementation in
    + *            [[org.apache.spark.mllib.tree.model.Node]], but this will change later.
    + * @param predictionStats  Predicted label + class probability (for classification).
    + *                         We will later modify this to store aggregate statistics for labels
    + *                         to provide all class probabilities (for classification) and maybe a
    + *                         distribution (for regression).
    + * @param isLeaf  Indicates whether this node will definitely be a leaf in the learned tree,
    + *                so that we do not need to consider splitting it further.
    + * @param stats  Old structure for storing stats about information gain, prediction, etc.
    + *               This is legacy and will be modified in the future.
    + */
    +private[tree] class LearningNode(
    +    var id: Int,
    +    var predictionStats: OldPredict,
    +    var impurity: Double,
    +    var leftChild: Option[LearningNode],
    +    var rightChild: Option[LearningNode],
    +    var split: Option[Split],
    +    var isLeaf: Boolean,
    +    var stats: Option[OldInformationGainStats]) extends Serializable {
    +
    +  /**
    +   * Convert this [[LearningNode]] to a regular [[Node]], and recurse on any children.
    +   */
    +  def toNode: Node = {
    +    if (leftChild.nonEmpty) {
    +      assert(rightChild.nonEmpty && split.nonEmpty && stats.nonEmpty,
    +        "Unknown error during Decision Tree learning.  Could not convert LearningNode to Node.")
    +      new InternalNode(predictionStats.predict, impurity, stats.get.gain,
    +        leftChild.get.toNode, rightChild.get.toNode, split.get)
    +    } else {
    +      new LeafNode(predictionStats.predict, impurity)
    +    }
    +  }
    +
    +}
    +
    +private[tree] object LearningNode {
    +
    +  /** Create a node with some of its fields set. */
    +  def apply(
    +      id: Int,
    +      predictionStats: OldPredict,
    +      impurity: Double,
    +      isLeaf: Boolean): LearningNode = {
    +    new LearningNode(id, predictionStats, impurity, None, None, None, false, None)
    +  }
    +
    +  /** Create an empty node with the given node index.  Values must be set later on. */
    +  def emptyNode(nodeIndex: Int): LearningNode = {
    +    new LearningNode(nodeIndex, new OldPredict(Double.NaN, Double.NaN), Double.NaN,
    +      None, None, None, false, None)
    +  }
    +
    +  // The below indexing methods were copied from spark.mllib.tree.model.Node
    +
    +  /**
    +   * Return the index of the left child of this node.
    +   */
    +  def leftChildIndex(nodeIndex: Int): Int = nodeIndex << 1
    +
    +  /**
    +   * Return the index of the right child of this node.
    +   */
    +  def rightChildIndex(nodeIndex: Int): Int = (nodeIndex << 1) + 1
    +
    +  /**
    +   * Get the parent index of the given node, or 0 if it is the root.
    +   */
    +  def parentIndex(nodeIndex: Int): Int = nodeIndex >> 1
    +
    +  /**
    +   * Return the level of a tree which the given node is in.
    +   */
    +  def indexToLevel(nodeIndex: Int): Int = if (nodeIndex == 0) {
    +    throw new IllegalArgumentException(s"0 is not a valid node index.")
    +  } else {
    +    java.lang.Integer.numberOfTrailingZeros(java.lang.Integer.highestOneBit(nodeIndex))
    +  }
    +
    +  /**
    +   * Returns true if this is a left child.
    +   * Note: Returns false for the root.
    +   */
    +  def isLeftChild(nodeIndex: Int): Boolean = nodeIndex > 1 && nodeIndex % 2 == 0
    +
    +  /**
    +   * Return the maximum number of nodes which can be in the given level of the tree.
    +   * @param level  Level of tree (0 = root).
    +   */
    +  def maxNodesInLevel(level: Int): Int = 1 << level
    +
    +  /**
    +   * Return the index of the first node in the given level.
    +   * @param level  Level of tree (0 = root).
    +   */
    +  def startIndexInLevel(level: Int): Int = 1 << level
    +
    +  /**
    +   * Traces down from a root node to get the node with the given node index.
    +   * This assumes the node exists.
    +   */
    +  def getNode(nodeIndex: Int, rootNode: LearningNode): LearningNode = {
    +    var tmpNode: LearningNode = rootNode
    +    var levelsToGo = indexToLevel(nodeIndex)
    +    while (levelsToGo > 0) {
    +      if ((nodeIndex & (1 << levelsToGo - 1)) == 0) {
    +        tmpNode = tmpNode.leftChild.asInstanceOf[LearningNode]
    +      } else {
    +        tmpNode = tmpNode.rightChild.asInstanceOf[LearningNode]
    +      }
    +      levelsToGo -= 1
    +    }
    +    tmpNode
    +  }
    +
    +}
    diff --git a/mllib/src/main/scala/org/apache/spark/ml/tree/Split.scala b/mllib/src/main/scala/org/apache/spark/ml/tree/Split.scala
    index 7acdeeee72d23..78199cc2df582 100644
    --- a/mllib/src/main/scala/org/apache/spark/ml/tree/Split.scala
    +++ b/mllib/src/main/scala/org/apache/spark/ml/tree/Split.scala
    @@ -34,9 +34,19 @@ sealed trait Split extends Serializable {
       /** Index of feature which this split tests */
       def featureIndex: Int
     
    -  /** Return true (split to left) or false (split to right) */
    +  /**
    +   * Return true (split to left) or false (split to right).
    +   * @param features  Vector of features (original values, not binned).
    +   */
       private[ml] def shouldGoLeft(features: Vector): Boolean
     
    +  /**
    +   * Return true (split to left) or false (split to right).
    +   * @param binnedFeature Binned feature value.
    +   * @param splits All splits for the given feature.
    +   */
    +  private[tree] def shouldGoLeft(binnedFeature: Int, splits: Array[Split]): Boolean
    +
       /** Convert to old Split format */
       private[tree] def toOld: OldSplit
     }
    @@ -94,6 +104,14 @@ final class CategoricalSplit private[ml] (
         }
       }
     
    +  override private[tree] def shouldGoLeft(binnedFeature: Int, splits: Array[Split]): Boolean = {
    +    if (isLeft) {
    +      categories.contains(binnedFeature.toDouble)
    +    } else {
    +      !categories.contains(binnedFeature.toDouble)
    +    }
    +  }
    +
       override def equals(o: Any): Boolean = {
         o match {
           case other: CategoricalSplit => featureIndex == other.featureIndex &&
    @@ -144,6 +162,16 @@ final class ContinuousSplit private[ml] (override val featureIndex: Int, val thr
         features(featureIndex) <= threshold
       }
     
    +  override private[tree] def shouldGoLeft(binnedFeature: Int, splits: Array[Split]): Boolean = {
    +    if (binnedFeature == splits.length) {
    +      // > last split, so split right
    +      false
    +    } else {
    +      val featureValueUpperBound = splits(binnedFeature).asInstanceOf[ContinuousSplit].threshold
    +      featureValueUpperBound <= threshold
    +    }
    +  }
    +
       override def equals(o: Any): Boolean = {
         o match {
           case other: ContinuousSplit =>
    diff --git a/mllib/src/main/scala/org/apache/spark/ml/tree/impl/NodeIdCache.scala b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/NodeIdCache.scala
    new file mode 100644
    index 0000000000000..488e8e4fb5dcd
    --- /dev/null
    +++ b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/NodeIdCache.scala
    @@ -0,0 +1,194 @@
    +/*
    + * Licensed to the Apache Software Foundation (ASF) under one or more
    + * contributor license agreements.  See the NOTICE file distributed with
    + * this work for additional information regarding copyright ownership.
    + * The ASF licenses this file to You under the Apache License, Version 2.0
    + * (the "License"); you may not use this file except in compliance with
    + * the License.  You may obtain a copy of the License at
    + *
    + *    http://www.apache.org/licenses/LICENSE-2.0
    + *
    + * Unless required by applicable law or agreed to in writing, software
    + * distributed under the License is distributed on an "AS IS" BASIS,
    + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    + * See the License for the specific language governing permissions and
    + * limitations under the License.
    + */
    +
    +package org.apache.spark.ml.tree.impl
    +
    +import java.io.IOException
    +
    +import scala.collection.mutable
    +
    +import org.apache.hadoop.fs.{Path, FileSystem}
    +
    +import org.apache.spark.Logging
    +import org.apache.spark.annotation.DeveloperApi
    +import org.apache.spark.ml.tree.{LearningNode, Split}
    +import org.apache.spark.mllib.tree.impl.BaggedPoint
    +import org.apache.spark.rdd.RDD
    +import org.apache.spark.storage.StorageLevel
    +
    +
    +/**
    + * This is used by the node id cache to find the child id that a data point would belong to.
    + * @param split Split information.
    + * @param nodeIndex The current node index of a data point that this will update.
    + */
    +private[tree] case class NodeIndexUpdater(split: Split, nodeIndex: Int) {
    +
    +  /**
    +   * Determine a child node index based on the feature value and the split.
    +   * @param binnedFeature Binned feature value.
    +   * @param splits Split information to convert the bin indices to approximate feature values.
    +   * @return Child node index to update to.
    +   */
    +  def updateNodeIndex(binnedFeature: Int, splits: Array[Split]): Int = {
    +    if (split.shouldGoLeft(binnedFeature, splits)) {
    +      LearningNode.leftChildIndex(nodeIndex)
    +    } else {
    +      LearningNode.rightChildIndex(nodeIndex)
    +    }
    +  }
    +}
    +
    +/**
    + * Each TreePoint belongs to a particular node per tree.
    + * Each row in the nodeIdsForInstances RDD is an array over trees of the node index
    + * in each tree. Initially, values should all be 1 for root node.
    + * The nodeIdsForInstances RDD needs to be updated at each iteration.
    + * @param nodeIdsForInstances The initial values in the cache
    + *                           (should be an Array of all 1's (meaning the root nodes)).
    + * @param checkpointInterval The checkpointing interval
    + *                           (how often should the cache be checkpointed.).
    + */
    +private[spark] class NodeIdCache(
    +  var nodeIdsForInstances: RDD[Array[Int]],
    +  val checkpointInterval: Int) extends Logging {
    +
    +  // Keep a reference to a previous node Ids for instances.
    +  // Because we will keep on re-persisting updated node Ids,
    +  // we want to unpersist the previous RDD.
    +  private var prevNodeIdsForInstances: RDD[Array[Int]] = null
    +
    +  // To keep track of the past checkpointed RDDs.
    +  private val checkpointQueue = mutable.Queue[RDD[Array[Int]]]()
    +  private var rddUpdateCount = 0
    +
    +  // Indicates whether we can checkpoint
    +  private val canCheckpoint = nodeIdsForInstances.sparkContext.getCheckpointDir.nonEmpty
    +
    +  // FileSystem instance for deleting checkpoints as needed
    +  private val fs = FileSystem.get(nodeIdsForInstances.sparkContext.hadoopConfiguration)
    +
    +  /**
    +   * Update the node index values in the cache.
    +   * This updates the RDD and its lineage.
    +   * TODO: Passing bin information to executors seems unnecessary and costly.
    +   * @param data The RDD of training rows.
    +   * @param nodeIdUpdaters A map of node index updaters.
    +   *                       The key is the indices of nodes that we want to update.
    +   * @param splits  Split information needed to find child node indices.
    +   */
    +  def updateNodeIndices(
    +      data: RDD[BaggedPoint[TreePoint]],
    +      nodeIdUpdaters: Array[mutable.Map[Int, NodeIndexUpdater]],
    +      splits: Array[Array[Split]]): Unit = {
    +    if (prevNodeIdsForInstances != null) {
    +      // Unpersist the previous one if one exists.
    +      prevNodeIdsForInstances.unpersist()
    +    }
    +
    +    prevNodeIdsForInstances = nodeIdsForInstances
    +    nodeIdsForInstances = data.zip(nodeIdsForInstances).map { case (point, ids) =>
    +      var treeId = 0
    +      while (treeId < nodeIdUpdaters.length) {
    +        val nodeIdUpdater = nodeIdUpdaters(treeId).getOrElse(ids(treeId), null)
    +        if (nodeIdUpdater != null) {
    +          val featureIndex = nodeIdUpdater.split.featureIndex
    +          val newNodeIndex = nodeIdUpdater.updateNodeIndex(
    +            binnedFeature = point.datum.binnedFeatures(featureIndex),
    +            splits = splits(featureIndex))
    +          ids(treeId) = newNodeIndex
    +        }
    +        treeId += 1
    +      }
    +      ids
    +    }
    +
    +    // Keep on persisting new ones.
    +    nodeIdsForInstances.persist(StorageLevel.MEMORY_AND_DISK)
    +    rddUpdateCount += 1
    +
    +    // Handle checkpointing if the directory is not None.
    +    if (canCheckpoint && (rddUpdateCount % checkpointInterval) == 0) {
    +      // Let's see if we can delete previous checkpoints.
    +      var canDelete = true
    +      while (checkpointQueue.size > 1 && canDelete) {
    +        // We can delete the oldest checkpoint iff
    +        // the next checkpoint actually exists in the file system.
    +        if (checkpointQueue(1).getCheckpointFile.isDefined) {
    +          val old = checkpointQueue.dequeue()
    +          // Since the old checkpoint is not deleted by Spark, we'll manually delete it here.
    +          try {
    +            fs.delete(new Path(old.getCheckpointFile.get), true)
    +          } catch {
    +            case e: IOException =>
    +              logError("Decision Tree learning using cacheNodeIds failed to remove checkpoint" +
    +                s" file: ${old.getCheckpointFile.get}")
    +          }
    +        } else {
    +          canDelete = false
    +        }
    +      }
    +
    +      nodeIdsForInstances.checkpoint()
    +      checkpointQueue.enqueue(nodeIdsForInstances)
    +    }
    +  }
    +
    +  /**
    +   * Call this after training is finished to delete any remaining checkpoints.
    +   */
    +  def deleteAllCheckpoints(): Unit = {
    +    while (checkpointQueue.nonEmpty) {
    +      val old = checkpointQueue.dequeue()
    +      if (old.getCheckpointFile.isDefined) {
    +        try {
    +          fs.delete(new Path(old.getCheckpointFile.get), true)
    +        } catch {
    +          case e: IOException =>
    +            logError("Decision Tree learning using cacheNodeIds failed to remove checkpoint" +
    +              s" file: ${old.getCheckpointFile.get}")
    +        }
    +      }
    +    }
    +  }
    +  if (prevNodeIdsForInstances != null) {
    +    // Unpersist the previous one if one exists.
    +    prevNodeIdsForInstances.unpersist()
    +  }
    +}
    +
    +@DeveloperApi
    +private[spark] object NodeIdCache {
    +  /**
    +   * Initialize the node Id cache with initial node Id values.
    +   * @param data The RDD of training rows.
    +   * @param numTrees The number of trees that we want to create cache for.
    +   * @param checkpointInterval The checkpointing interval
    +   *                           (how often should the cache be checkpointed.).
    +   * @param initVal The initial values in the cache.
    +   * @return A node Id cache containing an RDD of initial root node Indices.
    +   */
    +  def init(
    +      data: RDD[BaggedPoint[TreePoint]],
    +      numTrees: Int,
    +      checkpointInterval: Int,
    +      initVal: Int = 1): NodeIdCache = {
    +    new NodeIdCache(
    +      data.map(_ => Array.fill[Int](numTrees)(initVal)),
    +      checkpointInterval)
    +  }
    +}
    diff --git a/mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala
    new file mode 100644
    index 0000000000000..15b56bd844bad
    --- /dev/null
    +++ b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala
    @@ -0,0 +1,1132 @@
    +/*
    + * Licensed to the Apache Software Foundation (ASF) under one or more
    + * contributor license agreements.  See the NOTICE file distributed with
    + * this work for additional information regarding copyright ownership.
    + * The ASF licenses this file to You under the Apache License, Version 2.0
    + * (the "License"); you may not use this file except in compliance with
    + * the License.  You may obtain a copy of the License at
    + *
    + *    http://www.apache.org/licenses/LICENSE-2.0
    + *
    + * Unless required by applicable law or agreed to in writing, software
    + * distributed under the License is distributed on an "AS IS" BASIS,
    + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    + * See the License for the specific language governing permissions and
    + * limitations under the License.
    + */
    +
    +package org.apache.spark.ml.tree.impl
    +
    +import java.io.IOException
    +
    +import scala.collection.mutable
    +import scala.util.Random
    +
    +import org.apache.spark.Logging
    +import org.apache.spark.ml.classification.DecisionTreeClassificationModel
    +import org.apache.spark.ml.regression.DecisionTreeRegressionModel
    +import org.apache.spark.ml.tree._
    +import org.apache.spark.mllib.regression.LabeledPoint
    +import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo, Strategy => OldStrategy}
    +import org.apache.spark.mllib.tree.impl.{BaggedPoint, DTStatsAggregator, DecisionTreeMetadata,
    +  TimeTracker}
    +import org.apache.spark.mllib.tree.impurity.ImpurityCalculator
    +import org.apache.spark.mllib.tree.model.{InformationGainStats, Predict}
    +import org.apache.spark.rdd.RDD
    +import org.apache.spark.storage.StorageLevel
    +import org.apache.spark.util.random.{SamplingUtils, XORShiftRandom}
    +
    +
    +private[ml] object RandomForest extends Logging {
    +
    +  /**
    +   * Train a random forest.
    +   * @param input Training data: RDD of [[org.apache.spark.mllib.regression.LabeledPoint]]
    +   * @return an unweighted set of trees
    +   */
    +  def run(
    +      input: RDD[LabeledPoint],
    +      strategy: OldStrategy,
    +      numTrees: Int,
    +      featureSubsetStrategy: String,
    +      seed: Long,
    +      parentUID: Option[String] = None): Array[DecisionTreeModel] = {
    +
    +    val timer = new TimeTracker()
    +
    +    timer.start("total")
    +
    +    timer.start("init")
    +
    +    val retaggedInput = input.retag(classOf[LabeledPoint])
    +    val metadata =
    +      DecisionTreeMetadata.buildMetadata(retaggedInput, strategy, numTrees, featureSubsetStrategy)
    +    logDebug("algo = " + strategy.algo)
    +    logDebug("numTrees = " + numTrees)
    +    logDebug("seed = " + seed)
    +    logDebug("maxBins = " + metadata.maxBins)
    +    logDebug("featureSubsetStrategy = " + featureSubsetStrategy)
    +    logDebug("numFeaturesPerNode = " + metadata.numFeaturesPerNode)
    +    logDebug("subsamplingRate = " + strategy.subsamplingRate)
    +
    +    // Find the splits and the corresponding bins (interval between the splits) using a sample
    +    // of the input data.
    +    timer.start("findSplitsBins")
    +    val splits = findSplits(retaggedInput, metadata)
    +    timer.stop("findSplitsBins")
    +    logDebug("numBins: feature: number of bins")
    +    logDebug(Range(0, metadata.numFeatures).map { featureIndex =>
    +      s"\t$featureIndex\t${metadata.numBins(featureIndex)}"
    +    }.mkString("\n"))
    +
    +    // Bin feature values (TreePoint representation).
    +    // Cache input RDD for speedup during multiple passes.
    +    val treeInput = TreePoint.convertToTreeRDD(retaggedInput, splits, metadata)
    +
    +    val withReplacement = numTrees > 1
    +
    +    val baggedInput = BaggedPoint
    +      .convertToBaggedRDD(treeInput, strategy.subsamplingRate, numTrees, withReplacement, seed)
    +      .persist(StorageLevel.MEMORY_AND_DISK)
    +
    +    // depth of the decision tree
    +    val maxDepth = strategy.maxDepth
    +    require(maxDepth <= 30,
    +      s"DecisionTree currently only supports maxDepth <= 30, but was given maxDepth = $maxDepth.")
    +
    +    // Max memory usage for aggregates
    +    // TODO: Calculate memory usage more precisely.
    +    val maxMemoryUsage: Long = strategy.maxMemoryInMB * 1024L * 1024L
    +    logDebug("max memory usage for aggregates = " + maxMemoryUsage + " bytes.")
    +    val maxMemoryPerNode = {
    +      val featureSubset: Option[Array[Int]] = if (metadata.subsamplingFeatures) {
    +        // Find numFeaturesPerNode largest bins to get an upper bound on memory usage.
    +        Some(metadata.numBins.zipWithIndex.sortBy(- _._1)
    +          .take(metadata.numFeaturesPerNode).map(_._2))
    +      } else {
    +        None
    +      }
    +      RandomForest.aggregateSizeForNode(metadata, featureSubset) * 8L
    +    }
    +    require(maxMemoryPerNode <= maxMemoryUsage,
    +      s"RandomForest/DecisionTree given maxMemoryInMB = ${strategy.maxMemoryInMB}," +
    +        " which is too small for the given features." +
    +        s"  Minimum value = ${maxMemoryPerNode / (1024L * 1024L)}")
    +
    +    timer.stop("init")
    +
    +    /*
    +     * The main idea here is to perform group-wise training of the decision tree nodes thus
    +     * reducing the passes over the data from (# nodes) to (# nodes / maxNumberOfNodesPerGroup).
    +     * Each data sample is handled by a particular node (or it reaches a leaf and is not used
    +     * in lower levels).
    +     */
    +
    +    // Create an RDD of node Id cache.
    +    // At first, all the rows belong to the root nodes (node Id == 1).
    +    val nodeIdCache = if (strategy.useNodeIdCache) {
    +      Some(NodeIdCache.init(
    +        data = baggedInput,
    +        numTrees = numTrees,
    +        checkpointInterval = strategy.checkpointInterval,
    +        initVal = 1))
    +    } else {
    +      None
    +    }
    +
    +    // FIFO queue of nodes to train: (treeIndex, node)
    +    val nodeQueue = new mutable.Queue[(Int, LearningNode)]()
    +
    +    val rng = new Random()
    +    rng.setSeed(seed)
    +
    +    // Allocate and queue root nodes.
    +    val topNodes = Array.fill[LearningNode](numTrees)(LearningNode.emptyNode(nodeIndex = 1))
    +    Range(0, numTrees).foreach(treeIndex => nodeQueue.enqueue((treeIndex, topNodes(treeIndex))))
    +
    +    while (nodeQueue.nonEmpty) {
    +      // Collect some nodes to split, and choose features for each node (if subsampling).
    +      // Each group of nodes may come from one or multiple trees, and at multiple levels.
    +      val (nodesForGroup, treeToNodeToIndexInfo) =
    +        RandomForest.selectNodesToSplit(nodeQueue, maxMemoryUsage, metadata, rng)
    +      // Sanity check (should never occur):
    +      assert(nodesForGroup.nonEmpty,
    +        s"RandomForest selected empty nodesForGroup.  Error for unknown reason.")
    +
    +      // Choose node splits, and enqueue new nodes as needed.
    +      timer.start("findBestSplits")
    +      RandomForest.findBestSplits(baggedInput, metadata, topNodes, nodesForGroup,
    +        treeToNodeToIndexInfo, splits, nodeQueue, timer, nodeIdCache)
    +      timer.stop("findBestSplits")
    +    }
    +
    +    baggedInput.unpersist()
    +
    +    timer.stop("total")
    +
    +    logInfo("Internal timing for DecisionTree:")
    +    logInfo(s"$timer")
    +
    +    // Delete any remaining checkpoints used for node Id cache.
    +    if (nodeIdCache.nonEmpty) {
    +      try {
    +        nodeIdCache.get.deleteAllCheckpoints()
    +      } catch {
    +        case e: IOException =>
    +          logWarning(s"delete all checkpoints failed. Error reason: ${e.getMessage}")
    +      }
    +    }
    +
    +    parentUID match {
    +      case Some(uid) =>
    +        if (strategy.algo == OldAlgo.Classification) {
    +          topNodes.map(rootNode => new DecisionTreeClassificationModel(uid, rootNode.toNode))
    +        } else {
    +          topNodes.map(rootNode => new DecisionTreeRegressionModel(uid, rootNode.toNode))
    +        }
    +      case None =>
    +        if (strategy.algo == OldAlgo.Classification) {
    +          topNodes.map(rootNode => new DecisionTreeClassificationModel(rootNode.toNode))
    +        } else {
    +          topNodes.map(rootNode => new DecisionTreeRegressionModel(rootNode.toNode))
    +        }
    +    }
    +  }
    +
    +  /**
    +   * Get the node index corresponding to this data point.
    +   * This function mimics prediction, passing an example from the root node down to a leaf
    +   * or unsplit node; that node's index is returned.
    +   *
    +   * @param node  Node in tree from which to classify the given data point.
    +   * @param binnedFeatures  Binned feature vector for data point.
    +   * @param splits possible splits for all features, indexed (numFeatures)(numSplits)
    +   * @return  Leaf index if the data point reaches a leaf.
    +   *          Otherwise, last node reachable in tree matching this example.
    +   *          Note: This is the global node index, i.e., the index used in the tree.
    +   *                This index is different from the index used during training a particular
    +   *                group of nodes on one call to [[findBestSplits()]].
    +   */
    +  private def predictNodeIndex(
    +      node: LearningNode,
    +      binnedFeatures: Array[Int],
    +      splits: Array[Array[Split]]): Int = {
    +    if (node.isLeaf || node.split.isEmpty) {
    +      node.id
    +    } else {
    +      val split = node.split.get
    +      val featureIndex = split.featureIndex
    +      val splitLeft = split.shouldGoLeft(binnedFeatures(featureIndex), splits(featureIndex))
    +      if (node.leftChild.isEmpty) {
    +        // Not yet split. Return index from next layer of nodes to train
    +        if (splitLeft) {
    +          LearningNode.leftChildIndex(node.id)
    +        } else {
    +          LearningNode.rightChildIndex(node.id)
    +        }
    +      } else {
    +        if (splitLeft) {
    +          predictNodeIndex(node.leftChild.get, binnedFeatures, splits)
    +        } else {
    +          predictNodeIndex(node.rightChild.get, binnedFeatures, splits)
    +        }
    +      }
    +    }
    +  }
    +
    +  /**
    +   * Helper for binSeqOp, for data which can contain a mix of ordered and unordered features.
    +   *
    +   * For ordered features, a single bin is updated.
    +   * For unordered features, bins correspond to subsets of categories; either the left or right bin
    +   * for each subset is updated.
    +   *
    +   * @param agg  Array storing aggregate calculation, with a set of sufficient statistics for
    +   *             each (feature, bin).
    +   * @param treePoint  Data point being aggregated.
    +   * @param splits possible splits indexed (numFeatures)(numSplits)
    +   * @param unorderedFeatures  Set of indices of unordered features.
    +   * @param instanceWeight  Weight (importance) of instance in dataset.
    +   */
    +  private def mixedBinSeqOp(
    +      agg: DTStatsAggregator,
    +      treePoint: TreePoint,
    +      splits: Array[Array[Split]],
    +      unorderedFeatures: Set[Int],
    +      instanceWeight: Double,
    +      featuresForNode: Option[Array[Int]]): Unit = {
    +    val numFeaturesPerNode = if (featuresForNode.nonEmpty) {
    +      // Use subsampled features
    +      featuresForNode.get.length
    +    } else {
    +      // Use all features
    +      agg.metadata.numFeatures
    +    }
    +    // Iterate over features.
    +    var featureIndexIdx = 0
    +    while (featureIndexIdx < numFeaturesPerNode) {
    +      val featureIndex = if (featuresForNode.nonEmpty) {
    +        featuresForNode.get.apply(featureIndexIdx)
    +      } else {
    +        featureIndexIdx
    +      }
    +      if (unorderedFeatures.contains(featureIndex)) {
    +        // Unordered feature
    +        val featureValue = treePoint.binnedFeatures(featureIndex)
    +        val (leftNodeFeatureOffset, rightNodeFeatureOffset) =
    +          agg.getLeftRightFeatureOffsets(featureIndexIdx)
    +        // Update the left or right bin for each split.
    +        val numSplits = agg.metadata.numSplits(featureIndex)
    +        val featureSplits = splits(featureIndex)
    +        var splitIndex = 0
    +        while (splitIndex < numSplits) {
    +          if (featureSplits(splitIndex).shouldGoLeft(featureValue, featureSplits)) {
    +            agg.featureUpdate(leftNodeFeatureOffset, splitIndex, treePoint.label, instanceWeight)
    +          } else {
    +            agg.featureUpdate(rightNodeFeatureOffset, splitIndex, treePoint.label, instanceWeight)
    +          }
    +          splitIndex += 1
    +        }
    +      } else {
    +        // Ordered feature
    +        val binIndex = treePoint.binnedFeatures(featureIndex)
    +        agg.update(featureIndexIdx, binIndex, treePoint.label, instanceWeight)
    +      }
    +      featureIndexIdx += 1
    +    }
    +  }
    +
    +  /**
    +   * Helper for binSeqOp, for regression and for classification with only ordered features.
    +   *
    +   * For each feature, the sufficient statistics of one bin are updated.
    +   *
    +   * @param agg  Array storing aggregate calculation, with a set of sufficient statistics for
    +   *             each (feature, bin).
    +   * @param treePoint  Data point being aggregated.
    +   * @param instanceWeight  Weight (importance) of instance in dataset.
    +   */
    +  private def orderedBinSeqOp(
    +      agg: DTStatsAggregator,
    +      treePoint: TreePoint,
    +      instanceWeight: Double,
    +      featuresForNode: Option[Array[Int]]): Unit = {
    +    val label = treePoint.label
    +
    +    // Iterate over features.
    +    if (featuresForNode.nonEmpty) {
    +      // Use subsampled features
    +      var featureIndexIdx = 0
    +      while (featureIndexIdx < featuresForNode.get.length) {
    +        val binIndex = treePoint.binnedFeatures(featuresForNode.get.apply(featureIndexIdx))
    +        agg.update(featureIndexIdx, binIndex, label, instanceWeight)
    +        featureIndexIdx += 1
    +      }
    +    } else {
    +      // Use all features
    +      val numFeatures = agg.metadata.numFeatures
    +      var featureIndex = 0
    +      while (featureIndex < numFeatures) {
    +        val binIndex = treePoint.binnedFeatures(featureIndex)
    +        agg.update(featureIndex, binIndex, label, instanceWeight)
    +        featureIndex += 1
    +      }
    +    }
    +  }
    +
    +  /**
    +   * Given a group of nodes, this finds the best split for each node.
    +   *
    +   * @param input Training data: RDD of [[org.apache.spark.mllib.tree.impl.TreePoint]]
    +   * @param metadata Learning and dataset metadata
    +   * @param topNodes Root node for each tree.  Used for matching instances with nodes.
    +   * @param nodesForGroup Mapping: treeIndex --> nodes to be split in tree
    +   * @param treeToNodeToIndexInfo Mapping: treeIndex --> nodeIndex --> nodeIndexInfo,
    +   *                              where nodeIndexInfo stores the index in the group and the
    +   *                              feature subsets (if using feature subsets).
    +   * @param splits possible splits for all features, indexed (numFeatures)(numSplits)
    +   * @param nodeQueue  Queue of nodes to split, with values (treeIndex, node).
    +   *                   Updated with new non-leaf nodes which are created.
    +   * @param nodeIdCache Node Id cache containing an RDD of Array[Int] where
    +   *                    each value in the array is the data point's node Id
    +   *                    for a corresponding tree. This is used to prevent the need
    +   *                    to pass the entire tree to the executors during
    +   *                    the node stat aggregation phase.
    +   */
    +  private[tree] def findBestSplits(
    +      input: RDD[BaggedPoint[TreePoint]],
    +      metadata: DecisionTreeMetadata,
    +      topNodes: Array[LearningNode],
    +      nodesForGroup: Map[Int, Array[LearningNode]],
    +      treeToNodeToIndexInfo: Map[Int, Map[Int, NodeIndexInfo]],
    +      splits: Array[Array[Split]],
    +      nodeQueue: mutable.Queue[(Int, LearningNode)],
    +      timer: TimeTracker = new TimeTracker,
    +      nodeIdCache: Option[NodeIdCache] = None): Unit = {
    +
    +    /*
    +     * The high-level descriptions of the best split optimizations are noted here.
    +     *
    +     * *Group-wise training*
    +     * We perform bin calculations for groups of nodes to reduce the number of
    +     * passes over the data.  Each iteration requires more computation and storage,
    +     * but saves several iterations over the data.
    +     *
    +     * *Bin-wise computation*
    +     * We use a bin-wise best split computation strategy instead of a straightforward best split
    +     * computation strategy. Instead of analyzing each sample for contribution to the left/right
    +     * child node impurity of every split, we first categorize each feature of a sample into a
    +     * bin. We exploit this structure to calculate aggregates for bins and then use these aggregates
    +     * to calculate information gain for each split.
    +     *
    +     * *Aggregation over partitions*
    +     * Instead of performing a flatMap/reduceByKey operation, we exploit the fact that we know
    +     * the number of splits in advance. Thus, we store the aggregates (at the appropriate
    +     * indices) in a single array for all bins and rely upon the RDD aggregate method to
    +     * drastically reduce the communication overhead.
    +     */
    +
    +    // numNodes:  Number of nodes in this group
    +    val numNodes = nodesForGroup.values.map(_.length).sum
    +    logDebug("numNodes = " + numNodes)
    +    logDebug("numFeatures = " + metadata.numFeatures)
    +    logDebug("numClasses = " + metadata.numClasses)
    +    logDebug("isMulticlass = " + metadata.isMulticlass)
    +    logDebug("isMulticlassWithCategoricalFeatures = " +
    +      metadata.isMulticlassWithCategoricalFeatures)
    +    logDebug("using nodeIdCache = " + nodeIdCache.nonEmpty.toString)
    +
    +    /**
    +     * Performs a sequential aggregation over a partition for a particular tree and node.
    +     *
    +     * For each feature, the aggregate sufficient statistics are updated for the relevant
    +     * bins.
    +     *
    +     * @param treeIndex Index of the tree that we want to perform aggregation for.
    +     * @param nodeInfo The node info for the tree node.
    +     * @param agg Array storing aggregate calculation, with a set of sufficient statistics
    +     *            for each (node, feature, bin).
    +     * @param baggedPoint Data point being aggregated.
    +     */
    +    def nodeBinSeqOp(
    +        treeIndex: Int,
    +        nodeInfo: NodeIndexInfo,
    +        agg: Array[DTStatsAggregator],
    +        baggedPoint: BaggedPoint[TreePoint]): Unit = {
    +      if (nodeInfo != null) {
    +        val aggNodeIndex = nodeInfo.nodeIndexInGroup
    +        val featuresForNode = nodeInfo.featureSubset
    +        val instanceWeight = baggedPoint.subsampleWeights(treeIndex)
    +        if (metadata.unorderedFeatures.isEmpty) {
    +          orderedBinSeqOp(agg(aggNodeIndex), baggedPoint.datum, instanceWeight, featuresForNode)
    +        } else {
    +          mixedBinSeqOp(agg(aggNodeIndex), baggedPoint.datum, splits,
    +            metadata.unorderedFeatures, instanceWeight, featuresForNode)
    +        }
    +      }
    +    }
    +
    +    /**
    +     * Performs a sequential aggregation over a partition.
    +     *
    +     * Each data point contributes to one node. For each feature,
    +     * the aggregate sufficient statistics are updated for the relevant bins.
    +     *
    +     * @param agg  Array storing aggregate calculation, with a set of sufficient statistics for
    +     *             each (node, feature, bin).
    +     * @param baggedPoint   Data point being aggregated.
    +     * @return  agg
    +     */
    +    def binSeqOp(
    +        agg: Array[DTStatsAggregator],
    +        baggedPoint: BaggedPoint[TreePoint]): Array[DTStatsAggregator] = {
    +      treeToNodeToIndexInfo.foreach { case (treeIndex, nodeIndexToInfo) =>
    +        val nodeIndex =
    +          predictNodeIndex(topNodes(treeIndex), baggedPoint.datum.binnedFeatures, splits)
    +        nodeBinSeqOp(treeIndex, nodeIndexToInfo.getOrElse(nodeIndex, null), agg, baggedPoint)
    +      }
    +      agg
    +    }
    +
    +    /**
    +     * Do the same thing as binSeqOp, but with nodeIdCache.
    +     */
    +    def binSeqOpWithNodeIdCache(
    +        agg: Array[DTStatsAggregator],
    +        dataPoint: (BaggedPoint[TreePoint], Array[Int])): Array[DTStatsAggregator] = {
    +      treeToNodeToIndexInfo.foreach { case (treeIndex, nodeIndexToInfo) =>
    +        val baggedPoint = dataPoint._1
    +        val nodeIdCache = dataPoint._2
    +        val nodeIndex = nodeIdCache(treeIndex)
    +        nodeBinSeqOp(treeIndex, nodeIndexToInfo.getOrElse(nodeIndex, null), agg, baggedPoint)
    +      }
    +
    +      agg
    +    }
    +
    +    /**
    +     * Get node index in group --> features indices map,
    +     * which is a short cut to find feature indices for a node given node index in group.
    +     */
    +    def getNodeToFeatures(
    +        treeToNodeToIndexInfo: Map[Int, Map[Int, NodeIndexInfo]]): Option[Map[Int, Array[Int]]] = {
    +      if (!metadata.subsamplingFeatures) {
    +        None
    +      } else {
    +        val mutableNodeToFeatures = new mutable.HashMap[Int, Array[Int]]()
    +        treeToNodeToIndexInfo.values.foreach { nodeIdToNodeInfo =>
    +          nodeIdToNodeInfo.values.foreach { nodeIndexInfo =>
    +            assert(nodeIndexInfo.featureSubset.isDefined)
    +            mutableNodeToFeatures(nodeIndexInfo.nodeIndexInGroup) = nodeIndexInfo.featureSubset.get
    +          }
    +        }
    +        Some(mutableNodeToFeatures.toMap)
    +      }
    +    }
    +
    +    // array of nodes to train indexed by node index in group
    +    val nodes = new Array[LearningNode](numNodes)
    +    nodesForGroup.foreach { case (treeIndex, nodesForTree) =>
    +      nodesForTree.foreach { node =>
    +        nodes(treeToNodeToIndexInfo(treeIndex)(node.id).nodeIndexInGroup) = node
    +      }
    +    }
    +
    +    // Calculate best splits for all nodes in the group
    +    timer.start("chooseSplits")
    +
    +    // In each partition, iterate all instances and compute aggregate stats for each node,
    +    // yield an (nodeIndex, nodeAggregateStats) pair for each node.
    +    // After a `reduceByKey` operation,
    +    // stats of a node will be shuffled to a particular partition and be combined together,
    +    // then best splits for nodes are found there.
    +    // Finally, only best Splits for nodes are collected to driver to construct decision tree.
    +    val nodeToFeatures = getNodeToFeatures(treeToNodeToIndexInfo)
    +    val nodeToFeaturesBc = input.sparkContext.broadcast(nodeToFeatures)
    +
    +    val partitionAggregates : RDD[(Int, DTStatsAggregator)] = if (nodeIdCache.nonEmpty) {
    +      input.zip(nodeIdCache.get.nodeIdsForInstances).mapPartitions { points =>
    +        // Construct a nodeStatsAggregators array to hold node aggregate stats,
    +        // each node will have a nodeStatsAggregator
    +        val nodeStatsAggregators = Array.tabulate(numNodes) { nodeIndex =>
    +          val featuresForNode = nodeToFeaturesBc.value.flatMap { nodeToFeatures =>
    +            Some(nodeToFeatures(nodeIndex))
    +          }
    +          new DTStatsAggregator(metadata, featuresForNode)
    +        }
    +
    +        // iterator all instances in current partition and update aggregate stats
    +        points.foreach(binSeqOpWithNodeIdCache(nodeStatsAggregators, _))
    +
    +        // transform nodeStatsAggregators array to (nodeIndex, nodeAggregateStats) pairs,
    +        // which can be combined with other partition using `reduceByKey`
    +        nodeStatsAggregators.view.zipWithIndex.map(_.swap).iterator
    +      }
    +    } else {
    +      input.mapPartitions { points =>
    +        // Construct a nodeStatsAggregators array to hold node aggregate stats,
    +        // each node will have a nodeStatsAggregator
    +        val nodeStatsAggregators = Array.tabulate(numNodes) { nodeIndex =>
    +          val featuresForNode = nodeToFeaturesBc.value.flatMap { nodeToFeatures =>
    +            Some(nodeToFeatures(nodeIndex))
    +          }
    +          new DTStatsAggregator(metadata, featuresForNode)
    +        }
    +
    +        // iterator all instances in current partition and update aggregate stats
    +        points.foreach(binSeqOp(nodeStatsAggregators, _))
    +
    +        // transform nodeStatsAggregators array to (nodeIndex, nodeAggregateStats) pairs,
    +        // which can be combined with other partition using `reduceByKey`
    +        nodeStatsAggregators.view.zipWithIndex.map(_.swap).iterator
    +      }
    +    }
    +
    +    val nodeToBestSplits = partitionAggregates.reduceByKey((a, b) => a.merge(b)).map {
    +      case (nodeIndex, aggStats) =>
    +        val featuresForNode = nodeToFeaturesBc.value.flatMap { nodeToFeatures =>
    +          Some(nodeToFeatures(nodeIndex))
    +        }
    +
    +        // find best split for each node
    +        val (split: Split, stats: InformationGainStats, predict: Predict) =
    +          binsToBestSplit(aggStats, splits, featuresForNode, nodes(nodeIndex))
    +        (nodeIndex, (split, stats, predict))
    +    }.collectAsMap()
    +
    +    timer.stop("chooseSplits")
    +
    +    val nodeIdUpdaters = if (nodeIdCache.nonEmpty) {
    +      Array.fill[mutable.Map[Int, NodeIndexUpdater]](
    +        metadata.numTrees)(mutable.Map[Int, NodeIndexUpdater]())
    +    } else {
    +      null
    +    }
    +    // Iterate over all nodes in this group.
    +    nodesForGroup.foreach { case (treeIndex, nodesForTree) =>
    +      nodesForTree.foreach { node =>
    +        val nodeIndex = node.id
    +        val nodeInfo = treeToNodeToIndexInfo(treeIndex)(nodeIndex)
    +        val aggNodeIndex = nodeInfo.nodeIndexInGroup
    +        val (split: Split, stats: InformationGainStats, predict: Predict) =
    +          nodeToBestSplits(aggNodeIndex)
    +        logDebug("best split = " + split)
    +
    +        // Extract info for this node.  Create children if not leaf.
    +        val isLeaf =
    +          (stats.gain <= 0) || (LearningNode.indexToLevel(nodeIndex) == metadata.maxDepth)
    +        node.predictionStats = predict
    +        node.isLeaf = isLeaf
    +        node.stats = Some(stats)
    +        node.impurity = stats.impurity
    +        logDebug("Node = " + node)
    +
    +        if (!isLeaf) {
    +          node.split = Some(split)
    +          val childIsLeaf = (LearningNode.indexToLevel(nodeIndex) + 1) == metadata.maxDepth
    +          val leftChildIsLeaf = childIsLeaf || (stats.leftImpurity == 0.0)
    +          val rightChildIsLeaf = childIsLeaf || (stats.rightImpurity == 0.0)
    +          node.leftChild = Some(LearningNode(LearningNode.leftChildIndex(nodeIndex),
    +            stats.leftPredict, stats.leftImpurity, leftChildIsLeaf))
    +          node.rightChild = Some(LearningNode(LearningNode.rightChildIndex(nodeIndex),
    +            stats.rightPredict, stats.rightImpurity, rightChildIsLeaf))
    +
    +          if (nodeIdCache.nonEmpty) {
    +            val nodeIndexUpdater = NodeIndexUpdater(
    +              split = split,
    +              nodeIndex = nodeIndex)
    +            nodeIdUpdaters(treeIndex).put(nodeIndex, nodeIndexUpdater)
    +          }
    +
    +          // enqueue left child and right child if they are not leaves
    +          if (!leftChildIsLeaf) {
    +            nodeQueue.enqueue((treeIndex, node.leftChild.get))
    +          }
    +          if (!rightChildIsLeaf) {
    +            nodeQueue.enqueue((treeIndex, node.rightChild.get))
    +          }
    +
    +          logDebug("leftChildIndex = " + node.leftChild.get.id +
    +            ", impurity = " + stats.leftImpurity)
    +          logDebug("rightChildIndex = " + node.rightChild.get.id +
    +            ", impurity = " + stats.rightImpurity)
    +        }
    +      }
    +    }
    +
    +    if (nodeIdCache.nonEmpty) {
    +      // Update the cache if needed.
    +      nodeIdCache.get.updateNodeIndices(input, nodeIdUpdaters, splits)
    +    }
    +  }
    +
    +  /**
    +   * Calculate the information gain for a given (feature, split) based upon left/right aggregates.
    +   * @param leftImpurityCalculator left node aggregates for this (feature, split)
    +   * @param rightImpurityCalculator right node aggregate for this (feature, split)
    +   * @return information gain and statistics for split
    +   */
    +  private def calculateGainForSplit(
    +      leftImpurityCalculator: ImpurityCalculator,
    +      rightImpurityCalculator: ImpurityCalculator,
    +      metadata: DecisionTreeMetadata,
    +      impurity: Double): InformationGainStats = {
    +    val leftCount = leftImpurityCalculator.count
    +    val rightCount = rightImpurityCalculator.count
    +
    +    // If left child or right child doesn't satisfy minimum instances per node,
    +    // then this split is invalid, return invalid information gain stats.
    +    if ((leftCount < metadata.minInstancesPerNode) ||
    +      (rightCount < metadata.minInstancesPerNode)) {
    +      return InformationGainStats.invalidInformationGainStats
    +    }
    +
    +    val totalCount = leftCount + rightCount
    +
    +    val leftImpurity = leftImpurityCalculator.calculate() // Note: This equals 0 if count = 0
    +    val rightImpurity = rightImpurityCalculator.calculate()
    +
    +    val leftWeight = leftCount / totalCount.toDouble
    +    val rightWeight = rightCount / totalCount.toDouble
    +
    +    val gain = impurity - leftWeight * leftImpurity - rightWeight * rightImpurity
    +
    +    // if information gain doesn't satisfy minimum information gain,
    +    // then this split is invalid, return invalid information gain stats.
    +    if (gain < metadata.minInfoGain) {
    +      return InformationGainStats.invalidInformationGainStats
    +    }
    +
    +    // calculate left and right predict
    +    val leftPredict = calculatePredict(leftImpurityCalculator)
    +    val rightPredict = calculatePredict(rightImpurityCalculator)
    +
    +    new InformationGainStats(gain, impurity, leftImpurity, rightImpurity,
    +      leftPredict, rightPredict)
    +  }
    +
    +  private def calculatePredict(impurityCalculator: ImpurityCalculator): Predict = {
    +    val predict = impurityCalculator.predict
    +    val prob = impurityCalculator.prob(predict)
    +    new Predict(predict, prob)
    +  }
    +
    +  /**
    +   * Calculate predict value for current node, given stats of any split.
    +   * Note that this function is called only once for each node.
    +   * @param leftImpurityCalculator left node aggregates for a split
    +   * @param rightImpurityCalculator right node aggregates for a split
    +   * @return predict value and impurity for current node
    +   */
    +  private def calculatePredictImpurity(
    +      leftImpurityCalculator: ImpurityCalculator,
    +      rightImpurityCalculator: ImpurityCalculator): (Predict, Double) = {
    +    val parentNodeAgg = leftImpurityCalculator.copy
    +    parentNodeAgg.add(rightImpurityCalculator)
    +    val predict = calculatePredict(parentNodeAgg)
    +    val impurity = parentNodeAgg.calculate()
    +
    +    (predict, impurity)
    +  }
    +
    +  /**
    +   * Find the best split for a node.
    +   * @param binAggregates Bin statistics.
    +   * @return tuple for best split: (Split, information gain, prediction at node)
    +   */
    +  private def binsToBestSplit(
    +      binAggregates: DTStatsAggregator,
    +      splits: Array[Array[Split]],
    +      featuresForNode: Option[Array[Int]],
    +      node: LearningNode): (Split, InformationGainStats, Predict) = {
    +
    +    // Calculate prediction and impurity if current node is top node
    +    val level = LearningNode.indexToLevel(node.id)
    +    var predictionAndImpurity: Option[(Predict, Double)] = if (level == 0) {
    +      None
    +    } else {
    +      Some((node.predictionStats, node.impurity))
    +    }
    +
    +    // For each (feature, split), calculate the gain, and select the best (feature, split).
    +    val (bestSplit, bestSplitStats) =
    +      Range(0, binAggregates.metadata.numFeaturesPerNode).map { featureIndexIdx =>
    +        val featureIndex = if (featuresForNode.nonEmpty) {
    +          featuresForNode.get.apply(featureIndexIdx)
    +        } else {
    +          featureIndexIdx
    +        }
    +        val numSplits = binAggregates.metadata.numSplits(featureIndex)
    +        if (binAggregates.metadata.isContinuous(featureIndex)) {
    +          // Cumulative sum (scanLeft) of bin statistics.
    +          // Afterwards, binAggregates for a bin is the sum of aggregates for
    +          // that bin + all preceding bins.
    +          val nodeFeatureOffset = binAggregates.getFeatureOffset(featureIndexIdx)
    +          var splitIndex = 0
    +          while (splitIndex < numSplits) {
    +            binAggregates.mergeForFeature(nodeFeatureOffset, splitIndex + 1, splitIndex)
    +            splitIndex += 1
    +          }
    +          // Find best split.
    +          val (bestFeatureSplitIndex, bestFeatureGainStats) =
    +            Range(0, numSplits).map { case splitIdx =>
    +              val leftChildStats = binAggregates.getImpurityCalculator(nodeFeatureOffset, splitIdx)
    +              val rightChildStats =
    +                binAggregates.getImpurityCalculator(nodeFeatureOffset, numSplits)
    +              rightChildStats.subtract(leftChildStats)
    +              predictionAndImpurity = Some(predictionAndImpurity.getOrElse(
    +                calculatePredictImpurity(leftChildStats, rightChildStats)))
    +              val gainStats = calculateGainForSplit(leftChildStats,
    +                rightChildStats, binAggregates.metadata, predictionAndImpurity.get._2)
    +              (splitIdx, gainStats)
    +            }.maxBy(_._2.gain)
    +          (splits(featureIndex)(bestFeatureSplitIndex), bestFeatureGainStats)
    +        } else if (binAggregates.metadata.isUnordered(featureIndex)) {
    +          // Unordered categorical feature
    +          val (leftChildOffset, rightChildOffset) =
    +            binAggregates.getLeftRightFeatureOffsets(featureIndexIdx)
    +          val (bestFeatureSplitIndex, bestFeatureGainStats) =
    +            Range(0, numSplits).map { splitIndex =>
    +              val leftChildStats = binAggregates.getImpurityCalculator(leftChildOffset, splitIndex)
    +              val rightChildStats =
    +                binAggregates.getImpurityCalculator(rightChildOffset, splitIndex)
    +              predictionAndImpurity = Some(predictionAndImpurity.getOrElse(
    +                calculatePredictImpurity(leftChildStats, rightChildStats)))
    +              val gainStats = calculateGainForSplit(leftChildStats,
    +                rightChildStats, binAggregates.metadata, predictionAndImpurity.get._2)
    +              (splitIndex, gainStats)
    +            }.maxBy(_._2.gain)
    +          (splits(featureIndex)(bestFeatureSplitIndex), bestFeatureGainStats)
    +        } else {
    +          // Ordered categorical feature
    +          val nodeFeatureOffset = binAggregates.getFeatureOffset(featureIndexIdx)
    +          val numCategories = binAggregates.metadata.numBins(featureIndex)
    +
    +          /* Each bin is one category (feature value).
    +           * The bins are ordered based on centroidForCategories, and this ordering determines which
    +           * splits are considered.  (With K categories, we consider K - 1 possible splits.)
    +           *
    +           * centroidForCategories is a list: (category, centroid)
    +           */
    +          val centroidForCategories = if (binAggregates.metadata.isMulticlass) {
    +            // For categorical variables in multiclass classification,
    +            // the bins are ordered by the impurity of their corresponding labels.
    +            Range(0, numCategories).map { case featureValue =>
    +              val categoryStats =
    +                binAggregates.getImpurityCalculator(nodeFeatureOffset, featureValue)
    +              val centroid = if (categoryStats.count != 0) {
    +                categoryStats.calculate()
    +              } else {
    +                Double.MaxValue
    +              }
    +              (featureValue, centroid)
    +            }
    +          } else { // regression or binary classification
    +            // For categorical variables in regression and binary classification,
    +            // the bins are ordered by the centroid of their corresponding labels.
    +            Range(0, numCategories).map { case featureValue =>
    +              val categoryStats =
    +                binAggregates.getImpurityCalculator(nodeFeatureOffset, featureValue)
    +              val centroid = if (categoryStats.count != 0) {
    +                categoryStats.predict
    +              } else {
    +                Double.MaxValue
    +              }
    +              (featureValue, centroid)
    +            }
    +          }
    +
    +          logDebug("Centroids for categorical variable: " + centroidForCategories.mkString(","))
    +
    +          // bins sorted by centroids
    +          val categoriesSortedByCentroid = centroidForCategories.toList.sortBy(_._2)
    +
    +          logDebug("Sorted centroids for categorical variable = " +
    +            categoriesSortedByCentroid.mkString(","))
    +
    +          // Cumulative sum (scanLeft) of bin statistics.
    +          // Afterwards, binAggregates for a bin is the sum of aggregates for
    +          // that bin + all preceding bins.
    +          var splitIndex = 0
    +          while (splitIndex < numSplits) {
    +            val currentCategory = categoriesSortedByCentroid(splitIndex)._1
    +            val nextCategory = categoriesSortedByCentroid(splitIndex + 1)._1
    +            binAggregates.mergeForFeature(nodeFeatureOffset, nextCategory, currentCategory)
    +            splitIndex += 1
    +          }
    +          // lastCategory = index of bin with total aggregates for this (node, feature)
    +          val lastCategory = categoriesSortedByCentroid.last._1
    +          // Find best split.
    +          val (bestFeatureSplitIndex, bestFeatureGainStats) =
    +            Range(0, numSplits).map { splitIndex =>
    +              val featureValue = categoriesSortedByCentroid(splitIndex)._1
    +              val leftChildStats =
    +                binAggregates.getImpurityCalculator(nodeFeatureOffset, featureValue)
    +              val rightChildStats =
    +                binAggregates.getImpurityCalculator(nodeFeatureOffset, lastCategory)
    +              rightChildStats.subtract(leftChildStats)
    +              predictionAndImpurity = Some(predictionAndImpurity.getOrElse(
    +                calculatePredictImpurity(leftChildStats, rightChildStats)))
    +              val gainStats = calculateGainForSplit(leftChildStats,
    +                rightChildStats, binAggregates.metadata, predictionAndImpurity.get._2)
    +              (splitIndex, gainStats)
    +            }.maxBy(_._2.gain)
    +          val categoriesForSplit =
    +            categoriesSortedByCentroid.map(_._1.toDouble).slice(0, bestFeatureSplitIndex + 1)
    +          val bestFeatureSplit =
    +            new CategoricalSplit(featureIndex, categoriesForSplit.toArray, numCategories)
    +          (bestFeatureSplit, bestFeatureGainStats)
    +        }
    +      }.maxBy(_._2.gain)
    +
    +    (bestSplit, bestSplitStats, predictionAndImpurity.get._1)
    +  }
    +
    +  /**
    +   * Returns splits and bins for decision tree calculation.
    +   * Continuous and categorical features are handled differently.
    +   *
    +   * Continuous features:
    +   *   For each feature, there are numBins - 1 possible splits representing the possible binary
    +   *   decisions at each node in the tree.
    +   *   This finds locations (feature values) for splits using a subsample of the data.
    +   *
    +   * Categorical features:
    +   *   For each feature, there is 1 bin per split.
    +   *   Splits and bins are handled in 2 ways:
    +   *   (a) "unordered features"
    +   *       For multiclass classification with a low-arity feature
    +   *       (i.e., if isMulticlass && isSpaceSufficientForAllCategoricalSplits),
    +   *       the feature is split based on subsets of categories.
    +   *   (b) "ordered features"
    +   *       For regression and binary classification,
    +   *       and for multiclass classification with a high-arity feature,
    +   *       there is one bin per category.
    +   *
    +   * @param input Training data: RDD of [[org.apache.spark.mllib.regression.LabeledPoint]]
    +   * @param metadata Learning and dataset metadata
    +   * @return A tuple of (splits, bins).
    +   *         Splits is an Array of [[org.apache.spark.mllib.tree.model.Split]]
    +   *          of size (numFeatures, numSplits).
    +   *         Bins is an Array of [[org.apache.spark.mllib.tree.model.Bin]]
    +   *          of size (numFeatures, numBins).
    +   */
    +  protected[tree] def findSplits(
    +      input: RDD[LabeledPoint],
    +      metadata: DecisionTreeMetadata): Array[Array[Split]] = {
    +
    +    logDebug("isMulticlass = " + metadata.isMulticlass)
    +
    +    val numFeatures = metadata.numFeatures
    +
    +    // Sample the input only if there are continuous features.
    +    val hasContinuousFeatures = Range(0, numFeatures).exists(metadata.isContinuous)
    +    val sampledInput = if (hasContinuousFeatures) {
    +      // Calculate the number of samples for approximate quantile calculation.
    +      val requiredSamples = math.max(metadata.maxBins * metadata.maxBins, 10000)
    +      val fraction = if (requiredSamples < metadata.numExamples) {
    +        requiredSamples.toDouble / metadata.numExamples
    +      } else {
    +        1.0
    +      }
    +      logDebug("fraction of data used for calculating quantiles = " + fraction)
    +      input.sample(withReplacement = false, fraction, new XORShiftRandom(1).nextInt()).collect()
    +    } else {
    +      new Array[LabeledPoint](0)
    +    }
    +
    +    val splits = new Array[Array[Split]](numFeatures)
    +
    +    // Find all splits.
    +    // Iterate over all features.
    +    var featureIndex = 0
    +    while (featureIndex < numFeatures) {
    +      if (metadata.isContinuous(featureIndex)) {
    +        val featureSamples = sampledInput.map(_.features(featureIndex))
    +        val featureSplits = findSplitsForContinuousFeature(featureSamples, metadata, featureIndex)
    +
    +        val numSplits = featureSplits.length
    +        logDebug(s"featureIndex = $featureIndex, numSplits = $numSplits")
    +        splits(featureIndex) = new Array[Split](numSplits)
    +
    +        var splitIndex = 0
    +        while (splitIndex < numSplits) {
    +          val threshold = featureSplits(splitIndex)
    +          splits(featureIndex)(splitIndex) = new ContinuousSplit(featureIndex, threshold)
    +          splitIndex += 1
    +        }
    +      } else {
    +        // Categorical feature
    +        if (metadata.isUnordered(featureIndex)) {
    +          val numSplits = metadata.numSplits(featureIndex)
    +          val featureArity = metadata.featureArity(featureIndex)
    +          // TODO: Use an implicit representation mapping each category to a subset of indices.
    +          //       I.e., track indices such that we can calculate the set of bins for which
    +          //       feature value x splits to the left.
    +          // Unordered features
    +          // 2^(maxFeatureValue - 1) - 1 combinations
    +          splits(featureIndex) = new Array[Split](numSplits)
    +          var splitIndex = 0
    +          while (splitIndex < numSplits) {
    +            val categories: List[Double] =
    +              extractMultiClassCategories(splitIndex + 1, featureArity)
    +            splits(featureIndex)(splitIndex) =
    +              new CategoricalSplit(featureIndex, categories.toArray, featureArity)
    +            splitIndex += 1
    +          }
    +        } else {
    +          // Ordered features
    +          //   Bins correspond to feature values, so we do not need to compute splits or bins
    +          //   beforehand.  Splits are constructed as needed during training.
    +          splits(featureIndex) = new Array[Split](0)
    +        }
    +      }
    +      featureIndex += 1
    +    }
    +    splits
    +  }
    +
    +  /**
    +   * Nested method to extract list of eligible categories given an index. It extracts the
    +   * position of ones in a binary representation of the input. If binary
    +   * representation of an number is 01101 (13), the output list should (3.0, 2.0,
    +   * 0.0). The maxFeatureValue depict the number of rightmost digits that will be tested for ones.
    +   */
    +  private[tree] def extractMultiClassCategories(
    +      input: Int,
    +      maxFeatureValue: Int): List[Double] = {
    +    var categories = List[Double]()
    +    var j = 0
    +    var bitShiftedInput = input
    +    while (j < maxFeatureValue) {
    +      if (bitShiftedInput % 2 != 0) {
    +        // updating the list of categories.
    +        categories = j.toDouble :: categories
    +      }
    +      // Right shift by one
    +      bitShiftedInput = bitShiftedInput >> 1
    +      j += 1
    +    }
    +    categories
    +  }
    +
    +  /**
    +   * Find splits for a continuous feature
    +   * NOTE: Returned number of splits is set based on `featureSamples` and
    +   *       could be different from the specified `numSplits`.
    +   *       The `numSplits` attribute in the `DecisionTreeMetadata` class will be set accordingly.
    +   * @param featureSamples feature values of each sample
    +   * @param metadata decision tree metadata
    +   *                 NOTE: `metadata.numbins` will be changed accordingly
    +   *                       if there are not enough splits to be found
    +   * @param featureIndex feature index to find splits
    +   * @return array of splits
    +   */
    +  private[tree] def findSplitsForContinuousFeature(
    +      featureSamples: Array[Double],
    +      metadata: DecisionTreeMetadata,
    +      featureIndex: Int): Array[Double] = {
    +    require(metadata.isContinuous(featureIndex),
    +      "findSplitsForContinuousFeature can only be used to find splits for a continuous feature.")
    +
    +    val splits = {
    +      val numSplits = metadata.numSplits(featureIndex)
    +
    +      // get count for each distinct value
    +      val valueCountMap = featureSamples.foldLeft(Map.empty[Double, Int]) { (m, x) =>
    +        m + ((x, m.getOrElse(x, 0) + 1))
    +      }
    +      // sort distinct values
    +      val valueCounts = valueCountMap.toSeq.sortBy(_._1).toArray
    +
    +      // if possible splits is not enough or just enough, just return all possible splits
    +      val possibleSplits = valueCounts.length
    +      if (possibleSplits <= numSplits) {
    +        valueCounts.map(_._1)
    +      } else {
    +        // stride between splits
    +        val stride: Double = featureSamples.length.toDouble / (numSplits + 1)
    +        logDebug("stride = " + stride)
    +
    +        // iterate `valueCount` to find splits
    +        val splitsBuilder = mutable.ArrayBuilder.make[Double]
    +        var index = 1
    +        // currentCount: sum of counts of values that have been visited
    +        var currentCount = valueCounts(0)._2
    +        // targetCount: target value for `currentCount`.
    +        // If `currentCount` is closest value to `targetCount`,
    +        // then current value is a split threshold.
    +        // After finding a split threshold, `targetCount` is added by stride.
    +        var targetCount = stride
    +        while (index < valueCounts.length) {
    +          val previousCount = currentCount
    +          currentCount += valueCounts(index)._2
    +          val previousGap = math.abs(previousCount - targetCount)
    +          val currentGap = math.abs(currentCount - targetCount)
    +          // If adding count of current value to currentCount
    +          // makes the gap between currentCount and targetCount smaller,
    +          // previous value is a split threshold.
    +          if (previousGap < currentGap) {
    +            splitsBuilder += valueCounts(index - 1)._1
    +            targetCount += stride
    +          }
    +          index += 1
    +        }
    +
    +        splitsBuilder.result()
    +      }
    +    }
    +
    +    // TODO: Do not fail; just ignore the useless feature.
    +    assert(splits.length > 0,
    +      s"DecisionTree could not handle feature $featureIndex since it had only 1 unique value." +
    +        "  Please remove this feature and then try again.")
    +    // set number of splits accordingly
    +    metadata.setNumSplits(featureIndex, splits.length)
    +
    +    splits
    +  }
    +
    +  private[tree] class NodeIndexInfo(
    +      val nodeIndexInGroup: Int,
    +      val featureSubset: Option[Array[Int]]) extends Serializable
    +
    +  /**
    +   * Pull nodes off of the queue, and collect a group of nodes to be split on this iteration.
    +   * This tracks the memory usage for aggregates and stops adding nodes when too much memory
    +   * will be needed; this allows an adaptive number of nodes since different nodes may require
    +   * different amounts of memory (if featureSubsetStrategy is not "all").
    +   *
    +   * @param nodeQueue  Queue of nodes to split.
    +   * @param maxMemoryUsage  Bound on size of aggregate statistics.
    +   * @return  (nodesForGroup, treeToNodeToIndexInfo).
    +   *          nodesForGroup holds the nodes to split: treeIndex --> nodes in tree.
    +   *
    +   *          treeToNodeToIndexInfo holds indices selected features for each node:
    +   *            treeIndex --> (global) node index --> (node index in group, feature indices).
    +   *          The (global) node index is the index in the tree; the node index in group is the
    +   *           index in [0, numNodesInGroup) of the node in this group.
    +   *          The feature indices are None if not subsampling features.
    +   */
    +  private[tree] def selectNodesToSplit(
    +      nodeQueue: mutable.Queue[(Int, LearningNode)],
    +      maxMemoryUsage: Long,
    +      metadata: DecisionTreeMetadata,
    +      rng: Random): (Map[Int, Array[LearningNode]], Map[Int, Map[Int, NodeIndexInfo]]) = {
    +    // Collect some nodes to split:
    +    //  nodesForGroup(treeIndex) = nodes to split
    +    val mutableNodesForGroup = new mutable.HashMap[Int, mutable.ArrayBuffer[LearningNode]]()
    +    val mutableTreeToNodeToIndexInfo =
    +      new mutable.HashMap[Int, mutable.HashMap[Int, NodeIndexInfo]]()
    +    var memUsage: Long = 0L
    +    var numNodesInGroup = 0
    +    while (nodeQueue.nonEmpty && memUsage < maxMemoryUsage) {
    +      val (treeIndex, node) = nodeQueue.head
    +      // Choose subset of features for node (if subsampling).
    +      val featureSubset: Option[Array[Int]] = if (metadata.subsamplingFeatures) {
    +        Some(SamplingUtils.reservoirSampleAndCount(Range(0,
    +          metadata.numFeatures).iterator, metadata.numFeaturesPerNode, rng.nextLong())._1)
    +      } else {
    +        None
    +      }
    +      // Check if enough memory remains to add this node to the group.
    +      val nodeMemUsage = RandomForest.aggregateSizeForNode(metadata, featureSubset) * 8L
    +      if (memUsage + nodeMemUsage <= maxMemoryUsage) {
    +        nodeQueue.dequeue()
    +        mutableNodesForGroup.getOrElseUpdate(treeIndex, new mutable.ArrayBuffer[LearningNode]()) +=
    +          node
    +        mutableTreeToNodeToIndexInfo
    +          .getOrElseUpdate(treeIndex, new mutable.HashMap[Int, NodeIndexInfo]())(node.id)
    +          = new NodeIndexInfo(numNodesInGroup, featureSubset)
    +      }
    +      numNodesInGroup += 1
    +      memUsage += nodeMemUsage
    +    }
    +    // Convert mutable maps to immutable ones.
    +    val nodesForGroup: Map[Int, Array[LearningNode]] =
    +      mutableNodesForGroup.mapValues(_.toArray).toMap
    +    val treeToNodeToIndexInfo = mutableTreeToNodeToIndexInfo.mapValues(_.toMap).toMap
    +    (nodesForGroup, treeToNodeToIndexInfo)
    +  }
    +
    +  /**
    +   * Get the number of values to be stored for this node in the bin aggregates.
    +   * @param featureSubset  Indices of features which may be split at this node.
    +   *                       If None, then use all features.
    +   */
    +  private def aggregateSizeForNode(
    +      metadata: DecisionTreeMetadata,
    +      featureSubset: Option[Array[Int]]): Long = {
    +    val totalBins = if (featureSubset.nonEmpty) {
    +      featureSubset.get.map(featureIndex => metadata.numBins(featureIndex).toLong).sum
    +    } else {
    +      metadata.numBins.map(_.toLong).sum
    +    }
    +    if (metadata.isClassification) {
    +      metadata.numClasses * totalBins
    +    } else {
    +      3 * totalBins
    +    }
    +  }
    +
    +}
    diff --git a/mllib/src/main/scala/org/apache/spark/ml/tree/impl/TreePoint.scala b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/TreePoint.scala
    new file mode 100644
    index 0000000000000..9fa27e5e1f721
    --- /dev/null
    +++ b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/TreePoint.scala
    @@ -0,0 +1,134 @@
    +/*
    + * Licensed to the Apache Software Foundation (ASF) under one or more
    + * contributor license agreements.  See the NOTICE file distributed with
    + * this work for additional information regarding copyright ownership.
    + * The ASF licenses this file to You under the Apache License, Version 2.0
    + * (the "License"); you may not use this file except in compliance with
    + * the License.  You may obtain a copy of the License at
    + *
    + *    http://www.apache.org/licenses/LICENSE-2.0
    + *
    + * Unless required by applicable law or agreed to in writing, software
    + * distributed under the License is distributed on an "AS IS" BASIS,
    + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    + * See the License for the specific language governing permissions and
    + * limitations under the License.
    + */
    +
    +package org.apache.spark.ml.tree.impl
    +
    +import org.apache.spark.ml.tree.{ContinuousSplit, Split}
    +import org.apache.spark.mllib.regression.LabeledPoint
    +import org.apache.spark.mllib.tree.impl.DecisionTreeMetadata
    +import org.apache.spark.rdd.RDD
    +
    +
    +/**
    + * Internal representation of LabeledPoint for DecisionTree.
    + * This bins feature values based on a subsampled of data as follows:
    + *  (a) Continuous features are binned into ranges.
    + *  (b) Unordered categorical features are binned based on subsets of feature values.
    + *      "Unordered categorical features" are categorical features with low arity used in
    + *      multiclass classification.
    + *  (c) Ordered categorical features are binned based on feature values.
    + *      "Ordered categorical features" are categorical features with high arity,
    + *      or any categorical feature used in regression or binary classification.
    + *
    + * @param label  Label from LabeledPoint
    + * @param binnedFeatures  Binned feature values.
    + *                        Same length as LabeledPoint.features, but values are bin indices.
    + */
    +private[spark] class TreePoint(val label: Double, val binnedFeatures: Array[Int])
    +  extends Serializable {
    +}
    +
    +private[spark] object TreePoint {
    +
    +  /**
    +   * Convert an input dataset into its TreePoint representation,
    +   * binning feature values in preparation for DecisionTree training.
    +   * @param input     Input dataset.
    +   * @param splits    Splits for features, of size (numFeatures, numSplits).
    +   * @param metadata  Learning and dataset metadata
    +   * @return  TreePoint dataset representation
    +   */
    +  def convertToTreeRDD(
    +      input: RDD[LabeledPoint],
    +      splits: Array[Array[Split]],
    +      metadata: DecisionTreeMetadata): RDD[TreePoint] = {
    +    // Construct arrays for featureArity for efficiency in the inner loop.
    +    val featureArity: Array[Int] = new Array[Int](metadata.numFeatures)
    +    var featureIndex = 0
    +    while (featureIndex < metadata.numFeatures) {
    +      featureArity(featureIndex) = metadata.featureArity.getOrElse(featureIndex, 0)
    +      featureIndex += 1
    +    }
    +    val thresholds: Array[Array[Double]] = featureArity.zipWithIndex.map { case (arity, idx) =>
    +      if (arity == 0) {
    +        splits(idx).map(_.asInstanceOf[ContinuousSplit].threshold)
    +      } else {
    +        Array.empty[Double]
    +      }
    +    }
    +    input.map { x =>
    +      TreePoint.labeledPointToTreePoint(x, thresholds, featureArity)
    +    }
    +  }
    +
    +  /**
    +   * Convert one LabeledPoint into its TreePoint representation.
    +   * @param thresholds  For each feature, split thresholds for continuous features,
    +   *                    empty for categorical features.
    +   * @param featureArity  Array indexed by feature, with value 0 for continuous and numCategories
    +   *                      for categorical features.
    +   */
    +  private def labeledPointToTreePoint(
    +      labeledPoint: LabeledPoint,
    +      thresholds: Array[Array[Double]],
    +      featureArity: Array[Int]): TreePoint = {
    +    val numFeatures = labeledPoint.features.size
    +    val arr = new Array[Int](numFeatures)
    +    var featureIndex = 0
    +    while (featureIndex < numFeatures) {
    +      arr(featureIndex) =
    +        findBin(featureIndex, labeledPoint, featureArity(featureIndex), thresholds(featureIndex))
    +      featureIndex += 1
    +    }
    +    new TreePoint(labeledPoint.label, arr)
    +  }
    +
    +  /**
    +   * Find discretized value for one (labeledPoint, feature).
    +   *
    +   * NOTE: We cannot use Bucketizer since it handles split thresholds differently than the old
    +   *       (mllib) tree API.  We want to maintain the same behavior as the old tree API.
    +   *
    +   * @param featureArity  0 for continuous features; number of categories for categorical features.
    +   */
    +  private def findBin(
    +      featureIndex: Int,
    +      labeledPoint: LabeledPoint,
    +      featureArity: Int,
    +      thresholds: Array[Double]): Int = {
    +    val featureValue = labeledPoint.features(featureIndex)
    +
    +    if (featureArity == 0) {
    +      val idx = java.util.Arrays.binarySearch(thresholds, featureValue)
    +      if (idx >= 0) {
    +        idx
    +      } else {
    +        -idx - 1
    +      }
    +    } else {
    +      // Categorical feature bins are indexed by feature values.
    +      if (featureValue < 0 || featureValue >= featureArity) {
    +        throw new IllegalArgumentException(
    +          s"DecisionTree given invalid data:" +
    +            s" Feature $featureIndex is categorical with values in {0,...,${featureArity - 1}," +
    +            s" but a data point gives it value $featureValue.\n" +
    +            "  Bad data point: " + labeledPoint.toString)
    +      }
    +      featureValue.toInt
    +    }
    +  }
    +}
    diff --git a/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala b/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala
    index e2444ab65b43b..f979319cc4b58 100644
    --- a/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala
    +++ b/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala
    @@ -32,38 +32,7 @@ import org.apache.spark.sql.types.StructType
     /**
      * Params for [[CrossValidator]] and [[CrossValidatorModel]].
      */
    -private[ml] trait CrossValidatorParams extends Params {
    -
    -  /**
    -   * param for the estimator to be cross-validated
    -   * @group param
    -   */
    -  val estimator: Param[Estimator[_]] = new Param(this, "estimator", "estimator for selection")
    -
    -  /** @group getParam */
    -  def getEstimator: Estimator[_] = $(estimator)
    -
    -  /**
    -   * param for estimator param maps
    -   * @group param
    -   */
    -  val estimatorParamMaps: Param[Array[ParamMap]] =
    -    new Param(this, "estimatorParamMaps", "param maps for the estimator")
    -
    -  /** @group getParam */
    -  def getEstimatorParamMaps: Array[ParamMap] = $(estimatorParamMaps)
    -
    -  /**
    -   * param for the evaluator used to select hyper-parameters that maximize the cross-validated
    -   * metric
    -   * @group param
    -   */
    -  val evaluator: Param[Evaluator] = new Param(this, "evaluator",
    -    "evaluator used to select hyper-parameters that maximize the cross-validated metric")
    -
    -  /** @group getParam */
    -  def getEvaluator: Evaluator = $(evaluator)
    -
    +private[ml] trait CrossValidatorParams extends ValidatorParams {
       /**
        * Param for number of folds for cross validation.  Must be >= 2.
        * Default: 3
    diff --git a/mllib/src/main/scala/org/apache/spark/ml/tuning/TrainValidationSplit.scala b/mllib/src/main/scala/org/apache/spark/ml/tuning/TrainValidationSplit.scala
    new file mode 100644
    index 0000000000000..c0edc730b6fd6
    --- /dev/null
    +++ b/mllib/src/main/scala/org/apache/spark/ml/tuning/TrainValidationSplit.scala
    @@ -0,0 +1,168 @@
    +/*
    + * Licensed to the Apache Software Foundation (ASF) under one or more
    + * contributor license agreements.  See the NOTICE file distributed with
    + * this work for additional information regarding copyright ownership.
    + * The ASF licenses this file to You under the Apache License, Version 2.0
    + * (the "License"); you may not use this file except in compliance with
    + * the License.  You may obtain a copy of the License at
    + *
    + *    http://www.apache.org/licenses/LICENSE-2.0
    + *
    + * Unless required by applicable law or agreed to in writing, software
    + * distributed under the License is distributed on an "AS IS" BASIS,
    + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    + * See the License for the specific language governing permissions and
    + * limitations under the License.
    + */
    +
    +package org.apache.spark.ml.tuning
    +
    +import org.apache.spark.Logging
    +import org.apache.spark.annotation.Experimental
    +import org.apache.spark.ml.evaluation.Evaluator
    +import org.apache.spark.ml.{Estimator, Model}
    +import org.apache.spark.ml.param.{DoubleParam, ParamMap, ParamValidators}
    +import org.apache.spark.ml.util.Identifiable
    +import org.apache.spark.sql.DataFrame
    +import org.apache.spark.sql.types.StructType
    +
    +/**
    + * Params for [[TrainValidationSplit]] and [[TrainValidationSplitModel]].
    + */
    +private[ml] trait TrainValidationSplitParams extends ValidatorParams {
    +  /**
    +   * Param for ratio between train and validation data. Must be between 0 and 1.
    +   * Default: 0.75
    +   * @group param
    +   */
    +  val trainRatio: DoubleParam = new DoubleParam(this, "trainRatio",
    +    "ratio between training set and validation set (>= 0 && <= 1)", ParamValidators.inRange(0, 1))
    +
    +  /** @group getParam */
    +  def getTrainRatio: Double = $(trainRatio)
    +
    +  setDefault(trainRatio -> 0.75)
    +}
    +
    +/**
    + * :: Experimental ::
    + * Validation for hyper-parameter tuning.
    + * Randomly splits the input dataset into train and validation sets,
    + * and uses evaluation metric on the validation set to select the best model.
    + * Similar to [[CrossValidator]], but only splits the set once.
    + */
    +@Experimental
    +class TrainValidationSplit(override val uid: String) extends Estimator[TrainValidationSplitModel]
    +  with TrainValidationSplitParams with Logging {
    +
    +  def this() = this(Identifiable.randomUID("tvs"))
    +
    +  /** @group setParam */
    +  def setEstimator(value: Estimator[_]): this.type = set(estimator, value)
    +
    +  /** @group setParam */
    +  def setEstimatorParamMaps(value: Array[ParamMap]): this.type = set(estimatorParamMaps, value)
    +
    +  /** @group setParam */
    +  def setEvaluator(value: Evaluator): this.type = set(evaluator, value)
    +
    +  /** @group setParam */
    +  def setTrainRatio(value: Double): this.type = set(trainRatio, value)
    +
    +  override def fit(dataset: DataFrame): TrainValidationSplitModel = {
    +    val schema = dataset.schema
    +    transformSchema(schema, logging = true)
    +    val sqlCtx = dataset.sqlContext
    +    val est = $(estimator)
    +    val eval = $(evaluator)
    +    val epm = $(estimatorParamMaps)
    +    val numModels = epm.length
    +    val metrics = new Array[Double](epm.length)
    +
    +    val Array(training, validation) =
    +      dataset.rdd.randomSplit(Array($(trainRatio), 1 - $(trainRatio)))
    +    val trainingDataset = sqlCtx.createDataFrame(training, schema).cache()
    +    val validationDataset = sqlCtx.createDataFrame(validation, schema).cache()
    +
    +    // multi-model training
    +    logDebug(s"Train split with multiple sets of parameters.")
    +    val models = est.fit(trainingDataset, epm).asInstanceOf[Seq[Model[_]]]
    +    trainingDataset.unpersist()
    +    var i = 0
    +    while (i < numModels) {
    +      // TODO: duplicate evaluator to take extra params from input
    +      val metric = eval.evaluate(models(i).transform(validationDataset, epm(i)))
    +      logDebug(s"Got metric $metric for model trained with ${epm(i)}.")
    +      metrics(i) += metric
    +      i += 1
    +    }
    +    validationDataset.unpersist()
    +
    +    logInfo(s"Train validation split metrics: ${metrics.toSeq}")
    +    val (bestMetric, bestIndex) = metrics.zipWithIndex.maxBy(_._1)
    +    logInfo(s"Best set of parameters:\n${epm(bestIndex)}")
    +    logInfo(s"Best train validation split metric: $bestMetric.")
    +    val bestModel = est.fit(dataset, epm(bestIndex)).asInstanceOf[Model[_]]
    +    copyValues(new TrainValidationSplitModel(uid, bestModel, metrics).setParent(this))
    +  }
    +
    +  override def transformSchema(schema: StructType): StructType = {
    +    $(estimator).transformSchema(schema)
    +  }
    +
    +  override def validateParams(): Unit = {
    +    super.validateParams()
    +    val est = $(estimator)
    +    for (paramMap <- $(estimatorParamMaps)) {
    +      est.copy(paramMap).validateParams()
    +    }
    +  }
    +
    +  override def copy(extra: ParamMap): TrainValidationSplit = {
    +    val copied = defaultCopy(extra).asInstanceOf[TrainValidationSplit]
    +    if (copied.isDefined(estimator)) {
    +      copied.setEstimator(copied.getEstimator.copy(extra))
    +    }
    +    if (copied.isDefined(evaluator)) {
    +      copied.setEvaluator(copied.getEvaluator.copy(extra))
    +    }
    +    copied
    +  }
    +}
    +
    +/**
    + * :: Experimental ::
    + * Model from train validation split.
    + *
    + * @param uid Id.
    + * @param bestModel Estimator determined best model.
    + * @param validationMetrics Evaluated validation metrics.
    + */
    +@Experimental
    +class TrainValidationSplitModel private[ml] (
    +    override val uid: String,
    +    val bestModel: Model[_],
    +    val validationMetrics: Array[Double])
    +  extends Model[TrainValidationSplitModel] with TrainValidationSplitParams {
    +
    +  override def validateParams(): Unit = {
    +    bestModel.validateParams()
    +  }
    +
    +  override def transform(dataset: DataFrame): DataFrame = {
    +    transformSchema(dataset.schema, logging = true)
    +    bestModel.transform(dataset)
    +  }
    +
    +  override def transformSchema(schema: StructType): StructType = {
    +    bestModel.transformSchema(schema)
    +  }
    +
    +  override def copy(extra: ParamMap): TrainValidationSplitModel = {
    +    val copied = new TrainValidationSplitModel (
    +      uid,
    +      bestModel.copy(extra).asInstanceOf[Model[_]],
    +      validationMetrics.clone())
    +    copyValues(copied, extra)
    +  }
    +}
    diff --git a/mllib/src/main/scala/org/apache/spark/ml/tuning/ValidatorParams.scala b/mllib/src/main/scala/org/apache/spark/ml/tuning/ValidatorParams.scala
    new file mode 100644
    index 0000000000000..8897ab0825acd
    --- /dev/null
    +++ b/mllib/src/main/scala/org/apache/spark/ml/tuning/ValidatorParams.scala
    @@ -0,0 +1,60 @@
    +/*
    + * Licensed to the Apache Software Foundation (ASF) under one or more
    + * contributor license agreements.  See the NOTICE file distributed with
    + * this work for additional information regarding copyright ownership.
    + * The ASF licenses this file to You under the Apache License, Version 2.0
    + * (the "License"); you may not use this file except in compliance with
    + * the License.  You may obtain a copy of the License at
    + *
    + *    http://www.apache.org/licenses/LICENSE-2.0
    + *
    + * Unless required by applicable law or agreed to in writing, software
    + * distributed under the License is distributed on an "AS IS" BASIS,
    + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    + * See the License for the specific language governing permissions and
    + * limitations under the License.
    + */
    +
    +package org.apache.spark.ml.tuning
    +
    +import org.apache.spark.annotation.DeveloperApi
    +import org.apache.spark.ml.Estimator
    +import org.apache.spark.ml.evaluation.Evaluator
    +import org.apache.spark.ml.param.{ParamMap, Param, Params}
    +
    +/**
    + * :: DeveloperApi ::
    + * Common params for [[TrainValidationSplitParams]] and [[CrossValidatorParams]].
    + */
    +@DeveloperApi
    +private[ml] trait ValidatorParams extends Params {
    +
    +  /**
    +   * param for the estimator to be validated
    +   * @group param
    +   */
    +  val estimator: Param[Estimator[_]] = new Param(this, "estimator", "estimator for selection")
    +
    +  /** @group getParam */
    +  def getEstimator: Estimator[_] = $(estimator)
    +
    +  /**
    +   * param for estimator param maps
    +   * @group param
    +   */
    +  val estimatorParamMaps: Param[Array[ParamMap]] =
    +    new Param(this, "estimatorParamMaps", "param maps for the estimator")
    +
    +  /** @group getParam */
    +  def getEstimatorParamMaps: Array[ParamMap] = $(estimatorParamMaps)
    +
    +  /**
    +   * param for the evaluator used to select hyper-parameters that maximize the validated metric
    +   * @group param
    +   */
    +  val evaluator: Param[Evaluator] = new Param(this, "evaluator",
    +    "evaluator used to select hyper-parameters that maximize the validated metric")
    +
    +  /** @group getParam */
    +  def getEvaluator: Evaluator = $(evaluator)
    +}
    diff --git a/mllib/src/main/scala/org/apache/spark/ml/util/stopwatches.scala b/mllib/src/main/scala/org/apache/spark/ml/util/stopwatches.scala
    new file mode 100644
    index 0000000000000..8d4174124b5c4
    --- /dev/null
    +++ b/mllib/src/main/scala/org/apache/spark/ml/util/stopwatches.scala
    @@ -0,0 +1,153 @@
    +/*
    + * Licensed to the Apache Software Foundation (ASF) under one or more
    + * contributor license agreements.  See the NOTICE file distributed with
    + * this work for additional information regarding copyright ownership.
    + * The ASF licenses this file to You under the Apache License, Version 2.0
    + * (the "License"); you may not use this file except in compliance with
    + * the License.  You may obtain a copy of the License at
    + *
    + *    http://www.apache.org/licenses/LICENSE-2.0
    + *
    + * Unless required by applicable law or agreed to in writing, software
    + * distributed under the License is distributed on an "AS IS" BASIS,
    + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    + * See the License for the specific language governing permissions and
    + * limitations under the License.
    + */
    +
    +package org.apache.spark.ml.util
    +
    +import scala.collection.mutable
    +
    +import org.apache.spark.{Accumulator, SparkContext}
    +
    +/**
    + * Abstract class for stopwatches.
    + */
    +private[spark] abstract class Stopwatch extends Serializable {
    +
    +  @transient private var running: Boolean = false
    +  private var startTime: Long = _
    +
    +  /**
    +   * Name of the stopwatch.
    +   */
    +  val name: String
    +
    +  /**
    +   * Starts the stopwatch.
    +   * Throws an exception if the stopwatch is already running.
    +   */
    +  def start(): Unit = {
    +    assume(!running, "start() called but the stopwatch is already running.")
    +    running = true
    +    startTime = now
    +  }
    +
    +  /**
    +   * Stops the stopwatch and returns the duration of the last session in milliseconds.
    +   * Throws an exception if the stopwatch is not running.
    +   */
    +  def stop(): Long = {
    +    assume(running, "stop() called but the stopwatch is not running.")
    +    val duration = now - startTime
    +    add(duration)
    +    running = false
    +    duration
    +  }
    +
    +  /**
    +   * Checks whether the stopwatch is running.
    +   */
    +  def isRunning: Boolean = running
    +
    +  /**
    +   * Returns total elapsed time in milliseconds, not counting the current session if the stopwatch
    +   * is running.
    +   */
    +  def elapsed(): Long
    +
    +  override def toString: String = s"$name: ${elapsed()}ms"
    +
    +  /**
    +   * Gets the current time in milliseconds.
    +   */
    +  protected def now: Long = System.currentTimeMillis()
    +
    +  /**
    +   * Adds input duration to total elapsed time.
    +   */
    +  protected def add(duration: Long): Unit
    +}
    +
    +/**
    + * A local [[Stopwatch]].
    + */
    +private[spark] class LocalStopwatch(override val name: String) extends Stopwatch {
    +
    +  private var elapsedTime: Long = 0L
    +
    +  override def elapsed(): Long = elapsedTime
    +
    +  override protected def add(duration: Long): Unit = {
    +    elapsedTime += duration
    +  }
    +}
    +
    +/**
    + * A distributed [[Stopwatch]] using Spark accumulator.
    + * @param sc SparkContext
    + */
    +private[spark] class DistributedStopwatch(
    +    sc: SparkContext,
    +    override val name: String) extends Stopwatch {
    +
    +  private val elapsedTime: Accumulator[Long] = sc.accumulator(0L, s"DistributedStopwatch($name)")
    +
    +  override def elapsed(): Long = elapsedTime.value
    +
    +  override protected def add(duration: Long): Unit = {
    +    elapsedTime += duration
    +  }
    +}
    +
    +/**
    + * A multiple stopwatch that contains local and distributed stopwatches.
    + * @param sc SparkContext
    + */
    +private[spark] class MultiStopwatch(@transient private val sc: SparkContext) extends Serializable {
    +
    +  private val stopwatches: mutable.Map[String, Stopwatch] = mutable.Map.empty
    +
    +  /**
    +   * Adds a local stopwatch.
    +   * @param name stopwatch name
    +   */
    +  def addLocal(name: String): this.type = {
    +    require(!stopwatches.contains(name), s"Stopwatch with name $name already exists.")
    +    stopwatches(name) = new LocalStopwatch(name)
    +    this
    +  }
    +
    +  /**
    +   * Adds a distributed stopwatch.
    +   * @param name stopwatch name
    +   */
    +  def addDistributed(name: String): this.type = {
    +    require(!stopwatches.contains(name), s"Stopwatch with name $name already exists.")
    +    stopwatches(name) = new DistributedStopwatch(sc, name)
    +    this
    +  }
    +
    +  /**
    +   * Gets a stopwatch.
    +   * @param name stopwatch name
    +   */
    +  def apply(name: String): Stopwatch = stopwatches(name)
    +
    +  override def toString: String = {
    +    stopwatches.values.toArray.sortBy(_.name)
    +      .map(c => s"  $c")
    +      .mkString("{\n", ",\n", "\n}")
    +  }
    +}
    diff --git a/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala b/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala
    index e628059c4af8e..fda8d5a0b048f 100644
    --- a/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala
    +++ b/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala
    @@ -43,7 +43,7 @@ import org.apache.spark.mllib.recommendation._
     import org.apache.spark.mllib.regression._
     import org.apache.spark.mllib.stat.correlation.CorrelationNames
     import org.apache.spark.mllib.stat.distribution.MultivariateGaussian
    -import org.apache.spark.mllib.stat.test.ChiSqTestResult
    +import org.apache.spark.mllib.stat.test.{ChiSqTestResult, KolmogorovSmirnovTestResult}
     import org.apache.spark.mllib.stat.{
       KernelDensity, MultivariateStatisticalSummary, Statistics}
     import org.apache.spark.mllib.tree.configuration.{Algo, BoostingStrategy, Strategy}
    @@ -502,6 +502,39 @@ private[python] class PythonMLLibAPI extends Serializable {
         new MatrixFactorizationModelWrapper(model)
       }
     
    +  /**
    +   * Java stub for Python mllib LDA.run()
    +   */
    +  def trainLDAModel(
    +      data: JavaRDD[java.util.List[Any]],
    +      k: Int,
    +      maxIterations: Int,
    +      docConcentration: Double,
    +      topicConcentration: Double,
    +      seed: java.lang.Long,
    +      checkpointInterval: Int,
    +      optimizer: String): LDAModel = {
    +    val algo = new LDA()
    +      .setK(k)
    +      .setMaxIterations(maxIterations)
    +      .setDocConcentration(docConcentration)
    +      .setTopicConcentration(topicConcentration)
    +      .setCheckpointInterval(checkpointInterval)
    +      .setOptimizer(optimizer)
    +
    +    if (seed != null) algo.setSeed(seed)
    +
    +    val documents = data.rdd.map(_.asScala.toArray).map { r =>
    +      r(0) match {
    +        case i: java.lang.Integer => (i.toLong, r(1).asInstanceOf[Vector])
    +        case i: java.lang.Long => (i.toLong, r(1).asInstanceOf[Vector])
    +        case _ => throw new IllegalArgumentException("input values contains invalid type value.")
    +      }
    +    }
    +    algo.run(documents)
    +  }
    +
    +
       /**
        * Java stub for Python mllib FPGrowth.train().  This stub returns a handle
        * to the Java object instead of the content of the Java object.  Extra care
    @@ -1060,6 +1093,18 @@ private[python] class PythonMLLibAPI extends Serializable {
         LinearDataGenerator.generateLinearRDD(
           sc, nexamples, nfeatures, eps, nparts, intercept)
       }
    +
    +  /**
    +   * Java stub for Statistics.kolmogorovSmirnovTest()
    +   */
    +  def kolmogorovSmirnovTest(
    +      data: JavaRDD[Double],
    +      distName: String,
    +      params: JList[Double]): KolmogorovSmirnovTestResult = {
    +    val paramsSeq = params.asScala.toSeq
    +    Statistics.kolmogorovSmirnovTest(data, distName, paramsSeq: _*)
    +  }
    +
     }
     
     /**
    diff --git a/mllib/src/main/scala/org/apache/spark/mllib/classification/ClassificationModel.scala b/mllib/src/main/scala/org/apache/spark/mllib/classification/ClassificationModel.scala
    index 35a0db76f3a8c..ba73024e3c04d 100644
    --- a/mllib/src/main/scala/org/apache/spark/mllib/classification/ClassificationModel.scala
    +++ b/mllib/src/main/scala/org/apache/spark/mllib/classification/ClassificationModel.scala
    @@ -36,6 +36,7 @@ trait ClassificationModel extends Serializable {
        *
        * @param testData RDD representing data points to be predicted
        * @return an RDD[Double] where each entry contains the corresponding prediction
    +   * @since 0.8.0
        */
       def predict(testData: RDD[Vector]): RDD[Double]
     
    @@ -44,6 +45,7 @@ trait ClassificationModel extends Serializable {
        *
        * @param testData array representing a single data point
        * @return predicted category from the trained model
    +   * @since 0.8.0
        */
       def predict(testData: Vector): Double
     
    @@ -51,6 +53,7 @@ trait ClassificationModel extends Serializable {
        * Predict values for examples stored in a JavaRDD.
        * @param testData JavaRDD representing data points to be predicted
        * @return a JavaRDD[java.lang.Double] where each entry contains the corresponding prediction
    +   * @since 0.8.0
        */
       def predict(testData: JavaRDD[Vector]): JavaRDD[java.lang.Double] =
         predict(testData.rdd).toJavaRDD().asInstanceOf[JavaRDD[java.lang.Double]]
    diff --git a/mllib/src/main/scala/org/apache/spark/mllib/classification/LogisticRegression.scala b/mllib/src/main/scala/org/apache/spark/mllib/classification/LogisticRegression.scala
    index 2df4d21e8cd55..268642ac6a2f6 100644
    --- a/mllib/src/main/scala/org/apache/spark/mllib/classification/LogisticRegression.scala
    +++ b/mllib/src/main/scala/org/apache/spark/mllib/classification/LogisticRegression.scala
    @@ -85,6 +85,7 @@ class LogisticRegressionModel (
        * in Binary Logistic Regression. An example with prediction score greater than or equal to
        * this threshold is identified as an positive, and negative otherwise. The default value is 0.5.
        * It is only used for binary classification.
    +   * @since 1.0.0
        */
       @Experimental
       def setThreshold(threshold: Double): this.type = {
    @@ -96,6 +97,7 @@ class LogisticRegressionModel (
        * :: Experimental ::
        * Returns the threshold (if any) used for converting raw prediction scores into 0/1 predictions.
        * It is only used for binary classification.
    +   * @since 1.3.0
        */
       @Experimental
       def getThreshold: Option[Double] = threshold
    @@ -104,6 +106,7 @@ class LogisticRegressionModel (
        * :: Experimental ::
        * Clears the threshold so that `predict` will output raw prediction scores.
        * It is only used for binary classification.
    +   * @since 1.0.0
        */
       @Experimental
       def clearThreshold(): this.type = {
    @@ -155,6 +158,9 @@ class LogisticRegressionModel (
         }
       }
     
    +  /**
    +   * @since 1.3.0
    +   */
       override def save(sc: SparkContext, path: String): Unit = {
         GLMClassificationModel.SaveLoadV1_0.save(sc, path, this.getClass.getName,
           numFeatures, numClasses, weights, intercept, threshold)
    @@ -162,6 +168,9 @@ class LogisticRegressionModel (
     
       override protected def formatVersion: String = "1.0"
     
    +  /**
    +   * @since 1.4.0
    +   */
       override def toString: String = {
         s"${super.toString}, numClasses = ${numClasses}, threshold = ${threshold.getOrElse("None")}"
       }
    @@ -169,6 +178,9 @@ class LogisticRegressionModel (
     
     object LogisticRegressionModel extends Loader[LogisticRegressionModel] {
     
    +  /**
    +   * @since 1.3.0
    +   */
       override def load(sc: SparkContext, path: String): LogisticRegressionModel = {
         val (loadedClassName, version, metadata) = Loader.loadMetadata(sc, path)
         // Hard-code class name string in case it changes in the future
    @@ -249,6 +261,7 @@ object LogisticRegressionWithSGD {
        * @param miniBatchFraction Fraction of data to be used per iteration.
        * @param initialWeights Initial set of weights to be used. Array should be equal in size to
        *        the number of features in the data.
    +   * @since 1.0.0
        */
       def train(
           input: RDD[LabeledPoint],
    @@ -271,6 +284,7 @@ object LogisticRegressionWithSGD {
        * @param stepSize Step size to be used for each iteration of gradient descent.
     
        * @param miniBatchFraction Fraction of data to be used per iteration.
    +   * @since 1.0.0
        */
       def train(
           input: RDD[LabeledPoint],
    @@ -292,6 +306,7 @@ object LogisticRegressionWithSGD {
     
        * @param numIterations Number of iterations of gradient descent to run.
        * @return a LogisticRegressionModel which has the weights and offset from training.
    +   * @since 1.0.0
        */
       def train(
           input: RDD[LabeledPoint],
    @@ -309,6 +324,7 @@ object LogisticRegressionWithSGD {
        * @param input RDD of (label, array of features) pairs.
        * @param numIterations Number of iterations of gradient descent to run.
        * @return a LogisticRegressionModel which has the weights and offset from training.
    +   * @since 1.0.0
        */
       def train(
           input: RDD[LabeledPoint],
    @@ -345,6 +361,7 @@ class LogisticRegressionWithLBFGS
        * Set the number of possible outcomes for k classes classification problem in
        * Multinomial Logistic Regression.
        * By default, it is binary logistic regression so k will be set to 2.
    +   * @since 1.3.0
        */
       @Experimental
       def setNumClasses(numClasses: Int): this.type = {
    diff --git a/mllib/src/main/scala/org/apache/spark/mllib/classification/NaiveBayes.scala b/mllib/src/main/scala/org/apache/spark/mllib/classification/NaiveBayes.scala
    index f51ee36d0dfcb..2df91c09421e9 100644
    --- a/mllib/src/main/scala/org/apache/spark/mllib/classification/NaiveBayes.scala
    +++ b/mllib/src/main/scala/org/apache/spark/mllib/classification/NaiveBayes.scala
    @@ -40,7 +40,7 @@ import org.apache.spark.sql.{DataFrame, SQLContext}
      *              where D is number of features
      * @param modelType The type of NB model to fit  can be "multinomial" or "bernoulli"
      */
    -class NaiveBayesModel private[mllib] (
    +class NaiveBayesModel private[spark] (
         val labels: Array[Double],
         val pi: Array[Double],
         val theta: Array[Array[Double]],
    @@ -93,26 +93,70 @@ class NaiveBayesModel private[mllib] (
       override def predict(testData: Vector): Double = {
         modelType match {
           case Multinomial =>
    -        val prob = thetaMatrix.multiply(testData)
    -        BLAS.axpy(1.0, piVector, prob)
    -        labels(prob.argmax)
    +        labels(multinomialCalculation(testData).argmax)
           case Bernoulli =>
    -        testData.foreachActive { (index, value) =>
    -          if (value != 0.0 && value != 1.0) {
    -            throw new SparkException(
    -              s"Bernoulli naive Bayes requires 0 or 1 feature values but found $testData.")
    -          }
    -        }
    -        val prob = thetaMinusNegTheta.get.multiply(testData)
    -        BLAS.axpy(1.0, piVector, prob)
    -        BLAS.axpy(1.0, negThetaSum.get, prob)
    -        labels(prob.argmax)
    -      case _ =>
    -        // This should never happen.
    -        throw new UnknownError(s"Invalid modelType: $modelType.")
    +        labels(bernoulliCalculation(testData).argmax)
    +    }
    +  }
    +
    +  /**
    +   * Predict values for the given data set using the model trained.
    +   *
    +   * @param testData RDD representing data points to be predicted
    +   * @return an RDD[Vector] where each entry contains the predicted posterior class probabilities,
    +   *         in the same order as class labels
    +   */
    +  def predictProbabilities(testData: RDD[Vector]): RDD[Vector] = {
    +    val bcModel = testData.context.broadcast(this)
    +    testData.mapPartitions { iter =>
    +      val model = bcModel.value
    +      iter.map(model.predictProbabilities)
         }
       }
     
    +  /**
    +   * Predict posterior class probabilities for a single data point using the model trained.
    +   *
    +   * @param testData array representing a single data point
    +   * @return predicted posterior class probabilities from the trained model,
    +   *         in the same order as class labels
    +   */
    +  def predictProbabilities(testData: Vector): Vector = {
    +    modelType match {
    +      case Multinomial =>
    +        posteriorProbabilities(multinomialCalculation(testData))
    +      case Bernoulli =>
    +        posteriorProbabilities(bernoulliCalculation(testData))
    +    }
    +  }
    +
    +  private def multinomialCalculation(testData: Vector) = {
    +    val prob = thetaMatrix.multiply(testData)
    +    BLAS.axpy(1.0, piVector, prob)
    +    prob
    +  }
    +
    +  private def bernoulliCalculation(testData: Vector) = {
    +    testData.foreachActive((_, value) =>
    +      if (value != 0.0 && value != 1.0) {
    +        throw new SparkException(
    +          s"Bernoulli naive Bayes requires 0 or 1 feature values but found $testData.")
    +      }
    +    )
    +    val prob = thetaMinusNegTheta.get.multiply(testData)
    +    BLAS.axpy(1.0, piVector, prob)
    +    BLAS.axpy(1.0, negThetaSum.get, prob)
    +    prob
    +  }
    +
    +  private def posteriorProbabilities(logProb: DenseVector) = {
    +    val logProbArray = logProb.toArray
    +    val maxLog = logProbArray.max
    +    val scaledProbs = logProbArray.map(lp => math.exp(lp - maxLog))
    +    val probSum = scaledProbs.sum
    +    new DenseVector(scaledProbs.map(_ / probSum))
    +  }
    +
       override def save(sc: SparkContext, path: String): Unit = {
         val data = NaiveBayesModel.SaveLoadV2_0.Data(labels, pi, theta, modelType)
         NaiveBayesModel.SaveLoadV2_0.save(sc, path, data)
    @@ -338,7 +382,7 @@ class NaiveBayes private (
             BLAS.axpy(1.0, c2._2, c1._2)
             (c1._1 + c2._1, c1._2)
           }
    -    ).collect()
    +    ).collect().sortBy(_._1)
     
         val numLabels = aggregated.length
         var numDocuments = 0L
    @@ -381,13 +425,13 @@ class NaiveBayes private (
     object NaiveBayes {
     
       /** String name for multinomial model type. */
    -  private[classification] val Multinomial: String = "multinomial"
    +  private[spark] val Multinomial: String = "multinomial"
     
       /** String name for Bernoulli model type. */
    -  private[classification] val Bernoulli: String = "bernoulli"
    +  private[spark] val Bernoulli: String = "bernoulli"
     
       /* Set of modelTypes that NaiveBayes supports */
    -  private[classification] val supportedModelTypes = Set(Multinomial, Bernoulli)
    +  private[spark] val supportedModelTypes = Set(Multinomial, Bernoulli)
     
       /**
        * Trains a Naive Bayes model given an RDD of `(label, features)` pairs.
    @@ -400,6 +444,7 @@ object NaiveBayes {
        *
        * @param input RDD of `(label, array of features)` pairs.  Every vector should be a frequency
        *              vector or a count vector.
    +   * @since 0.9.0
        */
       def train(input: RDD[LabeledPoint]): NaiveBayesModel = {
         new NaiveBayes().run(input)
    @@ -415,6 +460,7 @@ object NaiveBayes {
        * @param input RDD of `(label, array of features)` pairs.  Every vector should be a frequency
        *              vector or a count vector.
        * @param lambda The smoothing parameter
    +   * @since 0.9.0
        */
       def train(input: RDD[LabeledPoint], lambda: Double): NaiveBayesModel = {
         new NaiveBayes(lambda, Multinomial).run(input)
    @@ -437,6 +483,7 @@ object NaiveBayes {
        *
        * @param modelType The type of NB model to fit from the enumeration NaiveBayesModels, can be
        *              multinomial or bernoulli
    +   * @since 0.9.0
        */
       def train(input: RDD[LabeledPoint], lambda: Double, modelType: String): NaiveBayesModel = {
         require(supportedModelTypes.contains(modelType),
    diff --git a/mllib/src/main/scala/org/apache/spark/mllib/classification/SVM.scala b/mllib/src/main/scala/org/apache/spark/mllib/classification/SVM.scala
    index 348485560713e..5b54feeb10467 100644
    --- a/mllib/src/main/scala/org/apache/spark/mllib/classification/SVM.scala
    +++ b/mllib/src/main/scala/org/apache/spark/mllib/classification/SVM.scala
    @@ -46,6 +46,7 @@ class SVMModel (
        * Sets the threshold that separates positive predictions from negative predictions. An example
        * with prediction score greater than or equal to this threshold is identified as an positive,
        * and negative otherwise. The default value is 0.0.
    +   * @since 1.3.0
        */
       @Experimental
       def setThreshold(threshold: Double): this.type = {
    @@ -56,6 +57,7 @@ class SVMModel (
       /**
        * :: Experimental ::
        * Returns the threshold (if any) used for converting raw prediction scores into 0/1 predictions.
    +   * @since 1.3.0
        */
       @Experimental
       def getThreshold: Option[Double] = threshold
    @@ -63,6 +65,7 @@ class SVMModel (
       /**
        * :: Experimental ::
        * Clears the threshold so that `predict` will output raw prediction scores.
    +   * @since 1.0.0
        */
       @Experimental
       def clearThreshold(): this.type = {
    @@ -81,6 +84,9 @@ class SVMModel (
         }
       }
     
    +  /**
    +   * @since 1.3.0
    +   */
       override def save(sc: SparkContext, path: String): Unit = {
         GLMClassificationModel.SaveLoadV1_0.save(sc, path, this.getClass.getName,
           numFeatures = weights.size, numClasses = 2, weights, intercept, threshold)
    @@ -88,6 +94,9 @@ class SVMModel (
     
       override protected def formatVersion: String = "1.0"
     
    +  /**
    +   * @since 1.4.0
    +   */
       override def toString: String = {
         s"${super.toString}, numClasses = 2, threshold = ${threshold.getOrElse("None")}"
       }
    @@ -95,6 +104,9 @@ class SVMModel (
     
     object SVMModel extends Loader[SVMModel] {
     
    +   /**
    +   * @since 1.3.0
    +   */
       override def load(sc: SparkContext, path: String): SVMModel = {
         val (loadedClassName, version, metadata) = Loader.loadMetadata(sc, path)
         // Hard-code class name string in case it changes in the future
    @@ -173,6 +185,7 @@ object SVMWithSGD {
        * @param miniBatchFraction Fraction of data to be used per iteration.
        * @param initialWeights Initial set of weights to be used. Array should be equal in size to
        *        the number of features in the data.
    +   * @since 0.8.0
        */
       def train(
           input: RDD[LabeledPoint],
    @@ -196,6 +209,7 @@ object SVMWithSGD {
        * @param stepSize Step size to be used for each iteration of gradient descent.
        * @param regParam Regularization parameter.
        * @param miniBatchFraction Fraction of data to be used per iteration.
    +   * @since 0.8.0
        */
       def train(
           input: RDD[LabeledPoint],
    @@ -217,6 +231,7 @@ object SVMWithSGD {
        * @param regParam Regularization parameter.
        * @param numIterations Number of iterations of gradient descent to run.
        * @return a SVMModel which has the weights and offset from training.
    +   * @since 0.8.0
        */
       def train(
           input: RDD[LabeledPoint],
    @@ -235,6 +250,7 @@ object SVMWithSGD {
        * @param input RDD of (label, array of features) pairs.
        * @param numIterations Number of iterations of gradient descent to run.
        * @return a SVMModel which has the weights and offset from training.
    +   * @since 0.8.0
        */
       def train(input: RDD[LabeledPoint], numIterations: Int): SVMModel = {
         train(input, numIterations, 1.0, 0.01, 1.0)
    diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeans.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeans.scala
    index 0f8d6a399682d..0a65403f4ec95 100644
    --- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeans.scala
    +++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeans.scala
    @@ -85,9 +85,7 @@ class KMeans private (
        * (Bahmani et al., Scalable K-Means++, VLDB 2012). Default: k-means||.
        */
       def setInitializationMode(initializationMode: String): this.type = {
    -    if (initializationMode != KMeans.RANDOM && initializationMode != KMeans.K_MEANS_PARALLEL) {
    -      throw new IllegalArgumentException("Invalid initialization mode: " + initializationMode)
    -    }
    +    KMeans.validateInitMode(initializationMode)
         this.initializationMode = initializationMode
         this
       }
    @@ -156,6 +154,21 @@ class KMeans private (
         this
       }
     
    +  // Initial cluster centers can be provided as a KMeansModel object rather than using the
    +  // random or k-means|| initializationMode
    +  private var initialModel: Option[KMeansModel] = None
    +
    +  /**
    +   * Set the initial starting point, bypassing the random initialization or k-means||
    +   * The condition model.k == this.k must be met, failure results
    +   * in an IllegalArgumentException.
    +   */
    +  def setInitialModel(model: KMeansModel): this.type = {
    +    require(model.k == k, "mismatched cluster count")
    +    initialModel = Some(model)
    +    this
    +  }
    +
       /**
        * Train a K-means model on the given set of points; `data` should be cached for high
        * performance, because this is an iterative algorithm.
    @@ -193,20 +206,34 @@ class KMeans private (
     
         val initStartTime = System.nanoTime()
     
    -    val centers = if (initializationMode == KMeans.RANDOM) {
    -      initRandom(data)
    +    // Only one run is allowed when initialModel is given
    +    val numRuns = if (initialModel.nonEmpty) {
    +      if (runs > 1) logWarning("Ignoring runs; one run is allowed when initialModel is given.")
    +      1
         } else {
    -      initKMeansParallel(data)
    +      runs
         }
     
    +    val centers = initialModel match {
    +      case Some(kMeansCenters) => {
    +        Array(kMeansCenters.clusterCenters.map(s => new VectorWithNorm(s)))
    +      }
    +      case None => {
    +        if (initializationMode == KMeans.RANDOM) {
    +          initRandom(data)
    +        } else {
    +          initKMeansParallel(data)
    +        }
    +      }
    +    }
         val initTimeInSeconds = (System.nanoTime() - initStartTime) / 1e9
         logInfo(s"Initialization with $initializationMode took " + "%.3f".format(initTimeInSeconds) +
           " seconds.")
     
    -    val active = Array.fill(runs)(true)
    -    val costs = Array.fill(runs)(0.0)
    +    val active = Array.fill(numRuns)(true)
    +    val costs = Array.fill(numRuns)(0.0)
     
    -    var activeRuns = new ArrayBuffer[Int] ++ (0 until runs)
    +    var activeRuns = new ArrayBuffer[Int] ++ (0 until numRuns)
         var iteration = 0
     
         val iterationStartTime = System.nanoTime()
    @@ -521,6 +548,14 @@ object KMeans {
           v2: VectorWithNorm): Double = {
         MLUtils.fastSquaredDistance(v1.vector, v1.norm, v2.vector, v2.norm)
       }
    +
    +  private[spark] def validateInitMode(initMode: String): Boolean = {
    +    initMode match {
    +      case KMeans.RANDOM => true
    +      case KMeans.K_MEANS_PARALLEL => true
    +      case _ => false
    +    }
    +  }
     }
     
     /**
    diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDA.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDA.scala
    index a410547a72fda..ab124e6d77c5e 100644
    --- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDA.scala
    +++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDA.scala
    @@ -23,11 +23,10 @@ import org.apache.spark.Logging
     import org.apache.spark.annotation.{DeveloperApi, Experimental}
     import org.apache.spark.api.java.JavaPairRDD
     import org.apache.spark.graphx._
    -import org.apache.spark.mllib.linalg.Vector
    +import org.apache.spark.mllib.linalg.{Vector, Vectors}
     import org.apache.spark.rdd.RDD
     import org.apache.spark.util.Utils
     
    -
     /**
      * :: Experimental ::
      *
    @@ -49,14 +48,15 @@ import org.apache.spark.util.Utils
     class LDA private (
         private var k: Int,
         private var maxIterations: Int,
    -    private var docConcentration: Double,
    +    private var docConcentration: Vector,
         private var topicConcentration: Double,
         private var seed: Long,
         private var checkpointInterval: Int,
         private var ldaOptimizer: LDAOptimizer) extends Logging {
     
    -  def this() = this(k = 10, maxIterations = 20, docConcentration = -1, topicConcentration = -1,
    -    seed = Utils.random.nextLong(), checkpointInterval = 10, ldaOptimizer = new EMLDAOptimizer)
    +  def this() = this(k = 10, maxIterations = 20, docConcentration = Vectors.dense(-1),
    +    topicConcentration = -1, seed = Utils.random.nextLong(), checkpointInterval = 10,
    +    ldaOptimizer = new EMLDAOptimizer)
     
       /**
        * Number of topics to infer.  I.e., the number of soft cluster centers.
    @@ -77,37 +77,50 @@ class LDA private (
        * Concentration parameter (commonly named "alpha") for the prior placed on documents'
        * distributions over topics ("theta").
        *
    -   * This is the parameter to a symmetric Dirichlet distribution.
    +   * This is the parameter to a Dirichlet distribution.
        */
    -  def getDocConcentration: Double = this.docConcentration
    +  def getDocConcentration: Vector = this.docConcentration
     
       /**
        * Concentration parameter (commonly named "alpha") for the prior placed on documents'
        * distributions over topics ("theta").
        *
    -   * This is the parameter to a symmetric Dirichlet distribution, where larger values
    -   * mean more smoothing (more regularization).
    +   * This is the parameter to a Dirichlet distribution, where larger values mean more smoothing
    +   * (more regularization).
        *
    -   * If set to -1, then docConcentration is set automatically.
    -   *  (default = -1 = automatic)
    +   * If set to a singleton vector Vector(-1), then docConcentration is set automatically. If set to
    +   * singleton vector Vector(t) where t != -1, then t is replicated to a vector of length k during
    +   * [[LDAOptimizer.initialize()]]. Otherwise, the [[docConcentration]] vector must be length k.
    +   * (default = Vector(-1) = automatic)
        *
        * Optimizer-specific parameter settings:
        *  - EM
    -   *     - Value should be > 1.0
    -   *     - default = (50 / k) + 1, where 50/k is common in LDA libraries and +1 follows
    -   *       Asuncion et al. (2009), who recommend a +1 adjustment for EM.
    +   *     - Currently only supports symmetric distributions, so all values in the vector should be
    +   *       the same.
    +   *     - Values should be > 1.0
    +   *     - default = uniformly (50 / k) + 1, where 50/k is common in LDA libraries and +1 follows
    +   *       from Asuncion et al. (2009), who recommend a +1 adjustment for EM.
        *  - Online
    -   *     - Value should be >= 0
    -   *     - default = (1.0 / k), following the implementation from
    +   *     - Values should be >= 0
    +   *     - default = uniformly (1.0 / k), following the implementation from
        *       [[https://github.com/Blei-Lab/onlineldavb]].
        */
    -  def setDocConcentration(docConcentration: Double): this.type = {
    +  def setDocConcentration(docConcentration: Vector): this.type = {
         this.docConcentration = docConcentration
         this
       }
     
    +  /** Replicates Double to create a symmetric prior */
    +  def setDocConcentration(docConcentration: Double): this.type = {
    +    this.docConcentration = Vectors.dense(docConcentration)
    +    this
    +  }
    +
       /** Alias for [[getDocConcentration]] */
    -  def getAlpha: Double = getDocConcentration
    +  def getAlpha: Vector = getDocConcentration
    +
    +  /** Alias for [[setDocConcentration()]] */
    +  def setAlpha(alpha: Vector): this.type = setDocConcentration(alpha)
     
       /** Alias for [[setDocConcentration()]] */
       def setAlpha(alpha: Double): this.type = setDocConcentration(alpha)
    diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAModel.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAModel.scala
    index 974b26924dfb8..920b57756b625 100644
    --- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAModel.scala
    +++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAModel.scala
    @@ -17,15 +17,25 @@
     
     package org.apache.spark.mllib.clustering
     
    -import breeze.linalg.{DenseMatrix => BDM, normalize, sum => brzSum}
    +import breeze.linalg.{DenseMatrix => BDM, normalize, sum => brzSum, DenseVector => BDV}
     
    +import org.apache.hadoop.fs.Path
    +
    +import org.json4s.DefaultFormats
    +import org.json4s.JsonDSL._
    +import org.json4s.jackson.JsonMethods._
    +
    +import org.apache.spark.SparkContext
     import org.apache.spark.annotation.Experimental
     import org.apache.spark.api.java.JavaPairRDD
    -import org.apache.spark.graphx.{VertexId, EdgeContext, Graph}
    -import org.apache.spark.mllib.linalg.{Vectors, Vector, Matrices, Matrix}
    +import org.apache.spark.graphx.{VertexId, Edge, EdgeContext, Graph}
    +import org.apache.spark.mllib.linalg.{Vectors, Vector, Matrices, Matrix, DenseVector}
    +import org.apache.spark.mllib.util.{Saveable, Loader}
     import org.apache.spark.rdd.RDD
    +import org.apache.spark.sql.{SQLContext, Row}
     import org.apache.spark.util.BoundedPriorityQueue
     
    +
     /**
      * :: Experimental ::
      *
    @@ -35,7 +45,7 @@ import org.apache.spark.util.BoundedPriorityQueue
      * including local and distributed data structures.
      */
     @Experimental
    -abstract class LDAModel private[clustering] {
    +abstract class LDAModel private[clustering] extends Saveable {
     
       /** Number of topics */
       def k: Int
    @@ -176,6 +186,11 @@ class LocalLDAModel private[clustering] (
         }.toArray
       }
     
    +  override protected def formatVersion = "1.0"
    +
    +  override def save(sc: SparkContext, path: String): Unit = {
    +    LocalLDAModel.SaveLoadV1_0.save(sc, path, topicsMatrix)
    +  }
       // TODO
       // override def logLikelihood(documents: RDD[(Long, Vector)]): Double = ???
     
    @@ -184,6 +199,80 @@ class LocalLDAModel private[clustering] (
     
     }
     
    +@Experimental
    +object LocalLDAModel extends Loader[LocalLDAModel] {
    +
    +  private object SaveLoadV1_0 {
    +
    +    val thisFormatVersion = "1.0"
    +
    +    val thisClassName = "org.apache.spark.mllib.clustering.LocalLDAModel"
    +
    +    // Store the distribution of terms of each topic and the column index in topicsMatrix
    +    // as a Row in data.
    +    case class Data(topic: Vector, index: Int)
    +
    +    def save(sc: SparkContext, path: String, topicsMatrix: Matrix): Unit = {
    +      val sqlContext = SQLContext.getOrCreate(sc)
    +      import sqlContext.implicits._
    +
    +      val k = topicsMatrix.numCols
    +      val metadata = compact(render
    +        (("class" -> thisClassName) ~ ("version" -> thisFormatVersion) ~
    +         ("k" -> k) ~ ("vocabSize" -> topicsMatrix.numRows)))
    +      sc.parallelize(Seq(metadata), 1).saveAsTextFile(Loader.metadataPath(path))
    +
    +      val topicsDenseMatrix = topicsMatrix.toBreeze.toDenseMatrix
    +      val topics = Range(0, k).map { topicInd =>
    +        Data(Vectors.dense((topicsDenseMatrix(::, topicInd).toArray)), topicInd)
    +      }.toSeq
    +      sc.parallelize(topics, 1).toDF().write.parquet(Loader.dataPath(path))
    +    }
    +
    +    def load(sc: SparkContext, path: String): LocalLDAModel = {
    +      val dataPath = Loader.dataPath(path)
    +      val sqlContext = SQLContext.getOrCreate(sc)
    +      val dataFrame = sqlContext.read.parquet(dataPath)
    +
    +      Loader.checkSchema[Data](dataFrame.schema)
    +      val topics = dataFrame.collect()
    +      val vocabSize = topics(0).getAs[Vector](0).size
    +      val k = topics.size
    +
    +      val brzTopics = BDM.zeros[Double](vocabSize, k)
    +      topics.foreach { case Row(vec: Vector, ind: Int) =>
    +        brzTopics(::, ind) := vec.toBreeze
    +      }
    +      new LocalLDAModel(Matrices.fromBreeze(brzTopics))
    +    }
    +  }
    +
    +  override def load(sc: SparkContext, path: String): LocalLDAModel = {
    +    val (loadedClassName, loadedVersion, metadata) = Loader.loadMetadata(sc, path)
    +    implicit val formats = DefaultFormats
    +    val expectedK = (metadata \ "k").extract[Int]
    +    val expectedVocabSize = (metadata \ "vocabSize").extract[Int]
    +    val classNameV1_0 = SaveLoadV1_0.thisClassName
    +
    +    val model = (loadedClassName, loadedVersion) match {
    +      case (className, "1.0") if className == classNameV1_0 =>
    +        SaveLoadV1_0.load(sc, path)
    +      case _ => throw new Exception(
    +        s"LocalLDAModel.load did not recognize model with (className, format version):" +
    +        s"($loadedClassName, $loadedVersion).  Supported:\n" +
    +        s"  ($classNameV1_0, 1.0)")
    +    }
    +
    +    val topicsMatrix = model.topicsMatrix
    +    require(expectedK == topicsMatrix.numCols,
    +      s"LocalLDAModel requires $expectedK topics, got ${topicsMatrix.numCols} topics")
    +    require(expectedVocabSize == topicsMatrix.numRows,
    +      s"LocalLDAModel requires $expectedVocabSize terms for each topic, " +
    +      s"but got ${topicsMatrix.numRows}")
    +    model
    +  }
    +}
    +
     /**
      * :: Experimental ::
      *
    @@ -354,4 +443,135 @@ class DistributedLDAModel private (
       // TODO:
       // override def topicDistributions(documents: RDD[(Long, Vector)]): RDD[(Long, Vector)] = ???
     
    +  override protected def formatVersion = "1.0"
    +
    +  override def save(sc: SparkContext, path: String): Unit = {
    +    DistributedLDAModel.SaveLoadV1_0.save(
    +      sc, path, graph, globalTopicTotals, k, vocabSize, docConcentration, topicConcentration,
    +      iterationTimes)
    +  }
    +}
    +
    +
    +@Experimental
    +object DistributedLDAModel extends Loader[DistributedLDAModel] {
    +
    +  private object SaveLoadV1_0 {
    +
    +    val thisFormatVersion = "1.0"
    +
    +    val classNameV1_0 = "org.apache.spark.mllib.clustering.DistributedLDAModel"
    +
    +    // Store globalTopicTotals as a Vector.
    +    case class Data(globalTopicTotals: Vector)
    +
    +    // Store each term and document vertex with an id and the topicWeights.
    +    case class VertexData(id: Long, topicWeights: Vector)
    +
    +    // Store each edge with the source id, destination id and tokenCounts.
    +    case class EdgeData(srcId: Long, dstId: Long, tokenCounts: Double)
    +
    +    def save(
    +        sc: SparkContext,
    +        path: String,
    +        graph: Graph[LDA.TopicCounts, LDA.TokenCount],
    +        globalTopicTotals: LDA.TopicCounts,
    +        k: Int,
    +        vocabSize: Int,
    +        docConcentration: Double,
    +        topicConcentration: Double,
    +        iterationTimes: Array[Double]): Unit = {
    +      val sqlContext = SQLContext.getOrCreate(sc)
    +      import sqlContext.implicits._
    +
    +      val metadata = compact(render
    +        (("class" -> classNameV1_0) ~ ("version" -> thisFormatVersion) ~
    +         ("k" -> k) ~ ("vocabSize" -> vocabSize) ~ ("docConcentration" -> docConcentration) ~
    +         ("topicConcentration" -> topicConcentration) ~
    +         ("iterationTimes" -> iterationTimes.toSeq)))
    +      sc.parallelize(Seq(metadata), 1).saveAsTextFile(Loader.metadataPath(path))
    +
    +      val newPath = new Path(Loader.dataPath(path), "globalTopicTotals").toUri.toString
    +      sc.parallelize(Seq(Data(Vectors.fromBreeze(globalTopicTotals)))).toDF()
    +        .write.parquet(newPath)
    +
    +      val verticesPath = new Path(Loader.dataPath(path), "topicCounts").toUri.toString
    +      graph.vertices.map { case (ind, vertex) =>
    +        VertexData(ind, Vectors.fromBreeze(vertex))
    +      }.toDF().write.parquet(verticesPath)
    +
    +      val edgesPath = new Path(Loader.dataPath(path), "tokenCounts").toUri.toString
    +      graph.edges.map { case Edge(srcId, dstId, prop) =>
    +        EdgeData(srcId, dstId, prop)
    +      }.toDF().write.parquet(edgesPath)
    +    }
    +
    +    def load(
    +        sc: SparkContext,
    +        path: String,
    +        vocabSize: Int,
    +        docConcentration: Double,
    +        topicConcentration: Double,
    +        iterationTimes: Array[Double]): DistributedLDAModel = {
    +      val dataPath = new Path(Loader.dataPath(path), "globalTopicTotals").toUri.toString
    +      val vertexDataPath = new Path(Loader.dataPath(path), "topicCounts").toUri.toString
    +      val edgeDataPath = new Path(Loader.dataPath(path), "tokenCounts").toUri.toString
    +      val sqlContext = SQLContext.getOrCreate(sc)
    +      val dataFrame = sqlContext.read.parquet(dataPath)
    +      val vertexDataFrame = sqlContext.read.parquet(vertexDataPath)
    +      val edgeDataFrame = sqlContext.read.parquet(edgeDataPath)
    +
    +      Loader.checkSchema[Data](dataFrame.schema)
    +      Loader.checkSchema[VertexData](vertexDataFrame.schema)
    +      Loader.checkSchema[EdgeData](edgeDataFrame.schema)
    +      val globalTopicTotals: LDA.TopicCounts =
    +        dataFrame.first().getAs[Vector](0).toBreeze.toDenseVector
    +      val vertices: RDD[(VertexId, LDA.TopicCounts)] = vertexDataFrame.map {
    +        case Row(ind: Long, vec: Vector) => (ind, vec.toBreeze.toDenseVector)
    +      }
    +
    +      val edges: RDD[Edge[LDA.TokenCount]] = edgeDataFrame.map {
    +        case Row(srcId: Long, dstId: Long, prop: Double) => Edge(srcId, dstId, prop)
    +      }
    +      val graph: Graph[LDA.TopicCounts, LDA.TokenCount] = Graph(vertices, edges)
    +
    +      new DistributedLDAModel(graph, globalTopicTotals, globalTopicTotals.length, vocabSize,
    +        docConcentration, topicConcentration, iterationTimes)
    +    }
    +
    +  }
    +
    +  override def load(sc: SparkContext, path: String): DistributedLDAModel = {
    +    val (loadedClassName, loadedVersion, metadata) = Loader.loadMetadata(sc, path)
    +    implicit val formats = DefaultFormats
    +    val expectedK = (metadata \ "k").extract[Int]
    +    val vocabSize = (metadata \ "vocabSize").extract[Int]
    +    val docConcentration = (metadata \ "docConcentration").extract[Double]
    +    val topicConcentration = (metadata \ "topicConcentration").extract[Double]
    +    val iterationTimes = (metadata \ "iterationTimes").extract[Seq[Double]]
    +    val classNameV1_0 = SaveLoadV1_0.classNameV1_0
    +
    +    val model = (loadedClassName, loadedVersion) match {
    +      case (className, "1.0") if className == classNameV1_0 => {
    +        DistributedLDAModel.SaveLoadV1_0.load(
    +          sc, path, vocabSize, docConcentration, topicConcentration, iterationTimes.toArray)
    +      }
    +      case _ => throw new Exception(
    +        s"DistributedLDAModel.load did not recognize model with (className, format version):" +
    +        s"($loadedClassName, $loadedVersion).  Supported: ($classNameV1_0, 1.0)")
    +    }
    +
    +    require(model.vocabSize == vocabSize,
    +      s"DistributedLDAModel requires $vocabSize vocabSize, got ${model.vocabSize} vocabSize")
    +    require(model.docConcentration == docConcentration,
    +      s"DistributedLDAModel requires $docConcentration docConcentration, " +
    +      s"got ${model.docConcentration} docConcentration")
    +    require(model.topicConcentration == topicConcentration,
    +      s"DistributedLDAModel requires $topicConcentration docConcentration, " +
    +      s"got ${model.topicConcentration} docConcentration")
    +    require(expectedK == model.k,
    +      s"DistributedLDAModel requires $expectedK topics, got ${model.k} topics")
    +    model
    +  }
    +
     }
    diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAOptimizer.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAOptimizer.scala
    index 8e5154b902d1d..f4170a3d98dd8 100644
    --- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAOptimizer.scala
    +++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAOptimizer.scala
    @@ -19,15 +19,15 @@ package org.apache.spark.mllib.clustering
     
     import java.util.Random
     
    -import breeze.linalg.{DenseVector => BDV, DenseMatrix => BDM, sum, normalize, kron}
    -import breeze.numerics.{digamma, exp, abs}
    +import breeze.linalg.{DenseMatrix => BDM, DenseVector => BDV, normalize, sum}
    +import breeze.numerics.{abs, digamma, exp}
     import breeze.stats.distributions.{Gamma, RandBasis}
     
     import org.apache.spark.annotation.DeveloperApi
     import org.apache.spark.graphx._
     import org.apache.spark.graphx.impl.GraphImpl
     import org.apache.spark.mllib.impl.PeriodicGraphCheckpointer
    -import org.apache.spark.mllib.linalg.{Matrices, SparseVector, DenseVector, Vector}
    +import org.apache.spark.mllib.linalg.{DenseVector, Matrices, SparseVector, Vector, Vectors}
     import org.apache.spark.rdd.RDD
     
     /**
    @@ -95,8 +95,11 @@ final class EMLDAOptimizer extends LDAOptimizer {
        * Compute bipartite term/doc graph.
        */
       override private[clustering] def initialize(docs: RDD[(Long, Vector)], lda: LDA): LDAOptimizer = {
    +    val docConcentration = lda.getDocConcentration(0)
    +    require({
    +      lda.getDocConcentration.toArray.forall(_ == docConcentration)
    +    }, "EMLDAOptimizer currently only supports symmetric document-topic priors")
     
    -    val docConcentration = lda.getDocConcentration
         val topicConcentration = lda.getTopicConcentration
         val k = lda.getK
     
    @@ -229,10 +232,10 @@ final class OnlineLDAOptimizer extends LDAOptimizer {
       private var vocabSize: Int = 0
     
       /** alias for docConcentration */
    -  private var alpha: Double = 0
    +  private var alpha: Vector = Vectors.dense(0)
     
       /** (private[clustering] for debugging)  Get docConcentration */
    -  private[clustering] def getAlpha: Double = alpha
    +  private[clustering] def getAlpha: Vector = alpha
     
       /** alias for topicConcentration */
       private var eta: Double = 0
    @@ -343,7 +346,19 @@ final class OnlineLDAOptimizer extends LDAOptimizer {
         this.k = lda.getK
         this.corpusSize = docs.count()
         this.vocabSize = docs.first()._2.size
    -    this.alpha = if (lda.getDocConcentration == -1) 1.0 / k else lda.getDocConcentration
    +    this.alpha = if (lda.getDocConcentration.size == 1) {
    +      if (lda.getDocConcentration(0) == -1) Vectors.dense(Array.fill(k)(1.0 / k))
    +      else {
    +        require(lda.getDocConcentration(0) >= 0, s"all entries in alpha must be >=0, got: $alpha")
    +        Vectors.dense(Array.fill(k)(lda.getDocConcentration(0)))
    +      }
    +    } else {
    +      require(lda.getDocConcentration.size == k, s"alpha must have length k, got: $alpha")
    +      lda.getDocConcentration.foreachActive { case (_, x) =>
    +        require(x >= 0, s"all entries in alpha must be >= 0, got: $alpha")
    +      }
    +      lda.getDocConcentration
    +    }
         this.eta = if (lda.getTopicConcentration == -1) 1.0 / k else lda.getTopicConcentration
         this.randomGenerator = new Random(lda.getSeed)
     
    @@ -370,9 +385,9 @@ final class OnlineLDAOptimizer extends LDAOptimizer {
         iteration += 1
         val k = this.k
         val vocabSize = this.vocabSize
    -    val Elogbeta = dirichletExpectation(lambda)
    +    val Elogbeta = dirichletExpectation(lambda).t
         val expElogbeta = exp(Elogbeta)
    -    val alpha = this.alpha
    +    val alpha = this.alpha.toBreeze
         val gammaShape = this.gammaShape
     
         val stats: RDD[BDM[Double]] = batch.mapPartitions { docs =>
    @@ -385,41 +400,36 @@ final class OnlineLDAOptimizer extends LDAOptimizer {
               case v => throw new IllegalArgumentException("Online LDA does not support vector type "
                 + v.getClass)
             }
    +        if (!ids.isEmpty) {
    +
    +          // Initialize the variational distribution q(theta|gamma) for the mini-batch
    +          val gammad: BDV[Double] =
    +            new Gamma(gammaShape, 1.0 / gammaShape).samplesVector(k) // K
    +          val expElogthetad: BDV[Double] = exp(digamma(gammad) - digamma(sum(gammad))) // K
    +          val expElogbetad: BDM[Double] = expElogbeta(ids, ::).toDenseMatrix // ids * K
    +
    +          val phinorm: BDV[Double] = expElogbetad * expElogthetad :+ 1e-100 // ids
    +          var meanchange = 1D
    +          val ctsVector = new BDV[Double](cts) // ids
    +
    +          // Iterate between gamma and phi until convergence
    +          while (meanchange > 1e-3) {
    +            val lastgamma = gammad.copy
    +            //        K                  K * ids               ids
    +            gammad := (expElogthetad :* (expElogbetad.t * (ctsVector :/ phinorm))) :+ alpha
    +            expElogthetad := exp(digamma(gammad) - digamma(sum(gammad)))
    +            phinorm := expElogbetad * expElogthetad :+ 1e-100
    +            meanchange = sum(abs(gammad - lastgamma)) / k
    +          }
     
    -        // Initialize the variational distribution q(theta|gamma) for the mini-batch
    -        var gammad = new Gamma(gammaShape, 1.0 / gammaShape).samplesVector(k).t // 1 * K
    -        var Elogthetad = digamma(gammad) - digamma(sum(gammad))     // 1 * K
    -        var expElogthetad = exp(Elogthetad)                         // 1 * K
    -        val expElogbetad = expElogbeta(::, ids).toDenseMatrix       // K * ids
    -
    -        var phinorm = expElogthetad * expElogbetad + 1e-100         // 1 * ids
    -        var meanchange = 1D
    -        val ctsVector = new BDV[Double](cts).t                      // 1 * ids
    -
    -        // Iterate between gamma and phi until convergence
    -        while (meanchange > 1e-3) {
    -          val lastgamma = gammad
    -          //        1*K                  1 * ids               ids * k
    -          gammad = (expElogthetad :* ((ctsVector / phinorm) * expElogbetad.t)) + alpha
    -          Elogthetad = digamma(gammad) - digamma(sum(gammad))
    -          expElogthetad = exp(Elogthetad)
    -          phinorm = expElogthetad * expElogbetad + 1e-100
    -          meanchange = sum(abs(gammad - lastgamma)) / k
    -        }
    -
    -        val m1 = expElogthetad.t
    -        val m2 = (ctsVector / phinorm).t.toDenseVector
    -        var i = 0
    -        while (i < ids.size) {
    -          stat(::, ids(i)) := stat(::, ids(i)) + m1 * m2(i)
    -          i += 1
    +          stat(::, ids) := expElogthetad.asDenseMatrix.t * (ctsVector :/ phinorm).asDenseMatrix
             }
           }
           Iterator(stat)
         }
     
         val statsSum: BDM[Double] = stats.reduce(_ += _)
    -    val batchResult = statsSum :* expElogbeta
    +    val batchResult = statsSum :* expElogbeta.t
     
         // Note that this is an optimization to avoid batch.count
         update(batchResult, iteration, (miniBatchFraction * corpusSize).ceil.toInt)
    diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/PowerIterationClustering.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/PowerIterationClustering.scala
    index e7a243f854e33..407e43a024a2e 100644
    --- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/PowerIterationClustering.scala
    +++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/PowerIterationClustering.scala
    @@ -153,6 +153,27 @@ class PowerIterationClustering private[clustering] (
         this
       }
     
    +  /**
    +   * Run the PIC algorithm on Graph.
    +   *
    +   * @param graph an affinity matrix represented as graph, which is the matrix A in the PIC paper.
    +   *              The similarity s,,ij,, represented as the edge between vertices (i, j) must
    +   *              be nonnegative. This is a symmetric matrix and hence s,,ij,, = s,,ji,,. For
    +   *              any (i, j) with nonzero similarity, there should be either (i, j, s,,ij,,)
    +   *              or (j, i, s,,ji,,) in the input. Tuples with i = j are ignored, because we
    +   *              assume s,,ij,, = 0.0.
    +   *
    +   * @return a [[PowerIterationClusteringModel]] that contains the clustering result
    +   */
    +  def run(graph: Graph[Double, Double]): PowerIterationClusteringModel = {
    +    val w = normalize(graph)
    +    val w0 = initMode match {
    +      case "random" => randomInit(w)
    +      case "degree" => initDegreeVector(w)
    +    }
    +    pic(w0)
    +  }
    +
       /**
        * Run the PIC algorithm.
        *
    @@ -212,6 +233,31 @@ object PowerIterationClustering extends Logging {
       @Experimental
       case class Assignment(id: Long, cluster: Int)
     
    +  /**
    +   * Normalizes the affinity graph (A) and returns the normalized affinity matrix (W).
    +   */
    +  private[clustering]
    +  def normalize(graph: Graph[Double, Double]): Graph[Double, Double] = {
    +    val vD = graph.aggregateMessages[Double](
    +      sendMsg = ctx => {
    +        val i = ctx.srcId
    +        val j = ctx.dstId
    +        val s = ctx.attr
    +        if (s < 0.0) {
    +          throw new SparkException("Similarity must be nonnegative but found s($i, $j) = $s.")
    +        }
    +        if (s > 0.0) {
    +          ctx.sendToSrc(s)
    +        }
    +      },
    +      mergeMsg = _ + _,
    +      TripletFields.EdgeOnly)
    +    GraphImpl.fromExistingRDDs(vD, graph.edges)
    +      .mapTriplets(
    +        e => e.attr / math.max(e.srcAttr, MLUtils.EPSILON),
    +        TripletFields.Src)
    +  }
    +
       /**
        * Normalizes the affinity matrix (A) by row sums and returns the normalized affinity matrix (W).
        */
    diff --git a/mllib/src/main/scala/org/apache/spark/mllib/evaluation/RegressionMetrics.scala b/mllib/src/main/scala/org/apache/spark/mllib/evaluation/RegressionMetrics.scala
    index e577bf87f885e..408847afa800d 100644
    --- a/mllib/src/main/scala/org/apache/spark/mllib/evaluation/RegressionMetrics.scala
    +++ b/mllib/src/main/scala/org/apache/spark/mllib/evaluation/RegressionMetrics.scala
    @@ -53,14 +53,22 @@ class RegressionMetrics(predictionAndObservations: RDD[(Double, Double)]) extend
           )
         summary
       }
    +  private lazy val SSerr = math.pow(summary.normL2(1), 2)
    +  private lazy val SStot = summary.variance(0) * (summary.count - 1)
    +  private lazy val SSreg = {
    +    val yMean = summary.mean(0)
    +    predictionAndObservations.map {
    +      case (prediction, _) => math.pow(prediction - yMean, 2)
    +    }.sum()
    +  }
     
       /**
    -   * Returns the explained variance regression score.
    -   * explainedVariance = 1 - variance(y - \hat{y}) / variance(y)
    -   * Reference: [[http://en.wikipedia.org/wiki/Explained_variation]]
    +   * Returns the variance explained by regression.
    +   * explainedVariance = \sum_i (\hat{y_i} - \bar{y})^2 / n
    +   * @see [[https://en.wikipedia.org/wiki/Fraction_of_variance_unexplained]]
        */
       def explainedVariance: Double = {
    -    1 - summary.variance(1) / summary.variance(0)
    +    SSreg / summary.count
       }
     
       /**
    @@ -76,8 +84,7 @@ class RegressionMetrics(predictionAndObservations: RDD[(Double, Double)]) extend
        * expected value of the squared error loss or quadratic loss.
        */
       def meanSquaredError: Double = {
    -    val rmse = summary.normL2(1) / math.sqrt(summary.count)
    -    rmse * rmse
    +    SSerr / summary.count
       }
     
       /**
    @@ -85,14 +92,14 @@ class RegressionMetrics(predictionAndObservations: RDD[(Double, Double)]) extend
        * the mean squared error.
        */
       def rootMeanSquaredError: Double = {
    -    summary.normL2(1) / math.sqrt(summary.count)
    +    math.sqrt(this.meanSquaredError)
       }
     
       /**
    -   * Returns R^2^, the coefficient of determination.
    -   * Reference: [[http://en.wikipedia.org/wiki/Coefficient_of_determination]]
    +   * Returns R^2^, the unadjusted coefficient of determination.
    +   * @see [[http://en.wikipedia.org/wiki/Coefficient_of_determination]]
        */
       def r2: Double = {
    -    1 - math.pow(summary.normL2(1), 2) / (summary.variance(0) * (summary.count - 1))
    +    1 - SSerr / SStot
       }
     }
    diff --git a/mllib/src/main/scala/org/apache/spark/mllib/fpm/LocalPrefixSpan.scala b/mllib/src/main/scala/org/apache/spark/mllib/fpm/LocalPrefixSpan.scala
    index 39c48b084e550..7ead6327486cc 100644
    --- a/mllib/src/main/scala/org/apache/spark/mllib/fpm/LocalPrefixSpan.scala
    +++ b/mllib/src/main/scala/org/apache/spark/mllib/fpm/LocalPrefixSpan.scala
    @@ -17,58 +17,49 @@
     
     package org.apache.spark.mllib.fpm
     
    +import scala.collection.mutable
    +
     import org.apache.spark.Logging
    -import org.apache.spark.annotation.Experimental
     
     /**
    - *
    - * :: Experimental ::
    - *
      * Calculate all patterns of a projected database in local.
      */
    -@Experimental
     private[fpm] object LocalPrefixSpan extends Logging with Serializable {
     
       /**
        * Calculate all patterns of a projected database.
        * @param minCount minimum count
        * @param maxPatternLength maximum pattern length
    -   * @param prefix prefix
    -   * @param projectedDatabase the projected dabase
    +   * @param prefixes prefixes in reversed order
    +   * @param database the projected database
        * @return a set of sequential pattern pairs,
    -   *         the key of pair is sequential pattern (a list of items),
    +   *         the key of pair is sequential pattern (a list of items in reversed order),
        *         the value of pair is the pattern's count.
        */
       def run(
           minCount: Long,
           maxPatternLength: Int,
    -      prefix: Array[Int],
    -      projectedDatabase: Array[Array[Int]]): Array[(Array[Int], Long)] = {
    -    val frequentPrefixAndCounts = getFreqItemAndCounts(minCount, projectedDatabase)
    -    val frequentPatternAndCounts = frequentPrefixAndCounts
    -      .map(x => (prefix ++ Array(x._1), x._2))
    -    val prefixProjectedDatabases = getPatternAndProjectedDatabase(
    -      prefix, frequentPrefixAndCounts.map(_._1), projectedDatabase)
    -
    -    val continueProcess = prefixProjectedDatabases.nonEmpty && prefix.length + 1 < maxPatternLength
    -    if (continueProcess) {
    -      val nextPatterns = prefixProjectedDatabases
    -        .map(x => run(minCount, maxPatternLength, x._1, x._2))
    -        .reduce(_ ++ _)
    -      frequentPatternAndCounts ++ nextPatterns
    -    } else {
    -      frequentPatternAndCounts
    +      prefixes: List[Int],
    +      database: Array[Array[Int]]): Iterator[(List[Int], Long)] = {
    +    if (prefixes.length == maxPatternLength || database.isEmpty) return Iterator.empty
    +    val frequentItemAndCounts = getFreqItemAndCounts(minCount, database)
    +    val filteredDatabase = database.map(x => x.filter(frequentItemAndCounts.contains))
    +    frequentItemAndCounts.iterator.flatMap { case (item, count) =>
    +      val newPrefixes = item :: prefixes
    +      val newProjected = project(filteredDatabase, item)
    +      Iterator.single((newPrefixes, count)) ++
    +        run(minCount, maxPatternLength, newPrefixes, newProjected)
         }
       }
     
       /**
    -   * calculate suffix sequence following a prefix in a sequence
    -   * @param prefix prefix
    -   * @param sequence sequence
    +   * Calculate suffix sequence immediately after the first occurrence of an item.
    +   * @param item item to get suffix after
    +   * @param sequence sequence to extract suffix from
        * @return suffix sequence
        */
    -  def getSuffix(prefix: Int, sequence: Array[Int]): Array[Int] = {
    -    val index = sequence.indexOf(prefix)
    +  def getSuffix(item: Int, sequence: Array[Int]): Array[Int] = {
    +    val index = sequence.indexOf(item)
         if (index == -1) {
           Array()
         } else {
    @@ -76,38 +67,28 @@ private[fpm] object LocalPrefixSpan extends Logging with Serializable {
         }
       }
     
    +  def project(database: Array[Array[Int]], prefix: Int): Array[Array[Int]] = {
    +    database
    +      .map(getSuffix(prefix, _))
    +      .filter(_.nonEmpty)
    +  }
    +
       /**
        * Generates frequent items by filtering the input data using minimal count level.
    -   * @param minCount the absolute minimum count
    -   * @param sequences sequences data
    -   * @return array of item and count pair
    +   * @param minCount the minimum count for an item to be frequent
    +   * @param database database of sequences
    +   * @return freq item to count map
        */
       private def getFreqItemAndCounts(
           minCount: Long,
    -      sequences: Array[Array[Int]]): Array[(Int, Long)] = {
    -    sequences.flatMap(_.distinct)
    -      .groupBy(x => x)
    -      .mapValues(_.length.toLong)
    -      .filter(_._2 >= minCount)
    -      .toArray
    -  }
    -
    -  /**
    -   * Get the frequent prefixes' projected database.
    -   * @param prePrefix the frequent prefixes' prefix
    -   * @param frequentPrefixes frequent prefixes
    -   * @param sequences sequences data
    -   * @return prefixes and projected database
    -   */
    -  private def getPatternAndProjectedDatabase(
    -      prePrefix: Array[Int],
    -      frequentPrefixes: Array[Int],
    -      sequences: Array[Array[Int]]): Array[(Array[Int], Array[Array[Int]])] = {
    -    val filteredProjectedDatabase = sequences
    -      .map(x => x.filter(frequentPrefixes.contains(_)))
    -    frequentPrefixes.map { x =>
    -      val sub = filteredProjectedDatabase.map(y => getSuffix(x, y)).filter(_.nonEmpty)
    -      (prePrefix ++ Array(x), sub)
    -    }.filter(x => x._2.nonEmpty)
    +      database: Array[Array[Int]]): mutable.Map[Int, Long] = {
    +    // TODO: use PrimitiveKeyOpenHashMap
    +    val counts = mutable.Map[Int, Long]().withDefaultValue(0L)
    +    database.foreach { sequence =>
    +      sequence.distinct.foreach { item =>
    +        counts(item) += 1L
    +      }
    +    }
    +    counts.filter(_._2 >= minCount)
       }
     }
    diff --git a/mllib/src/main/scala/org/apache/spark/mllib/fpm/PrefixSpan.scala b/mllib/src/main/scala/org/apache/spark/mllib/fpm/PrefixSpan.scala
    index 9d8c60ef0fc45..6f52db7b073ae 100644
    --- a/mllib/src/main/scala/org/apache/spark/mllib/fpm/PrefixSpan.scala
    +++ b/mllib/src/main/scala/org/apache/spark/mllib/fpm/PrefixSpan.scala
    @@ -150,8 +150,9 @@ class PrefixSpan private (
       private def getPatternsInLocal(
           minCount: Long,
           data: RDD[(Array[Int], Array[Array[Int]])]): RDD[(Array[Int], Long)] = {
    -    data.flatMap { x =>
    -      LocalPrefixSpan.run(minCount, maxPatternLength, x._1, x._2)
    +    data.flatMap { case (prefix, projDB) =>
    +      LocalPrefixSpan.run(minCount, maxPatternLength, prefix.toList, projDB)
    +        .map { case (pattern: List[Int], count: Long) => (pattern.toArray.reverse, count) }
         }
       }
     }
    diff --git a/mllib/src/main/scala/org/apache/spark/mllib/linalg/BLAS.scala b/mllib/src/main/scala/org/apache/spark/mllib/linalg/BLAS.scala
    index 3523f1804325d..9029093e0fa08 100644
    --- a/mllib/src/main/scala/org/apache/spark/mllib/linalg/BLAS.scala
    +++ b/mllib/src/main/scala/org/apache/spark/mllib/linalg/BLAS.scala
    @@ -303,8 +303,8 @@ private[spark] object BLAS extends Serializable with Logging {
           C: DenseMatrix): Unit = {
         require(!C.isTransposed,
           "The matrix C cannot be the product of a transpose() call. C.isTransposed must be false.")
    -    if (alpha == 0.0) {
    -      logDebug("gemm: alpha is equal to 0. Returning C.")
    +    if (alpha == 0.0 && beta == 1.0) {
    +      logDebug("gemm: alpha is equal to 0 and beta is equal to 1. Returning C.")
         } else {
           A match {
             case sparse: SparseMatrix => gemm(alpha, sparse, B, beta, C)
    diff --git a/mllib/src/main/scala/org/apache/spark/mllib/linalg/Matrices.scala b/mllib/src/main/scala/org/apache/spark/mllib/linalg/Matrices.scala
    index 0df07663405a3..55da0e094d132 100644
    --- a/mllib/src/main/scala/org/apache/spark/mllib/linalg/Matrices.scala
    +++ b/mllib/src/main/scala/org/apache/spark/mllib/linalg/Matrices.scala
    @@ -98,7 +98,7 @@ sealed trait Matrix extends Serializable {
       /** Map the values of this matrix using a function. Generates a new matrix. Performs the
         * function on only the backing array. For example, an operation such as addition or
         * subtraction will only be performed on the non-zero values in a `SparseMatrix`. */
    -  private[mllib] def map(f: Double => Double): Matrix
    +  private[spark] def map(f: Double => Double): Matrix
     
       /** Update all the values of this matrix using the function f. Performed in-place on the
         * backing array. For example, an operation such as addition or subtraction will only be
    @@ -289,7 +289,7 @@ class DenseMatrix(
     
       override def copy: DenseMatrix = new DenseMatrix(numRows, numCols, values.clone())
     
    -  private[mllib] def map(f: Double => Double) = new DenseMatrix(numRows, numCols, values.map(f),
    +  private[spark] def map(f: Double => Double) = new DenseMatrix(numRows, numCols, values.map(f),
         isTransposed)
     
       private[mllib] def update(f: Double => Double): DenseMatrix = {
    @@ -555,7 +555,7 @@ class SparseMatrix(
         new SparseMatrix(numRows, numCols, colPtrs, rowIndices, values.clone())
       }
     
    -  private[mllib] def map(f: Double => Double) =
    +  private[spark] def map(f: Double => Double) =
         new SparseMatrix(numRows, numCols, colPtrs, rowIndices, values.map(f), isTransposed)
     
       private[mllib] def update(f: Double => Double): SparseMatrix = {
    diff --git a/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala b/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala
    index e048b01d92462..9067b3ba9a7bb 100644
    --- a/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala
    +++ b/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala
    @@ -150,6 +150,12 @@ sealed trait Vector extends Serializable {
           toDense
         }
       }
    +
    +  /**
    +   * Find the index of a maximal element.  Returns the first maximal element in case of a tie.
    +   * Returns -1 if vector has length 0.
    +   */
    +  def argmax: Int
     }
     
     /**
    @@ -588,11 +594,7 @@ class DenseVector(val values: Array[Double]) extends Vector {
         new SparseVector(size, ii, vv)
       }
     
    -  /**
    -   * Find the index of a maximal element.  Returns the first maximal element in case of a tie.
    -   * Returns -1 if vector has length 0.
    -   */
    -  private[spark] def argmax: Int = {
    +  override def argmax: Int = {
         if (size == 0) {
           -1
         } else {
    @@ -717,6 +719,51 @@ class SparseVector(
           new SparseVector(size, ii, vv)
         }
       }
    +
    +  override def argmax: Int = {
    +    if (size == 0) {
    +      -1
    +    } else {
    +      // Find the max active entry.
    +      var maxIdx = indices(0)
    +      var maxValue = values(0)
    +      var maxJ = 0
    +      var j = 1
    +      val na = numActives
    +      while (j < na) {
    +        val v = values(j)
    +        if (v > maxValue) {
    +          maxValue = v
    +          maxIdx = indices(j)
    +          maxJ = j
    +        }
    +        j += 1
    +      }
    +
    +      // If the max active entry is nonpositive and there exists inactive ones, find the first zero.
    +      if (maxValue <= 0.0 && na < size) {
    +        if (maxValue == 0.0) {
    +          // If there exists an inactive entry before maxIdx, find it and return its index.
    +          if (maxJ < maxIdx) {
    +            var k = 0
    +            while (k < maxJ && indices(k) == k) {
    +              k += 1
    +            }
    +            maxIdx = k
    +          }
    +        } else {
    +          // If the max active value is negative, find and return the first inactive index.
    +          var k = 0
    +          while (k < na && indices(k) == k) {
    +            k += 1
    +          }
    +          maxIdx = k
    +        }
    +      }
    +
    +      maxIdx
    +    }
    +  }
     }
     
     object SparseVector {
    diff --git a/mllib/src/main/scala/org/apache/spark/mllib/rdd/SlidingRDD.scala b/mllib/src/main/scala/org/apache/spark/mllib/rdd/SlidingRDD.scala
    index 35e81fcb3de0d..1facf83d806d0 100644
    --- a/mllib/src/main/scala/org/apache/spark/mllib/rdd/SlidingRDD.scala
    +++ b/mllib/src/main/scala/org/apache/spark/mllib/rdd/SlidingRDD.scala
    @@ -72,7 +72,7 @@ class SlidingRDD[T: ClassTag](@transient val parent: RDD[T], val windowSize: Int
           val w1 = windowSize - 1
           // Get the first w1 items of each partition, starting from the second partition.
           val nextHeads =
    -        parent.context.runJob(parent, (iter: Iterator[T]) => iter.take(w1).toArray, 1 until n, true)
    +        parent.context.runJob(parent, (iter: Iterator[T]) => iter.take(w1).toArray, 1 until n)
           val partitions = mutable.ArrayBuffer[SlidingRDDPartition[T]]()
           var i = 0
           var partitionIndex = 0
    diff --git a/mllib/src/main/scala/org/apache/spark/mllib/stat/test/KolmogorovSmirnovTest.scala b/mllib/src/main/scala/org/apache/spark/mllib/stat/test/KolmogorovSmirnovTest.scala
    index d89b0059d83f3..2b3ed6df486c9 100644
    --- a/mllib/src/main/scala/org/apache/spark/mllib/stat/test/KolmogorovSmirnovTest.scala
    +++ b/mllib/src/main/scala/org/apache/spark/mllib/stat/test/KolmogorovSmirnovTest.scala
    @@ -20,7 +20,7 @@ package org.apache.spark.mllib.stat.test
     import scala.annotation.varargs
     
     import org.apache.commons.math3.distribution.{NormalDistribution, RealDistribution}
    -import org.apache.commons.math3.stat.inference.KolmogorovSmirnovTest
    +import org.apache.commons.math3.stat.inference.{KolmogorovSmirnovTest => CommonMathKolmogorovSmirnovTest}
     
     import org.apache.spark.Logging
     import org.apache.spark.rdd.RDD
    @@ -187,7 +187,7 @@ private[stat] object KolmogorovSmirnovTest extends Logging {
       }
     
       private def evalOneSampleP(ksStat: Double, n: Long): KolmogorovSmirnovTestResult = {
    -    val pval = 1 - new KolmogorovSmirnovTest().cdf(ksStat, n.toInt)
    +    val pval = 1 - new CommonMathKolmogorovSmirnovTest().cdf(ksStat, n.toInt)
         new KolmogorovSmirnovTestResult(pval, ksStat, NullHypothesis.OneSampleTwoSided.toString)
       }
     }
    diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/BaggedPoint.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/BaggedPoint.scala
    index 089010c81ffb6..572815df0bc4a 100644
    --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/BaggedPoint.scala
    +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/BaggedPoint.scala
    @@ -38,10 +38,10 @@ import org.apache.spark.util.random.XORShiftRandom
      * TODO: This does not currently support (Double) weighted instances.  Once MLlib has weighted
      *       dataset support, update.  (We store subsampleWeights as Double for this future extension.)
      */
    -private[tree] class BaggedPoint[Datum](val datum: Datum, val subsampleWeights: Array[Double])
    +private[spark] class BaggedPoint[Datum](val datum: Datum, val subsampleWeights: Array[Double])
       extends Serializable
     
    -private[tree] object BaggedPoint {
    +private[spark] object BaggedPoint {
     
       /**
        * Convert an input dataset into its BaggedPoint representation,
    @@ -60,7 +60,7 @@ private[tree] object BaggedPoint {
           subsamplingRate: Double,
           numSubsamples: Int,
           withReplacement: Boolean,
    -      seed: Int = Utils.random.nextInt()): RDD[BaggedPoint[Datum]] = {
    +      seed: Long = Utils.random.nextLong()): RDD[BaggedPoint[Datum]] = {
         if (withReplacement) {
           convertToBaggedRDDSamplingWithReplacement(input, subsamplingRate, numSubsamples, seed)
         } else {
    @@ -76,7 +76,7 @@ private[tree] object BaggedPoint {
           input: RDD[Datum],
           subsamplingRate: Double,
           numSubsamples: Int,
    -      seed: Int): RDD[BaggedPoint[Datum]] = {
    +      seed: Long): RDD[BaggedPoint[Datum]] = {
         input.mapPartitionsWithIndex { (partitionIndex, instances) =>
           // Use random seed = seed + partitionIndex + 1 to make generation reproducible.
           val rng = new XORShiftRandom
    @@ -100,7 +100,7 @@ private[tree] object BaggedPoint {
           input: RDD[Datum],
           subsample: Double,
           numSubsamples: Int,
    -      seed: Int): RDD[BaggedPoint[Datum]] = {
    +      seed: Long): RDD[BaggedPoint[Datum]] = {
         input.mapPartitionsWithIndex { (partitionIndex, instances) =>
           // Use random seed = seed + partitionIndex + 1 to make generation reproducible.
           val poisson = new PoissonDistribution(subsample)
    diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/DTStatsAggregator.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/DTStatsAggregator.scala
    index ce8825cc03229..7985ed4b4c0fa 100644
    --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/DTStatsAggregator.scala
    +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/DTStatsAggregator.scala
    @@ -27,7 +27,7 @@ import org.apache.spark.mllib.tree.impurity._
      * and helps with indexing.
      * This class is abstract to support learning with and without feature subsampling.
      */
    -private[tree] class DTStatsAggregator(
    +private[spark] class DTStatsAggregator(
         val metadata: DecisionTreeMetadata,
         featureSubset: Option[Array[Int]]) extends Serializable {
     
    diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/DecisionTreeMetadata.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/DecisionTreeMetadata.scala
    index f73896e37c05e..380291ac22bd3 100644
    --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/DecisionTreeMetadata.scala
    +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/DecisionTreeMetadata.scala
    @@ -37,7 +37,7 @@ import org.apache.spark.rdd.RDD
      *                      I.e., the feature takes values in {0, ..., arity - 1}.
      * @param numBins  Number of bins for each feature.
      */
    -private[tree] class DecisionTreeMetadata(
    +private[spark] class DecisionTreeMetadata(
         val numFeatures: Int,
         val numExamples: Long,
         val numClasses: Int,
    @@ -94,7 +94,7 @@ private[tree] class DecisionTreeMetadata(
     
     }
     
    -private[tree] object DecisionTreeMetadata extends Logging {
    +private[spark] object DecisionTreeMetadata extends Logging {
     
       /**
        * Construct a [[DecisionTreeMetadata]] instance for this dataset and parameters.
    diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/NodeIdCache.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/NodeIdCache.scala
    index bdd0f576b048d..8f9eb24b57b55 100644
    --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/NodeIdCache.scala
    +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/NodeIdCache.scala
    @@ -75,7 +75,7 @@ private[tree] case class NodeIndexUpdater(
      *                           (how often should the cache be checkpointed.).
      */
     @DeveloperApi
    -private[tree] class NodeIdCache(
    +private[spark] class NodeIdCache(
       var nodeIdsForInstances: RDD[Array[Int]],
       val checkpointInterval: Int) {
     
    @@ -170,7 +170,7 @@ private[tree] class NodeIdCache(
     }
     
     @DeveloperApi
    -private[tree] object NodeIdCache {
    +private[spark] object NodeIdCache {
       /**
        * Initialize the node Id cache with initial node Id values.
        * @param data The RDD of training rows.
    diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/TimeTracker.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/TimeTracker.scala
    index d215d68c4279e..aac84243d5ce1 100644
    --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/TimeTracker.scala
    +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/TimeTracker.scala
    @@ -25,7 +25,7 @@ import org.apache.spark.annotation.Experimental
      * Time tracker implementation which holds labeled timers.
      */
     @Experimental
    -private[tree] class TimeTracker extends Serializable {
    +private[spark] class TimeTracker extends Serializable {
     
       private val starts: MutableHashMap[String, Long] = new MutableHashMap[String, Long]()
     
    diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/TreePoint.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/TreePoint.scala
    index 50b292e71b067..21919d69a38a3 100644
    --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/TreePoint.scala
    +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/TreePoint.scala
    @@ -37,11 +37,11 @@ import org.apache.spark.rdd.RDD
      * @param binnedFeatures  Binned feature values.
      *                        Same length as LabeledPoint.features, but values are bin indices.
      */
    -private[tree] class TreePoint(val label: Double, val binnedFeatures: Array[Int])
    +private[spark] class TreePoint(val label: Double, val binnedFeatures: Array[Int])
       extends Serializable {
     }
     
    -private[tree] object TreePoint {
    +private[spark] object TreePoint {
     
       /**
        * Convert an input dataset into its TreePoint representation,
    diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Impurity.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Impurity.scala
    index 72eb24c49264a..578749d85a4e6 100644
    --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Impurity.scala
    +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Impurity.scala
    @@ -57,7 +57,7 @@ trait Impurity extends Serializable {
      * Note: Instances of this class do not hold the data; they operate on views of the data.
      * @param statsSize  Length of the vector of sufficient statistics for one bin.
      */
    -private[tree] abstract class ImpurityAggregator(val statsSize: Int) extends Serializable {
    +private[spark] abstract class ImpurityAggregator(val statsSize: Int) extends Serializable {
     
       /**
        * Merge the stats from one bin into another.
    @@ -95,7 +95,7 @@ private[tree] abstract class ImpurityAggregator(val statsSize: Int) extends Seri
      * (node, feature, bin).
      * @param stats  Array of sufficient statistics for a (node, feature, bin).
      */
    -private[tree] abstract class ImpurityCalculator(val stats: Array[Double]) {
    +private[spark] abstract class ImpurityCalculator(val stats: Array[Double]) {
     
       /**
        * Make a deep copy of this [[ImpurityCalculator]].
    diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/loss/SquaredError.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/loss/SquaredError.scala
    index a5582d3ef3324..011a5d57422f7 100644
    --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/loss/SquaredError.scala
    +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/loss/SquaredError.scala
    @@ -42,11 +42,11 @@ object SquaredError extends Loss {
        * @return Loss gradient
        */
       override def gradient(prediction: Double, label: Double): Double = {
    -    2.0 * (prediction - label)
    +    - 2.0 * (label - prediction)
       }
     
       override private[mllib] def computeError(prediction: Double, label: Double): Double = {
    -    val err = prediction - label
    +    val err = label - prediction
         err * err
       }
     }
    diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/InformationGainStats.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/InformationGainStats.scala
    index 2d087c967f679..dc9e0f9f51ffb 100644
    --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/InformationGainStats.scala
    +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/InformationGainStats.scala
    @@ -67,7 +67,7 @@ class InformationGainStats(
     }
     
     
    -private[tree] object InformationGainStats {
    +private[spark] object InformationGainStats {
       /**
        * An [[org.apache.spark.mllib.tree.model.InformationGainStats]] object to
        * denote that current split doesn't satisfies minimum info gain or
    diff --git a/mllib/src/test/java/org/apache/spark/ml/classification/JavaNaiveBayesSuite.java b/mllib/src/test/java/org/apache/spark/ml/classification/JavaNaiveBayesSuite.java
    new file mode 100644
    index 0000000000000..09a9fba0c19cf
    --- /dev/null
    +++ b/mllib/src/test/java/org/apache/spark/ml/classification/JavaNaiveBayesSuite.java
    @@ -0,0 +1,98 @@
    +/*
    + * Licensed to the Apache Software Foundation (ASF) under one or more
    + * contributor license agreements.  See the NOTICE file distributed with
    + * this work for additional information regarding copyright ownership.
    + * The ASF licenses this file to You under the Apache License, Version 2.0
    + * (the "License"); you may not use this file except in compliance with
    + * the License.  You may obtain a copy of the License at
    + *
    + *    http://www.apache.org/licenses/LICENSE-2.0
    + *
    + * Unless required by applicable law or agreed to in writing, software
    + * distributed under the License is distributed on an "AS IS" BASIS,
    + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    + * See the License for the specific language governing permissions and
    + * limitations under the License.
    + */
    +
    +package org.apache.spark.ml.classification;
    +
    +import java.io.Serializable;
    +
    +import com.google.common.collect.Lists;
    +import org.junit.After;
    +import org.junit.Before;
    +import org.junit.Test;
    +
    +import org.apache.spark.api.java.JavaRDD;
    +import org.apache.spark.api.java.JavaSparkContext;
    +import org.apache.spark.mllib.linalg.VectorUDT;
    +import org.apache.spark.mllib.linalg.Vectors;
    +import org.apache.spark.sql.DataFrame;
    +import org.apache.spark.sql.Row;
    +import org.apache.spark.sql.RowFactory;
    +import org.apache.spark.sql.SQLContext;
    +import org.apache.spark.sql.types.DataTypes;
    +import org.apache.spark.sql.types.Metadata;
    +import org.apache.spark.sql.types.StructField;
    +import org.apache.spark.sql.types.StructType;
    +
    +public class JavaNaiveBayesSuite implements Serializable {
    +
    +  private transient JavaSparkContext jsc;
    +  private transient SQLContext jsql;
    +
    +  @Before
    +  public void setUp() {
    +    jsc = new JavaSparkContext("local", "JavaLogisticRegressionSuite");
    +    jsql = new SQLContext(jsc);
    +  }
    +
    +  @After
    +  public void tearDown() {
    +    jsc.stop();
    +    jsc = null;
    +  }
    +
    +  public void validatePrediction(DataFrame predictionAndLabels) {
    +    for (Row r : predictionAndLabels.collect()) {
    +      double prediction = r.getAs(0);
    +      double label = r.getAs(1);
    +      assert(prediction == label);
    +    }
    +  }
    +
    +  @Test
    +  public void naiveBayesDefaultParams() {
    +    NaiveBayes nb = new NaiveBayes();
    +    assert(nb.getLabelCol() == "label");
    +    assert(nb.getFeaturesCol() == "features");
    +    assert(nb.getPredictionCol() == "prediction");
    +    assert(nb.getLambda() == 1.0);
    +    assert(nb.getModelType() == "multinomial");
    +  }
    +
    +  @Test
    +  public void testNaiveBayes() {
    +    JavaRDD jrdd = jsc.parallelize(Lists.newArrayList(
    +      RowFactory.create(0.0, Vectors.dense(1.0, 0.0, 0.0)),
    +      RowFactory.create(0.0, Vectors.dense(2.0, 0.0, 0.0)),
    +      RowFactory.create(1.0, Vectors.dense(0.0, 1.0, 0.0)),
    +      RowFactory.create(1.0, Vectors.dense(0.0, 2.0, 0.0)),
    +      RowFactory.create(2.0, Vectors.dense(0.0, 0.0, 1.0)),
    +      RowFactory.create(2.0, Vectors.dense(0.0, 0.0, 2.0))
    +    ));
    +
    +    StructType schema = new StructType(new StructField[]{
    +      new StructField("label", DataTypes.DoubleType, false, Metadata.empty()),
    +      new StructField("features", new VectorUDT(), false, Metadata.empty())
    +    });
    +
    +    DataFrame dataset = jsql.createDataFrame(jrdd, schema);
    +    NaiveBayes nb = new NaiveBayes().setLambda(0.5).setModelType("multinomial");
    +    NaiveBayesModel model = nb.fit(dataset);
    +
    +    DataFrame predictionAndLabels = model.transform(dataset).select("prediction", "label");
    +    validatePrediction(predictionAndLabels);
    +  }
    +}
    diff --git a/mllib/src/test/java/org/apache/spark/ml/clustering/JavaKMeansSuite.java b/mllib/src/test/java/org/apache/spark/ml/clustering/JavaKMeansSuite.java
    new file mode 100644
    index 0000000000000..d09fa7fd5637c
    --- /dev/null
    +++ b/mllib/src/test/java/org/apache/spark/ml/clustering/JavaKMeansSuite.java
    @@ -0,0 +1,72 @@
    +/*
    + * Licensed to the Apache Software Foundation (ASF) under one or more
    + * contributor license agreements.  See the NOTICE file distributed with
    + * this work for additional information regarding copyright ownership.
    + * The ASF licenses this file to You under the Apache License, Version 2.0
    + * (the "License"); you may not use this file except in compliance with
    + * the License.  You may obtain a copy of the License at
    + *
    + *    http://www.apache.org/licenses/LICENSE-2.0
    + *
    + * Unless required by applicable law or agreed to in writing, software
    + * distributed under the License is distributed on an "AS IS" BASIS,
    + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    + * See the License for the specific language governing permissions and
    + * limitations under the License.
    + */
    +
    +package org.apache.spark.ml.clustering;
    +
    +import java.io.Serializable;
    +import java.util.Arrays;
    +import java.util.List;
    +
    +import org.junit.After;
    +import org.junit.Before;
    +import org.junit.Test;
    +import static org.junit.Assert.assertArrayEquals;
    +import static org.junit.Assert.assertEquals;
    +import static org.junit.Assert.assertTrue;
    +
    +import org.apache.spark.api.java.JavaSparkContext;
    +import org.apache.spark.mllib.linalg.Vector;
    +import org.apache.spark.sql.DataFrame;
    +import org.apache.spark.sql.SQLContext;
    +
    +public class JavaKMeansSuite implements Serializable {
    +
    +  private transient int k = 5;
    +  private transient JavaSparkContext sc;
    +  private transient DataFrame dataset;
    +  private transient SQLContext sql;
    +
    +  @Before
    +  public void setUp() {
    +    sc = new JavaSparkContext("local", "JavaKMeansSuite");
    +    sql = new SQLContext(sc);
    +
    +    dataset = KMeansSuite.generateKMeansData(sql, 50, 3, k);
    +  }
    +
    +  @After
    +  public void tearDown() {
    +    sc.stop();
    +    sc = null;
    +  }
    +
    +  @Test
    +  public void fitAndTransform() {
    +    KMeans kmeans = new KMeans().setK(k).setSeed(1);
    +    KMeansModel model = kmeans.fit(dataset);
    +
    +    Vector[] centers = model.clusterCenters();
    +    assertEquals(k, centers.length);
    +
    +    DataFrame transformed = model.transform(dataset);
    +    List columns = Arrays.asList(transformed.columns());
    +    List expectedColumns = Arrays.asList("features", "prediction");
    +    for (String column: expectedColumns) {
    +      assertTrue(columns.contains(column));
    +    }
    +  }
    +}
    diff --git a/mllib/src/test/java/org/apache/spark/ml/param/JavaTestParams.java b/mllib/src/test/java/org/apache/spark/ml/param/JavaTestParams.java
    index 3ae09d39ef500..dc6ce8061f62b 100644
    --- a/mllib/src/test/java/org/apache/spark/ml/param/JavaTestParams.java
    +++ b/mllib/src/test/java/org/apache/spark/ml/param/JavaTestParams.java
    @@ -96,11 +96,8 @@ private void init() {
           new DoubleArrayParam(this, "myDoubleArrayParam", "this is a double param");
     
         setDefault(myIntParam(), 1);
    -    setDefault(myIntParam().w(1));
         setDefault(myDoubleParam(), 0.5);
    -    setDefault(myIntParam().w(1), myDoubleParam().w(0.5));
         setDefault(myDoubleArrayParam(), new double[] {1.0, 2.0});
    -    setDefault(myDoubleArrayParam().w(new double[] {1.0, 2.0}));
       }
     
       @Override
    diff --git a/mllib/src/test/java/org/apache/spark/ml/regression/JavaDecisionTreeRegressorSuite.java b/mllib/src/test/java/org/apache/spark/ml/regression/JavaDecisionTreeRegressorSuite.java
    index 71b041818d7ee..ebe800e749e05 100644
    --- a/mllib/src/test/java/org/apache/spark/ml/regression/JavaDecisionTreeRegressorSuite.java
    +++ b/mllib/src/test/java/org/apache/spark/ml/regression/JavaDecisionTreeRegressorSuite.java
    @@ -57,7 +57,7 @@ public void runDT() {
         JavaRDD data = sc.parallelize(
           LogisticRegressionSuite.generateLogisticInputAsList(A, B, nPoints, 42), 2).cache();
         Map categoricalFeatures = new HashMap();
    -    DataFrame dataFrame = TreeTests.setMetadata(data, categoricalFeatures, 2);
    +    DataFrame dataFrame = TreeTests.setMetadata(data, categoricalFeatures, 0);
     
         // This tests setters. Training with various options is tested in Scala.
         DecisionTreeRegressor dt = new DecisionTreeRegressor()
    diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/NaiveBayesSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/NaiveBayesSuite.scala
    new file mode 100644
    index 0000000000000..76381a2741296
    --- /dev/null
    +++ b/mllib/src/test/scala/org/apache/spark/ml/classification/NaiveBayesSuite.scala
    @@ -0,0 +1,116 @@
    +/*
    + * Licensed to the Apache Software Foundation (ASF) under one or more
    + * contributor license agreements.  See the NOTICE file distributed with
    + * this work for additional information regarding copyright ownership.
    + * The ASF licenses this file to You under the Apache License, Version 2.0
    + * (the "License"); you may not use this file except in compliance with
    + * the License.  You may obtain a copy of the License at
    + *
    + *    http://www.apache.org/licenses/LICENSE-2.0
    + *
    + * Unless required by applicable law or agreed to in writing, software
    + * distributed under the License is distributed on an "AS IS" BASIS,
    + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    + * See the License for the specific language governing permissions and
    + * limitations under the License.
    + */
    +
    +package org.apache.spark.ml.classification
    +
    +import org.apache.spark.SparkFunSuite
    +import org.apache.spark.ml.param.ParamsSuite
    +import org.apache.spark.mllib.linalg._
    +import org.apache.spark.mllib.util.MLlibTestSparkContext
    +import org.apache.spark.mllib.util.TestingUtils._
    +import org.apache.spark.mllib.classification.NaiveBayesSuite._
    +import org.apache.spark.sql.DataFrame
    +import org.apache.spark.sql.Row
    +
    +class NaiveBayesSuite extends SparkFunSuite with MLlibTestSparkContext {
    +
    +  def validatePrediction(predictionAndLabels: DataFrame): Unit = {
    +    val numOfErrorPredictions = predictionAndLabels.collect().count {
    +      case Row(prediction: Double, label: Double) =>
    +        prediction != label
    +    }
    +    // At least 80% of the predictions should be on.
    +    assert(numOfErrorPredictions < predictionAndLabels.count() / 5)
    +  }
    +
    +  def validateModelFit(
    +      piData: Vector,
    +      thetaData: Matrix,
    +      model: NaiveBayesModel): Unit = {
    +    assert(Vectors.dense(model.pi.toArray.map(math.exp)) ~==
    +      Vectors.dense(piData.toArray.map(math.exp)) absTol 0.05, "pi mismatch")
    +    assert(model.theta.map(math.exp) ~== thetaData.map(math.exp) absTol 0.05, "theta mismatch")
    +  }
    +
    +  test("params") {
    +    ParamsSuite.checkParams(new NaiveBayes)
    +    val model = new NaiveBayesModel("nb", pi = Vectors.dense(Array(0.2, 0.8)),
    +      theta = new DenseMatrix(2, 3, Array(0.1, 0.2, 0.3, 0.4, 0.6, 0.4)))
    +    ParamsSuite.checkParams(model)
    +  }
    +
    +  test("naive bayes: default params") {
    +    val nb = new NaiveBayes
    +    assert(nb.getLabelCol === "label")
    +    assert(nb.getFeaturesCol === "features")
    +    assert(nb.getPredictionCol === "prediction")
    +    assert(nb.getLambda === 1.0)
    +    assert(nb.getModelType === "multinomial")
    +  }
    +
    +  test("Naive Bayes Multinomial") {
    +    val nPoints = 1000
    +    val piArray = Array(0.5, 0.1, 0.4).map(math.log)
    +    val thetaArray = Array(
    +      Array(0.70, 0.10, 0.10, 0.10), // label 0
    +      Array(0.10, 0.70, 0.10, 0.10), // label 1
    +      Array(0.10, 0.10, 0.70, 0.10)  // label 2
    +    ).map(_.map(math.log))
    +    val pi = Vectors.dense(piArray)
    +    val theta = new DenseMatrix(3, 4, thetaArray.flatten, true)
    +
    +    val testDataset = sqlContext.createDataFrame(generateNaiveBayesInput(
    +      piArray, thetaArray, nPoints, 42, "multinomial"))
    +    val nb = new NaiveBayes().setLambda(1.0).setModelType("multinomial")
    +    val model = nb.fit(testDataset)
    +
    +    validateModelFit(pi, theta, model)
    +    assert(model.hasParent)
    +
    +    val validationDataset = sqlContext.createDataFrame(generateNaiveBayesInput(
    +      piArray, thetaArray, nPoints, 17, "multinomial"))
    +    val predictionAndLabels = model.transform(validationDataset).select("prediction", "label")
    +
    +    validatePrediction(predictionAndLabels)
    +  }
    +
    +  test("Naive Bayes Bernoulli") {
    +    val nPoints = 10000
    +    val piArray = Array(0.5, 0.3, 0.2).map(math.log)
    +    val thetaArray = Array(
    +      Array(0.50, 0.02, 0.02, 0.02, 0.02, 0.02, 0.02, 0.02, 0.02, 0.02, 0.02, 0.40), // label 0
    +      Array(0.02, 0.70, 0.10, 0.02, 0.02, 0.02, 0.02, 0.02, 0.02, 0.02, 0.02, 0.02), // label 1
    +      Array(0.02, 0.02, 0.60, 0.02, 0.02, 0.02, 0.02, 0.02, 0.02, 0.02, 0.02, 0.30)  // label 2
    +    ).map(_.map(math.log))
    +    val pi = Vectors.dense(piArray)
    +    val theta = new DenseMatrix(3, 12, thetaArray.flatten, true)
    +
    +    val testDataset = sqlContext.createDataFrame(generateNaiveBayesInput(
    +      piArray, thetaArray, nPoints, 45, "bernoulli"))
    +    val nb = new NaiveBayes().setLambda(1.0).setModelType("bernoulli")
    +    val model = nb.fit(testDataset)
    +
    +    validateModelFit(pi, theta, model)
    +    assert(model.hasParent)
    +
    +    val validationDataset = sqlContext.createDataFrame(generateNaiveBayesInput(
    +      piArray, thetaArray, nPoints, 20, "bernoulli"))
    +    val predictionAndLabels = model.transform(validationDataset).select("prediction", "label")
    +
    +    validatePrediction(predictionAndLabels)
    +  }
    +}
    diff --git a/mllib/src/test/scala/org/apache/spark/ml/clustering/KMeansSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/clustering/KMeansSuite.scala
    new file mode 100644
    index 0000000000000..1f15ac02f4008
    --- /dev/null
    +++ b/mllib/src/test/scala/org/apache/spark/ml/clustering/KMeansSuite.scala
    @@ -0,0 +1,114 @@
    +/*
    + * Licensed to the Apache Software Foundation (ASF) under one or more
    + * contributor license agreements.  See the NOTICE file distributed with
    + * this work for additional information regarding copyright ownership.
    + * The ASF licenses this file to You under the Apache License, Version 2.0
    + * (the "License"); you may not use this file except in compliance with
    + * the License.  You may obtain a copy of the License at
    + *
    + *    http://www.apache.org/licenses/LICENSE-2.0
    + *
    + * Unless required by applicable law or agreed to in writing, software
    + * distributed under the License is distributed on an "AS IS" BASIS,
    + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    + * See the License for the specific language governing permissions and
    + * limitations under the License.
    + */
    +
    +package org.apache.spark.ml.clustering
    +
    +import org.apache.spark.SparkFunSuite
    +import org.apache.spark.mllib.clustering.{KMeans => MLlibKMeans}
    +import org.apache.spark.mllib.linalg.{Vector, Vectors}
    +import org.apache.spark.mllib.util.MLlibTestSparkContext
    +import org.apache.spark.sql.{DataFrame, SQLContext}
    +
    +private[clustering] case class TestRow(features: Vector)
    +
    +object KMeansSuite {
    +  def generateKMeansData(sql: SQLContext, rows: Int, dim: Int, k: Int): DataFrame = {
    +    val sc = sql.sparkContext
    +    val rdd = sc.parallelize(1 to rows).map(i => Vectors.dense(Array.fill(dim)((i % k).toDouble)))
    +      .map(v => new TestRow(v))
    +    sql.createDataFrame(rdd)
    +  }
    +}
    +
    +class KMeansSuite extends SparkFunSuite with MLlibTestSparkContext {
    +
    +  final val k = 5
    +  @transient var dataset: DataFrame = _
    +
    +  override def beforeAll(): Unit = {
    +    super.beforeAll()
    +
    +    dataset = KMeansSuite.generateKMeansData(sqlContext, 50, 3, k)
    +  }
    +
    +  test("default parameters") {
    +    val kmeans = new KMeans()
    +
    +    assert(kmeans.getK === 2)
    +    assert(kmeans.getFeaturesCol === "features")
    +    assert(kmeans.getPredictionCol === "prediction")
    +    assert(kmeans.getMaxIter === 20)
    +    assert(kmeans.getRuns === 1)
    +    assert(kmeans.getInitMode === MLlibKMeans.K_MEANS_PARALLEL)
    +    assert(kmeans.getInitSteps === 5)
    +    assert(kmeans.getEpsilon === 1e-4)
    +  }
    +
    +  test("set parameters") {
    +    val kmeans = new KMeans()
    +      .setK(9)
    +      .setFeaturesCol("test_feature")
    +      .setPredictionCol("test_prediction")
    +      .setMaxIter(33)
    +      .setRuns(7)
    +      .setInitMode(MLlibKMeans.RANDOM)
    +      .setInitSteps(3)
    +      .setSeed(123)
    +      .setEpsilon(1e-3)
    +
    +    assert(kmeans.getK === 9)
    +    assert(kmeans.getFeaturesCol === "test_feature")
    +    assert(kmeans.getPredictionCol === "test_prediction")
    +    assert(kmeans.getMaxIter === 33)
    +    assert(kmeans.getRuns === 7)
    +    assert(kmeans.getInitMode === MLlibKMeans.RANDOM)
    +    assert(kmeans.getInitSteps === 3)
    +    assert(kmeans.getSeed === 123)
    +    assert(kmeans.getEpsilon === 1e-3)
    +  }
    +
    +  test("parameters validation") {
    +    intercept[IllegalArgumentException] {
    +      new KMeans().setK(1)
    +    }
    +    intercept[IllegalArgumentException] {
    +      new KMeans().setInitMode("no_such_a_mode")
    +    }
    +    intercept[IllegalArgumentException] {
    +      new KMeans().setInitSteps(0)
    +    }
    +    intercept[IllegalArgumentException] {
    +      new KMeans().setRuns(0)
    +    }
    +  }
    +
    +  test("fit & transform") {
    +    val predictionColName = "kmeans_prediction"
    +    val kmeans = new KMeans().setK(k).setPredictionCol(predictionColName).setSeed(1)
    +    val model = kmeans.fit(dataset)
    +    assert(model.clusterCenters.length === k)
    +
    +    val transformed = model.transform(dataset)
    +    val expectedColumns = Array("features", predictionColName)
    +    expectedColumns.foreach { column =>
    +      assert(transformed.columns.contains(column))
    +    }
    +    val clusters = transformed.select(predictionColName).map(_.getInt(0)).distinct().collect().toSet
    +    assert(clusters.size === k)
    +    assert(clusters === Set(0, 1, 2, 3, 4))
    +  }
    +}
    diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/RFormulaParserSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/RFormulaParserSuite.scala
    new file mode 100644
    index 0000000000000..c8d065f37a605
    --- /dev/null
    +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/RFormulaParserSuite.scala
    @@ -0,0 +1,34 @@
    +/*
    + * Licensed to the Apache Software Foundation (ASF) under one or more
    + * contributor license agreements.  See the NOTICE file distributed with
    + * this work for additional information regarding copyright ownership.
    + * The ASF licenses this file to You under the Apache License, Version 2.0
    + * (the "License"); you may not use this file except in compliance with
    + * the License.  You may obtain a copy of the License at
    + *
    + *    http://www.apache.org/licenses/LICENSE-2.0
    + *
    + * Unless required by applicable law or agreed to in writing, software
    + * distributed under the License is distributed on an "AS IS" BASIS,
    + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    + * See the License for the specific language governing permissions and
    + * limitations under the License.
    + */
    +
    +package org.apache.spark.ml.feature
    +
    +import org.apache.spark.SparkFunSuite
    +
    +class RFormulaParserSuite extends SparkFunSuite {
    +  private def checkParse(formula: String, label: String, terms: Seq[String]) {
    +    val parsed = RFormulaParser.parse(formula)
    +    assert(parsed.label == label)
    +    assert(parsed.terms == terms)
    +  }
    +
    +  test("parse simple formulas") {
    +    checkParse("y ~ x", "y", Seq("x"))
    +    checkParse("y ~   ._foo  ", "y", Seq("._foo"))
    +    checkParse("resp ~ A_VAR + B + c123", "resp", Seq("A_VAR", "B", "c123"))
    +  }
    +}
    diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/RFormulaSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/RFormulaSuite.scala
    new file mode 100644
    index 0000000000000..79c4ccf02d4e0
    --- /dev/null
    +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/RFormulaSuite.scala
    @@ -0,0 +1,102 @@
    +/*
    + * Licensed to the Apache Software Foundation (ASF) under one or more
    + * contributor license agreements.  See the NOTICE file distributed with
    + * this work for additional information regarding copyright ownership.
    + * The ASF licenses this file to You under the Apache License, Version 2.0
    + * (the "License"); you may not use this file except in compliance with
    + * the License.  You may obtain a copy of the License at
    + *
    + *    http://www.apache.org/licenses/LICENSE-2.0
    + *
    + * Unless required by applicable law or agreed to in writing, software
    + * distributed under the License is distributed on an "AS IS" BASIS,
    + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    + * See the License for the specific language governing permissions and
    + * limitations under the License.
    + */
    +
    +package org.apache.spark.ml.feature
    +
    +import org.apache.spark.SparkFunSuite
    +import org.apache.spark.ml.param.ParamsSuite
    +import org.apache.spark.mllib.linalg.Vectors
    +import org.apache.spark.mllib.util.MLlibTestSparkContext
    +
    +class RFormulaSuite extends SparkFunSuite with MLlibTestSparkContext {
    +  test("params") {
    +    ParamsSuite.checkParams(new RFormula())
    +  }
    +
    +  test("transform numeric data") {
    +    val formula = new RFormula().setFormula("id ~ v1 + v2")
    +    val original = sqlContext.createDataFrame(
    +      Seq((0, 1.0, 3.0), (2, 2.0, 5.0))).toDF("id", "v1", "v2")
    +    val result = formula.transform(original)
    +    val resultSchema = formula.transformSchema(original.schema)
    +    val expected = sqlContext.createDataFrame(
    +      Seq(
    +        (0, 1.0, 3.0, Vectors.dense(Array(1.0, 3.0)), 0.0),
    +        (2, 2.0, 5.0, Vectors.dense(Array(2.0, 5.0)), 2.0))
    +      ).toDF("id", "v1", "v2", "features", "label")
    +    // TODO(ekl) make schema comparisons ignore metadata, to avoid .toString
    +    assert(result.schema.toString == resultSchema.toString)
    +    assert(resultSchema == expected.schema)
    +    assert(result.collect().toSeq == expected.collect().toSeq)
    +  }
    +
    +  test("features column already exists") {
    +    val formula = new RFormula().setFormula("y ~ x").setFeaturesCol("x")
    +    val original = sqlContext.createDataFrame(Seq((0, 1.0), (2, 2.0))).toDF("x", "y")
    +    intercept[IllegalArgumentException] {
    +      formula.transformSchema(original.schema)
    +    }
    +    intercept[IllegalArgumentException] {
    +      formula.transform(original)
    +    }
    +  }
    +
    +  test("label column already exists") {
    +    val formula = new RFormula().setFormula("y ~ x").setLabelCol("y")
    +    val original = sqlContext.createDataFrame(Seq((0, 1.0), (2, 2.0))).toDF("x", "y")
    +    val resultSchema = formula.transformSchema(original.schema)
    +    assert(resultSchema.length == 3)
    +    assert(resultSchema.toString == formula.transform(original).schema.toString)
    +  }
    +
    +  test("label column already exists but is not double type") {
    +    val formula = new RFormula().setFormula("y ~ x").setLabelCol("y")
    +    val original = sqlContext.createDataFrame(Seq((0, 1), (2, 2))).toDF("x", "y")
    +    intercept[IllegalArgumentException] {
    +      formula.transformSchema(original.schema)
    +    }
    +    intercept[IllegalArgumentException] {
    +      formula.transform(original)
    +    }
    +  }
    +
    +  test("allow missing label column for test datasets") {
    +    val formula = new RFormula().setFormula("y ~ x").setLabelCol("label")
    +    val original = sqlContext.createDataFrame(Seq((0, 1.0), (2, 2.0))).toDF("x", "_not_y")
    +    val resultSchema = formula.transformSchema(original.schema)
    +    assert(resultSchema.length == 3)
    +    assert(!resultSchema.exists(_.name == "label"))
    +    assert(resultSchema.toString == formula.transform(original).schema.toString)
    +  }
    +
    +// TODO(ekl) enable after we implement string label support
    +//  test("transform string label") {
    +//    val formula = new RFormula().setFormula("name ~ id")
    +//    val original = sqlContext.createDataFrame(
    +//      Seq((1, "foo"), (2, "bar"), (3, "bar"))).toDF("id", "name")
    +//    val result = formula.transform(original)
    +//    val resultSchema = formula.transformSchema(original.schema)
    +//    val expected = sqlContext.createDataFrame(
    +//      Seq(
    +//        (1, "foo", Vectors.dense(Array(1.0)), 1.0),
    +//        (2, "bar", Vectors.dense(Array(2.0)), 0.0),
    +//        (3, "bar", Vectors.dense(Array(3.0)), 0.0))
    +//      ).toDF("id", "name", "features", "label")
    +//    assert(result.schema.toString == resultSchema.toString)
    +//    assert(result.collect().toSeq == expected.collect().toSeq)
    +//  }
    +}
    diff --git a/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala
    index cf120cf2a4b47..7cdda3db88ad1 100644
    --- a/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala
    +++ b/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala
    @@ -18,6 +18,7 @@
     package org.apache.spark.ml.regression
     
     import org.apache.spark.SparkFunSuite
    +import org.apache.spark.ml.param.ParamsSuite
     import org.apache.spark.mllib.linalg.{DenseVector, Vectors}
     import org.apache.spark.mllib.util.{LinearDataGenerator, MLlibTestSparkContext}
     import org.apache.spark.mllib.util.TestingUtils._
    @@ -55,6 +56,30 @@ class LinearRegressionSuite extends SparkFunSuite with MLlibTestSparkContext {
     
       }
     
    +  test("params") {
    +    ParamsSuite.checkParams(new LinearRegression)
    +    val model = new LinearRegressionModel("linearReg", Vectors.dense(0.0), 0.0)
    +    ParamsSuite.checkParams(model)
    +  }
    +
    +  test("linear regression: default params") {
    +    val lir = new LinearRegression
    +    assert(lir.getLabelCol === "label")
    +    assert(lir.getFeaturesCol === "features")
    +    assert(lir.getPredictionCol === "prediction")
    +    assert(lir.getRegParam === 0.0)
    +    assert(lir.getElasticNetParam === 0.0)
    +    assert(lir.getFitIntercept)
    +    val model = lir.fit(dataset)
    +    model.transform(dataset)
    +      .select("label", "prediction")
    +      .collect()
    +    assert(model.getFeaturesCol === "features")
    +    assert(model.getPredictionCol === "prediction")
    +    assert(model.intercept !== 0.0)
    +    assert(model.hasParent)
    +  }
    +
       test("linear regression with intercept without regularization") {
         val trainer = new LinearRegression
         val model = trainer.fit(dataset)
    @@ -302,7 +327,7 @@ class LinearRegressionSuite extends SparkFunSuite with MLlibTestSparkContext {
           .map { case Row(features: DenseVector, label: Double) =>
           val prediction =
             features(0) * model.weights(0) + features(1) * model.weights(1) + model.intercept
    -      prediction - label
    +      label - prediction
         }
           .zip(model.summary.residuals.map(_.getDouble(0)))
           .collect()
    @@ -314,7 +339,7 @@ class LinearRegressionSuite extends SparkFunSuite with MLlibTestSparkContext {
            Use the following R code to generate model training results.
     
            predictions <- predict(fit, newx=features)
    -       residuals <- predictions - label
    +       residuals <- label - predictions
            > mean(residuals^2) # MSE
            [1] 0.009720325
            > mean(abs(residuals)) # MAD
    diff --git a/mllib/src/test/scala/org/apache/spark/ml/tuning/TrainValidationSplitSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/tuning/TrainValidationSplitSuite.scala
    new file mode 100644
    index 0000000000000..c8e58f216cceb
    --- /dev/null
    +++ b/mllib/src/test/scala/org/apache/spark/ml/tuning/TrainValidationSplitSuite.scala
    @@ -0,0 +1,139 @@
    +/*
    + * Licensed to the Apache Software Foundation (ASF) under one or more
    + * contributor license agreements.  See the NOTICE file distributed with
    + * this work for additional information regarding copyright ownership.
    + * The ASF licenses this file to You under the Apache License, Version 2.0
    + * (the "License"); you may not use this file except in compliance with
    + * the License.  You may obtain a copy of the License at
    + *
    + *    http://www.apache.org/licenses/LICENSE-2.0
    + *
    + * Unless required by applicable law or agreed to in writing, software
    + * distributed under the License is distributed on an "AS IS" BASIS,
    + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    + * See the License for the specific language governing permissions and
    + * limitations under the License.
    + */
    +
    +package org.apache.spark.ml.tuning
    +
    +import org.apache.spark.SparkFunSuite
    +import org.apache.spark.ml.{Estimator, Model}
    +import org.apache.spark.ml.classification.LogisticRegression
    +import org.apache.spark.ml.evaluation.{BinaryClassificationEvaluator, Evaluator, RegressionEvaluator}
    +import org.apache.spark.ml.param.ParamMap
    +import org.apache.spark.ml.param.shared.HasInputCol
    +import org.apache.spark.ml.regression.LinearRegression
    +import org.apache.spark.mllib.classification.LogisticRegressionSuite.generateLogisticInput
    +import org.apache.spark.mllib.util.{LinearDataGenerator, MLlibTestSparkContext}
    +import org.apache.spark.sql.DataFrame
    +import org.apache.spark.sql.types.StructType
    +
    +class TrainValidationSplitSuite extends SparkFunSuite with MLlibTestSparkContext {
    +  test("train validation with logistic regression") {
    +    val dataset = sqlContext.createDataFrame(
    +      sc.parallelize(generateLogisticInput(1.0, 1.0, 100, 42), 2))
    +
    +    val lr = new LogisticRegression
    +    val lrParamMaps = new ParamGridBuilder()
    +      .addGrid(lr.regParam, Array(0.001, 1000.0))
    +      .addGrid(lr.maxIter, Array(0, 10))
    +      .build()
    +    val eval = new BinaryClassificationEvaluator
    +    val cv = new TrainValidationSplit()
    +      .setEstimator(lr)
    +      .setEstimatorParamMaps(lrParamMaps)
    +      .setEvaluator(eval)
    +      .setTrainRatio(0.5)
    +    val cvModel = cv.fit(dataset)
    +    val parent = cvModel.bestModel.parent.asInstanceOf[LogisticRegression]
    +    assert(cv.getTrainRatio === 0.5)
    +    assert(parent.getRegParam === 0.001)
    +    assert(parent.getMaxIter === 10)
    +    assert(cvModel.validationMetrics.length === lrParamMaps.length)
    +  }
    +
    +  test("train validation with linear regression") {
    +    val dataset = sqlContext.createDataFrame(
    +        sc.parallelize(LinearDataGenerator.generateLinearInput(
    +            6.3, Array(4.7, 7.2), Array(0.9, -1.3), Array(0.7, 1.2), 100, 42, 0.1), 2))
    +
    +    val trainer = new LinearRegression
    +    val lrParamMaps = new ParamGridBuilder()
    +      .addGrid(trainer.regParam, Array(1000.0, 0.001))
    +      .addGrid(trainer.maxIter, Array(0, 10))
    +      .build()
    +    val eval = new RegressionEvaluator()
    +    val cv = new TrainValidationSplit()
    +      .setEstimator(trainer)
    +      .setEstimatorParamMaps(lrParamMaps)
    +      .setEvaluator(eval)
    +      .setTrainRatio(0.5)
    +    val cvModel = cv.fit(dataset)
    +    val parent = cvModel.bestModel.parent.asInstanceOf[LinearRegression]
    +    assert(parent.getRegParam === 0.001)
    +    assert(parent.getMaxIter === 10)
    +    assert(cvModel.validationMetrics.length === lrParamMaps.length)
    +
    +      eval.setMetricName("r2")
    +    val cvModel2 = cv.fit(dataset)
    +    val parent2 = cvModel2.bestModel.parent.asInstanceOf[LinearRegression]
    +    assert(parent2.getRegParam === 0.001)
    +    assert(parent2.getMaxIter === 10)
    +    assert(cvModel2.validationMetrics.length === lrParamMaps.length)
    +  }
    +
    +  test("validateParams should check estimatorParamMaps") {
    +    import TrainValidationSplitSuite._
    +
    +    val est = new MyEstimator("est")
    +    val eval = new MyEvaluator
    +    val paramMaps = new ParamGridBuilder()
    +      .addGrid(est.inputCol, Array("input1", "input2"))
    +      .build()
    +
    +    val cv = new TrainValidationSplit()
    +      .setEstimator(est)
    +      .setEstimatorParamMaps(paramMaps)
    +      .setEvaluator(eval)
    +      .setTrainRatio(0.5)
    +    cv.validateParams() // This should pass.
    +
    +    val invalidParamMaps = paramMaps :+ ParamMap(est.inputCol -> "")
    +    cv.setEstimatorParamMaps(invalidParamMaps)
    +    intercept[IllegalArgumentException] {
    +      cv.validateParams()
    +    }
    +  }
    +}
    +
    +object TrainValidationSplitSuite {
    +
    +  abstract class MyModel extends Model[MyModel]
    +
    +  class MyEstimator(override val uid: String) extends Estimator[MyModel] with HasInputCol {
    +
    +    override def validateParams(): Unit = require($(inputCol).nonEmpty)
    +
    +    override def fit(dataset: DataFrame): MyModel = {
    +      throw new UnsupportedOperationException
    +    }
    +
    +    override def transformSchema(schema: StructType): StructType = {
    +      throw new UnsupportedOperationException
    +    }
    +
    +    override def copy(extra: ParamMap): MyEstimator = defaultCopy(extra)
    +  }
    +
    +  class MyEvaluator extends Evaluator {
    +
    +    override def evaluate(dataset: DataFrame): Double = {
    +      throw new UnsupportedOperationException
    +    }
    +
    +    override val uid: String = "eval"
    +
    +    override def copy(extra: ParamMap): MyEvaluator = defaultCopy(extra)
    +  }
    +}
    diff --git a/mllib/src/test/scala/org/apache/spark/ml/util/StopwatchSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/util/StopwatchSuite.scala
    new file mode 100644
    index 0000000000000..9e6bc7193c13b
    --- /dev/null
    +++ b/mllib/src/test/scala/org/apache/spark/ml/util/StopwatchSuite.scala
    @@ -0,0 +1,125 @@
    +/*
    + * Licensed to the Apache Software Foundation (ASF) under one or more
    + * contributor license agreements.  See the NOTICE file distributed with
    + * this work for additional information regarding copyright ownership.
    + * The ASF licenses this file to You under the Apache License, Version 2.0
    + * (the "License"); you may not use this file except in compliance with
    + * the License.  You may obtain a copy of the License at
    + *
    + *    http://www.apache.org/licenses/LICENSE-2.0
    + *
    + * Unless required by applicable law or agreed to in writing, software
    + * distributed under the License is distributed on an "AS IS" BASIS,
    + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    + * See the License for the specific language governing permissions and
    + * limitations under the License.
    + */
    +
    +package org.apache.spark.ml.util
    +
    +import java.util.Random
    +
    +import org.apache.spark.SparkFunSuite
    +import org.apache.spark.mllib.util.MLlibTestSparkContext
    +
    +class StopwatchSuite extends SparkFunSuite with MLlibTestSparkContext {
    +
    +  import StopwatchSuite._
    +
    +  private def testStopwatchOnDriver(sw: Stopwatch): Unit = {
    +    assert(sw.name === "sw")
    +    assert(sw.elapsed() === 0L)
    +    assert(!sw.isRunning)
    +    intercept[AssertionError] {
    +      sw.stop()
    +    }
    +    val duration = checkStopwatch(sw)
    +    val elapsed = sw.elapsed()
    +    assert(elapsed === duration)
    +    val duration2 = checkStopwatch(sw)
    +    val elapsed2 = sw.elapsed()
    +    assert(elapsed2 === duration + duration2)
    +    assert(sw.toString === s"sw: ${elapsed2}ms")
    +    sw.start()
    +    assert(sw.isRunning)
    +    intercept[AssertionError] {
    +      sw.start()
    +    }
    +  }
    +
    +  test("LocalStopwatch") {
    +    val sw = new LocalStopwatch("sw")
    +    testStopwatchOnDriver(sw)
    +  }
    +
    +  test("DistributedStopwatch on driver") {
    +    val sw = new DistributedStopwatch(sc, "sw")
    +    testStopwatchOnDriver(sw)
    +  }
    +
    +  test("DistributedStopwatch on executors") {
    +    val sw = new DistributedStopwatch(sc, "sw")
    +    val rdd = sc.parallelize(0 until 4, 4)
    +    val acc = sc.accumulator(0L)
    +    rdd.foreach { i =>
    +      acc += checkStopwatch(sw)
    +    }
    +    assert(!sw.isRunning)
    +    val elapsed = sw.elapsed()
    +    assert(elapsed === acc.value)
    +  }
    +
    +  test("MultiStopwatch") {
    +    val sw = new MultiStopwatch(sc)
    +      .addLocal("local")
    +      .addDistributed("spark")
    +    assert(sw("local").name === "local")
    +    assert(sw("spark").name === "spark")
    +    intercept[NoSuchElementException] {
    +      sw("some")
    +    }
    +    assert(sw.toString === "{\n  local: 0ms,\n  spark: 0ms\n}")
    +    val localDuration = checkStopwatch(sw("local"))
    +    val sparkDuration = checkStopwatch(sw("spark"))
    +    val localElapsed = sw("local").elapsed()
    +    val sparkElapsed = sw("spark").elapsed()
    +    assert(localElapsed === localDuration)
    +    assert(sparkElapsed === sparkDuration)
    +    assert(sw.toString ===
    +      s"{\n  local: ${localElapsed}ms,\n  spark: ${sparkElapsed}ms\n}")
    +    val rdd = sc.parallelize(0 until 4, 4)
    +    val acc = sc.accumulator(0L)
    +    rdd.foreach { i =>
    +      sw("local").start()
    +      val duration = checkStopwatch(sw("spark"))
    +      sw("local").stop()
    +      acc += duration
    +    }
    +    val localElapsed2 = sw("local").elapsed()
    +    assert(localElapsed2 === localElapsed)
    +    val sparkElapsed2 = sw("spark").elapsed()
    +    assert(sparkElapsed2 === sparkElapsed + acc.value)
    +  }
    +}
    +
    +private object StopwatchSuite extends SparkFunSuite {
    +
    +  /**
    +   * Checks the input stopwatch on a task that takes a random time (<10ms) to finish. Validates and
    +   * returns the duration reported by the stopwatch.
    +   */
    +  def checkStopwatch(sw: Stopwatch): Long = {
    +    val ubStart = now
    +    sw.start()
    +    val lbStart = now
    +    Thread.sleep(new Random().nextInt(10))
    +    val lb = now - lbStart
    +    val duration = sw.stop()
    +    val ub = now - ubStart
    +    assert(duration >= lb && duration <= ub)
    +    duration
    +  }
    +
    +  /** The current time in milliseconds. */
    +  private def now: Long = System.currentTimeMillis()
    +}
    diff --git a/mllib/src/test/scala/org/apache/spark/mllib/classification/NaiveBayesSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/classification/NaiveBayesSuite.scala
    index f7fc8730606af..cffa1ab700f80 100644
    --- a/mllib/src/test/scala/org/apache/spark/mllib/classification/NaiveBayesSuite.scala
    +++ b/mllib/src/test/scala/org/apache/spark/mllib/classification/NaiveBayesSuite.scala
    @@ -19,13 +19,14 @@ package org.apache.spark.mllib.classification
     
     import scala.util.Random
     
    -import breeze.linalg.{DenseMatrix => BDM, DenseVector => BDV, argmax => brzArgmax, sum => brzSum}
    +import breeze.linalg.{DenseMatrix => BDM, DenseVector => BDV, Vector => BV}
     import breeze.stats.distributions.{Multinomial => BrzMultinomial}
     
     import org.apache.spark.{SparkException, SparkFunSuite}
    -import org.apache.spark.mllib.linalg.Vectors
    +import org.apache.spark.mllib.linalg.{Vector, Vectors}
     import org.apache.spark.mllib.regression.LabeledPoint
     import org.apache.spark.mllib.util.{LocalClusterSparkContext, MLlibTestSparkContext}
    +import org.apache.spark.mllib.util.TestingUtils._
     import org.apache.spark.util.Utils
     
     object NaiveBayesSuite {
    @@ -154,6 +155,29 @@ class NaiveBayesSuite extends SparkFunSuite with MLlibTestSparkContext {
     
         // Test prediction on Array.
         validatePrediction(validationData.map(row => model.predict(row.features)), validationData)
    +
    +    // Test posteriors
    +    validationData.map(_.features).foreach { features =>
    +      val predicted = model.predictProbabilities(features).toArray
    +      assert(predicted.sum ~== 1.0 relTol 1.0e-10)
    +      val expected = expectedMultinomialProbabilities(model, features)
    +      expected.zip(predicted).foreach { case (e, p) => assert(e ~== p relTol 1.0e-10) }
    +    }
    +  }
    +
    +  /**
    +   * @param model Multinomial Naive Bayes model
    +   * @param testData input to compute posterior probabilities for
    +   * @return posterior class probabilities (in order of labels) for input
    +   */
    +  private def expectedMultinomialProbabilities(model: NaiveBayesModel, testData: Vector) = {
    +    val piVector = new BDV(model.pi)
    +    // model.theta is row-major; treat it as col-major representation of transpose, and transpose:
    +    val thetaMatrix = new BDM(model.theta(0).length, model.theta.length, model.theta.flatten).t
    +    val logClassProbs: BV[Double] = piVector + (thetaMatrix * testData.toBreeze)
    +    val classProbs = logClassProbs.toArray.map(math.exp)
    +    val classProbsSum = classProbs.sum
    +    classProbs.map(_ / classProbsSum)
       }
     
       test("Naive Bayes Bernoulli") {
    @@ -182,6 +206,33 @@ class NaiveBayesSuite extends SparkFunSuite with MLlibTestSparkContext {
     
         // Test prediction on Array.
         validatePrediction(validationData.map(row => model.predict(row.features)), validationData)
    +
    +    // Test posteriors
    +    validationData.map(_.features).foreach { features =>
    +      val predicted = model.predictProbabilities(features).toArray
    +      assert(predicted.sum ~== 1.0 relTol 1.0e-10)
    +      val expected = expectedBernoulliProbabilities(model, features)
    +      expected.zip(predicted).foreach { case (e, p) => assert(e ~== p relTol 1.0e-10) }
    +    }
    +  }
    +
    +  /**
    +   * @param model Bernoulli Naive Bayes model
    +   * @param testData input to compute posterior probabilities for
    +   * @return posterior class probabilities (in order of labels) for input
    +   */
    +  private def expectedBernoulliProbabilities(model: NaiveBayesModel, testData: Vector) = {
    +    val piVector = new BDV(model.pi)
    +    val thetaMatrix = new BDM(model.theta(0).length, model.theta.length, model.theta.flatten).t
    +    val negThetaMatrix = new BDM(model.theta(0).length, model.theta.length,
    +      model.theta.flatten.map(v => math.log(1.0 - math.exp(v)))).t
    +    val testBreeze = testData.toBreeze
    +    val negTestBreeze = new BDV(Array.fill(testBreeze.size)(1.0)) - testBreeze
    +    val piTheta: BV[Double] = piVector + (thetaMatrix * testBreeze)
    +    val logClassProbs: BV[Double] = piTheta + (negThetaMatrix * negTestBreeze)
    +    val classProbs = logClassProbs.toArray.map(math.exp)
    +    val classProbsSum = classProbs.sum
    +    classProbs.map(_ / classProbsSum)
       }
     
       test("detect negative values") {
    diff --git a/mllib/src/test/scala/org/apache/spark/mllib/clustering/KMeansSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/clustering/KMeansSuite.scala
    index 0dbbd7127444f..3003c62d9876c 100644
    --- a/mllib/src/test/scala/org/apache/spark/mllib/clustering/KMeansSuite.scala
    +++ b/mllib/src/test/scala/org/apache/spark/mllib/clustering/KMeansSuite.scala
    @@ -278,6 +278,28 @@ class KMeansSuite extends SparkFunSuite with MLlibTestSparkContext {
           }
         }
       }
    +
    +  test("Initialize using given cluster centers") {
    +    val points = Seq(
    +      Vectors.dense(0.0, 0.0),
    +      Vectors.dense(1.0, 0.0),
    +      Vectors.dense(0.0, 1.0),
    +      Vectors.dense(1.0, 1.0)
    +    )
    +    val rdd = sc.parallelize(points, 3)
    +    // creating an initial model
    +    val initialModel = new KMeansModel(Array(points(0), points(2)))
    +
    +    val returnModel = new KMeans()
    +      .setK(2)
    +      .setMaxIterations(0)
    +      .setInitialModel(initialModel)
    +      .run(rdd)
    +   // comparing the returned model and the initial model
    +    assert(returnModel.clusterCenters(0) === initialModel.clusterCenters(0))
    +    assert(returnModel.clusterCenters(1) === initialModel.clusterCenters(1))
    +  }
    +
     }
     
     object KMeansSuite extends SparkFunSuite {
    diff --git a/mllib/src/test/scala/org/apache/spark/mllib/clustering/LDASuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/clustering/LDASuite.scala
    index 03a8a2538b464..da70d9bd7c790 100644
    --- a/mllib/src/test/scala/org/apache/spark/mllib/clustering/LDASuite.scala
    +++ b/mllib/src/test/scala/org/apache/spark/mllib/clustering/LDASuite.scala
    @@ -20,9 +20,10 @@ package org.apache.spark.mllib.clustering
     import breeze.linalg.{DenseMatrix => BDM}
     
     import org.apache.spark.SparkFunSuite
    -import org.apache.spark.mllib.linalg.{Vector, DenseMatrix, Matrix, Vectors}
    +import org.apache.spark.mllib.linalg.{DenseMatrix, Matrix, Vector, Vectors}
     import org.apache.spark.mllib.util.MLlibTestSparkContext
     import org.apache.spark.mllib.util.TestingUtils._
    +import org.apache.spark.util.Utils
     
     class LDASuite extends SparkFunSuite with MLlibTestSparkContext {
     
    @@ -131,22 +132,38 @@ class LDASuite extends SparkFunSuite with MLlibTestSparkContext {
     
       test("setter alias") {
         val lda = new LDA().setAlpha(2.0).setBeta(3.0)
    -    assert(lda.getAlpha === 2.0)
    -    assert(lda.getDocConcentration === 2.0)
    +    assert(lda.getAlpha.toArray.forall(_ === 2.0))
    +    assert(lda.getDocConcentration.toArray.forall(_ === 2.0))
         assert(lda.getBeta === 3.0)
         assert(lda.getTopicConcentration === 3.0)
       }
     
    +  test("initializing with alpha length != k or 1 fails") {
    +    intercept[IllegalArgumentException] {
    +      val lda = new LDA().setK(2).setAlpha(Vectors.dense(1, 2, 3, 4))
    +      val corpus = sc.parallelize(tinyCorpus, 2)
    +      lda.run(corpus)
    +    }
    +  }
    +
    +  test("initializing with elements in alpha < 0 fails") {
    +    intercept[IllegalArgumentException] {
    +      val lda = new LDA().setK(4).setAlpha(Vectors.dense(-1, 2, 3, 4))
    +      val corpus = sc.parallelize(tinyCorpus, 2)
    +      lda.run(corpus)
    +    }
    +  }
    +
       test("OnlineLDAOptimizer initialization") {
         val lda = new LDA().setK(2)
         val corpus = sc.parallelize(tinyCorpus, 2)
         val op = new OnlineLDAOptimizer().initialize(corpus, lda)
         op.setKappa(0.9876).setMiniBatchFraction(0.123).setTau0(567)
    -    assert(op.getAlpha == 0.5) // default 1.0 / k
    -    assert(op.getEta == 0.5)   // default 1.0 / k
    -    assert(op.getKappa == 0.9876)
    -    assert(op.getMiniBatchFraction == 0.123)
    -    assert(op.getTau0 == 567)
    +    assert(op.getAlpha.toArray.forall(_ === 0.5)) // default 1.0 / k
    +    assert(op.getEta === 0.5)   // default 1.0 / k
    +    assert(op.getKappa === 0.9876)
    +    assert(op.getMiniBatchFraction === 0.123)
    +    assert(op.getTau0 === 567)
       }
     
       test("OnlineLDAOptimizer one iteration") {
    @@ -217,6 +234,96 @@ class LDASuite extends SparkFunSuite with MLlibTestSparkContext {
         }
       }
     
    +  test("OnlineLDAOptimizer with asymmetric prior") {
    +    def toydata: Array[(Long, Vector)] = Array(
    +      Vectors.sparse(6, Array(0, 1), Array(1, 1)),
    +      Vectors.sparse(6, Array(1, 2), Array(1, 1)),
    +      Vectors.sparse(6, Array(0, 2), Array(1, 1)),
    +      Vectors.sparse(6, Array(3, 4), Array(1, 1)),
    +      Vectors.sparse(6, Array(3, 5), Array(1, 1)),
    +      Vectors.sparse(6, Array(4, 5), Array(1, 1))
    +    ).zipWithIndex.map { case (wordCounts, docId) => (docId.toLong, wordCounts) }
    +
    +    val docs = sc.parallelize(toydata)
    +    val op = new OnlineLDAOptimizer().setMiniBatchFraction(1).setTau0(1024).setKappa(0.51)
    +      .setGammaShape(1e10)
    +    val lda = new LDA().setK(2)
    +      .setDocConcentration(Vectors.dense(0.00001, 0.1))
    +      .setTopicConcentration(0.01)
    +      .setMaxIterations(100)
    +      .setOptimizer(op)
    +      .setSeed(12345)
    +
    +    val ldaModel = lda.run(docs)
    +    val topicIndices = ldaModel.describeTopics(maxTermsPerTopic = 10)
    +    val topics = topicIndices.map { case (terms, termWeights) =>
    +      terms.zip(termWeights)
    +    }
    +
    +    /* Verify results with Python:
    +
    +       import numpy as np
    +       from gensim import models
    +       corpus = [
    +           [(0, 1.0), (1, 1.0)],
    +           [(1, 1.0), (2, 1.0)],
    +           [(0, 1.0), (2, 1.0)],
    +           [(3, 1.0), (4, 1.0)],
    +           [(3, 1.0), (5, 1.0)],
    +           [(4, 1.0), (5, 1.0)]]
    +       np.random.seed(10)
    +       lda = models.ldamodel.LdaModel(
    +           corpus=corpus, alpha=np.array([0.00001, 0.1]), num_topics=2, update_every=0, passes=100)
    +       lda.print_topics()
    +
    +       > ['0.167*0 + 0.167*1 + 0.167*2 + 0.167*3 + 0.167*4 + 0.167*5',
    +          '0.167*0 + 0.167*1 + 0.167*2 + 0.167*4 + 0.167*3 + 0.167*5']
    +     */
    +    topics.foreach { topic =>
    +      assert(topic.forall { case (_, p) => p ~= 0.167 absTol 0.05 })
    +    }
    +  }
    +
    +  test("model save/load") {
    +    // Test for LocalLDAModel.
    +    val localModel = new LocalLDAModel(tinyTopics)
    +    val tempDir1 = Utils.createTempDir()
    +    val path1 = tempDir1.toURI.toString
    +
    +    // Test for DistributedLDAModel.
    +    val k = 3
    +    val docConcentration = 1.2
    +    val topicConcentration = 1.5
    +    val lda = new LDA()
    +    lda.setK(k)
    +      .setDocConcentration(docConcentration)
    +      .setTopicConcentration(topicConcentration)
    +      .setMaxIterations(5)
    +      .setSeed(12345)
    +    val corpus = sc.parallelize(tinyCorpus, 2)
    +    val distributedModel: DistributedLDAModel = lda.run(corpus).asInstanceOf[DistributedLDAModel]
    +    val tempDir2 = Utils.createTempDir()
    +    val path2 = tempDir2.toURI.toString
    +
    +    try {
    +      localModel.save(sc, path1)
    +      distributedModel.save(sc, path2)
    +      val samelocalModel = LocalLDAModel.load(sc, path1)
    +      assert(samelocalModel.topicsMatrix === localModel.topicsMatrix)
    +      assert(samelocalModel.k === localModel.k)
    +      assert(samelocalModel.vocabSize === localModel.vocabSize)
    +
    +      val sameDistributedModel = DistributedLDAModel.load(sc, path2)
    +      assert(distributedModel.topicsMatrix === sameDistributedModel.topicsMatrix)
    +      assert(distributedModel.k === sameDistributedModel.k)
    +      assert(distributedModel.vocabSize === sameDistributedModel.vocabSize)
    +      assert(distributedModel.iterationTimes === sameDistributedModel.iterationTimes)
    +    } finally {
    +      Utils.deleteRecursively(tempDir1)
    +      Utils.deleteRecursively(tempDir2)
    +    }
    +  }
    +
     }
     
     private[clustering] object LDASuite {
    diff --git a/mllib/src/test/scala/org/apache/spark/mllib/clustering/PowerIterationClusteringSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/clustering/PowerIterationClusteringSuite.scala
    index 19e65f1b53ab5..189000512155f 100644
    --- a/mllib/src/test/scala/org/apache/spark/mllib/clustering/PowerIterationClusteringSuite.scala
    +++ b/mllib/src/test/scala/org/apache/spark/mllib/clustering/PowerIterationClusteringSuite.scala
    @@ -68,6 +68,54 @@ class PowerIterationClusteringSuite extends SparkFunSuite with MLlibTestSparkCon
         assert(predictions2.toSet == Set((0 to 3).toSet, (4 to 15).toSet))
       }
     
    +  test("power iteration clustering on graph") {
    +    /*
    +     We use the following graph to test PIC. All edges are assigned similarity 1.0 except 0.1 for
    +     edge (3, 4).
    +
    +     15-14 -13 -12
    +     |           |
    +     4 . 3 - 2  11
    +     |   | x |   |
    +     5   0 - 1  10
    +     |           |
    +     6 - 7 - 8 - 9
    +     */
    +
    +    val similarities = Seq[(Long, Long, Double)]((0, 1, 1.0), (0, 2, 1.0), (0, 3, 1.0), (1, 2, 1.0),
    +      (1, 3, 1.0), (2, 3, 1.0), (3, 4, 0.1), // (3, 4) is a weak edge
    +      (4, 5, 1.0), (4, 15, 1.0), (5, 6, 1.0), (6, 7, 1.0), (7, 8, 1.0), (8, 9, 1.0), (9, 10, 1.0),
    +      (10, 11, 1.0), (11, 12, 1.0), (12, 13, 1.0), (13, 14, 1.0), (14, 15, 1.0))
    +
    +    val edges = similarities.flatMap { case (i, j, s) =>
    +      if (i != j) {
    +        Seq(Edge(i, j, s), Edge(j, i, s))
    +      } else {
    +        None
    +      }
    +    }
    +    val graph = Graph.fromEdges(sc.parallelize(edges, 2), 0.0)
    +
    +    val model = new PowerIterationClustering()
    +      .setK(2)
    +      .run(graph)
    +    val predictions = Array.fill(2)(mutable.Set.empty[Long])
    +    model.assignments.collect().foreach { a =>
    +      predictions(a.cluster) += a.id
    +    }
    +    assert(predictions.toSet == Set((0 to 3).toSet, (4 to 15).toSet))
    +
    +    val model2 = new PowerIterationClustering()
    +      .setK(2)
    +      .setInitializationMode("degree")
    +      .run(sc.parallelize(similarities, 2))
    +    val predictions2 = Array.fill(2)(mutable.Set.empty[Long])
    +    model2.assignments.collect().foreach { a =>
    +      predictions2(a.cluster) += a.id
    +    }
    +    assert(predictions2.toSet == Set((0 to 3).toSet, (4 to 15).toSet))
    +  }
    +
       test("normalize and powerIter") {
         /*
          Test normalize() with the following graph:
    diff --git a/mllib/src/test/scala/org/apache/spark/mllib/evaluation/RegressionMetricsSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/evaluation/RegressionMetricsSuite.scala
    index 9de2bdb6d7246..4b7f1be58f99b 100644
    --- a/mllib/src/test/scala/org/apache/spark/mllib/evaluation/RegressionMetricsSuite.scala
    +++ b/mllib/src/test/scala/org/apache/spark/mllib/evaluation/RegressionMetricsSuite.scala
    @@ -23,24 +23,85 @@ import org.apache.spark.mllib.util.TestingUtils._
     
     class RegressionMetricsSuite extends SparkFunSuite with MLlibTestSparkContext {
     
    -  test("regression metrics") {
    +  test("regression metrics for unbiased (includes intercept term) predictor") {
    +    /* Verify results in R:
    +       preds = c(2.25, -0.25, 1.75, 7.75)
    +       obs = c(3.0, -0.5, 2.0, 7.0)
    +
    +       SStot = sum((obs - mean(obs))^2)
    +       SSreg = sum((preds - mean(obs))^2)
    +       SSerr = sum((obs - preds)^2)
    +
    +       explainedVariance = SSreg / length(obs)
    +       explainedVariance
    +       > [1] 8.796875
    +       meanAbsoluteError = mean(abs(preds - obs))
    +       meanAbsoluteError
    +       > [1] 0.5
    +       meanSquaredError = mean((preds - obs)^2)
    +       meanSquaredError
    +       > [1] 0.3125
    +       rmse = sqrt(meanSquaredError)
    +       rmse
    +       > [1] 0.559017
    +       r2 = 1 - SSerr / SStot
    +       r2
    +       > [1] 0.9571734
    +     */
    +    val predictionAndObservations = sc.parallelize(
    +      Seq((2.25, 3.0), (-0.25, -0.5), (1.75, 2.0), (7.75, 7.0)), 2)
    +    val metrics = new RegressionMetrics(predictionAndObservations)
    +    assert(metrics.explainedVariance ~== 8.79687 absTol 1E-5,
    +      "explained variance regression score mismatch")
    +    assert(metrics.meanAbsoluteError ~== 0.5 absTol 1E-5, "mean absolute error mismatch")
    +    assert(metrics.meanSquaredError ~== 0.3125 absTol 1E-5, "mean squared error mismatch")
    +    assert(metrics.rootMeanSquaredError ~== 0.55901 absTol 1E-5,
    +      "root mean squared error mismatch")
    +    assert(metrics.r2 ~== 0.95717 absTol 1E-5, "r2 score mismatch")
    +  }
    +
    +  test("regression metrics for biased (no intercept term) predictor") {
    +    /* Verify results in R:
    +       preds = c(2.5, 0.0, 2.0, 8.0)
    +       obs = c(3.0, -0.5, 2.0, 7.0)
    +
    +       SStot = sum((obs - mean(obs))^2)
    +       SSreg = sum((preds - mean(obs))^2)
    +       SSerr = sum((obs - preds)^2)
    +
    +       explainedVariance = SSreg / length(obs)
    +       explainedVariance
    +       > [1] 8.859375
    +       meanAbsoluteError = mean(abs(preds - obs))
    +       meanAbsoluteError
    +       > [1] 0.5
    +       meanSquaredError = mean((preds - obs)^2)
    +       meanSquaredError
    +       > [1] 0.375
    +       rmse = sqrt(meanSquaredError)
    +       rmse
    +       > [1] 0.6123724
    +       r2 = 1 - SSerr / SStot
    +       r2
    +       > [1] 0.9486081
    +     */
         val predictionAndObservations = sc.parallelize(
           Seq((2.5, 3.0), (0.0, -0.5), (2.0, 2.0), (8.0, 7.0)), 2)
         val metrics = new RegressionMetrics(predictionAndObservations)
    -    assert(metrics.explainedVariance ~== 0.95717 absTol 1E-5,
    +    assert(metrics.explainedVariance ~== 8.85937 absTol 1E-5,
           "explained variance regression score mismatch")
         assert(metrics.meanAbsoluteError ~== 0.5 absTol 1E-5, "mean absolute error mismatch")
         assert(metrics.meanSquaredError ~== 0.375 absTol 1E-5, "mean squared error mismatch")
         assert(metrics.rootMeanSquaredError ~== 0.61237 absTol 1E-5,
           "root mean squared error mismatch")
    -    assert(metrics.r2 ~== 0.94861 absTol 1E-5, "r2 score mismatch")
    +    assert(metrics.r2 ~== 0.94860 absTol 1E-5, "r2 score mismatch")
       }
     
       test("regression metrics with complete fitting") {
         val predictionAndObservations = sc.parallelize(
           Seq((3.0, 3.0), (0.0, 0.0), (2.0, 2.0), (8.0, 8.0)), 2)
         val metrics = new RegressionMetrics(predictionAndObservations)
    -    assert(metrics.explainedVariance ~== 1.0 absTol 1E-5,
    +    assert(metrics.explainedVariance ~== 8.6875 absTol 1E-5,
           "explained variance regression score mismatch")
         assert(metrics.meanAbsoluteError ~== 0.0 absTol 1E-5, "mean absolute error mismatch")
         assert(metrics.meanSquaredError ~== 0.0 absTol 1E-5, "mean squared error mismatch")
    diff --git a/mllib/src/test/scala/org/apache/spark/mllib/fpm/PrefixSpanSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/fpm/PrefixSpanSuite.scala
    index 413436d3db85f..9f107c89f6d80 100644
    --- a/mllib/src/test/scala/org/apache/spark/mllib/fpm/PrefixSpanSuite.scala
    +++ b/mllib/src/test/scala/org/apache/spark/mllib/fpm/PrefixSpanSuite.scala
    @@ -18,9 +18,8 @@ package org.apache.spark.mllib.fpm
     
     import org.apache.spark.SparkFunSuite
     import org.apache.spark.mllib.util.MLlibTestSparkContext
    -import org.apache.spark.rdd.RDD
     
    -class PrefixspanSuite extends SparkFunSuite with MLlibTestSparkContext {
    +class PrefixSpanSuite extends SparkFunSuite with MLlibTestSparkContext {
     
       test("PrefixSpan using Integer type") {
     
    @@ -48,15 +47,8 @@ class PrefixspanSuite extends SparkFunSuite with MLlibTestSparkContext {
         def compareResult(
             expectedValue: Array[(Array[Int], Long)],
             actualValue: Array[(Array[Int], Long)]): Boolean = {
    -      val sortedExpectedValue = expectedValue.sortWith{ (x, y) =>
    -        x._1.mkString(",") + ":" + x._2 < y._1.mkString(",") + ":" + y._2
    -      }
    -      val sortedActualValue = actualValue.sortWith{ (x, y) =>
    -        x._1.mkString(",") + ":" + x._2 < y._1.mkString(",") + ":" + y._2
    -      }
    -      sortedExpectedValue.zip(sortedActualValue)
    -        .map(x => x._1._1.mkString(",") == x._2._1.mkString(",") && x._1._2 == x._2._2)
    -        .reduce(_&&_)
    +      expectedValue.map(x => (x._1.toSeq, x._2)).toSet ==
    +        actualValue.map(x => (x._1.toSeq, x._2)).toSet
         }
     
         val prefixspan = new PrefixSpan()
    diff --git a/mllib/src/test/scala/org/apache/spark/mllib/linalg/BLASSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/linalg/BLASSuite.scala
    index b0f3f71113c57..d119e0b50a393 100644
    --- a/mllib/src/test/scala/org/apache/spark/mllib/linalg/BLASSuite.scala
    +++ b/mllib/src/test/scala/org/apache/spark/mllib/linalg/BLASSuite.scala
    @@ -200,8 +200,14 @@ class BLASSuite extends SparkFunSuite {
         val C10 = C1.copy
         val C11 = C1.copy
         val C12 = C1.copy
    +    val C13 = C1.copy
    +    val C14 = C1.copy
    +    val C15 = C1.copy
    +    val C16 = C1.copy
         val expected2 = new DenseMatrix(4, 2, Array(2.0, 1.0, 4.0, 2.0, 4.0, 0.0, 4.0, 3.0))
         val expected3 = new DenseMatrix(4, 2, Array(2.0, 2.0, 4.0, 2.0, 8.0, 0.0, 6.0, 6.0))
    +    val expected4 = new DenseMatrix(4, 2, Array(5.0, 0.0, 10.0, 5.0, 0.0, 0.0, 5.0, 0.0))
    +    val expected5 = C1.copy
     
         gemm(1.0, dA, B, 2.0, C1)
         gemm(1.0, sA, B, 2.0, C2)
    @@ -248,6 +254,16 @@ class BLASSuite extends SparkFunSuite {
         assert(C10 ~== expected2 absTol 1e-15)
         assert(C11 ~== expected3 absTol 1e-15)
         assert(C12 ~== expected3 absTol 1e-15)
    +
    +    gemm(0, dA, B, 5, C13)
    +    gemm(0, sA, B, 5, C14)
    +    gemm(0, dA, B, 1, C15)
    +    gemm(0, sA, B, 1, C16)
    +    assert(C13 ~== expected4 absTol 1e-15)
    +    assert(C14 ~== expected4 absTol 1e-15)
    +    assert(C15 ~== expected5 absTol 1e-15)
    +    assert(C16 ~== expected5 absTol 1e-15)
    +
       }
     
       test("gemv") {
    diff --git a/mllib/src/test/scala/org/apache/spark/mllib/linalg/VectorsSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/linalg/VectorsSuite.scala
    index 178d95a7b94ec..03be4119bdaca 100644
    --- a/mllib/src/test/scala/org/apache/spark/mllib/linalg/VectorsSuite.scala
    +++ b/mllib/src/test/scala/org/apache/spark/mllib/linalg/VectorsSuite.scala
    @@ -62,11 +62,50 @@ class VectorsSuite extends SparkFunSuite with Logging {
         assert(vec.toArray.eq(arr))
       }
     
    +  test("dense argmax") {
    +    val vec = Vectors.dense(Array.empty[Double]).asInstanceOf[DenseVector]
    +    assert(vec.argmax === -1)
    +
    +    val vec2 = Vectors.dense(arr).asInstanceOf[DenseVector]
    +    assert(vec2.argmax === 3)
    +
    +    val vec3 = Vectors.dense(Array(-1.0, 0.0, -2.0, 1.0)).asInstanceOf[DenseVector]
    +    assert(vec3.argmax === 3)
    +  }
    +
       test("sparse to array") {
         val vec = Vectors.sparse(n, indices, values).asInstanceOf[SparseVector]
         assert(vec.toArray === arr)
       }
     
    +  test("sparse argmax") {
    +    val vec = Vectors.sparse(0, Array.empty[Int], Array.empty[Double]).asInstanceOf[SparseVector]
    +    assert(vec.argmax === -1)
    +
    +    val vec2 = Vectors.sparse(n, indices, values).asInstanceOf[SparseVector]
    +    assert(vec2.argmax === 3)
    +
    +    val vec3 = Vectors.sparse(5, Array(2, 3, 4), Array(1.0, 0.0, -.7))
    +    assert(vec3.argmax === 2)
    +
    +    // check for case that sparse vector is created with
    +    // only negative values {0.0, 0.0,-1.0, -0.7, 0.0}
    +    val vec4 = Vectors.sparse(5, Array(2, 3), Array(-1.0, -.7))
    +    assert(vec4.argmax === 0)
    +
    +    val vec5 = Vectors.sparse(11, Array(0, 3, 10), Array(-1.0, -.7, 0.0))
    +    assert(vec5.argmax === 1)
    +
    +    val vec6 = Vectors.sparse(11, Array(0, 1, 2), Array(-1.0, -.7, 0.0))
    +    assert(vec6.argmax === 2)
    +
    +    val vec7 = Vectors.sparse(5, Array(0, 1, 3), Array(-1.0, 0.0, -.7))
    +    assert(vec7.argmax === 1)
    +
    +    val vec8 = Vectors.sparse(5, Array(1, 2), Array(0.0, -1.0))
    +    assert(vec8.argmax === 0)
    +  }
    +
       test("vector equals") {
         val dv1 = Vectors.dense(arr.clone())
         val dv2 = Vectors.dense(arr.clone())
    diff --git a/mllib/src/test/scala/org/apache/spark/mllib/tree/EnsembleTestHelper.scala b/mllib/src/test/scala/org/apache/spark/mllib/tree/EnsembleTestHelper.scala
    index 8972c229b7ecb..334bf3790fc7a 100644
    --- a/mllib/src/test/scala/org/apache/spark/mllib/tree/EnsembleTestHelper.scala
    +++ b/mllib/src/test/scala/org/apache/spark/mllib/tree/EnsembleTestHelper.scala
    @@ -70,7 +70,7 @@ object EnsembleTestHelper {
           metricName: String = "mse") {
         val predictions = input.map(x => model.predict(x.features))
         val errors = predictions.zip(input.map(_.label)).map { case (prediction, label) =>
    -      prediction - label
    +      label - prediction
         }
         val metric = metricName match {
           case "mse" =>
    diff --git a/mllib/src/test/scala/org/apache/spark/mllib/util/LocalClusterSparkContext.scala b/mllib/src/test/scala/org/apache/spark/mllib/util/LocalClusterSparkContext.scala
    index 5e9101cdd3804..525ab68c7921a 100644
    --- a/mllib/src/test/scala/org/apache/spark/mllib/util/LocalClusterSparkContext.scala
    +++ b/mllib/src/test/scala/org/apache/spark/mllib/util/LocalClusterSparkContext.scala
    @@ -26,7 +26,7 @@ trait LocalClusterSparkContext extends BeforeAndAfterAll { self: Suite =>
     
       override def beforeAll() {
         val conf = new SparkConf()
    -      .setMaster("local-cluster[2, 1, 512]")
    +      .setMaster("local-cluster[2, 1, 1024]")
           .setAppName("test-cluster")
           .set("spark.akka.frameSize", "1") // set to 1MB to detect direct serialization of data
         sc = new SparkContext(conf)
    diff --git a/pom.xml b/pom.xml
    index c2ebc1a11e770..1f44dc8abe1d4 100644
    --- a/pom.xml
    +++ b/pom.xml
    @@ -128,7 +128,7 @@
         ${hadoop.version}
         0.98.7-hadoop2
         hbase
    -    1.4.0
    +    1.6.0
         3.4.5
         2.4.0
         org.spark-project.hive
    @@ -144,7 +144,7 @@
         0.5.0
         2.4.0
         2.0.8
    -    3.1.0
    +    3.1.2
         1.7.7
         hadoop2
         0.7.1
    @@ -152,7 +152,6 @@
         1.2.1
         4.3.2
         3.4.1
    -    ${project.build.directory}/spark-test-classpath.txt
         2.10.4
         2.10
         ${scala.version}
    @@ -283,18 +282,6 @@
             false
           
         
    -    
    -    
    -      spark-1.4-staging
    -      Spark 1.4 RC4 Staging Repository
    -      https://repository.apache.org/content/repositories/orgapachespark-1112
    -      
    -        true
    -      
    -      
    -        false
    -      
    -    
       
       
         
    @@ -318,17 +305,6 @@
           unused
           1.0.0
         
    -    
    -    
    -      org.codehaus.groovy
    -      groovy-all
    -      2.3.7
    -      provided
    -    
         
           
             com.fasterxml.jackson.module
    -        jackson-module-scala_2.10
    +        jackson-module-scala_${scala.binary.version}
             ${fasterxml.jackson.version}
             
               
    @@ -748,6 +724,12 @@
             curator-framework
             ${curator.version}
           
    +      
    +        org.apache.curator
    +        curator-test
    +        ${curator.version}
    +        test
    +      
           
             org.apache.hadoop
             hadoop-client
    @@ -1406,6 +1388,58 @@
               maven-deploy-plugin
               2.8.2
             
    +        
    +        
    +        
    +          org.eclipse.m2e
    +          lifecycle-mapping
    +          1.0.0
    +          
    +            
    +              
    +                
    +                  
    +                    org.apache.maven.plugins
    +                    maven-dependency-plugin
    +                    [2.8,)
    +                    
    +                      build-classpath
    +                    
    +                  
    +                  
    +                    
    +                  
    +                
    +                
    +                  
    +                    org.apache.maven.plugins
    +                    maven-jar-plugin
    +                    [2.6,)
    +                    
    +                      test-jar
    +                    
    +                  
    +                  
    +                    
    +                  
    +                
    +                
    +                  
    +                    org.apache.maven.plugins
    +                    maven-antrun-plugin
    +                    [1.8,)
    +                    
    +                      run
    +                    
    +                  
    +                  
    +                    
    +                  
    +                
    +              
    +            
    +          
    +        
           
         
     
    @@ -1423,34 +1457,12 @@
                 
                 
                   test
    -              ${test_classpath_file}
    +              test_classpath
                 
               
             
           
     
    -      
    -      
    -        org.codehaus.gmavenplus
    -        gmavenplus-plugin
    -        1.5
    -        
    -          
    -            process-test-classes
    -            
    -              execute
    -            
    -            
    -              
    -                
    -              
    -            
    -          
    -        
    -      
           
       
       
    diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeFixedWidthAggregationMap.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeFixedWidthAggregationMap.java
    index 79d55b36dab01..2f7e84a7f59e2 100644
    --- a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeFixedWidthAggregationMap.java
    +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeFixedWidthAggregationMap.java
    @@ -19,11 +19,9 @@
     
     import java.util.Iterator;
     
    -import scala.Function1;
    -
     import org.apache.spark.sql.catalyst.InternalRow;
    -import org.apache.spark.sql.catalyst.util.ObjectPool;
    -import org.apache.spark.sql.catalyst.util.UniqueObjectPool;
    +import org.apache.spark.sql.types.StructField;
    +import org.apache.spark.sql.types.StructType;
     import org.apache.spark.unsafe.PlatformDependent;
     import org.apache.spark.unsafe.map.BytesToBytesMap;
     import org.apache.spark.unsafe.memory.MemoryLocation;
    @@ -40,48 +38,26 @@ public final class UnsafeFixedWidthAggregationMap {
        * An empty aggregation buffer, encoded in UnsafeRow format. When inserting a new key into the
        * map, we copy this buffer and use it as the value.
        */
    -  private final byte[] emptyBuffer;
    +  private final byte[] emptyAggregationBuffer;
     
    -  /**
    -   * An empty row used by `initProjection`
    -   */
    -  private static final InternalRow emptyRow = new GenericInternalRow();
    +  private final StructType aggregationBufferSchema;
     
    -  /**
    -   * Whether can the empty aggregation buffer be reuse without calling `initProjection` or not.
    -   */
    -  private final boolean reuseEmptyBuffer;
    +  private final StructType groupingKeySchema;
     
       /**
    -   * The projection used to initialize the emptyBuffer
    +   * Encodes grouping keys as UnsafeRows.
        */
    -  private final Function1 initProjection;
    -
    -  /**
    -   * Encodes grouping keys or buffers as UnsafeRows.
    -   */
    -  private final UnsafeRowConverter keyConverter;
    -  private final UnsafeRowConverter bufferConverter;
    +  private final UnsafeRowConverter groupingKeyToUnsafeRowConverter;
     
       /**
        * A hashmap which maps from opaque bytearray keys to bytearray values.
        */
       private final BytesToBytesMap map;
     
    -  /**
    -   * An object pool for objects that are used in grouping keys.
    -   */
    -  private final UniqueObjectPool keyPool;
    -
    -  /**
    -   * An object pool for objects that are used in aggregation buffers.
    -   */
    -  private final ObjectPool bufferPool;
    -
       /**
        * Re-used pointer to the current aggregation buffer
        */
    -  private final UnsafeRow currentBuffer = new UnsafeRow();
    +  private final UnsafeRow currentAggregationBuffer = new UnsafeRow();
     
       /**
        * Scratch space that is used when encoding grouping keys into UnsafeRow format.
    @@ -93,41 +69,69 @@ public final class UnsafeFixedWidthAggregationMap {
     
       private final boolean enablePerfMetrics;
     
    +  /**
    +   * @return true if UnsafeFixedWidthAggregationMap supports grouping keys with the given schema,
    +   *         false otherwise.
    +   */
    +  public static boolean supportsGroupKeySchema(StructType schema) {
    +    for (StructField field: schema.fields()) {
    +      if (!UnsafeRow.readableFieldTypes.contains(field.dataType())) {
    +        return false;
    +      }
    +    }
    +    return true;
    +  }
    +
    +  /**
    +   * @return true if UnsafeFixedWidthAggregationMap supports aggregation buffers with the given
    +   *         schema, false otherwise.
    +   */
    +  public static boolean supportsAggregationBufferSchema(StructType schema) {
    +    for (StructField field: schema.fields()) {
    +      if (!UnsafeRow.settableFieldTypes.contains(field.dataType())) {
    +        return false;
    +      }
    +    }
    +    return true;
    +  }
    +
       /**
        * Create a new UnsafeFixedWidthAggregationMap.
        *
    -   * @param initProjection the default value for new keys (a "zero" of the agg. function)
    -   * @param keyConverter the converter of the grouping key, used for row conversion.
    -   * @param bufferConverter the converter of the aggregation buffer, used for row conversion.
    +   * @param emptyAggregationBuffer the default value for new keys (a "zero" of the agg. function)
    +   * @param aggregationBufferSchema the schema of the aggregation buffer, used for row conversion.
    +   * @param groupingKeySchema the schema of the grouping key, used for row conversion.
        * @param memoryManager the memory manager used to allocate our Unsafe memory structures.
        * @param initialCapacity the initial capacity of the map (a sizing hint to avoid re-hashing).
        * @param enablePerfMetrics if true, performance metrics will be recorded (has minor perf impact)
        */
       public UnsafeFixedWidthAggregationMap(
    -      Function1 initProjection,
    -      UnsafeRowConverter keyConverter,
    -      UnsafeRowConverter bufferConverter,
    +      InternalRow emptyAggregationBuffer,
    +      StructType aggregationBufferSchema,
    +      StructType groupingKeySchema,
           TaskMemoryManager memoryManager,
           int initialCapacity,
           boolean enablePerfMetrics) {
    -    this.initProjection = initProjection;
    -    this.keyConverter = keyConverter;
    -    this.bufferConverter = bufferConverter;
    -    this.enablePerfMetrics = enablePerfMetrics;
    -
    +    this.emptyAggregationBuffer =
    +      convertToUnsafeRow(emptyAggregationBuffer, aggregationBufferSchema);
    +    this.aggregationBufferSchema = aggregationBufferSchema;
    +    this.groupingKeyToUnsafeRowConverter = new UnsafeRowConverter(groupingKeySchema);
    +    this.groupingKeySchema = groupingKeySchema;
         this.map = new BytesToBytesMap(memoryManager, initialCapacity, enablePerfMetrics);
    -    this.keyPool = new UniqueObjectPool(100);
    -    this.bufferPool = new ObjectPool(initialCapacity);
    +    this.enablePerfMetrics = enablePerfMetrics;
    +  }
     
    -    InternalRow initRow = initProjection.apply(emptyRow);
    -    int emptyBufferSize = bufferConverter.getSizeRequirement(initRow);
    -    this.emptyBuffer = new byte[emptyBufferSize];
    -    int writtenLength = bufferConverter.writeRow(
    -      initRow, emptyBuffer, PlatformDependent.BYTE_ARRAY_OFFSET, emptyBufferSize,
    -      bufferPool);
    -    assert (writtenLength == emptyBuffer.length): "Size requirement calculation was wrong!";
    -    // re-use the empty buffer only when there is no object saved in pool.
    -    reuseEmptyBuffer = bufferPool.size() == 0;
    +  /**
    +   * Convert a Java object row into an UnsafeRow, allocating it into a new byte array.
    +   */
    +  private static byte[] convertToUnsafeRow(InternalRow javaRow, StructType schema) {
    +    final UnsafeRowConverter converter = new UnsafeRowConverter(schema);
    +    final int size = converter.getSizeRequirement(javaRow);
    +    final byte[] unsafeRow = new byte[size];
    +    final int writtenLength =
    +      converter.writeRow(javaRow, unsafeRow, PlatformDependent.BYTE_ARRAY_OFFSET, size);
    +    assert (writtenLength == unsafeRow.length): "Size requirement calculation was wrong!";
    +    return unsafeRow;
       }
     
       /**
    @@ -135,17 +139,16 @@ public UnsafeFixedWidthAggregationMap(
        * return the same object.
        */
       public UnsafeRow getAggregationBuffer(InternalRow groupingKey) {
    -    final int groupingKeySize = keyConverter.getSizeRequirement(groupingKey);
    +    final int groupingKeySize = groupingKeyToUnsafeRowConverter.getSizeRequirement(groupingKey);
         // Make sure that the buffer is large enough to hold the key. If it's not, grow it:
         if (groupingKeySize > groupingKeyConversionScratchSpace.length) {
           groupingKeyConversionScratchSpace = new byte[groupingKeySize];
         }
    -    final int actualGroupingKeySize = keyConverter.writeRow(
    +    final int actualGroupingKeySize = groupingKeyToUnsafeRowConverter.writeRow(
           groupingKey,
           groupingKeyConversionScratchSpace,
           PlatformDependent.BYTE_ARRAY_OFFSET,
    -      groupingKeySize,
    -      keyPool);
    +      groupingKeySize);
         assert (groupingKeySize == actualGroupingKeySize) : "Size requirement calculation was wrong!";
     
         // Probe our map using the serialized key
    @@ -156,32 +159,25 @@ public UnsafeRow getAggregationBuffer(InternalRow groupingKey) {
         if (!loc.isDefined()) {
           // This is the first time that we've seen this grouping key, so we'll insert a copy of the
           // empty aggregation buffer into the map:
    -      if (!reuseEmptyBuffer) {
    -        // There is some objects referenced by emptyBuffer, so generate a new one
    -        InternalRow initRow = initProjection.apply(emptyRow);
    -        bufferConverter.writeRow(initRow, emptyBuffer, PlatformDependent.BYTE_ARRAY_OFFSET,
    -          groupingKeySize, bufferPool);
    -      }
           loc.putNewKey(
             groupingKeyConversionScratchSpace,
             PlatformDependent.BYTE_ARRAY_OFFSET,
             groupingKeySize,
    -        emptyBuffer,
    +        emptyAggregationBuffer,
             PlatformDependent.BYTE_ARRAY_OFFSET,
    -        emptyBuffer.length
    +        emptyAggregationBuffer.length
           );
         }
     
         // Reset the pointer to point to the value that we just stored or looked up:
         final MemoryLocation address = loc.getValueAddress();
    -    currentBuffer.pointTo(
    +    currentAggregationBuffer.pointTo(
           address.getBaseObject(),
           address.getBaseOffset(),
    -      bufferConverter.numFields(),
    -      loc.getValueLength(),
    -      bufferPool
    +      aggregationBufferSchema.length(),
    +      loc.getValueLength()
         );
    -    return currentBuffer;
    +    return currentAggregationBuffer;
       }
     
       /**
    @@ -217,16 +213,14 @@ public MapEntry next() {
             entry.key.pointTo(
               keyAddress.getBaseObject(),
               keyAddress.getBaseOffset(),
    -          keyConverter.numFields(),
    -          loc.getKeyLength(),
    -          keyPool
    +          groupingKeySchema.length(),
    +          loc.getKeyLength()
             );
             entry.value.pointTo(
               valueAddress.getBaseObject(),
               valueAddress.getBaseOffset(),
    -          bufferConverter.numFields(),
    -          loc.getValueLength(),
    -          bufferPool
    +          aggregationBufferSchema.length(),
    +          loc.getValueLength()
             );
             return entry;
           }
    @@ -254,8 +248,6 @@ public void printPerfMetrics() {
         System.out.println("Number of hash collisions: " + map.getNumHashCollisions());
         System.out.println("Time spent resizing (ns): " + map.getTimeSpentResizingNs());
         System.out.println("Total memory consumption (bytes): " + map.getTotalMemoryConsumption());
    -    System.out.println("Number of unique objects in keys: " + keyPool.size());
    -    System.out.println("Number of objects in buffers: " + bufferPool.size());
       }
     
     }
    diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java
    index 4b99030d1046f..fa1216b455a9e 100644
    --- a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java
    +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java
    @@ -17,12 +17,21 @@
     
     package org.apache.spark.sql.catalyst.expressions;
     
    -import org.apache.spark.sql.catalyst.InternalRow;
    -import org.apache.spark.sql.catalyst.util.ObjectPool;
    +import java.io.IOException;
    +import java.io.OutputStream;
    +import java.util.Arrays;
    +import java.util.Collections;
    +import java.util.HashSet;
    +import java.util.Set;
    +
    +import org.apache.spark.sql.types.DataType;
     import org.apache.spark.unsafe.PlatformDependent;
    +import org.apache.spark.unsafe.array.ByteArrayMethods;
     import org.apache.spark.unsafe.bitset.BitSetMethods;
    +import org.apache.spark.unsafe.hash.Murmur3_x86_32;
     import org.apache.spark.unsafe.types.UTF8String;
     
    +import static org.apache.spark.sql.types.DataTypes.*;
     
     /**
      * An Unsafe implementation of Row which is backed by raw memory instead of Java objects.
    @@ -36,20 +45,7 @@
      * primitive types, such as long, double, or int, we store the value directly in the word. For
      * fields with non-primitive or variable-length values, we store a relative offset (w.r.t. the
      * base address of the row) that points to the beginning of the variable-length field, and length
    - * (they are combined into a long). For other objects, they are stored in a pool, the indexes of
    - * them are hold in the the word.
    - *
    - * In order to support fast hashing and equality checks for UnsafeRows that contain objects
    - * when used as grouping key in BytesToBytesMap, we put the objects in an UniqueObjectPool to make
    - * sure all the key have the same index for same object, then we can hash/compare the objects by
    - * hash/compare the index.
    - *
    - * For non-primitive types, the word of a field could be:
    - *   UNION {
    - *     [1] [offset: 31bits] [length: 31bits]  // StringType
    - *     [0] [offset: 31bits] [length: 31bits]  // BinaryType
    - *     - [index: 63bits]                      // StringType, Binary, index to object in pool
    - *   }
    + * (they are combined into a long).
      *
      * Instances of `UnsafeRow` act as pointers to row data stored in this format.
      */
    @@ -58,13 +54,9 @@ public final class UnsafeRow extends MutableRow {
       private Object baseObject;
       private long baseOffset;
     
    -  /** A pool to hold non-primitive objects */
    -  private ObjectPool pool;
    -
       public Object getBaseObject() { return baseObject; }
       public long getBaseOffset() { return baseOffset; }
       public int getSizeInBytes() { return sizeInBytes; }
    -  public ObjectPool getPool() { return pool; }
     
       /** The number of fields in this row, used for calculating the bitset width (and in assertions) */
       private int numFields;
    @@ -85,7 +77,42 @@ public static int calculateBitSetWidthInBytes(int numFields) {
         return ((numFields / 64) + (numFields % 64 == 0 ? 0 : 1)) * 8;
       }
     
    -  public static final long OFFSET_BITS = 31L;
    +  /**
    +   * Field types that can be updated in place in UnsafeRows (e.g. we support set() for these types)
    +   */
    +  public static final Set settableFieldTypes;
    +
    +  /**
    +   * Fields types can be read(but not set (e.g. set() will throw UnsupportedOperationException).
    +   */
    +  public static final Set readableFieldTypes;
    +
    +  // TODO: support DecimalType
    +  static {
    +    settableFieldTypes = Collections.unmodifiableSet(
    +      new HashSet<>(
    +        Arrays.asList(new DataType[] {
    +          NullType,
    +          BooleanType,
    +          ByteType,
    +          ShortType,
    +          IntegerType,
    +          LongType,
    +          FloatType,
    +          DoubleType,
    +          DateType,
    +          TimestampType
    +    })));
    +
    +    // We support get() on a superset of the types for which we support set():
    +    final Set _readableFieldTypes = new HashSet<>(
    +      Arrays.asList(new DataType[]{
    +        StringType,
    +        BinaryType
    +      }));
    +    _readableFieldTypes.addAll(settableFieldTypes);
    +    readableFieldTypes = Collections.unmodifiableSet(_readableFieldTypes);
    +  }
     
       /**
        * Construct a new UnsafeRow. The resulting row won't be usable until `pointTo()` has been called,
    @@ -100,17 +127,14 @@ public UnsafeRow() { }
        * @param baseOffset the offset within the base object
        * @param numFields the number of fields in this row
        * @param sizeInBytes the size of this row's backing data, in bytes
    -   * @param pool the object pool to hold arbitrary objects
        */
    -  public void pointTo(
    -      Object baseObject, long baseOffset, int numFields, int sizeInBytes, ObjectPool pool) {
    +  public void pointTo(Object baseObject, long baseOffset, int numFields, int sizeInBytes) {
         assert numFields >= 0 : "numFields should >= 0";
         this.bitSetWidthInBytes = calculateBitSetWidthInBytes(numFields);
         this.baseObject = baseObject;
         this.baseOffset = baseOffset;
         this.numFields = numFields;
         this.sizeInBytes = sizeInBytes;
    -    this.pool = pool;
       }
     
       private void assertIndexIsValid(int index) {
    @@ -133,68 +157,9 @@ private void setNotNullAt(int i) {
         BitSetMethods.unset(baseObject, baseOffset, i);
       }
     
    -  /**
    -   * Updates the column `i` as Object `value`, which cannot be primitive types.
    -   */
       @Override
    -  public void update(int i, Object value) {
    -    if (value == null) {
    -      if (!isNullAt(i)) {
    -        // remove the old value from pool
    -        long idx = getLong(i);
    -        if (idx <= 0) {
    -          // this is the index of old value in pool, remove it
    -          pool.replace((int)-idx, null);
    -        } else {
    -          // there will be some garbage left (UTF8String or byte[])
    -        }
    -        setNullAt(i);
    -      }
    -      return;
    -    }
    -
    -    if (isNullAt(i)) {
    -      // there is not an old value, put the new value into pool
    -      int idx = pool.put(value);
    -      setLong(i, (long)-idx);
    -    } else {
    -      // there is an old value, check the type, then replace it or update it
    -      long v = getLong(i);
    -      if (v <= 0) {
    -        // it's the index in the pool, replace old value with new one
    -        int idx = (int)-v;
    -        pool.replace(idx, value);
    -      } else {
    -        // old value is UTF8String or byte[], try to reuse the space
    -        boolean isString;
    -        byte[] newBytes;
    -        if (value instanceof UTF8String) {
    -          newBytes = ((UTF8String) value).getBytes();
    -          isString = true;
    -        } else {
    -          newBytes = (byte[]) value;
    -          isString = false;
    -        }
    -        int offset = (int) ((v >> OFFSET_BITS) & Integer.MAX_VALUE);
    -        int oldLength = (int) (v & Integer.MAX_VALUE);
    -        if (newBytes.length <= oldLength) {
    -          // the new value can fit in the old buffer, re-use it
    -          PlatformDependent.copyMemory(
    -            newBytes,
    -            PlatformDependent.BYTE_ARRAY_OFFSET,
    -            baseObject,
    -            baseOffset + offset,
    -            newBytes.length);
    -          long flag = isString ? 1L << (OFFSET_BITS * 2) : 0L;
    -          setLong(i, flag | (((long) offset) << OFFSET_BITS) | (long) newBytes.length);
    -        } else {
    -          // Cannot fit in the buffer
    -          int idx = pool.put(value);
    -          setLong(i, (long) -idx);
    -        }
    -      }
    -    }
    -    setNotNullAt(i);
    +  public void update(int ordinal, Object value) {
    +    throw new UnsupportedOperationException();
       }
     
       @Override
    @@ -215,6 +180,9 @@ public void setLong(int ordinal, long value) {
       public void setDouble(int ordinal, double value) {
         assertIndexIsValid(ordinal);
         setNotNullAt(ordinal);
    +    if (Double.isNaN(value)) {
    +      value = Double.NaN;
    +    }
         PlatformDependent.UNSAFE.putDouble(baseObject, getFieldOffset(ordinal), value);
       }
     
    @@ -243,6 +211,9 @@ public void setByte(int ordinal, byte value) {
       public void setFloat(int ordinal, float value) {
         assertIndexIsValid(ordinal);
         setNotNullAt(ordinal);
    +    if (Float.isNaN(value)) {
    +      value = Float.NaN;
    +    }
         PlatformDependent.UNSAFE.putFloat(baseObject, getFieldOffset(ordinal), value);
       }
     
    @@ -251,40 +222,9 @@ public int size() {
         return numFields;
       }
     
    -  /**
    -   * Returns the object for column `i`, which should not be primitive type.
    -   */
       @Override
       public Object get(int i) {
    -    assertIndexIsValid(i);
    -    if (isNullAt(i)) {
    -      return null;
    -    }
    -    long v = PlatformDependent.UNSAFE.getLong(baseObject, getFieldOffset(i));
    -    if (v <= 0) {
    -      // It's an index to object in the pool.
    -      int idx = (int)-v;
    -      return pool.get(idx);
    -    } else {
    -      // The column could be StingType or BinaryType
    -      boolean isString = (v >> (OFFSET_BITS * 2)) > 0;
    -      int offset = (int) ((v >> OFFSET_BITS) & Integer.MAX_VALUE);
    -      int size = (int) (v & Integer.MAX_VALUE);
    -      final byte[] bytes = new byte[size];
    -      // TODO(davies): Avoid the copy once we can manage the life cycle of Row well.
    -      PlatformDependent.copyMemory(
    -        baseObject,
    -        baseOffset + offset,
    -        bytes,
    -        PlatformDependent.BYTE_ARRAY_OFFSET,
    -        size
    -      );
    -      if (isString) {
    -        return UTF8String.fromBytes(bytes);
    -      } else {
    -        return bytes;
    -      }
    -    }
    +    throw new UnsupportedOperationException();
       }
     
       @Override
    @@ -343,6 +283,38 @@ public double getDouble(int i) {
         }
       }
     
    +  @Override
    +  public UTF8String getUTF8String(int i) {
    +    assertIndexIsValid(i);
    +    return isNullAt(i) ? null : UTF8String.fromBytes(getBinary(i));
    +  }
    +
    +  @Override
    +  public byte[] getBinary(int i) {
    +    if (isNullAt(i)) {
    +      return null;
    +    } else {
    +      assertIndexIsValid(i);
    +      final long offsetAndSize = getLong(i);
    +      final int offset = (int) (offsetAndSize >> 32);
    +      final int size = (int) (offsetAndSize & ((1L << 32) - 1));
    +      final byte[] bytes = new byte[size];
    +      PlatformDependent.copyMemory(
    +        baseObject,
    +        baseOffset + offset,
    +        bytes,
    +        PlatformDependent.BYTE_ARRAY_OFFSET,
    +        size
    +      );
    +      return bytes;
    +    }
    +  }
    +
    +  @Override
    +  public String getString(int i) {
    +    return getUTF8String(i).toString();
    +  }
    +
       /**
        * Copies this row, returning a self-contained UnsafeRow that stores its data in an internal
        * byte array rather than referencing data stored in a data page.
    @@ -350,28 +322,95 @@ public double getDouble(int i) {
        * This method is only supported on UnsafeRows that do not use ObjectPools.
        */
       @Override
    -  public InternalRow copy() {
    -    if (pool != null) {
    -      throw new UnsupportedOperationException(
    -        "Copy is not supported for UnsafeRows that use object pools");
    +  public UnsafeRow copy() {
    +    UnsafeRow rowCopy = new UnsafeRow();
    +    final byte[] rowDataCopy = new byte[sizeInBytes];
    +    PlatformDependent.copyMemory(
    +      baseObject,
    +      baseOffset,
    +      rowDataCopy,
    +      PlatformDependent.BYTE_ARRAY_OFFSET,
    +      sizeInBytes
    +    );
    +    rowCopy.pointTo(rowDataCopy, PlatformDependent.BYTE_ARRAY_OFFSET, numFields, sizeInBytes);
    +    return rowCopy;
    +  }
    +
    +  /**
    +   * Write this UnsafeRow's underlying bytes to the given OutputStream.
    +   *
    +   * @param out the stream to write to.
    +   * @param writeBuffer a byte array for buffering chunks of off-heap data while writing to the
    +   *                    output stream. If this row is backed by an on-heap byte array, then this
    +   *                    buffer will not be used and may be null.
    +   */
    +  public void writeToStream(OutputStream out, byte[] writeBuffer) throws IOException {
    +    if (baseObject instanceof byte[]) {
    +      int offsetInByteArray = (int) (PlatformDependent.BYTE_ARRAY_OFFSET - baseOffset);
    +      out.write((byte[]) baseObject, offsetInByteArray, sizeInBytes);
         } else {
    -      UnsafeRow rowCopy = new UnsafeRow();
    -      final byte[] rowDataCopy = new byte[sizeInBytes];
    -      PlatformDependent.copyMemory(
    -        baseObject,
    -        baseOffset,
    -        rowDataCopy,
    -        PlatformDependent.BYTE_ARRAY_OFFSET,
    -        sizeInBytes
    -      );
    -      rowCopy.pointTo(
    -        rowDataCopy, PlatformDependent.BYTE_ARRAY_OFFSET, numFields, sizeInBytes, null);
    -      return rowCopy;
    +      int dataRemaining = sizeInBytes;
    +      long rowReadPosition = baseOffset;
    +      while (dataRemaining > 0) {
    +        int toTransfer = Math.min(writeBuffer.length, dataRemaining);
    +        PlatformDependent.copyMemory(
    +          baseObject,
    +          rowReadPosition,
    +          writeBuffer,
    +          PlatformDependent.BYTE_ARRAY_OFFSET,
    +          toTransfer);
    +        out.write(writeBuffer, 0, toTransfer);
    +        rowReadPosition += toTransfer;
    +        dataRemaining -= toTransfer;
    +      }
    +    }
    +  }
    +
    +  @Override
    +  public int hashCode() {
    +    return Murmur3_x86_32.hashUnsafeWords(baseObject, baseOffset, sizeInBytes, 42);
    +  }
    +
    +  @Override
    +  public boolean equals(Object other) {
    +    if (other instanceof UnsafeRow) {
    +      UnsafeRow o = (UnsafeRow) other;
    +      return (sizeInBytes == o.sizeInBytes) &&
    +        ByteArrayMethods.arrayEquals(baseObject, baseOffset, o.baseObject, o.baseOffset,
    +          sizeInBytes);
    +    }
    +    return false;
    +  }
    +
    +  /**
    +   * Returns the underlying bytes for this UnsafeRow.
    +   */
    +  public byte[] getBytes() {
    +    if (baseObject instanceof byte[] && baseOffset == PlatformDependent.BYTE_ARRAY_OFFSET
    +        && (((byte[]) baseObject).length == sizeInBytes)) {
    +      return (byte[]) baseObject;
    +    } else {
    +      byte[] bytes = new byte[sizeInBytes];
    +      PlatformDependent.copyMemory(baseObject, baseOffset, bytes,
    +        PlatformDependent.BYTE_ARRAY_OFFSET, sizeInBytes);
    +      return bytes;
    +    }
    +  }
    +
    +  // This is for debugging
    +  @Override
    +  public String toString() {
    +    StringBuilder build = new StringBuilder("[");
    +    for (int i = 0; i < sizeInBytes; i += 8) {
    +      build.append(PlatformDependent.UNSAFE.getLong(baseObject, baseOffset + i));
    +      build.append(',');
         }
    +    build.append(']');
    +    return build.toString();
       }
     
       @Override
       public boolean anyNull() {
    -    return BitSetMethods.anySet(baseObject, baseOffset, bitSetWidthInBytes);
    +    return BitSetMethods.anySet(baseObject, baseOffset, bitSetWidthInBytes / 8);
       }
     }
    diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/util/ObjectPool.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/util/ObjectPool.java
    deleted file mode 100644
    index 97f89a7d0b758..0000000000000
    --- a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/util/ObjectPool.java
    +++ /dev/null
    @@ -1,78 +0,0 @@
    -/*
    - * Licensed to the Apache Software Foundation (ASF) under one or more
    - * contributor license agreements.  See the NOTICE file distributed with
    - * this work for additional information regarding copyright ownership.
    - * The ASF licenses this file to You under the Apache License, Version 2.0
    - * (the "License"); you may not use this file except in compliance with
    - * the License.  You may obtain a copy of the License at
    - *
    - *    http://www.apache.org/licenses/LICENSE-2.0
    - *
    - * Unless required by applicable law or agreed to in writing, software
    - * distributed under the License is distributed on an "AS IS" BASIS,
    - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    - * See the License for the specific language governing permissions and
    - * limitations under the License.
    - */
    -
    -package org.apache.spark.sql.catalyst.util;
    -
    -/**
    - * A object pool stores a collection of objects in array, then they can be referenced by the
    - * pool plus an index.
    - */
    -public class ObjectPool {
    -
    -  /**
    -   * An array to hold objects, which will grow as needed.
    -   */
    -  private Object[] objects;
    -
    -  /**
    -   * How many objects in the pool.
    -   */
    -  private int numObj;
    -
    -  public ObjectPool(int capacity) {
    -    objects = new Object[capacity];
    -    numObj = 0;
    -  }
    -
    -  /**
    -   * Returns how many objects in the pool.
    -   */
    -  public int size() {
    -    return numObj;
    -  }
    -
    -  /**
    -   * Returns the object at position `idx` in the array.
    -   */
    -  public Object get(int idx) {
    -    assert (idx < numObj);
    -    return objects[idx];
    -  }
    -
    -  /**
    -   * Puts an object `obj` at the end of array, returns the index of it.
    -   * 

    - * The array will grow as needed. - */ - public int put(Object obj) { - if (numObj >= objects.length) { - Object[] tmp = new Object[objects.length * 2]; - System.arraycopy(objects, 0, tmp, 0, objects.length); - objects = tmp; - } - objects[numObj++] = obj; - return numObj - 1; - } - - /** - * Replaces the object at `idx` with new one `obj`. - */ - public void replace(int idx, Object obj) { - assert (idx < numObj); - objects[idx] = obj; - } -} diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/util/UniqueObjectPool.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/util/UniqueObjectPool.java deleted file mode 100644 index d512392dcaacc..0000000000000 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/util/UniqueObjectPool.java +++ /dev/null @@ -1,59 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.catalyst.util; - -import java.util.HashMap; - -/** - * An unique object pool stores a collection of unique objects in it. - */ -public class UniqueObjectPool extends ObjectPool { - - /** - * A hash map from objects to their indexes in the array. - */ - private HashMap objIndex; - - public UniqueObjectPool(int capacity) { - super(capacity); - objIndex = new HashMap(); - } - - /** - * Put an object `obj` into the pool. If there is an existing object equals to `obj`, it will - * return the index of the existing one. - */ - @Override - public int put(Object obj) { - if (objIndex.containsKey(obj)) { - return objIndex.get(obj); - } else { - int idx = super.put(obj); - objIndex.put(obj, idx); - return idx; - } - } - - /** - * The objects can not be replaced. - */ - @Override - public void replace(int idx, Object obj) { - throw new UnsupportedOperationException(); - } -} diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/execution/UnsafeExternalRowSorter.java b/sql/catalyst/src/main/java/org/apache/spark/sql/execution/UnsafeExternalRowSorter.java index b94601cf6d818..be4ff400c4754 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/execution/UnsafeExternalRowSorter.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/execution/UnsafeExternalRowSorter.java @@ -28,12 +28,8 @@ import org.apache.spark.TaskContext; import org.apache.spark.sql.AbstractScalaRowIterator; import org.apache.spark.sql.catalyst.InternalRow; -import org.apache.spark.sql.catalyst.expressions.ObjectUnsafeColumnWriter; -import org.apache.spark.sql.catalyst.expressions.UnsafeColumnWriter; +import org.apache.spark.sql.catalyst.expressions.UnsafeProjection; import org.apache.spark.sql.catalyst.expressions.UnsafeRow; -import org.apache.spark.sql.catalyst.expressions.UnsafeRowConverter; -import org.apache.spark.sql.catalyst.util.ObjectPool; -import org.apache.spark.sql.types.StructField; import org.apache.spark.sql.types.StructType; import org.apache.spark.unsafe.PlatformDependent; import org.apache.spark.util.collection.unsafe.sort.PrefixComparator; @@ -52,10 +48,9 @@ final class UnsafeExternalRowSorter { private long numRowsInserted = 0; private final StructType schema; - private final UnsafeRowConverter rowConverter; + private final UnsafeProjection unsafeProjection; private final PrefixComputer prefixComputer; private final UnsafeExternalSorter sorter; - private byte[] rowConversionBuffer = new byte[1024 * 8]; public static abstract class PrefixComputer { abstract long computePrefix(InternalRow row); @@ -67,7 +62,7 @@ public UnsafeExternalRowSorter( PrefixComparator prefixComparator, PrefixComputer prefixComputer) throws IOException { this.schema = schema; - this.rowConverter = new UnsafeRowConverter(schema); + this.unsafeProjection = UnsafeProjection.create(schema); this.prefixComputer = prefixComputer; final SparkEnv sparkEnv = SparkEnv.get(); final TaskContext taskContext = TaskContext.get(); @@ -76,7 +71,7 @@ public UnsafeExternalRowSorter( sparkEnv.shuffleMemoryManager(), sparkEnv.blockManager(), taskContext, - new RowComparator(ordering, schema.length(), null), + new RowComparator(ordering, schema.length()), prefixComparator, 4096, sparkEnv.conf() @@ -94,18 +89,12 @@ void setTestSpillFrequency(int frequency) { @VisibleForTesting void insertRow(InternalRow row) throws IOException { - final int sizeRequirement = rowConverter.getSizeRequirement(row); - if (sizeRequirement > rowConversionBuffer.length) { - rowConversionBuffer = new byte[sizeRequirement]; - } - final int bytesWritten = rowConverter.writeRow( - row, rowConversionBuffer, PlatformDependent.BYTE_ARRAY_OFFSET, sizeRequirement, null); - assert (bytesWritten == sizeRequirement); + UnsafeRow unsafeRow = unsafeProjection.apply(row); final long prefix = prefixComputer.computePrefix(row); sorter.insertRecord( - rowConversionBuffer, - PlatformDependent.BYTE_ARRAY_OFFSET, - sizeRequirement, + unsafeRow.getBaseObject(), + unsafeRow.getBaseOffset(), + unsafeRow.getSizeInBytes(), prefix ); numRowsInserted++; @@ -150,8 +139,7 @@ public InternalRow next() { sortedIterator.getBaseObject(), sortedIterator.getBaseOffset(), numFields, - sortedIterator.getRecordLength(), - null); + sortedIterator.getRecordLength()); if (!hasNext()) { row.copy(); // so that we don't have dangling pointers to freed page cleanupResources(); @@ -184,32 +172,25 @@ public Iterator sort(Iterator inputIterator) throws IO * Return true if UnsafeExternalRowSorter can sort rows with the given schema, false otherwise. */ public static boolean supportsSchema(StructType schema) { - // TODO: add spilling note to explain why we do this for now: - for (StructField field : schema.fields()) { - if (UnsafeColumnWriter.forType(field.dataType()) instanceof ObjectUnsafeColumnWriter) { - return false; - } - } - return true; + return UnsafeProjection.canSupport(schema); } private static final class RowComparator extends RecordComparator { private final Ordering ordering; private final int numFields; - private final ObjectPool objPool; private final UnsafeRow row1 = new UnsafeRow(); private final UnsafeRow row2 = new UnsafeRow(); - public RowComparator(Ordering ordering, int numFields, ObjectPool objPool) { + public RowComparator(Ordering ordering, int numFields) { this.numFields = numFields; this.ordering = ordering; - this.objPool = objPool; } @Override public int compare(Object baseObj1, long baseOff1, Object baseObj2, long baseOff2) { - row1.pointTo(baseObj1, baseOff1, numFields, -1, objPool); - row2.pointTo(baseObj2, baseOff2, numFields, -1, objPool); + // TODO: Why are the sizes -1? + row1.pointTo(baseObj1, baseOff1, numFields, -1); + row2.pointTo(baseObj2, baseOff2, numFields, -1); return ordering.compare(row1, row2); } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/AbstractScalaRowIterator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/AbstractScalaRowIterator.scala index cfefb13e7721e..1090bdb5a4bd3 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/AbstractScalaRowIterator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/AbstractScalaRowIterator.scala @@ -17,11 +17,14 @@ package org.apache.spark.sql -import org.apache.spark.sql.catalyst.InternalRow - /** * Shim to allow us to implement [[scala.Iterator]] in Java. Scala 2.11+ has an AbstractIterator * class for this, but that class is `private[scala]` in 2.10. We need to explicitly fix this to - * `Row` in order to work around a spurious IntelliJ compiler error. + * `Row` in order to work around a spurious IntelliJ compiler error. This cannot be an abstract + * class because that leads to compilation errors under Scala 2.11. */ -private[spark] abstract class AbstractScalaRowIterator extends Iterator[InternalRow] +private[spark] class AbstractScalaRowIterator[T] extends Iterator[T] { + override def hasNext: Boolean = throw new NotImplementedError + + override def next(): T = throw new NotImplementedError +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/Row.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/Row.scala index 0f2fd6a86d177..91449479fa539 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/Row.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/Row.scala @@ -17,6 +17,7 @@ package org.apache.spark.sql +import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.GenericRow import org.apache.spark.sql.types.StructType @@ -151,7 +152,7 @@ trait Row extends Serializable { * StructType -> org.apache.spark.sql.Row * }}} */ - def apply(i: Int): Any + def apply(i: Int): Any = get(i) /** * Returns the value at position i. If the value is null, null is returned. The following @@ -176,10 +177,10 @@ trait Row extends Serializable { * StructType -> org.apache.spark.sql.Row * }}} */ - def get(i: Int): Any = apply(i) + def get(i: Int): Any /** Checks whether the value at position i is null. */ - def isNullAt(i: Int): Boolean = apply(i) == null + def isNullAt(i: Int): Boolean = get(i) == null /** * Returns the value at position i as a primitive boolean. @@ -311,7 +312,7 @@ trait Row extends Serializable { * * @throws ClassCastException when data type does not match. */ - def getAs[T](i: Int): T = apply(i).asInstanceOf[T] + def getAs[T](i: Int): T = get(i).asInstanceOf[T] /** * Returns the value of a given fieldName. @@ -363,6 +364,69 @@ trait Row extends Serializable { false } + /** + * Returns true if we can check equality for these 2 rows. + * Equality check between external row and internal row is not allowed. + * Here we do this check to prevent call `equals` on external row with internal row. + */ + protected def canEqual(other: Row) = { + // Note that `Row` is not only the interface of external row but also the parent + // of `InternalRow`, so we have to ensure `other` is not a internal row here to prevent + // call `equals` on external row with internal row. + // `InternalRow` overrides canEqual, and these two canEquals together makes sure that + // equality check between external Row and InternalRow will always fail. + // In the future, InternalRow should not extend Row. In that case, we can remove these + // canEqual methods. + !other.isInstanceOf[InternalRow] + } + + override def equals(o: Any): Boolean = { + if (!o.isInstanceOf[Row]) return false + val other = o.asInstanceOf[Row] + + if (!canEqual(other)) { + throw new UnsupportedOperationException( + "cannot check equality between external and internal rows") + } + + if (other eq null) return false + + if (length != other.length) { + return false + } + + var i = 0 + while (i < length) { + if (isNullAt(i) != other.isNullAt(i)) { + return false + } + if (!isNullAt(i)) { + val o1 = get(i) + val o2 = other.get(i) + o1 match { + case b1: Array[Byte] => + if (!o2.isInstanceOf[Array[Byte]] || + !java.util.Arrays.equals(b1, o2.asInstanceOf[Array[Byte]])) { + return false + } + case f1: Float if java.lang.Float.isNaN(f1) => + if (!o2.isInstanceOf[Float] || ! java.lang.Float.isNaN(o2.asInstanceOf[Float])) { + return false + } + case d1: Double if java.lang.Double.isNaN(d1) => + if (!o2.isInstanceOf[Double] || ! java.lang.Double.isNaN(o2.asInstanceOf[Double])) { + return false + } + case _ => if (o1 != o2) { + return false + } + } + } + i += 1 + } + true + } + /* ---------------------- utility methods for Scala ---------------------- */ /** diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystTypeConverters.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystTypeConverters.scala index 8f63d2120ad0e..4067833d5e648 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystTypeConverters.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystTypeConverters.scala @@ -24,6 +24,7 @@ import java.util.{Map => JavaMap} import javax.annotation.Nullable import scala.collection.mutable.HashMap +import scala.language.existentials import org.apache.spark.sql.Row import org.apache.spark.sql.catalyst.expressions._ @@ -280,7 +281,8 @@ object CatalystTypeConverters { } override def toScala(catalystValue: UTF8String): String = if (catalystValue == null) null else catalystValue.toString - override def toScalaImpl(row: InternalRow, column: Int): String = row(column).toString + override def toScalaImpl(row: InternalRow, column: Int): String = + row.getUTF8String(column).toString } private object DateConverter extends CatalystTypeConverter[Date, Date, Any] { @@ -401,7 +403,7 @@ object CatalystTypeConverters { case seq: Seq[Any] => seq.map(convertToCatalyst) case r: Row => InternalRow(r.toSeq.map(convertToCatalyst): _*) case arr: Array[Any] => arr.toSeq.map(convertToCatalyst).toArray - case m: Map[Any, Any] => + case m: Map[_, _] => m.map { case (k, v) => (convertToCatalyst(k), convertToCatalyst(v)) }.toMap case other => other } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/InternalRow.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/InternalRow.scala index 57de0f26a9720..c7ec49b3d6c3d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/InternalRow.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/InternalRow.scala @@ -27,11 +27,12 @@ import org.apache.spark.unsafe.types.UTF8String */ abstract class InternalRow extends Row { + def getUTF8String(i: Int): UTF8String = getAs[UTF8String](i) + + def getBinary(i: Int): Array[Byte] = getAs[Array[Byte]](i) + // This is only use for test - override def getString(i: Int): String = { - val str = getAs[UTF8String](i) - if (str != null) str.toString else null - } + override def getString(i: Int): String = getAs[UTF8String](i).toString // These expensive API should not be used internally. final override def getDecimal(i: Int): java.math.BigDecimal = @@ -53,41 +54,13 @@ abstract class InternalRow extends Row { // A default implementation to change the return type override def copy(): InternalRow = this - override def apply(i: Int): Any = get(i) - - override def equals(o: Any): Boolean = { - if (!o.isInstanceOf[Row]) { - return false - } - - val other = o.asInstanceOf[Row] - if (length != other.length) { - return false - } - var i = 0 - while (i < length) { - if (isNullAt(i) != other.isNullAt(i)) { - return false - } - if (!isNullAt(i)) { - val o1 = apply(i) - val o2 = other.apply(i) - if (o1.isInstanceOf[Array[Byte]]) { - // handle equality of Array[Byte] - val b1 = o1.asInstanceOf[Array[Byte]] - if (!o2.isInstanceOf[Array[Byte]] || - !java.util.Arrays.equals(b1, o2.asInstanceOf[Array[Byte]])) { - return false - } - } else if (o1 != o2) { - return false - } - } - i += 1 - } - true - } + /** + * Returns true if we can check equality for these 2 rows. + * Equality check between external row and internal row is not allowed. + * Here we do this check to prevent call `equals` on internal row with external row. + */ + protected override def canEqual(other: Row) = other.isInstanceOf[InternalRow] // Custom hashCode function that matches the efficient code generated version. override def hashCode: Int = { @@ -98,7 +71,7 @@ abstract class InternalRow extends Row { if (isNullAt(i)) { 0 } else { - apply(i) match { + get(i) match { case b: Boolean => if (b) 0 else 1 case b: Byte => b.toInt case s: Short => s.toInt diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala index d4ef04c2294a2..29cfc064da89a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala @@ -266,12 +266,12 @@ class SqlParser extends AbstractSparkSQLParser with DataTypeParser { } } | ident ~ ("(" ~> repsep(expression, ",")) <~ ")" ^^ - { case udfName ~ exprs => UnresolvedFunction(udfName, exprs) } + { case udfName ~ exprs => UnresolvedFunction(udfName, exprs, isDistinct = false) } | ident ~ ("(" ~ DISTINCT ~> repsep(expression, ",")) <~ ")" ^^ { case udfName ~ exprs => lexical.normalizeKeyword(udfName) match { case "sum" => SumDistinct(exprs.head) case "count" => CountDistinct(exprs) - case _ => throw new AnalysisException(s"function $udfName does not support DISTINCT") + case _ => UnresolvedFunction(udfName, exprs, isDistinct = true) } } | APPROXIMATE ~> ident ~ ("(" ~ DISTINCT ~> expression <~ ")") ^^ { case udfName ~ exp => diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index 891408e310049..8cadbc57e87e1 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -18,6 +18,7 @@ package org.apache.spark.sql.catalyst.analysis import org.apache.spark.sql.AnalysisException +import org.apache.spark.sql.catalyst.expressions.aggregate.{Complete, AggregateExpression2, AggregateFunction2} import org.apache.spark.sql.catalyst.{SimpleCatalystConf, CatalystConf} import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.logical._ @@ -194,16 +195,52 @@ class Analyzer( } def apply(plan: LogicalPlan): LogicalPlan = plan transform { + case a if !a.childrenResolved => a // be sure all of the children are resolved. case a: Cube => GroupingSets(bitmasks(a), a.groupByExprs, a.child, a.aggregations) case a: Rollup => GroupingSets(bitmasks(a), a.groupByExprs, a.child, a.aggregations) case x: GroupingSets => val gid = AttributeReference(VirtualColumn.groupingIdName, IntegerType, false)() + // We will insert another Projection if the GROUP BY keys contains the + // non-attribute expressions. And the top operators can references those + // expressions by its alias. + // e.g. SELECT key%5 as c1 FROM src GROUP BY key%5 ==> + // SELECT a as c1 FROM (SELECT key%5 AS a FROM src) GROUP BY a + + // find all of the non-attribute expressions in the GROUP BY keys + val nonAttributeGroupByExpressions = new ArrayBuffer[Alias]() + + // The pair of (the original GROUP BY key, associated attribute) + val groupByExprPairs = x.groupByExprs.map(_ match { + case e: NamedExpression => (e, e.toAttribute) + case other => { + val alias = Alias(other, other.toString)() + nonAttributeGroupByExpressions += alias // add the non-attributes expression alias + (other, alias.toAttribute) + } + }) + + // substitute the non-attribute expressions for aggregations. + val aggregation = x.aggregations.map(expr => expr.transformDown { + case e => groupByExprPairs.find(_._1.semanticEquals(e)).map(_._2).getOrElse(e) + }.asInstanceOf[NamedExpression]) + + // substitute the group by expressions. + val newGroupByExprs = groupByExprPairs.map(_._2) + + val child = if (nonAttributeGroupByExpressions.length > 0) { + // insert additional projection if contains the + // non-attribute expressions in the GROUP BY keys + Project(x.child.output ++ nonAttributeGroupByExpressions, x.child) + } else { + x.child + } + Aggregate( - x.groupByExprs :+ VirtualColumn.groupingIdAttribute, - x.aggregations, - Expand(x.bitmasks, x.groupByExprs, gid, x.child)) + newGroupByExprs :+ VirtualColumn.groupingIdAttribute, + aggregation, + Expand(x.bitmasks, newGroupByExprs, gid, child)) } } @@ -241,7 +278,7 @@ class Analyzer( Project( projectList.flatMap { case s: Star => s.expand(child.output, resolver) - case UnresolvedAlias(f @ UnresolvedFunction(_, args)) if containsStar(args) => + case UnresolvedAlias(f @ UnresolvedFunction(_, args, _)) if containsStar(args) => val expandedArgs = args.flatMap { case s: Star => s.expand(child.output, resolver) case o => o :: Nil @@ -280,7 +317,7 @@ class Analyzer( ) // Special handling for cases when self-join introduce duplicate expression ids. - case j @ Join(left, right, _, _) if left.outputSet.intersect(right.outputSet).nonEmpty => + case j @ Join(left, right, _, _) if !j.selfJoinResolved => val conflictingAttributes = left.outputSet.intersect(right.outputSet) logDebug(s"Conflicting attributes ${conflictingAttributes.mkString(",")} in $j") @@ -481,9 +518,26 @@ class Analyzer( def apply(plan: LogicalPlan): LogicalPlan = plan transform { case q: LogicalPlan => q transformExpressions { - case u @ UnresolvedFunction(name, children) => + case u @ UnresolvedFunction(name, children, isDistinct) => withPosition(u) { - registry.lookupFunction(name, children) + registry.lookupFunction(name, children) match { + // We get an aggregate function built based on AggregateFunction2 interface. + // So, we wrap it in AggregateExpression2. + case agg2: AggregateFunction2 => AggregateExpression2(agg2, Complete, isDistinct) + // Currently, our old aggregate function interface supports SUM(DISTINCT ...) + // and COUTN(DISTINCT ...). + case sumDistinct: SumDistinct => sumDistinct + case countDistinct: CountDistinct => countDistinct + // DISTINCT is not meaningful with Max and Min. + case max: Max if isDistinct => max + case min: Min if isDistinct => min + // For other aggregate functions, DISTINCT keyword is not supported for now. + // Once we converted to the new code path, we will allow using DISTINCT keyword. + case other if isDistinct => + failAnalysis(s"$name does not support DISTINCT keyword.") + // If it does not have DISTINCT keyword, we will return it as is. + case other => other + } } } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala index 476ac2b7cb474..c203fcecf20fb 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala @@ -19,6 +19,7 @@ package org.apache.spark.sql.catalyst.analysis import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression2 import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.types._ @@ -109,29 +110,27 @@ trait CheckAnalysis { s"resolved attribute(s) $missingAttributes missing from $input " + s"in operator ${operator.simpleString}") - case o if !o.resolved => - failAnalysis( - s"unresolved operator ${operator.simpleString}") - case p @ Project(exprs, _) if containsMultipleGenerators(exprs) => failAnalysis( s"""Only a single table generating function is allowed in a SELECT clause, found: | ${exprs.map(_.prettyString).mkString(",")}""".stripMargin) + // Special handling for cases when self-join introduce duplicate expression ids. + case j @ Join(left, right, _, _) if left.outputSet.intersect(right.outputSet).nonEmpty => + val conflictingAttributes = left.outputSet.intersect(right.outputSet) + failAnalysis( + s""" + |Failure when resolving conflicting references in Join: + |$plan + |Conflicting attributes: ${conflictingAttributes.mkString(",")} + |""".stripMargin) + + case o if !o.resolved => + failAnalysis( + s"unresolved operator ${operator.simpleString}") case _ => // Analysis successful! } - - // Special handling for cases when self-join introduce duplicate expression ids. - case j @ Join(left, right, _, _) if left.outputSet.intersect(right.outputSet).nonEmpty => - val conflictingAttributes = left.outputSet.intersect(right.outputSet) - failAnalysis( - s""" - |Failure when resolving conflicting references in Join: - |$plan - |Conflicting attributes: ${conflictingAttributes.mkString(",")} - |""".stripMargin) - } extendedCheckRules.foreach(_(plan)) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala index ed69c42dcb825..9c349838c28a1 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala @@ -17,6 +17,7 @@ package org.apache.spark.sql.catalyst.analysis +import scala.language.existentials import scala.reflect.ClassTag import scala.util.{Failure, Success, Try} @@ -78,6 +79,7 @@ object FunctionRegistry { expression[Explode]("explode"), expression[Greatest]("greatest"), expression[If]("if"), + expression[IsNaN]("isnan"), expression[IsNull]("isnull"), expression[IsNotNull]("isnotnull"), expression[Least]("least"), @@ -87,6 +89,7 @@ object FunctionRegistry { expression[CreateStruct]("struct"), expression[CreateNamedStruct]("named_struct"), expression[Sqrt]("sqrt"), + expression[NaNvl]("nanvl"), // math functions expression[Acos]("acos"), @@ -98,6 +101,7 @@ object FunctionRegistry { expression[Ceil]("ceil"), expression[Ceil]("ceiling"), expression[Cos]("cos"), + expression[Conv]("conv"), expression[EulerNumber]("e"), expression[Exp]("exp"), expression[Expm1]("expm1"), @@ -109,13 +113,15 @@ object FunctionRegistry { expression[Log]("ln"), expression[Log10]("log10"), expression[Log1p]("log1p"), + expression[Log2]("log2"), expression[UnaryMinus]("negative"), expression[Pi]("pi"), - expression[Log2]("log2"), expression[Pow]("pow"), expression[Pow]("power"), + expression[Pmod]("pmod"), expression[UnaryPositive]("positive"), expression[Rint]("rint"), + expression[Round]("round"), expression[ShiftLeft]("shiftleft"), expression[ShiftRight]("shiftright"), expression[ShiftRightUnsigned]("shiftrightunsigned"), @@ -147,17 +153,23 @@ object FunctionRegistry { // string functions expression[Ascii]("ascii"), expression[Base64]("base64"), + expression[Concat]("concat"), + expression[ConcatWs]("concat_ws"), expression[Encode]("encode"), expression[Decode]("decode"), - expression[StringInstr]("instr"), + expression[FormatNumber]("format_number"), expression[Lower]("lcase"), expression[Lower]("lower"), - expression[StringLength]("length"), + expression[Length]("length"), expression[Levenshtein]("levenshtein"), + expression[RegExpExtract]("regexp_extract"), + expression[RegExpReplace]("regexp_replace"), + expression[StringInstr]("instr"), expression[StringLocate]("locate"), expression[StringLPad]("lpad"), expression[StringTrimLeft]("ltrim"), - expression[StringFormat]("printf"), + expression[FormatString]("format_string"), + expression[FormatString]("printf"), expression[StringRPad]("rpad"), expression[StringRepeat]("repeat"), expression[StringReverse]("reverse"), @@ -174,7 +186,21 @@ object FunctionRegistry { // datetime functions expression[CurrentDate]("current_date"), - expression[CurrentTimestamp]("current_timestamp") + expression[CurrentTimestamp]("current_timestamp"), + expression[DateFormatClass]("date_format"), + expression[DayOfMonth]("day"), + expression[DayOfYear]("dayofyear"), + expression[DayOfMonth]("dayofmonth"), + expression[Hour]("hour"), + expression[Month]("month"), + expression[Minute]("minute"), + expression[Quarter]("quarter"), + expression[Second]("second"), + expression[WeekOfYear]("weekofyear"), + expression[Year]("year"), + + // collection functions + expression[Size]("size") ) val builtin: FunctionRegistry = { @@ -192,7 +218,10 @@ object FunctionRegistry { val builder = (expressions: Seq[Expression]) => { if (varargCtor.isDefined) { // If there is an apply method that accepts Seq[Expression], use that one. - varargCtor.get.newInstance(expressions).asInstanceOf[Expression] + Try(varargCtor.get.newInstance(expressions).asInstanceOf[Expression]) match { + case Success(e) => e + case Failure(e) => throw new AnalysisException(e.getMessage) + } } else { // Otherwise, find an ctor method that matches the number of arguments, and use that. val params = Seq.fill(expressions.size)(classOf[Expression]) @@ -202,7 +231,10 @@ object FunctionRegistry { case Failure(e) => throw new AnalysisException(s"Invalid number of arguments for function $name") } - f.newInstance(expressions : _*).asInstanceOf[Expression] + Try(f.newInstance(expressions : _*).asInstanceOf[Expression]) match { + case Success(e) => e + case Failure(e) => throw new AnalysisException(e.getMessage) + } } } (name, builder) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala index 8cb71995eb818..e214545726249 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala @@ -20,7 +20,7 @@ package org.apache.spark.sql.catalyst.analysis import javax.annotation.Nullable import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, Project, Union} +import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.rules.Rule import org.apache.spark.sql.types._ @@ -168,65 +168,65 @@ object HiveTypeCoercion { * - LongType to DoubleType */ object WidenTypes extends Rule[LogicalPlan] { - def apply(plan: LogicalPlan): LogicalPlan = plan transform { - // TODO: unions with fixed-precision decimals - case u @ Union(left, right) if u.childrenResolved && !u.resolved => - val castedInput = left.output.zip(right.output).map { - // When a string is found on one side, make the other side a string too. - case (lhs, rhs) if lhs.dataType == StringType && rhs.dataType != StringType => - (lhs, Alias(Cast(rhs, StringType), rhs.name)()) - case (lhs, rhs) if lhs.dataType != StringType && rhs.dataType == StringType => - (Alias(Cast(lhs, StringType), lhs.name)(), rhs) - - case (lhs, rhs) if lhs.dataType != rhs.dataType => - logDebug(s"Resolving mismatched union input ${lhs.dataType}, ${rhs.dataType}") - findTightestCommonTypeOfTwo(lhs.dataType, rhs.dataType).map { widestType => - val newLeft = - if (lhs.dataType == widestType) lhs else Alias(Cast(lhs, widestType), lhs.name)() - val newRight = - if (rhs.dataType == widestType) rhs else Alias(Cast(rhs, widestType), rhs.name)() - - (newLeft, newRight) - }.getOrElse { - // If there is no applicable conversion, leave expression unchanged. - (lhs, rhs) - } - case other => other - } - - val (castedLeft, castedRight) = castedInput.unzip - - val newLeft = - if (castedLeft.map(_.dataType) != left.output.map(_.dataType)) { - logDebug(s"Widening numeric types in union $castedLeft ${left.output}") - Project(castedLeft, left) - } else { - left + private[this] def widenOutputTypes(planName: String, left: LogicalPlan, right: LogicalPlan): + (LogicalPlan, LogicalPlan) = { + + // TODO: with fixed-precision decimals + val castedInput = left.output.zip(right.output).map { + // When a string is found on one side, make the other side a string too. + case (lhs, rhs) if lhs.dataType == StringType && rhs.dataType != StringType => + (lhs, Alias(Cast(rhs, StringType), rhs.name)()) + case (lhs, rhs) if lhs.dataType != StringType && rhs.dataType == StringType => + (Alias(Cast(lhs, StringType), lhs.name)(), rhs) + + case (lhs, rhs) if lhs.dataType != rhs.dataType => + logDebug(s"Resolving mismatched $planName input ${lhs.dataType}, ${rhs.dataType}") + findTightestCommonTypeOfTwo(lhs.dataType, rhs.dataType).map { widestType => + val newLeft = + if (lhs.dataType == widestType) lhs else Alias(Cast(lhs, widestType), lhs.name)() + val newRight = + if (rhs.dataType == widestType) rhs else Alias(Cast(rhs, widestType), rhs.name)() + + (newLeft, newRight) + }.getOrElse { + // If there is no applicable conversion, leave expression unchanged. + (lhs, rhs) } - val newRight = - if (castedRight.map(_.dataType) != right.output.map(_.dataType)) { - logDebug(s"Widening numeric types in union $castedRight ${right.output}") - Project(castedRight, right) - } else { - right - } + case other => other + } - Union(newLeft, newRight) + val (castedLeft, castedRight) = castedInput.unzip - // Also widen types for BinaryOperator. - case q: LogicalPlan => q transformExpressions { - // Skip nodes who's children have not been resolved yet. - case e if !e.childrenResolved => e + val newLeft = + if (castedLeft.map(_.dataType) != left.output.map(_.dataType)) { + logDebug(s"Widening numeric types in $planName $castedLeft ${left.output}") + Project(castedLeft, left) + } else { + left + } - case b @ BinaryOperator(left, right) if left.dataType != right.dataType => - findTightestCommonTypeOfTwo(left.dataType, right.dataType).map { widestType => - val newLeft = if (left.dataType == widestType) left else Cast(left, widestType) - val newRight = if (right.dataType == widestType) right else Cast(right, widestType) - b.makeCopy(Array(newLeft, newRight)) - }.getOrElse(b) // If there is no applicable conversion, leave expression unchanged. - } + val newRight = + if (castedRight.map(_.dataType) != right.output.map(_.dataType)) { + logDebug(s"Widening numeric types in $planName $castedRight ${right.output}") + Project(castedRight, right) + } else { + right + } + (newLeft, newRight) + } + + def apply(plan: LogicalPlan): LogicalPlan = plan transform { + case u @ Union(left, right) if u.childrenResolved && !u.resolved => + val (newLeft, newRight) = widenOutputTypes(u.nodeName, left, right) + Union(newLeft, newRight) + case e @ Except(left, right) if e.childrenResolved && !e.resolved => + val (newLeft, newRight) = widenOutputTypes(e.nodeName, left, right) + Except(newLeft, newRight) + case i @ Intersect(left, right) if i.childrenResolved && !i.resolved => + val (newLeft, newRight) = widenOutputTypes(i.nodeName, left, right) + Intersect(newLeft, newRight) } } @@ -335,7 +335,7 @@ object HiveTypeCoercion { * - INT gets turned into DECIMAL(10, 0) * - LONG gets turned into DECIMAL(20, 0) * - FLOAT and DOUBLE - * 1. Union operation: + * 1. Union, Intersect and Except operations: * FLOAT gets turned into DECIMAL(7, 7), DOUBLE gets turned into DECIMAL(15, 15) (this is the * same as Hive) * 2. Other operation: @@ -362,47 +362,59 @@ object HiveTypeCoercion { DoubleType -> DecimalType(15, 15) ) - def apply(plan: LogicalPlan): LogicalPlan = plan transform { - // fix decimal precision for union - case u @ Union(left, right) if u.childrenResolved && !u.resolved => - val castedInput = left.output.zip(right.output).map { - case (lhs, rhs) if lhs.dataType != rhs.dataType => - (lhs.dataType, rhs.dataType) match { - case (DecimalType.Fixed(p1, s1), DecimalType.Fixed(p2, s2)) => - // Union decimals with precision/scale p1/s2 and p2/s2 will be promoted to - // DecimalType(max(s1, s2) + max(p1-s1, p2-s2), max(s1, s2)) - val fixedType = DecimalType(max(s1, s2) + max(p1 - s1, p2 - s2), max(s1, s2)) - (Alias(Cast(lhs, fixedType), lhs.name)(), Alias(Cast(rhs, fixedType), rhs.name)()) - case (t, DecimalType.Fixed(p, s)) if intTypeToFixed.contains(t) => - (Alias(Cast(lhs, intTypeToFixed(t)), lhs.name)(), rhs) - case (DecimalType.Fixed(p, s), t) if intTypeToFixed.contains(t) => - (lhs, Alias(Cast(rhs, intTypeToFixed(t)), rhs.name)()) - case (t, DecimalType.Fixed(p, s)) if floatTypeToFixed.contains(t) => - (Alias(Cast(lhs, floatTypeToFixed(t)), lhs.name)(), rhs) - case (DecimalType.Fixed(p, s), t) if floatTypeToFixed.contains(t) => - (lhs, Alias(Cast(rhs, floatTypeToFixed(t)), rhs.name)()) - case _ => (lhs, rhs) - } - case other => other - } + private def castDecimalPrecision( + left: LogicalPlan, + right: LogicalPlan): (LogicalPlan, LogicalPlan) = { + val castedInput = left.output.zip(right.output).map { + case (lhs, rhs) if lhs.dataType != rhs.dataType => + (lhs.dataType, rhs.dataType) match { + case (DecimalType.Fixed(p1, s1), DecimalType.Fixed(p2, s2)) => + // Decimals with precision/scale p1/s2 and p2/s2 will be promoted to + // DecimalType(max(s1, s2) + max(p1-s1, p2-s2), max(s1, s2)) + val fixedType = DecimalType(max(s1, s2) + max(p1 - s1, p2 - s2), max(s1, s2)) + (Alias(Cast(lhs, fixedType), lhs.name)(), Alias(Cast(rhs, fixedType), rhs.name)()) + case (t, DecimalType.Fixed(p, s)) if intTypeToFixed.contains(t) => + (Alias(Cast(lhs, intTypeToFixed(t)), lhs.name)(), rhs) + case (DecimalType.Fixed(p, s), t) if intTypeToFixed.contains(t) => + (lhs, Alias(Cast(rhs, intTypeToFixed(t)), rhs.name)()) + case (t, DecimalType.Fixed(p, s)) if floatTypeToFixed.contains(t) => + (Alias(Cast(lhs, floatTypeToFixed(t)), lhs.name)(), rhs) + case (DecimalType.Fixed(p, s), t) if floatTypeToFixed.contains(t) => + (lhs, Alias(Cast(rhs, floatTypeToFixed(t)), rhs.name)()) + case _ => (lhs, rhs) + } + case other => other + } - val (castedLeft, castedRight) = castedInput.unzip + val (castedLeft, castedRight) = castedInput.unzip - val newLeft = - if (castedLeft.map(_.dataType) != left.output.map(_.dataType)) { - Project(castedLeft, left) - } else { - left - } + val newLeft = + if (castedLeft.map(_.dataType) != left.output.map(_.dataType)) { + Project(castedLeft, left) + } else { + left + } - val newRight = - if (castedRight.map(_.dataType) != right.output.map(_.dataType)) { - Project(castedRight, right) - } else { - right - } + val newRight = + if (castedRight.map(_.dataType) != right.output.map(_.dataType)) { + Project(castedRight, right) + } else { + right + } + (newLeft, newRight) + } + def apply(plan: LogicalPlan): LogicalPlan = plan transform { + // fix decimal precision for union, intersect and except + case u @ Union(left, right) if u.childrenResolved && !u.resolved => + val (newLeft, newRight) = castDecimalPrecision(left, right) Union(newLeft, newRight) + case i @ Intersect(left, right) if i.childrenResolved && !i.resolved => + val (newLeft, newRight) = castDecimalPrecision(left, right) + Intersect(newLeft, newRight) + case e @ Except(left, right) if e.childrenResolved && !e.resolved => + val (newLeft, newRight) = castDecimalPrecision(left, right) + Except(newLeft, newRight) // fix decimal precision for expressions case q => q.transformExpressions { @@ -439,6 +451,12 @@ object HiveTypeCoercion { DecimalType(min(p1 - s1, p2 - s2) + max(s1, s2), max(s1, s2)) ) + case Pmod(e1 @ DecimalType.Expression(p1, s1), e2 @ DecimalType.Expression(p2, s2)) => + Cast( + Pmod(Cast(e1, DecimalType.Unlimited), Cast(e2, DecimalType.Unlimited)), + DecimalType(min(p1 - s1, p2 - s2) + max(s1, s2), max(s1, s2)) + ) + // When we compare 2 decimal types with different precisions, cast them to the smallest // common precision. case b @ BinaryComparison(e1 @ DecimalType.Expression(p1, s1), @@ -672,20 +690,44 @@ object HiveTypeCoercion { } /** - * Casts types according to the expected input types for Expressions that have the trait - * [[ExpectsInputTypes]]. + * Casts types according to the expected input types for [[Expression]]s. */ object ImplicitTypeCasts extends Rule[LogicalPlan] { def apply(plan: LogicalPlan): LogicalPlan = plan transformAllExpressions { // Skip nodes who's children have not been resolved yet. case e if !e.childrenResolved => e - case e: ExpectsInputTypes if (e.inputTypes.nonEmpty) => + case b @ BinaryOperator(left, right) if left.dataType != right.dataType => + findTightestCommonTypeOfTwo(left.dataType, right.dataType).map { commonType => + if (b.inputType.acceptsType(commonType)) { + // If the expression accepts the tightest common type, cast to that. + val newLeft = if (left.dataType == commonType) left else Cast(left, commonType) + val newRight = if (right.dataType == commonType) right else Cast(right, commonType) + b.withNewChildren(Seq(newLeft, newRight)) + } else { + // Otherwise, don't do anything with the expression. + b + } + }.getOrElse(b) // If there is no applicable conversion, leave expression unchanged. + + case e: ImplicitCastInputTypes if e.inputTypes.nonEmpty => val children: Seq[Expression] = e.children.zip(e.inputTypes).map { case (in, expected) => // If we cannot do the implicit cast, just use the original input. implicitCast(in, expected).getOrElse(in) } e.withNewChildren(children) + + case e: ExpectsInputTypes if e.inputTypes.nonEmpty => + // Convert NullType into some specific target type for ExpectsInputTypes that don't do + // general implicit casting. + val children: Seq[Expression] = e.children.zip(e.inputTypes).map { case (in, expected) => + if (in.dataType == NullType && !expected.acceptsType(NullType)) { + Literal.create(null, expected.defaultConcreteType) + } else { + in + } + } + e.withNewChildren(children) } /** @@ -702,27 +744,22 @@ object HiveTypeCoercion { @Nullable val ret: Expression = (inType, expectedType) match { // If the expected type is already a parent of the input type, no need to cast. - case _ if expectedType.isSameType(inType) => e + case _ if expectedType.acceptsType(inType) => e // Cast null type (usually from null literals) into target types case (NullType, target) => Cast(e, target.defaultConcreteType) - // If the function accepts any numeric type (i.e. the ADT `NumericType`) and the input is - // already a number, leave it as is. - case (_: NumericType, NumericType) => e - // If the function accepts any numeric type and the input is a string, we follow the hive // convention and cast that input into a double case (StringType, NumericType) => Cast(e, NumericType.defaultConcreteType) - // Implicit cast among numeric types + // Implicit cast among numeric types. When we reach here, input type is not acceptable. + // If input is a numeric type but not decimal, and we expect a decimal type, // cast the input to unlimited precision decimal. - case (_: NumericType, DecimalType) if !inType.isInstanceOf[DecimalType] => - Cast(e, DecimalType.Unlimited) + case (_: NumericType, DecimalType) => Cast(e, DecimalType.Unlimited) // For any other numeric types, implicitly cast to each other, e.g. long -> int, int -> long - case (_: NumericType, target: NumericType) if e.dataType != target => Cast(e, target) - case (_: NumericType, target: NumericType) => e + case (_: NumericType, target: NumericType) => Cast(e, target) // Implicit cast between date time types case (DateType, TimestampType) => Cast(e, TimestampType) @@ -736,15 +773,9 @@ object HiveTypeCoercion { case (StringType, BinaryType) => Cast(e, BinaryType) case (any, StringType) if any != StringType => Cast(e, StringType) - // Type collection. - // First see if we can find our input type in the type collection. If we can, then just - // use the current expression; otherwise, find the first one we can implicitly cast. - case (_, TypeCollection(types)) => - if (types.exists(_.isSameType(inType))) { - e - } else { - types.flatMap(implicitCast(e, _)).headOption.orNull - } + // When we reach here, input type is not acceptable for any types in this type collection, + // try to find the first one we can implicitly cast. + case (_, TypeCollection(types)) => types.flatMap(implicitCast(e, _)).headOption.orNull // Else, just return the same input expression case _ => null diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala index f2e579afe833a..03da45b09f928 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala @@ -17,9 +17,8 @@ package org.apache.spark.sql.catalyst.analysis -import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.{errors, trees} -import org.apache.spark.sql.catalyst.errors.TreeNodeException +import org.apache.spark.sql.catalyst.expressions.codegen.CodegenFallback +import org.apache.spark.sql.catalyst.errors import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.logical.LeafNode import org.apache.spark.sql.catalyst.trees.TreeNode @@ -50,8 +49,7 @@ case class UnresolvedRelation( /** * Holds the name of an attribute that has yet to be resolved. */ -case class UnresolvedAttribute(nameParts: Seq[String]) - extends Attribute with trees.LeafNode[Expression] { +case class UnresolvedAttribute(nameParts: Seq[String]) extends Attribute with Unevaluable { def name: String = nameParts.map(n => if (n.contains(".")) s"`$n`" else n).mkString(".") @@ -67,10 +65,6 @@ case class UnresolvedAttribute(nameParts: Seq[String]) override def withQualifiers(newQualifiers: Seq[String]): UnresolvedAttribute = this override def withName(newName: String): UnresolvedAttribute = UnresolvedAttribute.quoted(newName) - // Unresolved attributes are transient at compile time and don't get evaluated during execution. - override def eval(input: InternalRow = null): Any = - throw new TreeNodeException(this, s"No function to evaluate expression. type: ${this.nodeName}") - override def toString: String = s"'$name" } @@ -79,16 +73,17 @@ object UnresolvedAttribute { def quoted(name: String): UnresolvedAttribute = new UnresolvedAttribute(Seq(name)) } -case class UnresolvedFunction(name: String, children: Seq[Expression]) extends Expression { +case class UnresolvedFunction( + name: String, + children: Seq[Expression], + isDistinct: Boolean) + extends Expression with Unevaluable { + override def dataType: DataType = throw new UnresolvedException(this, "dataType") override def foldable: Boolean = throw new UnresolvedException(this, "foldable") override def nullable: Boolean = throw new UnresolvedException(this, "nullable") override lazy val resolved = false - // Unresolved functions are transient at compile time and don't get evaluated during execution. - override def eval(input: InternalRow = null): Any = - throw new TreeNodeException(this, s"No function to evaluate expression. type: ${this.nodeName}") - override def toString: String = s"'$name(${children.mkString(",")})" } @@ -96,8 +91,7 @@ case class UnresolvedFunction(name: String, children: Seq[Expression]) extends E * Represents all of the input attributes to a given relational operator, for example in * "SELECT * FROM ...". A [[Star]] gets automatically expanded during analysis. */ -trait Star extends NamedExpression with trees.LeafNode[Expression] { - self: Product => +abstract class Star extends LeafExpression with NamedExpression { override def name: String = throw new UnresolvedException(this, "name") override def exprId: ExprId = throw new UnresolvedException(this, "exprId") @@ -107,10 +101,6 @@ trait Star extends NamedExpression with trees.LeafNode[Expression] { override def toAttribute: Attribute = throw new UnresolvedException(this, "toAttribute") override lazy val resolved = false - // Star gets expanded at runtime so we never evaluate a Star. - override def eval(input: InternalRow = null): Any = - throw new TreeNodeException(this, s"No function to evaluate expression. type: ${this.nodeName}") - def expand(input: Seq[Attribute], resolver: Resolver): Seq[NamedExpression] } @@ -122,7 +112,7 @@ trait Star extends NamedExpression with trees.LeafNode[Expression] { * @param table an optional table that should be the target of the expansion. If omitted all * tables' columns are produced. */ -case class UnresolvedStar(table: Option[String]) extends Star { +case class UnresolvedStar(table: Option[String]) extends Star with Unevaluable { override def expand(input: Seq[Attribute], resolver: Resolver): Seq[NamedExpression] = { val expandedAttributes: Seq[Attribute] = table match { @@ -151,7 +141,7 @@ case class UnresolvedStar(table: Option[String]) extends Star { * @param names the names to be associated with each output of computing [[child]]. */ case class MultiAlias(child: Expression, names: Seq[String]) - extends NamedExpression with trees.UnaryNode[Expression] { + extends UnaryExpression with NamedExpression with CodegenFallback { override def name: String = throw new UnresolvedException(this, "name") @@ -167,9 +157,6 @@ case class MultiAlias(child: Expression, names: Seq[String]) override lazy val resolved = false - override def eval(input: InternalRow = null): Any = - throw new TreeNodeException(this, s"No function to evaluate expression. type: ${this.nodeName}") - override def toString: String = s"$child AS $names" } @@ -180,7 +167,7 @@ case class MultiAlias(child: Expression, names: Seq[String]) * * @param expressions Expressions to expand. */ -case class ResolvedStar(expressions: Seq[NamedExpression]) extends Star { +case class ResolvedStar(expressions: Seq[NamedExpression]) extends Star with Unevaluable { override def expand(input: Seq[Attribute], resolver: Resolver): Seq[NamedExpression] = expressions override def toString: String = expressions.mkString("ResolvedStar(", ", ", ")") } @@ -194,24 +181,21 @@ case class ResolvedStar(expressions: Seq[NamedExpression]) extends Star { * can be key of Map, index of Array, field name of Struct. */ case class UnresolvedExtractValue(child: Expression, extraction: Expression) - extends UnaryExpression { + extends UnaryExpression with Unevaluable { override def dataType: DataType = throw new UnresolvedException(this, "dataType") override def foldable: Boolean = throw new UnresolvedException(this, "foldable") override def nullable: Boolean = throw new UnresolvedException(this, "nullable") override lazy val resolved = false - override def eval(input: InternalRow = null): Any = - throw new TreeNodeException(this, s"No function to evaluate expression. type: ${this.nodeName}") - override def toString: String = s"$child[$extraction]" } /** * Holds the expression that has yet to be aliased. */ -case class UnresolvedAlias(child: Expression) extends NamedExpression - with trees.UnaryNode[Expression] { +case class UnresolvedAlias(child: Expression) + extends UnaryExpression with NamedExpression with Unevaluable { override def toAttribute: Attribute = throw new UnresolvedException(this, "toAttribute") override def qualifiers: Seq[String] = throw new UnresolvedException(this, "qualifiers") @@ -221,7 +205,4 @@ case class UnresolvedAlias(child: Expression) extends NamedExpression override def name: String = throw new UnresolvedException(this, "name") override lazy val resolved = false - - override def eval(input: InternalRow = null): Any = - throw new TreeNodeException(this, s"No function to evaluate expression. type: ${this.nodeName}") } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala index 3f0d7b803125f..6aa4930cb8587 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala @@ -21,7 +21,6 @@ import org.apache.spark.Logging import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.errors.attachTree import org.apache.spark.sql.catalyst.expressions.codegen.{CodeGenContext, GeneratedExpressionCode} -import org.apache.spark.sql.catalyst.trees import org.apache.spark.sql.types._ /** @@ -30,11 +29,29 @@ import org.apache.spark.sql.types._ * the layout of intermediate tuples, BindReferences should be run after all such transformations. */ case class BoundReference(ordinal: Int, dataType: DataType, nullable: Boolean) - extends NamedExpression with trees.LeafNode[Expression] { + extends LeafExpression with NamedExpression { - override def toString: String = s"input[$ordinal]" + override def toString: String = s"input[$ordinal, $dataType]" - override def eval(input: InternalRow): Any = input(ordinal) + // Use special getter for primitive types (for UnsafeRow) + override def eval(input: InternalRow): Any = { + if (input.isNullAt(ordinal)) { + null + } else { + dataType match { + case BooleanType => input.getBoolean(ordinal) + case ByteType => input.getByte(ordinal) + case ShortType => input.getShort(ordinal) + case IntegerType | DateType => input.getInt(ordinal) + case LongType | TimestampType => input.getLong(ordinal) + case FloatType => input.getFloat(ordinal) + case DoubleType => input.getDouble(ordinal) + case StringType => input.getUTF8String(ordinal) + case BinaryType => input.getBinary(ordinal) + case _ => input.get(ordinal) + } + } + } override def name: String = s"i[$ordinal]" diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala index ab02addfb4d25..e66cd828481bf 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala @@ -18,16 +18,16 @@ package org.apache.spark.sql.catalyst.expressions import java.math.{BigDecimal => JavaBigDecimal} -import java.sql.{Date, Timestamp} -import org.apache.spark.Logging import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.analysis.TypeCheckResult -import org.apache.spark.sql.catalyst.expressions.codegen.{CodeGenContext, GeneratedExpressionCode} +import org.apache.spark.sql.catalyst.expressions.codegen._ import org.apache.spark.sql.catalyst.util.DateTimeUtils import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.{Interval, UTF8String} +import scala.collection.mutable + object Cast { @@ -56,7 +56,6 @@ object Cast { case (_, DateType) => true case (StringType, IntervalType) => true - case (IntervalType, StringType) => true case (StringType, _: NumericType) => true case (BooleanType, _: NumericType) => true @@ -107,7 +106,8 @@ object Cast { } /** Cast the child expression to the target data type. */ -case class Cast(child: Expression, dataType: DataType) extends UnaryExpression with Logging { +case class Cast(child: Expression, dataType: DataType) + extends UnaryExpression with CodegenFallback { override def checkInputDataTypes(): TypeCheckResult = { if (Cast.canCast(child.dataType, dataType)) { @@ -167,17 +167,7 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression w // TimestampConverter private[this] def castToTimestamp(from: DataType): Any => Any = from match { case StringType => - buildCast[UTF8String](_, utfs => { - // Throw away extra if more than 9 decimal places - val s = utfs.toString - val periodIdx = s.indexOf(".") - var n = s - if (periodIdx != -1 && n.length() - periodIdx > 9) { - n = n.substring(0, periodIdx + 10) - } - try DateTimeUtils.fromJavaTimestamp(Timestamp.valueOf(n)) - catch { case _: java.lang.IllegalArgumentException => null } - }) + buildCast[UTF8String](_, utfs => DateTimeUtils.stringToTimestamp(utfs).orNull) case BooleanType => buildCast[Boolean](_, b => if (b) 1L else 0) case LongType => @@ -220,10 +210,7 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression w // DateConverter private[this] def castToDate(from: DataType): Any => Any = from match { case StringType => - buildCast[UTF8String](_, s => - try DateTimeUtils.fromJavaDate(Date.valueOf(s.toString)) - catch { case _: java.lang.IllegalArgumentException => null } - ) + buildCast[UTF8String](_, s => DateTimeUtils.stringToDate(s).orNull) case TimestampType => // throw valid precision more than seconds, according to Hive. // Timestamp.nanos is in 0 to 999,999,999, no more than a second. @@ -433,51 +420,506 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression w protected override def nullSafeEval(input: Any): Any = cast(input) override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { - // TODO: Add support for more data types. - (child.dataType, dataType) match { + val eval = child.gen(ctx) + val nullSafeCast = nullSafeCastFunction(child.dataType, dataType, ctx) + eval.code + + castCode(ctx, eval.primitive, eval.isNull, ev.primitive, ev.isNull, dataType, nullSafeCast) + } + + // three function arguments are: child.primitive, result.primitive and result.isNull + // it returns the code snippets to be put in null safe evaluation region + private[this] type CastFunction = (String, String, String) => String + + private[this] def nullSafeCastFunction( + from: DataType, + to: DataType, + ctx: CodeGenContext): CastFunction = to match { + + case _ if from == NullType => (c, evPrim, evNull) => s"$evNull = true;" + case _ if to == from => (c, evPrim, evNull) => s"$evPrim = $c;" + case StringType => castToStringCode(from, ctx) + case BinaryType => castToBinaryCode(from) + case DateType => castToDateCode(from, ctx) + case decimal: DecimalType => castToDecimalCode(from, decimal) + case TimestampType => castToTimestampCode(from, ctx) + case IntervalType => castToIntervalCode(from) + case BooleanType => castToBooleanCode(from) + case ByteType => castToByteCode(from) + case ShortType => castToShortCode(from) + case IntegerType => castToIntCode(from) + case FloatType => castToFloatCode(from) + case LongType => castToLongCode(from) + case DoubleType => castToDoubleCode(from) + + case array: ArrayType => castArrayCode(from.asInstanceOf[ArrayType], array, ctx) + case map: MapType => castMapCode(from.asInstanceOf[MapType], map, ctx) + case struct: StructType => castStructCode(from.asInstanceOf[StructType], struct, ctx) + } + + // Since we need to cast child expressions recursively inside ComplexTypes, such as Map's + // Key and Value, Struct's field, we need to name out all the variable names involved in a cast. + private[this] def castCode(ctx: CodeGenContext, childPrim: String, childNull: String, + resultPrim: String, resultNull: String, resultType: DataType, cast: CastFunction): String = { + s""" + boolean $resultNull = $childNull; + ${ctx.javaType(resultType)} $resultPrim = ${ctx.defaultValue(resultType)}; + if (!${childNull}) { + ${cast(childPrim, resultPrim, resultNull)} + } + """ + } + + private[this] def castToStringCode(from: DataType, ctx: CodeGenContext): CastFunction = { + from match { + case BinaryType => + (c, evPrim, evNull) => s"$evPrim = UTF8String.fromBytes($c);" + case DateType => + (c, evPrim, evNull) => s"""$evPrim = UTF8String.fromString( + org.apache.spark.sql.catalyst.util.DateTimeUtils.dateToString($c));""" + case TimestampType => + (c, evPrim, evNull) => s"""$evPrim = UTF8String.fromString( + org.apache.spark.sql.catalyst.util.DateTimeUtils.timestampToString($c));""" + case _ => + (c, evPrim, evNull) => s"$evPrim = UTF8String.fromString(String.valueOf($c));" + } + } - case (BinaryType, StringType) => - defineCodeGen (ctx, ev, c => - s"${ctx.stringType}.fromBytes($c)") + private[this] def castToBinaryCode(from: DataType): CastFunction = from match { + case StringType => + (c, evPrim, evNull) => s"$evPrim = $c.getBytes();" + } + + private[this] def castToDateCode( + from: DataType, + ctx: CodeGenContext): CastFunction = from match { + case StringType => + val intOpt = ctx.freshName("intOpt") + (c, evPrim, evNull) => s""" + scala.Option $intOpt = + org.apache.spark.sql.catalyst.util.DateTimeUtils.stringToDate($c); + if ($intOpt.isDefined()) { + $evPrim = ((Integer) $intOpt.get()).intValue(); + } else { + $evNull = true; + } + """ + case TimestampType => + (c, evPrim, evNull) => + s"$evPrim = org.apache.spark.sql.catalyst.util.DateTimeUtils.millisToDays($c / 1000L);"; + case _ => + (c, evPrim, evNull) => s"$evNull = true;" + } - case (DateType, StringType) => - defineCodeGen(ctx, ev, c => - s"""${ctx.stringType}.fromString( - org.apache.spark.sql.catalyst.util.DateTimeUtils.dateToString($c))""") + private[this] def changePrecision(d: String, decimalType: DecimalType, + evPrim: String, evNull: String): String = { + decimalType match { + case DecimalType.Unlimited => + s"$evPrim = $d;" + case DecimalType.Fixed(precision, scale) => + s""" + if ($d.changePrecision($precision, $scale)) { + $evPrim = $d; + } else { + $evNull = true; + } + """ + } + } - case (TimestampType, StringType) => - defineCodeGen(ctx, ev, c => - s"""${ctx.stringType}.fromString( - org.apache.spark.sql.catalyst.util.DateTimeUtils.timestampToString($c))""") + private[this] def castToDecimalCode(from: DataType, target: DecimalType): CastFunction = { + from match { + case StringType => + (c, evPrim, evNull) => + s""" + try { + org.apache.spark.sql.types.Decimal tmpDecimal = + new org.apache.spark.sql.types.Decimal().set( + new scala.math.BigDecimal( + new java.math.BigDecimal($c.toString()))); + ${changePrecision("tmpDecimal", target, evPrim, evNull)} + } catch (java.lang.NumberFormatException e) { + $evNull = true; + } + """ + case BooleanType => + (c, evPrim, evNull) => + s""" + org.apache.spark.sql.types.Decimal tmpDecimal = null; + if ($c) { + tmpDecimal = new org.apache.spark.sql.types.Decimal().set(1); + } else { + tmpDecimal = new org.apache.spark.sql.types.Decimal().set(0); + } + ${changePrecision("tmpDecimal", target, evPrim, evNull)} + """ + case DateType => + // date can't cast to decimal in Hive + (c, evPrim, evNull) => s"$evNull = true;" + case TimestampType => + // Note that we lose precision here. + (c, evPrim, evNull) => + s""" + org.apache.spark.sql.types.Decimal tmpDecimal = + new org.apache.spark.sql.types.Decimal().set( + scala.math.BigDecimal.valueOf(${timestampToDoubleCode(c)})); + ${changePrecision("tmpDecimal", target, evPrim, evNull)} + """ + case DecimalType() => + (c, evPrim, evNull) => + s""" + org.apache.spark.sql.types.Decimal tmpDecimal = $c.clone(); + ${changePrecision("tmpDecimal", target, evPrim, evNull)} + """ + case LongType => + (c, evPrim, evNull) => + s""" + org.apache.spark.sql.types.Decimal tmpDecimal = + new org.apache.spark.sql.types.Decimal().set($c); + ${changePrecision("tmpDecimal", target, evPrim, evNull)} + """ + case x: NumericType => + // All other numeric types can be represented precisely as Doubles + (c, evPrim, evNull) => + s""" + try { + org.apache.spark.sql.types.Decimal tmpDecimal = + new org.apache.spark.sql.types.Decimal().set( + scala.math.BigDecimal.valueOf((double) $c)); + ${changePrecision("tmpDecimal", target, evPrim, evNull)} + } catch (java.lang.NumberFormatException e) { + $evNull = true; + } + """ + } + } + + private[this] def castToTimestampCode( + from: DataType, + ctx: CodeGenContext): CastFunction = from match { + case StringType => + val longOpt = ctx.freshName("longOpt") + (c, evPrim, evNull) => + s""" + scala.Option $longOpt = + org.apache.spark.sql.catalyst.util.DateTimeUtils.stringToTimestamp($c); + if ($longOpt.isDefined()) { + $evPrim = ((Long) $longOpt.get()).longValue(); + } else { + $evNull = true; + } + """ + case BooleanType => + (c, evPrim, evNull) => s"$evPrim = $c ? 1L : 0;" + case _: IntegralType => + (c, evPrim, evNull) => s"$evPrim = ${longToTimeStampCode(c)};" + case DateType => + (c, evPrim, evNull) => + s"$evPrim = org.apache.spark.sql.catalyst.util.DateTimeUtils.daysToMillis($c) * 1000;" + case DecimalType() => + (c, evPrim, evNull) => s"$evPrim = ${decimalToTimestampCode(c)};" + case DoubleType => + (c, evPrim, evNull) => + s""" + if (Double.isNaN($c) || Double.isInfinite($c)) { + $evNull = true; + } else { + $evPrim = (long)($c * 1000000L); + } + """ + case FloatType => + (c, evPrim, evNull) => + s""" + if (Float.isNaN($c) || Float.isInfinite($c)) { + $evNull = true; + } else { + $evPrim = (long)($c * 1000000L); + } + """ + } + + private[this] def castToIntervalCode(from: DataType): CastFunction = from match { + case StringType => + (c, evPrim, evNull) => + s"$evPrim = org.apache.spark.unsafe.types.Interval.fromString($c.toString());" + } + + private[this] def decimalToTimestampCode(d: String): String = + s"($d.toBigDecimal().bigDecimal().multiply(new java.math.BigDecimal(1000000L))).longValue()" + private[this] def longToTimeStampCode(l: String): String = s"$l * 1000L" + private[this] def timestampToIntegerCode(ts: String): String = + s"java.lang.Math.floor((double) $ts / 1000000L)" + private[this] def timestampToDoubleCode(ts: String): String = s"$ts / 1000000.0" + + private[this] def castToBooleanCode(from: DataType): CastFunction = from match { + case StringType => + (c, evPrim, evNull) => s"$evPrim = $c.numBytes() != 0;" + case TimestampType => + (c, evPrim, evNull) => s"$evPrim = $c != 0;" + case DateType => + // Hive would return null when cast from date to boolean + (c, evPrim, evNull) => s"$evNull = true;" + case DecimalType() => + (c, evPrim, evNull) => s"$evPrim = !$c.isZero();" + case n: NumericType => + (c, evPrim, evNull) => s"$evPrim = $c != 0;" + } + + private[this] def castToByteCode(from: DataType): CastFunction = from match { + case StringType => + (c, evPrim, evNull) => + s""" + try { + $evPrim = Byte.valueOf($c.toString()); + } catch (java.lang.NumberFormatException e) { + $evNull = true; + } + """ + case BooleanType => + (c, evPrim, evNull) => s"$evPrim = $c ? 1 : 0;" + case DateType => + (c, evPrim, evNull) => s"$evNull = true;" + case TimestampType => + (c, evPrim, evNull) => s"$evPrim = (byte) ${timestampToIntegerCode(c)};" + case DecimalType() => + (c, evPrim, evNull) => s"$evPrim = $c.toByte();" + case x: NumericType => + (c, evPrim, evNull) => s"$evPrim = (byte) $c;" + } - case (_, StringType) => - defineCodeGen(ctx, ev, c => s"${ctx.stringType}.fromString(String.valueOf($c))") + private[this] def castToShortCode(from: DataType): CastFunction = from match { + case StringType => + (c, evPrim, evNull) => + s""" + try { + $evPrim = Short.valueOf($c.toString()); + } catch (java.lang.NumberFormatException e) { + $evNull = true; + } + """ + case BooleanType => + (c, evPrim, evNull) => s"$evPrim = $c ? 1 : 0;" + case DateType => + (c, evPrim, evNull) => s"$evNull = true;" + case TimestampType => + (c, evPrim, evNull) => s"$evPrim = (short) ${timestampToIntegerCode(c)};" + case DecimalType() => + (c, evPrim, evNull) => s"$evPrim = $c.toShort();" + case x: NumericType => + (c, evPrim, evNull) => s"$evPrim = (short) $c;" + } - case (StringType, IntervalType) => - defineCodeGen(ctx, ev, c => - s"org.apache.spark.unsafe.types.Interval.fromString($c.toString())") + private[this] def castToIntCode(from: DataType): CastFunction = from match { + case StringType => + (c, evPrim, evNull) => + s""" + try { + $evPrim = Integer.valueOf($c.toString()); + } catch (java.lang.NumberFormatException e) { + $evNull = true; + } + """ + case BooleanType => + (c, evPrim, evNull) => s"$evPrim = $c ? 1 : 0;" + case DateType => + (c, evPrim, evNull) => s"$evNull = true;" + case TimestampType => + (c, evPrim, evNull) => s"$evPrim = (int) ${timestampToIntegerCode(c)};" + case DecimalType() => + (c, evPrim, evNull) => s"$evPrim = $c.toInt();" + case x: NumericType => + (c, evPrim, evNull) => s"$evPrim = (int) $c;" + } - // fallback for DecimalType, this must be before other numeric types - case (_, dt: DecimalType) => - super.genCode(ctx, ev) + private[this] def castToLongCode(from: DataType): CastFunction = from match { + case StringType => + (c, evPrim, evNull) => + s""" + try { + $evPrim = Long.valueOf($c.toString()); + } catch (java.lang.NumberFormatException e) { + $evNull = true; + } + """ + case BooleanType => + (c, evPrim, evNull) => s"$evPrim = $c ? 1 : 0;" + case DateType => + (c, evPrim, evNull) => s"$evNull = true;" + case TimestampType => + (c, evPrim, evNull) => s"$evPrim = (long) ${timestampToIntegerCode(c)};" + case DecimalType() => + (c, evPrim, evNull) => s"$evPrim = $c.toLong();" + case x: NumericType => + (c, evPrim, evNull) => s"$evPrim = (long) $c;" + } - case (BooleanType, dt: NumericType) => - defineCodeGen(ctx, ev, c => s"(${ctx.javaType(dt)})($c ? 1 : 0)") + private[this] def castToFloatCode(from: DataType): CastFunction = from match { + case StringType => + (c, evPrim, evNull) => + s""" + try { + $evPrim = Float.valueOf($c.toString()); + } catch (java.lang.NumberFormatException e) { + $evNull = true; + } + """ + case BooleanType => + (c, evPrim, evNull) => s"$evPrim = $c ? 1 : 0;" + case DateType => + (c, evPrim, evNull) => s"$evNull = true;" + case TimestampType => + (c, evPrim, evNull) => s"$evPrim = (float) (${timestampToDoubleCode(c)});" + case DecimalType() => + (c, evPrim, evNull) => s"$evPrim = $c.toFloat();" + case x: NumericType => + (c, evPrim, evNull) => s"$evPrim = (float) $c;" + } - case (dt: DecimalType, BooleanType) => - defineCodeGen(ctx, ev, c => s"!$c.isZero()") + private[this] def castToDoubleCode(from: DataType): CastFunction = from match { + case StringType => + (c, evPrim, evNull) => + s""" + try { + $evPrim = Double.valueOf($c.toString()); + } catch (java.lang.NumberFormatException e) { + $evNull = true; + } + """ + case BooleanType => + (c, evPrim, evNull) => s"$evPrim = $c ? 1 : 0;" + case DateType => + (c, evPrim, evNull) => s"$evNull = true;" + case TimestampType => + (c, evPrim, evNull) => s"$evPrim = ${timestampToDoubleCode(c)};" + case DecimalType() => + (c, evPrim, evNull) => s"$evPrim = $c.toDouble();" + case x: NumericType => + (c, evPrim, evNull) => s"$evPrim = (double) $c;" + } - case (dt: NumericType, BooleanType) => - defineCodeGen(ctx, ev, c => s"$c != 0") + private[this] def castArrayCode( + from: ArrayType, to: ArrayType, ctx: CodeGenContext): CastFunction = { + val elementCast = nullSafeCastFunction(from.elementType, to.elementType, ctx) + + val arraySeqClass = classOf[mutable.ArraySeq[Any]].getName + val fromElementNull = ctx.freshName("feNull") + val fromElementPrim = ctx.freshName("fePrim") + val toElementNull = ctx.freshName("teNull") + val toElementPrim = ctx.freshName("tePrim") + val size = ctx.freshName("n") + val j = ctx.freshName("j") + val result = ctx.freshName("result") + + (c, evPrim, evNull) => + s""" + final int $size = $c.size(); + final $arraySeqClass $result = new $arraySeqClass($size); + for (int $j = 0; $j < $size; $j ++) { + if ($c.apply($j) == null) { + $result.update($j, null); + } else { + boolean $fromElementNull = false; + ${ctx.javaType(from.elementType)} $fromElementPrim = + (${ctx.boxedType(from.elementType)}) $c.apply($j); + ${castCode(ctx, fromElementPrim, + fromElementNull, toElementPrim, toElementNull, to.elementType, elementCast)} + if ($toElementNull) { + $result.update($j, null); + } else { + $result.update($j, $toElementPrim); + } + } + } + $evPrim = $result; + """ + } - case (_: DecimalType, dt: NumericType) => - defineCodeGen(ctx, ev, c => s"($c).to${ctx.primitiveTypeName(dt)}()") + private[this] def castMapCode(from: MapType, to: MapType, ctx: CodeGenContext): CastFunction = { + val keyCast = nullSafeCastFunction(from.keyType, to.keyType, ctx) + val valueCast = nullSafeCastFunction(from.valueType, to.valueType, ctx) + + val hashMapClass = classOf[mutable.HashMap[Any, Any]].getName + val fromKeyPrim = ctx.freshName("fkp") + val fromKeyNull = ctx.freshName("fkn") + val fromValuePrim = ctx.freshName("fvp") + val fromValueNull = ctx.freshName("fvn") + val toKeyPrim = ctx.freshName("tkp") + val toKeyNull = ctx.freshName("tkn") + val toValuePrim = ctx.freshName("tvp") + val toValueNull = ctx.freshName("tvn") + val result = ctx.freshName("result") + + (c, evPrim, evNull) => + s""" + final $hashMapClass $result = new $hashMapClass(); + scala.collection.Iterator iter = $c.iterator(); + while (iter.hasNext()) { + scala.Tuple2 kv = (scala.Tuple2) iter.next(); + boolean $fromKeyNull = false; + ${ctx.javaType(from.keyType)} $fromKeyPrim = + (${ctx.boxedType(from.keyType)}) kv._1(); + ${castCode(ctx, fromKeyPrim, + fromKeyNull, toKeyPrim, toKeyNull, to.keyType, keyCast)} + + boolean $fromValueNull = kv._2() == null; + if ($fromValueNull) { + $result.put($toKeyPrim, null); + } else { + ${ctx.javaType(from.valueType)} $fromValuePrim = + (${ctx.boxedType(from.valueType)}) kv._2(); + ${castCode(ctx, fromValuePrim, + fromValueNull, toValuePrim, toValueNull, to.valueType, valueCast)} + if ($toValueNull) { + $result.put($toKeyPrim, null); + } else { + $result.put($toKeyPrim, $toValuePrim); + } + } + } + $evPrim = $result; + """ + } - case (_: NumericType, dt: NumericType) => - defineCodeGen(ctx, ev, c => s"(${ctx.javaType(dt)})($c)") + private[this] def castStructCode( + from: StructType, to: StructType, ctx: CodeGenContext): CastFunction = { - case other => - super.genCode(ctx, ev) + val fieldsCasts = from.fields.zip(to.fields).map { + case (fromField, toField) => nullSafeCastFunction(fromField.dataType, toField.dataType, ctx) } + val rowClass = classOf[GenericMutableRow].getName + val result = ctx.freshName("result") + val tmpRow = ctx.freshName("tmpRow") + + val fieldsEvalCode = fieldsCasts.zipWithIndex.map { case (cast, i) => { + val fromFieldPrim = ctx.freshName("ffp") + val fromFieldNull = ctx.freshName("ffn") + val toFieldPrim = ctx.freshName("tfp") + val toFieldNull = ctx.freshName("tfn") + val fromType = ctx.javaType(from.fields(i).dataType) + s""" + boolean $fromFieldNull = $tmpRow.isNullAt($i); + if ($fromFieldNull) { + $result.setNullAt($i); + } else { + $fromType $fromFieldPrim = + ${ctx.getColumn(tmpRow, from.fields(i).dataType, i)}; + ${castCode(ctx, fromFieldPrim, + fromFieldNull, toFieldPrim, toFieldNull, to.fields(i).dataType, cast)} + if ($toFieldNull) { + $result.setNullAt($i); + } else { + ${ctx.setColumn(result, to.fields(i).dataType, i, toFieldPrim)}; + } + } + """ + } + }.mkString("\n") + + (c, evPrim, evNull) => + s""" + final $rowClass $result = new $rowClass(${fieldsCasts.size}); + final InternalRow $tmpRow = $c; + $fieldsEvalCode + $evPrim = $result.copy(); + """ } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ExpectsInputTypes.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ExpectsInputTypes.scala index 3eb0eb195c80d..abe6457747550 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ExpectsInputTypes.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ExpectsInputTypes.scala @@ -19,12 +19,17 @@ package org.apache.spark.sql.catalyst.expressions import org.apache.spark.sql.catalyst.analysis.TypeCheckResult import org.apache.spark.sql.types.AbstractDataType - +import org.apache.spark.sql.catalyst.analysis.HiveTypeCoercion.ImplicitTypeCasts /** * An trait that gets mixin to define the expected input types of an expression. + * + * This trait is typically used by operator expressions (e.g. [[Add]], [[Subtract]]) to define + * expected input types without any implicit casting. + * + * Most function expressions (e.g. [[Substring]] should extends [[ImplicitCastInputTypes]]) instead. */ -trait ExpectsInputTypes { self: Expression => +trait ExpectsInputTypes extends Expression { /** * Expected input types from child expressions. The i-th position in the returned seq indicates @@ -40,7 +45,7 @@ trait ExpectsInputTypes { self: Expression => val mismatches = children.zip(inputTypes).zipWithIndex.collect { case ((child, expected), idx) if !expected.acceptsType(child.dataType) => s"argument ${idx + 1} is expected to be of type ${expected.simpleString}, " + - s"however, '${child.prettyString}' is of type ${child.dataType.simpleString}." + s"however, '${child.prettyString}' is of type ${child.dataType.simpleString}." } if (mismatches.isEmpty) { @@ -50,3 +55,11 @@ trait ExpectsInputTypes { self: Expression => } } } + + +/** + * A mixin for the analyzer to perform implicit type casting using [[ImplicitTypeCasts]]. + */ +trait ImplicitCastInputTypes extends ExpectsInputTypes { + // No other methods +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala index 54ec10444c4f3..29ae47e842ddb 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala @@ -19,21 +19,36 @@ package org.apache.spark.sql.catalyst.expressions import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.analysis.{TypeCheckResult, UnresolvedAttribute} -import org.apache.spark.sql.catalyst.expressions.codegen.{CodeGenContext, GeneratedExpressionCode} -import org.apache.spark.sql.catalyst.trees +import org.apache.spark.sql.catalyst.expressions.codegen._ import org.apache.spark.sql.catalyst.trees.TreeNode import org.apache.spark.sql.types._ +//////////////////////////////////////////////////////////////////////////////////////////////////// +// This file defines the basic expression abstract classes in Catalyst. +//////////////////////////////////////////////////////////////////////////////////////////////////// /** + * An expression in Catalyst. + * * If an expression wants to be exposed in the function registry (so users can call it with * "name(arguments...)", the concrete implementation must be a case class whose constructor - * arguments are all Expressions types. + * arguments are all Expressions types. See [[Substring]] for an example. + * + * There are a few important traits: + * + * - [[Nondeterministic]]: an expression that is not deterministic. + * - [[Unevaluable]]: an expression that is not supposed to be evaluated. + * - [[CodegenFallback]]: an expression that does not have code gen implemented and falls back to + * interpreted mode. + * + * - [[LeafExpression]]: an expression that has no child. + * - [[UnaryExpression]]: an expression that has one child. + * - [[BinaryExpression]]: an expression that has two children. + * - [[BinaryOperator]]: a special case of [[BinaryExpression]] that requires two children to have + * the same output data type. * - * See [[Substring]] for an example. */ abstract class Expression extends TreeNode[Expression] { - self: Product => /** * Returns true when an expression is a candidate for static evaluation before the query is @@ -49,10 +64,18 @@ abstract class Expression extends TreeNode[Expression] { def foldable: Boolean = false /** - * Returns true when the current expression always return the same result for fixed input values. + * Returns true when the current expression always return the same result for fixed inputs from + * children. + * + * Note that this means that an expression should be considered as non-deterministic if: + * - if it relies on some mutable internal state, or + * - if it relies on some implicit input that is not part of the children expression list. + * - if it has non-deterministic child or children. + * + * An example would be `SparkPartitionID` that relies on the partition id returned by TaskContext. + * By default leaf expressions are deterministic as Nil.forall(_.deterministic) returns true. */ - // TODO: Need to define explicit input values vs implicit input values. - def deterministic: Boolean = true + def deterministic: Boolean = children.forall(_.deterministic) def nullable: Boolean @@ -73,7 +96,8 @@ abstract class Expression extends TreeNode[Expression] { val primitive = ctx.freshName("primitive") val ve = GeneratedExpressionCode("", isNull, primitive) ve.code = genCode(ctx, ve) - ve + // Add `this` in the comment. + ve.copy(s"/* $this */\n" + ve.code) } /** @@ -85,19 +109,7 @@ abstract class Expression extends TreeNode[Expression] { * @param ev an [[GeneratedExpressionCode]] with unique terms. * @return Java source code */ - protected def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { - ctx.references += this - val objectTerm = ctx.freshName("obj") - s""" - /* expression: ${this} */ - Object $objectTerm = expressions[${ctx.references.size - 1}].eval(i); - boolean ${ev.isNull} = $objectTerm == null; - ${ctx.javaType(this.dataType)} ${ev.primitive} = ${ctx.defaultValue(this.dataType)}; - if (!${ev.isNull}) { - ${ev.primitive} = (${ctx.boxedType(this.dataType)}) $objectTerm; - } - """ - } + protected def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String /** * Returns `true` if this expression and all its children have been resolved to a specific schema @@ -166,11 +178,34 @@ abstract class Expression extends TreeNode[Expression] { } +/** + * An expression that cannot be evaluated. Some expressions don't live past analysis or optimization + * time (e.g. Star). This trait is used by those expressions. + */ +trait Unevaluable extends Expression { + + override def eval(input: InternalRow = null): Any = + throw new UnsupportedOperationException(s"Cannot evaluate expression: $this") + + override protected def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = + throw new UnsupportedOperationException(s"Cannot evaluate expression: $this") +} + + +/** + * An expression that is nondeterministic. + */ +trait Nondeterministic extends Expression { + override def deterministic: Boolean = false +} + + /** * A leaf expression, i.e. one without any child expressions. */ -abstract class LeafExpression extends Expression with trees.LeafNode[Expression] { - self: Product => +abstract class LeafExpression extends Expression { + + def children: Seq[Expression] = Nil } @@ -178,8 +213,11 @@ abstract class LeafExpression extends Expression with trees.LeafNode[Expression] * An expression with one input and one output. The output is by default evaluated to null * if the input is evaluated to null. */ -abstract class UnaryExpression extends Expression with trees.UnaryNode[Expression] { - self: Product => +abstract class UnaryExpression extends Expression { + + def child: Expression + + override def children: Seq[Expression] = child :: Nil override def foldable: Boolean = child.foldable override def nullable: Boolean = child.nullable @@ -253,8 +291,12 @@ abstract class UnaryExpression extends Expression with trees.UnaryNode[Expressio * An expression with two inputs and one output. The output is by default evaluated to null * if any input is evaluated to null. */ -abstract class BinaryExpression extends Expression with trees.BinaryNode[Expression] { - self: Product => +abstract class BinaryExpression extends Expression { + + def left: Expression + def right: Expression + + override def children: Seq[Expression] = Seq(left, right) override def foldable: Boolean = left.foldable && right.foldable @@ -335,15 +377,38 @@ abstract class BinaryExpression extends Expression with trees.BinaryNode[Express /** - * An expression that has two inputs that are expected to the be same type. If the two inputs have - * different types, the analyzer will find the tightest common type and do the proper type casting. + * A [[BinaryExpression]] that is an operator, with two properties: + * + * 1. The string representation is "x symbol y", rather than "funcName(x, y)". + * 2. Two inputs are expected to the be same type. If the two inputs have different types, + * the analyzer will find the tightest common type and do the proper type casting. */ -abstract class BinaryOperator extends BinaryExpression { - self: Product => +abstract class BinaryOperator extends BinaryExpression with ExpectsInputTypes { + + /** + * Expected input type from both left/right child expressions, similar to the + * [[ImplicitCastInputTypes]] trait. + */ + def inputType: AbstractDataType def symbol: String override def toString: String = s"($left $symbol $right)" + + override def inputTypes: Seq[AbstractDataType] = Seq(inputType, inputType) + + override def checkInputDataTypes(): TypeCheckResult = { + // First check whether left and right have the same type, then check if the type is acceptable. + if (left.dataType != right.dataType) { + TypeCheckResult.TypeCheckFailure(s"differing types in '$prettyString' " + + s"(${left.dataType.simpleString} and ${right.dataType.simpleString}).") + } else if (!inputType.acceptsType(left.dataType)) { + TypeCheckResult.TypeCheckFailure(s"'$prettyString' accepts ${inputType.simpleString} type," + + s" not ${left.dataType.simpleString}") + } else { + TypeCheckResult.TypeCheckSuccess + } + } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala index 886a486bf5ee0..04872fbc8b091 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala @@ -18,6 +18,9 @@ package org.apache.spark.sql.catalyst.expressions import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions.codegen.{GenerateUnsafeProjection, GenerateMutableProjection} +import org.apache.spark.sql.types.{StructType, DataType} +import org.apache.spark.unsafe.types.UTF8String /** * A [[Projection]] that is calculated by calling the `eval` of each of the specified expressions. @@ -73,6 +76,71 @@ case class InterpretedMutableProjection(expressions: Seq[Expression]) extends Mu } } +/** + * A projection that returns UnsafeRow. + */ +abstract class UnsafeProjection extends Projection { + override def apply(row: InternalRow): UnsafeRow +} + +object UnsafeProjection { + + /* + * Returns whether UnsafeProjection can support given StructType, Array[DataType] or + * Seq[Expression]. + */ + def canSupport(schema: StructType): Boolean = canSupport(schema.fields.map(_.dataType)) + def canSupport(types: Array[DataType]): Boolean = types.forall(UnsafeColumnWriter.canEmbed(_)) + def canSupport(exprs: Seq[Expression]): Boolean = canSupport(exprs.map(_.dataType).toArray) + + /** + * Returns an UnsafeProjection for given StructType. + */ + def create(schema: StructType): UnsafeProjection = create(schema.fields.map(_.dataType)) + + /** + * Returns an UnsafeProjection for given Array of DataTypes. + */ + def create(fields: Array[DataType]): UnsafeProjection = { + val exprs = fields.zipWithIndex.map(x => new BoundReference(x._2, x._1, true)) + create(exprs) + } + + /** + * Returns an UnsafeProjection for given sequence of Expressions (bounded). + */ + def create(exprs: Seq[Expression]): UnsafeProjection = { + GenerateUnsafeProjection.generate(exprs) + } + + /** + * Returns an UnsafeProjection for given sequence of Expressions, which will be bound to + * `inputSchema`. + */ + def create(exprs: Seq[Expression], inputSchema: Seq[Attribute]): UnsafeProjection = { + create(exprs.map(BindReferences.bindReference(_, inputSchema))) + } +} + +/** + * A projection that could turn UnsafeRow into GenericInternalRow + */ +case class FromUnsafeProjection(fields: Seq[DataType]) extends Projection { + + def this(schema: StructType) = this(schema.fields.map(_.dataType)) + + private[this] val expressions = fields.zipWithIndex.map { case (dt, idx) => + new BoundReference(idx, dt, true) + } + + @transient private[this] lazy val generatedProj = + GenerateMutableProjection.generate(expressions)() + + override def apply(input: InternalRow): InternalRow = { + generatedProj(input) + } +} + /** * A mutable wrapper that makes two rows appear as a single concatenated row. Designed to * be instantiated once per thread and reused. @@ -110,7 +178,15 @@ class JoinedRow extends InternalRow { override def length: Int = row1.length + row2.length - override def apply(i: Int): Any = + override def getUTF8String(i: Int): UTF8String = { + if (i < row1.length) row1.getUTF8String(i) else row2.getUTF8String(i - row1.length) + } + + override def getBinary(i: Int): Array[Byte] = { + if (i < row1.length) row1.getBinary(i) else row2.getBinary(i - row1.length) + } + + override def get(i: Int): Any = if (i < row1.length) row1(i) else row2(i - row1.length) override def isNullAt(i: Int): Boolean = @@ -204,7 +280,15 @@ class JoinedRow2 extends InternalRow { override def length: Int = row1.length + row2.length - override def apply(i: Int): Any = + override def getUTF8String(i: Int): UTF8String = { + if (i < row1.length) row1.getUTF8String(i) else row2.getUTF8String(i - row1.length) + } + + override def getBinary(i: Int): Array[Byte] = { + if (i < row1.length) row1.getBinary(i) else row2.getBinary(i - row1.length) + } + + override def get(i: Int): Any = if (i < row1.length) row1(i) else row2(i - row1.length) override def isNullAt(i: Int): Boolean = @@ -292,7 +376,16 @@ class JoinedRow3 extends InternalRow { override def length: Int = row1.length + row2.length - override def apply(i: Int): Any = + override def getUTF8String(i: Int): UTF8String = { + if (i < row1.length) row1.getUTF8String(i) else row2.getUTF8String(i - row1.length) + } + + override def getBinary(i: Int): Array[Byte] = { + if (i < row1.length) row1.getBinary(i) else row2.getBinary(i - row1.length) + } + + + override def get(i: Int): Any = if (i < row1.length) row1(i) else row2(i - row1.length) override def isNullAt(i: Int): Boolean = @@ -380,7 +473,16 @@ class JoinedRow4 extends InternalRow { override def length: Int = row1.length + row2.length - override def apply(i: Int): Any = + override def getUTF8String(i: Int): UTF8String = { + if (i < row1.length) row1.getUTF8String(i) else row2.getUTF8String(i - row1.length) + } + + override def getBinary(i: Int): Array[Byte] = { + if (i < row1.length) row1.getBinary(i) else row2.getBinary(i - row1.length) + } + + + override def get(i: Int): Any = if (i < row1.length) row1(i) else row2(i - row1.length) override def isNullAt(i: Int): Boolean = @@ -468,7 +570,16 @@ class JoinedRow5 extends InternalRow { override def length: Int = row1.length + row2.length - override def apply(i: Int): Any = + override def getUTF8String(i: Int): UTF8String = { + if (i < row1.length) row1.getUTF8String(i) else row2.getUTF8String(i - row1.length) + } + + override def getBinary(i: Int): Array[Byte] = { + if (i < row1.length) row1.getBinary(i) else row2.getBinary(i - row1.length) + } + + + override def get(i: Int): Any = if (i < row1.length) row1(i) else row2(i - row1.length) override def isNullAt(i: Int): Boolean = @@ -556,7 +667,16 @@ class JoinedRow6 extends InternalRow { override def length: Int = row1.length + row2.length - override def apply(i: Int): Any = + override def getUTF8String(i: Int): UTF8String = { + if (i < row1.length) row1.getUTF8String(i) else row2.getUTF8String(i - row1.length) + } + + override def getBinary(i: Int): Array[Byte] = { + if (i < row1.length) row1.getBinary(i) else row2.getBinary(i - row1.length) + } + + + override def get(i: Int): Any = if (i < row1.length) row1(i) else row2(i - row1.length) override def isNullAt(i: Int): Boolean = diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUDF.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUDF.scala index 6fb3343bb63f2..11c7950c0613b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUDF.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUDF.scala @@ -19,6 +19,7 @@ package org.apache.spark.sql.catalyst.expressions import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.CatalystTypeConverters +import org.apache.spark.sql.catalyst.expressions.codegen.CodegenFallback import org.apache.spark.sql.types.DataType /** @@ -29,7 +30,8 @@ case class ScalaUDF( function: AnyRef, dataType: DataType, children: Seq[Expression], - inputTypes: Seq[DataType] = Nil) extends Expression with ExpectsInputTypes { + inputTypes: Seq[DataType] = Nil) + extends Expression with ImplicitCastInputTypes with CodegenFallback { override def nullable: Boolean = true diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SortOrder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SortOrder.scala index 8ab4ef060b68c..3f436c0eb893c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SortOrder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SortOrder.scala @@ -17,9 +17,6 @@ package org.apache.spark.sql.catalyst.expressions -import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.errors.TreeNodeException -import org.apache.spark.sql.catalyst.trees import org.apache.spark.sql.types.DataType abstract sealed class SortDirection @@ -30,15 +27,14 @@ case object Descending extends SortDirection * An expression that can be used to sort a tuple. This class extends expression primarily so that * transformations over expression will descend into its child. */ -case class SortOrder(child: Expression, direction: SortDirection) extends Expression - with trees.UnaryNode[Expression] { +case class SortOrder(child: Expression, direction: SortDirection) + extends UnaryExpression with Unevaluable { + + /** Sort order is not foldable because we don't have an eval for it. */ + override def foldable: Boolean = false override def dataType: DataType = child.dataType override def nullable: Boolean = child.nullable - // SortOrder itself is never evaluated. - override def eval(input: InternalRow = null): Any = - throw new TreeNodeException(this, s"No function to evaluate expression. type: ${this.nodeName}") - override def toString: String = s"$child ${if (direction == Ascending) "ASC" else "DESC"}" } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SpecificMutableRow.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SpecificMutableRow.scala index efa24710a5a67..6f291d2c86c1e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SpecificMutableRow.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SpecificMutableRow.scala @@ -219,7 +219,7 @@ final class SpecificMutableRow(val values: Array[MutableValue]) extends MutableR values(i).isNull = true } - override def apply(i: Int): Any = values(i).boxed + override def get(i: Int): Any = values(i).boxed override def isNullAt(i: Int): Boolean = values(i).isNull diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/UnsafeRowConverter.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/UnsafeRowConverter.scala index 6af5e6200e57b..c47b16c0f8585 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/UnsafeRowConverter.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/UnsafeRowConverter.scala @@ -17,13 +17,15 @@ package org.apache.spark.sql.catalyst.expressions +import scala.util.Try + import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.util.ObjectPool import org.apache.spark.sql.types._ import org.apache.spark.unsafe.PlatformDependent import org.apache.spark.unsafe.array.ByteArrayMethods import org.apache.spark.unsafe.types.UTF8String + /** * Converts Rows into UnsafeRow format. This class is NOT thread-safe. * @@ -35,8 +37,6 @@ class UnsafeRowConverter(fieldTypes: Array[DataType]) { this(schema.fields.map(_.dataType)) } - def numFields: Int = fieldTypes.length - /** Re-used pointer to the unsafe row being written */ private[this] val unsafeRow = new UnsafeRow() @@ -77,9 +77,8 @@ class UnsafeRowConverter(fieldTypes: Array[DataType]) { row: InternalRow, baseObject: Object, baseOffset: Long, - rowLengthInBytes: Int, - pool: ObjectPool): Int = { - unsafeRow.pointTo(baseObject, baseOffset, writers.length, rowLengthInBytes, pool) + rowLengthInBytes: Int): Int = { + unsafeRow.pointTo(baseObject, baseOffset, writers.length, rowLengthInBytes) if (writers.length > 0) { // zero-out the bitset @@ -94,16 +93,16 @@ class UnsafeRowConverter(fieldTypes: Array[DataType]) { } var fieldNumber = 0 - var cursor: Int = fixedLengthSize + var appendCursor: Int = fixedLengthSize while (fieldNumber < writers.length) { if (row.isNullAt(fieldNumber)) { unsafeRow.setNullAt(fieldNumber) } else { - cursor += writers(fieldNumber).write(row, unsafeRow, fieldNumber, cursor) + appendCursor += writers(fieldNumber).write(row, unsafeRow, fieldNumber, appendCursor) } fieldNumber += 1 } - cursor + appendCursor } } @@ -118,11 +117,11 @@ private abstract class UnsafeColumnWriter { * @param source the row being converted * @param target a pointer to the converted unsafe row * @param column the column to write - * @param cursor the offset from the start of the unsafe row to the end of the row; + * @param appendCursor the offset from the start of the unsafe row to the end of the row; * used for calculating where variable-length data should be written * @return the number of variable-length bytes written */ - def write(source: InternalRow, target: UnsafeRow, column: Int, cursor: Int): Int + def write(source: InternalRow, target: UnsafeRow, column: Int, appendCursor: Int): Int /** * Return the number of bytes that are needed to write this variable-length value. @@ -144,80 +143,74 @@ private object UnsafeColumnWriter { case DoubleType => DoubleUnsafeColumnWriter case StringType => StringUnsafeColumnWriter case BinaryType => BinaryUnsafeColumnWriter - case t => ObjectUnsafeColumnWriter + case t => + throw new UnsupportedOperationException(s"Do not know how to write columns of type $t") } } + + /** + * Returns whether the dataType can be embedded into UnsafeRow (not using ObjectPool). + */ + def canEmbed(dataType: DataType): Boolean = Try(forType(dataType)).isSuccess } // ------------------------------------------------------------------------------------------------ -private object NullUnsafeColumnWriter extends NullUnsafeColumnWriter -private object BooleanUnsafeColumnWriter extends BooleanUnsafeColumnWriter -private object ByteUnsafeColumnWriter extends ByteUnsafeColumnWriter -private object ShortUnsafeColumnWriter extends ShortUnsafeColumnWriter -private object IntUnsafeColumnWriter extends IntUnsafeColumnWriter -private object LongUnsafeColumnWriter extends LongUnsafeColumnWriter -private object FloatUnsafeColumnWriter extends FloatUnsafeColumnWriter -private object DoubleUnsafeColumnWriter extends DoubleUnsafeColumnWriter -private object StringUnsafeColumnWriter extends StringUnsafeColumnWriter -private object BinaryUnsafeColumnWriter extends BinaryUnsafeColumnWriter -private object ObjectUnsafeColumnWriter extends ObjectUnsafeColumnWriter - private abstract class PrimitiveUnsafeColumnWriter extends UnsafeColumnWriter { // Primitives don't write to the variable-length region: def getSize(sourceRow: InternalRow, column: Int): Int = 0 } -private class NullUnsafeColumnWriter private() extends PrimitiveUnsafeColumnWriter { +private object NullUnsafeColumnWriter extends PrimitiveUnsafeColumnWriter { override def write(source: InternalRow, target: UnsafeRow, column: Int, cursor: Int): Int = { target.setNullAt(column) 0 } } -private class BooleanUnsafeColumnWriter private() extends PrimitiveUnsafeColumnWriter { +private object BooleanUnsafeColumnWriter extends PrimitiveUnsafeColumnWriter { override def write(source: InternalRow, target: UnsafeRow, column: Int, cursor: Int): Int = { target.setBoolean(column, source.getBoolean(column)) 0 } } -private class ByteUnsafeColumnWriter private() extends PrimitiveUnsafeColumnWriter { +private object ByteUnsafeColumnWriter extends PrimitiveUnsafeColumnWriter { override def write(source: InternalRow, target: UnsafeRow, column: Int, cursor: Int): Int = { target.setByte(column, source.getByte(column)) 0 } } -private class ShortUnsafeColumnWriter private() extends PrimitiveUnsafeColumnWriter { +private object ShortUnsafeColumnWriter extends PrimitiveUnsafeColumnWriter { override def write(source: InternalRow, target: UnsafeRow, column: Int, cursor: Int): Int = { target.setShort(column, source.getShort(column)) 0 } } -private class IntUnsafeColumnWriter private() extends PrimitiveUnsafeColumnWriter { +private object IntUnsafeColumnWriter extends PrimitiveUnsafeColumnWriter { override def write(source: InternalRow, target: UnsafeRow, column: Int, cursor: Int): Int = { target.setInt(column, source.getInt(column)) 0 } } -private class LongUnsafeColumnWriter private() extends PrimitiveUnsafeColumnWriter { +private object LongUnsafeColumnWriter extends PrimitiveUnsafeColumnWriter { override def write(source: InternalRow, target: UnsafeRow, column: Int, cursor: Int): Int = { target.setLong(column, source.getLong(column)) 0 } } -private class FloatUnsafeColumnWriter private() extends PrimitiveUnsafeColumnWriter { +private object FloatUnsafeColumnWriter extends PrimitiveUnsafeColumnWriter { override def write(source: InternalRow, target: UnsafeRow, column: Int, cursor: Int): Int = { target.setFloat(column, source.getFloat(column)) 0 } } -private class DoubleUnsafeColumnWriter private() extends PrimitiveUnsafeColumnWriter { +private object DoubleUnsafeColumnWriter extends PrimitiveUnsafeColumnWriter { override def write(source: InternalRow, target: UnsafeRow, column: Int, cursor: Int): Int = { target.setDouble(column, source.getDouble(column)) 0 @@ -226,18 +219,21 @@ private class DoubleUnsafeColumnWriter private() extends PrimitiveUnsafeColumnWr private abstract class BytesUnsafeColumnWriter extends UnsafeColumnWriter { - def getBytes(source: InternalRow, column: Int): Array[Byte] + protected[this] def isString: Boolean + protected[this] def getBytes(source: InternalRow, column: Int): Array[Byte] - def getSize(source: InternalRow, column: Int): Int = { + override def getSize(source: InternalRow, column: Int): Int = { val numBytes = getBytes(source, column).length ByteArrayMethods.roundNumberOfBytesToNearestWord(numBytes) } - protected[this] def isString: Boolean - override def write(source: InternalRow, target: UnsafeRow, column: Int, cursor: Int): Int = { - val offset = target.getBaseOffset + cursor val bytes = getBytes(source, column) + write(target, bytes, column, cursor) + } + + def write(target: UnsafeRow, bytes: Array[Byte], column: Int, cursor: Int): Int = { + val offset = target.getBaseOffset + cursor val numBytes = bytes.length if ((numBytes & 0x07) > 0) { // zero-out the padding bytes @@ -250,32 +246,31 @@ private abstract class BytesUnsafeColumnWriter extends UnsafeColumnWriter { offset, numBytes ) - val flag = if (isString) 1L << (UnsafeRow.OFFSET_BITS * 2) else 0 - target.setLong(column, flag | (cursor.toLong << UnsafeRow.OFFSET_BITS) | numBytes.toLong) + target.setLong(column, (cursor.toLong << 32) | numBytes.toLong) ByteArrayMethods.roundNumberOfBytesToNearestWord(numBytes) } } -private class StringUnsafeColumnWriter private() extends BytesUnsafeColumnWriter { +private object StringUnsafeColumnWriter extends BytesUnsafeColumnWriter { protected[this] def isString: Boolean = true def getBytes(source: InternalRow, column: Int): Array[Byte] = { source.getAs[UTF8String](column).getBytes } -} - -private class BinaryUnsafeColumnWriter private() extends BytesUnsafeColumnWriter { - protected[this] def isString: Boolean = false - def getBytes(source: InternalRow, column: Int): Array[Byte] = { - source.getAs[Array[Byte]](column) + // TODO(davies): refactor this + // specialized for codegen + def getSize(value: UTF8String): Int = + ByteArrayMethods.roundNumberOfBytesToNearestWord(value.numBytes()) + def write(target: UnsafeRow, value: UTF8String, column: Int, cursor: Int): Int = { + write(target, value.getBytes, column, cursor) } } -private class ObjectUnsafeColumnWriter private() extends UnsafeColumnWriter { - def getSize(sourceRow: InternalRow, column: Int): Int = 0 - override def write(source: InternalRow, target: UnsafeRow, column: Int, cursor: Int): Int = { - val obj = source.get(column) - val idx = target.getPool.put(obj) - target.setLong(column, - idx) - 0 +private object BinaryUnsafeColumnWriter extends BytesUnsafeColumnWriter { + protected[this] override def isString: Boolean = false + override def getBytes(source: InternalRow, column: Int): Array[Byte] = { + source.getAs[Array[Byte]](column) } + // specialized for codegen + def getSize(value: Array[Byte]): Int = + ByteArrayMethods.roundNumberOfBytesToNearestWord(value.length) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/functions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/functions.scala new file mode 100644 index 0000000000000..b924af4cc84d8 --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/functions.scala @@ -0,0 +1,292 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.expressions.aggregate + +import org.apache.spark.sql.catalyst.dsl.expressions._ +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.types._ + +case class Average(child: Expression) extends AlgebraicAggregate { + + override def children: Seq[Expression] = child :: Nil + + override def nullable: Boolean = true + + // Return data type. + override def dataType: DataType = resultType + + // Expected input data type. + // TODO: Once we remove the old code path, we can use our analyzer to cast NullType + // to the default data type of the NumericType. + override def inputTypes: Seq[AbstractDataType] = Seq(TypeCollection(NumericType, NullType)) + + private val resultType = child.dataType match { + case DecimalType.Fixed(precision, scale) => + DecimalType(precision + 4, scale + 4) + case DecimalType.Unlimited => DecimalType.Unlimited + case _ => DoubleType + } + + private val sumDataType = child.dataType match { + case _ @ DecimalType() => DecimalType.Unlimited + case _ => DoubleType + } + + private val currentSum = AttributeReference("currentSum", sumDataType)() + private val currentCount = AttributeReference("currentCount", LongType)() + + override val bufferAttributes = currentSum :: currentCount :: Nil + + override val initialValues = Seq( + /* currentSum = */ Cast(Literal(0), sumDataType), + /* currentCount = */ Literal(0L) + ) + + override val updateExpressions = Seq( + /* currentSum = */ + Add( + currentSum, + Coalesce(Cast(child, sumDataType) :: Cast(Literal(0), sumDataType) :: Nil)), + /* currentCount = */ If(IsNull(child), currentCount, currentCount + 1L) + ) + + override val mergeExpressions = Seq( + /* currentSum = */ currentSum.left + currentSum.right, + /* currentCount = */ currentCount.left + currentCount.right + ) + + // If all input are nulls, currentCount will be 0 and we will get null after the division. + override val evaluateExpression = Cast(currentSum, resultType) / Cast(currentCount, resultType) +} + +case class Count(child: Expression) extends AlgebraicAggregate { + override def children: Seq[Expression] = child :: Nil + + override def nullable: Boolean = false + + // Return data type. + override def dataType: DataType = LongType + + // Expected input data type. + override def inputTypes: Seq[AbstractDataType] = Seq(AnyDataType) + + private val currentCount = AttributeReference("currentCount", LongType)() + + override val bufferAttributes = currentCount :: Nil + + override val initialValues = Seq( + /* currentCount = */ Literal(0L) + ) + + override val updateExpressions = Seq( + /* currentCount = */ If(IsNull(child), currentCount, currentCount + 1L) + ) + + override val mergeExpressions = Seq( + /* currentCount = */ currentCount.left + currentCount.right + ) + + override val evaluateExpression = Cast(currentCount, LongType) +} + +case class First(child: Expression) extends AlgebraicAggregate { + + override def children: Seq[Expression] = child :: Nil + + override def nullable: Boolean = true + + // First is not a deterministic function. + override def deterministic: Boolean = false + + // Return data type. + override def dataType: DataType = child.dataType + + // Expected input data type. + override def inputTypes: Seq[AbstractDataType] = Seq(AnyDataType) + + private val first = AttributeReference("first", child.dataType)() + + override val bufferAttributes = first :: Nil + + override val initialValues = Seq( + /* first = */ Literal.create(null, child.dataType) + ) + + override val updateExpressions = Seq( + /* first = */ If(IsNull(first), child, first) + ) + + override val mergeExpressions = Seq( + /* first = */ If(IsNull(first.left), first.right, first.left) + ) + + override val evaluateExpression = first +} + +case class Last(child: Expression) extends AlgebraicAggregate { + + override def children: Seq[Expression] = child :: Nil + + override def nullable: Boolean = true + + // Last is not a deterministic function. + override def deterministic: Boolean = false + + // Return data type. + override def dataType: DataType = child.dataType + + // Expected input data type. + override def inputTypes: Seq[AbstractDataType] = Seq(AnyDataType) + + private val last = AttributeReference("last", child.dataType)() + + override val bufferAttributes = last :: Nil + + override val initialValues = Seq( + /* last = */ Literal.create(null, child.dataType) + ) + + override val updateExpressions = Seq( + /* last = */ If(IsNull(child), last, child) + ) + + override val mergeExpressions = Seq( + /* last = */ If(IsNull(last.right), last.left, last.right) + ) + + override val evaluateExpression = last +} + +case class Max(child: Expression) extends AlgebraicAggregate { + + override def children: Seq[Expression] = child :: Nil + + override def nullable: Boolean = true + + // Return data type. + override def dataType: DataType = child.dataType + + // Expected input data type. + override def inputTypes: Seq[AbstractDataType] = Seq(AnyDataType) + + private val max = AttributeReference("max", child.dataType)() + + override val bufferAttributes = max :: Nil + + override val initialValues = Seq( + /* max = */ Literal.create(null, child.dataType) + ) + + override val updateExpressions = Seq( + /* max = */ If(IsNull(child), max, If(IsNull(max), child, Greatest(Seq(max, child)))) + ) + + override val mergeExpressions = { + val greatest = Greatest(Seq(max.left, max.right)) + Seq( + /* max = */ If(IsNull(max.right), max.left, If(IsNull(max.left), max.right, greatest)) + ) + } + + override val evaluateExpression = max +} + +case class Min(child: Expression) extends AlgebraicAggregate { + + override def children: Seq[Expression] = child :: Nil + + override def nullable: Boolean = true + + // Return data type. + override def dataType: DataType = child.dataType + + // Expected input data type. + override def inputTypes: Seq[AbstractDataType] = Seq(AnyDataType) + + private val min = AttributeReference("min", child.dataType)() + + override val bufferAttributes = min :: Nil + + override val initialValues = Seq( + /* min = */ Literal.create(null, child.dataType) + ) + + override val updateExpressions = Seq( + /* min = */ If(IsNull(child), min, If(IsNull(min), child, Least(Seq(min, child)))) + ) + + override val mergeExpressions = { + val least = Least(Seq(min.left, min.right)) + Seq( + /* min = */ If(IsNull(min.right), min.left, If(IsNull(min.left), min.right, least)) + ) + } + + override val evaluateExpression = min +} + +case class Sum(child: Expression) extends AlgebraicAggregate { + + override def children: Seq[Expression] = child :: Nil + + override def nullable: Boolean = true + + // Return data type. + override def dataType: DataType = resultType + + // Expected input data type. + override def inputTypes: Seq[AbstractDataType] = + Seq(TypeCollection(LongType, DoubleType, DecimalType, NullType)) + + private val resultType = child.dataType match { + case DecimalType.Fixed(precision, scale) => + DecimalType(precision + 4, scale + 4) + case DecimalType.Unlimited => DecimalType.Unlimited + case _ => child.dataType + } + + private val sumDataType = child.dataType match { + case _ @ DecimalType() => DecimalType.Unlimited + case _ => child.dataType + } + + private val currentSum = AttributeReference("currentSum", sumDataType)() + + private val zero = Cast(Literal(0), sumDataType) + + override val bufferAttributes = currentSum :: Nil + + override val initialValues = Seq( + /* currentSum = */ Literal.create(null, sumDataType) + ) + + override val updateExpressions = Seq( + /* currentSum = */ + Coalesce(Seq(Add(Coalesce(Seq(currentSum, zero)), Cast(child, sumDataType)), currentSum)) + ) + + override val mergeExpressions = { + val add = Add(Coalesce(Seq(currentSum.left, zero)), Cast(currentSum.right, sumDataType)) + Seq( + /* currentSum = */ + Coalesce(Seq(add, currentSum.left)) + ) + } + + override val evaluateExpression = Cast(currentSum, resultType) +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/interfaces.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/interfaces.scala new file mode 100644 index 0000000000000..577ede73cb01f --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/interfaces.scala @@ -0,0 +1,206 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.expressions.aggregate + +import org.apache.spark.sql.catalyst.errors.TreeNodeException +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.expressions.codegen.{GeneratedExpressionCode, CodeGenContext} +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.types._ + +/** The mode of an [[AggregateFunction1]]. */ +private[sql] sealed trait AggregateMode + +/** + * An [[AggregateFunction1]] with [[Partial]] mode is used for partial aggregation. + * This function updates the given aggregation buffer with the original input of this + * function. When it has processed all input rows, the aggregation buffer is returned. + */ +private[sql] case object Partial extends AggregateMode + +/** + * An [[AggregateFunction1]] with [[PartialMerge]] mode is used to merge aggregation buffers + * containing intermediate results for this function. + * This function updates the given aggregation buffer by merging multiple aggregation buffers. + * When it has processed all input rows, the aggregation buffer is returned. + */ +private[sql] case object PartialMerge extends AggregateMode + +/** + * An [[AggregateFunction1]] with [[PartialMerge]] mode is used to merge aggregation buffers + * containing intermediate results for this function and the generate final result. + * This function updates the given aggregation buffer by merging multiple aggregation buffers. + * When it has processed all input rows, the final result of this function is returned. + */ +private[sql] case object Final extends AggregateMode + +/** + * An [[AggregateFunction2]] with [[Partial]] mode is used to evaluate this function directly + * from original input rows without any partial aggregation. + * This function updates the given aggregation buffer with the original input of this + * function. When it has processed all input rows, the final result of this function is returned. + */ +private[sql] case object Complete extends AggregateMode + +/** + * A place holder expressions used in code-gen, it does not change the corresponding value + * in the row. + */ +private[sql] case object NoOp extends Expression with Unevaluable { + override def nullable: Boolean = true + override def eval(input: InternalRow): Any = { + throw new TreeNodeException( + this, s"No function to evaluate expression. type: ${this.nodeName}") + } + override def dataType: DataType = NullType + override def children: Seq[Expression] = Nil +} + +/** + * A container for an [[AggregateFunction2]] with its [[AggregateMode]] and a field + * (`isDistinct`) indicating if DISTINCT keyword is specified for this function. + * @param aggregateFunction + * @param mode + * @param isDistinct + */ +private[sql] case class AggregateExpression2( + aggregateFunction: AggregateFunction2, + mode: AggregateMode, + isDistinct: Boolean) extends AggregateExpression { + + override def children: Seq[Expression] = aggregateFunction :: Nil + override def dataType: DataType = aggregateFunction.dataType + override def foldable: Boolean = false + override def nullable: Boolean = aggregateFunction.nullable + + override def references: AttributeSet = { + val childReferemces = mode match { + case Partial | Complete => aggregateFunction.references.toSeq + case PartialMerge | Final => aggregateFunction.bufferAttributes + } + + AttributeSet(childReferemces) + } + + override def toString: String = s"(${aggregateFunction}2,mode=$mode,isDistinct=$isDistinct)" +} + +abstract class AggregateFunction2 + extends Expression with ImplicitCastInputTypes { + + self: Product => + + /** An aggregate function is not foldable. */ + override def foldable: Boolean = false + + /** + * The offset of this function's buffer in the underlying buffer shared with other functions. + */ + var bufferOffset: Int = 0 + + /** The schema of the aggregation buffer. */ + def bufferSchema: StructType + + /** Attributes of fields in bufferSchema. */ + def bufferAttributes: Seq[AttributeReference] + + /** Clones bufferAttributes. */ + def cloneBufferAttributes: Seq[Attribute] + + /** + * Initializes its aggregation buffer located in `buffer`. + * It will use bufferOffset to find the starting point of + * its buffer in the given `buffer` shared with other functions. + */ + def initialize(buffer: MutableRow): Unit + + /** + * Updates its aggregation buffer located in `buffer` based on the given `input`. + * It will use bufferOffset to find the starting point of its buffer in the given `buffer` + * shared with other functions. + */ + def update(buffer: MutableRow, input: InternalRow): Unit + + /** + * Updates its aggregation buffer located in `buffer1` by combining intermediate results + * in the current buffer and intermediate results from another buffer `buffer2`. + * It will use bufferOffset to find the starting point of its buffer in the given `buffer1` + * and `buffer2`. + */ + def merge(buffer1: MutableRow, buffer2: InternalRow): Unit + + override protected def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = + throw new UnsupportedOperationException(s"Cannot evaluate expression: $this") +} + +/** + * A helper class for aggregate functions that can be implemented in terms of catalyst expressions. + */ +abstract class AlgebraicAggregate extends AggregateFunction2 with Serializable { + self: Product => + + val initialValues: Seq[Expression] + val updateExpressions: Seq[Expression] + val mergeExpressions: Seq[Expression] + val evaluateExpression: Expression + + override lazy val cloneBufferAttributes = bufferAttributes.map(_.newInstance()) + + /** + * A helper class for representing an attribute used in merging two + * aggregation buffers. When merging two buffers, `bufferLeft` and `bufferRight`, + * we merge buffer values and then update bufferLeft. A [[RichAttribute]] + * of an [[AttributeReference]] `a` has two functions `left` and `right`, + * which represent `a` in `bufferLeft` and `bufferRight`, respectively. + * @param a + */ + implicit class RichAttribute(a: AttributeReference) { + /** Represents this attribute at the mutable buffer side. */ + def left: AttributeReference = a + + /** Represents this attribute at the input buffer side (the data value is read-only). */ + def right: AttributeReference = cloneBufferAttributes(bufferAttributes.indexOf(a)) + } + + /** An AlgebraicAggregate's bufferSchema is derived from bufferAttributes. */ + override def bufferSchema: StructType = StructType.fromAttributes(bufferAttributes) + + override def initialize(buffer: MutableRow): Unit = { + var i = 0 + while (i < bufferAttributes.size) { + buffer(i + bufferOffset) = initialValues(i).eval() + i += 1 + } + } + + override def update(buffer: MutableRow, input: InternalRow): Unit = { + throw new UnsupportedOperationException( + "AlgebraicAggregate's update should not be called directly") + } + + override def merge(buffer1: MutableRow, buffer2: InternalRow): Unit = { + throw new UnsupportedOperationException( + "AlgebraicAggregate's merge should not be called directly") + } + + override def eval(buffer: InternalRow): Any = { + throw new UnsupportedOperationException( + "AlgebraicAggregate's eval should not be called directly") + } +} + diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregates.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregates.scala index c0e17f97e9b3c..e07c920a41d0a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregates.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregates.scala @@ -20,28 +20,27 @@ package org.apache.spark.sql.catalyst.expressions import com.clearspring.analytics.stream.cardinality.HyperLogLog import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.trees import org.apache.spark.sql.catalyst.errors.TreeNodeException import org.apache.spark.sql.catalyst.analysis.TypeCheckResult import org.apache.spark.sql.catalyst.util.TypeUtils import org.apache.spark.sql.types._ import org.apache.spark.util.collection.OpenHashSet -abstract class AggregateExpression extends Expression { - self: Product => + +trait AggregateExpression extends Expression with Unevaluable + +trait AggregateExpression1 extends AggregateExpression { /** - * Creates a new instance that can be used to compute this aggregate expression for a group - * of input rows/ + * Aggregate expressions should not be foldable. */ - def newInstance(): AggregateFunction + override def foldable: Boolean = false /** - * [[AggregateExpression.eval]] should never be invoked because [[AggregateExpression]]'s are - * replaced with a physical aggregate operator at runtime. + * Creates a new instance that can be used to compute this aggregate expression for a group + * of input rows/ */ - override def eval(input: InternalRow = null): Any = - throw new TreeNodeException(this, s"No function to evaluate expression. type: ${this.nodeName}") + def newInstance(): AggregateFunction1 } /** @@ -57,11 +56,10 @@ case class SplitEvaluation( partialEvaluations: Seq[NamedExpression]) /** - * An [[AggregateExpression]] that can be partially computed without seeing all relevant tuples. + * An [[AggregateExpression1]] that can be partially computed without seeing all relevant tuples. * These partial evaluations can then be combined to compute the actual answer. */ -abstract class PartialAggregate extends AggregateExpression { - self: Product => +trait PartialAggregate1 extends AggregateExpression1 { /** * Returns a [[SplitEvaluation]] that computes this aggregation using partial aggregation. @@ -71,14 +69,13 @@ abstract class PartialAggregate extends AggregateExpression { /** * A specific implementation of an aggregate function. Used to wrap a generic - * [[AggregateExpression]] with an algorithm that will be used to compute one specific result. + * [[AggregateExpression1]] with an algorithm that will be used to compute one specific result. */ -abstract class AggregateFunction - extends AggregateExpression with Serializable with trees.LeafNode[Expression] { - self: Product => +abstract class AggregateFunction1 + extends LeafExpression with AggregateExpression1 with Serializable { /** Base should return the generic aggregate expression that this function is computing */ - val base: AggregateExpression + val base: AggregateExpression1 override def nullable: Boolean = base.nullable override def dataType: DataType = base.dataType @@ -86,12 +83,12 @@ abstract class AggregateFunction def update(input: InternalRow): Unit // Do we really need this? - override def newInstance(): AggregateFunction = { + override def newInstance(): AggregateFunction1 = { makeCopy(productIterator.map { case a: AnyRef => a }.toArray) } } -case class Min(child: Expression) extends PartialAggregate with trees.UnaryNode[Expression] { +case class Min(child: Expression) extends UnaryExpression with PartialAggregate1 { override def nullable: Boolean = true override def dataType: DataType = child.dataType @@ -107,7 +104,7 @@ case class Min(child: Expression) extends PartialAggregate with trees.UnaryNode[ TypeUtils.checkForOrderingExpr(child.dataType, "function min") } -case class MinFunction(expr: Expression, base: AggregateExpression) extends AggregateFunction { +case class MinFunction(expr: Expression, base: AggregateExpression1) extends AggregateFunction1 { def this() = this(null, null) // Required for serialization. val currentMin: MutableLiteral = MutableLiteral(null, expr.dataType) @@ -124,7 +121,7 @@ case class MinFunction(expr: Expression, base: AggregateExpression) extends Aggr override def eval(input: InternalRow): Any = currentMin.value } -case class Max(child: Expression) extends PartialAggregate with trees.UnaryNode[Expression] { +case class Max(child: Expression) extends UnaryExpression with PartialAggregate1 { override def nullable: Boolean = true override def dataType: DataType = child.dataType @@ -140,7 +137,7 @@ case class Max(child: Expression) extends PartialAggregate with trees.UnaryNode[ TypeUtils.checkForOrderingExpr(child.dataType, "function max") } -case class MaxFunction(expr: Expression, base: AggregateExpression) extends AggregateFunction { +case class MaxFunction(expr: Expression, base: AggregateExpression1) extends AggregateFunction1 { def this() = this(null, null) // Required for serialization. val currentMax: MutableLiteral = MutableLiteral(null, expr.dataType) @@ -157,7 +154,7 @@ case class MaxFunction(expr: Expression, base: AggregateExpression) extends Aggr override def eval(input: InternalRow): Any = currentMax.value } -case class Count(child: Expression) extends PartialAggregate with trees.UnaryNode[Expression] { +case class Count(child: Expression) extends UnaryExpression with PartialAggregate1 { override def nullable: Boolean = false override def dataType: LongType.type = LongType @@ -170,7 +167,7 @@ case class Count(child: Expression) extends PartialAggregate with trees.UnaryNod override def newInstance(): CountFunction = new CountFunction(child, this) } -case class CountFunction(expr: Expression, base: AggregateExpression) extends AggregateFunction { +case class CountFunction(expr: Expression, base: AggregateExpression1) extends AggregateFunction1 { def this() = this(null, null) // Required for serialization. var count: Long = _ @@ -185,7 +182,7 @@ case class CountFunction(expr: Expression, base: AggregateExpression) extends Ag override def eval(input: InternalRow): Any = count } -case class CountDistinct(expressions: Seq[Expression]) extends PartialAggregate { +case class CountDistinct(expressions: Seq[Expression]) extends PartialAggregate1 { def this() = this(null) override def children: Seq[Expression] = expressions @@ -205,8 +202,8 @@ case class CountDistinct(expressions: Seq[Expression]) extends PartialAggregate case class CountDistinctFunction( @transient expr: Seq[Expression], - @transient base: AggregateExpression) - extends AggregateFunction { + @transient base: AggregateExpression1) + extends AggregateFunction1 { def this() = this(null, null) // Required for serialization. @@ -225,7 +222,7 @@ case class CountDistinctFunction( override def eval(input: InternalRow): Any = seen.size.toLong } -case class CollectHashSet(expressions: Seq[Expression]) extends AggregateExpression { +case class CollectHashSet(expressions: Seq[Expression]) extends AggregateExpression1 { def this() = this(null) override def children: Seq[Expression] = expressions @@ -238,8 +235,8 @@ case class CollectHashSet(expressions: Seq[Expression]) extends AggregateExpress case class CollectHashSetFunction( @transient expr: Seq[Expression], - @transient base: AggregateExpression) - extends AggregateFunction { + @transient base: AggregateExpression1) + extends AggregateFunction1 { def this() = this(null, null) // Required for serialization. @@ -260,7 +257,7 @@ case class CollectHashSetFunction( } } -case class CombineSetsAndCount(inputSet: Expression) extends AggregateExpression { +case class CombineSetsAndCount(inputSet: Expression) extends AggregateExpression1 { def this() = this(null) override def children: Seq[Expression] = inputSet :: Nil @@ -274,8 +271,8 @@ case class CombineSetsAndCount(inputSet: Expression) extends AggregateExpression case class CombineSetsAndCountFunction( @transient inputSet: Expression, - @transient base: AggregateExpression) - extends AggregateFunction { + @transient base: AggregateExpression1) + extends AggregateFunction1 { def this() = this(null, null) // Required for serialization. @@ -310,7 +307,7 @@ private[sql] case object HyperLogLogUDT extends UserDefinedType[HyperLogLog] { } case class ApproxCountDistinctPartition(child: Expression, relativeSD: Double) - extends AggregateExpression with trees.UnaryNode[Expression] { + extends UnaryExpression with AggregateExpression1 { override def nullable: Boolean = false override def dataType: DataType = HyperLogLogUDT @@ -322,9 +319,9 @@ case class ApproxCountDistinctPartition(child: Expression, relativeSD: Double) case class ApproxCountDistinctPartitionFunction( expr: Expression, - base: AggregateExpression, + base: AggregateExpression1, relativeSD: Double) - extends AggregateFunction { + extends AggregateFunction1 { def this() = this(null, null, 0) // Required for serialization. private val hyperLogLog = new HyperLogLog(relativeSD) @@ -340,7 +337,7 @@ case class ApproxCountDistinctPartitionFunction( } case class ApproxCountDistinctMerge(child: Expression, relativeSD: Double) - extends AggregateExpression with trees.UnaryNode[Expression] { + extends UnaryExpression with AggregateExpression1 { override def nullable: Boolean = false override def dataType: LongType.type = LongType @@ -352,9 +349,9 @@ case class ApproxCountDistinctMerge(child: Expression, relativeSD: Double) case class ApproxCountDistinctMergeFunction( expr: Expression, - base: AggregateExpression, + base: AggregateExpression1, relativeSD: Double) - extends AggregateFunction { + extends AggregateFunction1 { def this() = this(null, null, 0) // Required for serialization. private val hyperLogLog = new HyperLogLog(relativeSD) @@ -368,7 +365,7 @@ case class ApproxCountDistinctMergeFunction( } case class ApproxCountDistinct(child: Expression, relativeSD: Double = 0.05) - extends PartialAggregate with trees.UnaryNode[Expression] { + extends UnaryExpression with PartialAggregate1 { override def nullable: Boolean = false override def dataType: LongType.type = LongType @@ -386,7 +383,7 @@ case class ApproxCountDistinct(child: Expression, relativeSD: Double = 0.05) override def newInstance(): CountDistinctFunction = new CountDistinctFunction(child :: Nil, this) } -case class Average(child: Expression) extends PartialAggregate with trees.UnaryNode[Expression] { +case class Average(child: Expression) extends UnaryExpression with PartialAggregate1 { override def prettyName: String = "avg" @@ -432,8 +429,8 @@ case class Average(child: Expression) extends PartialAggregate with trees.UnaryN TypeUtils.checkForNumericExpr(child.dataType, "function average") } -case class AverageFunction(expr: Expression, base: AggregateExpression) - extends AggregateFunction { +case class AverageFunction(expr: Expression, base: AggregateExpression1) + extends AggregateFunction1 { def this() = this(null, null) // Required for serialization. @@ -479,7 +476,7 @@ case class AverageFunction(expr: Expression, base: AggregateExpression) } } -case class Sum(child: Expression) extends PartialAggregate with trees.UnaryNode[Expression] { +case class Sum(child: Expression) extends UnaryExpression with PartialAggregate1 { override def nullable: Boolean = true @@ -514,7 +511,7 @@ case class Sum(child: Expression) extends PartialAggregate with trees.UnaryNode[ TypeUtils.checkForNumericExpr(child.dataType, "function sum") } -case class SumFunction(expr: Expression, base: AggregateExpression) extends AggregateFunction { +case class SumFunction(expr: Expression, base: AggregateExpression1) extends AggregateFunction1 { def this() = this(null, null) // Required for serialization. private val calcType = @@ -559,7 +556,7 @@ case class SumFunction(expr: Expression, base: AggregateExpression) extends Aggr * <-- null <-- no data * null <-- null <-- no data */ -case class CombineSum(child: Expression) extends AggregateExpression { +case class CombineSum(child: Expression) extends AggregateExpression1 { def this() = this(null) override def children: Seq[Expression] = child :: Nil @@ -569,8 +566,8 @@ case class CombineSum(child: Expression) extends AggregateExpression { override def newInstance(): CombineSumFunction = new CombineSumFunction(child, this) } -case class CombineSumFunction(expr: Expression, base: AggregateExpression) - extends AggregateFunction { +case class CombineSumFunction(expr: Expression, base: AggregateExpression1) + extends AggregateFunction1 { def this() = this(null, null) // Required for serialization. @@ -606,8 +603,7 @@ case class CombineSumFunction(expr: Expression, base: AggregateExpression) } } -case class SumDistinct(child: Expression) - extends PartialAggregate with trees.UnaryNode[Expression] { +case class SumDistinct(child: Expression) extends UnaryExpression with PartialAggregate1 { def this() = this(null) override def nullable: Boolean = true @@ -633,8 +629,8 @@ case class SumDistinct(child: Expression) TypeUtils.checkForNumericExpr(child.dataType, "function sumDistinct") } -case class SumDistinctFunction(expr: Expression, base: AggregateExpression) - extends AggregateFunction { +case class SumDistinctFunction(expr: Expression, base: AggregateExpression1) + extends AggregateFunction1 { def this() = this(null, null) // Required for serialization. @@ -659,7 +655,7 @@ case class SumDistinctFunction(expr: Expression, base: AggregateExpression) } } -case class CombineSetsAndSum(inputSet: Expression, base: Expression) extends AggregateExpression { +case class CombineSetsAndSum(inputSet: Expression, base: Expression) extends AggregateExpression1 { def this() = this(null, null) override def children: Seq[Expression] = inputSet :: Nil @@ -673,8 +669,8 @@ case class CombineSetsAndSum(inputSet: Expression, base: Expression) extends Agg case class CombineSetsAndSumFunction( @transient inputSet: Expression, - @transient base: AggregateExpression) - extends AggregateFunction { + @transient base: AggregateExpression1) + extends AggregateFunction1 { def this() = this(null, null) // Required for serialization. @@ -701,7 +697,7 @@ case class CombineSetsAndSumFunction( } } -case class First(child: Expression) extends PartialAggregate with trees.UnaryNode[Expression] { +case class First(child: Expression) extends UnaryExpression with PartialAggregate1 { override def nullable: Boolean = true override def dataType: DataType = child.dataType override def toString: String = s"FIRST($child)" @@ -715,7 +711,7 @@ case class First(child: Expression) extends PartialAggregate with trees.UnaryNod override def newInstance(): FirstFunction = new FirstFunction(child, this) } -case class FirstFunction(expr: Expression, base: AggregateExpression) extends AggregateFunction { +case class FirstFunction(expr: Expression, base: AggregateExpression1) extends AggregateFunction1 { def this() = this(null, null) // Required for serialization. var result: Any = null @@ -729,7 +725,7 @@ case class FirstFunction(expr: Expression, base: AggregateExpression) extends Ag override def eval(input: InternalRow): Any = result } -case class Last(child: Expression) extends PartialAggregate with trees.UnaryNode[Expression] { +case class Last(child: Expression) extends UnaryExpression with PartialAggregate1 { override def references: AttributeSet = child.references override def nullable: Boolean = true override def dataType: DataType = child.dataType @@ -744,7 +740,7 @@ case class Last(child: Expression) extends PartialAggregate with trees.UnaryNode override def newInstance(): LastFunction = new LastFunction(child, this) } -case class LastFunction(expr: Expression, base: AggregateExpression) extends AggregateFunction { +case class LastFunction(expr: Expression, base: AggregateExpression1) extends AggregateFunction1 { def this() = this(null, null) // Required for serialization. var result: Any = null diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala index 8476af4a5d8d6..05b5ad88fee8f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala @@ -18,36 +18,44 @@ package org.apache.spark.sql.catalyst.expressions import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.analysis.TypeCheckResult -import org.apache.spark.sql.catalyst.expressions.codegen.{CodeGenContext, GeneratedExpressionCode} +import org.apache.spark.sql.catalyst.expressions.codegen._ import org.apache.spark.sql.catalyst.util.TypeUtils import org.apache.spark.sql.types._ +import org.apache.spark.unsafe.types.Interval -abstract class UnaryArithmetic extends UnaryExpression { - self: Product => + +case class UnaryMinus(child: Expression) extends UnaryExpression with ExpectsInputTypes { + + override def inputTypes: Seq[AbstractDataType] = Seq(TypeCollection.NumericAndInterval) override def dataType: DataType = child.dataType -} -case class UnaryMinus(child: Expression) extends UnaryArithmetic { override def toString: String = s"-$child" - override def checkInputDataTypes(): TypeCheckResult = - TypeUtils.checkForNumericExpr(child.dataType, "operator -") - private lazy val numeric = TypeUtils.getNumeric(dataType) override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = dataType match { case dt: DecimalType => defineCodeGen(ctx, ev, c => s"$c.unary_$$minus()") case dt: NumericType => defineCodeGen(ctx, ev, c => s"(${ctx.javaType(dt)})(-($c))") + case dt: IntervalType => defineCodeGen(ctx, ev, c => s"$c.negate()") } - protected override def nullSafeEval(input: Any): Any = numeric.negate(input) + protected override def nullSafeEval(input: Any): Any = { + if (dataType.isInstanceOf[IntervalType]) { + input.asInstanceOf[Interval].negate() + } else { + numeric.negate(input) + } + } } -case class UnaryPositive(child: Expression) extends UnaryArithmetic { +case class UnaryPositive(child: Expression) extends UnaryExpression with ExpectsInputTypes { override def prettyName: String = "positive" + override def inputTypes: Seq[AbstractDataType] = Seq(TypeCollection.NumericAndInterval) + + override def dataType: DataType = child.dataType + override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = defineCodeGen(ctx, ev, c => c) @@ -57,32 +65,29 @@ case class UnaryPositive(child: Expression) extends UnaryArithmetic { /** * A function that get the absolute value of the numeric value. */ -case class Abs(child: Expression) extends UnaryArithmetic { - override def checkInputDataTypes(): TypeCheckResult = - TypeUtils.checkForNumericExpr(child.dataType, "function abs") +case class Abs(child: Expression) + extends UnaryExpression with ExpectsInputTypes with CodegenFallback { + + override def inputTypes: Seq[AbstractDataType] = Seq(NumericType) + + override def dataType: DataType = child.dataType private lazy val numeric = TypeUtils.getNumeric(dataType) + override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = dataType match { + case dt: DecimalType => + defineCodeGen(ctx, ev, c => s"$c.abs()") + case dt: NumericType => + defineCodeGen(ctx, ev, c => s"(${ctx.javaType(dt)})(java.lang.Math.abs($c))") + } + protected override def nullSafeEval(input: Any): Any = numeric.abs(input) } abstract class BinaryArithmetic extends BinaryOperator { - self: Product => override def dataType: DataType = left.dataType - override def checkInputDataTypes(): TypeCheckResult = { - if (left.dataType != right.dataType) { - TypeCheckResult.TypeCheckFailure( - s"differing types in ${this.getClass.getSimpleName} " + - s"(${left.dataType} and ${right.dataType}).") - } else { - checkTypesInternal(dataType) - } - } - - protected def checkTypesInternal(t: DataType): TypeCheckResult - /** Name of the function for this expression on a [[Decimal]] type. */ def decimalMethod: String = sys.error("BinaryArithmetics must override either decimalMethod or genCode") @@ -104,62 +109,95 @@ private[sql] object BinaryArithmetic { } case class Add(left: Expression, right: Expression) extends BinaryArithmetic { + + override def inputType: AbstractDataType = TypeCollection.NumericAndInterval + override def symbol: String = "+" - override def decimalMethod: String = "$plus" override lazy val resolved = childrenResolved && checkInputDataTypes().isSuccess && !DecimalType.isFixed(dataType) - protected def checkTypesInternal(t: DataType) = - TypeUtils.checkForNumericExpr(t, "operator " + symbol) - private lazy val numeric = TypeUtils.getNumeric(dataType) - protected override def nullSafeEval(input1: Any, input2: Any): Any = numeric.plus(input1, input2) + protected override def nullSafeEval(input1: Any, input2: Any): Any = { + if (dataType.isInstanceOf[IntervalType]) { + input1.asInstanceOf[Interval].add(input2.asInstanceOf[Interval]) + } else { + numeric.plus(input1, input2) + } + } + + override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = dataType match { + case dt: DecimalType => + defineCodeGen(ctx, ev, (eval1, eval2) => s"$eval1.$$plus($eval2)") + case ByteType | ShortType => + defineCodeGen(ctx, ev, + (eval1, eval2) => s"(${ctx.javaType(dataType)})($eval1 $symbol $eval2)") + case IntervalType => + defineCodeGen(ctx, ev, (eval1, eval2) => s"$eval1.add($eval2)") + case _ => + defineCodeGen(ctx, ev, (eval1, eval2) => s"$eval1 $symbol $eval2") + } } case class Subtract(left: Expression, right: Expression) extends BinaryArithmetic { + + override def inputType: AbstractDataType = TypeCollection.NumericAndInterval + override def symbol: String = "-" - override def decimalMethod: String = "$minus" override lazy val resolved = childrenResolved && checkInputDataTypes().isSuccess && !DecimalType.isFixed(dataType) - protected def checkTypesInternal(t: DataType) = - TypeUtils.checkForNumericExpr(t, "operator " + symbol) - private lazy val numeric = TypeUtils.getNumeric(dataType) - protected override def nullSafeEval(input1: Any, input2: Any): Any = numeric.minus(input1, input2) + protected override def nullSafeEval(input1: Any, input2: Any): Any = { + if (dataType.isInstanceOf[IntervalType]) { + input1.asInstanceOf[Interval].subtract(input2.asInstanceOf[Interval]) + } else { + numeric.minus(input1, input2) + } + } + + override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = dataType match { + case dt: DecimalType => + defineCodeGen(ctx, ev, (eval1, eval2) => s"$eval1.$$minus($eval2)") + case ByteType | ShortType => + defineCodeGen(ctx, ev, + (eval1, eval2) => s"(${ctx.javaType(dataType)})($eval1 $symbol $eval2)") + case IntervalType => + defineCodeGen(ctx, ev, (eval1, eval2) => s"$eval1.subtract($eval2)") + case _ => + defineCodeGen(ctx, ev, (eval1, eval2) => s"$eval1 $symbol $eval2") + } } case class Multiply(left: Expression, right: Expression) extends BinaryArithmetic { + + override def inputType: AbstractDataType = NumericType + override def symbol: String = "*" override def decimalMethod: String = "$times" override lazy val resolved = childrenResolved && checkInputDataTypes().isSuccess && !DecimalType.isFixed(dataType) - protected def checkTypesInternal(t: DataType) = - TypeUtils.checkForNumericExpr(t, "operator " + symbol) - private lazy val numeric = TypeUtils.getNumeric(dataType) protected override def nullSafeEval(input1: Any, input2: Any): Any = numeric.times(input1, input2) } case class Divide(left: Expression, right: Expression) extends BinaryArithmetic { + + override def inputType: AbstractDataType = NumericType + override def symbol: String = "/" override def decimalMethod: String = "$div" - override def nullable: Boolean = true override lazy val resolved = childrenResolved && checkInputDataTypes().isSuccess && !DecimalType.isFixed(dataType) - protected def checkTypesInternal(t: DataType) = - TypeUtils.checkForNumericExpr(t, "operator " + symbol) - private lazy val div: (Any, Any) => Any = dataType match { case ft: FractionalType => ft.fractional.asInstanceOf[Fractional[Any]].div case it: IntegralType => it.integral.asInstanceOf[Integral[Any]].quot @@ -215,17 +253,16 @@ case class Divide(left: Expression, right: Expression) extends BinaryArithmetic } case class Remainder(left: Expression, right: Expression) extends BinaryArithmetic { + + override def inputType: AbstractDataType = NumericType + override def symbol: String = "%" override def decimalMethod: String = "remainder" - override def nullable: Boolean = true override lazy val resolved = childrenResolved && checkInputDataTypes().isSuccess && !DecimalType.isFixed(dataType) - protected def checkTypesInternal(t: DataType) = - TypeUtils.checkForNumericExpr(t, "operator " + symbol) - private lazy val integral = dataType match { case i: IntegralType => i.integral.asInstanceOf[Integral[Any]] case i: FractionalType => i.asIntegral.asInstanceOf[Integral[Any]] @@ -281,10 +318,11 @@ case class Remainder(left: Expression, right: Expression) extends BinaryArithmet } case class MaxOf(left: Expression, right: Expression) extends BinaryArithmetic { - override def nullable: Boolean = left.nullable && right.nullable + // TODO: Remove MaxOf and MinOf, and replace its usage with Greatest and Least. - protected def checkTypesInternal(t: DataType) = - TypeUtils.checkForOrderingExpr(t, "function maxOf") + override def inputType: AbstractDataType = TypeCollection.Ordered + + override def nullable: Boolean = left.nullable && right.nullable private lazy val ordering = TypeUtils.getOrdering(dataType) @@ -331,14 +369,14 @@ case class MaxOf(left: Expression, right: Expression) extends BinaryArithmetic { } override def symbol: String = "max" - override def prettyName: String = symbol } case class MinOf(left: Expression, right: Expression) extends BinaryArithmetic { - override def nullable: Boolean = left.nullable && right.nullable + // TODO: Remove MaxOf and MinOf, and replace its usage with Greatest and Least. - protected def checkTypesInternal(t: DataType) = - TypeUtils.checkForOrderingExpr(t, "function minOf") + override def inputType: AbstractDataType = TypeCollection.Ordered + + override def nullable: Boolean = left.nullable && right.nullable private lazy val ordering = TypeUtils.getOrdering(dataType) @@ -385,5 +423,98 @@ case class MinOf(left: Expression, right: Expression) extends BinaryArithmetic { } override def symbol: String = "min" - override def prettyName: String = symbol +} + +case class Pmod(left: Expression, right: Expression) extends BinaryArithmetic { + + override def toString: String = s"pmod($left, $right)" + + override def symbol: String = "pmod" + + protected def checkTypesInternal(t: DataType) = + TypeUtils.checkForNumericExpr(t, "pmod") + + override def inputType: AbstractDataType = NumericType + + protected override def nullSafeEval(left: Any, right: Any) = + dataType match { + case IntegerType => pmod(left.asInstanceOf[Int], right.asInstanceOf[Int]) + case LongType => pmod(left.asInstanceOf[Long], right.asInstanceOf[Long]) + case ShortType => pmod(left.asInstanceOf[Short], right.asInstanceOf[Short]) + case ByteType => pmod(left.asInstanceOf[Byte], right.asInstanceOf[Byte]) + case FloatType => pmod(left.asInstanceOf[Float], right.asInstanceOf[Float]) + case DoubleType => pmod(left.asInstanceOf[Double], right.asInstanceOf[Double]) + case _: DecimalType => pmod(left.asInstanceOf[Decimal], right.asInstanceOf[Decimal]) + } + + override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + nullSafeCodeGen(ctx, ev, (eval1, eval2) => { + dataType match { + case dt: DecimalType => + val decimalAdd = "$plus" + s""" + ${ctx.javaType(dataType)} r = $eval1.remainder($eval2); + if (r.compare(new org.apache.spark.sql.types.Decimal().set(0)) < 0) { + ${ev.primitive} = (r.$decimalAdd($eval2)).remainder($eval2); + } else { + ${ev.primitive} = r; + } + """ + // byte and short are casted into int when add, minus, times or divide + case ByteType | ShortType => + s""" + ${ctx.javaType(dataType)} r = (${ctx.javaType(dataType)})($eval1 % $eval2); + if (r < 0) { + ${ev.primitive} = (${ctx.javaType(dataType)})((r + $eval2) % $eval2); + } else { + ${ev.primitive} = r; + } + """ + case _ => + s""" + ${ctx.javaType(dataType)} r = $eval1 % $eval2; + if (r < 0) { + ${ev.primitive} = (r + $eval2) % $eval2; + } else { + ${ev.primitive} = r; + } + """ + } + }) + } + + private def pmod(a: Int, n: Int): Int = { + val r = a % n + if (r < 0) {(r + n) % n} else r + } + + private def pmod(a: Long, n: Long): Long = { + val r = a % n + if (r < 0) {(r + n) % n} else r + } + + private def pmod(a: Byte, n: Byte): Byte = { + val r = a % n + if (r < 0) {((r + n) % n).toByte} else r.toByte + } + + private def pmod(a: Double, n: Double): Double = { + val r = a % n + if (r < 0) {(r + n) % n} else r + } + + private def pmod(a: Short, n: Short): Short = { + val r = a % n + if (r < 0) {((r + n) % n).toShort} else r.toShort + } + + private def pmod(a: Float, n: Float): Float = { + val r = a % n + if (r < 0) {(r + n) % n} else r + } + + private def pmod(a: Decimal, n: Decimal): Decimal = { + val r = a % n + if (r.compare(Decimal(0)) < 0) {(r + n) % n} else r + } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/bitwise.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/bitwise.scala index 2d47124d247e7..a1e48c4210877 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/bitwise.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/bitwise.scala @@ -17,9 +17,7 @@ package org.apache.spark.sql.catalyst.expressions -import org.apache.spark.sql.catalyst.analysis.TypeCheckResult import org.apache.spark.sql.catalyst.expressions.codegen._ -import org.apache.spark.sql.catalyst.util.TypeUtils import org.apache.spark.sql.types._ @@ -29,10 +27,10 @@ import org.apache.spark.sql.types._ * Code generation inherited from BinaryArithmetic. */ case class BitwiseAnd(left: Expression, right: Expression) extends BinaryArithmetic { - override def symbol: String = "&" - protected def checkTypesInternal(t: DataType) = - TypeUtils.checkForBitwiseExpr(t, "operator " + symbol) + override def inputType: AbstractDataType = IntegralType + + override def symbol: String = "&" private lazy val and: (Any, Any) => Any = dataType match { case ByteType => @@ -54,10 +52,10 @@ case class BitwiseAnd(left: Expression, right: Expression) extends BinaryArithme * Code generation inherited from BinaryArithmetic. */ case class BitwiseOr(left: Expression, right: Expression) extends BinaryArithmetic { - override def symbol: String = "|" - protected def checkTypesInternal(t: DataType) = - TypeUtils.checkForBitwiseExpr(t, "operator " + symbol) + override def inputType: AbstractDataType = IntegralType + + override def symbol: String = "|" private lazy val or: (Any, Any) => Any = dataType match { case ByteType => @@ -79,10 +77,10 @@ case class BitwiseOr(left: Expression, right: Expression) extends BinaryArithmet * Code generation inherited from BinaryArithmetic. */ case class BitwiseXor(left: Expression, right: Expression) extends BinaryArithmetic { - override def symbol: String = "^" - protected def checkTypesInternal(t: DataType) = - TypeUtils.checkForBitwiseExpr(t, "operator " + symbol) + override def inputType: AbstractDataType = IntegralType + + override def symbol: String = "^" private lazy val xor: (Any, Any) => Any = dataType match { case ByteType => @@ -101,11 +99,13 @@ case class BitwiseXor(left: Expression, right: Expression) extends BinaryArithme /** * A function that calculates bitwise not(~) of a number. */ -case class BitwiseNot(child: Expression) extends UnaryArithmetic { - override def toString: String = s"~$child" +case class BitwiseNot(child: Expression) extends UnaryExpression with ExpectsInputTypes { - override def checkInputDataTypes(): TypeCheckResult = - TypeUtils.checkForBitwiseExpr(child.dataType, "operator ~") + override def inputTypes: Seq[AbstractDataType] = Seq(IntegralType) + + override def dataType: DataType = child.dataType + + override def toString: String = s"~$child" private lazy val not: (Any) => Any = dataType match { case ByteType => diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala index 9f6329bbda4ec..48225e1574600 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala @@ -24,9 +24,10 @@ import com.google.common.cache.{CacheBuilder, CacheLoader} import org.codehaus.janino.ClassBodyEvaluator import org.apache.spark.Logging +import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.types._ -import org.apache.spark.unsafe.types.UTF8String +import org.apache.spark.unsafe.types._ // These classes are here to avoid issues with serialization and integration with quasiquotes. @@ -56,9 +57,29 @@ class CodeGenContext { */ val references: mutable.ArrayBuffer[Expression] = new mutable.ArrayBuffer[Expression]() - val stringType: String = classOf[UTF8String].getName - val decimalType: String = classOf[Decimal].getName + /** + * Holding expressions' mutable states like `MonotonicallyIncreasingID.count` as a + * 3-tuple: java type, variable name, code to init it. + * As an example, ("int", "count", "count = 0;") will produce code: + * {{{ + * private int count; + * }}} + * as a member variable, and add + * {{{ + * count = 0; + * }}} + * to the constructor. + * + * They will be kept as member variables in generated classes like `SpecificProjection`. + */ + val mutableStates: mutable.ArrayBuffer[(String, String, String)] = + mutable.ArrayBuffer.empty[(String, String, String)] + def addMutableState(javaType: String, variableName: String, initCode: String): Unit = { + mutableStates += ((javaType, variableName, initCode)) + } + + final val intervalType: String = classOf[Interval].getName final val JAVA_BOOLEAN = "boolean" final val JAVA_BYTE = "byte" final val JAVA_SHORT = "short" @@ -84,10 +105,11 @@ class CodeGenContext { */ def getColumn(row: String, dataType: DataType, ordinal: Int): String = { val jt = javaType(dataType) - if (isPrimitiveType(jt)) { - s"$row.get${primitiveTypeName(jt)}($ordinal)" - } else { - s"($jt)$row.apply($ordinal)" + dataType match { + case _ if isPrimitiveType(jt) => s"$row.get${primitiveTypeName(jt)}($ordinal)" + case StringType => s"$row.getUTF8String($ordinal)" + case BinaryType => s"$row.getBinary($ordinal)" + case _ => s"($jt)$row.apply($ordinal)" } } @@ -124,9 +146,10 @@ class CodeGenContext { case LongType | TimestampType => JAVA_LONG case FloatType => JAVA_FLOAT case DoubleType => JAVA_DOUBLE - case dt: DecimalType => decimalType + case dt: DecimalType => "Decimal" case BinaryType => "byte[]" - case StringType => stringType + case StringType => "UTF8String" + case IntervalType => intervalType case _: StructType => "InternalRow" case _: ArrayType => s"scala.collection.Seq" case _: MapType => s"scala.collection.Map" @@ -172,6 +195,8 @@ class CodeGenContext { */ def genEqual(dataType: DataType, c1: String, c2: String): String = dataType match { case BinaryType => s"java.util.Arrays.equals($c1, $c2)" + case FloatType => s"(java.lang.Float.isNaN($c1) && java.lang.Float.isNaN($c2)) || $c1 == $c2" + case DoubleType => s"(java.lang.Double.isNaN($c1) && java.lang.Double.isNaN($c2)) || $c1 == $c2" case dt: DataType if isPrimitiveType(dt) => s"$c1 == $c2" case other => s"$c1.equals($c2)" } @@ -182,6 +207,8 @@ class CodeGenContext { def genComp(dataType: DataType, c1: String, c2: String): String = dataType match { // java boolean doesn't support > or < operator case BooleanType => s"($c1 == $c2 ? 0 : ($c1 ? 1 : -1))" + case DoubleType => s"org.apache.spark.util.Utils.nanSafeCompareDoubles($c1, $c2)" + case FloatType => s"org.apache.spark.util.Utils.nanSafeCompareFloats($c1, $c2)" // use c1 - c2 may overflow case dt: DataType if isPrimitiveType(dt) => s"($c1 > $c2 ? 1 : $c1 < $c2 ? -1 : 0)" case BinaryType => s"org.apache.spark.sql.catalyst.util.TypeUtils.compareBinary($c1, $c2)" @@ -203,7 +230,10 @@ class CodeGenContext { def isPrimitiveType(dt: DataType): Boolean = isPrimitiveType(javaType(dt)) } - +/** + * A wrapper for generated class, defines a `generate` method so that we can pass extra objects + * into generated class. + */ abstract class GeneratedClass { def generate(expressions: Array[Expression]): Any } @@ -219,6 +249,16 @@ abstract class CodeGenerator[InType <: AnyRef, OutType <: AnyRef] extends Loggin protected val mutableRowType: String = classOf[MutableRow].getName protected val genericMutableRowType: String = classOf[GenericMutableRow].getName + protected def declareMutableStates(ctx: CodeGenContext) = { + ctx.mutableStates.map { case (javaType, variableName, _) => + s"private $javaType $variableName;" + }.mkString("\n ") + } + + protected def initMutableStates(ctx: CodeGenContext) = { + ctx.mutableStates.map(_._3).mkString("\n ") + } + /** * Generates a class for a given input expression. Called when there is not cached code * already available. @@ -247,14 +287,20 @@ abstract class CodeGenerator[InType <: AnyRef, OutType <: AnyRef] extends Loggin private[this] def doCompile(code: String): GeneratedClass = { val evaluator = new ClassBodyEvaluator() evaluator.setParentClassLoader(getClass.getClassLoader) - evaluator.setDefaultImports(Array("org.apache.spark.sql.catalyst.InternalRow")) + evaluator.setDefaultImports(Array( + classOf[InternalRow].getName, + classOf[UnsafeRow].getName, + classOf[UTF8String].getName, + classOf[Decimal].getName + )) evaluator.setExtendedClass(classOf[GeneratedClass]) try { evaluator.cook(code) } catch { case e: Exception => - logError(s"failed to compile:\n $code", e) - throw e + val msg = s"failed to compile:\n $code" + logError(msg, e) + throw new Exception(msg, e) } evaluator.getClazz().newInstance().asInstanceOf[GeneratedClass] } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodegenFallback.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodegenFallback.scala new file mode 100644 index 0000000000000..6b187f05604fd --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodegenFallback.scala @@ -0,0 +1,40 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.expressions.codegen + +import org.apache.spark.sql.catalyst.expressions.Expression + +/** + * A trait that can be used to provide a fallback mode for expression code generation. + */ +trait CodegenFallback extends Expression { + + protected def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + ctx.references += this + val objectTerm = ctx.freshName("obj") + s""" + /* expression: ${this} */ + Object $objectTerm = expressions[${ctx.references.size - 1}].eval(i); + boolean ${ev.isNull} = $objectTerm == null; + ${ctx.javaType(this.dataType)} ${ev.primitive} = ${ctx.defaultValue(this.dataType)}; + if (!${ev.isNull}) { + ${ev.primitive} = (${ctx.boxedType(this.dataType)}) $objectTerm; + } + """ + } +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala index addb8023d9c0b..d838268f46956 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala @@ -18,6 +18,9 @@ package org.apache.spark.sql.catalyst.expressions.codegen import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.expressions.aggregate.NoOp + +import scala.collection.mutable.ArrayBuffer // MutableProjection is not accessible in Java abstract class BaseMutableProjection extends MutableProjection @@ -36,16 +39,48 @@ object GenerateMutableProjection extends CodeGenerator[Seq[Expression], () => Mu protected def create(expressions: Seq[Expression]): (() => MutableProjection) = { val ctx = newCodeGenContext() - val projectionCode = expressions.zipWithIndex.map { case (e, i) => - val evaluationCode = e.gen(ctx) - evaluationCode.code + - s""" - if(${evaluationCode.isNull}) - mutableRow.setNullAt($i); - else - ${ctx.setColumn("mutableRow", e.dataType, i, evaluationCode.primitive)}; - """ - }.mkString("\n") + val projectionCode = expressions.zipWithIndex.map { + case (NoOp, _) => "" + case (e, i) => + val evaluationCode = e.gen(ctx) + evaluationCode.code + + s""" + if(${evaluationCode.isNull}) + mutableRow.setNullAt($i); + else + ${ctx.setColumn("mutableRow", e.dataType, i, evaluationCode.primitive)}; + """ + } + // collect projections into blocks as function has 64kb codesize limit in JVM + val projectionBlocks = new ArrayBuffer[String]() + val blockBuilder = new StringBuilder() + for (projection <- projectionCode) { + if (blockBuilder.length > 16 * 1000) { + projectionBlocks.append(blockBuilder.toString()) + blockBuilder.clear() + } + blockBuilder.append(projection) + } + projectionBlocks.append(blockBuilder.toString()) + + val (projectionFuns, projectionCalls) = { + // inline execution if codesize limit was not broken + if (projectionBlocks.length == 1) { + ("", projectionBlocks.head) + } else { + ( + projectionBlocks.zipWithIndex.map { case (body, i) => + s""" + |private void apply$i(InternalRow i) { + | $body + |} + """.stripMargin + }.mkString, + projectionBlocks.indices.map(i => s"apply$i(i);").mkString("\n") + ) + } + } + val code = s""" public Object generate($exprType[] expr) { return new SpecificProjection(expr); @@ -53,12 +88,14 @@ object GenerateMutableProjection extends CodeGenerator[Seq[Expression], () => Mu class SpecificProjection extends ${classOf[BaseMutableProjection].getName} { - private $exprType[] expressions = null; - private $mutableRowType mutableRow = null; + private $exprType[] expressions; + private $mutableRowType mutableRow; + ${declareMutableStates(ctx)} public SpecificProjection($exprType[] expr) { expressions = expr; mutableRow = new $genericMutableRowType(${expressions.size}); + ${initMutableStates(ctx)} } public ${classOf[BaseMutableProjection].getName} target($mutableRowType row) { @@ -71,9 +108,11 @@ object GenerateMutableProjection extends CodeGenerator[Seq[Expression], () => Mu return (InternalRow) mutableRow; } + $projectionFuns + public Object apply(Object _i) { InternalRow i = (InternalRow) _i; - $projectionCode + $projectionCalls return mutableRow; } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateOrdering.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateOrdering.scala index d05dfc108e63a..2e6f9e204d813 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateOrdering.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateOrdering.scala @@ -46,30 +46,44 @@ object GenerateOrdering extends CodeGenerator[Seq[SortOrder], Ordering[InternalR protected def create(ordering: Seq[SortOrder]): Ordering[InternalRow] = { val ctx = newCodeGenContext() - val comparisons = ordering.zipWithIndex.map { case (order, i) => - val evalA = order.child.gen(ctx) - val evalB = order.child.gen(ctx) + val comparisons = ordering.map { order => + val eval = order.child.gen(ctx) val asc = order.direction == Ascending + val isNullA = ctx.freshName("isNullA") + val primitiveA = ctx.freshName("primitiveA") + val isNullB = ctx.freshName("isNullB") + val primitiveB = ctx.freshName("primitiveB") s""" i = a; - ${evalA.code} + boolean $isNullA; + ${ctx.javaType(order.child.dataType)} $primitiveA; + { + ${eval.code} + $isNullA = ${eval.isNull}; + $primitiveA = ${eval.primitive}; + } i = b; - ${evalB.code} - if (${evalA.isNull} && ${evalB.isNull}) { + boolean $isNullB; + ${ctx.javaType(order.child.dataType)} $primitiveB; + { + ${eval.code} + $isNullB = ${eval.isNull}; + $primitiveB = ${eval.primitive}; + } + if ($isNullA && $isNullB) { // Nothing - } else if (${evalA.isNull}) { + } else if ($isNullA) { return ${if (order.direction == Ascending) "-1" else "1"}; - } else if (${evalB.isNull}) { + } else if ($isNullB) { return ${if (order.direction == Ascending) "1" else "-1"}; } else { - int comp = ${ctx.genComp(order.child.dataType, evalA.primitive, evalB.primitive)}; + int comp = ${ctx.genComp(order.child.dataType, primitiveA, primitiveB)}; if (comp != 0) { return ${if (asc) "comp" else "-comp"}; } } """ }.mkString("\n") - val code = s""" public SpecificOrdering generate($exprType[] expr) { return new SpecificOrdering(expr); @@ -77,10 +91,12 @@ object GenerateOrdering extends CodeGenerator[Seq[SortOrder], Ordering[InternalR class SpecificOrdering extends ${classOf[BaseOrdering].getName} { - private $exprType[] expressions = null; + private $exprType[] expressions; + ${declareMutableStates(ctx)} public SpecificOrdering($exprType[] expr) { expressions = expr; + ${initMutableStates(ctx)} } @Override diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GeneratePredicate.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GeneratePredicate.scala index 274a42cb69087..1dda5992c3654 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GeneratePredicate.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GeneratePredicate.scala @@ -47,8 +47,10 @@ object GeneratePredicate extends CodeGenerator[Expression, (InternalRow) => Bool class SpecificPredicate extends ${classOf[Predicate].getName} { private final $exprType[] expressions; + ${declareMutableStates(ctx)} public SpecificPredicate($exprType[] expr) { expressions = expr; + ${initMutableStates(ctx)} } @Override diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateProjection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateProjection.scala index 3c7ee9cc16599..405d6b0e3bc76 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateProjection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateProjection.scala @@ -23,7 +23,7 @@ import org.apache.spark.sql.types._ /** * Java can not access Projection (in package object) */ -abstract class BaseProject extends Projection {} +abstract class BaseProjection extends Projection {} /** * Generates bytecode that produces a new [[InternalRow]] object based on a fixed set of input @@ -156,74 +156,76 @@ object GenerateProjection extends CodeGenerator[Seq[Expression], Projection] { return new SpecificProjection(expr); } - class SpecificProjection extends ${classOf[BaseProject].getName} { - private $exprType[] expressions = null; + class SpecificProjection extends ${classOf[BaseProjection].getName} { + private $exprType[] expressions; + ${declareMutableStates(ctx)} public SpecificProjection($exprType[] expr) { expressions = expr; + ${initMutableStates(ctx)} } @Override public Object apply(Object r) { - return new SpecificRow(expressions, (InternalRow) r); + return new SpecificRow((InternalRow) r); } - } - final class SpecificRow extends ${classOf[MutableRow].getName} { + final class SpecificRow extends ${classOf[MutableRow].getName} { - $columns + $columns - public SpecificRow($exprType[] expressions, InternalRow i) { - $initColumns - } + public SpecificRow(InternalRow i) { + $initColumns + } - public int length() { return ${expressions.length};} - protected boolean[] nullBits = new boolean[${expressions.length}]; - public void setNullAt(int i) { nullBits[i] = true; } - public boolean isNullAt(int i) { return nullBits[i]; } + public int length() { return ${expressions.length};} + protected boolean[] nullBits = new boolean[${expressions.length}]; + public void setNullAt(int i) { nullBits[i] = true; } + public boolean isNullAt(int i) { return nullBits[i]; } - public Object get(int i) { - if (isNullAt(i)) return null; - switch (i) { - $getCases + public Object get(int i) { + if (isNullAt(i)) return null; + switch (i) { + $getCases + } + return null; } - return null; - } - public void update(int i, Object value) { - if (value == null) { - setNullAt(i); - return; + public void update(int i, Object value) { + if (value == null) { + setNullAt(i); + return; + } + nullBits[i] = false; + switch (i) { + $updateCases + } } - nullBits[i] = false; - switch (i) { - $updateCases + $specificAccessorFunctions + $specificMutatorFunctions + + @Override + public int hashCode() { + int result = 37; + $hashUpdates + return result; } - } - $specificAccessorFunctions - $specificMutatorFunctions - @Override - public int hashCode() { - int result = 37; - $hashUpdates - return result; - } - - @Override - public boolean equals(Object other) { - if (other instanceof SpecificRow) { - SpecificRow row = (SpecificRow) other; - $columnChecks - return true; + @Override + public boolean equals(Object other) { + if (other instanceof SpecificRow) { + SpecificRow row = (SpecificRow) other; + $columnChecks + return true; + } + return super.equals(other); } - return super.equals(other); - } - @Override - public InternalRow copy() { - Object[] arr = new Object[${expressions.length}]; - ${copyColumns} - return new ${classOf[GenericInternalRow].getName}(arr); + @Override + public InternalRow copy() { + Object[] arr = new Object[${expressions.length}]; + ${copyColumns} + return new ${classOf[GenericInternalRow].getName}(arr); + } } } """ diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala new file mode 100644 index 0000000000000..d65e5c38ebf5c --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala @@ -0,0 +1,122 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.expressions.codegen + +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.types.{NullType, BinaryType, StringType} + + +/** + * Generates a [[Projection]] that returns an [[UnsafeRow]]. + * + * It generates the code for all the expressions, compute the total length for all the columns + * (can be accessed via variables), and then copy the data into a scratch buffer space in the + * form of UnsafeRow (the scratch buffer will grow as needed). + * + * Note: The returned UnsafeRow will be pointed to a scratch buffer inside the projection. + */ +object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafeProjection] { + + protected def canonicalize(in: Seq[Expression]): Seq[Expression] = + in.map(ExpressionCanonicalizer.execute) + + protected def bind(in: Seq[Expression], inputSchema: Seq[Attribute]): Seq[Expression] = + in.map(BindReferences.bindReference(_, inputSchema)) + + protected def create(expressions: Seq[Expression]): UnsafeProjection = { + val ctx = newCodeGenContext() + val exprs = expressions.map(_.gen(ctx)) + val allExprs = exprs.map(_.code).mkString("\n") + val fixedSize = 8 * exprs.length + UnsafeRow.calculateBitSetWidthInBytes(exprs.length) + val stringWriter = "org.apache.spark.sql.catalyst.expressions.StringUnsafeColumnWriter" + val binaryWriter = "org.apache.spark.sql.catalyst.expressions.BinaryUnsafeColumnWriter" + val additionalSize = expressions.zipWithIndex.map { case (e, i) => + e.dataType match { + case StringType => + s" + (${exprs(i).isNull} ? 0 : $stringWriter.getSize(${exprs(i).primitive}))" + case BinaryType => + s" + (${exprs(i).isNull} ? 0 : $binaryWriter.getSize(${exprs(i).primitive}))" + case _ => "" + } + }.mkString("") + + val writers = expressions.zipWithIndex.map { case (e, i) => + val update = e.dataType match { + case dt if ctx.isPrimitiveType(dt) => + s"${ctx.setColumn("target", dt, i, exprs(i).primitive)}" + case StringType => + s"cursor += $stringWriter.write(target, ${exprs(i).primitive}, $i, cursor)" + case BinaryType => + s"cursor += $binaryWriter.write(target, ${exprs(i).primitive}, $i, cursor)" + case NullType => "" + case _ => + throw new UnsupportedOperationException(s"Not supported DataType: ${e.dataType}") + } + s"""if (${exprs(i).isNull}) { + target.setNullAt($i); + } else { + $update; + }""" + }.mkString("\n ") + + val code = s""" + private $exprType[] expressions; + + public Object generate($exprType[] expr) { + this.expressions = expr; + return new SpecificProjection(); + } + + class SpecificProjection extends ${classOf[UnsafeProjection].getName} { + + private UnsafeRow target = new UnsafeRow(); + private byte[] buffer = new byte[64]; + ${declareMutableStates(ctx)} + + public SpecificProjection() { + ${initMutableStates(ctx)} + } + + // Scala.Function1 need this + public Object apply(Object row) { + return apply((InternalRow) row); + } + + public UnsafeRow apply(InternalRow i) { + $allExprs + + // additionalSize had '+' in the beginning + int numBytes = $fixedSize $additionalSize; + if (numBytes > buffer.length) { + buffer = new byte[numBytes]; + } + target.pointTo(buffer, org.apache.spark.unsafe.PlatformDependent.BYTE_ARRAY_OFFSET, + ${expressions.size}, numBytes); + int cursor = $fixedSize; + $writers + return target; + } + } + """ + + logDebug(s"code for ${expressions.mkString(",")}:\n$code") + + val c = compile(code) + c.generate(ctx.references.toArray).asInstanceOf[UnsafeProjection] + } +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala new file mode 100644 index 0000000000000..2d92dcf23a86e --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala @@ -0,0 +1,37 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql.catalyst.expressions + +import org.apache.spark.sql.catalyst.expressions.codegen.{CodeGenContext, GeneratedExpressionCode} +import org.apache.spark.sql.types._ + +/** + * Given an array or map, returns its size. + */ +case class Size(child: Expression) extends UnaryExpression with ExpectsInputTypes { + override def dataType: DataType = IntegerType + override def inputTypes: Seq[AbstractDataType] = Seq(TypeCollection(ArrayType, MapType)) + + override def nullSafeEval(value: Any): Int = child.dataType match { + case ArrayType(_, _) => value.asInstanceOf[Seq[Any]].size + case MapType(_, _, _) => value.asInstanceOf[Map[Any, Any]].size + } + + override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + nullSafeCodeGen(ctx, ev, c => s"${ev.primitive} = ($c).size();") + } +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala index d1e4c458864f1..20b1eaab8e303 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala @@ -17,8 +17,11 @@ package org.apache.spark.sql.catalyst.expressions +import scala.collection.mutable + import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.analysis.TypeCheckResult +import org.apache.spark.sql.catalyst.expressions.codegen.{GeneratedExpressionCode, CodeGenContext} import org.apache.spark.sql.catalyst.util.TypeUtils import org.apache.spark.sql.types._ @@ -44,12 +47,29 @@ case class CreateArray(children: Seq[Expression]) extends Expression { children.map(_.eval(input)) } + override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + val arraySeqClass = classOf[mutable.ArraySeq[Any]].getName + s""" + boolean ${ev.isNull} = false; + $arraySeqClass ${ev.primitive} = new $arraySeqClass(${children.size}); + """ + + children.zipWithIndex.map { case (e, i) => + val eval = e.gen(ctx) + eval.code + s""" + if (${eval.isNull}) { + ${ev.primitive}.update($i, null); + } else { + ${ev.primitive}.update($i, ${eval.primitive}); + } + """ + }.mkString("\n") + } + override def prettyName: String = "array" } /** * Returns a Row containing the evaluation of all children expressions. - * TODO: [[CreateStruct]] does not support codegen. */ case class CreateStruct(children: Seq[Expression]) extends Expression { @@ -75,6 +95,24 @@ case class CreateStruct(children: Seq[Expression]) extends Expression { InternalRow(children.map(_.eval(input)): _*) } + override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + val rowClass = classOf[GenericMutableRow].getName + s""" + boolean ${ev.isNull} = false; + final $rowClass ${ev.primitive} = new $rowClass(${children.size}); + """ + + children.zipWithIndex.map { case (e, i) => + val eval = e.gen(ctx) + eval.code + s""" + if (${eval.isNull}) { + ${ev.primitive}.update($i, null); + } else { + ${ev.primitive}.update($i, ${eval.primitive}); + } + """ + }.mkString("\n") + } + override def prettyName: String = "struct" } @@ -103,11 +141,11 @@ case class CreateNamedStruct(children: Seq[Expression]) extends Expression { override def checkInputDataTypes(): TypeCheckResult = { if (children.size % 2 != 0) { - TypeCheckResult.TypeCheckFailure("CreateNamedStruct expects an even number of arguments.") + TypeCheckResult.TypeCheckFailure(s"$prettyName expects an even number of arguments.") } else { val invalidNames = nameExprs.filterNot(e => e.foldable && e.dataType == StringType && !nullable) - if (invalidNames.size != 0) { + if (invalidNames.nonEmpty) { TypeCheckResult.TypeCheckFailure( s"Odd position only allow foldable and not-null StringType expressions, got :" + s" ${invalidNames.mkString(",")}") @@ -121,5 +159,23 @@ case class CreateNamedStruct(children: Seq[Expression]) extends Expression { InternalRow(valExprs.map(_.eval(input)): _*) } + override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + val rowClass = classOf[GenericMutableRow].getName + s""" + boolean ${ev.isNull} = false; + final $rowClass ${ev.primitive} = new $rowClass(${valExprs.size}); + """ + + valExprs.zipWithIndex.map { case (e, i) => + val eval = e.gen(ctx) + eval.code + s""" + if (${eval.isNull}) { + ${ev.primitive}.update($i, null); + } else { + ${ev.primitive}.update($i, ${eval.primitive}); + } + """ + }.mkString("\n") + } + override def prettyName: String = "named_struct" } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionals.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionals.scala index e6a705fb8055b..15b33da884dcb 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionals.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionals.scala @@ -35,8 +35,8 @@ case class If(predicate: Expression, trueValue: Expression, falseValue: Expressi TypeCheckResult.TypeCheckFailure( s"type of predicate expression in If should be boolean, not ${predicate.dataType}") } else if (trueValue.dataType != falseValue.dataType) { - TypeCheckResult.TypeCheckFailure( - s"differing types in If (${trueValue.dataType} and ${falseValue.dataType}).") + TypeCheckResult.TypeCheckFailure(s"differing types in '$prettyString' " + + s"(${trueValue.dataType.simpleString} and ${falseValue.dataType.simpleString}).") } else { TypeCheckResult.TypeCheckSuccess } @@ -77,7 +77,6 @@ case class If(predicate: Expression, trueValue: Expression, falseValue: Expressi } trait CaseWhenLike extends Expression { - self: Product => // Note that `branches` are considered in consecutive pairs (cond, val), and the optional last // element is the value for the default catch-all case (if provided). @@ -230,24 +229,31 @@ case class CaseKeyWhen(key: Expression, branches: Seq[Expression]) extends CaseW } } + private def evalElse(input: InternalRow): Any = { + if (branchesArr.length % 2 == 0) { + null + } else { + branchesArr(branchesArr.length - 1).eval(input) + } + } + /** Written in imperative fashion for performance considerations. */ override def eval(input: InternalRow): Any = { val evaluatedKey = key.eval(input) - val len = branchesArr.length - var i = 0 - // If all branches fail and an elseVal is not provided, the whole statement - // defaults to null, according to Hive's semantics. - while (i < len - 1) { - if (equalNullSafe(evaluatedKey, branchesArr(i).eval(input))) { - return branchesArr(i + 1).eval(input) + // If key is null, we can just return the else part or null if there is no else. + // If key is not null but doesn't match any when part, we need to return + // the else part or null if there is no else, according to Hive's semantics. + if (evaluatedKey != null) { + val len = branchesArr.length + var i = 0 + while (i < len - 1) { + if (evaluatedKey == branchesArr(i).eval(input)) { + return branchesArr(i + 1).eval(input) + } + i += 2 } - i += 2 - } - var res: Any = null - if (i == len - 1) { - res = branchesArr(i).eval(input) } - return res + evalElse(input) } override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { @@ -261,9 +267,7 @@ case class CaseKeyWhen(key: Expression, branches: Seq[Expression]) extends CaseW s""" if (!$got) { ${cond.code} - if (${keyEval.isNull} && ${cond.isNull} || - !${keyEval.isNull} && !${cond.isNull} - && ${ctx.genEqual(key.dataType, keyEval.primitive, cond.primitive)}) { + if (!${cond.isNull} && ${ctx.genEqual(key.dataType, keyEval.primitive, cond.primitive)}) { $got = true; ${res.code} ${ev.isNull} = ${res.isNull}; @@ -291,21 +295,13 @@ case class CaseKeyWhen(key: Expression, branches: Seq[Expression]) extends CaseW boolean ${ev.isNull} = true; ${ctx.javaType(dataType)} ${ev.primitive} = ${ctx.defaultValue(dataType)}; ${keyEval.code} - $cases + if (!${keyEval.isNull}) { + $cases + } $other """ } - private def equalNullSafe(l: Any, r: Any) = { - if (l == null && r == null) { - true - } else if (l == null || r == null) { - false - } else { - l == r - } - } - override def toString: String = { s"CASE $key" + branches.sliding(2, 2).map { case Seq(cond, value) => s" WHEN $cond THEN $value" @@ -314,7 +310,11 @@ case class CaseKeyWhen(key: Expression, branches: Seq[Expression]) extends CaseW } } -case class Least(children: Expression*) extends Expression { +/** + * A function that returns the least value of all parameters, skipping null values. + * It takes at least 2 parameters, and returns null iff all parameters are null. + */ +case class Least(children: Seq[Expression]) extends Expression { require(children.length > 1, "LEAST requires at least 2 arguments, got " + children.length) override def nullable: Boolean = children.forall(_.nullable) @@ -359,12 +359,16 @@ case class Least(children: Expression*) extends Expression { ${evalChildren.map(_.code).mkString("\n")} boolean ${ev.isNull} = true; ${ctx.javaType(dataType)} ${ev.primitive} = ${ctx.defaultValue(dataType)}; - ${(0 until children.length).map(updateEval).mkString("\n")} + ${children.indices.map(updateEval).mkString("\n")} """ } } -case class Greatest(children: Expression*) extends Expression { +/** + * A function that returns the greatest value of all parameters, skipping null values. + * It takes at least 2 parameters, and returns null iff all parameters are null. + */ +case class Greatest(children: Seq[Expression]) extends Expression { require(children.length > 1, "GREATEST requires at least 2 arguments, got " + children.length) override def nullable: Boolean = children.forall(_.nullable) @@ -409,7 +413,7 @@ case class Greatest(children: Expression*) extends Expression { ${evalChildren.map(_.code).mkString("\n")} boolean ${ev.isNull} = true; ${ctx.javaType(dataType)} ${ev.primitive} = ${ctx.defaultValue(dataType)}; - ${(0 until children.length).map(updateEval).mkString("\n")} + ${children.indices.map(updateEval).mkString("\n")} """ } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeFunctions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeFunctions.scala index dd5ec330a771b..9e55f0546e123 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeFunctions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeFunctions.scala @@ -17,9 +17,16 @@ package org.apache.spark.sql.catalyst.expressions +import java.sql.Date +import java.text.SimpleDateFormat +import java.util.{Calendar, TimeZone} + +import org.apache.spark.sql.catalyst.expressions.codegen.{CodeGenContext, GeneratedExpressionCode} import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions.codegen.CodegenFallback import org.apache.spark.sql.catalyst.util.DateTimeUtils import org.apache.spark.sql.types._ +import org.apache.spark.unsafe.types.UTF8String /** * Returns the current date at the start of query evaluation. @@ -27,7 +34,7 @@ import org.apache.spark.sql.types._ * * There is no code generation since this expression should get constant folded by the optimizer. */ -case class CurrentDate() extends LeafExpression { +case class CurrentDate() extends LeafExpression with CodegenFallback { override def foldable: Boolean = true override def nullable: Boolean = false @@ -44,7 +51,7 @@ case class CurrentDate() extends LeafExpression { * * There is no code generation since this expression should get constant folded by the optimizer. */ -case class CurrentTimestamp() extends LeafExpression { +case class CurrentTimestamp() extends LeafExpression with CodegenFallback { override def foldable: Boolean = true override def nullable: Boolean = false @@ -54,3 +61,207 @@ case class CurrentTimestamp() extends LeafExpression { System.currentTimeMillis() * 1000L } } + +case class Hour(child: Expression) extends UnaryExpression with ImplicitCastInputTypes { + + override def inputTypes: Seq[AbstractDataType] = Seq(TimestampType) + + override def dataType: DataType = IntegerType + + override protected def nullSafeEval(timestamp: Any): Any = { + DateTimeUtils.getHours(timestamp.asInstanceOf[Long]) + } + + override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + val dtu = DateTimeUtils.getClass.getName.stripSuffix("$") + defineCodeGen(ctx, ev, (c) => + s"""$dtu.getHours($c)""" + ) + } +} + +case class Minute(child: Expression) extends UnaryExpression with ImplicitCastInputTypes { + + override def inputTypes: Seq[AbstractDataType] = Seq(TimestampType) + + override def dataType: DataType = IntegerType + + override protected def nullSafeEval(timestamp: Any): Any = { + DateTimeUtils.getMinutes(timestamp.asInstanceOf[Long]) + } + + override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + val dtu = DateTimeUtils.getClass.getName.stripSuffix("$") + defineCodeGen(ctx, ev, (c) => + s"""$dtu.getMinutes($c)""" + ) + } +} + +case class Second(child: Expression) extends UnaryExpression with ImplicitCastInputTypes { + + override def inputTypes: Seq[AbstractDataType] = Seq(TimestampType) + + override def dataType: DataType = IntegerType + + override protected def nullSafeEval(timestamp: Any): Any = { + DateTimeUtils.getSeconds(timestamp.asInstanceOf[Long]) + } + + override protected def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + val dtu = DateTimeUtils.getClass.getName.stripSuffix("$") + defineCodeGen(ctx, ev, (c) => + s"""$dtu.getSeconds($c)""" + ) + } +} + +case class DayOfYear(child: Expression) extends UnaryExpression with ImplicitCastInputTypes { + + override def inputTypes: Seq[AbstractDataType] = Seq(DateType) + + override def dataType: DataType = IntegerType + + override protected def nullSafeEval(date: Any): Any = { + DateTimeUtils.getDayInYear(date.asInstanceOf[Int]) + } + + override protected def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + val dtu = DateTimeUtils.getClass.getName.stripSuffix("$") + defineCodeGen(ctx, ev, (c) => + s"""$dtu.getDayInYear($c)""" + ) + } +} + + +case class Year(child: Expression) extends UnaryExpression with ImplicitCastInputTypes { + + override def inputTypes: Seq[AbstractDataType] = Seq(DateType) + + override def dataType: DataType = IntegerType + + override protected def nullSafeEval(date: Any): Any = { + DateTimeUtils.getYear(date.asInstanceOf[Int]) + } + + override protected def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + val dtu = DateTimeUtils.getClass.getName.stripSuffix("$") + defineCodeGen(ctx, ev, c => + s"""$dtu.getYear($c)""" + ) + } +} + +case class Quarter(child: Expression) extends UnaryExpression with ImplicitCastInputTypes { + + override def inputTypes: Seq[AbstractDataType] = Seq(DateType) + + override def dataType: DataType = IntegerType + + override protected def nullSafeEval(date: Any): Any = { + DateTimeUtils.getQuarter(date.asInstanceOf[Int]) + } + + override protected def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + val dtu = DateTimeUtils.getClass.getName.stripSuffix("$") + defineCodeGen(ctx, ev, (c) => + s"""$dtu.getQuarter($c)""" + ) + } +} + +case class Month(child: Expression) extends UnaryExpression with ImplicitCastInputTypes { + + override def inputTypes: Seq[AbstractDataType] = Seq(DateType) + + override def dataType: DataType = IntegerType + + override protected def nullSafeEval(date: Any): Any = { + DateTimeUtils.getMonth(date.asInstanceOf[Int]) + } + + override protected def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + val dtu = DateTimeUtils.getClass.getName.stripSuffix("$") + defineCodeGen(ctx, ev, (c) => + s"""$dtu.getMonth($c)""" + ) + } +} + +case class DayOfMonth(child: Expression) extends UnaryExpression with ImplicitCastInputTypes { + + override def inputTypes: Seq[AbstractDataType] = Seq(DateType) + + override def dataType: DataType = IntegerType + + override protected def nullSafeEval(date: Any): Any = { + DateTimeUtils.getDayOfMonth(date.asInstanceOf[Int]) + } + + override protected def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + val dtu = DateTimeUtils.getClass.getName.stripSuffix("$") + defineCodeGen(ctx, ev, (c) => + s"""$dtu.getDayOfMonth($c)""" + ) + } +} + +case class WeekOfYear(child: Expression) extends UnaryExpression with ImplicitCastInputTypes { + + override def inputTypes: Seq[AbstractDataType] = Seq(DateType) + + override def dataType: DataType = IntegerType + + @transient private lazy val c = { + val c = Calendar.getInstance(TimeZone.getTimeZone("UTC")) + c.setFirstDayOfWeek(Calendar.MONDAY) + c.setMinimalDaysInFirstWeek(4) + c + } + + override protected def nullSafeEval(date: Any): Any = { + c.setTimeInMillis(date.asInstanceOf[Int] * 1000L * 3600L * 24L) + c.get(Calendar.WEEK_OF_YEAR) + } + + override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + nullSafeCodeGen(ctx, ev, (time) => { + val cal = classOf[Calendar].getName + val c = ctx.freshName("cal") + ctx.addMutableState(cal, c, + s""" + $c = $cal.getInstance(java.util.TimeZone.getTimeZone("UTC")); + $c.setFirstDayOfWeek($cal.MONDAY); + $c.setMinimalDaysInFirstWeek(4); + """) + s""" + $c.setTimeInMillis($time * 1000L * 3600L * 24L); + ${ev.primitive} = $c.get($cal.WEEK_OF_YEAR); + """ + }) + } +} + +case class DateFormatClass(left: Expression, right: Expression) extends BinaryExpression + with ImplicitCastInputTypes { + + override def dataType: DataType = StringType + + override def inputTypes: Seq[AbstractDataType] = Seq(TimestampType, StringType) + + override def prettyName: String = "date_format" + + override protected def nullSafeEval(timestamp: Any, format: Any): Any = { + val sdf = new SimpleDateFormat(format.toString) + UTF8String.fromString(sdf.format(new Date(timestamp.asInstanceOf[Long] / 1000))) + } + + override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + val sdf = classOf[SimpleDateFormat].getName + defineCodeGen(ctx, ev, (timestamp, format) => { + s"""UTF8String.fromString((new $sdf($format.toString())) + .format(new java.sql.Date($timestamp / 1000)))""" + }) + } +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/decimalFunctions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/decimalFunctions.scala index 2fa74b4ffc5da..b9d4736a65e26 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/decimalFunctions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/decimalFunctions.scala @@ -54,7 +54,7 @@ case class MakeDecimal(child: Expression, precision: Int, scale: Int) extends Un override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { nullSafeCodeGen(ctx, ev, eval => { s""" - ${ev.primitive} = (new ${ctx.decimalType}()).setOrNull($eval, $precision, $scale); + ${ev.primitive} = (new Decimal()).setOrNull($eval, $precision, $scale); ${ev.isNull} = ${ev.primitive} == null; """ }) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala index b68d30a26abd8..2dbcf2830f876 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala @@ -20,9 +20,9 @@ package org.apache.spark.sql.catalyst.expressions import scala.collection.Map import org.apache.spark.sql.Row -import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow} import org.apache.spark.sql.catalyst.analysis.TypeCheckResult -import org.apache.spark.sql.catalyst.{CatalystTypeConverters, trees} +import org.apache.spark.sql.catalyst.expressions.codegen.CodegenFallback import org.apache.spark.sql.types._ /** @@ -40,13 +40,14 @@ import org.apache.spark.sql.types._ * requested. The attributes produced by this function will be automatically copied anytime rules * result in changes to the Generator or its children. */ -abstract class Generator extends Expression { - self: Product => +trait Generator extends Expression { // TODO ideally we should return the type of ArrayType(StructType), // however, we don't keep the output field names in the Generator. override def dataType: DataType = throw new UnsupportedOperationException + override def foldable: Boolean = false + override def nullable: Boolean = false /** @@ -72,7 +73,7 @@ case class UserDefinedGenerator( elementTypes: Seq[(DataType, Boolean)], function: Row => TraversableOnce[InternalRow], children: Seq[Expression]) - extends Generator { + extends Generator with CodegenFallback { @transient private[this] var inputRow: InterpretedProjection = _ @transient private[this] var convertToScala: (InternalRow) => Row = _ @@ -99,8 +100,9 @@ case class UserDefinedGenerator( /** * Given an input array produces a sequence of rows for each value in the array. */ -case class Explode(child: Expression) - extends Generator with trees.UnaryNode[Expression] { +case class Explode(child: Expression) extends UnaryExpression with Generator with CodegenFallback { + + override def children: Seq[Expression] = child :: Nil override def checkInputDataTypes(): TypeCheckResult = { if (child.dataType.isInstanceOf[ArrayType] || child.dataType.isInstanceOf[MapType]) { @@ -127,6 +129,4 @@ case class Explode(child: Expression) else inputMap.map { case (k, v) => InternalRow(k, v) } } } - - override def toString: String = s"explode($child)" } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala index 3a7a7ae440036..f25ac32679587 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala @@ -21,10 +21,10 @@ import java.sql.{Date, Timestamp} import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.CatalystTypeConverters -import org.apache.spark.sql.catalyst.expressions.codegen.{CodeGenContext, GeneratedExpressionCode} +import org.apache.spark.sql.catalyst.expressions.codegen._ import org.apache.spark.sql.catalyst.util.DateTimeUtils import org.apache.spark.sql.types._ -import org.apache.spark.unsafe.types.UTF8String +import org.apache.spark.unsafe.types._ object Literal { def apply(v: Any): Literal = v match { @@ -42,6 +42,7 @@ object Literal { case t: Timestamp => Literal(DateTimeUtils.fromJavaTimestamp(t), TimestampType) case d: Date => Literal(DateTimeUtils.fromJavaDate(d), DateType) case a: Array[Byte] => Literal(a, BinaryType) + case i: Interval => Literal(i, IntervalType) case null => Literal(null, NullType) case _ => throw new RuntimeException("Unsupported literal type " + v.getClass + " " + v) @@ -74,7 +75,8 @@ object IntegerLiteral { /** * In order to do type checking, use Literal.create() instead of constructor */ -case class Literal protected (value: Any, dataType: DataType) extends LeafExpression { +case class Literal protected (value: Any, dataType: DataType) + extends LeafExpression with CodegenFallback { override def foldable: Boolean = true override def nullable: Boolean = value == null @@ -141,7 +143,7 @@ case class Literal protected (value: Any, dataType: DataType) extends LeafExpres // TODO: Specialize case class MutableLiteral(var value: Any, dataType: DataType, nullable: Boolean = true) - extends LeafExpression { + extends LeafExpression with CodegenFallback { def update(expression: Expression, input: InternalRow): Unit = { value = expression.eval(input) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/math.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/math.scala index c31890e27fb54..68cca0ad3d067 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/math.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/math.scala @@ -19,19 +19,24 @@ package org.apache.spark.sql.catalyst.expressions import java.{lang => jl} -import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.analysis.TypeCheckResult +import org.apache.spark.sql.catalyst.analysis.TypeCheckResult.{TypeCheckSuccess, TypeCheckFailure} import org.apache.spark.sql.catalyst.expressions.codegen._ +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.util.NumberConverter import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.UTF8String /** * A leaf expression specifically for math constants. Math constants expect no input. + * + * There is no code generation because they should get constant folded by the optimizer. + * * @param c The math constant. * @param name The short name of the function */ abstract class LeafMathExpression(c: Double, name: String) - extends LeafExpression with Serializable { - self: Product => + extends LeafExpression with CodegenFallback { override def dataType: DataType = DoubleType override def foldable: Boolean = true @@ -39,13 +44,6 @@ abstract class LeafMathExpression(c: Double, name: String) override def toString: String = s"$name()" override def eval(input: InternalRow): Any = c - - override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { - s""" - boolean ${ev.isNull} = false; - ${ctx.javaType(dataType)} ${ev.primitive} = java.lang.Math.$name; - """ - } } /** @@ -55,7 +53,7 @@ abstract class LeafMathExpression(c: Double, name: String) * @param name The short name of the function */ abstract class UnaryMathExpression(f: Double => Double, name: String) - extends UnaryExpression with Serializable with ExpectsInputTypes { self: Product => + extends UnaryExpression with Serializable with ImplicitCastInputTypes { override def inputTypes: Seq[DataType] = Seq(DoubleType) override def dataType: DataType = DoubleType @@ -63,22 +61,38 @@ abstract class UnaryMathExpression(f: Double => Double, name: String) override def toString: String = s"$name($child)" protected override def nullSafeEval(input: Any): Any = { - val result = f(input.asInstanceOf[Double]) - if (result.isNaN) null else result + f(input.asInstanceOf[Double]) } // name of function in java.lang.Math def funcName: String = name.toLowerCase override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { - nullSafeCodeGen(ctx, ev, eval => { + defineCodeGen(ctx, ev, c => s"java.lang.Math.${funcName}($c)") + } +} + +abstract class UnaryLogExpression(f: Double => Double, name: String) + extends UnaryMathExpression(f, name) { + + // values less than or equal to yAsymptote eval to null in Hive, instead of NaN or -Infinity + protected val yAsymptote: Double = 0.0 + + protected override def nullSafeEval(input: Any): Any = { + val d = input.asInstanceOf[Double] + if (d <= yAsymptote) null else f(d) + } + + override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + nullSafeCodeGen(ctx, ev, c => s""" - ${ev.primitive} = java.lang.Math.${funcName}($eval); - if (Double.valueOf(${ev.primitive}).isNaN()) { + if ($c <= $yAsymptote) { ${ev.isNull} = true; + } else { + ${ev.primitive} = java.lang.Math.${funcName}($c); } """ - }) + ) } } @@ -89,7 +103,7 @@ abstract class UnaryMathExpression(f: Double => Double, name: String) * @param name The short name of the function */ abstract class BinaryMathExpression(f: (Double, Double) => Double, name: String) - extends BinaryExpression with Serializable with ExpectsInputTypes { self: Product => + extends BinaryExpression with Serializable with ImplicitCastInputTypes { override def inputTypes: Seq[DataType] = Seq(DoubleType, DoubleType) @@ -98,8 +112,7 @@ abstract class BinaryMathExpression(f: (Double, Double) => Double, name: String) override def dataType: DataType = DoubleType protected override def nullSafeEval(input1: Any, input2: Any): Any = { - val result = f(input1.asInstanceOf[Double], input2.asInstanceOf[Double]) - if (result.isNaN) null else result + f(input1.asInstanceOf[Double], input2.asInstanceOf[Double]) } override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { @@ -113,8 +126,16 @@ abstract class BinaryMathExpression(f: (Double, Double) => Double, name: String) //////////////////////////////////////////////////////////////////////////////////////////////////// //////////////////////////////////////////////////////////////////////////////////////////////////// +/** + * Euler's number. Note that there is no code generation because this is only + * evaluated by the optimizer during constant folding. + */ case class EulerNumber() extends LeafMathExpression(math.E, "E") +/** + * Pi. Note that there is no code generation because this is only + * evaluated by the optimizer during constant folding. + */ case class Pi() extends LeafMathExpression(math.Pi, "PI") //////////////////////////////////////////////////////////////////////////////////////////////////// @@ -137,6 +158,79 @@ case class Cos(child: Expression) extends UnaryMathExpression(math.cos, "COS") case class Cosh(child: Expression) extends UnaryMathExpression(math.cosh, "COSH") +/** + * Convert a num from one base to another + * @param numExpr the number to be converted + * @param fromBaseExpr from which base + * @param toBaseExpr to which base + */ +case class Conv(numExpr: Expression, fromBaseExpr: Expression, toBaseExpr: Expression) + extends Expression with ImplicitCastInputTypes { + + override def foldable: Boolean = numExpr.foldable && fromBaseExpr.foldable && toBaseExpr.foldable + + override def nullable: Boolean = numExpr.nullable || fromBaseExpr.nullable || toBaseExpr.nullable + + override def children: Seq[Expression] = Seq(numExpr, fromBaseExpr, toBaseExpr) + + override def inputTypes: Seq[AbstractDataType] = Seq(StringType, IntegerType, IntegerType) + + override def dataType: DataType = StringType + + /** Returns the result of evaluating this expression on a given input Row */ + override def eval(input: InternalRow): Any = { + val num = numExpr.eval(input) + if (num != null) { + val fromBase = fromBaseExpr.eval(input) + if (fromBase != null) { + val toBase = toBaseExpr.eval(input) + if (toBase != null) { + NumberConverter.convert( + num.asInstanceOf[UTF8String].getBytes, + fromBase.asInstanceOf[Int], + toBase.asInstanceOf[Int]) + } else { + null + } + } else { + null + } + } else { + null + } + } + + override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + val numGen = numExpr.gen(ctx) + val from = fromBaseExpr.gen(ctx) + val to = toBaseExpr.gen(ctx) + + val numconv = NumberConverter.getClass.getName.stripSuffix("$") + s""" + ${ctx.javaType(dataType)} ${ev.primitive} = ${ctx.defaultValue(dataType)}; + ${numGen.code} + boolean ${ev.isNull} = ${numGen.isNull}; + if (!${ev.isNull}) { + ${from.code} + if (!${from.isNull}) { + ${to.code} + if (!${to.isNull}) { + ${ev.primitive} = $numconv.convert(${numGen.primitive}.getBytes(), + ${from.primitive}, ${to.primitive}); + if (${ev.primitive} == null) { + ${ev.isNull} = true; + } + } else { + ${ev.isNull} = true; + } + } else { + ${ev.isNull} = true; + } + } + """ + } +} + case class Exp(child: Expression) extends UnaryMathExpression(math.exp, "EXP") case class Expm1(child: Expression) extends UnaryMathExpression(math.expm1, "EXPM1") @@ -174,7 +268,7 @@ object Factorial { ) } -case class Factorial(child: Expression) extends UnaryExpression with ExpectsInputTypes { +case class Factorial(child: Expression) extends UnaryExpression with ImplicitCastInputTypes { override def inputTypes: Seq[DataType] = Seq(IntegerType) @@ -206,25 +300,28 @@ case class Factorial(child: Expression) extends UnaryExpression with ExpectsInpu } } -case class Log(child: Expression) extends UnaryMathExpression(math.log, "LOG") +case class Log(child: Expression) extends UnaryLogExpression(math.log, "LOG") case class Log2(child: Expression) - extends UnaryMathExpression((x: Double) => math.log(x) / math.log(2), "LOG2") { + extends UnaryLogExpression((x: Double) => math.log(x) / math.log(2), "LOG2") { override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { - nullSafeCodeGen(ctx, ev, eval => { + nullSafeCodeGen(ctx, ev, c => s""" - ${ev.primitive} = java.lang.Math.log($eval) / java.lang.Math.log(2); - if (Double.valueOf(${ev.primitive}).isNaN()) { + if ($c <= $yAsymptote) { ${ev.isNull} = true; + } else { + ${ev.primitive} = java.lang.Math.log($c) / java.lang.Math.log(2); } """ - }) + ) } } -case class Log10(child: Expression) extends UnaryMathExpression(math.log10, "LOG10") +case class Log10(child: Expression) extends UnaryLogExpression(math.log10, "LOG10") -case class Log1p(child: Expression) extends UnaryMathExpression(math.log1p, "LOG1P") +case class Log1p(child: Expression) extends UnaryLogExpression(math.log1p, "LOG1P") { + protected override val yAsymptote: Double = -1.0 +} case class Rint(child: Expression) extends UnaryMathExpression(math.rint, "ROUND") { override def funcName: String = "rint" @@ -251,7 +348,7 @@ case class ToRadians(child: Expression) extends UnaryMathExpression(math.toRadia } case class Bin(child: Expression) - extends UnaryExpression with Serializable with ExpectsInputTypes { + extends UnaryExpression with Serializable with ImplicitCastInputTypes { override def inputTypes: Seq[DataType] = Seq(LongType) override def dataType: DataType = StringType @@ -261,7 +358,7 @@ case class Bin(child: Expression) override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { defineCodeGen(ctx, ev, (c) => - s"${ctx.stringType}.fromString(java.lang.Long.toBinaryString($c))") + s"UTF8String.fromString(java.lang.Long.toBinaryString($c))") } } @@ -278,28 +375,8 @@ object Hex { (0 to 5).foreach(i => array('a' + i) = (i + 10).toByte) array } -} - -/** - * If the argument is an INT or binary, hex returns the number as a STRING in hexadecimal format. - * Otherwise if the number is a STRING, it converts each character into its hex representation - * and returns the resulting STRING. Negative numbers would be treated as two's complement. - */ -case class Hex(child: Expression) extends UnaryExpression with ExpectsInputTypes { - // TODO: Create code-gen version. - - override def inputTypes: Seq[AbstractDataType] = - Seq(TypeCollection(LongType, BinaryType, StringType)) - - override def dataType: DataType = StringType - - protected override def nullSafeEval(num: Any): Any = child.dataType match { - case LongType => hex(num.asInstanceOf[Long]) - case BinaryType => hex(num.asInstanceOf[Array[Byte]]) - case StringType => hex(num.asInstanceOf[UTF8String].getBytes) - } - private[this] def hex(bytes: Array[Byte]): UTF8String = { + def hex(bytes: Array[Byte]): UTF8String = { val length = bytes.length val value = new Array[Byte](length * 2) var i = 0 @@ -311,7 +388,7 @@ case class Hex(child: Expression) extends UnaryExpression with ExpectsInputTypes UTF8String.fromBytes(value) } - private def hex(num: Long): UTF8String = { + def hex(num: Long): UTF8String = { // Extract the hex digits of num into value[] from right to left val value = new Array[Byte](16) var numBuf = num @@ -323,24 +400,8 @@ case class Hex(child: Expression) extends UnaryExpression with ExpectsInputTypes } while (numBuf != 0) UTF8String.fromBytes(java.util.Arrays.copyOfRange(value, value.length - len, value.length)) } -} - -/** - * Performs the inverse operation of HEX. - * Resulting characters are returned as a byte array. - */ -case class Unhex(child: Expression) extends UnaryExpression with ExpectsInputTypes { - // TODO: Create code-gen version. - override def inputTypes: Seq[AbstractDataType] = Seq(StringType) - - override def nullable: Boolean = true - override def dataType: DataType = BinaryType - - protected override def nullSafeEval(num: Any): Any = - unhex(num.asInstanceOf[UTF8String].getBytes) - - private[this] def unhex(bytes: Array[Byte]): Array[Byte] = { + def unhex(bytes: Array[Byte]): Array[Byte] = { val out = new Array[Byte]((bytes.length + 1) >> 1) var i = 0 if ((bytes.length & 0x01) != 0) { @@ -372,6 +433,60 @@ case class Unhex(child: Expression) extends UnaryExpression with ExpectsInputTyp } } +/** + * If the argument is an INT or binary, hex returns the number as a STRING in hexadecimal format. + * Otherwise if the number is a STRING, it converts each character into its hex representation + * and returns the resulting STRING. Negative numbers would be treated as two's complement. + */ +case class Hex(child: Expression) extends UnaryExpression with ImplicitCastInputTypes { + + override def inputTypes: Seq[AbstractDataType] = + Seq(TypeCollection(LongType, BinaryType, StringType)) + + override def dataType: DataType = StringType + + protected override def nullSafeEval(num: Any): Any = child.dataType match { + case LongType => Hex.hex(num.asInstanceOf[Long]) + case BinaryType => Hex.hex(num.asInstanceOf[Array[Byte]]) + case StringType => Hex.hex(num.asInstanceOf[UTF8String].getBytes) + } + + override protected def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + nullSafeCodeGen(ctx, ev, (c) => { + val hex = Hex.getClass.getName.stripSuffix("$") + s"${ev.primitive} = " + (child.dataType match { + case StringType => s"""$hex.hex($c.getBytes());""" + case _ => s"""$hex.hex($c);""" + }) + }) + } +} + +/** + * Performs the inverse operation of HEX. + * Resulting characters are returned as a byte array. + */ +case class Unhex(child: Expression) extends UnaryExpression with ImplicitCastInputTypes { + + override def inputTypes: Seq[AbstractDataType] = Seq(StringType) + + override def nullable: Boolean = true + override def dataType: DataType = BinaryType + + protected override def nullSafeEval(num: Any): Any = + Hex.unhex(num.asInstanceOf[UTF8String].getBytes) + + override protected def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + nullSafeCodeGen(ctx, ev, (c) => { + val hex = Hex.getClass.getName.stripSuffix("$") + s""" + ${ev.primitive} = $hex.unhex($c.getBytes()); + ${ev.isNull} = ${ev.primitive} == null; + """ + }) + } +} + //////////////////////////////////////////////////////////////////////////////////////////////////// //////////////////////////////////////////////////////////////////////////////////////////////////// @@ -385,27 +500,18 @@ case class Atan2(left: Expression, right: Expression) protected override def nullSafeEval(input1: Any, input2: Any): Any = { // With codegen, the values returned by -0.0 and 0.0 are different. Handled with +0.0 - val result = math.atan2(input1.asInstanceOf[Double] + 0.0, input2.asInstanceOf[Double] + 0.0) - if (result.isNaN) null else result + math.atan2(input1.asInstanceOf[Double] + 0.0, input2.asInstanceOf[Double] + 0.0) } override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { - defineCodeGen(ctx, ev, (c1, c2) => s"java.lang.Math.atan2($c1 + 0.0, $c2 + 0.0)") + s""" - if (Double.valueOf(${ev.primitive}).isNaN()) { - ${ev.isNull} = true; - } - """ + defineCodeGen(ctx, ev, (c1, c2) => s"java.lang.Math.atan2($c1 + 0.0, $c2 + 0.0)") } } case class Pow(left: Expression, right: Expression) extends BinaryMathExpression(math.pow, "POWER") { override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { - defineCodeGen(ctx, ev, (c1, c2) => s"java.lang.Math.pow($c1, $c2)") + s""" - if (Double.valueOf(${ev.primitive}).isNaN()) { - ${ev.isNull} = true; - } - """ + defineCodeGen(ctx, ev, (c1, c2) => s"java.lang.Math.pow($c1, $c2)") } } @@ -416,7 +522,7 @@ case class Pow(left: Expression, right: Expression) * @param right number of bits to left shift. */ case class ShiftLeft(left: Expression, right: Expression) - extends BinaryExpression with ExpectsInputTypes { + extends BinaryExpression with ImplicitCastInputTypes { override def inputTypes: Seq[AbstractDataType] = Seq(TypeCollection(IntegerType, LongType), IntegerType) @@ -442,7 +548,7 @@ case class ShiftLeft(left: Expression, right: Expression) * @param right number of bits to left shift. */ case class ShiftRight(left: Expression, right: Expression) - extends BinaryExpression with ExpectsInputTypes { + extends BinaryExpression with ImplicitCastInputTypes { override def inputTypes: Seq[AbstractDataType] = Seq(TypeCollection(IntegerType, LongType), IntegerType) @@ -468,7 +574,7 @@ case class ShiftRight(left: Expression, right: Expression) * @param right the number of bits to right shift. */ case class ShiftRightUnsigned(left: Expression, right: Expression) - extends BinaryExpression with ExpectsInputTypes { + extends BinaryExpression with ImplicitCastInputTypes { override def inputTypes: Seq[AbstractDataType] = Seq(TypeCollection(IntegerType, LongType), IntegerType) @@ -507,16 +613,231 @@ case class Logarithm(left: Expression, right: Expression) this(EulerNumber(), child) } + protected override def nullSafeEval(input1: Any, input2: Any): Any = { + val dLeft = input1.asInstanceOf[Double] + val dRight = input2.asInstanceOf[Double] + // Unlike Hive, we support Log base in (0.0, 1.0] + if (dLeft <= 0.0 || dRight <= 0.0) null else math.log(dRight) / math.log(dLeft) + } + override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { - val logCode = if (left.isInstanceOf[EulerNumber]) { - defineCodeGen(ctx, ev, (c1, c2) => s"java.lang.Math.log($c2)") + if (left.isInstanceOf[EulerNumber]) { + nullSafeCodeGen(ctx, ev, (c1, c2) => + s""" + if ($c2 <= 0.0) { + ${ev.isNull} = true; + } else { + ${ev.primitive} = java.lang.Math.log($c2); + } + """) } else { - defineCodeGen(ctx, ev, (c1, c2) => s"java.lang.Math.log($c2) / java.lang.Math.log($c1)") + nullSafeCodeGen(ctx, ev, (c1, c2) => + s""" + if ($c1 <= 0.0 || $c2 <= 0.0) { + ${ev.isNull} = true; + } else { + ${ev.primitive} = java.lang.Math.log($c2) / java.lang.Math.log($c1); + } + """) + } + } +} + +/** + * Round the `child`'s result to `scale` decimal place when `scale` >= 0 + * or round at integral part when `scale` < 0. + * For example, round(31.415, 2) would eval to 31.42 and round(31.415, -1) would eval to 30. + * + * Child of IntegralType would eval to itself when `scale` >= 0. + * Child of FractionalType whose value is NaN or Infinite would always eval to itself. + * + * Round's dataType would always equal to `child`'s dataType except for [[DecimalType.Fixed]], + * which leads to scale update in DecimalType's [[PrecisionInfo]] + * + * @param child expr to be round, all [[NumericType]] is allowed as Input + * @param scale new scale to be round to, this should be a constant int at runtime + */ +case class Round(child: Expression, scale: Expression) + extends BinaryExpression with ExpectsInputTypes { + + import BigDecimal.RoundingMode.HALF_UP + + def this(child: Expression) = this(child, Literal(0)) + + override def left: Expression = child + override def right: Expression = scale + + // round of Decimal would eval to null if it fails to `changePrecision` + override def nullable: Boolean = true + + override def foldable: Boolean = child.foldable + + override lazy val dataType: DataType = child.dataType match { + // if the new scale is bigger which means we are scaling up, + // keep the original scale as `Decimal` does + case DecimalType.Fixed(p, s) => DecimalType(p, if (_scale > s) s else _scale) + case t => t + } + + override def inputTypes: Seq[AbstractDataType] = Seq(NumericType, IntegerType) + + override def checkInputDataTypes(): TypeCheckResult = { + super.checkInputDataTypes() match { + case TypeCheckSuccess => + if (scale.foldable) { + TypeCheckSuccess + } else { + TypeCheckFailure("Only foldable Expression is allowed for scale arguments") + } + case f => f } - logCode + s""" - if (Double.isNaN(${ev.primitive})) { - ${ev.isNull} = true; + } + + // Avoid repeated evaluation since `scale` is a constant int, + // avoid unnecessary `child` evaluation in both codegen and non-codegen eval + // by checking if scaleV == null as well. + private lazy val scaleV: Any = scale.eval(EmptyRow) + private lazy val _scale: Int = scaleV.asInstanceOf[Int] + + override def eval(input: InternalRow): Any = { + if (scaleV == null) { // if scale is null, no need to eval its child at all + null + } else { + val evalE = child.eval(input) + if (evalE == null) { + null + } else { + nullSafeEval(evalE) } - """ + } } + + // not overriding since _scale is a constant int at runtime + def nullSafeEval(input1: Any): Any = { + child.dataType match { + case _: DecimalType => + val decimal = input1.asInstanceOf[Decimal] + if (decimal.changePrecision(decimal.precision, _scale)) decimal else null + case ByteType => + BigDecimal(input1.asInstanceOf[Byte]).setScale(_scale, HALF_UP).toByte + case ShortType => + BigDecimal(input1.asInstanceOf[Short]).setScale(_scale, HALF_UP).toShort + case IntegerType => + BigDecimal(input1.asInstanceOf[Int]).setScale(_scale, HALF_UP).toInt + case LongType => + BigDecimal(input1.asInstanceOf[Long]).setScale(_scale, HALF_UP).toLong + case FloatType => + val f = input1.asInstanceOf[Float] + if (f.isNaN || f.isInfinite) { + f + } else { + BigDecimal(f).setScale(_scale, HALF_UP).toFloat + } + case DoubleType => + val d = input1.asInstanceOf[Double] + if (d.isNaN || d.isInfinite) { + d + } else { + BigDecimal(d).setScale(_scale, HALF_UP).toDouble + } + } + } + + override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + val ce = child.gen(ctx) + + val evaluationCode = child.dataType match { + case _: DecimalType => + s""" + if (${ce.primitive}.changePrecision(${ce.primitive}.precision(), ${_scale})) { + ${ev.primitive} = ${ce.primitive}; + } else { + ${ev.isNull} = true; + }""" + case ByteType => + if (_scale < 0) { + s""" + ${ev.primitive} = new java.math.BigDecimal(${ce.primitive}). + setScale(${_scale}, java.math.BigDecimal.ROUND_HALF_UP).byteValue();""" + } else { + s"${ev.primitive} = ${ce.primitive};" + } + case ShortType => + if (_scale < 0) { + s""" + ${ev.primitive} = new java.math.BigDecimal(${ce.primitive}). + setScale(${_scale}, java.math.BigDecimal.ROUND_HALF_UP).shortValue();""" + } else { + s"${ev.primitive} = ${ce.primitive};" + } + case IntegerType => + if (_scale < 0) { + s""" + ${ev.primitive} = new java.math.BigDecimal(${ce.primitive}). + setScale(${_scale}, java.math.BigDecimal.ROUND_HALF_UP).intValue();""" + } else { + s"${ev.primitive} = ${ce.primitive};" + } + case LongType => + if (_scale < 0) { + s""" + ${ev.primitive} = new java.math.BigDecimal(${ce.primitive}). + setScale(${_scale}, java.math.BigDecimal.ROUND_HALF_UP).longValue();""" + } else { + s"${ev.primitive} = ${ce.primitive};" + } + case FloatType => // if child eval to NaN or Infinity, just return it. + if (_scale == 0) { + s""" + if (Float.isNaN(${ce.primitive}) || Float.isInfinite(${ce.primitive})){ + ${ev.primitive} = ${ce.primitive}; + } else { + ${ev.primitive} = Math.round(${ce.primitive}); + }""" + } else { + s""" + if (Float.isNaN(${ce.primitive}) || Float.isInfinite(${ce.primitive})){ + ${ev.primitive} = ${ce.primitive}; + } else { + ${ev.primitive} = java.math.BigDecimal.valueOf(${ce.primitive}). + setScale(${_scale}, java.math.BigDecimal.ROUND_HALF_UP).floatValue(); + }""" + } + case DoubleType => // if child eval to NaN or Infinity, just return it. + if (_scale == 0) { + s""" + if (Double.isNaN(${ce.primitive}) || Double.isInfinite(${ce.primitive})){ + ${ev.primitive} = ${ce.primitive}; + } else { + ${ev.primitive} = Math.round(${ce.primitive}); + }""" + } else { + s""" + if (Double.isNaN(${ce.primitive}) || Double.isInfinite(${ce.primitive})){ + ${ev.primitive} = ${ce.primitive}; + } else { + ${ev.primitive} = java.math.BigDecimal.valueOf(${ce.primitive}). + setScale(${_scale}, java.math.BigDecimal.ROUND_HALF_UP).doubleValue(); + }""" + } + } + + if (scaleV == null) { // if scale is null, no need to eval its child at all + s""" + boolean ${ev.isNull} = true; + ${ctx.javaType(dataType)} ${ev.primitive} = ${ctx.defaultValue(dataType)}; + """ + } else { + s""" + ${ce.code} + boolean ${ev.isNull} = ${ce.isNull}; + ${ctx.javaType(dataType)} ${ev.primitive} = ${ctx.defaultValue(dataType)}; + if (!${ev.isNull}) { + $evaluationCode + } + """ + } + } + + override def prettyName: String = "round" } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala index 3b59cd431b871..8d8d66ddeb341 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala @@ -17,12 +17,11 @@ package org.apache.spark.sql.catalyst.expressions -import java.security.MessageDigest -import java.security.NoSuchAlgorithmException +import java.security.{MessageDigest, NoSuchAlgorithmException} import java.util.zip.CRC32 import org.apache.commons.codec.digest.DigestUtils -import org.apache.spark.sql.catalyst.analysis.TypeCheckResult + import org.apache.spark.sql.catalyst.expressions.codegen._ import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.UTF8String @@ -31,7 +30,7 @@ import org.apache.spark.unsafe.types.UTF8String * A function that calculates an MD5 128-bit checksum and returns it as a hex string * For input of type [[BinaryType]] */ -case class Md5(child: Expression) extends UnaryExpression with ExpectsInputTypes { +case class Md5(child: Expression) extends UnaryExpression with ImplicitCastInputTypes { override def dataType: DataType = StringType @@ -42,7 +41,7 @@ case class Md5(child: Expression) extends UnaryExpression with ExpectsInputTypes override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { defineCodeGen(ctx, ev, c => - s"${ctx.stringType}.fromString(org.apache.commons.codec.digest.DigestUtils.md5Hex($c))") + s"UTF8String.fromString(org.apache.commons.codec.digest.DigestUtils.md5Hex($c))") } } @@ -55,7 +54,7 @@ case class Md5(child: Expression) extends UnaryExpression with ExpectsInputTypes * the hash length is not one of the permitted values, the return value is NULL. */ case class Sha2(left: Expression, right: Expression) - extends BinaryExpression with Serializable with ExpectsInputTypes { + extends BinaryExpression with Serializable with ImplicitCastInputTypes { override def dataType: DataType = StringType @@ -93,19 +92,19 @@ case class Sha2(left: Expression, right: Expression) try { java.security.MessageDigest md = java.security.MessageDigest.getInstance("SHA-224"); md.update($eval1); - ${ev.primitive} = ${ctx.stringType}.fromBytes(md.digest()); + ${ev.primitive} = UTF8String.fromBytes(md.digest()); } catch (java.security.NoSuchAlgorithmException e) { ${ev.isNull} = true; } } else if ($eval2 == 256 || $eval2 == 0) { ${ev.primitive} = - ${ctx.stringType}.fromString($digestUtils.sha256Hex($eval1)); + UTF8String.fromString($digestUtils.sha256Hex($eval1)); } else if ($eval2 == 384) { ${ev.primitive} = - ${ctx.stringType}.fromString($digestUtils.sha384Hex($eval1)); + UTF8String.fromString($digestUtils.sha384Hex($eval1)); } else if ($eval2 == 512) { ${ev.primitive} = - ${ctx.stringType}.fromString($digestUtils.sha512Hex($eval1)); + UTF8String.fromString($digestUtils.sha512Hex($eval1)); } else { ${ev.isNull} = true; } @@ -118,7 +117,7 @@ case class Sha2(left: Expression, right: Expression) * A function that calculates a sha1 hash value and returns it as a hex string * For input of type [[BinaryType]] or [[StringType]] */ -case class Sha1(child: Expression) extends UnaryExpression with ExpectsInputTypes { +case class Sha1(child: Expression) extends UnaryExpression with ImplicitCastInputTypes { override def dataType: DataType = StringType @@ -129,7 +128,7 @@ case class Sha1(child: Expression) extends UnaryExpression with ExpectsInputType override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { defineCodeGen(ctx, ev, c => - s"${ctx.stringType}.fromString(org.apache.commons.codec.digest.DigestUtils.shaHex($c))" + s"UTF8String.fromString(org.apache.commons.codec.digest.DigestUtils.shaHex($c))" ) } } @@ -138,7 +137,7 @@ case class Sha1(child: Expression) extends UnaryExpression with ExpectsInputType * A function that computes a cyclic redundancy check value and returns it as a bigint * For input of type [[BinaryType]] */ -case class Crc32(child: Expression) extends UnaryExpression with ExpectsInputTypes { +case class Crc32(child: Expression) extends UnaryExpression with ImplicitCastInputTypes { override def dataType: DataType = LongType diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala index 6181c60c0e453..6f173b52ad9b9 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala @@ -19,9 +19,7 @@ package org.apache.spark.sql.catalyst.expressions import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.analysis.UnresolvedAttribute -import org.apache.spark.sql.catalyst.errors.TreeNodeException -import org.apache.spark.sql.catalyst.expressions.codegen.{CodeGenContext, GeneratedExpressionCode} -import org.apache.spark.sql.catalyst.trees +import org.apache.spark.sql.catalyst.expressions.codegen._ import org.apache.spark.sql.types._ object NamedExpression { @@ -37,8 +35,13 @@ object NamedExpression { */ case class ExprId(id: Long) -abstract class NamedExpression extends Expression { - self: Product => +/** + * An [[Expression]] that is named. + */ +trait NamedExpression extends Expression { + + /** We should never fold named expressions in order to not remove the alias. */ + override def foldable: Boolean = false def name: String def exprId: ExprId @@ -78,8 +81,7 @@ abstract class NamedExpression extends Expression { } } -abstract class Attribute extends NamedExpression { - self: Product => +abstract class Attribute extends LeafExpression with NamedExpression { override def references: AttributeSet = AttributeSet(this) @@ -110,7 +112,7 @@ case class Alias(child: Expression, name: String)( val exprId: ExprId = NamedExpression.newExprId, val qualifiers: Seq[String] = Nil, val explicitMetadata: Option[Metadata] = None) - extends NamedExpression with trees.UnaryNode[Expression] { + extends UnaryExpression with NamedExpression { // Alias(Generator, xx) need to be transformed into Generate(generator, ...) override lazy val resolved = @@ -118,7 +120,9 @@ case class Alias(child: Expression, name: String)( override def eval(input: InternalRow): Any = child.eval(input) + /** Just a simple passthrough for code generation. */ override def gen(ctx: CodeGenContext): GeneratedExpressionCode = child.gen(ctx) + override protected def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = "" override def dataType: DataType = child.dataType override def nullable: Boolean = child.nullable @@ -172,7 +176,8 @@ case class AttributeReference( nullable: Boolean = true, override val metadata: Metadata = Metadata.empty)( val exprId: ExprId = NamedExpression.newExprId, - val qualifiers: Seq[String] = Nil) extends Attribute with trees.LeafNode[Expression] { + val qualifiers: Seq[String] = Nil) + extends Attribute with Unevaluable { /** * Returns true iff the expression id is the same for both attributes. @@ -231,10 +236,6 @@ case class AttributeReference( } } - // Unresolved attributes are transient at compile time and don't get evaluated during execution. - override def eval(input: InternalRow = null): Any = - throw new TreeNodeException(this, s"No function to evaluate expression. type: ${this.nodeName}") - override def toString: String = s"$name#${exprId.id}$typeSuffix" } @@ -242,7 +243,7 @@ case class AttributeReference( * A place holder used when printing expressions without debugging information such as the * expression id or the unresolved indicator. */ -case class PrettyAttribute(name: String) extends Attribute with trees.LeafNode[Expression] { +case class PrettyAttribute(name: String) extends Attribute with Unevaluable { override def toString: String = name @@ -254,7 +255,6 @@ case class PrettyAttribute(name: String) extends Attribute with trees.LeafNode[E override def withName(newName: String): Attribute = throw new UnsupportedOperationException override def qualifiers: Seq[String] = throw new UnsupportedOperationException override def exprId: ExprId = throw new UnsupportedOperationException - override def eval(input: InternalRow): Any = throw new UnsupportedOperationException override def nullable: Boolean = throw new UnsupportedOperationException override def dataType: DataType = NullType } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullFunctions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullFunctions.scala index 1522bcae08d17..287718fab7f0d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullFunctions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullFunctions.scala @@ -21,8 +21,19 @@ import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.analysis.TypeCheckResult import org.apache.spark.sql.catalyst.expressions.codegen.{CodeGenContext, GeneratedExpressionCode} import org.apache.spark.sql.catalyst.util.TypeUtils -import org.apache.spark.sql.types.DataType +import org.apache.spark.sql.types._ + +/** + * An expression that is evaluated to the first non-null input. + * + * {{{ + * coalesce(1, 2) => 1 + * coalesce(null, 1, 2) => 1 + * coalesce(null, null, 2) => 2 + * coalesce(null, null, null) => null + * }}} + */ case class Coalesce(children: Seq[Expression]) extends Expression { /** Coalesce is nullable if all of its children are nullable, or if it has no children. */ @@ -70,6 +81,101 @@ case class Coalesce(children: Seq[Expression]) extends Expression { } } + +/** + * Evaluates to `true` iff it's NaN. + */ +case class IsNaN(child: Expression) extends UnaryExpression + with Predicate with ImplicitCastInputTypes { + + override def inputTypes: Seq[AbstractDataType] = Seq(TypeCollection(DoubleType, FloatType)) + + override def nullable: Boolean = false + + override def eval(input: InternalRow): Any = { + val value = child.eval(input) + if (value == null) { + false + } else { + child.dataType match { + case DoubleType => value.asInstanceOf[Double].isNaN + case FloatType => value.asInstanceOf[Float].isNaN + } + } + } + + override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + val eval = child.gen(ctx) + child.dataType match { + case DoubleType | FloatType => + s""" + ${eval.code} + boolean ${ev.isNull} = false; + ${ctx.javaType(dataType)} ${ev.primitive} = ${ctx.defaultValue(dataType)}; + ${ev.primitive} = !${eval.isNull} && Double.isNaN(${eval.primitive}); + """ + } + } +} + +/** + * An Expression evaluates to `left` iff it's not NaN, or evaluates to `right` otherwise. + * This Expression is useful for mapping NaN values to null. + */ +case class NaNvl(left: Expression, right: Expression) + extends BinaryExpression with ImplicitCastInputTypes { + + override def dataType: DataType = left.dataType + + override def inputTypes: Seq[AbstractDataType] = + Seq(TypeCollection(DoubleType, FloatType), TypeCollection(DoubleType, FloatType)) + + override def eval(input: InternalRow): Any = { + val value = left.eval(input) + if (value == null) { + null + } else { + left.dataType match { + case DoubleType => + if (!value.asInstanceOf[Double].isNaN) value else right.eval(input) + case FloatType => + if (!value.asInstanceOf[Float].isNaN) value else right.eval(input) + } + } + } + + override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + val leftGen = left.gen(ctx) + val rightGen = right.gen(ctx) + left.dataType match { + case DoubleType | FloatType => + s""" + ${leftGen.code} + boolean ${ev.isNull} = false; + ${ctx.javaType(dataType)} ${ev.primitive} = ${ctx.defaultValue(dataType)}; + if (${leftGen.isNull}) { + ${ev.isNull} = true; + } else { + if (!Double.isNaN(${leftGen.primitive})) { + ${ev.primitive} = ${leftGen.primitive}; + } else { + ${rightGen.code} + if (${rightGen.isNull}) { + ${ev.isNull} = true; + } else { + ${ev.primitive} = ${rightGen.primitive}; + } + } + } + """ + } + } +} + + +/** + * An expression that is evaluated to true if the input is null. + */ case class IsNull(child: Expression) extends UnaryExpression with Predicate { override def nullable: Boolean = false @@ -83,13 +189,14 @@ case class IsNull(child: Expression) extends UnaryExpression with Predicate { ev.primitive = eval.isNull eval.code } - - override def toString: String = s"IS NULL $child" } + +/** + * An expression that is evaluated to true if the input is not null. + */ case class IsNotNull(child: Expression) extends UnaryExpression with Predicate { override def nullable: Boolean = false - override def toString: String = s"IS NOT NULL $child" override def eval(input: InternalRow): Any = { child.eval(input) != null @@ -103,12 +210,13 @@ case class IsNotNull(child: Expression) extends UnaryExpression with Predicate { } } + /** - * A predicate that is evaluated to be true if there are at least `n` non-null values. + * A predicate that is evaluated to be true if there are at least `n` non-null and non-NaN values. */ case class AtLeastNNonNulls(n: Int, children: Seq[Expression]) extends Predicate { override def nullable: Boolean = false - override def foldable: Boolean = false + override def foldable: Boolean = children.forall(_.foldable) override def toString: String = s"AtLeastNNulls(n, ${children.mkString(",")})" private[this] val childrenArray = children.toArray @@ -117,8 +225,15 @@ case class AtLeastNNonNulls(n: Int, children: Seq[Expression]) extends Predicate var numNonNulls = 0 var i = 0 while (i < childrenArray.length && numNonNulls < n) { - if (childrenArray(i).eval(input) != null) { - numNonNulls += 1 + val evalC = childrenArray(i).eval(input) + if (evalC != null) { + childrenArray(i).dataType match { + case DoubleType => + if (!evalC.asInstanceOf[Double].isNaN) numNonNulls += 1 + case FloatType => + if (!evalC.asInstanceOf[Float].isNaN) numNonNulls += 1 + case _ => numNonNulls += 1 + } } i += 1 } @@ -129,14 +244,26 @@ case class AtLeastNNonNulls(n: Int, children: Seq[Expression]) extends Predicate val nonnull = ctx.freshName("nonnull") val code = children.map { e => val eval = e.gen(ctx) - s""" - if ($nonnull < $n) { - ${eval.code} - if (!${eval.isNull}) { - $nonnull += 1; - } - } - """ + e.dataType match { + case DoubleType | FloatType => + s""" + if ($nonnull < $n) { + ${eval.code} + if (!${eval.isNull} && !Double.isNaN(${eval.primitive})) { + $nonnull += 1; + } + } + """ + case _ => + s""" + if ($nonnull < $n) { + ${eval.code} + if (!${eval.isNull}) { + $nonnull += 1; + } + } + """ + } }.mkString("\n") s""" int $nonnull = 0; diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala index f74fd04619714..3f1bd2a925fe7 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala @@ -17,12 +17,13 @@ package org.apache.spark.sql.catalyst.expressions +import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenFallback, GeneratedExpressionCode, CodeGenContext} import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.analysis.TypeCheckResult -import org.apache.spark.sql.catalyst.util.TypeUtils -import org.apache.spark.sql.catalyst.expressions.codegen.{GeneratedExpressionCode, CodeGenContext} import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan +import org.apache.spark.sql.catalyst.util.TypeUtils import org.apache.spark.sql.types._ +import org.apache.spark.util.Utils + object InterpretedPredicate { def create(expression: Expression, inputSchema: Seq[Attribute]): (InternalRow => Boolean) = @@ -33,12 +34,15 @@ object InterpretedPredicate { } } -trait Predicate extends Expression { - self: Product => +/** + * An [[Expression]] that returns a boolean value. + */ +trait Predicate extends Expression { override def dataType: DataType = BooleanType } + trait PredicateHelper { protected def splitConjunctivePredicates(condition: Expression): Seq[Expression] = { condition match { @@ -70,7 +74,10 @@ trait PredicateHelper { expr.references.subsetOf(plan.outputSet) } -case class Not(child: Expression) extends UnaryExpression with Predicate with ExpectsInputTypes { + +case class Not(child: Expression) + extends UnaryExpression with Predicate with ImplicitCastInputTypes { + override def toString: String = s"NOT $child" override def inputTypes: Seq[DataType] = Seq(BooleanType) @@ -82,10 +89,11 @@ case class Not(child: Expression) extends UnaryExpression with Predicate with Ex } } + /** * Evaluates to `true` if `list` contains `value`. */ -case class In(value: Expression, list: Seq[Expression]) extends Predicate { +case class In(value: Expression, list: Seq[Expression]) extends Predicate with CodegenFallback { override def children: Seq[Expression] = value +: list override def nullable: Boolean = true // TODO: Figure out correct nullability semantics of IN. @@ -97,12 +105,13 @@ case class In(value: Expression, list: Seq[Expression]) extends Predicate { } } + /** * Optimized version of In clause, when all filter values of In clause are * static. */ case class InSet(child: Expression, hset: Set[Any]) - extends UnaryExpression with Predicate { + extends UnaryExpression with Predicate with CodegenFallback { override def nullable: Boolean = true // TODO: Figure out correct nullability semantics of IN. override def toString: String = s"$child INSET ${hset.mkString("(", ",", ")")}" @@ -112,12 +121,11 @@ case class InSet(child: Expression, hset: Set[Any]) } } -case class And(left: Expression, right: Expression) - extends BinaryExpression with Predicate with ExpectsInputTypes { +case class And(left: Expression, right: Expression) extends BinaryOperator with Predicate { - override def toString: String = s"($left && $right)" + override def inputType: AbstractDataType = BooleanType - override def inputTypes: Seq[DataType] = Seq(BooleanType, BooleanType) + override def symbol: String = "&&" override def eval(input: InternalRow): Any = { val input1 = left.eval(input) @@ -161,12 +169,12 @@ case class And(left: Expression, right: Expression) } } -case class Or(left: Expression, right: Expression) - extends BinaryExpression with Predicate with ExpectsInputTypes { - override def toString: String = s"($left || $right)" +case class Or(left: Expression, right: Expression) extends BinaryOperator with Predicate { - override def inputTypes: Seq[DataType] = Seq(BooleanType, BooleanType) + override def inputType: AbstractDataType = BooleanType + + override def symbol: String = "||" override def eval(input: InternalRow): Any = { val input1 = left.eval(input) @@ -210,23 +218,13 @@ case class Or(left: Expression, right: Expression) } } -abstract class BinaryComparison extends BinaryOperator with Predicate { - self: Product => - - override def checkInputDataTypes(): TypeCheckResult = { - if (left.dataType != right.dataType) { - TypeCheckResult.TypeCheckFailure( - s"differing types in ${this.getClass.getSimpleName} " + - s"(${left.dataType} and ${right.dataType}).") - } else { - checkTypesInternal(dataType) - } - } - protected def checkTypesInternal(t: DataType): TypeCheckResult +abstract class BinaryComparison extends BinaryOperator with Predicate { override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { - if (ctx.isPrimitiveType(left.dataType)) { + if (ctx.isPrimitiveType(left.dataType) + && left.dataType != FloatType + && left.dataType != DoubleType) { // faster version defineCodeGen(ctx, ev, (c1, c2) => s"$c1 $symbol $c2") } else { @@ -235,10 +233,12 @@ abstract class BinaryComparison extends BinaryOperator with Predicate { } } + private[sql] object BinaryComparison { def unapply(e: BinaryComparison): Option[(Expression, Expression)] = Some((e.left, e.right)) } + /** An extractor that matches both standard 3VL equality and null-safe equality. */ private[sql] object Equality { def unapply(e: BinaryComparison): Option[(Expression, Expression)] = e match { @@ -248,14 +248,23 @@ private[sql] object Equality { } } + case class EqualTo(left: Expression, right: Expression) extends BinaryComparison { - override def symbol: String = "=" - override protected def checkTypesInternal(t: DataType) = TypeCheckResult.TypeCheckSuccess + override def inputType: AbstractDataType = AnyDataType + + override def symbol: String = "=" protected override def nullSafeEval(input1: Any, input2: Any): Any = { - if (left.dataType != BinaryType) input1 == input2 - else java.util.Arrays.equals(input1.asInstanceOf[Array[Byte]], input2.asInstanceOf[Array[Byte]]) + if (left.dataType == FloatType) { + Utils.nanSafeCompareFloats(input1.asInstanceOf[Float], input2.asInstanceOf[Float]) == 0 + } else if (left.dataType == DoubleType) { + Utils.nanSafeCompareDoubles(input1.asInstanceOf[Double], input2.asInstanceOf[Double]) == 0 + } else if (left.dataType != BinaryType) { + input1 == input2 + } else { + java.util.Arrays.equals(input1.asInstanceOf[Array[Byte]], input2.asInstanceOf[Array[Byte]]) + } } override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { @@ -263,13 +272,15 @@ case class EqualTo(left: Expression, right: Expression) extends BinaryComparison } } + case class EqualNullSafe(left: Expression, right: Expression) extends BinaryComparison { + + override def inputType: AbstractDataType = AnyDataType + override def symbol: String = "<=>" override def nullable: Boolean = false - override protected def checkTypesInternal(t: DataType) = TypeCheckResult.TypeCheckSuccess - override def eval(input: InternalRow): Any = { val input1 = left.eval(input) val input2 = right.eval(input) @@ -278,7 +289,11 @@ case class EqualNullSafe(left: Expression, right: Expression) extends BinaryComp } else if (input1 == null || input2 == null) { false } else { - if (left.dataType != BinaryType) { + if (left.dataType == FloatType) { + Utils.nanSafeCompareFloats(input1.asInstanceOf[Float], input2.asInstanceOf[Float]) == 0 + } else if (left.dataType == DoubleType) { + Utils.nanSafeCompareDoubles(input1.asInstanceOf[Double], input2.asInstanceOf[Double]) == 0 + } else if (left.dataType != BinaryType) { input1 == input2 } else { java.util.Arrays.equals(input1.asInstanceOf[Array[Byte]], input2.asInstanceOf[Array[Byte]]) @@ -298,44 +313,48 @@ case class EqualNullSafe(left: Expression, right: Expression) extends BinaryComp } } + case class LessThan(left: Expression, right: Expression) extends BinaryComparison { - override def symbol: String = "<" - override protected def checkTypesInternal(t: DataType) = - TypeUtils.checkForOrderingExpr(left.dataType, "operator " + symbol) + override def inputType: AbstractDataType = TypeCollection.Ordered + + override def symbol: String = "<" private lazy val ordering = TypeUtils.getOrdering(left.dataType) protected override def nullSafeEval(input1: Any, input2: Any): Any = ordering.lt(input1, input2) } + case class LessThanOrEqual(left: Expression, right: Expression) extends BinaryComparison { - override def symbol: String = "<=" - override protected def checkTypesInternal(t: DataType) = - TypeUtils.checkForOrderingExpr(left.dataType, "operator " + symbol) + override def inputType: AbstractDataType = TypeCollection.Ordered + + override def symbol: String = "<=" private lazy val ordering = TypeUtils.getOrdering(left.dataType) protected override def nullSafeEval(input1: Any, input2: Any): Any = ordering.lteq(input1, input2) } + case class GreaterThan(left: Expression, right: Expression) extends BinaryComparison { - override def symbol: String = ">" - override protected def checkTypesInternal(t: DataType) = - TypeUtils.checkForOrderingExpr(left.dataType, "operator " + symbol) + override def inputType: AbstractDataType = TypeCollection.Ordered + + override def symbol: String = ">" private lazy val ordering = TypeUtils.getOrdering(left.dataType) protected override def nullSafeEval(input1: Any, input2: Any): Any = ordering.gt(input1, input2) } + case class GreaterThanOrEqual(left: Expression, right: Expression) extends BinaryComparison { - override def symbol: String = ">=" - override protected def checkTypesInternal(t: DataType) = - TypeUtils.checkForOrderingExpr(left.dataType, "operator " + symbol) + override def inputType: AbstractDataType = TypeCollection.Ordered + + override def symbol: String = ">=" private lazy val ordering = TypeUtils.getOrdering(left.dataType) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/random.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/random.scala index 6cdc3000382e2..aef24a5486466 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/random.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/random.scala @@ -20,6 +20,7 @@ package org.apache.spark.sql.catalyst.expressions import org.apache.spark.TaskContext import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions.codegen.{CodeGenContext, GeneratedExpressionCode} import org.apache.spark.sql.types.{DataType, DoubleType} import org.apache.spark.util.Utils import org.apache.spark.util.random.XORShiftRandom @@ -31,20 +32,15 @@ import org.apache.spark.util.random.XORShiftRandom * * Since this expression is stateful, it cannot be a case object. */ -abstract class RDG(seed: Long) extends LeafExpression with Serializable { - self: Product => +abstract class RDG extends LeafExpression with Nondeterministic { + + protected def seed: Long /** * Record ID within each partition. By being transient, the Random Number Generator is * reset every time we serialize and deserialize it. */ - @transient protected lazy val partitionId = TaskContext.get() match { - case null => 0 - case _ => TaskContext.get().partitionId() - } - @transient protected lazy val rng = new XORShiftRandom(seed + partitionId) - - override def deterministic: Boolean = false + @transient protected lazy val rng = new XORShiftRandom(seed + TaskContext.getPartitionId) override def nullable: Boolean = false @@ -52,7 +48,7 @@ abstract class RDG(seed: Long) extends LeafExpression with Serializable { } /** Generate a random column with i.i.d. uniformly distributed values in [0, 1). */ -case class Rand(seed: Long) extends RDG(seed) { +case class Rand(seed: Long) extends RDG { override def eval(input: InternalRow): Double = rng.nextDouble() def this() = this(Utils.random.nextLong()) @@ -61,10 +57,21 @@ case class Rand(seed: Long) extends RDG(seed) { case IntegerLiteral(s) => s case _ => throw new AnalysisException("Input argument to rand must be an integer literal.") }) + + override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + val rngTerm = ctx.freshName("rng") + val className = classOf[XORShiftRandom].getName + ctx.addMutableState(className, rngTerm, + s"$rngTerm = new $className($seed + org.apache.spark.TaskContext.getPartitionId());") + ev.isNull = "false" + s""" + final ${ctx.javaType(dataType)} ${ev.primitive} = $rngTerm.nextDouble(); + """ + } } /** Generate a random column with i.i.d. gaussian random distribution. */ -case class Randn(seed: Long) extends RDG(seed) { +case class Randn(seed: Long) extends RDG { override def eval(input: InternalRow): Double = rng.nextGaussian() def this() = this(Utils.random.nextLong()) @@ -73,4 +80,15 @@ case class Randn(seed: Long) extends RDG(seed) { case IntegerLiteral(s) => s case _ => throw new AnalysisException("Input argument to rand must be an integer literal.") }) + + override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + val rngTerm = ctx.freshName("rng") + val className = classOf[XORShiftRandom].getName + ctx.addMutableState(className, rngTerm, + s"$rngTerm = new $className($seed + org.apache.spark.TaskContext.getPartitionId());") + ev.isNull = "false" + s""" + final ${ctx.javaType(dataType)} ${ev.primitive} = $rngTerm.nextGaussian(); + """ + } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/rows.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/rows.scala index 094904bbf9c15..d78be5a5958f9 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/rows.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/rows.scala @@ -66,7 +66,7 @@ trait ArrayBackedRow { def length: Int = values.length - override def apply(i: Int): Any = values(i) + override def get(i: Int): Any = values(i) def setNullAt(i: Int): Unit = { values(i) = null} @@ -84,27 +84,6 @@ class GenericRow(protected[sql] val values: Array[Any]) extends Row with ArrayBa def this(size: Int) = this(new Array[Any](size)) - // This is used by test or outside - override def equals(o: Any): Boolean = o match { - case other: Row if other.length == length => - var i = 0 - while (i < length) { - if (isNullAt(i) != other.isNullAt(i)) { - return false - } - val equal = (apply(i), other.apply(i)) match { - case (a: Array[Byte], b: Array[Byte]) => java.util.Arrays.equals(a, b) - case (a, b) => a == b - } - if (!equal) { - return false - } - i += 1 - } - true - case _ => false - } - override def copy(): Row = this } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/sets.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/sets.scala index 49b2026364cd6..5b0fe8dfe2fc8 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/sets.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/sets.scala @@ -18,7 +18,7 @@ package org.apache.spark.sql.catalyst.expressions import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.expressions.codegen.{CodeGenContext, GeneratedExpressionCode} +import org.apache.spark.sql.catalyst.expressions.codegen._ import org.apache.spark.sql.types._ import org.apache.spark.util.collection.OpenHashSet @@ -52,7 +52,7 @@ private[sql] class OpenHashSetUDT( /** * Creates a new set of the specified type */ -case class NewSet(elementType: DataType) extends LeafExpression { +case class NewSet(elementType: DataType) extends LeafExpression with CodegenFallback { override def nullable: Boolean = false @@ -82,7 +82,8 @@ case class NewSet(elementType: DataType) extends LeafExpression { * Note: this expression is internal and created only by the GeneratedAggregate, * we don't need to do type check for it. */ -case class AddItemToSet(item: Expression, set: Expression) extends Expression { +case class AddItemToSet(item: Expression, set: Expression) + extends Expression with CodegenFallback { override def children: Seq[Expression] = item :: set :: Nil @@ -134,7 +135,8 @@ case class AddItemToSet(item: Expression, set: Expression) extends Expression { * Note: this expression is internal and created only by the GeneratedAggregate, * we don't need to do type check for it. */ -case class CombineSets(left: Expression, right: Expression) extends BinaryExpression { +case class CombineSets(left: Expression, right: Expression) + extends BinaryExpression with CodegenFallback { override def nullable: Boolean = left.nullable override def dataType: DataType = left.dataType @@ -181,7 +183,7 @@ case class CombineSets(left: Expression, right: Expression) extends BinaryExpres * Note: this expression is internal and created only by the GeneratedAggregate, * we don't need to do type check for it. */ -case class CountSet(child: Expression) extends UnaryExpression { +case class CountSet(child: Expression) extends UnaryExpression with CodegenFallback { override def dataType: DataType = LongType diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala index f64899c1ed84c..cf187ad5a0a9f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala @@ -17,10 +17,9 @@ package org.apache.spark.sql.catalyst.expressions +import java.text.DecimalFormat import java.util.Locale -import java.util.regex.Pattern - -import org.apache.commons.lang3.StringUtils +import java.util.regex.{MatchResult, Pattern} import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.analysis.UnresolvedException @@ -28,8 +27,101 @@ import org.apache.spark.sql.catalyst.expressions.codegen._ import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.UTF8String +//////////////////////////////////////////////////////////////////////////////////////////////////// +// This file defines expressions for string operations. +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +/** + * An expression that concatenates multiple input strings into a single string. + * If any input is null, concat returns null. + */ +case class Concat(children: Seq[Expression]) extends Expression with ImplicitCastInputTypes { + + override def inputTypes: Seq[AbstractDataType] = Seq.fill(children.size)(StringType) + override def dataType: DataType = StringType + + override def nullable: Boolean = children.exists(_.nullable) + override def foldable: Boolean = children.forall(_.foldable) + + override def eval(input: InternalRow): Any = { + val inputs = children.map(_.eval(input).asInstanceOf[UTF8String]) + UTF8String.concat(inputs : _*) + } + + override protected def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + val evals = children.map(_.gen(ctx)) + val inputs = evals.map { eval => + s"${eval.isNull} ? (UTF8String)null : ${eval.primitive}" + }.mkString(", ") + evals.map(_.code).mkString("\n") + s""" + boolean ${ev.isNull} = false; + UTF8String ${ev.primitive} = UTF8String.concat($inputs); + if (${ev.primitive} == null) { + ${ev.isNull} = true; + } + """ + } +} + + +/** + * An expression that concatenates multiple input strings or array of strings into a single string, + * using a given separator (the first child). + * + * Returns null if the separator is null. Otherwise, concat_ws skips all null values. + */ +case class ConcatWs(children: Seq[Expression]) + extends Expression with ImplicitCastInputTypes with CodegenFallback { + + require(children.nonEmpty, s"$prettyName requires at least one argument.") + + override def prettyName: String = "concat_ws" + + /** The 1st child (separator) is str, and rest are either str or array of str. */ + override def inputTypes: Seq[AbstractDataType] = { + val arrayOrStr = TypeCollection(ArrayType(StringType), StringType) + StringType +: Seq.fill(children.size - 1)(arrayOrStr) + } -trait StringRegexExpression extends ExpectsInputTypes { + override def dataType: DataType = StringType + + override def nullable: Boolean = children.head.nullable + override def foldable: Boolean = children.forall(_.foldable) + + override def eval(input: InternalRow): Any = { + val flatInputs = children.flatMap { child => + child.eval(input) match { + case s: UTF8String => Iterator(s) + case arr: Seq[_] => arr.asInstanceOf[Seq[UTF8String]] + case null => Iterator(null.asInstanceOf[UTF8String]) + } + } + UTF8String.concatWs(flatInputs.head, flatInputs.tail : _*) + } + + override protected def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + if (children.forall(_.dataType == StringType)) { + // All children are strings. In that case we can construct a fixed size array. + val evals = children.map(_.gen(ctx)) + + val inputs = evals.map { eval => + s"${eval.isNull} ? (UTF8String)null : ${eval.primitive}" + }.mkString(", ") + + evals.map(_.code).mkString("\n") + s""" + UTF8String ${ev.primitive} = UTF8String.concatWs($inputs); + boolean ${ev.isNull} = ${ev.primitive} == null; + """ + } else { + // Contains a mix of strings and arrays. Fall back to interpreted mode for now. + super.genCode(ctx, ev) + } + } +} + + +trait StringRegexExpression extends ImplicitCastInputTypes { self: BinaryExpression => def escape(v: String): String @@ -67,7 +159,7 @@ trait StringRegexExpression extends ExpectsInputTypes { * Simple RegEx pattern matching function */ case class Like(left: Expression, right: Expression) - extends BinaryExpression with StringRegexExpression { + extends BinaryExpression with StringRegexExpression with CodegenFallback { // replace the _ with .{1} exactly match 1 time of any character // replace the % with .*, match 0 or more times with any character @@ -97,15 +189,17 @@ case class Like(left: Expression, right: Expression) override def toString: String = s"$left LIKE $right" } + case class RLike(left: Expression, right: Expression) - extends BinaryExpression with StringRegexExpression { + extends BinaryExpression with StringRegexExpression with CodegenFallback { override def escape(v: String): String = v override def matches(regex: Pattern, str: String): Boolean = regex.matcher(str).find(0) override def toString: String = s"$left RLIKE $right" } -trait String2StringExpression extends ExpectsInputTypes { + +trait String2StringExpression extends ImplicitCastInputTypes { self: UnaryExpression => def convert(v: UTF8String): UTF8String @@ -120,7 +214,8 @@ trait String2StringExpression extends ExpectsInputTypes { /** * A function that converts the characters of a string to uppercase. */ -case class Upper(child: Expression) extends UnaryExpression with String2StringExpression { +case class Upper(child: Expression) + extends UnaryExpression with String2StringExpression { override def convert(v: UTF8String): UTF8String = v.toUpperCase @@ -142,7 +237,7 @@ case class Lower(child: Expression) extends UnaryExpression with String2StringEx } /** A base trait for functions that compare two strings, returning a boolean. */ -trait StringComparison extends ExpectsInputTypes { +trait StringComparison extends ImplicitCastInputTypes { self: BinaryExpression => def compare(l: UTF8String, r: UTF8String): Boolean @@ -241,7 +336,7 @@ case class StringTrimRight(child: Expression) * NOTE: that this is not zero based, but 1-based index. The first character in str has index 1. */ case class StringInstr(str: Expression, substr: Expression) - extends BinaryExpression with ExpectsInputTypes { + extends BinaryExpression with ImplicitCastInputTypes { override def left: Expression = str override def right: Expression = substr @@ -265,7 +360,7 @@ case class StringInstr(str: Expression, substr: Expression) * in given string after position pos. */ case class StringLocate(substr: Expression, str: Expression, start: Expression) - extends Expression with ExpectsInputTypes { + extends Expression with ImplicitCastInputTypes with CodegenFallback { def this(substr: Expression, str: Expression) = { this(substr, str, Literal(0)) @@ -306,7 +401,7 @@ case class StringLocate(substr: Expression, str: Expression, start: Expression) * Returns str, left-padded with pad to a length of len. */ case class StringLPad(str: Expression, len: Expression, pad: Expression) - extends Expression with ExpectsInputTypes { + extends Expression with ImplicitCastInputTypes { override def children: Seq[Expression] = str :: len :: pad :: Nil override def foldable: Boolean = children.forall(_.foldable) @@ -337,6 +432,31 @@ case class StringLPad(str: Expression, len: Expression, pad: Expression) } } + override protected def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + val lenGen = len.gen(ctx) + val strGen = str.gen(ctx) + val padGen = pad.gen(ctx) + + s""" + ${lenGen.code} + boolean ${ev.isNull} = ${lenGen.isNull}; + ${ctx.javaType(dataType)} ${ev.primitive} = ${ctx.defaultValue(dataType)}; + if (!${ev.isNull}) { + ${strGen.code} + if (!${strGen.isNull}) { + ${padGen.code} + if (!${padGen.isNull}) { + ${ev.primitive} = ${strGen.primitive}.lpad(${lenGen.primitive}, ${padGen.primitive}); + } else { + ${ev.isNull} = true; + } + } else { + ${ev.isNull} = true; + } + } + """ + } + override def prettyName: String = "lpad" } @@ -344,7 +464,7 @@ case class StringLPad(str: Expression, len: Expression, pad: Expression) * Returns str, right-padded with pad to a length of len. */ case class StringRPad(str: Expression, len: Expression, pad: Expression) - extends Expression with ExpectsInputTypes { + extends Expression with ImplicitCastInputTypes { override def children: Seq[Expression] = str :: len :: pad :: Nil override def foldable: Boolean = children.forall(_.foldable) @@ -375,45 +495,107 @@ case class StringRPad(str: Expression, len: Expression, pad: Expression) } } + override protected def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + val lenGen = len.gen(ctx) + val strGen = str.gen(ctx) + val padGen = pad.gen(ctx) + + s""" + ${lenGen.code} + boolean ${ev.isNull} = ${lenGen.isNull}; + ${ctx.javaType(dataType)} ${ev.primitive} = ${ctx.defaultValue(dataType)}; + if (!${ev.isNull}) { + ${strGen.code} + if (!${strGen.isNull}) { + ${padGen.code} + if (!${padGen.isNull}) { + ${ev.primitive} = ${strGen.primitive}.rpad(${lenGen.primitive}, ${padGen.primitive}); + } else { + ${ev.isNull} = true; + } + } else { + ${ev.isNull} = true; + } + } + """ + } + override def prettyName: String = "rpad" } /** * Returns the input formatted according do printf-style format strings */ -case class StringFormat(children: Expression*) extends Expression { +case class FormatString(children: Expression*) extends Expression with ImplicitCastInputTypes { - require(children.length >=1, "printf() should take at least 1 argument") + require(children.nonEmpty, "format_string() should take at least 1 argument") override def foldable: Boolean = children.forall(_.foldable) override def nullable: Boolean = children(0).nullable override def dataType: DataType = StringType - private def format: Expression = children(0) - private def args: Seq[Expression] = children.tail + + override def inputTypes: Seq[AbstractDataType] = + StringType :: List.fill(children.size - 1)(AnyDataType) override def eval(input: InternalRow): Any = { - val pattern = format.eval(input) + val pattern = children(0).eval(input) if (pattern == null) { null } else { val sb = new StringBuffer() val formatter = new java.util.Formatter(sb, Locale.US) - val arglist = args.map(_.eval(input).asInstanceOf[AnyRef]) - formatter.format(pattern.asInstanceOf[UTF8String].toString(), arglist: _*) + val arglist = children.tail.map(_.eval(input).asInstanceOf[AnyRef]) + formatter.format(pattern.asInstanceOf[UTF8String].toString, arglist: _*) UTF8String.fromString(sb.toString) } } - override def prettyName: String = "printf" + override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + val pattern = children.head.gen(ctx) + + val argListGen = children.tail.map(x => (x.dataType, x.gen(ctx))) + val argListCode = argListGen.map(_._2.code + "\n") + + val argListString = argListGen.foldLeft("")((s, v) => { + val nullSafeString = + if (ctx.boxedType(v._1) != ctx.javaType(v._1)) { + // Java primitives get boxed in order to allow null values. + s"(${v._2.isNull}) ? (${ctx.boxedType(v._1)}) null : " + + s"new ${ctx.boxedType(v._1)}(${v._2.primitive})" + } else { + s"(${v._2.isNull}) ? null : ${v._2.primitive}" + } + s + "," + nullSafeString + }) + + val form = ctx.freshName("formatter") + val formatter = classOf[java.util.Formatter].getName + val sb = ctx.freshName("sb") + val stringBuffer = classOf[StringBuffer].getName + s""" + ${pattern.code} + boolean ${ev.isNull} = ${pattern.isNull}; + ${ctx.javaType(dataType)} ${ev.primitive} = ${ctx.defaultValue(dataType)}; + if (!${ev.isNull}) { + ${argListCode.mkString} + $stringBuffer $sb = new $stringBuffer(); + $formatter $form = new $formatter($sb, ${classOf[Locale].getName}.US); + $form.format(${pattern.primitive}.toString() $argListString); + ${ev.primitive} = UTF8String.fromString($sb.toString()); + } + """ + } + + override def prettyName: String = "format_string" } /** * Returns the string which repeat the given string value n times. */ case class StringRepeat(str: Expression, times: Expression) - extends BinaryExpression with ExpectsInputTypes { + extends BinaryExpression with ImplicitCastInputTypes { override def left: Expression = str override def right: Expression = times @@ -447,17 +629,20 @@ case class StringReverse(child: Expression) extends UnaryExpression with String2 /** * Returns a n spaces string. */ -case class StringSpace(child: Expression) extends UnaryExpression with ExpectsInputTypes { +case class StringSpace(child: Expression) + extends UnaryExpression with ImplicitCastInputTypes { override def dataType: DataType = StringType override def inputTypes: Seq[DataType] = Seq(IntegerType) override def nullSafeEval(s: Any): Any = { - val length = s.asInstanceOf[Integer] + val length = s.asInstanceOf[Int] + UTF8String.blankString(if (length < 0) 0 else length) + } - val spaces = new Array[Byte](if (length < 0) 0 else length) - java.util.Arrays.fill(spaces, ' '.asInstanceOf[Byte]) - UTF8String.fromBytes(spaces) + override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + nullSafeCodeGen(ctx, ev, (length) => + s"""${ev.primitive} = UTF8String.blankString(($length < 0) ? 0 : $length);""") } override def prettyName: String = "space" @@ -467,7 +652,7 @@ case class StringSpace(child: Expression) extends UnaryExpression with ExpectsIn * Splits str around pat (pattern is a regular expression). */ case class StringSplit(str: Expression, pattern: Expression) - extends BinaryExpression with ExpectsInputTypes { + extends BinaryExpression with ImplicitCastInputTypes { override def left: Expression = str override def right: Expression = pattern @@ -475,9 +660,13 @@ case class StringSplit(str: Expression, pattern: Expression) override def inputTypes: Seq[DataType] = Seq(StringType, StringType) override def nullSafeEval(string: Any, regex: Any): Any = { - val splits = - string.asInstanceOf[UTF8String].toString.split(regex.asInstanceOf[UTF8String].toString, -1) - splits.toSeq.map(UTF8String.fromString) + string.asInstanceOf[UTF8String].split(regex.asInstanceOf[UTF8String], -1).toSeq + } + + override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + nullSafeCodeGen(ctx, ev, (str, pattern) => + s"""${ev.primitive} = scala.collection.JavaConversions.asScalaBuffer( + java.util.Arrays.asList($str.split($pattern, -1)));""") } override def prettyName: String = "split" @@ -488,7 +677,7 @@ case class StringSplit(str: Expression, pattern: Expression) * Defined for String and Binary types. */ case class Substring(str: Expression, pos: Expression, len: Expression) - extends Expression with ExpectsInputTypes { + extends Expression with ImplicitCastInputTypes { def this(str: Expression, pos: Expression) = { this(str, pos, Literal(Integer.MAX_VALUE)) @@ -497,83 +686,87 @@ case class Substring(str: Expression, pos: Expression, len: Expression) override def foldable: Boolean = str.foldable && pos.foldable && len.foldable override def nullable: Boolean = str.nullable || pos.nullable || len.nullable - override def dataType: DataType = { - if (!resolved) { - throw new UnresolvedException(this, s"Cannot resolve since $children are not resolved") - } - if (str.dataType == BinaryType) str.dataType else StringType - } + override def dataType: DataType = StringType override def inputTypes: Seq[DataType] = Seq(StringType, IntegerType, IntegerType) override def children: Seq[Expression] = str :: pos :: len :: Nil - @inline - def slicePos(startPos: Int, sliceLen: Int, length: () => Int): (Int, Int) = { - // Hive and SQL use one-based indexing for SUBSTR arguments but also accept zero and - // negative indices for start positions. If a start index i is greater than 0, it - // refers to element i-1 in the sequence. If a start index i is less than 0, it refers - // to the -ith element before the end of the sequence. If a start index i is 0, it - // refers to the first element. - - val start = startPos match { - case pos if pos > 0 => pos - 1 - case neg if neg < 0 => length() + neg - case _ => 0 - } - - val end = sliceLen match { - case max if max == Integer.MAX_VALUE => max - case x => start + x + override def eval(input: InternalRow): Any = { + val stringEval = str.eval(input) + if (stringEval != null) { + val posEval = pos.eval(input) + if (posEval != null) { + val lenEval = len.eval(input) + if (lenEval != null) { + stringEval.asInstanceOf[UTF8String] + .substringSQL(posEval.asInstanceOf[Int], lenEval.asInstanceOf[Int]) + } else { + null + } + } else { + null + } + } else { + null } - - (start, end) } - override def eval(input: InternalRow): Any = { - val string = str.eval(input) - val po = pos.eval(input) - val ln = len.eval(input) - - if ((string == null) || (po == null) || (ln == null)) { - null - } else { - val start = po.asInstanceOf[Int] - val length = ln.asInstanceOf[Int] - string match { - case ba: Array[Byte] => - val (st, end) = slicePos(start, length, () => ba.length) - ba.slice(st, end) - case s: UTF8String => - val (st, end) = slicePos(start, length, () => s.numChars()) - s.substring(st, end) + override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + val strGen = str.gen(ctx) + val posGen = pos.gen(ctx) + val lenGen = len.gen(ctx) + + val start = ctx.freshName("start") + val end = ctx.freshName("end") + + s""" + ${strGen.code} + boolean ${ev.isNull} = ${strGen.isNull}; + ${ctx.javaType(dataType)} ${ev.primitive} = ${ctx.defaultValue(dataType)}; + if (!${ev.isNull}) { + ${posGen.code} + if (!${posGen.isNull}) { + ${lenGen.code} + if (!${lenGen.isNull}) { + ${ev.primitive} = ${strGen.primitive} + .substringSQL(${posGen.primitive}, ${lenGen.primitive}); + } else { + ${ev.isNull} = true; + } + } else { + ${ev.isNull} = true; + } } - } + """ } } /** - * A function that return the length of the given string expression. + * A function that return the length of the given string or binary expression. */ -case class StringLength(child: Expression) extends UnaryExpression with ExpectsInputTypes { +case class Length(child: Expression) extends UnaryExpression with ExpectsInputTypes { override def dataType: DataType = IntegerType - override def inputTypes: Seq[DataType] = Seq(StringType) + override def inputTypes: Seq[AbstractDataType] = Seq(TypeCollection(StringType, BinaryType)) - protected override def nullSafeEval(string: Any): Any = - string.asInstanceOf[UTF8String].numChars + protected override def nullSafeEval(value: Any): Any = child.dataType match { + case StringType => value.asInstanceOf[UTF8String].numChars + case BinaryType => value.asInstanceOf[Array[Byte]].length + } override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { - defineCodeGen(ctx, ev, c => s"($c).numChars()") + child.dataType match { + case StringType => defineCodeGen(ctx, ev, c => s"($c).numChars()") + case BinaryType => defineCodeGen(ctx, ev, c => s"($c).length") + } } - - override def prettyName: String = "length" } /** * A function that return the Levenshtein distance between the two given strings. */ case class Levenshtein(left: Expression, right: Expression) extends BinaryExpression - with ExpectsInputTypes { + with ImplicitCastInputTypes { override def inputTypes: Seq[AbstractDataType] = Seq(StringType, StringType) @@ -591,7 +784,8 @@ case class Levenshtein(left: Expression, right: Expression) extends BinaryExpres /** * Returns the numeric value of the first character of str. */ -case class Ascii(child: Expression) extends UnaryExpression with ExpectsInputTypes { +case class Ascii(child: Expression) extends UnaryExpression with ImplicitCastInputTypes { + override def dataType: DataType = IntegerType override def inputTypes: Seq[DataType] = Seq(StringType) @@ -603,12 +797,26 @@ case class Ascii(child: Expression) extends UnaryExpression with ExpectsInputTyp 0 } } + + override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + nullSafeCodeGen(ctx, ev, (child) => { + val bytes = ctx.freshName("bytes") + s""" + byte[] $bytes = $child.getBytes(); + if ($bytes.length > 0) { + ${ev.primitive} = (int) $bytes[0]; + } else { + ${ev.primitive} = 0; + } + """}) + } } /** * Converts the argument from binary to a base 64 string. */ -case class Base64(child: Expression) extends UnaryExpression with ExpectsInputTypes { +case class Base64(child: Expression) extends UnaryExpression with ImplicitCastInputTypes { + override def dataType: DataType = StringType override def inputTypes: Seq[DataType] = Seq(BinaryType) @@ -617,17 +825,33 @@ case class Base64(child: Expression) extends UnaryExpression with ExpectsInputTy org.apache.commons.codec.binary.Base64.encodeBase64( bytes.asInstanceOf[Array[Byte]])) } + + override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + nullSafeCodeGen(ctx, ev, (child) => { + s"""${ev.primitive} = UTF8String.fromBytes( + org.apache.commons.codec.binary.Base64.encodeBase64($child)); + """}) + } + } /** * Converts the argument from a base 64 string to BINARY. */ -case class UnBase64(child: Expression) extends UnaryExpression with ExpectsInputTypes { +case class UnBase64(child: Expression) extends UnaryExpression with ImplicitCastInputTypes { + override def dataType: DataType = BinaryType override def inputTypes: Seq[DataType] = Seq(StringType) protected override def nullSafeEval(string: Any): Any = org.apache.commons.codec.binary.Base64.decodeBase64(string.asInstanceOf[UTF8String].toString) + + override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + nullSafeCodeGen(ctx, ev, (child) => { + s""" + ${ev.primitive} = org.apache.commons.codec.binary.Base64.decodeBase64($child.toString()); + """}) + } } /** @@ -636,7 +860,7 @@ case class UnBase64(child: Expression) extends UnaryExpression with ExpectsInput * If either argument is null, the result will also be null. */ case class Decode(bin: Expression, charset: Expression) - extends BinaryExpression with ExpectsInputTypes { + extends BinaryExpression with ImplicitCastInputTypes { override def left: Expression = bin override def right: Expression = charset @@ -647,6 +871,17 @@ case class Decode(bin: Expression, charset: Expression) val fromCharset = input2.asInstanceOf[UTF8String].toString UTF8String.fromString(new String(input1.asInstanceOf[Array[Byte]], fromCharset)) } + + override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + nullSafeCodeGen(ctx, ev, (bytes, charset) => + s""" + try { + ${ev.primitive} = UTF8String.fromString(new String($bytes, $charset.toString())); + } catch (java.io.UnsupportedEncodingException e) { + org.apache.spark.unsafe.PlatformDependent.throwException(e); + } + """) + } } /** @@ -655,7 +890,7 @@ case class Decode(bin: Expression, charset: Expression) * If either argument is null, the result will also be null. */ case class Encode(value: Expression, charset: Expression) - extends BinaryExpression with ExpectsInputTypes { + extends BinaryExpression with ImplicitCastInputTypes { override def left: Expression = value override def right: Expression = charset @@ -666,5 +901,344 @@ case class Encode(value: Expression, charset: Expression) val toCharset = input2.asInstanceOf[UTF8String].toString input1.asInstanceOf[UTF8String].toString.getBytes(toCharset) } + + override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + nullSafeCodeGen(ctx, ev, (string, charset) => + s""" + try { + ${ev.primitive} = $string.toString().getBytes($charset.toString()); + } catch (java.io.UnsupportedEncodingException e) { + org.apache.spark.unsafe.PlatformDependent.throwException(e); + }""") + } +} + +/** + * Replace all substrings of str that match regexp with rep. + * + * NOTE: this expression is not THREAD-SAFE, as it has some internal mutable status. + */ +case class RegExpReplace(subject: Expression, regexp: Expression, rep: Expression) + extends Expression with ImplicitCastInputTypes { + + // last regex in string, we will update the pattern iff regexp value changed. + @transient private var lastRegex: UTF8String = _ + // last regex pattern, we cache it for performance concern + @transient private var pattern: Pattern = _ + // last replacement string, we don't want to convert a UTF8String => java.langString every time. + @transient private var lastReplacement: String = _ + @transient private var lastReplacementInUTF8: UTF8String = _ + // result buffer write by Matcher + @transient private val result: StringBuffer = new StringBuffer + + override def nullable: Boolean = subject.nullable || regexp.nullable || rep.nullable + override def foldable: Boolean = subject.foldable && regexp.foldable && rep.foldable + + override def eval(input: InternalRow): Any = { + val s = subject.eval(input) + if (null != s) { + val p = regexp.eval(input) + if (null != p) { + val r = rep.eval(input) + if (null != r) { + if (!p.equals(lastRegex)) { + // regex value changed + lastRegex = p.asInstanceOf[UTF8String] + pattern = Pattern.compile(lastRegex.toString) + } + if (!r.equals(lastReplacementInUTF8)) { + // replacement string changed + lastReplacementInUTF8 = r.asInstanceOf[UTF8String] + lastReplacement = lastReplacementInUTF8.toString + } + val m = pattern.matcher(s.toString()) + result.delete(0, result.length()) + + while (m.find) { + m.appendReplacement(result, lastReplacement) + } + m.appendTail(result) + + return UTF8String.fromString(result.toString) + } + } + } + + null + } + + override def dataType: DataType = StringType + override def inputTypes: Seq[AbstractDataType] = Seq(StringType, StringType, StringType) + override def children: Seq[Expression] = subject :: regexp :: rep :: Nil + override def prettyName: String = "regexp_replace" + + override protected def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + val termLastRegex = ctx.freshName("lastRegex") + val termPattern = ctx.freshName("pattern") + + val termLastReplacement = ctx.freshName("lastReplacement") + val termLastReplacementInUTF8 = ctx.freshName("lastReplacementInUTF8") + + val termResult = ctx.freshName("result") + + val classNameUTF8String = classOf[UTF8String].getCanonicalName + val classNamePattern = classOf[Pattern].getCanonicalName + val classNameString = classOf[java.lang.String].getCanonicalName + val classNameStringBuffer = classOf[java.lang.StringBuffer].getCanonicalName + + ctx.addMutableState(classNameUTF8String, + termLastRegex, s"${termLastRegex} = null;") + ctx.addMutableState(classNamePattern, + termPattern, s"${termPattern} = null;") + ctx.addMutableState(classNameString, + termLastReplacement, s"${termLastReplacement} = null;") + ctx.addMutableState(classNameUTF8String, + termLastReplacementInUTF8, s"${termLastReplacementInUTF8} = null;") + ctx.addMutableState(classNameStringBuffer, + termResult, s"${termResult} = new $classNameStringBuffer();") + + val evalSubject = subject.gen(ctx) + val evalRegexp = regexp.gen(ctx) + val evalRep = rep.gen(ctx) + + s""" + ${evalSubject.code} + boolean ${ev.isNull} = ${evalSubject.isNull}; + ${ctx.javaType(dataType)} ${ev.primitive} = ${ctx.defaultValue(dataType)}; + if (!${evalSubject.isNull}) { + ${evalRegexp.code} + if (!${evalRegexp.isNull}) { + ${evalRep.code} + if (!${evalRep.isNull}) { + if (!${evalRegexp.primitive}.equals(${termLastRegex})) { + // regex value changed + ${termLastRegex} = ${evalRegexp.primitive}; + ${termPattern} = ${classNamePattern}.compile(${termLastRegex}.toString()); + } + if (!${evalRep.primitive}.equals(${termLastReplacementInUTF8})) { + // replacement string changed + ${termLastReplacementInUTF8} = ${evalRep.primitive}; + ${termLastReplacement} = ${termLastReplacementInUTF8}.toString(); + } + ${termResult}.delete(0, ${termResult}.length()); + ${classOf[java.util.regex.Matcher].getCanonicalName} m = + ${termPattern}.matcher(${evalSubject.primitive}.toString()); + + while (m.find()) { + m.appendReplacement(${termResult}, ${termLastReplacement}); + } + m.appendTail(${termResult}); + ${ev.primitive} = ${classNameUTF8String}.fromString(${termResult}.toString()); + ${ev.isNull} = false; + } + } + } + """ + } +} + +/** + * Extract a specific(idx) group identified by a Java regex. + * + * NOTE: this expression is not THREAD-SAFE, as it has some internal mutable status. + */ +case class RegExpExtract(subject: Expression, regexp: Expression, idx: Expression) + extends Expression with ImplicitCastInputTypes { + def this(s: Expression, r: Expression) = this(s, r, Literal(1)) + + // last regex in string, we will update the pattern iff regexp value changed. + @transient private var lastRegex: UTF8String = _ + // last regex pattern, we cache it for performance concern + @transient private var pattern: Pattern = _ + + override def nullable: Boolean = subject.nullable || regexp.nullable || idx.nullable + override def foldable: Boolean = subject.foldable && regexp.foldable && idx.foldable + + override def eval(input: InternalRow): Any = { + val s = subject.eval(input) + if (null != s) { + val p = regexp.eval(input) + if (null != p) { + val r = idx.eval(input) + if (null != r) { + if (!p.equals(lastRegex)) { + // regex value changed + lastRegex = p.asInstanceOf[UTF8String] + pattern = Pattern.compile(lastRegex.toString) + } + val m = pattern.matcher(s.toString()) + if (m.find) { + val mr: MatchResult = m.toMatchResult + return UTF8String.fromString(mr.group(r.asInstanceOf[Int])) + } + return UTF8String.EMPTY_UTF8 + } + } + } + + null + } + + override def dataType: DataType = StringType + override def inputTypes: Seq[AbstractDataType] = Seq(StringType, StringType, IntegerType) + override def children: Seq[Expression] = subject :: regexp :: idx :: Nil + override def prettyName: String = "regexp_extract" + + override protected def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + val termLastRegex = ctx.freshName("lastRegex") + val termPattern = ctx.freshName("pattern") + val classNameUTF8String = classOf[UTF8String].getCanonicalName + val classNamePattern = classOf[Pattern].getCanonicalName + + ctx.addMutableState(classNameUTF8String, termLastRegex, s"${termLastRegex} = null;") + ctx.addMutableState(classNamePattern, termPattern, s"${termPattern} = null;") + + val evalSubject = subject.gen(ctx) + val evalRegexp = regexp.gen(ctx) + val evalIdx = idx.gen(ctx) + + s""" + ${ctx.javaType(dataType)} ${ev.primitive} = null; + boolean ${ev.isNull} = true; + ${evalSubject.code} + if (!${evalSubject.isNull}) { + ${evalRegexp.code} + if (!${evalRegexp.isNull}) { + ${evalIdx.code} + if (!${evalIdx.isNull}) { + if (!${evalRegexp.primitive}.equals(${termLastRegex})) { + // regex value changed + ${termLastRegex} = ${evalRegexp.primitive}; + ${termPattern} = ${classNamePattern}.compile(${termLastRegex}.toString()); + } + ${classOf[java.util.regex.Matcher].getCanonicalName} m = + ${termPattern}.matcher(${evalSubject.primitive}.toString()); + if (m.find()) { + ${classOf[java.util.regex.MatchResult].getCanonicalName} mr = m.toMatchResult(); + ${ev.primitive} = ${classNameUTF8String}.fromString(mr.group(${evalIdx.primitive})); + ${ev.isNull} = false; + } else { + ${ev.primitive} = ${classNameUTF8String}.EMPTY_UTF8; + ${ev.isNull} = false; + } + } + } + } + """ + } +} + +/** + * Formats the number X to a format like '#,###,###.##', rounded to D decimal places, + * and returns the result as a string. If D is 0, the result has no decimal point or + * fractional part. + */ +case class FormatNumber(x: Expression, d: Expression) + extends BinaryExpression with ExpectsInputTypes with CodegenFallback { + + override def left: Expression = x + override def right: Expression = d + override def dataType: DataType = StringType + override def inputTypes: Seq[AbstractDataType] = Seq(NumericType, IntegerType) + + // Associated with the pattern, for the last d value, and we will update the + // pattern (DecimalFormat) once the new coming d value differ with the last one. + @transient + private var lastDValue: Int = -100 + + // A cached DecimalFormat, for performance concern, we will change it + // only if the d value changed. + @transient + private val pattern: StringBuffer = new StringBuffer() + + @transient + private val numberFormat: DecimalFormat = new DecimalFormat("") + + override protected def nullSafeEval(xObject: Any, dObject: Any): Any = { + val dValue = dObject.asInstanceOf[Int] + if (dValue < 0) { + return null + } + + if (dValue != lastDValue) { + // construct a new DecimalFormat only if a new dValue + pattern.delete(0, pattern.length) + pattern.append("#,###,###,###,###,###,##0") + + // decimal place + if (dValue > 0) { + pattern.append(".") + + var i = 0 + while (i < dValue) { + i += 1 + pattern.append("0") + } + } + val dFormat = new DecimalFormat(pattern.toString) + lastDValue = dValue + + numberFormat.applyPattern(dFormat.toPattern) + } + + x.dataType match { + case ByteType => UTF8String.fromString(numberFormat.format(xObject.asInstanceOf[Byte])) + case ShortType => UTF8String.fromString(numberFormat.format(xObject.asInstanceOf[Short])) + case FloatType => UTF8String.fromString(numberFormat.format(xObject.asInstanceOf[Float])) + case IntegerType => UTF8String.fromString(numberFormat.format(xObject.asInstanceOf[Int])) + case LongType => UTF8String.fromString(numberFormat.format(xObject.asInstanceOf[Long])) + case DoubleType => UTF8String.fromString(numberFormat.format(xObject.asInstanceOf[Double])) + case _: DecimalType => + UTF8String.fromString(numberFormat.format(xObject.asInstanceOf[Decimal].toJavaBigDecimal)) + } + } + + override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + nullSafeCodeGen(ctx, ev, (num, d) => { + + def typeHelper(p: String): String = { + x.dataType match { + case _ : DecimalType => s"""$p.toJavaBigDecimal()""" + case _ => s"$p" + } + } + + val sb = classOf[StringBuffer].getName + val df = classOf[DecimalFormat].getName + val lastDValue = ctx.freshName("lastDValue") + val pattern = ctx.freshName("pattern") + val numberFormat = ctx.freshName("numberFormat") + val i = ctx.freshName("i") + val dFormat = ctx.freshName("dFormat") + ctx.addMutableState("int", lastDValue, s"$lastDValue = -100;") + ctx.addMutableState(sb, pattern, s"$pattern = new $sb();") + ctx.addMutableState(df, numberFormat, s"""$numberFormat = new $df("");""") + + s""" + if ($d >= 0) { + $pattern.delete(0, $pattern.length()); + if ($d != $lastDValue) { + $pattern.append("#,###,###,###,###,###,##0"); + + if ($d > 0) { + $pattern.append("."); + for (int $i = 0; $i < $d; $i++) { + $pattern.append("0"); + } + } + $df $dFormat = new $df($pattern.toString()); + $lastDValue = $d; + $numberFormat.applyPattern($dFormat.toPattern()); + ${ev.primitive} = UTF8String.fromString($numberFormat.format(${typeHelper(num)})); + } + } else { + ${ev.primitive} = null; + ${ev.isNull} = true; + } + """ + }) + } + + override def prettyName: String = "format_number" } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/windowExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/windowExpressions.scala index 344361685853f..09ec0e333aa44 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/windowExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/windowExpressions.scala @@ -19,7 +19,6 @@ package org.apache.spark.sql.catalyst.expressions import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.analysis.UnresolvedException -import org.apache.spark.sql.catalyst.errors.TreeNodeException import org.apache.spark.sql.types.{DataType, NumericType} /** @@ -37,7 +36,7 @@ sealed trait WindowSpec case class WindowSpecDefinition( partitionSpec: Seq[Expression], orderSpec: Seq[SortOrder], - frameSpecification: WindowFrame) extends Expression with WindowSpec { + frameSpecification: WindowFrame) extends Expression with WindowSpec with Unevaluable { def validate: Option[String] = frameSpecification match { case UnspecifiedFrame => @@ -75,7 +74,6 @@ case class WindowSpecDefinition( override def toString: String = simpleString - override def eval(input: InternalRow): Any = throw new UnsupportedOperationException override def nullable: Boolean = true override def foldable: Boolean = false override def dataType: DataType = throw new UnsupportedOperationException @@ -254,8 +252,6 @@ object SpecifiedWindowFrame { * to retrieve value corresponding with these n rows. */ trait WindowFunction extends Expression { - self: Product => - def init(): Unit def reset(): Unit @@ -276,60 +272,43 @@ trait WindowFunction extends Expression { case class UnresolvedWindowFunction( name: String, children: Seq[Expression]) - extends Expression with WindowFunction { + extends Expression with WindowFunction with Unevaluable { override def dataType: DataType = throw new UnresolvedException(this, "dataType") override def foldable: Boolean = throw new UnresolvedException(this, "foldable") override def nullable: Boolean = throw new UnresolvedException(this, "nullable") override lazy val resolved = false - override def init(): Unit = - throw new UnresolvedException(this, "init") - override def reset(): Unit = - throw new UnresolvedException(this, "reset") + override def init(): Unit = throw new UnresolvedException(this, "init") + override def reset(): Unit = throw new UnresolvedException(this, "reset") override def prepareInputParameters(input: InternalRow): AnyRef = throw new UnresolvedException(this, "prepareInputParameters") - override def update(input: AnyRef): Unit = - throw new UnresolvedException(this, "update") + override def update(input: AnyRef): Unit = throw new UnresolvedException(this, "update") override def batchUpdate(inputs: Array[AnyRef]): Unit = throw new UnresolvedException(this, "batchUpdate") - override def evaluate(): Unit = - throw new UnresolvedException(this, "evaluate") - override def get(index: Int): Any = - throw new UnresolvedException(this, "get") - // Unresolved functions are transient at compile time and don't get evaluated during execution. - override def eval(input: InternalRow = null): Any = - throw new TreeNodeException(this, s"No function to evaluate expression. type: ${this.nodeName}") + override def evaluate(): Unit = throw new UnresolvedException(this, "evaluate") + override def get(index: Int): Any = throw new UnresolvedException(this, "get") override def toString: String = s"'$name(${children.mkString(",")})" - override def newInstance(): WindowFunction = - throw new UnresolvedException(this, "newInstance") + override def newInstance(): WindowFunction = throw new UnresolvedException(this, "newInstance") } case class UnresolvedWindowExpression( child: UnresolvedWindowFunction, - windowSpec: WindowSpecReference) extends UnaryExpression { + windowSpec: WindowSpecReference) extends UnaryExpression with Unevaluable { override def dataType: DataType = throw new UnresolvedException(this, "dataType") override def foldable: Boolean = throw new UnresolvedException(this, "foldable") override def nullable: Boolean = throw new UnresolvedException(this, "nullable") override lazy val resolved = false - - // Unresolved functions are transient at compile time and don't get evaluated during execution. - override def eval(input: InternalRow = null): Any = - throw new TreeNodeException(this, s"No function to evaluate expression. type: ${this.nodeName}") } case class WindowExpression( windowFunction: WindowFunction, - windowSpec: WindowSpecDefinition) extends Expression { + windowSpec: WindowSpecDefinition) extends Expression with Unevaluable { - override def children: Seq[Expression] = - windowFunction :: windowSpec :: Nil - - override def eval(input: InternalRow): Any = - throw new TreeNodeException(this, s"No function to evaluate expression. type: ${this.nodeName}") + override def children: Seq[Expression] = windowFunction :: windowSpec :: Nil override def dataType: DataType = windowFunction.dataType override def foldable: Boolean = windowFunction.foldable @@ -337,3 +316,15 @@ case class WindowExpression( override def toString: String = s"$windowFunction $windowSpec" } + +/** + * Extractor for making working with frame boundaries easier. + */ +object FrameBoundaryExtractor { + def unapply(boundary: FrameBoundary): Option[Int] = boundary match { + case CurrentRow => Some(0) + case ValuePreceding(offset) => Some(-offset) + case ValueFollowing(offset) => Some(offset) + case _ => None + } +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index 5d80214abf141..d2db3dd3d078e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -40,7 +40,7 @@ object DefaultOptimizer extends Optimizer { ReplaceDistinctWithAggregate) :: Batch("Operator Optimizations", FixedPoint(100), // Operator push down - UnionPushDown, + SetOperationPushDown, SamplePushDown, PushPredicateThroughJoin, PushPredicateThroughProject, @@ -84,23 +84,24 @@ object SamplePushDown extends Rule[LogicalPlan] { } /** - * Pushes operations to either side of a Union. + * Pushes operations to either side of a Union, Intersect or Except. */ -object UnionPushDown extends Rule[LogicalPlan] { +object SetOperationPushDown extends Rule[LogicalPlan] { /** * Maps Attributes from the left side to the corresponding Attribute on the right side. */ - private def buildRewrites(union: Union): AttributeMap[Attribute] = { - assert(union.left.output.size == union.right.output.size) + private def buildRewrites(bn: BinaryNode): AttributeMap[Attribute] = { + assert(bn.isInstanceOf[Union] || bn.isInstanceOf[Intersect] || bn.isInstanceOf[Except]) + assert(bn.left.output.size == bn.right.output.size) - AttributeMap(union.left.output.zip(union.right.output)) + AttributeMap(bn.left.output.zip(bn.right.output)) } /** - * Rewrites an expression so that it can be pushed to the right side of a Union operator. - * This method relies on the fact that the output attributes of a union are always equal - * to the left child's output. + * Rewrites an expression so that it can be pushed to the right side of a + * Union, Intersect or Except operator. This method relies on the fact that the output attributes + * of a union/intersect/except are always equal to the left child's output. */ private def pushToRight[A <: Expression](e: A, rewrites: AttributeMap[Attribute]) = { val result = e transform { @@ -126,6 +127,34 @@ object UnionPushDown extends Rule[LogicalPlan] { Union( Project(projectList, left), Project(projectList.map(pushToRight(_, rewrites)), right)) + + // Push down filter into intersect + case Filter(condition, i @ Intersect(left, right)) => + val rewrites = buildRewrites(i) + Intersect( + Filter(condition, left), + Filter(pushToRight(condition, rewrites), right)) + + // Push down projection into intersect + case Project(projectList, i @ Intersect(left, right)) => + val rewrites = buildRewrites(i) + Intersect( + Project(projectList, left), + Project(projectList.map(pushToRight(_, rewrites)), right)) + + // Push down filter into except + case Filter(condition, e @ Except(left, right)) => + val rewrites = buildRewrites(e) + Except( + Filter(condition, left), + Filter(pushToRight(condition, rewrites), right)) + + // Push down projection into except + case Project(projectList, e @ Except(left, right)) => + val rewrites = buildRewrites(e) + Except( + Project(projectList, left), + Project(projectList.map(pushToRight(_, rewrites)), right)) } } @@ -206,31 +235,33 @@ object ColumnPruning extends Rule[LogicalPlan] { */ object ProjectCollapsing extends Rule[LogicalPlan] { - /** Returns true if any expression in projectList is non-deterministic. */ - private def hasNondeterministic(projectList: Seq[NamedExpression]): Boolean = { - projectList.exists(expr => expr.find(!_.deterministic).isDefined) - } - def apply(plan: LogicalPlan): LogicalPlan = plan transformUp { - // We only collapse these two Projects if the child Project's expressions are all - // deterministic. - case Project(projectList1, Project(projectList2, child)) - if !hasNondeterministic(projectList2) => + case p @ Project(projectList1, Project(projectList2, child)) => // Create a map of Aliases to their values from the child projection. // e.g., 'SELECT ... FROM (SELECT a + b AS c, d ...)' produces Map(c -> Alias(a + b, c)). val aliasMap = AttributeMap(projectList2.collect { - case a @ Alias(e, _) => (a.toAttribute, a) + case a: Alias => (a.toAttribute, a) }) - // Substitute any attributes that are produced by the child projection, so that we safely - // eliminate it. - // e.g., 'SELECT c + 1 FROM (SELECT a + b AS C ...' produces 'SELECT a + b + 1 ...' - // TODO: Fix TransformBase to avoid the cast below. - val substitutedProjection = projectList1.map(_.transform { - case a: Attribute if aliasMap.contains(a) => aliasMap(a) - }).asInstanceOf[Seq[NamedExpression]] + // We only collapse these two Projects if their overlapped expressions are all + // deterministic. + val hasNondeterministic = projectList1.exists(_.collect { + case a: Attribute if aliasMap.contains(a) => aliasMap(a).child + }.exists(!_.deterministic)) - Project(substitutedProjection, child) + if (hasNondeterministic) { + p + } else { + // Substitute any attributes that are produced by the child projection, so that we safely + // eliminate it. + // e.g., 'SELECT c + 1 FROM (SELECT a + b AS C ...' produces 'SELECT a + b + 1 ...' + // TODO: Fix TransformBase to avoid the cast below. + val substitutedProjection = projectList1.map(_.transform { + case a: Attribute => aliasMap.getOrElse(a, a) + }).asInstanceOf[Seq[NamedExpression]] + + Project(substitutedProjection, child) + } } } @@ -342,7 +373,7 @@ object ConstantFolding extends Rule[LogicalPlan] { case l: Literal => l // Fold expressions that are foldable. - case e if e.foldable => Literal.create(e.eval(null), e.dataType) + case e if e.foldable => Literal.create(e.eval(EmptyRow), e.dataType) // Fold "literal in (item1, item2, ..., literal, ...)" into true directly. case In(Literal(v, _), list) if list.exists { @@ -361,7 +392,7 @@ object OptimizeIn extends Rule[LogicalPlan] { def apply(plan: LogicalPlan): LogicalPlan = plan transform { case q: LogicalPlan => q transformExpressionsDown { case In(v, list) if !list.exists(!_.isInstanceOf[Literal]) => - val hSet = list.map(e => e.eval(null)) + val hSet = list.map(e => e.eval(EmptyRow)) InSet(v, HashSet() ++ hSet) } } @@ -391,26 +422,26 @@ object BooleanSimplification extends Rule[LogicalPlan] with PredicateHelper { // (a || b) && (a || c) => a || (b && c) case _ => // 1. Split left and right to get the disjunctive predicates, - // i.e. lhsSet = (a, b), rhsSet = (a, c) + // i.e. lhs = (a, b), rhs = (a, c) // 2. Find the common predict between lhsSet and rhsSet, i.e. common = (a) // 3. Remove common predict from lhsSet and rhsSet, i.e. ldiff = (b), rdiff = (c) // 4. Apply the formula, get the optimized predicate: common || (ldiff && rdiff) - val lhsSet = splitDisjunctivePredicates(left).toSet - val rhsSet = splitDisjunctivePredicates(right).toSet - val common = lhsSet.intersect(rhsSet) + val lhs = splitDisjunctivePredicates(left) + val rhs = splitDisjunctivePredicates(right) + val common = lhs.filter(e => rhs.exists(e.semanticEquals(_))) if (common.isEmpty) { // No common factors, return the original predicate and } else { - val ldiff = lhsSet.diff(common) - val rdiff = rhsSet.diff(common) + val ldiff = lhs.filterNot(e => common.exists(e.semanticEquals(_))) + val rdiff = rhs.filterNot(e => common.exists(e.semanticEquals(_))) if (ldiff.isEmpty || rdiff.isEmpty) { // (a || b || c || ...) && (a || b) => (a || b) common.reduce(Or) } else { // (a || b || c || ...) && (a || b || d || ...) => // ((c || ...) && (d || ...)) || a || b - (common + And(ldiff.reduce(Or), rdiff.reduce(Or))).reduce(Or) + (common :+ And(ldiff.reduce(Or), rdiff.reduce(Or))).reduce(Or) } } } // end of And(left, right) @@ -429,26 +460,26 @@ object BooleanSimplification extends Rule[LogicalPlan] with PredicateHelper { // (a && b) || (a && c) => a && (b || c) case _ => // 1. Split left and right to get the conjunctive predicates, - // i.e. lhsSet = (a, b), rhsSet = (a, c) + // i.e. lhs = (a, b), rhs = (a, c) // 2. Find the common predict between lhsSet and rhsSet, i.e. common = (a) // 3. Remove common predict from lhsSet and rhsSet, i.e. ldiff = (b), rdiff = (c) // 4. Apply the formula, get the optimized predicate: common && (ldiff || rdiff) - val lhsSet = splitConjunctivePredicates(left).toSet - val rhsSet = splitConjunctivePredicates(right).toSet - val common = lhsSet.intersect(rhsSet) + val lhs = splitConjunctivePredicates(left) + val rhs = splitConjunctivePredicates(right) + val common = lhs.filter(e => rhs.exists(e.semanticEquals(_))) if (common.isEmpty) { // No common factors, return the original predicate or } else { - val ldiff = lhsSet.diff(common) - val rdiff = rhsSet.diff(common) + val ldiff = lhs.filterNot(e => common.exists(e.semanticEquals(_))) + val rdiff = rhs.filterNot(e => common.exists(e.semanticEquals(_))) if (ldiff.isEmpty || rdiff.isEmpty) { // (a && b) || (a && b && c && ...) => a && b common.reduce(And) } else { // (a && b && c && ...) || (a && b && d && ...) => // ((c && ...) || (d && ...)) && a && b - (common + Or(ldiff.reduce(And), rdiff.reduce(And))).reduce(And) + (common :+ Or(ldiff.reduce(And), rdiff.reduce(And))).reduce(And) } } } // end of Or(left, right) @@ -510,20 +541,50 @@ object SimplifyFilters extends Rule[LogicalPlan] { * * This heuristic is valid assuming the expression evaluation cost is minimal. */ -object PushPredicateThroughProject extends Rule[LogicalPlan] { +object PushPredicateThroughProject extends Rule[LogicalPlan] with PredicateHelper { def apply(plan: LogicalPlan): LogicalPlan = plan transform { case filter @ Filter(condition, project @ Project(fields, grandChild)) => - val sourceAliases = fields.collect { case a @ Alias(c, _) => - (a.toAttribute: Attribute) -> c - }.toMap - project.copy(child = filter.copy( - replaceAlias(condition, sourceAliases), - grandChild)) + // Create a map of Aliases to their values from the child projection. + // e.g., 'SELECT a + b AS c, d ...' produces Map(c -> a + b). + val aliasMap = AttributeMap(fields.collect { + case a: Alias => (a.toAttribute, a.child) + }) + + // Split the condition into small conditions by `And`, so that we can push down part of this + // condition without nondeterministic expressions. + val andConditions = splitConjunctivePredicates(condition) + val nondeterministicConditions = andConditions.filter(hasNondeterministic(_, aliasMap)) + + // If there is no nondeterministic conditions, push down the whole condition. + if (nondeterministicConditions.isEmpty) { + project.copy(child = Filter(replaceAlias(condition, aliasMap), grandChild)) + } else { + // If they are all nondeterministic conditions, leave it un-changed. + if (nondeterministicConditions.length == andConditions.length) { + filter + } else { + val deterministicConditions = andConditions.filterNot(hasNondeterministic(_, aliasMap)) + // Push down the small conditions without nondeterministic expressions. + val pushedCondition = deterministicConditions.map(replaceAlias(_, aliasMap)).reduce(And) + Filter(nondeterministicConditions.reduce(And), + project.copy(child = Filter(pushedCondition, grandChild))) + } + } + } + + private def hasNondeterministic( + condition: Expression, + sourceAliases: AttributeMap[Expression]) = { + condition.collect { + case a: Attribute if sourceAliases.contains(a) => sourceAliases(a) + }.exists(!_.deterministic) } - private def replaceAlias(condition: Expression, sourceAliases: Map[Attribute, Expression]) = { - condition transform { - case a: AttributeReference => sourceAliases.getOrElse(a, a) + // Substitute any attributes that are produced by the child projection, so that we safely + // eliminate it. + private def replaceAlias(condition: Expression, sourceAliases: AttributeMap[Expression]) = { + condition.transform { + case a: Attribute => sourceAliases.getOrElse(a, a) } } } @@ -682,7 +743,7 @@ object CombineLimits extends Rule[LogicalPlan] { } /** - * Removes the inner [[CaseConversionExpression]] that are unnecessary because + * Removes the inner case conversion expressions that are unnecessary because * the inner conversion is overwritten by the outer one. */ object SimplifyCaseConversionExpressions extends Rule[LogicalPlan] { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala index 179a348d5baac..b8e3b0d53a505 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala @@ -129,10 +129,10 @@ object PartialAggregation { case logical.Aggregate(groupingExpressions, aggregateExpressions, child) => // Collect all aggregate expressions. val allAggregates = - aggregateExpressions.flatMap(_ collect { case a: AggregateExpression => a}) + aggregateExpressions.flatMap(_ collect { case a: AggregateExpression1 => a}) // Collect all aggregate expressions that can be computed partially. val partialAggregates = - aggregateExpressions.flatMap(_ collect { case p: PartialAggregate => p}) + aggregateExpressions.flatMap(_ collect { case p: PartialAggregate1 => p}) // Only do partial aggregation if supported by all aggregate expressions. if (allAggregates.size == partialAggregates.size) { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala index b89e3382f06a9..d06a7a2add754 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala @@ -22,7 +22,7 @@ import org.apache.spark.sql.catalyst.trees.TreeNode import org.apache.spark.sql.types.{ArrayType, DataType, StructField, StructType} abstract class QueryPlan[PlanType <: TreeNode[PlanType]] extends TreeNode[PlanType] { - self: PlanType with Product => + self: PlanType => def output: Seq[Attribute] diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LocalRelation.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LocalRelation.scala index 1868f119f0e97..e3e7a11dba973 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LocalRelation.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LocalRelation.scala @@ -17,6 +17,7 @@ package org.apache.spark.sql.catalyst.plans.logical +import org.apache.spark.sql.Row import org.apache.spark.sql.catalyst.expressions.Attribute import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow, analysis} import org.apache.spark.sql.types.{StructField, StructType} @@ -28,6 +29,12 @@ object LocalRelation { new LocalRelation(StructType(output1 +: output).toAttributes) } + def fromExternalRows(output: Seq[Attribute], data: Seq[Row]): LocalRelation = { + val schema = StructType.fromAttributes(output) + val converter = CatalystTypeConverters.createToCatalystConverter(schema) + LocalRelation(output, data.map(converter(_).asInstanceOf[InternalRow])) + } + def fromProduct(output: Seq[Attribute], data: Seq[Product]): LocalRelation = { val schema = StructType.fromAttributes(output) val converter = CatalystTypeConverters.createToCatalystConverter(schema) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala index e911b907e8536..bedeaf06adf12 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala @@ -23,11 +23,9 @@ import org.apache.spark.sql.catalyst.analysis._ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.QueryPlan import org.apache.spark.sql.catalyst.trees.TreeNode -import org.apache.spark.sql.catalyst.trees abstract class LogicalPlan extends QueryPlan[LogicalPlan] with Logging { - self: Product => /** * Computes [[Statistics]] for this plan. The default implementation assumes the output @@ -277,20 +275,25 @@ abstract class LogicalPlan extends QueryPlan[LogicalPlan] with Logging { /** * A logical plan node with no children. */ -abstract class LeafNode extends LogicalPlan with trees.LeafNode[LogicalPlan] { - self: Product => +abstract class LeafNode extends LogicalPlan { + override def children: Seq[LogicalPlan] = Nil } /** * A logical plan node with single child. */ -abstract class UnaryNode extends LogicalPlan with trees.UnaryNode[LogicalPlan] { - self: Product => +abstract class UnaryNode extends LogicalPlan { + def child: LogicalPlan + + override def children: Seq[LogicalPlan] = child :: Nil } /** * A logical plan node with a left and right child. */ -abstract class BinaryNode extends LogicalPlan with trees.BinaryNode[LogicalPlan] { - self: Product => +abstract class BinaryNode extends LogicalPlan { + def left: LogicalPlan + def right: LogicalPlan + + override def children: Seq[LogicalPlan] = Seq(left, right) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala index fae339808c233..6aefa9f67556a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala @@ -18,6 +18,7 @@ package org.apache.spark.sql.catalyst.plans.logical import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression2 import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.types._ import org.apache.spark.util.collection.OpenHashSet @@ -123,11 +124,11 @@ case class Join( } } - private def selfJoinResolved: Boolean = left.outputSet.intersect(right.outputSet).isEmpty + def selfJoinResolved: Boolean = left.outputSet.intersect(right.outputSet).isEmpty - // Joins are only resolved if they don't introduce ambiguious expression ids. + // Joins are only resolved if they don't introduce ambiguous expression ids. override lazy val resolved: Boolean = { - childrenResolved && !expressions.exists(!_.resolved) && selfJoinResolved + childrenResolved && expressions.forall(_.resolved) && selfJoinResolved } } @@ -141,6 +142,10 @@ case class BroadcastHint(child: LogicalPlan) extends UnaryNode { case class Except(left: LogicalPlan, right: LogicalPlan) extends BinaryNode { override def output: Seq[Attribute] = left.output + + override lazy val resolved: Boolean = + childrenResolved && + left.output.zip(right.output).forall { case (l, r) => l.dataType == r.dataType } } case class InsertIntoTable( @@ -298,7 +303,7 @@ case class Expand( } trait GroupingAnalytics extends UnaryNode { - self: Product => + def groupByExprs: Seq[Expression] def aggregations: Seq[NamedExpression] @@ -437,4 +442,8 @@ case object OneRowRelation extends LeafNode { case class Intersect(left: LogicalPlan, right: LogicalPlan) extends BinaryNode { override def output: Seq[Attribute] = left.output + + override lazy val resolved: Boolean = + childrenResolved && + left.output.zip(right.output).forall { case (l, r) => l.dataType == r.dataType } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/partitioning.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/partitioning.scala index 63df2c1ee72ff..1f76b03bcb0f6 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/partitioning.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/partitioning.scala @@ -24,8 +24,6 @@ import org.apache.spark.sql.catalyst.expressions.{Attribute, Expression, SortOrd * result have expectations about the distribution and ordering of partitioned input data. */ abstract class RedistributeData extends UnaryNode { - self: Product => - override def output: Seq[Attribute] = child.output } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala index 42dead7c28425..2dcfa19fec383 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala @@ -17,9 +17,7 @@ package org.apache.spark.sql.catalyst.plans.physical -import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.errors.TreeNodeException -import org.apache.spark.sql.catalyst.expressions.{Expression, SortOrder} +import org.apache.spark.sql.catalyst.expressions.{Unevaluable, Expression, SortOrder} import org.apache.spark.sql.types.{DataType, IntegerType} /** @@ -146,8 +144,7 @@ case object BroadcastPartitioning extends Partitioning { * in the same partition. */ case class HashPartitioning(expressions: Seq[Expression], numPartitions: Int) - extends Expression - with Partitioning { + extends Expression with Partitioning with Unevaluable { override def children: Seq[Expression] = expressions override def nullable: Boolean = false @@ -169,9 +166,6 @@ case class HashPartitioning(expressions: Seq[Expression], numPartitions: Int) } override def keyExpressions: Seq[Expression] = expressions - - override def eval(input: InternalRow = null): Any = - throw new TreeNodeException(this, s"No function to evaluate expression. type: ${this.nodeName}") } /** @@ -187,8 +181,7 @@ case class HashPartitioning(expressions: Seq[Expression], numPartitions: Int) * into its child. */ case class RangePartitioning(ordering: Seq[SortOrder], numPartitions: Int) - extends Expression - with Partitioning { + extends Expression with Partitioning with Unevaluable { override def children: Seq[SortOrder] = ordering override def nullable: Boolean = false @@ -213,7 +206,4 @@ case class RangePartitioning(ordering: Seq[SortOrder], numPartitions: Int) } override def keyExpressions: Seq[Expression] = ordering.map(_.child) - - override def eval(input: InternalRow): Any = - throw new TreeNodeException(this, s"No function to evaluate expression. type: ${this.nodeName}") } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala index 09f6c6b0ec423..122e9fc5ed77f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala @@ -54,8 +54,8 @@ object CurrentOrigin { } } -abstract class TreeNode[BaseType <: TreeNode[BaseType]] { - self: BaseType with Product => +abstract class TreeNode[BaseType <: TreeNode[BaseType]] extends Product { + self: BaseType => val origin: Origin = CurrentOrigin.get @@ -452,28 +452,3 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]] { s"$nodeName(${args.mkString(",")})" } } - -/** - * A [[TreeNode]] that has two children, [[left]] and [[right]]. - */ -trait BinaryNode[BaseType <: TreeNode[BaseType]] { - def left: BaseType - def right: BaseType - - def children: Seq[BaseType] = Seq(left, right) -} - -/** - * A [[TreeNode]] with no children. - */ -trait LeafNode[BaseType <: TreeNode[BaseType]] { - def children: Seq[BaseType] = Nil -} - -/** - * A [[TreeNode]] with a single [[child]]. - */ -trait UnaryNode[BaseType <: TreeNode[BaseType]] { - def child: BaseType - def children: Seq[BaseType] = child :: Nil -} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/DateTimeUtils.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/DateTimeUtils.scala index c1ddee3ef0230..07412e73b6a5b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/DateTimeUtils.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/DateTimeUtils.scala @@ -19,7 +19,9 @@ package org.apache.spark.sql.catalyst.util import java.sql.{Date, Timestamp} import java.text.{DateFormat, SimpleDateFormat} -import java.util.{Calendar, TimeZone} +import java.util.{TimeZone, Calendar} + +import org.apache.spark.unsafe.types.UTF8String /** * Helper functions for converting between internal and external date and time representations. @@ -29,14 +31,23 @@ import java.util.{Calendar, TimeZone} * precision. */ object DateTimeUtils { - final val MILLIS_PER_DAY = SECONDS_PER_DAY * 1000L - // see http://stackoverflow.com/questions/466321/convert-unix-timestamp-to-julian final val JULIAN_DAY_OF_EPOCH = 2440587 // and .5 final val SECONDS_PER_DAY = 60 * 60 * 24L final val MICROS_PER_SECOND = 1000L * 1000L final val NANOS_PER_SECOND = MICROS_PER_SECOND * 1000L + final val MILLIS_PER_DAY = SECONDS_PER_DAY * 1000L + + // number of days in 400 years + final val daysIn400Years: Int = 146097 + // number of days between 1.1.1970 and 1.1.2001 + final val to2001 = -11323 + + // this is year -17999, calculation: 50 * daysIn400Year + final val toYearZero = to2001 + 7304850 + + @transient lazy val defaultTimeZone = TimeZone.getDefault // Java TimeZone has no mention of thread safety. Use thread local instance to be safe. private val threadLocalLocalTimeZone = new ThreadLocal[TimeZone] { @@ -63,8 +74,8 @@ object DateTimeUtils { def millisToDays(millisUtc: Long): Int = { // SPARK-6785: use Math.floor so negative number of days (dates before 1970) // will correctly work as input for function toJavaDate(Int) - val millisLocal = millisUtc.toDouble + threadLocalLocalTimeZone.get().getOffset(millisUtc) - Math.floor(millisLocal / MILLIS_PER_DAY).toInt + val millisLocal = millisUtc + threadLocalLocalTimeZone.get().getOffset(millisUtc) + Math.floor(millisLocal.toDouble / MILLIS_PER_DAY).toInt } // reverse of millisToDays @@ -180,4 +191,386 @@ object DateTimeUtils { val nanos = (us % MICROS_PER_SECOND) * 1000L (day.toInt, secondsInDay * NANOS_PER_SECOND + nanos) } + + /** + * Parses a given UTF8 date string to the corresponding a corresponding [[Long]] value. + * The return type is [[Option]] in order to distinguish between 0L and null. The following + * formats are allowed: + * + * `yyyy` + * `yyyy-[m]m` + * `yyyy-[m]m-[d]d` + * `yyyy-[m]m-[d]d ` + * `yyyy-[m]m-[d]d [h]h:[m]m:[s]s.[ms][ms][ms][us][us][us]` + * `yyyy-[m]m-[d]d [h]h:[m]m:[s]s.[ms][ms][ms][us][us][us]Z` + * `yyyy-[m]m-[d]d [h]h:[m]m:[s]s.[ms][ms][ms][us][us][us]-[h]h:[m]m` + * `yyyy-[m]m-[d]d [h]h:[m]m:[s]s.[ms][ms][ms][us][us][us]+[h]h:[m]m` + * `yyyy-[m]m-[d]dT[h]h:[m]m:[s]s.[ms][ms][ms][us][us][us]` + * `yyyy-[m]m-[d]dT[h]h:[m]m:[s]s.[ms][ms][ms][us][us][us]Z` + * `yyyy-[m]m-[d]dT[h]h:[m]m:[s]s.[ms][ms][ms][us][us][us]-[h]h:[m]m` + * `yyyy-[m]m-[d]dT[h]h:[m]m:[s]s.[ms][ms][ms][us][us][us]+[h]h:[m]m` + * `[h]h:[m]m:[s]s.[ms][ms][ms][us][us][us]` + * `[h]h:[m]m:[s]s.[ms][ms][ms][us][us][us]Z` + * `[h]h:[m]m:[s]s.[ms][ms][ms][us][us][us]-[h]h:[m]m` + * `[h]h:[m]m:[s]s.[ms][ms][ms][us][us][us]+[h]h:[m]m` + * `T[h]h:[m]m:[s]s.[ms][ms][ms][us][us][us]` + * `T[h]h:[m]m:[s]s.[ms][ms][ms][us][us][us]Z` + * `T[h]h:[m]m:[s]s.[ms][ms][ms][us][us][us]-[h]h:[m]m` + * `T[h]h:[m]m:[s]s.[ms][ms][ms][us][us][us]+[h]h:[m]m` + */ + def stringToTimestamp(s: UTF8String): Option[Long] = { + if (s == null) { + return None + } + var timeZone: Option[Byte] = None + val segments: Array[Int] = Array[Int](1, 1, 1, 0, 0, 0, 0, 0, 0) + var i = 0 + var currentSegmentValue = 0 + val bytes = s.getBytes + var j = 0 + var digitsMilli = 0 + var justTime = false + while (j < bytes.length) { + val b = bytes(j) + val parsedValue = b - '0'.toByte + if (parsedValue < 0 || parsedValue > 9) { + if (j == 0 && b == 'T') { + justTime = true + i += 3 + } else if (i < 2) { + if (b == '-') { + segments(i) = currentSegmentValue + currentSegmentValue = 0 + i += 1 + } else if (i == 0 && b == ':') { + justTime = true + segments(3) = currentSegmentValue + currentSegmentValue = 0 + i = 4 + } else { + return None + } + } else if (i == 2) { + if (b == ' ' || b == 'T') { + segments(i) = currentSegmentValue + currentSegmentValue = 0 + i += 1 + } else { + return None + } + } else if (i == 3 || i == 4) { + if (b == ':') { + segments(i) = currentSegmentValue + currentSegmentValue = 0 + i += 1 + } else { + return None + } + } else if (i == 5 || i == 6) { + if (b == 'Z') { + segments(i) = currentSegmentValue + currentSegmentValue = 0 + i += 1 + timeZone = Some(43) + } else if (b == '-' || b == '+') { + segments(i) = currentSegmentValue + currentSegmentValue = 0 + i += 1 + timeZone = Some(b) + } else if (b == '.' && i == 5) { + segments(i) = currentSegmentValue + currentSegmentValue = 0 + i += 1 + } else { + return None + } + if (i == 6 && b != '.') { + i += 1 + } + } else { + if (b == ':' || b == ' ') { + segments(i) = currentSegmentValue + currentSegmentValue = 0 + i += 1 + } else { + return None + } + } + } else { + if (i == 6) { + digitsMilli += 1 + } + currentSegmentValue = currentSegmentValue * 10 + parsedValue + } + j += 1 + } + + segments(i) = currentSegmentValue + + while (digitsMilli < 6) { + segments(6) *= 10 + digitsMilli += 1 + } + + if (!justTime && (segments(0) < 1000 || segments(0) > 9999 || segments(1) < 1 || + segments(1) > 12 || segments(2) < 1 || segments(2) > 31)) { + return None + } + + if (segments(3) < 0 || segments(3) > 23 || segments(4) < 0 || segments(4) > 59 || + segments(5) < 0 || segments(5) > 59 || segments(6) < 0 || segments(6) > 999999 || + segments(7) < 0 || segments(7) > 23 || segments(8) < 0 || segments(8) > 59) { + return None + } + + val c = if (timeZone.isEmpty) { + Calendar.getInstance() + } else { + Calendar.getInstance( + TimeZone.getTimeZone(f"GMT${timeZone.get.toChar}${segments(7)}%02d:${segments(8)}%02d")) + } + c.set(Calendar.MILLISECOND, 0) + + if (justTime) { + c.set(Calendar.HOUR_OF_DAY, segments(3)) + c.set(Calendar.MINUTE, segments(4)) + c.set(Calendar.SECOND, segments(5)) + } else { + c.set(segments(0), segments(1) - 1, segments(2), segments(3), segments(4), segments(5)) + } + + Some(c.getTimeInMillis * 1000 + segments(6)) + } + + /** + * Parses a given UTF8 date string to the corresponding a corresponding [[Int]] value. + * The return type is [[Option]] in order to distinguish between 0 and null. The following + * formats are allowed: + * + * `yyyy`, + * `yyyy-[m]m` + * `yyyy-[m]m-[d]d` + * `yyyy-[m]m-[d]d ` + * `yyyy-[m]m-[d]d *` + * `yyyy-[m]m-[d]dT*` + */ + def stringToDate(s: UTF8String): Option[Int] = { + if (s == null) { + return None + } + val segments: Array[Int] = Array[Int](1, 1, 1) + var i = 0 + var currentSegmentValue = 0 + val bytes = s.getBytes + var j = 0 + while (j < bytes.length && (i < 3 && !(bytes(j) == ' ' || bytes(j) == 'T'))) { + val b = bytes(j) + if (i < 2 && b == '-') { + segments(i) = currentSegmentValue + currentSegmentValue = 0 + i += 1 + } else { + val parsedValue = b - '0'.toByte + if (parsedValue < 0 || parsedValue > 9) { + return None + } else { + currentSegmentValue = currentSegmentValue * 10 + parsedValue + } + } + j += 1 + } + segments(i) = currentSegmentValue + if (segments(0) < 1000 || segments(0) > 9999 || segments(1) < 1 || segments(1) > 12 || + segments(2) < 1 || segments(2) > 31) { + return None + } + val c = Calendar.getInstance(TimeZone.getTimeZone("GMT")) + c.set(segments(0), segments(1) - 1, segments(2), 0, 0, 0) + c.set(Calendar.MILLISECOND, 0) + Some((c.getTimeInMillis / MILLIS_PER_DAY).toInt) + } + + /** + * Returns the hour value of a given timestamp value. The timestamp is expressed in microseconds. + */ + def getHours(timestamp: Long): Int = { + val localTs = (timestamp / 1000) + defaultTimeZone.getOffset(timestamp / 1000) + ((localTs / 1000 / 3600) % 24).toInt + } + + /** + * Returns the minute value of a given timestamp value. The timestamp is expressed in + * microseconds. + */ + def getMinutes(timestamp: Long): Int = { + val localTs = (timestamp / 1000) + defaultTimeZone.getOffset(timestamp / 1000) + ((localTs / 1000 / 60) % 60).toInt + } + + /** + * Returns the second value of a given timestamp value. The timestamp is expressed in + * microseconds. + */ + def getSeconds(timestamp: Long): Int = { + ((timestamp / 1000 / 1000) % 60).toInt + } + + private[this] def isLeapYear(year: Int): Boolean = { + (year % 4) == 0 && ((year % 100) != 0 || (year % 400) == 0) + } + + /** + * Return the number of days since the start of 400 year period. + * The second year of a 400 year period (year 1) starts on day 365. + */ + private[this] def yearBoundary(year: Int): Int = { + year * 365 + ((year / 4 ) - (year / 100) + (year / 400)) + } + + /** + * Calculates the number of years for the given number of days. This depends + * on a 400 year period. + * @param days days since the beginning of the 400 year period + * @return (number of year, days in year) + */ + private[this] def numYears(days: Int): (Int, Int) = { + val year = days / 365 + val boundary = yearBoundary(year) + if (days > boundary) (year, days - boundary) else (year - 1, days - yearBoundary(year - 1)) + } + + /** + * Calculates the year and and the number of the day in the year for the given + * number of days. The given days is the number of days since 1.1.1970. + * + * The calculation uses the fact that the period 1.1.2001 until 31.12.2400 is + * equals to the period 1.1.1601 until 31.12.2000. + */ + private[this] def getYearAndDayInYear(daysSince1970: Int): (Int, Int) = { + // add the difference (in days) between 1.1.1970 and the artificial year 0 (-17999) + val daysNormalized = daysSince1970 + toYearZero + val numOfQuarterCenturies = daysNormalized / daysIn400Years + val daysInThis400 = daysNormalized % daysIn400Years + 1 + val (years, dayInYear) = numYears(daysInThis400) + val year: Int = (2001 - 20000) + 400 * numOfQuarterCenturies + years + (year, dayInYear) + } + + /** + * Returns the 'day in year' value for the given date. The date is expressed in days + * since 1.1.1970. + */ + def getDayInYear(date: Int): Int = { + getYearAndDayInYear(date)._2 + } + + /** + * Returns the year value for the given date. The date is expressed in days + * since 1.1.1970. + */ + def getYear(date: Int): Int = { + getYearAndDayInYear(date)._1 + } + + /** + * Returns the quarter for the given date. The date is expressed in days + * since 1.1.1970. + */ + def getQuarter(date: Int): Int = { + var (year, dayInYear) = getYearAndDayInYear(date) + if (isLeapYear(year)) { + dayInYear = dayInYear - 1 + } + if (dayInYear <= 90) { + 1 + } else if (dayInYear <= 181) { + 2 + } else if (dayInYear <= 273) { + 3 + } else { + 4 + } + } + + /** + * Returns the month value for the given date. The date is expressed in days + * since 1.1.1970. January is month 1. + */ + def getMonth(date: Int): Int = { + var (year, dayInYear) = getYearAndDayInYear(date) + if (isLeapYear(year)) { + if (dayInYear == 60) { + return 2 + } else if (dayInYear > 60) { + dayInYear = dayInYear - 1 + } + } + + if (dayInYear <= 31) { + 1 + } else if (dayInYear <= 59) { + 2 + } else if (dayInYear <= 90) { + 3 + } else if (dayInYear <= 120) { + 4 + } else if (dayInYear <= 151) { + 5 + } else if (dayInYear <= 181) { + 6 + } else if (dayInYear <= 212) { + 7 + } else if (dayInYear <= 243) { + 8 + } else if (dayInYear <= 273) { + 9 + } else if (dayInYear <= 304) { + 10 + } else if (dayInYear <= 334) { + 11 + } else { + 12 + } + } + + /** + * Returns the 'day of month' value for the given date. The date is expressed in days + * since 1.1.1970. + */ + def getDayOfMonth(date: Int): Int = { + var (year, dayInYear) = getYearAndDayInYear(date) + if (isLeapYear(year)) { + if (dayInYear == 60) { + return 29 + } else if (dayInYear > 60) { + dayInYear = dayInYear - 1 + } + } + + if (dayInYear <= 31) { + dayInYear + } else if (dayInYear <= 59) { + dayInYear - 31 + } else if (dayInYear <= 90) { + dayInYear - 59 + } else if (dayInYear <= 120) { + dayInYear - 90 + } else if (dayInYear <= 151) { + dayInYear - 120 + } else if (dayInYear <= 181) { + dayInYear - 151 + } else if (dayInYear <= 212) { + dayInYear - 181 + } else if (dayInYear <= 243) { + dayInYear - 212 + } else if (dayInYear <= 273) { + dayInYear - 243 + } else if (dayInYear <= 304) { + dayInYear - 273 + } else if (dayInYear <= 334) { + dayInYear - 304 + } else { + dayInYear - 334 + } + } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/NumberConverter.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/NumberConverter.scala new file mode 100644 index 0000000000000..9fefc5656aac0 --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/NumberConverter.scala @@ -0,0 +1,176 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.util + +import org.apache.spark.unsafe.types.UTF8String + +object NumberConverter { + + private val value = new Array[Byte](64) + + /** + * Divide x by m as if x is an unsigned 64-bit integer. Examples: + * unsignedLongDiv(-1, 2) == Long.MAX_VALUE unsignedLongDiv(6, 3) == 2 + * unsignedLongDiv(0, 5) == 0 + * + * @param x is treated as unsigned + * @param m is treated as signed + */ + private def unsignedLongDiv(x: Long, m: Int): Long = { + if (x >= 0) { + x / m + } else { + // Let uval be the value of the unsigned long with the same bits as x + // Two's complement => x = uval - 2*MAX - 2 + // => uval = x + 2*MAX + 2 + // Now, use the fact: (a+b)/c = a/c + b/c + (a%c+b%c)/c + x / m + 2 * (Long.MaxValue / m) + 2 / m + (x % m + 2 * (Long.MaxValue % m) + 2 % m) / m + } + } + + /** + * Decode v into value[]. + * + * @param v is treated as an unsigned 64-bit integer + * @param radix must be between MIN_RADIX and MAX_RADIX + */ + private def decode(v: Long, radix: Int): Unit = { + var tmpV = v + java.util.Arrays.fill(value, 0.asInstanceOf[Byte]) + var i = value.length - 1 + while (tmpV != 0) { + val q = unsignedLongDiv(tmpV, radix) + value(i) = (tmpV - q * radix).asInstanceOf[Byte] + tmpV = q + i -= 1 + } + } + + /** + * Convert value[] into a long. On overflow, return -1 (as mySQL does). If a + * negative digit is found, ignore the suffix starting there. + * + * @param radix must be between MIN_RADIX and MAX_RADIX + * @param fromPos is the first element that should be conisdered + * @return the result should be treated as an unsigned 64-bit integer. + */ + private def encode(radix: Int, fromPos: Int): Long = { + var v: Long = 0L + val bound = unsignedLongDiv(-1 - radix, radix) // Possible overflow once + // val + // exceeds this value + var i = fromPos + while (i < value.length && value(i) >= 0) { + if (v >= bound) { + // Check for overflow + if (unsignedLongDiv(-1 - value(i), radix) < v) { + return -1 + } + } + v = v * radix + value(i) + i += 1 + } + v + } + + /** + * Convert the bytes in value[] to the corresponding chars. + * + * @param radix must be between MIN_RADIX and MAX_RADIX + * @param fromPos is the first nonzero element + */ + private def byte2char(radix: Int, fromPos: Int): Unit = { + var i = fromPos + while (i < value.length) { + value(i) = Character.toUpperCase(Character.forDigit(value(i), radix)).asInstanceOf[Byte] + i += 1 + } + } + + /** + * Convert the chars in value[] to the corresponding integers. Convert invalid + * characters to -1. + * + * @param radix must be between MIN_RADIX and MAX_RADIX + * @param fromPos is the first nonzero element + */ + private def char2byte(radix: Int, fromPos: Int): Unit = { + var i = fromPos + while ( i < value.length) { + value(i) = Character.digit(value(i), radix).asInstanceOf[Byte] + i += 1 + } + } + + /** + * Convert numbers between different number bases. If toBase>0 the result is + * unsigned, otherwise it is signed. + * NB: This logic is borrowed from org.apache.hadoop.hive.ql.ud.UDFConv + */ + def convert(n: Array[Byte] , fromBase: Int, toBase: Int ): UTF8String = { + if (fromBase < Character.MIN_RADIX || fromBase > Character.MAX_RADIX + || Math.abs(toBase) < Character.MIN_RADIX + || Math.abs(toBase) > Character.MAX_RADIX) { + return null + } + + if (n.length == 0) { + return null + } + + var (negative, first) = if (n(0) == '-') (true, 1) else (false, 0) + + // Copy the digits in the right side of the array + var i = 1 + while (i <= n.length - first) { + value(value.length - i) = n(n.length - i) + i += 1 + } + char2byte(fromBase, value.length - n.length + first) + + // Do the conversion by going through a 64 bit integer + var v = encode(fromBase, value.length - n.length + first) + if (negative && toBase > 0) { + if (v < 0) { + v = -1 + } else { + v = -v + } + } + if (toBase < 0 && v < 0) { + v = -v + negative = true + } + decode(v, Math.abs(toBase)) + + // Find the first non-zero digit or the last digits if all are zero. + val firstNonZeroPos = { + val firstNonZero = value.indexWhere( _ != 0) + if (firstNonZero != -1) firstNonZero else value.length - 1 + } + + byte2char(Math.abs(toBase), firstNonZeroPos) + + var resultStartPos = firstNonZeroPos + if (negative && toBase < 0) { + resultStartPos = firstNonZeroPos - 1 + value(resultStartPos) = '-' + } + UTF8String.fromBytes(java.util.Arrays.copyOfRange(value, resultStartPos, value.length)) + } +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/TypeUtils.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/TypeUtils.scala index 3148309a2166f..0103ddcf9cfb7 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/TypeUtils.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/TypeUtils.scala @@ -32,14 +32,6 @@ object TypeUtils { } } - def checkForBitwiseExpr(t: DataType, caller: String): TypeCheckResult = { - if (t.isInstanceOf[IntegralType] || t == NullType) { - TypeCheckResult.TypeCheckSuccess - } else { - TypeCheckResult.TypeCheckFailure(s"$caller accepts integral types, not $t") - } - } - def checkForOrderingExpr(t: DataType, caller: String): TypeCheckResult = { if (t.isInstanceOf[AtomicType] || t == NullType) { TypeCheckResult.TypeCheckSuccess diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/AbstractDataType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/AbstractDataType.scala index 32f87440b4e37..40bf4b299c990 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/AbstractDataType.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/AbstractDataType.scala @@ -34,32 +34,18 @@ private[sql] abstract class AbstractDataType { private[sql] def defaultConcreteType: DataType /** - * Returns true if this data type is the same type as `other`. This is different that equality - * as equality will also consider data type parametrization, such as decimal precision. + * Returns true if `other` is an acceptable input type for a function that expects this, + * possibly abstract DataType. * * {{{ * // this should return true - * DecimalType.isSameType(DecimalType(10, 2)) - * - * // this should return false - * NumericType.isSameType(DecimalType(10, 2)) - * }}} - */ - private[sql] def isSameType(other: DataType): Boolean - - /** - * Returns true if `other` is an acceptable input type for a function that expectes this, - * possibly abstract, DataType. - * - * {{{ - * // this should return true - * DecimalType.isSameType(DecimalType(10, 2)) + * DecimalType.acceptsType(DecimalType(10, 2)) * * // this should return true as well * NumericType.acceptsType(DecimalType(10, 2)) * }}} */ - private[sql] def acceptsType(other: DataType): Boolean = isSameType(other) + private[sql] def acceptsType(other: DataType): Boolean /** Readable string representation for the type. */ private[sql] def simpleString: String @@ -83,10 +69,8 @@ private[sql] class TypeCollection(private val types: Seq[AbstractDataType]) override private[sql] def defaultConcreteType: DataType = types.head.defaultConcreteType - override private[sql] def isSameType(other: DataType): Boolean = false - override private[sql] def acceptsType(other: DataType): Boolean = - types.exists(_.isSameType(other)) + types.exists(_.acceptsType(other)) override private[sql] def simpleString: String = { types.map(_.simpleString).mkString("(", " or ", ")") @@ -96,6 +80,23 @@ private[sql] class TypeCollection(private val types: Seq[AbstractDataType]) private[sql] object TypeCollection { + /** + * Types that can be ordered/compared. In the long run we should probably make this a trait + * that can be mixed into each data type, and perhaps create an [[AbstractDataType]]. + */ + val Ordered = TypeCollection( + BooleanType, + ByteType, ShortType, IntegerType, LongType, + FloatType, DoubleType, DecimalType, + TimestampType, DateType, + StringType, BinaryType) + + /** + * Types that include numeric types and interval type. They are only used in unary_minus, + * unary_positive, add and subtract operations. + */ + val NumericAndInterval = TypeCollection(NumericType, IntervalType) + def apply(types: AbstractDataType*): TypeCollection = new TypeCollection(types) def unapply(typ: AbstractDataType): Option[Seq[AbstractDataType]] = typ match { @@ -105,6 +106,21 @@ private[sql] object TypeCollection { } +/** + * An [[AbstractDataType]] that matches any concrete data types. + */ +protected[sql] object AnyDataType extends AbstractDataType { + + // Note that since AnyDataType matches any concrete types, defaultConcreteType should never + // be invoked. + override private[sql] def defaultConcreteType: DataType = throw new UnsupportedOperationException + + override private[sql] def simpleString: String = "any" + + override private[sql] def acceptsType(other: DataType): Boolean = true +} + + /** * An internal type used to represent everything that is not null, UDTs, arrays, structs, and maps. */ @@ -148,13 +164,11 @@ private[sql] object NumericType extends AbstractDataType { override private[sql] def simpleString: String = "numeric" - override private[sql] def isSameType(other: DataType): Boolean = false - override private[sql] def acceptsType(other: DataType): Boolean = other.isInstanceOf[NumericType] } -private[sql] object IntegralType { +private[sql] object IntegralType extends AbstractDataType { /** * Enables matching against IntegralType for expressions: * {{{ @@ -163,6 +177,12 @@ private[sql] object IntegralType { * }}} */ def unapply(e: Expression): Boolean = e.dataType.isInstanceOf[IntegralType] + + override private[sql] def defaultConcreteType: DataType = IntegerType + + override private[sql] def simpleString: String = "integral" + + override private[sql] def acceptsType(other: DataType): Boolean = other.isInstanceOf[IntegralType] } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/ArrayType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/ArrayType.scala index 76ca7a84c1d1a..5094058164b2f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/ArrayType.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/ArrayType.scala @@ -28,7 +28,7 @@ object ArrayType extends AbstractDataType { override private[sql] def defaultConcreteType: DataType = ArrayType(NullType, containsNull = true) - override private[sql] def isSameType(other: DataType): Boolean = { + override private[sql] def acceptsType(other: DataType): Boolean = { other.isInstanceOf[ArrayType] } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DataType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DataType.scala index 57718228e490f..e98fd2583b931 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DataType.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DataType.scala @@ -27,6 +27,7 @@ import org.json4s.jackson.JsonMethods._ import org.apache.spark.annotation.DeveloperApi import org.apache.spark.sql.catalyst.expressions.Expression +import org.apache.spark.util.Utils /** @@ -78,7 +79,7 @@ abstract class DataType extends AbstractDataType { override private[sql] def defaultConcreteType: DataType = this - override private[sql] def isSameType(other: DataType): Boolean = this == other + override private[sql] def acceptsType(other: DataType): Boolean = sameType(other) } @@ -146,7 +147,7 @@ object DataType { ("pyClass", _), ("sqlType", _), ("type", JString("udt"))) => - Class.forName(udtClass).newInstance().asInstanceOf[UserDefinedType[_]] + Utils.classForName(udtClass).newInstance().asInstanceOf[UserDefinedType[_]] } private def parseStructField(json: JValue): StructField = json match { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/Decimal.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/Decimal.scala index 5a169488c97eb..bc689810bc292 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/Decimal.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/Decimal.scala @@ -17,8 +17,6 @@ package org.apache.spark.sql.types -import java.math.{MathContext, RoundingMode} - import org.apache.spark.annotation.DeveloperApi /** @@ -139,9 +137,9 @@ final class Decimal extends Ordered[Decimal] with Serializable { def toBigDecimal: BigDecimal = { if (decimalVal.ne(null)) { - decimalVal(MathContext.UNLIMITED) + decimalVal } else { - BigDecimal(longVal, _scale)(MathContext.UNLIMITED) + BigDecimal(longVal, _scale) } } @@ -265,15 +263,8 @@ final class Decimal extends Ordered[Decimal] with Serializable { def * (that: Decimal): Decimal = Decimal(toBigDecimal * that.toBigDecimal) - def / (that: Decimal): Decimal = { - if (that.isZero) { - null - } else { - // To avoid non-terminating decimal expansion problem, we turn to Java BigDecimal's divide - // with specified ROUNDING_MODE. - Decimal(toJavaBigDecimal.divide(that.toJavaBigDecimal, ROUNDING_MODE.id)) - } - } + def / (that: Decimal): Decimal = + if (that.isZero) null else Decimal(toBigDecimal / that.toBigDecimal) def % (that: Decimal): Decimal = if (that.isZero) null else Decimal(toBigDecimal % that.toBigDecimal) @@ -287,6 +278,8 @@ final class Decimal extends Ordered[Decimal] with Serializable { Decimal(-longVal, precision, scale) } } + + def abs: Decimal = if (this.compare(Decimal(0)) < 0) this.unary_- else this } object Decimal { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DecimalType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DecimalType.scala index a1cafeab1704d..377c75f6e85a5 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DecimalType.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DecimalType.scala @@ -86,7 +86,7 @@ object DecimalType extends AbstractDataType { override private[sql] def defaultConcreteType: DataType = Unlimited - override private[sql] def isSameType(other: DataType): Boolean = { + override private[sql] def acceptsType(other: DataType): Boolean = { other.isInstanceOf[DecimalType] } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DoubleType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DoubleType.scala index 986c2ab055386..2a1bf0938e5a8 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DoubleType.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DoubleType.scala @@ -23,6 +23,7 @@ import scala.reflect.runtime.universe.typeTag import org.apache.spark.annotation.DeveloperApi import org.apache.spark.sql.catalyst.ScalaReflectionLock +import org.apache.spark.util.Utils /** * :: DeveloperApi :: @@ -37,7 +38,9 @@ class DoubleType private() extends FractionalType { @transient private[sql] lazy val tag = ScalaReflectionLock.synchronized { typeTag[InternalType] } private[sql] val numeric = implicitly[Numeric[Double]] private[sql] val fractional = implicitly[Fractional[Double]] - private[sql] val ordering = implicitly[Ordering[InternalType]] + private[sql] val ordering = new Ordering[Double] { + override def compare(x: Double, y: Double): Int = Utils.nanSafeCompareDoubles(x, y) + } private[sql] val asIntegral = DoubleAsIfIntegral /** diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/FloatType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/FloatType.scala index 9bd48ece83a1c..08e22252aef82 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/FloatType.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/FloatType.scala @@ -23,6 +23,7 @@ import scala.reflect.runtime.universe.typeTag import org.apache.spark.annotation.DeveloperApi import org.apache.spark.sql.catalyst.ScalaReflectionLock +import org.apache.spark.util.Utils /** * :: DeveloperApi :: @@ -37,7 +38,9 @@ class FloatType private() extends FractionalType { @transient private[sql] lazy val tag = ScalaReflectionLock.synchronized { typeTag[InternalType] } private[sql] val numeric = implicitly[Numeric[Float]] private[sql] val fractional = implicitly[Fractional[Float]] - private[sql] val ordering = implicitly[Ordering[InternalType]] + private[sql] val ordering = new Ordering[Float] { + override def compare(x: Float, y: Float): Int = Utils.nanSafeCompareFloats(x, y) + } private[sql] val asIntegral = FloatAsIfIntegral /** diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/MapType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/MapType.scala index ddead10bc2171..ac34b642827ca 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/MapType.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/MapType.scala @@ -71,7 +71,7 @@ object MapType extends AbstractDataType { override private[sql] def defaultConcreteType: DataType = apply(NullType, NullType) - override private[sql] def isSameType(other: DataType): Boolean = { + override private[sql] def acceptsType(other: DataType): Boolean = { other.isInstanceOf[MapType] } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/StructType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/StructType.scala index b8097403ec3cc..2ef97a427c37e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/StructType.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/StructType.scala @@ -307,7 +307,7 @@ object StructType extends AbstractDataType { override private[sql] def defaultConcreteType: DataType = new StructType - override private[sql] def isSameType(other: DataType): Boolean = { + override private[sql] def acceptsType(other: DataType): Boolean = { other.isInstanceOf[StructType] } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/RowTest.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/RowTest.scala index bbb9739e9cc76..878a1bb9b7e6d 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/RowTest.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/RowTest.scala @@ -17,6 +17,7 @@ package org.apache.spark.sql +import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.{GenericRow, GenericRowWithSchema} import org.apache.spark.sql.types._ import org.scalatest.{Matchers, FunSpec} @@ -68,4 +69,29 @@ class RowTest extends FunSpec with Matchers { sampleRow.getValuesMap(List("col1", "col2")) shouldBe expected } } + + describe("row equals") { + val externalRow = Row(1, 2) + val externalRow2 = Row(1, 2) + val internalRow = InternalRow(1, 2) + val internalRow2 = InternalRow(1, 2) + + it("equality check for external rows") { + externalRow shouldEqual externalRow2 + } + + it("equality check for internal rows") { + internalRow shouldEqual internalRow2 + } + + it("throws an exception when check equality between external and internal rows") { + def assertError(f: => Unit): Unit = { + val e = intercept[UnsupportedOperationException](f) + e.getMessage.contains("cannot check equality between external and internal rows") + } + + assertError(internalRow.equals(externalRow)) + assertError(externalRow.equals(internalRow)) + } + } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala index 9d0c69a2451d1..dca8c881f21ab 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala @@ -23,16 +23,17 @@ import org.apache.spark.SparkFunSuite import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.logical._ -import org.apache.spark.sql.types._ -import org.apache.spark.sql.catalyst.{InternalRow, SimpleCatalystConf} +import org.apache.spark.sql.catalyst.plans.Inner +import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.dsl.plans._ +import org.apache.spark.sql.types._ case class TestFunction( children: Seq[Expression], - inputTypes: Seq[AbstractDataType]) extends Expression with ExpectsInputTypes { + inputTypes: Seq[AbstractDataType]) + extends Expression with ImplicitCastInputTypes with Unevaluable { override def nullable: Boolean = true - override def eval(input: InternalRow): Any = throw new UnsupportedOperationException override def dataType: DataType = StringType } @@ -164,4 +165,13 @@ class AnalysisErrorSuite extends SparkFunSuite with BeforeAndAfter { assert(message.contains("resolved attribute(s) a#1 missing from a#2")) } + + test("error test for self-join") { + val join = Join(testRelation, testRelation, Inner, None) + val error = intercept[AnalysisException] { + SimpleAnalyzer.checkAnalysis(join) + } + error.message.contains("Failure when resolving conflicting references in Join") + error.message.contains("Conflicting attributes") + } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ExpressionTypeCheckingSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ExpressionTypeCheckingSuite.scala index 8e0551b23eea6..ad15136ee9a2f 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ExpressionTypeCheckingSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ExpressionTypeCheckingSuite.scala @@ -23,7 +23,7 @@ import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.dsl.plans._ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.logical.LocalRelation -import org.apache.spark.sql.types.StringType +import org.apache.spark.sql.types.{TypeCollection, StringType} class ExpressionTypeCheckingSuite extends SparkFunSuite { @@ -49,13 +49,13 @@ class ExpressionTypeCheckingSuite extends SparkFunSuite { def assertErrorForDifferingTypes(expr: Expression): Unit = { assertError(expr, - s"differing types in ${expr.getClass.getSimpleName} (IntegerType and BooleanType).") + s"differing types in '${expr.prettyString}'") } test("check types for unary arithmetic") { - assertError(UnaryMinus('stringField), "operator - accepts numeric type") - assertError(Abs('stringField), "function abs accepts numeric type") - assertError(BitwiseNot('stringField), "operator ~ accepts integral type") + assertError(UnaryMinus('stringField), "type (numeric or interval)") + assertError(Abs('stringField), "expected to be of type numeric") + assertError(BitwiseNot('stringField), "expected to be of type integral") } test("check types for binary arithmetic") { @@ -78,18 +78,20 @@ class ExpressionTypeCheckingSuite extends SparkFunSuite { assertErrorForDifferingTypes(MaxOf('intField, 'booleanField)) assertErrorForDifferingTypes(MinOf('intField, 'booleanField)) - assertError(Add('booleanField, 'booleanField), "operator + accepts numeric type") - assertError(Subtract('booleanField, 'booleanField), "operator - accepts numeric type") - assertError(Multiply('booleanField, 'booleanField), "operator * accepts numeric type") - assertError(Divide('booleanField, 'booleanField), "operator / accepts numeric type") - assertError(Remainder('booleanField, 'booleanField), "operator % accepts numeric type") + assertError(Add('booleanField, 'booleanField), "accepts (numeric or interval) type") + assertError(Subtract('booleanField, 'booleanField), "accepts (numeric or interval) type") + assertError(Multiply('booleanField, 'booleanField), "accepts numeric type") + assertError(Divide('booleanField, 'booleanField), "accepts numeric type") + assertError(Remainder('booleanField, 'booleanField), "accepts numeric type") - assertError(BitwiseAnd('booleanField, 'booleanField), "operator & accepts integral type") - assertError(BitwiseOr('booleanField, 'booleanField), "operator | accepts integral type") - assertError(BitwiseXor('booleanField, 'booleanField), "operator ^ accepts integral type") + assertError(BitwiseAnd('booleanField, 'booleanField), "accepts integral type") + assertError(BitwiseOr('booleanField, 'booleanField), "accepts integral type") + assertError(BitwiseXor('booleanField, 'booleanField), "accepts integral type") - assertError(MaxOf('complexField, 'complexField), "function maxOf accepts non-complex type") - assertError(MinOf('complexField, 'complexField), "function minOf accepts non-complex type") + assertError(MaxOf('complexField, 'complexField), + s"accepts ${TypeCollection.Ordered.simpleString} type") + assertError(MinOf('complexField, 'complexField), + s"accepts ${TypeCollection.Ordered.simpleString} type") } test("check types for predicates") { @@ -105,25 +107,23 @@ class ExpressionTypeCheckingSuite extends SparkFunSuite { assertSuccess(EqualTo('intField, 'booleanField)) assertSuccess(EqualNullSafe('intField, 'booleanField)) - assertError(EqualTo('intField, 'complexField), "differing types") - assertError(EqualNullSafe('intField, 'complexField), "differing types") - + assertErrorForDifferingTypes(EqualTo('intField, 'complexField)) + assertErrorForDifferingTypes(EqualNullSafe('intField, 'complexField)) assertErrorForDifferingTypes(LessThan('intField, 'booleanField)) assertErrorForDifferingTypes(LessThanOrEqual('intField, 'booleanField)) assertErrorForDifferingTypes(GreaterThan('intField, 'booleanField)) assertErrorForDifferingTypes(GreaterThanOrEqual('intField, 'booleanField)) - assertError( - LessThan('complexField, 'complexField), "operator < accepts non-complex type") - assertError( - LessThanOrEqual('complexField, 'complexField), "operator <= accepts non-complex type") - assertError( - GreaterThan('complexField, 'complexField), "operator > accepts non-complex type") - assertError( - GreaterThanOrEqual('complexField, 'complexField), "operator >= accepts non-complex type") + assertError(LessThan('complexField, 'complexField), + s"accepts ${TypeCollection.Ordered.simpleString} type") + assertError(LessThanOrEqual('complexField, 'complexField), + s"accepts ${TypeCollection.Ordered.simpleString} type") + assertError(GreaterThan('complexField, 'complexField), + s"accepts ${TypeCollection.Ordered.simpleString} type") + assertError(GreaterThanOrEqual('complexField, 'complexField), + s"accepts ${TypeCollection.Ordered.simpleString} type") - assertError( - If('intField, 'stringField, 'stringField), + assertError(If('intField, 'stringField, 'stringField), "type of predicate expression in If should be boolean") assertErrorForDifferingTypes(If('booleanField, 'intField, 'booleanField)) @@ -171,4 +171,14 @@ class ExpressionTypeCheckingSuite extends SparkFunSuite { CreateNamedStruct(Seq('a.string.at(0), "a", "b", 2.0)), "Odd position only allow foldable and not-null StringType expressions") } + + test("check types for ROUND") { + assertSuccess(Round(Literal(null), Literal(null))) + assertSuccess(Round('intField, Literal(1))) + + assertError(Round('intField, 'intField), "Only foldable Expression is allowed") + assertError(Round('intField, 'booleanField), "expected to be of type int") + assertError(Round('intField, 'complexField), "expected to be of type int") + assertError(Round('booleanField, 'intField), "expected to be of type numeric") + } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercionSuite.scala index acb9a433de903..835220c563f41 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercionSuite.scala @@ -20,7 +20,7 @@ package org.apache.spark.sql.catalyst.analysis import org.apache.spark.sql.catalyst.plans.PlanTest import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, LocalRelation, Project} +import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.rules.Rule import org.apache.spark.sql.types._ @@ -37,7 +37,6 @@ class HiveTypeCoercionSuite extends PlanTest { shouldCast(NullType, IntegerType, IntegerType) shouldCast(NullType, DecimalType, DecimalType.Unlimited) - // TODO: write the entire implicit cast table out for test cases. shouldCast(ByteType, IntegerType, IntegerType) shouldCast(IntegerType, IntegerType, IntegerType) shouldCast(IntegerType, LongType, LongType) @@ -86,6 +85,16 @@ class HiveTypeCoercionSuite extends PlanTest { DecimalType.Unlimited, DecimalType(10, 2)).foreach { tpe => shouldCast(tpe, NumericType, tpe) } + + shouldCast( + ArrayType(StringType, false), + TypeCollection(ArrayType(StringType), StringType), + ArrayType(StringType, false)) + + shouldCast( + ArrayType(StringType, true), + TypeCollection(ArrayType(StringType), StringType), + ArrayType(StringType, true)) } test("ineligible implicit type cast") { @@ -194,6 +203,30 @@ class HiveTypeCoercionSuite extends PlanTest { Project(Seq(Alias(transformed, "a")()), testRelation)) } + test("cast NullType for expresions that implement ExpectsInputTypes") { + import HiveTypeCoercionSuite._ + + ruleTest(HiveTypeCoercion.ImplicitTypeCasts, + AnyTypeUnaryExpression(Literal.create(null, NullType)), + AnyTypeUnaryExpression(Literal.create(null, NullType))) + + ruleTest(HiveTypeCoercion.ImplicitTypeCasts, + NumericTypeUnaryExpression(Literal.create(null, NullType)), + NumericTypeUnaryExpression(Literal.create(null, DoubleType))) + } + + test("cast NullType for binary operators") { + import HiveTypeCoercionSuite._ + + ruleTest(HiveTypeCoercion.ImplicitTypeCasts, + AnyTypeBinaryOperator(Literal.create(null, NullType), Literal.create(null, NullType)), + AnyTypeBinaryOperator(Literal.create(null, NullType), Literal.create(null, NullType))) + + ruleTest(HiveTypeCoercion.ImplicitTypeCasts, + NumericTypeBinaryOperator(Literal.create(null, NullType), Literal.create(null, NullType)), + NumericTypeBinaryOperator(Literal.create(null, DoubleType), Literal.create(null, DoubleType))) + } + test("coalesce casts") { ruleTest(HiveTypeCoercion.FunctionArgumentConversion, Coalesce(Literal(1.0) @@ -281,6 +314,93 @@ class HiveTypeCoercionSuite extends PlanTest { ) } + test("WidenTypes for union except and intersect") { + def checkOutput(logical: LogicalPlan, expectTypes: Seq[DataType]): Unit = { + logical.output.zip(expectTypes).foreach { case (attr, dt) => + assert(attr.dataType === dt) + } + } + + val left = LocalRelation( + AttributeReference("i", IntegerType)(), + AttributeReference("u", DecimalType.Unlimited)(), + AttributeReference("b", ByteType)(), + AttributeReference("d", DoubleType)()) + val right = LocalRelation( + AttributeReference("s", StringType)(), + AttributeReference("d", DecimalType(2, 1))(), + AttributeReference("f", FloatType)(), + AttributeReference("l", LongType)()) + + val wt = HiveTypeCoercion.WidenTypes + val expectedTypes = Seq(StringType, DecimalType.Unlimited, FloatType, DoubleType) + + val r1 = wt(Union(left, right)).asInstanceOf[Union] + val r2 = wt(Except(left, right)).asInstanceOf[Except] + val r3 = wt(Intersect(left, right)).asInstanceOf[Intersect] + checkOutput(r1.left, expectedTypes) + checkOutput(r1.right, expectedTypes) + checkOutput(r2.left, expectedTypes) + checkOutput(r2.right, expectedTypes) + checkOutput(r3.left, expectedTypes) + checkOutput(r3.right, expectedTypes) + } + + test("Transform Decimal precision/scale for union except and intersect") { + def checkOutput(logical: LogicalPlan, expectTypes: Seq[DataType]): Unit = { + logical.output.zip(expectTypes).foreach { case (attr, dt) => + assert(attr.dataType === dt) + } + } + + val dp = HiveTypeCoercion.DecimalPrecision + + val left1 = LocalRelation( + AttributeReference("l", DecimalType(10, 8))()) + val right1 = LocalRelation( + AttributeReference("r", DecimalType(5, 5))()) + val expectedType1 = Seq(DecimalType(math.max(8, 5) + math.max(10 - 8, 5 - 5), math.max(8, 5))) + + val r1 = dp(Union(left1, right1)).asInstanceOf[Union] + val r2 = dp(Except(left1, right1)).asInstanceOf[Except] + val r3 = dp(Intersect(left1, right1)).asInstanceOf[Intersect] + + checkOutput(r1.left, expectedType1) + checkOutput(r1.right, expectedType1) + checkOutput(r2.left, expectedType1) + checkOutput(r2.right, expectedType1) + checkOutput(r3.left, expectedType1) + checkOutput(r3.right, expectedType1) + + val plan1 = LocalRelation( + AttributeReference("l", DecimalType(10, 10))()) + + val rightTypes = Seq(ByteType, ShortType, IntegerType, LongType, FloatType, DoubleType) + val expectedTypes = Seq(DecimalType(3, 0), DecimalType(5, 0), DecimalType(10, 0), + DecimalType(20, 0), DecimalType(7, 7), DecimalType(15, 15)) + + rightTypes.zip(expectedTypes).map { case (rType, expectedType) => + val plan2 = LocalRelation( + AttributeReference("r", rType)()) + + val r1 = dp(Union(plan1, plan2)).asInstanceOf[Union] + val r2 = dp(Except(plan1, plan2)).asInstanceOf[Except] + val r3 = dp(Intersect(plan1, plan2)).asInstanceOf[Intersect] + + checkOutput(r1.right, Seq(expectedType)) + checkOutput(r2.right, Seq(expectedType)) + checkOutput(r3.right, Seq(expectedType)) + + val r4 = dp(Union(plan2, plan1)).asInstanceOf[Union] + val r5 = dp(Except(plan2, plan1)).asInstanceOf[Except] + val r6 = dp(Intersect(plan2, plan1)).asInstanceOf[Intersect] + + checkOutput(r4.left, Seq(expectedType)) + checkOutput(r5.left, Seq(expectedType)) + checkOutput(r6.left, Seq(expectedType)) + } + } + /** * There are rules that need to not fire before child expressions get resolved. * We use this test to make sure those rules do not fire early. @@ -302,3 +422,33 @@ class HiveTypeCoercionSuite extends PlanTest { ) } } + + +object HiveTypeCoercionSuite { + + case class AnyTypeUnaryExpression(child: Expression) + extends UnaryExpression with ExpectsInputTypes with Unevaluable { + override def inputTypes: Seq[AbstractDataType] = Seq(AnyDataType) + override def dataType: DataType = NullType + } + + case class NumericTypeUnaryExpression(child: Expression) + extends UnaryExpression with ExpectsInputTypes with Unevaluable { + override def inputTypes: Seq[AbstractDataType] = Seq(NumericType) + override def dataType: DataType = NullType + } + + case class AnyTypeBinaryOperator(left: Expression, right: Expression) + extends BinaryOperator with Unevaluable { + override def dataType: DataType = NullType + override def inputType: AbstractDataType = AnyDataType + override def symbol: String = "anytype" + } + + case class NumericTypeBinaryOperator(left: Expression, right: Expression) + extends BinaryOperator with Unevaluable { + override def dataType: DataType = NullType + override def inputType: AbstractDataType = NumericType + override def symbol: String = "numerictype" + } +} diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ArithmeticExpressionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ArithmeticExpressionSuite.scala index 6c93698f8017b..e7e5231d32c9e 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ArithmeticExpressionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ArithmeticExpressionSuite.scala @@ -21,7 +21,6 @@ import org.apache.spark.SparkFunSuite import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.types.Decimal - class ArithmeticExpressionSuite extends SparkFunSuite with ExpressionEvalHelper { /** @@ -158,4 +157,19 @@ class ArithmeticExpressionSuite extends SparkFunSuite with ExpressionEvalHelper checkEvaluation(MinOf(Array(1.toByte, 2.toByte), Array(1.toByte, 3.toByte)), Array(1.toByte, 2.toByte)) } + + test("pmod") { + testNumericDataTypes { convert => + val left = Literal(convert(7)) + val right = Literal(convert(3)) + checkEvaluation(Pmod(left, right), convert(1)) + checkEvaluation(Pmod(Literal.create(null, left.dataType), right), null) + checkEvaluation(Pmod(left, Literal.create(null, right.dataType)), null) + checkEvaluation(Remainder(left, Literal(convert(0))), null) // mod by 0 + } + checkEvaluation(Pmod(-7, 3), 2) + checkEvaluation(Pmod(7.2D, 4.1D), 3.1000000000000005) + checkEvaluation(Pmod(Decimal(0.7), Decimal(0.2)), Decimal(0.1)) + checkEvaluation(Pmod(2L, Long.MaxValue), 2) + } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuite.scala index 1de161c367a1d..ccf448eee0688 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuite.scala @@ -18,6 +18,7 @@ package org.apache.spark.sql.catalyst.expressions import java.sql.{Timestamp, Date} +import java.util.{TimeZone, Calendar} import org.apache.spark.SparkFunSuite import org.apache.spark.sql.catalyst.InternalRow @@ -41,6 +42,137 @@ class CastSuite extends SparkFunSuite with ExpressionEvalHelper { checkEvaluation(cast(v, Literal(expected).dataType), expected) } + test("cast string to date") { + var c = Calendar.getInstance() + c.set(2015, 0, 1, 0, 0, 0) + c.set(Calendar.MILLISECOND, 0) + checkEvaluation(Cast(Literal("2015"), DateType), new Date(c.getTimeInMillis)) + c = Calendar.getInstance() + c.set(2015, 2, 1, 0, 0, 0) + c.set(Calendar.MILLISECOND, 0) + checkEvaluation(Cast(Literal("2015-03"), DateType), new Date(c.getTimeInMillis)) + c = Calendar.getInstance() + c.set(2015, 2, 18, 0, 0, 0) + c.set(Calendar.MILLISECOND, 0) + checkEvaluation(Cast(Literal("2015-03-18"), DateType), new Date(c.getTimeInMillis)) + checkEvaluation(Cast(Literal("2015-03-18 "), DateType), new Date(c.getTimeInMillis)) + checkEvaluation(Cast(Literal("2015-03-18 123142"), DateType), new Date(c.getTimeInMillis)) + checkEvaluation(Cast(Literal("2015-03-18T123123"), DateType), new Date(c.getTimeInMillis)) + checkEvaluation(Cast(Literal("2015-03-18T"), DateType), new Date(c.getTimeInMillis)) + + checkEvaluation(Cast(Literal("2015-03-18X"), DateType), null) + checkEvaluation(Cast(Literal("2015/03/18"), DateType), null) + checkEvaluation(Cast(Literal("2015.03.18"), DateType), null) + checkEvaluation(Cast(Literal("20150318"), DateType), null) + checkEvaluation(Cast(Literal("2015-031-8"), DateType), null) + } + + test("cast string to timestamp") { + checkEvaluation(Cast(Literal("123"), TimestampType), + null) + + var c = Calendar.getInstance() + c.set(2015, 0, 1, 0, 0, 0) + c.set(Calendar.MILLISECOND, 0) + checkEvaluation(Cast(Literal("2015"), TimestampType), + new Timestamp(c.getTimeInMillis)) + c = Calendar.getInstance() + c.set(2015, 2, 1, 0, 0, 0) + c.set(Calendar.MILLISECOND, 0) + checkEvaluation(Cast(Literal("2015-03"), TimestampType), + new Timestamp(c.getTimeInMillis)) + c = Calendar.getInstance() + c.set(2015, 2, 18, 0, 0, 0) + c.set(Calendar.MILLISECOND, 0) + checkEvaluation(Cast(Literal("2015-03-18"), TimestampType), + new Timestamp(c.getTimeInMillis)) + checkEvaluation(Cast(Literal("2015-03-18 "), TimestampType), + new Timestamp(c.getTimeInMillis)) + checkEvaluation(Cast(Literal("2015-03-18T"), TimestampType), + new Timestamp(c.getTimeInMillis)) + + c = Calendar.getInstance() + c.set(2015, 2, 18, 12, 3, 17) + c.set(Calendar.MILLISECOND, 0) + checkEvaluation(Cast(Literal("2015-03-18 12:03:17"), TimestampType), + new Timestamp(c.getTimeInMillis)) + checkEvaluation(Cast(Literal("2015-03-18T12:03:17"), TimestampType), + new Timestamp(c.getTimeInMillis)) + + c = Calendar.getInstance(TimeZone.getTimeZone("UTC")) + c.set(2015, 2, 18, 12, 3, 17) + c.set(Calendar.MILLISECOND, 0) + checkEvaluation(Cast(Literal("2015-03-18T12:03:17Z"), TimestampType), + new Timestamp(c.getTimeInMillis)) + checkEvaluation(Cast(Literal("2015-03-18 12:03:17Z"), TimestampType), + new Timestamp(c.getTimeInMillis)) + + c = Calendar.getInstance(TimeZone.getTimeZone("GMT-01:00")) + c.set(2015, 2, 18, 12, 3, 17) + c.set(Calendar.MILLISECOND, 0) + checkEvaluation(Cast(Literal("2015-03-18T12:03:17-1:0"), TimestampType), + new Timestamp(c.getTimeInMillis)) + checkEvaluation(Cast(Literal("2015-03-18T12:03:17-01:00"), TimestampType), + new Timestamp(c.getTimeInMillis)) + + c = Calendar.getInstance(TimeZone.getTimeZone("GMT+07:30")) + c.set(2015, 2, 18, 12, 3, 17) + c.set(Calendar.MILLISECOND, 0) + checkEvaluation(Cast(Literal("2015-03-18T12:03:17+07:30"), TimestampType), + new Timestamp(c.getTimeInMillis)) + + c = Calendar.getInstance(TimeZone.getTimeZone("GMT+07:03")) + c.set(2015, 2, 18, 12, 3, 17) + c.set(Calendar.MILLISECOND, 0) + checkEvaluation(Cast(Literal("2015-03-18T12:03:17+7:3"), TimestampType), + new Timestamp(c.getTimeInMillis)) + + c = Calendar.getInstance() + c.set(2015, 2, 18, 12, 3, 17) + c.set(Calendar.MILLISECOND, 123) + checkEvaluation(Cast(Literal("2015-03-18 12:03:17.123"), TimestampType), + new Timestamp(c.getTimeInMillis)) + checkEvaluation(Cast(Literal("2015-03-18T12:03:17.123"), TimestampType), + new Timestamp(c.getTimeInMillis)) + + c = Calendar.getInstance(TimeZone.getTimeZone("UTC")) + c.set(2015, 2, 18, 12, 3, 17) + c.set(Calendar.MILLISECOND, 456) + checkEvaluation(Cast(Literal("2015-03-18T12:03:17.456Z"), TimestampType), + new Timestamp(c.getTimeInMillis)) + checkEvaluation(Cast(Literal("2015-03-18 12:03:17.456Z"), TimestampType), + new Timestamp(c.getTimeInMillis)) + + c = Calendar.getInstance(TimeZone.getTimeZone("GMT-01:00")) + c.set(2015, 2, 18, 12, 3, 17) + c.set(Calendar.MILLISECOND, 123) + checkEvaluation(Cast(Literal("2015-03-18T12:03:17.123-1:0"), TimestampType), + new Timestamp(c.getTimeInMillis)) + checkEvaluation(Cast(Literal("2015-03-18T12:03:17.123-01:00"), TimestampType), + new Timestamp(c.getTimeInMillis)) + + c = Calendar.getInstance(TimeZone.getTimeZone("GMT+07:30")) + c.set(2015, 2, 18, 12, 3, 17) + c.set(Calendar.MILLISECOND, 123) + checkEvaluation(Cast(Literal("2015-03-18T12:03:17.123+07:30"), TimestampType), + new Timestamp(c.getTimeInMillis)) + + c = Calendar.getInstance(TimeZone.getTimeZone("GMT+07:03")) + c.set(2015, 2, 18, 12, 3, 17) + c.set(Calendar.MILLISECOND, 123) + checkEvaluation(Cast(Literal("2015-03-18T12:03:17.123+7:3"), TimestampType), + new Timestamp(c.getTimeInMillis)) + + checkEvaluation(Cast(Literal("2015-03-18 123142"), TimestampType), null) + checkEvaluation(Cast(Literal("2015-03-18T123123"), TimestampType), null) + checkEvaluation(Cast(Literal("2015-03-18X"), TimestampType), null) + checkEvaluation(Cast(Literal("2015/03/18"), TimestampType), null) + checkEvaluation(Cast(Literal("2015.03.18"), TimestampType), null) + checkEvaluation(Cast(Literal("20150318"), TimestampType), null) + checkEvaluation(Cast(Literal("2015-031-8"), TimestampType), null) + checkEvaluation(Cast(Literal("2015-03-18T12:03:17-0:70"), TimestampType), null) + } + test("cast from int") { checkCast(0, false) checkCast(1, true) @@ -149,6 +281,15 @@ class CastSuite extends SparkFunSuite with ExpressionEvalHelper { val nts = sts + ".1" val ts = Timestamp.valueOf(nts) + var c = Calendar.getInstance() + c.set(2015, 2, 8, 2, 30, 0) + checkEvaluation(cast(cast(new Timestamp(c.getTimeInMillis), StringType), TimestampType), + c.getTimeInMillis * 1000) + c = Calendar.getInstance() + c.set(2015, 10, 1, 2, 30, 0) + checkEvaluation(cast(cast(new Timestamp(c.getTimeInMillis), StringType), TimestampType), + c.getTimeInMillis * 1000) + checkEvaluation(cast("abdef", StringType), "abdef") checkEvaluation(cast("abdef", DecimalType.Unlimited), null) checkEvaluation(cast("abdef", TimestampType), null) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CodeGenerationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CodeGenerationSuite.scala index 481b335d15dfd..f4fbc49677ca3 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CodeGenerationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CodeGenerationSuite.scala @@ -17,14 +17,19 @@ package org.apache.spark.sql.catalyst.expressions +import scala.math._ + import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.RandomDataGenerator +import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow} import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.expressions.codegen._ +import org.apache.spark.sql.types.{DataTypeTestUtils, NullType, StructField, StructType} /** * Additional tests for code generation. */ -class CodeGenerationSuite extends SparkFunSuite { +class CodeGenerationSuite extends SparkFunSuite with ExpressionEvalHelper { test("multithreaded eval") { import scala.concurrent._ @@ -42,4 +47,50 @@ class CodeGenerationSuite extends SparkFunSuite { futures.foreach(Await.result(_, 10.seconds)) } + + // Test GenerateOrdering for all common types. For each type, we construct random input rows that + // contain two columns of that type, then for pairs of randomly-generated rows we check that + // GenerateOrdering agrees with RowOrdering. + (DataTypeTestUtils.atomicTypes ++ Set(NullType)).foreach { dataType => + test(s"GenerateOrdering with $dataType") { + val rowOrdering = RowOrdering.forSchema(Seq(dataType, dataType)) + val genOrdering = GenerateOrdering.generate( + BoundReference(0, dataType, nullable = true).asc :: + BoundReference(1, dataType, nullable = true).asc :: Nil) + val rowType = StructType( + StructField("a", dataType, nullable = true) :: + StructField("b", dataType, nullable = true) :: Nil) + val maybeDataGenerator = RandomDataGenerator.forType(rowType, nullable = false) + assume(maybeDataGenerator.isDefined) + val randGenerator = maybeDataGenerator.get + val toCatalyst = CatalystTypeConverters.createToCatalystConverter(rowType) + for (_ <- 1 to 50) { + val a = toCatalyst(randGenerator()).asInstanceOf[InternalRow] + val b = toCatalyst(randGenerator()).asInstanceOf[InternalRow] + withClue(s"a = $a, b = $b") { + assert(genOrdering.compare(a, a) === 0) + assert(genOrdering.compare(b, b) === 0) + assert(rowOrdering.compare(a, a) === 0) + assert(rowOrdering.compare(b, b) === 0) + assert(signum(genOrdering.compare(a, b)) === -1 * signum(genOrdering.compare(b, a))) + assert(signum(rowOrdering.compare(a, b)) === -1 * signum(rowOrdering.compare(b, a))) + assert( + signum(rowOrdering.compare(a, b)) === signum(genOrdering.compare(a, b)), + "Generated and non-generated orderings should agree") + } + } + } + } + + test("SPARK-8443: split wide projections into blocks due to JVM code size limit") { + val length = 5000 + val expressions = List.fill(length)(EqualTo(Literal(1), Literal(1))) + val plan = GenerateMutableProjection.generate(expressions)() + val actual = plan(new GenericMutableRow(length)).toSeq + val expected = Seq.fill(length)(true) + + if (!checkResult(actual, expected)) { + fail(s"Incorrect Evaluation: expressions: $expressions, actual: $actual, expected: $expected") + } + } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionFunctionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionFunctionsSuite.scala new file mode 100644 index 0000000000000..28c41b57169f9 --- /dev/null +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionFunctionsSuite.scala @@ -0,0 +1,46 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.expressions + +import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.types._ + + +class CollectionFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper { + + test("Array and Map Size") { + val a0 = Literal.create(Seq(1, 2, 3), ArrayType(IntegerType)) + val a1 = Literal.create(Seq[Integer](), ArrayType(IntegerType)) + val a2 = Literal.create(Seq(1, 2), ArrayType(IntegerType)) + + checkEvaluation(Size(a0), 3) + checkEvaluation(Size(a1), 0) + checkEvaluation(Size(a2), 2) + + val m0 = Literal.create(Map("a" -> "a", "b" -> "b"), MapType(StringType, StringType)) + val m1 = Literal.create(Map[String, String](), MapType(StringType, StringType)) + val m2 = Literal.create(Map("a" -> "a"), MapType(StringType, StringType)) + + checkEvaluation(Size(m0), 2) + checkEvaluation(Size(m1), 0) + checkEvaluation(Size(m2), 1) + + checkEvaluation(Literal.create(null, MapType(StringType, StringType)), null) + checkEvaluation(Literal.create(null, ArrayType(StringType)), null) + } +} diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ComplexTypeSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ComplexTypeSuite.scala index e3042143632aa..a8aee8f634e03 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ComplexTypeSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ComplexTypeSuite.scala @@ -117,6 +117,22 @@ class ComplexTypeSuite extends SparkFunSuite with ExpressionEvalHelper { checkEvaluation(getArrayStructFields(nullArrayStruct, "a"), null) } + test("CreateArray") { + val intSeq = Seq(5, 10, 15, 20, 25) + val longSeq = intSeq.map(_.toLong) + val strSeq = intSeq.map(_.toString) + checkEvaluation(CreateArray(intSeq.map(Literal(_))), intSeq, EmptyRow) + checkEvaluation(CreateArray(longSeq.map(Literal(_))), longSeq, EmptyRow) + checkEvaluation(CreateArray(strSeq.map(Literal(_))), strSeq, EmptyRow) + + val intWithNull = intSeq.map(Literal(_)) :+ Literal.create(null, IntegerType) + val longWithNull = longSeq.map(Literal(_)) :+ Literal.create(null, LongType) + val strWithNull = strSeq.map(Literal(_)) :+ Literal.create(null, StringType) + checkEvaluation(CreateArray(intWithNull), intSeq :+ null, EmptyRow) + checkEvaluation(CreateArray(longWithNull), longSeq :+ null, EmptyRow) + checkEvaluation(CreateArray(strWithNull), strSeq :+ null, EmptyRow) + } + test("CreateStruct") { val row = create_row(1, 2, 3) val c1 = 'a.int.at(0) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ConditionalExpressionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ConditionalExpressionSuite.scala index aaf40cc83e762..afa143bd5f331 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ConditionalExpressionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ConditionalExpressionSuite.scala @@ -125,7 +125,7 @@ class ConditionalExpressionSuite extends SparkFunSuite with ExpressionEvalHelper val literalString = Literal("a") checkEvaluation(CaseKeyWhen(c1, Seq(c2, c4, c5)), "b", row) - checkEvaluation(CaseKeyWhen(c1, Seq(c2, c4, literalNull, c5, c6)), "b", row) + checkEvaluation(CaseKeyWhen(c1, Seq(c2, c4, literalNull, c5, c6)), "c", row) checkEvaluation(CaseKeyWhen(c2, Seq(literalInt, c4, c5)), "a", row) checkEvaluation(CaseKeyWhen(c2, Seq(c1, c4, c5)), "b", row) checkEvaluation(CaseKeyWhen(c4, Seq(literalString, c2, c3)), 1, row) @@ -134,7 +134,7 @@ class ConditionalExpressionSuite extends SparkFunSuite with ExpressionEvalHelper checkEvaluation(CaseKeyWhen(literalInt, Seq(c2, c4, c5)), "a", row) checkEvaluation(CaseKeyWhen(literalString, Seq(c5, c2, c4, c3)), 2, row) checkEvaluation(CaseKeyWhen(c6, Seq(c5, c2, c4, c3)), null, row) - checkEvaluation(CaseKeyWhen(literalNull, Seq(c2, c5, c1, c6)), "c", row) + checkEvaluation(CaseKeyWhen(literalNull, Seq(c2, c5, c1, c6)), null, row) } test("function least") { @@ -144,35 +144,35 @@ class ConditionalExpressionSuite extends SparkFunSuite with ExpressionEvalHelper val c3 = 'a.string.at(2) val c4 = 'a.string.at(3) val c5 = 'a.string.at(4) - checkEvaluation(Least(c4, c3, c5), "a", row) - checkEvaluation(Least(c1, c2), 1, row) - checkEvaluation(Least(c1, c2, Literal(-1)), -1, row) - checkEvaluation(Least(c4, c5, c3, c3, Literal("a")), "a", row) - - checkEvaluation(Least(Literal(null), Literal(null)), null, InternalRow.empty) - checkEvaluation(Least(Literal(-1.0), Literal(2.5)), -1.0, InternalRow.empty) - checkEvaluation(Least(Literal(-1), Literal(2)), -1, InternalRow.empty) + checkEvaluation(Least(Seq(c4, c3, c5)), "a", row) + checkEvaluation(Least(Seq(c1, c2)), 1, row) + checkEvaluation(Least(Seq(c1, c2, Literal(-1))), -1, row) + checkEvaluation(Least(Seq(c4, c5, c3, c3, Literal("a"))), "a", row) + + checkEvaluation(Least(Seq(Literal(null), Literal(null))), null, InternalRow.empty) + checkEvaluation(Least(Seq(Literal(-1.0), Literal(2.5))), -1.0, InternalRow.empty) + checkEvaluation(Least(Seq(Literal(-1), Literal(2))), -1, InternalRow.empty) checkEvaluation( - Least(Literal((-1.0).toFloat), Literal(2.5.toFloat)), (-1.0).toFloat, InternalRow.empty) + Least(Seq(Literal((-1.0).toFloat), Literal(2.5.toFloat))), (-1.0).toFloat, InternalRow.empty) checkEvaluation( - Least(Literal(Long.MaxValue), Literal(Long.MinValue)), Long.MinValue, InternalRow.empty) - checkEvaluation(Least(Literal(1.toByte), Literal(2.toByte)), 1.toByte, InternalRow.empty) + Least(Seq(Literal(Long.MaxValue), Literal(Long.MinValue))), Long.MinValue, InternalRow.empty) + checkEvaluation(Least(Seq(Literal(1.toByte), Literal(2.toByte))), 1.toByte, InternalRow.empty) checkEvaluation( - Least(Literal(1.toShort), Literal(2.toByte.toShort)), 1.toShort, InternalRow.empty) - checkEvaluation(Least(Literal("abc"), Literal("aaaa")), "aaaa", InternalRow.empty) - checkEvaluation(Least(Literal(true), Literal(false)), false, InternalRow.empty) + Least(Seq(Literal(1.toShort), Literal(2.toByte.toShort))), 1.toShort, InternalRow.empty) + checkEvaluation(Least(Seq(Literal("abc"), Literal("aaaa"))), "aaaa", InternalRow.empty) + checkEvaluation(Least(Seq(Literal(true), Literal(false))), false, InternalRow.empty) checkEvaluation( - Least( + Least(Seq( Literal(BigDecimal("1234567890987654321123456")), - Literal(BigDecimal("1234567890987654321123458"))), + Literal(BigDecimal("1234567890987654321123458")))), BigDecimal("1234567890987654321123456"), InternalRow.empty) checkEvaluation( - Least(Literal(Date.valueOf("2015-01-01")), Literal(Date.valueOf("2015-07-01"))), + Least(Seq(Literal(Date.valueOf("2015-01-01")), Literal(Date.valueOf("2015-07-01")))), Date.valueOf("2015-01-01"), InternalRow.empty) checkEvaluation( - Least( + Least(Seq( Literal(Timestamp.valueOf("2015-07-01 08:00:00")), - Literal(Timestamp.valueOf("2015-07-01 10:00:00"))), + Literal(Timestamp.valueOf("2015-07-01 10:00:00")))), Timestamp.valueOf("2015-07-01 08:00:00"), InternalRow.empty) } @@ -183,35 +183,36 @@ class ConditionalExpressionSuite extends SparkFunSuite with ExpressionEvalHelper val c3 = 'a.string.at(2) val c4 = 'a.string.at(3) val c5 = 'a.string.at(4) - checkEvaluation(Greatest(c4, c5, c3), "c", row) - checkEvaluation(Greatest(c2, c1), 2, row) - checkEvaluation(Greatest(c1, c2, Literal(2)), 2, row) - checkEvaluation(Greatest(c4, c5, c3, Literal("ccc")), "ccc", row) - - checkEvaluation(Greatest(Literal(null), Literal(null)), null, InternalRow.empty) - checkEvaluation(Greatest(Literal(-1.0), Literal(2.5)), 2.5, InternalRow.empty) - checkEvaluation(Greatest(Literal(-1), Literal(2)), 2, InternalRow.empty) + checkEvaluation(Greatest(Seq(c4, c5, c3)), "c", row) + checkEvaluation(Greatest(Seq(c2, c1)), 2, row) + checkEvaluation(Greatest(Seq(c1, c2, Literal(2))), 2, row) + checkEvaluation(Greatest(Seq(c4, c5, c3, Literal("ccc"))), "ccc", row) + + checkEvaluation(Greatest(Seq(Literal(null), Literal(null))), null, InternalRow.empty) + checkEvaluation(Greatest(Seq(Literal(-1.0), Literal(2.5))), 2.5, InternalRow.empty) + checkEvaluation(Greatest(Seq(Literal(-1), Literal(2))), 2, InternalRow.empty) checkEvaluation( - Greatest(Literal((-1.0).toFloat), Literal(2.5.toFloat)), 2.5.toFloat, InternalRow.empty) + Greatest(Seq(Literal((-1.0).toFloat), Literal(2.5.toFloat))), 2.5.toFloat, InternalRow.empty) + checkEvaluation(Greatest( + Seq(Literal(Long.MaxValue), Literal(Long.MinValue))), Long.MaxValue, InternalRow.empty) checkEvaluation( - Greatest(Literal(Long.MaxValue), Literal(Long.MinValue)), Long.MaxValue, InternalRow.empty) - checkEvaluation(Greatest(Literal(1.toByte), Literal(2.toByte)), 2.toByte, InternalRow.empty) + Greatest(Seq(Literal(1.toByte), Literal(2.toByte))), 2.toByte, InternalRow.empty) checkEvaluation( - Greatest(Literal(1.toShort), Literal(2.toByte.toShort)), 2.toShort, InternalRow.empty) - checkEvaluation(Greatest(Literal("abc"), Literal("aaaa")), "abc", InternalRow.empty) - checkEvaluation(Greatest(Literal(true), Literal(false)), true, InternalRow.empty) + Greatest(Seq(Literal(1.toShort), Literal(2.toByte.toShort))), 2.toShort, InternalRow.empty) + checkEvaluation(Greatest(Seq(Literal("abc"), Literal("aaaa"))), "abc", InternalRow.empty) + checkEvaluation(Greatest(Seq(Literal(true), Literal(false))), true, InternalRow.empty) checkEvaluation( - Greatest( + Greatest(Seq( Literal(BigDecimal("1234567890987654321123456")), - Literal(BigDecimal("1234567890987654321123458"))), + Literal(BigDecimal("1234567890987654321123458")))), BigDecimal("1234567890987654321123458"), InternalRow.empty) - checkEvaluation( - Greatest(Literal(Date.valueOf("2015-01-01")), Literal(Date.valueOf("2015-07-01"))), + checkEvaluation(Greatest( + Seq(Literal(Date.valueOf("2015-01-01")), Literal(Date.valueOf("2015-07-01")))), Date.valueOf("2015-07-01"), InternalRow.empty) checkEvaluation( - Greatest( + Greatest(Seq( Literal(Timestamp.valueOf("2015-07-01 08:00:00")), - Literal(Timestamp.valueOf("2015-07-01 10:00:00"))), + Literal(Timestamp.valueOf("2015-07-01 10:00:00")))), Timestamp.valueOf("2015-07-01 10:00:00"), InternalRow.empty) } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/DateExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/DateExpressionsSuite.scala new file mode 100644 index 0000000000000..bdba6ce891386 --- /dev/null +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/DateExpressionsSuite.scala @@ -0,0 +1,249 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.expressions + +import java.sql.{Timestamp, Date} +import java.text.SimpleDateFormat +import java.util.Calendar + +import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.types.{StringType, TimestampType, DateType} + +class DateExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { + + val sdf = new SimpleDateFormat("yyyy-MM-dd HH:mm:ss") + val sdfDate = new SimpleDateFormat("yyyy-MM-dd") + val d = new Date(sdf.parse("2015-04-08 13:10:15").getTime) + val ts = new Timestamp(sdf.parse("2013-11-08 13:10:15").getTime) + + test("DayOfYear") { + val sdfDay = new SimpleDateFormat("D") + (2002 to 2004).foreach { y => + (0 to 11).foreach { m => + (0 to 5).foreach { i => + val c = Calendar.getInstance() + c.set(y, m, 28, 0, 0, 0) + c.add(Calendar.DATE, i) + checkEvaluation(DayOfYear(Literal(new Date(c.getTimeInMillis))), + sdfDay.format(c.getTime).toInt) + } + } + } + + (1998 to 2002).foreach { y => + (0 to 11).foreach { m => + (0 to 5).foreach { i => + val c = Calendar.getInstance() + c.set(y, m, 28, 0, 0, 0) + c.add(Calendar.DATE, i) + checkEvaluation(DayOfYear(Literal(new Date(c.getTimeInMillis))), + sdfDay.format(c.getTime).toInt) + } + } + } + + (1969 to 1970).foreach { y => + (0 to 11).foreach { m => + (0 to 5).foreach { i => + val c = Calendar.getInstance() + c.set(y, m, 28, 0, 0, 0) + c.add(Calendar.DATE, i) + checkEvaluation(DayOfYear(Literal(new Date(c.getTimeInMillis))), + sdfDay.format(c.getTime).toInt) + } + } + } + + (2402 to 2404).foreach { y => + (0 to 11).foreach { m => + (0 to 5).foreach { i => + val c = Calendar.getInstance() + c.set(y, m, 28, 0, 0, 0) + c.add(Calendar.DATE, i) + checkEvaluation(DayOfYear(Literal(new Date(c.getTimeInMillis))), + sdfDay.format(c.getTime).toInt) + } + } + } + + (2398 to 2402).foreach { y => + (0 to 11).foreach { m => + (0 to 5).foreach { i => + val c = Calendar.getInstance() + c.set(y, m, 28, 0, 0, 0) + c.add(Calendar.DATE, i) + checkEvaluation(DayOfYear(Literal(new Date(c.getTimeInMillis))), + sdfDay.format(c.getTime).toInt) + } + } + } + } + + test("Year") { + checkEvaluation(Year(Literal.create(null, DateType)), null) + checkEvaluation(Year(Literal(d)), 2015) + checkEvaluation(Year(Cast(Literal(sdfDate.format(d)), DateType)), 2015) + checkEvaluation(Year(Cast(Literal(ts), DateType)), 2013) + + val c = Calendar.getInstance() + (2000 to 2010).foreach { y => + (0 to 11 by 11).foreach { m => + c.set(y, m, 28) + (0 to 5 * 24).foreach { i => + c.add(Calendar.HOUR_OF_DAY, 1) + checkEvaluation(Year(Literal(new Date(c.getTimeInMillis))), + c.get(Calendar.YEAR)) + } + } + } + } + + test("Quarter") { + checkEvaluation(Quarter(Literal.create(null, DateType)), null) + checkEvaluation(Quarter(Literal(d)), 2) + checkEvaluation(Quarter(Cast(Literal(sdfDate.format(d)), DateType)), 2) + checkEvaluation(Quarter(Cast(Literal(ts), DateType)), 4) + + val c = Calendar.getInstance() + (2003 to 2004).foreach { y => + (0 to 11 by 3).foreach { m => + c.set(y, m, 28, 0, 0, 0) + (0 to 5 * 24).foreach { i => + c.add(Calendar.HOUR_OF_DAY, 1) + checkEvaluation(Quarter(Literal(new Date(c.getTimeInMillis))), + c.get(Calendar.MONTH) / 3 + 1) + } + } + } + } + + test("Month") { + checkEvaluation(Month(Literal.create(null, DateType)), null) + checkEvaluation(Month(Literal(d)), 4) + checkEvaluation(Month(Cast(Literal(sdfDate.format(d)), DateType)), 4) + checkEvaluation(Month(Cast(Literal(ts), DateType)), 11) + + (2003 to 2004).foreach { y => + (0 to 11).foreach { m => + (0 to 5 * 24).foreach { i => + val c = Calendar.getInstance() + c.set(y, m, 28, 0, 0, 0) + c.add(Calendar.HOUR_OF_DAY, i) + checkEvaluation(Month(Literal(new Date(c.getTimeInMillis))), + c.get(Calendar.MONTH) + 1) + } + } + } + + (1999 to 2000).foreach { y => + (0 to 11).foreach { m => + (0 to 5 * 24).foreach { i => + val c = Calendar.getInstance() + c.set(y, m, 28, 0, 0, 0) + c.add(Calendar.HOUR_OF_DAY, i) + checkEvaluation(Month(Literal(new Date(c.getTimeInMillis))), + c.get(Calendar.MONTH) + 1) + } + } + } + } + + test("Day / DayOfMonth") { + checkEvaluation(DayOfMonth(Cast(Literal("2000-02-29"), DateType)), 29) + checkEvaluation(DayOfMonth(Literal.create(null, DateType)), null) + checkEvaluation(DayOfMonth(Literal(d)), 8) + checkEvaluation(DayOfMonth(Cast(Literal(sdfDate.format(d)), DateType)), 8) + checkEvaluation(DayOfMonth(Cast(Literal(ts), DateType)), 8) + + (1999 to 2000).foreach { y => + val c = Calendar.getInstance() + c.set(y, 0, 1, 0, 0, 0) + (0 to 365).foreach { d => + c.add(Calendar.DATE, 1) + checkEvaluation(DayOfMonth(Literal(new Date(c.getTimeInMillis))), + c.get(Calendar.DAY_OF_MONTH)) + } + } + } + + test("Seconds") { + checkEvaluation(Second(Literal.create(null, DateType)), null) + checkEvaluation(Second(Cast(Literal(d), TimestampType)), 0) + checkEvaluation(Second(Cast(Literal(sdf.format(d)), TimestampType)), 15) + checkEvaluation(Second(Literal(ts)), 15) + + val c = Calendar.getInstance() + (0 to 60 by 5).foreach { s => + c.set(2015, 18, 3, 3, 5, s) + checkEvaluation(Second(Literal(new Timestamp(c.getTimeInMillis))), + c.get(Calendar.SECOND)) + } + } + + test("WeekOfYear") { + checkEvaluation(WeekOfYear(Literal.create(null, DateType)), null) + checkEvaluation(WeekOfYear(Literal(d)), 15) + checkEvaluation(WeekOfYear(Cast(Literal(sdfDate.format(d)), DateType)), 15) + checkEvaluation(WeekOfYear(Cast(Literal(ts), DateType)), 45) + checkEvaluation(WeekOfYear(Cast(Literal("2011-05-06"), DateType)), 18) + } + + test("DateFormat") { + checkEvaluation(DateFormatClass(Literal.create(null, TimestampType), Literal("y")), null) + checkEvaluation(DateFormatClass(Cast(Literal(d), TimestampType), + Literal.create(null, StringType)), null) + checkEvaluation(DateFormatClass(Cast(Literal(d), TimestampType), + Literal("y")), "2015") + checkEvaluation(DateFormatClass(Literal(ts), Literal("y")), "2013") + } + + test("Hour") { + checkEvaluation(Hour(Literal.create(null, DateType)), null) + checkEvaluation(Hour(Cast(Literal(d), TimestampType)), 0) + checkEvaluation(Hour(Cast(Literal(sdf.format(d)), TimestampType)), 13) + checkEvaluation(Hour(Literal(ts)), 13) + + val c = Calendar.getInstance() + (0 to 24).foreach { h => + (0 to 60 by 15).foreach { m => + (0 to 60 by 15).foreach { s => + c.set(2015, 18, 3, h, m, s) + checkEvaluation(Hour(Literal(new Timestamp(c.getTimeInMillis))), + c.get(Calendar.HOUR_OF_DAY)) + } + } + } + } + + test("Minute") { + checkEvaluation(Minute(Literal.create(null, DateType)), null) + checkEvaluation(Minute(Cast(Literal(d), TimestampType)), 0) + checkEvaluation(Minute(Cast(Literal(sdf.format(d)), TimestampType)), 10) + checkEvaluation(Minute(Literal(ts)), 10) + + val c = Calendar.getInstance() + (0 to 60 by 5).foreach { m => + (0 to 60 by 15).foreach { s => + c.set(2015, 18, 3, 3, m, s) + checkEvaluation(Minute(Literal(new Timestamp(c.getTimeInMillis))), + c.get(Calendar.MINUTE)) + } + } + } + +} diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala index 43392df4bec2e..6e17ffcda9dc4 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala @@ -23,7 +23,7 @@ import org.scalatest.Matchers._ import org.apache.spark.SparkFunSuite import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.CatalystTypeConverters -import org.apache.spark.sql.catalyst.expressions.codegen.{GenerateProjection, GenerateMutableProjection} +import org.apache.spark.sql.catalyst.expressions.codegen._ import org.apache.spark.sql.catalyst.optimizer.DefaultOptimizer import org.apache.spark.sql.catalyst.plans.logical.{OneRowRelation, Project} @@ -38,22 +38,27 @@ trait ExpressionEvalHelper { } protected def checkEvaluation( - expression: Expression, expected: Any, inputRow: InternalRow = EmptyRow): Unit = { + expression: => Expression, expected: Any, inputRow: InternalRow = EmptyRow): Unit = { val catalystValue = CatalystTypeConverters.convertToCatalyst(expected) checkEvaluationWithoutCodegen(expression, catalystValue, inputRow) checkEvaluationWithGeneratedMutableProjection(expression, catalystValue, inputRow) checkEvaluationWithGeneratedProjection(expression, catalystValue, inputRow) + if (UnsafeColumnWriter.canEmbed(expression.dataType)) { + checkEvalutionWithUnsafeProjection(expression, catalystValue, inputRow) + } checkEvaluationWithOptimization(expression, catalystValue, inputRow) } /** * Check the equality between result of expression and expected value, it will handle - * Array[Byte]. + * Array[Byte] and Spread[Double]. */ protected def checkResult(result: Any, expected: Any): Boolean = { (result, expected) match { case (result: Array[Byte], expected: Array[Byte]) => java.util.Arrays.equals(result, expected) + case (result: Double, expected: Spread[Double]) => + expected.isWithin(result) case _ => result == expected } } @@ -62,10 +67,28 @@ trait ExpressionEvalHelper { expression.eval(inputRow) } + protected def generateProject( + generator: => Projection, + expression: Expression): Projection = { + try { + generator + } catch { + case e: Throwable => + val ctx = new CodeGenContext + val evaluated = expression.gen(ctx) + fail( + s""" + |Code generation of $expression failed: + |$e + """.stripMargin) + } + } + protected def checkEvaluationWithoutCodegen( expression: Expression, expected: Any, inputRow: InternalRow = EmptyRow): Unit = { + val actual = try evaluate(expression, inputRow) catch { case e: Exception => fail(s"Exception evaluating $expression", e) } @@ -82,21 +105,11 @@ trait ExpressionEvalHelper { expected: Any, inputRow: InternalRow = EmptyRow): Unit = { - val plan = try { - GenerateMutableProjection.generate(Alias(expression, s"Optimized($expression)")() :: Nil)() - } catch { - case e: Throwable => - val ctx = GenerateProjection.newCodeGenContext() - val evaluated = expression.gen(ctx) - fail( - s""" - |Code generation of $expression failed: - |${evaluated.code} - |$e - """.stripMargin) - } + val plan = generateProject( + GenerateMutableProjection.generate(Alias(expression, s"Optimized($expression)")() :: Nil)(), + expression) - val actual = plan(inputRow).apply(0) + val actual = plan(inputRow).get(0) if (!checkResult(actual, expected)) { val input = if (inputRow == EmptyRow) "" else s", input: $inputRow" fail(s"Incorrect Evaluation: $expression, actual: $actual, expected: $expected$input") @@ -107,24 +120,19 @@ trait ExpressionEvalHelper { expression: Expression, expected: Any, inputRow: InternalRow = EmptyRow): Unit = { - val ctx = GenerateProjection.newCodeGenContext() - lazy val evaluated = expression.gen(ctx) - val plan = try { - GenerateProjection.generate(Alias(expression, s"Optimized($expression)")() :: Nil) - } catch { - case e: Throwable => - fail( - s""" - |Code generation of $expression failed: - |${evaluated.code} - |$e - """.stripMargin) - } + val plan = generateProject( + GenerateProjection.generate(Alias(expression, s"Optimized($expression)")() :: Nil), + expression) val actual = plan(inputRow) val expectedRow = InternalRow(expected) + + // We reimplement hashCode in generated `SpecificRow`, make sure it's consistent with our + // interpreted version. if (actual.hashCode() != expectedRow.hashCode()) { + val ctx = new CodeGenContext + val evaluated = expression.gen(ctx) fail( s""" |Mismatched hashCodes for values: $actual, $expectedRow @@ -133,15 +141,35 @@ trait ExpressionEvalHelper { |Code: $evaluated """.stripMargin) } + if (actual != expectedRow) { val input = if (inputRow == EmptyRow) "" else s", input: $inputRow" - fail(s"Incorrect Evaluation: $expression, actual: $actual, expected: $expected$input") + fail(s"Incorrect Evaluation: $expression, actual: $actual, expected: $expectedRow$input") } if (actual.copy() != expectedRow) { fail(s"Copy of generated Row is wrong: actual: ${actual.copy()}, expected: $expectedRow") } } + protected def checkEvalutionWithUnsafeProjection( + expression: Expression, + expected: Any, + inputRow: InternalRow = EmptyRow): Unit = { + + val plan = generateProject( + GenerateUnsafeProjection.generate(Alias(expression, s"Optimized($expression)")() :: Nil), + expression) + + val unsafeRow = plan(inputRow) + // UnsafeRow cannot be compared with GenericInternalRow directly + val actual = FromUnsafeProjection(expression.dataType :: Nil)(unsafeRow) + val expectedRow = InternalRow(expected) + if (actual != expectedRow) { + val input = if (inputRow == EmptyRow) "" else s", input: $inputRow" + fail(s"Incorrect Evaluation: $expression, actual: $actual, expected: $expectedRow$input") + } + } + protected def checkEvaluationWithOptimization( expression: Expression, expected: Any, @@ -152,12 +180,23 @@ trait ExpressionEvalHelper { } protected def checkDoubleEvaluation( - expression: Expression, + expression: => Expression, expected: Spread[Double], inputRow: InternalRow = EmptyRow): Unit = { - val actual = try evaluate(expression, inputRow) catch { - case e: Exception => fail(s"Exception evaluating $expression", e) - } - actual.asInstanceOf[Double] shouldBe expected + checkEvaluationWithoutCodegen(expression, expected) + checkEvaluationWithGeneratedMutableProjection(expression, expected) + checkEvaluationWithOptimization(expression, expected) + + var plan = generateProject( + GenerateProjection.generate(Alias(expression, s"Optimized($expression)")() :: Nil), + expression) + var actual = plan(inputRow).get(0) + assert(checkResult(actual, expected)) + + plan = generateProject( + GenerateUnsafeProjection.generate(Alias(expression, s"Optimized($expression)")() :: Nil), + expression) + actual = FromUnsafeProjection(expression.dataType :: Nil)(plan(inputRow)).get(0) + assert(checkResult(actual, expected)) } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MathFunctionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MathFunctionsSuite.scala index 7ca9e30b2bcd5..a2b0fad7b7a04 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MathFunctionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MathFunctionsSuite.scala @@ -21,8 +21,13 @@ import com.google.common.math.LongMath import org.apache.spark.SparkFunSuite import org.apache.spark.sql.catalyst.dsl.expressions._ +import org.apache.spark.sql.catalyst.expressions.codegen.{GenerateProjection, GenerateMutableProjection} +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.optimizer.DefaultOptimizer +import org.apache.spark.sql.catalyst.plans.logical.{OneRowRelation, Project} import org.apache.spark.sql.types._ + class MathFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper { /** @@ -46,6 +51,7 @@ class MathFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper { * @param f The functions in scala.math or elsewhere used to generate expected results * @param domain The set of values to run the function with * @param expectNull Whether the given values should return null or not + * @param expectNaN Whether the given values should eval to NaN or not * @tparam T Generic type for primitives * @tparam U Generic type for the output of the given function `f` */ @@ -54,11 +60,16 @@ class MathFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper { f: T => U, domain: Iterable[T] = (-20 to 20).map(_ * 0.1), expectNull: Boolean = false, + expectNaN: Boolean = false, evalType: DataType = DoubleType): Unit = { if (expectNull) { domain.foreach { value => checkEvaluation(c(Literal(value)), null, EmptyRow) } + } else if (expectNaN) { + domain.foreach { value => + checkNaN(c(Literal(value)), EmptyRow) + } } else { domain.foreach { value => checkEvaluation(c(Literal(value)), f(value), EmptyRow) @@ -73,16 +84,22 @@ class MathFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper { * @param c The DataFrame function * @param f The functions in scala.math * @param domain The set of values to run the function with + * @param expectNull Whether the given values should return null or not + * @param expectNaN Whether the given values should eval to NaN or not */ private def testBinary( c: (Expression, Expression) => Expression, f: (Double, Double) => Double, domain: Iterable[(Double, Double)] = (-20 to 20).map(v => (v * 0.1, v * -0.1)), - expectNull: Boolean = false): Unit = { + expectNull: Boolean = false, expectNaN: Boolean = false): Unit = { if (expectNull) { domain.foreach { case (v1, v2) => checkEvaluation(c(Literal(v1), Literal(v2)), null, create_row(null)) } + } else if (expectNaN) { + domain.foreach { case (v1, v2) => + checkNaN(c(Literal(v1), Literal(v2)), EmptyRow) + } } else { domain.foreach { case (v1, v2) => checkEvaluation(c(Literal(v1), Literal(v2)), f(v1 + 0.0, v2 + 0.0), EmptyRow) @@ -93,6 +110,68 @@ class MathFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper { checkEvaluation(c(Literal(1.0), Literal.create(null, DoubleType)), null, create_row(null)) } + test("conv") { + checkEvaluation(Conv(Literal("3"), Literal(10), Literal(2)), "11") + checkEvaluation(Conv(Literal("-15"), Literal(10), Literal(-16)), "-F") + checkEvaluation(Conv(Literal("-15"), Literal(10), Literal(16)), "FFFFFFFFFFFFFFF1") + checkEvaluation(Conv(Literal("big"), Literal(36), Literal(16)), "3A48") + checkEvaluation(Conv(Literal.create(null, StringType), Literal(36), Literal(16)), null) + checkEvaluation(Conv(Literal("3"), Literal.create(null, IntegerType), Literal(16)), null) + checkEvaluation( + Conv(Literal("1234"), Literal(10), Literal(37)), null) + checkEvaluation( + Conv(Literal(""), Literal(10), Literal(16)), null) + checkEvaluation( + Conv(Literal("9223372036854775807"), Literal(36), Literal(16)), "FFFFFFFFFFFFFFFF") + // If there is an invalid digit in the number, the longest valid prefix should be converted. + checkEvaluation( + Conv(Literal("11abc"), Literal(10), Literal(16)), "B") + } + + private def checkNaN( + expression: Expression, inputRow: InternalRow = EmptyRow): Unit = { + checkNaNWithoutCodegen(expression, inputRow) + checkNaNWithGeneratedProjection(expression, inputRow) + checkNaNWithOptimization(expression, inputRow) + } + + private def checkNaNWithoutCodegen( + expression: Expression, + expected: Any, + inputRow: InternalRow = EmptyRow): Unit = { + val actual = try evaluate(expression, inputRow) catch { + case e: Exception => fail(s"Exception evaluating $expression", e) + } + if (!actual.asInstanceOf[Double].isNaN) { + fail(s"Incorrect evaluation (codegen off): $expression, " + + s"actual: $actual, " + + s"expected: NaN") + } + } + + + private def checkNaNWithGeneratedProjection( + expression: Expression, + inputRow: InternalRow = EmptyRow): Unit = { + + val plan = generateProject( + GenerateMutableProjection.generate(Alias(expression, s"Optimized($expression)")() :: Nil)(), + expression) + + val actual = plan(inputRow).apply(0) + if (!actual.asInstanceOf[Double].isNaN) { + fail(s"Incorrect Evaluation: $expression, actual: $actual, expected: NaN") + } + } + + private def checkNaNWithOptimization( + expression: Expression, + inputRow: InternalRow = EmptyRow): Unit = { + val plan = Project(Alias(expression, s"Optimized($expression)")() :: Nil, OneRowRelation) + val optimizedPlan = DefaultOptimizer.execute(plan) + checkNaNWithoutCodegen(optimizedPlan.expressions.head, inputRow) + } + test("e") { testLeaf(EulerNumber, math.E) } @@ -107,7 +186,7 @@ class MathFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper { test("asin") { testUnary(Asin, math.asin, (-10 to 10).map(_ * 0.1)) - testUnary(Asin, math.asin, (11 to 20).map(_ * 0.1), expectNull = true) + testUnary(Asin, math.asin, (11 to 20).map(_ * 0.1), expectNaN = true) } test("sinh") { @@ -120,7 +199,7 @@ class MathFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper { test("acos") { testUnary(Acos, math.acos, (-10 to 10).map(_ * 0.1)) - testUnary(Acos, math.acos, (11 to 20).map(_ * 0.1), expectNull = true) + testUnary(Acos, math.acos, (11 to 20).map(_ * 0.1), expectNaN = true) } test("cosh") { @@ -185,18 +264,18 @@ class MathFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper { } test("log") { - testUnary(Log, math.log, (0 to 20).map(_ * 0.1)) - testUnary(Log, math.log, (-5 to -1).map(_ * 0.1), expectNull = true) + testUnary(Log, math.log, (1 to 20).map(_ * 0.1)) + testUnary(Log, math.log, (-5 to 0).map(_ * 0.1), expectNull = true) } test("log10") { - testUnary(Log10, math.log10, (0 to 20).map(_ * 0.1)) - testUnary(Log10, math.log10, (-5 to -1).map(_ * 0.1), expectNull = true) + testUnary(Log10, math.log10, (1 to 20).map(_ * 0.1)) + testUnary(Log10, math.log10, (-5 to 0).map(_ * 0.1), expectNull = true) } test("log1p") { - testUnary(Log1p, math.log1p, (-1 to 20).map(_ * 0.1)) - testUnary(Log1p, math.log1p, (-10 to -2).map(_ * 1.0), expectNull = true) + testUnary(Log1p, math.log1p, (0 to 20).map(_ * 0.1)) + testUnary(Log1p, math.log1p, (-10 to -1).map(_ * 1.0), expectNull = true) } test("bin") { @@ -218,22 +297,22 @@ class MathFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper { test("log2") { def f: (Double) => Double = (x: Double) => math.log(x) / math.log(2) - testUnary(Log2, f, (0 to 20).map(_ * 0.1)) - testUnary(Log2, f, (-5 to -1).map(_ * 1.0), expectNull = true) + testUnary(Log2, f, (1 to 20).map(_ * 0.1)) + testUnary(Log2, f, (-5 to 0).map(_ * 1.0), expectNull = true) } test("sqrt") { testUnary(Sqrt, math.sqrt, (0 to 20).map(_ * 0.1)) - testUnary(Sqrt, math.sqrt, (-5 to -1).map(_ * 1.0), expectNull = true) + testUnary(Sqrt, math.sqrt, (-5 to -1).map(_ * 1.0), expectNaN = true) checkEvaluation(Sqrt(Literal.create(null, DoubleType)), null, create_row(null)) - checkEvaluation(Sqrt(Literal(-1.0)), null, EmptyRow) - checkEvaluation(Sqrt(Literal(-1.5)), null, EmptyRow) + checkNaN(Sqrt(Literal(-1.0)), EmptyRow) + checkNaN(Sqrt(Literal(-1.5)), EmptyRow) } test("pow") { testBinary(Pow, math.pow, (-5 to 5).map(v => (v * 1.0, v * 1.0))) - testBinary(Pow, math.pow, Seq((-1.0, 0.9), (-2.2, 1.7), (-2.2, -1.7)), expectNull = true) + testBinary(Pow, math.pow, Seq((-1.0, 0.9), (-2.2, 1.7), (-2.2, -1.7)), expectNaN = true) } test("shift left") { @@ -336,4 +415,46 @@ class MathFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper { null, create_row(null)) } + + test("round") { + val domain = -6 to 6 + val doublePi: Double = math.Pi + val shortPi: Short = 31415 + val intPi: Int = 314159265 + val longPi: Long = 31415926535897932L + val bdPi: BigDecimal = BigDecimal(31415927L, 7) + + val doubleResults: Seq[Double] = Seq(0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 3.0, 3.1, 3.14, 3.142, + 3.1416, 3.14159, 3.141593) + + val shortResults: Seq[Short] = Seq[Short](0, 0, 30000, 31000, 31400, 31420) ++ + Seq.fill[Short](7)(31415) + + val intResults: Seq[Int] = Seq(314000000, 314200000, 314160000, 314159000, 314159300, + 314159270) ++ Seq.fill(7)(314159265) + + val longResults: Seq[Long] = Seq(31415926536000000L, 31415926535900000L, + 31415926535900000L, 31415926535898000L, 31415926535897900L, 31415926535897930L) ++ + Seq.fill(7)(31415926535897932L) + + val bdResults: Seq[BigDecimal] = Seq(BigDecimal(3.0), BigDecimal(3.1), BigDecimal(3.14), + BigDecimal(3.142), BigDecimal(3.1416), BigDecimal(3.14159), + BigDecimal(3.141593), BigDecimal(3.1415927)) + + domain.zipWithIndex.foreach { case (scale, i) => + checkEvaluation(Round(doublePi, scale), doubleResults(i), EmptyRow) + checkEvaluation(Round(shortPi, scale), shortResults(i), EmptyRow) + checkEvaluation(Round(intPi, scale), intResults(i), EmptyRow) + checkEvaluation(Round(longPi, scale), longResults(i), EmptyRow) + } + + // round_scale > current_scale would result in precision increase + // and not allowed by o.a.s.s.types.Decimal.changePrecision, therefore null + (0 to 7).foreach { i => + checkEvaluation(Round(bdPi, i), bdResults(i), EmptyRow) + } + (8 to 10).foreach { scale => + checkEvaluation(Round(bdPi, scale), null, EmptyRow) + } + } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/NullFunctionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/NullFunctionsSuite.scala index ccdada8b56f83..0728f6695c39d 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/NullFunctionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/NullFunctionsSuite.scala @@ -18,48 +18,89 @@ package org.apache.spark.sql.catalyst.expressions import org.apache.spark.SparkFunSuite -import org.apache.spark.sql.catalyst.dsl.expressions._ -import org.apache.spark.sql.types.{BooleanType, StringType, ShortType} +import org.apache.spark.sql.types._ class NullFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper { - test("null checking") { - val row = create_row("^Ba*n", null, true, null) - val c1 = 'a.string.at(0) - val c2 = 'a.string.at(1) - val c3 = 'a.boolean.at(2) - val c4 = 'a.boolean.at(3) + def testAllTypes(testFunc: (Any, DataType) => Unit): Unit = { + testFunc(false, BooleanType) + testFunc(1.toByte, ByteType) + testFunc(1.toShort, ShortType) + testFunc(1, IntegerType) + testFunc(1L, LongType) + testFunc(1.0F, FloatType) + testFunc(1.0, DoubleType) + testFunc(Decimal(1.5), DecimalType.Unlimited) + testFunc(new java.sql.Date(10), DateType) + testFunc(new java.sql.Timestamp(10), TimestampType) + testFunc("abcd", StringType) + } + + test("isnull and isnotnull") { + testAllTypes { (value: Any, tpe: DataType) => + checkEvaluation(IsNull(Literal.create(value, tpe)), false) + checkEvaluation(IsNotNull(Literal.create(value, tpe)), true) + checkEvaluation(IsNull(Literal.create(null, tpe)), true) + checkEvaluation(IsNotNull(Literal.create(null, tpe)), false) + } + } - checkEvaluation(c1.isNull, false, row) - checkEvaluation(c1.isNotNull, true, row) + test("IsNaN") { + checkEvaluation(IsNaN(Literal(Double.NaN)), true) + checkEvaluation(IsNaN(Literal(Float.NaN)), true) + checkEvaluation(IsNaN(Literal(math.log(-3))), true) + checkEvaluation(IsNaN(Literal.create(null, DoubleType)), false) + checkEvaluation(IsNaN(Literal(Double.PositiveInfinity)), false) + checkEvaluation(IsNaN(Literal(Float.MaxValue)), false) + checkEvaluation(IsNaN(Literal(5.5f)), false) + } - checkEvaluation(c2.isNull, true, row) - checkEvaluation(c2.isNotNull, false, row) + test("nanvl") { + checkEvaluation(NaNvl(Literal(5.0), Literal.create(null, DoubleType)), 5.0) + checkEvaluation(NaNvl(Literal.create(null, DoubleType), Literal(5.0)), null) + checkEvaluation(NaNvl(Literal.create(null, DoubleType), Literal(Double.NaN)), null) + checkEvaluation(NaNvl(Literal(Double.NaN), Literal(5.0)), 5.0) + checkEvaluation(NaNvl(Literal(Double.NaN), Literal.create(null, DoubleType)), null) + assert(NaNvl(Literal(Double.NaN), Literal(Double.NaN)). + eval(EmptyRow).asInstanceOf[Double].isNaN) + } - checkEvaluation(Literal.create(1, ShortType).isNull, false) - checkEvaluation(Literal.create(1, ShortType).isNotNull, true) + test("coalesce") { + testAllTypes { (value: Any, tpe: DataType) => + val lit = Literal.create(value, tpe) + val nullLit = Literal.create(null, tpe) + checkEvaluation(Coalesce(Seq(nullLit)), null) + checkEvaluation(Coalesce(Seq(lit)), value) + checkEvaluation(Coalesce(Seq(nullLit, lit)), value) + checkEvaluation(Coalesce(Seq(nullLit, lit, lit)), value) + checkEvaluation(Coalesce(Seq(nullLit, nullLit, lit)), value) + } + } - checkEvaluation(Literal.create(null, ShortType).isNull, true) - checkEvaluation(Literal.create(null, ShortType).isNotNull, false) + test("AtLeastNNonNulls") { + val mix = Seq(Literal("x"), + Literal.create(null, StringType), + Literal.create(null, DoubleType), + Literal(Double.NaN), + Literal(5f)) - checkEvaluation(Coalesce(c1 :: c2 :: Nil), "^Ba*n", row) - checkEvaluation(Coalesce(Literal.create(null, StringType) :: Nil), null, row) - checkEvaluation(Coalesce(Literal.create(null, StringType) :: c1 :: c2 :: Nil), "^Ba*n", row) + val nanOnly = Seq(Literal("x"), + Literal(10.0), + Literal(Float.NaN), + Literal(math.log(-2)), + Literal(Double.MaxValue)) - checkEvaluation( - If(c3, Literal.create("a", StringType), Literal.create("b", StringType)), "a", row) - checkEvaluation(If(c3, c1, c2), "^Ba*n", row) - checkEvaluation(If(c4, c2, c1), "^Ba*n", row) - checkEvaluation(If(Literal.create(null, BooleanType), c2, c1), "^Ba*n", row) - checkEvaluation(If(Literal.create(true, BooleanType), c1, c2), "^Ba*n", row) - checkEvaluation(If(Literal.create(false, BooleanType), c2, c1), "^Ba*n", row) - checkEvaluation(If(Literal.create(false, BooleanType), - Literal.create("a", StringType), Literal.create("b", StringType)), "b", row) + val nullOnly = Seq(Literal("x"), + Literal.create(null, DoubleType), + Literal.create(null, DecimalType.Unlimited), + Literal(Float.MaxValue), + Literal(false)) - checkEvaluation(c1 in (c1, c2), true, row) - checkEvaluation( - Literal.create("^Ba*n", StringType) in (Literal.create("^Ba*n", StringType)), true, row) - checkEvaluation( - Literal.create("^Ba*n", StringType) in (Literal.create("^Ba*n", StringType), c2), true, row) + checkEvaluation(AtLeastNNonNulls(2, mix), true, EmptyRow) + checkEvaluation(AtLeastNNonNulls(3, mix), false, EmptyRow) + checkEvaluation(AtLeastNNonNulls(3, nanOnly), true, EmptyRow) + checkEvaluation(AtLeastNNonNulls(4, nanOnly), false, EmptyRow) + checkEvaluation(AtLeastNNonNulls(3, nullOnly), true, EmptyRow) + checkEvaluation(AtLeastNNonNulls(4, nullOnly), false, EmptyRow) } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/PredicateSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/PredicateSuite.scala index 188ecef9e7679..0bc2812a5dc83 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/PredicateSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/PredicateSuite.scala @@ -21,7 +21,7 @@ import scala.collection.immutable.HashSet import org.apache.spark.SparkFunSuite import org.apache.spark.sql.catalyst.dsl.expressions._ -import org.apache.spark.sql.types.{Decimal, IntegerType, BooleanType} +import org.apache.spark.sql.types.{Decimal, DoubleType, IntegerType, BooleanType} class PredicateSuite extends SparkFunSuite with ExpressionEvalHelper { @@ -114,6 +114,10 @@ class PredicateSuite extends SparkFunSuite with ExpressionEvalHelper { checkEvaluation( And(In(Literal(1), Seq(Literal(1), Literal(2))), In(Literal(2), Seq(Literal(1), Literal(2)))), true) + + checkEvaluation(In(Literal("^Ba*n"), Seq(Literal("^Ba*n"))), true) + checkEvaluation(In(Literal("^Ba*n"), Seq(Literal("aa"), Literal("^Ba*n"))), true) + checkEvaluation(In(Literal("^Ba*n"), Seq(Literal("aa"), Literal("^n"))), false) } test("INSET") { @@ -132,11 +136,14 @@ class PredicateSuite extends SparkFunSuite with ExpressionEvalHelper { checkEvaluation(And(InSet(one, hS), InSet(two, hS)), true) } - private val smallValues = Seq(1, Decimal(1), Array(1.toByte), "a").map(Literal(_)) - private val largeValues = Seq(2, Decimal(2), Array(2.toByte), "b").map(Literal(_)) + private val smallValues = Seq(1, Decimal(1), Array(1.toByte), "a", 0f, 0d).map(Literal(_)) + private val largeValues = + Seq(2, Decimal(2), Array(2.toByte), "b", Float.NaN, Double.NaN).map(Literal(_)) - private val equalValues1 = smallValues - private val equalValues2 = Seq(1, Decimal(1), Array(1.toByte), "a").map(Literal(_)) + private val equalValues1 = + Seq(1, Decimal(1), Array(1.toByte), "a", Float.NaN, Double.NaN).map(Literal(_)) + private val equalValues2 = + Seq(1, Decimal(1), Array(1.toByte), "a", Float.NaN, Double.NaN).map(Literal(_)) test("BinaryComparison: <") { for (i <- 0 until smallValues.length) { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/RandomSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/RandomSuite.scala index 9be2b23a53f27..698c81ba24482 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/RandomSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/RandomSuite.scala @@ -21,13 +21,13 @@ import org.scalatest.Matchers._ import org.apache.spark.SparkFunSuite import org.apache.spark.sql.catalyst.dsl.expressions._ -import org.apache.spark.sql.types.{DoubleType, IntegerType} +import org.apache.spark.sql.types.DoubleType class RandomSuite extends SparkFunSuite with ExpressionEvalHelper { test("random") { - val row = create_row(1.1, 2.0, 3.1, null) - checkDoubleEvaluation(Rand(30), (0.7363714192755834 +- 0.001), row) + checkDoubleEvaluation(Rand(30), 0.7363714192755834 +- 0.001) + checkDoubleEvaluation(Randn(30), 0.5181478766595276 +- 0.001) } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringFunctionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringExpressionsSuite.scala similarity index 74% rename from sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringFunctionsSuite.scala rename to sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringExpressionsSuite.scala index b19f4ee37a109..3d294fda5d103 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringFunctionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringExpressionsSuite.scala @@ -19,10 +19,61 @@ package org.apache.spark.sql.catalyst.expressions import org.apache.spark.SparkFunSuite import org.apache.spark.sql.catalyst.dsl.expressions._ -import org.apache.spark.sql.types.{BinaryType, IntegerType, StringType} +import org.apache.spark.sql.types._ -class StringFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper { +class StringExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { + + test("concat") { + def testConcat(inputs: String*): Unit = { + val expected = if (inputs.contains(null)) null else inputs.mkString + checkEvaluation(Concat(inputs.map(Literal.create(_, StringType))), expected, EmptyRow) + } + + testConcat() + testConcat(null) + testConcat("") + testConcat("ab") + testConcat("a", "b") + testConcat("a", "b", "C") + testConcat("a", null, "C") + testConcat("a", null, null) + testConcat(null, null, null) + + // scalastyle:off + // non ascii characters are not allowed in the code, so we disable the scalastyle here. + testConcat("数据", null, "砖头") + // scalastyle:on + } + + test("concat_ws") { + def testConcatWs(expected: String, sep: String, inputs: Any*): Unit = { + val inputExprs = inputs.map { + case s: Seq[_] => Literal.create(s, ArrayType(StringType)) + case null => Literal.create(null, StringType) + case s: String => Literal.create(s, StringType) + } + val sepExpr = Literal.create(sep, StringType) + checkEvaluation(ConcatWs(sepExpr +: inputExprs), expected, EmptyRow) + } + + // scalastyle:off + // non ascii characters are not allowed in the code, so we disable the scalastyle here. + testConcatWs(null, null) + testConcatWs(null, null, "a", "b") + testConcatWs("", "") + testConcatWs("ab", "哈哈", "ab") + testConcatWs("a哈哈b", "哈哈", "a", "b") + testConcatWs("a哈哈b", "哈哈", "a", null, "b") + testConcatWs("a哈哈b哈哈c", "哈哈", null, "a", null, "b", "c") + + testConcatWs("ab", "哈哈", Seq("ab")) + testConcatWs("a哈哈b", "哈哈", Seq("a", "b")) + testConcatWs("a哈哈b哈哈c哈哈d", "哈哈", Seq("a", null, "b"), null, "c", Seq(null, "d")) + testConcatWs("a哈哈b哈哈c", "哈哈", Seq("a", null, "b"), null, "c", Seq.empty[String]) + testConcatWs("a哈哈b哈哈c", "哈哈", Seq("a", null, "b"), null, "c", Seq[String](null)) + // scalastyle:on + } test("StringComparison") { val row = create_row("abc", null) @@ -216,15 +267,6 @@ class StringFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper { } } - test("length for string") { - val a = 'a.string.at(0) - checkEvaluation(StringLength(Literal("abc")), 3, create_row("abdef")) - checkEvaluation(StringLength(a), 5, create_row("abdef")) - checkEvaluation(StringLength(a), 0, create_row("")) - checkEvaluation(StringLength(a), null, create_row(null)) - checkEvaluation(StringLength(Literal.create(null, StringType)), null, create_row("abdef")) - } - test("ascii for string") { val a = 'a.string.at(0) checkEvaluation(Ascii(Literal("efg")), 101, create_row("abdef")) @@ -248,7 +290,7 @@ class StringFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper { checkEvaluation(Base64(b), "AQIDBA==", create_row(bytes)) checkEvaluation(Base64(b), "", create_row(Array[Byte]())) checkEvaluation(Base64(b), null, create_row(null)) - checkEvaluation(Base64(Literal.create(null, StringType)), null, create_row("abdef")) + checkEvaluation(Base64(Literal.create(null, BinaryType)), null, create_row("abdef")) checkEvaluation(UnBase64(a), null, create_row(null)) checkEvaluation(UnBase64(Literal.create(null, StringType)), null, create_row("abdef")) @@ -309,18 +351,16 @@ class StringFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper { } test("FORMAT") { - val f = 'f.string.at(0) - val d1 = 'd.int.at(1) - val s1 = 's.int.at(2) - - val row1 = create_row("aa%d%s", 12, "cc") - val row2 = create_row(null, 12, "cc") - checkEvaluation(StringFormat(Literal("aa%d%s"), Literal(123), Literal("a")), "aa123a", row1) - checkEvaluation(StringFormat(Literal("aa")), "aa", create_row(null)) - checkEvaluation(StringFormat(Literal("aa%d%s"), Literal(123), Literal("a")), "aa123a", row1) - - checkEvaluation(StringFormat(f, d1, s1), "aa12cc", row1) - checkEvaluation(StringFormat(f, d1, s1), null, row2) + checkEvaluation(FormatString(Literal("aa%d%s"), Literal(123), Literal("a")), "aa123a") + checkEvaluation(FormatString(Literal("aa")), "aa", create_row(null)) + checkEvaluation(FormatString(Literal("aa%d%s"), Literal(123), Literal("a")), "aa123a") + checkEvaluation(FormatString(Literal("aa%d%s"), 12, "cc"), "aa12cc") + + checkEvaluation(FormatString(Literal.create(null, StringType), 12, "cc"), null) + checkEvaluation( + FormatString(Literal("aa%d%s"), Literal.create(null, IntegerType), "cc"), "aanullcc") + checkEvaluation( + FormatString(Literal("aa%d%s"), 12, Literal.create(null, StringType)), "aa12null") } test("INSTR") { @@ -371,18 +411,24 @@ class StringFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper { val row1 = create_row("hi", 5, "??") val row2 = create_row("hi", 1, "?") val row3 = create_row(null, 1, "?") + val row4 = create_row("hi", null, "?") + val row5 = create_row("hi", 1, null) checkEvaluation(StringLPad(Literal("hi"), Literal(5), Literal("??")), "???hi", row1) checkEvaluation(StringLPad(Literal("hi"), Literal(1), Literal("??")), "h", row1) checkEvaluation(StringLPad(s1, s2, s3), "???hi", row1) checkEvaluation(StringLPad(s1, s2, s3), "h", row2) checkEvaluation(StringLPad(s1, s2, s3), null, row3) + checkEvaluation(StringLPad(s1, s2, s3), null, row4) + checkEvaluation(StringLPad(s1, s2, s3), null, row5) checkEvaluation(StringRPad(Literal("hi"), Literal(5), Literal("??")), "hi???", row1) checkEvaluation(StringRPad(Literal("hi"), Literal(1), Literal("??")), "h", row1) checkEvaluation(StringRPad(s1, s2, s3), "hi???", row1) checkEvaluation(StringRPad(s1, s2, s3), "h", row2) checkEvaluation(StringRPad(s1, s2, s3), null, row3) + checkEvaluation(StringRPad(s1, s2, s3), null, row4) + checkEvaluation(StringRPad(s1, s2, s3), null, row5) } test("REPEAT") { @@ -416,6 +462,41 @@ class StringFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper { checkEvaluation(StringSpace(s1), null, row2) } + test("RegexReplace") { + val row1 = create_row("100-200", "(\\d+)", "num") + val row2 = create_row("100-200", "(\\d+)", "###") + val row3 = create_row("100-200", "(-)", "###") + + val s = 's.string.at(0) + val p = 'p.string.at(1) + val r = 'r.string.at(2) + + val expr = RegExpReplace(s, p, r) + checkEvaluation(expr, "num-num", row1) + checkEvaluation(expr, "###-###", row2) + checkEvaluation(expr, "100###200", row3) + } + + test("RegexExtract") { + val row1 = create_row("100-200", "(\\d+)-(\\d+)", 1) + val row2 = create_row("100-200", "(\\d+)-(\\d+)", 2) + val row3 = create_row("100-200", "(\\d+).*", 1) + val row4 = create_row("100-200", "([a-z])", 1) + + val s = 's.string.at(0) + val p = 'p.string.at(1) + val r = 'r.int.at(2) + + val expr = RegExpExtract(s, p, r) + checkEvaluation(expr, "100", row1) + checkEvaluation(expr, "200", row2) + checkEvaluation(expr, "100", row3) + checkEvaluation(expr, "", row4) // will not match anything, empty string get + + val expr1 = new RegExpExtract(s, p) + checkEvaluation(expr1, "100", row1) + } + test("SPLIT") { val s1 = 'a.string.at(0) val s2 = 'b.string.at(1) @@ -426,4 +507,46 @@ class StringFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper { checkEvaluation( StringSplit(s1, s2), Seq("aa", "bb", "cc"), row1) } + + test("length for string / binary") { + val a = 'a.string.at(0) + val b = 'b.binary.at(0) + val bytes = Array[Byte](1, 2, 3, 1, 2) + val string = "abdef" + + // scalastyle:off + // non ascii characters are not allowed in the source code, so we disable the scalastyle. + checkEvaluation(Length(Literal("a花花c")), 4, create_row(string)) + // scalastyle:on + checkEvaluation(Length(Literal(bytes)), 5, create_row(Array[Byte]())) + + checkEvaluation(Length(a), 5, create_row(string)) + checkEvaluation(Length(b), 5, create_row(bytes)) + + checkEvaluation(Length(a), 0, create_row("")) + checkEvaluation(Length(b), 0, create_row(Array[Byte]())) + + checkEvaluation(Length(a), null, create_row(null)) + checkEvaluation(Length(b), null, create_row(null)) + + checkEvaluation(Length(Literal.create(null, StringType)), null, create_row(string)) + checkEvaluation(Length(Literal.create(null, BinaryType)), null, create_row(bytes)) + } + + test("number format") { + checkEvaluation(FormatNumber(Literal(4.asInstanceOf[Byte]), Literal(3)), "4.000") + checkEvaluation(FormatNumber(Literal(4.asInstanceOf[Short]), Literal(3)), "4.000") + checkEvaluation(FormatNumber(Literal(4.0f), Literal(3)), "4.000") + checkEvaluation(FormatNumber(Literal(4), Literal(3)), "4.000") + checkEvaluation(FormatNumber(Literal(12831273.23481d), Literal(3)), "12,831,273.235") + checkEvaluation(FormatNumber(Literal(12831273.83421d), Literal(0)), "12,831,274") + checkEvaluation(FormatNumber(Literal(123123324123L), Literal(3)), "123,123,324,123.000") + checkEvaluation(FormatNumber(Literal(123123324123L), Literal(-1)), null) + checkEvaluation( + FormatNumber( + Literal(Decimal(123123324123L) * Decimal(123123.21234d)), Literal(4)), + "15,159,339,180,002,773.2778") + checkEvaluation(FormatNumber(Literal.create(null, IntegerType), Literal(3)), null) + checkEvaluation(FormatNumber(Literal.create(null, NullType), Literal(3)), null) + } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/UnsafeFixedWidthAggregationMapSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/UnsafeFixedWidthAggregationMapSuite.scala index c9667e90a0aaa..7566cb59e34ee 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/UnsafeFixedWidthAggregationMapSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/UnsafeFixedWidthAggregationMapSuite.scala @@ -24,9 +24,8 @@ import org.scalatest.{BeforeAndAfterEach, Matchers} import org.apache.spark.SparkFunSuite import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.expressions.codegen.GenerateProjection import org.apache.spark.sql.types._ -import org.apache.spark.unsafe.memory.{ExecutorMemoryManager, MemoryAllocator, TaskMemoryManager} +import org.apache.spark.unsafe.memory.{ExecutorMemoryManager, TaskMemoryManager, MemoryAllocator} import org.apache.spark.unsafe.types.UTF8String @@ -35,10 +34,10 @@ class UnsafeFixedWidthAggregationMapSuite with Matchers with BeforeAndAfterEach { + import UnsafeFixedWidthAggregationMap._ + private val groupKeySchema = StructType(StructField("product", StringType) :: Nil) private val aggBufferSchema = StructType(StructField("salePrice", IntegerType) :: Nil) - private def emptyProjection: Projection = - GenerateProjection.generate(Seq(Literal(0)), Seq(AttributeReference("price", IntegerType)())) private def emptyAggregationBuffer: InternalRow = InternalRow(0) private var memoryManager: TaskMemoryManager = null @@ -54,11 +53,21 @@ class UnsafeFixedWidthAggregationMapSuite } } + test("supported schemas") { + assert(!supportsAggregationBufferSchema(StructType(StructField("x", StringType) :: Nil))) + assert(supportsGroupKeySchema(StructType(StructField("x", StringType) :: Nil))) + + assert( + !supportsAggregationBufferSchema(StructType(StructField("x", ArrayType(IntegerType)) :: Nil))) + assert( + !supportsGroupKeySchema(StructType(StructField("x", ArrayType(IntegerType)) :: Nil))) + } + test("empty map") { val map = new UnsafeFixedWidthAggregationMap( - emptyProjection, - new UnsafeRowConverter(groupKeySchema), - new UnsafeRowConverter(aggBufferSchema), + emptyAggregationBuffer, + aggBufferSchema, + groupKeySchema, memoryManager, 1024, // initial capacity false // disable perf metrics @@ -69,9 +78,9 @@ class UnsafeFixedWidthAggregationMapSuite test("updating values for a single key") { val map = new UnsafeFixedWidthAggregationMap( - emptyProjection, - new UnsafeRowConverter(groupKeySchema), - new UnsafeRowConverter(aggBufferSchema), + emptyAggregationBuffer, + aggBufferSchema, + groupKeySchema, memoryManager, 1024, // initial capacity false // disable perf metrics @@ -95,9 +104,9 @@ class UnsafeFixedWidthAggregationMapSuite test("inserting large random keys") { val map = new UnsafeFixedWidthAggregationMap( - emptyProjection, - new UnsafeRowConverter(groupKeySchema), - new UnsafeRowConverter(aggBufferSchema), + emptyAggregationBuffer, + aggBufferSchema, + groupKeySchema, memoryManager, 128, // initial capacity false // disable perf metrics @@ -112,36 +121,6 @@ class UnsafeFixedWidthAggregationMapSuite }.toSet seenKeys.size should be (groupKeys.size) seenKeys should be (groupKeys) - - map.free() - } - - test("with decimal in the key and values") { - val groupKeySchema = StructType(StructField("price", DecimalType(10, 0)) :: Nil) - val aggBufferSchema = StructType(StructField("amount", DecimalType.Unlimited) :: Nil) - val emptyProjection = GenerateProjection.generate(Seq(Literal(Decimal(0))), - Seq(AttributeReference("price", DecimalType.Unlimited)())) - val map = new UnsafeFixedWidthAggregationMap( - emptyProjection, - new UnsafeRowConverter(groupKeySchema), - new UnsafeRowConverter(aggBufferSchema), - memoryManager, - 1, // initial capacity - false // disable perf metrics - ) - - (0 until 100).foreach { i => - val groupKey = InternalRow(Decimal(i % 10)) - val row = map.getAggregationBuffer(groupKey) - row.update(0, Decimal(i)) - } - val seenKeys: Set[Int] = map.iterator().asScala.map { entry => - entry.key.getAs[Decimal](0).toInt - }.toSet - seenKeys.size should be (10) - seenKeys should be ((0 until 10).toSet) - - map.free() } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/UnsafeRowConverterSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/UnsafeRowConverterSuite.scala index d00aeb4dfbf47..8819234e78e60 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/UnsafeRowConverterSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/UnsafeRowConverterSuite.scala @@ -24,7 +24,7 @@ import org.scalatest.Matchers import org.apache.spark.SparkFunSuite import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.util.{ObjectPool, DateTimeUtils} +import org.apache.spark.sql.catalyst.util.DateTimeUtils import org.apache.spark.sql.types._ import org.apache.spark.unsafe.PlatformDependent import org.apache.spark.unsafe.array.ByteArrayMethods @@ -45,12 +45,11 @@ class UnsafeRowConverterSuite extends SparkFunSuite with Matchers { assert(sizeRequired === 8 + (3 * 8)) val buffer: Array[Long] = new Array[Long](sizeRequired / 8) val numBytesWritten = - converter.writeRow(row, buffer, PlatformDependent.LONG_ARRAY_OFFSET, sizeRequired, null) + converter.writeRow(row, buffer, PlatformDependent.LONG_ARRAY_OFFSET, sizeRequired) assert(numBytesWritten === sizeRequired) val unsafeRow = new UnsafeRow() - unsafeRow.pointTo( - buffer, PlatformDependent.LONG_ARRAY_OFFSET, fieldTypes.length, sizeRequired, null) + unsafeRow.pointTo(buffer, PlatformDependent.LONG_ARRAY_OFFSET, fieldTypes.length, sizeRequired) assert(unsafeRow.getLong(0) === 0) assert(unsafeRow.getLong(1) === 1) assert(unsafeRow.getInt(2) === 2) @@ -87,67 +86,15 @@ class UnsafeRowConverterSuite extends SparkFunSuite with Matchers { ByteArrayMethods.roundNumberOfBytesToNearestWord("World".getBytes.length)) val buffer: Array[Long] = new Array[Long](sizeRequired / 8) val numBytesWritten = converter.writeRow( - row, buffer, PlatformDependent.LONG_ARRAY_OFFSET, sizeRequired, null) + row, buffer, PlatformDependent.LONG_ARRAY_OFFSET, sizeRequired) assert(numBytesWritten === sizeRequired) val unsafeRow = new UnsafeRow() - val pool = new ObjectPool(10) unsafeRow.pointTo( - buffer, PlatformDependent.LONG_ARRAY_OFFSET, fieldTypes.length, sizeRequired, pool) + buffer, PlatformDependent.LONG_ARRAY_OFFSET, fieldTypes.length, sizeRequired) assert(unsafeRow.getLong(0) === 0) assert(unsafeRow.getString(1) === "Hello") - assert(unsafeRow.get(2) === "World".getBytes) - - unsafeRow.update(1, UTF8String.fromString("World")) - assert(unsafeRow.getString(1) === "World") - assert(pool.size === 0) - unsafeRow.update(1, UTF8String.fromString("Hello World")) - assert(unsafeRow.getString(1) === "Hello World") - assert(pool.size === 1) - - unsafeRow.update(2, "World".getBytes) - assert(unsafeRow.get(2) === "World".getBytes) - assert(pool.size === 1) - unsafeRow.update(2, "Hello World".getBytes) - assert(unsafeRow.get(2) === "Hello World".getBytes) - assert(pool.size === 2) - - // We do not support copy() for UnsafeRows that reference ObjectPools - intercept[UnsupportedOperationException] { - unsafeRow.copy() - } - } - - test("basic conversion with primitive, decimal and array") { - val fieldTypes: Array[DataType] = Array(LongType, DecimalType(10, 0), ArrayType(StringType)) - val converter = new UnsafeRowConverter(fieldTypes) - - val row = new SpecificMutableRow(fieldTypes) - row.setLong(0, 0) - row.update(1, Decimal(1)) - row.update(2, Array(2)) - - val pool = new ObjectPool(10) - val sizeRequired: Int = converter.getSizeRequirement(row) - assert(sizeRequired === 8 + (8 * 3)) - val buffer: Array[Long] = new Array[Long](sizeRequired / 8) - val numBytesWritten = - converter.writeRow(row, buffer, PlatformDependent.LONG_ARRAY_OFFSET, sizeRequired, pool) - assert(numBytesWritten === sizeRequired) - assert(pool.size === 2) - - val unsafeRow = new UnsafeRow() - unsafeRow.pointTo( - buffer, PlatformDependent.LONG_ARRAY_OFFSET, fieldTypes.length, sizeRequired, pool) - assert(unsafeRow.getLong(0) === 0) - assert(unsafeRow.get(1) === Decimal(1)) - assert(unsafeRow.get(2) === Array(2)) - - unsafeRow.update(1, Decimal(2)) - assert(unsafeRow.get(1) === Decimal(2)) - unsafeRow.update(2, Array(3, 4)) - assert(unsafeRow.get(2) === Array(3, 4)) - assert(pool.size === 2) + assert(unsafeRow.getBinary(2) === "World".getBytes) } test("basic conversion with primitive, string, date and timestamp types") { @@ -165,25 +112,25 @@ class UnsafeRowConverterSuite extends SparkFunSuite with Matchers { ByteArrayMethods.roundNumberOfBytesToNearestWord("Hello".getBytes.length)) val buffer: Array[Long] = new Array[Long](sizeRequired / 8) val numBytesWritten = - converter.writeRow(row, buffer, PlatformDependent.LONG_ARRAY_OFFSET, sizeRequired, null) + converter.writeRow(row, buffer, PlatformDependent.LONG_ARRAY_OFFSET, sizeRequired) assert(numBytesWritten === sizeRequired) val unsafeRow = new UnsafeRow() unsafeRow.pointTo( - buffer, PlatformDependent.LONG_ARRAY_OFFSET, fieldTypes.length, sizeRequired, null) + buffer, PlatformDependent.LONG_ARRAY_OFFSET, fieldTypes.length, sizeRequired) assert(unsafeRow.getLong(0) === 0) assert(unsafeRow.getString(1) === "Hello") // Date is represented as Int in unsafeRow assert(DateTimeUtils.toJavaDate(unsafeRow.getInt(2)) === Date.valueOf("1970-01-01")) // Timestamp is represented as Long in unsafeRow DateTimeUtils.toJavaTimestamp(unsafeRow.getLong(3)) should be - (Timestamp.valueOf("2015-05-08 08:10:25")) + (Timestamp.valueOf("2015-05-08 08:10:25")) unsafeRow.setInt(2, DateTimeUtils.fromJavaDate(Date.valueOf("2015-06-22"))) assert(DateTimeUtils.toJavaDate(unsafeRow.getInt(2)) === Date.valueOf("2015-06-22")) unsafeRow.setLong(3, DateTimeUtils.fromJavaTimestamp(Timestamp.valueOf("2015-06-22 08:10:25"))) DateTimeUtils.toJavaTimestamp(unsafeRow.getLong(3)) should be - (Timestamp.valueOf("2015-06-22 08:10:25")) + (Timestamp.valueOf("2015-06-22 08:10:25")) } test("null handling") { @@ -197,9 +144,9 @@ class UnsafeRowConverterSuite extends SparkFunSuite with Matchers { FloatType, DoubleType, StringType, - BinaryType, - DecimalType.Unlimited, - ArrayType(IntegerType) + BinaryType + // DecimalType.Unlimited, + // ArrayType(IntegerType) ) val converter = new UnsafeRowConverter(fieldTypes) @@ -215,14 +162,13 @@ class UnsafeRowConverterSuite extends SparkFunSuite with Matchers { val createdFromNullBuffer: Array[Long] = new Array[Long](sizeRequired / 8) val numBytesWritten = converter.writeRow( rowWithAllNullColumns, createdFromNullBuffer, PlatformDependent.LONG_ARRAY_OFFSET, - sizeRequired, null) + sizeRequired) assert(numBytesWritten === sizeRequired) val createdFromNull = new UnsafeRow() createdFromNull.pointTo( - createdFromNullBuffer, PlatformDependent.LONG_ARRAY_OFFSET, fieldTypes.length, - sizeRequired, null) - for (i <- 0 to fieldTypes.length - 1) { + createdFromNullBuffer, PlatformDependent.LONG_ARRAY_OFFSET, fieldTypes.length, sizeRequired) + for (i <- fieldTypes.indices) { assert(createdFromNull.isNullAt(i)) } assert(createdFromNull.getBoolean(1) === false) @@ -232,10 +178,10 @@ class UnsafeRowConverterSuite extends SparkFunSuite with Matchers { assert(createdFromNull.getLong(5) === 0) assert(java.lang.Float.isNaN(createdFromNull.getFloat(6))) assert(java.lang.Double.isNaN(createdFromNull.getDouble(7))) - assert(createdFromNull.getString(8) === null) - assert(createdFromNull.get(9) === null) - assert(createdFromNull.get(10) === null) - assert(createdFromNull.get(11) === null) + assert(createdFromNull.getUTF8String(8) === null) + assert(createdFromNull.getBinary(9) === null) + // assert(createdFromNull.get(10) === null) + // assert(createdFromNull.get(11) === null) // If we have an UnsafeRow with columns that are initially non-null and we null out those // columns, then the serialized row representation should be identical to what we would get by @@ -252,19 +198,18 @@ class UnsafeRowConverterSuite extends SparkFunSuite with Matchers { r.setDouble(7, 700) r.update(8, UTF8String.fromString("hello")) r.update(9, "world".getBytes) - r.update(10, Decimal(10)) - r.update(11, Array(11)) + // r.update(10, Decimal(10)) + // r.update(11, Array(11)) r } - val pool = new ObjectPool(1) val setToNullAfterCreationBuffer: Array[Long] = new Array[Long](sizeRequired / 8 + 2) converter.writeRow( rowWithNoNullColumns, setToNullAfterCreationBuffer, PlatformDependent.LONG_ARRAY_OFFSET, - sizeRequired, pool) + sizeRequired) val setToNullAfterCreation = new UnsafeRow() setToNullAfterCreation.pointTo( setToNullAfterCreationBuffer, PlatformDependent.LONG_ARRAY_OFFSET, fieldTypes.length, - sizeRequired, pool) + sizeRequired) assert(setToNullAfterCreation.isNullAt(0) === rowWithNoNullColumns.isNullAt(0)) assert(setToNullAfterCreation.getBoolean(1) === rowWithNoNullColumns.getBoolean(1)) @@ -275,14 +220,11 @@ class UnsafeRowConverterSuite extends SparkFunSuite with Matchers { assert(setToNullAfterCreation.getFloat(6) === rowWithNoNullColumns.getFloat(6)) assert(setToNullAfterCreation.getDouble(7) === rowWithNoNullColumns.getDouble(7)) assert(setToNullAfterCreation.getString(8) === rowWithNoNullColumns.getString(8)) - assert(setToNullAfterCreation.get(9) === rowWithNoNullColumns.get(9)) - assert(setToNullAfterCreation.get(10) === rowWithNoNullColumns.get(10)) - assert(setToNullAfterCreation.get(11) === rowWithNoNullColumns.get(11)) + assert(setToNullAfterCreation.getBinary(9) === rowWithNoNullColumns.get(9)) + // assert(setToNullAfterCreation.get(10) === rowWithNoNullColumns.get(10)) + // assert(setToNullAfterCreation.get(11) === rowWithNoNullColumns.get(11)) - for (i <- 0 to fieldTypes.length - 1) { - if (i >= 8) { - setToNullAfterCreation.update(i, null) - } + for (i <- fieldTypes.indices) { setToNullAfterCreation.setNullAt(i) } // There are some garbage left in the var-length area @@ -297,10 +239,10 @@ class UnsafeRowConverterSuite extends SparkFunSuite with Matchers { setToNullAfterCreation.setLong(5, 500) setToNullAfterCreation.setFloat(6, 600) setToNullAfterCreation.setDouble(7, 700) - setToNullAfterCreation.update(8, UTF8String.fromString("hello")) - setToNullAfterCreation.update(9, "world".getBytes) - setToNullAfterCreation.update(10, Decimal(10)) - setToNullAfterCreation.update(11, Array(11)) + // setToNullAfterCreation.update(8, UTF8String.fromString("hello")) + // setToNullAfterCreation.update(9, "world".getBytes) + // setToNullAfterCreation.update(10, Decimal(10)) + // setToNullAfterCreation.update(11, Array(11)) assert(setToNullAfterCreation.isNullAt(0) === rowWithNoNullColumns.isNullAt(0)) assert(setToNullAfterCreation.getBoolean(1) === rowWithNoNullColumns.getBoolean(1)) @@ -310,10 +252,29 @@ class UnsafeRowConverterSuite extends SparkFunSuite with Matchers { assert(setToNullAfterCreation.getLong(5) === rowWithNoNullColumns.getLong(5)) assert(setToNullAfterCreation.getFloat(6) === rowWithNoNullColumns.getFloat(6)) assert(setToNullAfterCreation.getDouble(7) === rowWithNoNullColumns.getDouble(7)) - assert(setToNullAfterCreation.getString(8) === rowWithNoNullColumns.getString(8)) - assert(setToNullAfterCreation.get(9) === rowWithNoNullColumns.get(9)) - assert(setToNullAfterCreation.get(10) === rowWithNoNullColumns.get(10)) - assert(setToNullAfterCreation.get(11) === rowWithNoNullColumns.get(11)) + // assert(setToNullAfterCreation.getString(8) === rowWithNoNullColumns.getString(8)) + // assert(setToNullAfterCreation.get(9) === rowWithNoNullColumns.get(9)) + // assert(setToNullAfterCreation.get(10) === rowWithNoNullColumns.get(10)) + // assert(setToNullAfterCreation.get(11) === rowWithNoNullColumns.get(11)) } + test("NaN canonicalization") { + val fieldTypes: Array[DataType] = Array(FloatType, DoubleType) + + val row1 = new SpecificMutableRow(fieldTypes) + row1.setFloat(0, java.lang.Float.intBitsToFloat(0x7f800001)) + row1.setDouble(1, java.lang.Double.longBitsToDouble(0x7ff0000000000001L)) + + val row2 = new SpecificMutableRow(fieldTypes) + row2.setFloat(0, java.lang.Float.intBitsToFloat(0x7fffffff)) + row2.setDouble(1, java.lang.Double.longBitsToDouble(0x7fffffffffffffffL)) + + val converter = new UnsafeRowConverter(fieldTypes) + val row1Buffer = new Array[Byte](converter.getSizeRequirement(row1)) + val row2Buffer = new Array[Byte](converter.getSizeRequirement(row2)) + converter.writeRow(row1, row1Buffer, PlatformDependent.BYTE_ARRAY_OFFSET, row1Buffer.length) + converter.writeRow(row2, row2Buffer, PlatformDependent.BYTE_ARRAY_OFFSET, row2Buffer.length) + + assert(row1Buffer.toSeq === row2Buffer.toSeq) + } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/BooleanSimplificationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/BooleanSimplificationSuite.scala index 465a5e6914204..d4916ea8d273a 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/BooleanSimplificationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/BooleanSimplificationSuite.scala @@ -17,7 +17,7 @@ package org.apache.spark.sql.catalyst.optimizer -import org.apache.spark.sql.catalyst.analysis.EliminateSubQueries +import org.apache.spark.sql.catalyst.analysis.{AnalysisSuite, EliminateSubQueries} import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.plans.PlanTest @@ -40,29 +40,11 @@ class BooleanSimplificationSuite extends PlanTest with PredicateHelper { val testRelation = LocalRelation('a.int, 'b.int, 'c.int, 'd.string) - // The `foldLeft` is required to handle cases like comparing `a && (b && c)` and `(a && b) && c` - def compareConditions(e1: Expression, e2: Expression): Boolean = (e1, e2) match { - case (lhs: And, rhs: And) => - val lhsSet = splitConjunctivePredicates(lhs).toSet - val rhsSet = splitConjunctivePredicates(rhs).toSet - lhsSet.foldLeft(rhsSet) { (set, e) => - set.find(compareConditions(_, e)).map(set - _).getOrElse(set) - }.isEmpty - - case (lhs: Or, rhs: Or) => - val lhsSet = splitDisjunctivePredicates(lhs).toSet - val rhsSet = splitDisjunctivePredicates(rhs).toSet - lhsSet.foldLeft(rhsSet) { (set, e) => - set.find(compareConditions(_, e)).map(set - _).getOrElse(set) - }.isEmpty - - case (l, r) => l == r - } - - def checkCondition(input: Expression, expected: Expression): Unit = { + private def checkCondition(input: Expression, expected: Expression): Unit = { val plan = testRelation.where(input).analyze - val actual = Optimize.execute(plan).expressions.head - compareConditions(actual, expected) + val actual = Optimize.execute(plan) + val correctAnswer = testRelation.where(expected).analyze + comparePlans(actual, correctAnswer) } test("a && a => a") { @@ -86,10 +68,8 @@ class BooleanSimplificationSuite extends PlanTest with PredicateHelper { ('a === 'b && 'c < 1 && 'a === 5) || ('a === 'b && 'b < 5 && 'a > 1) - val expected = - (((('b > 3) && ('c > 2)) || - (('c < 1) && ('a === 5))) || - (('b < 5) && ('a > 1))) && ('a === 'b) + val expected = 'a === 'b && ( + ('b > 3 && 'c > 2) || ('c < 1 && 'a === 5) || ('b < 5 && 'a > 1)) checkCondition(input, expected) } @@ -101,10 +81,27 @@ class BooleanSimplificationSuite extends PlanTest with PredicateHelper { checkCondition('a < 2 && ('a < 2 || 'a > 3 || 'b > 5) , 'a < 2) - checkCondition(('a < 2 || 'b > 3) && ('a < 2 || 'c > 5), ('b > 3 && 'c > 5) || 'a < 2) + checkCondition(('a < 2 || 'b > 3) && ('a < 2 || 'c > 5), 'a < 2 || ('b > 3 && 'c > 5)) checkCondition( ('a === 'b || 'b > 3) && ('a === 'b || 'a > 3) && ('a === 'b || 'a < 5), - ('b > 3 && 'a > 3 && 'a < 5) || 'a === 'b) + ('a === 'b || 'b > 3 && 'a > 3 && 'a < 5)) + } + + private def caseInsensitiveAnalyse(plan: LogicalPlan) = + AnalysisSuite.caseInsensitiveAnalyzer.execute(plan) + + test("(a && b) || (a && c) => a && (b || c) when case insensitive") { + val plan = caseInsensitiveAnalyse(testRelation.where(('a > 2 && 'b > 3) || ('A > 2 && 'b < 5))) + val actual = Optimize.execute(plan) + val expected = caseInsensitiveAnalyse(testRelation.where('a > 2 && ('b > 3 || 'b < 5))) + comparePlans(actual, expected) + } + + test("(a || b) && (a || c) => a || (b && c) when case insensitive") { + val plan = caseInsensitiveAnalyse(testRelation.where(('a > 2 || 'b > 3) && ('A > 2 || 'b < 5))) + val actual = Optimize.execute(plan) + val expected = caseInsensitiveAnalyse(testRelation.where('a > 2 || ('b > 3 && 'b < 5))) + comparePlans(actual, expected) } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FilterPushdownSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FilterPushdownSuite.scala index dc28b3ffb59ee..0f1fde2fb0f67 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FilterPushdownSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FilterPushdownSuite.scala @@ -19,7 +19,7 @@ package org.apache.spark.sql.catalyst.optimizer import org.apache.spark.sql.catalyst.analysis import org.apache.spark.sql.catalyst.analysis.EliminateSubQueries -import org.apache.spark.sql.catalyst.expressions.{SortOrder, Ascending, Count, Explode} +import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.plans.{LeftSemi, PlanTest, LeftOuter, RightOuter} import org.apache.spark.sql.catalyst.rules._ @@ -146,6 +146,49 @@ class FilterPushdownSuite extends PlanTest { comparePlans(optimized, correctAnswer) } + test("nondeterministic: can't push down filter through project") { + val originalQuery = testRelation + .select(Rand(10).as('rand), 'a) + .where('rand > 5 || 'a > 5) + .analyze + + val optimized = Optimize.execute(originalQuery) + + comparePlans(optimized, originalQuery) + } + + test("nondeterministic: push down part of filter through project") { + val originalQuery = testRelation + .select(Rand(10).as('rand), 'a) + .where('rand > 5 && 'a > 5) + .analyze + + val optimized = Optimize.execute(originalQuery) + + val correctAnswer = testRelation + .where('a > 5) + .select(Rand(10).as('rand), 'a) + .where('rand > 5) + .analyze + + comparePlans(optimized, correctAnswer) + } + + test("nondeterministic: push down filter through project") { + val originalQuery = testRelation + .select(Rand(10).as('rand), 'a) + .where('a > 5 && 'a < 10) + .analyze + + val optimized = Optimize.execute(originalQuery) + val correctAnswer = testRelation + .where('a > 5 && 'a < 10) + .select(Rand(10).as('rand), 'a) + .analyze + + comparePlans(optimized, correctAnswer) + } + test("filters: combines filters") { val originalQuery = testRelation .select('a) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ProjectCollapsingSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ProjectCollapsingSuite.scala index 151654bffbd66..1aa89991cc698 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ProjectCollapsingSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ProjectCollapsingSuite.scala @@ -70,4 +70,30 @@ class ProjectCollapsingSuite extends PlanTest { comparePlans(optimized, correctAnswer) } + + test("collapse two nondeterministic, independent projects into one") { + val query = testRelation + .select(Rand(10).as('rand)) + .select(Rand(20).as('rand2)) + + val optimized = Optimize.execute(query.analyze) + + val correctAnswer = testRelation + .select(Rand(20).as('rand2)).analyze + + comparePlans(optimized, correctAnswer) + } + + test("collapse one nondeterministic, one deterministic, independent projects into one") { + val query = testRelation + .select(Rand(10).as('rand), 'a) + .select(('a + 1).as('a_plus_1)) + + val optimized = Optimize.execute(query.analyze) + + val correctAnswer = testRelation + .select(('a + 1).as('a_plus_1)).analyze + + comparePlans(optimized, correctAnswer) + } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/SetOperationPushDownSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/SetOperationPushDownSuite.scala new file mode 100644 index 0000000000000..49c979bc7d72c --- /dev/null +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/SetOperationPushDownSuite.scala @@ -0,0 +1,82 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.optimizer + +import org.apache.spark.sql.catalyst.analysis.EliminateSubQueries +import org.apache.spark.sql.catalyst.plans.PlanTest +import org.apache.spark.sql.catalyst.plans.logical._ +import org.apache.spark.sql.catalyst.rules._ +import org.apache.spark.sql.catalyst.dsl.plans._ +import org.apache.spark.sql.catalyst.dsl.expressions._ + +class SetOperationPushDownSuite extends PlanTest { + object Optimize extends RuleExecutor[LogicalPlan] { + val batches = + Batch("Subqueries", Once, + EliminateSubQueries) :: + Batch("Union Pushdown", Once, + SetOperationPushDown) :: Nil + } + + val testRelation = LocalRelation('a.int, 'b.int, 'c.int) + val testRelation2 = LocalRelation('d.int, 'e.int, 'f.int) + val testUnion = Union(testRelation, testRelation2) + val testIntersect = Intersect(testRelation, testRelation2) + val testExcept = Except(testRelation, testRelation2) + + test("union/intersect/except: filter to each side") { + val unionQuery = testUnion.where('a === 1) + val intersectQuery = testIntersect.where('b < 10) + val exceptQuery = testExcept.where('c >= 5) + + val unionOptimized = Optimize.execute(unionQuery.analyze) + val intersectOptimized = Optimize.execute(intersectQuery.analyze) + val exceptOptimized = Optimize.execute(exceptQuery.analyze) + + val unionCorrectAnswer = + Union(testRelation.where('a === 1), testRelation2.where('d === 1)).analyze + val intersectCorrectAnswer = + Intersect(testRelation.where('b < 10), testRelation2.where('e < 10)).analyze + val exceptCorrectAnswer = + Except(testRelation.where('c >= 5), testRelation2.where('f >= 5)).analyze + + comparePlans(unionOptimized, unionCorrectAnswer) + comparePlans(intersectOptimized, intersectCorrectAnswer) + comparePlans(exceptOptimized, exceptCorrectAnswer) + } + + test("union/intersect/except: project to each side") { + val unionQuery = testUnion.select('a) + val intersectQuery = testIntersect.select('b, 'c) + val exceptQuery = testExcept.select('a, 'b, 'c) + + val unionOptimized = Optimize.execute(unionQuery.analyze) + val intersectOptimized = Optimize.execute(intersectQuery.analyze) + val exceptOptimized = Optimize.execute(exceptQuery.analyze) + + val unionCorrectAnswer = + Union(testRelation.select('a), testRelation2.select('d)).analyze + val intersectCorrectAnswer = + Intersect(testRelation.select('b, 'c), testRelation2.select('e, 'f)).analyze + val exceptCorrectAnswer = + Except(testRelation.select('a, 'b, 'c), testRelation2.select('d, 'e, 'f)).analyze + + comparePlans(unionOptimized, unionCorrectAnswer) + comparePlans(intersectOptimized, intersectCorrectAnswer) + comparePlans(exceptOptimized, exceptCorrectAnswer) } +} diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/UnionPushdownSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/UnionPushdownSuite.scala deleted file mode 100644 index ec379489a6d1e..0000000000000 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/UnionPushdownSuite.scala +++ /dev/null @@ -1,61 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.catalyst.optimizer - -import org.apache.spark.sql.catalyst.analysis.EliminateSubQueries -import org.apache.spark.sql.catalyst.plans.PlanTest -import org.apache.spark.sql.catalyst.plans.logical._ -import org.apache.spark.sql.catalyst.rules._ -import org.apache.spark.sql.catalyst.dsl.plans._ -import org.apache.spark.sql.catalyst.dsl.expressions._ - -class UnionPushDownSuite extends PlanTest { - object Optimize extends RuleExecutor[LogicalPlan] { - val batches = - Batch("Subqueries", Once, - EliminateSubQueries) :: - Batch("Union Pushdown", Once, - UnionPushDown) :: Nil - } - - val testRelation = LocalRelation('a.int, 'b.int, 'c.int) - val testRelation2 = LocalRelation('d.int, 'e.int, 'f.int) - val testUnion = Union(testRelation, testRelation2) - - test("union: filter to each side") { - val query = testUnion.where('a === 1) - - val optimized = Optimize.execute(query.analyze) - - val correctAnswer = - Union(testRelation.where('a === 1), testRelation2.where('d === 1)).analyze - - comparePlans(optimized, correctAnswer) - } - - test("union: project to each side") { - val query = testUnion.select('b) - - val optimized = Optimize.execute(query.analyze) - - val correctAnswer = - Union(testRelation.select('b), testRelation2.select('e)).analyze - - comparePlans(optimized, correctAnswer) - } -} diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/trees/TreeNodeSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/trees/TreeNodeSuite.scala index 1bd7d4e5cdf0f..8fff39906b342 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/trees/TreeNodeSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/trees/TreeNodeSuite.scala @@ -22,9 +22,10 @@ import scala.collection.mutable.ArrayBuffer import org.apache.spark.SparkFunSuite import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.expressions.codegen.CodegenFallback import org.apache.spark.sql.types.{IntegerType, StringType, NullType} -case class Dummy(optKey: Option[Expression]) extends Expression { +case class Dummy(optKey: Option[Expression]) extends Expression with CodegenFallback { override def children: Seq[Expression] = optKey.toSeq override def nullable: Boolean = true override def dataType: NullType = NullType diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/DateTimeUtilsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/DateTimeUtilsSuite.scala index f63ac191e7366..fab9eb9cd4c9f 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/DateTimeUtilsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/DateTimeUtilsSuite.scala @@ -19,11 +19,18 @@ package org.apache.spark.sql.catalyst.util import java.sql.{Date, Timestamp} import java.text.SimpleDateFormat +import java.util.{TimeZone, Calendar} import org.apache.spark.SparkFunSuite +import org.apache.spark.unsafe.types.UTF8String class DateTimeUtilsSuite extends SparkFunSuite { + private[this] def getInUTCDays(timestamp: Long): Int = { + val tz = TimeZone.getDefault + ((timestamp + tz.getOffset(timestamp)) / DateTimeUtils.MILLIS_PER_DAY).toInt + } + test("timestamp and us") { val now = new Timestamp(System.currentTimeMillis()) now.setNanos(1000) @@ -86,4 +93,272 @@ class DateTimeUtilsSuite extends SparkFunSuite { checkFromToJavaDate(new Date(df1.parse("1776-07-04 10:30:00").getTime)) checkFromToJavaDate(new Date(df2.parse("1776-07-04 18:30:00 UTC").getTime)) } + + test("string to date") { + import DateTimeUtils.millisToDays + + var c = Calendar.getInstance() + c.set(2015, 0, 28, 0, 0, 0) + c.set(Calendar.MILLISECOND, 0) + assert(DateTimeUtils.stringToDate(UTF8String.fromString("2015-01-28")).get === + millisToDays(c.getTimeInMillis)) + c.set(2015, 0, 1, 0, 0, 0) + c.set(Calendar.MILLISECOND, 0) + assert(DateTimeUtils.stringToDate(UTF8String.fromString("2015")).get === + millisToDays(c.getTimeInMillis)) + c = Calendar.getInstance() + c.set(2015, 2, 1, 0, 0, 0) + c.set(Calendar.MILLISECOND, 0) + assert(DateTimeUtils.stringToDate(UTF8String.fromString("2015-03")).get === + millisToDays(c.getTimeInMillis)) + c = Calendar.getInstance() + c.set(2015, 2, 18, 0, 0, 0) + c.set(Calendar.MILLISECOND, 0) + assert(DateTimeUtils.stringToDate(UTF8String.fromString("2015-03-18")).get === + millisToDays(c.getTimeInMillis)) + assert(DateTimeUtils.stringToDate(UTF8String.fromString("2015-03-18 ")).get === + millisToDays(c.getTimeInMillis)) + assert(DateTimeUtils.stringToDate(UTF8String.fromString("2015-03-18 123142")).get === + millisToDays(c.getTimeInMillis)) + assert(DateTimeUtils.stringToDate(UTF8String.fromString("2015-03-18T123123")).get === + millisToDays(c.getTimeInMillis)) + assert(DateTimeUtils.stringToDate(UTF8String.fromString("2015-03-18T")).get === + millisToDays(c.getTimeInMillis)) + + assert(DateTimeUtils.stringToDate(UTF8String.fromString("2015-03-18X")).isEmpty) + assert(DateTimeUtils.stringToDate(UTF8String.fromString("2015/03/18")).isEmpty) + assert(DateTimeUtils.stringToDate(UTF8String.fromString("2015.03.18")).isEmpty) + assert(DateTimeUtils.stringToDate(UTF8String.fromString("20150318")).isEmpty) + assert(DateTimeUtils.stringToDate(UTF8String.fromString("2015-031-8")).isEmpty) + } + + test("string to timestamp") { + var c = Calendar.getInstance() + c.set(1969, 11, 31, 16, 0, 0) + c.set(Calendar.MILLISECOND, 0) + assert(DateTimeUtils.stringToTimestamp(UTF8String.fromString("1969-12-31 16:00:00")).get === + c.getTimeInMillis * 1000) + c.set(2015, 0, 1, 0, 0, 0) + c.set(Calendar.MILLISECOND, 0) + assert(DateTimeUtils.stringToTimestamp(UTF8String.fromString("2015")).get === + c.getTimeInMillis * 1000) + c = Calendar.getInstance() + c.set(2015, 2, 1, 0, 0, 0) + c.set(Calendar.MILLISECOND, 0) + assert(DateTimeUtils.stringToTimestamp(UTF8String.fromString("2015-03")).get === + c.getTimeInMillis * 1000) + c = Calendar.getInstance() + c.set(2015, 2, 18, 0, 0, 0) + c.set(Calendar.MILLISECOND, 0) + assert(DateTimeUtils.stringToTimestamp(UTF8String.fromString("2015-03-18")).get === + c.getTimeInMillis * 1000) + assert(DateTimeUtils.stringToTimestamp(UTF8String.fromString("2015-03-18 ")).get === + c.getTimeInMillis * 1000) + assert(DateTimeUtils.stringToTimestamp(UTF8String.fromString("2015-03-18T")).get === + c.getTimeInMillis * 1000) + + c = Calendar.getInstance() + c.set(2015, 2, 18, 12, 3, 17) + c.set(Calendar.MILLISECOND, 0) + assert(DateTimeUtils.stringToTimestamp(UTF8String.fromString("2015-03-18 12:03:17")).get === + c.getTimeInMillis * 1000) + assert(DateTimeUtils.stringToTimestamp(UTF8String.fromString("2015-03-18T12:03:17")).get === + c.getTimeInMillis * 1000) + + c = Calendar.getInstance(TimeZone.getTimeZone("GMT-13:53")) + c.set(2015, 2, 18, 12, 3, 17) + c.set(Calendar.MILLISECOND, 0) + assert(DateTimeUtils.stringToTimestamp( + UTF8String.fromString("2015-03-18T12:03:17-13:53")).get === c.getTimeInMillis * 1000) + + c = Calendar.getInstance(TimeZone.getTimeZone("UTC")) + c.set(2015, 2, 18, 12, 3, 17) + c.set(Calendar.MILLISECOND, 0) + assert(DateTimeUtils.stringToTimestamp(UTF8String.fromString("2015-03-18T12:03:17Z")).get === + c.getTimeInMillis * 1000) + assert(DateTimeUtils.stringToTimestamp(UTF8String.fromString("2015-03-18 12:03:17Z")).get === + c.getTimeInMillis * 1000) + + c = Calendar.getInstance(TimeZone.getTimeZone("GMT-01:00")) + c.set(2015, 2, 18, 12, 3, 17) + c.set(Calendar.MILLISECOND, 0) + assert(DateTimeUtils.stringToTimestamp(UTF8String.fromString("2015-03-18T12:03:17-1:0")).get === + c.getTimeInMillis * 1000) + assert(DateTimeUtils.stringToTimestamp( + UTF8String.fromString("2015-03-18T12:03:17-01:00")).get === c.getTimeInMillis * 1000) + + c = Calendar.getInstance(TimeZone.getTimeZone("GMT+07:30")) + c.set(2015, 2, 18, 12, 3, 17) + c.set(Calendar.MILLISECOND, 0) + assert(DateTimeUtils.stringToTimestamp( + UTF8String.fromString("2015-03-18T12:03:17+07:30")).get === c.getTimeInMillis * 1000) + + c = Calendar.getInstance(TimeZone.getTimeZone("GMT+07:03")) + c.set(2015, 2, 18, 12, 3, 17) + c.set(Calendar.MILLISECOND, 0) + assert(DateTimeUtils.stringToTimestamp( + UTF8String.fromString("2015-03-18T12:03:17+07:03")).get === c.getTimeInMillis * 1000) + + c = Calendar.getInstance() + c.set(2015, 2, 18, 12, 3, 17) + c.set(Calendar.MILLISECOND, 123) + assert(DateTimeUtils.stringToTimestamp( + UTF8String.fromString("2015-03-18 12:03:17.123")).get === c.getTimeInMillis * 1000) + assert(DateTimeUtils.stringToTimestamp( + UTF8String.fromString("2015-03-18T12:03:17.123")).get === c.getTimeInMillis * 1000) + + c = Calendar.getInstance(TimeZone.getTimeZone("UTC")) + c.set(2015, 2, 18, 12, 3, 17) + c.set(Calendar.MILLISECOND, 456) + assert(DateTimeUtils.stringToTimestamp( + UTF8String.fromString("2015-03-18T12:03:17.456Z")).get === c.getTimeInMillis * 1000) + assert(DateTimeUtils.stringToTimestamp( + UTF8String.fromString("2015-03-18 12:03:17.456Z")).get === c.getTimeInMillis * 1000) + + c = Calendar.getInstance(TimeZone.getTimeZone("GMT-01:00")) + c.set(2015, 2, 18, 12, 3, 17) + c.set(Calendar.MILLISECOND, 123) + assert(DateTimeUtils.stringToTimestamp( + UTF8String.fromString("2015-03-18T12:03:17.123-1:0")).get === c.getTimeInMillis * 1000) + assert(DateTimeUtils.stringToTimestamp( + UTF8String.fromString("2015-03-18T12:03:17.123-01:00")).get === c.getTimeInMillis * 1000) + + c = Calendar.getInstance(TimeZone.getTimeZone("GMT+07:30")) + c.set(2015, 2, 18, 12, 3, 17) + c.set(Calendar.MILLISECOND, 123) + assert(DateTimeUtils.stringToTimestamp( + UTF8String.fromString("2015-03-18T12:03:17.123+07:30")).get === c.getTimeInMillis * 1000) + + c = Calendar.getInstance(TimeZone.getTimeZone("GMT+07:30")) + c.set(2015, 2, 18, 12, 3, 17) + c.set(Calendar.MILLISECOND, 123) + assert(DateTimeUtils.stringToTimestamp( + UTF8String.fromString("2015-03-18T12:03:17.123+07:30")).get === c.getTimeInMillis * 1000) + + c = Calendar.getInstance(TimeZone.getTimeZone("GMT+07:30")) + c.set(2015, 2, 18, 12, 3, 17) + c.set(Calendar.MILLISECOND, 123) + assert(DateTimeUtils.stringToTimestamp( + UTF8String.fromString("2015-03-18T12:03:17.123121+7:30")).get === + c.getTimeInMillis * 1000 + 121) + + c = Calendar.getInstance(TimeZone.getTimeZone("GMT+07:30")) + c.set(2015, 2, 18, 12, 3, 17) + c.set(Calendar.MILLISECOND, 123) + assert(DateTimeUtils.stringToTimestamp( + UTF8String.fromString("2015-03-18T12:03:17.12312+7:30")).get === + c.getTimeInMillis * 1000 + 120) + + c = Calendar.getInstance() + c.set(Calendar.HOUR_OF_DAY, 18) + c.set(Calendar.MINUTE, 12) + c.set(Calendar.SECOND, 15) + c.set(Calendar.MILLISECOND, 0) + assert(DateTimeUtils.stringToTimestamp( + UTF8String.fromString("18:12:15")).get === + c.getTimeInMillis * 1000) + + c = Calendar.getInstance(TimeZone.getTimeZone("GMT+07:30")) + c.set(Calendar.HOUR_OF_DAY, 18) + c.set(Calendar.MINUTE, 12) + c.set(Calendar.SECOND, 15) + c.set(Calendar.MILLISECOND, 123) + assert(DateTimeUtils.stringToTimestamp( + UTF8String.fromString("T18:12:15.12312+7:30")).get === + c.getTimeInMillis * 1000 + 120) + + c = Calendar.getInstance(TimeZone.getTimeZone("GMT+07:30")) + c.set(Calendar.HOUR_OF_DAY, 18) + c.set(Calendar.MINUTE, 12) + c.set(Calendar.SECOND, 15) + c.set(Calendar.MILLISECOND, 123) + assert(DateTimeUtils.stringToTimestamp( + UTF8String.fromString("18:12:15.12312+7:30")).get === + c.getTimeInMillis * 1000 + 120) + + c = Calendar.getInstance() + c.set(2011, 4, 6, 7, 8, 9) + c.set(Calendar.MILLISECOND, 100) + assert(DateTimeUtils.stringToTimestamp( + UTF8String.fromString("2011-05-06 07:08:09.1000")).get === c.getTimeInMillis * 1000) + + assert(DateTimeUtils.stringToTimestamp(UTF8String.fromString("238")).isEmpty) + assert(DateTimeUtils.stringToTimestamp(UTF8String.fromString("2015-03-18 123142")).isEmpty) + assert(DateTimeUtils.stringToTimestamp(UTF8String.fromString("2015-03-18T123123")).isEmpty) + assert(DateTimeUtils.stringToTimestamp(UTF8String.fromString("2015-03-18X")).isEmpty) + assert(DateTimeUtils.stringToTimestamp(UTF8String.fromString("2015/03/18")).isEmpty) + assert(DateTimeUtils.stringToTimestamp(UTF8String.fromString("2015.03.18")).isEmpty) + assert(DateTimeUtils.stringToTimestamp(UTF8String.fromString("20150318")).isEmpty) + assert(DateTimeUtils.stringToTimestamp(UTF8String.fromString("2015-031-8")).isEmpty) + assert(DateTimeUtils.stringToTimestamp( + UTF8String.fromString("2015-03-18T12:03.17-20:0")).isEmpty) + assert(DateTimeUtils.stringToTimestamp( + UTF8String.fromString("2015-03-18T12:03.17-0:70")).isEmpty) + assert(DateTimeUtils.stringToTimestamp( + UTF8String.fromString("2015-03-18T12:03.17-1:0:0")).isEmpty) + } + + test("hours") { + val c = Calendar.getInstance() + c.set(2015, 2, 18, 13, 2, 11) + assert(DateTimeUtils.getHours(c.getTimeInMillis * 1000) === 13) + c.set(2015, 12, 8, 2, 7, 9) + assert(DateTimeUtils.getHours(c.getTimeInMillis * 1000) === 2) + } + + test("minutes") { + val c = Calendar.getInstance() + c.set(2015, 2, 18, 13, 2, 11) + assert(DateTimeUtils.getMinutes(c.getTimeInMillis * 1000) === 2) + c.set(2015, 2, 8, 2, 7, 9) + assert(DateTimeUtils.getMinutes(c.getTimeInMillis * 1000) === 7) + } + + test("seconds") { + val c = Calendar.getInstance() + c.set(2015, 2, 18, 13, 2, 11) + assert(DateTimeUtils.getSeconds(c.getTimeInMillis * 1000) === 11) + c.set(2015, 2, 8, 2, 7, 9) + assert(DateTimeUtils.getSeconds(c.getTimeInMillis * 1000) === 9) + } + + test("get day in year") { + val c = Calendar.getInstance() + c.set(2015, 2, 18, 0, 0, 0) + assert(DateTimeUtils.getDayInYear(getInUTCDays(c.getTimeInMillis)) === 77) + c.set(2012, 2, 18, 0, 0, 0) + assert(DateTimeUtils.getDayInYear(getInUTCDays(c.getTimeInMillis)) === 78) + } + + test("get year") { + val c = Calendar.getInstance() + c.set(2015, 2, 18, 0, 0, 0) + assert(DateTimeUtils.getYear(getInUTCDays(c.getTimeInMillis)) === 2015) + c.set(2012, 2, 18, 0, 0, 0) + assert(DateTimeUtils.getYear(getInUTCDays(c.getTimeInMillis)) === 2012) + } + + test("get quarter") { + val c = Calendar.getInstance() + c.set(2015, 2, 18, 0, 0, 0) + assert(DateTimeUtils.getQuarter(getInUTCDays(c.getTimeInMillis)) === 1) + c.set(2012, 11, 18, 0, 0, 0) + assert(DateTimeUtils.getQuarter(getInUTCDays(c.getTimeInMillis)) === 4) + } + + test("get month") { + val c = Calendar.getInstance() + c.set(2015, 2, 18, 0, 0, 0) + assert(DateTimeUtils.getMonth(getInUTCDays(c.getTimeInMillis)) === 3) + c.set(2012, 11, 18, 0, 0, 0) + assert(DateTimeUtils.getMonth(getInUTCDays(c.getTimeInMillis)) === 12) + } + + test("get day of month") { + val c = Calendar.getInstance() + c.set(2015, 2, 18, 0, 0, 0) + assert(DateTimeUtils.getDayOfMonth(getInUTCDays(c.getTimeInMillis)) === 18) + c.set(2012, 11, 24, 0, 0, 0) + assert(DateTimeUtils.getDayOfMonth(getInUTCDays(c.getTimeInMillis)) === 24) + } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/ObjectPoolSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/NumberConverterSuite.scala similarity index 51% rename from sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/ObjectPoolSuite.scala rename to sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/NumberConverterSuite.scala index 94764df4b9cdb..13265a1ff1c7f 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/ObjectPoolSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/NumberConverterSuite.scala @@ -17,41 +17,24 @@ package org.apache.spark.sql.catalyst.util -import org.scalatest.Matchers - import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.catalyst.util.NumberConverter.convert +import org.apache.spark.unsafe.types.UTF8String -class ObjectPoolSuite extends SparkFunSuite with Matchers { - - test("pool") { - val pool = new ObjectPool(1) - assert(pool.put(1) === 0) - assert(pool.put("hello") === 1) - assert(pool.put(false) === 2) - - assert(pool.get(0) === 1) - assert(pool.get(1) === "hello") - assert(pool.get(2) === false) - assert(pool.size() === 3) +class NumberConverterSuite extends SparkFunSuite { - pool.replace(1, "world") - assert(pool.get(1) === "world") - assert(pool.size() === 3) + private[this] def checkConv(n: String, fromBase: Int, toBase: Int, expected: String): Unit = { + assert(convert(UTF8String.fromString(n).getBytes, fromBase, toBase) === + UTF8String.fromString(expected)) } - test("unique pool") { - val pool = new UniqueObjectPool(1) - assert(pool.put(1) === 0) - assert(pool.put("hello") === 1) - assert(pool.put(1) === 0) - assert(pool.put("hello") === 1) - - assert(pool.get(0) === 1) - assert(pool.get(1) === "hello") - assert(pool.size() === 2) - - intercept[UnsupportedOperationException] { - pool.replace(1, "world") - } + test("convert") { + checkConv("3", 10, 2, "11") + checkConv("-15", 10, -16, "-F") + checkConv("-15", 10, 16, "FFFFFFFFFFFFFFF1") + checkConv("big", 36, 16, "3A48") + checkConv("9223372036854775807", 36, 16, "FFFFFFFFFFFFFFFF") + checkConv("11abc", 10, 16, "B") } + } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/types/decimal/DecimalSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/types/decimal/DecimalSuite.scala index 5f312964e5bf7..1d297beb3868d 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/types/decimal/DecimalSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/types/decimal/DecimalSuite.scala @@ -24,14 +24,14 @@ import org.scalatest.PrivateMethodTester import scala.language.postfixOps class DecimalSuite extends SparkFunSuite with PrivateMethodTester { - test("creating decimals") { - /** Check that a Decimal has the given string representation, precision and scale */ - def checkDecimal(d: Decimal, string: String, precision: Int, scale: Int): Unit = { - assert(d.toString === string) - assert(d.precision === precision) - assert(d.scale === scale) - } + /** Check that a Decimal has the given string representation, precision and scale */ + private def checkDecimal(d: Decimal, string: String, precision: Int, scale: Int): Unit = { + assert(d.toString === string) + assert(d.precision === precision) + assert(d.scale === scale) + } + test("creating decimals") { checkDecimal(new Decimal(), "0", 1, 0) checkDecimal(Decimal(BigDecimal("10.030")), "10.030", 5, 3) checkDecimal(Decimal(BigDecimal("10.030"), 4, 1), "10.0", 4, 1) @@ -53,6 +53,15 @@ class DecimalSuite extends SparkFunSuite with PrivateMethodTester { intercept[IllegalArgumentException](Decimal(1e17.toLong, 17, 0)) } + test("creating decimals with negative scale") { + checkDecimal(Decimal(BigDecimal("98765"), 5, -3), "9.9E+4", 5, -3) + checkDecimal(Decimal(BigDecimal("314.159"), 6, -2), "3E+2", 6, -2) + checkDecimal(Decimal(BigDecimal(1.579e12), 4, -9), "1.579E+12", 4, -9) + checkDecimal(Decimal(BigDecimal(1.579e12), 4, -10), "1.58E+12", 4, -10) + checkDecimal(Decimal(103050709L, 9, -10), "1.03050709E+18", 9, -10) + checkDecimal(Decimal(1e8.toLong, 10, -10), "1.00000000E+18", 10, -10) + } + test("double and long values") { /** Check that a Decimal converts to the given double and long values */ def checkValues(d: Decimal, doubleValue: Double, longValue: Long): Unit = { @@ -162,14 +171,4 @@ class DecimalSuite extends SparkFunSuite with PrivateMethodTester { assert(new Decimal().set(100L, 10, 0).toUnscaledLong === 100L) assert(Decimal(Long.MaxValue, 100, 0).toUnscaledLong === Long.MaxValue) } - - test("accurate precision after multiplication") { - val decimal = (Decimal(Long.MaxValue, 38, 0) * Decimal(Long.MaxValue, 38, 0)).toJavaBigDecimal - assert(decimal.unscaledValue.toString === "85070591730234615847396907784232501249") - } - - test("fix non-terminating decimal expansion problem") { - val decimal = Decimal(1.0, 10, 3) / Decimal(3.0, 10, 3) - assert(decimal.toString === "0.333") - } } diff --git a/sql/core/pom.xml b/sql/core/pom.xml index 2c03fee9dbd71..be0966641b5c4 100644 --- a/sql/core/pom.xml +++ b/sql/core/pom.xml @@ -123,7 +123,6 @@ - src/test/scala src/test/gen-java diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Column.scala b/sql/core/src/main/scala/org/apache/spark/sql/Column.scala index 10250264625b2..6e2a6525bf17e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Column.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Column.scala @@ -400,6 +400,14 @@ class Column(protected[sql] val expr: Expression) extends Logging { (this >= lowerBound) && (this <= upperBound) } + /** + * True if the current expression is NaN. + * + * @group expr_ops + * @since 1.5.0 + */ + def isNaN: Column = IsNaN(expr) + /** * True if the current expression is null. * diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala index 830fba35bb7bc..fa942a1f8fd93 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala @@ -20,6 +20,8 @@ package org.apache.spark.sql import java.io.CharArrayWriter import java.util.Properties +import org.apache.spark.unsafe.types.UTF8String + import scala.language.implicitConversions import scala.reflect.ClassTag import scala.reflect.runtime.universe.TypeTag @@ -38,8 +40,8 @@ import org.apache.spark.sql.catalyst.plans.logical.{Filter, _} import org.apache.spark.sql.catalyst.plans.{Inner, JoinType} import org.apache.spark.sql.catalyst.{CatalystTypeConverters, ScalaReflection, SqlParser} import org.apache.spark.sql.execution.{EvaluatePython, ExplainCommand, LogicalRDD} +import org.apache.spark.sql.execution.datasources.CreateTableUsingAsSelect import org.apache.spark.sql.json.JacksonGenerator -import org.apache.spark.sql.sources.CreateTableUsingAsSelect import org.apache.spark.sql.types._ import org.apache.spark.storage.StorageLevel import org.apache.spark.util.Utils @@ -1282,7 +1284,7 @@ class DataFrame private[sql]( val outputCols = (if (cols.isEmpty) numericColumns.map(_.prettyString) else cols).toList - val ret: Seq[InternalRow] = if (outputCols.nonEmpty) { + val ret: Seq[Row] = if (outputCols.nonEmpty) { val aggExprs = statistics.flatMap { case (_, colToAgg) => outputCols.map(c => Column(Cast(colToAgg(Column(c).expr), StringType)).as(c)) } @@ -1290,19 +1292,18 @@ class DataFrame private[sql]( val row = agg(aggExprs.head, aggExprs.tail: _*).head().toSeq // Pivot the data so each summary is one row - row.grouped(outputCols.size).toSeq.zip(statistics).map { - case (aggregation, (statistic, _)) => - InternalRow(statistic :: aggregation.toList: _*) + row.grouped(outputCols.size).toSeq.zip(statistics).map { case (aggregation, (statistic, _)) => + Row(statistic :: aggregation.toList: _*) } } else { // If there are no output columns, just output a single column that contains the stats. - statistics.map { case (name, _) => InternalRow(name) } + statistics.map { case (name, _) => Row(name) } } // All columns are string type val schema = StructType( StructField("summary", StringType) :: outputCols.map(StructField(_, StringType))).toAttributes - LocalRelation(schema, ret) + LocalRelation.fromExternalRows(schema, ret) } /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameNaFunctions.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameNaFunctions.scala index 8681a56c82f1e..a4fd4cf3b330b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameNaFunctions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameNaFunctions.scala @@ -37,24 +37,24 @@ import org.apache.spark.sql.types._ final class DataFrameNaFunctions private[sql](df: DataFrame) { /** - * Returns a new [[DataFrame]] that drops rows containing any null values. + * Returns a new [[DataFrame]] that drops rows containing any null or NaN values. * * @since 1.3.1 */ def drop(): DataFrame = drop("any", df.columns) /** - * Returns a new [[DataFrame]] that drops rows containing null values. + * Returns a new [[DataFrame]] that drops rows containing null or NaN values. * - * If `how` is "any", then drop rows containing any null values. - * If `how` is "all", then drop rows only if every column is null for that row. + * If `how` is "any", then drop rows containing any null or NaN values. + * If `how` is "all", then drop rows only if every column is null or NaN for that row. * * @since 1.3.1 */ def drop(how: String): DataFrame = drop(how, df.columns) /** - * Returns a new [[DataFrame]] that drops rows containing any null values + * Returns a new [[DataFrame]] that drops rows containing any null or NaN values * in the specified columns. * * @since 1.3.1 @@ -62,7 +62,7 @@ final class DataFrameNaFunctions private[sql](df: DataFrame) { def drop(cols: Array[String]): DataFrame = drop(cols.toSeq) /** - * (Scala-specific) Returns a new [[DataFrame ]] that drops rows containing any null values + * (Scala-specific) Returns a new [[DataFrame]] that drops rows containing any null or NaN values * in the specified columns. * * @since 1.3.1 @@ -70,22 +70,22 @@ final class DataFrameNaFunctions private[sql](df: DataFrame) { def drop(cols: Seq[String]): DataFrame = drop(cols.size, cols) /** - * Returns a new [[DataFrame]] that drops rows containing null values + * Returns a new [[DataFrame]] that drops rows containing null or NaN values * in the specified columns. * - * If `how` is "any", then drop rows containing any null values in the specified columns. - * If `how` is "all", then drop rows only if every specified column is null for that row. + * If `how` is "any", then drop rows containing any null or NaN values in the specified columns. + * If `how` is "all", then drop rows only if every specified column is null or NaN for that row. * * @since 1.3.1 */ def drop(how: String, cols: Array[String]): DataFrame = drop(how, cols.toSeq) /** - * (Scala-specific) Returns a new [[DataFrame]] that drops rows containing null values + * (Scala-specific) Returns a new [[DataFrame]] that drops rows containing null or NaN values * in the specified columns. * - * If `how` is "any", then drop rows containing any null values in the specified columns. - * If `how` is "all", then drop rows only if every specified column is null for that row. + * If `how` is "any", then drop rows containing any null or NaN values in the specified columns. + * If `how` is "all", then drop rows only if every specified column is null or NaN for that row. * * @since 1.3.1 */ @@ -98,15 +98,16 @@ final class DataFrameNaFunctions private[sql](df: DataFrame) { } /** - * Returns a new [[DataFrame]] that drops rows containing less than `minNonNulls` non-null values. + * Returns a new [[DataFrame]] that drops rows containing + * less than `minNonNulls` non-null and non-NaN values. * * @since 1.3.1 */ def drop(minNonNulls: Int): DataFrame = drop(minNonNulls, df.columns) /** - * Returns a new [[DataFrame]] that drops rows containing less than `minNonNulls` non-null - * values in the specified columns. + * Returns a new [[DataFrame]] that drops rows containing + * less than `minNonNulls` non-null and non-NaN values in the specified columns. * * @since 1.3.1 */ @@ -114,32 +115,33 @@ final class DataFrameNaFunctions private[sql](df: DataFrame) { /** * (Scala-specific) Returns a new [[DataFrame]] that drops rows containing less than - * `minNonNulls` non-null values in the specified columns. + * `minNonNulls` non-null and non-NaN values in the specified columns. * * @since 1.3.1 */ def drop(minNonNulls: Int, cols: Seq[String]): DataFrame = { - // Filtering condition -- only keep the row if it has at least `minNonNulls` non-null values. + // Filtering condition: + // only keep the row if it has at least `minNonNulls` non-null and non-NaN values. val predicate = AtLeastNNonNulls(minNonNulls, cols.map(name => df.resolve(name))) df.filter(Column(predicate)) } /** - * Returns a new [[DataFrame]] that replaces null values in numeric columns with `value`. + * Returns a new [[DataFrame]] that replaces null or NaN values in numeric columns with `value`. * * @since 1.3.1 */ def fill(value: Double): DataFrame = fill(value, df.columns) /** - * Returns a new [[DataFrame ]] that replaces null values in string columns with `value`. + * Returns a new [[DataFrame]] that replaces null values in string columns with `value`. * * @since 1.3.1 */ def fill(value: String): DataFrame = fill(value, df.columns) /** - * Returns a new [[DataFrame]] that replaces null values in specified numeric columns. + * Returns a new [[DataFrame]] that replaces null or NaN values in specified numeric columns. * If a specified column is not a numeric column, it is ignored. * * @since 1.3.1 @@ -147,7 +149,7 @@ final class DataFrameNaFunctions private[sql](df: DataFrame) { def fill(value: Double, cols: Array[String]): DataFrame = fill(value, cols.toSeq) /** - * (Scala-specific) Returns a new [[DataFrame]] that replaces null values in specified + * (Scala-specific) Returns a new [[DataFrame]] that replaces null or NaN values in specified * numeric columns. If a specified column is not a numeric column, it is ignored. * * @since 1.3.1 @@ -391,7 +393,13 @@ final class DataFrameNaFunctions private[sql](df: DataFrame) { * Returns a [[Column]] expression that replaces null value in `col` with `replacement`. */ private def fillCol[T](col: StructField, replacement: T): Column = { - coalesce(df.col("`" + col.name + "`"), lit(replacement).cast(col.dataType)).as(col.name) + col.dataType match { + case DoubleType | FloatType => + coalesce(nanvl(df.col("`" + col.name + "`"), lit(null)), + lit(replacement).cast(col.dataType)).as(col.name) + case _ => + coalesce(df.col("`" + col.name + "`"), lit(replacement).cast(col.dataType)).as(col.name) + } } /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala index 9ad6e21da7bf7..e9d782cdcd667 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala @@ -20,16 +20,16 @@ package org.apache.spark.sql import java.util.Properties import org.apache.hadoop.fs.Path -import org.apache.spark.Partition +import org.apache.spark.{Logging, Partition} import org.apache.spark.annotation.Experimental import org.apache.spark.api.java.JavaRDD import org.apache.spark.deploy.SparkHadoopUtil import org.apache.spark.rdd.RDD +import org.apache.spark.sql.execution.datasources.{ResolvedDataSource, LogicalRelation} import org.apache.spark.sql.jdbc.{JDBCPartition, JDBCPartitioningInfo, JDBCRelation} -import org.apache.spark.sql.json.{JsonRDD, JSONRelation} +import org.apache.spark.sql.json.JSONRelation import org.apache.spark.sql.parquet.ParquetRelation2 -import org.apache.spark.sql.sources.{LogicalRelation, ResolvedDataSource} import org.apache.spark.sql.types.StructType /** @@ -40,7 +40,7 @@ import org.apache.spark.sql.types.StructType * @since 1.4.0 */ @Experimental -class DataFrameReader private[sql](sqlContext: SQLContext) { +class DataFrameReader private[sql](sqlContext: SQLContext) extends Logging { /** * Specifies the input data source format. @@ -236,17 +236,8 @@ class DataFrameReader private[sql](sqlContext: SQLContext) { */ def json(jsonRDD: RDD[String]): DataFrame = { val samplingRatio = extraOptions.getOrElse("samplingRatio", "1.0").toDouble - if (sqlContext.conf.useJacksonStreamingAPI) { - sqlContext.baseRelationToDataFrame( - new JSONRelation(() => jsonRDD, None, samplingRatio, userSpecifiedSchema)(sqlContext)) - } else { - val columnNameOfCorruptJsonRecord = sqlContext.conf.columnNameOfCorruptRecord - val appliedSchema = userSpecifiedSchema.getOrElse( - JsonRDD.nullTypeToStringType( - JsonRDD.inferSchema(jsonRDD, 1.0, columnNameOfCorruptJsonRecord))) - val rowRDD = JsonRDD.jsonStringToRow(jsonRDD, appliedSchema, columnNameOfCorruptJsonRecord) - sqlContext.internalCreateDataFrame(rowRDD, appliedSchema) - } + sqlContext.baseRelationToDataFrame( + new JSONRelation(() => jsonRDD, None, samplingRatio, userSpecifiedSchema)(sqlContext)) } /** @@ -260,13 +251,28 @@ class DataFrameReader private[sql](sqlContext: SQLContext) { if (paths.isEmpty) { sqlContext.emptyDataFrame } else { - val globbedPaths = paths.map(new Path(_)).flatMap(SparkHadoopUtil.get.globPath).toArray + val globbedPaths = paths.flatMap { path => + val hdfsPath = new Path(path) + val fs = hdfsPath.getFileSystem(sqlContext.sparkContext.hadoopConfiguration) + val qualified = hdfsPath.makeQualified(fs.getUri, fs.getWorkingDirectory) + SparkHadoopUtil.get.globPathIfNecessary(qualified) + }.toArray + sqlContext.baseRelationToDataFrame( new ParquetRelation2( globbedPaths.map(_.toString), None, None, extraOptions.toMap)(sqlContext)) } } + /** + * Loads an ORC file and returns the result as a [[DataFrame]]. + * + * @param path input path + * @since 1.5.0 + * @note Currently, this method can only be used together with `HiveContext`. + */ + def orc(path: String): DataFrame = format("orc").load(path) + /** * Returns the specified table as a [[DataFrame]]. * diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala index 5548b26cb8f80..05da05d7b8050 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala @@ -22,8 +22,8 @@ import java.util.Properties import org.apache.spark.annotation.Experimental import org.apache.spark.sql.catalyst.analysis.UnresolvedRelation import org.apache.spark.sql.catalyst.plans.logical.InsertIntoTable +import org.apache.spark.sql.execution.datasources.{CreateTableUsingAsSelect, ResolvedDataSource} import org.apache.spark.sql.jdbc.{JDBCWriteDetails, JdbcUtils} -import org.apache.spark.sql.sources.{ResolvedDataSource, CreateTableUsingAsSelect} /** @@ -197,6 +197,9 @@ final class DataFrameWriter private[sql](df: DataFrame) { // the table. But, insertInto with Overwrite requires the schema of data be the same // the schema of the table. insertInto(tableName) + + case SaveMode.Overwrite => + throw new UnsupportedOperationException("overwrite mode unsupported.") } } else { val cmd = @@ -280,6 +283,18 @@ final class DataFrameWriter private[sql](df: DataFrame) { */ def parquet(path: String): Unit = format("parquet").save(path) + /** + * Saves the content of the [[DataFrame]] in ORC format at the specified path. + * This is equivalent to: + * {{{ + * format("orc").save(path) + * }}} + * + * @since 1.5.0 + * @note Currently, this method can only be used together with `HiveContext`. + */ + def orc(path: String): Unit = format("orc").save(path) + /////////////////////////////////////////////////////////////////////////////////////// // Builder pattern config options /////////////////////////////////////////////////////////////////////////////////////// diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala index 6005d35f015a9..1474b170ba896 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala @@ -26,6 +26,11 @@ import org.apache.parquet.hadoop.ParquetOutputCommitter import org.apache.spark.sql.catalyst.CatalystConf +//////////////////////////////////////////////////////////////////////////////////////////////////// +// This file defines the configuration options for Spark SQL. +//////////////////////////////////////////////////////////////////////////////////////////////////// + + private[spark] object SQLConf { private val sqlConfEntries = java.util.Collections.synchronizedMap( @@ -184,17 +189,20 @@ private[spark] object SQLConf { val COMPRESS_CACHED = booleanConf("spark.sql.inMemoryColumnarStorage.compressed", defaultValue = Some(true), doc = "When set to true Spark SQL will automatically select a compression codec for each " + - "column based on statistics of the data.") + "column based on statistics of the data.", + isPublic = false) val COLUMN_BATCH_SIZE = intConf("spark.sql.inMemoryColumnarStorage.batchSize", defaultValue = Some(10000), doc = "Controls the size of batches for columnar caching. Larger batch sizes can improve " + - "memory utilization and compression, but risk OOMs when caching data.") + "memory utilization and compression, but risk OOMs when caching data.", + isPublic = false) val IN_MEMORY_PARTITION_PRUNING = booleanConf("spark.sql.inMemoryColumnarStorage.partitionPruning", defaultValue = Some(false), - doc = "") + doc = "When true, enable partition pruning for in-memory columnar tables.", + isPublic = false) val AUTO_BROADCASTJOIN_THRESHOLD = intConf("spark.sql.autoBroadcastJoinThreshold", defaultValue = Some(10 * 1024 * 1024), @@ -203,32 +211,38 @@ private[spark] object SQLConf { "Note that currently statistics are only supported for Hive Metastore tables where the " + "commandANALYZE TABLE <tableName> COMPUTE STATISTICS noscan has been run.") - val DEFAULT_SIZE_IN_BYTES = longConf("spark.sql.defaultSizeInBytes", isPublic = false) + val DEFAULT_SIZE_IN_BYTES = longConf( + "spark.sql.defaultSizeInBytes", + doc = "The default table size used in query planning. By default, it is set to a larger " + + "value than `spark.sql.autoBroadcastJoinThreshold` to be more conservative. That is to say " + + "by default the optimizer will not choose to broadcast a table unless it knows for sure its" + + "size is small enough.", + isPublic = false) val SHUFFLE_PARTITIONS = intConf("spark.sql.shuffle.partitions", defaultValue = Some(200), - doc = "Configures the number of partitions to use when shuffling data for joins or " + - "aggregations.") + doc = "The default number of partitions to use when shuffling data for joins or aggregations.") val CODEGEN_ENABLED = booleanConf("spark.sql.codegen", defaultValue = Some(true), doc = "When true, code will be dynamically generated at runtime for expression evaluation in" + - " a specific query. For some queries with complicated expression this option can lead to " + - "significant speed-ups. However, for simple queries this can actually slow down query " + - "execution.") + " a specific query.") val UNSAFE_ENABLED = booleanConf("spark.sql.unsafe.enabled", defaultValue = Some(false), - doc = "") + doc = "When true, use the new optimized Tungsten physical execution backend.") - val DIALECT = stringConf("spark.sql.dialect", defaultValue = Some("sql"), doc = "") + val DIALECT = stringConf( + "spark.sql.dialect", + defaultValue = Some("sql"), + doc = "The default SQL dialect to use.") val CASE_SENSITIVE = booleanConf("spark.sql.caseSensitive", defaultValue = Some(true), - doc = "") + doc = "Whether the query analyzer should be case sensitive or not.") val PARQUET_SCHEMA_MERGING_ENABLED = booleanConf("spark.sql.parquet.mergeSchema", - defaultValue = Some(true), + defaultValue = Some(false), doc = "When true, the Parquet data source merges schemas collected from all data files, " + "otherwise the schema is picked from the summary file or a random data file " + "if no summary file is available.") @@ -273,9 +287,8 @@ private[spark] object SQLConf { val PARQUET_FOLLOW_PARQUET_FORMAT_SPEC = booleanConf( key = "spark.sql.parquet.followParquetFormatSpec", defaultValue = Some(false), - doc = "Whether to stick to Parquet format specification when converting Parquet schema to " + - "Spark SQL schema and vice versa. Sticks to the specification if set to true; falls back " + - "to compatible mode if set to false.", + doc = "Whether to follow Parquet's format specification when converting Parquet schema to " + + "Spark SQL schema and vice versa.", isPublic = false) val PARQUET_OUTPUT_COMMITTER_CLASS = stringConf( @@ -290,7 +303,7 @@ private[spark] object SQLConf { val ORC_FILTER_PUSHDOWN_ENABLED = booleanConf("spark.sql.orc.filterPushdown", defaultValue = Some(false), - doc = "") + doc = "When true, enable filter pushdown for ORC files.") val HIVE_VERIFY_PARTITION_PATH = booleanConf("spark.sql.hive.verifyPartitionPath", defaultValue = Some(true), @@ -302,7 +315,7 @@ private[spark] object SQLConf { val BROADCAST_TIMEOUT = intConf("spark.sql.broadcastTimeout", defaultValue = Some(5 * 60), - doc = "") + doc = "Timeout in seconds for the broadcast wait time in broadcast joins.") // Options that control which operators can be chosen by the query planner. These should be // considered hints and may be ignored by future versions of Spark SQL. @@ -313,7 +326,7 @@ private[spark] object SQLConf { val SORTMERGE_JOIN = booleanConf("spark.sql.planner.sortMergeJoin", defaultValue = Some(false), - doc = "") + doc = "When true, use sort merge join (as opposed to hash join) by default for large joins.") // This is only used for the thriftserver val THRIFTSERVER_POOL = stringConf("spark.sql.thriftserver.scheduler.pool", @@ -321,16 +334,16 @@ private[spark] object SQLConf { val THRIFTSERVER_UI_STATEMENT_LIMIT = intConf("spark.sql.thriftserver.ui.retainedStatements", defaultValue = Some(200), - doc = "") + doc = "The number of SQL statements kept in the JDBC/ODBC web UI history.") val THRIFTSERVER_UI_SESSION_LIMIT = intConf("spark.sql.thriftserver.ui.retainedSessions", defaultValue = Some(200), - doc = "") + doc = "The number of SQL client sessions kept in the JDBC/ODBC web UI history.") // This is used to set the default data source val DEFAULT_DATA_SOURCE_NAME = stringConf("spark.sql.sources.default", defaultValue = Some("org.apache.spark.sql.parquet"), - doc = "") + doc = "The default data source to use in input/output.") // This is used to control the when we will split a schema's JSON string to multiple pieces // in order to fit the JSON string in metastore's table property (by default, the value has @@ -338,18 +351,20 @@ private[spark] object SQLConf { // to its length exceeds the threshold. val SCHEMA_STRING_LENGTH_THRESHOLD = intConf("spark.sql.sources.schemaStringLengthThreshold", defaultValue = Some(4000), - doc = "") + doc = "The maximum length allowed in a single cell when " + + "storing additional schema information in Hive's metastore.", + isPublic = false) // Whether to perform partition discovery when loading external data sources. Default to true. val PARTITION_DISCOVERY_ENABLED = booleanConf("spark.sql.sources.partitionDiscovery.enabled", defaultValue = Some(true), - doc = "") + doc = "When true, automtically discover data partitions.") // Whether to perform partition column type inference. Default to true. val PARTITION_COLUMN_TYPE_INFERENCE = booleanConf("spark.sql.sources.partitionColumnTypeInference.enabled", defaultValue = Some(true), - doc = "") + doc = "When true, automatically infer the data types for partitioned columns.") // The output committer class used by HadoopFsRelation. The specified class needs to be a // subclass of org.apache.hadoop.mapreduce.OutputCommitter. @@ -361,27 +376,38 @@ private[spark] object SQLConf { val OUTPUT_COMMITTER_CLASS = stringConf("spark.sql.sources.outputCommitterClass", isPublic = false) + val PARALLEL_PARTITION_DISCOVERY_THRESHOLD = intConf( + key = "spark.sql.sources.parallelPartitionDiscovery.threshold", + defaultValue = Some(32), + doc = "") + // Whether to perform eager analysis when constructing a dataframe. // Set to false when debugging requires the ability to look at invalid query plans. - val DATAFRAME_EAGER_ANALYSIS = booleanConf("spark.sql.eagerAnalysis", + val DATAFRAME_EAGER_ANALYSIS = booleanConf( + "spark.sql.eagerAnalysis", defaultValue = Some(true), - doc = "") + doc = "When true, eagerly applies query analysis on DataFrame operations.", + isPublic = false) // Whether to automatically resolve ambiguity in join conditions for self-joins. // See SPARK-6231. - val DATAFRAME_SELF_JOIN_AUTO_RESOLVE_AMBIGUITY = - booleanConf("spark.sql.selfJoinAutoResolveAmbiguity", defaultValue = Some(true), doc = "") + val DATAFRAME_SELF_JOIN_AUTO_RESOLVE_AMBIGUITY = booleanConf( + "spark.sql.selfJoinAutoResolveAmbiguity", + defaultValue = Some(true), + isPublic = false) // Whether to retain group by columns or not in GroupedData.agg. - val DATAFRAME_RETAIN_GROUP_COLUMNS = booleanConf("spark.sql.retainGroupColumns", + val DATAFRAME_RETAIN_GROUP_COLUMNS = booleanConf( + "spark.sql.retainGroupColumns", defaultValue = Some(true), - doc = "") + isPublic = false) - val USE_SQL_SERIALIZER2 = booleanConf("spark.sql.useSerializer2", + val USE_SQL_AGGREGATE2 = booleanConf("spark.sql.useAggregate2", defaultValue = Some(true), doc = "") - val USE_JACKSON_STREAMING_API = booleanConf("spark.sql.json.useJacksonStreamingAPI", - defaultValue = Some(true), doc = "") + val USE_SQL_SERIALIZER2 = booleanConf( + "spark.sql.useSerializer2", + defaultValue = Some(true), isPublic = false) object Deprecated { val MAPRED_REDUCE_TASKS = "mapred.reduce.tasks" @@ -422,112 +448,53 @@ private[sql] class SQLConf extends Serializable with CatalystConf { */ private[spark] def dialect: String = getConf(DIALECT) - /** When true tables cached using the in-memory columnar caching will be compressed. */ private[spark] def useCompression: Boolean = getConf(COMPRESS_CACHED) - /** The compression codec for writing to a Parquetfile */ private[spark] def parquetCompressionCodec: String = getConf(PARQUET_COMPRESSION) private[spark] def parquetCacheMetadata: Boolean = getConf(PARQUET_CACHE_METADATA) - /** The number of rows that will be */ private[spark] def columnBatchSize: Int = getConf(COLUMN_BATCH_SIZE) - /** Number of partitions to use for shuffle operators. */ private[spark] def numShufflePartitions: Int = getConf(SHUFFLE_PARTITIONS) - /** When true predicates will be passed to the parquet record reader when possible. */ private[spark] def parquetFilterPushDown: Boolean = getConf(PARQUET_FILTER_PUSHDOWN_ENABLED) - /** When true uses Parquet implementation based on data source API */ private[spark] def parquetUseDataSourceApi: Boolean = getConf(PARQUET_USE_DATA_SOURCE_API) private[spark] def orcFilterPushDown: Boolean = getConf(ORC_FILTER_PUSHDOWN_ENABLED) - /** When true uses verifyPartitionPath to prune the path which is not exists. */ private[spark] def verifyPartitionPath: Boolean = getConf(HIVE_VERIFY_PARTITION_PATH) - /** When true the planner will use the external sort, which may spill to disk. */ private[spark] def externalSortEnabled: Boolean = getConf(EXTERNAL_SORT) - /** - * Sort merge join would sort the two side of join first, and then iterate both sides together - * only once to get all matches. Using sort merge join can save a lot of memory usage compared - * to HashJoin. - */ private[spark] def sortMergeJoinEnabled: Boolean = getConf(SORTMERGE_JOIN) - /** - * When set to true, Spark SQL will use the Janino at runtime to generate custom bytecode - * that evaluates expressions found in queries. In general this custom code runs much faster - * than interpreted evaluation, but there are some start-up costs (5-10ms) due to compilation. - */ private[spark] def codegenEnabled: Boolean = getConf(CODEGEN_ENABLED) - /** - * caseSensitive analysis true by default - */ def caseSensitiveAnalysis: Boolean = getConf(SQLConf.CASE_SENSITIVE) - /** - * When set to true, Spark SQL will use managed memory for certain operations. This option only - * takes effect if codegen is enabled. - * - * Defaults to false as this feature is currently experimental. - */ private[spark] def unsafeEnabled: Boolean = getConf(UNSAFE_ENABLED) - private[spark] def useSqlSerializer2: Boolean = getConf(USE_SQL_SERIALIZER2) + private[spark] def useSqlAggregate2: Boolean = getConf(USE_SQL_AGGREGATE2) - /** - * Selects between the new (true) and old (false) JSON handlers, to be removed in Spark 1.5.0 - */ - private[spark] def useJacksonStreamingAPI: Boolean = getConf(USE_JACKSON_STREAMING_API) + private[spark] def useSqlSerializer2: Boolean = getConf(USE_SQL_SERIALIZER2) - /** - * Upper bound on the sizes (in bytes) of the tables qualified for the auto conversion to - * a broadcast value during the physical executions of join operations. Setting this to -1 - * effectively disables auto conversion. - * - * Hive setting: hive.auto.convert.join.noconditionaltask.size, whose default value is 10000. - */ private[spark] def autoBroadcastJoinThreshold: Int = getConf(AUTO_BROADCASTJOIN_THRESHOLD) - /** - * The default size in bytes to assign to a logical operator's estimation statistics. By default, - * it is set to a larger value than `autoBroadcastJoinThreshold`, hence any logical operator - * without a properly implemented estimation of this statistic will not be incorrectly broadcasted - * in joins. - */ private[spark] def defaultSizeInBytes: Long = getConf(DEFAULT_SIZE_IN_BYTES, autoBroadcastJoinThreshold + 1L) - /** - * When set to true, we always treat byte arrays in Parquet files as strings. - */ private[spark] def isParquetBinaryAsString: Boolean = getConf(PARQUET_BINARY_AS_STRING) - /** - * When set to true, we always treat INT96Values in Parquet files as timestamp. - */ private[spark] def isParquetINT96AsTimestamp: Boolean = getConf(PARQUET_INT96_AS_TIMESTAMP) - /** - * When set to true, sticks to Parquet format spec when converting Parquet schema to Spark SQL - * schema and vice versa. Otherwise, falls back to compatible mode. - */ private[spark] def followParquetFormatSpec: Boolean = getConf(PARQUET_FOLLOW_PARQUET_FORMAT_SPEC) - /** - * When set to true, partition pruning for in-memory columnar tables is enabled. - */ private[spark] def inMemoryPartitionPruning: Boolean = getConf(IN_MEMORY_PARTITION_PRUNING) private[spark] def columnNameOfCorruptRecord: String = getConf(COLUMN_NAME_OF_CORRUPT_RECORD) - /** - * Timeout in seconds for the broadcast wait time in hash join - */ private[spark] def broadcastTimeout: Int = getConf(BROADCAST_TIMEOUT) private[spark] def defaultDataSourceName: String = getConf(DEFAULT_DATA_SOURCE_NAME) @@ -538,6 +505,9 @@ private[sql] class SQLConf extends Serializable with CatalystConf { private[spark] def partitionColumnTypeInferenceEnabled(): Boolean = getConf(SQLConf.PARTITION_COLUMN_TYPE_INFERENCE) + private[spark] def parallelPartitionDiscoveryThreshold: Int = + getConf(SQLConf.PARALLEL_PARTITION_DISCOVERY_THRESHOLD) + // Do not use a value larger than 4000 as the default value of this property. // See the comments of SCHEMA_STRING_LENGTH_THRESHOLD above for more information. private[spark] def schemaStringLengthThreshold: Int = getConf(SCHEMA_STRING_LENGTH_THRESHOLD) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala index 477dea9164726..49bfe74b680af 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala @@ -39,8 +39,9 @@ import org.apache.spark.sql.catalyst.optimizer.{DefaultOptimizer, Optimizer} import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan} import org.apache.spark.sql.catalyst.rules.RuleExecutor import org.apache.spark.sql.catalyst.{InternalRow, ParserDialect, _} -import org.apache.spark.sql.execution.{Filter, _} -import org.apache.spark.sql.sources._ +import org.apache.spark.sql.execution._ +import org.apache.spark.sql.execution.datasources._ +import org.apache.spark.sql.sources.BaseRelation import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.UTF8String import org.apache.spark.util.Utils @@ -146,11 +147,11 @@ class SQLContext(@transient val sparkContext: SparkContext) new Analyzer(catalog, functionRegistry, conf) { override val extendedResolutionRules = ExtractPythonUDFs :: - sources.PreInsertCastAndRename :: + PreInsertCastAndRename :: Nil override val extendedCheckRules = Seq( - sources.PreWriteCheck(catalog) + datasources.PreWriteCheck(catalog) ) } @@ -284,6 +285,9 @@ class SQLContext(@transient val sparkContext: SparkContext) @transient val udf: UDFRegistration = new UDFRegistration(this) + @transient + val udaf: UDAFRegistration = new UDAFRegistration(this) + /** * Returns true if the table is currently cached in-memory. * @group cachemgmt @@ -554,8 +558,7 @@ class SQLContext(@transient val sparkContext: SparkContext) val className = beanClass.getName val rowRdd = rdd.mapPartitions { iter => // BeanInfo is not serializable so we must rediscover it remotely for each partition. - val localBeanInfo = Introspector.getBeanInfo( - Class.forName(className, true, Utils.getContextOrSparkClassLoader)) + val localBeanInfo = Introspector.getBeanInfo(Utils.classForName(className)) val extractors = localBeanInfo.getPropertyDescriptors.filterNot(_.getName == "class").map(_.getReadMethod) val methodsToConverts = extractors.zip(attributeSeq).map { case (e, attr) => @@ -863,6 +866,7 @@ class SQLContext(@transient val sparkContext: SparkContext) DDLStrategy :: TakeOrderedAndProject :: HashAggregation :: + Aggregation :: LeftSemiJoin :: HashJoin :: InMemoryScans :: @@ -922,12 +926,15 @@ class SQLContext(@transient val sparkContext: SparkContext) protected[sql] lazy val emptyResult = sparkContext.parallelize(Seq.empty[InternalRow], 1) /** - * Prepares a planned SparkPlan for execution by inserting shuffle operations as needed. + * Prepares a planned SparkPlan for execution by inserting shuffle operations and internal + * row format conversions as needed. */ @transient protected[sql] val prepareForExecution = new RuleExecutor[SparkPlan] { - val batches = - Batch("Add exchange", Once, EnsureRequirements(self)) :: Nil + val batches = Seq( + Batch("Add exchange", Once, EnsureRequirements(self)), + Batch("Add row converters", Once, EnsureRowFormats) + ) } protected[sql] def openSession(): SQLSession = { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/UDAFRegistration.scala b/sql/core/src/main/scala/org/apache/spark/sql/UDAFRegistration.scala new file mode 100644 index 0000000000000..5b872f5e3eecd --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/UDAFRegistration.scala @@ -0,0 +1,35 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql + +import org.apache.spark.Logging +import org.apache.spark.sql.catalyst.expressions.{Expression} +import org.apache.spark.sql.expressions.aggregate.{ScalaUDAF, UserDefinedAggregateFunction} + +class UDAFRegistration private[sql] (sqlContext: SQLContext) extends Logging { + + private val functionRegistry = sqlContext.functionRegistry + + def register( + name: String, + func: UserDefinedAggregateFunction): UserDefinedAggregateFunction = { + def builder(children: Seq[Expression]) = ScalaUDAF(children, func) + functionRegistry.registerFunction(name, builder) + func + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/UDFRegistration.scala b/sql/core/src/main/scala/org/apache/spark/sql/UDFRegistration.scala index d35d37d017198..7cd7421a518c9 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/UDFRegistration.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/UDFRegistration.scala @@ -22,13 +22,10 @@ import java.util.{List => JList, Map => JMap} import scala.reflect.runtime.universe.TypeTag import scala.util.Try -import org.apache.spark.{Accumulator, Logging} -import org.apache.spark.api.python.PythonBroadcast -import org.apache.spark.broadcast.Broadcast +import org.apache.spark.Logging import org.apache.spark.sql.api.java._ import org.apache.spark.sql.catalyst.ScalaReflection import org.apache.spark.sql.catalyst.expressions.{Expression, ScalaUDF} -import org.apache.spark.sql.execution.PythonUDF import org.apache.spark.sql.types.DataType /** @@ -40,44 +37,19 @@ class UDFRegistration private[sql] (sqlContext: SQLContext) extends Logging { private val functionRegistry = sqlContext.functionRegistry - protected[sql] def registerPython( - name: String, - command: Array[Byte], - envVars: JMap[String, String], - pythonIncludes: JList[String], - pythonExec: String, - pythonVer: String, - broadcastVars: JList[Broadcast[PythonBroadcast]], - accumulator: Accumulator[JList[Array[Byte]]], - stringDataType: String): Unit = { + protected[sql] def registerPython(name: String, udf: UserDefinedPythonFunction): Unit = { log.debug( s""" | Registering new PythonUDF: | name: $name - | command: ${command.toSeq} - | envVars: $envVars - | pythonIncludes: $pythonIncludes - | pythonExec: $pythonExec - | dataType: $stringDataType + | command: ${udf.command.toSeq} + | envVars: ${udf.envVars} + | pythonIncludes: ${udf.pythonIncludes} + | pythonExec: ${udf.pythonExec} + | dataType: ${udf.dataType} """.stripMargin) - - val dataType = sqlContext.parseDataType(stringDataType) - - def builder(e: Seq[Expression]): PythonUDF = - PythonUDF( - name, - command, - envVars, - pythonIncludes, - pythonExec, - pythonVer, - broadcastVars, - accumulator, - dataType, - e) - - functionRegistry.registerFunction(name, builder) + functionRegistry.registerFunction(name, udf.builder) } // scalastyle:off diff --git a/sql/core/src/main/scala/org/apache/spark/sql/UserDefinedFunction.scala b/sql/core/src/main/scala/org/apache/spark/sql/UserDefinedFunction.scala index b14e00ab9b163..0f8cd280b5acb 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/UserDefinedFunction.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/UserDefinedFunction.scala @@ -23,7 +23,7 @@ import org.apache.spark.Accumulator import org.apache.spark.annotation.Experimental import org.apache.spark.api.python.PythonBroadcast import org.apache.spark.broadcast.Broadcast -import org.apache.spark.sql.catalyst.expressions.ScalaUDF +import org.apache.spark.sql.catalyst.expressions.{Expression, ScalaUDF} import org.apache.spark.sql.execution.PythonUDF import org.apache.spark.sql.types.DataType @@ -66,10 +66,14 @@ private[sql] case class UserDefinedPythonFunction( accumulator: Accumulator[JList[Array[Byte]]], dataType: DataType) { + def builder(e: Seq[Expression]): PythonUDF = { + PythonUDF(name, command, envVars, pythonIncludes, pythonExec, pythonVer, broadcastVars, + accumulator, dataType, e) + } + /** Returns a [[Column]] that will evaluate to calling this UDF with the given input. */ def apply(exprs: Column*): Column = { - val udf = PythonUDF(name, command, envVars, pythonIncludes, pythonExec, pythonVer, - broadcastVars, accumulator, dataType, exprs.map(_.expr)) + val udf = builder(exprs.map(_.expr)) Column(udf) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/api/r/SQLUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/api/r/SQLUtils.scala index 43b62f0e822f8..92861ab038f19 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/api/r/SQLUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/api/r/SQLUtils.scala @@ -47,6 +47,7 @@ private[r] object SQLUtils { dataType match { case "byte" => org.apache.spark.sql.types.ByteType case "integer" => org.apache.spark.sql.types.IntegerType + case "float" => org.apache.spark.sql.types.FloatType case "double" => org.apache.spark.sql.types.DoubleType case "numeric" => org.apache.spark.sql.types.DoubleType case "character" => org.apache.spark.sql.types.StringType @@ -68,7 +69,7 @@ private[r] object SQLUtils { def createDF(rdd: RDD[Array[Byte]], schema: StructType, sqlContext: SQLContext): DataFrame = { val num = schema.fields.size - val rowRDD = rdd.map(bytesToRow) + val rowRDD = rdd.map(bytesToRow(_, schema)) sqlContext.createDataFrame(rowRDD, schema) } @@ -76,12 +77,20 @@ private[r] object SQLUtils { df.map(r => rowToRBytes(r)) } - private[this] def bytesToRow(bytes: Array[Byte]): Row = { + private[this] def doConversion(data: Object, dataType: DataType): Object = { + data match { + case d: java.lang.Double if dataType == FloatType => + new java.lang.Float(d) + case _ => data + } + } + + private[this] def bytesToRow(bytes: Array[Byte], schema: StructType): Row = { val bis = new ByteArrayInputStream(bytes) val dis = new DataInputStream(bis) val num = SerDe.readInt(dis) Row.fromSeq((0 until num).map { i => - SerDe.readObject(dis) + doConversion(SerDe.readObject(dis), schema.fields(i).dataType) }.toSeq) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/Aggregate.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/Aggregate.scala index 3cd60a2aa55ed..c2c945321db95 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/Aggregate.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/Aggregate.scala @@ -68,14 +68,14 @@ case class Aggregate( * output. */ case class ComputedAggregate( - unbound: AggregateExpression, - aggregate: AggregateExpression, + unbound: AggregateExpression1, + aggregate: AggregateExpression1, resultAttribute: AttributeReference) /** A list of aggregates that need to be computed for each group. */ private[this] val computedAggregates = aggregateExpressions.flatMap { agg => agg.collect { - case a: AggregateExpression => + case a: AggregateExpression1 => ComputedAggregate( a, BindReferences.bindReference(a, child.output), @@ -87,8 +87,8 @@ case class Aggregate( private[this] val computedSchema = computedAggregates.map(_.resultAttribute) /** Creates a new aggregate buffer for a group. */ - private[this] def newAggregateBuffer(): Array[AggregateFunction] = { - val buffer = new Array[AggregateFunction](computedAggregates.length) + private[this] def newAggregateBuffer(): Array[AggregateFunction1] = { + val buffer = new Array[AggregateFunction1](computedAggregates.length) var i = 0 while (i < computedAggregates.length) { buffer(i) = computedAggregates(i).aggregate.newInstance() @@ -146,7 +146,7 @@ case class Aggregate( } } else { child.execute().mapPartitions { iter => - val hashTable = new HashMap[InternalRow, Array[AggregateFunction]] + val hashTable = new HashMap[InternalRow, Array[AggregateFunction1]] val groupingProjection = new InterpretedMutableProjection(groupingExpressions, child.output) var currentRow: InternalRow = null diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala index 4b783e30d95e1..d31e265a293e9 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala @@ -18,7 +18,7 @@ package org.apache.spark.sql.execution import org.apache.spark.annotation.DeveloperApi -import org.apache.spark.rdd.{RDD, ShuffledRDD} +import org.apache.spark.rdd.RDD import org.apache.spark.serializer.Serializer import org.apache.spark.shuffle.hash.HashShuffleManager import org.apache.spark.shuffle.sort.SortShuffleManager @@ -29,29 +29,26 @@ import org.apache.spark.sql.catalyst.errors.attachTree import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.physical._ import org.apache.spark.sql.catalyst.rules.Rule -import org.apache.spark.sql.types.DataType import org.apache.spark.util.MutablePair import org.apache.spark.{HashPartitioner, Partitioner, RangePartitioner, SparkEnv} /** * :: DeveloperApi :: - * Performs a shuffle that will result in the desired `newPartitioning`. Optionally sorts each - * resulting partition based on expressions from the partition key. It is invalid to construct an - * exchange operator with a `newOrdering` that cannot be calculated using the partitioning key. + * Performs a shuffle that will result in the desired `newPartitioning`. */ @DeveloperApi -case class Exchange( - newPartitioning: Partitioning, - newOrdering: Seq[SortOrder], - child: SparkPlan) - extends UnaryNode { +case class Exchange(newPartitioning: Partitioning, child: SparkPlan) extends UnaryNode { override def outputPartitioning: Partitioning = newPartitioning - override def outputOrdering: Seq[SortOrder] = newOrdering - override def output: Seq[Attribute] = child.output + override def outputsUnsafeRows: Boolean = child.outputsUnsafeRows + + override def canProcessSafeRows: Boolean = true + + override def canProcessUnsafeRows: Boolean = true + /** * Determines whether records must be defensively copied before being sent to the shuffle. * Several of Spark's shuffle components will buffer deserialized Java objects in memory. The @@ -120,109 +117,70 @@ case class Exchange( @transient private lazy val sparkConf = child.sqlContext.sparkContext.getConf - private def getSerializer( - keySchema: Array[DataType], - valueSchema: Array[DataType], - numPartitions: Int): Serializer = { + private val serializer: Serializer = { + val rowDataTypes = child.output.map(_.dataType).toArray // It is true when there is no field that needs to be write out. // For now, we will not use SparkSqlSerializer2 when noField is true. - val noField = - (keySchema == null || keySchema.length == 0) && - (valueSchema == null || valueSchema.length == 0) + val noField = rowDataTypes == null || rowDataTypes.length == 0 val useSqlSerializer2 = child.sqlContext.conf.useSqlSerializer2 && // SparkSqlSerializer2 is enabled. - SparkSqlSerializer2.support(keySchema) && // The schema of key is supported. - SparkSqlSerializer2.support(valueSchema) && // The schema of value is supported. + SparkSqlSerializer2.support(rowDataTypes) && // The schema of row is supported. !noField - val serializer = if (useSqlSerializer2) { + if (child.outputsUnsafeRows) { + logInfo("Using UnsafeRowSerializer.") + new UnsafeRowSerializer(child.output.size) + } else if (useSqlSerializer2) { logInfo("Using SparkSqlSerializer2.") - new SparkSqlSerializer2(keySchema, valueSchema) + new SparkSqlSerializer2(rowDataTypes) } else { logInfo("Using SparkSqlSerializer.") new SparkSqlSerializer(sparkConf) } - - serializer } protected override def doExecute(): RDD[InternalRow] = attachTree(this , "execute") { - newPartitioning match { - case HashPartitioning(expressions, numPartitions) => - val keySchema = expressions.map(_.dataType).toArray - val valueSchema = child.output.map(_.dataType).toArray - val serializer = getSerializer(keySchema, valueSchema, numPartitions) - val part = new HashPartitioner(numPartitions) - - val rdd = if (needToCopyObjectsBeforeShuffle(part, serializer)) { - child.execute().mapPartitions { iter => - val hashExpressions = newMutableProjection(expressions, child.output)() - iter.map(r => (hashExpressions(r).copy(), r.copy())) - } - } else { - child.execute().mapPartitions { iter => - val hashExpressions = newMutableProjection(expressions, child.output)() - val mutablePair = new MutablePair[InternalRow, InternalRow]() - iter.map(r => mutablePair.update(hashExpressions(r), r)) - } - } - val shuffled = new ShuffledRDD[InternalRow, InternalRow, InternalRow](rdd, part) - shuffled.setSerializer(serializer) - shuffled.map(_._2) - + val rdd = child.execute() + val part: Partitioner = newPartitioning match { + case HashPartitioning(expressions, numPartitions) => new HashPartitioner(numPartitions) case RangePartitioning(sortingExpressions, numPartitions) => - val keySchema = child.output.map(_.dataType).toArray - val serializer = getSerializer(keySchema, null, numPartitions) - - val childRdd = child.execute() - val part: Partitioner = { - // Internally, RangePartitioner runs a job on the RDD that samples keys to compute - // partition bounds. To get accurate samples, we need to copy the mutable keys. - val rddForSampling = childRdd.mapPartitions { iter => - val mutablePair = new MutablePair[InternalRow, Null]() - iter.map(row => mutablePair.update(row.copy(), null)) - } - // TODO: RangePartitioner should take an Ordering. - implicit val ordering = new RowOrdering(sortingExpressions, child.output) - new RangePartitioner(numPartitions, rddForSampling, ascending = true) + // Internally, RangePartitioner runs a job on the RDD that samples keys to compute + // partition bounds. To get accurate samples, we need to copy the mutable keys. + val rddForSampling = rdd.mapPartitions { iter => + val mutablePair = new MutablePair[InternalRow, Null]() + iter.map(row => mutablePair.update(row.copy(), null)) } - - val rdd = if (needToCopyObjectsBeforeShuffle(part, serializer)) { - childRdd.mapPartitions { iter => iter.map(row => (row.copy(), null))} - } else { - childRdd.mapPartitions { iter => - val mutablePair = new MutablePair[InternalRow, Null]() - iter.map(row => mutablePair.update(row, null)) - } - } - - val shuffled = new ShuffledRDD[InternalRow, Null, Null](rdd, part) - shuffled.setSerializer(serializer) - shuffled.map(_._1) - + implicit val ordering = new RowOrdering(sortingExpressions, child.output) + new RangePartitioner(numPartitions, rddForSampling, ascending = true) case SinglePartition => - val valueSchema = child.output.map(_.dataType).toArray - val serializer = getSerializer(null, valueSchema, numPartitions = 1) - val partitioner = new HashPartitioner(1) - - val rdd = if (needToCopyObjectsBeforeShuffle(partitioner, serializer)) { - child.execute().mapPartitions { - iter => iter.map(r => (null, r.copy())) - } - } else { - child.execute().mapPartitions { iter => - val mutablePair = new MutablePair[Null, InternalRow]() - iter.map(r => mutablePair.update(null, r)) - } + new Partitioner { + override def numPartitions: Int = 1 + override def getPartition(key: Any): Int = 0 } - val shuffled = new ShuffledRDD[Null, InternalRow, InternalRow](rdd, partitioner) - shuffled.setSerializer(serializer) - shuffled.map(_._2) - case _ => sys.error(s"Exchange not implemented for $newPartitioning") // TODO: Handle BroadcastPartitioning. } + def getPartitionKeyExtractor(): InternalRow => InternalRow = newPartitioning match { + case HashPartitioning(expressions, _) => newMutableProjection(expressions, child.output)() + case RangePartitioning(_, _) | SinglePartition => identity + case _ => sys.error(s"Exchange not implemented for $newPartitioning") + } + val rddWithPartitionIds: RDD[Product2[Int, InternalRow]] = { + if (needToCopyObjectsBeforeShuffle(part, serializer)) { + rdd.mapPartitions { iter => + val getPartitionKey = getPartitionKeyExtractor() + iter.map { row => (part.getPartition(getPartitionKey(row)), row.copy()) } + } + } else { + rdd.mapPartitions { iter => + val getPartitionKey = getPartitionKeyExtractor() + val mutablePair = new MutablePair[Int, InternalRow]() + iter.map { row => mutablePair.update(part.getPartition(getPartitionKey(row)), row) } + } + } + } + new ShuffledRowRDD(rddWithPartitionIds, serializer, part.numPartitions) } } @@ -279,23 +237,31 @@ private[sql] case class EnsureRequirements(sqlContext: SQLContext) extends Rule[ partitioning: Partitioning, rowOrdering: Seq[SortOrder], child: SparkPlan): SparkPlan = { - val needSort = rowOrdering.nonEmpty && child.outputOrdering != rowOrdering - val needsShuffle = child.outputPartitioning != partitioning - val withShuffle = if (needsShuffle) { - Exchange(partitioning, Nil, child) - } else { - child + def addShuffleIfNecessary(child: SparkPlan): SparkPlan = { + if (child.outputPartitioning != partitioning) { + Exchange(partitioning, child) + } else { + child + } } - val withSort = if (needSort) { - sqlContext.planner.BasicOperators.getSortOperator( - rowOrdering, global = false, withShuffle) - } else { - withShuffle + def addSortIfNecessary(child: SparkPlan): SparkPlan = { + + if (rowOrdering.nonEmpty) { + // If child.outputOrdering is [a, b] and rowOrdering is [a], we do not need to sort. + val minSize = Seq(rowOrdering.size, child.outputOrdering.size).min + if (minSize == 0 || rowOrdering.take(minSize) != child.outputOrdering.take(minSize)) { + sqlContext.planner.BasicOperators.getSortOperator(rowOrdering, global = false, child) + } else { + child + } + } else { + child + } } - withSort + addSortIfNecessary(addShuffleIfNecessary(child)) } if (meetsRequirements && compatible && !needsAnySort) { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/GeneratedAggregate.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/GeneratedAggregate.scala index c069da016f9f0..16176abe3a51d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/GeneratedAggregate.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/GeneratedAggregate.scala @@ -69,7 +69,7 @@ case class GeneratedAggregate( protected override def doExecute(): RDD[InternalRow] = { val aggregatesToCompute = aggregateExpressions.flatMap { a => - a.collect { case agg: AggregateExpression => agg} + a.collect { case agg: AggregateExpression1 => agg} } // If you add any new function support, please add tests in org.apache.spark.sql.SQLQuerySuite @@ -239,6 +239,11 @@ case class GeneratedAggregate( StructType(fields) } + val schemaSupportsUnsafe: Boolean = { + UnsafeFixedWidthAggregationMap.supportsAggregationBufferSchema(aggregationBufferSchema) && + UnsafeFixedWidthAggregationMap.supportsGroupKeySchema(groupKeySchema) + } + child.execute().mapPartitions { iter => // Builds a new custom class for holding the results of aggregation for a group. val initialValues = computeFunctions.flatMap(_.initialValues) @@ -266,7 +271,18 @@ case class GeneratedAggregate( val joinedRow = new JoinedRow3 - if (groupingExpressions.isEmpty) { + if (!iter.hasNext) { + // This is an empty input, so return early so that we do not allocate data structures + // that won't be cleaned up (see SPARK-8357). + if (groupingExpressions.isEmpty) { + // This is a global aggregate, so return an empty aggregation buffer. + val resultProjection = resultProjectionBuilder() + Iterator(resultProjection(newAggregationBuffer(EmptyRow))) + } else { + // This is a grouped aggregate, so return an empty iterator. + Iterator[InternalRow]() + } + } else if (groupingExpressions.isEmpty) { // TODO: Codegening anything other than the updateProjection is probably over kill. val buffer = newAggregationBuffer(EmptyRow).asInstanceOf[MutableRow] var currentRow: InternalRow = null @@ -279,12 +295,14 @@ case class GeneratedAggregate( val resultProjection = resultProjectionBuilder() Iterator(resultProjection(buffer)) - } else if (unsafeEnabled) { + + } else if (unsafeEnabled && schemaSupportsUnsafe) { + assert(iter.hasNext, "There should be at least one row for this path") log.info("Using Unsafe-based aggregator") val aggregationMap = new UnsafeFixedWidthAggregationMap( - newAggregationBuffer, - new UnsafeRowConverter(groupKeySchema), - new UnsafeRowConverter(aggregationBufferSchema), + newAggregationBuffer(EmptyRow), + aggregationBufferSchema, + groupKeySchema, TaskContext.get.taskMemoryManager(), 1024 * 16, // initial capacity false // disable tracking of performance metrics @@ -319,6 +337,9 @@ case class GeneratedAggregate( } } } else { + if (unsafeEnabled) { + log.info("Not using Unsafe-based aggregator because it is not supported for this schema") + } val buffers = new java.util.HashMap[InternalRow, MutableRow]() var currentRow: InternalRow = null diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/LocalTableScan.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/LocalTableScan.scala index cd341180b6100..34e926e4582be 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/LocalTableScan.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/LocalTableScan.scala @@ -34,13 +34,11 @@ private[sql] case class LocalTableScan( protected override def doExecute(): RDD[InternalRow] = rdd - override def executeCollect(): Array[Row] = { val converter = CatalystTypeConverters.createToScalaConverter(schema) rows.map(converter(_).asInstanceOf[Row]).toArray } - override def executeTake(limit: Int): Array[Row] = { val converter = CatalystTypeConverters.createToScalaConverter(schema) rows.map(converter(_).asInstanceOf[Row]).take(limit).toArray diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/ShuffledRowRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/ShuffledRowRDD.scala new file mode 100644 index 0000000000000..88f5b13c8f248 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/ShuffledRowRDD.scala @@ -0,0 +1,80 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution + +import org.apache.spark._ +import org.apache.spark.rdd.RDD +import org.apache.spark.serializer.Serializer +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.types.DataType + +private class ShuffledRowRDDPartition(val idx: Int) extends Partition { + override val index: Int = idx + override def hashCode(): Int = idx +} + +/** + * A dummy partitioner for use with records whose partition ids have been pre-computed (i.e. for + * use on RDDs of (Int, Row) pairs where the Int is a partition id in the expected range). + */ +private class PartitionIdPassthrough(override val numPartitions: Int) extends Partitioner { + override def getPartition(key: Any): Int = key.asInstanceOf[Int] +} + +/** + * This is a specialized version of [[org.apache.spark.rdd.ShuffledRDD]] that is optimized for + * shuffling rows instead of Java key-value pairs. Note that something like this should eventually + * be implemented in Spark core, but that is blocked by some more general refactorings to shuffle + * interfaces / internals. + * + * @param prev the RDD being shuffled. Elements of this RDD are (partitionId, Row) pairs. + * Partition ids should be in the range [0, numPartitions - 1]. + * @param serializer the serializer used during the shuffle. + * @param numPartitions the number of post-shuffle partitions. + */ +class ShuffledRowRDD( + @transient var prev: RDD[Product2[Int, InternalRow]], + serializer: Serializer, + numPartitions: Int) + extends RDD[InternalRow](prev.context, Nil) { + + private val part: Partitioner = new PartitionIdPassthrough(numPartitions) + + override def getDependencies: Seq[Dependency[_]] = { + List(new ShuffleDependency[Int, InternalRow, InternalRow](prev, part, Some(serializer))) + } + + override val partitioner = Some(part) + + override def getPartitions: Array[Partition] = { + Array.tabulate[Partition](part.numPartitions)(i => new ShuffledRowRDDPartition(i)) + } + + override def compute(split: Partition, context: TaskContext): Iterator[InternalRow] = { + val dep = dependencies.head.asInstanceOf[ShuffleDependency[Int, InternalRow, InternalRow]] + SparkEnv.get.shuffleManager.getReader(dep.shuffleHandle, split.index, split.index + 1, context) + .read() + .asInstanceOf[Iterator[Product2[Int, InternalRow]]] + .map(_._2) + } + + override def clearDependencies() { + super.clearDependencies() + prev = null + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala index 4d7d8626a0ecc..50c27def8ea54 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala @@ -17,20 +17,20 @@ package org.apache.spark.sql.execution -import org.apache.spark.annotation.DeveloperApi +import scala.collection.mutable.ArrayBuffer + import org.apache.spark.Logging +import org.apache.spark.annotation.DeveloperApi import org.apache.spark.rdd.{RDD, RDDOperationScope} import org.apache.spark.sql.SQLContext import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.{CatalystTypeConverters, trees} +import org.apache.spark.sql.catalyst.CatalystTypeConverters import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.Row import org.apache.spark.sql.catalyst.expressions.codegen._ import org.apache.spark.sql.catalyst.plans.QueryPlan import org.apache.spark.sql.catalyst.plans.physical._ -import scala.collection.mutable.ArrayBuffer - object SparkPlan { protected[sql] val currentContext = new ThreadLocal[SQLContext]() } @@ -40,7 +40,6 @@ object SparkPlan { */ @DeveloperApi abstract class SparkPlan extends QueryPlan[SparkPlan] with Logging with Serializable { - self: Product => /** * A handle to the SQL Context that was used to create this plan. Since many operators need @@ -80,12 +79,36 @@ abstract class SparkPlan extends QueryPlan[SparkPlan] with Logging with Serializ /** Specifies sort order for each partition requirements on the input data for this operator. */ def requiredChildOrdering: Seq[Seq[SortOrder]] = Seq.fill(children.size)(Nil) + /** Specifies whether this operator outputs UnsafeRows */ + def outputsUnsafeRows: Boolean = false + + /** Specifies whether this operator is capable of processing UnsafeRows */ + def canProcessUnsafeRows: Boolean = false + + /** + * Specifies whether this operator is capable of processing Java-object-based Rows (i.e. rows + * that are not UnsafeRows). + */ + def canProcessSafeRows: Boolean = true + /** * Returns the result of this query as an RDD[InternalRow] by delegating to doExecute * after adding query plan information to created RDDs for visualization. * Concrete implementations of SparkPlan should override doExecute instead. */ final def execute(): RDD[InternalRow] = { + if (children.nonEmpty) { + val hasUnsafeInputs = children.exists(_.outputsUnsafeRows) + val hasSafeInputs = children.exists(!_.outputsUnsafeRows) + assert(!(hasSafeInputs && hasUnsafeInputs), + "Child operators should output rows in the same format") + assert(canProcessSafeRows || canProcessUnsafeRows, + "Operator must be able to process at least one row format") + assert(!hasSafeInputs || canProcessSafeRows, + "Operator will receive safe rows as input but cannot process safe rows") + assert(!hasUnsafeInputs || canProcessUnsafeRows, + "Operator will receive unsafe rows as input but cannot process unsafe rows") + } RDDOperationScope.withScope(sparkContext, nodeName, false, true) { doExecute() } @@ -142,8 +165,7 @@ abstract class SparkPlan extends QueryPlan[SparkPlan] with Logging with Serializ val p = partsScanned until math.min(partsScanned + numPartsToTry, totalParts) val sc = sqlContext.sparkContext val res = - sc.runJob(childRDD, (it: Iterator[InternalRow]) => it.take(left).toArray, p, - allowLocal = false) + sc.runJob(childRDD, (it: Iterator[InternalRow]) => it.take(left).toArray, p) res.foreach(buf ++= _.take(n - buf.size)) partsScanned += numPartsToTry @@ -238,15 +260,21 @@ abstract class SparkPlan extends QueryPlan[SparkPlan] with Logging with Serializ } } -private[sql] trait LeafNode extends SparkPlan with trees.LeafNode[SparkPlan] { - self: Product => +private[sql] trait LeafNode extends SparkPlan { + override def children: Seq[SparkPlan] = Nil } -private[sql] trait UnaryNode extends SparkPlan with trees.UnaryNode[SparkPlan] { - self: Product => +private[sql] trait UnaryNode extends SparkPlan { + def child: SparkPlan + + override def children: Seq[SparkPlan] = child :: Nil + override def outputPartitioning: Partitioning = child.outputPartitioning } -private[sql] trait BinaryNode extends SparkPlan with trees.BinaryNode[SparkPlan] { - self: Product => +private[sql] trait BinaryNode extends SparkPlan { + def left: SparkPlan + def right: SparkPlan + + override def children: Seq[SparkPlan] = Seq(left, right) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlSerializer2.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlSerializer2.scala index 6ed822dc70d68..c87e2064a8f33 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlSerializer2.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlSerializer2.scala @@ -45,14 +45,12 @@ import org.apache.spark.unsafe.types.UTF8String * the comment of the `serializer` method in [[Exchange]] for more information on it. */ private[sql] class Serializer2SerializationStream( - keySchema: Array[DataType], - valueSchema: Array[DataType], + rowSchema: Array[DataType], out: OutputStream) extends SerializationStream with Logging { private val rowOut = new DataOutputStream(new BufferedOutputStream(out)) - private val writeKeyFunc = SparkSqlSerializer2.createSerializationFunction(keySchema, rowOut) - private val writeValueFunc = SparkSqlSerializer2.createSerializationFunction(valueSchema, rowOut) + private val writeRowFunc = SparkSqlSerializer2.createSerializationFunction(rowSchema, rowOut) override def writeObject[T: ClassTag](t: T): SerializationStream = { val kv = t.asInstanceOf[Product2[Row, Row]] @@ -63,12 +61,12 @@ private[sql] class Serializer2SerializationStream( } override def writeKey[T: ClassTag](t: T): SerializationStream = { - writeKeyFunc(t.asInstanceOf[Row]) + // No-op. this } override def writeValue[T: ClassTag](t: T): SerializationStream = { - writeValueFunc(t.asInstanceOf[Row]) + writeRowFunc(t.asInstanceOf[Row]) this } @@ -85,8 +83,7 @@ private[sql] class Serializer2SerializationStream( * The corresponding deserialization stream for [[Serializer2SerializationStream]]. */ private[sql] class Serializer2DeserializationStream( - keySchema: Array[DataType], - valueSchema: Array[DataType], + rowSchema: Array[DataType], in: InputStream) extends DeserializationStream with Logging { @@ -103,22 +100,20 @@ private[sql] class Serializer2DeserializationStream( } // Functions used to return rows for key and value. - private val getKey = rowGenerator(keySchema) - private val getValue = rowGenerator(valueSchema) + private val getRow = rowGenerator(rowSchema) // Functions used to read a serialized row from the InputStream and deserialize it. - private val readKeyFunc = SparkSqlSerializer2.createDeserializationFunction(keySchema, rowIn) - private val readValueFunc = SparkSqlSerializer2.createDeserializationFunction(valueSchema, rowIn) + private val readRowFunc = SparkSqlSerializer2.createDeserializationFunction(rowSchema, rowIn) override def readObject[T: ClassTag](): T = { - (readKeyFunc(getKey()), readValueFunc(getValue())).asInstanceOf[T] + readValue() } override def readKey[T: ClassTag](): T = { - readKeyFunc(getKey()).asInstanceOf[T] + null.asInstanceOf[T] // intentionally left blank. } override def readValue[T: ClassTag](): T = { - readValueFunc(getValue()).asInstanceOf[T] + readRowFunc(getRow()).asInstanceOf[T] } override def close(): Unit = { @@ -127,8 +122,7 @@ private[sql] class Serializer2DeserializationStream( } private[sql] class SparkSqlSerializer2Instance( - keySchema: Array[DataType], - valueSchema: Array[DataType]) + rowSchema: Array[DataType]) extends SerializerInstance { def serialize[T: ClassTag](t: T): ByteBuffer = @@ -141,30 +135,25 @@ private[sql] class SparkSqlSerializer2Instance( throw new UnsupportedOperationException("Not supported.") def serializeStream(s: OutputStream): SerializationStream = { - new Serializer2SerializationStream(keySchema, valueSchema, s) + new Serializer2SerializationStream(rowSchema, s) } def deserializeStream(s: InputStream): DeserializationStream = { - new Serializer2DeserializationStream(keySchema, valueSchema, s) + new Serializer2DeserializationStream(rowSchema, s) } } /** * SparkSqlSerializer2 is a special serializer that creates serialization function and * deserialization function based on the schema of data. It assumes that values passed in - * are key/value pairs and values returned from it are also key/value pairs. - * The schema of keys is represented by `keySchema` and that of values is represented by - * `valueSchema`. + * are Rows. */ -private[sql] class SparkSqlSerializer2( - keySchema: Array[DataType], - valueSchema: Array[DataType]) +private[sql] class SparkSqlSerializer2(rowSchema: Array[DataType]) extends Serializer with Logging with Serializable{ - def newInstance(): SerializerInstance = - new SparkSqlSerializer2Instance(keySchema, valueSchema) + def newInstance(): SerializerInstance = new SparkSqlSerializer2Instance(rowSchema) override def supportsRelocationOfSerializedObjects: Boolean = { // SparkSqlSerializer2 is stateless and writes no stream headers diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala index ce25af58b6cab..f54aa2027f6a6 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala @@ -17,18 +17,19 @@ package org.apache.spark.sql.execution +import org.apache.spark.sql.{SQLContext, Strategy, execution} import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateExpression2} import org.apache.spark.sql.catalyst.planning._ import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.plans.logical.{BroadcastHint, LogicalPlan} import org.apache.spark.sql.catalyst.plans.physical._ import org.apache.spark.sql.columnar.{InMemoryColumnarTableScan, InMemoryRelation} import org.apache.spark.sql.execution.{DescribeCommand => RunnableDescribeCommand} +import org.apache.spark.sql.execution.datasources.{CreateTableUsing, CreateTempTableUsing, DescribeCommand => LogicalDescribeCommand, _} import org.apache.spark.sql.parquet._ -import org.apache.spark.sql.sources.{CreateTableUsing, CreateTempTableUsing, DescribeCommand => LogicalDescribeCommand, _} import org.apache.spark.sql.types._ -import org.apache.spark.sql.{SQLContext, Strategy, execution} private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] { self: SQLContext#SparkPlanner => @@ -38,14 +39,12 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] { case ExtractEquiJoinKeys(LeftSemi, leftKeys, rightKeys, condition, left, right) if sqlContext.conf.autoBroadcastJoinThreshold > 0 && right.statistics.sizeInBytes <= sqlContext.conf.autoBroadcastJoinThreshold => - val semiJoin = joins.BroadcastLeftSemiJoinHash( - leftKeys, rightKeys, planLater(left), planLater(right)) - condition.map(Filter(_, semiJoin)).getOrElse(semiJoin) :: Nil + joins.BroadcastLeftSemiJoinHash( + leftKeys, rightKeys, planLater(left), planLater(right), condition) :: Nil // Find left semi joins where at least some predicates can be evaluated by matching join keys case ExtractEquiJoinKeys(LeftSemi, leftKeys, rightKeys, condition, left, right) => - val semiJoin = joins.LeftSemiJoinHash( - leftKeys, rightKeys, planLater(left), planLater(right)) - condition.map(Filter(_, semiJoin)).getOrElse(semiJoin) :: Nil + joins.LeftSemiJoinHash( + leftKeys, rightKeys, planLater(left), planLater(right), condition) :: Nil // no predicate can be evaluated by matching hash keys case logical.Join(left, right, LeftSemi, condition) => joins.LeftSemiJoinBNL(planLater(left), planLater(right), condition) :: Nil @@ -150,7 +149,8 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] { if canBeCodeGened( allAggregates(partialComputation) ++ allAggregates(rewrittenAggregateExpressions)) && - codegenEnabled => + codegenEnabled && + !canBeConvertedToNewAggregation(plan) => execution.GeneratedAggregate( partial = false, namedGroupingAttributes, @@ -169,7 +169,7 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] { rewrittenAggregateExpressions, groupingExpressions, partialComputation, - child) => + child) if !canBeConvertedToNewAggregation(plan) => execution.Aggregate( partial = false, namedGroupingAttributes, @@ -183,7 +183,14 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] { case _ => Nil } - def canBeCodeGened(aggs: Seq[AggregateExpression]): Boolean = !aggs.exists { + def canBeConvertedToNewAggregation(plan: LogicalPlan): Boolean = { + aggregate.Utils.tryConvert( + plan, + sqlContext.conf.useSqlAggregate2, + sqlContext.conf.codegenEnabled).isDefined + } + + def canBeCodeGened(aggs: Seq[AggregateExpression1]): Boolean = !aggs.exists { case _: CombineSum | _: Sum | _: Count | _: Max | _: Min | _: CombineSetsAndCount => false // The generated set implementation is pretty limited ATM. case CollectHashSet(exprs) if exprs.size == 1 && @@ -191,10 +198,74 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] { case _ => true } - def allAggregates(exprs: Seq[Expression]): Seq[AggregateExpression] = - exprs.flatMap(_.collect { case a: AggregateExpression => a }) + def allAggregates(exprs: Seq[Expression]): Seq[AggregateExpression1] = + exprs.flatMap(_.collect { case a: AggregateExpression1 => a }) } + /** + * Used to plan the aggregate operator for expressions based on the AggregateFunction2 interface. + */ + object Aggregation extends Strategy { + def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match { + case p: logical.Aggregate => + val converted = + aggregate.Utils.tryConvert( + p, + sqlContext.conf.useSqlAggregate2, + sqlContext.conf.codegenEnabled) + converted match { + case None => Nil // Cannot convert to new aggregation code path. + case Some(logical.Aggregate(groupingExpressions, resultExpressions, child)) => + // Extracts all distinct aggregate expressions from the resultExpressions. + val aggregateExpressions = resultExpressions.flatMap { expr => + expr.collect { + case agg: AggregateExpression2 => agg + } + }.toSet.toSeq + // For those distinct aggregate expressions, we create a map from the + // aggregate function to the corresponding attribute of the function. + val aggregateFunctionMap = aggregateExpressions.map { agg => + val aggregateFunction = agg.aggregateFunction + (aggregateFunction, agg.isDistinct) -> + Alias(aggregateFunction, aggregateFunction.toString)().toAttribute + }.toMap + + val (functionsWithDistinct, functionsWithoutDistinct) = + aggregateExpressions.partition(_.isDistinct) + if (functionsWithDistinct.map(_.aggregateFunction.children).distinct.length > 1) { + // This is a sanity check. We should not reach here when we have multiple distinct + // column sets (aggregate.NewAggregation will not match). + sys.error( + "Multiple distinct column sets are not supported by the new aggregation" + + "code path.") + } + + val aggregateOperator = + if (functionsWithDistinct.isEmpty) { + aggregate.Utils.planAggregateWithoutDistinct( + groupingExpressions, + aggregateExpressions, + aggregateFunctionMap, + resultExpressions, + planLater(child)) + } else { + aggregate.Utils.planAggregateWithOneDistinct( + groupingExpressions, + functionsWithDistinct, + functionsWithoutDistinct, + aggregateFunctionMap, + resultExpressions, + planLater(child)) + } + + aggregateOperator + } + + case _ => Nil + } + } + + object BroadcastNestedLoopJoin extends Strategy { def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match { case logical.Join(left, right, joinType, condition) => @@ -338,8 +409,21 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] { execution.Filter(condition, planLater(child)) :: Nil case e @ logical.Expand(_, _, _, child) => execution.Expand(e.projections, e.output, planLater(child)) :: Nil - case logical.Aggregate(group, agg, child) => - execution.Aggregate(partial = false, group, agg, planLater(child)) :: Nil + case a @ logical.Aggregate(group, agg, child) => { + val useNewAggregation = + aggregate.Utils.tryConvert( + a, + sqlContext.conf.useSqlAggregate2, + sqlContext.conf.codegenEnabled).isDefined + if (useNewAggregation) { + // If this logical.Aggregate can be planned to use new aggregation code path + // (i.e. it can be planned by the Strategy Aggregation), we will not use the old + // aggregation code path. + Nil + } else { + execution.Aggregate(partial = false, group, agg, planLater(child)) :: Nil + } + } case logical.Window(projectList, windowExpressions, spec, child) => execution.Window(projectList, windowExpressions, spec, planLater(child)) :: Nil case logical.Sample(lb, ub, withReplacement, seed, child) => @@ -360,8 +444,7 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] { case logical.OneRowRelation => execution.PhysicalRDD(Nil, singleRowRdd) :: Nil case logical.RepartitionByExpression(expressions, child) => - execution.Exchange( - HashPartitioning(expressions, numPartitions), Nil, planLater(child)) :: Nil + execution.Exchange(HashPartitioning(expressions, numPartitions), planLater(child)) :: Nil case e @ EvaluatePython(udf, child, _) => BatchPythonEvaluation(udf, e.output, planLater(child)) :: Nil case LogicalRDD(output, rdd) => PhysicalRDD(output, rdd) :: Nil diff --git a/sql/core/src/main/scala/org/apache/spark/sql/sources/SqlNewHadoopRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SqlNewHadoopRDD.scala similarity index 99% rename from sql/core/src/main/scala/org/apache/spark/sql/sources/SqlNewHadoopRDD.scala rename to sql/core/src/main/scala/org/apache/spark/sql/execution/SqlNewHadoopRDD.scala index 2bdc341021256..e1c1a6c06268f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/sources/SqlNewHadoopRDD.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SqlNewHadoopRDD.scala @@ -15,24 +15,23 @@ * limitations under the License. */ -package org.apache.spark.sql.sources +package org.apache.spark.sql.execution import java.text.SimpleDateFormat import java.util.Date +import org.apache.spark.{Partition => SparkPartition, _} import org.apache.hadoop.conf.{Configurable, Configuration} import org.apache.hadoop.io.Writable import org.apache.hadoop.mapreduce._ import org.apache.hadoop.mapreduce.lib.input.{CombineFileSplit, FileSplit} -import org.apache.spark.broadcast.Broadcast - -import org.apache.spark.{Partition => SparkPartition, _} import org.apache.spark.annotation.DeveloperApi +import org.apache.spark.broadcast.Broadcast import org.apache.spark.deploy.SparkHadoopUtil import org.apache.spark.executor.DataReadMethod import org.apache.spark.mapreduce.SparkHadoopMapReduceUtil -import org.apache.spark.rdd.{RDD, HadoopRDD} import org.apache.spark.rdd.NewHadoopRDD.NewHadoopMapPartitionsWithSplitRDD +import org.apache.spark.rdd.{HadoopRDD, RDD} import org.apache.spark.storage.StorageLevel import org.apache.spark.util.{SerializableConfiguration, Utils} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/UnsafeRowSerializer.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/UnsafeRowSerializer.scala new file mode 100644 index 0000000000000..16498da080c88 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/UnsafeRowSerializer.scala @@ -0,0 +1,172 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution + +import java.io.{DataInputStream, DataOutputStream, OutputStream, InputStream} +import java.nio.ByteBuffer + +import scala.reflect.ClassTag + +import com.google.common.io.ByteStreams + +import org.apache.spark.serializer.{SerializationStream, DeserializationStream, SerializerInstance, Serializer} +import org.apache.spark.sql.catalyst.expressions.UnsafeRow +import org.apache.spark.unsafe.PlatformDependent + +/** + * Serializer for serializing [[UnsafeRow]]s during shuffle. Since UnsafeRows are already stored as + * bytes, this serializer simply copies those bytes to the underlying output stream. When + * deserializing a stream of rows, instances of this serializer mutate and return a single UnsafeRow + * instance that is backed by an on-heap byte array. + * + * Note that this serializer implements only the [[Serializer]] methods that are used during + * shuffle, so certain [[SerializerInstance]] methods will throw UnsupportedOperationException. + * + * @param numFields the number of fields in the row being serialized. + */ +private[sql] class UnsafeRowSerializer(numFields: Int) extends Serializer with Serializable { + override def newInstance(): SerializerInstance = new UnsafeRowSerializerInstance(numFields) + override private[spark] def supportsRelocationOfSerializedObjects: Boolean = true +} + +private class UnsafeRowSerializerInstance(numFields: Int) extends SerializerInstance { + + /** + * Marks the end of a stream written with [[serializeStream()]]. + */ + private[this] val EOF: Int = -1 + + /** + * Serializes a stream of UnsafeRows. Within the stream, each record consists of a record + * length (stored as a 4-byte integer, written high byte first), followed by the record's bytes. + * The end of the stream is denoted by a record with the special length `EOF` (-1). + */ + override def serializeStream(out: OutputStream): SerializationStream = new SerializationStream { + private[this] var writeBuffer: Array[Byte] = new Array[Byte](4096) + private[this] val dOut: DataOutputStream = new DataOutputStream(out) + + override def writeValue[T: ClassTag](value: T): SerializationStream = { + val row = value.asInstanceOf[UnsafeRow] + dOut.writeInt(row.getSizeInBytes) + row.writeToStream(out, writeBuffer) + this + } + + override def writeKey[T: ClassTag](key: T): SerializationStream = { + // The key is only needed on the map side when computing partition ids. It does not need to + // be shuffled. + assert(key.isInstanceOf[Int]) + this + } + + override def writeAll[T: ClassTag](iter: Iterator[T]): SerializationStream = { + // This method is never called by shuffle code. + throw new UnsupportedOperationException + } + + override def writeObject[T: ClassTag](t: T): SerializationStream = { + // This method is never called by shuffle code. + throw new UnsupportedOperationException + } + + override def flush(): Unit = { + dOut.flush() + } + + override def close(): Unit = { + writeBuffer = null + dOut.writeInt(EOF) + dOut.close() + } + } + + override def deserializeStream(in: InputStream): DeserializationStream = { + new DeserializationStream { + private[this] val dIn: DataInputStream = new DataInputStream(in) + // 1024 is a default buffer size; this buffer will grow to accommodate larger rows + private[this] var rowBuffer: Array[Byte] = new Array[Byte](1024) + private[this] var row: UnsafeRow = new UnsafeRow() + private[this] var rowTuple: (Int, UnsafeRow) = (0, row) + + override def asKeyValueIterator: Iterator[(Int, UnsafeRow)] = { + new Iterator[(Int, UnsafeRow)] { + private[this] var rowSize: Int = dIn.readInt() + + override def hasNext: Boolean = rowSize != EOF + + override def next(): (Int, UnsafeRow) = { + if (rowBuffer.length < rowSize) { + rowBuffer = new Array[Byte](rowSize) + } + ByteStreams.readFully(in, rowBuffer, 0, rowSize) + row.pointTo(rowBuffer, PlatformDependent.BYTE_ARRAY_OFFSET, numFields, rowSize) + rowSize = dIn.readInt() // read the next row's size + if (rowSize == EOF) { // We are returning the last row in this stream + val _rowTuple = rowTuple + // Null these out so that the byte array can be garbage collected once the entire + // iterator has been consumed + row = null + rowBuffer = null + rowTuple = null + _rowTuple + } else { + rowTuple + } + } + } + } + + override def asIterator: Iterator[Any] = { + // This method is never called by shuffle code. + throw new UnsupportedOperationException + } + + override def readKey[T: ClassTag](): T = { + // We skipped serialization of the key in writeKey(), so just return a dummy value since + // this is going to be discarded anyways. + null.asInstanceOf[T] + } + + override def readValue[T: ClassTag](): T = { + val rowSize = dIn.readInt() + if (rowBuffer.length < rowSize) { + rowBuffer = new Array[Byte](rowSize) + } + ByteStreams.readFully(in, rowBuffer, 0, rowSize) + row.pointTo(rowBuffer, PlatformDependent.BYTE_ARRAY_OFFSET, numFields, rowSize) + row.asInstanceOf[T] + } + + override def readObject[T: ClassTag](): T = { + // This method is never called by shuffle code. + throw new UnsupportedOperationException + } + + override def close(): Unit = { + dIn.close() + } + } + } + + // These methods are never called by shuffle code. + override def serialize[T: ClassTag](t: T): ByteBuffer = throw new UnsupportedOperationException + override def deserialize[T: ClassTag](bytes: ByteBuffer): T = + throw new UnsupportedOperationException + override def deserialize[T: ClassTag](bytes: ByteBuffer, loader: ClassLoader): T = + throw new UnsupportedOperationException +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/Window.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/Window.scala index 6e127e548a120..de04132eb1104 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/Window.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/Window.scala @@ -19,18 +19,64 @@ package org.apache.spark.sql.execution import java.util -import org.apache.spark.rdd.RDD +import org.apache.spark.annotation.DeveloperApi import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.plans.physical.{AllTuples, ClusteredDistribution, Distribution, Partitioning} +import org.apache.spark.sql.catalyst.plans.physical._ +import org.apache.spark.sql.types.IntegerType +import org.apache.spark.rdd.RDD import org.apache.spark.util.collection.CompactBuffer +import scala.collection.mutable /** * :: DeveloperApi :: - * For every row, evaluates `windowExpression` containing Window Functions and attaches - * the results with other regular expressions (presented by `projectList`). - * Evert operator handles a single Window Specification, `windowSpec`. + * This class calculates and outputs (windowed) aggregates over the rows in a single (sorted) + * partition. The aggregates are calculated for each row in the group. Special processing + * instructions, frames, are used to calculate these aggregates. Frames are processed in the order + * specified in the window specification (the ORDER BY ... clause). There are four different frame + * types: + * - Entire partition: The frame is the entire partition, i.e. + * UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING. For this case, window function will take all + * rows as inputs and be evaluated once. + * - Growing frame: We only add new rows into the frame, i.e. UNBOUNDED PRECEDING AND .... + * Every time we move to a new row to process, we add some rows to the frame. We do not remove + * rows from this frame. + * - Shrinking frame: We only remove rows from the frame, i.e. ... AND UNBOUNDED FOLLOWING. + * Every time we move to a new row to process, we remove some rows from the frame. We do not add + * rows to this frame. + * - Moving frame: Every time we move to a new row to process, we remove some rows from the frame + * and we add some rows to the frame. Examples are: + * 1 PRECEDING AND CURRENT ROW and 1 FOLLOWING AND 2 FOLLOWING. + * + * Different frame boundaries can be used in Growing, Shrinking and Moving frames. A frame + * boundary can be either Row or Range based: + * - Row Based: A row based boundary is based on the position of the row within the partition. + * An offset indicates the number of rows above or below the current row, the frame for the + * current row starts or ends. For instance, given a row based sliding frame with a lower bound + * offset of -1 and a upper bound offset of +2. The frame for row with index 5 would range from + * index 4 to index 6. + * - Range based: A range based boundary is based on the actual value of the ORDER BY + * expression(s). An offset is used to alter the value of the ORDER BY expression, for + * instance if the current order by expression has a value of 10 and the lower bound offset + * is -3, the resulting lower bound for the current row will be 10 - 3 = 7. This however puts a + * number of constraints on the ORDER BY expressions: there can be only one expression and this + * expression must have a numerical data type. An exception can be made when the offset is 0, + * because no value modification is needed, in this case multiple and non-numeric ORDER BY + * expression are allowed. + * + * This is quite an expensive operator because every row for a single group must be in the same + * partition and partitions must be sorted according to the grouping and sort order. The operator + * requires the planner to take care of the partitioning and sorting. + * + * The operator is semi-blocking. The window functions and aggregates are calculated one group at + * a time, the result will only be made available after the processing for the entire group has + * finished. The operator is able to process different frame configurations at the same time. This + * is done by delegating the actual frame processing (i.e. calculation of the window functions) to + * specialized classes, see [[WindowFunctionFrame]], which take care of their own frame type: + * Entire Partition, Sliding, Growing & Shrinking. Boundary evaluation is also delegated to a pair + * of specialized classes: [[RowBoundOrdering]] & [[RangeBoundOrdering]]. */ +@DeveloperApi case class Window( projectList: Seq[Attribute], windowExpression: Seq[NamedExpression], @@ -38,443 +84,669 @@ case class Window( child: SparkPlan) extends UnaryNode { - override def output: Seq[Attribute] = - (projectList ++ windowExpression).map(_.toAttribute) + override def output: Seq[Attribute] = projectList ++ windowExpression.map(_.toAttribute) - override def requiredChildDistribution: Seq[Distribution] = + override def requiredChildDistribution: Seq[Distribution] = { if (windowSpec.partitionSpec.isEmpty) { - // This operator will be very expensive. + // Only show warning when the number of bytes is larger than 100 MB? + logWarning("No Partition Defined for Window operation! Moving all data to a single " + + "partition, this can cause serious performance degradation.") AllTuples :: Nil - } else { - ClusteredDistribution(windowSpec.partitionSpec) :: Nil - } - - // Since window functions are adding columns to the input rows, the child's outputPartitioning - // is preserved. - override def outputPartitioning: Partitioning = child.outputPartitioning - - override def requiredChildOrdering: Seq[Seq[SortOrder]] = { - // The required child ordering has two parts. - // The first part is the expressions in the partition specification. - // We add these expressions to the required ordering to make sure input rows are grouped - // based on the partition specification. So, we only need to process a single partition - // at a time. - // The second part is the expressions specified in the ORDER BY cluase. - // Basically, we first use sort to group rows based on partition specifications and then sort - // Rows in a group based on the order specification. - (windowSpec.partitionSpec.map(SortOrder(_, Ascending)) ++ windowSpec.orderSpec) :: Nil + } else ClusteredDistribution(windowSpec.partitionSpec) :: Nil } - // Since window functions basically add columns to input rows, this operator - // will not change the ordering of input rows. + override def requiredChildOrdering: Seq[Seq[SortOrder]] = + Seq(windowSpec.partitionSpec.map(SortOrder(_, Ascending)) ++ windowSpec.orderSpec) + override def outputOrdering: Seq[SortOrder] = child.outputOrdering - case class ComputedWindow( - unbound: WindowExpression, - windowFunction: WindowFunction, - resultAttribute: AttributeReference) - - // A list of window functions that need to be computed for each group. - private[this] val computedWindowExpressions = windowExpression.flatMap { window => - window.collect { - case w: WindowExpression => - ComputedWindow( - w, - BindReferences.bindReference(w.windowFunction, child.output), - AttributeReference(s"windowResult:$w", w.dataType, w.nullable)()) + /** + * Create a bound ordering object for a given frame type and offset. A bound ordering object is + * used to determine which input row lies within the frame boundaries of an output row. + * + * This method uses Code Generation. It can only be used on the executor side. + * + * @param frameType to evaluate. This can either be Row or Range based. + * @param offset with respect to the row. + * @return a bound ordering object. + */ + private[this] def createBoundOrdering(frameType: FrameType, offset: Int): BoundOrdering = { + frameType match { + case RangeFrame => + val (exprs, current, bound) = if (offset == 0) { + // Use the entire order expression when the offset is 0. + val exprs = windowSpec.orderSpec.map(_.child) + val projection = newMutableProjection(exprs, child.output) + (windowSpec.orderSpec, projection(), projection()) + } else if (windowSpec.orderSpec.size == 1) { + // Use only the first order expression when the offset is non-null. + val sortExpr = windowSpec.orderSpec.head + val expr = sortExpr.child + // Create the projection which returns the current 'value'. + val current = newMutableProjection(expr :: Nil, child.output)() + // Flip the sign of the offset when processing the order is descending + val boundOffset = + if (sortExpr.direction == Descending) { + -offset + } else { + offset + } + // Create the projection which returns the current 'value' modified by adding the offset. + val boundExpr = Add(expr, Cast(Literal.create(boundOffset, IntegerType), expr.dataType)) + val bound = newMutableProjection(boundExpr :: Nil, child.output)() + (sortExpr :: Nil, current, bound) + } else { + sys.error("Non-Zero range offsets are not supported for windows " + + "with multiple order expressions.") + } + // Construct the ordering. This is used to compare the result of current value projection + // to the result of bound value projection. This is done manually because we want to use + // Code Generation (if it is enabled). + val (sortExprs, schema) = exprs.map { case e => + val ref = AttributeReference("ordExpr", e.dataType, e.nullable)() + (SortOrder(ref, e.direction), ref) + }.unzip + val ordering = newOrdering(sortExprs, schema) + RangeBoundOrdering(ordering, current, bound) + case RowFrame => RowBoundOrdering(offset) } - }.toArray + } - private[this] val windowFrame = - windowSpec.frameSpecification.asInstanceOf[SpecifiedWindowFrame] + /** + * Create a frame processor. + * + * This method uses Code Generation. It can only be used on the executor side. + * + * @param frame boundaries. + * @param functions to process in the frame. + * @param ordinal at which the processor starts writing to the output. + * @return a frame processor. + */ + private[this] def createFrameProcessor( + frame: WindowFrame, + functions: Array[WindowFunction], + ordinal: Int): WindowFunctionFrame = frame match { + // Growing Frame. + case SpecifiedWindowFrame(frameType, UnboundedPreceding, FrameBoundaryExtractor(high)) => + val uBoundOrdering = createBoundOrdering(frameType, high) + new UnboundedPrecedingWindowFunctionFrame(ordinal, functions, uBoundOrdering) + + // Shrinking Frame. + case SpecifiedWindowFrame(frameType, FrameBoundaryExtractor(low), UnboundedFollowing) => + val lBoundOrdering = createBoundOrdering(frameType, low) + new UnboundedFollowingWindowFunctionFrame(ordinal, functions, lBoundOrdering) + + // Moving Frame. + case SpecifiedWindowFrame(frameType, + FrameBoundaryExtractor(low), FrameBoundaryExtractor(high)) => + val lBoundOrdering = createBoundOrdering(frameType, low) + val uBoundOrdering = createBoundOrdering(frameType, high) + new SlidingWindowFunctionFrame(ordinal, functions, lBoundOrdering, uBoundOrdering) + + // Entire Partition Frame. + case SpecifiedWindowFrame(_, UnboundedPreceding, UnboundedFollowing) => + new UnboundedWindowFunctionFrame(ordinal, functions) + + // Error + case fr => + sys.error(s"Unsupported Frame $fr for functions: $functions") + } - // Create window functions. - private[this] def windowFunctions(): Array[WindowFunction] = { - val functions = new Array[WindowFunction](computedWindowExpressions.length) - var i = 0 - while (i < computedWindowExpressions.length) { - functions(i) = computedWindowExpressions(i).windowFunction.newInstance() - functions(i).init() - i += 1 + /** + * Create the resulting projection. + * + * This method uses Code Generation. It can only be used on the executor side. + * + * @param expressions unbound ordered function expressions. + * @return the final resulting projection. + */ + private[this] def createResultProjection( + expressions: Seq[Expression]): MutableProjection = { + val unboundToAttr = expressions.map { + e => (e, AttributeReference("windowResult", e.dataType, e.nullable)()) } - functions + val unboundToAttrMap = unboundToAttr.toMap + val patchedWindowExpression = windowExpression.map(_.transform(unboundToAttrMap)) + newMutableProjection( + projectList ++ patchedWindowExpression, + child.output ++ unboundToAttr.map(_._2))() } - // The schema of the result of all window function evaluations - private[this] val computedSchema = computedWindowExpressions.map(_.resultAttribute) - - private[this] val computedResultMap = - computedWindowExpressions.map { w => w.unbound -> w.resultAttribute }.toMap + protected override def doExecute(): RDD[InternalRow] = { + // Prepare processing. + // Group the window expression by their processing frame. + val windowExprs = windowExpression.flatMap { + _.collect { + case e: WindowExpression => e + } + } - private[this] val windowExpressionResult = windowExpression.map { window => - window.transform { - case w: WindowExpression if computedResultMap.contains(w) => computedResultMap(w) + // Create Frame processor factories and order the unbound window expressions by the frame they + // are processed in; this is the order in which their results will be written to window + // function result buffer. + val framedWindowExprs = windowExprs.groupBy(_.windowSpec.frameSpecification) + val factories = Array.ofDim[() => WindowFunctionFrame](framedWindowExprs.size) + val unboundExpressions = mutable.Buffer.empty[Expression] + framedWindowExprs.zipWithIndex.foreach { + case ((frame, unboundFrameExpressions), index) => + // Track the ordinal. + val ordinal = unboundExpressions.size + + // Track the unbound expressions + unboundExpressions ++= unboundFrameExpressions + + // Bind the expressions. + val functions = unboundFrameExpressions.map { e => + BindReferences.bindReference(e.windowFunction, child.output) + }.toArray + + // Create the frame processor factory. + factories(index) = () => createFrameProcessor(frame, functions, ordinal) } - } - protected override def doExecute(): RDD[InternalRow] = { - child.execute().mapPartitions { iter => + // Start processing. + child.execute().mapPartitions { stream => new Iterator[InternalRow] { - // Although input rows are grouped based on windowSpec.partitionSpec, we need to - // know when we have a new partition. - // This is to manually construct an ordering that can be used to compare rows. - // TODO: We may want to have a newOrdering that takes BoundReferences. - // So, we can take advantave of code gen. - private val partitionOrdering: Ordering[InternalRow] = - RowOrdering.forSchema(windowSpec.partitionSpec.map(_.dataType)) - - // This is used to project expressions for the partition specification. - protected val partitionGenerator = - newMutableProjection(windowSpec.partitionSpec, child.output)() - - // This is ued to project expressions for the order specification. - protected val rowOrderGenerator = - newMutableProjection(windowSpec.orderSpec.map(_.child), child.output)() - - // The position of next output row in the inputRowBuffer. - var rowPosition: Int = 0 - // The number of buffered rows in the inputRowBuffer (the size of the current partition). - var partitionSize: Int = 0 - // The buffer used to buffer rows in a partition. - var inputRowBuffer: CompactBuffer[InternalRow] = _ - // The partition key of the current partition. - var currentPartitionKey: InternalRow = _ - // The partition key of next partition. - var nextPartitionKey: InternalRow = _ - // The first row of next partition. - var firstRowInNextPartition: InternalRow = _ - // Indicates if this partition is the last one in the iter. - var lastPartition: Boolean = false - - def createBoundaryEvaluator(): () => Unit = { - def findPhysicalBoundary( - boundary: FrameBoundary): () => Int = boundary match { - case UnboundedPreceding => () => 0 - case UnboundedFollowing => () => partitionSize - 1 - case CurrentRow => () => rowPosition - case ValuePreceding(value) => - () => - val newPosition = rowPosition - value - if (newPosition > 0) newPosition else 0 - case ValueFollowing(value) => - () => - val newPosition = rowPosition + value - if (newPosition < partitionSize) newPosition else partitionSize - 1 + // Get all relevant projections. + val result = createResultProjection(unboundExpressions) + val grouping = newProjection(windowSpec.partitionSpec, child.output) + + // Manage the stream and the grouping. + var nextRow: InternalRow = EmptyRow + var nextGroup: InternalRow = EmptyRow + var nextRowAvailable: Boolean = false + private[this] def fetchNextRow() { + nextRowAvailable = stream.hasNext + if (nextRowAvailable) { + nextRow = stream.next() + nextGroup = grouping(nextRow) + } else { + nextRow = EmptyRow + nextGroup = EmptyRow } - - def findLogicalBoundary( - boundary: FrameBoundary, - searchDirection: Int, - evaluator: Expression, - joinedRow: JoinedRow): () => Int = boundary match { - case UnboundedPreceding => () => 0 - case UnboundedFollowing => () => partitionSize - 1 - case other => - () => { - // CurrentRow, ValuePreceding, or ValueFollowing. - var newPosition = rowPosition + searchDirection - var stopSearch = false - // rowOrderGenerator is a mutable projection. - // We need to make a copy of the returned by rowOrderGenerator since we will - // compare searched row with this currentOrderByValue. - val currentOrderByValue = rowOrderGenerator(inputRowBuffer(rowPosition)).copy() - while (newPosition >= 0 && newPosition < partitionSize && !stopSearch) { - val r = rowOrderGenerator(inputRowBuffer(newPosition)) - stopSearch = - !(evaluator.eval(joinedRow(currentOrderByValue, r)).asInstanceOf[Boolean]) - if (!stopSearch) { - newPosition += searchDirection - } - } - newPosition -= searchDirection - - if (newPosition < 0) { - 0 - } else if (newPosition >= partitionSize) { - partitionSize - 1 - } else { - newPosition - } - } + } + fetchNextRow() + + // Manage the current partition. + var rows: CompactBuffer[InternalRow] = _ + val frames: Array[WindowFunctionFrame] = factories.map(_()) + val numFrames = frames.length + private[this] def fetchNextPartition() { + // Collect all the rows in the current partition. + val currentGroup = nextGroup + rows = new CompactBuffer + while (nextRowAvailable && nextGroup == currentGroup) { + rows += nextRow.copy() + fetchNextRow() } - windowFrame.frameType match { - case RowFrame => - val findStart = findPhysicalBoundary(windowFrame.frameStart) - val findEnd = findPhysicalBoundary(windowFrame.frameEnd) - () => { - frameStart = findStart() - frameEnd = findEnd() - } - case RangeFrame => - val joinedRowForBoundaryEvaluation: JoinedRow = new JoinedRow() - val orderByExpr = windowSpec.orderSpec.head - val currentRowExpr = - BoundReference(0, orderByExpr.dataType, orderByExpr.nullable) - val examedRowExpr = - BoundReference(1, orderByExpr.dataType, orderByExpr.nullable) - val differenceExpr = Abs(Subtract(currentRowExpr, examedRowExpr)) - - val frameStartEvaluator = windowFrame.frameStart match { - case CurrentRow => EqualTo(currentRowExpr, examedRowExpr) - case ValuePreceding(value) => - LessThanOrEqual(differenceExpr, Cast(Literal(value), orderByExpr.dataType)) - case ValueFollowing(value) => - GreaterThanOrEqual(differenceExpr, Cast(Literal(value), orderByExpr.dataType)) - case o => Literal(true) // This is just a dummy expression, we will not use it. - } - - val frameEndEvaluator = windowFrame.frameEnd match { - case CurrentRow => EqualTo(currentRowExpr, examedRowExpr) - case ValuePreceding(value) => - GreaterThanOrEqual(differenceExpr, Cast(Literal(value), orderByExpr.dataType)) - case ValueFollowing(value) => - LessThanOrEqual(differenceExpr, Cast(Literal(value), orderByExpr.dataType)) - case o => Literal(true) // This is just a dummy expression, we will not use it. - } - - val findStart = - findLogicalBoundary( - boundary = windowFrame.frameStart, - searchDirection = -1, - evaluator = frameStartEvaluator, - joinedRow = joinedRowForBoundaryEvaluation) - val findEnd = - findLogicalBoundary( - boundary = windowFrame.frameEnd, - searchDirection = 1, - evaluator = frameEndEvaluator, - joinedRow = joinedRowForBoundaryEvaluation) - () => { - frameStart = findStart() - frameEnd = findEnd() - } + // Setup the frames. + var i = 0 + while (i < numFrames) { + frames(i).prepare(rows) + i += 1 } + + // Setup iteration + rowIndex = 0 + rowsSize = rows.size } - val boundaryEvaluator = createBoundaryEvaluator() - // Indicates if we the specified window frame requires us to maintain a sliding frame - // (e.g. RANGES BETWEEN 1 PRECEDING AND CURRENT ROW) or the window frame - // is the entire partition (e.g. ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING). - val requireUpdateFrame: Boolean = { - def requireUpdateBoundary(boundary: FrameBoundary): Boolean = boundary match { - case UnboundedPreceding => false - case UnboundedFollowing => false - case _ => true - } + // Iteration + var rowIndex = 0 + var rowsSize = 0 + override final def hasNext: Boolean = rowIndex < rowsSize || nextRowAvailable - requireUpdateBoundary(windowFrame.frameStart) || - requireUpdateBoundary(windowFrame.frameEnd) - } - // The start position of the current frame in the partition. - var frameStart: Int = 0 - // The end position of the current frame in the partition. - var frameEnd: Int = -1 - // Window functions. - val functions: Array[WindowFunction] = windowFunctions() - // Buffers used to store input parameters for window functions. Because we may need to - // maintain a sliding frame, we use this buffer to avoid evaluate the parameters from - // the same row multiple times. - val windowFunctionParameterBuffers: Array[util.LinkedList[AnyRef]] = - functions.map(_ => new util.LinkedList[AnyRef]()) - - // The projection used to generate the final result rows of this operator. - private[this] val resultProjection = - newMutableProjection( - projectList ++ windowExpressionResult, - projectList ++ computedSchema)() - - // The row used to hold results of window functions. - private[this] val windowExpressionResultRow = - new GenericMutableRow(computedSchema.length) - - private[this] val joinedRow = new JoinedRow6 - - // Initialize this iterator. - initialize() - - private def initialize(): Unit = { - if (iter.hasNext) { - val currentRow = iter.next().copy() - // partitionGenerator is a mutable projection. Since we need to track nextPartitionKey, - // we are making a copy of the returned partitionKey at here. - nextPartitionKey = partitionGenerator(currentRow).copy() - firstRowInNextPartition = currentRow + val join = new JoinedRow6 + val windowFunctionResult = new GenericMutableRow(unboundExpressions.size) + override final def next(): InternalRow = { + // Load the next partition if we need to. + if (rowIndex >= rowsSize && nextRowAvailable) { fetchNextPartition() - } else { - // The iter is an empty one. So, we set all of the following variables - // to make sure hasNext will return false. - lastPartition = true - rowPosition = 0 - partitionSize = 0 } - } - - // Indicates if we will have new output row. - override final def hasNext: Boolean = { - !lastPartition || (rowPosition < partitionSize) - } - override final def next(): InternalRow = { - if (hasNext) { - if (rowPosition == partitionSize) { - // All rows of this buffer have been consumed. - // We will move to next partition. - fetchNextPartition() - } - // Get the input row for the current output row. - val inputRow = inputRowBuffer(rowPosition) - // Get all results of the window functions for this output row. + if (rowIndex < rowsSize) { + // Get the results for the window frames. var i = 0 - while (i < functions.length) { - windowExpressionResultRow.update(i, functions(i).get(rowPosition)) + while (i < numFrames) { + frames(i).write(windowFunctionResult) i += 1 } - // Construct the output row. - val outputRow = resultProjection(joinedRow(inputRow, windowExpressionResultRow)) - // We will move to the next one. - rowPosition += 1 - if (requireUpdateFrame && rowPosition < partitionSize) { - // If we need to maintain a sliding frame and - // we will still work on this partition when next is called next time, do the update. - updateFrame() - } + // 'Merge' the input row with the window function result + join(rows(rowIndex), windowFunctionResult) + rowIndex += 1 - // Return the output row. - outputRow - } else { - // no more result - throw new NoSuchElementException - } + // Return the projection. + result(join) + } else throw new NoSuchElementException } + } + } + } +} - // Fetch the next partition. - private def fetchNextPartition(): Unit = { - // Create a new buffer for input rows. - inputRowBuffer = new CompactBuffer[InternalRow]() - // We already have the first row for this partition - // (recorded in firstRowInNextPartition). Add it back. - inputRowBuffer += firstRowInNextPartition - // Set the current partition key. - currentPartitionKey = nextPartitionKey - // Now, we will start to find all rows belonging to this partition. - // Create a variable to track if we see the next partition. - var findNextPartition = false - // The search will stop when we see the next partition or there is no - // input row left in the iter. - while (iter.hasNext && !findNextPartition) { - // Make a copy of the input row since we will put it in the buffer. - val currentRow = iter.next().copy() - // Get the partition key based on the partition specification. - // For the below compare method, we do not need to make a copy of partitionKey. - val partitionKey = partitionGenerator(currentRow) - // Check if the current row belongs the current input row. - val comparing = partitionOrdering.compare(currentPartitionKey, partitionKey) - if (comparing == 0) { - // This row is still in the current partition. - inputRowBuffer += currentRow - } else { - // The current input row is in a different partition. - findNextPartition = true - // partitionGenerator is a mutable projection. - // Since we need to track nextPartitionKey and we determine that it should be set - // as partitionKey, we are making a copy of the partitionKey at here. - nextPartitionKey = partitionKey.copy() - firstRowInNextPartition = currentRow - } - } +/** + * Function for comparing boundary values. + */ +private[execution] abstract class BoundOrdering { + def compare(input: Seq[InternalRow], inputIndex: Int, outputIndex: Int): Int +} - // We have not seen a new partition. It means that there is no new row in the - // iter. The current partition is the last partition of the iter. - if (!findNextPartition) { - lastPartition = true - } +/** + * Compare the input index to the bound of the output index. + */ +private[execution] final case class RowBoundOrdering(offset: Int) extends BoundOrdering { + override def compare(input: Seq[InternalRow], inputIndex: Int, outputIndex: Int): Int = + inputIndex - (outputIndex + offset) +} - // We have got all rows for the current partition. - // Set rowPosition to 0 (the next output row will be based on the first - // input row of this partition). - rowPosition = 0 - // The size of this partition. - partitionSize = inputRowBuffer.size - // Reset all parameter buffers of window functions. - var i = 0 - while (i < windowFunctionParameterBuffers.length) { - windowFunctionParameterBuffers(i).clear() - i += 1 - } - frameStart = 0 - frameEnd = -1 - // Create the first window frame for this partition. - // If we do not need to maintain a sliding frame, this frame will - // have the entire partition. - updateFrame() - } +/** + * Compare the value of the input index to the value bound of the output index. + */ +private[execution] final case class RangeBoundOrdering( + ordering: Ordering[InternalRow], + current: Projection, + bound: Projection) extends BoundOrdering { + override def compare(input: Seq[InternalRow], inputIndex: Int, outputIndex: Int): Int = + ordering.compare(current(input(inputIndex)), bound(input(outputIndex))) +} - /** The function used to maintain the sliding frame. */ - private def updateFrame(): Unit = { - // Based on the difference between the new frame and old frame, - // updates the buffers holding input parameters of window functions. - // We will start to prepare input parameters starting from the row - // indicated by offset in the input row buffer. - def updateWindowFunctionParameterBuffers( - numToRemove: Int, - numToAdd: Int, - offset: Int): Unit = { - // First, remove unneeded entries from the head of every buffer. - var i = 0 - while (i < numToRemove) { - var j = 0 - while (j < windowFunctionParameterBuffers.length) { - windowFunctionParameterBuffers(j).remove() - j += 1 - } - i += 1 - } - // Then, add needed entries to the tail of every buffer. - i = 0 - while (i < numToAdd) { - var j = 0 - while (j < windowFunctionParameterBuffers.length) { - // Ask the function to prepare the input parameters. - val parameters = functions(j).prepareInputParameters(inputRowBuffer(i + offset)) - windowFunctionParameterBuffers(j).add(parameters) - j += 1 - } - i += 1 - } - } +/** + * A window function calculates the results of a number of window functions for a window frame. + * Before use a frame must be prepared by passing it all the rows in the current partition. After + * preparation the update method can be called to fill the output rows. + * + * TODO How to improve performance? A few thoughts: + * - Window functions are expensive due to its distribution and ordering requirements. + * Unfortunately it is up to the Spark engine to solve this. Improvements in the form of project + * Tungsten are on the way. + * - The window frame processing bit can be improved though. But before we start doing that we + * need to see how much of the time and resources are spent on partitioning and ordering, and + * how much time and resources are spent processing the partitions. There are a couple ways to + * improve on the current situation: + * - Reduce memory footprint by performing streaming calculations. This can only be done when + * there are no Unbound/Unbounded Following calculations present. + * - Use Tungsten style memory usage. + * - Use code generation in general, and use the approach to aggregation taken in the + * GeneratedAggregate class in specific. + * + * @param ordinal of the first column written by this frame. + * @param functions to calculate the row values with. + */ +private[execution] abstract class WindowFunctionFrame( + ordinal: Int, + functions: Array[WindowFunction]) { + + // Make sure functions are initialized. + functions.foreach(_.init()) + + /** Number of columns the window function frame is managing */ + val numColumns = functions.length + + /** + * Create a fresh thread safe copy of the frame. + * + * @return the copied frame. + */ + def copy: WindowFunctionFrame + + /** + * Create new instances of the functions. + * + * @return an array containing copies of the current window functions. + */ + protected final def copyFunctions: Array[WindowFunction] = functions.map(_.newInstance()) + + /** + * Prepare the frame for calculating the results for a partition. + * + * @param rows to calculate the frame results for. + */ + def prepare(rows: CompactBuffer[InternalRow]): Unit + + /** + * Write the result for the current row to the given target row. + * + * @param target row to write the result for the current row to. + */ + def write(target: GenericMutableRow): Unit + + /** Reset the current window functions. */ + protected final def reset(): Unit = { + var i = 0 + while (i < numColumns) { + functions(i).reset() + i += 1 + } + } - // Record the current frame start point and end point before - // we update them. - val previousFrameStart = frameStart - val previousFrameEnd = frameEnd - boundaryEvaluator() - updateWindowFunctionParameterBuffers( - frameStart - previousFrameStart, - frameEnd - previousFrameEnd, - previousFrameEnd + 1) - // Evaluate the current frame. - evaluateCurrentFrame() - } + /** Prepare an input row for processing. */ + protected final def prepare(input: InternalRow): Array[AnyRef] = { + val prepared = new Array[AnyRef](numColumns) + var i = 0 + while (i < numColumns) { + prepared(i) = functions(i).prepareInputParameters(input) + i += 1 + } + prepared + } - /** Evaluate the current window frame. */ - private def evaluateCurrentFrame(): Unit = { - var i = 0 - while (i < functions.length) { - // Reset the state of the window function. - functions(i).reset() - // Get all buffered input parameters based on rows of this window frame. - val inputParameters = windowFunctionParameterBuffers(i).toArray() - // Send these input parameters to the window function. - functions(i).batchUpdate(inputParameters) - // Ask the function to evaluate based on this window frame. - functions(i).evaluate() - i += 1 - } - } + /** Evaluate a prepared buffer (iterator). */ + protected final def evaluatePrepared(iterator: java.util.Iterator[Array[AnyRef]]): Unit = { + reset() + while (iterator.hasNext) { + val prepared = iterator.next() + var i = 0 + while (i < numColumns) { + functions(i).update(prepared(i)) + i += 1 + } + } + evaluate() + } + + /** Evaluate a prepared buffer (array). */ + protected final def evaluatePrepared(prepared: Array[Array[AnyRef]], + fromIndex: Int, toIndex: Int): Unit = { + var i = 0 + while (i < numColumns) { + val function = functions(i) + function.reset() + var j = fromIndex + while (j < toIndex) { + function.update(prepared(j)(i)) + j += 1 } + function.evaluate() + i += 1 + } + } + + /** Update an array of window functions. */ + protected final def update(input: InternalRow): Unit = { + var i = 0 + while (i < numColumns) { + val aggregate = functions(i) + val preparedInput = aggregate.prepareInputParameters(input) + aggregate.update(preparedInput) + i += 1 + } + } + + /** Evaluate the window functions. */ + protected final def evaluate(): Unit = { + var i = 0 + while (i < numColumns) { + functions(i).evaluate() + i += 1 + } + } + + /** Fill a target row with the current window function results. */ + protected final def fill(target: GenericMutableRow, rowIndex: Int): Unit = { + var i = 0 + while (i < numColumns) { + target.update(ordinal + i, functions(i).get(rowIndex)) + i += 1 + } + } +} + +/** + * The sliding window frame calculates frames with the following SQL form: + * ... BETWEEN 1 PRECEDING AND 1 FOLLOWING + * + * @param ordinal of the first column written by this frame. + * @param functions to calculate the row values with. + * @param lbound comparator used to identify the lower bound of an output row. + * @param ubound comparator used to identify the upper bound of an output row. + */ +private[execution] final class SlidingWindowFunctionFrame( + ordinal: Int, + functions: Array[WindowFunction], + lbound: BoundOrdering, + ubound: BoundOrdering) extends WindowFunctionFrame(ordinal, functions) { + + /** Rows of the partition currently being processed. */ + private[this] var input: CompactBuffer[InternalRow] = null + + /** Index of the first input row with a value greater than the upper bound of the current + * output row. */ + private[this] var inputHighIndex = 0 + + /** Index of the first input row with a value equal to or greater than the lower bound of the + * current output row. */ + private[this] var inputLowIndex = 0 + + /** Buffer used for storing prepared input for the window functions. */ + private[this] val buffer = new util.ArrayDeque[Array[AnyRef]] + + /** Index of the row we are currently writing. */ + private[this] var outputIndex = 0 + + /** Prepare the frame for calculating a new partition. Reset all variables. */ + override def prepare(rows: CompactBuffer[InternalRow]): Unit = { + input = rows + inputHighIndex = 0 + inputLowIndex = 0 + outputIndex = 0 + buffer.clear() + } + + /** Write the frame columns for the current row to the given target row. */ + override def write(target: GenericMutableRow): Unit = { + var bufferUpdated = outputIndex == 0 + + // Add all rows to the buffer for which the input row value is equal to or less than + // the output row upper bound. + while (inputHighIndex < input.size && + ubound.compare(input, inputHighIndex, outputIndex) <= 0) { + buffer.offer(prepare(input(inputHighIndex))) + inputHighIndex += 1 + bufferUpdated = true + } + + // Drop all rows from the buffer for which the input row value is smaller than + // the output row lower bound. + while (inputLowIndex < inputHighIndex && + lbound.compare(input, inputLowIndex, outputIndex) < 0) { + buffer.pop() + inputLowIndex += 1 + bufferUpdated = true + } + + // Only recalculate and update when the buffer changes. + if (bufferUpdated) { + evaluatePrepared(buffer.iterator()) + fill(target, outputIndex) + } + + // Move to the next row. + outputIndex += 1 + } + + /** Copy the frame. */ + override def copy: SlidingWindowFunctionFrame = + new SlidingWindowFunctionFrame(ordinal, copyFunctions, lbound, ubound) +} + +/** + * The unbounded window frame calculates frames with the following SQL forms: + * ... (No Frame Definition) + * ... BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING + * + * Its results are the same for each and every row in the partition. This class can be seen as a + * special case of a sliding window, but is optimized for the unbound case. + * + * @param ordinal of the first column written by this frame. + * @param functions to calculate the row values with. + */ +private[execution] final class UnboundedWindowFunctionFrame( + ordinal: Int, + functions: Array[WindowFunction]) extends WindowFunctionFrame(ordinal, functions) { + + /** Index of the row we are currently writing. */ + private[this] var outputIndex = 0 + + /** Prepare the frame for calculating a new partition. Process all rows eagerly. */ + override def prepare(rows: CompactBuffer[InternalRow]): Unit = { + reset() + outputIndex = 0 + val iterator = rows.iterator + while (iterator.hasNext) { + update(iterator.next()) } + evaluate() + } + + /** Write the frame columns for the current row to the given target row. */ + override def write(target: GenericMutableRow): Unit = { + fill(target, outputIndex) + outputIndex += 1 + } + + /** Copy the frame. */ + override def copy: UnboundedWindowFunctionFrame = + new UnboundedWindowFunctionFrame(ordinal, copyFunctions) +} + +/** + * The UnboundPreceding window frame calculates frames with the following SQL form: + * ... BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW + * + * There is only an upper bound. Very common use cases are for instance running sums or counts + * (row_number). Technically this is a special case of a sliding window. However a sliding window + * has to maintain a buffer, and it must do a full evaluation everytime the buffer changes. This + * is not the case when there is no lower bound, given the additive nature of most aggregates + * streaming updates and partial evaluation suffice and no buffering is needed. + * + * @param ordinal of the first column written by this frame. + * @param functions to calculate the row values with. + * @param ubound comparator used to identify the upper bound of an output row. + */ +private[execution] final class UnboundedPrecedingWindowFunctionFrame( + ordinal: Int, + functions: Array[WindowFunction], + ubound: BoundOrdering) extends WindowFunctionFrame(ordinal, functions) { + + /** Rows of the partition currently being processed. */ + private[this] var input: CompactBuffer[InternalRow] = null + + /** Index of the first input row with a value greater than the upper bound of the current + * output row. */ + private[this] var inputIndex = 0 + + /** Index of the row we are currently writing. */ + private[this] var outputIndex = 0 + + /** Prepare the frame for calculating a new partition. */ + override def prepare(rows: CompactBuffer[InternalRow]): Unit = { + reset() + input = rows + inputIndex = 0 + outputIndex = 0 + } + + /** Write the frame columns for the current row to the given target row. */ + override def write(target: GenericMutableRow): Unit = { + var bufferUpdated = outputIndex == 0 + + // Add all rows to the aggregates for which the input row value is equal to or less than + // the output row upper bound. + while (inputIndex < input.size && ubound.compare(input, inputIndex, outputIndex) <= 0) { + update(input(inputIndex)) + inputIndex += 1 + bufferUpdated = true + } + + // Only recalculate and update when the buffer changes. + if (bufferUpdated) { + evaluate() + fill(target, outputIndex) + } + + // Move to the next row. + outputIndex += 1 } + + /** Copy the frame. */ + override def copy: UnboundedPrecedingWindowFunctionFrame = + new UnboundedPrecedingWindowFunctionFrame(ordinal, copyFunctions, ubound) +} + +/** + * The UnboundFollowing window frame calculates frames with the following SQL form: + * ... BETWEEN CURRENT ROW AND UNBOUNDED FOLLOWING + * + * There is only an upper bound. This is a slightly modified version of the sliding window. The + * sliding window operator has to check if both upper and the lower bound change when a new row + * gets processed, where as the unbounded following only has to check the lower bound. + * + * This is a very expensive operator to use, O(n * (n - 1) /2), because we need to maintain a + * buffer and must do full recalculation after each row. Reverse iteration would be possible, if + * the communitativity of the used window functions can be guaranteed. + * + * @param ordinal of the first column written by this frame. + * @param functions to calculate the row values with. + * @param lbound comparator used to identify the lower bound of an output row. + */ +private[execution] final class UnboundedFollowingWindowFunctionFrame( + ordinal: Int, + functions: Array[WindowFunction], + lbound: BoundOrdering) extends WindowFunctionFrame(ordinal, functions) { + + /** Buffer used for storing prepared input for the window functions. */ + private[this] var buffer: Array[Array[AnyRef]] = _ + + /** Rows of the partition currently being processed. */ + private[this] var input: CompactBuffer[InternalRow] = null + + /** Index of the first input row with a value equal to or greater than the lower bound of the + * current output row. */ + private[this] var inputIndex = 0 + + /** Index of the row we are currently writing. */ + private[this] var outputIndex = 0 + + /** Prepare the frame for calculating a new partition. */ + override def prepare(rows: CompactBuffer[InternalRow]): Unit = { + input = rows + inputIndex = 0 + outputIndex = 0 + val size = input.size + buffer = Array.ofDim(size) + var i = 0 + while (i < size) { + buffer(i) = prepare(input(i)) + i += 1 + } + evaluatePrepared(buffer, 0, buffer.length) + } + + /** Write the frame columns for the current row to the given target row. */ + override def write(target: GenericMutableRow): Unit = { + var bufferUpdated = outputIndex == 0 + + // Drop all rows from the buffer for which the input row value is smaller than + // the output row lower bound. + while (inputIndex < input.size && lbound.compare(input, inputIndex, outputIndex) < 0) { + inputIndex += 1 + bufferUpdated = true + } + + // Only recalculate and update when the buffer changes. + if (bufferUpdated) { + evaluatePrepared(buffer, inputIndex, buffer.length) + fill(target, outputIndex) + } + + // Move to the next row. + outputIndex += 1 + } + + /** Copy the frame. */ + override def copy: UnboundedFollowingWindowFunctionFrame = + new UnboundedFollowingWindowFunctionFrame(ordinal, copyFunctions, lbound) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/aggregateOperators.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/aggregateOperators.scala new file mode 100644 index 0000000000000..0c9082897f390 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/aggregateOperators.scala @@ -0,0 +1,173 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.aggregate + +import org.apache.spark.rdd.RDD +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.errors._ +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.expressions.aggregate._ +import org.apache.spark.sql.catalyst.plans.physical.{AllTuples, ClusteredDistribution, Distribution, UnspecifiedDistribution} +import org.apache.spark.sql.execution.{SparkPlan, UnaryNode} + +case class Aggregate2Sort( + requiredChildDistributionExpressions: Option[Seq[Expression]], + groupingExpressions: Seq[NamedExpression], + aggregateExpressions: Seq[AggregateExpression2], + aggregateAttributes: Seq[Attribute], + resultExpressions: Seq[NamedExpression], + child: SparkPlan) + extends UnaryNode { + + override def canProcessUnsafeRows: Boolean = true + + override def references: AttributeSet = { + val referencesInResults = + AttributeSet(resultExpressions.flatMap(_.references)) -- AttributeSet(aggregateAttributes) + + AttributeSet( + groupingExpressions.flatMap(_.references) ++ + aggregateExpressions.flatMap(_.references) ++ + referencesInResults) + } + + override def requiredChildDistribution: List[Distribution] = { + requiredChildDistributionExpressions match { + case Some(exprs) if exprs.length == 0 => AllTuples :: Nil + case Some(exprs) if exprs.length > 0 => ClusteredDistribution(exprs) :: Nil + case None => UnspecifiedDistribution :: Nil + } + } + + override def requiredChildOrdering: Seq[Seq[SortOrder]] = { + // TODO: We should not sort the input rows if they are just in reversed order. + groupingExpressions.map(SortOrder(_, Ascending)) :: Nil + } + + override def outputOrdering: Seq[SortOrder] = { + // It is possible that the child.outputOrdering starts with the required + // ordering expressions (e.g. we require [a] as the sort expression and the + // child's outputOrdering is [a, b]). We can only guarantee the output rows + // are sorted by values of groupingExpressions. + groupingExpressions.map(SortOrder(_, Ascending)) + } + + override def output: Seq[Attribute] = resultExpressions.map(_.toAttribute) + + protected override def doExecute(): RDD[InternalRow] = attachTree(this, "execute") { + child.execute().mapPartitions { iter => + if (aggregateExpressions.length == 0) { + new GroupingIterator( + groupingExpressions, + resultExpressions, + newMutableProjection, + child.output, + iter) + } else { + val aggregationIterator: SortAggregationIterator = { + aggregateExpressions.map(_.mode).distinct.toList match { + case Partial :: Nil => + new PartialSortAggregationIterator( + groupingExpressions, + aggregateExpressions, + newMutableProjection, + child.output, + iter) + case PartialMerge :: Nil => + new PartialMergeSortAggregationIterator( + groupingExpressions, + aggregateExpressions, + newMutableProjection, + child.output, + iter) + case Final :: Nil => + new FinalSortAggregationIterator( + groupingExpressions, + aggregateExpressions, + aggregateAttributes, + resultExpressions, + newMutableProjection, + child.output, + iter) + case other => + sys.error( + s"Could not evaluate ${aggregateExpressions} because we do not support evaluate " + + s"modes $other in this operator.") + } + } + + aggregationIterator + } + } + } +} + +case class FinalAndCompleteAggregate2Sort( + previousGroupingExpressions: Seq[NamedExpression], + groupingExpressions: Seq[NamedExpression], + finalAggregateExpressions: Seq[AggregateExpression2], + finalAggregateAttributes: Seq[Attribute], + completeAggregateExpressions: Seq[AggregateExpression2], + completeAggregateAttributes: Seq[Attribute], + resultExpressions: Seq[NamedExpression], + child: SparkPlan) + extends UnaryNode { + override def references: AttributeSet = { + val referencesInResults = + AttributeSet(resultExpressions.flatMap(_.references)) -- + AttributeSet(finalAggregateExpressions) -- + AttributeSet(completeAggregateExpressions) + + AttributeSet( + groupingExpressions.flatMap(_.references) ++ + finalAggregateExpressions.flatMap(_.references) ++ + completeAggregateExpressions.flatMap(_.references) ++ + referencesInResults) + } + + override def requiredChildDistribution: List[Distribution] = { + if (groupingExpressions.isEmpty) { + AllTuples :: Nil + } else { + ClusteredDistribution(groupingExpressions) :: Nil + } + } + + override def requiredChildOrdering: Seq[Seq[SortOrder]] = + groupingExpressions.map(SortOrder(_, Ascending)) :: Nil + + override def output: Seq[Attribute] = resultExpressions.map(_.toAttribute) + + protected override def doExecute(): RDD[InternalRow] = attachTree(this, "execute") { + child.execute().mapPartitions { iter => + + new FinalAndCompleteSortAggregationIterator( + previousGroupingExpressions.length, + groupingExpressions, + finalAggregateExpressions, + finalAggregateAttributes, + completeAggregateExpressions, + completeAggregateAttributes, + resultExpressions, + newMutableProjection, + child.output, + iter) + } + } + +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/sortBasedIterators.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/sortBasedIterators.scala new file mode 100644 index 0000000000000..ce1cbdc9cb090 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/sortBasedIterators.scala @@ -0,0 +1,749 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.aggregate + +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.expressions.aggregate._ +import org.apache.spark.sql.types.NullType + +import scala.collection.mutable.ArrayBuffer + +/** + * An iterator used to evaluate aggregate functions. It assumes that input rows + * are already grouped by values of `groupingExpressions`. + */ +private[sql] abstract class SortAggregationIterator( + groupingExpressions: Seq[NamedExpression], + aggregateExpressions: Seq[AggregateExpression2], + newMutableProjection: (Seq[Expression], Seq[Attribute]) => (() => MutableProjection), + inputAttributes: Seq[Attribute], + inputIter: Iterator[InternalRow]) + extends Iterator[InternalRow] { + + /////////////////////////////////////////////////////////////////////////// + // Static fields for this iterator + /////////////////////////////////////////////////////////////////////////// + + protected val aggregateFunctions: Array[AggregateFunction2] = { + var bufferOffset = initialBufferOffset + val functions = new Array[AggregateFunction2](aggregateExpressions.length) + var i = 0 + while (i < aggregateExpressions.length) { + val func = aggregateExpressions(i).aggregateFunction + val funcWithBoundReferences = aggregateExpressions(i).mode match { + case Partial | Complete if !func.isInstanceOf[AlgebraicAggregate] => + // We need to create BoundReferences if the function is not an + // AlgebraicAggregate (it does not support code-gen) and the mode of + // this function is Partial or Complete because we will call eval of this + // function's children in the update method of this aggregate function. + // Those eval calls require BoundReferences to work. + BindReferences.bindReference(func, inputAttributes) + case _ => func + } + // Set bufferOffset for this function. It is important that setting bufferOffset + // happens after all potential bindReference operations because bindReference + // will create a new instance of the function. + funcWithBoundReferences.bufferOffset = bufferOffset + bufferOffset += funcWithBoundReferences.bufferSchema.length + functions(i) = funcWithBoundReferences + i += 1 + } + functions + } + + // All non-algebraic aggregate functions. + protected val nonAlgebraicAggregateFunctions: Array[AggregateFunction2] = { + aggregateFunctions.collect { + case func: AggregateFunction2 if !func.isInstanceOf[AlgebraicAggregate] => func + }.toArray + } + + // Positions of those non-algebraic aggregate functions in aggregateFunctions. + // For example, we have func1, func2, func3, func4 in aggregateFunctions, and + // func2 and func3 are non-algebraic aggregate functions. + // nonAlgebraicAggregateFunctionPositions will be [1, 2]. + protected val nonAlgebraicAggregateFunctionPositions: Array[Int] = { + val positions = new ArrayBuffer[Int]() + var i = 0 + while (i < aggregateFunctions.length) { + aggregateFunctions(i) match { + case agg: AlgebraicAggregate => + case _ => positions += i + } + i += 1 + } + positions.toArray + } + + // This is used to project expressions for the grouping expressions. + protected val groupGenerator = + newMutableProjection(groupingExpressions, inputAttributes)() + + // The underlying buffer shared by all aggregate functions. + protected val buffer: MutableRow = { + // The number of elements of the underlying buffer of this operator. + // All aggregate functions are sharing this underlying buffer and they find their + // buffer values through bufferOffset. + var size = initialBufferOffset + var i = 0 + while (i < aggregateFunctions.length) { + size += aggregateFunctions(i).bufferSchema.length + i += 1 + } + new GenericMutableRow(size) + } + + protected val joinedRow = new JoinedRow4 + + protected val placeholderExpressions = Seq.fill(initialBufferOffset)(NoOp) + + // This projection is used to initialize buffer values for all AlgebraicAggregates. + protected val algebraicInitialProjection = { + val initExpressions = placeholderExpressions ++ aggregateFunctions.flatMap { + case ae: AlgebraicAggregate => ae.initialValues + case agg: AggregateFunction2 => Seq.fill(agg.bufferAttributes.length)(NoOp) + } + newMutableProjection(initExpressions, Nil)().target(buffer) + } + + /////////////////////////////////////////////////////////////////////////// + // Mutable states + /////////////////////////////////////////////////////////////////////////// + + // The partition key of the current partition. + protected var currentGroupingKey: InternalRow = _ + // The partition key of next partition. + protected var nextGroupingKey: InternalRow = _ + // The first row of next partition. + protected var firstRowInNextGroup: InternalRow = _ + // Indicates if we has new group of rows to process. + protected var hasNewGroup: Boolean = true + + /////////////////////////////////////////////////////////////////////////// + // Private methods + /////////////////////////////////////////////////////////////////////////// + + /** Initializes buffer values for all aggregate functions. */ + protected def initializeBuffer(): Unit = { + algebraicInitialProjection(EmptyRow) + var i = 0 + while (i < nonAlgebraicAggregateFunctions.length) { + nonAlgebraicAggregateFunctions(i).initialize(buffer) + i += 1 + } + } + + protected def initialize(): Unit = { + if (inputIter.hasNext) { + initializeBuffer() + val currentRow = inputIter.next().copy() + // partitionGenerator is a mutable projection. Since we need to track nextGroupingKey, + // we are making a copy at here. + nextGroupingKey = groupGenerator(currentRow).copy() + firstRowInNextGroup = currentRow + } else { + // This iter is an empty one. + hasNewGroup = false + } + } + + /** Processes rows in the current group. It will stop when it find a new group. */ + private def processCurrentGroup(): Unit = { + currentGroupingKey = nextGroupingKey + // Now, we will start to find all rows belonging to this group. + // We create a variable to track if we see the next group. + var findNextPartition = false + // firstRowInNextGroup is the first row of this group. We first process it. + processRow(firstRowInNextGroup) + // The search will stop when we see the next group or there is no + // input row left in the iter. + while (inputIter.hasNext && !findNextPartition) { + val currentRow = inputIter.next() + // Get the grouping key based on the grouping expressions. + // For the below compare method, we do not need to make a copy of groupingKey. + val groupingKey = groupGenerator(currentRow) + // Check if the current row belongs the current input row. + currentGroupingKey.equals(groupingKey) + + if (currentGroupingKey == groupingKey) { + processRow(currentRow) + } else { + // We find a new group. + findNextPartition = true + nextGroupingKey = groupingKey.copy() + firstRowInNextGroup = currentRow.copy() + } + } + // We have not seen a new group. It means that there is no new row in the input + // iter. The current group is the last group of the iter. + if (!findNextPartition) { + hasNewGroup = false + } + } + + /////////////////////////////////////////////////////////////////////////// + // Public methods + /////////////////////////////////////////////////////////////////////////// + + override final def hasNext: Boolean = hasNewGroup + + override final def next(): InternalRow = { + if (hasNext) { + // Process the current group. + processCurrentGroup() + // Generate output row for the current group. + val outputRow = generateOutput() + // Initilize buffer values for the next group. + initializeBuffer() + + outputRow + } else { + // no more result + throw new NoSuchElementException + } + } + + /////////////////////////////////////////////////////////////////////////// + // Methods that need to be implemented + /////////////////////////////////////////////////////////////////////////// + + protected def initialBufferOffset: Int + + protected def processRow(row: InternalRow): Unit + + protected def generateOutput(): InternalRow + + /////////////////////////////////////////////////////////////////////////// + // Initialize this iterator + /////////////////////////////////////////////////////////////////////////// + + initialize() +} + +/** + * An iterator only used to group input rows according to values of `groupingExpressions`. + * It assumes that input rows are already grouped by values of `groupingExpressions`. + */ +class GroupingIterator( + groupingExpressions: Seq[NamedExpression], + resultExpressions: Seq[NamedExpression], + newMutableProjection: (Seq[Expression], Seq[Attribute]) => (() => MutableProjection), + inputAttributes: Seq[Attribute], + inputIter: Iterator[InternalRow]) + extends SortAggregationIterator( + groupingExpressions, + Nil, + newMutableProjection, + inputAttributes, + inputIter) { + + private val resultProjection = + newMutableProjection(resultExpressions, groupingExpressions.map(_.toAttribute))() + + override protected def initialBufferOffset: Int = 0 + + override protected def processRow(row: InternalRow): Unit = { + // Since we only do grouping, there is nothing to do at here. + } + + override protected def generateOutput(): InternalRow = { + resultProjection(currentGroupingKey) + } +} + +/** + * An iterator used to do partial aggregations (for those aggregate functions with mode Partial). + * It assumes that input rows are already grouped by values of `groupingExpressions`. + * The format of its output rows is: + * |groupingExpr1|...|groupingExprN|aggregationBuffer1|...|aggregationBufferN| + */ +class PartialSortAggregationIterator( + groupingExpressions: Seq[NamedExpression], + aggregateExpressions: Seq[AggregateExpression2], + newMutableProjection: (Seq[Expression], Seq[Attribute]) => (() => MutableProjection), + inputAttributes: Seq[Attribute], + inputIter: Iterator[InternalRow]) + extends SortAggregationIterator( + groupingExpressions, + aggregateExpressions, + newMutableProjection, + inputAttributes, + inputIter) { + + // This projection is used to update buffer values for all AlgebraicAggregates. + private val algebraicUpdateProjection = { + val bufferSchema = aggregateFunctions.flatMap { + case ae: AlgebraicAggregate => ae.bufferAttributes + case agg: AggregateFunction2 => agg.bufferAttributes + } + val updateExpressions = aggregateFunctions.flatMap { + case ae: AlgebraicAggregate => ae.updateExpressions + case agg: AggregateFunction2 => Seq.fill(agg.bufferAttributes.length)(NoOp) + } + newMutableProjection(updateExpressions, bufferSchema ++ inputAttributes)().target(buffer) + } + + override protected def initialBufferOffset: Int = 0 + + override protected def processRow(row: InternalRow): Unit = { + // Process all algebraic aggregate functions. + algebraicUpdateProjection(joinedRow(buffer, row)) + // Process all non-algebraic aggregate functions. + var i = 0 + while (i < nonAlgebraicAggregateFunctions.length) { + nonAlgebraicAggregateFunctions(i).update(buffer, row) + i += 1 + } + } + + override protected def generateOutput(): InternalRow = { + // We just output the grouping expressions and the underlying buffer. + joinedRow(currentGroupingKey, buffer).copy() + } +} + +/** + * An iterator used to do partial merge aggregations (for those aggregate functions with mode + * PartialMerge). It assumes that input rows are already grouped by values of + * `groupingExpressions`. + * The format of its input rows is: + * |groupingExpr1|...|groupingExprN|aggregationBuffer1|...|aggregationBufferN| + * + * The format of its internal buffer is: + * |placeholder1|...|placeholderN|aggregationBuffer1|...|aggregationBufferN| + * Every placeholder is for a grouping expression. + * The actual buffers are stored after placeholderN. + * The reason that we have placeholders at here is to make our underlying buffer have the same + * length with a input row. + * + * The format of its output rows is: + * |groupingExpr1|...|groupingExprN|aggregationBuffer1|...|aggregationBufferN| + */ +class PartialMergeSortAggregationIterator( + groupingExpressions: Seq[NamedExpression], + aggregateExpressions: Seq[AggregateExpression2], + newMutableProjection: (Seq[Expression], Seq[Attribute]) => (() => MutableProjection), + inputAttributes: Seq[Attribute], + inputIter: Iterator[InternalRow]) + extends SortAggregationIterator( + groupingExpressions, + aggregateExpressions, + newMutableProjection, + inputAttributes, + inputIter) { + + private val placeholderAttribtues = + Seq.fill(initialBufferOffset)(AttributeReference("placeholder", NullType)()) + + // This projection is used to merge buffer values for all AlgebraicAggregates. + private val algebraicMergeProjection = { + val bufferSchemata = + placeholderAttribtues ++ aggregateFunctions.flatMap { + case ae: AlgebraicAggregate => ae.bufferAttributes + case agg: AggregateFunction2 => agg.bufferAttributes + } ++ placeholderAttribtues ++ aggregateFunctions.flatMap { + case ae: AlgebraicAggregate => ae.cloneBufferAttributes + case agg: AggregateFunction2 => agg.cloneBufferAttributes + } + val mergeExpressions = placeholderExpressions ++ aggregateFunctions.flatMap { + case ae: AlgebraicAggregate => ae.mergeExpressions + case agg: AggregateFunction2 => Seq.fill(agg.bufferAttributes.length)(NoOp) + } + + newMutableProjection(mergeExpressions, bufferSchemata)() + } + + // This projection is used to extract aggregation buffers from the underlying buffer. + // We need it because the underlying buffer has placeholders at its beginning. + private val extractsBufferValues = { + val expressions = aggregateFunctions.flatMap { + case agg => agg.bufferAttributes + } + + newMutableProjection(expressions, inputAttributes)() + } + + override protected def initialBufferOffset: Int = groupingExpressions.length + + override protected def processRow(row: InternalRow): Unit = { + // Process all algebraic aggregate functions. + algebraicMergeProjection.target(buffer)(joinedRow(buffer, row)) + // Process all non-algebraic aggregate functions. + var i = 0 + while (i < nonAlgebraicAggregateFunctions.length) { + nonAlgebraicAggregateFunctions(i).merge(buffer, row) + i += 1 + } + } + + override protected def generateOutput(): InternalRow = { + // We output grouping expressions and aggregation buffers. + joinedRow(currentGroupingKey, extractsBufferValues(buffer)) + } +} + +/** + * An iterator used to do final aggregations (for those aggregate functions with mode + * Final). It assumes that input rows are already grouped by values of + * `groupingExpressions`. + * The format of its input rows is: + * |groupingExpr1|...|groupingExprN|aggregationBuffer1|...|aggregationBufferN| + * + * The format of its internal buffer is: + * |placeholder1|...|placeholder N|aggregationBuffer1|...|aggregationBufferN| + * Every placeholder is for a grouping expression. + * The actual buffers are stored after placeholderN. + * The reason that we have placeholders at here is to make our underlying buffer have the same + * length with a input row. + * + * The format of its output rows is represented by the schema of `resultExpressions`. + */ +class FinalSortAggregationIterator( + groupingExpressions: Seq[NamedExpression], + aggregateExpressions: Seq[AggregateExpression2], + aggregateAttributes: Seq[Attribute], + resultExpressions: Seq[NamedExpression], + newMutableProjection: (Seq[Expression], Seq[Attribute]) => (() => MutableProjection), + inputAttributes: Seq[Attribute], + inputIter: Iterator[InternalRow]) + extends SortAggregationIterator( + groupingExpressions, + aggregateExpressions, + newMutableProjection, + inputAttributes, + inputIter) { + + // The result of aggregate functions. + private val aggregateResult: MutableRow = new GenericMutableRow(aggregateAttributes.length) + + // The projection used to generate the output rows of this operator. + // This is only used when we are generating final results of aggregate functions. + private val resultProjection = + newMutableProjection( + resultExpressions, groupingExpressions.map(_.toAttribute) ++ aggregateAttributes)() + + private val offsetAttributes = + Seq.fill(initialBufferOffset)(AttributeReference("placeholder", NullType)()) + + // This projection is used to merge buffer values for all AlgebraicAggregates. + private val algebraicMergeProjection = { + val bufferSchemata = + offsetAttributes ++ aggregateFunctions.flatMap { + case ae: AlgebraicAggregate => ae.bufferAttributes + case agg: AggregateFunction2 => agg.bufferAttributes + } ++ offsetAttributes ++ aggregateFunctions.flatMap { + case ae: AlgebraicAggregate => ae.cloneBufferAttributes + case agg: AggregateFunction2 => agg.cloneBufferAttributes + } + val mergeExpressions = placeholderExpressions ++ aggregateFunctions.flatMap { + case ae: AlgebraicAggregate => ae.mergeExpressions + case agg: AggregateFunction2 => Seq.fill(agg.bufferAttributes.length)(NoOp) + } + + newMutableProjection(mergeExpressions, bufferSchemata)() + } + + // This projection is used to evaluate all AlgebraicAggregates. + private val algebraicEvalProjection = { + val bufferSchemata = + offsetAttributes ++ aggregateFunctions.flatMap { + case ae: AlgebraicAggregate => ae.bufferAttributes + case agg: AggregateFunction2 => agg.bufferAttributes + } ++ offsetAttributes ++ aggregateFunctions.flatMap { + case ae: AlgebraicAggregate => ae.cloneBufferAttributes + case agg: AggregateFunction2 => agg.cloneBufferAttributes + } + val evalExpressions = aggregateFunctions.map { + case ae: AlgebraicAggregate => ae.evaluateExpression + case agg: AggregateFunction2 => NoOp + } + + newMutableProjection(evalExpressions, bufferSchemata)() + } + + override protected def initialBufferOffset: Int = groupingExpressions.length + + override def initialize(): Unit = { + if (inputIter.hasNext) { + initializeBuffer() + val currentRow = inputIter.next().copy() + // partitionGenerator is a mutable projection. Since we need to track nextGroupingKey, + // we are making a copy at here. + nextGroupingKey = groupGenerator(currentRow).copy() + firstRowInNextGroup = currentRow + } else { + if (groupingExpressions.isEmpty) { + // If there is no grouping expression, we need to generate a single row as the output. + initializeBuffer() + // Right now, the buffer only contains initial buffer values. Because + // merging two buffers with initial values will generate a row that + // still store initial values. We set the currentRow as the copy of the current buffer. + val currentRow = buffer.copy() + nextGroupingKey = groupGenerator(currentRow).copy() + firstRowInNextGroup = currentRow + } else { + // This iter is an empty one. + hasNewGroup = false + } + } + } + + override protected def processRow(row: InternalRow): Unit = { + // Process all algebraic aggregate functions. + algebraicMergeProjection.target(buffer)(joinedRow(buffer, row)) + // Process all non-algebraic aggregate functions. + var i = 0 + while (i < nonAlgebraicAggregateFunctions.length) { + nonAlgebraicAggregateFunctions(i).merge(buffer, row) + i += 1 + } + } + + override protected def generateOutput(): InternalRow = { + // Generate results for all algebraic aggregate functions. + algebraicEvalProjection.target(aggregateResult)(buffer) + // Generate results for all non-algebraic aggregate functions. + var i = 0 + while (i < nonAlgebraicAggregateFunctions.length) { + aggregateResult.update( + nonAlgebraicAggregateFunctionPositions(i), + nonAlgebraicAggregateFunctions(i).eval(buffer)) + i += 1 + } + resultProjection(joinedRow(currentGroupingKey, aggregateResult)) + } +} + +/** + * An iterator used to do both final aggregations (for those aggregate functions with mode + * Final) and complete aggregations (for those aggregate functions with mode Complete). + * It assumes that input rows are already grouped by values of `groupingExpressions`. + * The format of its input rows is: + * |groupingExpr1|...|groupingExprN|col1|...|colM|aggregationBuffer1|...|aggregationBufferN| + * col1 to colM are columns used by aggregate functions with Complete mode. + * aggregationBuffer1 to aggregationBufferN are buffers used by aggregate functions with + * Final mode. + * + * The format of its internal buffer is: + * |placeholder1|...|placeholder(N+M)|aggregationBuffer1|...|aggregationBuffer(N+M)| + * The first N placeholders represent slots of grouping expressions. + * Then, next M placeholders represent slots of col1 to colM. + * For aggregation buffers, first N aggregation buffers are used by N aggregate functions with + * mode Final. Then, the last M aggregation buffers are used by M aggregate functions with mode + * Complete. The reason that we have placeholders at here is to make our underlying buffer + * have the same length with a input row. + * + * The format of its output rows is represented by the schema of `resultExpressions`. + */ +class FinalAndCompleteSortAggregationIterator( + override protected val initialBufferOffset: Int, + groupingExpressions: Seq[NamedExpression], + finalAggregateExpressions: Seq[AggregateExpression2], + finalAggregateAttributes: Seq[Attribute], + completeAggregateExpressions: Seq[AggregateExpression2], + completeAggregateAttributes: Seq[Attribute], + resultExpressions: Seq[NamedExpression], + newMutableProjection: (Seq[Expression], Seq[Attribute]) => (() => MutableProjection), + inputAttributes: Seq[Attribute], + inputIter: Iterator[InternalRow]) + extends SortAggregationIterator( + groupingExpressions, + // TODO: document the ordering + finalAggregateExpressions ++ completeAggregateExpressions, + newMutableProjection, + inputAttributes, + inputIter) { + + // The result of aggregate functions. + private val aggregateResult: MutableRow = + new GenericMutableRow(completeAggregateAttributes.length + finalAggregateAttributes.length) + + // The projection used to generate the output rows of this operator. + // This is only used when we are generating final results of aggregate functions. + private val resultProjection = { + val inputSchema = + groupingExpressions.map(_.toAttribute) ++ + finalAggregateAttributes ++ + completeAggregateAttributes + newMutableProjection(resultExpressions, inputSchema)() + } + + private val offsetAttributes = + Seq.fill(initialBufferOffset)(AttributeReference("placeholder", NullType)()) + + // All aggregate functions with mode Final. + private val finalAggregateFunctions: Array[AggregateFunction2] = { + val functions = new Array[AggregateFunction2](finalAggregateExpressions.length) + var i = 0 + while (i < finalAggregateExpressions.length) { + functions(i) = aggregateFunctions(i) + i += 1 + } + functions + } + + // All non-algebraic aggregate functions with mode Final. + private val finalNonAlgebraicAggregateFunctions: Array[AggregateFunction2] = { + finalAggregateFunctions.collect { + case func: AggregateFunction2 if !func.isInstanceOf[AlgebraicAggregate] => func + }.toArray + } + + // All aggregate functions with mode Complete. + private val completeAggregateFunctions: Array[AggregateFunction2] = { + val functions = new Array[AggregateFunction2](completeAggregateExpressions.length) + var i = 0 + while (i < completeAggregateExpressions.length) { + functions(i) = aggregateFunctions(finalAggregateFunctions.length + i) + i += 1 + } + functions + } + + // All non-algebraic aggregate functions with mode Complete. + private val completeNonAlgebraicAggregateFunctions: Array[AggregateFunction2] = { + completeAggregateFunctions.collect { + case func: AggregateFunction2 if !func.isInstanceOf[AlgebraicAggregate] => func + }.toArray + } + + // This projection is used to merge buffer values for all AlgebraicAggregates with mode + // Final. + private val finalAlgebraicMergeProjection = { + val numCompleteOffsetAttributes = + completeAggregateFunctions.map(_.bufferAttributes.length).sum + val completeOffsetAttributes = + Seq.fill(numCompleteOffsetAttributes)(AttributeReference("placeholder", NullType)()) + val completeOffsetExpressions = Seq.fill(numCompleteOffsetAttributes)(NoOp) + + val bufferSchemata = + offsetAttributes ++ finalAggregateFunctions.flatMap { + case ae: AlgebraicAggregate => ae.bufferAttributes + case agg: AggregateFunction2 => agg.bufferAttributes + } ++ completeOffsetAttributes ++ offsetAttributes ++ finalAggregateFunctions.flatMap { + case ae: AlgebraicAggregate => ae.cloneBufferAttributes + case agg: AggregateFunction2 => agg.cloneBufferAttributes + } ++ completeOffsetAttributes + val mergeExpressions = + placeholderExpressions ++ finalAggregateFunctions.flatMap { + case ae: AlgebraicAggregate => ae.mergeExpressions + case agg: AggregateFunction2 => Seq.fill(agg.bufferAttributes.length)(NoOp) + } ++ completeOffsetExpressions + + newMutableProjection(mergeExpressions, bufferSchemata)() + } + + // This projection is used to update buffer values for all AlgebraicAggregates with mode + // Complete. + private val completeAlgebraicUpdateProjection = { + val numFinalOffsetAttributes = finalAggregateFunctions.map(_.bufferAttributes.length).sum + val finalOffsetAttributes = + Seq.fill(numFinalOffsetAttributes)(AttributeReference("placeholder", NullType)()) + val finalOffsetExpressions = Seq.fill(numFinalOffsetAttributes)(NoOp) + + val bufferSchema = + offsetAttributes ++ finalOffsetAttributes ++ completeAggregateFunctions.flatMap { + case ae: AlgebraicAggregate => ae.bufferAttributes + case agg: AggregateFunction2 => agg.bufferAttributes + } + val updateExpressions = + placeholderExpressions ++ finalOffsetExpressions ++ completeAggregateFunctions.flatMap { + case ae: AlgebraicAggregate => ae.updateExpressions + case agg: AggregateFunction2 => Seq.fill(agg.bufferAttributes.length)(NoOp) + } + newMutableProjection(updateExpressions, bufferSchema ++ inputAttributes)().target(buffer) + } + + // This projection is used to evaluate all AlgebraicAggregates. + private val algebraicEvalProjection = { + val bufferSchemata = + offsetAttributes ++ aggregateFunctions.flatMap { + case ae: AlgebraicAggregate => ae.bufferAttributes + case agg: AggregateFunction2 => agg.bufferAttributes + } ++ offsetAttributes ++ aggregateFunctions.flatMap { + case ae: AlgebraicAggregate => ae.cloneBufferAttributes + case agg: AggregateFunction2 => agg.cloneBufferAttributes + } + val evalExpressions = aggregateFunctions.map { + case ae: AlgebraicAggregate => ae.evaluateExpression + case agg: AggregateFunction2 => NoOp + } + + newMutableProjection(evalExpressions, bufferSchemata)() + } + + override def initialize(): Unit = { + if (inputIter.hasNext) { + initializeBuffer() + val currentRow = inputIter.next().copy() + // partitionGenerator is a mutable projection. Since we need to track nextGroupingKey, + // we are making a copy at here. + nextGroupingKey = groupGenerator(currentRow).copy() + firstRowInNextGroup = currentRow + } else { + if (groupingExpressions.isEmpty) { + // If there is no grouping expression, we need to generate a single row as the output. + initializeBuffer() + // Right now, the buffer only contains initial buffer values. Because + // merging two buffers with initial values will generate a row that + // still store initial values. We set the currentRow as the copy of the current buffer. + val currentRow = buffer.copy() + nextGroupingKey = groupGenerator(currentRow).copy() + firstRowInNextGroup = currentRow + } else { + // This iter is an empty one. + hasNewGroup = false + } + } + } + + override protected def processRow(row: InternalRow): Unit = { + val input = joinedRow(buffer, row) + // For all aggregate functions with mode Complete, update buffers. + completeAlgebraicUpdateProjection(input) + var i = 0 + while (i < completeNonAlgebraicAggregateFunctions.length) { + completeNonAlgebraicAggregateFunctions(i).update(buffer, row) + i += 1 + } + + // For all aggregate functions with mode Final, merge buffers. + finalAlgebraicMergeProjection.target(buffer)(input) + i = 0 + while (i < finalNonAlgebraicAggregateFunctions.length) { + finalNonAlgebraicAggregateFunctions(i).merge(buffer, row) + i += 1 + } + } + + override protected def generateOutput(): InternalRow = { + // Generate results for all algebraic aggregate functions. + algebraicEvalProjection.target(aggregateResult)(buffer) + // Generate results for all non-algebraic aggregate functions. + var i = 0 + while (i < nonAlgebraicAggregateFunctions.length) { + aggregateResult.update( + nonAlgebraicAggregateFunctionPositions(i), + nonAlgebraicAggregateFunctions(i).eval(buffer)) + i += 1 + } + + resultProjection(joinedRow(currentGroupingKey, aggregateResult)) + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/utils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/utils.scala new file mode 100644 index 0000000000000..1cb27710e0480 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/utils.scala @@ -0,0 +1,364 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.aggregate + +import org.apache.spark.sql.AnalysisException +import org.apache.spark.sql.catalyst._ +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.expressions.aggregate._ +import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, LogicalPlan} +import org.apache.spark.sql.execution.SparkPlan +import org.apache.spark.sql.types.{StructType, MapType, ArrayType} + +/** + * Utility functions used by the query planner to convert our plan to new aggregation code path. + */ +object Utils { + // Right now, we do not support complex types in the grouping key schema. + private def supportsGroupingKeySchema(aggregate: Aggregate): Boolean = { + val hasComplexTypes = aggregate.groupingExpressions.map(_.dataType).exists { + case array: ArrayType => true + case map: MapType => true + case struct: StructType => true + case _ => false + } + + !hasComplexTypes + } + + private def tryConvert(plan: LogicalPlan): Option[Aggregate] = plan match { + case p: Aggregate if supportsGroupingKeySchema(p) => + val converted = p.transformExpressionsDown { + case expressions.Average(child) => + aggregate.AggregateExpression2( + aggregateFunction = aggregate.Average(child), + mode = aggregate.Complete, + isDistinct = false) + + case expressions.Count(child) => + aggregate.AggregateExpression2( + aggregateFunction = aggregate.Count(child), + mode = aggregate.Complete, + isDistinct = false) + + // We do not support multiple COUNT DISTINCT columns for now. + case expressions.CountDistinct(children) if children.length == 1 => + aggregate.AggregateExpression2( + aggregateFunction = aggregate.Count(children.head), + mode = aggregate.Complete, + isDistinct = true) + + case expressions.First(child) => + aggregate.AggregateExpression2( + aggregateFunction = aggregate.First(child), + mode = aggregate.Complete, + isDistinct = false) + + case expressions.Last(child) => + aggregate.AggregateExpression2( + aggregateFunction = aggregate.Last(child), + mode = aggregate.Complete, + isDistinct = false) + + case expressions.Max(child) => + aggregate.AggregateExpression2( + aggregateFunction = aggregate.Max(child), + mode = aggregate.Complete, + isDistinct = false) + + case expressions.Min(child) => + aggregate.AggregateExpression2( + aggregateFunction = aggregate.Min(child), + mode = aggregate.Complete, + isDistinct = false) + + case expressions.Sum(child) => + aggregate.AggregateExpression2( + aggregateFunction = aggregate.Sum(child), + mode = aggregate.Complete, + isDistinct = false) + + case expressions.SumDistinct(child) => + aggregate.AggregateExpression2( + aggregateFunction = aggregate.Sum(child), + mode = aggregate.Complete, + isDistinct = true) + } + // Check if there is any expressions.AggregateExpression1 left. + // If so, we cannot convert this plan. + val hasAggregateExpression1 = converted.aggregateExpressions.exists { expr => + // For every expressions, check if it contains AggregateExpression1. + expr.find { + case agg: expressions.AggregateExpression1 => true + case other => false + }.isDefined + } + + // Check if there are multiple distinct columns. + val aggregateExpressions = converted.aggregateExpressions.flatMap { expr => + expr.collect { + case agg: AggregateExpression2 => agg + } + }.toSet.toSeq + val functionsWithDistinct = aggregateExpressions.filter(_.isDistinct) + val hasMultipleDistinctColumnSets = + if (functionsWithDistinct.map(_.aggregateFunction.children).distinct.length > 1) { + true + } else { + false + } + + if (!hasAggregateExpression1 && !hasMultipleDistinctColumnSets) Some(converted) else None + + case other => None + } + + private def checkInvalidAggregateFunction2(aggregate: Aggregate): Unit = { + // If the plan cannot be converted, we will do a final round check to if the original + // logical.Aggregate contains both AggregateExpression1 and AggregateExpression2. If so, + // we need to throw an exception. + val aggregateFunction2s = aggregate.aggregateExpressions.flatMap { expr => + expr.collect { + case agg: AggregateExpression2 => agg.aggregateFunction + } + }.distinct + if (aggregateFunction2s.nonEmpty) { + // For functions implemented based on the new interface, prepare a list of function names. + val invalidFunctions = { + if (aggregateFunction2s.length > 1) { + s"${aggregateFunction2s.tail.map(_.nodeName).mkString(",")} " + + s"and ${aggregateFunction2s.head.nodeName} are" + } else { + s"${aggregateFunction2s.head.nodeName} is" + } + } + val errorMessage = + s"${invalidFunctions} implemented based on the new Aggregate Function " + + s"interface and it cannot be used with functions implemented based on " + + s"the old Aggregate Function interface." + throw new AnalysisException(errorMessage) + } + } + + def tryConvert( + plan: LogicalPlan, + useNewAggregation: Boolean, + codeGenEnabled: Boolean): Option[Aggregate] = plan match { + case p: Aggregate if useNewAggregation && codeGenEnabled => + val converted = tryConvert(p) + if (converted.isDefined) { + converted + } else { + checkInvalidAggregateFunction2(p) + None + } + case p: Aggregate => + checkInvalidAggregateFunction2(p) + None + case other => None + } + + def planAggregateWithoutDistinct( + groupingExpressions: Seq[Expression], + aggregateExpressions: Seq[AggregateExpression2], + aggregateFunctionMap: Map[(AggregateFunction2, Boolean), Attribute], + resultExpressions: Seq[NamedExpression], + child: SparkPlan): Seq[SparkPlan] = { + // 1. Create an Aggregate Operator for partial aggregations. + val namedGroupingExpressions = groupingExpressions.map { + case ne: NamedExpression => ne -> ne + // If the expression is not a NamedExpressions, we add an alias. + // So, when we generate the result of the operator, the Aggregate Operator + // can directly get the Seq of attributes representing the grouping expressions. + case other => + val withAlias = Alias(other, other.toString)() + other -> withAlias + } + val groupExpressionMap = namedGroupingExpressions.toMap + val namedGroupingAttributes = namedGroupingExpressions.map(_._2.toAttribute) + val partialAggregateExpressions = aggregateExpressions.map { + case AggregateExpression2(aggregateFunction, mode, isDistinct) => + AggregateExpression2(aggregateFunction, Partial, isDistinct) + } + val partialAggregateAttributes = partialAggregateExpressions.flatMap { agg => + agg.aggregateFunction.bufferAttributes + } + val partialAggregate = + Aggregate2Sort( + None: Option[Seq[Expression]], + namedGroupingExpressions.map(_._2), + partialAggregateExpressions, + partialAggregateAttributes, + namedGroupingAttributes ++ partialAggregateAttributes, + child) + + // 2. Create an Aggregate Operator for final aggregations. + val finalAggregateExpressions = aggregateExpressions.map { + case AggregateExpression2(aggregateFunction, mode, isDistinct) => + AggregateExpression2(aggregateFunction, Final, isDistinct) + } + val finalAggregateAttributes = + finalAggregateExpressions.map { + expr => aggregateFunctionMap(expr.aggregateFunction, expr.isDistinct) + } + val rewrittenResultExpressions = resultExpressions.map { expr => + expr.transformDown { + case agg: AggregateExpression2 => + aggregateFunctionMap(agg.aggregateFunction, agg.isDistinct).toAttribute + case expression => + // We do not rely on the equality check at here since attributes may + // different cosmetically. Instead, we use semanticEquals. + groupExpressionMap.collectFirst { + case (expr, ne) if expr semanticEquals expression => ne.toAttribute + }.getOrElse(expression) + }.asInstanceOf[NamedExpression] + } + val finalAggregate = Aggregate2Sort( + Some(namedGroupingAttributes), + namedGroupingAttributes, + finalAggregateExpressions, + finalAggregateAttributes, + rewrittenResultExpressions, + partialAggregate) + + finalAggregate :: Nil + } + + def planAggregateWithOneDistinct( + groupingExpressions: Seq[Expression], + functionsWithDistinct: Seq[AggregateExpression2], + functionsWithoutDistinct: Seq[AggregateExpression2], + aggregateFunctionMap: Map[(AggregateFunction2, Boolean), Attribute], + resultExpressions: Seq[NamedExpression], + child: SparkPlan): Seq[SparkPlan] = { + + // 1. Create an Aggregate Operator for partial aggregations. + // The grouping expressions are original groupingExpressions and + // distinct columns. For example, for avg(distinct value) ... group by key + // the grouping expressions of this Aggregate Operator will be [key, value]. + val namedGroupingExpressions = groupingExpressions.map { + case ne: NamedExpression => ne -> ne + // If the expression is not a NamedExpressions, we add an alias. + // So, when we generate the result of the operator, the Aggregate Operator + // can directly get the Seq of attributes representing the grouping expressions. + case other => + val withAlias = Alias(other, other.toString)() + other -> withAlias + } + val groupExpressionMap = namedGroupingExpressions.toMap + val namedGroupingAttributes = namedGroupingExpressions.map(_._2.toAttribute) + + // It is safe to call head at here since functionsWithDistinct has at least one + // AggregateExpression2. + val distinctColumnExpressions = + functionsWithDistinct.head.aggregateFunction.children + val namedDistinctColumnExpressions = distinctColumnExpressions.map { + case ne: NamedExpression => ne -> ne + case other => + val withAlias = Alias(other, other.toString)() + other -> withAlias + } + val distinctColumnExpressionMap = namedDistinctColumnExpressions.toMap + val distinctColumnAttributes = namedDistinctColumnExpressions.map(_._2.toAttribute) + + val partialAggregateExpressions = functionsWithoutDistinct.map { + case AggregateExpression2(aggregateFunction, mode, _) => + AggregateExpression2(aggregateFunction, Partial, false) + } + val partialAggregateAttributes = partialAggregateExpressions.flatMap { agg => + agg.aggregateFunction.bufferAttributes + } + val partialAggregate = + Aggregate2Sort( + None: Option[Seq[Expression]], + (namedGroupingExpressions ++ namedDistinctColumnExpressions).map(_._2), + partialAggregateExpressions, + partialAggregateAttributes, + namedGroupingAttributes ++ distinctColumnAttributes ++ partialAggregateAttributes, + child) + + // 2. Create an Aggregate Operator for partial merge aggregations. + val partialMergeAggregateExpressions = functionsWithoutDistinct.map { + case AggregateExpression2(aggregateFunction, mode, _) => + AggregateExpression2(aggregateFunction, PartialMerge, false) + } + val partialMergeAggregateAttributes = + partialMergeAggregateExpressions.map { + expr => aggregateFunctionMap(expr.aggregateFunction, expr.isDistinct) + } + val partialMergeAggregate = + Aggregate2Sort( + Some(namedGroupingAttributes), + namedGroupingAttributes ++ distinctColumnAttributes, + partialMergeAggregateExpressions, + partialMergeAggregateAttributes, + namedGroupingAttributes ++ distinctColumnAttributes ++ partialMergeAggregateAttributes, + partialAggregate) + + // 3. Create an Aggregate Operator for partial merge aggregations. + val finalAggregateExpressions = functionsWithoutDistinct.map { + case AggregateExpression2(aggregateFunction, mode, _) => + AggregateExpression2(aggregateFunction, Final, false) + } + val finalAggregateAttributes = + finalAggregateExpressions.map { + expr => aggregateFunctionMap(expr.aggregateFunction, expr.isDistinct) + } + val (completeAggregateExpressions, completeAggregateAttributes) = functionsWithDistinct.map { + // Children of an AggregateFunction with DISTINCT keyword has already + // been evaluated. At here, we need to replace original children + // to AttributeReferences. + case agg @ AggregateExpression2(aggregateFunction, mode, isDistinct) => + val rewrittenAggregateFunction = aggregateFunction.transformDown { + case expr if distinctColumnExpressionMap.contains(expr) => + distinctColumnExpressionMap(expr).toAttribute + }.asInstanceOf[AggregateFunction2] + // We rewrite the aggregate function to a non-distinct aggregation because + // its input will have distinct arguments. + val rewrittenAggregateExpression = + AggregateExpression2(rewrittenAggregateFunction, Complete, false) + + val aggregateFunctionAttribute = aggregateFunctionMap(agg.aggregateFunction, isDistinct) + (rewrittenAggregateExpression -> aggregateFunctionAttribute) + }.unzip + + val rewrittenResultExpressions = resultExpressions.map { expr => + expr.transform { + case agg: AggregateExpression2 => + aggregateFunctionMap(agg.aggregateFunction, agg.isDistinct).toAttribute + case expression => + // We do not rely on the equality check at here since attributes may + // different cosmetically. Instead, we use semanticEquals. + groupExpressionMap.collectFirst { + case (expr, ne) if expr semanticEquals expression => ne.toAttribute + }.getOrElse(expression) + }.asInstanceOf[NamedExpression] + } + val finalAndCompleteAggregate = FinalAndCompleteAggregate2Sort( + namedGroupingAttributes ++ distinctColumnAttributes, + namedGroupingAttributes, + finalAggregateExpressions, + finalAggregateAttributes, + completeAggregateExpressions, + completeAggregateAttributes, + rewrittenResultExpressions, + partialMergeAggregate) + + finalAndCompleteAggregate :: Nil + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala index 4c063c299ba53..fdd7ad59aba50 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala @@ -56,14 +56,17 @@ case class Project(projectList: Seq[NamedExpression], child: SparkPlan) extends case class Filter(condition: Expression, child: SparkPlan) extends UnaryNode { override def output: Seq[Attribute] = child.output - @transient lazy val conditionEvaluator: (InternalRow) => Boolean = - newPredicate(condition, child.output) - protected override def doExecute(): RDD[InternalRow] = child.execute().mapPartitions { iter => - iter.filter(conditionEvaluator) + iter.filter(newPredicate(condition, child.output)) } override def outputOrdering: Seq[SortOrder] = child.outputOrdering + + override def outputsUnsafeRows: Boolean = child.outputsUnsafeRows + + override def canProcessUnsafeRows: Boolean = true + + override def canProcessSafeRows: Boolean = true } /** @@ -104,6 +107,9 @@ case class Sample( case class Union(children: Seq[SparkPlan]) extends SparkPlan { // TODO: attributes output by union should be distinct for nullability purposes override def output: Seq[Attribute] = children.head.output + override def outputsUnsafeRows: Boolean = children.forall(_.outputsUnsafeRows) + override def canProcessUnsafeRows: Boolean = true + override def canProcessSafeRows: Boolean = true protected override def doExecute(): RDD[InternalRow] = sparkContext.union(children.map(_.execute())) } @@ -306,6 +312,8 @@ case class UnsafeExternalSort( override def output: Seq[Attribute] = child.output override def outputOrdering: Seq[SortOrder] = sortOrder + + override def outputsUnsafeRows: Boolean = true } @DeveloperApi diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/commands.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/commands.scala index 5e9951f248ff2..bace3f8a9c8d4 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/commands.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/commands.scala @@ -35,8 +35,6 @@ import org.apache.spark.sql.{DataFrame, Row, SQLConf, SQLContext} * wrapped in `ExecutedCommand` during execution. */ private[sql] trait RunnableCommand extends LogicalPlan with logical.Command { - self: Product => - override def output: Seq[Attribute] = Seq.empty override def children: Seq[LogicalPlan] = Seq.empty def run(sqlContext: SQLContext): Seq[Row] diff --git a/sql/core/src/main/scala/org/apache/spark/sql/sources/DataSourceStrategy.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala similarity index 98% rename from sql/core/src/main/scala/org/apache/spark/sql/sources/DataSourceStrategy.scala rename to sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala index 70c9e06927582..2b400926177fe 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/sources/DataSourceStrategy.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala @@ -15,22 +15,21 @@ * limitations under the License. */ -package org.apache.spark.sql.sources +package org.apache.spark.sql.execution.datasources import org.apache.spark.{Logging, TaskContext} import org.apache.spark.deploy.SparkHadoopUtil import org.apache.spark.rdd.{MapPartitionsRDD, RDD, UnionRDD} -import org.apache.spark.sql._ -import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.expressions +import org.apache.spark.sql.catalyst.{InternalRow, expressions} import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.planning.PhysicalOperation import org.apache.spark.sql.catalyst.plans.logical import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan +import org.apache.spark.sql.sources._ import org.apache.spark.sql.types.{StringType, StructType} -import org.apache.spark.sql.{SaveMode, Strategy, execution, sources} -import org.apache.spark.util.{SerializableConfiguration, Utils} +import org.apache.spark.sql.{SaveMode, Strategy, execution, sources, _} import org.apache.spark.unsafe.types.UTF8String +import org.apache.spark.util.{SerializableConfiguration, Utils} /** * A Strategy for planning scans over data sources defined using the sources API. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/sources/LogicalRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/LogicalRelation.scala similarity index 88% rename from sql/core/src/main/scala/org/apache/spark/sql/sources/LogicalRelation.scala rename to sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/LogicalRelation.scala index f374abffdd505..a7123dc845fa2 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/sources/LogicalRelation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/LogicalRelation.scala @@ -14,11 +14,12 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package org.apache.spark.sql.sources +package org.apache.spark.sql.execution.datasources import org.apache.spark.sql.catalyst.analysis.MultiInstanceRelation -import org.apache.spark.sql.catalyst.expressions.{AttributeReference, AttributeMap} -import org.apache.spark.sql.catalyst.plans.logical.{Statistics, LeafNode, LogicalPlan} +import org.apache.spark.sql.catalyst.expressions.{AttributeMap, AttributeReference} +import org.apache.spark.sql.catalyst.plans.logical.{LeafNode, LogicalPlan, Statistics} +import org.apache.spark.sql.sources.BaseRelation /** * Used to link a [[BaseRelation]] in to a logical query plan. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/sources/PartitioningUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PartitioningUtils.scala similarity index 99% rename from sql/core/src/main/scala/org/apache/spark/sql/sources/PartitioningUtils.scala rename to sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PartitioningUtils.scala index 8b2a45d8e970a..6b4a359db22d1 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/sources/PartitioningUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PartitioningUtils.scala @@ -15,9 +15,9 @@ * limitations under the License. */ -package org.apache.spark.sql.sources +package org.apache.spark.sql.execution.datasources -import java.lang.{Double => JDouble, Float => JFloat, Integer => JInteger, Long => JLong} +import java.lang.{Double => JDouble, Long => JLong} import java.math.{BigDecimal => JBigDecimal} import scala.collection.mutable.ArrayBuffer @@ -29,6 +29,7 @@ import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.{Cast, Literal} import org.apache.spark.sql.types._ + private[sql] case class Partition(values: InternalRow, path: String) private[sql] case class PartitionSpec(partitionColumns: StructType, partitions: Seq[Partition]) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/sources/commands.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/commands.scala similarity index 98% rename from sql/core/src/main/scala/org/apache/spark/sql/sources/commands.scala rename to sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/commands.scala index 5c6ef2dc90c73..cd2aa7f7433c5 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/sources/commands.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/commands.scala @@ -15,7 +15,7 @@ * limitations under the License. */ -package org.apache.spark.sql.sources +package org.apache.spark.sql.execution.datasources import java.util.{Date, UUID} @@ -24,7 +24,6 @@ import scala.collection.JavaConversions.asScalaIterator import org.apache.hadoop.fs.Path import org.apache.hadoop.mapreduce._ import org.apache.hadoop.mapreduce.lib.output.{FileOutputCommitter => MapReduceFileOutputCommitter, FileOutputFormat} - import org.apache.spark._ import org.apache.spark.mapred.SparkHadoopMapRedUtil import org.apache.spark.mapreduce.SparkHadoopMapReduceUtil @@ -35,9 +34,11 @@ import org.apache.spark.sql.catalyst.expressions.codegen.GenerateProjection import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, Project} import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow} import org.apache.spark.sql.execution.RunnableCommand +import org.apache.spark.sql.sources._ import org.apache.spark.sql.types.StringType import org.apache.spark.util.SerializableConfiguration + private[sql] case class InsertIntoDataSource( logicalRelation: LogicalRelation, query: LogicalPlan, @@ -99,7 +100,7 @@ private[sql] case class InsertIntoHadoopFsRelation( val pathExists = fs.exists(qualifiedOutputPath) val doInsertion = (mode, pathExists) match { case (SaveMode.ErrorIfExists, true) => - sys.error(s"path $qualifiedOutputPath already exists.") + throw new AnalysisException(s"path $qualifiedOutputPath already exists.") case (SaveMode.Overwrite, true) => fs.delete(qualifiedOutputPath, true) true @@ -107,6 +108,8 @@ private[sql] case class InsertIntoHadoopFsRelation( true case (SaveMode.Ignore, exists) => !exists + case (s, exists) => + throw new IllegalStateException(s"unsupported save mode $s ($exists)") } // If we are appending data to an existing dir. val isAppend = pathExists && (mode == SaveMode.Append) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/sources/ddl.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/ddl.scala similarity index 96% rename from sql/core/src/main/scala/org/apache/spark/sql/sources/ddl.scala rename to sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/ddl.scala index d7440c55bd4a6..c8033d3c0470a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/sources/ddl.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/ddl.scala @@ -15,23 +15,22 @@ * limitations under the License. */ -package org.apache.spark.sql.sources +package org.apache.spark.sql.execution.datasources import scala.language.{existentials, implicitConversions} import scala.util.matching.Regex import org.apache.hadoop.fs.Path - import org.apache.spark.Logging import org.apache.spark.deploy.SparkHadoopUtil -import org.apache.spark.sql.catalyst.AbstractSparkSQLParser +import org.apache.spark.sql.{AnalysisException, DataFrame, SQLContext, SaveMode} import org.apache.spark.sql.catalyst.analysis.UnresolvedRelation -import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference} import org.apache.spark.sql.catalyst.plans.logical._ +import org.apache.spark.sql.catalyst.{AbstractSparkSQLParser, InternalRow} import org.apache.spark.sql.execution.RunnableCommand +import org.apache.spark.sql.sources._ import org.apache.spark.sql.types._ -import org.apache.spark.sql.{AnalysisException, DataFrame, SQLContext, SaveMode} import org.apache.spark.util.Utils /** @@ -247,7 +246,9 @@ private[sql] object ResolvedDataSource { val caseInsensitiveOptions = new CaseInsensitiveMap(options) val paths = { val patternPath = new Path(caseInsensitiveOptions("path")) - SparkHadoopUtil.get.globPath(patternPath).map(_.toString).toArray + val fs = patternPath.getFileSystem(sqlContext.sparkContext.hadoopConfiguration) + val qualifiedPattern = patternPath.makeQualified(fs.getUri, fs.getWorkingDirectory) + SparkHadoopUtil.get.globPathIfNecessary(qualifiedPattern).map(_.toString).toArray } val dataSchema = @@ -272,7 +273,9 @@ private[sql] object ResolvedDataSource { val caseInsensitiveOptions = new CaseInsensitiveMap(options) val paths = { val patternPath = new Path(caseInsensitiveOptions("path")) - SparkHadoopUtil.get.globPath(patternPath).map(_.toString).toArray + val fs = patternPath.getFileSystem(sqlContext.sparkContext.hadoopConfiguration) + val qualifiedPattern = patternPath.makeQualified(fs.getUri, fs.getWorkingDirectory) + SparkHadoopUtil.get.globPathIfNecessary(qualifiedPattern).map(_.toString).toArray } dataSource.createRelation(sqlContext, paths, None, None, caseInsensitiveOptions) case dataSource: org.apache.spark.sql.sources.SchemaRelationProvider => diff --git a/sql/core/src/main/scala/org/apache/spark/sql/sources/rules.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/rules.scala similarity index 90% rename from sql/core/src/main/scala/org/apache/spark/sql/sources/rules.scala rename to sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/rules.scala index a3fd7f13b3db7..11bb49b8d83de 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/sources/rules.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/rules.scala @@ -15,15 +15,15 @@ * limitations under the License. */ -package org.apache.spark.sql.sources +package org.apache.spark.sql.execution.datasources -import org.apache.spark.sql.{SaveMode, AnalysisException} -import org.apache.spark.sql.catalyst.analysis.{EliminateSubQueries, Catalog} -import org.apache.spark.sql.catalyst.expressions.{Attribute, Cast, Alias} +import org.apache.spark.sql.{AnalysisException, SaveMode} +import org.apache.spark.sql.catalyst.analysis.{Catalog, EliminateSubQueries} +import org.apache.spark.sql.catalyst.expressions.{Alias, Attribute, Cast} import org.apache.spark.sql.catalyst.plans.logical -import org.apache.spark.sql.catalyst.plans.logical.{InsertIntoTable, LogicalPlan, Project} +import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.rules.Rule -import org.apache.spark.sql.types.DataType +import org.apache.spark.sql.sources.{BaseRelation, HadoopFsRelation, InsertableRelation} /** * A rule to do pre-insert data type casting and field renaming. Before we insert into @@ -119,6 +119,13 @@ private[sql] case class PreWriteCheck(catalog: Catalog) extends (LogicalPlan => // The relation in l is not an InsertableRelation. failAnalysis(s"$l does not allow insertion.") + case logical.InsertIntoTable(t, _, _, _, _) => + if (!t.isInstanceOf[LeafNode] || t == OneRowRelation || t.isInstanceOf[LocalRelation]) { + failAnalysis(s"Inserting into an RDD-based table is not allowed.") + } else { + // OK + } + case CreateTableUsingAsSelect(tableName, _, _, _, SaveMode.Overwrite, _, query) => // When the SaveMode is Overwrite, we need to check if the table is an input table of // the query. If so, we will throw an AnalysisException to let users know it is not allowed. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/expressions/MonotonicallyIncreasingID.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/expressions/MonotonicallyIncreasingID.scala index 437d143e53f3f..2645eb1854bce 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/expressions/MonotonicallyIncreasingID.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/expressions/MonotonicallyIncreasingID.scala @@ -19,7 +19,8 @@ package org.apache.spark.sql.execution.expressions import org.apache.spark.TaskContext import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.expressions.LeafExpression +import org.apache.spark.sql.catalyst.expressions.{Nondeterministic, LeafExpression} +import org.apache.spark.sql.catalyst.expressions.codegen.{GeneratedExpressionCode, CodeGenContext} import org.apache.spark.sql.types.{LongType, DataType} /** @@ -32,7 +33,7 @@ import org.apache.spark.sql.types.{LongType, DataType} * * Since this expression is stateful, it cannot be a case object. */ -private[sql] case class MonotonicallyIncreasingID() extends LeafExpression { +private[sql] case class MonotonicallyIncreasingID() extends LeafExpression with Nondeterministic { /** * Record ID within each partition. By being transient, count's value is reset to 0 every time @@ -40,6 +41,8 @@ private[sql] case class MonotonicallyIncreasingID() extends LeafExpression { */ @transient private[this] var count: Long = 0L + @transient private lazy val partitionMask = TaskContext.getPartitionId().toLong << 33 + override def nullable: Boolean = false override def dataType: DataType = LongType @@ -47,6 +50,20 @@ private[sql] case class MonotonicallyIncreasingID() extends LeafExpression { override def eval(input: InternalRow): Long = { val currentCount = count count += 1 - (TaskContext.get().partitionId().toLong << 33) + currentCount + partitionMask + currentCount + } + + override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + val countTerm = ctx.freshName("count") + val partitionMaskTerm = ctx.freshName("partitionMask") + ctx.addMutableState(ctx.JAVA_LONG, countTerm, s"$countTerm = 0L;") + ctx.addMutableState(ctx.JAVA_LONG, partitionMaskTerm, + s"$partitionMaskTerm = ((long) org.apache.spark.TaskContext.getPartitionId()) << 33;") + + ev.isNull = "false" + s""" + final ${ctx.javaType(dataType)} ${ev.primitive} = $partitionMaskTerm + $countTerm; + $countTerm++; + """ } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/expressions/SparkPartitionID.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/expressions/SparkPartitionID.scala index 822d3d8c9108d..53ddd47e3e0c1 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/expressions/SparkPartitionID.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/expressions/SparkPartitionID.scala @@ -19,18 +19,29 @@ package org.apache.spark.sql.execution.expressions import org.apache.spark.TaskContext import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.expressions.LeafExpression +import org.apache.spark.sql.catalyst.expressions.{Nondeterministic, LeafExpression} +import org.apache.spark.sql.catalyst.expressions.codegen.{GeneratedExpressionCode, CodeGenContext} import org.apache.spark.sql.types.{IntegerType, DataType} /** * Expression that returns the current partition id of the Spark task. */ -private[sql] case object SparkPartitionID extends LeafExpression { +private[sql] case object SparkPartitionID extends LeafExpression with Nondeterministic { override def nullable: Boolean = false override def dataType: DataType = IntegerType - override def eval(input: InternalRow): Int = TaskContext.get().partitionId() + @transient private lazy val partitionId = TaskContext.getPartitionId() + + override def eval(input: InternalRow): Int = partitionId + + override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + val idTerm = ctx.freshName("partitionId") + ctx.addMutableState(ctx.JAVA_INT, idTerm, + s"$idTerm = org.apache.spark.TaskContext.getPartitionId();") + ev.isNull = "false" + s"final ${ctx.javaType(dataType)} ${ev.primitive} = $idTerm;" + } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoin.scala index 7ffdce60d2955..abaa4a6ce86a2 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoin.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoin.scala @@ -62,7 +62,7 @@ case class BroadcastHashJoin( private val broadcastFuture = future { // Note that we use .execute().collect() because we don't want to convert data to Scala types val input: Array[InternalRow] = buildPlan.execute().map(_.copy()).collect() - val hashed = HashedRelation(input.iterator, buildSideKeyGenerator, input.length) + val hashed = buildHashRelation(input.iterator) sparkContext.broadcast(hashed) }(BroadcastHashJoin.broadcastHashJoinExecutionContext) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashOuterJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashOuterJoin.scala index ab757fc7de6cd..c9d1a880f4ef4 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashOuterJoin.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashOuterJoin.scala @@ -17,6 +17,9 @@ package org.apache.spark.sql.execution.joins +import scala.concurrent._ +import scala.concurrent.duration._ + import org.apache.spark.annotation.DeveloperApi import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.InternalRow @@ -26,10 +29,6 @@ import org.apache.spark.sql.catalyst.plans.{JoinType, LeftOuter, RightOuter} import org.apache.spark.sql.execution.{BinaryNode, SparkPlan} import org.apache.spark.util.ThreadUtils -import scala.collection.JavaConversions._ -import scala.concurrent._ -import scala.concurrent.duration._ - /** * :: DeveloperApi :: * Performs a outer hash join for two child relations. When the output RDD of this operator is @@ -58,28 +57,11 @@ case class BroadcastHashOuterJoin( override def requiredChildDistribution: Seq[Distribution] = UnspecifiedDistribution :: UnspecifiedDistribution :: Nil - private[this] lazy val (buildPlan, streamedPlan) = joinType match { - case RightOuter => (left, right) - case LeftOuter => (right, left) - case x => - throw new IllegalArgumentException( - s"BroadcastHashOuterJoin should not take $x as the JoinType") - } - - private[this] lazy val (buildKeys, streamedKeys) = joinType match { - case RightOuter => (leftKeys, rightKeys) - case LeftOuter => (rightKeys, leftKeys) - case x => - throw new IllegalArgumentException( - s"BroadcastHashOuterJoin should not take $x as the JoinType") - } - @transient private val broadcastFuture = future { // Note that we use .execute().collect() because we don't want to convert data to Scala types val input: Array[InternalRow] = buildPlan.execute().map(_.copy()).collect() - // buildHashTable uses code-generated rows as keys, which are not serializable - val hashed = buildHashTable(input.iterator, newProjection(buildKeys, buildPlan.output)) + val hashed = buildHashRelation(input.iterator) sparkContext.broadcast(hashed) }(BroadcastHashOuterJoin.broadcastHashOuterJoinExecutionContext) @@ -89,21 +71,21 @@ case class BroadcastHashOuterJoin( streamedPlan.execute().mapPartitions { streamedIter => val joinedRow = new JoinedRow() val hashTable = broadcastRelation.value - val keyGenerator = newProjection(streamedKeys, streamedPlan.output) + val keyGenerator = streamedKeyGenerator joinType match { case LeftOuter => streamedIter.flatMap(currentRow => { val rowKey = keyGenerator(currentRow) joinedRow.withLeft(currentRow) - leftOuterIterator(rowKey, joinedRow, hashTable.getOrElse(rowKey, EMPTY_LIST)) + leftOuterIterator(rowKey, joinedRow, hashTable.get(rowKey)) }) case RightOuter => streamedIter.flatMap(currentRow => { val rowKey = keyGenerator(currentRow) joinedRow.withRight(currentRow) - rightOuterIterator(rowKey, hashTable.getOrElse(rowKey, EMPTY_LIST), joinedRow) + rightOuterIterator(rowKey, hashTable.get(rowKey), joinedRow) }) case x => diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastLeftSemiJoinHash.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastLeftSemiJoinHash.scala index f7b46d6888d7d..f71c0ce352904 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastLeftSemiJoinHash.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastLeftSemiJoinHash.scala @@ -33,37 +33,26 @@ case class BroadcastLeftSemiJoinHash( leftKeys: Seq[Expression], rightKeys: Seq[Expression], left: SparkPlan, - right: SparkPlan) extends BinaryNode with HashJoin { - - override val buildSide: BuildSide = BuildRight - - override def output: Seq[Attribute] = left.output + right: SparkPlan, + condition: Option[Expression]) extends BinaryNode with HashSemiJoin { protected override def doExecute(): RDD[InternalRow] = { - val buildIter = buildPlan.execute().map(_.copy()).collect().toIterator - val hashSet = new java.util.HashSet[InternalRow]() - var currentRow: InternalRow = null + val buildIter = right.execute().map(_.copy()).collect().toIterator - // Create a Hash set of buildKeys - while (buildIter.hasNext) { - currentRow = buildIter.next() - val rowKey = buildSideKeyGenerator(currentRow) - if (!rowKey.anyNull) { - val keyExists = hashSet.contains(rowKey) - if (!keyExists) { - // rowKey may be not serializable (from codegen) - hashSet.add(rowKey.copy()) - } - } - } + if (condition.isEmpty) { + val hashSet = buildKeyHashSet(buildIter) + val broadcastedRelation = sparkContext.broadcast(hashSet) - val broadcastedRelation = sparkContext.broadcast(hashSet) + left.execute().mapPartitions { streamIter => + hashSemiJoin(streamIter, broadcastedRelation.value) + } + } else { + val hashRelation = buildHashRelation(buildIter) + val broadcastedRelation = sparkContext.broadcast(hashRelation) - streamedPlan.execute().mapPartitions { streamIter => - val joinKeys = streamSideKeyGenerator() - streamIter.filter(current => { - !joinKeys(current).anyNull && broadcastedRelation.value.contains(joinKeys.currentValue) - }) + left.execute().mapPartitions { streamIter => + hashSemiJoin(streamIter, broadcastedRelation.value) + } } } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastNestedLoopJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastNestedLoopJoin.scala index 60b4266fad8b1..700636966f8be 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastNestedLoopJoin.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastNestedLoopJoin.scala @@ -44,6 +44,19 @@ case class BroadcastNestedLoopJoin( case BuildLeft => (right, left) } + override def outputsUnsafeRows: Boolean = left.outputsUnsafeRows || right.outputsUnsafeRows + override def canProcessUnsafeRows: Boolean = true + + @transient private[this] lazy val resultProjection: Projection = { + if (outputsUnsafeRows) { + UnsafeProjection.create(schema) + } else { + new Projection { + override def apply(r: InternalRow): InternalRow = r + } + } + } + override def outputPartitioning: Partitioning = streamed.outputPartitioning override def output: Seq[Attribute] = { @@ -74,6 +87,7 @@ case class BroadcastNestedLoopJoin( val includedBroadcastTuples = new scala.collection.mutable.BitSet(broadcastedRelation.value.size) val joinedRow = new JoinedRow + val leftNulls = new GenericMutableRow(left.output.size) val rightNulls = new GenericMutableRow(right.output.size) @@ -86,11 +100,11 @@ case class BroadcastNestedLoopJoin( val broadcastedRow = broadcastedRelation.value(i) buildSide match { case BuildRight if boundCondition(joinedRow(streamedRow, broadcastedRow)) => - matchedRows += joinedRow(streamedRow, broadcastedRow).copy() + matchedRows += resultProjection(joinedRow(streamedRow, broadcastedRow)).copy() streamRowMatched = true includedBroadcastTuples += i case BuildLeft if boundCondition(joinedRow(broadcastedRow, streamedRow)) => - matchedRows += joinedRow(broadcastedRow, streamedRow).copy() + matchedRows += resultProjection(joinedRow(broadcastedRow, streamedRow)).copy() streamRowMatched = true includedBroadcastTuples += i case _ => @@ -100,9 +114,9 @@ case class BroadcastNestedLoopJoin( (streamRowMatched, joinType, buildSide) match { case (false, LeftOuter | FullOuter, BuildRight) => - matchedRows += joinedRow(streamedRow, rightNulls).copy() + matchedRows += resultProjection(joinedRow(streamedRow, rightNulls)).copy() case (false, RightOuter | FullOuter, BuildLeft) => - matchedRows += joinedRow(leftNulls, streamedRow).copy() + matchedRows += resultProjection(joinedRow(leftNulls, streamedRow)).copy() case _ => } } @@ -110,12 +124,9 @@ case class BroadcastNestedLoopJoin( } val includedBroadcastTuples = matchesOrStreamedRowsWithNulls.map(_._2) - val allIncludedBroadcastTuples = - if (includedBroadcastTuples.count == 0) { - new scala.collection.mutable.BitSet(broadcastedRelation.value.size) - } else { - includedBroadcastTuples.reduce(_ ++ _) - } + val allIncludedBroadcastTuples = includedBroadcastTuples.fold( + new scala.collection.mutable.BitSet(broadcastedRelation.value.size) + )(_ ++ _) val leftNulls = new GenericMutableRow(left.output.size) val rightNulls = new GenericMutableRow(right.output.size) @@ -127,8 +138,10 @@ case class BroadcastNestedLoopJoin( while (i < rel.length) { if (!allIncludedBroadcastTuples.contains(i)) { (joinType, buildSide) match { - case (RightOuter | FullOuter, BuildRight) => buf += new JoinedRow(leftNulls, rel(i)) - case (LeftOuter | FullOuter, BuildLeft) => buf += new JoinedRow(rel(i), rightNulls) + case (RightOuter | FullOuter, BuildRight) => + buf += resultProjection(new JoinedRow(leftNulls, rel(i))) + case (LeftOuter | FullOuter, BuildLeft) => + buf += resultProjection(new JoinedRow(rel(i), rightNulls)) case _ => } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashJoin.scala index ff85ea3f6a410..ae34409bcfcca 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashJoin.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashJoin.scala @@ -44,11 +44,20 @@ trait HashJoin { override def output: Seq[Attribute] = left.output ++ right.output - @transient protected lazy val buildSideKeyGenerator: Projection = - newProjection(buildKeys, buildPlan.output) + protected[this] def supportUnsafe: Boolean = { + (self.codegenEnabled && UnsafeProjection.canSupport(buildKeys) + && UnsafeProjection.canSupport(self.schema)) + } + + override def outputsUnsafeRows: Boolean = supportUnsafe + override def canProcessUnsafeRows: Boolean = supportUnsafe - @transient protected lazy val streamSideKeyGenerator: () => MutableProjection = - newMutableProjection(streamedKeys, streamedPlan.output) + @transient protected lazy val streamSideKeyGenerator: Projection = + if (supportUnsafe) { + UnsafeProjection.create(streamedKeys, streamedPlan.output) + } else { + newMutableProjection(streamedKeys, streamedPlan.output)() + } protected def hashJoin( streamIter: Iterator[InternalRow], @@ -61,8 +70,17 @@ trait HashJoin { // Mutable per row objects. private[this] val joinRow = new JoinedRow2 + private[this] val resultProjection: Projection = { + if (supportUnsafe) { + UnsafeProjection.create(self.schema) + } else { + new Projection { + override def apply(r: InternalRow): InternalRow = r + } + } + } - private[this] val joinKeys = streamSideKeyGenerator() + private[this] val joinKeys = streamSideKeyGenerator override final def hasNext: Boolean = (currentMatchPosition != -1 && currentMatchPosition < currentHashMatches.size) || @@ -74,7 +92,7 @@ trait HashJoin { case BuildLeft => joinRow(currentHashMatches(currentMatchPosition), currentStreamedRow) } currentMatchPosition += 1 - ret + resultProjection(ret) } /** @@ -89,8 +107,9 @@ trait HashJoin { while (currentHashMatches == null && streamIter.hasNext) { currentStreamedRow = streamIter.next() - if (!joinKeys(currentStreamedRow).anyNull) { - currentHashMatches = hashedRelation.get(joinKeys.currentValue) + val key = joinKeys(currentStreamedRow) + if (!key.anyNull) { + currentHashMatches = hashedRelation.get(key) } } @@ -103,4 +122,12 @@ trait HashJoin { } } } + + protected[this] def buildHashRelation(buildIter: Iterator[InternalRow]): HashedRelation = { + if (supportUnsafe) { + UnsafeHashedRelation(buildIter, buildKeys, buildPlan) + } else { + HashedRelation(buildIter, newProjection(buildKeys, buildPlan.output)) + } + } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashOuterJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashOuterJoin.scala index 0522ee85eeb8a..6bf2f82954046 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashOuterJoin.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashOuterJoin.scala @@ -23,7 +23,7 @@ import org.apache.spark.annotation.DeveloperApi import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.physical.{Partitioning, UnknownPartitioning} -import org.apache.spark.sql.catalyst.plans.{FullOuter, JoinType, LeftOuter, RightOuter} +import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.execution.SparkPlan import org.apache.spark.util.collection.CompactBuffer @@ -38,7 +38,7 @@ trait HashOuterJoin { val left: SparkPlan val right: SparkPlan -override def outputPartitioning: Partitioning = joinType match { + override def outputPartitioning: Partitioning = joinType match { case LeftOuter => left.outputPartitioning case RightOuter => right.outputPartitioning case FullOuter => UnknownPartitioning(left.outputPartitioning.numPartitions) @@ -59,14 +59,56 @@ override def outputPartitioning: Partitioning = joinType match { } } + protected[this] lazy val (buildPlan, streamedPlan) = joinType match { + case RightOuter => (left, right) + case LeftOuter => (right, left) + case x => + throw new IllegalArgumentException( + s"HashOuterJoin should not take $x as the JoinType") + } + + protected[this] lazy val (buildKeys, streamedKeys) = joinType match { + case RightOuter => (leftKeys, rightKeys) + case LeftOuter => (rightKeys, leftKeys) + case x => + throw new IllegalArgumentException( + s"HashOuterJoin should not take $x as the JoinType") + } + + protected[this] def supportUnsafe: Boolean = { + (self.codegenEnabled && joinType != FullOuter + && UnsafeProjection.canSupport(buildKeys) + && UnsafeProjection.canSupport(self.schema)) + } + + override def outputsUnsafeRows: Boolean = supportUnsafe + override def canProcessUnsafeRows: Boolean = supportUnsafe + + protected[this] def streamedKeyGenerator(): Projection = { + if (supportUnsafe) { + UnsafeProjection.create(streamedKeys, streamedPlan.output) + } else { + newProjection(streamedKeys, streamedPlan.output) + } + } + + @transient private[this] lazy val resultProjection: Projection = { + if (supportUnsafe) { + UnsafeProjection.create(self.schema) + } else { + new Projection { + override def apply(r: InternalRow): InternalRow = r + } + } + } + @transient private[this] lazy val DUMMY_LIST = CompactBuffer[InternalRow](null) @transient protected[this] lazy val EMPTY_LIST = CompactBuffer[InternalRow]() @transient private[this] lazy val leftNullRow = new GenericInternalRow(left.output.length) @transient private[this] lazy val rightNullRow = new GenericInternalRow(right.output.length) @transient private[this] lazy val boundCondition = - condition.map( - newPredicate(_, left.output ++ right.output)).getOrElse((row: InternalRow) => true) + newPredicate(condition.getOrElse(Literal(true)), left.output ++ right.output) // TODO we need to rewrite all of the iterators with our own implementation instead of the Scala // iterator for performance purpose. @@ -77,16 +119,20 @@ override def outputPartitioning: Partitioning = joinType match { rightIter: Iterable[InternalRow]): Iterator[InternalRow] = { val ret: Iterable[InternalRow] = { if (!key.anyNull) { - val temp = rightIter.collect { - case r if boundCondition(joinedRow.withRight(r)) => joinedRow.copy() + val temp = if (rightIter != null) { + rightIter.collect { + case r if boundCondition(joinedRow.withRight(r)) => resultProjection(joinedRow).copy() + } + } else { + List.empty } if (temp.isEmpty) { - joinedRow.withRight(rightNullRow).copy :: Nil + resultProjection(joinedRow.withRight(rightNullRow)).copy :: Nil } else { temp } } else { - joinedRow.withRight(rightNullRow).copy :: Nil + resultProjection(joinedRow.withRight(rightNullRow)).copy :: Nil } } ret.iterator @@ -98,17 +144,21 @@ override def outputPartitioning: Partitioning = joinType match { joinedRow: JoinedRow): Iterator[InternalRow] = { val ret: Iterable[InternalRow] = { if (!key.anyNull) { - val temp = leftIter.collect { - case l if boundCondition(joinedRow.withLeft(l)) => - joinedRow.copy() + val temp = if (leftIter != null) { + leftIter.collect { + case l if boundCondition(joinedRow.withLeft(l)) => + resultProjection(joinedRow).copy() + } + } else { + List.empty } if (temp.isEmpty) { - joinedRow.withLeft(leftNullRow).copy :: Nil + resultProjection(joinedRow.withLeft(leftNullRow)).copy :: Nil } else { temp } } else { - joinedRow.withLeft(leftNullRow).copy :: Nil + resultProjection(joinedRow.withLeft(leftNullRow)).copy :: Nil } } ret.iterator @@ -160,6 +210,7 @@ override def outputPartitioning: Partitioning = joinType match { } } + // This is only used by FullOuter protected[this] def buildHashTable( iter: Iterator[InternalRow], keyGenerator: Projection): JavaHashMap[InternalRow, CompactBuffer[InternalRow]] = { @@ -179,4 +230,12 @@ override def outputPartitioning: Partitioning = joinType match { hashTable } + + protected[this] def buildHashRelation(buildIter: Iterator[InternalRow]): HashedRelation = { + if (supportUnsafe) { + UnsafeHashedRelation(buildIter, buildKeys, buildPlan) + } else { + HashedRelation(buildIter, newProjection(buildKeys, buildPlan.output)) + } + } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashSemiJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashSemiJoin.scala new file mode 100644 index 0000000000000..7f49264d40354 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashSemiJoin.scala @@ -0,0 +1,111 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.joins + +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.execution.SparkPlan + + +trait HashSemiJoin { + self: SparkPlan => + val leftKeys: Seq[Expression] + val rightKeys: Seq[Expression] + val left: SparkPlan + val right: SparkPlan + val condition: Option[Expression] + + override def output: Seq[Attribute] = left.output + + protected[this] def supportUnsafe: Boolean = { + (self.codegenEnabled && UnsafeProjection.canSupport(leftKeys) + && UnsafeProjection.canSupport(rightKeys) + && UnsafeProjection.canSupport(left.schema)) + } + + override def outputsUnsafeRows: Boolean = right.outputsUnsafeRows + override def canProcessUnsafeRows: Boolean = supportUnsafe + + @transient protected lazy val leftKeyGenerator: Projection = + if (supportUnsafe) { + UnsafeProjection.create(leftKeys, left.output) + } else { + newMutableProjection(leftKeys, left.output)() + } + + @transient protected lazy val rightKeyGenerator: Projection = + if (supportUnsafe) { + UnsafeProjection.create(rightKeys, right.output) + } else { + newMutableProjection(rightKeys, right.output)() + } + + @transient private lazy val boundCondition = + newPredicate(condition.getOrElse(Literal(true)), left.output ++ right.output) + + protected def buildKeyHashSet(buildIter: Iterator[InternalRow]): java.util.Set[InternalRow] = { + val hashSet = new java.util.HashSet[InternalRow]() + var currentRow: InternalRow = null + + // Create a Hash set of buildKeys + val rightKey = rightKeyGenerator + while (buildIter.hasNext) { + currentRow = buildIter.next() + val rowKey = rightKey(currentRow) + if (!rowKey.anyNull) { + val keyExists = hashSet.contains(rowKey) + if (!keyExists) { + hashSet.add(rowKey.copy()) + } + } + } + hashSet + } + + protected def hashSemiJoin( + streamIter: Iterator[InternalRow], + hashSet: java.util.Set[InternalRow]): Iterator[InternalRow] = { + val joinKeys = leftKeyGenerator + streamIter.filter(current => { + val key = joinKeys(current) + !key.anyNull && hashSet.contains(key) + }) + } + + protected def buildHashRelation(buildIter: Iterator[InternalRow]): HashedRelation = { + if (supportUnsafe) { + UnsafeHashedRelation(buildIter, rightKeys, right) + } else { + HashedRelation(buildIter, newProjection(rightKeys, right.output)) + } + } + + protected def hashSemiJoin( + streamIter: Iterator[InternalRow], + hashedRelation: HashedRelation): Iterator[InternalRow] = { + val joinKeys = leftKeyGenerator + val joinedRow = new JoinedRow + streamIter.filter { current => + val key = joinKeys(current) + lazy val rowBuffer = hashedRelation.get(key) + !key.anyNull && rowBuffer != null && rowBuffer.exists { + (row: InternalRow) => boundCondition(joinedRow(current, row)) + } + } + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala index 6b51f5d4151d3..8d5731afd59b8 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala @@ -17,12 +17,13 @@ package org.apache.spark.sql.execution.joins -import java.io.{ObjectInput, ObjectOutput, Externalizable} +import java.io.{Externalizable, ObjectInput, ObjectOutput} import java.util.{HashMap => JavaHashMap} import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.expressions.Projection -import org.apache.spark.sql.execution.SparkSqlSerializer +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.execution.{SparkPlan, SparkSqlSerializer} +import org.apache.spark.sql.types.StructType import org.apache.spark.util.collection.CompactBuffer @@ -98,7 +99,6 @@ final class UniqueKeyHashedRelation(private var hashTable: JavaHashMap[InternalR } } - // TODO(rxin): a version of [[HashedRelation]] backed by arrays for consecutive integer keys. @@ -148,3 +148,80 @@ private[joins] object HashedRelation { } } } + + +/** + * A HashedRelation for UnsafeRow, which is backed by BytesToBytesMap that maps the key into a + * sequence of values. + * + * TODO(davies): use BytesToBytesMap + */ +private[joins] final class UnsafeHashedRelation( + private var hashTable: JavaHashMap[UnsafeRow, CompactBuffer[UnsafeRow]]) + extends HashedRelation with Externalizable { + + def this() = this(null) // Needed for serialization + + override def get(key: InternalRow): CompactBuffer[InternalRow] = { + val unsafeKey = key.asInstanceOf[UnsafeRow] + // Thanks to type eraser + hashTable.get(unsafeKey).asInstanceOf[CompactBuffer[InternalRow]] + } + + override def writeExternal(out: ObjectOutput): Unit = { + writeBytes(out, SparkSqlSerializer.serialize(hashTable)) + } + + override def readExternal(in: ObjectInput): Unit = { + hashTable = SparkSqlSerializer.deserialize(readBytes(in)) + } +} + +private[joins] object UnsafeHashedRelation { + + def apply( + input: Iterator[InternalRow], + buildKeys: Seq[Expression], + buildPlan: SparkPlan, + sizeEstimate: Int = 64): HashedRelation = { + val boundedKeys = buildKeys.map(BindReferences.bindReference(_, buildPlan.output)) + apply(input, boundedKeys, buildPlan.schema, sizeEstimate) + } + + // Used for tests + def apply( + input: Iterator[InternalRow], + buildKeys: Seq[Expression], + rowSchema: StructType, + sizeEstimate: Int): HashedRelation = { + + // TODO: Use BytesToBytesMap. + val hashTable = new JavaHashMap[UnsafeRow, CompactBuffer[UnsafeRow]](sizeEstimate) + val toUnsafe = UnsafeProjection.create(rowSchema) + val keyGenerator = UnsafeProjection.create(buildKeys) + + // Create a mapping of buildKeys -> rows + while (input.hasNext) { + val currentRow = input.next() + val unsafeRow = if (currentRow.isInstanceOf[UnsafeRow]) { + currentRow.asInstanceOf[UnsafeRow] + } else { + toUnsafe(currentRow) + } + val rowKey = keyGenerator(unsafeRow) + if (!rowKey.anyNull) { + val existingMatchList = hashTable.get(rowKey) + val matchList = if (existingMatchList == null) { + val newMatchList = new CompactBuffer[UnsafeRow]() + hashTable.put(rowKey.copy(), newMatchList) + newMatchList + } else { + existingMatchList + } + matchList += unsafeRow.copy() + } + } + + new UnsafeHashedRelation(hashTable) + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/LeftSemiJoinBNL.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/LeftSemiJoinBNL.scala index db5be9f453674..4443455ef11fe 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/LeftSemiJoinBNL.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/LeftSemiJoinBNL.scala @@ -39,6 +39,9 @@ case class LeftSemiJoinBNL( override def output: Seq[Attribute] = left.output + override def outputsUnsafeRows: Boolean = streamed.outputsUnsafeRows + override def canProcessUnsafeRows: Boolean = true + /** The Streamed Relation */ override def left: SparkPlan = streamed diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/LeftSemiJoinHash.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/LeftSemiJoinHash.scala index 611ba928a16ec..874712a4e739f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/LeftSemiJoinHash.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/LeftSemiJoinHash.scala @@ -20,7 +20,7 @@ package org.apache.spark.sql.execution.joins import org.apache.spark.annotation.DeveloperApi import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.expressions.{Attribute, Expression} +import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.physical.ClusteredDistribution import org.apache.spark.sql.execution.{BinaryNode, SparkPlan} @@ -34,36 +34,21 @@ case class LeftSemiJoinHash( leftKeys: Seq[Expression], rightKeys: Seq[Expression], left: SparkPlan, - right: SparkPlan) extends BinaryNode with HashJoin { - - override val buildSide: BuildSide = BuildRight + right: SparkPlan, + condition: Option[Expression]) extends BinaryNode with HashSemiJoin { override def requiredChildDistribution: Seq[ClusteredDistribution] = ClusteredDistribution(leftKeys) :: ClusteredDistribution(rightKeys) :: Nil - override def output: Seq[Attribute] = left.output - protected override def doExecute(): RDD[InternalRow] = { - buildPlan.execute().zipPartitions(streamedPlan.execute()) { (buildIter, streamIter) => - val hashSet = new java.util.HashSet[InternalRow]() - var currentRow: InternalRow = null - - // Create a Hash set of buildKeys - while (buildIter.hasNext) { - currentRow = buildIter.next() - val rowKey = buildSideKeyGenerator(currentRow) - if (!rowKey.anyNull) { - val keyExists = hashSet.contains(rowKey) - if (!keyExists) { - hashSet.add(rowKey) - } - } + right.execute().zipPartitions(left.execute()) { (buildIter, streamIter) => + if (condition.isEmpty) { + val hashSet = buildKeyHashSet(buildIter) + hashSemiJoin(streamIter, hashSet) + } else { + val hashRelation = buildHashRelation(buildIter) + hashSemiJoin(streamIter, hashRelation) } - - val joinKeys = streamSideKeyGenerator() - streamIter.filter(current => { - !joinKeys(current).anyNull && hashSet.contains(joinKeys.currentValue) - }) } } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashJoin.scala index 5439e10a60b2a..948d0ccebceb0 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashJoin.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashJoin.scala @@ -45,7 +45,7 @@ case class ShuffledHashJoin( protected override def doExecute(): RDD[InternalRow] = { buildPlan.execute().zipPartitions(streamedPlan.execute()) { (buildIter, streamIter) => - val hashed = HashedRelation(buildIter, buildSideKeyGenerator) + val hashed = buildHashRelation(buildIter) hashJoin(streamIter, hashed) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashOuterJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashOuterJoin.scala index ab0a6ad56acde..f54f1edd38ec8 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashOuterJoin.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashOuterJoin.scala @@ -50,24 +50,25 @@ case class ShuffledHashOuterJoin( // TODO this probably can be replaced by external sort (sort merged join?) joinType match { case LeftOuter => - val rightHashTable = buildHashTable(rightIter, newProjection(rightKeys, right.output)) - val keyGenerator = newProjection(leftKeys, left.output) + val hashed = buildHashRelation(rightIter) + val keyGenerator = streamedKeyGenerator() leftIter.flatMap( currentRow => { val rowKey = keyGenerator(currentRow) joinedRow.withLeft(currentRow) - leftOuterIterator(rowKey, joinedRow, rightHashTable.getOrElse(rowKey, EMPTY_LIST)) + leftOuterIterator(rowKey, joinedRow, hashed.get(rowKey)) }) case RightOuter => - val leftHashTable = buildHashTable(leftIter, newProjection(leftKeys, left.output)) - val keyGenerator = newProjection(rightKeys, right.output) + val hashed = buildHashRelation(leftIter) + val keyGenerator = streamedKeyGenerator() rightIter.flatMap ( currentRow => { val rowKey = keyGenerator(currentRow) joinedRow.withRight(currentRow) - rightOuterIterator(rowKey, leftHashTable.getOrElse(rowKey, EMPTY_LIST), joinedRow) + rightOuterIterator(rowKey, hashed.get(rowKey), joinedRow) }) case FullOuter => + // TODO(davies): use UnsafeRow val leftHashTable = buildHashTable(leftIter, newProjection(leftKeys, left.output)) val rightHashTable = buildHashTable(rightIter, newProjection(rightKeys, right.output)) (leftHashTable.keySet ++ rightHashTable.keySet).iterator.flatMap { key => diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/pythonUDFs.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/pythonUDFs.scala index 6d6e67dace177..e6e27a87c7151 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/pythonUDFs.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/pythonUDFs.scala @@ -51,15 +51,11 @@ private[spark] case class PythonUDF( broadcastVars: JList[Broadcast[PythonBroadcast]], accumulator: Accumulator[JList[Array[Byte]]], dataType: DataType, - children: Seq[Expression]) extends Expression with SparkLogging { + children: Seq[Expression]) extends Expression with Unevaluable with SparkLogging { override def toString: String = s"PythonUDF#$name(${children.mkString(",")})" override def nullable: Boolean = true - - override def eval(input: InternalRow): Any = { - throw new UnsupportedOperationException("PythonUDFs can not be directly evaluated.") - } } /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/rowFormatConverters.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/rowFormatConverters.scala new file mode 100644 index 0000000000000..29f3beb3cb3c8 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/rowFormatConverters.scala @@ -0,0 +1,118 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution + +import org.apache.spark.annotation.DeveloperApi +import org.apache.spark.rdd.RDD +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.rules.Rule + +/** + * :: DeveloperApi :: + * Converts Java-object-based rows into [[UnsafeRow]]s. + */ +@DeveloperApi +case class ConvertToUnsafe(child: SparkPlan) extends UnaryNode { + + require(UnsafeProjection.canSupport(child.schema), s"Cannot convert ${child.schema} to Unsafe") + + override def output: Seq[Attribute] = child.output + override def outputsUnsafeRows: Boolean = true + override def canProcessUnsafeRows: Boolean = false + override def canProcessSafeRows: Boolean = true + override protected def doExecute(): RDD[InternalRow] = { + child.execute().mapPartitions { iter => + val convertToUnsafe = UnsafeProjection.create(child.schema) + iter.map(convertToUnsafe) + } + } +} + +/** + * :: DeveloperApi :: + * Converts [[UnsafeRow]]s back into Java-object-based rows. + */ +@DeveloperApi +case class ConvertToSafe(child: SparkPlan) extends UnaryNode { + override def output: Seq[Attribute] = child.output + override def outputsUnsafeRows: Boolean = false + override def canProcessUnsafeRows: Boolean = true + override def canProcessSafeRows: Boolean = false + override protected def doExecute(): RDD[InternalRow] = { + child.execute().mapPartitions { iter => + val convertToSafe = FromUnsafeProjection(child.output.map(_.dataType)) + iter.map(convertToSafe) + } + } +} + +private[sql] object EnsureRowFormats extends Rule[SparkPlan] { + + private def onlyHandlesSafeRows(operator: SparkPlan): Boolean = + operator.canProcessSafeRows && !operator.canProcessUnsafeRows + + private def onlyHandlesUnsafeRows(operator: SparkPlan): Boolean = + operator.canProcessUnsafeRows && !operator.canProcessSafeRows + + private def handlesBothSafeAndUnsafeRows(operator: SparkPlan): Boolean = + operator.canProcessSafeRows && operator.canProcessUnsafeRows + + override def apply(operator: SparkPlan): SparkPlan = operator.transformUp { + case operator: SparkPlan if onlyHandlesSafeRows(operator) => + if (operator.children.exists(_.outputsUnsafeRows)) { + operator.withNewChildren { + operator.children.map { + c => if (c.outputsUnsafeRows) ConvertToSafe(c) else c + } + } + } else { + operator + } + case operator: SparkPlan if onlyHandlesUnsafeRows(operator) => + if (operator.children.exists(!_.outputsUnsafeRows)) { + operator.withNewChildren { + operator.children.map { + c => if (!c.outputsUnsafeRows) ConvertToUnsafe(c) else c + } + } + } else { + operator + } + case operator: SparkPlan if handlesBothSafeAndUnsafeRows(operator) => + if (operator.children.map(_.outputsUnsafeRows).toSet.size != 1) { + // If this operator's children produce both unsafe and safe rows, + // convert everything unsafe rows if all the schema of them are support by UnsafeRow + if (operator.children.forall(c => UnsafeProjection.canSupport(c.schema))) { + operator.withNewChildren { + operator.children.map { + c => if (!c.outputsUnsafeRows) ConvertToUnsafe(c) else c + } + } + } else { + operator.withNewChildren { + operator.children.map { + c => if (c.outputsUnsafeRows) ConvertToSafe(c) else c + } + } + } + } else { + operator + } + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/expressions/aggregate/udaf.scala b/sql/core/src/main/scala/org/apache/spark/sql/expressions/aggregate/udaf.scala new file mode 100644 index 0000000000000..6c49a906c848a --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/expressions/aggregate/udaf.scala @@ -0,0 +1,280 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.expressions.aggregate + +import org.apache.spark.Logging +import org.apache.spark.sql.catalyst.expressions.codegen.GenerateMutableProjection +import org.apache.spark.sql.catalyst.{InternalRow, CatalystTypeConverters} +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateFunction2 +import org.apache.spark.sql.types._ +import org.apache.spark.sql.Row + +/** + * The abstract class for implementing user-defined aggregate function. + */ +abstract class UserDefinedAggregateFunction extends Serializable { + + /** + * A [[StructType]] represents data types of input arguments of this aggregate function. + * For example, if a [[UserDefinedAggregateFunction]] expects two input arguments + * with type of [[DoubleType]] and [[LongType]], the returned [[StructType]] will look like + * + * ``` + * StructType(Seq(StructField("doubleInput", DoubleType), StructField("longInput", LongType))) + * ``` + * + * The name of a field of this [[StructType]] is only used to identify the corresponding + * input argument. Users can choose names to identify the input arguments. + */ + def inputSchema: StructType + + /** + * A [[StructType]] represents data types of values in the aggregation buffer. + * For example, if a [[UserDefinedAggregateFunction]]'s buffer has two values + * (i.e. two intermediate values) with type of [[DoubleType]] and [[LongType]], + * the returned [[StructType]] will look like + * + * ``` + * StructType(Seq(StructField("doubleInput", DoubleType), StructField("longInput", LongType))) + * ``` + * + * The name of a field of this [[StructType]] is only used to identify the corresponding + * buffer value. Users can choose names to identify the input arguments. + */ + def bufferSchema: StructType + + /** + * The [[DataType]] of the returned value of this [[UserDefinedAggregateFunction]]. + */ + def returnDataType: DataType + + /** Indicates if this function is deterministic. */ + def deterministic: Boolean + + /** + * Initializes the given aggregation buffer. Initial values set by this method should satisfy + * the condition that when merging two buffers with initial values, the new buffer should + * still store initial values. + */ + def initialize(buffer: MutableAggregationBuffer): Unit + + /** Updates the given aggregation buffer `buffer` with new input data from `input`. */ + def update(buffer: MutableAggregationBuffer, input: Row): Unit + + /** Merges two aggregation buffers and stores the updated buffer values back in `buffer1`. */ + def merge(buffer1: MutableAggregationBuffer, buffer2: Row): Unit + + /** + * Calculates the final result of this [[UserDefinedAggregateFunction]] based on the given + * aggregation buffer. + */ + def evaluate(buffer: Row): Any +} + +private[sql] abstract class AggregationBuffer( + toCatalystConverters: Array[Any => Any], + toScalaConverters: Array[Any => Any], + bufferOffset: Int) + extends Row { + + override def length: Int = toCatalystConverters.length + + protected val offsets: Array[Int] = { + val newOffsets = new Array[Int](length) + var i = 0 + while (i < newOffsets.length) { + newOffsets(i) = bufferOffset + i + i += 1 + } + newOffsets + } +} + +/** + * A Mutable [[Row]] representing an mutable aggregation buffer. + */ +class MutableAggregationBuffer private[sql] ( + toCatalystConverters: Array[Any => Any], + toScalaConverters: Array[Any => Any], + bufferOffset: Int, + var underlyingBuffer: MutableRow) + extends AggregationBuffer(toCatalystConverters, toScalaConverters, bufferOffset) { + + override def get(i: Int): Any = { + if (i >= length || i < 0) { + throw new IllegalArgumentException( + s"Could not access ${i}th value in this buffer because it only has $length values.") + } + toScalaConverters(i)(underlyingBuffer(offsets(i))) + } + + def update(i: Int, value: Any): Unit = { + if (i >= length || i < 0) { + throw new IllegalArgumentException( + s"Could not update ${i}th value in this buffer because it only has $length values.") + } + underlyingBuffer.update(offsets(i), toCatalystConverters(i)(value)) + } + + override def copy(): MutableAggregationBuffer = { + new MutableAggregationBuffer( + toCatalystConverters, + toScalaConverters, + bufferOffset, + underlyingBuffer) + } +} + +/** + * A [[Row]] representing an immutable aggregation buffer. + */ +class InputAggregationBuffer private[sql] ( + toCatalystConverters: Array[Any => Any], + toScalaConverters: Array[Any => Any], + bufferOffset: Int, + var underlyingInputBuffer: Row) + extends AggregationBuffer(toCatalystConverters, toScalaConverters, bufferOffset) { + + override def get(i: Int): Any = { + if (i >= length || i < 0) { + throw new IllegalArgumentException( + s"Could not access ${i}th value in this buffer because it only has $length values.") + } + toScalaConverters(i)(underlyingInputBuffer(offsets(i))) + } + + override def copy(): InputAggregationBuffer = { + new InputAggregationBuffer( + toCatalystConverters, + toScalaConverters, + bufferOffset, + underlyingInputBuffer) + } +} + +/** + * The internal wrapper used to hook a [[UserDefinedAggregateFunction]] `udaf` in the + * internal aggregation code path. + * @param children + * @param udaf + */ +case class ScalaUDAF( + children: Seq[Expression], + udaf: UserDefinedAggregateFunction) + extends AggregateFunction2 with Logging { + + require( + children.length == udaf.inputSchema.length, + s"$udaf only accepts ${udaf.inputSchema.length} arguments, " + + s"but ${children.length} are provided.") + + override def nullable: Boolean = true + + override def dataType: DataType = udaf.returnDataType + + override def deterministic: Boolean = udaf.deterministic + + override val inputTypes: Seq[DataType] = udaf.inputSchema.map(_.dataType) + + override val bufferSchema: StructType = udaf.bufferSchema + + override val bufferAttributes: Seq[AttributeReference] = bufferSchema.toAttributes + + override lazy val cloneBufferAttributes = bufferAttributes.map(_.newInstance()) + + val childrenSchema: StructType = { + val inputFields = children.zipWithIndex.map { + case (child, index) => + StructField(s"input$index", child.dataType, child.nullable, Metadata.empty) + } + StructType(inputFields) + } + + lazy val inputProjection = { + val inputAttributes = childrenSchema.toAttributes + log.debug( + s"Creating MutableProj: $children, inputSchema: $inputAttributes.") + try { + GenerateMutableProjection.generate(children, inputAttributes)() + } catch { + case e: Exception => + log.error("Failed to generate mutable projection, fallback to interpreted", e) + new InterpretedMutableProjection(children, inputAttributes) + } + } + + val inputToScalaConverters: Any => Any = + CatalystTypeConverters.createToScalaConverter(childrenSchema) + + val bufferValuesToCatalystConverters: Array[Any => Any] = bufferSchema.fields.map { field => + CatalystTypeConverters.createToCatalystConverter(field.dataType) + } + + val bufferValuesToScalaConverters: Array[Any => Any] = bufferSchema.fields.map { field => + CatalystTypeConverters.createToScalaConverter(field.dataType) + } + + lazy val inputAggregateBuffer: InputAggregationBuffer = + new InputAggregationBuffer( + bufferValuesToCatalystConverters, + bufferValuesToScalaConverters, + bufferOffset, + null) + + lazy val mutableAggregateBuffer: MutableAggregationBuffer = + new MutableAggregationBuffer( + bufferValuesToCatalystConverters, + bufferValuesToScalaConverters, + bufferOffset, + null) + + + override def initialize(buffer: MutableRow): Unit = { + mutableAggregateBuffer.underlyingBuffer = buffer + + udaf.initialize(mutableAggregateBuffer) + } + + override def update(buffer: MutableRow, input: InternalRow): Unit = { + mutableAggregateBuffer.underlyingBuffer = buffer + + udaf.update( + mutableAggregateBuffer, + inputToScalaConverters(inputProjection(input)).asInstanceOf[Row]) + } + + override def merge(buffer1: MutableRow, buffer2: InternalRow): Unit = { + mutableAggregateBuffer.underlyingBuffer = buffer1 + inputAggregateBuffer.underlyingInputBuffer = buffer2 + + udaf.merge(mutableAggregateBuffer, inputAggregateBuffer) + } + + override def eval(buffer: InternalRow = null): Any = { + inputAggregateBuffer.underlyingInputBuffer = buffer + + udaf.evaluate(inputAggregateBuffer) + } + + override def toString: String = { + s"""${udaf.getClass.getSimpleName}(${children.mkString(",")})""" + } + + override def nodeName: String = udaf.getClass.getSimpleName +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala index ffa52f62588dc..bfeecbe8b2ab5 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala @@ -42,6 +42,7 @@ import org.apache.spark.util.Utils * @groupname misc_funcs Misc functions * @groupname window_funcs Window functions * @groupname string_funcs String functions + * @groupname collection_funcs Collection functions * @groupname Ungrouped Support functions for DataFrames. * @since 1.3.0 */ @@ -68,6 +69,15 @@ object functions { */ def column(colName: String): Column = Column(colName) + /** + * Convert a number in string format from one base to another. + * + * @group math_funcs + * @since 1.5.0 + */ + def conv(num: Column, fromBase: Int, toBase: Int): Column = + Conv(num.expr, lit(fromBase).expr, lit(toBase).expr) + /** * Creates a [[Column]] of literal value. * @@ -585,7 +595,7 @@ object functions { } /** - * Returns the first column that is not null. + * Returns the first column that is not null and not NaN. * {{{ * df.select(coalesce(df("a"), df("b"))) * }}} @@ -602,12 +612,12 @@ object functions { def explode(e: Column): Column = Explode(e.expr) /** - * Converts a string exprsesion to lower case. + * Return true iff the column is NaN. * * @group normal_funcs - * @since 1.3.0 + * @since 1.5.0 */ - def lower(e: Column): Column = Lower(e.expr) + def isNaN(e: Column): Column = IsNaN(e.expr) /** * A column expression that generates monotonically increasing 64-bit integers. @@ -626,6 +636,15 @@ object functions { */ def monotonicallyIncreasingId(): Column = execution.expressions.MonotonicallyIncreasingID() + /** + * Return an alternative value `r` if `l` is NaN. + * This function is useful for mapping NaN values to null. + * + * @group normal_funcs + * @since 1.5.0 + */ + def nanvl(l: Column, r: Column): Column = NaNvl(l.expr, r.expr) + /** * Unary minus, i.e. negate the expression. * {{{ @@ -765,14 +784,6 @@ object functions { struct((colName +: colNames).map(col) : _*) } - /** - * Converts a string expression to upper case. - * - * @group normal_funcs - * @since 1.3.0 - */ - def upper(e: Column): Column = Upper(e.expr) - /** * Computes bitwise NOT. * @@ -1073,28 +1084,27 @@ object functions { def floor(columnName: String): Column = floor(Column(columnName)) /** - * Returns the greatest value of the list of values. + * Returns the greatest value of the list of values, skipping null values. + * This function takes at least 2 parameters. It will return null iff all parameters are null. * * @group normal_funcs * @since 1.5.0 */ @scala.annotation.varargs - def greatest(exprs: Column*): Column = if (exprs.length < 2) { - sys.error("GREATEST takes at least 2 parameters") - } else { - Greatest(exprs.map(_.expr): _*) + def greatest(exprs: Column*): Column = { + require(exprs.length > 1, "greatest requires at least 2 arguments.") + Greatest(exprs.map(_.expr)) } /** - * Returns the greatest value of the list of column names. + * Returns the greatest value of the list of column names, skipping null values. + * This function takes at least 2 parameters. It will return null iff all parameters are null. * * @group normal_funcs * @since 1.5.0 */ @scala.annotation.varargs - def greatest(columnName: String, columnNames: String*): Column = if (columnNames.isEmpty) { - sys.error("GREATEST takes at least 2 parameters") - } else { + def greatest(columnName: String, columnNames: String*): Column = { greatest((columnName +: columnNames).map(Column.apply): _*) } @@ -1106,14 +1116,6 @@ object functions { */ def hex(column: Column): Column = Hex(column.expr) - /** - * Computes hex value of the given input. - * - * @group math_funcs - * @since 1.5.0 - */ - def hex(colName: String): Column = hex(Column(colName)) - /** * Inverse of hex. Interprets each pair of characters as a hexadecimal number * and converts to the byte representation of number. @@ -1123,15 +1125,6 @@ object functions { */ def unhex(column: Column): Column = Unhex(column.expr) - /** - * Inverse of hex. Interprets each pair of characters as a hexadecimal number - * and converts to the byte representation of number. - * - * @group math_funcs - * @since 1.5.0 - */ - def unhex(colName: String): Column = unhex(Column(colName)) - /** * Computes `sqrt(a^2^ + b^2^)` without intermediate overflow or underflow. * @@ -1198,28 +1191,27 @@ object functions { def hypot(l: Double, rightName: String): Column = hypot(l, Column(rightName)) /** - * Returns the least value of the list of values. + * Returns the least value of the list of values, skipping null values. + * This function takes at least 2 parameters. It will return null iff all parameters are null. * * @group normal_funcs * @since 1.5.0 */ @scala.annotation.varargs - def least(exprs: Column*): Column = if (exprs.length < 2) { - sys.error("LEAST takes at least 2 parameters") - } else { - Least(exprs.map(_.expr): _*) + def least(exprs: Column*): Column = { + require(exprs.length > 1, "least requires at least 2 arguments.") + Least(exprs.map(_.expr)) } /** - * Returns the least value of the list of column names. + * Returns the least value of the list of column names, skipping null values. + * This function takes at least 2 parameters. It will return null iff all parameters are null. * * @group normal_funcs * @since 1.5.0 */ @scala.annotation.varargs - def least(columnName: String, columnNames: String*): Column = if (columnNames.isEmpty) { - sys.error("LEAST takes at least 2 parameters") - } else { + def least(columnName: String, columnNames: String*): Column = { least((columnName +: columnNames).map(Column.apply): _*) } @@ -1367,6 +1359,23 @@ object functions { */ def pow(l: Double, rightName: String): Column = pow(l, Column(rightName)) + /** + * Returns the positive value of dividend mod divisor. + * + * @group math_funcs + * @since 1.5.0 + */ + def pmod(dividend: Column, divisor: Column): Column = Pmod(dividend.expr, divisor.expr) + + /** + * Returns the positive value of dividend mod divisor. + * + * @group math_funcs + * @since 1.5.0 + */ + def pmod(dividendColName: String, divisorColName: String): Column = + pmod(Column(dividendColName), Column(divisorColName)) + /** * Returns the double value that is closest in value to the argument and * is equal to a mathematical integer. @@ -1385,6 +1394,38 @@ object functions { */ def rint(columnName: String): Column = rint(Column(columnName)) + /** + * Returns the value of the column `e` rounded to 0 decimal places. + * + * @group math_funcs + * @since 1.5.0 + */ + def round(e: Column): Column = round(e.expr, 0) + + /** + * Returns the value of the given column rounded to 0 decimal places. + * + * @group math_funcs + * @since 1.5.0 + */ + def round(columnName: String): Column = round(Column(columnName), 0) + + /** + * Returns the value of `e` rounded to `scale` decimal places. + * + * @group math_funcs + * @since 1.5.0 + */ + def round(e: Column, scale: Int): Column = Round(e.expr, Literal(scale)) + + /** + * Returns the value of the given column rounded to `scale` decimal places. + * + * @group math_funcs + * @since 1.5.0 + */ + def round(columnName: String, scale: Int): Column = round(Column(columnName), scale) + /** * Shift the the given value numBits left. If the given value is a long value, this function * will return a long value else it will return an integer value. @@ -1560,7 +1601,8 @@ object functions { ////////////////////////////////////////////////////////////////////////////////////////////// /** - * Calculates the MD5 digest and returns the value as a 32 character hex string. + * Calculates the MD5 digest of a binary column and returns the value + * as a 32 character hex string. * * @group misc_funcs * @since 1.5.0 @@ -1568,15 +1610,8 @@ object functions { def md5(e: Column): Column = Md5(e.expr) /** - * Calculates the MD5 digest and returns the value as a 32 character hex string. - * - * @group misc_funcs - * @since 1.5.0 - */ - def md5(columnName: String): Column = md5(Column(columnName)) - - /** - * Calculates the SHA-1 digest and returns the value as a 40 character hex string. + * Calculates the SHA-1 digest of a binary column and returns the value + * as a 40 character hex string. * * @group misc_funcs * @since 1.5.0 @@ -1584,15 +1619,11 @@ object functions { def sha1(e: Column): Column = Sha1(e.expr) /** - * Calculates the SHA-1 digest and returns the value as a 40 character hex string. + * Calculates the SHA-2 family of hash functions of a binary column and + * returns the value as a hex string. * - * @group misc_funcs - * @since 1.5.0 - */ - def sha1(columnName: String): Column = sha1(Column(columnName)) - - /** - * Calculates the SHA-2 family of hash functions and returns the value as a hex string. + * @param e column to compute SHA-2 on. + * @param numBits one of 224, 256, 384, or 512. * * @group misc_funcs * @since 1.5.0 @@ -1604,153 +1635,125 @@ object functions { } /** - * Calculates the SHA-2 family of hash functions and returns the value as a hex string. - * - * @group misc_funcs - * @since 1.5.0 - */ - def sha2(columnName: String, numBits: Int): Column = sha2(Column(columnName), numBits) - - /** - * Calculates the cyclic redundancy check value and returns the value as a bigint. + * Calculates the cyclic redundancy check value (CRC32) of a binary column and + * returns the value as a bigint. * * @group misc_funcs * @since 1.5.0 */ def crc32(e: Column): Column = Crc32(e.expr) - /** - * Calculates the cyclic redundancy check value and returns the value as a bigint. - * - * @group misc_funcs - * @since 1.5.0 - */ - def crc32(columnName: String): Column = crc32(Column(columnName)) - ////////////////////////////////////////////////////////////////////////////////////////////// // String functions ////////////////////////////////////////////////////////////////////////////////////////////// /** - * Computes the length of a given string value. + * Concatenates input strings together into a single string. * * @group string_funcs * @since 1.5.0 */ - def strlen(e: Column): Column = StringLength(e.expr) + @scala.annotation.varargs + def concat(exprs: Column*): Column = Concat(exprs.map(_.expr)) /** - * Computes the length of a given string column. + * Concatenates input strings together into a single string, using the given separator. * * @group string_funcs * @since 1.5.0 */ - def strlen(columnName: String): Column = strlen(Column(columnName)) - - /** - * Computes the Levenshtein distance of the two given strings. - * @group string_funcs - * @since 1.5.0 - */ - def levenshtein(l: Column, r: Column): Column = Levenshtein(l.expr, r.expr) - - /** - * Computes the Levenshtein distance of the two given strings. - * @group string_funcs - * @since 1.5.0 - */ - def levenshtein(leftColumnName: String, rightColumnName: String): Column = - levenshtein(Column(leftColumnName), Column(rightColumnName)) + @scala.annotation.varargs + def concat_ws(sep: String, exprs: Column*): Column = { + ConcatWs(Literal.create(sep, StringType) +: exprs.map(_.expr)) + } /** - * Computes the numeric value of the first character of the specified string value. + * Computes the length of a given string / binary value. * * @group string_funcs * @since 1.5.0 */ - def ascii(e: Column): Column = Ascii(e.expr) + def length(e: Column): Column = Length(e.expr) /** - * Computes the numeric value of the first character of the specified string column. + * Converts a string expression to lower case. * * @group string_funcs - * @since 1.5.0 + * @since 1.3.0 */ - def ascii(columnName: String): Column = ascii(Column(columnName)) + def lower(e: Column): Column = Lower(e.expr) /** - * Trim the spaces from both ends for the specified string value. + * Converts a string expression to upper case. * * @group string_funcs - * @since 1.5.0 + * @since 1.3.0 */ - def trim(e: Column): Column = StringTrim(e.expr) + def upper(e: Column): Column = Upper(e.expr) /** - * Trim the spaces from both ends for the specified column. + * Formats the number X to a format like '#,###,###.##', rounded to d decimal places, + * and returns the result as a string. + * If d is 0, the result has no decimal point or fractional part. + * If d < 0, the result will be null. * * @group string_funcs * @since 1.5.0 */ - def trim(columnName: String): Column = trim(Column(columnName)) + def format_number(x: Column, d: Int): Column = FormatNumber(x.expr, lit(d).expr) /** - * Trim the spaces from left end for the specified string value. - * + * Computes the Levenshtein distance of the two given string columns. * @group string_funcs * @since 1.5.0 */ - def ltrim(e: Column): Column = StringTrimLeft(e.expr) + def levenshtein(l: Column, r: Column): Column = Levenshtein(l.expr, r.expr) /** - * Trim the spaces from left end for the specified column. + * Computes the numeric value of the first character of the specified string column. * * @group string_funcs * @since 1.5.0 */ - def ltrim(columnName: String): Column = ltrim(Column(columnName)) + def ascii(e: Column): Column = Ascii(e.expr) /** - * Trim the spaces from right end for the specified string value. + * Trim the spaces from both ends for the specified string column. * * @group string_funcs * @since 1.5.0 */ - def rtrim(e: Column): Column = StringTrimRight(e.expr) + def trim(e: Column): Column = StringTrim(e.expr) /** - * Trim the spaces from right end for the specified column. + * Trim the spaces from left end for the specified string value. * * @group string_funcs * @since 1.5.0 */ - def rtrim(columnName: String): Column = rtrim(Column(columnName)) + def ltrim(e: Column): Column = StringTrimLeft(e.expr) /** - * Format strings in printf-style. + * Trim the spaces from right end for the specified string value. * * @group string_funcs * @since 1.5.0 */ - @scala.annotation.varargs - def formatString(format: Column, arguments: Column*): Column = { - StringFormat((format +: arguments).map(_.expr): _*) - } + def rtrim(e: Column): Column = StringTrimRight(e.expr) /** - * Format strings in printf-style. - * NOTE: `format` is the string value of the formatter, not column name. + * Formats the arguments in printf-style and returns the result as a string column. * * @group string_funcs * @since 1.5.0 */ @scala.annotation.varargs - def formatString(format: String, arguNames: String*): Column = { - StringFormat(lit(format).expr +: arguNames.map(Column(_).expr): _*) + def format_string(format: String, arguments: Column*): Column = { + FormatString((lit(format) +: arguments).map(_.expr): _*) } /** - * Locate the position of the first occurrence of substr value in the given string. + * Locate the position of the first occurrence of substr column in the given string. * Returns null if either of the arguments are null. * * NOTE: The position is not zero based, but 1 based index, returns 0 if substr @@ -1759,11 +1762,10 @@ object functions { * @group string_funcs * @since 1.5.0 */ - def instr(substr: String, sub: String): Column = instr(Column(substr), Column(sub)) + def instr(str: Column, substring: String): Column = StringInstr(str.expr, lit(substring).expr) /** - * Locate the position of the first occurrence of substr column in the given string. - * Returns null if either of the arguments are null. + * Locate the position of the first occurrence of substr in a string column. * * NOTE: The position is not zero based, but 1 based index, returns 0 if substr * could not be found in str. @@ -1771,342 +1773,313 @@ object functions { * @group string_funcs * @since 1.5.0 */ - def instr(substr: Column, sub: Column): Column = StringInstr(substr.expr, sub.expr) + def locate(substr: String, str: Column): Column = { + new StringLocate(lit(substr).expr, str.expr) + } /** - * Locate the position of the first occurrence of substr. + * Locate the position of the first occurrence of substr in a string column, after position pos. * - * NOTE: The position is not zero based, but 1 based index, returns 0 if substr + * NOTE: The position is not zero based, but 1 based index. returns 0 if substr * could not be found in str. * * @group string_funcs * @since 1.5.0 */ - def locate(substr: String, str: String): Column = { - locate(Column(substr), Column(str)) + def locate(substr: String, str: Column, pos: Int): Column = { + StringLocate(lit(substr).expr, str.expr, lit(pos).expr) } + /** - * Locate the position of the first occurrence of substr. - * - * NOTE: The position is not zero based, but 1 based index, returns 0 if substr - * could not be found in str. + * Extract a specific(idx) group identified by a java regex, from the specified string column. * * @group string_funcs * @since 1.5.0 */ - def locate(substr: Column, str: Column): Column = { - new StringLocate(substr.expr, str.expr) + def regexp_extract(e: Column, exp: String, groupIdx: Int): Column = { + RegExpExtract(e.expr, lit(exp).expr, lit(groupIdx).expr) } /** - * Locate the position of the first occurrence of substr in a given string after position pos. - * - * NOTE: The position is not zero based, but 1 based index, returns 0 if substr - * could not be found in str. + * Replace all substrings of the specified string value that match regexp with rep. * * @group string_funcs * @since 1.5.0 */ - def locate(substr: String, str: String, pos: String): Column = { - locate(Column(substr), Column(str), Column(pos)) + def regexp_replace(e: Column, pattern: String, replacement: String): Column = { + RegExpReplace(e.expr, lit(pattern).expr, lit(replacement).expr) } /** - * Locate the position of the first occurrence of substr in a given string after position pos. - * - * NOTE: The position is not zero based, but 1 based index, returns 0 if substr - * could not be found in str. + * Computes the BASE64 encoding of a binary column and returns it as a string column. + * This is the reverse of unbase64. * * @group string_funcs * @since 1.5.0 */ - def locate(substr: Column, str: Column, pos: Column): Column = { - StringLocate(substr.expr, str.expr, pos.expr) - } + def base64(e: Column): Column = Base64(e.expr) /** - * Locate the position of the first occurrence of substr in a given string after position pos. - * - * NOTE: The position is not zero based, but 1 based index, returns 0 if substr - * could not be found in str. + * Decodes a BASE64 encoded string column and returns it as a binary column. + * This is the reverse of base64. * * @group string_funcs * @since 1.5.0 */ - def locate(substr: Column, str: Column, pos: Int): Column = { - StringLocate(substr.expr, str.expr, lit(pos).expr) - } + def unbase64(e: Column): Column = UnBase64(e.expr) /** - * Locate the position of the first occurrence of substr in a given string after position pos. - * - * NOTE: The position is not zero based, but 1 based index, returns 0 if substr - * could not be found in str. + * Left-padded with pad to a length of len. * * @group string_funcs * @since 1.5.0 */ - def locate(substr: String, str: String, pos: Int): Column = { - locate(Column(substr), Column(str), lit(pos)) + def lpad(str: Column, len: Int, pad: String): Column = { + StringLPad(str.expr, lit(len).expr, lit(pad).expr) } /** - * Computes the specified value from binary to a base64 string. + * Computes the first argument into a binary from a string using the provided character set + * (one of 'US-ASCII', 'ISO-8859-1', 'UTF-8', 'UTF-16BE', 'UTF-16LE', 'UTF-16'). + * If either argument is null, the result will also be null. * * @group string_funcs * @since 1.5.0 */ - def base64(e: Column): Column = Base64(e.expr) + def encode(value: Column, charset: String): Column = Encode(value.expr, lit(charset).expr) /** - * Computes the specified column from binary to a base64 string. + * Computes the first argument into a string from a binary using the provided character set + * (one of 'US-ASCII', 'ISO-8859-1', 'UTF-8', 'UTF-16BE', 'UTF-16LE', 'UTF-16'). + * If either argument is null, the result will also be null. * * @group string_funcs * @since 1.5.0 */ - def base64(columnName: String): Column = base64(Column(columnName)) + def decode(value: Column, charset: String): Column = Decode(value.expr, lit(charset).expr) /** - * Computes the specified value from a base64 string to binary. + * Right-padded with pad to a length of len. * * @group string_funcs * @since 1.5.0 */ - def unbase64(e: Column): Column = UnBase64(e.expr) + def rpad(str: Column, len: Int, pad: String): Column = { + StringRPad(str.expr, lit(len).expr, lit(pad).expr) + } /** - * Computes the specified column from a base64 string to binary. + * Repeats a string column n times, and returns it as a new string column. * * @group string_funcs * @since 1.5.0 */ - def unbase64(columnName: String): Column = unbase64(Column(columnName)) + def repeat(str: Column, n: Int): Column = { + StringRepeat(str.expr, lit(n).expr) + } /** - * Left-padded with pad to a length of len. + * Splits str around pattern (pattern is a regular expression). + * NOTE: pattern is a string represent the regular expression. * * @group string_funcs * @since 1.5.0 */ - def lpad(str: String, len: String, pad: String): Column = { - lpad(Column(str), Column(len), Column(pad)) + def split(str: Column, pattern: String): Column = { + StringSplit(str.expr, lit(pattern).expr) } /** - * Left-padded with pad to a length of len. + * Reversed the string for the specified value. * * @group string_funcs * @since 1.5.0 */ - def lpad(str: Column, len: Column, pad: Column): Column = { - StringLPad(str.expr, len.expr, pad.expr) + def reverse(str: Column): Column = { + StringReverse(str.expr) } + ////////////////////////////////////////////////////////////////////////////////////////////// + // DateTime functions + ////////////////////////////////////////////////////////////////////////////////////////////// + /** - * Left-padded with pad to a length of len. + * Converts a date/timestamp/string to a value of string in the format specified by the date + * format given by the second argument. * - * @group string_funcs + * A pattern could be for instance `dd.MM.yyyy` and could return a string like '18.03.1993'. All + * pattern letters of [[java.text.SimpleDateFormat]] can be used. + * + * NOTE: Use when ever possible specialized functions like [[year]]. These benefit from a + * specialized implementation. + * + * @group datetime_funcs * @since 1.5.0 */ - def lpad(str: Column, len: Int, pad: Column): Column = { - StringLPad(str.expr, lit(len).expr, pad.expr) - } + def date_format(dateExpr: Column, format: String): Column = + DateFormatClass(dateExpr.expr, Literal(format)) /** - * Left-padded with pad to a length of len. + * Converts a date/timestamp/string to a value of string in the format specified by the date + * format given by the second argument. * - * @group string_funcs + * A pattern could be for instance `dd.MM.yyyy` and could return a string like '18.03.1993'. All + * pattern letters of [[java.text.SimpleDateFormat]] can be used. + * + * NOTE: Use when ever possible specialized functions like [[year]]. These benefit from a + * specialized implementation. + * + * @group datetime_funcs * @since 1.5.0 */ - def lpad(str: String, len: Int, pad: String): Column = { - lpad(Column(str), len, Column(pad)) - } + def date_format(dateColumnName: String, format: String): Column = + date_format(Column(dateColumnName), format) /** - * Computes the first argument into a binary from a string using the provided character set - * (one of 'US-ASCII', 'ISO-8859-1', 'UTF-8', 'UTF-16BE', 'UTF-16LE', 'UTF-16'). - * If either argument is null, the result will also be null. - * - * @group string_funcs + * Extracts the year as an integer from a given date/timestamp/string. + * @group datetime_funcs * @since 1.5.0 */ - def encode(value: Column, charset: String): Column = Encode(value.expr, lit(charset).expr) + def year(e: Column): Column = Year(e.expr) /** - * Computes the first argument into a binary from a string using the provided character set - * (one of 'US-ASCII', 'ISO-8859-1', 'UTF-8', 'UTF-16BE', 'UTF-16LE', 'UTF-16'). - * If either argument is null, the result will also be null. - * NOTE: charset represents the string value of the character set, not the column name. - * - * @group string_funcs + * Extracts the year as an integer from a given date/timestamp/string. + * @group datetime_funcs * @since 1.5.0 */ - def encode(columnName: String, charset: String): Column = - encode(Column(columnName), charset) + def year(columnName: String): Column = year(Column(columnName)) /** - * Computes the first argument into a string from a binary using the provided character set - * (one of 'US-ASCII', 'ISO-8859-1', 'UTF-8', 'UTF-16BE', 'UTF-16LE', 'UTF-16'). - * If either argument is null, the result will also be null. - * - * @group string_funcs + * Extracts the quarter as an integer from a given date/timestamp/string. + * @group datetime_funcs * @since 1.5.0 */ - def decode(value: Column, charset: String): Column = Decode(value.expr, lit(charset).expr) + def quarter(e: Column): Column = Quarter(e.expr) /** - * Computes the first argument into a string from a binary using the provided character set - * (one of 'US-ASCII', 'ISO-8859-1', 'UTF-8', 'UTF-16BE', 'UTF-16LE', 'UTF-16'). - * If either argument is null, the result will also be null. - * NOTE: charset represents the string value of the character set, not the column name. - * - * @group string_funcs + * Extracts the quarter as an integer from a given date/timestamp/string. + * @group datetime_funcs * @since 1.5.0 */ - def decode(columnName: String, charset: String): Column = - decode(Column(columnName), charset) + def quarter(columnName: String): Column = quarter(Column(columnName)) /** - * Right-padded with pad to a length of len. - * - * @group string_funcs + * Extracts the month as an integer from a given date/timestamp/string. + * @group datetime_funcs * @since 1.5.0 */ - def rpad(str: String, len: String, pad: String): Column = { - rpad(Column(str), Column(len), Column(pad)) - } + def month(e: Column): Column = Month(e.expr) /** - * Right-padded with pad to a length of len. - * - * @group string_funcs + * Extracts the month as an integer from a given date/timestamp/string. + * @group datetime_funcs * @since 1.5.0 */ - def rpad(str: Column, len: Column, pad: Column): Column = { - StringRPad(str.expr, len.expr, pad.expr) - } + def month(columnName: String): Column = month(Column(columnName)) /** - * Right-padded with pad to a length of len. - * - * @group string_funcs + * Extracts the day of the month as an integer from a given date/timestamp/string. + * @group datetime_funcs * @since 1.5.0 */ - def rpad(str: String, len: Int, pad: String): Column = { - rpad(Column(str), len, Column(pad)) - } + def dayofmonth(e: Column): Column = DayOfMonth(e.expr) /** - * Right-padded with pad to a length of len. - * - * @group string_funcs + * Extracts the day of the month as an integer from a given date/timestamp/string. + * @group datetime_funcs * @since 1.5.0 */ - def rpad(str: Column, len: Int, pad: Column): Column = { - StringRPad(str.expr, lit(len).expr, pad.expr) - } + def dayofmonth(columnName: String): Column = dayofmonth(Column(columnName)) /** - * Repeat the string value of the specified column n times. - * - * @group string_funcs + * Extracts the day of the year as an integer from a given date/timestamp/string. + * @group datetime_funcs * @since 1.5.0 */ - def repeat(strColumn: String, timesColumn: String): Column = { - repeat(Column(strColumn), Column(timesColumn)) - } + def dayofyear(e: Column): Column = DayOfYear(e.expr) /** - * Repeat the string expression value n times. - * - * @group string_funcs + * Extracts the day of the year as an integer from a given date/timestamp/string. + * @group datetime_funcs * @since 1.5.0 */ - def repeat(str: Column, times: Column): Column = { - StringRepeat(str.expr, times.expr) - } + def dayofyear(columnName: String): Column = dayofyear(Column(columnName)) /** - * Repeat the string value of the specified column n times. - * - * @group string_funcs + * Extracts the hours as an integer from a given date/timestamp/string. + * @group datetime_funcs * @since 1.5.0 */ - def repeat(strColumn: String, times: Int): Column = { - repeat(Column(strColumn), times) - } + def hour(e: Column): Column = Hour(e.expr) /** - * Repeat the string expression value n times. - * - * @group string_funcs + * Extracts the hours as an integer from a given date/timestamp/string. + * @group datetime_funcs * @since 1.5.0 */ - def repeat(str: Column, times: Int): Column = { - StringRepeat(str.expr, lit(times).expr) - } + def hour(columnName: String): Column = hour(Column(columnName)) /** - * Splits str around pattern (pattern is a regular expression). - * - * @group string_funcs + * Extracts the minutes as an integer from a given date/timestamp/string. + * @group datetime_funcs * @since 1.5.0 */ - def split(strColumnName: String, pattern: String): Column = { - split(Column(strColumnName), pattern) - } + def minute(e: Column): Column = Minute(e.expr) /** - * Splits str around pattern (pattern is a regular expression). - * NOTE: pattern is a string represent the regular expression. - * - * @group string_funcs + * Extracts the minutes as an integer from a given date/timestamp/string. + * @group datetime_funcs * @since 1.5.0 */ - def split(str: Column, pattern: String): Column = { - StringSplit(str.expr, lit(pattern).expr) - } + def minute(columnName: String): Column = minute(Column(columnName)) /** - * Reversed the string for the specified column. - * - * @group string_funcs + * Extracts the seconds as an integer from a given date/timestamp/string. + * @group datetime_funcs * @since 1.5.0 */ - def reverse(str: String): Column = { - reverse(Column(str)) - } + def second(e: Column): Column = Second(e.expr) /** - * Reversed the string for the specified value. - * - * @group string_funcs + * Extracts the seconds as an integer from a given date/timestamp/string. + * @group datetime_funcs * @since 1.5.0 */ - def reverse(str: Column): Column = { - StringReverse(str.expr) - } + def second(columnName: String): Column = second(Column(columnName)) /** - * Make a n spaces of string. - * - * @group string_funcs + * Extracts the week number as an integer from a given date/timestamp/string. + * @group datetime_funcs * @since 1.5.0 */ - def space(n: String): Column = { - space(Column(n)) - } + def weekofyear(e: Column): Column = WeekOfYear(e.expr) /** - * Make a n spaces of string. - * - * @group string_funcs + * Extracts the week number as an integer from a given date/timestamp/string. + * @group datetime_funcs * @since 1.5.0 */ - def space(n: Column): Column = { - StringSpace(n.expr) - } + def weekofyear(columnName: String): Column = weekofyear(Column(columnName)) + + ////////////////////////////////////////////////////////////////////////////////////////////// + // Collection functions + ////////////////////////////////////////////////////////////////////////////////////////////// + + /** + * Returns length of array or map + * @group collection_funcs + * @since 1.5.0 + */ + def size(columnName: String): Column = size(Column(columnName)) + + /** + * Returns length of array or map + * @group collection_funcs + * @since 1.5.0 + */ + def size(column: Column): Column = Size(column.expr) + ////////////////////////////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////////////////////////// @@ -2447,7 +2420,7 @@ object functions { * @since 1.5.0 */ def callUDF(udfName: String, cols: Column*): Column = { - UnresolvedFunction(udfName, cols.map(_.expr)) + UnresolvedFunction(udfName, cols.map(_.expr), isDistinct = false) } /** @@ -2476,7 +2449,7 @@ object functions { exprs(i) = cols(i).expr i += 1 } - UnresolvedFunction(udfName, exprs) + UnresolvedFunction(udfName, exprs, isDistinct = false) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/json/JSONRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/json/JSONRelation.scala index 2361d3bf52d2b..922794ac9aac5 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/json/JSONRelation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/json/JSONRelation.scala @@ -19,7 +19,7 @@ package org.apache.spark.sql.json import java.io.IOException -import org.apache.hadoop.fs.Path +import org.apache.hadoop.fs.{FileSystem, Path} import org.apache.spark.rdd.RDD import org.apache.spark.sql.AnalysisException @@ -87,20 +87,7 @@ private[sql] class DefaultSource case SaveMode.Append => sys.error(s"Append mode is not supported by ${this.getClass.getCanonicalName}") case SaveMode.Overwrite => { - var success: Boolean = false - try { - success = fs.delete(filesystemPath, true) - } catch { - case e: IOException => - throw new IOException( - s"Unable to clear output directory ${filesystemPath.toString} prior" - + s" to writing to JSON table:\n${e.toString}") - } - if (!success) { - throw new IOException( - s"Unable to clear output directory ${filesystemPath.toString} prior" - + s" to writing to JSON table.") - } + JSONRelation.delete(filesystemPath, fs) true } case SaveMode.ErrorIfExists => @@ -157,51 +144,27 @@ private[sql] class JSONRelation( } } - private val useJacksonStreamingAPI: Boolean = sqlContext.conf.useJacksonStreamingAPI - override val needConversion: Boolean = false override lazy val schema = userSpecifiedSchema.getOrElse { - if (useJacksonStreamingAPI) { - InferSchema( - baseRDD(), - samplingRatio, - sqlContext.conf.columnNameOfCorruptRecord) - } else { - JsonRDD.nullTypeToStringType( - JsonRDD.inferSchema( - baseRDD(), - samplingRatio, - sqlContext.conf.columnNameOfCorruptRecord)) - } + InferSchema( + baseRDD(), + samplingRatio, + sqlContext.conf.columnNameOfCorruptRecord) } override def buildScan(): RDD[Row] = { - if (useJacksonStreamingAPI) { - JacksonParser( - baseRDD(), - schema, - sqlContext.conf.columnNameOfCorruptRecord).map(_.asInstanceOf[Row]) - } else { - JsonRDD.jsonStringToRow( - baseRDD(), - schema, - sqlContext.conf.columnNameOfCorruptRecord).map(_.asInstanceOf[Row]) - } + JacksonParser( + baseRDD(), + schema, + sqlContext.conf.columnNameOfCorruptRecord).map(_.asInstanceOf[Row]) } override def buildScan(requiredColumns: Seq[Attribute], filters: Seq[Expression]): RDD[Row] = { - if (useJacksonStreamingAPI) { - JacksonParser( - baseRDD(), - StructType.fromAttributes(requiredColumns), - sqlContext.conf.columnNameOfCorruptRecord).map(_.asInstanceOf[Row]) - } else { - JsonRDD.jsonStringToRow( - baseRDD(), - StructType.fromAttributes(requiredColumns), - sqlContext.conf.columnNameOfCorruptRecord).map(_.asInstanceOf[Row]) - } + JacksonParser( + baseRDD(), + StructType.fromAttributes(requiredColumns), + sqlContext.conf.columnNameOfCorruptRecord).map(_.asInstanceOf[Row]) } override def insert(data: DataFrame, overwrite: Boolean): Unit = { @@ -219,20 +182,7 @@ private[sql] class JSONRelation( if (overwrite) { if (fs.exists(filesystemPath)) { - var success: Boolean = false - try { - success = fs.delete(filesystemPath, true) - } catch { - case e: IOException => - throw new IOException( - s"Unable to clear output directory ${filesystemPath.toString} prior" - + s" to writing to JSON table:\n${e.toString}") - } - if (!success) { - throw new IOException( - s"Unable to clear output directory ${filesystemPath.toString} prior" - + s" to writing to JSON table.") - } + JSONRelation.delete(filesystemPath, fs) } // Write the data. data.toJSON.saveAsTextFile(filesystemPath.toString) @@ -252,3 +202,21 @@ private[sql] class JSONRelation( case _ => false } } + +private object JSONRelation { + + /** Delete the specified directory to overwrite it with new JSON data. */ + def delete(dir: Path, fs: FileSystem): Unit = { + var success: Boolean = false + val failMessage = s"Unable to clear output directory $dir prior to writing to JSON table" + try { + success = fs.delete(dir, true /* recursive */) + } catch { + case e: IOException => + throw new IOException(s"$failMessage\n${e.toString}") + } + if (!success) { + throw new IOException(failMessage) + } + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/json/JsonRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/json/JsonRDD.scala deleted file mode 100644 index b392a51bf7dce..0000000000000 --- a/sql/core/src/main/scala/org/apache/spark/sql/json/JsonRDD.scala +++ /dev/null @@ -1,449 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.json - -import scala.collection.Map -import scala.collection.convert.Wrappers.{JListWrapper, JMapWrapper} - -import com.fasterxml.jackson.core.JsonProcessingException -import com.fasterxml.jackson.databind.ObjectMapper - -import org.apache.spark.Logging -import org.apache.spark.rdd.RDD -import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.ScalaReflection -import org.apache.spark.sql.catalyst.analysis.HiveTypeCoercion -import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.util.DateTimeUtils -import org.apache.spark.sql.types._ -import org.apache.spark.unsafe.types.UTF8String - - -private[sql] object JsonRDD extends Logging { - - private[sql] def jsonStringToRow( - json: RDD[String], - schema: StructType, - columnNameOfCorruptRecords: String): RDD[InternalRow] = { - parseJson(json, columnNameOfCorruptRecords).map(parsed => asRow(parsed, schema)) - } - - private[sql] def inferSchema( - json: RDD[String], - samplingRatio: Double = 1.0, - columnNameOfCorruptRecords: String): StructType = { - require(samplingRatio > 0, s"samplingRatio ($samplingRatio) should be greater than 0") - val schemaData = if (samplingRatio > 0.99) json else json.sample(false, samplingRatio, 1) - val allKeys = - if (schemaData.isEmpty()) { - Set.empty[(String, DataType)] - } else { - parseJson(schemaData, columnNameOfCorruptRecords).map(allKeysWithValueTypes).reduce(_ ++ _) - } - createSchema(allKeys) - } - - private def createSchema(allKeys: Set[(String, DataType)]): StructType = { - // Resolve type conflicts - val resolved = allKeys.groupBy { - case (key, dataType) => key - }.map { - // Now, keys and types are organized in the format of - // key -> Set(type1, type2, ...). - case (key, typeSet) => { - val fieldName = key.substring(1, key.length - 1).split("`.`").toSeq - val dataType = typeSet.map { - case (_, dataType) => dataType - }.reduce((type1: DataType, type2: DataType) => compatibleType(type1, type2)) - - (fieldName, dataType) - } - } - - def makeStruct(values: Seq[Seq[String]], prefix: Seq[String]): StructType = { - val (topLevel, structLike) = values.partition(_.size == 1) - - val topLevelFields = topLevel.filter { - name => resolved.get(prefix ++ name).get match { - case ArrayType(elementType, _) => { - def hasInnerStruct(t: DataType): Boolean = t match { - case s: StructType => true - case ArrayType(t1, _) => hasInnerStruct(t1) - case o => false - } - - // Check if this array has inner struct. - !hasInnerStruct(elementType) - } - case struct: StructType => false - case _ => true - } - }.map { - a => StructField(a.head, resolved.get(prefix ++ a).get, nullable = true) - } - val topLevelFieldNameSet = topLevelFields.map(_.name) - - val structFields: Seq[StructField] = structLike.groupBy(_(0)).filter { - case (name, _) => !topLevelFieldNameSet.contains(name) - }.map { - case (name, fields) => { - val nestedFields = fields.map(_.tail) - val structType = makeStruct(nestedFields, prefix :+ name) - val dataType = resolved.get(prefix :+ name).get - dataType match { - case array: ArrayType => - // The pattern of this array is ArrayType(...(ArrayType(StructType))). - // Since the inner struct of array is a placeholder (StructType(Nil)), - // we need to replace this placeholder with the actual StructType (structType). - def getActualArrayType( - innerStruct: StructType, - currentArray: ArrayType): ArrayType = currentArray match { - case ArrayType(s: StructType, containsNull) => - ArrayType(innerStruct, containsNull) - case ArrayType(a: ArrayType, containsNull) => - ArrayType(getActualArrayType(innerStruct, a), containsNull) - } - Some(StructField(name, getActualArrayType(structType, array), nullable = true)) - case struct: StructType => Some(StructField(name, structType, nullable = true)) - // dataType is StringType means that we have resolved type conflicts involving - // primitive types and complex types. So, the type of name has been relaxed to - // StringType. Also, this field should have already been put in topLevelFields. - case StringType => None - } - } - }.flatMap(field => field).toSeq - - StructType((topLevelFields ++ structFields).sortBy(_.name)) - } - - makeStruct(resolved.keySet.toSeq, Nil) - } - - private[sql] def nullTypeToStringType(struct: StructType): StructType = { - val fields = struct.fields.map { - case StructField(fieldName, dataType, nullable, _) => { - val newType = dataType match { - case NullType => StringType - case ArrayType(NullType, containsNull) => ArrayType(StringType, containsNull) - case ArrayType(struct: StructType, containsNull) => - ArrayType(nullTypeToStringType(struct), containsNull) - case struct: StructType => nullTypeToStringType(struct) - case other: DataType => other - } - StructField(fieldName, newType, nullable) - } - } - - StructType(fields) - } - - /** - * Returns the most general data type for two given data types. - */ - private[json] def compatibleType(t1: DataType, t2: DataType): DataType = { - HiveTypeCoercion.findTightestCommonTypeOfTwo(t1, t2) match { - case Some(commonType) => commonType - case None => - // t1 or t2 is a StructType, ArrayType, or an unexpected type. - (t1, t2) match { - case (other: DataType, NullType) => other - case (NullType, other: DataType) => other - case (StructType(fields1), StructType(fields2)) => { - val newFields = (fields1 ++ fields2).groupBy(field => field.name).map { - case (name, fieldTypes) => { - val dataType = fieldTypes.map(field => field.dataType).reduce( - (type1: DataType, type2: DataType) => compatibleType(type1, type2)) - StructField(name, dataType, true) - } - } - StructType(newFields.toSeq.sortBy(_.name)) - } - case (ArrayType(elementType1, containsNull1), ArrayType(elementType2, containsNull2)) => - ArrayType(compatibleType(elementType1, elementType2), containsNull1 || containsNull2) - // TODO: We should use JsonObjectStringType to mark that values of field will be - // strings and every string is a Json object. - case (_, _) => StringType - } - } - } - - private def typeOfPrimitiveValue: PartialFunction[Any, DataType] = { - // For Integer values, use LongType by default. - val useLongType: PartialFunction[Any, DataType] = { - case value: IntegerType.InternalType => LongType - } - - useLongType orElse ScalaReflection.typeOfObject orElse { - // Since we do not have a data type backed by BigInteger, - // when we see a Java BigInteger, we use DecimalType. - case value: java.math.BigInteger => DecimalType.Unlimited - // DecimalType's JVMType is scala BigDecimal. - case value: java.math.BigDecimal => DecimalType.Unlimited - // Unexpected data type. - case _ => StringType - } - } - - /** - * Returns the element type of an JSON array. We go through all elements of this array - * to detect any possible type conflict. We use [[compatibleType]] to resolve - * type conflicts. - */ - private def typeOfArray(l: Seq[Any]): ArrayType = { - val elements = l.flatMap(v => Option(v)) - if (elements.isEmpty) { - // If this JSON array is empty, we use NullType as a placeholder. - // If this array is not empty in other JSON objects, we can resolve - // the type after we have passed through all JSON objects. - ArrayType(NullType, containsNull = true) - } else { - val elementType = elements.map { - e => e match { - case map: Map[_, _] => StructType(Nil) - // We have an array of arrays. If those element arrays do not have the same - // element types, we will return ArrayType[StringType]. - case seq: Seq[_] => typeOfArray(seq) - case value => typeOfPrimitiveValue(value) - } - }.reduce((type1: DataType, type2: DataType) => compatibleType(type1, type2)) - - ArrayType(elementType, containsNull = true) - } - } - - /** - * Figures out all key names and data types of values from a parsed JSON object - * (in the format of Map[Stirng, Any]). When the value of a key is an JSON object, we - * only use a placeholder (StructType(Nil)) to mark that it should be a struct - * instead of getting all fields of this struct because a field does not appear - * in this JSON object can appear in other JSON objects. - */ - private def allKeysWithValueTypes(m: Map[String, Any]): Set[(String, DataType)] = { - val keyValuePairs = m.map { - // Quote the key with backticks to handle cases which have dots - // in the field name. - case (key, value) => (s"`$key`", value) - }.toSet - keyValuePairs.flatMap { - case (key: String, struct: Map[_, _]) => { - // The value associated with the key is an JSON object. - allKeysWithValueTypes(struct.asInstanceOf[Map[String, Any]]).map { - case (k, dataType) => (s"$key.$k", dataType) - } ++ Set((key, StructType(Nil))) - } - case (key: String, array: Seq[_]) => { - // The value associated with the key is an array. - // Handle inner structs of an array. - def buildKeyPathForInnerStructs(v: Any, t: DataType): Seq[(String, DataType)] = t match { - case ArrayType(e: StructType, _) => { - // The elements of this arrays are structs. - v.asInstanceOf[Seq[Map[String, Any]]].flatMap(Option(_)).flatMap { - element => allKeysWithValueTypes(element) - }.map { - case (k, t) => (s"$key.$k", t) - } - } - case ArrayType(t1, _) => - v.asInstanceOf[Seq[Any]].flatMap(Option(_)).flatMap { - element => buildKeyPathForInnerStructs(element, t1) - } - case other => Nil - } - val elementType = typeOfArray(array) - buildKeyPathForInnerStructs(array, elementType) :+ (key, elementType) - } - // we couldn't tell what the type is if the value is null or empty string - case (key: String, value) if value == "" || value == null => (key, NullType) :: Nil - case (key: String, value) => (key, typeOfPrimitiveValue(value)) :: Nil - } - } - - /** - * Converts a Java Map/List to a Scala Map/Seq. - * We do not use Jackson's scala module at here because - * DefaultScalaModule in jackson-module-scala will make - * the parsing very slow. - */ - private def scalafy(obj: Any): Any = obj match { - case map: java.util.Map[_, _] => - // .map(identity) is used as a workaround of non-serializable Map - // generated by .mapValues. - // This issue is documented at https://issues.scala-lang.org/browse/SI-7005 - JMapWrapper(map).mapValues(scalafy).map(identity) - case list: java.util.List[_] => - JListWrapper(list).map(scalafy) - case atom => atom - } - - private def parseJson( - json: RDD[String], - columnNameOfCorruptRecords: String): RDD[Map[String, Any]] = { - // According to [Jackson-72: https://jira.codehaus.org/browse/JACKSON-72], - // ObjectMapper will not return BigDecimal when - // "DeserializationFeature.USE_BIG_DECIMAL_FOR_FLOATS" is disabled - // (see NumberDeserializer.deserialize for the logic). - // But, we do not want to enable this feature because it will use BigDecimal - // for every float number, which will be slow. - // So, right now, we will have Infinity for those BigDecimal number. - // TODO: Support BigDecimal. - json.mapPartitions(iter => { - // When there is a key appearing multiple times (a duplicate key), - // the ObjectMapper will take the last value associated with this duplicate key. - // For example: for {"key": 1, "key":2}, we will get "key"->2. - val mapper = new ObjectMapper() - iter.flatMap { record => - try { - val parsed = mapper.readValue(record, classOf[Object]) match { - case map: java.util.Map[_, _] => scalafy(map).asInstanceOf[Map[String, Any]] :: Nil - case list: java.util.List[_] => scalafy(list).asInstanceOf[Seq[Map[String, Any]]] - case _ => - sys.error( - s"Failed to parse record $record. Please make sure that each line of the file " + - "(or each string in the RDD) is a valid JSON object or an array of JSON objects.") - } - - parsed - } catch { - case e: JsonProcessingException => - Map(columnNameOfCorruptRecords -> UTF8String.fromString(record)) :: Nil - } - } - }) - } - - private def toLong(value: Any): Long = { - value match { - case value: java.lang.Integer => value.asInstanceOf[Int].toLong - case value: java.lang.Long => value.asInstanceOf[Long] - } - } - - private def toDouble(value: Any): Double = { - value match { - case value: java.lang.Integer => value.asInstanceOf[Int].toDouble - case value: java.lang.Long => value.asInstanceOf[Long].toDouble - case value: java.lang.Double => value.asInstanceOf[Double] - } - } - - private def toDecimal(value: Any): Decimal = { - value match { - case value: java.lang.Integer => Decimal(value) - case value: java.lang.Long => Decimal(value) - case value: java.math.BigInteger => Decimal(new java.math.BigDecimal(value)) - case value: java.lang.Double => Decimal(value) - case value: java.math.BigDecimal => Decimal(value) - } - } - - private def toJsonArrayString(seq: Seq[Any]): String = { - val builder = new StringBuilder - builder.append("[") - var count = 0 - seq.foreach { - element => - if (count > 0) builder.append(",") - count += 1 - builder.append(toString(element)) - } - builder.append("]") - - builder.toString() - } - - private def toJsonObjectString(map: Map[String, Any]): String = { - val builder = new StringBuilder - builder.append("{") - var count = 0 - map.foreach { - case (key, value) => - if (count > 0) builder.append(",") - count += 1 - val stringValue = if (value.isInstanceOf[String]) s"""\"$value\"""" else toString(value) - builder.append(s"""\"${key}\":${stringValue}""") - } - builder.append("}") - - builder.toString() - } - - private def toString(value: Any): String = { - value match { - case value: Map[_, _] => toJsonObjectString(value.asInstanceOf[Map[String, Any]]) - case value: Seq[_] => toJsonArrayString(value) - case value => Option(value).map(_.toString).orNull - } - } - - private def toDate(value: Any): Int = { - value match { - // only support string as date - case value: java.lang.String => - DateTimeUtils.millisToDays(DateTimeUtils.stringToTime(value).getTime) - case value: java.sql.Date => DateTimeUtils.fromJavaDate(value) - } - } - - private def toTimestamp(value: Any): Long = { - value match { - case value: java.lang.Integer => value.asInstanceOf[Int].toLong * 1000L - case value: java.lang.Long => value * 1000L - case value: java.lang.String => DateTimeUtils.stringToTime(value).getTime * 1000L - } - } - - private[json] def enforceCorrectType(value: Any, desiredType: DataType): Any = { - if (value == null) { - null - } else { - desiredType match { - case StringType => UTF8String.fromString(toString(value)) - case _ if value == null || value == "" => null // guard the non string type - case IntegerType => value.asInstanceOf[IntegerType.InternalType] - case LongType => toLong(value) - case DoubleType => toDouble(value) - case DecimalType() => toDecimal(value) - case BooleanType => value.asInstanceOf[BooleanType.InternalType] - case NullType => null - case ArrayType(elementType, _) => - value.asInstanceOf[Seq[Any]].map(enforceCorrectType(_, elementType)) - case MapType(StringType, valueType, _) => - val map = value.asInstanceOf[Map[String, Any]] - map.map { - case (k, v) => - (UTF8String.fromString(k), enforceCorrectType(v, valueType)) - }.map(identity) - case struct: StructType => asRow(value.asInstanceOf[Map[String, Any]], struct) - case DateType => toDate(value) - case TimestampType => toTimestamp(value) - } - } - } - - private def asRow(json: Map[String, Any], schema: StructType): InternalRow = { - // TODO: Reuse the row instead of creating a new one for every record. - val row = new GenericMutableRow(schema.fields.length) - schema.fields.zipWithIndex.foreach { - case (StructField(name, dataType, _, _), i) => - row.update(i, json.get(name).flatMap(v => Option(v)).map( - enforceCorrectType(_, dataType)).orNull) - } - - row - } -} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetRelation.scala index 704cf56f38265..086559e9f7658 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetRelation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetRelation.scala @@ -33,6 +33,7 @@ import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeMap} import org.apache.spark.sql.catalyst.plans.logical.{LeafNode, LogicalPlan, Statistics} import org.apache.spark.sql.types.StructType import org.apache.spark.sql.{DataFrame, SQLContext} +import org.apache.spark.util.Utils /** * Relation that consists of data stored in a Parquet columnar format. @@ -53,8 +54,6 @@ private[sql] case class ParquetRelation( partitioningAttributes: Seq[Attribute] = Nil) extends LeafNode with MultiInstanceRelation { - self: Product => - /** Schema derived from ParquetFile */ def parquetSchema: MessageType = ParquetTypesConverter @@ -108,7 +107,7 @@ private[sql] object ParquetRelation { // // Therefore we need to force the class to be loaded. // This should really be resolved by Parquet. - Class.forName(classOf[ParquetLog].getName) + Utils.classForName(classOf[ParquetLog].getName) // Note: Logger.getLogger("parquet") has a default logger // that appends to Console which needs to be cleared. @@ -119,12 +118,12 @@ private[sql] object ParquetRelation { // Disables a WARN log message in ParquetOutputCommitter. We first ensure that // ParquetOutputCommitter is loaded and the static LOG field gets initialized. // See https://issues.apache.org/jira/browse/SPARK-5968 for details - Class.forName(classOf[ParquetOutputCommitter].getName) + Utils.classForName(classOf[ParquetOutputCommitter].getName) JLogger.getLogger(classOf[ParquetOutputCommitter].getName).setLevel(Level.OFF) // Similar as above, disables a unnecessary WARN log message in ParquetRecordReader. // See https://issues.apache.org/jira/browse/PARQUET-220 for details - Class.forName(classOf[ParquetRecordReader[_]].getName) + Utils.classForName(classOf[ParquetRecordReader[_]].getName) JLogger.getLogger(classOf[ParquetRecordReader[_]].getName).setLevel(Level.OFF) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableOperations.scala b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableOperations.scala index 9058b09375291..28cba5e54d69e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableOperations.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableOperations.scala @@ -426,6 +426,7 @@ private[parquet] class AppendingParquetOutputFormat(offset: Int) } } +// TODO Removes this class after removing old Parquet support code /** * We extend ParquetInputFormat in order to have more control over which * RecordFilter we want to use. @@ -433,8 +434,6 @@ private[parquet] class AppendingParquetOutputFormat(offset: Int) private[parquet] class FilteringParquetRowInputFormat extends org.apache.parquet.hadoop.ParquetInputFormat[InternalRow] with Logging { - private var fileStatuses = Map.empty[Path, FileStatus] - override def createRecordReader( inputSplit: InputSplit, taskAttemptContext: TaskAttemptContext): RecordReader[Void, InternalRow] = { @@ -455,17 +454,6 @@ private[parquet] class FilteringParquetRowInputFormat } -private[parquet] object FilteringParquetRowInputFormat { - private val footerCache = CacheBuilder.newBuilder() - .maximumSize(20000) - .build[FileStatus, Footer]() - - private val blockLocationCache = CacheBuilder.newBuilder() - .maximumSize(20000) - .expireAfterWrite(15, TimeUnit.MINUTES) // Expire locations since HDFS files might move - .build[FileStatus, Array[BlockLocation]]() -} - private[parquet] object FileSystemHelper { def listFiles(pathStr: String, conf: Configuration): Seq[Path] = { val origPath = new Path(pathStr) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/parquet/newParquet.scala b/sql/core/src/main/scala/org/apache/spark/sql/parquet/newParquet.scala index 01dd6f471bd7c..2f9f880c70690 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/parquet/newParquet.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/parquet/newParquet.scala @@ -22,7 +22,7 @@ import java.util.{List => JList} import scala.collection.JavaConversions._ import scala.collection.mutable -import scala.util.Try +import scala.util.{Failure, Try} import com.google.common.base.Objects import org.apache.hadoop.fs.{FileStatus, Path} @@ -31,20 +31,22 @@ import org.apache.hadoop.mapreduce._ import org.apache.hadoop.mapreduce.lib.input.FileInputFormat import org.apache.parquet.filter2.predicate.FilterApi import org.apache.parquet.hadoop._ -import org.apache.parquet.hadoop.metadata.{FileMetaData, CompressionCodecName} +import org.apache.parquet.hadoop.metadata.CompressionCodecName import org.apache.parquet.hadoop.util.ContextUtil import org.apache.parquet.schema.MessageType +import org.apache.spark.{Logging, Partition => SparkPartition, SparkException} import org.apache.spark.broadcast.Broadcast -import org.apache.spark.deploy.SparkHadoopUtil import org.apache.spark.rdd.RDD import org.apache.spark.rdd.RDD._ import org.apache.spark.sql._ import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.execution.{SqlNewHadoopPartition, SqlNewHadoopRDD} +import org.apache.spark.sql.execution.datasources.PartitionSpec import org.apache.spark.sql.sources._ import org.apache.spark.sql.types.{DataType, StructType} import org.apache.spark.util.{SerializableConfiguration, Utils} -import org.apache.spark.{Logging, Partition => SparkPartition, SparkException} + private[sql] class DefaultSource extends HadoopFsRelationProvider { override def createRelation( @@ -278,19 +280,13 @@ private[sql] class ParquetRelation2( // Create the function to set input paths at the driver side. val setInputPaths = ParquetRelation2.initializeDriverSideJobFunc(inputFiles) _ - val footers = inputFiles.map(f => metadataCache.footers(f.getPath)) - Utils.withDummyCallSite(sqlContext.sparkContext) { - // TODO Stop using `FilteringParquetRowInputFormat` and overriding `getPartition`. - // After upgrading to Parquet 1.6.0, we should be able to stop caching `FileStatus` objects - // and footers. Especially when a global arbitrative schema (either from metastore or data - // source DDL) is available. new SqlNewHadoopRDD( sc = sqlContext.sparkContext, broadcastedConf = broadcastedConf, initDriverSideJobFuncOpt = Some(setInputPaths), initLocalJobFuncOpt = Some(initLocalJobFuncOpt), - inputFormatClass = classOf[FilteringParquetRowInputFormat], + inputFormatClass = classOf[ParquetInputFormat[InternalRow]], keyClass = classOf[Void], valueClass = classOf[InternalRow]) { @@ -306,12 +302,6 @@ private[sql] class ParquetRelation2( f.getAccessTime, f.getPermission, f.getOwner, f.getGroup, pathWithEscapedAuthority) }.toSeq - @transient val cachedFooters = footers.map { f => - // In order to encode the authority of a Path containing special characters such as /, - // we need to use the string returned by the URI of the path to create a new Path. - new Footer(escapePathUserInfo(f.getFile), f.getParquetMetadata) - }.toSeq - private def escapePathUserInfo(path: Path): Path = { val uri = path.toUri new Path(new URI( @@ -321,13 +311,10 @@ private[sql] class ParquetRelation2( // Overridden so we can inject our own cached files statuses. override def getPartitions: Array[SparkPartition] = { - val inputFormat = if (cacheMetadata) { - new FilteringParquetRowInputFormat { - override def listStatus(jobContext: JobContext): JList[FileStatus] = cachedStatuses - override def getFooters(jobContext: JobContext): JList[Footer] = cachedFooters + val inputFormat = new ParquetInputFormat[InternalRow] { + override def listStatus(jobContext: JobContext): JList[FileStatus] = { + if (cacheMetadata) cachedStatuses else super.listStatus(jobContext) } - } else { - new FilteringParquetRowInputFormat } val jobContext = newJobContext(getConf(isDriverSide = true), jobId) @@ -348,9 +335,6 @@ private[sql] class ParquetRelation2( // `FileStatus` objects of all "_common_metadata" files. private var commonMetadataStatuses: Array[FileStatus] = _ - // Parquet footer cache. - var footers: Map[Path, Footer] = _ - // `FileStatus` objects of all data files (Parquet part-files). var dataStatuses: Array[FileStatus] = _ @@ -376,20 +360,6 @@ private[sql] class ParquetRelation2( commonMetadataStatuses = leaves.filter(_.getPath.getName == ParquetFileWriter.PARQUET_COMMON_METADATA_FILE) - footers = { - val conf = SparkHadoopUtil.get.conf - val taskSideMetaData = conf.getBoolean(ParquetInputFormat.TASK_SIDE_METADATA, true) - val rawFooters = if (shouldMergeSchemas) { - ParquetFileReader.readAllFootersInParallel( - conf, seqAsJavaList(leaves), taskSideMetaData) - } else { - ParquetFileReader.readAllFootersInParallelUsingSummaryFiles( - conf, seqAsJavaList(leaves), taskSideMetaData) - } - - rawFooters.map(footer => footer.getFile -> footer).toMap - } - // If we already get the schema, don't need to re-compute it since the schema merging is // time-consuming. if (dataSchema == null) { @@ -422,7 +392,7 @@ private[sql] class ParquetRelation2( // Always tries the summary files first if users don't require a merged schema. In this case, // "_common_metadata" is more preferable than "_metadata" because it doesn't contain row // groups information, and could be much smaller for large Parquet files with lots of row - // groups. + // groups. If no summary file is available, falls back to some random part-file. // // NOTE: Metadata stored in the summary files are merged from all part-files. However, for // user defined key-value metadata (in which we store Spark SQL schema), Parquet doesn't know @@ -457,10 +427,10 @@ private[sql] class ParquetRelation2( assert( filesToTouch.nonEmpty || maybeDataSchema.isDefined || maybeMetastoreSchema.isDefined, - "No schema defined, " + - s"and no Parquet data file or summary file found under ${paths.mkString(", ")}.") + "No predefined schema found, " + + s"and no Parquet data files or summary files found under ${paths.mkString(", ")}.") - ParquetRelation2.readSchema(filesToTouch.map(f => footers.apply(f.getPath)), sqlContext) + ParquetRelation2.mergeSchemasInParallel(filesToTouch, sqlContext) } } } @@ -519,6 +489,7 @@ private[sql] object ParquetRelation2 extends Logging { private[parquet] def initializeDriverSideJobFunc( inputFiles: Array[FileStatus])(job: Job): Unit = { // We side the input paths at the driver side. + logInfo(s"Reading Parquet file(s) from ${inputFiles.map(_.getPath).mkString(", ")}") if (inputFiles.nonEmpty) { FileInputFormat.setInputPaths(job, inputFiles.map(_.getPath): _*) } @@ -543,7 +514,7 @@ private[sql] object ParquetRelation2 extends Logging { .getKeyValueMetaData .toMap .get(RowReadSupport.SPARK_METADATA_KEY) - if (serializedSchema == None) { + if (serializedSchema.isEmpty) { // Falls back to Parquet schema if no Spark SQL schema found. Some(parseParquetSchema(metadata.getSchema)) } else if (!seen.contains(serializedSchema.get)) { @@ -646,4 +617,106 @@ private[sql] object ParquetRelation2 extends Logging { .filter(_.nullable) StructType(parquetSchema ++ missingFields) } + + /** + * Figures out a merged Parquet schema with a distributed Spark job. + * + * Note that locality is not taken into consideration here because: + * + * 1. For a single Parquet part-file, in most cases the footer only resides in the last block of + * that file. Thus we only need to retrieve the location of the last block. However, Hadoop + * `FileSystem` only provides API to retrieve locations of all blocks, which can be + * potentially expensive. + * + * 2. This optimization is mainly useful for S3, where file metadata operations can be pretty + * slow. And basically locality is not available when using S3 (you can't run computation on + * S3 nodes). + */ + def mergeSchemasInParallel( + filesToTouch: Seq[FileStatus], sqlContext: SQLContext): Option[StructType] = { + val assumeBinaryIsString = sqlContext.conf.isParquetBinaryAsString + val assumeInt96IsTimestamp = sqlContext.conf.isParquetINT96AsTimestamp + val followParquetFormatSpec = sqlContext.conf.followParquetFormatSpec + val serializedConf = new SerializableConfiguration(sqlContext.sparkContext.hadoopConfiguration) + + // HACK ALERT: + // + // Parquet requires `FileStatus`es to read footers. Here we try to send cached `FileStatus`es + // to executor side to avoid fetching them again. However, `FileStatus` is not `Serializable` + // but only `Writable`. What makes it worth, for some reason, `FileStatus` doesn't play well + // with `SerializableWritable[T]` and always causes a weird `IllegalStateException`. These + // facts virtually prevents us to serialize `FileStatus`es. + // + // Since Parquet only relies on path and length information of those `FileStatus`es to read + // footers, here we just extract them (which can be easily serialized), send them to executor + // side, and resemble fake `FileStatus`es there. + val partialFileStatusInfo = filesToTouch.map(f => (f.getPath.toString, f.getLen)) + + // Issues a Spark job to read Parquet schema in parallel. + val partiallyMergedSchemas = + sqlContext + .sparkContext + .parallelize(partialFileStatusInfo) + .mapPartitions { iterator => + // Resembles fake `FileStatus`es with serialized path and length information. + val fakeFileStatuses = iterator.map { case (path, length) => + new FileStatus(length, false, 0, 0, 0, 0, null, null, null, new Path(path)) + }.toSeq + + // Skips row group information since we only need the schema + val skipRowGroups = true + + // Reads footers in multi-threaded manner within each task + val footers = + ParquetFileReader.readAllFootersInParallel( + serializedConf.value, fakeFileStatuses, skipRowGroups) + + // Converter used to convert Parquet `MessageType` to Spark SQL `StructType` + val converter = + new CatalystSchemaConverter( + assumeBinaryIsString = assumeBinaryIsString, + assumeInt96IsTimestamp = assumeInt96IsTimestamp, + followParquetFormatSpec = followParquetFormatSpec) + + footers.map { footer => + ParquetRelation2.readSchemaFromFooter(footer, converter) + }.reduceOption(_ merge _).iterator + }.collect() + + partiallyMergedSchemas.reduceOption(_ merge _) + } + + /** + * Reads Spark SQL schema from a Parquet footer. If a valid serialized Spark SQL schema string + * can be found in the file metadata, returns the deserialized [[StructType]], otherwise, returns + * a [[StructType]] converted from the [[MessageType]] stored in this footer. + */ + def readSchemaFromFooter( + footer: Footer, converter: CatalystSchemaConverter): StructType = { + val fileMetaData = footer.getParquetMetadata.getFileMetaData + fileMetaData + .getKeyValueMetaData + .toMap + .get(RowReadSupport.SPARK_METADATA_KEY) + .flatMap(deserializeSchemaString) + .getOrElse(converter.convert(fileMetaData.getSchema)) + } + + private def deserializeSchemaString(schemaString: String): Option[StructType] = { + // Tries to deserialize the schema string as JSON first, then falls back to the case class + // string parser (data generated by older versions of Spark SQL uses this format). + Try(DataType.fromJson(schemaString).asInstanceOf[StructType]).recover { + case _: Throwable => + logInfo( + s"Serialized Spark schema in Parquet key-value metadata is not in JSON format, " + + "falling back to the deprecated DataType.fromCaseClassString parser.") + DataType.fromCaseClassString(schemaString).asInstanceOf[StructType] + }.recoverWith { + case cause: Throwable => + logWarning( + "Failed to parse and ignored serialized Spark schema in " + + s"Parquet key-value metadata:\n\t$schemaString", cause) + Failure(cause) + }.toOption + } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/sources/filters.scala b/sql/core/src/main/scala/org/apache/spark/sql/sources/filters.scala index 24e86ca415c51..4d942e4f9287a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/sources/filters.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/sources/filters.scala @@ -17,6 +17,10 @@ package org.apache.spark.sql.sources +//////////////////////////////////////////////////////////////////////////////////////////////////// +// This file defines all the filters that we can push down to the data sources. +//////////////////////////////////////////////////////////////////////////////////////////////////// + /** * A filter predicate for data sources. * diff --git a/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala b/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala index b13c5313b25c9..7cd005b959488 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala @@ -24,15 +24,17 @@ import org.apache.hadoop.conf.Configuration import org.apache.hadoop.fs.{FileStatus, FileSystem, Path} import org.apache.hadoop.mapreduce.{Job, TaskAttemptContext} +import org.apache.spark.{Logging, SparkContext} import org.apache.spark.annotation.{DeveloperApi, Experimental} import org.apache.spark.broadcast.Broadcast import org.apache.spark.rdd.RDD -import org.apache.spark.sql.{DataFrame, Row, SaveMode, SQLContext} import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.execution.RDDConversions import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.codegen.GenerateMutableProjection +import org.apache.spark.sql.execution.RDDConversions +import org.apache.spark.sql.execution.datasources.{PartitioningUtils, PartitionSpec, Partition} import org.apache.spark.sql.types.StructType +import org.apache.spark.sql._ import org.apache.spark.util.SerializableConfiguration /** @@ -75,7 +77,7 @@ trait RelationProvider { * A new instance of this class with be instantiated each time a DDL call is made. * * The difference between a [[RelationProvider]] and a [[SchemaRelationProvider]] is that - * users need to provide a schema when using a SchemaRelationProvider. + * users need to provide a schema when using a [[SchemaRelationProvider]]. * A relation provider can inherits both [[RelationProvider]] and [[SchemaRelationProvider]] * if it can support both schema inference and user-specified schemas. * @@ -111,7 +113,7 @@ trait SchemaRelationProvider { * * The difference between a [[RelationProvider]] and a [[HadoopFsRelationProvider]] is * that users need to provide a schema and a (possibly empty) list of partition columns when - * using a SchemaRelationProvider. A relation provider can inherits both [[RelationProvider]], + * using a [[HadoopFsRelationProvider]]. A relation provider can inherits both [[RelationProvider]], * and [[HadoopFsRelationProvider]] if it can support schema inference, user-specified * schemas, and accessing partitioned relations. * @@ -367,7 +369,9 @@ abstract class OutputWriter { */ @Experimental abstract class HadoopFsRelation private[sql](maybePartitionSpec: Option[PartitionSpec]) - extends BaseRelation { + extends BaseRelation with Logging { + + logInfo("Constructing HadoopFsRelation") def this() = this(None) @@ -382,36 +386,40 @@ abstract class HadoopFsRelation private[sql](maybePartitionSpec: Option[Partitio var leafDirToChildrenFiles = mutable.Map.empty[Path, Array[FileStatus]] - def refresh(): Unit = { - // We don't filter files/directories whose name start with "_" except "_temporary" here, as - // specific data sources may take advantages over them (e.g. Parquet _metadata and - // _common_metadata files). "_temporary" directories are explicitly ignored since failed - // tasks/jobs may leave partial/corrupted data files there. - def listLeafFilesAndDirs(fs: FileSystem, status: FileStatus): Set[FileStatus] = { - if (status.getPath.getName.toLowerCase == "_temporary") { - Set.empty + private def listLeafFiles(paths: Array[String]): Set[FileStatus] = { + if (paths.length >= sqlContext.conf.parallelPartitionDiscoveryThreshold) { + HadoopFsRelation.listLeafFilesInParallel(paths, hadoopConf, sqlContext.sparkContext) + } else { + val statuses = paths.flatMap { path => + val hdfsPath = new Path(path) + val fs = hdfsPath.getFileSystem(hadoopConf) + val qualified = hdfsPath.makeQualified(fs.getUri, fs.getWorkingDirectory) + + logInfo(s"Listing $qualified on driver") + Try(fs.listStatus(qualified)).getOrElse(Array.empty) + }.filterNot { status => + val name = status.getPath.getName + name.toLowerCase == "_temporary" || name.startsWith(".") + } + + val (dirs, files) = statuses.partition(_.isDir) + + if (dirs.isEmpty) { + files.toSet } else { - val (dirs, files) = fs.listStatus(status.getPath).partition(_.isDir) - val leafDirs = if (dirs.isEmpty) Set(status) else Set.empty[FileStatus] - files.toSet ++ leafDirs ++ dirs.flatMap(dir => listLeafFilesAndDirs(fs, dir)) + files.toSet ++ listLeafFiles(dirs.map(_.getPath.toString)) } } + } - leafFiles.clear() + def refresh(): Unit = { + val files = listLeafFiles(paths) - val statuses = paths.flatMap { path => - val hdfsPath = new Path(path) - val fs = hdfsPath.getFileSystem(hadoopConf) - val qualified = hdfsPath.makeQualified(fs.getUri, fs.getWorkingDirectory) - Try(fs.getFileStatus(qualified)).toOption.toArray.flatMap(listLeafFilesAndDirs(fs, _)) - }.filterNot { status => - // SPARK-8037: Ignores files like ".DS_Store" and other hidden files/directories - status.getPath.getName.startsWith(".") - } + leafFiles.clear() + leafDirToChildrenFiles.clear() - val files = statuses.filterNot(_.isDir) leafFiles ++= files.map(f => f.getPath -> f).toMap - leafDirToChildrenFiles ++= files.groupBy(_.getPath.getParent) + leafDirToChildrenFiles ++= files.toArray.groupBy(_.getPath.getParent) } } @@ -515,7 +523,7 @@ abstract class HadoopFsRelation private[sql](maybePartitionSpec: Option[Partitio }) } - private[sources] final def buildScan( + private[sql] final def buildScan( requiredColumns: Array[String], filters: Array[Filter], inputPaths: Array[String], @@ -666,3 +674,63 @@ abstract class HadoopFsRelation private[sql](maybePartitionSpec: Option[Partitio */ def prepareJobForWrite(job: Job): OutputWriterFactory } + +private[sql] object HadoopFsRelation extends Logging { + // We don't filter files/directories whose name start with "_" except "_temporary" here, as + // specific data sources may take advantages over them (e.g. Parquet _metadata and + // _common_metadata files). "_temporary" directories are explicitly ignored since failed + // tasks/jobs may leave partial/corrupted data files there. Files and directories whose name + // start with "." are also ignored. + def listLeafFiles(fs: FileSystem, status: FileStatus): Array[FileStatus] = { + logInfo(s"Listing ${status.getPath}") + val name = status.getPath.getName.toLowerCase + if (name == "_temporary" || name.startsWith(".")) { + Array.empty + } else { + val (dirs, files) = fs.listStatus(status.getPath).partition(_.isDir) + files ++ dirs.flatMap(dir => listLeafFiles(fs, dir)) + } + } + + // `FileStatus` is Writable but not serializable. What make it worse, somehow it doesn't play + // well with `SerializableWritable`. So there seems to be no way to serialize a `FileStatus`. + // Here we use `FakeFileStatus` to extract key components of a `FileStatus` to serialize it from + // executor side and reconstruct it on driver side. + case class FakeFileStatus( + path: String, + length: Long, + isDir: Boolean, + blockReplication: Short, + blockSize: Long, + modificationTime: Long, + accessTime: Long) + + def listLeafFilesInParallel( + paths: Array[String], + hadoopConf: Configuration, + sparkContext: SparkContext): Set[FileStatus] = { + logInfo(s"Listing leaf files and directories in parallel under: ${paths.mkString(", ")}") + + val serializableConfiguration = new SerializableConfiguration(hadoopConf) + val fakeStatuses = sparkContext.parallelize(paths).flatMap { path => + val hdfsPath = new Path(path) + val fs = hdfsPath.getFileSystem(serializableConfiguration.value) + val qualified = hdfsPath.makeQualified(fs.getUri, fs.getWorkingDirectory) + Try(listLeafFiles(fs, fs.getFileStatus(qualified))).getOrElse(Array.empty) + }.map { status => + FakeFileStatus( + status.getPath.toString, + status.getLen, + status.isDir, + status.getReplication, + status.getBlockSize, + status.getModificationTime, + status.getAccessTime) + }.collect() + + fakeStatuses.map { f => + new FileStatus( + f.length, f.isDir, f.blockReplication, f.blockSize, f.modificationTime, new Path(f.path)) + }.toSet + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala index 88bb743ab0bc9..1f9f7118c3f04 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala @@ -201,6 +201,46 @@ class ColumnExpressionSuite extends QueryTest { Row(false, true)) } + test("isNaN") { + val testData = ctx.createDataFrame(ctx.sparkContext.parallelize( + Row(Double.NaN, Float.NaN) :: + Row(math.log(-1), math.log(-3).toFloat) :: + Row(null, null) :: + Row(Double.MaxValue, Float.MinValue):: Nil), + StructType(Seq(StructField("a", DoubleType), StructField("b", FloatType)))) + + checkAnswer( + testData.select($"a".isNaN, $"b".isNaN), + Row(true, true) :: Row(true, true) :: Row(false, false) :: Row(false, false) :: Nil) + + checkAnswer( + testData.select(isNaN($"a"), isNaN($"b")), + Row(true, true) :: Row(true, true) :: Row(false, false) :: Row(false, false) :: Nil) + + checkAnswer( + ctx.sql("select isnan(15), isnan('invalid')"), + Row(false, false)) + } + + test("nanvl") { + val testData = ctx.createDataFrame(ctx.sparkContext.parallelize( + Row(null, 3.0, Double.NaN, Double.PositiveInfinity) :: Nil), + StructType(Seq(StructField("a", DoubleType), StructField("b", DoubleType), + StructField("c", DoubleType), StructField("d", DoubleType)))) + + checkAnswer( + testData.select( + nanvl($"a", lit(5)), nanvl($"b", lit(10)), + nanvl($"c", lit(null).cast(DoubleType)), nanvl($"d", lit(10))), + Row(null, 3.0, null, Double.PositiveInfinity) + ) + testData.registerTempTable("t") + checkAnswer( + ctx.sql("select nanvl(a, 5), nanvl(b, 10), nanvl(c, null), nanvl(d, 10) from t"), + Row(null, 3.0, null, Double.PositiveInfinity) + ) + } + test("===") { checkAnswer( testData2.filter($"a" === 1), @@ -429,7 +469,7 @@ class ColumnExpressionSuite extends QueryTest { test("monotonicallyIncreasingId") { // Make sure we have 2 partitions, each with 2 records. - val df = ctx.sparkContext.parallelize(1 to 2, 2).mapPartitions { iter => + val df = ctx.sparkContext.parallelize(Seq[Int](), 2).mapPartitions { _ => Iterator(Tuple1(1), Tuple1(2)) }.toDF("a") checkAnswer( @@ -439,10 +479,13 @@ class ColumnExpressionSuite extends QueryTest { } test("sparkPartitionId") { - val df = ctx.sparkContext.parallelize(1 to 1, 1).map(i => (i, i)).toDF("a", "b") + // Make sure we have 2 partitions, each with 2 records. + val df = ctx.sparkContext.parallelize(Seq[Int](), 2).mapPartitions { _ => + Iterator(Tuple1(1), Tuple1(2)) + }.toDF("a") checkAnswer( df.select(sparkPartitionId()), - Row(0) + Row(0) :: Row(0) :: Row(1) :: Row(1) :: Nil ) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameDateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameDateSuite.scala deleted file mode 100644 index a4719a38de1d4..0000000000000 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameDateSuite.scala +++ /dev/null @@ -1,56 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql - -import java.sql.{Date, Timestamp} - -class DataFrameDateTimeSuite extends QueryTest { - - private lazy val ctx = org.apache.spark.sql.test.TestSQLContext - import ctx.implicits._ - - test("timestamp comparison with date strings") { - val df = Seq( - (1, Timestamp.valueOf("2015-01-01 00:00:00")), - (2, Timestamp.valueOf("2014-01-01 00:00:00"))).toDF("i", "t") - - checkAnswer( - df.select("t").filter($"t" <= "2014-06-01"), - Row(Timestamp.valueOf("2014-01-01 00:00:00")) :: Nil) - - - checkAnswer( - df.select("t").filter($"t" >= "2014-06-01"), - Row(Timestamp.valueOf("2015-01-01 00:00:00")) :: Nil) - } - - test("date comparison with date strings") { - val df = Seq( - (1, Date.valueOf("2015-01-01")), - (2, Date.valueOf("2014-01-01"))).toDF("i", "t") - - checkAnswer( - df.select("t").filter($"t" <= "2014-06-01"), - Row(Date.valueOf("2014-01-01")) :: Nil) - - - checkAnswer( - df.select("t").filter($"t" >= "2015"), - Row(Date.valueOf("2015-01-01")) :: Nil) - } -} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala index 6cebec95d2850..1baec5d37699d 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala @@ -160,7 +160,7 @@ class DataFrameFunctionsSuite extends QueryTest { test("misc md5 function") { val df = Seq(("ABC", Array[Byte](1, 2, 3, 4, 5, 6))).toDF("a", "b") checkAnswer( - df.select(md5($"a"), md5("b")), + df.select(md5($"a"), md5($"b")), Row("902fbdd2b1df0c4f70b4a5d23525e932", "6ac1e56bc78f031059be7be854522c4c")) checkAnswer( @@ -171,7 +171,7 @@ class DataFrameFunctionsSuite extends QueryTest { test("misc sha1 function") { val df = Seq(("ABC", "ABC".getBytes)).toDF("a", "b") checkAnswer( - df.select(sha1($"a"), sha1("b")), + df.select(sha1($"a"), sha1($"b")), Row("3c01bdbb26f358bab27f267924aa2c9a03fcfdb8", "3c01bdbb26f358bab27f267924aa2c9a03fcfdb8")) val dfEmpty = Seq(("", "".getBytes)).toDF("a", "b") @@ -183,7 +183,7 @@ class DataFrameFunctionsSuite extends QueryTest { test("misc sha2 function") { val df = Seq(("ABC", Array[Byte](1, 2, 3, 4, 5, 6))).toDF("a", "b") checkAnswer( - df.select(sha2($"a", 256), sha2("b", 256)), + df.select(sha2($"a", 256), sha2($"b", 256)), Row("b5d4045c3f466fa91fe2cc6abe79232a1a57cdf104f7a26e716e0a1e2789df78", "7192385c3c0605de55bb9476ce1d90748190ecb32a8eed7f5207b30cf6a1fe89")) @@ -200,7 +200,7 @@ class DataFrameFunctionsSuite extends QueryTest { test("misc crc32 function") { val df = Seq(("ABC", Array[Byte](1, 2, 3, 4, 5, 6))).toDF("a", "b") checkAnswer( - df.select(crc32($"a"), crc32("b")), + df.select(crc32($"a"), crc32($"b")), Row(2743272264L, 2180413220L)) checkAnswer( @@ -208,199 +208,94 @@ class DataFrameFunctionsSuite extends QueryTest { Row(2743272264L, 2180413220L)) } - test("string length function") { - val df = Seq(("abc", "")).toDF("a", "b") - checkAnswer( - df.select(strlen($"a"), strlen("b")), - Row(3, 0)) - - checkAnswer( - df.selectExpr("length(a)", "length(b)"), - Row(3, 0)) - } - - test("Levenshtein distance") { - val df = Seq(("kitten", "sitting"), ("frog", "fog")).toDF("l", "r") - checkAnswer(df.select(levenshtein("l", "r")), Seq(Row(3), Row(1))) - checkAnswer(df.selectExpr("levenshtein(l, r)"), Seq(Row(3), Row(1))) - } - - test("string ascii function") { - val df = Seq(("abc", "")).toDF("a", "b") - checkAnswer( - df.select(ascii($"a"), ascii("b")), - Row(97, 0)) - - checkAnswer( - df.selectExpr("ascii(a)", "ascii(b)"), - Row(97, 0)) - } - - test("string base64/unbase64 function") { - val bytes = Array[Byte](1, 2, 3, 4) - val df = Seq((bytes, "AQIDBA==")).toDF("a", "b") - checkAnswer( - df.select(base64("a"), base64($"a"), unbase64("b"), unbase64($"b")), - Row("AQIDBA==", "AQIDBA==", bytes, bytes)) - - checkAnswer( - df.selectExpr("base64(a)", "unbase64(b)"), - Row("AQIDBA==", bytes)) - } - - test("string encode/decode function") { - val bytes = Array[Byte](-27, -92, -89, -27, -115, -125, -28, -72, -106, -25, -107, -116) - // scalastyle:off - // non ascii characters are not allowed in the code, so we disable the scalastyle here. - val df = Seq(("大千世界", "utf-8", bytes)).toDF("a", "b", "c") - checkAnswer( - df.select( - encode($"a", "utf-8"), - encode("a", "utf-8"), - decode($"c", "utf-8"), - decode("c", "utf-8")), - Row(bytes, bytes, "大千世界", "大千世界")) - - checkAnswer( - df.selectExpr("encode(a, 'utf-8')", "decode(c, 'utf-8')"), - Row(bytes, "大千世界")) - // scalastyle:on - } - - test("string trim functions") { - val df = Seq((" example ", "")).toDF("a", "b") - - checkAnswer( - df.select(ltrim($"a"), rtrim($"a"), trim($"a")), - Row("example ", " example", "example")) - - checkAnswer( - df.selectExpr("ltrim(a)", "rtrim(a)", "trim(a)"), - Row("example ", " example", "example")) - } - - test("string formatString function") { - val df = Seq(("aa%d%s", 123, "cc")).toDF("a", "b", "c") - - checkAnswer( - df.select(formatString($"a", $"b", $"c"), formatString("aa%d%s", "b", "c")), - Row("aa123cc", "aa123cc")) - - checkAnswer( - df.selectExpr("printf(a, b, c)"), - Row("aa123cc")) - } - - test("string instr function") { - val df = Seq(("aaads", "aa", "zz")).toDF("a", "b", "c") - - checkAnswer( - df.select(instr($"a", $"b"), instr("a", "b")), - Row(1, 1)) - - checkAnswer( - df.selectExpr("instr(a, b)"), - Row(1)) - } - - test("string locate function") { - val df = Seq(("aaads", "aa", "zz", 1)).toDF("a", "b", "c", "d") - + test("conditional function: least") { checkAnswer( - df.select( - locate($"b", $"a"), locate("b", "a"), locate($"b", $"a", 1), - locate("b", "a", 1), locate($"b", $"a", $"d"), locate("b", "a", "d")), - Row(1, 1, 2, 2, 2, 2)) - + testData2.select(least(lit(-1), lit(0), col("a"), col("b"))).limit(1), + Row(-1) + ) checkAnswer( - df.selectExpr("locate(b, a)", "locate(b, a, d)"), - Row(1, 2)) + ctx.sql("SELECT least(a, 2) as l from testData2 order by l"), + Seq(Row(1), Row(1), Row(2), Row(2), Row(2), Row(2)) + ) } - test("string padding functions") { - val df = Seq(("hi", 5, "??")).toDF("a", "b", "c") - + test("conditional function: greatest") { checkAnswer( - df.select( - lpad($"a", $"b", $"c"), rpad("a", "b", "c"), - lpad($"a", 1, $"c"), rpad("a", 1, "c")), - Row("???hi", "hi???", "h", "h")) - + testData2.select(greatest(lit(2), lit(3), col("a"), col("b"))).limit(1), + Row(3) + ) checkAnswer( - df.selectExpr("lpad(a, b, c)", "rpad(a, b, c)", "lpad(a, 1, c)", "rpad(a, 1, c)"), - Row("???hi", "hi???", "h", "h")) + ctx.sql("SELECT greatest(a, 2) as g from testData2 order by g"), + Seq(Row(2), Row(2), Row(2), Row(2), Row(3), Row(3)) + ) } - test("string repeat function") { - val df = Seq(("hi", 2)).toDF("a", "b") - + test("pmod") { + val intData = Seq((7, 3), (-7, 3)).toDF("a", "b") checkAnswer( - df.select( - repeat($"a", 2), repeat("a", 2), repeat($"a", $"b"), repeat("a", "b")), - Row("hihi", "hihi", "hihi", "hihi")) - + intData.select(pmod('a, 'b)), + Seq(Row(1), Row(2)) + ) checkAnswer( - df.selectExpr("repeat(a, 2)", "repeat(a, b)"), - Row("hihi", "hihi")) - } - - test("string reverse function") { - val df = Seq(("hi", "hhhi")).toDF("a", "b") - + intData.select(pmod('a, lit(3))), + Seq(Row(1), Row(2)) + ) checkAnswer( - df.select(reverse($"a"), reverse("b")), - Row("ih", "ihhh")) - + intData.select(pmod(lit(-7), 'b)), + Seq(Row(2), Row(2)) + ) checkAnswer( - df.selectExpr("reverse(b)"), - Row("ihhh")) - } - - test("string space function") { - val df = Seq((2, 3)).toDF("a", "b") - + intData.selectExpr("pmod(a, b)"), + Seq(Row(1), Row(2)) + ) checkAnswer( - df.select(space($"a"), space("b")), - Row(" ", " ")) - + intData.selectExpr("pmod(a, 3)"), + Seq(Row(1), Row(2)) + ) checkAnswer( - df.selectExpr("space(b)"), - Row(" ")) - } - - test("string split function") { - val df = Seq(("aa2bb3cc", "[1-9]+")).toDF("a", "b") - + intData.selectExpr("pmod(-7, b)"), + Seq(Row(2), Row(2)) + ) + val doubleData = Seq((7.2, 4.1)).toDF("a", "b") checkAnswer( - df.select( - split($"a", "[1-9]+"), - split("a", "[1-9]+")), - Row(Seq("aa", "bb", "cc"), Seq("aa", "bb", "cc"))) - + doubleData.select(pmod('a, 'b)), + Seq(Row(3.1000000000000005)) // same as hive + ) checkAnswer( - df.selectExpr("split(a, '[1-9]+')"), - Row(Seq("aa", "bb", "cc"))) + doubleData.select(pmod(lit(2), lit(Int.MaxValue))), + Seq(Row(2)) + ) } - test("conditional function: least") { + test("array size function") { + val df = Seq( + (Array[Int](1, 2), "x"), + (Array[Int](), "y"), + (Array[Int](1, 2, 3), "z") + ).toDF("a", "b") checkAnswer( - testData2.select(least(lit(-1), lit(0), col("a"), col("b"))).limit(1), - Row(-1) + df.select(size("a")), + Seq(Row(2), Row(0), Row(3)) ) checkAnswer( - ctx.sql("SELECT least(a, 2) as l from testData2 order by l"), - Seq(Row(1), Row(1), Row(2), Row(2), Row(2), Row(2)) + df.selectExpr("size(a)"), + Seq(Row(2), Row(0), Row(3)) ) } - test("conditional function: greatest") { + test("map size function") { + val df = Seq( + (Map[Int, Int](1 -> 1, 2 -> 2), "x"), + (Map[Int, Int](), "y"), + (Map[Int, Int](1 -> 1, 2 -> 2, 3 -> 3), "z") + ).toDF("a", "b") checkAnswer( - testData2.select(greatest(lit(2), lit(3), col("a"), col("b"))).limit(1), - Row(3) + df.select(size("a")), + Seq(Row(2), Row(0), Row(3)) ) checkAnswer( - ctx.sql("SELECT greatest(a, 2) as g from testData2 order by g"), - Seq(Row(2), Row(2), Row(2), Row(2), Row(3), Row(3)) + df.selectExpr("size(a)"), + Seq(Row(2), Row(0), Row(3)) ) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameNaFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameNaFunctionsSuite.scala index 495701d4f616c..dbe3b44ee2c79 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameNaFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameNaFunctionsSuite.scala @@ -30,8 +30,10 @@ class DataFrameNaFunctionsSuite extends QueryTest { ("Bob", 16, 176.5), ("Alice", null, 164.3), ("David", 60, null), + ("Nina", 25, Double.NaN), ("Amy", null, null), - (null, null, null)).toDF("name", "age", "height") + (null, null, null) + ).toDF("name", "age", "height") } test("drop") { @@ -39,12 +41,12 @@ class DataFrameNaFunctionsSuite extends QueryTest { val rows = input.collect() checkAnswer( - input.na.drop("name" :: Nil), - rows(0) :: rows(1) :: rows(2) :: rows(3) :: Nil) + input.na.drop("name" :: Nil).select("name"), + Row("Bob") :: Row("Alice") :: Row("David") :: Row("Nina") :: Row("Amy") :: Nil) checkAnswer( - input.na.drop("age" :: Nil), - rows(0) :: rows(2) :: Nil) + input.na.drop("age" :: Nil).select("name"), + Row("Bob") :: Row("David") :: Row("Nina") :: Nil) checkAnswer( input.na.drop("age" :: "height" :: Nil), @@ -67,8 +69,8 @@ class DataFrameNaFunctionsSuite extends QueryTest { val rows = input.collect() checkAnswer( - input.na.drop("all"), - rows(0) :: rows(1) :: rows(2) :: rows(3) :: Nil) + input.na.drop("all").select("name"), + Row("Bob") :: Row("Alice") :: Row("David") :: Row("Nina") :: Row("Amy") :: Nil) checkAnswer( input.na.drop("any"), @@ -79,8 +81,8 @@ class DataFrameNaFunctionsSuite extends QueryTest { rows(0) :: Nil) checkAnswer( - input.na.drop("all", Seq("age", "height")), - rows(0) :: rows(1) :: rows(2) :: Nil) + input.na.drop("all", Seq("age", "height")).select("name"), + Row("Bob") :: Row("Alice") :: Row("David") :: Row("Nina") :: Nil) } test("drop with threshold") { @@ -108,6 +110,7 @@ class DataFrameNaFunctionsSuite extends QueryTest { Row("Bob", 16, 176.5) :: Row("Alice", 50, 164.3) :: Row("David", 60, 50.6) :: + Row("Nina", 25, 50.6) :: Row("Amy", 50, 50.6) :: Row(null, 50, 50.6) :: Nil) @@ -117,17 +120,19 @@ class DataFrameNaFunctionsSuite extends QueryTest { // string checkAnswer( input.na.fill("unknown").select("name"), - Row("Bob") :: Row("Alice") :: Row("David") :: Row("Amy") :: Row("unknown") :: Nil) + Row("Bob") :: Row("Alice") :: Row("David") :: + Row("Nina") :: Row("Amy") :: Row("unknown") :: Nil) assert(input.na.fill("unknown").columns.toSeq === input.columns.toSeq) // fill double with subset columns checkAnswer( - input.na.fill(50.6, "age" :: Nil), - Row("Bob", 16, 176.5) :: - Row("Alice", 50, 164.3) :: - Row("David", 60, null) :: - Row("Amy", 50, null) :: - Row(null, 50, null) :: Nil) + input.na.fill(50.6, "age" :: Nil).select("name", "age"), + Row("Bob", 16) :: + Row("Alice", 50) :: + Row("David", 60) :: + Row("Nina", 25) :: + Row("Amy", 50) :: + Row(null, 50) :: Nil) // fill string with subset columns checkAnswer( @@ -164,29 +169,27 @@ class DataFrameNaFunctionsSuite extends QueryTest { 16 -> 61, 60 -> 6, 164.3 -> 461.3 // Alice is really tall - )) + )).collect() - checkAnswer( - out, - Row("Bob", 61, 176.5) :: - Row("Alice", null, 461.3) :: - Row("David", 6, null) :: - Row("Amy", null, null) :: - Row(null, null, null) :: Nil) + assert(out(0) === Row("Bob", 61, 176.5)) + assert(out(1) === Row("Alice", null, 461.3)) + assert(out(2) === Row("David", 6, null)) + assert(out(3).get(2).asInstanceOf[Double].isNaN) + assert(out(4) === Row("Amy", null, null)) + assert(out(5) === Row(null, null, null)) // Replace only the age column val out1 = input.na.replace("age", Map( 16 -> 61, 60 -> 6, 164.3 -> 461.3 // Alice is really tall - )) - - checkAnswer( - out1, - Row("Bob", 61, 176.5) :: - Row("Alice", null, 164.3) :: - Row("David", 6, null) :: - Row("Amy", null, null) :: - Row(null, null, null) :: Nil) + )).collect() + + assert(out1(0) === Row("Bob", 61, 176.5)) + assert(out1(1) === Row("Alice", null, 164.3)) + assert(out1(2) === Row("David", 6, null)) + assert(out1(3).get(2).asInstanceOf[Double].isNaN) + assert(out1(4) === Row("Amy", null, null)) + assert(out1(5) === Row(null, null, null)) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala index f592a9934d0e6..f67f2c60c0e16 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala @@ -17,19 +17,24 @@ package org.apache.spark.sql +import java.io.File + import scala.language.postfixOps +import scala.util.Random +import org.apache.spark.sql.catalyst.plans.logical.OneRowRelation import org.apache.spark.sql.functions._ import org.apache.spark.sql.types._ -import org.apache.spark.sql.test.{ExamplePointUDT, ExamplePoint} - +import org.apache.spark.sql.test.{ExamplePointUDT, ExamplePoint, SQLTestUtils} -class DataFrameSuite extends QueryTest { +class DataFrameSuite extends QueryTest with SQLTestUtils { import org.apache.spark.sql.TestData._ lazy val ctx = org.apache.spark.sql.test.TestSQLContext import ctx.implicits._ + def sqlContext: SQLContext = ctx + test("analysis error should be eagerly reported") { val oldSetting = ctx.conf.dataFrameEagerAnalysis // Eager analysis. @@ -738,11 +743,32 @@ class DataFrameSuite extends QueryTest { df.col("t.``") } + test("SPARK-8797: sort by float column containing NaN should not crash") { + val inputData = Seq.fill(10)(Tuple1(Float.NaN)) ++ (1 to 1000).map(x => Tuple1(x.toFloat)) + val df = Random.shuffle(inputData).toDF("a") + df.orderBy("a").collect() + } + + test("SPARK-8797: sort by double column containing NaN should not crash") { + val inputData = Seq.fill(10)(Tuple1(Double.NaN)) ++ (1 to 1000).map(x => Tuple1(x.toDouble)) + val df = Random.shuffle(inputData).toDF("a") + df.orderBy("a").collect() + } + + test("NaN is greater than all other non-NaN numeric values") { + val maxDouble = Seq(Double.NaN, Double.PositiveInfinity, Double.MaxValue) + .map(Tuple1.apply).toDF("a").selectExpr("max(a)").first() + assert(java.lang.Double.isNaN(maxDouble.getDouble(0))) + val maxFloat = Seq(Float.NaN, Float.PositiveInfinity, Float.MaxValue) + .map(Tuple1.apply).toDF("a").selectExpr("max(a)").first() + assert(java.lang.Float.isNaN(maxFloat.getFloat(0))) + } + test("SPARK-8072: Better Exception for Duplicate Columns") { // only one duplicate column present val e = intercept[org.apache.spark.sql.AnalysisException] { - val df1 = Seq((1, 2, 3), (2, 3, 4), (3, 4, 5)).toDF("column1", "column2", "column1") - .write.format("parquet").save("temp") + Seq((1, 2, 3), (2, 3, 4), (3, 4, 5)).toDF("column1", "column2", "column1") + .write.format("parquet").save("temp") } assert(e.getMessage.contains("Duplicate column(s)")) assert(e.getMessage.contains("parquet")) @@ -751,9 +777,9 @@ class DataFrameSuite extends QueryTest { // multiple duplicate columns present val f = intercept[org.apache.spark.sql.AnalysisException] { - val df2 = Seq((1, 2, 3, 4, 5), (2, 3, 4, 5, 6), (3, 4, 5, 6, 7)) - .toDF("column1", "column2", "column3", "column1", "column3") - .write.format("json").save("temp") + Seq((1, 2, 3, 4, 5), (2, 3, 4, 5, 6), (3, 4, 5, 6, 7)) + .toDF("column1", "column2", "column3", "column1", "column3") + .write.format("json").save("temp") } assert(f.getMessage.contains("Duplicate column(s)")) assert(f.getMessage.contains("JSON")) @@ -761,4 +787,49 @@ class DataFrameSuite extends QueryTest { assert(f.getMessage.contains("column3")) assert(!f.getMessage.contains("column2")) } + + test("SPARK-6941: Better error message for inserting into RDD-based Table") { + withTempDir { dir => + + val tempParquetFile = new File(dir, "tmp_parquet") + val tempJsonFile = new File(dir, "tmp_json") + + val df = Seq(Tuple1(1)).toDF() + val insertion = Seq(Tuple1(2)).toDF("col") + + // pass case: parquet table (HadoopFsRelation) + df.write.mode(SaveMode.Overwrite).parquet(tempParquetFile.getCanonicalPath) + val pdf = ctx.read.parquet(tempParquetFile.getCanonicalPath) + pdf.registerTempTable("parquet_base") + insertion.write.insertInto("parquet_base") + + // pass case: json table (InsertableRelation) + df.write.mode(SaveMode.Overwrite).json(tempJsonFile.getCanonicalPath) + val jdf = ctx.read.json(tempJsonFile.getCanonicalPath) + jdf.registerTempTable("json_base") + insertion.write.mode(SaveMode.Overwrite).insertInto("json_base") + + // error cases: insert into an RDD + df.registerTempTable("rdd_base") + val e1 = intercept[AnalysisException] { + insertion.write.insertInto("rdd_base") + } + assert(e1.getMessage.contains("Inserting into an RDD-based table is not allowed.")) + + // error case: insert into a logical plan that is not a LeafNode + val indirectDS = pdf.select("_1").filter($"_1" > 5) + indirectDS.registerTempTable("indirect_ds") + val e2 = intercept[AnalysisException] { + insertion.write.insertInto("indirect_ds") + } + assert(e2.getMessage.contains("Inserting into an RDD-based table is not allowed.")) + + // error case: insert into an OneRowRelation + new DataFrame(ctx, OneRowRelation).registerTempTable("one_row") + val e3 = intercept[AnalysisException] { + insertion.write.insertInto("one_row") + } + assert(e3.getMessage.contains("Inserting into an RDD-based table is not allowed.")) + } + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DateFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DateFunctionsSuite.scala new file mode 100644 index 0000000000000..9e80ae86920d9 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/DateFunctionsSuite.scala @@ -0,0 +1,187 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql + +import java.sql.{Timestamp, Date} +import java.text.SimpleDateFormat + +import org.apache.spark.sql.functions._ + +class DateFunctionsSuite extends QueryTest { + private lazy val ctx = org.apache.spark.sql.test.TestSQLContext + + import ctx.implicits._ + + val sdf = new SimpleDateFormat("yyyy-MM-dd HH:mm:ss") + val sdfDate = new SimpleDateFormat("yyyy-MM-dd") + val d = new Date(sdf.parse("2015-04-08 13:10:15").getTime) + val ts = new Timestamp(sdf.parse("2013-04-08 13:10:15").getTime) + + test("timestamp comparison with date strings") { + val df = Seq( + (1, Timestamp.valueOf("2015-01-01 00:00:00")), + (2, Timestamp.valueOf("2014-01-01 00:00:00"))).toDF("i", "t") + + checkAnswer( + df.select("t").filter($"t" <= "2014-06-01"), + Row(Timestamp.valueOf("2014-01-01 00:00:00")) :: Nil) + + + checkAnswer( + df.select("t").filter($"t" >= "2014-06-01"), + Row(Timestamp.valueOf("2015-01-01 00:00:00")) :: Nil) + } + + test("date comparison with date strings") { + val df = Seq( + (1, Date.valueOf("2015-01-01")), + (2, Date.valueOf("2014-01-01"))).toDF("i", "t") + + checkAnswer( + df.select("t").filter($"t" <= "2014-06-01"), + Row(Date.valueOf("2014-01-01")) :: Nil) + + + checkAnswer( + df.select("t").filter($"t" >= "2015"), + Row(Date.valueOf("2015-01-01")) :: Nil) + } + + test("date format") { + val df = Seq((d, sdf.format(d), ts)).toDF("a", "b", "c") + + checkAnswer( + df.select(date_format("a", "y"), date_format("b", "y"), date_format("c", "y")), + Row("2015", "2015", "2013")) + + checkAnswer( + df.selectExpr("date_format(a, 'y')", "date_format(b, 'y')", "date_format(c, 'y')"), + Row("2015", "2015", "2013")) + } + + test("year") { + val df = Seq((d, sdfDate.format(d), ts)).toDF("a", "b", "c") + + checkAnswer( + df.select(year("a"), year("b"), year("c")), + Row(2015, 2015, 2013)) + + checkAnswer( + df.selectExpr("year(a)", "year(b)", "year(c)"), + Row(2015, 2015, 2013)) + } + + test("quarter") { + val ts = new Timestamp(sdf.parse("2013-11-08 13:10:15").getTime) + + val df = Seq((d, sdfDate.format(d), ts)).toDF("a", "b", "c") + + checkAnswer( + df.select(quarter("a"), quarter("b"), quarter("c")), + Row(2, 2, 4)) + + checkAnswer( + df.selectExpr("quarter(a)", "quarter(b)", "quarter(c)"), + Row(2, 2, 4)) + } + + test("month") { + val df = Seq((d, sdfDate.format(d), ts)).toDF("a", "b", "c") + + checkAnswer( + df.select(month("a"), month("b"), month("c")), + Row(4, 4, 4)) + + checkAnswer( + df.selectExpr("month(a)", "month(b)", "month(c)"), + Row(4, 4, 4)) + } + + test("dayofmonth") { + val df = Seq((d, sdfDate.format(d), ts)).toDF("a", "b", "c") + + checkAnswer( + df.select(dayofmonth("a"), dayofmonth("b"), dayofmonth("c")), + Row(8, 8, 8)) + + checkAnswer( + df.selectExpr("day(a)", "day(b)", "dayofmonth(c)"), + Row(8, 8, 8)) + } + + test("dayofyear") { + val df = Seq((d, sdfDate.format(d), ts)).toDF("a", "b", "c") + + checkAnswer( + df.select(dayofyear("a"), dayofyear("b"), dayofyear("c")), + Row(98, 98, 98)) + + checkAnswer( + df.selectExpr("dayofyear(a)", "dayofyear(b)", "dayofyear(c)"), + Row(98, 98, 98)) + } + + test("hour") { + val df = Seq((d, sdf.format(d), ts)).toDF("a", "b", "c") + + checkAnswer( + df.select(hour("a"), hour("b"), hour("c")), + Row(0, 13, 13)) + + checkAnswer( + df.selectExpr("hour(a)", "hour(b)", "hour(c)"), + Row(0, 13, 13)) + } + + test("minute") { + val df = Seq((d, sdf.format(d), ts)).toDF("a", "b", "c") + + checkAnswer( + df.select(minute("a"), minute("b"), minute("c")), + Row(0, 10, 10)) + + checkAnswer( + df.selectExpr("minute(a)", "minute(b)", "minute(c)"), + Row(0, 10, 10)) + } + + test("second") { + val df = Seq((d, sdf.format(d), ts)).toDF("a", "b", "c") + + checkAnswer( + df.select(second("a"), second("b"), second("c")), + Row(0, 15, 15)) + + checkAnswer( + df.selectExpr("second(a)", "second(b)", "second(c)"), + Row(0, 15, 15)) + } + + test("weekofyear") { + val df = Seq((d, sdfDate.format(d), ts)).toDF("a", "b", "c") + + checkAnswer( + df.select(weekofyear("a"), weekofyear("b"), weekofyear("c")), + Row(15, 15, 15)) + + checkAnswer( + df.selectExpr("weekofyear(a)", "weekofyear(b)", "weekofyear(c)"), + Row(15, 15, 15)) + } + +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/MathExpressionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/MathExpressionsSuite.scala index 24bef21b999ea..21256704a5b16 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/MathExpressionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/MathExpressionsSuite.scala @@ -68,12 +68,7 @@ class MathExpressionsSuite extends QueryTest { if (f(-1) === math.log1p(-1)) { checkAnswer( nnDoubleData.select(c('b)), - (1 to 9).map(n => Row(f(n * -0.1))) :+ Row(Double.NegativeInfinity) - ) - } else { - checkAnswer( - nnDoubleData.select(c('b)), - (1 to 10).map(n => Row(null)) + (1 to 9).map(n => Row(f(n * -0.1))) :+ Row(null) ) } @@ -178,6 +173,18 @@ class MathExpressionsSuite extends QueryTest { Row(0.0, 1.0, 2.0)) } + test("conv") { + val df = Seq(("333", 10, 2)).toDF("num", "fromBase", "toBase") + checkAnswer(df.select(conv('num, 10, 16)), Row("14D")) + checkAnswer(df.select(conv(lit(100), 2, 16)), Row("4")) + checkAnswer(df.select(conv(lit(3122234455L), 10, 16)), Row("BA198457")) + checkAnswer(df.selectExpr("conv(num, fromBase, toBase)"), Row("101001101")) + checkAnswer(df.selectExpr("""conv("100", 2, 10)"""), Row("4")) + checkAnswer(df.selectExpr("""conv("-10", 16, -10)"""), Row("-16")) + checkAnswer( + df.selectExpr("""conv("9223372036854775807", 36, -16)"""), Row("-1")) // for overflow + } + test("floor") { testOneToOneMathFunction(floor, math.floor) } @@ -198,6 +205,21 @@ class MathExpressionsSuite extends QueryTest { testOneToOneMathFunction(rint, math.rint) } + test("round") { + val df = Seq(5, 55, 555).map(Tuple1(_)).toDF("a") + checkAnswer( + df.select(round('a), round('a, -1), round('a, -2)), + Seq(Row(5, 10, 0), Row(55, 60, 100), Row(555, 560, 600)) + ) + + val pi = 3.1415 + checkAnswer( + ctx.sql(s"SELECT round($pi, -3), round($pi, -2), round($pi, -1), " + + s"round($pi, 0), round($pi, 1), round($pi, 2), round($pi, 3)"), + Seq(Row(0.0, 0.0, 0.0, 3.0, 3.1, 3.14, 3.142)) + ) + } + test("exp") { testOneToOneMathFunction(exp, math.exp) } @@ -375,6 +397,5 @@ class MathExpressionsSuite extends QueryTest { val df = Seq((1, -1, "abc")).toDF("a", "b", "c") checkAnswer(df.selectExpr("positive(a)"), Row(1)) checkAnswer(df.selectExpr("positive(b)"), Row(-1)) - checkAnswer(df.selectExpr("positive(c)"), Row("abc")) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/RowSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/RowSuite.scala index d84b57af9c882..7cc6ffd7548d0 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/RowSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/RowSuite.scala @@ -73,4 +73,16 @@ class RowSuite extends SparkFunSuite { row.getAs[Int]("c") } } + + test("float NaN == NaN") { + val r1 = Row(Float.NaN) + val r2 = Row(Float.NaN) + assert(r1 === r2) + } + + test("double NaN == NaN") { + val r1 = Row(Double.NaN) + val r2 = Row(Double.NaN) + assert(r1 === r2) + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala index 231440892bf0b..ab8dce603c117 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala @@ -23,6 +23,7 @@ import java.sql.Timestamp import org.apache.spark.sql.catalyst.DefaultParserDialect import org.apache.spark.sql.catalyst.errors.DialectException +import org.apache.spark.sql.execution.aggregate.Aggregate2Sort import org.apache.spark.sql.execution.GeneratedAggregate import org.apache.spark.sql.functions._ import org.apache.spark.sql.TestData._ @@ -204,6 +205,7 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll with SQLTestUtils { var hasGeneratedAgg = false df.queryExecution.executedPlan.foreach { case generatedAgg: GeneratedAggregate => hasGeneratedAgg = true + case newAggregate: Aggregate2Sort => hasGeneratedAgg = true case _ => } if (!hasGeneratedAgg) { @@ -285,7 +287,7 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll with SQLTestUtils { // Aggregate with Code generation handling all null values testCodeGen( "SELECT sum('a'), avg('a'), count(null) FROM testData", - Row(0, null, 0) :: Nil) + Row(null, null, 0) :: Nil) } finally { sqlContext.dropTempTable("testData3x") sqlContext.setConf(SQLConf.CODEGEN_ENABLED, originalValue) @@ -395,6 +397,18 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll with SQLTestUtils { ) } + test("left semi greater than predicate and equal operator") { + checkAnswer( + sql("SELECT * FROM testData2 x LEFT SEMI JOIN testData2 y ON x.b = y.b and x.a >= y.a + 2"), + Seq(Row(3, 1), Row(3, 2)) + ) + + checkAnswer( + sql("SELECT * FROM testData2 x LEFT SEMI JOIN testData2 y ON x.b = y.a and x.a >= y.b + 1"), + Seq(Row(2, 1), Row(2, 2), Row(3, 1), Row(3, 2)) + ) + } + test("index into array of arrays") { checkAnswer( sql( @@ -636,6 +650,15 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll with SQLTestUtils { Row(2, 1, 2, 2, 1)) } + test("count of empty table") { + withTempTable("t") { + Seq.empty[(Int, Int)].toDF("a", "b").registerTempTable("t") + checkAnswer( + sql("select count(a) from t"), + Row(0)) + } + } + test("inner join where, one match per row") { checkAnswer( sql("SELECT * FROM upperCaseData JOIN lowerCaseData WHERE n = N"), @@ -1492,4 +1515,21 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll with SQLTestUtils { // Currently we don't yet support nanosecond checkIntervalParseError("select interval 23 nanosecond") } + + test("SPARK-8945: add and subtract expressions for interval type") { + import org.apache.spark.unsafe.types.Interval + + val df = sql("select interval 3 years -3 month 7 week 123 microseconds as i") + checkAnswer(df, Row(new Interval(12 * 3 - 3, 7L * 1000 * 1000 * 3600 * 24 * 7 + 123))) + + checkAnswer(df.select(df("i") + new Interval(2, 123)), + Row(new Interval(12 * 3 - 3 + 2, 7L * 1000 * 1000 * 3600 * 24 * 7 + 123 + 123))) + + checkAnswer(df.select(df("i") - new Interval(2, 123)), + Row(new Interval(12 * 3 - 3 - 2, 7L * 1000 * 1000 * 3600 * 24 * 7 + 123 - 123))) + + // unary minus + checkAnswer(df.select(-df("i")), + Row(new Interval(-(12 * 3 - 3), -(7L * 1000 * 1000 * 3600 * 24 * 7 + 123)))) + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/StringFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/StringFunctionsSuite.scala new file mode 100644 index 0000000000000..0f9c986f649a1 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/StringFunctionsSuite.scala @@ -0,0 +1,294 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql + +import org.apache.spark.sql.functions._ +import org.apache.spark.sql.types.Decimal + + +class StringFunctionsSuite extends QueryTest { + + private lazy val ctx = org.apache.spark.sql.test.TestSQLContext + import ctx.implicits._ + + test("string concat") { + val df = Seq[(String, String, String)](("a", "b", null)).toDF("a", "b", "c") + + checkAnswer( + df.select(concat($"a", $"b"), concat($"a", $"b", $"c")), + Row("ab", null)) + + checkAnswer( + df.selectExpr("concat(a, b)", "concat(a, b, c)"), + Row("ab", null)) + } + + test("string concat_ws") { + val df = Seq[(String, String, String)](("a", "b", null)).toDF("a", "b", "c") + + checkAnswer( + df.select(concat_ws("||", $"a", $"b", $"c")), + Row("a||b")) + + checkAnswer( + df.selectExpr("concat_ws('||', a, b, c)"), + Row("a||b")) + } + + test("string Levenshtein distance") { + val df = Seq(("kitten", "sitting"), ("frog", "fog")).toDF("l", "r") + checkAnswer(df.select(levenshtein($"l", $"r")), Seq(Row(3), Row(1))) + checkAnswer(df.selectExpr("levenshtein(l, r)"), Seq(Row(3), Row(1))) + } + + test("string regex_replace / regex_extract") { + val df = Seq(("100-200", "")).toDF("a", "b") + + checkAnswer( + df.select( + regexp_replace($"a", "(\\d+)", "num"), + regexp_extract($"a", "(\\d+)-(\\d+)", 1)), + Row("num-num", "100")) + + checkAnswer( + df.selectExpr( + "regexp_replace(a, '(\\d+)', 'num')", + "regexp_extract(a, '(\\d+)-(\\d+)', 2)"), + Row("num-num", "200")) + } + + test("string ascii function") { + val df = Seq(("abc", "")).toDF("a", "b") + checkAnswer( + df.select(ascii($"a"), ascii($"b")), + Row(97, 0)) + + checkAnswer( + df.selectExpr("ascii(a)", "ascii(b)"), + Row(97, 0)) + } + + test("string base64/unbase64 function") { + val bytes = Array[Byte](1, 2, 3, 4) + val df = Seq((bytes, "AQIDBA==")).toDF("a", "b") + checkAnswer( + df.select(base64($"a"), unbase64($"b")), + Row("AQIDBA==", bytes)) + + checkAnswer( + df.selectExpr("base64(a)", "unbase64(b)"), + Row("AQIDBA==", bytes)) + } + + test("string encode/decode function") { + val bytes = Array[Byte](-27, -92, -89, -27, -115, -125, -28, -72, -106, -25, -107, -116) + // scalastyle:off + // non ascii characters are not allowed in the code, so we disable the scalastyle here. + val df = Seq(("大千世界", "utf-8", bytes)).toDF("a", "b", "c") + checkAnswer( + df.select(encode($"a", "utf-8"), decode($"c", "utf-8")), + Row(bytes, "大千世界")) + + checkAnswer( + df.selectExpr("encode(a, 'utf-8')", "decode(c, 'utf-8')"), + Row(bytes, "大千世界")) + // scalastyle:on + } + + test("string trim functions") { + val df = Seq((" example ", "")).toDF("a", "b") + + checkAnswer( + df.select(ltrim($"a"), rtrim($"a"), trim($"a")), + Row("example ", " example", "example")) + + checkAnswer( + df.selectExpr("ltrim(a)", "rtrim(a)", "trim(a)"), + Row("example ", " example", "example")) + } + + test("string formatString function") { + val df = Seq(("aa%d%s", 123, "cc")).toDF("a", "b", "c") + + checkAnswer( + df.select(format_string("aa%d%s", $"b", $"c")), + Row("aa123cc")) + + checkAnswer( + df.selectExpr("printf(a, b, c)"), + Row("aa123cc")) + } + + test("string instr function") { + val df = Seq(("aaads", "aa", "zz")).toDF("a", "b", "c") + + checkAnswer( + df.select(instr($"a", "aa")), + Row(1)) + + checkAnswer( + df.selectExpr("instr(a, b)"), + Row(1)) + } + + test("string locate function") { + val df = Seq(("aaads", "aa", "zz", 1)).toDF("a", "b", "c", "d") + + checkAnswer( + df.select(locate("aa", $"a"), locate("aa", $"a", 1)), + Row(1, 2)) + + checkAnswer( + df.selectExpr("locate(b, a)", "locate(b, a, d)"), + Row(1, 2)) + } + + test("string padding functions") { + val df = Seq(("hi", 5, "??")).toDF("a", "b", "c") + + checkAnswer( + df.select(lpad($"a", 1, "c"), lpad($"a", 5, "??"), rpad($"a", 1, "c"), rpad($"a", 5, "??")), + Row("h", "???hi", "h", "hi???")) + + checkAnswer( + df.selectExpr("lpad(a, b, c)", "rpad(a, b, c)", "lpad(a, 1, c)", "rpad(a, 1, c)"), + Row("???hi", "hi???", "h", "h")) + } + + test("string repeat function") { + val df = Seq(("hi", 2)).toDF("a", "b") + + checkAnswer( + df.select(repeat($"a", 2)), + Row("hihi")) + + checkAnswer( + df.selectExpr("repeat(a, 2)", "repeat(a, b)"), + Row("hihi", "hihi")) + } + + test("string reverse function") { + val df = Seq(("hi", "hhhi")).toDF("a", "b") + + checkAnswer( + df.select(reverse($"a"), reverse($"b")), + Row("ih", "ihhh")) + + checkAnswer( + df.selectExpr("reverse(b)"), + Row("ihhh")) + } + + test("string space function") { + val df = Seq((2, 3)).toDF("a", "b") + + checkAnswer( + df.selectExpr("space(b)"), + Row(" ")) + } + + test("string split function") { + val df = Seq(("aa2bb3cc", "[1-9]+")).toDF("a", "b") + + checkAnswer( + df.select(split($"a", "[1-9]+")), + Row(Seq("aa", "bb", "cc"))) + + checkAnswer( + df.selectExpr("split(a, '[1-9]+')"), + Row(Seq("aa", "bb", "cc"))) + } + + test("string / binary length function") { + val df = Seq(("123", Array[Byte](1, 2, 3, 4), 123)).toDF("a", "b", "c") + checkAnswer( + df.select(length($"a"), length($"b")), + Row(3, 4)) + + checkAnswer( + df.selectExpr("length(a)", "length(b)"), + Row(3, 4)) + + intercept[AnalysisException] { + checkAnswer( + df.selectExpr("length(c)"), // int type of the argument is unacceptable + Row("5.0000")) + } + } + + test("number format function") { + val tuple = + ("aa", 1.asInstanceOf[Byte], 2.asInstanceOf[Short], + 3.13223f, 4, 5L, 6.48173d, Decimal(7.128381)) + val df = + Seq(tuple) + .toDF( + "a", // string "aa" + "b", // byte 1 + "c", // short 2 + "d", // float 3.13223f + "e", // integer 4 + "f", // long 5L + "g", // double 6.48173d + "h") // decimal 7.128381 + + checkAnswer( + df.select(format_number($"f", 4)), + Row("5.0000")) + + checkAnswer( + df.selectExpr("format_number(b, e)"), // convert the 1st argument to integer + Row("1.0000")) + + checkAnswer( + df.selectExpr("format_number(c, e)"), // convert the 1st argument to integer + Row("2.0000")) + + checkAnswer( + df.selectExpr("format_number(d, e)"), // convert the 1st argument to double + Row("3.1322")) + + checkAnswer( + df.selectExpr("format_number(e, e)"), // not convert anything + Row("4.0000")) + + checkAnswer( + df.selectExpr("format_number(f, e)"), // not convert anything + Row("5.0000")) + + checkAnswer( + df.selectExpr("format_number(g, e)"), // not convert anything + Row("6.4817")) + + checkAnswer( + df.selectExpr("format_number(h, e)"), // not convert anything + Row("7.1284")) + + intercept[AnalysisException] { + checkAnswer( + df.selectExpr("format_number(a, e)"), // string type of the 1st argument is unacceptable + Row("5.0000")) + } + + intercept[AnalysisException] { + checkAnswer( + df.selectExpr("format_number(e, g)"), // decimal type of the 2nd argument is unacceptable + Row("5.0000")) + } + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/UnsafeRowSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/UnsafeRowSuite.scala new file mode 100644 index 0000000000000..ad3bb1744cb3c --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/UnsafeRowSuite.scala @@ -0,0 +1,70 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql + +import java.io.ByteArrayOutputStream + +import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions.{UnsafeRow, UnsafeProjection} +import org.apache.spark.sql.types.{DataType, IntegerType, StringType} +import org.apache.spark.unsafe.PlatformDependent +import org.apache.spark.unsafe.memory.MemoryAllocator +import org.apache.spark.unsafe.types.UTF8String + +class UnsafeRowSuite extends SparkFunSuite { + test("writeToStream") { + val row = InternalRow.apply(UTF8String.fromString("hello"), UTF8String.fromString("world"), 123) + val arrayBackedUnsafeRow: UnsafeRow = + UnsafeProjection.create(Array[DataType](StringType, StringType, IntegerType)).apply(row) + assert(arrayBackedUnsafeRow.getBaseObject.isInstanceOf[Array[Byte]]) + val bytesFromArrayBackedRow: Array[Byte] = { + val baos = new ByteArrayOutputStream() + arrayBackedUnsafeRow.writeToStream(baos, null) + baos.toByteArray + } + val bytesFromOffheapRow: Array[Byte] = { + val offheapRowPage = MemoryAllocator.UNSAFE.allocate(arrayBackedUnsafeRow.getSizeInBytes) + try { + PlatformDependent.copyMemory( + arrayBackedUnsafeRow.getBaseObject, + arrayBackedUnsafeRow.getBaseOffset, + offheapRowPage.getBaseObject, + offheapRowPage.getBaseOffset, + arrayBackedUnsafeRow.getSizeInBytes + ) + val offheapUnsafeRow: UnsafeRow = new UnsafeRow() + offheapUnsafeRow.pointTo( + offheapRowPage.getBaseObject, + offheapRowPage.getBaseOffset, + 3, // num fields + arrayBackedUnsafeRow.getSizeInBytes + ) + assert(offheapUnsafeRow.getBaseObject === null) + val baos = new ByteArrayOutputStream() + val writeBuffer = new Array[Byte](1024) + offheapUnsafeRow.writeToStream(baos, writeBuffer) + baos.toByteArray + } finally { + MemoryAllocator.UNSAFE.free(offheapRowPage) + } + } + + assert(bytesFromArrayBackedRow === bytesFromOffheapRow) + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/AggregateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/AggregateSuite.scala new file mode 100644 index 0000000000000..20def6bef0c17 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/AggregateSuite.scala @@ -0,0 +1,48 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution + +import org.apache.spark.sql.SQLConf +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.test.TestSQLContext + +class AggregateSuite extends SparkPlanTest { + + test("SPARK-8357 unsafe aggregation path should not leak memory with empty input") { + val codegenDefault = TestSQLContext.getConf(SQLConf.CODEGEN_ENABLED) + val unsafeDefault = TestSQLContext.getConf(SQLConf.UNSAFE_ENABLED) + try { + TestSQLContext.setConf(SQLConf.CODEGEN_ENABLED, true) + TestSQLContext.setConf(SQLConf.UNSAFE_ENABLED, true) + val df = Seq.empty[(Int, Int)].toDF("a", "b") + checkAnswer( + df, + GeneratedAggregate( + partial = true, + Seq(df.col("b").expr), + Seq(Alias(Count(df.col("a").expr), "cnt")()), + unsafeEnabled = true, + _: SparkPlan), + Seq.empty + ) + } finally { + TestSQLContext.setConf(SQLConf.CODEGEN_ENABLED, codegenDefault) + TestSQLContext.setConf(SQLConf.UNSAFE_ENABLED, unsafeDefault) + } + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/ExchangeSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/ExchangeSuite.scala new file mode 100644 index 0000000000000..79e903c2bbd40 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/ExchangeSuite.scala @@ -0,0 +1,32 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution + +import org.apache.spark.sql.Row +import org.apache.spark.sql.catalyst.plans.physical.SinglePartition + +class ExchangeSuite extends SparkPlanTest { + test("shuffling UnsafeRows in exchange") { + val input = (1 to 1000).map(Tuple1.apply) + checkAnswer( + input.toDF(), + plan => ConvertToSafe(Exchange(SinglePartition, ConvertToUnsafe(plan))), + input.map(Row.fromTuple) + ) + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala index 3dd24130af81a..3d71deb13e884 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala @@ -20,6 +20,7 @@ package org.apache.spark.sql.execution import org.apache.spark.SparkFunSuite import org.apache.spark.sql.TestData._ import org.apache.spark.sql.catalyst.plans._ +import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.execution.joins.{BroadcastHashJoin, ShuffledHashJoin} import org.apache.spark.sql.functions._ import org.apache.spark.sql.test.TestSQLContext._ @@ -30,6 +31,20 @@ import org.apache.spark.sql.{Row, SQLConf, execution} class PlannerSuite extends SparkFunSuite { + private def testPartialAggregationPlan(query: LogicalPlan): Unit = { + val plannedOption = HashAggregation(query).headOption.orElse(Aggregation(query).headOption) + val planned = + plannedOption.getOrElse( + fail(s"Could query play aggregation query $query. Is it an aggregation query?")) + val aggregations = planned.collect { case n if n.nodeName contains "Aggregate" => n } + + // For the new aggregation code path, there will be three aggregate operator for + // distinct aggregations. + assert( + aggregations.size == 2 || aggregations.size == 3, + s"The plan of query $query does not have partial aggregations.") + } + test("unions are collapsed") { val query = testData.unionAll(testData).unionAll(testData).logicalPlan val planned = BasicOperators(query).head @@ -42,23 +57,18 @@ class PlannerSuite extends SparkFunSuite { test("count is partially aggregated") { val query = testData.groupBy('value).agg(count('key)).queryExecution.analyzed - val planned = HashAggregation(query).head - val aggregations = planned.collect { case n if n.nodeName contains "Aggregate" => n } - - assert(aggregations.size === 2) + testPartialAggregationPlan(query) } test("count distinct is partially aggregated") { val query = testData.groupBy('value).agg(countDistinct('key)).queryExecution.analyzed - val planned = HashAggregation(query) - assert(planned.nonEmpty) + testPartialAggregationPlan(query) } test("mixed aggregates are partially aggregated") { val query = testData.groupBy('value).agg(count('value), countDistinct('key)).queryExecution.analyzed - val planned = HashAggregation(query) - assert(planned.nonEmpty) + testPartialAggregationPlan(query) } test("sizeInBytes estimation of limit operator for broadcast hash join optimization") { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/RowFormatConvertersSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/RowFormatConvertersSuite.scala new file mode 100644 index 0000000000000..7b75f755918c1 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/RowFormatConvertersSuite.scala @@ -0,0 +1,91 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution + +import org.apache.spark.sql.Row +import org.apache.spark.sql.catalyst.CatalystTypeConverters +import org.apache.spark.sql.catalyst.expressions.IsNull +import org.apache.spark.sql.test.TestSQLContext + +class RowFormatConvertersSuite extends SparkPlanTest { + + private def getConverters(plan: SparkPlan): Seq[SparkPlan] = plan.collect { + case c: ConvertToUnsafe => c + case c: ConvertToSafe => c + } + + private val outputsSafe = ExternalSort(Nil, false, PhysicalRDD(Seq.empty, null)) + assert(!outputsSafe.outputsUnsafeRows) + private val outputsUnsafe = UnsafeExternalSort(Nil, false, PhysicalRDD(Seq.empty, null)) + assert(outputsUnsafe.outputsUnsafeRows) + + test("planner should insert unsafe->safe conversions when required") { + val plan = Limit(10, outputsUnsafe) + val preparedPlan = TestSQLContext.prepareForExecution.execute(plan) + assert(preparedPlan.children.head.isInstanceOf[ConvertToSafe]) + } + + test("filter can process unsafe rows") { + val plan = Filter(IsNull(null), outputsUnsafe) + val preparedPlan = TestSQLContext.prepareForExecution.execute(plan) + assert(getConverters(preparedPlan).isEmpty) + assert(preparedPlan.outputsUnsafeRows) + } + + test("filter can process safe rows") { + val plan = Filter(IsNull(null), outputsSafe) + val preparedPlan = TestSQLContext.prepareForExecution.execute(plan) + assert(getConverters(preparedPlan).isEmpty) + assert(!preparedPlan.outputsUnsafeRows) + } + + test("execute() fails an assertion if inputs rows are of different formats") { + val e = intercept[AssertionError] { + Union(Seq(outputsSafe, outputsUnsafe)).execute() + } + assert(e.getMessage.contains("format")) + } + + test("union requires all of its input rows' formats to agree") { + val plan = Union(Seq(outputsSafe, outputsUnsafe)) + assert(plan.canProcessSafeRows && plan.canProcessUnsafeRows) + val preparedPlan = TestSQLContext.prepareForExecution.execute(plan) + assert(preparedPlan.outputsUnsafeRows) + } + + test("union can process safe rows") { + val plan = Union(Seq(outputsSafe, outputsSafe)) + val preparedPlan = TestSQLContext.prepareForExecution.execute(plan) + assert(!preparedPlan.outputsUnsafeRows) + } + + test("union can process unsafe rows") { + val plan = Union(Seq(outputsUnsafe, outputsUnsafe)) + val preparedPlan = TestSQLContext.prepareForExecution.execute(plan) + assert(preparedPlan.outputsUnsafeRows) + } + + test("round trip with ConvertToUnsafe and ConvertToSafe") { + val input = Seq(("hello", 1), ("world", 2)) + checkAnswer( + TestSQLContext.createDataFrame(input), + plan => ConvertToSafe(ConvertToUnsafe(plan)), + input.map(Row.fromTuple) + ) + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkSqlSerializer2Suite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkSqlSerializer2Suite.scala index 71f6b26bcd01a..4a53fadd7e099 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkSqlSerializer2Suite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkSqlSerializer2Suite.scala @@ -132,8 +132,8 @@ abstract class SparkSqlSerializer2Suite extends QueryTest with BeforeAndAfterAll expectedSerializerClass: Class[T]): Unit = { executedPlan.foreach { case exchange: Exchange => - val shuffledRDD = exchange.execute().firstParent.asInstanceOf[ShuffledRDD[_, _, _]] - val dependency = shuffledRDD.getDependencies.head.asInstanceOf[ShuffleDependency[_, _, _]] + val shuffledRDD = exchange.execute() + val dependency = shuffledRDD.dependencies.head.asInstanceOf[ShuffleDependency[_, _, _]] val serializerNotSetMessage = s"Expected $expectedSerializerClass as the serializer of Exchange. " + s"However, the serializer was not set." diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeExternalSortSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeExternalSortSuite.scala index 4f4c1f28564cb..7a4baa9e4a49d 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeExternalSortSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeExternalSortSuite.scala @@ -39,7 +39,7 @@ class UnsafeExternalSortSuite extends SparkPlanTest with BeforeAndAfterAll { ignore("sort followed by limit should not leak memory") { // TODO: this test is going to fail until we implement a proper iterator interface // with a close() method. - TestSQLContext.sparkContext.conf.set("spark.unsafe.exceptionOnMemoryLeak", "true") + TestSQLContext.sparkContext.conf.set("spark.unsafe.exceptionOnMemoryLeak", "false") checkThatPlansAgree( (1 to 100).map(v => Tuple1(v)).toDF("a"), (child: SparkPlan) => Limit(10, UnsafeExternalSort('a.asc :: Nil, true, child)), @@ -58,7 +58,7 @@ class UnsafeExternalSortSuite extends SparkPlanTest with BeforeAndAfterAll { sortAnswers = false ) } finally { - TestSQLContext.sparkContext.conf.set("spark.unsafe.exceptionOnMemoryLeak", "true") + TestSQLContext.sparkContext.conf.set("spark.unsafe.exceptionOnMemoryLeak", "false") } } @@ -83,11 +83,7 @@ class UnsafeExternalSortSuite extends SparkPlanTest with BeforeAndAfterAll { randomDataGenerator <- RandomDataGenerator.forType(dataType, nullable) ) { test(s"sorting on $dataType with nullable=$nullable, sortOrder=$sortOrder") { - val inputData = Seq.fill(1000)(randomDataGenerator()).filter { - case d: Double => !d.isNaN - case f: Float => !java.lang.Float.isNaN(f) - case x => true - } + val inputData = Seq.fill(1000)(randomDataGenerator()) val inputDf = TestSQLContext.createDataFrame( TestSQLContext.sparkContext.parallelize(Random.shuffle(inputData).map(v => Row(v))), StructType(StructField("a", dataType, nullable = true) :: Nil) @@ -95,7 +91,8 @@ class UnsafeExternalSortSuite extends SparkPlanTest with BeforeAndAfterAll { assert(UnsafeExternalSort.supportsSchema(inputDf.schema)) checkThatPlansAgree( inputDf, - UnsafeExternalSort(sortOrder, global = true, _: SparkPlan, testSpillFrequency = 23), + plan => ConvertToSafe( + UnsafeExternalSort(sortOrder, global = true, plan: SparkPlan, testSpillFrequency = 23)), Sort(sortOrder, global = true, _: SparkPlan), sortAnswers = false ) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeRowSerializerSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeRowSerializerSuite.scala new file mode 100644 index 0000000000000..a1e1695717e23 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeRowSerializerSuite.scala @@ -0,0 +1,72 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution + +import java.io.{ByteArrayInputStream, ByteArrayOutputStream} + +import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.Row +import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow} +import org.apache.spark.sql.catalyst.expressions.{UnsafeRow, UnsafeRowConverter} +import org.apache.spark.sql.types._ +import org.apache.spark.unsafe.PlatformDependent + +class UnsafeRowSerializerSuite extends SparkFunSuite { + + private def toUnsafeRow(row: Row, schema: Array[DataType]): UnsafeRow = { + val internalRow = CatalystTypeConverters.convertToCatalyst(row).asInstanceOf[InternalRow] + val rowConverter = new UnsafeRowConverter(schema) + val rowSizeInBytes = rowConverter.getSizeRequirement(internalRow) + val byteArray = new Array[Byte](rowSizeInBytes) + rowConverter.writeRow( + internalRow, byteArray, PlatformDependent.BYTE_ARRAY_OFFSET, rowSizeInBytes) + val unsafeRow = new UnsafeRow() + unsafeRow.pointTo(byteArray, PlatformDependent.BYTE_ARRAY_OFFSET, row.length, rowSizeInBytes) + unsafeRow + } + + ignore("toUnsafeRow() test helper method") { + // This currently doesnt work because the generic getter throws an exception. + val row = Row("Hello", 123) + val unsafeRow = toUnsafeRow(row, Array(StringType, IntegerType)) + assert(row.getString(0) === unsafeRow.get(0).toString) + assert(row.getInt(1) === unsafeRow.getInt(1)) + } + + test("basic row serialization") { + val rows = Seq(Row("Hello", 1), Row("World", 2)) + val unsafeRows = rows.map(row => toUnsafeRow(row, Array(StringType, IntegerType))) + val serializer = new UnsafeRowSerializer(numFields = 2).newInstance() + val baos = new ByteArrayOutputStream() + val serializerStream = serializer.serializeStream(baos) + for (unsafeRow <- unsafeRows) { + serializerStream.writeKey(0) + serializerStream.writeValue(unsafeRow) + } + serializerStream.close() + val deserializerIter = serializer.deserializeStream( + new ByteArrayInputStream(baos.toByteArray)).asKeyValueIterator + for (expectedRow <- unsafeRows) { + val actualRow = deserializerIter.next().asInstanceOf[(Integer, UnsafeRow)]._2 + assert(expectedRow.getSizeInBytes === actualRow.getSizeInBytes) + assert(expectedRow.getString(0) === actualRow.getString(0)) + assert(expectedRow.getInt(1) === actualRow.getInt(1)) + } + assert(!deserializerIter.hasNext) + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/expression/NondeterministicSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/expression/NondeterministicSuite.scala new file mode 100644 index 0000000000000..99e11fd64b2b9 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/expression/NondeterministicSuite.scala @@ -0,0 +1,32 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.expression + +import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.catalyst.expressions. ExpressionEvalHelper +import org.apache.spark.sql.execution.expressions.{SparkPartitionID, MonotonicallyIncreasingID} + +class NondeterministicSuite extends SparkFunSuite with ExpressionEvalHelper { + test("MonotonicallyIncreasingID") { + checkEvaluation(MonotonicallyIncreasingID(), 0) + } + + test("SparkPartitionID") { + checkEvaluation(SparkPartitionID, 0) + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/HashedRelationSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/HashedRelationSuite.scala index 9d9858b1c6151..9dd2220f0967e 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/HashedRelationSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/HashedRelationSuite.scala @@ -19,7 +19,9 @@ package org.apache.spark.sql.execution.joins import org.apache.spark.SparkFunSuite import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.expressions.Projection +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.execution.SparkSqlSerializer +import org.apache.spark.sql.types.{StructField, StructType, IntegerType} import org.apache.spark.util.collection.CompactBuffer @@ -35,13 +37,13 @@ class HashedRelationSuite extends SparkFunSuite { val hashed = HashedRelation(data.iterator, keyProjection) assert(hashed.isInstanceOf[GeneralHashedRelation]) - assert(hashed.get(data(0)) == CompactBuffer[InternalRow](data(0))) - assert(hashed.get(data(1)) == CompactBuffer[InternalRow](data(1))) + assert(hashed.get(data(0)) === CompactBuffer[InternalRow](data(0))) + assert(hashed.get(data(1)) === CompactBuffer[InternalRow](data(1))) assert(hashed.get(InternalRow(10)) === null) val data2 = CompactBuffer[InternalRow](data(2)) data2 += data(2) - assert(hashed.get(data(2)) == data2) + assert(hashed.get(data(2)) === data2) } test("UniqueKeyHashedRelation") { @@ -49,15 +51,40 @@ class HashedRelationSuite extends SparkFunSuite { val hashed = HashedRelation(data.iterator, keyProjection) assert(hashed.isInstanceOf[UniqueKeyHashedRelation]) - assert(hashed.get(data(0)) == CompactBuffer[InternalRow](data(0))) - assert(hashed.get(data(1)) == CompactBuffer[InternalRow](data(1))) - assert(hashed.get(data(2)) == CompactBuffer[InternalRow](data(2))) + assert(hashed.get(data(0)) === CompactBuffer[InternalRow](data(0))) + assert(hashed.get(data(1)) === CompactBuffer[InternalRow](data(1))) + assert(hashed.get(data(2)) === CompactBuffer[InternalRow](data(2))) assert(hashed.get(InternalRow(10)) === null) val uniqHashed = hashed.asInstanceOf[UniqueKeyHashedRelation] - assert(uniqHashed.getValue(data(0)) == data(0)) - assert(uniqHashed.getValue(data(1)) == data(1)) - assert(uniqHashed.getValue(data(2)) == data(2)) - assert(uniqHashed.getValue(InternalRow(10)) == null) + assert(uniqHashed.getValue(data(0)) === data(0)) + assert(uniqHashed.getValue(data(1)) === data(1)) + assert(uniqHashed.getValue(data(2)) === data(2)) + assert(uniqHashed.getValue(InternalRow(10)) === null) + } + + test("UnsafeHashedRelation") { + val data = Array(InternalRow(0), InternalRow(1), InternalRow(2), InternalRow(2)) + val buildKey = Seq(BoundReference(0, IntegerType, false)) + val schema = StructType(StructField("a", IntegerType, true) :: Nil) + val hashed = UnsafeHashedRelation(data.iterator, buildKey, schema, 1) + assert(hashed.isInstanceOf[UnsafeHashedRelation]) + + val toUnsafeKey = UnsafeProjection.create(schema) + val unsafeData = data.map(toUnsafeKey(_).copy()).toArray + assert(hashed.get(unsafeData(0)) === CompactBuffer[InternalRow](unsafeData(0))) + assert(hashed.get(unsafeData(1)) === CompactBuffer[InternalRow](unsafeData(1))) + assert(hashed.get(toUnsafeKey(InternalRow(10))) === null) + + val data2 = CompactBuffer[InternalRow](unsafeData(2).copy()) + data2 += unsafeData(2).copy() + assert(hashed.get(unsafeData(2)) === data2) + + val hashed2 = SparkSqlSerializer.deserialize(SparkSqlSerializer.serialize(hashed)) + .asInstanceOf[UnsafeHashedRelation] + assert(hashed2.get(unsafeData(0)) === CompactBuffer[InternalRow](unsafeData(0))) + assert(hashed2.get(unsafeData(1)) === CompactBuffer[InternalRow](unsafeData(1))) + assert(hashed2.get(toUnsafeKey(InternalRow(10))) === null) + assert(hashed2.get(unsafeData(2)) === data2) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/SemiJoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/SemiJoinSuite.scala new file mode 100644 index 0000000000000..927e85a7db3dc --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/SemiJoinSuite.scala @@ -0,0 +1,74 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.joins + +import org.apache.spark.sql.Row +import org.apache.spark.sql.catalyst.dsl.expressions._ +import org.apache.spark.sql.catalyst.expressions.{LessThan, Expression} +import org.apache.spark.sql.execution.{SparkPlan, SparkPlanTest} + + +class SemiJoinSuite extends SparkPlanTest{ + val left = Seq( + (1, 2.0), + (1, 2.0), + (2, 1.0), + (2, 1.0), + (3, 3.0) + ).toDF("a", "b") + + val right = Seq( + (2, 3.0), + (2, 3.0), + (3, 2.0), + (4, 1.0) + ).toDF("c", "d") + + val leftKeys: List[Expression] = 'a :: Nil + val rightKeys: List[Expression] = 'c :: Nil + val condition = Some(LessThan('b, 'd)) + + test("left semi join hash") { + checkAnswer2(left, right, (left: SparkPlan, right: SparkPlan) => + LeftSemiJoinHash(leftKeys, rightKeys, left, right, condition), + Seq( + (2, 1.0), + (2, 1.0) + ).map(Row.fromTuple)) + } + + test("left semi join BNL") { + checkAnswer2(left, right, (left: SparkPlan, right: SparkPlan) => + LeftSemiJoinBNL(left, right, condition), + Seq( + (1, 2.0), + (1, 2.0), + (2, 1.0), + (2, 1.0) + ).map(Row.fromTuple)) + } + + test("broadcast left semi join hash") { + checkAnswer2(left, right, (left: SparkPlan, right: SparkPlan) => + BroadcastLeftSemiJoinHash(leftKeys, rightKeys, left, right, condition), + Seq( + (2, 1.0), + (2, 1.0) + ).map(Row.fromTuple)) + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala index 566a52dc1b784..0f82f13088d39 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala @@ -26,6 +26,7 @@ import org.scalatest.BeforeAndAfter import org.apache.spark.SparkFunSuite import org.apache.spark.sql.types._ +import org.apache.spark.util.Utils class JDBCSuite extends SparkFunSuite with BeforeAndAfter { val url = "jdbc:h2:mem:testdb0" @@ -46,7 +47,7 @@ class JDBCSuite extends SparkFunSuite with BeforeAndAfter { import ctx.sql before { - Class.forName("org.h2.Driver") + Utils.classForName("org.h2.Driver") // Extra properties that will be specified for our database. We need these to test // usage of parameters from OPTIONS clause in queries. val properties = new Properties() diff --git a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCWriteSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCWriteSuite.scala index d949ef42267ec..84b52ca2c733c 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCWriteSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCWriteSuite.scala @@ -25,6 +25,7 @@ import org.scalatest.BeforeAndAfter import org.apache.spark.SparkFunSuite import org.apache.spark.sql.{SaveMode, Row} import org.apache.spark.sql.types._ +import org.apache.spark.util.Utils class JDBCWriteSuite extends SparkFunSuite with BeforeAndAfter { val url = "jdbc:h2:mem:testdb2" @@ -41,7 +42,7 @@ class JDBCWriteSuite extends SparkFunSuite with BeforeAndAfter { import ctx.sql before { - Class.forName("org.h2.Driver") + Utils.classForName("org.h2.Driver") conn = DriverManager.getConnection(url) conn.prepareStatement("create schema test").executeUpdate() diff --git a/sql/core/src/test/scala/org/apache/spark/sql/json/JsonSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/json/JsonSuite.scala index 8204a584179bb..1d04513a44672 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/json/JsonSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/json/JsonSuite.scala @@ -26,8 +26,8 @@ import org.scalactic.Tolerance._ import org.apache.spark.sql.{QueryTest, Row, SQLConf} import org.apache.spark.sql.TestData._ import org.apache.spark.sql.catalyst.util.DateTimeUtils +import org.apache.spark.sql.execution.datasources.LogicalRelation import org.apache.spark.sql.json.InferSchema.compatibleType -import org.apache.spark.sql.sources.LogicalRelation import org.apache.spark.sql.types._ import org.apache.spark.util.Utils @@ -1079,28 +1079,23 @@ class JsonSuite extends QueryTest with TestJsonData { } test("SPARK-7565 MapType in JsonRDD") { - val useStreaming = ctx.conf.useJacksonStreamingAPI val oldColumnNameOfCorruptRecord = ctx.conf.columnNameOfCorruptRecord ctx.setConf(SQLConf.COLUMN_NAME_OF_CORRUPT_RECORD, "_unparsed") val schemaWithSimpleMap = StructType( StructField("map", MapType(StringType, IntegerType, true), false) :: Nil) - try{ - for (useStreaming <- List(true, false)) { - ctx.setConf(SQLConf.USE_JACKSON_STREAMING_API, useStreaming) - val temp = Utils.createTempDir().getPath - - val df = ctx.read.schema(schemaWithSimpleMap).json(mapType1) - df.write.mode("overwrite").parquet(temp) - // order of MapType is not defined - assert(ctx.read.parquet(temp).count() == 5) - - val df2 = ctx.read.json(corruptRecords) - df2.write.mode("overwrite").parquet(temp) - checkAnswer(ctx.read.parquet(temp), df2.collect()) - } + try { + val temp = Utils.createTempDir().getPath + + val df = ctx.read.schema(schemaWithSimpleMap).json(mapType1) + df.write.mode("overwrite").parquet(temp) + // order of MapType is not defined + assert(ctx.read.parquet(temp).count() == 5) + + val df2 = ctx.read.json(corruptRecords) + df2.write.mode("overwrite").parquet(temp) + checkAnswer(ctx.read.parquet(temp), df2.collect()) } finally { - ctx.setConf(SQLConf.USE_JACKSON_STREAMING_API, useStreaming) ctx.setConf(SQLConf.COLUMN_NAME_OF_CORRUPT_RECORD, oldColumnNameOfCorruptRecord) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetFilterSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetFilterSuite.scala index a2763c78b6450..23df102cd951d 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetFilterSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetFilterSuite.scala @@ -24,7 +24,7 @@ import org.apache.parquet.filter2.predicate.{FilterPredicate, Operators} import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.planning.PhysicalOperation -import org.apache.spark.sql.sources.LogicalRelation +import org.apache.spark.sql.execution.datasources.LogicalRelation import org.apache.spark.sql.types._ import org.apache.spark.sql.{Column, DataFrame, QueryTest, Row, SQLConf} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetPartitionDiscoverySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetPartitionDiscoverySuite.scala index d0ebb11b063f0..4f98776b91160 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetPartitionDiscoverySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetPartitionDiscoverySuite.scala @@ -28,11 +28,11 @@ import org.apache.hadoop.fs.Path import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.Literal -import org.apache.spark.sql.sources.PartitioningUtils._ -import org.apache.spark.sql.sources.{LogicalRelation, Partition, PartitionSpec} +import org.apache.spark.sql.execution.datasources.{LogicalRelation, PartitionSpec, Partition, PartitioningUtils} import org.apache.spark.sql.types._ import org.apache.spark.sql._ import org.apache.spark.unsafe.types.UTF8String +import PartitioningUtils._ // The data where the partitioning key exists only in the directory structure. case class ParquetData(intField: Int, stringField: String) @@ -447,7 +447,12 @@ class ParquetPartitionDiscoverySuite extends QueryTest with ParquetTest { (1 to 10).map(i => (i, i.toString)).toDF("intField", "stringField"), makePartitionDir(base, defaultPartitionName, "pi" -> 2)) - sqlContext.read.format("parquet").load(base.getCanonicalPath).registerTempTable("t") + sqlContext + .read + .option("mergeSchema", "true") + .format("parquet") + .load(base.getCanonicalPath) + .registerTempTable("t") withTempTable("t") { checkAnswer( @@ -583,4 +588,15 @@ class ParquetPartitionDiscoverySuite extends QueryTest with ParquetTest { Seq("a", "a, b"), Seq("file:/tmp/foo/a=1", "file:/tmp/foo/a=1/b=foo"))) } + + test("Parallel partition discovery") { + withTempPath { dir => + withSQLConf(SQLConf.PARALLEL_PARTITION_DISCOVERY_THRESHOLD.key -> "1") { + val path = dir.getCanonicalPath + val df = sqlContext.range(5).select('id as 'a, 'id as 'b, 'id as 'c).coalesce(1) + df.write.partitionBy("b", "c").parquet(path) + checkAnswer(sqlContext.read.parquet(path), df) + } + } + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/CreateTableAsSelectSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/CreateTableAsSelectSuite.scala index a71088430bfd5..1907e643c85dd 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/sources/CreateTableAsSelectSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/CreateTableAsSelectSuite.scala @@ -22,6 +22,7 @@ import java.io.{File, IOException} import org.scalatest.BeforeAndAfterAll import org.apache.spark.sql.AnalysisException +import org.apache.spark.sql.execution.datasources.DDLException import org.apache.spark.util.Utils class CreateTableAsSelectSuite extends DataSourceTest with BeforeAndAfterAll { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/ResolvedDataSourceSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/ResolvedDataSourceSuite.scala index 296b0d6f74a0c..3cbf5467b253a 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/sources/ResolvedDataSourceSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/ResolvedDataSourceSuite.scala @@ -18,6 +18,7 @@ package org.apache.spark.sql.sources import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.execution.datasources.ResolvedDataSource class ResolvedDataSourceSuite extends SparkFunSuite { diff --git a/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2Suites.scala b/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2Suites.scala index 301aa5a6411e2..39b31523e07cb 100644 --- a/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2Suites.scala +++ b/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2Suites.scala @@ -417,7 +417,7 @@ object ServerMode extends Enumeration { } abstract class HiveThriftJdbcTest extends HiveThriftServer2Test { - Class.forName(classOf[HiveDriver].getCanonicalName) + Utils.classForName(classOf[HiveDriver].getCanonicalName) private def jdbcUri = if (mode == ServerMode.http) { s"""jdbc:hive2://localhost:$serverPort/ diff --git a/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala b/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala index c884c399281a8..b12b3838e615c 100644 --- a/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala +++ b/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala @@ -115,6 +115,13 @@ class HiveCompatibilitySuite extends HiveQueryFileTest with BeforeAndAfter { // This test is totally fine except that it includes wrong queries and expects errors, but error // message format in Hive and Spark SQL differ. Should workaround this later. "udf_to_unix_timestamp", + // we can cast dates likes '2015-03-18' to a timestamp and extract the seconds. + // Hive returns null for second('2015-03-18') + "udf_second", + // we can cast dates likes '2015-03-18' to a timestamp and extract the minutes. + // Hive returns null for minute('2015-03-18') + "udf_minute", + // Cant run without local map/reduce. "index_auto_update", @@ -221,9 +228,6 @@ class HiveCompatibilitySuite extends HiveQueryFileTest with BeforeAndAfter { "udf_when", "udf_case", - // Needs constant object inspectors - "udf_round", - // the table src(key INT, value STRING) is not the same as HIVE unittest. In Hive // is src(key STRING, value STRING), and in the reflect.q, it failed in // Integer.valueOf, which expect the first argument passed as STRING type not INT. @@ -257,7 +261,10 @@ class HiveCompatibilitySuite extends HiveQueryFileTest with BeforeAndAfter { // Spark SQL use Long for TimestampType, lose the precision under 1us "timestamp_1", "timestamp_2", - "timestamp_udf" + "timestamp_udf", + + // Unlike Hive, we do support log base in (0, 1.0], therefore disable this + "udf7" ) /** @@ -819,19 +826,18 @@ class HiveCompatibilitySuite extends HiveQueryFileTest with BeforeAndAfter { "udf2", "udf5", "udf6", - // "udf7", turn this on after we figure out null vs nan vs infinity "udf8", "udf9", "udf_10_trims", "udf_E", "udf_PI", "udf_abs", - // "udf_acos", turn this on after we figure out null vs nan vs infinity + "udf_acos", "udf_add", "udf_array", "udf_array_contains", "udf_ascii", - // "udf_asin", turn this on after we figure out null vs nan vs infinity + "udf_asin", "udf_atan", "udf_avg", "udf_bigint", @@ -895,7 +901,6 @@ class HiveCompatibilitySuite extends HiveQueryFileTest with BeforeAndAfter { "udf_lpad", "udf_ltrim", "udf_map", - "udf_minute", "udf_modulo", "udf_month", "udf_named_struct", @@ -919,10 +924,9 @@ class HiveCompatibilitySuite extends HiveQueryFileTest with BeforeAndAfter { "udf_repeat", "udf_rlike", "udf_round", - // "udf_round_3", TODO: FIX THIS failed due to cast exception + "udf_round_3", "udf_rpad", "udf_rtrim", - "udf_second", "udf_sign", "udf_sin", "udf_smallint", diff --git a/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveWindowFunctionQuerySuite.scala b/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveWindowFunctionQuerySuite.scala index 31a49a3683338..24a758f53170a 100644 --- a/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveWindowFunctionQuerySuite.scala +++ b/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveWindowFunctionQuerySuite.scala @@ -833,6 +833,7 @@ abstract class HiveWindowFunctionQueryFileBaseSuite "windowing_adjust_rowcontainer_sz" ) + // Only run those query tests in the realWhileList (do not try other ignored query files). override def testCases: Seq[(String, File)] = super.testCases.filter { case (name, _) => realWhiteList.contains(name) } diff --git a/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/SortMergeCompatibilitySuite.scala b/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/SortMergeCompatibilitySuite.scala index f458567e5d7ea..1fe4fe9629c02 100644 --- a/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/SortMergeCompatibilitySuite.scala +++ b/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/SortMergeCompatibilitySuite.scala @@ -17,6 +17,8 @@ package org.apache.spark.sql.hive.execution +import java.io.File + import org.apache.spark.sql.SQLConf import org.apache.spark.sql.hive.test.TestHive @@ -159,4 +161,9 @@ class SortMergeCompatibilitySuite extends HiveCompatibilitySuite { "join_reorder4", "join_star" ) + + // Only run those query tests in the realWhileList (do not try other ignored query files). + override def testCases: Seq[(String, File)] = super.testCases.filter { + case (name, _) => realWhiteList.contains(name) + } } diff --git a/sql/hive/pom.xml b/sql/hive/pom.xml index a17546d706248..b00f320318be0 100644 --- a/sql/hive/pom.xml +++ b/sql/hive/pom.xml @@ -133,7 +133,6 @@ - src/test/scala compatibility/src/test/scala diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala index 4684d48aff889..4cdb83c5116f9 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala @@ -44,9 +44,9 @@ import org.apache.spark.sql.catalyst.ParserDialect import org.apache.spark.sql.catalyst.analysis._ import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.execution.{ExecutedCommand, ExtractPythonUDFs, SetCommand} +import org.apache.spark.sql.execution.datasources.{PreWriteCheck, PreInsertCastAndRename, DataSourceStrategy} import org.apache.spark.sql.hive.client._ import org.apache.spark.sql.hive.execution.{DescribeHiveTableCommand, HiveNativeCommand} -import org.apache.spark.sql.sources.DataSourceStrategy import org.apache.spark.sql.types._ import org.apache.spark.util.Utils @@ -384,11 +384,11 @@ class HiveContext(sc: SparkContext) extends SQLContext(sc) with Logging { catalog.PreInsertionCasts :: ExtractPythonUDFs :: ResolveHiveWindowFunction :: - sources.PreInsertCastAndRename :: + PreInsertCastAndRename :: Nil override val extendedCheckRules = Seq( - sources.PreWriteCheck(catalog) + PreWriteCheck(catalog) ) } @@ -451,6 +451,7 @@ class HiveContext(sc: SparkContext) extends SQLContext(sc) with Logging { DataSinks, Scripts, HashAggregation, + Aggregation, LeftSemiJoin, HashJoin, BasicOperators, diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala index 4b7a782c805a0..0a2121c955871 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala @@ -17,6 +17,8 @@ package org.apache.spark.sql.hive +import scala.collection.JavaConversions._ + import com.google.common.base.Objects import com.google.common.cache.{CacheBuilder, CacheLoader, LoadingCache} @@ -28,6 +30,7 @@ import org.apache.hadoop.hive.ql.metadata._ import org.apache.hadoop.hive.ql.plan.TableDesc import org.apache.spark.Logging +import org.apache.spark.sql.{AnalysisException, SQLContext, SaveMode} import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.analysis.{Catalog, MultiInstanceRelation, OverrideCatalog} import org.apache.spark.sql.catalyst.expressions._ @@ -35,14 +38,12 @@ import org.apache.spark.sql.catalyst.planning.PhysicalOperation import org.apache.spark.sql.catalyst.plans.logical import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.rules._ +import org.apache.spark.sql.execution.datasources +import org.apache.spark.sql.execution.datasources.{Partition => ParquetPartition, PartitionSpec, CreateTableUsingAsSelect, ResolvedDataSource, LogicalRelation} import org.apache.spark.sql.hive.client._ import org.apache.spark.sql.parquet.ParquetRelation2 -import org.apache.spark.sql.sources.{CreateTableUsingAsSelect, LogicalRelation, Partition => ParquetPartition, PartitionSpec, ResolvedDataSource} import org.apache.spark.sql.types._ -import org.apache.spark.sql.{AnalysisException, SQLContext, SaveMode, sources} -/* Implicit conversions */ -import scala.collection.JavaConversions._ private[hive] class HiveMetastoreCatalog(val client: ClientInterface, hive: HiveContext) extends Catalog with Logging { @@ -278,7 +279,7 @@ private[hive] class HiveMetastoreCatalog(val client: ClientInterface, hive: Hive parquetRelation.paths.toSet == pathsInMetastore.toSet && logical.schema.sameType(metastoreSchema) && parquetRelation.partitionSpec == partitionSpecInMetastore.getOrElse { - PartitionSpec(StructType(Nil), Array.empty[sources.Partition]) + PartitionSpec(StructType(Nil), Array.empty[datasources.Partition]) } if (useCached) { @@ -301,7 +302,9 @@ private[hive] class HiveMetastoreCatalog(val client: ClientInterface, hive: Hive val result = if (metastoreRelation.hiveQlTable.isPartitioned) { val partitionSchema = StructType.fromAttributes(metastoreRelation.partitionKeys) val partitionColumnDataTypes = partitionSchema.map(_.dataType) - val partitions = metastoreRelation.hiveQlPartitions.map { p => + // We're converting the entire table into ParquetRelation, so predicates to Hive metastore + // are empty. + val partitions = metastoreRelation.getHiveQlPartitions().map { p => val location = p.getLocation val values = InternalRow.fromSeq(p.getValues.zip(partitionColumnDataTypes).map { case (rawValue, dataType) => Cast(Literal(rawValue), dataType).eval(null) @@ -596,8 +599,6 @@ private[hive] case class MetastoreRelation (@transient sqlContext: SQLContext) extends LeafNode with MultiInstanceRelation { - self: Product => - override def equals(other: Any): Boolean = other match { case relation: MetastoreRelation => databaseName == relation.databaseName && @@ -644,32 +645,6 @@ private[hive] case class MetastoreRelation new Table(tTable) } - @transient val hiveQlPartitions: Seq[Partition] = table.getAllPartitions.map { p => - val tPartition = new org.apache.hadoop.hive.metastore.api.Partition - tPartition.setDbName(databaseName) - tPartition.setTableName(tableName) - tPartition.setValues(p.values) - - val sd = new org.apache.hadoop.hive.metastore.api.StorageDescriptor() - tPartition.setSd(sd) - sd.setCols(table.schema.map(c => new FieldSchema(c.name, c.hiveType, c.comment))) - - sd.setLocation(p.storage.location) - sd.setInputFormat(p.storage.inputFormat) - sd.setOutputFormat(p.storage.outputFormat) - - val serdeInfo = new org.apache.hadoop.hive.metastore.api.SerDeInfo - sd.setSerdeInfo(serdeInfo) - serdeInfo.setSerializationLib(p.storage.serde) - - val serdeParameters = new java.util.HashMap[String, String]() - serdeInfo.setParameters(serdeParameters) - table.serdeProperties.foreach { case (k, v) => serdeParameters.put(k, v) } - p.storage.serdeProperties.foreach { case (k, v) => serdeParameters.put(k, v) } - - new Partition(hiveQlTable, tPartition) - } - @transient override lazy val statistics: Statistics = Statistics( sizeInBytes = { val totalSize = hiveQlTable.getParameters.get(StatsSetupConst.TOTAL_SIZE) @@ -690,6 +665,34 @@ private[hive] case class MetastoreRelation } ) + def getHiveQlPartitions(predicates: Seq[Expression] = Nil): Seq[Partition] = { + table.getPartitions(predicates).map { p => + val tPartition = new org.apache.hadoop.hive.metastore.api.Partition + tPartition.setDbName(databaseName) + tPartition.setTableName(tableName) + tPartition.setValues(p.values) + + val sd = new org.apache.hadoop.hive.metastore.api.StorageDescriptor() + tPartition.setSd(sd) + sd.setCols(table.schema.map(c => new FieldSchema(c.name, c.hiveType, c.comment))) + + sd.setLocation(p.storage.location) + sd.setInputFormat(p.storage.inputFormat) + sd.setOutputFormat(p.storage.outputFormat) + + val serdeInfo = new org.apache.hadoop.hive.metastore.api.SerDeInfo + sd.setSerdeInfo(serdeInfo) + serdeInfo.setSerializationLib(p.storage.serde) + + val serdeParameters = new java.util.HashMap[String, String]() + serdeInfo.setParameters(serdeParameters) + table.serdeProperties.foreach { case (k, v) => serdeParameters.put(k, v) } + p.storage.serdeProperties.foreach { case (k, v) => serdeParameters.put(k, v) } + + new Partition(hiveQlTable, tPartition) + } + } + /** Only compare database and tablename, not alias. */ override def sameResult(plan: LogicalPlan): Boolean = { plan match { diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala index 7fc517b646b20..8518e333e8058 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala @@ -37,7 +37,7 @@ import org.apache.spark.sql.catalyst.plans.logical import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.trees.CurrentOrigin import org.apache.spark.sql.execution.ExplainCommand -import org.apache.spark.sql.sources.DescribeCommand +import org.apache.spark.sql.execution.datasources.DescribeCommand import org.apache.spark.sql.hive.HiveShim._ import org.apache.spark.sql.hive.client._ import org.apache.spark.sql.hive.execution.{HiveNativeCommand, DropTable, AnalyzeTable, HiveScriptIOSchema} @@ -1464,9 +1464,12 @@ https://cwiki.apache.org/confluence/display/Hive/Enhanced+Aggregation%2C+Cube%2C /* UDFs - Must be last otherwise will preempt built in functions */ case Token("TOK_FUNCTION", Token(name, Nil) :: args) => - UnresolvedFunction(name, args.map(nodeToExpr)) + UnresolvedFunction(name, args.map(nodeToExpr), isDistinct = false) + // Aggregate function with DISTINCT keyword. + case Token("TOK_FUNCTIONDI", Token(name, Nil) :: args) => + UnresolvedFunction(name, args.map(nodeToExpr), isDistinct = true) case Token("TOK_FUNCTIONSTAR", Token(name, Nil) :: args) => - UnresolvedFunction(name, UnresolvedStar(None) :: Nil) + UnresolvedFunction(name, UnresolvedStar(None) :: Nil, isDistinct = false) /* Literals */ case Token("TOK_NULL", Nil) => Literal.create(null, NullType) diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveShim.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveShim.scala index d08c594151654..a357bb39ca7fd 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveShim.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveShim.scala @@ -27,6 +27,7 @@ import scala.reflect.ClassTag import com.esotericsoftware.kryo.Kryo import com.esotericsoftware.kryo.io.{Input, Output} + import org.apache.hadoop.conf.Configuration import org.apache.hadoop.fs.Path import org.apache.hadoop.hive.ql.exec.{UDF, Utilities} diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala index ed359620a5f7f..a22c3292eff94 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala @@ -30,9 +30,9 @@ import org.apache.spark.sql.catalyst.planning._ import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.execution.{DescribeCommand => RunnableDescribeCommand, _} +import org.apache.spark.sql.execution.datasources.{CreateTableUsing, CreateTableUsingAsSelect, DescribeCommand} import org.apache.spark.sql.hive.execution._ import org.apache.spark.sql.parquet.ParquetRelation -import org.apache.spark.sql.sources.{CreateTableUsing, CreateTableUsingAsSelect, DescribeCommand} import org.apache.spark.sql.types.StringType @@ -125,7 +125,7 @@ private[hive] trait HiveStrategies { InterpretedPredicate.create(castedPredicate) } - val partitions = relation.hiveQlPartitions.filter { part => + val partitions = relation.getHiveQlPartitions(pruningPredicates).filter { part => val partitionValues = part.getValues var i = 0 while (i < partitionValues.size()) { @@ -213,7 +213,7 @@ private[hive] trait HiveStrategies { projectList, otherPredicates, identity[Seq[Expression]], - HiveTableScan(_, relation, pruningPredicates.reduceLeftOption(And))(hiveContext)) :: Nil + HiveTableScan(_, relation, pruningPredicates)(hiveContext)) :: Nil case _ => Nil } diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/TableReader.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/TableReader.scala index d65d29daacf31..dc355690852bd 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/TableReader.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/TableReader.scala @@ -78,9 +78,7 @@ class HadoopTableReader( override def makeRDDForTable(hiveTable: HiveTable): RDD[InternalRow] = makeRDDForTable( hiveTable, - Class.forName( - relation.tableDesc.getSerdeClassName, true, Utils.getContextOrSparkClassLoader) - .asInstanceOf[Class[Deserializer]], + Utils.classForName(relation.tableDesc.getSerdeClassName).asInstanceOf[Class[Deserializer]], filterOpt = None) /** diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/ClientInterface.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/ClientInterface.scala index 0a1d761a52f88..1656587d14835 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/ClientInterface.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/ClientInterface.scala @@ -21,6 +21,7 @@ import java.io.PrintStream import java.util.{Map => JMap} import org.apache.spark.sql.catalyst.analysis.{NoSuchDatabaseException, NoSuchTableException} +import org.apache.spark.sql.catalyst.expressions.Expression private[hive] case class HiveDatabase( name: String, @@ -71,7 +72,12 @@ private[hive] case class HiveTable( def isPartitioned: Boolean = partitionColumns.nonEmpty - def getAllPartitions: Seq[HivePartition] = client.getAllPartitions(this) + def getPartitions(predicates: Seq[Expression]): Seq[HivePartition] = { + predicates match { + case Nil => client.getAllPartitions(this) + case _ => client.getPartitionsByFilter(this, predicates) + } + } // Hive does not support backticks when passing names to the client. def qualifiedName: String = s"$database.$name" @@ -132,6 +138,9 @@ private[hive] trait ClientInterface { /** Returns all partitions for the given table. */ def getAllPartitions(hTable: HiveTable): Seq[HivePartition] + /** Returns partitions filtered by predicates for the given table. */ + def getPartitionsByFilter(hTable: HiveTable, predicates: Seq[Expression]): Seq[HivePartition] + /** Loads a static partition into an existing table. */ def loadPartition( loadPath: String, diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/ClientWrapper.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/ClientWrapper.scala index 9d83ca6c113dc..8adda54754230 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/ClientWrapper.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/ClientWrapper.scala @@ -17,30 +17,26 @@ package org.apache.spark.sql.hive.client -import java.io.{BufferedReader, InputStreamReader, File, PrintStream} -import java.net.URI -import java.util.{ArrayList => JArrayList, Map => JMap, List => JList, Set => JSet} +import java.io.{File, PrintStream} +import java.util.{Map => JMap} import javax.annotation.concurrent.GuardedBy -import org.apache.spark.util.CircularBuffer - import scala.collection.JavaConversions._ import scala.language.reflectiveCalls import org.apache.hadoop.fs.Path -import org.apache.hadoop.hive.metastore.api.Database import org.apache.hadoop.hive.conf.HiveConf +import org.apache.hadoop.hive.metastore.api.{Database, FieldSchema} import org.apache.hadoop.hive.metastore.{TableType => HTableType} -import org.apache.hadoop.hive.metastore.api -import org.apache.hadoop.hive.metastore.api.FieldSchema -import org.apache.hadoop.hive.ql.metadata import org.apache.hadoop.hive.ql.metadata.Hive -import org.apache.hadoop.hive.ql.session.SessionState import org.apache.hadoop.hive.ql.processors._ -import org.apache.hadoop.hive.ql.Driver +import org.apache.hadoop.hive.ql.session.SessionState +import org.apache.hadoop.hive.ql.{Driver, metadata} import org.apache.spark.Logging +import org.apache.spark.sql.catalyst.expressions.Expression import org.apache.spark.sql.execution.QueryExecutionException +import org.apache.spark.util.{CircularBuffer, Utils} /** @@ -252,10 +248,10 @@ private[hive] class ClientWrapper( } private def toInputFormat(name: String) = - Class.forName(name).asInstanceOf[Class[_ <: org.apache.hadoop.mapred.InputFormat[_, _]]] + Utils.classForName(name).asInstanceOf[Class[_ <: org.apache.hadoop.mapred.InputFormat[_, _]]] private def toOutputFormat(name: String) = - Class.forName(name) + Utils.classForName(name) .asInstanceOf[Class[_ <: org.apache.hadoop.hive.ql.io.HiveOutputFormat[_, _]]] private def toQlTable(table: HiveTable): metadata.Table = { @@ -316,6 +312,13 @@ private[hive] class ClientWrapper( shim.getAllPartitions(client, qlTable).map(toHivePartition) } + override def getPartitionsByFilter( + hTable: HiveTable, + predicates: Seq[Expression]): Seq[HivePartition] = withHiveState { + val qlTable = toQlTable(hTable) + shim.getPartitionsByFilter(client, qlTable, predicates).map(toHivePartition) + } + override def listTables(dbName: String): Seq[String] = withHiveState { client.getAllTables(dbName) } diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveShim.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveShim.scala index 1fa9d278e2a57..956997e5f9dce 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveShim.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveShim.scala @@ -31,6 +31,11 @@ import org.apache.hadoop.hive.ql.Driver import org.apache.hadoop.hive.ql.metadata.{Hive, Partition, Table} import org.apache.hadoop.hive.ql.processors.{CommandProcessor, CommandProcessorFactory} import org.apache.hadoop.hive.ql.session.SessionState +import org.apache.hadoop.hive.serde.serdeConstants + +import org.apache.spark.Logging +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.types.{StringType, IntegralType} /** * A shim that defines the interface between ClientWrapper and the underlying Hive library used to @@ -61,6 +66,8 @@ private[client] sealed abstract class Shim { def getAllPartitions(hive: Hive, table: Table): Seq[Partition] + def getPartitionsByFilter(hive: Hive, table: Table, predicates: Seq[Expression]): Seq[Partition] + def getCommandProcessor(token: String, conf: HiveConf): CommandProcessor def getDriverResults(driver: Driver): Seq[String] @@ -109,7 +116,7 @@ private[client] sealed abstract class Shim { } -private[client] class Shim_v0_12 extends Shim { +private[client] class Shim_v0_12 extends Shim with Logging { private lazy val startMethod = findStaticMethod( @@ -196,6 +203,17 @@ private[client] class Shim_v0_12 extends Shim { override def getAllPartitions(hive: Hive, table: Table): Seq[Partition] = getAllPartitionsMethod.invoke(hive, table).asInstanceOf[JSet[Partition]].toSeq + override def getPartitionsByFilter( + hive: Hive, + table: Table, + predicates: Seq[Expression]): Seq[Partition] = { + // getPartitionsByFilter() doesn't support binary comparison ops in Hive 0.12. + // See HIVE-4888. + logDebug("Hive 0.12 doesn't support predicate pushdown to metastore. " + + "Please use Hive 0.13 or higher.") + getAllPartitions(hive, table) + } + override def getCommandProcessor(token: String, conf: HiveConf): CommandProcessor = getCommandProcessorMethod.invoke(null, token, conf).asInstanceOf[CommandProcessor] @@ -267,6 +285,12 @@ private[client] class Shim_v0_13 extends Shim_v0_12 { classOf[Hive], "getAllPartitionsOf", classOf[Table]) + private lazy val getPartitionsByFilterMethod = + findMethod( + classOf[Hive], + "getPartitionsByFilter", + classOf[Table], + classOf[String]) private lazy val getCommandProcessorMethod = findStaticMethod( classOf[CommandProcessorFactory], @@ -288,6 +312,51 @@ private[client] class Shim_v0_13 extends Shim_v0_12 { override def getAllPartitions(hive: Hive, table: Table): Seq[Partition] = getAllPartitionsMethod.invoke(hive, table).asInstanceOf[JSet[Partition]].toSeq + /** + * Converts catalyst expression to the format that Hive's getPartitionsByFilter() expects, i.e. + * a string that represents partition predicates like "str_key=\"value\" and int_key=1 ...". + * + * Unsupported predicates are skipped. + */ + def convertFilters(table: Table, filters: Seq[Expression]): String = { + // hive varchar is treated as catalyst string, but hive varchar can't be pushed down. + val varcharKeys = table.getPartitionKeys + .filter(col => col.getType.startsWith(serdeConstants.VARCHAR_TYPE_NAME)) + .map(col => col.getName).toSet + + filters.collect { + case op @ BinaryComparison(a: Attribute, Literal(v, _: IntegralType)) => + s"${a.name} ${op.symbol} $v" + case op @ BinaryComparison(Literal(v, _: IntegralType), a: Attribute) => + s"$v ${op.symbol} ${a.name}" + case op @ BinaryComparison(a: Attribute, Literal(v, _: StringType)) + if !varcharKeys.contains(a.name) => + s"""${a.name} ${op.symbol} "$v"""" + case op @ BinaryComparison(Literal(v, _: StringType), a: Attribute) + if !varcharKeys.contains(a.name) => + s""""$v" ${op.symbol} ${a.name}""" + }.mkString(" and ") + } + + override def getPartitionsByFilter( + hive: Hive, + table: Table, + predicates: Seq[Expression]): Seq[Partition] = { + + // Hive getPartitionsByFilter() takes a string that represents partition + // predicates like "str_key=\"value\" and int_key=1 ..." + val filter = convertFilters(table, predicates) + val partitions = + if (filter.isEmpty) { + getAllPartitionsMethod.invoke(hive, table).asInstanceOf[JSet[Partition]] + } else { + logDebug(s"Hive metastore filter is '$filter'.") + getPartitionsByFilterMethod.invoke(hive, table, filter).asInstanceOf[JArrayList[Partition]] + } + + partitions.toSeq + } + override def getCommandProcessor(token: String, conf: HiveConf): CommandProcessor = getCommandProcessorMethod.invoke(null, Array(token), conf).asInstanceOf[CommandProcessor] diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/IsolatedClientLoader.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/IsolatedClientLoader.scala index 3d609a66f3664..97fb98199991b 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/IsolatedClientLoader.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/IsolatedClientLoader.scala @@ -125,7 +125,7 @@ private[hive] class IsolatedClientLoader( name.contains("log4j") || name.startsWith("org.apache.spark.") || name.startsWith("scala.") || - name.startsWith("com.google") || + (name.startsWith("com.google") && !name.startsWith("com.google.cloud")) || name.startsWith("java.lang.") || name.startsWith("java.net") || sharedPrefixes.exists(name.startsWith) diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/HiveTableScan.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/HiveTableScan.scala index d33da8242cc1d..ba7eb15a1c0c6 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/HiveTableScan.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/HiveTableScan.scala @@ -44,7 +44,7 @@ private[hive] case class HiveTableScan( requestedAttributes: Seq[Attribute], relation: MetastoreRelation, - partitionPruningPred: Option[Expression])( + partitionPruningPred: Seq[Expression])( @transient val context: HiveContext) extends LeafNode { @@ -56,7 +56,7 @@ case class HiveTableScan( // Bind all partition key attribute references in the partition pruning predicate for later // evaluation. - private[this] val boundPruningPred = partitionPruningPred.map { pred => + private[this] val boundPruningPred = partitionPruningPred.reduceLeftOption(And).map { pred => require( pred.dataType == BooleanType, s"Data type of predicate $pred must be BooleanType rather than ${pred.dataType}.") @@ -133,7 +133,8 @@ case class HiveTableScan( protected override def doExecute(): RDD[InternalRow] = if (!relation.hiveQlTable.isPartitioned) { hadoopReader.makeRDDForTable(relation.hiveQlTable) } else { - hadoopReader.makeRDDForPartitionedTable(prunePartitions(relation.hiveQlPartitions)) + hadoopReader.makeRDDForPartitionedTable( + prunePartitions(relation.getHiveQlPartitions(partitionPruningPred))) } override def output: Seq[Attribute] = attributes diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/commands.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/commands.scala index 71fa3e9c33ad9..a47f9a4feb21b 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/commands.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/commands.scala @@ -23,6 +23,7 @@ import org.apache.spark.sql.catalyst.expressions.Attribute import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.catalyst.util._ import org.apache.spark.sql.execution.RunnableCommand +import org.apache.spark.sql.execution.datasources.{ResolvedDataSource, LogicalRelation} import org.apache.spark.sql.hive.HiveContext import org.apache.spark.sql.sources._ import org.apache.spark.sql.types._ diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUDFs.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUDFs.scala index 0bc8adb16afc0..3259b50acc765 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUDFs.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUDFs.scala @@ -36,8 +36,8 @@ import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.analysis import org.apache.spark.sql.catalyst.analysis.FunctionRegistry.FunctionBuilder -import org.apache.spark.sql.catalyst.errors.TreeNodeException import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.expressions.codegen.CodegenFallback import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.rules.Rule import org.apache.spark.sql.hive.HiveShim._ @@ -81,7 +81,7 @@ private[hive] class HiveFunctionRegistry(underlying: analysis.FunctionRegistry) } private[hive] case class HiveSimpleUDF(funcWrapper: HiveFunctionWrapper, children: Seq[Expression]) - extends Expression with HiveInspectors with Logging { + extends Expression with HiveInspectors with CodegenFallback with Logging { type UDFType = UDF @@ -146,7 +146,7 @@ private[hive] class DeferredObjectAdapter(oi: ObjectInspector) } private[hive] case class HiveGenericUDF(funcWrapper: HiveFunctionWrapper, children: Seq[Expression]) - extends Expression with HiveInspectors with Logging { + extends Expression with HiveInspectors with CodegenFallback with Logging { type UDFType = GenericUDF override def deterministic: Boolean = isUDFDeterministic @@ -166,8 +166,8 @@ private[hive] case class HiveGenericUDF(funcWrapper: HiveFunctionWrapper, childr @transient protected lazy val isUDFDeterministic = { - val udfType = function.getClass().getAnnotation(classOf[HiveUDFType]) - (udfType != null && udfType.deterministic()) + val udfType = function.getClass.getAnnotation(classOf[HiveUDFType]) + udfType != null && udfType.deterministic() } override def foldable: Boolean = @@ -301,7 +301,7 @@ private[hive] case class HiveWindowFunction( pivotResult: Boolean, isUDAFBridgeRequired: Boolean, children: Seq[Expression]) extends WindowFunction - with HiveInspectors { + with HiveInspectors with Unevaluable { // Hive window functions are based on GenericUDAFResolver2. type UDFType = GenericUDAFResolver2 @@ -330,7 +330,7 @@ private[hive] case class HiveWindowFunction( evaluator.init(GenericUDAFEvaluator.Mode.COMPLETE, inputInspectors) } - def dataType: DataType = + override def dataType: DataType = if (!pivotResult) { inspectorToDataType(returnInspector) } else { @@ -344,10 +344,7 @@ private[hive] case class HiveWindowFunction( } } - def nullable: Boolean = true - - override def eval(input: InternalRow): Any = - throw new TreeNodeException(this, s"No function to evaluate expression. type: ${this.nodeName}") + override def nullable: Boolean = true @transient lazy val inputProjection = new InterpretedProjection(children) @@ -406,13 +403,13 @@ private[hive] case class HiveWindowFunction( s"$nodeName#${funcWrapper.functionClassName}(${children.mkString(",")})" } - override def newInstance: WindowFunction = + override def newInstance(): WindowFunction = new HiveWindowFunction(funcWrapper, pivotResult, isUDAFBridgeRequired, children) } private[hive] case class HiveGenericUDAF( funcWrapper: HiveFunctionWrapper, - children: Seq[Expression]) extends AggregateExpression + children: Seq[Expression]) extends AggregateExpression1 with HiveInspectors { type UDFType = AbstractGenericUDAFResolver @@ -444,7 +441,7 @@ private[hive] case class HiveGenericUDAF( /** It is used as a wrapper for the hive functions which uses UDAF interface */ private[hive] case class HiveUDAF( funcWrapper: HiveFunctionWrapper, - children: Seq[Expression]) extends AggregateExpression + children: Seq[Expression]) extends AggregateExpression1 with HiveInspectors { type UDFType = UDAF @@ -476,7 +473,7 @@ private[hive] case class HiveUDAF( /** * Converts a Hive Generic User Defined Table Generating Function (UDTF) to a - * [[catalyst.expressions.Generator Generator]]. Note that the semantics of Generators do not allow + * [[Generator]]. Note that the semantics of Generators do not allow * Generators to maintain state in between input rows. Thus UDTFs that rely on partitioning * dependent operations like calls to `close()` before producing output will not operate the same as * in Hive. However, in practice this should not affect compatibility for most sane UDTFs @@ -488,7 +485,7 @@ private[hive] case class HiveUDAF( private[hive] case class HiveGenericUDTF( funcWrapper: HiveFunctionWrapper, children: Seq[Expression]) - extends Generator with HiveInspectors { + extends Generator with HiveInspectors with CodegenFallback { @transient protected lazy val function: GenericUDTF = { @@ -553,9 +550,9 @@ private[hive] case class HiveGenericUDTF( private[hive] case class HiveUDAFFunction( funcWrapper: HiveFunctionWrapper, exprs: Seq[Expression], - base: AggregateExpression, + base: AggregateExpression1, isUDAFBridgeRequired: Boolean = false) - extends AggregateFunction + extends AggregateFunction1 with HiveInspectors { def this() = this(null, null, null) diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcFilters.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcFilters.scala index 250e73a4dba92..ddd5d24717add 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcFilters.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcFilters.scala @@ -41,10 +41,10 @@ private[orc] object OrcFilters extends Logging { private def buildSearchArgument(expression: Filter, builder: Builder): Option[Builder] = { def newBuilder = SearchArgument.FACTORY.newBuilder() - def isSearchableLiteral(value: Any) = value match { + def isSearchableLiteral(value: Any): Boolean = value match { // These are types recognized by the `SearchArgumentImpl.BuilderImpl.boxLiteral()` method. - case _: String | _: Long | _: Double | _: DateWritable | _: HiveDecimal | _: HiveChar | - _: HiveVarchar | _: Byte | _: Short | _: Integer | _: Float => true + case _: String | _: Long | _: Double | _: Byte | _: Short | _: Integer | _: Float => true + case _: DateWritable | _: HiveDecimal | _: HiveChar | _: HiveVarchar => true case _ => false } diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcRelation.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcRelation.scala index 48d35a60a759b..de63ee56dd8e6 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcRelation.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcRelation.scala @@ -37,6 +37,7 @@ import org.apache.spark.mapred.SparkHadoopMapRedUtil import org.apache.spark.rdd.{HadoopRDD, RDD} import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.execution.datasources.PartitionSpec import org.apache.spark.sql.hive.{HiveContext, HiveInspectors, HiveMetastoreTypes, HiveShim} import org.apache.spark.sql.sources.{Filter, _} import org.apache.spark.sql.types.StructType diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/test/TestHive.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/test/TestHive.scala index 0f217bc66869f..3662a4352f55d 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/test/TestHive.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/test/TestHive.scala @@ -21,6 +21,7 @@ import java.io.File import java.util.{Set => JavaSet} import org.apache.hadoop.hive.conf.HiveConf +import org.apache.hadoop.hive.conf.HiveConf.ConfVars import org.apache.hadoop.hive.ql.exec.FunctionRegistry import org.apache.hadoop.hive.ql.io.avro.{AvroContainerInputFormat, AvroContainerOutputFormat} import org.apache.hadoop.hive.ql.metadata.Table @@ -87,7 +88,9 @@ class TestHiveContext(sc: SparkContext) extends HiveContext(sc) { /** Sets up the system initially or after a RESET command */ protected override def configure(): Map[String, String] = - temporaryConfig ++ Map("hive.metastore.warehouse.dir" -> warehousePath.toString) + temporaryConfig ++ Map( + ConfVars.METASTOREWAREHOUSE.varname -> warehousePath.toString, + ConfVars.METASTORE_INTEGER_JDO_PUSHDOWN.varname -> "true") val testTempDir = Utils.createTempDir() diff --git a/sql/hive/src/test/java/test/org/apache/spark/sql/hive/JavaDataFrameSuite.java b/sql/hive/src/test/java/test/org/apache/spark/sql/hive/JavaDataFrameSuite.java index c4828c4717643..741a3cd31c603 100644 --- a/sql/hive/src/test/java/test/org/apache/spark/sql/hive/JavaDataFrameSuite.java +++ b/sql/hive/src/test/java/test/org/apache/spark/sql/hive/JavaDataFrameSuite.java @@ -61,7 +61,9 @@ public void setUp() throws IOException { @After public void tearDown() throws IOException { // Clean up tables. - hc.sql("DROP TABLE IF EXISTS window_table"); + if (hc != null) { + hc.sql("DROP TABLE IF EXISTS window_table"); + } } @Test diff --git a/sql/hive/src/test/java/test/org/apache/spark/sql/hive/aggregate/MyDoubleAvg.java b/sql/hive/src/test/java/test/org/apache/spark/sql/hive/aggregate/MyDoubleAvg.java new file mode 100644 index 0000000000000..5c9d0e97a99c6 --- /dev/null +++ b/sql/hive/src/test/java/test/org/apache/spark/sql/hive/aggregate/MyDoubleAvg.java @@ -0,0 +1,107 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package test.org.apache.spark.sql.hive.aggregate; + +import java.util.ArrayList; +import java.util.List; + +import org.apache.spark.sql.Row; +import org.apache.spark.sql.expressions.aggregate.MutableAggregationBuffer; +import org.apache.spark.sql.expressions.aggregate.UserDefinedAggregateFunction; +import org.apache.spark.sql.types.DataType; +import org.apache.spark.sql.types.DataTypes; +import org.apache.spark.sql.types.StructField; +import org.apache.spark.sql.types.StructType; + +public class MyDoubleAvg extends UserDefinedAggregateFunction { + + private StructType _inputDataType; + + private StructType _bufferSchema; + + private DataType _returnDataType; + + public MyDoubleAvg() { + List inputfields = new ArrayList(); + inputfields.add(DataTypes.createStructField("inputDouble", DataTypes.DoubleType, true)); + _inputDataType = DataTypes.createStructType(inputfields); + + List bufferFields = new ArrayList(); + bufferFields.add(DataTypes.createStructField("bufferSum", DataTypes.DoubleType, true)); + bufferFields.add(DataTypes.createStructField("bufferCount", DataTypes.LongType, true)); + _bufferSchema = DataTypes.createStructType(bufferFields); + + _returnDataType = DataTypes.DoubleType; + } + + @Override public StructType inputSchema() { + return _inputDataType; + } + + @Override public StructType bufferSchema() { + return _bufferSchema; + } + + @Override public DataType returnDataType() { + return _returnDataType; + } + + @Override public boolean deterministic() { + return true; + } + + @Override public void initialize(MutableAggregationBuffer buffer) { + buffer.update(0, null); + buffer.update(1, 0L); + } + + @Override public void update(MutableAggregationBuffer buffer, Row input) { + if (!input.isNullAt(0)) { + if (buffer.isNullAt(0)) { + buffer.update(0, input.getDouble(0)); + buffer.update(1, 1L); + } else { + Double newValue = input.getDouble(0) + buffer.getDouble(0); + buffer.update(0, newValue); + buffer.update(1, buffer.getLong(1) + 1L); + } + } + } + + @Override public void merge(MutableAggregationBuffer buffer1, Row buffer2) { + if (!buffer2.isNullAt(0)) { + if (buffer1.isNullAt(0)) { + buffer1.update(0, buffer2.getDouble(0)); + buffer1.update(1, buffer2.getLong(1)); + } else { + Double newValue = buffer2.getDouble(0) + buffer1.getDouble(0); + buffer1.update(0, newValue); + buffer1.update(1, buffer1.getLong(1) + buffer2.getLong(1)); + } + } + } + + @Override public Object evaluate(Row buffer) { + if (buffer.isNullAt(0)) { + return null; + } else { + return buffer.getDouble(0) / buffer.getLong(1) + 100.0; + } + } +} + diff --git a/sql/hive/src/test/java/test/org/apache/spark/sql/hive/aggregate/MyDoubleSum.java b/sql/hive/src/test/java/test/org/apache/spark/sql/hive/aggregate/MyDoubleSum.java new file mode 100644 index 0000000000000..1d4587a27c787 --- /dev/null +++ b/sql/hive/src/test/java/test/org/apache/spark/sql/hive/aggregate/MyDoubleSum.java @@ -0,0 +1,100 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package test.org.apache.spark.sql.hive.aggregate; + +import java.util.ArrayList; +import java.util.List; + +import org.apache.spark.sql.expressions.aggregate.MutableAggregationBuffer; +import org.apache.spark.sql.expressions.aggregate.UserDefinedAggregateFunction; +import org.apache.spark.sql.types.StructField; +import org.apache.spark.sql.types.StructType; +import org.apache.spark.sql.types.DataType; +import org.apache.spark.sql.types.DataTypes; +import org.apache.spark.sql.Row; + +public class MyDoubleSum extends UserDefinedAggregateFunction { + + private StructType _inputDataType; + + private StructType _bufferSchema; + + private DataType _returnDataType; + + public MyDoubleSum() { + List inputfields = new ArrayList(); + inputfields.add(DataTypes.createStructField("inputDouble", DataTypes.DoubleType, true)); + _inputDataType = DataTypes.createStructType(inputfields); + + List bufferFields = new ArrayList(); + bufferFields.add(DataTypes.createStructField("bufferDouble", DataTypes.DoubleType, true)); + _bufferSchema = DataTypes.createStructType(bufferFields); + + _returnDataType = DataTypes.DoubleType; + } + + @Override public StructType inputSchema() { + return _inputDataType; + } + + @Override public StructType bufferSchema() { + return _bufferSchema; + } + + @Override public DataType returnDataType() { + return _returnDataType; + } + + @Override public boolean deterministic() { + return true; + } + + @Override public void initialize(MutableAggregationBuffer buffer) { + buffer.update(0, null); + } + + @Override public void update(MutableAggregationBuffer buffer, Row input) { + if (!input.isNullAt(0)) { + if (buffer.isNullAt(0)) { + buffer.update(0, input.getDouble(0)); + } else { + Double newValue = input.getDouble(0) + buffer.getDouble(0); + buffer.update(0, newValue); + } + } + } + + @Override public void merge(MutableAggregationBuffer buffer1, Row buffer2) { + if (!buffer2.isNullAt(0)) { + if (buffer1.isNullAt(0)) { + buffer1.update(0, buffer2.getDouble(0)); + } else { + Double newValue = buffer2.getDouble(0) + buffer1.getDouble(0); + buffer1.update(0, newValue); + } + } + } + + @Override public Object evaluate(Row buffer) { + if (buffer.isNullAt(0)) { + return null; + } else { + return buffer.getDouble(0); + } + } +} diff --git a/sql/hive/src/test/resources/golden/SPARK-8976 Wrong Result for CUBE #1-0-63b61fb3f0e74226001ad279be440864 b/sql/hive/src/test/resources/golden/SPARK-8976 Wrong Result for CUBE #1-0-63b61fb3f0e74226001ad279be440864 new file mode 100644 index 0000000000000..dac1b84b916d7 --- /dev/null +++ b/sql/hive/src/test/resources/golden/SPARK-8976 Wrong Result for CUBE #1-0-63b61fb3f0e74226001ad279be440864 @@ -0,0 +1,6 @@ +500 NULL 0 +91 0 1 +84 1 1 +105 2 1 +113 3 1 +107 4 1 diff --git a/sql/hive/src/test/resources/golden/SPARK-8976 Wrong Result for CUBE #2-0-7a511f02a16f0af4f810b1666cfcd896 b/sql/hive/src/test/resources/golden/SPARK-8976 Wrong Result for CUBE #2-0-7a511f02a16f0af4f810b1666cfcd896 new file mode 100644 index 0000000000000..c7cb747c0a659 --- /dev/null +++ b/sql/hive/src/test/resources/golden/SPARK-8976 Wrong Result for CUBE #2-0-7a511f02a16f0af4f810b1666cfcd896 @@ -0,0 +1,10 @@ +1 NULL -3 2 +1 NULL -1 2 +1 NULL 3 2 +1 NULL 4 2 +1 NULL 5 2 +1 NULL 6 2 +1 NULL 12 2 +1 NULL 14 2 +1 NULL 15 2 +1 NULL 22 2 diff --git a/sql/hive/src/test/resources/golden/SPARK-8976 Wrong Result for GroupingSet-0-8c14c24670a4b06c440346277ce9cf1c b/sql/hive/src/test/resources/golden/SPARK-8976 Wrong Result for GroupingSet-0-8c14c24670a4b06c440346277ce9cf1c new file mode 100644 index 0000000000000..c7cb747c0a659 --- /dev/null +++ b/sql/hive/src/test/resources/golden/SPARK-8976 Wrong Result for GroupingSet-0-8c14c24670a4b06c440346277ce9cf1c @@ -0,0 +1,10 @@ +1 NULL -3 2 +1 NULL -1 2 +1 NULL 3 2 +1 NULL 4 2 +1 NULL 5 2 +1 NULL 6 2 +1 NULL 12 2 +1 NULL 14 2 +1 NULL 15 2 +1 NULL 22 2 diff --git a/sql/hive/src/test/resources/golden/SPARK-8976 Wrong Result for Rollup #1-0-a78e3dbf242f240249e36b3d3fd0926a b/sql/hive/src/test/resources/golden/SPARK-8976 Wrong Result for Rollup #1-0-a78e3dbf242f240249e36b3d3fd0926a new file mode 100644 index 0000000000000..dac1b84b916d7 --- /dev/null +++ b/sql/hive/src/test/resources/golden/SPARK-8976 Wrong Result for Rollup #1-0-a78e3dbf242f240249e36b3d3fd0926a @@ -0,0 +1,6 @@ +500 NULL 0 +91 0 1 +84 1 1 +105 2 1 +113 3 1 +107 4 1 diff --git a/sql/hive/src/test/resources/golden/SPARK-8976 Wrong Result for Rollup #2-0-bf180c9d1a18f61b9d9f31bb0115cf89 b/sql/hive/src/test/resources/golden/SPARK-8976 Wrong Result for Rollup #2-0-bf180c9d1a18f61b9d9f31bb0115cf89 new file mode 100644 index 0000000000000..1eea4a9b23687 --- /dev/null +++ b/sql/hive/src/test/resources/golden/SPARK-8976 Wrong Result for Rollup #2-0-bf180c9d1a18f61b9d9f31bb0115cf89 @@ -0,0 +1,10 @@ +1 0 5 3 +1 0 15 3 +1 0 25 3 +1 0 60 3 +1 0 75 3 +1 0 80 3 +1 0 100 3 +1 0 140 3 +1 0 145 3 +1 0 150 3 diff --git a/sql/hive/src/test/resources/golden/SPARK-8976 Wrong Result for Rollup #3-0-9257085d123728730be96b6d9fbb84ce b/sql/hive/src/test/resources/golden/SPARK-8976 Wrong Result for Rollup #3-0-9257085d123728730be96b6d9fbb84ce new file mode 100644 index 0000000000000..1eea4a9b23687 --- /dev/null +++ b/sql/hive/src/test/resources/golden/SPARK-8976 Wrong Result for Rollup #3-0-9257085d123728730be96b6d9fbb84ce @@ -0,0 +1,10 @@ +1 0 5 3 +1 0 15 3 +1 0 25 3 +1 0 60 3 +1 0 75 3 +1 0 80 3 +1 0 100 3 +1 0 140 3 +1 0 145 3 +1 0 150 3 diff --git a/sql/hive/src/test/resources/golden/udf_unhex-0-50131c0ba7b7a6b65c789a5a8497bada b/sql/hive/src/test/resources/golden/udf_unhex-0-50131c0ba7b7a6b65c789a5a8497bada new file mode 100644 index 0000000000000..573541ac9702d --- /dev/null +++ b/sql/hive/src/test/resources/golden/udf_unhex-0-50131c0ba7b7a6b65c789a5a8497bada @@ -0,0 +1 @@ +0 diff --git a/sql/hive/src/test/resources/golden/udf_unhex-1-11eb3cc5216d5446f4165007203acc47 b/sql/hive/src/test/resources/golden/udf_unhex-1-11eb3cc5216d5446f4165007203acc47 new file mode 100644 index 0000000000000..44b2a42cc26c5 --- /dev/null +++ b/sql/hive/src/test/resources/golden/udf_unhex-1-11eb3cc5216d5446f4165007203acc47 @@ -0,0 +1 @@ +unhex(str) - Converts hexadecimal argument to binary diff --git a/sql/hive/src/test/resources/golden/udf_unhex-2-a660886085b8651852b9b77934848ae4 b/sql/hive/src/test/resources/golden/udf_unhex-2-a660886085b8651852b9b77934848ae4 new file mode 100644 index 0000000000000..97af3b812a429 --- /dev/null +++ b/sql/hive/src/test/resources/golden/udf_unhex-2-a660886085b8651852b9b77934848ae4 @@ -0,0 +1,14 @@ +unhex(str) - Converts hexadecimal argument to binary +Performs the inverse operation of HEX(str). That is, it interprets +each pair of hexadecimal digits in the argument as a number and +converts it to the byte representation of the number. The +resulting characters are returned as a binary string. + +Example: +> SELECT DECODE(UNHEX('4D7953514C'), 'UTF-8') from src limit 1; +'MySQL' + +The characters in the argument string must be legal hexadecimal +digits: '0' .. '9', 'A' .. 'F', 'a' .. 'f'. If UNHEX() encounters +any nonhexadecimal digits in the argument, it returns NULL. Also, +if there are an odd number of characters a leading 0 is appended. diff --git a/sql/hive/src/test/resources/golden/udf_unhex-3-4b2cf4050af229fde91ab53fd9f3af3e b/sql/hive/src/test/resources/golden/udf_unhex-3-4b2cf4050af229fde91ab53fd9f3af3e new file mode 100644 index 0000000000000..b4a6f2b692227 --- /dev/null +++ b/sql/hive/src/test/resources/golden/udf_unhex-3-4b2cf4050af229fde91ab53fd9f3af3e @@ -0,0 +1 @@ +MySQL 1267 a -4 diff --git a/sql/hive/src/test/resources/golden/udf_unhex-4-7d3e094f139892ecef17de3fd63ca3c3 b/sql/hive/src/test/resources/golden/udf_unhex-4-7d3e094f139892ecef17de3fd63ca3c3 new file mode 100644 index 0000000000000..3a67adaf0a9a8 --- /dev/null +++ b/sql/hive/src/test/resources/golden/udf_unhex-4-7d3e094f139892ecef17de3fd63ca3c3 @@ -0,0 +1 @@ +NULL NULL NULL diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveDataFrameWindowSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveDataFrameWindowSuite.scala index efb3f2545db84..c177cbdd991cf 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveDataFrameWindowSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveDataFrameWindowSuite.scala @@ -183,13 +183,13 @@ class HiveDataFrameWindowSuite extends QueryTest { } test("aggregation and range betweens with unbounded") { - val df = Seq((1, "1"), (2, "2"), (2, "2"), (2, "2"), (1, "1"), (2, "2")).toDF("key", "value") + val df = Seq((5, "1"), (5, "2"), (4, "2"), (6, "2"), (3, "1"), (2, "2")).toDF("key", "value") df.registerTempTable("window_table") checkAnswer( df.select( $"key", last("value").over( - Window.partitionBy($"value").orderBy($"key").rangeBetween(1, Long.MaxValue)) + Window.partitionBy($"value").orderBy($"key").rangeBetween(-2, -1)) .equalTo("2") .as("last_v"), avg("key").over(Window.partitionBy("value").orderBy("key").rangeBetween(Long.MinValue, 1)) @@ -203,7 +203,7 @@ class HiveDataFrameWindowSuite extends QueryTest { """SELECT | key, | last_value(value) OVER - | (PARTITION BY value ORDER BY key RANGE 1 preceding) == "2", + | (PARTITION BY value ORDER BY key RANGE BETWEEN 2 preceding and 1 preceding) == "2", | avg(key) OVER | (PARTITION BY value ORDER BY key RANGE BETWEEN unbounded preceding and 1 following), | avg(key) OVER @@ -212,4 +212,47 @@ class HiveDataFrameWindowSuite extends QueryTest { | (PARTITION BY value ORDER BY key RANGE BETWEEN 1 preceding and current row) | FROM window_table""".stripMargin).collect()) } + + test("reverse sliding range frame") { + val df = Seq( + (1, "Thin", "Cell Phone", 6000), + (2, "Normal", "Tablet", 1500), + (3, "Mini", "Tablet", 5500), + (4, "Ultra thin", "Cell Phone", 5500), + (5, "Very thin", "Cell Phone", 6000), + (6, "Big", "Tablet", 2500), + (7, "Bendable", "Cell Phone", 3000), + (8, "Foldable", "Cell Phone", 3000), + (9, "Pro", "Tablet", 4500), + (10, "Pro2", "Tablet", 6500)). + toDF("id", "product", "category", "revenue") + val window = Window. + partitionBy($"category"). + orderBy($"revenue".desc). + rangeBetween(-2000L, 1000L) + checkAnswer( + df.select( + $"id", + avg($"revenue").over(window).cast("int")), + Row(1, 5833) :: Row(2, 2000) :: Row(3, 5500) :: + Row(4, 5833) :: Row(5, 5833) :: Row(6, 2833) :: + Row(7, 3000) :: Row(8, 3000) :: Row(9, 5500) :: + Row(10, 6000) :: Nil) + } + + // This is here to illustrate the fact that reverse order also reverses offsets. + test("reverse unbounded range frame") { + val df = Seq(1, 2, 4, 3, 2, 1). + map(Tuple1.apply). + toDF("value") + val window = Window.orderBy($"value".desc) + checkAnswer( + df.select( + $"value", + sum($"value").over(window.rangeBetween(Long.MinValue, 1)), + sum($"value").over(window.rangeBetween(1, Long.MaxValue))), + Row(1, 13, null) :: Row(2, 13, 2) :: Row(4, 7, 9) :: + Row(3, 11, 6) :: Row(2, 13, 2) :: Row(1, 13, null) :: Nil) + + } } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveSparkSubmitSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveSparkSubmitSuite.scala index 917900e5f46dc..72b35959a491b 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveSparkSubmitSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveSparkSubmitSuite.scala @@ -53,7 +53,7 @@ class HiveSparkSubmitSuite val args = Seq( "--class", SparkSubmitClassLoaderTest.getClass.getName.stripSuffix("$"), "--name", "SparkSubmitClassLoaderTest", - "--master", "local-cluster[2,1,512]", + "--master", "local-cluster[2,1,1024]", "--jars", jarsString, unusedJar.toString, "SparkSubmitClassA", "SparkSubmitClassB") runSparkSubmit(args) @@ -64,7 +64,7 @@ class HiveSparkSubmitSuite val args = Seq( "--class", SparkSQLConfTest.getClass.getName.stripSuffix("$"), "--name", "SparkSQLConfTest", - "--master", "local-cluster[2,1,512]", + "--master", "local-cluster[2,1,1024]", unusedJar.toString) runSparkSubmit(args) } @@ -120,8 +120,8 @@ object SparkSubmitClassLoaderTest extends Logging { logInfo("Testing load classes at the driver side.") // First, we load classes at driver side. try { - Class.forName(args(0), true, Thread.currentThread().getContextClassLoader) - Class.forName(args(1), true, Thread.currentThread().getContextClassLoader) + Utils.classForName(args(0)) + Utils.classForName(args(1)) } catch { case t: Throwable => throw new Exception("Could not load user class from jar:\n", t) @@ -131,8 +131,8 @@ object SparkSubmitClassLoaderTest extends Logging { val result = df.mapPartitions { x => var exception: String = null try { - Class.forName(args(0), true, Thread.currentThread().getContextClassLoader) - Class.forName(args(1), true, Thread.currentThread().getContextClassLoader) + Utils.classForName(args(0)) + Utils.classForName(args(1)) } catch { case t: Throwable => exception = t + "\n" + t.getStackTraceString diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/MetastoreDataSourcesSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/MetastoreDataSourcesSuite.scala index d910af22c3dd1..e403f32efaf91 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/MetastoreDataSourcesSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/MetastoreDataSourcesSuite.scala @@ -28,12 +28,12 @@ import org.apache.hadoop.mapred.InvalidInputException import org.apache.spark.Logging import org.apache.spark.sql._ +import org.apache.spark.sql.execution.datasources.LogicalRelation import org.apache.spark.sql.hive.client.{HiveTable, ManagedTable} import org.apache.spark.sql.hive.test.TestHive import org.apache.spark.sql.hive.test.TestHive._ import org.apache.spark.sql.hive.test.TestHive.implicits._ import org.apache.spark.sql.parquet.ParquetRelation2 -import org.apache.spark.sql.sources.LogicalRelation import org.apache.spark.sql.test.SQLTestUtils import org.apache.spark.sql.types._ import org.apache.spark.util.Utils diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/FiltersSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/FiltersSuite.scala new file mode 100644 index 0000000000000..0efcf80bd4ea7 --- /dev/null +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/FiltersSuite.scala @@ -0,0 +1,78 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.hive.client + +import scala.collection.JavaConversions._ + +import org.apache.hadoop.hive.metastore.api.FieldSchema +import org.apache.hadoop.hive.serde.serdeConstants + +import org.apache.spark.{Logging, SparkFunSuite} +import org.apache.spark.sql.catalyst.dsl.expressions._ +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.types._ + +/** + * A set of tests for the filter conversion logic used when pushing partition pruning into the + * metastore + */ +class FiltersSuite extends SparkFunSuite with Logging { + private val shim = new Shim_v0_13 + + private val testTable = new org.apache.hadoop.hive.ql.metadata.Table("default", "test") + private val varCharCol = new FieldSchema() + varCharCol.setName("varchar") + varCharCol.setType(serdeConstants.VARCHAR_TYPE_NAME) + testTable.setPartCols(varCharCol :: Nil) + + filterTest("string filter", + (a("stringcol", StringType) > Literal("test")) :: Nil, + "stringcol > \"test\"") + + filterTest("string filter backwards", + (Literal("test") > a("stringcol", StringType)) :: Nil, + "\"test\" > stringcol") + + filterTest("int filter", + (a("intcol", IntegerType) === Literal(1)) :: Nil, + "intcol = 1") + + filterTest("int filter backwards", + (Literal(1) === a("intcol", IntegerType)) :: Nil, + "1 = intcol") + + filterTest("int and string filter", + (Literal(1) === a("intcol", IntegerType)) :: (Literal("a") === a("strcol", IntegerType)) :: Nil, + "1 = intcol and \"a\" = strcol") + + filterTest("skip varchar", + (Literal("") === a("varchar", StringType)) :: Nil, + "") + + private def filterTest(name: String, filters: Seq[Expression], result: String) = { + test(name){ + val converted = shim.convertFilters(testTable, filters) + if (converted != result) { + fail( + s"Expected filters ${filters.mkString(",")} to convert to '$result' but got '$converted'") + } + } + } + + private def a(name: String, dataType: DataType) = AttributeReference(name, dataType)() +} diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/VersionsSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/VersionsSuite.scala index d52e162acbd04..3eb127e23d486 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/VersionsSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/VersionsSuite.scala @@ -20,7 +20,9 @@ package org.apache.spark.sql.hive.client import java.io.File import org.apache.spark.{Logging, SparkFunSuite} +import org.apache.spark.sql.catalyst.expressions.{NamedExpression, Literal, AttributeReference, EqualTo} import org.apache.spark.sql.catalyst.util.quietly +import org.apache.spark.sql.types.IntegerType import org.apache.spark.util.Utils /** @@ -151,6 +153,12 @@ class VersionsSuite extends SparkFunSuite with Logging { client.getAllPartitions(client.getTable("default", "src_part")) } + test(s"$version: getPartitionsByFilter") { + client.getPartitionsByFilter(client.getTable("default", "src_part"), Seq(EqualTo( + AttributeReference("key", IntegerType, false)(NamedExpression.newExprId), + Literal(1)))) + } + test(s"$version: loadPartition") { client.loadPartition( emptyDir, diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala new file mode 100644 index 0000000000000..0375eb79add95 --- /dev/null +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala @@ -0,0 +1,507 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.hive.execution + +import org.apache.spark.sql.execution.aggregate.Aggregate2Sort +import org.apache.spark.sql.hive.test.TestHive +import org.apache.spark.sql.test.SQLTestUtils +import org.apache.spark.sql.types.{IntegerType, StringType, StructField, StructType} +import org.apache.spark.sql.{AnalysisException, QueryTest, Row} +import org.scalatest.BeforeAndAfterAll +import test.org.apache.spark.sql.hive.aggregate.{MyDoubleAvg, MyDoubleSum} + +class AggregationQuerySuite extends QueryTest with SQLTestUtils with BeforeAndAfterAll { + + override val sqlContext = TestHive + import sqlContext.implicits._ + + var originalUseAggregate2: Boolean = _ + + override def beforeAll(): Unit = { + originalUseAggregate2 = sqlContext.conf.useSqlAggregate2 + sqlContext.sql("set spark.sql.useAggregate2=true") + val data1 = Seq[(Integer, Integer)]( + (1, 10), + (null, -60), + (1, 20), + (1, 30), + (2, 0), + (null, -10), + (2, -1), + (2, null), + (2, null), + (null, 100), + (3, null), + (null, null), + (3, null)).toDF("key", "value") + data1.write.saveAsTable("agg1") + + val data2 = Seq[(Integer, Integer, Integer)]( + (1, 10, -10), + (null, -60, 60), + (1, 30, -30), + (1, 30, 30), + (2, 1, 1), + (null, -10, 10), + (2, -1, null), + (2, 1, 1), + (2, null, 1), + (null, 100, -10), + (3, null, 3), + (null, null, null), + (3, null, null)).toDF("key", "value1", "value2") + data2.write.saveAsTable("agg2") + + val emptyDF = sqlContext.createDataFrame( + sqlContext.sparkContext.emptyRDD[Row], + StructType(StructField("key", StringType) :: StructField("value", IntegerType) :: Nil)) + emptyDF.registerTempTable("emptyTable") + + // Register UDAFs + sqlContext.udaf.register("mydoublesum", new MyDoubleSum) + sqlContext.udaf.register("mydoubleavg", new MyDoubleAvg) + } + + override def afterAll(): Unit = { + sqlContext.sql("DROP TABLE IF EXISTS agg1") + sqlContext.sql("DROP TABLE IF EXISTS agg2") + sqlContext.dropTempTable("emptyTable") + sqlContext.sql(s"set spark.sql.useAggregate2=$originalUseAggregate2") + } + + test("empty table") { + // If there is no GROUP BY clause and the table is empty, we will generate a single row. + checkAnswer( + sqlContext.sql( + """ + |SELECT + | AVG(value), + | COUNT(*), + | COUNT(key), + | COUNT(value), + | FIRST(key), + | LAST(value), + | MAX(key), + | MIN(value), + | SUM(key) + |FROM emptyTable + """.stripMargin), + Row(null, 0, 0, 0, null, null, null, null, null) :: Nil) + + checkAnswer( + sqlContext.sql( + """ + |SELECT + | AVG(value), + | COUNT(*), + | COUNT(key), + | COUNT(value), + | FIRST(key), + | LAST(value), + | MAX(key), + | MIN(value), + | SUM(key), + | COUNT(DISTINCT value) + |FROM emptyTable + """.stripMargin), + Row(null, 0, 0, 0, null, null, null, null, null, 0) :: Nil) + + // If there is a GROUP BY clause and the table is empty, there is no output. + checkAnswer( + sqlContext.sql( + """ + |SELECT + | AVG(value), + | COUNT(*), + | COUNT(value), + | FIRST(value), + | LAST(value), + | MAX(value), + | MIN(value), + | SUM(value), + | COUNT(DISTINCT value) + |FROM emptyTable + |GROUP BY key + """.stripMargin), + Nil) + } + + test("only do grouping") { + checkAnswer( + sqlContext.sql( + """ + |SELECT key + |FROM agg1 + |GROUP BY key + """.stripMargin), + Row(1) :: Row(2) :: Row(3) :: Row(null) :: Nil) + + checkAnswer( + sqlContext.sql( + """ + |SELECT DISTINCT value1, key + |FROM agg2 + """.stripMargin), + Row(10, 1) :: + Row(-60, null) :: + Row(30, 1) :: + Row(1, 2) :: + Row(-10, null) :: + Row(-1, 2) :: + Row(null, 2) :: + Row(100, null) :: + Row(null, 3) :: + Row(null, null) :: Nil) + + checkAnswer( + sqlContext.sql( + """ + |SELECT value1, key + |FROM agg2 + |GROUP BY key, value1 + """.stripMargin), + Row(10, 1) :: + Row(-60, null) :: + Row(30, 1) :: + Row(1, 2) :: + Row(-10, null) :: + Row(-1, 2) :: + Row(null, 2) :: + Row(100, null) :: + Row(null, 3) :: + Row(null, null) :: Nil) + } + + test("case in-sensitive resolution") { + checkAnswer( + sqlContext.sql( + """ + |SELECT avg(value), kEY - 100 + |FROM agg1 + |GROUP BY Key - 100 + """.stripMargin), + Row(20.0, -99) :: Row(-0.5, -98) :: Row(null, -97) :: Row(10.0, null) :: Nil) + + checkAnswer( + sqlContext.sql( + """ + |SELECT sum(distinct value1), kEY - 100, count(distinct value1) + |FROM agg2 + |GROUP BY Key - 100 + """.stripMargin), + Row(40, -99, 2) :: Row(0, -98, 2) :: Row(null, -97, 0) :: Row(30, null, 3) :: Nil) + + checkAnswer( + sqlContext.sql( + """ + |SELECT valUe * key - 100 + |FROM agg1 + |GROUP BY vAlue * keY - 100 + """.stripMargin), + Row(-90) :: + Row(-80) :: + Row(-70) :: + Row(-100) :: + Row(-102) :: + Row(null) :: Nil) + } + + test("test average no key in output") { + checkAnswer( + sqlContext.sql( + """ + |SELECT avg(value) + |FROM agg1 + |GROUP BY key + """.stripMargin), + Row(-0.5) :: Row(20.0) :: Row(null) :: Row(10.0) :: Nil) + } + + test("test average") { + checkAnswer( + sqlContext.sql( + """ + |SELECT key, avg(value) + |FROM agg1 + |GROUP BY key + """.stripMargin), + Row(1, 20.0) :: Row(2, -0.5) :: Row(3, null) :: Row(null, 10.0) :: Nil) + + checkAnswer( + sqlContext.sql( + """ + |SELECT avg(value), key + |FROM agg1 + |GROUP BY key + """.stripMargin), + Row(20.0, 1) :: Row(-0.5, 2) :: Row(null, 3) :: Row(10.0, null) :: Nil) + + checkAnswer( + sqlContext.sql( + """ + |SELECT avg(value) + 1.5, key + 10 + |FROM agg1 + |GROUP BY key + 10 + """.stripMargin), + Row(21.5, 11) :: Row(1.0, 12) :: Row(null, 13) :: Row(11.5, null) :: Nil) + + checkAnswer( + sqlContext.sql( + """ + |SELECT avg(value) FROM agg1 + """.stripMargin), + Row(11.125) :: Nil) + + checkAnswer( + sqlContext.sql( + """ + |SELECT avg(null) + """.stripMargin), + Row(null) :: Nil) + } + + test("udaf") { + checkAnswer( + sqlContext.sql( + """ + |SELECT + | key, + | mydoublesum(value + 1.5 * key), + | mydoubleavg(value), + | avg(value - key), + | mydoublesum(value - 1.5 * key), + | avg(value) + |FROM agg1 + |GROUP BY key + """.stripMargin), + Row(1, 64.5, 120.0, 19.0, 55.5, 20.0) :: + Row(2, 5.0, 99.5, -2.5, -7.0, -0.5) :: + Row(3, null, null, null, null, null) :: + Row(null, null, 110.0, null, null, 10.0) :: Nil) + } + + test("non-AlgebraicAggregate aggreguate function") { + checkAnswer( + sqlContext.sql( + """ + |SELECT mydoublesum(value), key + |FROM agg1 + |GROUP BY key + """.stripMargin), + Row(60.0, 1) :: Row(-1.0, 2) :: Row(null, 3) :: Row(30.0, null) :: Nil) + + checkAnswer( + sqlContext.sql( + """ + |SELECT mydoublesum(value) FROM agg1 + """.stripMargin), + Row(89.0) :: Nil) + + checkAnswer( + sqlContext.sql( + """ + |SELECT mydoublesum(null) + """.stripMargin), + Row(null) :: Nil) + } + + test("non-AlgebraicAggregate and AlgebraicAggregate aggreguate function") { + checkAnswer( + sqlContext.sql( + """ + |SELECT mydoublesum(value), key, avg(value) + |FROM agg1 + |GROUP BY key + """.stripMargin), + Row(60.0, 1, 20.0) :: + Row(-1.0, 2, -0.5) :: + Row(null, 3, null) :: + Row(30.0, null, 10.0) :: Nil) + + checkAnswer( + sqlContext.sql( + """ + |SELECT + | mydoublesum(value + 1.5 * key), + | avg(value - key), + | key, + | mydoublesum(value - 1.5 * key), + | avg(value) + |FROM agg1 + |GROUP BY key + """.stripMargin), + Row(64.5, 19.0, 1, 55.5, 20.0) :: + Row(5.0, -2.5, 2, -7.0, -0.5) :: + Row(null, null, 3, null, null) :: + Row(null, null, null, null, 10.0) :: Nil) + } + + test("single distinct column set") { + // DISTINCT is not meaningful with Max and Min, so we just ignore the DISTINCT keyword. + checkAnswer( + sqlContext.sql( + """ + |SELECT + | min(distinct value1), + | sum(distinct value1), + | avg(value1), + | avg(value2), + | max(distinct value1) + |FROM agg2 + """.stripMargin), + Row(-60, 70.0, 101.0/9.0, 5.6, 100.0)) + + checkAnswer( + sqlContext.sql( + """ + |SELECT + | mydoubleavg(distinct value1), + | avg(value1), + | avg(value2), + | key, + | mydoubleavg(value1 - 1), + | mydoubleavg(distinct value1) * 0.1, + | avg(value1 + value2) + |FROM agg2 + |GROUP BY key + """.stripMargin), + Row(120.0, 70.0/3.0, -10.0/3.0, 1, 67.0/3.0 + 100.0, 12.0, 20.0) :: + Row(100.0, 1.0/3.0, 1.0, 2, -2.0/3.0 + 100.0, 10.0, 2.0) :: + Row(null, null, 3.0, 3, null, null, null) :: + Row(110.0, 10.0, 20.0, null, 109.0, 11.0, 30.0) :: Nil) + + checkAnswer( + sqlContext.sql( + """ + |SELECT + | key, + | mydoubleavg(distinct value1), + | mydoublesum(value2), + | mydoublesum(distinct value1), + | mydoubleavg(distinct value1), + | mydoubleavg(value1) + |FROM agg2 + |GROUP BY key + """.stripMargin), + Row(1, 120.0, -10.0, 40.0, 120.0, 70.0/3.0 + 100.0) :: + Row(2, 100.0, 3.0, 0.0, 100.0, 1.0/3.0 + 100.0) :: + Row(3, null, 3.0, null, null, null) :: + Row(null, 110.0, 60.0, 30.0, 110.0, 110.0) :: Nil) + } + + test("test count") { + checkAnswer( + sqlContext.sql( + """ + |SELECT + | count(value2), + | value1, + | count(*), + | count(1), + | key + |FROM agg2 + |GROUP BY key, value1 + """.stripMargin), + Row(1, 10, 1, 1, 1) :: + Row(1, -60, 1, 1, null) :: + Row(2, 30, 2, 2, 1) :: + Row(2, 1, 2, 2, 2) :: + Row(1, -10, 1, 1, null) :: + Row(0, -1, 1, 1, 2) :: + Row(1, null, 1, 1, 2) :: + Row(1, 100, 1, 1, null) :: + Row(1, null, 2, 2, 3) :: + Row(0, null, 1, 1, null) :: Nil) + + checkAnswer( + sqlContext.sql( + """ + |SELECT + | count(value2), + | value1, + | count(*), + | count(1), + | key, + | count(DISTINCT abs(value2)) + |FROM agg2 + |GROUP BY key, value1 + """.stripMargin), + Row(1, 10, 1, 1, 1, 1) :: + Row(1, -60, 1, 1, null, 1) :: + Row(2, 30, 2, 2, 1, 1) :: + Row(2, 1, 2, 2, 2, 1) :: + Row(1, -10, 1, 1, null, 1) :: + Row(0, -1, 1, 1, 2, 0) :: + Row(1, null, 1, 1, 2, 1) :: + Row(1, 100, 1, 1, null, 1) :: + Row(1, null, 2, 2, 3, 1) :: + Row(0, null, 1, 1, null, 0) :: Nil) + } + + test("error handling") { + sqlContext.sql(s"set spark.sql.useAggregate2=false") + var errorMessage = intercept[AnalysisException] { + sqlContext.sql( + """ + |SELECT + | key, + | sum(value + 1.5 * key), + | mydoublesum(value), + | mydoubleavg(value) + |FROM agg1 + |GROUP BY key + """.stripMargin).collect() + }.getMessage + assert(errorMessage.contains("implemented based on the new Aggregate Function interface")) + + // TODO: once we support Hive UDAF in the new interface, + // we can remove the following two tests. + sqlContext.sql(s"set spark.sql.useAggregate2=true") + errorMessage = intercept[AnalysisException] { + sqlContext.sql( + """ + |SELECT + | key, + | mydoublesum(value + 1.5 * key), + | stddev_samp(value) + |FROM agg1 + |GROUP BY key + """.stripMargin).collect() + }.getMessage + assert(errorMessage.contains("implemented based on the new Aggregate Function interface")) + + // This will fall back to the old aggregate + val newAggregateOperators = sqlContext.sql( + """ + |SELECT + | key, + | sum(value + 1.5 * key), + | stddev_samp(value) + |FROM agg1 + |GROUP BY key + """.stripMargin).queryExecution.executedPlan.collect { + case agg: Aggregate2Sort => agg + } + val message = + "We should fallback to the old aggregation code path if there is any aggregate function " + + "that cannot be converted to the new interface." + assert(newAggregateOperators.isEmpty, message) + + sqlContext.sql(s"set spark.sql.useAggregate2=true") + } +} diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveComparisonTest.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveComparisonTest.scala index c9dd4c0935a72..efb04bf3d5097 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveComparisonTest.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveComparisonTest.scala @@ -22,11 +22,11 @@ import java.io._ import org.scalatest.{BeforeAndAfterAll, GivenWhenThen} import org.apache.spark.{Logging, SparkFunSuite} -import org.apache.spark.sql.sources.DescribeCommand -import org.apache.spark.sql.execution.{SetCommand, ExplainCommand} import org.apache.spark.sql.catalyst.planning.PhysicalOperation import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.util._ +import org.apache.spark.sql.execution.{SetCommand, ExplainCommand} +import org.apache.spark.sql.execution.datasources.DescribeCommand import org.apache.spark.sql.hive.test.TestHive /** diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala index 991da2f829ae5..11a843becce69 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala @@ -85,6 +85,60 @@ class HiveQuerySuite extends HiveComparisonTest with BeforeAndAfter { } } + createQueryTest("SPARK-8976 Wrong Result for Rollup #1", + """ + SELECT count(*) AS cnt, key % 5,GROUPING__ID FROM src group by key%5 WITH ROLLUP + """.stripMargin) + + createQueryTest("SPARK-8976 Wrong Result for Rollup #2", + """ + SELECT + count(*) AS cnt, + key % 5 as k1, + key-5 as k2, + GROUPING__ID as k3 + FROM src group by key%5, key-5 + WITH ROLLUP ORDER BY cnt, k1, k2, k3 LIMIT 10 + """.stripMargin) + + createQueryTest("SPARK-8976 Wrong Result for Rollup #3", + """ + SELECT + count(*) AS cnt, + key % 5 as k1, + key-5 as k2, + GROUPING__ID as k3 + FROM (SELECT key, key%2, key - 5 FROM src) t group by key%5, key-5 + WITH ROLLUP ORDER BY cnt, k1, k2, k3 LIMIT 10 + """.stripMargin) + + createQueryTest("SPARK-8976 Wrong Result for CUBE #1", + """ + SELECT count(*) AS cnt, key % 5,GROUPING__ID FROM src group by key%5 WITH CUBE + """.stripMargin) + + createQueryTest("SPARK-8976 Wrong Result for CUBE #2", + """ + SELECT + count(*) AS cnt, + key % 5 as k1, + key-5 as k2, + GROUPING__ID as k3 + FROM (SELECT key, key%2, key - 5 FROM src) t group by key%5, key-5 + WITH CUBE ORDER BY cnt, k1, k2, k3 LIMIT 10 + """.stripMargin) + + createQueryTest("SPARK-8976 Wrong Result for GroupingSet", + """ + SELECT + count(*) AS cnt, + key % 5 as k1, + key-5 as k2, + GROUPING__ID as k3 + FROM (SELECT key, key%2, key - 5 FROM src) t group by key%5, key-5 + GROUPING SETS (key%5, key-5) ORDER BY cnt, k1, k2, k3 LIMIT 10 + """.stripMargin) + createQueryTest("insert table with generator with column name", """ | CREATE TABLE gen_tmp (key Int); diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/PruningSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/PruningSuite.scala index de6a41ce5bfcb..e83a7dc77e329 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/PruningSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/PruningSuite.scala @@ -151,7 +151,7 @@ class PruningSuite extends HiveComparisonTest with BeforeAndAfter { case p @ HiveTableScan(columns, relation, _) => val columnNames = columns.map(_.name) val partValues = if (relation.table.isPartitioned) { - p.prunePartitions(relation.hiveQlPartitions).map(_.getValues) + p.prunePartitions(relation.getHiveQlPartitions()).map(_.getValues) } else { Seq.empty } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala index 05a1f0094e5e1..03428265422e6 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala @@ -23,12 +23,12 @@ import org.apache.spark.sql._ import org.apache.spark.sql.catalyst.DefaultParserDialect import org.apache.spark.sql.catalyst.analysis.EliminateSubQueries import org.apache.spark.sql.catalyst.errors.DialectException +import org.apache.spark.sql.execution.datasources.LogicalRelation import org.apache.spark.sql.hive.test.TestHive import org.apache.spark.sql.hive.test.TestHive._ import org.apache.spark.sql.hive.test.TestHive.implicits._ import org.apache.spark.sql.hive.{HiveContext, HiveQLDialect, MetastoreRelation} import org.apache.spark.sql.parquet.ParquetRelation2 -import org.apache.spark.sql.sources.LogicalRelation import org.apache.spark.sql.types._ case class Nested1(f1: Nested2) diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcHadoopFsRelationSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcHadoopFsRelationSuite.scala index 080af5bb23c16..af3f468aaa5e9 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcHadoopFsRelationSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcHadoopFsRelationSuite.scala @@ -41,8 +41,7 @@ class OrcHadoopFsRelationSuite extends HadoopFsRelationTest { .parallelize(for (i <- 1 to 3) yield (i, s"val_$i", p1)) .toDF("a", "b", "p1") .write - .format("orc") - .save(partitionDir.toString) + .orc(partitionDir.toString) } val dataSchemaWithPartition = diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcPartitionDiscoverySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcPartitionDiscoverySuite.scala index 3c2efe329bfd5..d463e8fd626f9 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcPartitionDiscoverySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcPartitionDiscoverySuite.scala @@ -49,13 +49,13 @@ class OrcPartitionDiscoverySuite extends QueryTest with BeforeAndAfterAll { def makeOrcFile[T <: Product: ClassTag: TypeTag]( data: Seq[T], path: File): Unit = { - data.toDF().write.format("orc").mode("overwrite").save(path.getCanonicalPath) + data.toDF().write.mode("overwrite").orc(path.getCanonicalPath) } def makeOrcFile[T <: Product: ClassTag: TypeTag]( df: DataFrame, path: File): Unit = { - df.write.format("orc").mode("overwrite").save(path.getCanonicalPath) + df.write.mode("overwrite").orc(path.getCanonicalPath) } protected def withTempTable(tableName: String)(f: => Unit): Unit = { @@ -90,7 +90,7 @@ class OrcPartitionDiscoverySuite extends QueryTest with BeforeAndAfterAll { makePartitionDir(base, defaultPartitionName, "pi" -> pi, "ps" -> ps)) } - read.format("orc").load(base.getCanonicalPath).registerTempTable("t") + read.orc(base.getCanonicalPath).registerTempTable("t") withTempTable("t") { checkAnswer( @@ -137,7 +137,7 @@ class OrcPartitionDiscoverySuite extends QueryTest with BeforeAndAfterAll { makePartitionDir(base, defaultPartitionName, "pi" -> pi, "ps" -> ps)) } - read.format("orc").load(base.getCanonicalPath).registerTempTable("t") + read.orc(base.getCanonicalPath).registerTempTable("t") withTempTable("t") { checkAnswer( @@ -187,9 +187,8 @@ class OrcPartitionDiscoverySuite extends QueryTest with BeforeAndAfterAll { } read - .format("orc") .option(ConfVars.DEFAULTPARTITIONNAME.varname, defaultPartitionName) - .load(base.getCanonicalPath) + .orc(base.getCanonicalPath) .registerTempTable("t") withTempTable("t") { @@ -230,9 +229,8 @@ class OrcPartitionDiscoverySuite extends QueryTest with BeforeAndAfterAll { } read - .format("orc") .option(ConfVars.DEFAULTPARTITIONNAME.varname, defaultPartitionName) - .load(base.getCanonicalPath) + .orc(base.getCanonicalPath) .registerTempTable("t") withTempTable("t") { diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcQuerySuite.scala index ca131faaeef05..744d462938141 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcQuerySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcQuerySuite.scala @@ -63,14 +63,14 @@ class OrcQuerySuite extends QueryTest with BeforeAndAfterAll with OrcTest { withOrcFile(data) { file => checkAnswer( - sqlContext.read.format("orc").load(file), + sqlContext.read.orc(file), data.toDF().collect()) } } test("Read/write binary data") { withOrcFile(BinaryData("test".getBytes("utf8")) :: Nil) { file => - val bytes = read.format("orc").load(file).head().getAs[Array[Byte]](0) + val bytes = read.orc(file).head().getAs[Array[Byte]](0) assert(new String(bytes, "utf8") === "test") } } @@ -88,7 +88,7 @@ class OrcQuerySuite extends QueryTest with BeforeAndAfterAll with OrcTest { withOrcFile(data) { file => checkAnswer( - read.format("orc").load(file), + read.orc(file), data.toDF().collect()) } } @@ -158,7 +158,7 @@ class OrcQuerySuite extends QueryTest with BeforeAndAfterAll with OrcTest { withOrcFile(data) { file => checkAnswer( - read.format("orc").load(file), + read.orc(file), Row(Seq.fill(5)(null): _*)) } } @@ -310,7 +310,7 @@ class OrcQuerySuite extends QueryTest with BeforeAndAfterAll with OrcTest { """.stripMargin) val errorMessage = intercept[AnalysisException] { - sqlContext.read.format("orc").load(path) + sqlContext.read.orc(path) }.getMessage assert(errorMessage.contains("Failed to discover schema from ORC files")) @@ -323,7 +323,7 @@ class OrcQuerySuite extends QueryTest with BeforeAndAfterAll with OrcTest { |SELECT key, value FROM single """.stripMargin) - val df = sqlContext.read.format("orc").load(path) + val df = sqlContext.read.orc(path) assert(df.schema === singleRowDF.schema.asNullable) checkAnswer(df, singleRowDF) } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcTest.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcTest.scala index 5daf691aa8c53..9d76d6503a3e6 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcTest.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcTest.scala @@ -39,7 +39,7 @@ private[sql] trait OrcTest extends SQLTestUtils { (data: Seq[T]) (f: String => Unit): Unit = { withTempPath { file => - sparkContext.parallelize(data).toDF().write.format("orc").save(file.getCanonicalPath) + sparkContext.parallelize(data).toDF().write.orc(file.getCanonicalPath) f(file.getCanonicalPath) } } @@ -51,7 +51,7 @@ private[sql] trait OrcTest extends SQLTestUtils { protected def withOrcDataFrame[T <: Product: ClassTag: TypeTag] (data: Seq[T]) (f: DataFrame => Unit): Unit = { - withOrcFile(data)(path => f(sqlContext.read.format("orc").load(path))) + withOrcFile(data)(path => f(sqlContext.read.orc(path))) } /** @@ -70,11 +70,11 @@ private[sql] trait OrcTest extends SQLTestUtils { protected def makeOrcFile[T <: Product: ClassTag: TypeTag]( data: Seq[T], path: File): Unit = { - data.toDF().write.format("orc").mode(SaveMode.Overwrite).save(path.getCanonicalPath) + data.toDF().write.mode(SaveMode.Overwrite).orc(path.getCanonicalPath) } protected def makeOrcFile[T <: Product: ClassTag: TypeTag]( df: DataFrame, path: File): Unit = { - df.write.format("orc").mode(SaveMode.Overwrite).save(path.getCanonicalPath) + df.write.mode(SaveMode.Overwrite).orc(path.getCanonicalPath) } } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/parquetSuites.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/parquetSuites.scala index 9d79a4b007d66..82a8daf8b4b09 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/parquetSuites.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/parquetSuites.scala @@ -23,12 +23,12 @@ import org.scalatest.BeforeAndAfterAll import org.apache.spark.sql._ import org.apache.spark.sql.execution.{ExecutedCommand, PhysicalRDD} +import org.apache.spark.sql.execution.datasources.{InsertIntoDataSource, InsertIntoHadoopFsRelation, LogicalRelation} import org.apache.spark.sql.hive.execution.HiveTableScan import org.apache.spark.sql.hive.test.TestHive import org.apache.spark.sql.hive.test.TestHive._ import org.apache.spark.sql.hive.test.TestHive.implicits._ import org.apache.spark.sql.parquet.{ParquetRelation2, ParquetTableScan} -import org.apache.spark.sql.sources.{InsertIntoDataSource, InsertIntoHadoopFsRelation, LogicalRelation} import org.apache.spark.sql.test.SQLTestUtils import org.apache.spark.sql.types._ import org.apache.spark.util.Utils diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/sources/hadoopFsRelationSuites.scala b/sql/hive/src/test/scala/org/apache/spark/sql/sources/hadoopFsRelationSuites.scala index afecf9675e11f..2a8748d913569 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/sources/hadoopFsRelationSuites.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/sources/hadoopFsRelationSuites.scala @@ -17,10 +17,10 @@ package org.apache.spark.sql.sources -import scala.collection.JavaConversions._ - import java.io.File +import scala.collection.JavaConversions._ + import com.google.common.io.Files import org.apache.hadoop.conf.Configuration import org.apache.hadoop.fs.Path @@ -31,10 +31,12 @@ import org.apache.parquet.hadoop.ParquetOutputCommitter import org.apache.spark.{SparkException, SparkFunSuite} import org.apache.spark.deploy.SparkHadoopUtil import org.apache.spark.sql._ +import org.apache.spark.sql.execution.datasources.LogicalRelation import org.apache.spark.sql.hive.test.TestHive import org.apache.spark.sql.test.SQLTestUtils import org.apache.spark.sql.types._ + abstract class HadoopFsRelationTest extends QueryTest with SQLTestUtils { override lazy val sqlContext: SQLContext = TestHive @@ -132,7 +134,7 @@ abstract class HadoopFsRelationTest extends QueryTest with SQLTestUtils { test("save()/load() - non-partitioned table - ErrorIfExists") { withTempDir { file => - intercept[RuntimeException] { + intercept[AnalysisException] { testDF.write.format(dataSourceName).mode(SaveMode.ErrorIfExists).save(file.getCanonicalPath) } } @@ -231,7 +233,7 @@ abstract class HadoopFsRelationTest extends QueryTest with SQLTestUtils { test("save()/load() - partitioned table - ErrorIfExists") { withTempDir { file => - intercept[RuntimeException] { + intercept[AnalysisException] { partitionedTestDF.write .format(dataSourceName) .mode(SaveMode.ErrorIfExists) @@ -694,7 +696,7 @@ class ParquetHadoopFsRelationSuite extends HadoopFsRelationTest { // This should only complain that the destination directory already exists, rather than file // "empty" is not a Parquet file. assert { - intercept[RuntimeException] { + intercept[AnalysisException] { df.write.format("parquet").mode(SaveMode.ErrorIfExists).save(path) }.getMessage.contains("already exists") } diff --git a/streaming/src/main/scala/org/apache/spark/streaming/Checkpoint.scala b/streaming/src/main/scala/org/apache/spark/streaming/Checkpoint.scala index 5279331c9e122..65d4e933bf8e9 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/Checkpoint.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/Checkpoint.scala @@ -48,6 +48,8 @@ class Checkpoint(@transient ssc: StreamingContext, val checkpointTime: Time) // Reload properties for the checkpoint application since user wants to set a reload property // or spark had changed its value and user wants to set it back. val propertiesToReload = List( + "spark.driver.host", + "spark.driver.port", "spark.master", "spark.yarn.keytab", "spark.yarn.principal") diff --git a/streaming/src/main/scala/org/apache/spark/streaming/StreamingContext.scala b/streaming/src/main/scala/org/apache/spark/streaming/StreamingContext.scala index ec49d0f42d122..92438f1b1fbf7 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/StreamingContext.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/StreamingContext.scala @@ -192,11 +192,8 @@ class StreamingContext private[streaming] ( None } - /** Register streaming source to metrics system */ + /* Initializing a streamingSource to register metrics */ private val streamingSource = new StreamingSource(this) - assert(env != null) - assert(env.metricsSystem != null) - env.metricsSystem.registerSource(streamingSource) private var state: StreamingContextState = INITIALIZED @@ -204,6 +201,8 @@ class StreamingContext private[streaming] ( private var shutdownHookRef: AnyRef = _ + conf.getOption("spark.streaming.checkpoint.directory").foreach(checkpoint) + /** * Return the associated Spark context */ @@ -606,6 +605,9 @@ class StreamingContext private[streaming] ( } shutdownHookRef = Utils.addShutdownHook( StreamingContext.SHUTDOWN_HOOK_PRIORITY)(stopOnShutdown) + // Registering Streaming Metrics at the start of the StreamingContext + assert(env.metricsSystem != null) + env.metricsSystem.registerSource(streamingSource) uiTab.foreach(_.attach()) logInfo("StreamingContext started") case ACTIVE => @@ -682,6 +684,8 @@ class StreamingContext private[streaming] ( logWarning("StreamingContext has already been stopped") case ACTIVE => scheduler.stop(stopGracefully) + // Removing the streamingSource to de-register the metrics on stop() + env.metricsSystem.removeSource(streamingSource) uiTab.foreach(_.detach()) StreamingContext.setActiveContext(null) waiter.notifyStop() diff --git a/streaming/src/main/scala/org/apache/spark/streaming/receiver/ReceiverSupervisor.scala b/streaming/src/main/scala/org/apache/spark/streaming/receiver/ReceiverSupervisor.scala index 18a5bd7519fef..a7c220f426ecf 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/receiver/ReceiverSupervisor.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/receiver/ReceiverSupervisor.scala @@ -22,6 +22,7 @@ import java.util.concurrent.CountDownLatch import scala.collection.mutable.ArrayBuffer import scala.concurrent._ +import scala.util.control.NonFatal import org.apache.spark.{Logging, SparkConf} import org.apache.spark.storage.StreamBlockId @@ -36,7 +37,7 @@ private[streaming] abstract class ReceiverSupervisor( conf: SparkConf ) extends Logging { - /** Enumeration to identify current state of the StreamingContext */ + /** Enumeration to identify current state of the Receiver */ object ReceiverState extends Enumeration { type CheckpointState = Value val Initialized, Started, Stopped = Value @@ -100,8 +101,8 @@ private[streaming] abstract class ReceiverSupervisor( /** Called when supervisor is stopped */ protected def onStop(message: String, error: Option[Throwable]) { } - /** Called when receiver is started */ - protected def onReceiverStart() { } + /** Called when receiver is started. Return true if the driver accepts us */ + protected def onReceiverStart(): Boolean /** Called when receiver is stopped */ protected def onReceiverStop(message: String, error: Option[Throwable]) { } @@ -124,13 +125,17 @@ private[streaming] abstract class ReceiverSupervisor( /** Start receiver */ def startReceiver(): Unit = synchronized { try { - logInfo("Starting receiver") - receiver.onStart() - logInfo("Called receiver onStart") - onReceiverStart() - receiverState = Started + if (onReceiverStart()) { + logInfo("Starting receiver") + receiverState = Started + receiver.onStart() + logInfo("Called receiver onStart") + } else { + // The driver refused us + stop("Registered unsuccessfully because Driver refused to start receiver " + streamId, None) + } } catch { - case t: Throwable => + case NonFatal(t) => stop("Error starting receiver " + streamId, Some(t)) } } @@ -139,12 +144,19 @@ private[streaming] abstract class ReceiverSupervisor( def stopReceiver(message: String, error: Option[Throwable]): Unit = synchronized { try { logInfo("Stopping receiver with message: " + message + ": " + error.getOrElse("")) - receiverState = Stopped - receiver.onStop() - logInfo("Called receiver onStop") - onReceiverStop(message, error) + receiverState match { + case Initialized => + logWarning("Skip stopping receiver because it has not yet stared") + case Started => + receiverState = Stopped + receiver.onStop() + logInfo("Called receiver onStop") + onReceiverStop(message, error) + case Stopped => + logWarning("Receiver has been stopped") + } } catch { - case t: Throwable => + case NonFatal(t) => logError("Error stopping receiver " + streamId + t.getStackTraceString) } } @@ -170,7 +182,7 @@ private[streaming] abstract class ReceiverSupervisor( }(futureExecutionContext) } - /** Check if receiver has been marked for stopping */ + /** Check if receiver has been marked for starting */ def isReceiverStarted(): Boolean = { logDebug("state = " + receiverState) receiverState == Started diff --git a/streaming/src/main/scala/org/apache/spark/streaming/receiver/ReceiverSupervisorImpl.scala b/streaming/src/main/scala/org/apache/spark/streaming/receiver/ReceiverSupervisorImpl.scala index 0ada69b6e1aa1..2f6841ee8879c 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/receiver/ReceiverSupervisorImpl.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/receiver/ReceiverSupervisorImpl.scala @@ -168,7 +168,7 @@ private[streaming] class ReceiverSupervisorImpl( env.rpcEnv.stop(endpoint) } - override protected def onReceiverStart() { + override protected def onReceiverStart(): Boolean = { val msg = RegisterReceiver( streamId, receiver.getClass.getSimpleName, Utils.localHostName(), endpoint) trackerEndpoint.askWithRetry[Boolean](msg) diff --git a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobGenerator.scala b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobGenerator.scala index f5d41858646e4..9f2117ada61c0 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobGenerator.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobGenerator.scala @@ -22,7 +22,7 @@ import scala.util.{Failure, Success, Try} import org.apache.spark.{SparkEnv, Logging} import org.apache.spark.streaming.{Checkpoint, CheckpointWriter, Time} import org.apache.spark.streaming.util.RecurringTimer -import org.apache.spark.util.{Clock, EventLoop, ManualClock} +import org.apache.spark.util.{Utils, Clock, EventLoop, ManualClock} /** Event classes for JobGenerator */ private[scheduler] sealed trait JobGeneratorEvent @@ -47,11 +47,11 @@ class JobGenerator(jobScheduler: JobScheduler) extends Logging { val clockClass = ssc.sc.conf.get( "spark.streaming.clock", "org.apache.spark.util.SystemClock") try { - Class.forName(clockClass).newInstance().asInstanceOf[Clock] + Utils.classForName(clockClass).newInstance().asInstanceOf[Clock] } catch { case e: ClassNotFoundException if clockClass.startsWith("org.apache.spark.streaming") => val newClockClass = clockClass.replace("org.apache.spark.streaming", "org.apache.spark") - Class.forName(newClockClass).newInstance().asInstanceOf[Clock] + Utils.classForName(newClockClass).newInstance().asInstanceOf[Clock] } } diff --git a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceiverTracker.scala b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceiverTracker.scala index b0469ebccecc2..9cc6ffcd12f61 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceiverTracker.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceiverTracker.scala @@ -20,7 +20,6 @@ package org.apache.spark.streaming.scheduler import scala.collection.mutable.{ArrayBuffer, HashMap, SynchronizedMap} import scala.language.existentials import scala.math.max -import org.apache.spark.rdd._ import org.apache.spark.streaming.util.WriteAheadLogUtils import org.apache.spark.{Logging, SparkEnv, SparkException} @@ -47,6 +46,8 @@ private[streaming] case class ReportError(streamId: Int, message: String, error: private[streaming] case class DeregisterReceiver(streamId: Int, msg: String, error: String) extends ReceiverTrackerMessage +private[streaming] case object StopAllReceivers extends ReceiverTrackerMessage + /** * This class manages the execution of the receivers of ReceiverInputDStreams. Instance of * this class must be created after all input streams have been added and StreamingContext.start() @@ -71,13 +72,23 @@ class ReceiverTracker(ssc: StreamingContext, skipReceiverLaunch: Boolean = false ) private val listenerBus = ssc.scheduler.listenerBus + /** Enumeration to identify current state of the ReceiverTracker */ + object TrackerState extends Enumeration { + type TrackerState = Value + val Initialized, Started, Stopping, Stopped = Value + } + import TrackerState._ + + /** State of the tracker. Protected by "trackerStateLock" */ + @volatile private var trackerState = Initialized + // endpoint is created when generator starts. // This not being null means the tracker has been started and not stopped private var endpoint: RpcEndpointRef = null /** Start the endpoint and receiver execution thread. */ def start(): Unit = synchronized { - if (endpoint != null) { + if (isTrackerStarted) { throw new SparkException("ReceiverTracker already started") } @@ -86,20 +97,46 @@ class ReceiverTracker(ssc: StreamingContext, skipReceiverLaunch: Boolean = false "ReceiverTracker", new ReceiverTrackerEndpoint(ssc.env.rpcEnv)) if (!skipReceiverLaunch) receiverExecutor.start() logInfo("ReceiverTracker started") + trackerState = Started } } /** Stop the receiver execution thread. */ def stop(graceful: Boolean): Unit = synchronized { - if (!receiverInputStreams.isEmpty && endpoint != null) { + if (isTrackerStarted) { // First, stop the receivers - if (!skipReceiverLaunch) receiverExecutor.stop(graceful) + trackerState = Stopping + if (!skipReceiverLaunch) { + // Send the stop signal to all the receivers + endpoint.askWithRetry[Boolean](StopAllReceivers) + + // Wait for the Spark job that runs the receivers to be over + // That is, for the receivers to quit gracefully. + receiverExecutor.awaitTermination(10000) + + if (graceful) { + val pollTime = 100 + logInfo("Waiting for receiver job to terminate gracefully") + while (receiverInfo.nonEmpty || receiverExecutor.running) { + Thread.sleep(pollTime) + } + logInfo("Waited for receiver job to terminate gracefully") + } + + // Check if all the receivers have been deregistered or not + if (receiverInfo.nonEmpty) { + logWarning("Not all of the receivers have deregistered, " + receiverInfo) + } else { + logInfo("All of the receivers have deregistered successfully") + } + } // Finally, stop the endpoint ssc.env.rpcEnv.stop(endpoint) endpoint = null receivedBlockTracker.stop() logInfo("ReceiverTracker stopped") + trackerState = Stopped } } @@ -145,14 +182,23 @@ class ReceiverTracker(ssc: StreamingContext, skipReceiverLaunch: Boolean = false host: String, receiverEndpoint: RpcEndpointRef, senderAddress: RpcAddress - ) { + ): Boolean = { if (!receiverInputStreamIds.contains(streamId)) { throw new SparkException("Register received for unexpected id " + streamId) } - receiverInfo(streamId) = ReceiverInfo( - streamId, s"${typ}-${streamId}", receiverEndpoint, true, host) - listenerBus.post(StreamingListenerReceiverStarted(receiverInfo(streamId))) - logInfo("Registered receiver for stream " + streamId + " from " + senderAddress) + + if (isTrackerStopping || isTrackerStopped) { + false + } else { + // "stopReceivers" won't happen at the same time because both "registerReceiver" and are + // called in the event loop. So here we can assume "stopReceivers" has not yet been called. If + // "stopReceivers" is called later, it should be able to see this receiver. + receiverInfo(streamId) = ReceiverInfo( + streamId, s"${typ}-${streamId}", receiverEndpoint, true, host) + listenerBus.post(StreamingListenerReceiverStarted(receiverInfo(streamId))) + logInfo("Registered receiver for stream " + streamId + " from " + senderAddress) + true + } } /** Deregister a receiver */ @@ -227,20 +273,33 @@ class ReceiverTracker(ssc: StreamingContext, skipReceiverLaunch: Boolean = false override def receiveAndReply(context: RpcCallContext): PartialFunction[Any, Unit] = { case RegisterReceiver(streamId, typ, host, receiverEndpoint) => - registerReceiver(streamId, typ, host, receiverEndpoint, context.sender.address) - context.reply(true) + val successful = + registerReceiver(streamId, typ, host, receiverEndpoint, context.sender.address) + context.reply(successful) case AddBlock(receivedBlockInfo) => context.reply(addBlock(receivedBlockInfo)) case DeregisterReceiver(streamId, message, error) => deregisterReceiver(streamId, message, error) context.reply(true) + case StopAllReceivers => + assert(isTrackerStopping || isTrackerStopped) + stopReceivers() + context.reply(true) + } + + /** Send stop signal to the receivers. */ + private def stopReceivers() { + // Signal the receivers to stop + receiverInfo.values.flatMap { info => Option(info.endpoint)} + .foreach { _.send(StopReceiver) } + logInfo("Sent stop signal to all " + receiverInfo.size + " receivers") } } /** This thread class runs all the receivers on the cluster. */ class ReceiverLauncher { @transient val env = ssc.env - @volatile @transient private var running = false + @volatile @transient var running = false @transient val thread = new Thread() { override def run() { try { @@ -256,31 +315,6 @@ class ReceiverTracker(ssc: StreamingContext, skipReceiverLaunch: Boolean = false thread.start() } - def stop(graceful: Boolean) { - // Send the stop signal to all the receivers - stopReceivers() - - // Wait for the Spark job that runs the receivers to be over - // That is, for the receivers to quit gracefully. - thread.join(10000) - - if (graceful) { - val pollTime = 100 - logInfo("Waiting for receiver job to terminate gracefully") - while (receiverInfo.nonEmpty || running) { - Thread.sleep(pollTime) - } - logInfo("Waited for receiver job to terminate gracefully") - } - - // Check if all the receivers have been deregistered or not - if (receiverInfo.nonEmpty) { - logWarning("Not all of the receivers have deregistered, " + receiverInfo) - } else { - logInfo("All of the receivers have deregistered successfully") - } - } - /** * Get the list of executors excluding driver */ @@ -365,17 +399,30 @@ class ReceiverTracker(ssc: StreamingContext, skipReceiverLaunch: Boolean = false // Distribute the receivers and start them logInfo("Starting " + receivers.length + " receivers") running = true - ssc.sparkContext.runJob(tempRDD, ssc.sparkContext.clean(startReceiver)) - running = false - logInfo("All of the receivers have been terminated") + try { + ssc.sparkContext.runJob(tempRDD, ssc.sparkContext.clean(startReceiver)) + logInfo("All of the receivers have been terminated") + } finally { + running = false + } } - /** Stops the receivers. */ - private def stopReceivers() { - // Signal the receivers to stop - receiverInfo.values.flatMap { info => Option(info.endpoint)} - .foreach { _.send(StopReceiver) } - logInfo("Sent stop signal to all " + receiverInfo.size + " receivers") + /** + * Wait until the Spark job that runs the receivers is terminated, or return when + * `milliseconds` elapses + */ + def awaitTermination(milliseconds: Long): Unit = { + thread.join(milliseconds) } } + + /** Check if tracker has been marked for starting */ + private def isTrackerStarted(): Boolean = trackerState == Started + + /** Check if tracker has been marked for stopping */ + private def isTrackerStopping(): Boolean = trackerState == Stopping + + /** Check if tracker has been marked for stopped */ + private def isTrackerStopped(): Boolean = trackerState == Stopped + } diff --git a/streaming/src/test/scala/org/apache/spark/streaming/CheckpointSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/CheckpointSuite.scala index 6a94928076236..d308ac05a54fe 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/CheckpointSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/CheckpointSuite.scala @@ -191,8 +191,51 @@ class CheckpointSuite extends TestSuiteBase { } } + // This tests if "spark.driver.host" and "spark.driver.port" is set by user, can be recovered + // with correct value. + test("get correct spark.driver.[host|port] from checkpoint") { + val conf = Map("spark.driver.host" -> "localhost", "spark.driver.port" -> "9999") + conf.foreach(kv => System.setProperty(kv._1, kv._2)) + ssc = new StreamingContext(master, framework, batchDuration) + val originalConf = ssc.conf + assert(originalConf.get("spark.driver.host") === "localhost") + assert(originalConf.get("spark.driver.port") === "9999") + + val cp = new Checkpoint(ssc, Time(1000)) + ssc.stop() + + // Serialize/deserialize to simulate write to storage and reading it back + val newCp = Utils.deserialize[Checkpoint](Utils.serialize(cp)) + + val newCpConf = newCp.createSparkConf() + assert(newCpConf.contains("spark.driver.host")) + assert(newCpConf.contains("spark.driver.port")) + assert(newCpConf.get("spark.driver.host") === "localhost") + assert(newCpConf.get("spark.driver.port") === "9999") + + // Check if all the parameters have been restored + ssc = new StreamingContext(null, newCp, null) + val restoredConf = ssc.conf + assert(restoredConf.get("spark.driver.host") === "localhost") + assert(restoredConf.get("spark.driver.port") === "9999") + ssc.stop() + + // If spark.driver.host and spark.driver.host is not set in system property, these two + // parameters should not be presented in the newly recovered conf. + conf.foreach(kv => System.clearProperty(kv._1)) + val newCpConf1 = newCp.createSparkConf() + assert(!newCpConf1.contains("spark.driver.host")) + assert(!newCpConf1.contains("spark.driver.port")) + + // Spark itself will dispatch a random, not-used port for spark.driver.port if it is not set + // explicitly. + ssc = new StreamingContext(null, newCp, null) + val restoredConf1 = ssc.conf + assert(restoredConf1.get("spark.driver.host") === "localhost") + assert(restoredConf1.get("spark.driver.port") !== "9999") + } - // This tests whether the systm can recover from a master failure with simple + // This tests whether the system can recover from a master failure with simple // non-stateful operations. This assumes as reliable, replayable input // source - TestInputDStream. test("recovery with map and reduceByKey operations") { diff --git a/streaming/src/test/scala/org/apache/spark/streaming/ReceiverSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/ReceiverSuite.scala index 5d7127627eea5..13b4d17c86183 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/ReceiverSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/ReceiverSuite.scala @@ -346,6 +346,8 @@ class ReceiverSuite extends TestSuiteBase with Timeouts with Serializable { def reportError(message: String, throwable: Throwable) { errors += throwable } + + override protected def onReceiverStart(): Boolean = true } /** diff --git a/streaming/src/test/scala/org/apache/spark/streaming/StreamingContextSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/StreamingContextSuite.scala index 56b4ce5638a51..4bba9691f8aa5 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/StreamingContextSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/StreamingContextSuite.scala @@ -20,20 +20,23 @@ package org.apache.spark.streaming import java.io.{File, NotSerializableException} import java.util.concurrent.atomic.AtomicInteger +import scala.collection.mutable.ArrayBuffer import scala.collection.mutable.Queue import org.apache.commons.io.FileUtils +import org.scalatest.{Assertions, BeforeAndAfter, PrivateMethodTester} import org.scalatest.concurrent.Eventually._ import org.scalatest.concurrent.Timeouts import org.scalatest.exceptions.TestFailedDueToTimeoutException import org.scalatest.time.SpanSugar._ -import org.scalatest.{Assertions, BeforeAndAfter} +import org.apache.spark.{Logging, SparkConf, SparkContext, SparkFunSuite} +import org.apache.spark.metrics.MetricsSystem +import org.apache.spark.metrics.source.Source import org.apache.spark.storage.StorageLevel import org.apache.spark.streaming.dstream.DStream import org.apache.spark.streaming.receiver.Receiver import org.apache.spark.util.Utils -import org.apache.spark.{Logging, SparkConf, SparkContext, SparkException, SparkFunSuite} class StreamingContextSuite extends SparkFunSuite with BeforeAndAfter with Timeouts with Logging { @@ -112,6 +115,15 @@ class StreamingContextSuite extends SparkFunSuite with BeforeAndAfter with Timeo assert(ssc.conf.getTimeAsSeconds("spark.cleaner.ttl", "-1") === 10) } + test("checkPoint from conf") { + val checkpointDirectory = Utils.createTempDir().getAbsolutePath() + + val myConf = SparkContext.updatedConf(new SparkConf(false), master, appName) + myConf.set("spark.streaming.checkpoint.directory", checkpointDirectory) + val ssc = new StreamingContext(myConf, batchDuration) + assert(ssc.checkpointDir != null) + } + test("state matching") { import StreamingContextState._ assert(INITIALIZED === INITIALIZED) @@ -273,6 +285,21 @@ class StreamingContextSuite extends SparkFunSuite with BeforeAndAfter with Timeo } } + test("stop gracefully even if a receiver misses StopReceiver") { + // This is not a deterministic unit. But if this unit test is flaky, then there is definitely + // something wrong. See SPARK-5681 + val conf = new SparkConf().setMaster(master).setAppName(appName) + sc = new SparkContext(conf) + ssc = new StreamingContext(sc, Milliseconds(100)) + val input = ssc.receiverStream(new TestReceiver) + input.foreachRDD(_ => {}) + ssc.start() + // Call `ssc.stop` at once so that it's possible that the receiver will miss "StopReceiver" + failAfter(30000 millis) { + ssc.stop(stopSparkContext = true, stopGracefully = true) + } + } + test("stop slow receiver gracefully") { val conf = new SparkConf().setMaster(master).setAppName(appName) conf.set("spark.streaming.gracefulStopTimeout", "20000s") @@ -299,6 +326,25 @@ class StreamingContextSuite extends SparkFunSuite with BeforeAndAfter with Timeo Thread.sleep(100) } + test ("registering and de-registering of streamingSource") { + val conf = new SparkConf().setMaster(master).setAppName(appName) + ssc = new StreamingContext(conf, batchDuration) + assert(ssc.getState() === StreamingContextState.INITIALIZED) + addInputStream(ssc).register() + ssc.start() + + val sources = StreamingContextSuite.getSources(ssc.env.metricsSystem) + val streamingSource = StreamingContextSuite.getStreamingSource(ssc) + assert(sources.contains(streamingSource)) + assert(ssc.getState() === StreamingContextState.ACTIVE) + + ssc.stop() + val sourcesAfterStop = StreamingContextSuite.getSources(ssc.env.metricsSystem) + val streamingSourceAfterStop = StreamingContextSuite.getStreamingSource(ssc) + assert(ssc.getState() === StreamingContextState.STOPPED) + assert(!sourcesAfterStop.contains(streamingSourceAfterStop)) + } + test("awaitTermination") { ssc = new StreamingContext(master, appName, batchDuration) val inputStream = addInputStream(ssc) @@ -811,3 +857,18 @@ package object testPackage extends Assertions { } } } + +/** + * Helper methods for testing StreamingContextSuite + * This includes methods to access private methods and fields in StreamingContext and MetricsSystem + */ +private object StreamingContextSuite extends PrivateMethodTester { + private val _sources = PrivateMethod[ArrayBuffer[Source]]('sources) + private def getSources(metricsSystem: MetricsSystem): ArrayBuffer[Source] = { + metricsSystem.invokePrivate(_sources()) + } + private val _streamingSource = PrivateMethod[StreamingSource]('streamingSource) + private def getStreamingSource(streamingContext: StreamingContext): StreamingSource = { + streamingContext.invokePrivate(_streamingSource()) + } +} diff --git a/tools/pom.xml b/tools/pom.xml index feffde4c857eb..298ee2348b58e 100644 --- a/tools/pom.xml +++ b/tools/pom.xml @@ -76,10 +76,6 @@ org.apache.maven.plugins maven-source-plugin - - org.codehaus.mojo - build-helper-maven-plugin - diff --git a/tools/src/main/scala/org/apache/spark/tools/GenerateMIMAIgnore.scala b/tools/src/main/scala/org/apache/spark/tools/GenerateMIMAIgnore.scala index 9483d2b692ab5..9418beb6b3e3a 100644 --- a/tools/src/main/scala/org/apache/spark/tools/GenerateMIMAIgnore.scala +++ b/tools/src/main/scala/org/apache/spark/tools/GenerateMIMAIgnore.scala @@ -15,6 +15,7 @@ * limitations under the License. */ +// scalastyle:off classforname package org.apache.spark.tools import java.io.File @@ -188,3 +189,4 @@ object GenerateMIMAIgnore { classes } } +// scalastyle:on classforname diff --git a/unsafe/src/main/java/org/apache/spark/unsafe/bitset/BitSet.java b/unsafe/src/main/java/org/apache/spark/unsafe/bitset/BitSet.java index 28e23da108ebe..7c124173b0bbb 100644 --- a/unsafe/src/main/java/org/apache/spark/unsafe/bitset/BitSet.java +++ b/unsafe/src/main/java/org/apache/spark/unsafe/bitset/BitSet.java @@ -90,7 +90,7 @@ public boolean isSet(int index) { * To iterate over the true bits in a BitSet, use the following loop: *
        * 
    -   *  for (long i = bs.nextSetBit(0); i >= 0; i = bs.nextSetBit(i + 1)) {
    +   *  for (long i = bs.nextSetBit(0); i >= 0; i = bs.nextSetBit(i + 1)) {
        *    // operate on index i here
        *  }
        * 
    diff --git a/unsafe/src/main/java/org/apache/spark/unsafe/bitset/BitSetMethods.java b/unsafe/src/main/java/org/apache/spark/unsafe/bitset/BitSetMethods.java
    index 0987191c1c636..27462c7fa5e62 100644
    --- a/unsafe/src/main/java/org/apache/spark/unsafe/bitset/BitSetMethods.java
    +++ b/unsafe/src/main/java/org/apache/spark/unsafe/bitset/BitSetMethods.java
    @@ -87,7 +87,7 @@ public static boolean anySet(Object baseObject, long baseOffset, long bitSetWidt
        * To iterate over the true bits in a BitSet, use the following loop:
        * 
        * 
    -   *  for (long i = bs.nextSetBit(0, sizeInWords); i >= 0; i = bs.nextSetBit(i + 1, sizeInWords)) {
    +   *  for (long i = bs.nextSetBit(0, sizeInWords); i >= 0; i = bs.nextSetBit(i + 1, sizeInWords)) {
        *    // operate on index i here
        *  }
        * 
    diff --git a/unsafe/src/main/java/org/apache/spark/unsafe/hash/Murmur3_x86_32.java b/unsafe/src/main/java/org/apache/spark/unsafe/hash/Murmur3_x86_32.java
    index 85cd02469adb7..61f483ced3217 100644
    --- a/unsafe/src/main/java/org/apache/spark/unsafe/hash/Murmur3_x86_32.java
    +++ b/unsafe/src/main/java/org/apache/spark/unsafe/hash/Murmur3_x86_32.java
    @@ -44,12 +44,16 @@ public int hashInt(int input) {
         return fmix(h1, 4);
       }
     
    -  public int hashUnsafeWords(Object baseObject, long baseOffset, int lengthInBytes) {
    +  public int hashUnsafeWords(Object base, long offset, int lengthInBytes) {
    +    return hashUnsafeWords(base, offset, lengthInBytes, seed);
    +  }
    +
    +  public static int hashUnsafeWords(Object base, long offset, int lengthInBytes, int seed) {
         // This is based on Guava's `Murmur32_Hasher.processRemaining(ByteBuffer)` method.
         assert (lengthInBytes % 8 == 0): "lengthInBytes must be a multiple of 8 (word-aligned)";
         int h1 = seed;
    -    for (int offset = 0; offset < lengthInBytes; offset += 4) {
    -      int halfWord = PlatformDependent.UNSAFE.getInt(baseObject, baseOffset + offset);
    +    for (int i = 0; i < lengthInBytes; i += 4) {
    +      int halfWord = PlatformDependent.UNSAFE.getInt(base, offset + i);
           int k1 = mixK1(halfWord);
           h1 = mixH1(h1, k1);
         }
    diff --git a/unsafe/src/main/java/org/apache/spark/unsafe/map/BytesToBytesMap.java b/unsafe/src/main/java/org/apache/spark/unsafe/map/BytesToBytesMap.java
    index 8d8c08de52b84..d0bde69cc1068 100644
    --- a/unsafe/src/main/java/org/apache/spark/unsafe/map/BytesToBytesMap.java
    +++ b/unsafe/src/main/java/org/apache/spark/unsafe/map/BytesToBytesMap.java
    @@ -404,14 +404,17 @@ public int getValueLength() {
          * at the value address.
          * 

    * It is only valid to call this method immediately after calling `lookup()` using the same key. + *

    *

    * The key and value must be word-aligned (that is, their sizes must multiples of 8). + *

    *

    * After calling this method, calls to `get[Key|Value]Address()` and `get[Key|Value]Length` * will return information on the data stored by this `putNewKey` call. + *

    *

    * As an example usage, here's the proper way to store a new key: - *

    + *

    *
          *   Location loc = map.lookup(keyBaseObject, keyBaseOffset, keyLengthInBytes);
          *   if (!loc.isDefined()) {
    @@ -420,6 +423,7 @@ public int getValueLength() {
          * 
    *

    * Unspecified behavior if the key is not defined. + *

    */ public void putNewKey( Object keyBaseObject, diff --git a/unsafe/src/main/java/org/apache/spark/unsafe/types/Interval.java b/unsafe/src/main/java/org/apache/spark/unsafe/types/Interval.java index eb7475e9df869..71b1a85a818ea 100644 --- a/unsafe/src/main/java/org/apache/spark/unsafe/types/Interval.java +++ b/unsafe/src/main/java/org/apache/spark/unsafe/types/Interval.java @@ -62,6 +62,7 @@ public static Interval fromString(String s) { if (s == null) { return null; } + s = s.trim(); Matcher m = p.matcher(s); if (!m.matches() || s.equals("interval")) { return null; @@ -86,6 +87,22 @@ public Interval(int months, long microseconds) { this.microseconds = microseconds; } + public Interval add(Interval that) { + int months = this.months + that.months; + long microseconds = this.microseconds + that.microseconds; + return new Interval(months, microseconds); + } + + public Interval subtract(Interval that) { + int months = this.months - that.months; + long microseconds = this.microseconds - that.microseconds; + return new Interval(months, microseconds); + } + + public Interval negate() { + return new Interval(-this.months, -this.microseconds); + } + @Override public boolean equals(Object other) { if (this == other) return true; diff --git a/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java b/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java index 60d050b0a0c97..946d355f1fc28 100644 --- a/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java +++ b/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java @@ -20,7 +20,9 @@ import javax.annotation.Nonnull; import java.io.Serializable; import java.io.UnsupportedEncodingException; +import java.util.Arrays; +import org.apache.spark.unsafe.PlatformDependent; import org.apache.spark.unsafe.array.ByteArrayMethods; import static org.apache.spark.unsafe.PlatformDependent.*; @@ -48,6 +50,8 @@ public final class UTF8String implements Comparable, Serializable { 5, 5, 5, 5, 6, 6, 6, 6}; + public static final UTF8String EMPTY_UTF8 = UTF8String.fromString(""); + /** * Creates an UTF8String from byte array, which should be encoded in UTF-8. * @@ -76,6 +80,15 @@ public static UTF8String fromString(String str) { } } + /** + * Creates an UTF8String that contains `length` spaces. + */ + public static UTF8String blankString(int length) { + byte[] spaces = new byte[length]; + Arrays.fill(spaces, (byte) ' '); + return fromBytes(spaces); + } + protected UTF8String(Object base, long offset, int size) { this.base = base; this.offset = offset; @@ -152,6 +165,18 @@ public UTF8String substring(final int start, final int until) { return fromBytes(bytes); } + public UTF8String substringSQL(int pos, int length) { + // Information regarding the pos calculation: + // Hive and SQL use one-based indexing for SUBSTR arguments but also accept zero and + // negative indices for start positions. If a start index i is greater than 0, it + // refers to element i-1 in the sequence. If a start index i is less than 0, it refers + // to the -ith element before the end of the sequence. If a start index i is 0, it + // refers to the first element. + int start = (pos > 0) ? pos -1 : ((pos < 0) ? numChars() + pos : 0); + int end = (length == Integer.MAX_VALUE) ? Integer.MAX_VALUE : start + length; + return substring(start, end); + } + /** * Returns whether this contains `substring` or not. */ @@ -322,7 +347,7 @@ public int indexOf(UTF8String v, int start) { } i += numBytesForFirstByte(getByte(i)); c += 1; - } while(i < numBytes); + } while (i < numBytes); return -1; } @@ -330,8 +355,8 @@ public int indexOf(UTF8String v, int start) { /** * Returns str, right-padded with pad to a length of len * For example: - * ('hi', 5, '??') => 'hi???' - * ('hi', 1, '??') => 'h' + * ('hi', 5, '??') => 'hi???' + * ('hi', 1, '??') => 'h' */ public UTF8String rpad(int len, UTF8String pad) { int spaces = len - this.numChars(); // number of char need to pad @@ -363,8 +388,8 @@ public UTF8String rpad(int len, UTF8String pad) { /** * Returns str, left-padded with pad to a length of len. * For example: - * ('hi', 5, '??') => '???hi' - * ('hi', 1, '??') => 'h' + * ('hi', 5, '??') => '???hi' + * ('hi', 1, '??') => 'h' */ public UTF8String lpad(int len, UTF8String pad) { int spaces = len - this.numChars(); // number of char need to pad @@ -395,6 +420,94 @@ public UTF8String lpad(int len, UTF8String pad) { } } + /** + * Concatenates input strings together into a single string. Returns null if any input is null. + */ + public static UTF8String concat(UTF8String... inputs) { + // Compute the total length of the result. + int totalLength = 0; + for (int i = 0; i < inputs.length; i++) { + if (inputs[i] != null) { + totalLength += inputs[i].numBytes; + } else { + return null; + } + } + + // Allocate a new byte array, and copy the inputs one by one into it. + final byte[] result = new byte[totalLength]; + int offset = 0; + for (int i = 0; i < inputs.length; i++) { + int len = inputs[i].numBytes; + PlatformDependent.copyMemory( + inputs[i].base, inputs[i].offset, + result, PlatformDependent.BYTE_ARRAY_OFFSET + offset, + len); + offset += len; + } + return fromBytes(result); + } + + /** + * Concatenates input strings together into a single string using the separator. + * A null input is skipped. For example, concat(",", "a", null, "c") would yield "a,c". + */ + public static UTF8String concatWs(UTF8String separator, UTF8String... inputs) { + if (separator == null) { + return null; + } + + int numInputBytes = 0; // total number of bytes from the inputs + int numInputs = 0; // number of non-null inputs + for (int i = 0; i < inputs.length; i++) { + if (inputs[i] != null) { + numInputBytes += inputs[i].numBytes; + numInputs++; + } + } + + if (numInputs == 0) { + // Return an empty string if there is no input, or all the inputs are null. + return fromBytes(new byte[0]); + } + + // Allocate a new byte array, and copy the inputs one by one into it. + // The size of the new array is the size of all inputs, plus the separators. + final byte[] result = new byte[numInputBytes + (numInputs - 1) * separator.numBytes]; + int offset = 0; + + for (int i = 0, j = 0; i < inputs.length; i++) { + if (inputs[i] != null) { + int len = inputs[i].numBytes; + PlatformDependent.copyMemory( + inputs[i].base, inputs[i].offset, + result, PlatformDependent.BYTE_ARRAY_OFFSET + offset, + len); + offset += len; + + j++; + // Add separator if this is not the last input. + if (j < numInputs) { + PlatformDependent.copyMemory( + separator.base, separator.offset, + result, PlatformDependent.BYTE_ARRAY_OFFSET + offset, + separator.numBytes); + offset += separator.numBytes; + } + } + } + return fromBytes(result); + } + + public UTF8String[] split(UTF8String pattern, int limit) { + String[] splits = toString().split(pattern.toString(), limit); + UTF8String[] res = new UTF8String[splits.length]; + for (int i = 0; i < res.length; i++) { + res[i] = fromString(splits[i]); + } + return res; + } + @Override public String toString() { try { @@ -413,7 +526,7 @@ public UTF8String clone() { } @Override - public int compareTo(final UTF8String other) { + public int compareTo(@Nonnull final UTF8String other) { int len = Math.min(numBytes, other.numBytes); // TODO: compare 8 bytes as unsigned long for (int i = 0; i < len; i ++) { @@ -434,7 +547,7 @@ public int compare(final UTF8String other) { public boolean equals(final Object other) { if (other instanceof UTF8String) { UTF8String o = (UTF8String) other; - if (numBytes != o.numBytes){ + if (numBytes != o.numBytes) { return false; } return ByteArrayMethods.arrayEquals(base, offset, o.base, o.offset, numBytes); diff --git a/unsafe/src/test/java/org/apache/spark/unsafe/types/IntervalSuite.java b/unsafe/src/test/java/org/apache/spark/unsafe/types/IntervalSuite.java index 44a949a371f2b..d29517cda66a3 100644 --- a/unsafe/src/test/java/org/apache/spark/unsafe/types/IntervalSuite.java +++ b/unsafe/src/test/java/org/apache/spark/unsafe/types/IntervalSuite.java @@ -75,6 +75,12 @@ public void fromStringTest() { Interval result = new Interval(-5 * 12 + 23, 0); assertEquals(Interval.fromString(input), result); + input = "interval -5 years 23 month "; + assertEquals(Interval.fromString(input), result); + + input = " interval -5 years 23 month "; + assertEquals(Interval.fromString(input), result); + // Error cases input = "interval 3month 1 hour"; assertEquals(Interval.fromString(input), null); @@ -95,6 +101,44 @@ public void fromStringTest() { assertEquals(Interval.fromString(input), null); } + @Test + public void addTest() { + String input = "interval 3 month 1 hour"; + String input2 = "interval 2 month 100 hour"; + + Interval interval = Interval.fromString(input); + Interval interval2 = Interval.fromString(input2); + + assertEquals(interval.add(interval2), new Interval(5, 101 * MICROS_PER_HOUR)); + + input = "interval -10 month -81 hour"; + input2 = "interval 75 month 200 hour"; + + interval = Interval.fromString(input); + interval2 = Interval.fromString(input2); + + assertEquals(interval.add(interval2), new Interval(65, 119 * MICROS_PER_HOUR)); + } + + @Test + public void subtractTest() { + String input = "interval 3 month 1 hour"; + String input2 = "interval 2 month 100 hour"; + + Interval interval = Interval.fromString(input); + Interval interval2 = Interval.fromString(input2); + + assertEquals(interval.subtract(interval2), new Interval(1, -99 * MICROS_PER_HOUR)); + + input = "interval -10 month -81 hour"; + input2 = "interval 75 month 200 hour"; + + interval = Interval.fromString(input); + interval2 = Interval.fromString(input2); + + assertEquals(interval.subtract(interval2), new Interval(-85, -281 * MICROS_PER_HOUR)); + } + private void testSingleUnit(String unit, int number, int months, long microseconds) { String input1 = "interval " + number + " " + unit; String input2 = "interval " + number + " " + unit + "s"; diff --git a/unsafe/src/test/java/org/apache/spark/unsafe/types/UTF8StringSuite.java b/unsafe/src/test/java/org/apache/spark/unsafe/types/UTF8StringSuite.java index 694bdc29f39d1..e2a5628ff4d93 100644 --- a/unsafe/src/test/java/org/apache/spark/unsafe/types/UTF8StringSuite.java +++ b/unsafe/src/test/java/org/apache/spark/unsafe/types/UTF8StringSuite.java @@ -18,6 +18,7 @@ package org.apache.spark.unsafe.types; import java.io.UnsupportedEncodingException; +import java.util.Arrays; import org.junit.Test; @@ -54,6 +55,14 @@ public void basicTest() throws UnsupportedEncodingException { checkBasic("大 千 世 界", 7); } + @Test + public void emptyStringTest() { + assertEquals(fromString(""), EMPTY_UTF8); + assertEquals(fromBytes(new byte[0]), EMPTY_UTF8); + assertEquals(0, EMPTY_UTF8.numChars()); + assertEquals(0, EMPTY_UTF8.numBytes()); + } + @Test public void compareTo() { assertTrue(fromString("abc").compareTo(fromString("ABC")) > 0); @@ -86,9 +95,57 @@ public void upperAndLower() { testUpperandLower("大千世界 数据砖头", "大千世界 数据砖头"); } + @Test + public void concatTest() { + assertEquals(EMPTY_UTF8, concat()); + assertEquals(null, concat((UTF8String) null)); + assertEquals(EMPTY_UTF8, concat(EMPTY_UTF8)); + assertEquals(fromString("ab"), concat(fromString("ab"))); + assertEquals(fromString("ab"), concat(fromString("a"), fromString("b"))); + assertEquals(fromString("abc"), concat(fromString("a"), fromString("b"), fromString("c"))); + assertEquals(null, concat(fromString("a"), null, fromString("c"))); + assertEquals(null, concat(fromString("a"), null, null)); + assertEquals(null, concat(null, null, null)); + assertEquals(fromString("数据砖头"), concat(fromString("数据"), fromString("砖头"))); + } + + @Test + public void concatWsTest() { + // Returns null if the separator is null + assertEquals(null, concatWs(null, (UTF8String)null)); + assertEquals(null, concatWs(null, fromString("a"))); + + // If separator is null, concatWs should skip all null inputs and never return null. + UTF8String sep = fromString("哈哈"); + assertEquals( + EMPTY_UTF8, + concatWs(sep, EMPTY_UTF8)); + assertEquals( + fromString("ab"), + concatWs(sep, fromString("ab"))); + assertEquals( + fromString("a哈哈b"), + concatWs(sep, fromString("a"), fromString("b"))); + assertEquals( + fromString("a哈哈b哈哈c"), + concatWs(sep, fromString("a"), fromString("b"), fromString("c"))); + assertEquals( + fromString("a哈哈c"), + concatWs(sep, fromString("a"), null, fromString("c"))); + assertEquals( + fromString("a"), + concatWs(sep, fromString("a"), null, null)); + assertEquals( + EMPTY_UTF8, + concatWs(sep, null, null, null)); + assertEquals( + fromString("数据哈哈砖头"), + concatWs(sep, fromString("数据"), fromString("砖头"))); + } + @Test public void contains() { - assertTrue(fromString("").contains(fromString(""))); + assertTrue(EMPTY_UTF8.contains(EMPTY_UTF8)); assertTrue(fromString("hello").contains(fromString("ello"))); assertFalse(fromString("hello").contains(fromString("vello"))); assertFalse(fromString("hello").contains(fromString("hellooo"))); @@ -99,7 +156,7 @@ public void contains() { @Test public void startsWith() { - assertTrue(fromString("").startsWith(fromString(""))); + assertTrue(EMPTY_UTF8.startsWith(EMPTY_UTF8)); assertTrue(fromString("hello").startsWith(fromString("hell"))); assertFalse(fromString("hello").startsWith(fromString("ell"))); assertFalse(fromString("hello").startsWith(fromString("hellooo"))); @@ -110,7 +167,7 @@ public void startsWith() { @Test public void endsWith() { - assertTrue(fromString("").endsWith(fromString(""))); + assertTrue(EMPTY_UTF8.endsWith(EMPTY_UTF8)); assertTrue(fromString("hello").endsWith(fromString("ello"))); assertFalse(fromString("hello").endsWith(fromString("ellov"))); assertFalse(fromString("hello").endsWith(fromString("hhhello"))); @@ -121,7 +178,7 @@ public void endsWith() { @Test public void substring() { - assertEquals(fromString(""), fromString("hello").substring(0, 0)); + assertEquals(EMPTY_UTF8, fromString("hello").substring(0, 0)); assertEquals(fromString("el"), fromString("hello").substring(1, 3)); assertEquals(fromString("数"), fromString("数据砖头").substring(0, 1)); assertEquals(fromString("据砖"), fromString("数据砖头").substring(1, 3)); @@ -135,9 +192,9 @@ public void trims() { assertEquals(fromString("hello "), fromString(" hello ").trimLeft()); assertEquals(fromString(" hello"), fromString(" hello ").trimRight()); - assertEquals(fromString(""), fromString(" ").trim()); - assertEquals(fromString(""), fromString(" ").trimLeft()); - assertEquals(fromString(""), fromString(" ").trimRight()); + assertEquals(EMPTY_UTF8, fromString(" ").trim()); + assertEquals(EMPTY_UTF8, fromString(" ").trimLeft()); + assertEquals(EMPTY_UTF8, fromString(" ").trimRight()); assertEquals(fromString("数据砖头"), fromString(" 数据砖头 ").trim()); assertEquals(fromString("数据砖头 "), fromString(" 数据砖头 ").trimLeft()); @@ -150,9 +207,9 @@ public void trims() { @Test public void indexOf() { - assertEquals(0, fromString("").indexOf(fromString(""), 0)); - assertEquals(-1, fromString("").indexOf(fromString("l"), 0)); - assertEquals(0, fromString("hello").indexOf(fromString(""), 0)); + assertEquals(0, EMPTY_UTF8.indexOf(EMPTY_UTF8, 0)); + assertEquals(-1, EMPTY_UTF8.indexOf(fromString("l"), 0)); + assertEquals(0, fromString("hello").indexOf(EMPTY_UTF8, 0)); assertEquals(2, fromString("hello").indexOf(fromString("l"), 0)); assertEquals(3, fromString("hello").indexOf(fromString("l"), 3)); assertEquals(-1, fromString("hello").indexOf(fromString("a"), 0)); @@ -167,7 +224,7 @@ public void indexOf() { @Test public void reverse() { assertEquals(fromString("olleh"), fromString("hello").reverse()); - assertEquals(fromString(""), fromString("").reverse()); + assertEquals(EMPTY_UTF8, EMPTY_UTF8.reverse()); assertEquals(fromString("者行孙"), fromString("孙行者").reverse()); assertEquals(fromString("者行孙 olleh"), fromString("hello 孙行者").reverse()); } @@ -176,7 +233,7 @@ public void reverse() { public void repeat() { assertEquals(fromString("数d数d数d数d数d"), fromString("数d").repeat(5)); assertEquals(fromString("数d"), fromString("数d").repeat(1)); - assertEquals(fromString(""), fromString("数d").repeat(-1)); + assertEquals(EMPTY_UTF8, fromString("数d").repeat(-1)); } @Test @@ -186,14 +243,14 @@ public void pad() { assertEquals(fromString("?hello"), fromString("hello").lpad(6, fromString("????"))); assertEquals(fromString("???????hello"), fromString("hello").lpad(12, fromString("????"))); assertEquals(fromString("?????hello"), fromString("hello").lpad(10, fromString("?????"))); - assertEquals(fromString("???????"), fromString("").lpad(7, fromString("?????"))); + assertEquals(fromString("???????"), EMPTY_UTF8.lpad(7, fromString("?????"))); assertEquals(fromString("hel"), fromString("hello").rpad(3, fromString("????"))); assertEquals(fromString("hello"), fromString("hello").rpad(5, fromString("????"))); assertEquals(fromString("hello?"), fromString("hello").rpad(6, fromString("????"))); assertEquals(fromString("hello???????"), fromString("hello").rpad(12, fromString("????"))); assertEquals(fromString("hello?????"), fromString("hello").rpad(10, fromString("?????"))); - assertEquals(fromString("???????"), fromString("").rpad(7, fromString("?????"))); + assertEquals(fromString("???????"), EMPTY_UTF8.rpad(7, fromString("?????"))); assertEquals(fromString("数据砖"), fromString("数据砖头").lpad(3, fromString("????"))); @@ -201,37 +258,68 @@ public void pad() { assertEquals(fromString("??数据砖头"), fromString("数据砖头").lpad(6, fromString("????"))); assertEquals(fromString("孙行数据砖头"), fromString("数据砖头").lpad(6, fromString("孙行者"))); assertEquals(fromString("孙行者数据砖头"), fromString("数据砖头").lpad(7, fromString("孙行者"))); - assertEquals(fromString("孙行者孙行者孙行数据砖头"), fromString("数据砖头").lpad(12, fromString("孙行者"))); + assertEquals( + fromString("孙行者孙行者孙行数据砖头"), + fromString("数据砖头").lpad(12, fromString("孙行者"))); assertEquals(fromString("数据砖"), fromString("数据砖头").rpad(3, fromString("????"))); assertEquals(fromString("数据砖头?"), fromString("数据砖头").rpad(5, fromString("????"))); assertEquals(fromString("数据砖头??"), fromString("数据砖头").rpad(6, fromString("????"))); assertEquals(fromString("数据砖头孙行"), fromString("数据砖头").rpad(6, fromString("孙行者"))); assertEquals(fromString("数据砖头孙行者"), fromString("数据砖头").rpad(7, fromString("孙行者"))); - assertEquals(fromString("数据砖头孙行者孙行者孙行"), fromString("数据砖头").rpad(12, fromString("孙行者"))); + assertEquals( + fromString("数据砖头孙行者孙行者孙行"), + fromString("数据砖头").rpad(12, fromString("孙行者"))); + } + + @Test + public void substringSQL() { + UTF8String e = fromString("example"); + assertEquals(e.substringSQL(0, 2), fromString("ex")); + assertEquals(e.substringSQL(1, 2), fromString("ex")); + assertEquals(e.substringSQL(0, 7), fromString("example")); + assertEquals(e.substringSQL(1, 2), fromString("ex")); + assertEquals(e.substringSQL(0, 100), fromString("example")); + assertEquals(e.substringSQL(1, 100), fromString("example")); + assertEquals(e.substringSQL(2, 2), fromString("xa")); + assertEquals(e.substringSQL(1, 6), fromString("exampl")); + assertEquals(e.substringSQL(2, 100), fromString("xample")); + assertEquals(e.substringSQL(0, 0), fromString("")); + assertEquals(e.substringSQL(100, 4), EMPTY_UTF8); + assertEquals(e.substringSQL(0, Integer.MAX_VALUE), fromString("example")); + assertEquals(e.substringSQL(1, Integer.MAX_VALUE), fromString("example")); + assertEquals(e.substringSQL(2, Integer.MAX_VALUE), fromString("xample")); + } + + @Test + public void split() { + assertTrue(Arrays.equals(fromString("ab,def,ghi").split(fromString(","), -1), + new UTF8String[]{fromString("ab"), fromString("def"), fromString("ghi")})); + assertTrue(Arrays.equals(fromString("ab,def,ghi").split(fromString(","), 2), + new UTF8String[]{fromString("ab"), fromString("def,ghi")})); + assertTrue(Arrays.equals(fromString("ab,def,ghi").split(fromString(","), 2), + new UTF8String[]{fromString("ab"), fromString("def,ghi")})); } @Test public void levenshteinDistance() { - assertEquals( - UTF8String.fromString("").levenshteinDistance(UTF8String.fromString("")), 0); - assertEquals( - UTF8String.fromString("").levenshteinDistance(UTF8String.fromString("a")), 1); - assertEquals( - UTF8String.fromString("aaapppp").levenshteinDistance(UTF8String.fromString("")), 7); - assertEquals( - UTF8String.fromString("frog").levenshteinDistance(UTF8String.fromString("fog")), 1); - assertEquals( - UTF8String.fromString("fly").levenshteinDistance(UTF8String.fromString("ant")),3); - assertEquals( - UTF8String.fromString("elephant").levenshteinDistance(UTF8String.fromString("hippo")), 7); - assertEquals( - UTF8String.fromString("hippo").levenshteinDistance(UTF8String.fromString("elephant")), 7); - assertEquals( - UTF8String.fromString("hippo").levenshteinDistance(UTF8String.fromString("zzzzzzzz")), 8); - assertEquals( - UTF8String.fromString("hello").levenshteinDistance(UTF8String.fromString("hallo")),1); - assertEquals( - UTF8String.fromString("世界千世").levenshteinDistance(UTF8String.fromString("千a世b")),4); + assertEquals(EMPTY_UTF8.levenshteinDistance(EMPTY_UTF8), 0); + assertEquals(EMPTY_UTF8.levenshteinDistance(fromString("a")), 1); + assertEquals(fromString("aaapppp").levenshteinDistance(EMPTY_UTF8), 7); + assertEquals(fromString("frog").levenshteinDistance(fromString("fog")), 1); + assertEquals(fromString("fly").levenshteinDistance(fromString("ant")),3); + assertEquals(fromString("elephant").levenshteinDistance(fromString("hippo")), 7); + assertEquals(fromString("hippo").levenshteinDistance(fromString("elephant")), 7); + assertEquals(fromString("hippo").levenshteinDistance(fromString("zzzzzzzz")), 8); + assertEquals(fromString("hello").levenshteinDistance(fromString("hallo")),1); + assertEquals(fromString("世界千世").levenshteinDistance(fromString("千a世b")),4); + } + + @Test + public void createBlankString() { + assertEquals(fromString(" "), blankString(1)); + assertEquals(fromString(" "), blankString(2)); + assertEquals(fromString(" "), blankString(3)); + assertEquals(fromString(""), blankString(0)); } } diff --git a/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala b/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala index f0af6f875f523..bc28ce5eeae72 100644 --- a/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala +++ b/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala @@ -80,10 +80,12 @@ private[spark] class Client( private val isClusterMode = args.isClusterMode private var loginFromKeytab = false + private var principal: String = null + private var keytab: String = null + private val fireAndForget = isClusterMode && !sparkConf.getBoolean("spark.yarn.submit.waitAppCompletion", true) - def stop(): Unit = yarnClient.stop() /** @@ -339,7 +341,7 @@ private[spark] class Client( if (loginFromKeytab) { logInfo("To enable the AM to login from keytab, credentials are being copied over to the AM" + " via the YARN Secure Distributed Cache.") - val (_, localizedPath) = distribute(args.keytab, + val (_, localizedPath) = distribute(keytab, destName = Some(sparkConf.get("spark.yarn.keytab")), appMasterOnly = true) require(localizedPath != null, "Keytab file already distributed.") @@ -616,7 +618,7 @@ private[spark] class Client( val appId = newAppResponse.getApplicationId val appStagingDir = getAppStagingDir(appId) val pySparkArchives = - if (sys.props.getOrElse("spark.yarn.isPython", "false").toBoolean) { + if (sparkConf.getBoolean("spark.yarn.isPython", false)) { findPySparkArchives() } else { Nil @@ -732,9 +734,9 @@ private[spark] class Client( } val amClass = if (isClusterMode) { - Class.forName("org.apache.spark.deploy.yarn.ApplicationMaster").getName + Utils.classForName("org.apache.spark.deploy.yarn.ApplicationMaster").getName } else { - Class.forName("org.apache.spark.deploy.yarn.ExecutorLauncher").getName + Utils.classForName("org.apache.spark.deploy.yarn.ExecutorLauncher").getName } if (args.primaryRFile != null && args.primaryRFile.endsWith(".R")) { args.userArgs = ArrayBuffer(args.primaryRFile) ++ args.userArgs @@ -785,19 +787,27 @@ private[spark] class Client( } def setupCredentials(): Unit = { - if (args.principal != null) { - require(args.keytab != null, "Keytab must be specified when principal is specified.") + loginFromKeytab = args.principal != null || sparkConf.contains("spark.yarn.principal") + if (loginFromKeytab) { + principal = + if (args.principal != null) args.principal else sparkConf.get("spark.yarn.principal") + keytab = { + if (args.keytab != null) { + args.keytab + } else { + sparkConf.getOption("spark.yarn.keytab").orNull + } + } + + require(keytab != null, "Keytab must be specified when principal is specified.") logInfo("Attempting to login to the Kerberos" + - s" using principal: ${args.principal} and keytab: ${args.keytab}") - val f = new File(args.keytab) + s" using principal: $principal and keytab: $keytab") + val f = new File(keytab) // Generate a file name that can be used for the keytab file, that does not conflict // with any user file. val keytabFileName = f.getName + "-" + UUID.randomUUID().toString - UserGroupInformation.loginUserFromKeytab(args.principal, args.keytab) - loginFromKeytab = true sparkConf.set("spark.yarn.keytab", keytabFileName) - sparkConf.set("spark.yarn.principal", args.principal) - logInfo("Successfully logged into the KDC.") + sparkConf.set("spark.yarn.principal", principal) } credentials = UserGroupInformation.getCurrentUser.getCredentials } @@ -1162,7 +1172,7 @@ object Client extends Logging { * * If not a "local:" file and no alternate name, the environment is not modified. * - * @parma conf Spark configuration. + * @param conf Spark configuration. * @param uri URI to add to classpath (optional). * @param fileName Alternate name for the file (optional). * @param env Map holding the environment variables. diff --git a/yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnClientSchedulerBackend.scala b/yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnClientSchedulerBackend.scala index 3a0b9443d2d7b..d97fa2e2151bc 100644 --- a/yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnClientSchedulerBackend.scala +++ b/yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnClientSchedulerBackend.scala @@ -20,10 +20,9 @@ package org.apache.spark.scheduler.cluster import scala.collection.mutable.ArrayBuffer import org.apache.hadoop.yarn.api.records.{ApplicationId, YarnApplicationState} -import org.apache.hadoop.yarn.exceptions.ApplicationNotFoundException import org.apache.spark.{SparkException, Logging, SparkContext} -import org.apache.spark.deploy.yarn.{Client, ClientArguments} +import org.apache.spark.deploy.yarn.{Client, ClientArguments, YarnSparkHadoopUtil} import org.apache.spark.scheduler.TaskSchedulerImpl private[spark] class YarnClientSchedulerBackend( @@ -62,6 +61,13 @@ private[spark] class YarnClientSchedulerBackend( super.start() waitForApplication() + + // SPARK-8851: In yarn-client mode, the AM still does the credentials refresh. The driver + // reads the credentials from HDFS, just like the executors and updates its own credentials + // cache. + if (conf.contains("spark.yarn.credentials.file")) { + YarnSparkHadoopUtil.get.startExecutorDelegationTokenRenewer(conf) + } monitorThread = asyncMonitorApplication() monitorThread.start() } @@ -158,6 +164,7 @@ private[spark] class YarnClientSchedulerBackend( } super.stop() client.stop() + YarnSparkHadoopUtil.get.stopExecutorDelegationTokenRenewer() logInfo("Stopped") }