diff --git a/R/pkg/DESCRIPTION b/R/pkg/DESCRIPTION index 0cb3a80a6e892..cc471edc376b3 100644 --- a/R/pkg/DESCRIPTION +++ b/R/pkg/DESCRIPTION @@ -1,6 +1,6 @@ Package: SparkR Type: Package -Version: 2.1.0 +Version: 2.2.0 Title: R Frontend for Apache Spark Description: The SparkR package provides an R Frontend for Apache Spark. Authors@R: c(person("Shivaram", "Venkataraman", role = c("aut", "cre"), @@ -41,7 +41,13 @@ Collate: 'functions.R' 'install.R' 'jvm.R' - 'mllib.R' + 'mllib_classification.R' + 'mllib_clustering.R' + 'mllib_recommendation.R' + 'mllib_regression.R' + 'mllib_stat.R' + 'mllib_tree.R' + 'mllib_utils.R' 'serialize.R' 'sparkR.R' 'stats.R' diff --git a/R/pkg/R/DataFrame.R b/R/pkg/R/DataFrame.R index 7737ffe4ed43b..c56648a8c4fba 100644 --- a/R/pkg/R/DataFrame.R +++ b/R/pkg/R/DataFrame.R @@ -2313,9 +2313,9 @@ setMethod("dropDuplicates", #' @param joinExpr (Optional) The expression used to perform the join. joinExpr must be a #' Column expression. If joinExpr is omitted, the default, inner join is attempted and an error is #' thrown if it would be a Cartesian Product. For Cartesian join, use crossJoin instead. -#' @param joinType The type of join to perform. The following join types are available: -#' 'inner', 'outer', 'full', 'fullouter', leftouter', 'left_outer', 'left', -#' 'right_outer', 'rightouter', 'right', and 'leftsemi'. The default joinType is "inner". +#' @param joinType The type of join to perform, default 'inner'. +#' Must be one of: 'inner', 'cross', 'outer', 'full', 'full_outer', +#' 'left', 'left_outer', 'right', 'right_outer', 'left_semi', or 'left_anti'. #' @return A SparkDataFrame containing the result of the join operation. #' @family SparkDataFrame functions #' @aliases join,SparkDataFrame,SparkDataFrame-method @@ -2344,15 +2344,18 @@ setMethod("join", if (is.null(joinType)) { sdf <- callJMethod(x@sdf, "join", y@sdf, joinExpr@jc) } else { - if (joinType %in% c("inner", "outer", "full", "fullouter", - "leftouter", "left_outer", "left", - "rightouter", "right_outer", "right", "leftsemi")) { + if (joinType %in% c("inner", "cross", + "outer", "full", "fullouter", "full_outer", + "left", "leftouter", "left_outer", + "right", "rightouter", "right_outer", + "left_semi", "leftsemi", "left_anti", "leftanti")) { joinType <- gsub("_", "", joinType) sdf <- callJMethod(x@sdf, "join", y@sdf, joinExpr@jc, joinType) } else { stop("joinType must be one of the following types: ", - "'inner', 'outer', 'full', 'fullouter', 'leftouter', 'left_outer', 'left', - 'rightouter', 'right_outer', 'right', 'leftsemi'") + "'inner', 'cross', 'outer', 'full', 'full_outer',", + "'left', 'left_outer', 'right', 'right_outer',", + "'left_semi', or 'left_anti'.") } } } diff --git a/R/pkg/R/functions.R b/R/pkg/R/functions.R index bf5c96373c632..6ffa0f5481c65 100644 --- a/R/pkg/R/functions.R +++ b/R/pkg/R/functions.R @@ -3150,7 +3150,8 @@ setMethod("cume_dist", #' The difference between rank and dense_rank is that dense_rank leaves no gaps in ranking #' sequence when there are ties. That is, if you were ranking a competition using dense_rank #' and had three people tie for second place, you would say that all three were in second -#' place and that the next person came in third. +#' place and that the next person came in third. Rank would give me sequential numbers, making +#' the person that came in third place (after the ties) would register as coming in fifth. #' #' This is equivalent to the \code{DENSE_RANK} function in SQL. #' @@ -3321,10 +3322,11 @@ setMethod("percent_rank", #' #' Window function: returns the rank of rows within a window partition. #' -#' The difference between rank and denseRank is that denseRank leaves no gaps in ranking -#' sequence when there are ties. That is, if you were ranking a competition using denseRank +#' The difference between rank and dense_rank is that dense_rank leaves no gaps in ranking +#' sequence when there are ties. That is, if you were ranking a competition using dense_rank #' and had three people tie for second place, you would say that all three were in second -#' place and that the next person came in third. +#' place and that the next person came in third. Rank would give me sequential numbers, making +#' the person that came in third place (after the ties) would register as coming in fifth. #' #' This is equivalent to the RANK function in SQL. #' diff --git a/R/pkg/R/mllib.R b/R/pkg/R/mllib.R deleted file mode 100644 index d736bbb5e9113..0000000000000 --- a/R/pkg/R/mllib.R +++ /dev/null @@ -1,2114 +0,0 @@ -# -# Licensed to the Apache Software Foundation (ASF) under one or more -# contributor license agreements. See the NOTICE file distributed with -# this work for additional information regarding copyright ownership. -# The ASF licenses this file to You under the Apache License, Version 2.0 -# (the "License"); you may not use this file except in compliance with -# the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# - -# mllib.R: Provides methods for MLlib integration - -# Integration with R's standard functions. -# Most of MLlib's argorithms are provided in two flavours: -# - a specialization of the default R methods (glm). These methods try to respect -# the inputs and the outputs of R's method to the largest extent, but some small differences -# may exist. -# - a set of methods that reflect the arguments of the other languages supported by Spark. These -# methods are prefixed with the `spark.` prefix: spark.glm, spark.kmeans, etc. - -#' S4 class that represents a generalized linear model -#' -#' @param jobj a Java object reference to the backing Scala GeneralizedLinearRegressionWrapper -#' @export -#' @note GeneralizedLinearRegressionModel since 2.0.0 -setClass("GeneralizedLinearRegressionModel", representation(jobj = "jobj")) - -#' S4 class that represents a NaiveBayesModel -#' -#' @param jobj a Java object reference to the backing Scala NaiveBayesWrapper -#' @export -#' @note NaiveBayesModel since 2.0.0 -setClass("NaiveBayesModel", representation(jobj = "jobj")) - -#' S4 class that represents an LDAModel -#' -#' @param jobj a Java object reference to the backing Scala LDAWrapper -#' @export -#' @note LDAModel since 2.1.0 -setClass("LDAModel", representation(jobj = "jobj")) - -#' S4 class that represents a AFTSurvivalRegressionModel -#' -#' @param jobj a Java object reference to the backing Scala AFTSurvivalRegressionWrapper -#' @export -#' @note AFTSurvivalRegressionModel since 2.0.0 -setClass("AFTSurvivalRegressionModel", representation(jobj = "jobj")) - -#' S4 class that represents a KMeansModel -#' -#' @param jobj a Java object reference to the backing Scala KMeansModel -#' @export -#' @note KMeansModel since 2.0.0 -setClass("KMeansModel", representation(jobj = "jobj")) - -#' S4 class that represents a MultilayerPerceptronClassificationModel -#' -#' @param jobj a Java object reference to the backing Scala MultilayerPerceptronClassifierWrapper -#' @export -#' @note MultilayerPerceptronClassificationModel since 2.1.0 -setClass("MultilayerPerceptronClassificationModel", representation(jobj = "jobj")) - -#' S4 class that represents an IsotonicRegressionModel -#' -#' @param jobj a Java object reference to the backing Scala IsotonicRegressionModel -#' @export -#' @note IsotonicRegressionModel since 2.1.0 -setClass("IsotonicRegressionModel", representation(jobj = "jobj")) - -#' S4 class that represents a GaussianMixtureModel -#' -#' @param jobj a Java object reference to the backing Scala GaussianMixtureModel -#' @export -#' @note GaussianMixtureModel since 2.1.0 -setClass("GaussianMixtureModel", representation(jobj = "jobj")) - -#' S4 class that represents an ALSModel -#' -#' @param jobj a Java object reference to the backing Scala ALSWrapper -#' @export -#' @note ALSModel since 2.1.0 -setClass("ALSModel", representation(jobj = "jobj")) - -#' S4 class that represents an KSTest -#' -#' @param jobj a Java object reference to the backing Scala KSTestWrapper -#' @export -#' @note KSTest since 2.1.0 -setClass("KSTest", representation(jobj = "jobj")) - -#' S4 class that represents an LogisticRegressionModel -#' -#' @param jobj a Java object reference to the backing Scala LogisticRegressionModel -#' @export -#' @note LogisticRegressionModel since 2.1.0 -setClass("LogisticRegressionModel", representation(jobj = "jobj")) - -#' S4 class that represents a RandomForestRegressionModel -#' -#' @param jobj a Java object reference to the backing Scala RandomForestRegressionModel -#' @export -#' @note RandomForestRegressionModel since 2.1.0 -setClass("RandomForestRegressionModel", representation(jobj = "jobj")) - -#' S4 class that represents a RandomForestClassificationModel -#' -#' @param jobj a Java object reference to the backing Scala RandomForestClassificationModel -#' @export -#' @note RandomForestClassificationModel since 2.1.0 -setClass("RandomForestClassificationModel", representation(jobj = "jobj")) - -#' S4 class that represents a GBTRegressionModel -#' -#' @param jobj a Java object reference to the backing Scala GBTRegressionModel -#' @export -#' @note GBTRegressionModel since 2.1.0 -setClass("GBTRegressionModel", representation(jobj = "jobj")) - -#' S4 class that represents a GBTClassificationModel -#' -#' @param jobj a Java object reference to the backing Scala GBTClassificationModel -#' @export -#' @note GBTClassificationModel since 2.1.0 -setClass("GBTClassificationModel", representation(jobj = "jobj")) - -#' Saves the MLlib model to the input path -#' -#' Saves the MLlib model to the input path. For more information, see the specific -#' MLlib model below. -#' @rdname write.ml -#' @name write.ml -#' @export -#' @seealso \link{spark.glm}, \link{glm}, -#' @seealso \link{spark.als}, \link{spark.gaussianMixture}, \link{spark.gbt}, \link{spark.isoreg}, -#' @seealso \link{spark.kmeans}, -#' @seealso \link{spark.lda}, \link{spark.logit}, \link{spark.mlp}, \link{spark.naiveBayes}, -#' @seealso \link{spark.randomForest}, \link{spark.survreg}, -#' @seealso \link{read.ml} -NULL - -#' Makes predictions from a MLlib model -#' -#' Makes predictions from a MLlib model. For more information, see the specific -#' MLlib model below. -#' @rdname predict -#' @name predict -#' @export -#' @seealso \link{spark.glm}, \link{glm}, -#' @seealso \link{spark.als}, \link{spark.gaussianMixture}, \link{spark.gbt}, \link{spark.isoreg}, -#' @seealso \link{spark.kmeans}, -#' @seealso \link{spark.logit}, \link{spark.mlp}, \link{spark.naiveBayes}, -#' @seealso \link{spark.randomForest}, \link{spark.survreg} -NULL - -write_internal <- function(object, path, overwrite = FALSE) { - writer <- callJMethod(object@jobj, "write") - if (overwrite) { - writer <- callJMethod(writer, "overwrite") - } - invisible(callJMethod(writer, "save", path)) -} - -predict_internal <- function(object, newData) { - dataFrame(callJMethod(object@jobj, "transform", newData@sdf)) -} - -#' Generalized Linear Models -#' -#' Fits generalized linear model against a Spark DataFrame. -#' Users can call \code{summary} to print a summary of the fitted model, \code{predict} to make -#' predictions on new data, and \code{write.ml}/\code{read.ml} to save/load fitted models. -#' -#' @param data a SparkDataFrame for training. -#' @param formula a symbolic description of the model to be fitted. Currently only a few formula -#' operators are supported, including '~', '.', ':', '+', and '-'. -#' @param family a description of the error distribution and link function to be used in the model. -#' This can be a character string naming a family function, a family function or -#' the result of a call to a family function. Refer R family at -#' \url{https://stat.ethz.ch/R-manual/R-devel/library/stats/html/family.html}. -#' @param tol positive convergence tolerance of iterations. -#' @param maxIter integer giving the maximal number of IRLS iterations. -#' @param weightCol the weight column name. If this is not set or \code{NULL}, we treat all instance -#' weights as 1.0. -#' @param regParam regularization parameter for L2 regularization. -#' @param ... additional arguments passed to the method. -#' @aliases spark.glm,SparkDataFrame,formula-method -#' @return \code{spark.glm} returns a fitted generalized linear model. -#' @rdname spark.glm -#' @name spark.glm -#' @export -#' @examples -#' \dontrun{ -#' sparkR.session() -#' data(iris) -#' df <- createDataFrame(iris) -#' model <- spark.glm(df, Sepal_Length ~ Sepal_Width, family = "gaussian") -#' summary(model) -#' -#' # fitted values on training data -#' fitted <- predict(model, df) -#' head(select(fitted, "Sepal_Length", "prediction")) -#' -#' # save fitted model to input path -#' path <- "path/to/model" -#' write.ml(model, path) -#' -#' # can also read back the saved model and print -#' savedModel <- read.ml(path) -#' summary(savedModel) -#' } -#' @note spark.glm since 2.0.0 -#' @seealso \link{glm}, \link{read.ml} -setMethod("spark.glm", signature(data = "SparkDataFrame", formula = "formula"), - function(data, formula, family = gaussian, tol = 1e-6, maxIter = 25, weightCol = NULL, - regParam = 0.0) { - if (is.character(family)) { - family <- get(family, mode = "function", envir = parent.frame()) - } - if (is.function(family)) { - family <- family() - } - if (is.null(family$family)) { - print(family) - stop("'family' not recognized") - } - - formula <- paste(deparse(formula), collapse = "") - if (is.null(weightCol)) { - weightCol <- "" - } - - jobj <- callJStatic("org.apache.spark.ml.r.GeneralizedLinearRegressionWrapper", - "fit", formula, data@sdf, family$family, family$link, - tol, as.integer(maxIter), as.character(weightCol), regParam) - new("GeneralizedLinearRegressionModel", jobj = jobj) - }) - -#' Generalized Linear Models (R-compliant) -#' -#' Fits a generalized linear model, similarly to R's glm(). -#' @param formula a symbolic description of the model to be fitted. Currently only a few formula -#' operators are supported, including '~', '.', ':', '+', and '-'. -#' @param data a SparkDataFrame or R's glm data for training. -#' @param family a description of the error distribution and link function to be used in the model. -#' This can be a character string naming a family function, a family function or -#' the result of a call to a family function. Refer R family at -#' \url{https://stat.ethz.ch/R-manual/R-devel/library/stats/html/family.html}. -#' @param weightCol the weight column name. If this is not set or \code{NULL}, we treat all instance -#' weights as 1.0. -#' @param epsilon positive convergence tolerance of iterations. -#' @param maxit integer giving the maximal number of IRLS iterations. -#' @return \code{glm} returns a fitted generalized linear model. -#' @rdname glm -#' @export -#' @examples -#' \dontrun{ -#' sparkR.session() -#' data(iris) -#' df <- createDataFrame(iris) -#' model <- glm(Sepal_Length ~ Sepal_Width, df, family = "gaussian") -#' summary(model) -#' } -#' @note glm since 1.5.0 -#' @seealso \link{spark.glm} -setMethod("glm", signature(formula = "formula", family = "ANY", data = "SparkDataFrame"), - function(formula, family = gaussian, data, epsilon = 1e-6, maxit = 25, weightCol = NULL) { - spark.glm(data, formula, family, tol = epsilon, maxIter = maxit, weightCol = weightCol) - }) - -# Returns the summary of a model produced by glm() or spark.glm(), similarly to R's summary(). - -#' @param object a fitted generalized linear model. -#' @return \code{summary} returns summary information of the fitted model, which is a list. -#' The list of components includes at least the \code{coefficients} (coefficients matrix, which includes -#' coefficients, standard error of coefficients, t value and p value), -#' \code{null.deviance} (null/residual degrees of freedom), \code{aic} (AIC) -#' and \code{iter} (number of iterations IRLS takes). If there are collinear columns in the data, -#' the coefficients matrix only provides coefficients. -#' @rdname spark.glm -#' @export -#' @note summary(GeneralizedLinearRegressionModel) since 2.0.0 -setMethod("summary", signature(object = "GeneralizedLinearRegressionModel"), - function(object) { - jobj <- object@jobj - is.loaded <- callJMethod(jobj, "isLoaded") - features <- callJMethod(jobj, "rFeatures") - coefficients <- callJMethod(jobj, "rCoefficients") - dispersion <- callJMethod(jobj, "rDispersion") - null.deviance <- callJMethod(jobj, "rNullDeviance") - deviance <- callJMethod(jobj, "rDeviance") - df.null <- callJMethod(jobj, "rResidualDegreeOfFreedomNull") - df.residual <- callJMethod(jobj, "rResidualDegreeOfFreedom") - aic <- callJMethod(jobj, "rAic") - iter <- callJMethod(jobj, "rNumIterations") - family <- callJMethod(jobj, "rFamily") - deviance.resid <- if (is.loaded) { - NULL - } else { - dataFrame(callJMethod(jobj, "rDevianceResiduals")) - } - # If the underlying WeightedLeastSquares using "normal" solver, we can provide - # coefficients, standard error of coefficients, t value and p value. Otherwise, - # it will be fitted by local "l-bfgs", we can only provide coefficients. - if (length(features) == length(coefficients)) { - coefficients <- matrix(coefficients, ncol = 1) - colnames(coefficients) <- c("Estimate") - rownames(coefficients) <- unlist(features) - } else { - coefficients <- matrix(coefficients, ncol = 4) - colnames(coefficients) <- c("Estimate", "Std. Error", "t value", "Pr(>|t|)") - rownames(coefficients) <- unlist(features) - } - ans <- list(deviance.resid = deviance.resid, coefficients = coefficients, - dispersion = dispersion, null.deviance = null.deviance, - deviance = deviance, df.null = df.null, df.residual = df.residual, - aic = aic, iter = iter, family = family, is.loaded = is.loaded) - class(ans) <- "summary.GeneralizedLinearRegressionModel" - ans - }) - -# Prints the summary of GeneralizedLinearRegressionModel - -#' @rdname spark.glm -#' @param x summary object of fitted generalized linear model returned by \code{summary} function. -#' @export -#' @note print.summary.GeneralizedLinearRegressionModel since 2.0.0 -print.summary.GeneralizedLinearRegressionModel <- function(x, ...) { - if (x$is.loaded) { - cat("\nSaved-loaded model does not support output 'Deviance Residuals'.\n") - } else { - x$deviance.resid <- setNames(unlist(approxQuantile(x$deviance.resid, "devianceResiduals", - c(0.0, 0.25, 0.5, 0.75, 1.0), 0.01)), c("Min", "1Q", "Median", "3Q", "Max")) - x$deviance.resid <- zapsmall(x$deviance.resid, 5L) - cat("\nDeviance Residuals: \n") - cat("(Note: These are approximate quantiles with relative error <= 0.01)\n") - print.default(x$deviance.resid, digits = 5L, na.print = "", print.gap = 2L) - } - - cat("\nCoefficients:\n") - print.default(x$coefficients, digits = 5L, na.print = "", print.gap = 2L) - - cat("\n(Dispersion parameter for ", x$family, " family taken to be ", format(x$dispersion), - ")\n\n", apply(cbind(paste(format(c("Null", "Residual"), justify = "right"), "deviance:"), - format(unlist(x[c("null.deviance", "deviance")]), digits = 5L), - " on", format(unlist(x[c("df.null", "df.residual")])), " degrees of freedom\n"), - 1L, paste, collapse = " "), sep = "") - cat("AIC: ", format(x$aic, digits = 4L), "\n\n", - "Number of Fisher Scoring iterations: ", x$iter, "\n\n", sep = "") - invisible(x) - } - -# Makes predictions from a generalized linear model produced by glm() or spark.glm(), -# similarly to R's predict(). - -#' @param newData a SparkDataFrame for testing. -#' @return \code{predict} returns a SparkDataFrame containing predicted labels in a column named -#' "prediction". -#' @rdname spark.glm -#' @export -#' @note predict(GeneralizedLinearRegressionModel) since 1.5.0 -setMethod("predict", signature(object = "GeneralizedLinearRegressionModel"), - function(object, newData) { - predict_internal(object, newData) - }) - -# Makes predictions from a naive Bayes model or a model produced by spark.naiveBayes(), -# similarly to R package e1071's predict. - -#' @param newData a SparkDataFrame for testing. -#' @return \code{predict} returns a SparkDataFrame containing predicted labeled in a column named -#' "prediction". -#' @rdname spark.naiveBayes -#' @export -#' @note predict(NaiveBayesModel) since 2.0.0 -setMethod("predict", signature(object = "NaiveBayesModel"), - function(object, newData) { - predict_internal(object, newData) - }) - -# Returns the summary of a naive Bayes model produced by \code{spark.naiveBayes} - -#' @param object a naive Bayes model fitted by \code{spark.naiveBayes}. -#' @return \code{summary} returns summary information of the fitted model, which is a list. -#' The list includes \code{apriori} (the label distribution) and -#' \code{tables} (conditional probabilities given the target label). -#' @rdname spark.naiveBayes -#' @export -#' @note summary(NaiveBayesModel) since 2.0.0 -setMethod("summary", signature(object = "NaiveBayesModel"), - function(object) { - jobj <- object@jobj - features <- callJMethod(jobj, "features") - labels <- callJMethod(jobj, "labels") - apriori <- callJMethod(jobj, "apriori") - apriori <- t(as.matrix(unlist(apriori))) - colnames(apriori) <- unlist(labels) - tables <- callJMethod(jobj, "tables") - tables <- matrix(tables, nrow = length(labels)) - rownames(tables) <- unlist(labels) - colnames(tables) <- unlist(features) - list(apriori = apriori, tables = tables) - }) - -# Returns posterior probabilities from a Latent Dirichlet Allocation model produced by spark.lda() - -#' @param newData A SparkDataFrame for testing. -#' @return \code{spark.posterior} returns a SparkDataFrame containing posterior probabilities -#' vectors named "topicDistribution". -#' @rdname spark.lda -#' @aliases spark.posterior,LDAModel,SparkDataFrame-method -#' @export -#' @note spark.posterior(LDAModel) since 2.1.0 -setMethod("spark.posterior", signature(object = "LDAModel", newData = "SparkDataFrame"), - function(object, newData) { - predict_internal(object, newData) - }) - -# Returns the summary of a Latent Dirichlet Allocation model produced by \code{spark.lda} - -#' @param object A Latent Dirichlet Allocation model fitted by \code{spark.lda}. -#' @param maxTermsPerTopic Maximum number of terms to collect for each topic. Default value of 10. -#' @return \code{summary} returns summary information of the fitted model, which is a list. -#' The list includes -#' \item{\code{docConcentration}}{concentration parameter commonly named \code{alpha} for -#' the prior placed on documents distributions over topics \code{theta}} -#' \item{\code{topicConcentration}}{concentration parameter commonly named \code{beta} or -#' \code{eta} for the prior placed on topic distributions over terms} -#' \item{\code{logLikelihood}}{log likelihood of the entire corpus} -#' \item{\code{logPerplexity}}{log perplexity} -#' \item{\code{isDistributed}}{TRUE for distributed model while FALSE for local model} -#' \item{\code{vocabSize}}{number of terms in the corpus} -#' \item{\code{topics}}{top 10 terms and their weights of all topics} -#' \item{\code{vocabulary}}{whole terms of the training corpus, NULL if libsvm format file -#' used as training set} -#' @rdname spark.lda -#' @aliases summary,LDAModel-method -#' @export -#' @note summary(LDAModel) since 2.1.0 -setMethod("summary", signature(object = "LDAModel"), - function(object, maxTermsPerTopic) { - maxTermsPerTopic <- as.integer(ifelse(missing(maxTermsPerTopic), 10, maxTermsPerTopic)) - jobj <- object@jobj - docConcentration <- callJMethod(jobj, "docConcentration") - topicConcentration <- callJMethod(jobj, "topicConcentration") - logLikelihood <- callJMethod(jobj, "logLikelihood") - logPerplexity <- callJMethod(jobj, "logPerplexity") - isDistributed <- callJMethod(jobj, "isDistributed") - vocabSize <- callJMethod(jobj, "vocabSize") - topics <- dataFrame(callJMethod(jobj, "topics", maxTermsPerTopic)) - vocabulary <- callJMethod(jobj, "vocabulary") - list(docConcentration = unlist(docConcentration), - topicConcentration = topicConcentration, - logLikelihood = logLikelihood, logPerplexity = logPerplexity, - isDistributed = isDistributed, vocabSize = vocabSize, - topics = topics, vocabulary = unlist(vocabulary)) - }) - -# Returns the log perplexity of a Latent Dirichlet Allocation model produced by \code{spark.lda} - -#' @return \code{spark.perplexity} returns the log perplexity of given SparkDataFrame, or the log -#' perplexity of the training data if missing argument "data". -#' @rdname spark.lda -#' @aliases spark.perplexity,LDAModel-method -#' @export -#' @note spark.perplexity(LDAModel) since 2.1.0 -setMethod("spark.perplexity", signature(object = "LDAModel", data = "SparkDataFrame"), - function(object, data) { - ifelse(missing(data), callJMethod(object@jobj, "logPerplexity"), - callJMethod(object@jobj, "computeLogPerplexity", data@sdf)) - }) - -# Saves the Latent Dirichlet Allocation model to the input path. - -#' @param path The directory where the model is saved. -#' @param overwrite Overwrites or not if the output path already exists. Default is FALSE -#' which means throw exception if the output path exists. -#' -#' @rdname spark.lda -#' @aliases write.ml,LDAModel,character-method -#' @export -#' @seealso \link{read.ml} -#' @note write.ml(LDAModel, character) since 2.1.0 -setMethod("write.ml", signature(object = "LDAModel", path = "character"), - function(object, path, overwrite = FALSE) { - write_internal(object, path, overwrite) - }) - -#' Isotonic Regression Model -#' -#' Fits an Isotonic Regression model against a Spark DataFrame, similarly to R's isoreg(). -#' Users can print, make predictions on the produced model and save the model to the input path. -#' -#' @param data SparkDataFrame for training. -#' @param formula A symbolic description of the model to be fitted. Currently only a few formula -#' operators are supported, including '~', '.', ':', '+', and '-'. -#' @param isotonic Whether the output sequence should be isotonic/increasing (TRUE) or -#' antitonic/decreasing (FALSE). -#' @param featureIndex The index of the feature if \code{featuresCol} is a vector column -#' (default: 0), no effect otherwise. -#' @param weightCol The weight column name. -#' @param ... additional arguments passed to the method. -#' @return \code{spark.isoreg} returns a fitted Isotonic Regression model. -#' @rdname spark.isoreg -#' @aliases spark.isoreg,SparkDataFrame,formula-method -#' @name spark.isoreg -#' @export -#' @examples -#' \dontrun{ -#' sparkR.session() -#' data <- list(list(7.0, 0.0), list(5.0, 1.0), list(3.0, 2.0), -#' list(5.0, 3.0), list(1.0, 4.0)) -#' df <- createDataFrame(data, c("label", "feature")) -#' model <- spark.isoreg(df, label ~ feature, isotonic = FALSE) -#' # return model boundaries and prediction as lists -#' result <- summary(model, df) -#' # prediction based on fitted model -#' predict_data <- list(list(-2.0), list(-1.0), list(0.5), -#' list(0.75), list(1.0), list(2.0), list(9.0)) -#' predict_df <- createDataFrame(predict_data, c("feature")) -#' # get prediction column -#' predict_result <- collect(select(predict(model, predict_df), "prediction")) -#' -#' # save fitted model to input path -#' path <- "path/to/model" -#' write.ml(model, path) -#' -#' # can also read back the saved model and print -#' savedModel <- read.ml(path) -#' summary(savedModel) -#' } -#' @note spark.isoreg since 2.1.0 -setMethod("spark.isoreg", signature(data = "SparkDataFrame", formula = "formula"), - function(data, formula, isotonic = TRUE, featureIndex = 0, weightCol = NULL) { - formula <- paste(deparse(formula), collapse = "") - - if (is.null(weightCol)) { - weightCol <- "" - } - - jobj <- callJStatic("org.apache.spark.ml.r.IsotonicRegressionWrapper", "fit", - data@sdf, formula, as.logical(isotonic), as.integer(featureIndex), - as.character(weightCol)) - new("IsotonicRegressionModel", jobj = jobj) - }) - -# Predicted values based on an isotonicRegression model - -#' @param object a fitted IsotonicRegressionModel. -#' @param newData SparkDataFrame for testing. -#' @return \code{predict} returns a SparkDataFrame containing predicted values. -#' @rdname spark.isoreg -#' @aliases predict,IsotonicRegressionModel,SparkDataFrame-method -#' @export -#' @note predict(IsotonicRegressionModel) since 2.1.0 -setMethod("predict", signature(object = "IsotonicRegressionModel"), - function(object, newData) { - predict_internal(object, newData) - }) - -# Get the summary of an IsotonicRegressionModel model - -#' @return \code{summary} returns summary information of the fitted model, which is a list. -#' The list includes model's \code{boundaries} (boundaries in increasing order) -#' and \code{predictions} (predictions associated with the boundaries at the same index). -#' @rdname spark.isoreg -#' @aliases summary,IsotonicRegressionModel-method -#' @export -#' @note summary(IsotonicRegressionModel) since 2.1.0 -setMethod("summary", signature(object = "IsotonicRegressionModel"), - function(object) { - jobj <- object@jobj - boundaries <- callJMethod(jobj, "boundaries") - predictions <- callJMethod(jobj, "predictions") - list(boundaries = boundaries, predictions = predictions) - }) - -#' K-Means Clustering Model -#' -#' Fits a k-means clustering model against a Spark DataFrame, similarly to R's kmeans(). -#' Users can call \code{summary} to print a summary of the fitted model, \code{predict} to make -#' predictions on new data, and \code{write.ml}/\code{read.ml} to save/load fitted models. -#' -#' @param data a SparkDataFrame for training. -#' @param formula a symbolic description of the model to be fitted. Currently only a few formula -#' operators are supported, including '~', '.', ':', '+', and '-'. -#' Note that the response variable of formula is empty in spark.kmeans. -#' @param k number of centers. -#' @param maxIter maximum iteration number. -#' @param initMode the initialization algorithm choosen to fit the model. -#' @param ... additional argument(s) passed to the method. -#' @return \code{spark.kmeans} returns a fitted k-means model. -#' @rdname spark.kmeans -#' @aliases spark.kmeans,SparkDataFrame,formula-method -#' @name spark.kmeans -#' @export -#' @examples -#' \dontrun{ -#' sparkR.session() -#' data(iris) -#' df <- createDataFrame(iris) -#' model <- spark.kmeans(df, Sepal_Length ~ Sepal_Width, k = 4, initMode = "random") -#' summary(model) -#' -#' # fitted values on training data -#' fitted <- predict(model, df) -#' head(select(fitted, "Sepal_Length", "prediction")) -#' -#' # save fitted model to input path -#' path <- "path/to/model" -#' write.ml(model, path) -#' -#' # can also read back the saved model and print -#' savedModel <- read.ml(path) -#' summary(savedModel) -#' } -#' @note spark.kmeans since 2.0.0 -#' @seealso \link{predict}, \link{read.ml}, \link{write.ml} -setMethod("spark.kmeans", signature(data = "SparkDataFrame", formula = "formula"), - function(data, formula, k = 2, maxIter = 20, initMode = c("k-means||", "random")) { - formula <- paste(deparse(formula), collapse = "") - initMode <- match.arg(initMode) - jobj <- callJStatic("org.apache.spark.ml.r.KMeansWrapper", "fit", data@sdf, formula, - as.integer(k), as.integer(maxIter), initMode) - new("KMeansModel", jobj = jobj) - }) - -#' Get fitted result from a k-means model -#' -#' Get fitted result from a k-means model, similarly to R's fitted(). -#' Note: A saved-loaded model does not support this method. -#' -#' @param object a fitted k-means model. -#' @param method type of fitted results, \code{"centers"} for cluster centers -#' or \code{"classes"} for assigned classes. -#' @param ... additional argument(s) passed to the method. -#' @return \code{fitted} returns a SparkDataFrame containing fitted values. -#' @rdname fitted -#' @export -#' @examples -#' \dontrun{ -#' model <- spark.kmeans(trainingData, ~ ., 2) -#' fitted.model <- fitted(model) -#' showDF(fitted.model) -#'} -#' @note fitted since 2.0.0 -setMethod("fitted", signature(object = "KMeansModel"), - function(object, method = c("centers", "classes")) { - method <- match.arg(method) - jobj <- object@jobj - is.loaded <- callJMethod(jobj, "isLoaded") - if (is.loaded) { - stop("Saved-loaded k-means model does not support 'fitted' method") - } else { - dataFrame(callJMethod(jobj, "fitted", method)) - } - }) - -# Get the summary of a k-means model - -#' @param object a fitted k-means model. -#' @return \code{summary} returns summary information of the fitted model, which is a list. -#' The list includes the model's \code{k} (number of cluster centers), -#' \code{coefficients} (model cluster centers), -#' \code{size} (number of data points in each cluster), and \code{cluster} -#' (cluster centers of the transformed data). -#' @rdname spark.kmeans -#' @export -#' @note summary(KMeansModel) since 2.0.0 -setMethod("summary", signature(object = "KMeansModel"), - function(object) { - jobj <- object@jobj - is.loaded <- callJMethod(jobj, "isLoaded") - features <- callJMethod(jobj, "features") - coefficients <- callJMethod(jobj, "coefficients") - k <- callJMethod(jobj, "k") - size <- callJMethod(jobj, "size") - coefficients <- t(matrix(coefficients, ncol = k)) - colnames(coefficients) <- unlist(features) - rownames(coefficients) <- 1:k - cluster <- if (is.loaded) { - NULL - } else { - dataFrame(callJMethod(jobj, "cluster")) - } - list(k = k, coefficients = coefficients, size = size, - cluster = cluster, is.loaded = is.loaded) - }) - -# Predicted values based on a k-means model - -#' @param newData a SparkDataFrame for testing. -#' @return \code{predict} returns the predicted values based on a k-means model. -#' @rdname spark.kmeans -#' @export -#' @note predict(KMeansModel) since 2.0.0 -setMethod("predict", signature(object = "KMeansModel"), - function(object, newData) { - predict_internal(object, newData) - }) - -#' Logistic Regression Model -#' -#' Fits an logistic regression model against a Spark DataFrame. It supports "binomial": Binary logistic regression -#' with pivoting; "multinomial": Multinomial logistic (softmax) regression without pivoting, similar to glmnet. -#' Users can print, make predictions on the produced model and save the model to the input path. -#' -#' @param data SparkDataFrame for training. -#' @param formula A symbolic description of the model to be fitted. Currently only a few formula -#' operators are supported, including '~', '.', ':', '+', and '-'. -#' @param regParam the regularization parameter. -#' @param elasticNetParam the ElasticNet mixing parameter. For alpha = 0.0, the penalty is an L2 penalty. -#' For alpha = 1.0, it is an L1 penalty. For 0.0 < alpha < 1.0, the penalty is a combination -#' of L1 and L2. Default is 0.0 which is an L2 penalty. -#' @param maxIter maximum iteration number. -#' @param tol convergence tolerance of iterations. -#' @param family the name of family which is a description of the label distribution to be used in the model. -#' Supported options: -#' \itemize{ -#' \item{"auto": Automatically select the family based on the number of classes: -#' If number of classes == 1 || number of classes == 2, set to "binomial". -#' Else, set to "multinomial".} -#' \item{"binomial": Binary logistic regression with pivoting.} -#' \item{"multinomial": Multinomial logistic (softmax) regression without pivoting.} -#' } -#' @param standardization whether to standardize the training features before fitting the model. The coefficients -#' of models will be always returned on the original scale, so it will be transparent for -#' users. Note that with/without standardization, the models should be always converged -#' to the same solution when no regularization is applied. Default is TRUE, same as glmnet. -#' @param thresholds in binary classification, in range [0, 1]. If the estimated probability of class label 1 -#' is > threshold, then predict 1, else 0. A high threshold encourages the model to predict 0 -#' more often; a low threshold encourages the model to predict 1 more often. Note: Setting this with -#' threshold p is equivalent to setting thresholds c(1-p, p). In multiclass (or binary) classification to adjust the probability of -#' predicting each class. Array must have length equal to the number of classes, with values > 0, -#' excepting that at most one value may be 0. The class with largest value p/t is predicted, where p -#' is the original probability of that class and t is the class's threshold. -#' @param weightCol The weight column name. -#' @param ... additional arguments passed to the method. -#' @return \code{spark.logit} returns a fitted logistic regression model. -#' @rdname spark.logit -#' @aliases spark.logit,SparkDataFrame,formula-method -#' @name spark.logit -#' @export -#' @examples -#' \dontrun{ -#' sparkR.session() -#' # binary logistic regression -#' df <- createDataFrame(iris) -#' training <- df[df$Species %in% c("versicolor", "virginica"), ] -#' model <- spark.logit(training, Species ~ ., regParam = 0.5) -#' summary <- summary(model) -#' -#' # fitted values on training data -#' fitted <- predict(model, training) -#' -#' # save fitted model to input path -#' path <- "path/to/model" -#' write.ml(model, path) -#' -#' # can also read back the saved model and predict -#' # Note that summary deos not work on loaded model -#' savedModel <- read.ml(path) -#' summary(savedModel) -#' -#' # multinomial logistic regression -#' -#' df <- createDataFrame(iris) -#' model <- spark.logit(df, Species ~ ., regParam = 0.5) -#' summary <- summary(model) -#' -#' } -#' @note spark.logit since 2.1.0 -setMethod("spark.logit", signature(data = "SparkDataFrame", formula = "formula"), - function(data, formula, regParam = 0.0, elasticNetParam = 0.0, maxIter = 100, - tol = 1E-6, family = "auto", standardization = TRUE, - thresholds = 0.5, weightCol = NULL) { - formula <- paste(deparse(formula), collapse = "") - - if (is.null(weightCol)) { - weightCol <- "" - } - - jobj <- callJStatic("org.apache.spark.ml.r.LogisticRegressionWrapper", "fit", - data@sdf, formula, as.numeric(regParam), - as.numeric(elasticNetParam), as.integer(maxIter), - as.numeric(tol), as.character(family), - as.logical(standardization), as.array(thresholds), - as.character(weightCol)) - new("LogisticRegressionModel", jobj = jobj) - }) - -# Predicted values based on an LogisticRegressionModel model - -#' @param newData a SparkDataFrame for testing. -#' @return \code{predict} returns the predicted values based on an LogisticRegressionModel. -#' @rdname spark.logit -#' @aliases predict,LogisticRegressionModel,SparkDataFrame-method -#' @export -#' @note predict(LogisticRegressionModel) since 2.1.0 -setMethod("predict", signature(object = "LogisticRegressionModel"), - function(object, newData) { - predict_internal(object, newData) - }) - -# Get the summary of an LogisticRegressionModel - -#' @param object an LogisticRegressionModel fitted by \code{spark.logit}. -#' @return \code{summary} returns summary information of the fitted model, which is a list. -#' The list includes \code{coefficients} (coefficients matrix of the fitted model). -#' @rdname spark.logit -#' @aliases summary,LogisticRegressionModel-method -#' @export -#' @note summary(LogisticRegressionModel) since 2.1.0 -setMethod("summary", signature(object = "LogisticRegressionModel"), - function(object) { - jobj <- object@jobj - features <- callJMethod(jobj, "rFeatures") - labels <- callJMethod(jobj, "labels") - coefficients <- callJMethod(jobj, "rCoefficients") - nCol <- length(coefficients) / length(features) - coefficients <- matrix(coefficients, ncol = nCol) - # If nCol == 1, means this is a binomial logistic regression model with pivoting. - # Otherwise, it's a multinomial logistic regression model without pivoting. - if (nCol == 1) { - colnames(coefficients) <- c("Estimate") - } else { - colnames(coefficients) <- unlist(labels) - } - rownames(coefficients) <- unlist(features) - - list(coefficients = coefficients) - }) - -#' Multilayer Perceptron Classification Model -#' -#' \code{spark.mlp} fits a multi-layer perceptron neural network model against a SparkDataFrame. -#' Users can call \code{summary} to print a summary of the fitted model, \code{predict} to make -#' predictions on new data, and \code{write.ml}/\code{read.ml} to save/load fitted models. -#' Only categorical data is supported. -#' For more details, see -#' \href{http://spark.apache.org/docs/latest/ml-classification-regression.html}{ -#' Multilayer Perceptron} -#' -#' @param data a \code{SparkDataFrame} of observations and labels for model fitting. -#' @param formula a symbolic description of the model to be fitted. Currently only a few formula -#' operators are supported, including '~', '.', ':', '+', and '-'. -#' @param blockSize blockSize parameter. -#' @param layers integer vector containing the number of nodes for each layer. -#' @param solver solver parameter, supported options: "gd" (minibatch gradient descent) or "l-bfgs". -#' @param maxIter maximum iteration number. -#' @param tol convergence tolerance of iterations. -#' @param stepSize stepSize parameter. -#' @param seed seed parameter for weights initialization. -#' @param initialWeights initialWeights parameter for weights initialization, it should be a -#' numeric vector. -#' @param ... additional arguments passed to the method. -#' @return \code{spark.mlp} returns a fitted Multilayer Perceptron Classification Model. -#' @rdname spark.mlp -#' @aliases spark.mlp,SparkDataFrame,formula-method -#' @name spark.mlp -#' @seealso \link{read.ml} -#' @export -#' @examples -#' \dontrun{ -#' df <- read.df("data/mllib/sample_multiclass_classification_data.txt", source = "libsvm") -#' -#' # fit a Multilayer Perceptron Classification Model -#' model <- spark.mlp(df, label ~ features, blockSize = 128, layers = c(4, 3), solver = "l-bfgs", -#' maxIter = 100, tol = 0.5, stepSize = 1, seed = 1, -#' initialWeights = c(0, 0, 0, 0, 0, 5, 5, 5, 5, 5, 9, 9, 9, 9, 9)) -#' -#' # get the summary of the model -#' summary(model) -#' -#' # make predictions -#' predictions <- predict(model, df) -#' -#' # save and load the model -#' path <- "path/to/model" -#' write.ml(model, path) -#' savedModel <- read.ml(path) -#' summary(savedModel) -#' } -#' @note spark.mlp since 2.1.0 -setMethod("spark.mlp", signature(data = "SparkDataFrame", formula = "formula"), - function(data, formula, layers, blockSize = 128, solver = "l-bfgs", maxIter = 100, - tol = 1E-6, stepSize = 0.03, seed = NULL, initialWeights = NULL) { - formula <- paste(deparse(formula), collapse = "") - if (is.null(layers)) { - stop ("layers must be a integer vector with length > 1.") - } - layers <- as.integer(na.omit(layers)) - if (length(layers) <= 1) { - stop ("layers must be a integer vector with length > 1.") - } - if (!is.null(seed)) { - seed <- as.character(as.integer(seed)) - } - if (!is.null(initialWeights)) { - initialWeights <- as.array(as.numeric(na.omit(initialWeights))) - } - jobj <- callJStatic("org.apache.spark.ml.r.MultilayerPerceptronClassifierWrapper", - "fit", data@sdf, formula, as.integer(blockSize), as.array(layers), - as.character(solver), as.integer(maxIter), as.numeric(tol), - as.numeric(stepSize), seed, initialWeights) - new("MultilayerPerceptronClassificationModel", jobj = jobj) - }) - -# Makes predictions from a model produced by spark.mlp(). - -#' @param newData a SparkDataFrame for testing. -#' @return \code{predict} returns a SparkDataFrame containing predicted labeled in a column named -#' "prediction". -#' @rdname spark.mlp -#' @aliases predict,MultilayerPerceptronClassificationModel-method -#' @export -#' @note predict(MultilayerPerceptronClassificationModel) since 2.1.0 -setMethod("predict", signature(object = "MultilayerPerceptronClassificationModel"), - function(object, newData) { - predict_internal(object, newData) - }) - -# Returns the summary of a Multilayer Perceptron Classification Model produced by \code{spark.mlp} - -#' @param object a Multilayer Perceptron Classification Model fitted by \code{spark.mlp} -#' @return \code{summary} returns summary information of the fitted model, which is a list. -#' The list includes \code{numOfInputs} (number of inputs), \code{numOfOutputs} -#' (number of outputs), \code{layers} (array of layer sizes including input -#' and output layers), and \code{weights} (the weights of layers). -#' For \code{weights}, it is a numeric vector with length equal to the expected -#' given the architecture (i.e., for 8-10-2 network, 112 connection weights). -#' @rdname spark.mlp -#' @export -#' @aliases summary,MultilayerPerceptronClassificationModel-method -#' @note summary(MultilayerPerceptronClassificationModel) since 2.1.0 -setMethod("summary", signature(object = "MultilayerPerceptronClassificationModel"), - function(object) { - jobj <- object@jobj - layers <- unlist(callJMethod(jobj, "layers")) - numOfInputs <- head(layers, n = 1) - numOfOutputs <- tail(layers, n = 1) - weights <- callJMethod(jobj, "weights") - list(numOfInputs = numOfInputs, numOfOutputs = numOfOutputs, - layers = layers, weights = weights) - }) - -#' Naive Bayes Models -#' -#' \code{spark.naiveBayes} fits a Bernoulli naive Bayes model against a SparkDataFrame. -#' Users can call \code{summary} to print a summary of the fitted model, \code{predict} to make -#' predictions on new data, and \code{write.ml}/\code{read.ml} to save/load fitted models. -#' Only categorical data is supported. -#' -#' @param data a \code{SparkDataFrame} of observations and labels for model fitting. -#' @param formula a symbolic description of the model to be fitted. Currently only a few formula -#' operators are supported, including '~', '.', ':', '+', and '-'. -#' @param smoothing smoothing parameter. -#' @param ... additional argument(s) passed to the method. Currently only \code{smoothing}. -#' @return \code{spark.naiveBayes} returns a fitted naive Bayes model. -#' @rdname spark.naiveBayes -#' @aliases spark.naiveBayes,SparkDataFrame,formula-method -#' @name spark.naiveBayes -#' @seealso e1071: \url{https://cran.r-project.org/package=e1071} -#' @export -#' @examples -#' \dontrun{ -#' data <- as.data.frame(UCBAdmissions) -#' df <- createDataFrame(data) -#' -#' # fit a Bernoulli naive Bayes model -#' model <- spark.naiveBayes(df, Admit ~ Gender + Dept, smoothing = 0) -#' -#' # get the summary of the model -#' summary(model) -#' -#' # make predictions -#' predictions <- predict(model, df) -#' -#' # save and load the model -#' path <- "path/to/model" -#' write.ml(model, path) -#' savedModel <- read.ml(path) -#' summary(savedModel) -#' } -#' @note spark.naiveBayes since 2.0.0 -setMethod("spark.naiveBayes", signature(data = "SparkDataFrame", formula = "formula"), - function(data, formula, smoothing = 1.0) { - formula <- paste(deparse(formula), collapse = "") - jobj <- callJStatic("org.apache.spark.ml.r.NaiveBayesWrapper", "fit", - formula, data@sdf, smoothing) - new("NaiveBayesModel", jobj = jobj) - }) - -# Saves the Bernoulli naive Bayes model to the input path. - -#' @param path the directory where the model is saved. -#' @param overwrite overwrites or not if the output path already exists. Default is FALSE -#' which means throw exception if the output path exists. -#' -#' @rdname spark.naiveBayes -#' @export -#' @seealso \link{write.ml} -#' @note write.ml(NaiveBayesModel, character) since 2.0.0 -setMethod("write.ml", signature(object = "NaiveBayesModel", path = "character"), - function(object, path, overwrite = FALSE) { - write_internal(object, path, overwrite) - }) - -# Saves the AFT survival regression model to the input path. - -#' @param path the directory where the model is saved. -#' @param overwrite overwrites or not if the output path already exists. Default is FALSE -#' which means throw exception if the output path exists. -#' @rdname spark.survreg -#' @export -#' @note write.ml(AFTSurvivalRegressionModel, character) since 2.0.0 -#' @seealso \link{write.ml} -setMethod("write.ml", signature(object = "AFTSurvivalRegressionModel", path = "character"), - function(object, path, overwrite = FALSE) { - write_internal(object, path, overwrite) - }) - -# Saves the generalized linear model to the input path. - -#' @param path the directory where the model is saved. -#' @param overwrite overwrites or not if the output path already exists. Default is FALSE -#' which means throw exception if the output path exists. -#' -#' @rdname spark.glm -#' @export -#' @note write.ml(GeneralizedLinearRegressionModel, character) since 2.0.0 -setMethod("write.ml", signature(object = "GeneralizedLinearRegressionModel", path = "character"), - function(object, path, overwrite = FALSE) { - write_internal(object, path, overwrite) - }) - -# Save fitted MLlib model to the input path - -#' @param path the directory where the model is saved. -#' @param overwrite overwrites or not if the output path already exists. Default is FALSE -#' which means throw exception if the output path exists. -#' -#' @rdname spark.kmeans -#' @export -#' @note write.ml(KMeansModel, character) since 2.0.0 -setMethod("write.ml", signature(object = "KMeansModel", path = "character"), - function(object, path, overwrite = FALSE) { - write_internal(object, path, overwrite) - }) - -# Saves the Multilayer Perceptron Classification Model to the input path. - -#' @param path the directory where the model is saved. -#' @param overwrite overwrites or not if the output path already exists. Default is FALSE -#' which means throw exception if the output path exists. -#' -#' @rdname spark.mlp -#' @aliases write.ml,MultilayerPerceptronClassificationModel,character-method -#' @export -#' @seealso \link{write.ml} -#' @note write.ml(MultilayerPerceptronClassificationModel, character) since 2.1.0 -setMethod("write.ml", signature(object = "MultilayerPerceptronClassificationModel", - path = "character"), - function(object, path, overwrite = FALSE) { - write_internal(object, path, overwrite) - }) - -# Save fitted IsotonicRegressionModel to the input path - -#' @param path The directory where the model is saved. -#' @param overwrite Overwrites or not if the output path already exists. Default is FALSE -#' which means throw exception if the output path exists. -#' -#' @rdname spark.isoreg -#' @aliases write.ml,IsotonicRegressionModel,character-method -#' @export -#' @note write.ml(IsotonicRegression, character) since 2.1.0 -setMethod("write.ml", signature(object = "IsotonicRegressionModel", path = "character"), - function(object, path, overwrite = FALSE) { - write_internal(object, path, overwrite) - }) - -# Save fitted LogisticRegressionModel to the input path - -#' @param path The directory where the model is saved. -#' @param overwrite Overwrites or not if the output path already exists. Default is FALSE -#' which means throw exception if the output path exists. -#' -#' @rdname spark.logit -#' @aliases write.ml,LogisticRegressionModel,character-method -#' @export -#' @note write.ml(LogisticRegression, character) since 2.1.0 -setMethod("write.ml", signature(object = "LogisticRegressionModel", path = "character"), - function(object, path, overwrite = FALSE) { - write_internal(object, path, overwrite) - }) - -# Save fitted MLlib model to the input path - -#' @param path the directory where the model is saved. -#' @param overwrite overwrites or not if the output path already exists. Default is FALSE -#' which means throw exception if the output path exists. -#' -#' @aliases write.ml,GaussianMixtureModel,character-method -#' @rdname spark.gaussianMixture -#' @export -#' @note write.ml(GaussianMixtureModel, character) since 2.1.0 -setMethod("write.ml", signature(object = "GaussianMixtureModel", path = "character"), - function(object, path, overwrite = FALSE) { - write_internal(object, path, overwrite) - }) - -#' Load a fitted MLlib model from the input path. -#' -#' @param path path of the model to read. -#' @return A fitted MLlib model. -#' @rdname read.ml -#' @name read.ml -#' @export -#' @seealso \link{write.ml} -#' @examples -#' \dontrun{ -#' path <- "path/to/model" -#' model <- read.ml(path) -#' } -#' @note read.ml since 2.0.0 -read.ml <- function(path) { - path <- suppressWarnings(normalizePath(path)) - jobj <- callJStatic("org.apache.spark.ml.r.RWrappers", "load", path) - if (isInstanceOf(jobj, "org.apache.spark.ml.r.NaiveBayesWrapper")) { - new("NaiveBayesModel", jobj = jobj) - } else if (isInstanceOf(jobj, "org.apache.spark.ml.r.AFTSurvivalRegressionWrapper")) { - new("AFTSurvivalRegressionModel", jobj = jobj) - } else if (isInstanceOf(jobj, "org.apache.spark.ml.r.GeneralizedLinearRegressionWrapper")) { - new("GeneralizedLinearRegressionModel", jobj = jobj) - } else if (isInstanceOf(jobj, "org.apache.spark.ml.r.KMeansWrapper")) { - new("KMeansModel", jobj = jobj) - } else if (isInstanceOf(jobj, "org.apache.spark.ml.r.LDAWrapper")) { - new("LDAModel", jobj = jobj) - } else if (isInstanceOf(jobj, "org.apache.spark.ml.r.MultilayerPerceptronClassifierWrapper")) { - new("MultilayerPerceptronClassificationModel", jobj = jobj) - } else if (isInstanceOf(jobj, "org.apache.spark.ml.r.IsotonicRegressionWrapper")) { - new("IsotonicRegressionModel", jobj = jobj) - } else if (isInstanceOf(jobj, "org.apache.spark.ml.r.GaussianMixtureWrapper")) { - new("GaussianMixtureModel", jobj = jobj) - } else if (isInstanceOf(jobj, "org.apache.spark.ml.r.ALSWrapper")) { - new("ALSModel", jobj = jobj) - } else if (isInstanceOf(jobj, "org.apache.spark.ml.r.LogisticRegressionWrapper")) { - new("LogisticRegressionModel", jobj = jobj) - } else if (isInstanceOf(jobj, "org.apache.spark.ml.r.RandomForestRegressorWrapper")) { - new("RandomForestRegressionModel", jobj = jobj) - } else if (isInstanceOf(jobj, "org.apache.spark.ml.r.RandomForestClassifierWrapper")) { - new("RandomForestClassificationModel", jobj = jobj) - } else if (isInstanceOf(jobj, "org.apache.spark.ml.r.GBTRegressorWrapper")) { - new("GBTRegressionModel", jobj = jobj) - } else if (isInstanceOf(jobj, "org.apache.spark.ml.r.GBTClassifierWrapper")) { - new("GBTClassificationModel", jobj = jobj) - } else { - stop("Unsupported model: ", jobj) - } -} - -#' Accelerated Failure Time (AFT) Survival Regression Model -#' -#' \code{spark.survreg} fits an accelerated failure time (AFT) survival regression model on -#' a SparkDataFrame. Users can call \code{summary} to get a summary of the fitted AFT model, -#' \code{predict} to make predictions on new data, and \code{write.ml}/\code{read.ml} to -#' save/load fitted models. -#' -#' @param data a SparkDataFrame for training. -#' @param formula a symbolic description of the model to be fitted. Currently only a few formula -#' operators are supported, including '~', ':', '+', and '-'. -#' Note that operator '.' is not supported currently. -#' @return \code{spark.survreg} returns a fitted AFT survival regression model. -#' @rdname spark.survreg -#' @seealso survival: \url{https://cran.r-project.org/package=survival} -#' @export -#' @examples -#' \dontrun{ -#' df <- createDataFrame(ovarian) -#' model <- spark.survreg(df, Surv(futime, fustat) ~ ecog_ps + rx) -#' -#' # get a summary of the model -#' summary(model) -#' -#' # make predictions -#' predicted <- predict(model, df) -#' showDF(predicted) -#' -#' # save and load the model -#' path <- "path/to/model" -#' write.ml(model, path) -#' savedModel <- read.ml(path) -#' summary(savedModel) -#' } -#' @note spark.survreg since 2.0.0 -setMethod("spark.survreg", signature(data = "SparkDataFrame", formula = "formula"), - function(data, formula) { - formula <- paste(deparse(formula), collapse = "") - jobj <- callJStatic("org.apache.spark.ml.r.AFTSurvivalRegressionWrapper", - "fit", formula, data@sdf) - new("AFTSurvivalRegressionModel", jobj = jobj) - }) - -#' Latent Dirichlet Allocation -#' -#' \code{spark.lda} fits a Latent Dirichlet Allocation model on a SparkDataFrame. Users can call -#' \code{summary} to get a summary of the fitted LDA model, \code{spark.posterior} to compute -#' posterior probabilities on new data, \code{spark.perplexity} to compute log perplexity on new -#' data and \code{write.ml}/\code{read.ml} to save/load fitted models. -#' -#' @param data A SparkDataFrame for training. -#' @param features Features column name. Either libSVM-format column or character-format column is -#' valid. -#' @param k Number of topics. -#' @param maxIter Maximum iterations. -#' @param optimizer Optimizer to train an LDA model, "online" or "em", default is "online". -#' @param subsamplingRate (For online optimizer) Fraction of the corpus to be sampled and used in -#' each iteration of mini-batch gradient descent, in range (0, 1]. -#' @param topicConcentration concentration parameter (commonly named \code{beta} or \code{eta}) for -#' the prior placed on topic distributions over terms, default -1 to set automatically on the -#' Spark side. Use \code{summary} to retrieve the effective topicConcentration. Only 1-size -#' numeric is accepted. -#' @param docConcentration concentration parameter (commonly named \code{alpha}) for the -#' prior placed on documents distributions over topics (\code{theta}), default -1 to set -#' automatically on the Spark side. Use \code{summary} to retrieve the effective -#' docConcentration. Only 1-size or \code{k}-size numeric is accepted. -#' @param customizedStopWords stopwords that need to be removed from the given corpus. Ignore the -#' parameter if libSVM-format column is used as the features column. -#' @param maxVocabSize maximum vocabulary size, default 1 << 18 -#' @param ... additional argument(s) passed to the method. -#' @return \code{spark.lda} returns a fitted Latent Dirichlet Allocation model. -#' @rdname spark.lda -#' @aliases spark.lda,SparkDataFrame-method -#' @seealso topicmodels: \url{https://cran.r-project.org/package=topicmodels} -#' @export -#' @examples -#' \dontrun{ -#' # nolint start -#' # An example "path/to/file" can be -#' # paste0(Sys.getenv("SPARK_HOME"), "/data/mllib/sample_lda_libsvm_data.txt") -#' # nolint end -#' text <- read.df("path/to/file", source = "libsvm") -#' model <- spark.lda(data = text, optimizer = "em") -#' -#' # get a summary of the model -#' summary(model) -#' -#' # compute posterior probabilities -#' posterior <- spark.posterior(model, text) -#' showDF(posterior) -#' -#' # compute perplexity -#' perplexity <- spark.perplexity(model, text) -#' -#' # save and load the model -#' path <- "path/to/model" -#' write.ml(model, path) -#' savedModel <- read.ml(path) -#' summary(savedModel) -#' } -#' @note spark.lda since 2.1.0 -setMethod("spark.lda", signature(data = "SparkDataFrame"), - function(data, features = "features", k = 10, maxIter = 20, optimizer = c("online", "em"), - subsamplingRate = 0.05, topicConcentration = -1, docConcentration = -1, - customizedStopWords = "", maxVocabSize = bitwShiftL(1, 18)) { - optimizer <- match.arg(optimizer) - jobj <- callJStatic("org.apache.spark.ml.r.LDAWrapper", "fit", data@sdf, features, - as.integer(k), as.integer(maxIter), optimizer, - as.numeric(subsamplingRate), topicConcentration, - as.array(docConcentration), as.array(customizedStopWords), - maxVocabSize) - new("LDAModel", jobj = jobj) - }) - -# Returns a summary of the AFT survival regression model produced by spark.survreg, -# similarly to R's summary(). - -#' @param object a fitted AFT survival regression model. -#' @return \code{summary} returns summary information of the fitted model, which is a list. -#' The list includes the model's \code{coefficients} (features, coefficients, -#' intercept and log(scale)). -#' @rdname spark.survreg -#' @export -#' @note summary(AFTSurvivalRegressionModel) since 2.0.0 -setMethod("summary", signature(object = "AFTSurvivalRegressionModel"), - function(object) { - jobj <- object@jobj - features <- callJMethod(jobj, "rFeatures") - coefficients <- callJMethod(jobj, "rCoefficients") - coefficients <- as.matrix(unlist(coefficients)) - colnames(coefficients) <- c("Value") - rownames(coefficients) <- unlist(features) - list(coefficients = coefficients) - }) - -# Makes predictions from an AFT survival regression model or a model produced by -# spark.survreg, similarly to R package survival's predict. - -#' @param newData a SparkDataFrame for testing. -#' @return \code{predict} returns a SparkDataFrame containing predicted values -#' on the original scale of the data (mean predicted value at scale = 1.0). -#' @rdname spark.survreg -#' @export -#' @note predict(AFTSurvivalRegressionModel) since 2.0.0 -setMethod("predict", signature(object = "AFTSurvivalRegressionModel"), - function(object, newData) { - predict_internal(object, newData) - }) - -#' Multivariate Gaussian Mixture Model (GMM) -#' -#' Fits multivariate gaussian mixture model against a Spark DataFrame, similarly to R's -#' mvnormalmixEM(). Users can call \code{summary} to print a summary of the fitted model, -#' \code{predict} to make predictions on new data, and \code{write.ml}/\code{read.ml} -#' to save/load fitted models. -#' -#' @param data a SparkDataFrame for training. -#' @param formula a symbolic description of the model to be fitted. Currently only a few formula -#' operators are supported, including '~', '.', ':', '+', and '-'. -#' Note that the response variable of formula is empty in spark.gaussianMixture. -#' @param k number of independent Gaussians in the mixture model. -#' @param maxIter maximum iteration number. -#' @param tol the convergence tolerance. -#' @param ... additional arguments passed to the method. -#' @aliases spark.gaussianMixture,SparkDataFrame,formula-method -#' @return \code{spark.gaussianMixture} returns a fitted multivariate gaussian mixture model. -#' @rdname spark.gaussianMixture -#' @name spark.gaussianMixture -#' @seealso mixtools: \url{https://cran.r-project.org/package=mixtools} -#' @export -#' @examples -#' \dontrun{ -#' sparkR.session() -#' library(mvtnorm) -#' set.seed(100) -#' a <- rmvnorm(4, c(0, 0)) -#' b <- rmvnorm(6, c(3, 4)) -#' data <- rbind(a, b) -#' df <- createDataFrame(as.data.frame(data)) -#' model <- spark.gaussianMixture(df, ~ V1 + V2, k = 2) -#' summary(model) -#' -#' # fitted values on training data -#' fitted <- predict(model, df) -#' head(select(fitted, "V1", "prediction")) -#' -#' # save fitted model to input path -#' path <- "path/to/model" -#' write.ml(model, path) -#' -#' # can also read back the saved model and print -#' savedModel <- read.ml(path) -#' summary(savedModel) -#' } -#' @note spark.gaussianMixture since 2.1.0 -#' @seealso \link{predict}, \link{read.ml}, \link{write.ml} -setMethod("spark.gaussianMixture", signature(data = "SparkDataFrame", formula = "formula"), - function(data, formula, k = 2, maxIter = 100, tol = 0.01) { - formula <- paste(deparse(formula), collapse = "") - jobj <- callJStatic("org.apache.spark.ml.r.GaussianMixtureWrapper", "fit", data@sdf, - formula, as.integer(k), as.integer(maxIter), as.numeric(tol)) - new("GaussianMixtureModel", jobj = jobj) - }) - -# Get the summary of a multivariate gaussian mixture model - -#' @param object a fitted gaussian mixture model. -#' @return \code{summary} returns summary of the fitted model, which is a list. -#' The list includes the model's \code{lambda} (lambda), \code{mu} (mu), -#' \code{sigma} (sigma), and \code{posterior} (posterior). -#' @aliases spark.gaussianMixture,SparkDataFrame,formula-method -#' @rdname spark.gaussianMixture -#' @export -#' @note summary(GaussianMixtureModel) since 2.1.0 -setMethod("summary", signature(object = "GaussianMixtureModel"), - function(object) { - jobj <- object@jobj - is.loaded <- callJMethod(jobj, "isLoaded") - lambda <- unlist(callJMethod(jobj, "lambda")) - muList <- callJMethod(jobj, "mu") - sigmaList <- callJMethod(jobj, "sigma") - k <- callJMethod(jobj, "k") - dim <- callJMethod(jobj, "dim") - mu <- c() - for (i in 1 : k) { - start <- (i - 1) * dim + 1 - end <- i * dim - mu[[i]] <- unlist(muList[start : end]) - } - sigma <- c() - for (i in 1 : k) { - start <- (i - 1) * dim * dim + 1 - end <- i * dim * dim - sigma[[i]] <- t(matrix(sigmaList[start : end], ncol = dim)) - } - posterior <- if (is.loaded) { - NULL - } else { - dataFrame(callJMethod(jobj, "posterior")) - } - list(lambda = lambda, mu = mu, sigma = sigma, - posterior = posterior, is.loaded = is.loaded) - }) - -# Predicted values based on a gaussian mixture model - -#' @param newData a SparkDataFrame for testing. -#' @return \code{predict} returns a SparkDataFrame containing predicted labels in a column named -#' "prediction". -#' @aliases predict,GaussianMixtureModel,SparkDataFrame-method -#' @rdname spark.gaussianMixture -#' @export -#' @note predict(GaussianMixtureModel) since 2.1.0 -setMethod("predict", signature(object = "GaussianMixtureModel"), - function(object, newData) { - predict_internal(object, newData) - }) - -#' Alternating Least Squares (ALS) for Collaborative Filtering -#' -#' \code{spark.als} learns latent factors in collaborative filtering via alternating least -#' squares. Users can call \code{summary} to obtain fitted latent factors, \code{predict} -#' to make predictions on new data, and \code{write.ml}/\code{read.ml} to save/load fitted models. -#' -#' For more details, see -#' \href{http://spark.apache.org/docs/latest/ml-collaborative-filtering.html}{MLlib: -#' Collaborative Filtering}. -#' -#' @param data a SparkDataFrame for training. -#' @param ratingCol column name for ratings. -#' @param userCol column name for user ids. Ids must be (or can be coerced into) integers. -#' @param itemCol column name for item ids. Ids must be (or can be coerced into) integers. -#' @param rank rank of the matrix factorization (> 0). -#' @param regParam regularization parameter (>= 0). -#' @param maxIter maximum number of iterations (>= 0). -#' @param nonnegative logical value indicating whether to apply nonnegativity constraints. -#' @param implicitPrefs logical value indicating whether to use implicit preference. -#' @param alpha alpha parameter in the implicit preference formulation (>= 0). -#' @param seed integer seed for random number generation. -#' @param numUserBlocks number of user blocks used to parallelize computation (> 0). -#' @param numItemBlocks number of item blocks used to parallelize computation (> 0). -#' @param checkpointInterval number of checkpoint intervals (>= 1) or disable checkpoint (-1). -#' @param ... additional argument(s) passed to the method. -#' @return \code{spark.als} returns a fitted ALS model. -#' @rdname spark.als -#' @aliases spark.als,SparkDataFrame-method -#' @name spark.als -#' @export -#' @examples -#' \dontrun{ -#' ratings <- list(list(0, 0, 4.0), list(0, 1, 2.0), list(1, 1, 3.0), list(1, 2, 4.0), -#' list(2, 1, 1.0), list(2, 2, 5.0)) -#' df <- createDataFrame(ratings, c("user", "item", "rating")) -#' model <- spark.als(df, "rating", "user", "item") -#' -#' # extract latent factors -#' stats <- summary(model) -#' userFactors <- stats$userFactors -#' itemFactors <- stats$itemFactors -#' -#' # make predictions -#' predicted <- predict(model, df) -#' showDF(predicted) -#' -#' # save and load the model -#' path <- "path/to/model" -#' write.ml(model, path) -#' savedModel <- read.ml(path) -#' summary(savedModel) -#' -#' # set other arguments -#' modelS <- spark.als(df, "rating", "user", "item", rank = 20, -#' regParam = 0.1, nonnegative = TRUE) -#' statsS <- summary(modelS) -#' } -#' @note spark.als since 2.1.0 -setMethod("spark.als", signature(data = "SparkDataFrame"), - function(data, ratingCol = "rating", userCol = "user", itemCol = "item", - rank = 10, regParam = 0.1, maxIter = 10, nonnegative = FALSE, - implicitPrefs = FALSE, alpha = 1.0, numUserBlocks = 10, numItemBlocks = 10, - checkpointInterval = 10, seed = 0) { - - if (!is.numeric(rank) || rank <= 0) { - stop("rank should be a positive number.") - } - if (!is.numeric(regParam) || regParam < 0) { - stop("regParam should be a nonnegative number.") - } - if (!is.numeric(maxIter) || maxIter <= 0) { - stop("maxIter should be a positive number.") - } - - jobj <- callJStatic("org.apache.spark.ml.r.ALSWrapper", - "fit", data@sdf, ratingCol, userCol, itemCol, as.integer(rank), - regParam, as.integer(maxIter), implicitPrefs, alpha, nonnegative, - as.integer(numUserBlocks), as.integer(numItemBlocks), - as.integer(checkpointInterval), as.integer(seed)) - new("ALSModel", jobj = jobj) - }) - -# Returns a summary of the ALS model produced by spark.als. - -#' @param object a fitted ALS model. -#' @return \code{summary} returns summary information of the fitted model, which is a list. -#' The list includes \code{user} (the names of the user column), -#' \code{item} (the item column), \code{rating} (the rating column), \code{userFactors} -#' (the estimated user factors), \code{itemFactors} (the estimated item factors), -#' and \code{rank} (rank of the matrix factorization model). -#' @rdname spark.als -#' @aliases summary,ALSModel-method -#' @export -#' @note summary(ALSModel) since 2.1.0 -setMethod("summary", signature(object = "ALSModel"), - function(object) { - jobj <- object@jobj - user <- callJMethod(jobj, "userCol") - item <- callJMethod(jobj, "itemCol") - rating <- callJMethod(jobj, "ratingCol") - userFactors <- dataFrame(callJMethod(jobj, "userFactors")) - itemFactors <- dataFrame(callJMethod(jobj, "itemFactors")) - rank <- callJMethod(jobj, "rank") - list(user = user, item = item, rating = rating, userFactors = userFactors, - itemFactors = itemFactors, rank = rank) - }) - - -# Makes predictions from an ALS model or a model produced by spark.als. - -#' @param newData a SparkDataFrame for testing. -#' @return \code{predict} returns a SparkDataFrame containing predicted values. -#' @rdname spark.als -#' @aliases predict,ALSModel-method -#' @export -#' @note predict(ALSModel) since 2.1.0 -setMethod("predict", signature(object = "ALSModel"), - function(object, newData) { - predict_internal(object, newData) - }) - - -# Saves the ALS model to the input path. - -#' @param path the directory where the model is saved. -#' @param overwrite logical value indicating whether to overwrite if the output path -#' already exists. Default is FALSE which means throw exception -#' if the output path exists. -#' -#' @rdname spark.als -#' @aliases write.ml,ALSModel,character-method -#' @export -#' @seealso \link{read.ml} -#' @note write.ml(ALSModel, character) since 2.1.0 -setMethod("write.ml", signature(object = "ALSModel", path = "character"), - function(object, path, overwrite = FALSE) { - write_internal(object, path, overwrite) - }) - -#' (One-Sample) Kolmogorov-Smirnov Test -#' -#' @description -#' \code{spark.kstest} Conduct the two-sided Kolmogorov-Smirnov (KS) test for data sampled from a -#' continuous distribution. -#' -#' By comparing the largest difference between the empirical cumulative -#' distribution of the sample data and the theoretical distribution we can provide a test for the -#' the null hypothesis that the sample data comes from that theoretical distribution. -#' -#' Users can call \code{summary} to obtain a summary of the test, and \code{print.summary.KSTest} -#' to print out a summary result. -#' -#' @param data a SparkDataFrame of user data. -#' @param testCol column name where the test data is from. It should be a column of double type. -#' @param nullHypothesis name of the theoretical distribution tested against. Currently only -#' \code{"norm"} for normal distribution is supported. -#' @param distParams parameters(s) of the distribution. For \code{nullHypothesis = "norm"}, -#' we can provide as a vector the mean and standard deviation of -#' the distribution. If none is provided, then standard normal will be used. -#' If only one is provided, then the standard deviation will be set to be one. -#' @param ... additional argument(s) passed to the method. -#' @return \code{spark.kstest} returns a test result object. -#' @rdname spark.kstest -#' @aliases spark.kstest,SparkDataFrame-method -#' @name spark.kstest -#' @seealso \href{http://spark.apache.org/docs/latest/mllib-statistics.html#hypothesis-testing}{ -#' MLlib: Hypothesis Testing} -#' @export -#' @examples -#' \dontrun{ -#' data <- data.frame(test = c(0.1, 0.15, 0.2, 0.3, 0.25)) -#' df <- createDataFrame(data) -#' test <- spark.kstest(df, "test", "norm", c(0, 1)) -#' -#' # get a summary of the test result -#' testSummary <- summary(test) -#' testSummary -#' -#' # print out the summary in an organized way -#' print.summary.KSTest(testSummary) -#' } -#' @note spark.kstest since 2.1.0 -setMethod("spark.kstest", signature(data = "SparkDataFrame"), - function(data, testCol = "test", nullHypothesis = c("norm"), distParams = c(0, 1)) { - tryCatch(match.arg(nullHypothesis), - error = function(e) { - msg <- paste("Distribution", nullHypothesis, "is not supported.") - stop(msg) - }) - if (nullHypothesis == "norm") { - distParams <- as.numeric(distParams) - mu <- ifelse(length(distParams) < 1, 0, distParams[1]) - sigma <- ifelse(length(distParams) < 2, 1, distParams[2]) - jobj <- callJStatic("org.apache.spark.ml.r.KSTestWrapper", - "test", data@sdf, testCol, nullHypothesis, - as.array(c(mu, sigma))) - new("KSTest", jobj = jobj) - } -}) - -# Get the summary of Kolmogorov-Smirnov (KS) Test. -#' @param object test result object of KSTest by \code{spark.kstest}. -#' @return \code{summary} returns summary information of KSTest object, which is a list. -#' The list includes the \code{p.value} (p-value), \code{statistic} (test statistic -#' computed for the test), \code{nullHypothesis} (the null hypothesis with its -#' parameters tested against) and \code{degreesOfFreedom} (degrees of freedom of the test). -#' @rdname spark.kstest -#' @aliases summary,KSTest-method -#' @export -#' @note summary(KSTest) since 2.1.0 -setMethod("summary", signature(object = "KSTest"), - function(object) { - jobj <- object@jobj - pValue <- callJMethod(jobj, "pValue") - statistic <- callJMethod(jobj, "statistic") - nullHypothesis <- callJMethod(jobj, "nullHypothesis") - distName <- callJMethod(jobj, "distName") - distParams <- unlist(callJMethod(jobj, "distParams")) - degreesOfFreedom <- callJMethod(jobj, "degreesOfFreedom") - - ans <- list(p.value = pValue, statistic = statistic, nullHypothesis = nullHypothesis, - nullHypothesis.name = distName, nullHypothesis.parameters = distParams, - degreesOfFreedom = degreesOfFreedom, jobj = jobj) - class(ans) <- "summary.KSTest" - ans - }) - -# Prints the summary of KSTest - -#' @rdname spark.kstest -#' @param x summary object of KSTest returned by \code{summary}. -#' @export -#' @note print.summary.KSTest since 2.1.0 -print.summary.KSTest <- function(x, ...) { - jobj <- x$jobj - summaryStr <- callJMethod(jobj, "summary") - cat(summaryStr, "\n") - invisible(x) -} - -#' Random Forest Model for Regression and Classification -#' -#' \code{spark.randomForest} fits a Random Forest Regression model or Classification model on -#' a SparkDataFrame. Users can call \code{summary} to get a summary of the fitted Random Forest -#' model, \code{predict} to make predictions on new data, and \code{write.ml}/\code{read.ml} to -#' save/load fitted models. -#' For more details, see -#' \href{http://spark.apache.org/docs/latest/ml-classification-regression.html#random-forest-regression}{ -#' Random Forest Regression} and -#' \href{http://spark.apache.org/docs/latest/ml-classification-regression.html#random-forest-classifier}{ -#' Random Forest Classification} -#' -#' @param data a SparkDataFrame for training. -#' @param formula a symbolic description of the model to be fitted. Currently only a few formula -#' operators are supported, including '~', ':', '+', and '-'. -#' @param type type of model, one of "regression" or "classification", to fit -#' @param maxDepth Maximum depth of the tree (>= 0). -#' @param maxBins Maximum number of bins used for discretizing continuous features and for choosing -#' how to split on features at each node. More bins give higher granularity. Must be -#' >= 2 and >= number of categories in any categorical feature. -#' @param numTrees Number of trees to train (>= 1). -#' @param impurity Criterion used for information gain calculation. -#' For regression, must be "variance". For classification, must be one of -#' "entropy" and "gini", default is "gini". -#' @param featureSubsetStrategy The number of features to consider for splits at each tree node. -#' Supported options: "auto", "all", "onethird", "sqrt", "log2", (0.0-1.0], [1-n]. -#' @param seed integer seed for random number generation. -#' @param subsamplingRate Fraction of the training data used for learning each decision tree, in -#' range (0, 1]. -#' @param minInstancesPerNode Minimum number of instances each child must have after split. -#' @param minInfoGain Minimum information gain for a split to be considered at a tree node. -#' @param checkpointInterval Param for set checkpoint interval (>= 1) or disable checkpoint (-1). -#' @param maxMemoryInMB Maximum memory in MB allocated to histogram aggregation. -#' @param cacheNodeIds If FALSE, the algorithm will pass trees to executors to match instances with -#' nodes. If TRUE, the algorithm will cache node IDs for each instance. Caching -#' can speed up training of deeper trees. Users can set how often should the -#' cache be checkpointed or disable it by setting checkpointInterval. -#' @param ... additional arguments passed to the method. -#' @aliases spark.randomForest,SparkDataFrame,formula-method -#' @return \code{spark.randomForest} returns a fitted Random Forest model. -#' @rdname spark.randomForest -#' @name spark.randomForest -#' @export -#' @examples -#' \dontrun{ -#' # fit a Random Forest Regression Model -#' df <- createDataFrame(longley) -#' model <- spark.randomForest(df, Employed ~ ., type = "regression", maxDepth = 5, maxBins = 16) -#' -#' # get the summary of the model -#' summary(model) -#' -#' # make predictions -#' predictions <- predict(model, df) -#' -#' # save and load the model -#' path <- "path/to/model" -#' write.ml(model, path) -#' savedModel <- read.ml(path) -#' summary(savedModel) -#' -#' # fit a Random Forest Classification Model -#' df <- createDataFrame(iris) -#' model <- spark.randomForest(df, Species ~ Petal_Length + Petal_Width, "classification") -#' } -#' @note spark.randomForest since 2.1.0 -setMethod("spark.randomForest", signature(data = "SparkDataFrame", formula = "formula"), - function(data, formula, type = c("regression", "classification"), - maxDepth = 5, maxBins = 32, numTrees = 20, impurity = NULL, - featureSubsetStrategy = "auto", seed = NULL, subsamplingRate = 1.0, - minInstancesPerNode = 1, minInfoGain = 0.0, checkpointInterval = 10, - maxMemoryInMB = 256, cacheNodeIds = FALSE) { - type <- match.arg(type) - formula <- paste(deparse(formula), collapse = "") - if (!is.null(seed)) { - seed <- as.character(as.integer(seed)) - } - switch(type, - regression = { - if (is.null(impurity)) impurity <- "variance" - impurity <- match.arg(impurity, "variance") - jobj <- callJStatic("org.apache.spark.ml.r.RandomForestRegressorWrapper", - "fit", data@sdf, formula, as.integer(maxDepth), - as.integer(maxBins), as.integer(numTrees), - impurity, as.integer(minInstancesPerNode), - as.numeric(minInfoGain), as.integer(checkpointInterval), - as.character(featureSubsetStrategy), seed, - as.numeric(subsamplingRate), - as.integer(maxMemoryInMB), as.logical(cacheNodeIds)) - new("RandomForestRegressionModel", jobj = jobj) - }, - classification = { - if (is.null(impurity)) impurity <- "gini" - impurity <- match.arg(impurity, c("gini", "entropy")) - jobj <- callJStatic("org.apache.spark.ml.r.RandomForestClassifierWrapper", - "fit", data@sdf, formula, as.integer(maxDepth), - as.integer(maxBins), as.integer(numTrees), - impurity, as.integer(minInstancesPerNode), - as.numeric(minInfoGain), as.integer(checkpointInterval), - as.character(featureSubsetStrategy), seed, - as.numeric(subsamplingRate), - as.integer(maxMemoryInMB), as.logical(cacheNodeIds)) - new("RandomForestClassificationModel", jobj = jobj) - } - ) - }) - -# Makes predictions from a Random Forest Regression model or Classification model - -#' @param newData a SparkDataFrame for testing. -#' @return \code{predict} returns a SparkDataFrame containing predicted labeled in a column named -#' "prediction". -#' @rdname spark.randomForest -#' @aliases predict,RandomForestRegressionModel-method -#' @export -#' @note predict(RandomForestRegressionModel) since 2.1.0 -setMethod("predict", signature(object = "RandomForestRegressionModel"), - function(object, newData) { - predict_internal(object, newData) - }) - -#' @rdname spark.randomForest -#' @aliases predict,RandomForestClassificationModel-method -#' @export -#' @note predict(RandomForestClassificationModel) since 2.1.0 -setMethod("predict", signature(object = "RandomForestClassificationModel"), - function(object, newData) { - predict_internal(object, newData) - }) - -# Save the Random Forest Regression or Classification model to the input path. - -#' @param object A fitted Random Forest regression model or classification model. -#' @param path The directory where the model is saved. -#' @param overwrite Overwrites or not if the output path already exists. Default is FALSE -#' which means throw exception if the output path exists. -#' -#' @aliases write.ml,RandomForestRegressionModel,character-method -#' @rdname spark.randomForest -#' @export -#' @note write.ml(RandomForestRegressionModel, character) since 2.1.0 -setMethod("write.ml", signature(object = "RandomForestRegressionModel", path = "character"), - function(object, path, overwrite = FALSE) { - write_internal(object, path, overwrite) - }) - -#' @aliases write.ml,RandomForestClassificationModel,character-method -#' @rdname spark.randomForest -#' @export -#' @note write.ml(RandomForestClassificationModel, character) since 2.1.0 -setMethod("write.ml", signature(object = "RandomForestClassificationModel", path = "character"), - function(object, path, overwrite = FALSE) { - write_internal(object, path, overwrite) - }) - -# Create the summary of a tree ensemble model (eg. Random Forest, GBT) -summary.treeEnsemble <- function(model) { - jobj <- model@jobj - formula <- callJMethod(jobj, "formula") - numFeatures <- callJMethod(jobj, "numFeatures") - features <- callJMethod(jobj, "features") - featureImportances <- callJMethod(callJMethod(jobj, "featureImportances"), "toString") - numTrees <- callJMethod(jobj, "numTrees") - treeWeights <- callJMethod(jobj, "treeWeights") - list(formula = formula, - numFeatures = numFeatures, - features = features, - featureImportances = featureImportances, - numTrees = numTrees, - treeWeights = treeWeights, - jobj = jobj) -} - -# Get the summary of a Random Forest Regression Model - -#' @return \code{summary} returns summary information of the fitted model, which is a list. -#' The list of components includes \code{formula} (formula), -#' \code{numFeatures} (number of features), \code{features} (list of features), -#' \code{featureImportances} (feature importances), \code{numTrees} (number of trees), -#' and \code{treeWeights} (tree weights). -#' @rdname spark.randomForest -#' @aliases summary,RandomForestRegressionModel-method -#' @export -#' @note summary(RandomForestRegressionModel) since 2.1.0 -setMethod("summary", signature(object = "RandomForestRegressionModel"), - function(object) { - ans <- summary.treeEnsemble(object) - class(ans) <- "summary.RandomForestRegressionModel" - ans - }) - -# Get the summary of a Random Forest Classification Model - -#' @rdname spark.randomForest -#' @aliases summary,RandomForestClassificationModel-method -#' @export -#' @note summary(RandomForestClassificationModel) since 2.1.0 -setMethod("summary", signature(object = "RandomForestClassificationModel"), - function(object) { - ans <- summary.treeEnsemble(object) - class(ans) <- "summary.RandomForestClassificationModel" - ans - }) - -# Prints the summary of tree ensemble models (eg. Random Forest, GBT) -print.summary.treeEnsemble <- function(x) { - jobj <- x$jobj - cat("Formula: ", x$formula) - cat("\nNumber of features: ", x$numFeatures) - cat("\nFeatures: ", unlist(x$features)) - cat("\nFeature importances: ", x$featureImportances) - cat("\nNumber of trees: ", x$numTrees) - cat("\nTree weights: ", unlist(x$treeWeights)) - - summaryStr <- callJMethod(jobj, "summary") - cat("\n", summaryStr, "\n") - invisible(x) -} - -# Prints the summary of Random Forest Regression Model - -#' @param x summary object of Random Forest regression model or classification model -#' returned by \code{summary}. -#' @rdname spark.randomForest -#' @export -#' @note print.summary.RandomForestRegressionModel since 2.1.0 -print.summary.RandomForestRegressionModel <- function(x, ...) { - print.summary.treeEnsemble(x) -} - -# Prints the summary of Random Forest Classification Model - -#' @rdname spark.randomForest -#' @export -#' @note print.summary.RandomForestClassificationModel since 2.1.0 -print.summary.RandomForestClassificationModel <- function(x, ...) { - print.summary.treeEnsemble(x) -} - -#' Gradient Boosted Tree Model for Regression and Classification -#' -#' \code{spark.gbt} fits a Gradient Boosted Tree Regression model or Classification model on a -#' SparkDataFrame. Users can call \code{summary} to get a summary of the fitted -#' Gradient Boosted Tree model, \code{predict} to make predictions on new data, and -#' \code{write.ml}/\code{read.ml} to save/load fitted models. -#' For more details, see -#' \href{http://spark.apache.org/docs/latest/ml-classification-regression.html#gradient-boosted-tree-regression}{ -#' GBT Regression} and -#' \href{http://spark.apache.org/docs/latest/ml-classification-regression.html#gradient-boosted-tree-classifier}{ -#' GBT Classification} -#' -#' @param data a SparkDataFrame for training. -#' @param formula a symbolic description of the model to be fitted. Currently only a few formula -#' operators are supported, including '~', ':', '+', and '-'. -#' @param type type of model, one of "regression" or "classification", to fit -#' @param maxDepth Maximum depth of the tree (>= 0). -#' @param maxBins Maximum number of bins used for discretizing continuous features and for choosing -#' how to split on features at each node. More bins give higher granularity. Must be -#' >= 2 and >= number of categories in any categorical feature. -#' @param maxIter Param for maximum number of iterations (>= 0). -#' @param stepSize Param for Step size to be used for each iteration of optimization. -#' @param lossType Loss function which GBT tries to minimize. -#' For classification, must be "logistic". For regression, must be one of -#' "squared" (L2) and "absolute" (L1), default is "squared". -#' @param seed integer seed for random number generation. -#' @param subsamplingRate Fraction of the training data used for learning each decision tree, in -#' range (0, 1]. -#' @param minInstancesPerNode Minimum number of instances each child must have after split. If a -#' split causes the left or right child to have fewer than -#' minInstancesPerNode, the split will be discarded as invalid. Should be -#' >= 1. -#' @param minInfoGain Minimum information gain for a split to be considered at a tree node. -#' @param checkpointInterval Param for set checkpoint interval (>= 1) or disable checkpoint (-1). -#' @param maxMemoryInMB Maximum memory in MB allocated to histogram aggregation. -#' @param cacheNodeIds If FALSE, the algorithm will pass trees to executors to match instances with -#' nodes. If TRUE, the algorithm will cache node IDs for each instance. Caching -#' can speed up training of deeper trees. Users can set how often should the -#' cache be checkpointed or disable it by setting checkpointInterval. -#' @param ... additional arguments passed to the method. -#' @aliases spark.gbt,SparkDataFrame,formula-method -#' @return \code{spark.gbt} returns a fitted Gradient Boosted Tree model. -#' @rdname spark.gbt -#' @name spark.gbt -#' @export -#' @examples -#' \dontrun{ -#' # fit a Gradient Boosted Tree Regression Model -#' df <- createDataFrame(longley) -#' model <- spark.gbt(df, Employed ~ ., type = "regression", maxDepth = 5, maxBins = 16) -#' -#' # get the summary of the model -#' summary(model) -#' -#' # make predictions -#' predictions <- predict(model, df) -#' -#' # save and load the model -#' path <- "path/to/model" -#' write.ml(model, path) -#' savedModel <- read.ml(path) -#' summary(savedModel) -#' -#' # fit a Gradient Boosted Tree Classification Model -#' # label must be binary - Only binary classification is supported for GBT. -#' df <- createDataFrame(iris[iris$Species != "virginica", ]) -#' model <- spark.gbt(df, Species ~ Petal_Length + Petal_Width, "classification") -#' -#' # numeric label is also supported -#' iris2 <- iris[iris$Species != "virginica", ] -#' iris2$NumericSpecies <- ifelse(iris2$Species == "setosa", 0, 1) -#' df <- createDataFrame(iris2) -#' model <- spark.gbt(df, NumericSpecies ~ ., type = "classification") -#' } -#' @note spark.gbt since 2.1.0 -setMethod("spark.gbt", signature(data = "SparkDataFrame", formula = "formula"), - function(data, formula, type = c("regression", "classification"), - maxDepth = 5, maxBins = 32, maxIter = 20, stepSize = 0.1, lossType = NULL, - seed = NULL, subsamplingRate = 1.0, minInstancesPerNode = 1, minInfoGain = 0.0, - checkpointInterval = 10, maxMemoryInMB = 256, cacheNodeIds = FALSE) { - type <- match.arg(type) - formula <- paste(deparse(formula), collapse = "") - if (!is.null(seed)) { - seed <- as.character(as.integer(seed)) - } - switch(type, - regression = { - if (is.null(lossType)) lossType <- "squared" - lossType <- match.arg(lossType, c("squared", "absolute")) - jobj <- callJStatic("org.apache.spark.ml.r.GBTRegressorWrapper", - "fit", data@sdf, formula, as.integer(maxDepth), - as.integer(maxBins), as.integer(maxIter), - as.numeric(stepSize), as.integer(minInstancesPerNode), - as.numeric(minInfoGain), as.integer(checkpointInterval), - lossType, seed, as.numeric(subsamplingRate), - as.integer(maxMemoryInMB), as.logical(cacheNodeIds)) - new("GBTRegressionModel", jobj = jobj) - }, - classification = { - if (is.null(lossType)) lossType <- "logistic" - lossType <- match.arg(lossType, "logistic") - jobj <- callJStatic("org.apache.spark.ml.r.GBTClassifierWrapper", - "fit", data@sdf, formula, as.integer(maxDepth), - as.integer(maxBins), as.integer(maxIter), - as.numeric(stepSize), as.integer(minInstancesPerNode), - as.numeric(minInfoGain), as.integer(checkpointInterval), - lossType, seed, as.numeric(subsamplingRate), - as.integer(maxMemoryInMB), as.logical(cacheNodeIds)) - new("GBTClassificationModel", jobj = jobj) - } - ) - }) - -# Makes predictions from a Gradient Boosted Tree Regression model or Classification model - -#' @param newData a SparkDataFrame for testing. -#' @return \code{predict} returns a SparkDataFrame containing predicted labeled in a column named -#' "prediction". -#' @rdname spark.gbt -#' @aliases predict,GBTRegressionModel-method -#' @export -#' @note predict(GBTRegressionModel) since 2.1.0 -setMethod("predict", signature(object = "GBTRegressionModel"), - function(object, newData) { - predict_internal(object, newData) - }) - -#' @rdname spark.gbt -#' @aliases predict,GBTClassificationModel-method -#' @export -#' @note predict(GBTClassificationModel) since 2.1.0 -setMethod("predict", signature(object = "GBTClassificationModel"), - function(object, newData) { - predict_internal(object, newData) - }) - -# Save the Gradient Boosted Tree Regression or Classification model to the input path. - -#' @param object A fitted Gradient Boosted Tree regression model or classification model. -#' @param path The directory where the model is saved. -#' @param overwrite Overwrites or not if the output path already exists. Default is FALSE -#' which means throw exception if the output path exists. -#' @aliases write.ml,GBTRegressionModel,character-method -#' @rdname spark.gbt -#' @export -#' @note write.ml(GBTRegressionModel, character) since 2.1.0 -setMethod("write.ml", signature(object = "GBTRegressionModel", path = "character"), - function(object, path, overwrite = FALSE) { - write_internal(object, path, overwrite) - }) - -#' @aliases write.ml,GBTClassificationModel,character-method -#' @rdname spark.gbt -#' @export -#' @note write.ml(GBTClassificationModel, character) since 2.1.0 -setMethod("write.ml", signature(object = "GBTClassificationModel", path = "character"), - function(object, path, overwrite = FALSE) { - write_internal(object, path, overwrite) - }) - -# Get the summary of a Gradient Boosted Tree Regression Model - -#' @return \code{summary} returns summary information of the fitted model, which is a list. -#' The list of components includes \code{formula} (formula), -#' \code{numFeatures} (number of features), \code{features} (list of features), -#' \code{featureImportances} (feature importances), \code{numTrees} (number of trees), -#' and \code{treeWeights} (tree weights). -#' @rdname spark.gbt -#' @aliases summary,GBTRegressionModel-method -#' @export -#' @note summary(GBTRegressionModel) since 2.1.0 -setMethod("summary", signature(object = "GBTRegressionModel"), - function(object) { - ans <- summary.treeEnsemble(object) - class(ans) <- "summary.GBTRegressionModel" - ans - }) - -# Get the summary of a Gradient Boosted Tree Classification Model - -#' @rdname spark.gbt -#' @aliases summary,GBTClassificationModel-method -#' @export -#' @note summary(GBTClassificationModel) since 2.1.0 -setMethod("summary", signature(object = "GBTClassificationModel"), - function(object) { - ans <- summary.treeEnsemble(object) - class(ans) <- "summary.GBTClassificationModel" - ans - }) - -# Prints the summary of Gradient Boosted Tree Regression Model - -#' @param x summary object of Gradient Boosted Tree regression model or classification model -#' returned by \code{summary}. -#' @rdname spark.gbt -#' @export -#' @note print.summary.GBTRegressionModel since 2.1.0 -print.summary.GBTRegressionModel <- function(x, ...) { - print.summary.treeEnsemble(x) -} - -# Prints the summary of Gradient Boosted Tree Classification Model - -#' @rdname spark.gbt -#' @export -#' @note print.summary.GBTClassificationModel since 2.1.0 -print.summary.GBTClassificationModel <- function(x, ...) { - print.summary.treeEnsemble(x) -} diff --git a/R/pkg/R/mllib_classification.R b/R/pkg/R/mllib_classification.R new file mode 100644 index 0000000000000..8da84499dfc09 --- /dev/null +++ b/R/pkg/R/mllib_classification.R @@ -0,0 +1,417 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +# mllib_regression.R: Provides methods for MLlib classification algorithms +# (except for tree-based algorithms) integration + +#' S4 class that represents an LogisticRegressionModel +#' +#' @param jobj a Java object reference to the backing Scala LogisticRegressionModel +#' @export +#' @note LogisticRegressionModel since 2.1.0 +setClass("LogisticRegressionModel", representation(jobj = "jobj")) + +#' S4 class that represents a MultilayerPerceptronClassificationModel +#' +#' @param jobj a Java object reference to the backing Scala MultilayerPerceptronClassifierWrapper +#' @export +#' @note MultilayerPerceptronClassificationModel since 2.1.0 +setClass("MultilayerPerceptronClassificationModel", representation(jobj = "jobj")) + +#' S4 class that represents a NaiveBayesModel +#' +#' @param jobj a Java object reference to the backing Scala NaiveBayesWrapper +#' @export +#' @note NaiveBayesModel since 2.0.0 +setClass("NaiveBayesModel", representation(jobj = "jobj")) + +#' Logistic Regression Model +#' +#' Fits an logistic regression model against a Spark DataFrame. It supports "binomial": Binary logistic regression +#' with pivoting; "multinomial": Multinomial logistic (softmax) regression without pivoting, similar to glmnet. +#' Users can print, make predictions on the produced model and save the model to the input path. +#' +#' @param data SparkDataFrame for training. +#' @param formula A symbolic description of the model to be fitted. Currently only a few formula +#' operators are supported, including '~', '.', ':', '+', and '-'. +#' @param regParam the regularization parameter. +#' @param elasticNetParam the ElasticNet mixing parameter. For alpha = 0.0, the penalty is an L2 penalty. +#' For alpha = 1.0, it is an L1 penalty. For 0.0 < alpha < 1.0, the penalty is a combination +#' of L1 and L2. Default is 0.0 which is an L2 penalty. +#' @param maxIter maximum iteration number. +#' @param tol convergence tolerance of iterations. +#' @param family the name of family which is a description of the label distribution to be used in the model. +#' Supported options: +#' \itemize{ +#' \item{"auto": Automatically select the family based on the number of classes: +#' If number of classes == 1 || number of classes == 2, set to "binomial". +#' Else, set to "multinomial".} +#' \item{"binomial": Binary logistic regression with pivoting.} +#' \item{"multinomial": Multinomial logistic (softmax) regression without pivoting.} +#' } +#' @param standardization whether to standardize the training features before fitting the model. The coefficients +#' of models will be always returned on the original scale, so it will be transparent for +#' users. Note that with/without standardization, the models should be always converged +#' to the same solution when no regularization is applied. Default is TRUE, same as glmnet. +#' @param thresholds in binary classification, in range [0, 1]. If the estimated probability of class label 1 +#' is > threshold, then predict 1, else 0. A high threshold encourages the model to predict 0 +#' more often; a low threshold encourages the model to predict 1 more often. Note: Setting this with +#' threshold p is equivalent to setting thresholds c(1-p, p). In multiclass (or binary) classification to adjust the probability of +#' predicting each class. Array must have length equal to the number of classes, with values > 0, +#' excepting that at most one value may be 0. The class with largest value p/t is predicted, where p +#' is the original probability of that class and t is the class's threshold. +#' @param weightCol The weight column name. +#' @param ... additional arguments passed to the method. +#' @return \code{spark.logit} returns a fitted logistic regression model. +#' @rdname spark.logit +#' @aliases spark.logit,SparkDataFrame,formula-method +#' @name spark.logit +#' @export +#' @examples +#' \dontrun{ +#' sparkR.session() +#' # binary logistic regression +#' df <- createDataFrame(iris) +#' training <- df[df$Species %in% c("versicolor", "virginica"), ] +#' model <- spark.logit(training, Species ~ ., regParam = 0.5) +#' summary <- summary(model) +#' +#' # fitted values on training data +#' fitted <- predict(model, training) +#' +#' # save fitted model to input path +#' path <- "path/to/model" +#' write.ml(model, path) +#' +#' # can also read back the saved model and predict +#' # Note that summary deos not work on loaded model +#' savedModel <- read.ml(path) +#' summary(savedModel) +#' +#' # multinomial logistic regression +#' +#' df <- createDataFrame(iris) +#' model <- spark.logit(df, Species ~ ., regParam = 0.5) +#' summary <- summary(model) +#' +#' } +#' @note spark.logit since 2.1.0 +setMethod("spark.logit", signature(data = "SparkDataFrame", formula = "formula"), + function(data, formula, regParam = 0.0, elasticNetParam = 0.0, maxIter = 100, + tol = 1E-6, family = "auto", standardization = TRUE, + thresholds = 0.5, weightCol = NULL) { + formula <- paste(deparse(formula), collapse = "") + + if (is.null(weightCol)) { + weightCol <- "" + } + + jobj <- callJStatic("org.apache.spark.ml.r.LogisticRegressionWrapper", "fit", + data@sdf, formula, as.numeric(regParam), + as.numeric(elasticNetParam), as.integer(maxIter), + as.numeric(tol), as.character(family), + as.logical(standardization), as.array(thresholds), + as.character(weightCol)) + new("LogisticRegressionModel", jobj = jobj) + }) + +# Get the summary of an LogisticRegressionModel + +#' @param object an LogisticRegressionModel fitted by \code{spark.logit}. +#' @return \code{summary} returns summary information of the fitted model, which is a list. +#' The list includes \code{coefficients} (coefficients matrix of the fitted model). +#' @rdname spark.logit +#' @aliases summary,LogisticRegressionModel-method +#' @export +#' @note summary(LogisticRegressionModel) since 2.1.0 +setMethod("summary", signature(object = "LogisticRegressionModel"), + function(object) { + jobj <- object@jobj + features <- callJMethod(jobj, "rFeatures") + labels <- callJMethod(jobj, "labels") + coefficients <- callJMethod(jobj, "rCoefficients") + nCol <- length(coefficients) / length(features) + coefficients <- matrix(coefficients, ncol = nCol) + # If nCol == 1, means this is a binomial logistic regression model with pivoting. + # Otherwise, it's a multinomial logistic regression model without pivoting. + if (nCol == 1) { + colnames(coefficients) <- c("Estimate") + } else { + colnames(coefficients) <- unlist(labels) + } + rownames(coefficients) <- unlist(features) + + list(coefficients = coefficients) + }) + +# Predicted values based on an LogisticRegressionModel model + +#' @param newData a SparkDataFrame for testing. +#' @return \code{predict} returns the predicted values based on an LogisticRegressionModel. +#' @rdname spark.logit +#' @aliases predict,LogisticRegressionModel,SparkDataFrame-method +#' @export +#' @note predict(LogisticRegressionModel) since 2.1.0 +setMethod("predict", signature(object = "LogisticRegressionModel"), + function(object, newData) { + predict_internal(object, newData) + }) + +# Save fitted LogisticRegressionModel to the input path + +#' @param path The directory where the model is saved. +#' @param overwrite Overwrites or not if the output path already exists. Default is FALSE +#' which means throw exception if the output path exists. +#' +#' @rdname spark.logit +#' @aliases write.ml,LogisticRegressionModel,character-method +#' @export +#' @note write.ml(LogisticRegression, character) since 2.1.0 +setMethod("write.ml", signature(object = "LogisticRegressionModel", path = "character"), + function(object, path, overwrite = FALSE) { + write_internal(object, path, overwrite) + }) + +#' Multilayer Perceptron Classification Model +#' +#' \code{spark.mlp} fits a multi-layer perceptron neural network model against a SparkDataFrame. +#' Users can call \code{summary} to print a summary of the fitted model, \code{predict} to make +#' predictions on new data, and \code{write.ml}/\code{read.ml} to save/load fitted models. +#' Only categorical data is supported. +#' For more details, see +#' \href{http://spark.apache.org/docs/latest/ml-classification-regression.html}{ +#' Multilayer Perceptron} +#' +#' @param data a \code{SparkDataFrame} of observations and labels for model fitting. +#' @param formula a symbolic description of the model to be fitted. Currently only a few formula +#' operators are supported, including '~', '.', ':', '+', and '-'. +#' @param blockSize blockSize parameter. +#' @param layers integer vector containing the number of nodes for each layer. +#' @param solver solver parameter, supported options: "gd" (minibatch gradient descent) or "l-bfgs". +#' @param maxIter maximum iteration number. +#' @param tol convergence tolerance of iterations. +#' @param stepSize stepSize parameter. +#' @param seed seed parameter for weights initialization. +#' @param initialWeights initialWeights parameter for weights initialization, it should be a +#' numeric vector. +#' @param ... additional arguments passed to the method. +#' @return \code{spark.mlp} returns a fitted Multilayer Perceptron Classification Model. +#' @rdname spark.mlp +#' @aliases spark.mlp,SparkDataFrame,formula-method +#' @name spark.mlp +#' @seealso \link{read.ml} +#' @export +#' @examples +#' \dontrun{ +#' df <- read.df("data/mllib/sample_multiclass_classification_data.txt", source = "libsvm") +#' +#' # fit a Multilayer Perceptron Classification Model +#' model <- spark.mlp(df, label ~ features, blockSize = 128, layers = c(4, 3), solver = "l-bfgs", +#' maxIter = 100, tol = 0.5, stepSize = 1, seed = 1, +#' initialWeights = c(0, 0, 0, 0, 0, 5, 5, 5, 5, 5, 9, 9, 9, 9, 9)) +#' +#' # get the summary of the model +#' summary(model) +#' +#' # make predictions +#' predictions <- predict(model, df) +#' +#' # save and load the model +#' path <- "path/to/model" +#' write.ml(model, path) +#' savedModel <- read.ml(path) +#' summary(savedModel) +#' } +#' @note spark.mlp since 2.1.0 +setMethod("spark.mlp", signature(data = "SparkDataFrame", formula = "formula"), + function(data, formula, layers, blockSize = 128, solver = "l-bfgs", maxIter = 100, + tol = 1E-6, stepSize = 0.03, seed = NULL, initialWeights = NULL) { + formula <- paste(deparse(formula), collapse = "") + if (is.null(layers)) { + stop ("layers must be a integer vector with length > 1.") + } + layers <- as.integer(na.omit(layers)) + if (length(layers) <= 1) { + stop ("layers must be a integer vector with length > 1.") + } + if (!is.null(seed)) { + seed <- as.character(as.integer(seed)) + } + if (!is.null(initialWeights)) { + initialWeights <- as.array(as.numeric(na.omit(initialWeights))) + } + jobj <- callJStatic("org.apache.spark.ml.r.MultilayerPerceptronClassifierWrapper", + "fit", data@sdf, formula, as.integer(blockSize), as.array(layers), + as.character(solver), as.integer(maxIter), as.numeric(tol), + as.numeric(stepSize), seed, initialWeights) + new("MultilayerPerceptronClassificationModel", jobj = jobj) + }) + +# Returns the summary of a Multilayer Perceptron Classification Model produced by \code{spark.mlp} + +#' @param object a Multilayer Perceptron Classification Model fitted by \code{spark.mlp} +#' @return \code{summary} returns summary information of the fitted model, which is a list. +#' The list includes \code{numOfInputs} (number of inputs), \code{numOfOutputs} +#' (number of outputs), \code{layers} (array of layer sizes including input +#' and output layers), and \code{weights} (the weights of layers). +#' For \code{weights}, it is a numeric vector with length equal to the expected +#' given the architecture (i.e., for 8-10-2 network, 112 connection weights). +#' @rdname spark.mlp +#' @export +#' @aliases summary,MultilayerPerceptronClassificationModel-method +#' @note summary(MultilayerPerceptronClassificationModel) since 2.1.0 +setMethod("summary", signature(object = "MultilayerPerceptronClassificationModel"), + function(object) { + jobj <- object@jobj + layers <- unlist(callJMethod(jobj, "layers")) + numOfInputs <- head(layers, n = 1) + numOfOutputs <- tail(layers, n = 1) + weights <- callJMethod(jobj, "weights") + list(numOfInputs = numOfInputs, numOfOutputs = numOfOutputs, + layers = layers, weights = weights) + }) + +# Makes predictions from a model produced by spark.mlp(). + +#' @param newData a SparkDataFrame for testing. +#' @return \code{predict} returns a SparkDataFrame containing predicted labeled in a column named +#' "prediction". +#' @rdname spark.mlp +#' @aliases predict,MultilayerPerceptronClassificationModel-method +#' @export +#' @note predict(MultilayerPerceptronClassificationModel) since 2.1.0 +setMethod("predict", signature(object = "MultilayerPerceptronClassificationModel"), + function(object, newData) { + predict_internal(object, newData) + }) + +# Saves the Multilayer Perceptron Classification Model to the input path. + +#' @param path the directory where the model is saved. +#' @param overwrite overwrites or not if the output path already exists. Default is FALSE +#' which means throw exception if the output path exists. +#' +#' @rdname spark.mlp +#' @aliases write.ml,MultilayerPerceptronClassificationModel,character-method +#' @export +#' @seealso \link{write.ml} +#' @note write.ml(MultilayerPerceptronClassificationModel, character) since 2.1.0 +setMethod("write.ml", signature(object = "MultilayerPerceptronClassificationModel", + path = "character"), + function(object, path, overwrite = FALSE) { + write_internal(object, path, overwrite) + }) + +#' Naive Bayes Models +#' +#' \code{spark.naiveBayes} fits a Bernoulli naive Bayes model against a SparkDataFrame. +#' Users can call \code{summary} to print a summary of the fitted model, \code{predict} to make +#' predictions on new data, and \code{write.ml}/\code{read.ml} to save/load fitted models. +#' Only categorical data is supported. +#' +#' @param data a \code{SparkDataFrame} of observations and labels for model fitting. +#' @param formula a symbolic description of the model to be fitted. Currently only a few formula +#' operators are supported, including '~', '.', ':', '+', and '-'. +#' @param smoothing smoothing parameter. +#' @param ... additional argument(s) passed to the method. Currently only \code{smoothing}. +#' @return \code{spark.naiveBayes} returns a fitted naive Bayes model. +#' @rdname spark.naiveBayes +#' @aliases spark.naiveBayes,SparkDataFrame,formula-method +#' @name spark.naiveBayes +#' @seealso e1071: \url{https://cran.r-project.org/package=e1071} +#' @export +#' @examples +#' \dontrun{ +#' data <- as.data.frame(UCBAdmissions) +#' df <- createDataFrame(data) +#' +#' # fit a Bernoulli naive Bayes model +#' model <- spark.naiveBayes(df, Admit ~ Gender + Dept, smoothing = 0) +#' +#' # get the summary of the model +#' summary(model) +#' +#' # make predictions +#' predictions <- predict(model, df) +#' +#' # save and load the model +#' path <- "path/to/model" +#' write.ml(model, path) +#' savedModel <- read.ml(path) +#' summary(savedModel) +#' } +#' @note spark.naiveBayes since 2.0.0 +setMethod("spark.naiveBayes", signature(data = "SparkDataFrame", formula = "formula"), + function(data, formula, smoothing = 1.0) { + formula <- paste(deparse(formula), collapse = "") + jobj <- callJStatic("org.apache.spark.ml.r.NaiveBayesWrapper", "fit", + formula, data@sdf, smoothing) + new("NaiveBayesModel", jobj = jobj) + }) + +# Returns the summary of a naive Bayes model produced by \code{spark.naiveBayes} + +#' @param object a naive Bayes model fitted by \code{spark.naiveBayes}. +#' @return \code{summary} returns summary information of the fitted model, which is a list. +#' The list includes \code{apriori} (the label distribution) and +#' \code{tables} (conditional probabilities given the target label). +#' @rdname spark.naiveBayes +#' @export +#' @note summary(NaiveBayesModel) since 2.0.0 +setMethod("summary", signature(object = "NaiveBayesModel"), + function(object) { + jobj <- object@jobj + features <- callJMethod(jobj, "features") + labels <- callJMethod(jobj, "labels") + apriori <- callJMethod(jobj, "apriori") + apriori <- t(as.matrix(unlist(apriori))) + colnames(apriori) <- unlist(labels) + tables <- callJMethod(jobj, "tables") + tables <- matrix(tables, nrow = length(labels)) + rownames(tables) <- unlist(labels) + colnames(tables) <- unlist(features) + list(apriori = apriori, tables = tables) + }) + +# Makes predictions from a naive Bayes model or a model produced by spark.naiveBayes(), +# similarly to R package e1071's predict. + +#' @param newData a SparkDataFrame for testing. +#' @return \code{predict} returns a SparkDataFrame containing predicted labeled in a column named +#' "prediction". +#' @rdname spark.naiveBayes +#' @export +#' @note predict(NaiveBayesModel) since 2.0.0 +setMethod("predict", signature(object = "NaiveBayesModel"), + function(object, newData) { + predict_internal(object, newData) + }) + +# Saves the Bernoulli naive Bayes model to the input path. + +#' @param path the directory where the model is saved. +#' @param overwrite overwrites or not if the output path already exists. Default is FALSE +#' which means throw exception if the output path exists. +#' +#' @rdname spark.naiveBayes +#' @export +#' @seealso \link{write.ml} +#' @note write.ml(NaiveBayesModel, character) since 2.0.0 +setMethod("write.ml", signature(object = "NaiveBayesModel", path = "character"), + function(object, path, overwrite = FALSE) { + write_internal(object, path, overwrite) + }) diff --git a/R/pkg/R/mllib_clustering.R b/R/pkg/R/mllib_clustering.R new file mode 100644 index 0000000000000..c44358838703f --- /dev/null +++ b/R/pkg/R/mllib_clustering.R @@ -0,0 +1,456 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +# mllib_clustering.R: Provides methods for MLlib clustering algorithms integration + +#' S4 class that represents a GaussianMixtureModel +#' +#' @param jobj a Java object reference to the backing Scala GaussianMixtureModel +#' @export +#' @note GaussianMixtureModel since 2.1.0 +setClass("GaussianMixtureModel", representation(jobj = "jobj")) + +#' S4 class that represents a KMeansModel +#' +#' @param jobj a Java object reference to the backing Scala KMeansModel +#' @export +#' @note KMeansModel since 2.0.0 +setClass("KMeansModel", representation(jobj = "jobj")) + +#' S4 class that represents an LDAModel +#' +#' @param jobj a Java object reference to the backing Scala LDAWrapper +#' @export +#' @note LDAModel since 2.1.0 +setClass("LDAModel", representation(jobj = "jobj")) + +#' Multivariate Gaussian Mixture Model (GMM) +#' +#' Fits multivariate gaussian mixture model against a Spark DataFrame, similarly to R's +#' mvnormalmixEM(). Users can call \code{summary} to print a summary of the fitted model, +#' \code{predict} to make predictions on new data, and \code{write.ml}/\code{read.ml} +#' to save/load fitted models. +#' +#' @param data a SparkDataFrame for training. +#' @param formula a symbolic description of the model to be fitted. Currently only a few formula +#' operators are supported, including '~', '.', ':', '+', and '-'. +#' Note that the response variable of formula is empty in spark.gaussianMixture. +#' @param k number of independent Gaussians in the mixture model. +#' @param maxIter maximum iteration number. +#' @param tol the convergence tolerance. +#' @param ... additional arguments passed to the method. +#' @aliases spark.gaussianMixture,SparkDataFrame,formula-method +#' @return \code{spark.gaussianMixture} returns a fitted multivariate gaussian mixture model. +#' @rdname spark.gaussianMixture +#' @name spark.gaussianMixture +#' @seealso mixtools: \url{https://cran.r-project.org/package=mixtools} +#' @export +#' @examples +#' \dontrun{ +#' sparkR.session() +#' library(mvtnorm) +#' set.seed(100) +#' a <- rmvnorm(4, c(0, 0)) +#' b <- rmvnorm(6, c(3, 4)) +#' data <- rbind(a, b) +#' df <- createDataFrame(as.data.frame(data)) +#' model <- spark.gaussianMixture(df, ~ V1 + V2, k = 2) +#' summary(model) +#' +#' # fitted values on training data +#' fitted <- predict(model, df) +#' head(select(fitted, "V1", "prediction")) +#' +#' # save fitted model to input path +#' path <- "path/to/model" +#' write.ml(model, path) +#' +#' # can also read back the saved model and print +#' savedModel <- read.ml(path) +#' summary(savedModel) +#' } +#' @note spark.gaussianMixture since 2.1.0 +#' @seealso \link{predict}, \link{read.ml}, \link{write.ml} +setMethod("spark.gaussianMixture", signature(data = "SparkDataFrame", formula = "formula"), + function(data, formula, k = 2, maxIter = 100, tol = 0.01) { + formula <- paste(deparse(formula), collapse = "") + jobj <- callJStatic("org.apache.spark.ml.r.GaussianMixtureWrapper", "fit", data@sdf, + formula, as.integer(k), as.integer(maxIter), as.numeric(tol)) + new("GaussianMixtureModel", jobj = jobj) + }) + +# Get the summary of a multivariate gaussian mixture model + +#' @param object a fitted gaussian mixture model. +#' @return \code{summary} returns summary of the fitted model, which is a list. +#' The list includes the model's \code{lambda} (lambda), \code{mu} (mu), +#' \code{sigma} (sigma), and \code{posterior} (posterior). +#' @aliases spark.gaussianMixture,SparkDataFrame,formula-method +#' @rdname spark.gaussianMixture +#' @export +#' @note summary(GaussianMixtureModel) since 2.1.0 +setMethod("summary", signature(object = "GaussianMixtureModel"), + function(object) { + jobj <- object@jobj + is.loaded <- callJMethod(jobj, "isLoaded") + lambda <- unlist(callJMethod(jobj, "lambda")) + muList <- callJMethod(jobj, "mu") + sigmaList <- callJMethod(jobj, "sigma") + k <- callJMethod(jobj, "k") + dim <- callJMethod(jobj, "dim") + mu <- c() + for (i in 1 : k) { + start <- (i - 1) * dim + 1 + end <- i * dim + mu[[i]] <- unlist(muList[start : end]) + } + sigma <- c() + for (i in 1 : k) { + start <- (i - 1) * dim * dim + 1 + end <- i * dim * dim + sigma[[i]] <- t(matrix(sigmaList[start : end], ncol = dim)) + } + posterior <- if (is.loaded) { + NULL + } else { + dataFrame(callJMethod(jobj, "posterior")) + } + list(lambda = lambda, mu = mu, sigma = sigma, + posterior = posterior, is.loaded = is.loaded) + }) + +# Predicted values based on a gaussian mixture model + +#' @param newData a SparkDataFrame for testing. +#' @return \code{predict} returns a SparkDataFrame containing predicted labels in a column named +#' "prediction". +#' @aliases predict,GaussianMixtureModel,SparkDataFrame-method +#' @rdname spark.gaussianMixture +#' @export +#' @note predict(GaussianMixtureModel) since 2.1.0 +setMethod("predict", signature(object = "GaussianMixtureModel"), + function(object, newData) { + predict_internal(object, newData) + }) + +# Save fitted MLlib model to the input path + +#' @param path the directory where the model is saved. +#' @param overwrite overwrites or not if the output path already exists. Default is FALSE +#' which means throw exception if the output path exists. +#' +#' @aliases write.ml,GaussianMixtureModel,character-method +#' @rdname spark.gaussianMixture +#' @export +#' @note write.ml(GaussianMixtureModel, character) since 2.1.0 +setMethod("write.ml", signature(object = "GaussianMixtureModel", path = "character"), + function(object, path, overwrite = FALSE) { + write_internal(object, path, overwrite) + }) + +#' K-Means Clustering Model +#' +#' Fits a k-means clustering model against a Spark DataFrame, similarly to R's kmeans(). +#' Users can call \code{summary} to print a summary of the fitted model, \code{predict} to make +#' predictions on new data, and \code{write.ml}/\code{read.ml} to save/load fitted models. +#' +#' @param data a SparkDataFrame for training. +#' @param formula a symbolic description of the model to be fitted. Currently only a few formula +#' operators are supported, including '~', '.', ':', '+', and '-'. +#' Note that the response variable of formula is empty in spark.kmeans. +#' @param k number of centers. +#' @param maxIter maximum iteration number. +#' @param initMode the initialization algorithm choosen to fit the model. +#' @param ... additional argument(s) passed to the method. +#' @return \code{spark.kmeans} returns a fitted k-means model. +#' @rdname spark.kmeans +#' @aliases spark.kmeans,SparkDataFrame,formula-method +#' @name spark.kmeans +#' @export +#' @examples +#' \dontrun{ +#' sparkR.session() +#' data(iris) +#' df <- createDataFrame(iris) +#' model <- spark.kmeans(df, Sepal_Length ~ Sepal_Width, k = 4, initMode = "random") +#' summary(model) +#' +#' # fitted values on training data +#' fitted <- predict(model, df) +#' head(select(fitted, "Sepal_Length", "prediction")) +#' +#' # save fitted model to input path +#' path <- "path/to/model" +#' write.ml(model, path) +#' +#' # can also read back the saved model and print +#' savedModel <- read.ml(path) +#' summary(savedModel) +#' } +#' @note spark.kmeans since 2.0.0 +#' @seealso \link{predict}, \link{read.ml}, \link{write.ml} +setMethod("spark.kmeans", signature(data = "SparkDataFrame", formula = "formula"), + function(data, formula, k = 2, maxIter = 20, initMode = c("k-means||", "random")) { + formula <- paste(deparse(formula), collapse = "") + initMode <- match.arg(initMode) + jobj <- callJStatic("org.apache.spark.ml.r.KMeansWrapper", "fit", data@sdf, formula, + as.integer(k), as.integer(maxIter), initMode) + new("KMeansModel", jobj = jobj) + }) + +# Get the summary of a k-means model + +#' @param object a fitted k-means model. +#' @return \code{summary} returns summary information of the fitted model, which is a list. +#' The list includes the model's \code{k} (number of cluster centers), +#' \code{coefficients} (model cluster centers), +#' \code{size} (number of data points in each cluster), and \code{cluster} +#' (cluster centers of the transformed data). +#' @rdname spark.kmeans +#' @export +#' @note summary(KMeansModel) since 2.0.0 +setMethod("summary", signature(object = "KMeansModel"), + function(object) { + jobj <- object@jobj + is.loaded <- callJMethod(jobj, "isLoaded") + features <- callJMethod(jobj, "features") + coefficients <- callJMethod(jobj, "coefficients") + k <- callJMethod(jobj, "k") + size <- callJMethod(jobj, "size") + coefficients <- t(matrix(coefficients, ncol = k)) + colnames(coefficients) <- unlist(features) + rownames(coefficients) <- 1:k + cluster <- if (is.loaded) { + NULL + } else { + dataFrame(callJMethod(jobj, "cluster")) + } + list(k = k, coefficients = coefficients, size = size, + cluster = cluster, is.loaded = is.loaded) + }) + +# Predicted values based on a k-means model + +#' @param newData a SparkDataFrame for testing. +#' @return \code{predict} returns the predicted values based on a k-means model. +#' @rdname spark.kmeans +#' @export +#' @note predict(KMeansModel) since 2.0.0 +setMethod("predict", signature(object = "KMeansModel"), + function(object, newData) { + predict_internal(object, newData) + }) + +#' Get fitted result from a k-means model +#' +#' Get fitted result from a k-means model, similarly to R's fitted(). +#' Note: A saved-loaded model does not support this method. +#' +#' @param object a fitted k-means model. +#' @param method type of fitted results, \code{"centers"} for cluster centers +#' or \code{"classes"} for assigned classes. +#' @param ... additional argument(s) passed to the method. +#' @return \code{fitted} returns a SparkDataFrame containing fitted values. +#' @rdname fitted +#' @export +#' @examples +#' \dontrun{ +#' model <- spark.kmeans(trainingData, ~ ., 2) +#' fitted.model <- fitted(model) +#' showDF(fitted.model) +#'} +#' @note fitted since 2.0.0 +setMethod("fitted", signature(object = "KMeansModel"), + function(object, method = c("centers", "classes")) { + method <- match.arg(method) + jobj <- object@jobj + is.loaded <- callJMethod(jobj, "isLoaded") + if (is.loaded) { + stop("Saved-loaded k-means model does not support 'fitted' method") + } else { + dataFrame(callJMethod(jobj, "fitted", method)) + } + }) + +# Save fitted MLlib model to the input path + +#' @param path the directory where the model is saved. +#' @param overwrite overwrites or not if the output path already exists. Default is FALSE +#' which means throw exception if the output path exists. +#' +#' @rdname spark.kmeans +#' @export +#' @note write.ml(KMeansModel, character) since 2.0.0 +setMethod("write.ml", signature(object = "KMeansModel", path = "character"), + function(object, path, overwrite = FALSE) { + write_internal(object, path, overwrite) + }) + +#' Latent Dirichlet Allocation +#' +#' \code{spark.lda} fits a Latent Dirichlet Allocation model on a SparkDataFrame. Users can call +#' \code{summary} to get a summary of the fitted LDA model, \code{spark.posterior} to compute +#' posterior probabilities on new data, \code{spark.perplexity} to compute log perplexity on new +#' data and \code{write.ml}/\code{read.ml} to save/load fitted models. +#' +#' @param data A SparkDataFrame for training. +#' @param features Features column name. Either libSVM-format column or character-format column is +#' valid. +#' @param k Number of topics. +#' @param maxIter Maximum iterations. +#' @param optimizer Optimizer to train an LDA model, "online" or "em", default is "online". +#' @param subsamplingRate (For online optimizer) Fraction of the corpus to be sampled and used in +#' each iteration of mini-batch gradient descent, in range (0, 1]. +#' @param topicConcentration concentration parameter (commonly named \code{beta} or \code{eta}) for +#' the prior placed on topic distributions over terms, default -1 to set automatically on the +#' Spark side. Use \code{summary} to retrieve the effective topicConcentration. Only 1-size +#' numeric is accepted. +#' @param docConcentration concentration parameter (commonly named \code{alpha}) for the +#' prior placed on documents distributions over topics (\code{theta}), default -1 to set +#' automatically on the Spark side. Use \code{summary} to retrieve the effective +#' docConcentration. Only 1-size or \code{k}-size numeric is accepted. +#' @param customizedStopWords stopwords that need to be removed from the given corpus. Ignore the +#' parameter if libSVM-format column is used as the features column. +#' @param maxVocabSize maximum vocabulary size, default 1 << 18 +#' @param ... additional argument(s) passed to the method. +#' @return \code{spark.lda} returns a fitted Latent Dirichlet Allocation model. +#' @rdname spark.lda +#' @aliases spark.lda,SparkDataFrame-method +#' @seealso topicmodels: \url{https://cran.r-project.org/package=topicmodels} +#' @export +#' @examples +#' \dontrun{ +#' # nolint start +#' # An example "path/to/file" can be +#' # paste0(Sys.getenv("SPARK_HOME"), "/data/mllib/sample_lda_libsvm_data.txt") +#' # nolint end +#' text <- read.df("path/to/file", source = "libsvm") +#' model <- spark.lda(data = text, optimizer = "em") +#' +#' # get a summary of the model +#' summary(model) +#' +#' # compute posterior probabilities +#' posterior <- spark.posterior(model, text) +#' showDF(posterior) +#' +#' # compute perplexity +#' perplexity <- spark.perplexity(model, text) +#' +#' # save and load the model +#' path <- "path/to/model" +#' write.ml(model, path) +#' savedModel <- read.ml(path) +#' summary(savedModel) +#' } +#' @note spark.lda since 2.1.0 +setMethod("spark.lda", signature(data = "SparkDataFrame"), + function(data, features = "features", k = 10, maxIter = 20, optimizer = c("online", "em"), + subsamplingRate = 0.05, topicConcentration = -1, docConcentration = -1, + customizedStopWords = "", maxVocabSize = bitwShiftL(1, 18)) { + optimizer <- match.arg(optimizer) + jobj <- callJStatic("org.apache.spark.ml.r.LDAWrapper", "fit", data@sdf, features, + as.integer(k), as.integer(maxIter), optimizer, + as.numeric(subsamplingRate), topicConcentration, + as.array(docConcentration), as.array(customizedStopWords), + maxVocabSize) + new("LDAModel", jobj = jobj) + }) + +# Returns the summary of a Latent Dirichlet Allocation model produced by \code{spark.lda} + +#' @param object A Latent Dirichlet Allocation model fitted by \code{spark.lda}. +#' @param maxTermsPerTopic Maximum number of terms to collect for each topic. Default value of 10. +#' @return \code{summary} returns summary information of the fitted model, which is a list. +#' The list includes +#' \item{\code{docConcentration}}{concentration parameter commonly named \code{alpha} for +#' the prior placed on documents distributions over topics \code{theta}} +#' \item{\code{topicConcentration}}{concentration parameter commonly named \code{beta} or +#' \code{eta} for the prior placed on topic distributions over terms} +#' \item{\code{logLikelihood}}{log likelihood of the entire corpus} +#' \item{\code{logPerplexity}}{log perplexity} +#' \item{\code{isDistributed}}{TRUE for distributed model while FALSE for local model} +#' \item{\code{vocabSize}}{number of terms in the corpus} +#' \item{\code{topics}}{top 10 terms and their weights of all topics} +#' \item{\code{vocabulary}}{whole terms of the training corpus, NULL if libsvm format file +#' used as training set} +#' @rdname spark.lda +#' @aliases summary,LDAModel-method +#' @export +#' @note summary(LDAModel) since 2.1.0 +setMethod("summary", signature(object = "LDAModel"), + function(object, maxTermsPerTopic) { + maxTermsPerTopic <- as.integer(ifelse(missing(maxTermsPerTopic), 10, maxTermsPerTopic)) + jobj <- object@jobj + docConcentration <- callJMethod(jobj, "docConcentration") + topicConcentration <- callJMethod(jobj, "topicConcentration") + logLikelihood <- callJMethod(jobj, "logLikelihood") + logPerplexity <- callJMethod(jobj, "logPerplexity") + isDistributed <- callJMethod(jobj, "isDistributed") + vocabSize <- callJMethod(jobj, "vocabSize") + topics <- dataFrame(callJMethod(jobj, "topics", maxTermsPerTopic)) + vocabulary <- callJMethod(jobj, "vocabulary") + list(docConcentration = unlist(docConcentration), + topicConcentration = topicConcentration, + logLikelihood = logLikelihood, logPerplexity = logPerplexity, + isDistributed = isDistributed, vocabSize = vocabSize, + topics = topics, vocabulary = unlist(vocabulary)) + }) + +# Returns the log perplexity of a Latent Dirichlet Allocation model produced by \code{spark.lda} + +#' @return \code{spark.perplexity} returns the log perplexity of given SparkDataFrame, or the log +#' perplexity of the training data if missing argument "data". +#' @rdname spark.lda +#' @aliases spark.perplexity,LDAModel-method +#' @export +#' @note spark.perplexity(LDAModel) since 2.1.0 +setMethod("spark.perplexity", signature(object = "LDAModel", data = "SparkDataFrame"), + function(object, data) { + ifelse(missing(data), callJMethod(object@jobj, "logPerplexity"), + callJMethod(object@jobj, "computeLogPerplexity", data@sdf)) + }) + +# Returns posterior probabilities from a Latent Dirichlet Allocation model produced by spark.lda() + +#' @param newData A SparkDataFrame for testing. +#' @return \code{spark.posterior} returns a SparkDataFrame containing posterior probabilities +#' vectors named "topicDistribution". +#' @rdname spark.lda +#' @aliases spark.posterior,LDAModel,SparkDataFrame-method +#' @export +#' @note spark.posterior(LDAModel) since 2.1.0 +setMethod("spark.posterior", signature(object = "LDAModel", newData = "SparkDataFrame"), + function(object, newData) { + predict_internal(object, newData) + }) + +# Saves the Latent Dirichlet Allocation model to the input path. + +#' @param path The directory where the model is saved. +#' @param overwrite Overwrites or not if the output path already exists. Default is FALSE +#' which means throw exception if the output path exists. +#' +#' @rdname spark.lda +#' @aliases write.ml,LDAModel,character-method +#' @export +#' @seealso \link{read.ml} +#' @note write.ml(LDAModel, character) since 2.1.0 +setMethod("write.ml", signature(object = "LDAModel", path = "character"), + function(object, path, overwrite = FALSE) { + write_internal(object, path, overwrite) + }) diff --git a/R/pkg/R/mllib_recommendation.R b/R/pkg/R/mllib_recommendation.R new file mode 100644 index 0000000000000..fa794249085d7 --- /dev/null +++ b/R/pkg/R/mllib_recommendation.R @@ -0,0 +1,162 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +# mllib_recommendation.R: Provides methods for MLlib recommendation algorithms integration + +#' S4 class that represents an ALSModel +#' +#' @param jobj a Java object reference to the backing Scala ALSWrapper +#' @export +#' @note ALSModel since 2.1.0 +setClass("ALSModel", representation(jobj = "jobj")) + +#' Alternating Least Squares (ALS) for Collaborative Filtering +#' +#' \code{spark.als} learns latent factors in collaborative filtering via alternating least +#' squares. Users can call \code{summary} to obtain fitted latent factors, \code{predict} +#' to make predictions on new data, and \code{write.ml}/\code{read.ml} to save/load fitted models. +#' +#' For more details, see +#' \href{http://spark.apache.org/docs/latest/ml-collaborative-filtering.html}{MLlib: +#' Collaborative Filtering}. +#' +#' @param data a SparkDataFrame for training. +#' @param ratingCol column name for ratings. +#' @param userCol column name for user ids. Ids must be (or can be coerced into) integers. +#' @param itemCol column name for item ids. Ids must be (or can be coerced into) integers. +#' @param rank rank of the matrix factorization (> 0). +#' @param regParam regularization parameter (>= 0). +#' @param maxIter maximum number of iterations (>= 0). +#' @param nonnegative logical value indicating whether to apply nonnegativity constraints. +#' @param implicitPrefs logical value indicating whether to use implicit preference. +#' @param alpha alpha parameter in the implicit preference formulation (>= 0). +#' @param seed integer seed for random number generation. +#' @param numUserBlocks number of user blocks used to parallelize computation (> 0). +#' @param numItemBlocks number of item blocks used to parallelize computation (> 0). +#' @param checkpointInterval number of checkpoint intervals (>= 1) or disable checkpoint (-1). +#' @param ... additional argument(s) passed to the method. +#' @return \code{spark.als} returns a fitted ALS model. +#' @rdname spark.als +#' @aliases spark.als,SparkDataFrame-method +#' @name spark.als +#' @export +#' @examples +#' \dontrun{ +#' ratings <- list(list(0, 0, 4.0), list(0, 1, 2.0), list(1, 1, 3.0), list(1, 2, 4.0), +#' list(2, 1, 1.0), list(2, 2, 5.0)) +#' df <- createDataFrame(ratings, c("user", "item", "rating")) +#' model <- spark.als(df, "rating", "user", "item") +#' +#' # extract latent factors +#' stats <- summary(model) +#' userFactors <- stats$userFactors +#' itemFactors <- stats$itemFactors +#' +#' # make predictions +#' predicted <- predict(model, df) +#' showDF(predicted) +#' +#' # save and load the model +#' path <- "path/to/model" +#' write.ml(model, path) +#' savedModel <- read.ml(path) +#' summary(savedModel) +#' +#' # set other arguments +#' modelS <- spark.als(df, "rating", "user", "item", rank = 20, +#' regParam = 0.1, nonnegative = TRUE) +#' statsS <- summary(modelS) +#' } +#' @note spark.als since 2.1.0 +setMethod("spark.als", signature(data = "SparkDataFrame"), + function(data, ratingCol = "rating", userCol = "user", itemCol = "item", + rank = 10, regParam = 0.1, maxIter = 10, nonnegative = FALSE, + implicitPrefs = FALSE, alpha = 1.0, numUserBlocks = 10, numItemBlocks = 10, + checkpointInterval = 10, seed = 0) { + + if (!is.numeric(rank) || rank <= 0) { + stop("rank should be a positive number.") + } + if (!is.numeric(regParam) || regParam < 0) { + stop("regParam should be a nonnegative number.") + } + if (!is.numeric(maxIter) || maxIter <= 0) { + stop("maxIter should be a positive number.") + } + + jobj <- callJStatic("org.apache.spark.ml.r.ALSWrapper", + "fit", data@sdf, ratingCol, userCol, itemCol, as.integer(rank), + regParam, as.integer(maxIter), implicitPrefs, alpha, nonnegative, + as.integer(numUserBlocks), as.integer(numItemBlocks), + as.integer(checkpointInterval), as.integer(seed)) + new("ALSModel", jobj = jobj) + }) + +# Returns a summary of the ALS model produced by spark.als. + +#' @param object a fitted ALS model. +#' @return \code{summary} returns summary information of the fitted model, which is a list. +#' The list includes \code{user} (the names of the user column), +#' \code{item} (the item column), \code{rating} (the rating column), \code{userFactors} +#' (the estimated user factors), \code{itemFactors} (the estimated item factors), +#' and \code{rank} (rank of the matrix factorization model). +#' @rdname spark.als +#' @aliases summary,ALSModel-method +#' @export +#' @note summary(ALSModel) since 2.1.0 +setMethod("summary", signature(object = "ALSModel"), + function(object) { + jobj <- object@jobj + user <- callJMethod(jobj, "userCol") + item <- callJMethod(jobj, "itemCol") + rating <- callJMethod(jobj, "ratingCol") + userFactors <- dataFrame(callJMethod(jobj, "userFactors")) + itemFactors <- dataFrame(callJMethod(jobj, "itemFactors")) + rank <- callJMethod(jobj, "rank") + list(user = user, item = item, rating = rating, userFactors = userFactors, + itemFactors = itemFactors, rank = rank) + }) + +# Makes predictions from an ALS model or a model produced by spark.als. + +#' @param newData a SparkDataFrame for testing. +#' @return \code{predict} returns a SparkDataFrame containing predicted values. +#' @rdname spark.als +#' @aliases predict,ALSModel-method +#' @export +#' @note predict(ALSModel) since 2.1.0 +setMethod("predict", signature(object = "ALSModel"), + function(object, newData) { + predict_internal(object, newData) + }) + +# Saves the ALS model to the input path. + +#' @param path the directory where the model is saved. +#' @param overwrite logical value indicating whether to overwrite if the output path +#' already exists. Default is FALSE which means throw exception +#' if the output path exists. +#' +#' @rdname spark.als +#' @aliases write.ml,ALSModel,character-method +#' @export +#' @seealso \link{read.ml} +#' @note write.ml(ALSModel, character) since 2.1.0 +setMethod("write.ml", signature(object = "ALSModel", path = "character"), + function(object, path, overwrite = FALSE) { + write_internal(object, path, overwrite) + }) diff --git a/R/pkg/R/mllib_regression.R b/R/pkg/R/mllib_regression.R new file mode 100644 index 0000000000000..a480168a29126 --- /dev/null +++ b/R/pkg/R/mllib_regression.R @@ -0,0 +1,448 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +# mllib_regression.R: Provides methods for MLlib regression algorithms +# (except for tree-based algorithms) integration + +#' S4 class that represents a AFTSurvivalRegressionModel +#' +#' @param jobj a Java object reference to the backing Scala AFTSurvivalRegressionWrapper +#' @export +#' @note AFTSurvivalRegressionModel since 2.0.0 +setClass("AFTSurvivalRegressionModel", representation(jobj = "jobj")) + +#' S4 class that represents a generalized linear model +#' +#' @param jobj a Java object reference to the backing Scala GeneralizedLinearRegressionWrapper +#' @export +#' @note GeneralizedLinearRegressionModel since 2.0.0 +setClass("GeneralizedLinearRegressionModel", representation(jobj = "jobj")) + +#' S4 class that represents an IsotonicRegressionModel +#' +#' @param jobj a Java object reference to the backing Scala IsotonicRegressionModel +#' @export +#' @note IsotonicRegressionModel since 2.1.0 +setClass("IsotonicRegressionModel", representation(jobj = "jobj")) + +#' Generalized Linear Models +#' +#' Fits generalized linear model against a Spark DataFrame. +#' Users can call \code{summary} to print a summary of the fitted model, \code{predict} to make +#' predictions on new data, and \code{write.ml}/\code{read.ml} to save/load fitted models. +#' +#' @param data a SparkDataFrame for training. +#' @param formula a symbolic description of the model to be fitted. Currently only a few formula +#' operators are supported, including '~', '.', ':', '+', and '-'. +#' @param family a description of the error distribution and link function to be used in the model. +#' This can be a character string naming a family function, a family function or +#' the result of a call to a family function. Refer R family at +#' \url{https://stat.ethz.ch/R-manual/R-devel/library/stats/html/family.html}. +#' @param tol positive convergence tolerance of iterations. +#' @param maxIter integer giving the maximal number of IRLS iterations. +#' @param weightCol the weight column name. If this is not set or \code{NULL}, we treat all instance +#' weights as 1.0. +#' @param regParam regularization parameter for L2 regularization. +#' @param ... additional arguments passed to the method. +#' @aliases spark.glm,SparkDataFrame,formula-method +#' @return \code{spark.glm} returns a fitted generalized linear model. +#' @rdname spark.glm +#' @name spark.glm +#' @export +#' @examples +#' \dontrun{ +#' sparkR.session() +#' data(iris) +#' df <- createDataFrame(iris) +#' model <- spark.glm(df, Sepal_Length ~ Sepal_Width, family = "gaussian") +#' summary(model) +#' +#' # fitted values on training data +#' fitted <- predict(model, df) +#' head(select(fitted, "Sepal_Length", "prediction")) +#' +#' # save fitted model to input path +#' path <- "path/to/model" +#' write.ml(model, path) +#' +#' # can also read back the saved model and print +#' savedModel <- read.ml(path) +#' summary(savedModel) +#' } +#' @note spark.glm since 2.0.0 +#' @seealso \link{glm}, \link{read.ml} +setMethod("spark.glm", signature(data = "SparkDataFrame", formula = "formula"), + function(data, formula, family = gaussian, tol = 1e-6, maxIter = 25, weightCol = NULL, + regParam = 0.0) { + if (is.character(family)) { + family <- get(family, mode = "function", envir = parent.frame()) + } + if (is.function(family)) { + family <- family() + } + if (is.null(family$family)) { + print(family) + stop("'family' not recognized") + } + + formula <- paste(deparse(formula), collapse = "") + if (is.null(weightCol)) { + weightCol <- "" + } + + jobj <- callJStatic("org.apache.spark.ml.r.GeneralizedLinearRegressionWrapper", + "fit", formula, data@sdf, family$family, family$link, + tol, as.integer(maxIter), as.character(weightCol), regParam) + new("GeneralizedLinearRegressionModel", jobj = jobj) + }) + +#' Generalized Linear Models (R-compliant) +#' +#' Fits a generalized linear model, similarly to R's glm(). +#' @param formula a symbolic description of the model to be fitted. Currently only a few formula +#' operators are supported, including '~', '.', ':', '+', and '-'. +#' @param data a SparkDataFrame or R's glm data for training. +#' @param family a description of the error distribution and link function to be used in the model. +#' This can be a character string naming a family function, a family function or +#' the result of a call to a family function. Refer R family at +#' \url{https://stat.ethz.ch/R-manual/R-devel/library/stats/html/family.html}. +#' @param weightCol the weight column name. If this is not set or \code{NULL}, we treat all instance +#' weights as 1.0. +#' @param epsilon positive convergence tolerance of iterations. +#' @param maxit integer giving the maximal number of IRLS iterations. +#' @return \code{glm} returns a fitted generalized linear model. +#' @rdname glm +#' @export +#' @examples +#' \dontrun{ +#' sparkR.session() +#' data(iris) +#' df <- createDataFrame(iris) +#' model <- glm(Sepal_Length ~ Sepal_Width, df, family = "gaussian") +#' summary(model) +#' } +#' @note glm since 1.5.0 +#' @seealso \link{spark.glm} +setMethod("glm", signature(formula = "formula", family = "ANY", data = "SparkDataFrame"), + function(formula, family = gaussian, data, epsilon = 1e-6, maxit = 25, weightCol = NULL) { + spark.glm(data, formula, family, tol = epsilon, maxIter = maxit, weightCol = weightCol) + }) + +# Returns the summary of a model produced by glm() or spark.glm(), similarly to R's summary(). + +#' @param object a fitted generalized linear model. +#' @return \code{summary} returns summary information of the fitted model, which is a list. +#' The list of components includes at least the \code{coefficients} (coefficients matrix, which includes +#' coefficients, standard error of coefficients, t value and p value), +#' \code{null.deviance} (null/residual degrees of freedom), \code{aic} (AIC) +#' and \code{iter} (number of iterations IRLS takes). If there are collinear columns in the data, +#' the coefficients matrix only provides coefficients. +#' @rdname spark.glm +#' @export +#' @note summary(GeneralizedLinearRegressionModel) since 2.0.0 +setMethod("summary", signature(object = "GeneralizedLinearRegressionModel"), + function(object) { + jobj <- object@jobj + is.loaded <- callJMethod(jobj, "isLoaded") + features <- callJMethod(jobj, "rFeatures") + coefficients <- callJMethod(jobj, "rCoefficients") + dispersion <- callJMethod(jobj, "rDispersion") + null.deviance <- callJMethod(jobj, "rNullDeviance") + deviance <- callJMethod(jobj, "rDeviance") + df.null <- callJMethod(jobj, "rResidualDegreeOfFreedomNull") + df.residual <- callJMethod(jobj, "rResidualDegreeOfFreedom") + aic <- callJMethod(jobj, "rAic") + iter <- callJMethod(jobj, "rNumIterations") + family <- callJMethod(jobj, "rFamily") + deviance.resid <- if (is.loaded) { + NULL + } else { + dataFrame(callJMethod(jobj, "rDevianceResiduals")) + } + # If the underlying WeightedLeastSquares using "normal" solver, we can provide + # coefficients, standard error of coefficients, t value and p value. Otherwise, + # it will be fitted by local "l-bfgs", we can only provide coefficients. + if (length(features) == length(coefficients)) { + coefficients <- matrix(coefficients, ncol = 1) + colnames(coefficients) <- c("Estimate") + rownames(coefficients) <- unlist(features) + } else { + coefficients <- matrix(coefficients, ncol = 4) + colnames(coefficients) <- c("Estimate", "Std. Error", "t value", "Pr(>|t|)") + rownames(coefficients) <- unlist(features) + } + ans <- list(deviance.resid = deviance.resid, coefficients = coefficients, + dispersion = dispersion, null.deviance = null.deviance, + deviance = deviance, df.null = df.null, df.residual = df.residual, + aic = aic, iter = iter, family = family, is.loaded = is.loaded) + class(ans) <- "summary.GeneralizedLinearRegressionModel" + ans + }) + +# Prints the summary of GeneralizedLinearRegressionModel + +#' @rdname spark.glm +#' @param x summary object of fitted generalized linear model returned by \code{summary} function. +#' @export +#' @note print.summary.GeneralizedLinearRegressionModel since 2.0.0 +print.summary.GeneralizedLinearRegressionModel <- function(x, ...) { + if (x$is.loaded) { + cat("\nSaved-loaded model does not support output 'Deviance Residuals'.\n") + } else { + x$deviance.resid <- setNames(unlist(approxQuantile(x$deviance.resid, "devianceResiduals", + c(0.0, 0.25, 0.5, 0.75, 1.0), 0.01)), c("Min", "1Q", "Median", "3Q", "Max")) + x$deviance.resid <- zapsmall(x$deviance.resid, 5L) + cat("\nDeviance Residuals: \n") + cat("(Note: These are approximate quantiles with relative error <= 0.01)\n") + print.default(x$deviance.resid, digits = 5L, na.print = "", print.gap = 2L) + } + + cat("\nCoefficients:\n") + print.default(x$coefficients, digits = 5L, na.print = "", print.gap = 2L) + + cat("\n(Dispersion parameter for ", x$family, " family taken to be ", format(x$dispersion), + ")\n\n", apply(cbind(paste(format(c("Null", "Residual"), justify = "right"), "deviance:"), + format(unlist(x[c("null.deviance", "deviance")]), digits = 5L), + " on", format(unlist(x[c("df.null", "df.residual")])), " degrees of freedom\n"), + 1L, paste, collapse = " "), sep = "") + cat("AIC: ", format(x$aic, digits = 4L), "\n\n", + "Number of Fisher Scoring iterations: ", x$iter, "\n\n", sep = "") + invisible(x) + } + +# Makes predictions from a generalized linear model produced by glm() or spark.glm(), +# similarly to R's predict(). + +#' @param newData a SparkDataFrame for testing. +#' @return \code{predict} returns a SparkDataFrame containing predicted labels in a column named +#' "prediction". +#' @rdname spark.glm +#' @export +#' @note predict(GeneralizedLinearRegressionModel) since 1.5.0 +setMethod("predict", signature(object = "GeneralizedLinearRegressionModel"), + function(object, newData) { + predict_internal(object, newData) + }) + +# Saves the generalized linear model to the input path. + +#' @param path the directory where the model is saved. +#' @param overwrite overwrites or not if the output path already exists. Default is FALSE +#' which means throw exception if the output path exists. +#' +#' @rdname spark.glm +#' @export +#' @note write.ml(GeneralizedLinearRegressionModel, character) since 2.0.0 +setMethod("write.ml", signature(object = "GeneralizedLinearRegressionModel", path = "character"), + function(object, path, overwrite = FALSE) { + write_internal(object, path, overwrite) + }) + +#' Isotonic Regression Model +#' +#' Fits an Isotonic Regression model against a Spark DataFrame, similarly to R's isoreg(). +#' Users can print, make predictions on the produced model and save the model to the input path. +#' +#' @param data SparkDataFrame for training. +#' @param formula A symbolic description of the model to be fitted. Currently only a few formula +#' operators are supported, including '~', '.', ':', '+', and '-'. +#' @param isotonic Whether the output sequence should be isotonic/increasing (TRUE) or +#' antitonic/decreasing (FALSE). +#' @param featureIndex The index of the feature if \code{featuresCol} is a vector column +#' (default: 0), no effect otherwise. +#' @param weightCol The weight column name. +#' @param ... additional arguments passed to the method. +#' @return \code{spark.isoreg} returns a fitted Isotonic Regression model. +#' @rdname spark.isoreg +#' @aliases spark.isoreg,SparkDataFrame,formula-method +#' @name spark.isoreg +#' @export +#' @examples +#' \dontrun{ +#' sparkR.session() +#' data <- list(list(7.0, 0.0), list(5.0, 1.0), list(3.0, 2.0), +#' list(5.0, 3.0), list(1.0, 4.0)) +#' df <- createDataFrame(data, c("label", "feature")) +#' model <- spark.isoreg(df, label ~ feature, isotonic = FALSE) +#' # return model boundaries and prediction as lists +#' result <- summary(model, df) +#' # prediction based on fitted model +#' predict_data <- list(list(-2.0), list(-1.0), list(0.5), +#' list(0.75), list(1.0), list(2.0), list(9.0)) +#' predict_df <- createDataFrame(predict_data, c("feature")) +#' # get prediction column +#' predict_result <- collect(select(predict(model, predict_df), "prediction")) +#' +#' # save fitted model to input path +#' path <- "path/to/model" +#' write.ml(model, path) +#' +#' # can also read back the saved model and print +#' savedModel <- read.ml(path) +#' summary(savedModel) +#' } +#' @note spark.isoreg since 2.1.0 +setMethod("spark.isoreg", signature(data = "SparkDataFrame", formula = "formula"), + function(data, formula, isotonic = TRUE, featureIndex = 0, weightCol = NULL) { + formula <- paste(deparse(formula), collapse = "") + + if (is.null(weightCol)) { + weightCol <- "" + } + + jobj <- callJStatic("org.apache.spark.ml.r.IsotonicRegressionWrapper", "fit", + data@sdf, formula, as.logical(isotonic), as.integer(featureIndex), + as.character(weightCol)) + new("IsotonicRegressionModel", jobj = jobj) + }) + +# Get the summary of an IsotonicRegressionModel model + +#' @return \code{summary} returns summary information of the fitted model, which is a list. +#' The list includes model's \code{boundaries} (boundaries in increasing order) +#' and \code{predictions} (predictions associated with the boundaries at the same index). +#' @rdname spark.isoreg +#' @aliases summary,IsotonicRegressionModel-method +#' @export +#' @note summary(IsotonicRegressionModel) since 2.1.0 +setMethod("summary", signature(object = "IsotonicRegressionModel"), + function(object) { + jobj <- object@jobj + boundaries <- callJMethod(jobj, "boundaries") + predictions <- callJMethod(jobj, "predictions") + list(boundaries = boundaries, predictions = predictions) + }) + +# Predicted values based on an isotonicRegression model + +#' @param object a fitted IsotonicRegressionModel. +#' @param newData SparkDataFrame for testing. +#' @return \code{predict} returns a SparkDataFrame containing predicted values. +#' @rdname spark.isoreg +#' @aliases predict,IsotonicRegressionModel,SparkDataFrame-method +#' @export +#' @note predict(IsotonicRegressionModel) since 2.1.0 +setMethod("predict", signature(object = "IsotonicRegressionModel"), + function(object, newData) { + predict_internal(object, newData) + }) + +# Save fitted IsotonicRegressionModel to the input path + +#' @param path The directory where the model is saved. +#' @param overwrite Overwrites or not if the output path already exists. Default is FALSE +#' which means throw exception if the output path exists. +#' +#' @rdname spark.isoreg +#' @aliases write.ml,IsotonicRegressionModel,character-method +#' @export +#' @note write.ml(IsotonicRegression, character) since 2.1.0 +setMethod("write.ml", signature(object = "IsotonicRegressionModel", path = "character"), + function(object, path, overwrite = FALSE) { + write_internal(object, path, overwrite) + }) + +#' Accelerated Failure Time (AFT) Survival Regression Model +#' +#' \code{spark.survreg} fits an accelerated failure time (AFT) survival regression model on +#' a SparkDataFrame. Users can call \code{summary} to get a summary of the fitted AFT model, +#' \code{predict} to make predictions on new data, and \code{write.ml}/\code{read.ml} to +#' save/load fitted models. +#' +#' @param data a SparkDataFrame for training. +#' @param formula a symbolic description of the model to be fitted. Currently only a few formula +#' operators are supported, including '~', ':', '+', and '-'. +#' Note that operator '.' is not supported currently. +#' @return \code{spark.survreg} returns a fitted AFT survival regression model. +#' @rdname spark.survreg +#' @seealso survival: \url{https://cran.r-project.org/package=survival} +#' @export +#' @examples +#' \dontrun{ +#' df <- createDataFrame(ovarian) +#' model <- spark.survreg(df, Surv(futime, fustat) ~ ecog_ps + rx) +#' +#' # get a summary of the model +#' summary(model) +#' +#' # make predictions +#' predicted <- predict(model, df) +#' showDF(predicted) +#' +#' # save and load the model +#' path <- "path/to/model" +#' write.ml(model, path) +#' savedModel <- read.ml(path) +#' summary(savedModel) +#' } +#' @note spark.survreg since 2.0.0 +setMethod("spark.survreg", signature(data = "SparkDataFrame", formula = "formula"), + function(data, formula) { + formula <- paste(deparse(formula), collapse = "") + jobj <- callJStatic("org.apache.spark.ml.r.AFTSurvivalRegressionWrapper", + "fit", formula, data@sdf) + new("AFTSurvivalRegressionModel", jobj = jobj) + }) + +# Returns a summary of the AFT survival regression model produced by spark.survreg, +# similarly to R's summary(). + +#' @param object a fitted AFT survival regression model. +#' @return \code{summary} returns summary information of the fitted model, which is a list. +#' The list includes the model's \code{coefficients} (features, coefficients, +#' intercept and log(scale)). +#' @rdname spark.survreg +#' @export +#' @note summary(AFTSurvivalRegressionModel) since 2.0.0 +setMethod("summary", signature(object = "AFTSurvivalRegressionModel"), + function(object) { + jobj <- object@jobj + features <- callJMethod(jobj, "rFeatures") + coefficients <- callJMethod(jobj, "rCoefficients") + coefficients <- as.matrix(unlist(coefficients)) + colnames(coefficients) <- c("Value") + rownames(coefficients) <- unlist(features) + list(coefficients = coefficients) + }) + +# Makes predictions from an AFT survival regression model or a model produced by +# spark.survreg, similarly to R package survival's predict. + +#' @param newData a SparkDataFrame for testing. +#' @return \code{predict} returns a SparkDataFrame containing predicted values +#' on the original scale of the data (mean predicted value at scale = 1.0). +#' @rdname spark.survreg +#' @export +#' @note predict(AFTSurvivalRegressionModel) since 2.0.0 +setMethod("predict", signature(object = "AFTSurvivalRegressionModel"), + function(object, newData) { + predict_internal(object, newData) + }) + +# Saves the AFT survival regression model to the input path. + +#' @param path the directory where the model is saved. +#' @param overwrite overwrites or not if the output path already exists. Default is FALSE +#' which means throw exception if the output path exists. +#' @rdname spark.survreg +#' @export +#' @note write.ml(AFTSurvivalRegressionModel, character) since 2.0.0 +#' @seealso \link{write.ml} +setMethod("write.ml", signature(object = "AFTSurvivalRegressionModel", path = "character"), + function(object, path, overwrite = FALSE) { + write_internal(object, path, overwrite) + }) diff --git a/R/pkg/R/mllib_stat.R b/R/pkg/R/mllib_stat.R new file mode 100644 index 0000000000000..3e013f1d45e38 --- /dev/null +++ b/R/pkg/R/mllib_stat.R @@ -0,0 +1,127 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +# mllib_stat.R: Provides methods for MLlib statistics algorithms integration + +#' S4 class that represents an KSTest +#' +#' @param jobj a Java object reference to the backing Scala KSTestWrapper +#' @export +#' @note KSTest since 2.1.0 +setClass("KSTest", representation(jobj = "jobj")) + +#' (One-Sample) Kolmogorov-Smirnov Test +#' +#' @description +#' \code{spark.kstest} Conduct the two-sided Kolmogorov-Smirnov (KS) test for data sampled from a +#' continuous distribution. +#' +#' By comparing the largest difference between the empirical cumulative +#' distribution of the sample data and the theoretical distribution we can provide a test for the +#' the null hypothesis that the sample data comes from that theoretical distribution. +#' +#' Users can call \code{summary} to obtain a summary of the test, and \code{print.summary.KSTest} +#' to print out a summary result. +#' +#' @param data a SparkDataFrame of user data. +#' @param testCol column name where the test data is from. It should be a column of double type. +#' @param nullHypothesis name of the theoretical distribution tested against. Currently only +#' \code{"norm"} for normal distribution is supported. +#' @param distParams parameters(s) of the distribution. For \code{nullHypothesis = "norm"}, +#' we can provide as a vector the mean and standard deviation of +#' the distribution. If none is provided, then standard normal will be used. +#' If only one is provided, then the standard deviation will be set to be one. +#' @param ... additional argument(s) passed to the method. +#' @return \code{spark.kstest} returns a test result object. +#' @rdname spark.kstest +#' @aliases spark.kstest,SparkDataFrame-method +#' @name spark.kstest +#' @seealso \href{http://spark.apache.org/docs/latest/mllib-statistics.html#hypothesis-testing}{ +#' MLlib: Hypothesis Testing} +#' @export +#' @examples +#' \dontrun{ +#' data <- data.frame(test = c(0.1, 0.15, 0.2, 0.3, 0.25)) +#' df <- createDataFrame(data) +#' test <- spark.kstest(df, "test", "norm", c(0, 1)) +#' +#' # get a summary of the test result +#' testSummary <- summary(test) +#' testSummary +#' +#' # print out the summary in an organized way +#' print.summary.KSTest(testSummary) +#' } +#' @note spark.kstest since 2.1.0 +setMethod("spark.kstest", signature(data = "SparkDataFrame"), + function(data, testCol = "test", nullHypothesis = c("norm"), distParams = c(0, 1)) { + tryCatch(match.arg(nullHypothesis), + error = function(e) { + msg <- paste("Distribution", nullHypothesis, "is not supported.") + stop(msg) + }) + if (nullHypothesis == "norm") { + distParams <- as.numeric(distParams) + mu <- ifelse(length(distParams) < 1, 0, distParams[1]) + sigma <- ifelse(length(distParams) < 2, 1, distParams[2]) + jobj <- callJStatic("org.apache.spark.ml.r.KSTestWrapper", + "test", data@sdf, testCol, nullHypothesis, + as.array(c(mu, sigma))) + new("KSTest", jobj = jobj) + } +}) + +# Get the summary of Kolmogorov-Smirnov (KS) Test. + +#' @param object test result object of KSTest by \code{spark.kstest}. +#' @return \code{summary} returns summary information of KSTest object, which is a list. +#' The list includes the \code{p.value} (p-value), \code{statistic} (test statistic +#' computed for the test), \code{nullHypothesis} (the null hypothesis with its +#' parameters tested against) and \code{degreesOfFreedom} (degrees of freedom of the test). +#' @rdname spark.kstest +#' @aliases summary,KSTest-method +#' @export +#' @note summary(KSTest) since 2.1.0 +setMethod("summary", signature(object = "KSTest"), + function(object) { + jobj <- object@jobj + pValue <- callJMethod(jobj, "pValue") + statistic <- callJMethod(jobj, "statistic") + nullHypothesis <- callJMethod(jobj, "nullHypothesis") + distName <- callJMethod(jobj, "distName") + distParams <- unlist(callJMethod(jobj, "distParams")) + degreesOfFreedom <- callJMethod(jobj, "degreesOfFreedom") + + ans <- list(p.value = pValue, statistic = statistic, nullHypothesis = nullHypothesis, + nullHypothesis.name = distName, nullHypothesis.parameters = distParams, + degreesOfFreedom = degreesOfFreedom, jobj = jobj) + class(ans) <- "summary.KSTest" + ans + }) + +# Prints the summary of KSTest + +#' @rdname spark.kstest +#' @param x summary object of KSTest returned by \code{summary}. +#' @export +#' @note print.summary.KSTest since 2.1.0 +print.summary.KSTest <- function(x, ...) { + jobj <- x$jobj + summaryStr <- callJMethod(jobj, "summary") + cat(summaryStr, "\n") + invisible(x) +} diff --git a/R/pkg/R/mllib_tree.R b/R/pkg/R/mllib_tree.R new file mode 100644 index 0000000000000..0d53fad061809 --- /dev/null +++ b/R/pkg/R/mllib_tree.R @@ -0,0 +1,496 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +# mllib_tree.R: Provides methods for MLlib tree-based algorithms integration + +#' S4 class that represents a GBTRegressionModel +#' +#' @param jobj a Java object reference to the backing Scala GBTRegressionModel +#' @export +#' @note GBTRegressionModel since 2.1.0 +setClass("GBTRegressionModel", representation(jobj = "jobj")) + +#' S4 class that represents a GBTClassificationModel +#' +#' @param jobj a Java object reference to the backing Scala GBTClassificationModel +#' @export +#' @note GBTClassificationModel since 2.1.0 +setClass("GBTClassificationModel", representation(jobj = "jobj")) + +#' S4 class that represents a RandomForestRegressionModel +#' +#' @param jobj a Java object reference to the backing Scala RandomForestRegressionModel +#' @export +#' @note RandomForestRegressionModel since 2.1.0 +setClass("RandomForestRegressionModel", representation(jobj = "jobj")) + +#' S4 class that represents a RandomForestClassificationModel +#' +#' @param jobj a Java object reference to the backing Scala RandomForestClassificationModel +#' @export +#' @note RandomForestClassificationModel since 2.1.0 +setClass("RandomForestClassificationModel", representation(jobj = "jobj")) + +# Create the summary of a tree ensemble model (eg. Random Forest, GBT) +summary.treeEnsemble <- function(model) { + jobj <- model@jobj + formula <- callJMethod(jobj, "formula") + numFeatures <- callJMethod(jobj, "numFeatures") + features <- callJMethod(jobj, "features") + featureImportances <- callJMethod(callJMethod(jobj, "featureImportances"), "toString") + numTrees <- callJMethod(jobj, "numTrees") + treeWeights <- callJMethod(jobj, "treeWeights") + list(formula = formula, + numFeatures = numFeatures, + features = features, + featureImportances = featureImportances, + numTrees = numTrees, + treeWeights = treeWeights, + jobj = jobj) +} + +# Prints the summary of tree ensemble models (eg. Random Forest, GBT) +print.summary.treeEnsemble <- function(x) { + jobj <- x$jobj + cat("Formula: ", x$formula) + cat("\nNumber of features: ", x$numFeatures) + cat("\nFeatures: ", unlist(x$features)) + cat("\nFeature importances: ", x$featureImportances) + cat("\nNumber of trees: ", x$numTrees) + cat("\nTree weights: ", unlist(x$treeWeights)) + + summaryStr <- callJMethod(jobj, "summary") + cat("\n", summaryStr, "\n") + invisible(x) +} + +#' Gradient Boosted Tree Model for Regression and Classification +#' +#' \code{spark.gbt} fits a Gradient Boosted Tree Regression model or Classification model on a +#' SparkDataFrame. Users can call \code{summary} to get a summary of the fitted +#' Gradient Boosted Tree model, \code{predict} to make predictions on new data, and +#' \code{write.ml}/\code{read.ml} to save/load fitted models. +#' For more details, see +#' \href{http://spark.apache.org/docs/latest/ml-classification-regression.html#gradient-boosted-tree-regression}{ +#' GBT Regression} and +#' \href{http://spark.apache.org/docs/latest/ml-classification-regression.html#gradient-boosted-tree-classifier}{ +#' GBT Classification} +#' +#' @param data a SparkDataFrame for training. +#' @param formula a symbolic description of the model to be fitted. Currently only a few formula +#' operators are supported, including '~', ':', '+', and '-'. +#' @param type type of model, one of "regression" or "classification", to fit +#' @param maxDepth Maximum depth of the tree (>= 0). +#' @param maxBins Maximum number of bins used for discretizing continuous features and for choosing +#' how to split on features at each node. More bins give higher granularity. Must be +#' >= 2 and >= number of categories in any categorical feature. +#' @param maxIter Param for maximum number of iterations (>= 0). +#' @param stepSize Param for Step size to be used for each iteration of optimization. +#' @param lossType Loss function which GBT tries to minimize. +#' For classification, must be "logistic". For regression, must be one of +#' "squared" (L2) and "absolute" (L1), default is "squared". +#' @param seed integer seed for random number generation. +#' @param subsamplingRate Fraction of the training data used for learning each decision tree, in +#' range (0, 1]. +#' @param minInstancesPerNode Minimum number of instances each child must have after split. If a +#' split causes the left or right child to have fewer than +#' minInstancesPerNode, the split will be discarded as invalid. Should be +#' >= 1. +#' @param minInfoGain Minimum information gain for a split to be considered at a tree node. +#' @param checkpointInterval Param for set checkpoint interval (>= 1) or disable checkpoint (-1). +#' @param maxMemoryInMB Maximum memory in MB allocated to histogram aggregation. +#' @param cacheNodeIds If FALSE, the algorithm will pass trees to executors to match instances with +#' nodes. If TRUE, the algorithm will cache node IDs for each instance. Caching +#' can speed up training of deeper trees. Users can set how often should the +#' cache be checkpointed or disable it by setting checkpointInterval. +#' @param ... additional arguments passed to the method. +#' @aliases spark.gbt,SparkDataFrame,formula-method +#' @return \code{spark.gbt} returns a fitted Gradient Boosted Tree model. +#' @rdname spark.gbt +#' @name spark.gbt +#' @export +#' @examples +#' \dontrun{ +#' # fit a Gradient Boosted Tree Regression Model +#' df <- createDataFrame(longley) +#' model <- spark.gbt(df, Employed ~ ., type = "regression", maxDepth = 5, maxBins = 16) +#' +#' # get the summary of the model +#' summary(model) +#' +#' # make predictions +#' predictions <- predict(model, df) +#' +#' # save and load the model +#' path <- "path/to/model" +#' write.ml(model, path) +#' savedModel <- read.ml(path) +#' summary(savedModel) +#' +#' # fit a Gradient Boosted Tree Classification Model +#' # label must be binary - Only binary classification is supported for GBT. +#' df <- createDataFrame(iris[iris$Species != "virginica", ]) +#' model <- spark.gbt(df, Species ~ Petal_Length + Petal_Width, "classification") +#' +#' # numeric label is also supported +#' iris2 <- iris[iris$Species != "virginica", ] +#' iris2$NumericSpecies <- ifelse(iris2$Species == "setosa", 0, 1) +#' df <- createDataFrame(iris2) +#' model <- spark.gbt(df, NumericSpecies ~ ., type = "classification") +#' } +#' @note spark.gbt since 2.1.0 +setMethod("spark.gbt", signature(data = "SparkDataFrame", formula = "formula"), + function(data, formula, type = c("regression", "classification"), + maxDepth = 5, maxBins = 32, maxIter = 20, stepSize = 0.1, lossType = NULL, + seed = NULL, subsamplingRate = 1.0, minInstancesPerNode = 1, minInfoGain = 0.0, + checkpointInterval = 10, maxMemoryInMB = 256, cacheNodeIds = FALSE) { + type <- match.arg(type) + formula <- paste(deparse(formula), collapse = "") + if (!is.null(seed)) { + seed <- as.character(as.integer(seed)) + } + switch(type, + regression = { + if (is.null(lossType)) lossType <- "squared" + lossType <- match.arg(lossType, c("squared", "absolute")) + jobj <- callJStatic("org.apache.spark.ml.r.GBTRegressorWrapper", + "fit", data@sdf, formula, as.integer(maxDepth), + as.integer(maxBins), as.integer(maxIter), + as.numeric(stepSize), as.integer(minInstancesPerNode), + as.numeric(minInfoGain), as.integer(checkpointInterval), + lossType, seed, as.numeric(subsamplingRate), + as.integer(maxMemoryInMB), as.logical(cacheNodeIds)) + new("GBTRegressionModel", jobj = jobj) + }, + classification = { + if (is.null(lossType)) lossType <- "logistic" + lossType <- match.arg(lossType, "logistic") + jobj <- callJStatic("org.apache.spark.ml.r.GBTClassifierWrapper", + "fit", data@sdf, formula, as.integer(maxDepth), + as.integer(maxBins), as.integer(maxIter), + as.numeric(stepSize), as.integer(minInstancesPerNode), + as.numeric(minInfoGain), as.integer(checkpointInterval), + lossType, seed, as.numeric(subsamplingRate), + as.integer(maxMemoryInMB), as.logical(cacheNodeIds)) + new("GBTClassificationModel", jobj = jobj) + } + ) + }) + +# Get the summary of a Gradient Boosted Tree Regression Model + +#' @return \code{summary} returns summary information of the fitted model, which is a list. +#' The list of components includes \code{formula} (formula), +#' \code{numFeatures} (number of features), \code{features} (list of features), +#' \code{featureImportances} (feature importances), \code{numTrees} (number of trees), +#' and \code{treeWeights} (tree weights). +#' @rdname spark.gbt +#' @aliases summary,GBTRegressionModel-method +#' @export +#' @note summary(GBTRegressionModel) since 2.1.0 +setMethod("summary", signature(object = "GBTRegressionModel"), + function(object) { + ans <- summary.treeEnsemble(object) + class(ans) <- "summary.GBTRegressionModel" + ans + }) + +# Prints the summary of Gradient Boosted Tree Regression Model + +#' @param x summary object of Gradient Boosted Tree regression model or classification model +#' returned by \code{summary}. +#' @rdname spark.gbt +#' @export +#' @note print.summary.GBTRegressionModel since 2.1.0 +print.summary.GBTRegressionModel <- function(x, ...) { + print.summary.treeEnsemble(x) +} + +# Get the summary of a Gradient Boosted Tree Classification Model + +#' @rdname spark.gbt +#' @aliases summary,GBTClassificationModel-method +#' @export +#' @note summary(GBTClassificationModel) since 2.1.0 +setMethod("summary", signature(object = "GBTClassificationModel"), + function(object) { + ans <- summary.treeEnsemble(object) + class(ans) <- "summary.GBTClassificationModel" + ans + }) + +# Prints the summary of Gradient Boosted Tree Classification Model + +#' @rdname spark.gbt +#' @export +#' @note print.summary.GBTClassificationModel since 2.1.0 +print.summary.GBTClassificationModel <- function(x, ...) { + print.summary.treeEnsemble(x) +} + +# Makes predictions from a Gradient Boosted Tree Regression model or Classification model + +#' @param newData a SparkDataFrame for testing. +#' @return \code{predict} returns a SparkDataFrame containing predicted labeled in a column named +#' "prediction". +#' @rdname spark.gbt +#' @aliases predict,GBTRegressionModel-method +#' @export +#' @note predict(GBTRegressionModel) since 2.1.0 +setMethod("predict", signature(object = "GBTRegressionModel"), + function(object, newData) { + predict_internal(object, newData) + }) + +#' @rdname spark.gbt +#' @aliases predict,GBTClassificationModel-method +#' @export +#' @note predict(GBTClassificationModel) since 2.1.0 +setMethod("predict", signature(object = "GBTClassificationModel"), + function(object, newData) { + predict_internal(object, newData) + }) + +# Save the Gradient Boosted Tree Regression or Classification model to the input path. + +#' @param object A fitted Gradient Boosted Tree regression model or classification model. +#' @param path The directory where the model is saved. +#' @param overwrite Overwrites or not if the output path already exists. Default is FALSE +#' which means throw exception if the output path exists. +#' @aliases write.ml,GBTRegressionModel,character-method +#' @rdname spark.gbt +#' @export +#' @note write.ml(GBTRegressionModel, character) since 2.1.0 +setMethod("write.ml", signature(object = "GBTRegressionModel", path = "character"), + function(object, path, overwrite = FALSE) { + write_internal(object, path, overwrite) + }) + +#' @aliases write.ml,GBTClassificationModel,character-method +#' @rdname spark.gbt +#' @export +#' @note write.ml(GBTClassificationModel, character) since 2.1.0 +setMethod("write.ml", signature(object = "GBTClassificationModel", path = "character"), + function(object, path, overwrite = FALSE) { + write_internal(object, path, overwrite) + }) + +#' Random Forest Model for Regression and Classification +#' +#' \code{spark.randomForest} fits a Random Forest Regression model or Classification model on +#' a SparkDataFrame. Users can call \code{summary} to get a summary of the fitted Random Forest +#' model, \code{predict} to make predictions on new data, and \code{write.ml}/\code{read.ml} to +#' save/load fitted models. +#' For more details, see +#' \href{http://spark.apache.org/docs/latest/ml-classification-regression.html#random-forest-regression}{ +#' Random Forest Regression} and +#' \href{http://spark.apache.org/docs/latest/ml-classification-regression.html#random-forest-classifier}{ +#' Random Forest Classification} +#' +#' @param data a SparkDataFrame for training. +#' @param formula a symbolic description of the model to be fitted. Currently only a few formula +#' operators are supported, including '~', ':', '+', and '-'. +#' @param type type of model, one of "regression" or "classification", to fit +#' @param maxDepth Maximum depth of the tree (>= 0). +#' @param maxBins Maximum number of bins used for discretizing continuous features and for choosing +#' how to split on features at each node. More bins give higher granularity. Must be +#' >= 2 and >= number of categories in any categorical feature. +#' @param numTrees Number of trees to train (>= 1). +#' @param impurity Criterion used for information gain calculation. +#' For regression, must be "variance". For classification, must be one of +#' "entropy" and "gini", default is "gini". +#' @param featureSubsetStrategy The number of features to consider for splits at each tree node. +#' Supported options: "auto", "all", "onethird", "sqrt", "log2", (0.0-1.0], [1-n]. +#' @param seed integer seed for random number generation. +#' @param subsamplingRate Fraction of the training data used for learning each decision tree, in +#' range (0, 1]. +#' @param minInstancesPerNode Minimum number of instances each child must have after split. +#' @param minInfoGain Minimum information gain for a split to be considered at a tree node. +#' @param checkpointInterval Param for set checkpoint interval (>= 1) or disable checkpoint (-1). +#' @param maxMemoryInMB Maximum memory in MB allocated to histogram aggregation. +#' @param cacheNodeIds If FALSE, the algorithm will pass trees to executors to match instances with +#' nodes. If TRUE, the algorithm will cache node IDs for each instance. Caching +#' can speed up training of deeper trees. Users can set how often should the +#' cache be checkpointed or disable it by setting checkpointInterval. +#' @param ... additional arguments passed to the method. +#' @aliases spark.randomForest,SparkDataFrame,formula-method +#' @return \code{spark.randomForest} returns a fitted Random Forest model. +#' @rdname spark.randomForest +#' @name spark.randomForest +#' @export +#' @examples +#' \dontrun{ +#' # fit a Random Forest Regression Model +#' df <- createDataFrame(longley) +#' model <- spark.randomForest(df, Employed ~ ., type = "regression", maxDepth = 5, maxBins = 16) +#' +#' # get the summary of the model +#' summary(model) +#' +#' # make predictions +#' predictions <- predict(model, df) +#' +#' # save and load the model +#' path <- "path/to/model" +#' write.ml(model, path) +#' savedModel <- read.ml(path) +#' summary(savedModel) +#' +#' # fit a Random Forest Classification Model +#' df <- createDataFrame(iris) +#' model <- spark.randomForest(df, Species ~ Petal_Length + Petal_Width, "classification") +#' } +#' @note spark.randomForest since 2.1.0 +setMethod("spark.randomForest", signature(data = "SparkDataFrame", formula = "formula"), + function(data, formula, type = c("regression", "classification"), + maxDepth = 5, maxBins = 32, numTrees = 20, impurity = NULL, + featureSubsetStrategy = "auto", seed = NULL, subsamplingRate = 1.0, + minInstancesPerNode = 1, minInfoGain = 0.0, checkpointInterval = 10, + maxMemoryInMB = 256, cacheNodeIds = FALSE) { + type <- match.arg(type) + formula <- paste(deparse(formula), collapse = "") + if (!is.null(seed)) { + seed <- as.character(as.integer(seed)) + } + switch(type, + regression = { + if (is.null(impurity)) impurity <- "variance" + impurity <- match.arg(impurity, "variance") + jobj <- callJStatic("org.apache.spark.ml.r.RandomForestRegressorWrapper", + "fit", data@sdf, formula, as.integer(maxDepth), + as.integer(maxBins), as.integer(numTrees), + impurity, as.integer(minInstancesPerNode), + as.numeric(minInfoGain), as.integer(checkpointInterval), + as.character(featureSubsetStrategy), seed, + as.numeric(subsamplingRate), + as.integer(maxMemoryInMB), as.logical(cacheNodeIds)) + new("RandomForestRegressionModel", jobj = jobj) + }, + classification = { + if (is.null(impurity)) impurity <- "gini" + impurity <- match.arg(impurity, c("gini", "entropy")) + jobj <- callJStatic("org.apache.spark.ml.r.RandomForestClassifierWrapper", + "fit", data@sdf, formula, as.integer(maxDepth), + as.integer(maxBins), as.integer(numTrees), + impurity, as.integer(minInstancesPerNode), + as.numeric(minInfoGain), as.integer(checkpointInterval), + as.character(featureSubsetStrategy), seed, + as.numeric(subsamplingRate), + as.integer(maxMemoryInMB), as.logical(cacheNodeIds)) + new("RandomForestClassificationModel", jobj = jobj) + } + ) + }) + +# Get the summary of a Random Forest Regression Model + +#' @return \code{summary} returns summary information of the fitted model, which is a list. +#' The list of components includes \code{formula} (formula), +#' \code{numFeatures} (number of features), \code{features} (list of features), +#' \code{featureImportances} (feature importances), \code{numTrees} (number of trees), +#' and \code{treeWeights} (tree weights). +#' @rdname spark.randomForest +#' @aliases summary,RandomForestRegressionModel-method +#' @export +#' @note summary(RandomForestRegressionModel) since 2.1.0 +setMethod("summary", signature(object = "RandomForestRegressionModel"), + function(object) { + ans <- summary.treeEnsemble(object) + class(ans) <- "summary.RandomForestRegressionModel" + ans + }) + +# Prints the summary of Random Forest Regression Model + +#' @param x summary object of Random Forest regression model or classification model +#' returned by \code{summary}. +#' @rdname spark.randomForest +#' @export +#' @note print.summary.RandomForestRegressionModel since 2.1.0 +print.summary.RandomForestRegressionModel <- function(x, ...) { + print.summary.treeEnsemble(x) +} + +# Get the summary of a Random Forest Classification Model + +#' @rdname spark.randomForest +#' @aliases summary,RandomForestClassificationModel-method +#' @export +#' @note summary(RandomForestClassificationModel) since 2.1.0 +setMethod("summary", signature(object = "RandomForestClassificationModel"), + function(object) { + ans <- summary.treeEnsemble(object) + class(ans) <- "summary.RandomForestClassificationModel" + ans + }) + +# Prints the summary of Random Forest Classification Model + +#' @rdname spark.randomForest +#' @export +#' @note print.summary.RandomForestClassificationModel since 2.1.0 +print.summary.RandomForestClassificationModel <- function(x, ...) { + print.summary.treeEnsemble(x) +} + +# Makes predictions from a Random Forest Regression model or Classification model + +#' @param newData a SparkDataFrame for testing. +#' @return \code{predict} returns a SparkDataFrame containing predicted labeled in a column named +#' "prediction". +#' @rdname spark.randomForest +#' @aliases predict,RandomForestRegressionModel-method +#' @export +#' @note predict(RandomForestRegressionModel) since 2.1.0 +setMethod("predict", signature(object = "RandomForestRegressionModel"), + function(object, newData) { + predict_internal(object, newData) + }) + +#' @rdname spark.randomForest +#' @aliases predict,RandomForestClassificationModel-method +#' @export +#' @note predict(RandomForestClassificationModel) since 2.1.0 +setMethod("predict", signature(object = "RandomForestClassificationModel"), + function(object, newData) { + predict_internal(object, newData) + }) + +# Save the Random Forest Regression or Classification model to the input path. + +#' @param object A fitted Random Forest regression model or classification model. +#' @param path The directory where the model is saved. +#' @param overwrite Overwrites or not if the output path already exists. Default is FALSE +#' which means throw exception if the output path exists. +#' +#' @aliases write.ml,RandomForestRegressionModel,character-method +#' @rdname spark.randomForest +#' @export +#' @note write.ml(RandomForestRegressionModel, character) since 2.1.0 +setMethod("write.ml", signature(object = "RandomForestRegressionModel", path = "character"), + function(object, path, overwrite = FALSE) { + write_internal(object, path, overwrite) + }) + +#' @aliases write.ml,RandomForestClassificationModel,character-method +#' @rdname spark.randomForest +#' @export +#' @note write.ml(RandomForestClassificationModel, character) since 2.1.0 +setMethod("write.ml", signature(object = "RandomForestClassificationModel", path = "character"), + function(object, path, overwrite = FALSE) { + write_internal(object, path, overwrite) + }) diff --git a/R/pkg/R/mllib_utils.R b/R/pkg/R/mllib_utils.R new file mode 100644 index 0000000000000..720ee41c58065 --- /dev/null +++ b/R/pkg/R/mllib_utils.R @@ -0,0 +1,119 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +# mllib_utils.R: Utilities for MLlib integration + +# Integration with R's standard functions. +# Most of MLlib's argorithms are provided in two flavours: +# - a specialization of the default R methods (glm). These methods try to respect +# the inputs and the outputs of R's method to the largest extent, but some small differences +# may exist. +# - a set of methods that reflect the arguments of the other languages supported by Spark. These +# methods are prefixed with the `spark.` prefix: spark.glm, spark.kmeans, etc. + +#' Saves the MLlib model to the input path +#' +#' Saves the MLlib model to the input path. For more information, see the specific +#' MLlib model below. +#' @rdname write.ml +#' @name write.ml +#' @export +#' @seealso \link{spark.glm}, \link{glm}, +#' @seealso \link{spark.als}, \link{spark.gaussianMixture}, \link{spark.gbt}, \link{spark.isoreg}, +#' @seealso \link{spark.kmeans}, +#' @seealso \link{spark.lda}, \link{spark.logit}, \link{spark.mlp}, \link{spark.naiveBayes}, +#' @seealso \link{spark.randomForest}, \link{spark.survreg}, +#' @seealso \link{read.ml} +NULL + +#' Makes predictions from a MLlib model +#' +#' Makes predictions from a MLlib model. For more information, see the specific +#' MLlib model below. +#' @rdname predict +#' @name predict +#' @export +#' @seealso \link{spark.glm}, \link{glm}, +#' @seealso \link{spark.als}, \link{spark.gaussianMixture}, \link{spark.gbt}, \link{spark.isoreg}, +#' @seealso \link{spark.kmeans}, +#' @seealso \link{spark.logit}, \link{spark.mlp}, \link{spark.naiveBayes}, +#' @seealso \link{spark.randomForest}, \link{spark.survreg} +NULL + +write_internal <- function(object, path, overwrite = FALSE) { + writer <- callJMethod(object@jobj, "write") + if (overwrite) { + writer <- callJMethod(writer, "overwrite") + } + invisible(callJMethod(writer, "save", path)) +} + +predict_internal <- function(object, newData) { + dataFrame(callJMethod(object@jobj, "transform", newData@sdf)) +} + +#' Load a fitted MLlib model from the input path. +#' +#' @param path path of the model to read. +#' @return A fitted MLlib model. +#' @rdname read.ml +#' @name read.ml +#' @export +#' @seealso \link{write.ml} +#' @examples +#' \dontrun{ +#' path <- "path/to/model" +#' model <- read.ml(path) +#' } +#' @note read.ml since 2.0.0 +read.ml <- function(path) { + path <- suppressWarnings(normalizePath(path)) + sparkSession <- getSparkSession() + callJStatic("org.apache.spark.ml.r.RWrappers", "session", sparkSession) + jobj <- callJStatic("org.apache.spark.ml.r.RWrappers", "load", path) + if (isInstanceOf(jobj, "org.apache.spark.ml.r.NaiveBayesWrapper")) { + new("NaiveBayesModel", jobj = jobj) + } else if (isInstanceOf(jobj, "org.apache.spark.ml.r.AFTSurvivalRegressionWrapper")) { + new("AFTSurvivalRegressionModel", jobj = jobj) + } else if (isInstanceOf(jobj, "org.apache.spark.ml.r.GeneralizedLinearRegressionWrapper")) { + new("GeneralizedLinearRegressionModel", jobj = jobj) + } else if (isInstanceOf(jobj, "org.apache.spark.ml.r.KMeansWrapper")) { + new("KMeansModel", jobj = jobj) + } else if (isInstanceOf(jobj, "org.apache.spark.ml.r.LDAWrapper")) { + new("LDAModel", jobj = jobj) + } else if (isInstanceOf(jobj, "org.apache.spark.ml.r.MultilayerPerceptronClassifierWrapper")) { + new("MultilayerPerceptronClassificationModel", jobj = jobj) + } else if (isInstanceOf(jobj, "org.apache.spark.ml.r.IsotonicRegressionWrapper")) { + new("IsotonicRegressionModel", jobj = jobj) + } else if (isInstanceOf(jobj, "org.apache.spark.ml.r.GaussianMixtureWrapper")) { + new("GaussianMixtureModel", jobj = jobj) + } else if (isInstanceOf(jobj, "org.apache.spark.ml.r.ALSWrapper")) { + new("ALSModel", jobj = jobj) + } else if (isInstanceOf(jobj, "org.apache.spark.ml.r.LogisticRegressionWrapper")) { + new("LogisticRegressionModel", jobj = jobj) + } else if (isInstanceOf(jobj, "org.apache.spark.ml.r.RandomForestRegressorWrapper")) { + new("RandomForestRegressionModel", jobj = jobj) + } else if (isInstanceOf(jobj, "org.apache.spark.ml.r.RandomForestClassifierWrapper")) { + new("RandomForestClassificationModel", jobj = jobj) + } else if (isInstanceOf(jobj, "org.apache.spark.ml.r.GBTRegressorWrapper")) { + new("GBTRegressionModel", jobj = jobj) + } else if (isInstanceOf(jobj, "org.apache.spark.ml.r.GBTClassifierWrapper")) { + new("GBTClassificationModel", jobj = jobj) + } else { + stop("Unsupported model: ", jobj) + } +} diff --git a/R/pkg/inst/tests/testthat/test_mllib.R b/R/pkg/inst/tests/testthat/test_mllib.R deleted file mode 100644 index 0f0d831c6f7c0..0000000000000 --- a/R/pkg/inst/tests/testthat/test_mllib.R +++ /dev/null @@ -1,1170 +0,0 @@ -# -# Licensed to the Apache Software Foundation (ASF) under one or more -# contributor license agreements. See the NOTICE file distributed with -# this work for additional information regarding copyright ownership. -# The ASF licenses this file to You under the Apache License, Version 2.0 -# (the "License"); you may not use this file except in compliance with -# the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# - -library(testthat) - -context("MLlib functions") - -# Tests for MLlib functions in SparkR -sparkSession <- sparkR.session(enableHiveSupport = FALSE) - -absoluteSparkPath <- function(x) { - sparkHome <- sparkR.conf("spark.home") - file.path(sparkHome, x) -} - -test_that("formula of spark.glm", { - training <- suppressWarnings(createDataFrame(iris)) - # directly calling the spark API - # dot minus and intercept vs native glm - model <- spark.glm(training, Sepal_Width ~ . - Species + 0) - vals <- collect(select(predict(model, training), "prediction")) - rVals <- predict(glm(Sepal.Width ~ . - Species + 0, data = iris), iris) - expect_true(all(abs(rVals - vals) < 1e-6), rVals - vals) - - # feature interaction vs native glm - model <- spark.glm(training, Sepal_Width ~ Species:Sepal_Length) - vals <- collect(select(predict(model, training), "prediction")) - rVals <- predict(glm(Sepal.Width ~ Species:Sepal.Length, data = iris), iris) - expect_true(all(abs(rVals - vals) < 1e-6), rVals - vals) - - # glm should work with long formula - training <- suppressWarnings(createDataFrame(iris)) - training$LongLongLongLongLongName <- training$Sepal_Width - training$VeryLongLongLongLonLongName <- training$Sepal_Length - training$AnotherLongLongLongLongName <- training$Species - model <- spark.glm(training, LongLongLongLongLongName ~ VeryLongLongLongLonLongName + - AnotherLongLongLongLongName) - vals <- collect(select(predict(model, training), "prediction")) - rVals <- predict(glm(Sepal.Width ~ Sepal.Length + Species, data = iris), iris) - expect_true(all(abs(rVals - vals) < 1e-6), rVals - vals) -}) - -test_that("spark.glm and predict", { - training <- suppressWarnings(createDataFrame(iris)) - # gaussian family - model <- spark.glm(training, Sepal_Width ~ Sepal_Length + Species) - prediction <- predict(model, training) - expect_equal(typeof(take(select(prediction, "prediction"), 1)$prediction), "double") - vals <- collect(select(prediction, "prediction")) - rVals <- predict(glm(Sepal.Width ~ Sepal.Length + Species, data = iris), iris) - expect_true(all(abs(rVals - vals) < 1e-6), rVals - vals) - - # poisson family - model <- spark.glm(training, Sepal_Width ~ Sepal_Length + Species, - family = poisson(link = identity)) - prediction <- predict(model, training) - expect_equal(typeof(take(select(prediction, "prediction"), 1)$prediction), "double") - vals <- collect(select(prediction, "prediction")) - rVals <- suppressWarnings(predict(glm(Sepal.Width ~ Sepal.Length + Species, - data = iris, family = poisson(link = identity)), iris)) - expect_true(all(abs(rVals - vals) < 1e-6), rVals - vals) - - # Test stats::predict is working - x <- rnorm(15) - y <- x + rnorm(15) - expect_equal(length(predict(lm(y ~ x))), 15) -}) - -test_that("spark.glm summary", { - # gaussian family - training <- suppressWarnings(createDataFrame(iris)) - stats <- summary(spark.glm(training, Sepal_Width ~ Sepal_Length + Species)) - - rStats <- summary(glm(Sepal.Width ~ Sepal.Length + Species, data = iris)) - - coefs <- unlist(stats$coefficients) - rCoefs <- unlist(rStats$coefficients) - expect_true(all(abs(rCoefs - coefs) < 1e-4)) - expect_true(all( - rownames(stats$coefficients) == - c("(Intercept)", "Sepal_Length", "Species_versicolor", "Species_virginica"))) - expect_equal(stats$dispersion, rStats$dispersion) - expect_equal(stats$null.deviance, rStats$null.deviance) - expect_equal(stats$deviance, rStats$deviance) - expect_equal(stats$df.null, rStats$df.null) - expect_equal(stats$df.residual, rStats$df.residual) - expect_equal(stats$aic, rStats$aic) - - out <- capture.output(print(stats)) - expect_match(out[2], "Deviance Residuals:") - expect_true(any(grepl("AIC: 59.22", out))) - - # binomial family - df <- suppressWarnings(createDataFrame(iris)) - training <- df[df$Species %in% c("versicolor", "virginica"), ] - stats <- summary(spark.glm(training, Species ~ Sepal_Length + Sepal_Width, - family = binomial(link = "logit"))) - - rTraining <- iris[iris$Species %in% c("versicolor", "virginica"), ] - rStats <- summary(glm(Species ~ Sepal.Length + Sepal.Width, data = rTraining, - family = binomial(link = "logit"))) - - coefs <- unlist(stats$coefficients) - rCoefs <- unlist(rStats$coefficients) - expect_true(all(abs(rCoefs - coefs) < 1e-4)) - expect_true(all( - rownames(stats$coefficients) == - c("(Intercept)", "Sepal_Length", "Sepal_Width"))) - expect_equal(stats$dispersion, rStats$dispersion) - expect_equal(stats$null.deviance, rStats$null.deviance) - expect_equal(stats$deviance, rStats$deviance) - expect_equal(stats$df.null, rStats$df.null) - expect_equal(stats$df.residual, rStats$df.residual) - expect_equal(stats$aic, rStats$aic) - - # Test spark.glm works with weighted dataset - a1 <- c(0, 1, 2, 3) - a2 <- c(5, 2, 1, 3) - w <- c(1, 2, 3, 4) - b <- c(1, 0, 1, 0) - data <- as.data.frame(cbind(a1, a2, w, b)) - df <- createDataFrame(data) - - stats <- summary(spark.glm(df, b ~ a1 + a2, family = "binomial", weightCol = "w")) - rStats <- summary(glm(b ~ a1 + a2, family = "binomial", data = data, weights = w)) - - coefs <- unlist(stats$coefficients) - rCoefs <- unlist(rStats$coefficients) - expect_true(all(abs(rCoefs - coefs) < 1e-3)) - expect_true(all(rownames(stats$coefficients) == c("(Intercept)", "a1", "a2"))) - expect_equal(stats$dispersion, rStats$dispersion) - expect_equal(stats$null.deviance, rStats$null.deviance) - expect_equal(stats$deviance, rStats$deviance) - expect_equal(stats$df.null, rStats$df.null) - expect_equal(stats$df.residual, rStats$df.residual) - expect_equal(stats$aic, rStats$aic) - - # Test summary works on base GLM models - baseModel <- stats::glm(Sepal.Width ~ Sepal.Length + Species, data = iris) - baseSummary <- summary(baseModel) - expect_true(abs(baseSummary$deviance - 12.19313) < 1e-4) - - # Test spark.glm works with regularization parameter - data <- as.data.frame(cbind(a1, a2, b)) - df <- suppressWarnings(createDataFrame(data)) - regStats <- summary(spark.glm(df, b ~ a1 + a2, regParam = 1.0)) - expect_equal(regStats$aic, 13.32836, tolerance = 1e-4) # 13.32836 is from summary() result - - # Test spark.glm works on collinear data - A <- matrix(c(1, 2, 3, 4, 2, 4, 6, 8), 4, 2) - b <- c(1, 2, 3, 4) - data <- as.data.frame(cbind(A, b)) - df <- createDataFrame(data) - stats <- summary(spark.glm(df, b ~ . - 1)) - coefs <- unlist(stats$coefficients) - expect_true(all(abs(c(0.5, 0.25) - coefs) < 1e-4)) -}) - -test_that("spark.glm save/load", { - training <- suppressWarnings(createDataFrame(iris)) - m <- spark.glm(training, Sepal_Width ~ Sepal_Length + Species) - s <- summary(m) - - modelPath <- tempfile(pattern = "spark-glm", fileext = ".tmp") - write.ml(m, modelPath) - expect_error(write.ml(m, modelPath)) - write.ml(m, modelPath, overwrite = TRUE) - m2 <- read.ml(modelPath) - s2 <- summary(m2) - - expect_equal(s$coefficients, s2$coefficients) - expect_equal(rownames(s$coefficients), rownames(s2$coefficients)) - expect_equal(s$dispersion, s2$dispersion) - expect_equal(s$null.deviance, s2$null.deviance) - expect_equal(s$deviance, s2$deviance) - expect_equal(s$df.null, s2$df.null) - expect_equal(s$df.residual, s2$df.residual) - expect_equal(s$aic, s2$aic) - expect_equal(s$iter, s2$iter) - expect_true(!s$is.loaded) - expect_true(s2$is.loaded) - - unlink(modelPath) -}) - - - -test_that("formula of glm", { - training <- suppressWarnings(createDataFrame(iris)) - # dot minus and intercept vs native glm - model <- glm(Sepal_Width ~ . - Species + 0, data = training) - vals <- collect(select(predict(model, training), "prediction")) - rVals <- predict(glm(Sepal.Width ~ . - Species + 0, data = iris), iris) - expect_true(all(abs(rVals - vals) < 1e-6), rVals - vals) - - # feature interaction vs native glm - model <- glm(Sepal_Width ~ Species:Sepal_Length, data = training) - vals <- collect(select(predict(model, training), "prediction")) - rVals <- predict(glm(Sepal.Width ~ Species:Sepal.Length, data = iris), iris) - expect_true(all(abs(rVals - vals) < 1e-6), rVals - vals) - - # glm should work with long formula - training <- suppressWarnings(createDataFrame(iris)) - training$LongLongLongLongLongName <- training$Sepal_Width - training$VeryLongLongLongLonLongName <- training$Sepal_Length - training$AnotherLongLongLongLongName <- training$Species - model <- glm(LongLongLongLongLongName ~ VeryLongLongLongLonLongName + AnotherLongLongLongLongName, - data = training) - vals <- collect(select(predict(model, training), "prediction")) - rVals <- predict(glm(Sepal.Width ~ Sepal.Length + Species, data = iris), iris) - expect_true(all(abs(rVals - vals) < 1e-6), rVals - vals) -}) - -test_that("glm and predict", { - training <- suppressWarnings(createDataFrame(iris)) - # gaussian family - model <- glm(Sepal_Width ~ Sepal_Length + Species, data = training) - prediction <- predict(model, training) - expect_equal(typeof(take(select(prediction, "prediction"), 1)$prediction), "double") - vals <- collect(select(prediction, "prediction")) - rVals <- predict(glm(Sepal.Width ~ Sepal.Length + Species, data = iris), iris) - expect_true(all(abs(rVals - vals) < 1e-6), rVals - vals) - - # poisson family - model <- glm(Sepal_Width ~ Sepal_Length + Species, data = training, - family = poisson(link = identity)) - prediction <- predict(model, training) - expect_equal(typeof(take(select(prediction, "prediction"), 1)$prediction), "double") - vals <- collect(select(prediction, "prediction")) - rVals <- suppressWarnings(predict(glm(Sepal.Width ~ Sepal.Length + Species, - data = iris, family = poisson(link = identity)), iris)) - expect_true(all(abs(rVals - vals) < 1e-6), rVals - vals) - - # Test stats::predict is working - x <- rnorm(15) - y <- x + rnorm(15) - expect_equal(length(predict(lm(y ~ x))), 15) -}) - -test_that("glm summary", { - # gaussian family - training <- suppressWarnings(createDataFrame(iris)) - stats <- summary(glm(Sepal_Width ~ Sepal_Length + Species, data = training)) - - rStats <- summary(glm(Sepal.Width ~ Sepal.Length + Species, data = iris)) - - coefs <- unlist(stats$coefficients) - rCoefs <- unlist(rStats$coefficients) - expect_true(all(abs(rCoefs - coefs) < 1e-4)) - expect_true(all( - rownames(stats$coefficients) == - c("(Intercept)", "Sepal_Length", "Species_versicolor", "Species_virginica"))) - expect_equal(stats$dispersion, rStats$dispersion) - expect_equal(stats$null.deviance, rStats$null.deviance) - expect_equal(stats$deviance, rStats$deviance) - expect_equal(stats$df.null, rStats$df.null) - expect_equal(stats$df.residual, rStats$df.residual) - expect_equal(stats$aic, rStats$aic) - - # binomial family - df <- suppressWarnings(createDataFrame(iris)) - training <- df[df$Species %in% c("versicolor", "virginica"), ] - stats <- summary(glm(Species ~ Sepal_Length + Sepal_Width, data = training, - family = binomial(link = "logit"))) - - rTraining <- iris[iris$Species %in% c("versicolor", "virginica"), ] - rStats <- summary(glm(Species ~ Sepal.Length + Sepal.Width, data = rTraining, - family = binomial(link = "logit"))) - - coefs <- unlist(stats$coefficients) - rCoefs <- unlist(rStats$coefficients) - expect_true(all(abs(rCoefs - coefs) < 1e-4)) - expect_true(all( - rownames(stats$coefficients) == - c("(Intercept)", "Sepal_Length", "Sepal_Width"))) - expect_equal(stats$dispersion, rStats$dispersion) - expect_equal(stats$null.deviance, rStats$null.deviance) - expect_equal(stats$deviance, rStats$deviance) - expect_equal(stats$df.null, rStats$df.null) - expect_equal(stats$df.residual, rStats$df.residual) - expect_equal(stats$aic, rStats$aic) - - # Test summary works on base GLM models - baseModel <- stats::glm(Sepal.Width ~ Sepal.Length + Species, data = iris) - baseSummary <- summary(baseModel) - expect_true(abs(baseSummary$deviance - 12.19313) < 1e-4) -}) - -test_that("glm save/load", { - training <- suppressWarnings(createDataFrame(iris)) - m <- glm(Sepal_Width ~ Sepal_Length + Species, data = training) - s <- summary(m) - - modelPath <- tempfile(pattern = "glm", fileext = ".tmp") - write.ml(m, modelPath) - expect_error(write.ml(m, modelPath)) - write.ml(m, modelPath, overwrite = TRUE) - m2 <- read.ml(modelPath) - s2 <- summary(m2) - - expect_equal(s$coefficients, s2$coefficients) - expect_equal(rownames(s$coefficients), rownames(s2$coefficients)) - expect_equal(s$dispersion, s2$dispersion) - expect_equal(s$null.deviance, s2$null.deviance) - expect_equal(s$deviance, s2$deviance) - expect_equal(s$df.null, s2$df.null) - expect_equal(s$df.residual, s2$df.residual) - expect_equal(s$aic, s2$aic) - expect_equal(s$iter, s2$iter) - expect_true(!s$is.loaded) - expect_true(s2$is.loaded) - - unlink(modelPath) -}) - -test_that("spark.kmeans", { - newIris <- iris - newIris$Species <- NULL - training <- suppressWarnings(createDataFrame(newIris)) - - take(training, 1) - - model <- spark.kmeans(data = training, ~ ., k = 2, maxIter = 10, initMode = "random") - sample <- take(select(predict(model, training), "prediction"), 1) - expect_equal(typeof(sample$prediction), "integer") - expect_equal(sample$prediction, 1) - - # Test stats::kmeans is working - statsModel <- kmeans(x = newIris, centers = 2) - expect_equal(sort(unique(statsModel$cluster)), c(1, 2)) - - # Test fitted works on KMeans - fitted.model <- fitted(model) - expect_equal(sort(collect(distinct(select(fitted.model, "prediction")))$prediction), c(0, 1)) - - # Test summary works on KMeans - summary.model <- summary(model) - cluster <- summary.model$cluster - k <- summary.model$k - expect_equal(k, 2) - expect_equal(sort(collect(distinct(select(cluster, "prediction")))$prediction), c(0, 1)) - - # Test model save/load - modelPath <- tempfile(pattern = "spark-kmeans", fileext = ".tmp") - write.ml(model, modelPath) - expect_error(write.ml(model, modelPath)) - write.ml(model, modelPath, overwrite = TRUE) - model2 <- read.ml(modelPath) - summary2 <- summary(model2) - expect_equal(sort(unlist(summary.model$size)), sort(unlist(summary2$size))) - expect_equal(summary.model$coefficients, summary2$coefficients) - expect_true(!summary.model$is.loaded) - expect_true(summary2$is.loaded) - - unlink(modelPath) -}) - -test_that("spark.mlp", { - df <- read.df(absoluteSparkPath("data/mllib/sample_multiclass_classification_data.txt"), - source = "libsvm") - model <- spark.mlp(df, label ~ features, blockSize = 128, layers = c(4, 5, 4, 3), - solver = "l-bfgs", maxIter = 100, tol = 0.5, stepSize = 1, seed = 1) - - # Test summary method - summary <- summary(model) - expect_equal(summary$numOfInputs, 4) - expect_equal(summary$numOfOutputs, 3) - expect_equal(summary$layers, c(4, 5, 4, 3)) - expect_equal(length(summary$weights), 64) - expect_equal(head(summary$weights, 5), list(-0.878743, 0.2154151, -1.16304, -0.6583214, 1.009825), - tolerance = 1e-6) - - # Test predict method - mlpTestDF <- df - mlpPredictions <- collect(select(predict(model, mlpTestDF), "prediction")) - expect_equal(head(mlpPredictions$prediction, 6), c("1.0", "0.0", "0.0", "0.0", "0.0", "0.0")) - - # Test model save/load - modelPath <- tempfile(pattern = "spark-mlp", fileext = ".tmp") - write.ml(model, modelPath) - expect_error(write.ml(model, modelPath)) - write.ml(model, modelPath, overwrite = TRUE) - model2 <- read.ml(modelPath) - summary2 <- summary(model2) - - expect_equal(summary2$numOfInputs, 4) - expect_equal(summary2$numOfOutputs, 3) - expect_equal(summary2$layers, c(4, 5, 4, 3)) - expect_equal(length(summary2$weights), 64) - - unlink(modelPath) - - # Test default parameter - model <- spark.mlp(df, label ~ features, layers = c(4, 5, 4, 3)) - mlpPredictions <- collect(select(predict(model, mlpTestDF), "prediction")) - expect_equal(head(mlpPredictions$prediction, 10), - c("1.0", "1.0", "1.0", "1.0", "0.0", "1.0", "2.0", "2.0", "1.0", "0.0")) - - # Test illegal parameter - expect_error(spark.mlp(df, label ~ features, layers = NULL), - "layers must be a integer vector with length > 1.") - expect_error(spark.mlp(df, label ~ features, layers = c()), - "layers must be a integer vector with length > 1.") - expect_error(spark.mlp(df, label ~ features, layers = c(3)), - "layers must be a integer vector with length > 1.") - - # Test random seed - # default seed - model <- spark.mlp(df, label ~ features, layers = c(4, 5, 4, 3), maxIter = 10) - mlpPredictions <- collect(select(predict(model, mlpTestDF), "prediction")) - expect_equal(head(mlpPredictions$prediction, 10), - c("1.0", "1.0", "1.0", "1.0", "0.0", "1.0", "2.0", "2.0", "1.0", "0.0")) - # seed equals 10 - model <- spark.mlp(df, label ~ features, layers = c(4, 5, 4, 3), maxIter = 10, seed = 10) - mlpPredictions <- collect(select(predict(model, mlpTestDF), "prediction")) - expect_equal(head(mlpPredictions$prediction, 10), - c("1.0", "1.0", "1.0", "1.0", "0.0", "1.0", "2.0", "2.0", "1.0", "0.0")) - - # test initialWeights - model <- spark.mlp(df, label ~ features, layers = c(4, 3), maxIter = 2, initialWeights = - c(0, 0, 0, 0, 0, 5, 5, 5, 5, 5, 9, 9, 9, 9, 9)) - mlpPredictions <- collect(select(predict(model, mlpTestDF), "prediction")) - expect_equal(head(mlpPredictions$prediction, 10), - c("1.0", "1.0", "1.0", "1.0", "2.0", "1.0", "2.0", "2.0", "1.0", "0.0")) - - model <- spark.mlp(df, label ~ features, layers = c(4, 3), maxIter = 2, initialWeights = - c(0.0, 0.0, 0.0, 0.0, 0.0, 5.0, 5.0, 5.0, 5.0, 5.0, 9.0, 9.0, 9.0, 9.0, 9.0)) - mlpPredictions <- collect(select(predict(model, mlpTestDF), "prediction")) - expect_equal(head(mlpPredictions$prediction, 10), - c("1.0", "1.0", "1.0", "1.0", "2.0", "1.0", "2.0", "2.0", "1.0", "0.0")) - - model <- spark.mlp(df, label ~ features, layers = c(4, 3), maxIter = 2) - mlpPredictions <- collect(select(predict(model, mlpTestDF), "prediction")) - expect_equal(head(mlpPredictions$prediction, 10), - c("1.0", "1.0", "1.0", "1.0", "0.0", "1.0", "0.0", "2.0", "1.0", "0.0")) - - # Test formula works well - df <- suppressWarnings(createDataFrame(iris)) - model <- spark.mlp(df, Species ~ Sepal_Length + Sepal_Width + Petal_Length + Petal_Width, - layers = c(4, 3)) - summary <- summary(model) - expect_equal(summary$numOfInputs, 4) - expect_equal(summary$numOfOutputs, 3) - expect_equal(summary$layers, c(4, 3)) - expect_equal(length(summary$weights), 15) - expect_equal(head(summary$weights, 5), list(-1.1957257, -5.2693685, 7.4489734, -6.3751413, - -10.2376130), tolerance = 1e-6) -}) - -test_that("spark.naiveBayes", { - # R code to reproduce the result. - # We do not support instance weights yet. So we ignore the frequencies. - # - #' library(e1071) - #' t <- as.data.frame(Titanic) - #' t1 <- t[t$Freq > 0, -5] - #' m <- naiveBayes(Survived ~ ., data = t1) - #' m - #' predict(m, t1) - # - # -- output of 'm' - # - # A-priori probabilities: - # Y - # No Yes - # 0.4166667 0.5833333 - # - # Conditional probabilities: - # Class - # Y 1st 2nd 3rd Crew - # No 0.2000000 0.2000000 0.4000000 0.2000000 - # Yes 0.2857143 0.2857143 0.2857143 0.1428571 - # - # Sex - # Y Male Female - # No 0.5 0.5 - # Yes 0.5 0.5 - # - # Age - # Y Child Adult - # No 0.2000000 0.8000000 - # Yes 0.4285714 0.5714286 - # - # -- output of 'predict(m, t1)' - # - # Yes Yes Yes Yes No No Yes Yes No No Yes Yes Yes Yes Yes Yes Yes Yes No No Yes Yes No No - # - - t <- as.data.frame(Titanic) - t1 <- t[t$Freq > 0, -5] - df <- suppressWarnings(createDataFrame(t1)) - m <- spark.naiveBayes(df, Survived ~ ., smoothing = 0.0) - s <- summary(m) - expect_equal(as.double(s$apriori[1, "Yes"]), 0.5833333, tolerance = 1e-6) - expect_equal(sum(s$apriori), 1) - expect_equal(as.double(s$tables["Yes", "Age_Adult"]), 0.5714286, tolerance = 1e-6) - p <- collect(select(predict(m, df), "prediction")) - expect_equal(p$prediction, c("Yes", "Yes", "Yes", "Yes", "No", "No", "Yes", "Yes", "No", "No", - "Yes", "Yes", "Yes", "Yes", "Yes", "Yes", "Yes", "Yes", "No", "No", - "Yes", "Yes", "No", "No")) - - # Test model save/load - modelPath <- tempfile(pattern = "spark-naiveBayes", fileext = ".tmp") - write.ml(m, modelPath) - expect_error(write.ml(m, modelPath)) - write.ml(m, modelPath, overwrite = TRUE) - m2 <- read.ml(modelPath) - s2 <- summary(m2) - expect_equal(s$apriori, s2$apriori) - expect_equal(s$tables, s2$tables) - - unlink(modelPath) - - # Test e1071::naiveBayes - if (requireNamespace("e1071", quietly = TRUE)) { - expect_error(m <- e1071::naiveBayes(Survived ~ ., data = t1), NA) - expect_equal(as.character(predict(m, t1[1, ])), "Yes") - } - - # Test numeric response variable - t1$NumericSurvived <- ifelse(t1$Survived == "No", 0, 1) - t2 <- t1[-4] - df <- suppressWarnings(createDataFrame(t2)) - m <- spark.naiveBayes(df, NumericSurvived ~ ., smoothing = 0.0) - s <- summary(m) - expect_equal(as.double(s$apriori[1, 1]), 0.5833333, tolerance = 1e-6) - expect_equal(sum(s$apriori), 1) - expect_equal(as.double(s$tables[1, "Age_Adult"]), 0.5714286, tolerance = 1e-6) -}) - -test_that("spark.survreg", { - # R code to reproduce the result. - # - #' rData <- list(time = c(4, 3, 1, 1, 2, 2, 3), status = c(1, 1, 1, 0, 1, 1, 0), - #' x = c(0, 2, 1, 1, 1, 0, 0), sex = c(0, 0, 0, 0, 1, 1, 1)) - #' library(survival) - #' model <- survreg(Surv(time, status) ~ x + sex, rData) - #' summary(model) - #' predict(model, data) - # - # -- output of 'summary(model)' - # - # Value Std. Error z p - # (Intercept) 1.315 0.270 4.88 1.07e-06 - # x -0.190 0.173 -1.10 2.72e-01 - # sex -0.253 0.329 -0.77 4.42e-01 - # Log(scale) -1.160 0.396 -2.93 3.41e-03 - # - # -- output of 'predict(model, data)' - # - # 1 2 3 4 5 6 7 - # 3.724591 2.545368 3.079035 3.079035 2.390146 2.891269 2.891269 - # - data <- list(list(4, 1, 0, 0), list(3, 1, 2, 0), list(1, 1, 1, 0), - list(1, 0, 1, 0), list(2, 1, 1, 1), list(2, 1, 0, 1), list(3, 0, 0, 1)) - df <- createDataFrame(data, c("time", "status", "x", "sex")) - model <- spark.survreg(df, Surv(time, status) ~ x + sex) - stats <- summary(model) - coefs <- as.vector(stats$coefficients[, 1]) - rCoefs <- c(1.3149571, -0.1903409, -0.2532618, -1.1599800) - expect_equal(coefs, rCoefs, tolerance = 1e-4) - expect_true(all( - rownames(stats$coefficients) == - c("(Intercept)", "x", "sex", "Log(scale)"))) - p <- collect(select(predict(model, df), "prediction")) - expect_equal(p$prediction, c(3.724591, 2.545368, 3.079035, 3.079035, - 2.390146, 2.891269, 2.891269), tolerance = 1e-4) - - # Test model save/load - modelPath <- tempfile(pattern = "spark-survreg", fileext = ".tmp") - write.ml(model, modelPath) - expect_error(write.ml(model, modelPath)) - write.ml(model, modelPath, overwrite = TRUE) - model2 <- read.ml(modelPath) - stats2 <- summary(model2) - coefs2 <- as.vector(stats2$coefficients[, 1]) - expect_equal(coefs, coefs2) - expect_equal(rownames(stats$coefficients), rownames(stats2$coefficients)) - - unlink(modelPath) - - # Test survival::survreg - if (requireNamespace("survival", quietly = TRUE)) { - rData <- list(time = c(4, 3, 1, 1, 2, 2, 3), status = c(1, 1, 1, 0, 1, 1, 0), - x = c(0, 2, 1, 1, 1, 0, 0), sex = c(0, 0, 0, 0, 1, 1, 1)) - expect_error( - model <- survival::survreg(formula = survival::Surv(time, status) ~ x + sex, data = rData), - NA) - expect_equal(predict(model, rData)[[1]], 3.724591, tolerance = 1e-4) - } -}) - -test_that("spark.isotonicRegression", { - label <- c(7.0, 5.0, 3.0, 5.0, 1.0) - feature <- c(0.0, 1.0, 2.0, 3.0, 4.0) - weight <- c(1.0, 1.0, 1.0, 1.0, 1.0) - data <- as.data.frame(cbind(label, feature, weight)) - df <- createDataFrame(data) - - model <- spark.isoreg(df, label ~ feature, isotonic = FALSE, - weightCol = "weight") - # only allow one variable on the right hand side of the formula - expect_error(model2 <- spark.isoreg(df, ~., isotonic = FALSE)) - result <- summary(model) - expect_equal(result$predictions, list(7, 5, 4, 4, 1)) - - # Test model prediction - predict_data <- list(list(-2.0), list(-1.0), list(0.5), - list(0.75), list(1.0), list(2.0), list(9.0)) - predict_df <- createDataFrame(predict_data, c("feature")) - predict_result <- collect(select(predict(model, predict_df), "prediction")) - expect_equal(predict_result$prediction, c(7.0, 7.0, 6.0, 5.5, 5.0, 4.0, 1.0)) - - # Test model save/load - modelPath <- tempfile(pattern = "spark-isotonicRegression", fileext = ".tmp") - write.ml(model, modelPath) - expect_error(write.ml(model, modelPath)) - write.ml(model, modelPath, overwrite = TRUE) - model2 <- read.ml(modelPath) - expect_equal(result, summary(model2)) - - unlink(modelPath) -}) - -test_that("spark.logit", { - # R code to reproduce the result. - # nolint start - #' library(glmnet) - #' iris.x = as.matrix(iris[, 1:4]) - #' iris.y = as.factor(as.character(iris[, 5])) - #' logit = glmnet(iris.x, iris.y, family="multinomial", alpha=0, lambda=0.5) - #' coef(logit) - # - # $setosa - # 5 x 1 sparse Matrix of class "dgCMatrix" - # s0 - # 1.0981324 - # Sepal.Length -0.2909860 - # Sepal.Width 0.5510907 - # Petal.Length -0.1915217 - # Petal.Width -0.4211946 - # - # $versicolor - # 5 x 1 sparse Matrix of class "dgCMatrix" - # s0 - # 1.520061e+00 - # Sepal.Length 2.524501e-02 - # Sepal.Width -5.310313e-01 - # Petal.Length 3.656543e-02 - # Petal.Width -3.144464e-05 - # - # $virginica - # 5 x 1 sparse Matrix of class "dgCMatrix" - # s0 - # -2.61819385 - # Sepal.Length 0.26574097 - # Sepal.Width -0.02005932 - # Petal.Length 0.15495629 - # Petal.Width 0.42122607 - # nolint end - - # Test multinomial logistic regression againt three classes - df <- suppressWarnings(createDataFrame(iris)) - model <- spark.logit(df, Species ~ ., regParam = 0.5) - summary <- summary(model) - versicolorCoefsR <- c(1.52, 0.03, -0.53, 0.04, 0.00) - virginicaCoefsR <- c(-2.62, 0.27, -0.02, 0.16, 0.42) - setosaCoefsR <- c(1.10, -0.29, 0.55, -0.19, -0.42) - versicolorCoefs <- unlist(summary$coefficients[, "versicolor"]) - virginicaCoefs <- unlist(summary$coefficients[, "virginica"]) - setosaCoefs <- unlist(summary$coefficients[, "setosa"]) - expect_true(all(abs(versicolorCoefsR - versicolorCoefs) < 0.1)) - expect_true(all(abs(virginicaCoefsR - virginicaCoefs) < 0.1)) - expect_true(all(abs(setosaCoefs - setosaCoefs) < 0.1)) - - # Test model save and load - modelPath <- tempfile(pattern = "spark-logit", fileext = ".tmp") - write.ml(model, modelPath) - expect_error(write.ml(model, modelPath)) - write.ml(model, modelPath, overwrite = TRUE) - model2 <- read.ml(modelPath) - coefs <- summary(model)$coefficients - coefs2 <- summary(model2)$coefficients - expect_equal(coefs, coefs2) - unlink(modelPath) - - # R code to reproduce the result. - # nolint start - #' library(glmnet) - #' iris2 <- iris[iris$Species %in% c("versicolor", "virginica"), ] - #' iris.x = as.matrix(iris2[, 1:4]) - #' iris.y = as.factor(as.character(iris2[, 5])) - #' logit = glmnet(iris.x, iris.y, family="multinomial", alpha=0, lambda=0.5) - #' coef(logit) - # - # $versicolor - # 5 x 1 sparse Matrix of class "dgCMatrix" - # s0 - # 3.93844796 - # Sepal.Length -0.13538675 - # Sepal.Width -0.02386443 - # Petal.Length -0.35076451 - # Petal.Width -0.77971954 - # - # $virginica - # 5 x 1 sparse Matrix of class "dgCMatrix" - # s0 - # -3.93844796 - # Sepal.Length 0.13538675 - # Sepal.Width 0.02386443 - # Petal.Length 0.35076451 - # Petal.Width 0.77971954 - # - #' logit = glmnet(iris.x, iris.y, family="binomial", alpha=0, lambda=0.5) - #' coef(logit) - # - # 5 x 1 sparse Matrix of class "dgCMatrix" - # s0 - # (Intercept) -6.0824412 - # Sepal.Length 0.2458260 - # Sepal.Width 0.1642093 - # Petal.Length 0.4759487 - # Petal.Width 1.0383948 - # - # nolint end - - # Test multinomial logistic regression againt two classes - df <- suppressWarnings(createDataFrame(iris)) - training <- df[df$Species %in% c("versicolor", "virginica"), ] - model <- spark.logit(training, Species ~ ., regParam = 0.5, family = "multinomial") - summary <- summary(model) - versicolorCoefsR <- c(3.94, -0.16, -0.02, -0.35, -0.78) - virginicaCoefsR <- c(-3.94, 0.16, -0.02, 0.35, 0.78) - versicolorCoefs <- unlist(summary$coefficients[, "versicolor"]) - virginicaCoefs <- unlist(summary$coefficients[, "virginica"]) - expect_true(all(abs(versicolorCoefsR - versicolorCoefs) < 0.1)) - expect_true(all(abs(virginicaCoefsR - virginicaCoefs) < 0.1)) - - # Test binomial logistic regression againt two classes - model <- spark.logit(training, Species ~ ., regParam = 0.5) - summary <- summary(model) - coefsR <- c(-6.08, 0.25, 0.16, 0.48, 1.04) - coefs <- unlist(summary$coefficients[, "Estimate"]) - expect_true(all(abs(coefsR - coefs) < 0.1)) - - # Test prediction with string label - prediction <- predict(model, training) - expect_equal(typeof(take(select(prediction, "prediction"), 1)$prediction), "character") - expected <- c("versicolor", "versicolor", "virginica", "versicolor", "versicolor", - "versicolor", "versicolor", "versicolor", "versicolor", "versicolor") - expect_equal(as.list(take(select(prediction, "prediction"), 10))[[1]], expected) - - # Test prediction with numeric label - label <- c(0.0, 0.0, 0.0, 1.0, 1.0) - feature <- c(1.1419053, 0.9194079, -0.9498666, -1.1069903, 0.2809776) - data <- as.data.frame(cbind(label, feature)) - df <- createDataFrame(data) - model <- spark.logit(df, label ~ feature) - prediction <- collect(select(predict(model, df), "prediction")) - expect_equal(prediction$prediction, c("0.0", "0.0", "1.0", "1.0", "0.0")) -}) - -test_that("spark.gaussianMixture", { - # R code to reproduce the result. - # nolint start - #' library(mvtnorm) - #' set.seed(1) - #' a <- rmvnorm(7, c(0, 0)) - #' b <- rmvnorm(8, c(10, 10)) - #' data <- rbind(a, b) - #' model <- mvnormalmixEM(data, k = 2) - #' model$lambda - # - # [1] 0.4666667 0.5333333 - # - #' model$mu - # - # [1] 0.11731091 -0.06192351 - # [1] 10.363673 9.897081 - # - #' model$sigma - # - # [[1]] - # [,1] [,2] - # [1,] 0.62049934 0.06880802 - # [2,] 0.06880802 1.27431874 - # - # [[2]] - # [,1] [,2] - # [1,] 0.2961543 0.160783 - # [2,] 0.1607830 1.008878 - # nolint end - data <- list(list(-0.6264538, 0.1836433), list(-0.8356286, 1.5952808), - list(0.3295078, -0.8204684), list(0.4874291, 0.7383247), - list(0.5757814, -0.3053884), list(1.5117812, 0.3898432), - list(-0.6212406, -2.2146999), list(11.1249309, 9.9550664), - list(9.9838097, 10.9438362), list(10.8212212, 10.5939013), - list(10.9189774, 10.7821363), list(10.0745650, 8.0106483), - list(10.6198257, 9.9438713), list(9.8442045, 8.5292476), - list(9.5218499, 10.4179416)) - df <- createDataFrame(data, c("x1", "x2")) - model <- spark.gaussianMixture(df, ~ x1 + x2, k = 2) - stats <- summary(model) - rLambda <- c(0.4666667, 0.5333333) - rMu <- c(0.11731091, -0.06192351, 10.363673, 9.897081) - rSigma <- c(0.62049934, 0.06880802, 0.06880802, 1.27431874, - 0.2961543, 0.160783, 0.1607830, 1.008878) - expect_equal(stats$lambda, rLambda, tolerance = 1e-3) - expect_equal(unlist(stats$mu), rMu, tolerance = 1e-3) - expect_equal(unlist(stats$sigma), rSigma, tolerance = 1e-3) - p <- collect(select(predict(model, df), "prediction")) - expect_equal(p$prediction, c(0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1)) - - # Test model save/load - modelPath <- tempfile(pattern = "spark-gaussianMixture", fileext = ".tmp") - write.ml(model, modelPath) - expect_error(write.ml(model, modelPath)) - write.ml(model, modelPath, overwrite = TRUE) - model2 <- read.ml(modelPath) - stats2 <- summary(model2) - expect_equal(stats$lambda, stats2$lambda) - expect_equal(unlist(stats$mu), unlist(stats2$mu)) - expect_equal(unlist(stats$sigma), unlist(stats2$sigma)) - - unlink(modelPath) -}) - -test_that("spark.lda with libsvm", { - text <- read.df(absoluteSparkPath("data/mllib/sample_lda_libsvm_data.txt"), source = "libsvm") - model <- spark.lda(text, optimizer = "em") - - stats <- summary(model, 10) - isDistributed <- stats$isDistributed - logLikelihood <- stats$logLikelihood - logPerplexity <- stats$logPerplexity - vocabSize <- stats$vocabSize - topics <- stats$topicTopTerms - weights <- stats$topicTopTermsWeights - vocabulary <- stats$vocabulary - - expect_false(isDistributed) - expect_true(logLikelihood <= 0 & is.finite(logLikelihood)) - expect_true(logPerplexity >= 0 & is.finite(logPerplexity)) - expect_equal(vocabSize, 11) - expect_true(is.null(vocabulary)) - - # Test model save/load - modelPath <- tempfile(pattern = "spark-lda", fileext = ".tmp") - write.ml(model, modelPath) - expect_error(write.ml(model, modelPath)) - write.ml(model, modelPath, overwrite = TRUE) - model2 <- read.ml(modelPath) - stats2 <- summary(model2) - - expect_false(stats2$isDistributed) - expect_equal(logLikelihood, stats2$logLikelihood) - expect_equal(logPerplexity, stats2$logPerplexity) - expect_equal(vocabSize, stats2$vocabSize) - expect_equal(vocabulary, stats2$vocabulary) - - unlink(modelPath) -}) - -test_that("spark.lda with text input", { - text <- read.text(absoluteSparkPath("data/mllib/sample_lda_data.txt")) - model <- spark.lda(text, optimizer = "online", features = "value") - - stats <- summary(model) - isDistributed <- stats$isDistributed - logLikelihood <- stats$logLikelihood - logPerplexity <- stats$logPerplexity - vocabSize <- stats$vocabSize - topics <- stats$topicTopTerms - weights <- stats$topicTopTermsWeights - vocabulary <- stats$vocabulary - - expect_false(isDistributed) - expect_true(logLikelihood <= 0 & is.finite(logLikelihood)) - expect_true(logPerplexity >= 0 & is.finite(logPerplexity)) - expect_equal(vocabSize, 10) - expect_true(setequal(stats$vocabulary, c("0", "1", "2", "3", "4", "5", "6", "7", "8", "9"))) - - # Test model save/load - modelPath <- tempfile(pattern = "spark-lda-text", fileext = ".tmp") - write.ml(model, modelPath) - expect_error(write.ml(model, modelPath)) - write.ml(model, modelPath, overwrite = TRUE) - model2 <- read.ml(modelPath) - stats2 <- summary(model2) - - expect_false(stats2$isDistributed) - expect_equal(logLikelihood, stats2$logLikelihood) - expect_equal(logPerplexity, stats2$logPerplexity) - expect_equal(vocabSize, stats2$vocabSize) - expect_true(all.equal(vocabulary, stats2$vocabulary)) - - unlink(modelPath) -}) - -test_that("spark.posterior and spark.perplexity", { - text <- read.text(absoluteSparkPath("data/mllib/sample_lda_data.txt")) - model <- spark.lda(text, features = "value", k = 3) - - # Assert perplexities are equal - stats <- summary(model) - logPerplexity <- spark.perplexity(model, text) - expect_equal(logPerplexity, stats$logPerplexity) - - # Assert the sum of every topic distribution is equal to 1 - posterior <- spark.posterior(model, text) - local.posterior <- collect(posterior)$topicDistribution - expect_equal(length(local.posterior), sum(unlist(local.posterior))) -}) - -test_that("spark.als", { - data <- list(list(0, 0, 4.0), list(0, 1, 2.0), list(1, 1, 3.0), list(1, 2, 4.0), - list(2, 1, 1.0), list(2, 2, 5.0)) - df <- createDataFrame(data, c("user", "item", "score")) - model <- spark.als(df, ratingCol = "score", userCol = "user", itemCol = "item", - rank = 10, maxIter = 5, seed = 0, regParam = 0.1) - stats <- summary(model) - expect_equal(stats$rank, 10) - test <- createDataFrame(list(list(0, 2), list(1, 0), list(2, 0)), c("user", "item")) - predictions <- collect(predict(model, test)) - - expect_equal(predictions$prediction, c(-0.1380762, 2.6258414, -1.5018409), - tolerance = 1e-4) - - # Test model save/load - modelPath <- tempfile(pattern = "spark-als", fileext = ".tmp") - write.ml(model, modelPath) - expect_error(write.ml(model, modelPath)) - write.ml(model, modelPath, overwrite = TRUE) - model2 <- read.ml(modelPath) - stats2 <- summary(model2) - expect_equal(stats2$rating, "score") - userFactors <- collect(stats$userFactors) - itemFactors <- collect(stats$itemFactors) - userFactors2 <- collect(stats2$userFactors) - itemFactors2 <- collect(stats2$itemFactors) - - orderUser <- order(userFactors$id) - orderUser2 <- order(userFactors2$id) - expect_equal(userFactors$id[orderUser], userFactors2$id[orderUser2]) - expect_equal(userFactors$features[orderUser], userFactors2$features[orderUser2]) - - orderItem <- order(itemFactors$id) - orderItem2 <- order(itemFactors2$id) - expect_equal(itemFactors$id[orderItem], itemFactors2$id[orderItem2]) - expect_equal(itemFactors$features[orderItem], itemFactors2$features[orderItem2]) - - unlink(modelPath) -}) - -test_that("spark.kstest", { - data <- data.frame(test = c(0.1, 0.15, 0.2, 0.3, 0.25, -1, -0.5)) - df <- createDataFrame(data) - testResult <- spark.kstest(df, "test", "norm") - stats <- summary(testResult) - - rStats <- ks.test(data$test, "pnorm", alternative = "two.sided") - - expect_equal(stats$p.value, rStats$p.value, tolerance = 1e-4) - expect_equal(stats$statistic, unname(rStats$statistic), tolerance = 1e-4) - expect_match(capture.output(stats)[1], "Kolmogorov-Smirnov test summary:") - - testResult <- spark.kstest(df, "test", "norm", -0.5) - stats <- summary(testResult) - - rStats <- ks.test(data$test, "pnorm", -0.5, 1, alternative = "two.sided") - - expect_equal(stats$p.value, rStats$p.value, tolerance = 1e-4) - expect_equal(stats$statistic, unname(rStats$statistic), tolerance = 1e-4) - expect_match(capture.output(stats)[1], "Kolmogorov-Smirnov test summary:") - - # Test print.summary.KSTest - printStats <- capture.output(print.summary.KSTest(stats)) - expect_match(printStats[1], "Kolmogorov-Smirnov test summary:") - expect_match(printStats[5], - "Low presumption against null hypothesis: Sample follows theoretical distribution. ") -}) - -test_that("spark.randomForest", { - # regression - data <- suppressWarnings(createDataFrame(longley)) - model <- spark.randomForest(data, Employed ~ ., "regression", maxDepth = 5, maxBins = 16, - numTrees = 1) - - predictions <- collect(predict(model, data)) - expect_equal(predictions$prediction, c(60.323, 61.122, 60.171, 61.187, - 63.221, 63.639, 64.989, 63.761, - 66.019, 67.857, 68.169, 66.513, - 68.655, 69.564, 69.331, 70.551), - tolerance = 1e-4) - - stats <- summary(model) - expect_equal(stats$numTrees, 1) - expect_error(capture.output(stats), NA) - expect_true(length(capture.output(stats)) > 6) - - model <- spark.randomForest(data, Employed ~ ., "regression", maxDepth = 5, maxBins = 16, - numTrees = 20, seed = 123) - predictions <- collect(predict(model, data)) - expect_equal(predictions$prediction, c(60.32820, 61.22315, 60.69025, 62.11070, - 63.53160, 64.05470, 65.12710, 64.30450, - 66.70910, 67.86125, 68.08700, 67.21865, - 68.89275, 69.53180, 69.39640, 69.68250), - - tolerance = 1e-4) - stats <- summary(model) - expect_equal(stats$numTrees, 20) - - modelPath <- tempfile(pattern = "spark-randomForestRegression", fileext = ".tmp") - write.ml(model, modelPath) - expect_error(write.ml(model, modelPath)) - write.ml(model, modelPath, overwrite = TRUE) - model2 <- read.ml(modelPath) - stats2 <- summary(model2) - expect_equal(stats$formula, stats2$formula) - expect_equal(stats$numFeatures, stats2$numFeatures) - expect_equal(stats$features, stats2$features) - expect_equal(stats$featureImportances, stats2$featureImportances) - expect_equal(stats$numTrees, stats2$numTrees) - expect_equal(stats$treeWeights, stats2$treeWeights) - - unlink(modelPath) - - # classification - data <- suppressWarnings(createDataFrame(iris)) - model <- spark.randomForest(data, Species ~ Petal_Length + Petal_Width, "classification", - maxDepth = 5, maxBins = 16) - - stats <- summary(model) - expect_equal(stats$numFeatures, 2) - expect_equal(stats$numTrees, 20) - expect_error(capture.output(stats), NA) - expect_true(length(capture.output(stats)) > 6) - # Test string prediction values - predictions <- collect(predict(model, data))$prediction - expect_equal(length(grep("setosa", predictions)), 50) - expect_equal(length(grep("versicolor", predictions)), 50) - - modelPath <- tempfile(pattern = "spark-randomForestClassification", fileext = ".tmp") - write.ml(model, modelPath) - expect_error(write.ml(model, modelPath)) - write.ml(model, modelPath, overwrite = TRUE) - model2 <- read.ml(modelPath) - stats2 <- summary(model2) - expect_equal(stats$depth, stats2$depth) - expect_equal(stats$numNodes, stats2$numNodes) - expect_equal(stats$numClasses, stats2$numClasses) - - unlink(modelPath) - - # Test numeric response variable - labelToIndex <- function(species) { - switch(as.character(species), - setosa = 0.0, - versicolor = 1.0, - virginica = 2.0 - ) - } - iris$NumericSpecies <- lapply(iris$Species, labelToIndex) - data <- suppressWarnings(createDataFrame(iris[-5])) - model <- spark.randomForest(data, NumericSpecies ~ Petal_Length + Petal_Width, "classification", - maxDepth = 5, maxBins = 16) - stats <- summary(model) - expect_equal(stats$numFeatures, 2) - expect_equal(stats$numTrees, 20) - # Test numeric prediction values - predictions <- collect(predict(model, data))$prediction - expect_equal(length(grep("1.0", predictions)), 50) - expect_equal(length(grep("2.0", predictions)), 50) - - # spark.randomForest classification can work on libsvm data - data <- read.df(absoluteSparkPath("data/mllib/sample_multiclass_classification_data.txt"), - source = "libsvm") - model <- spark.randomForest(data, label ~ features, "classification") - expect_equal(summary(model)$numFeatures, 4) -}) - -test_that("spark.gbt", { - # regression - data <- suppressWarnings(createDataFrame(longley)) - model <- spark.gbt(data, Employed ~ ., "regression", maxDepth = 5, maxBins = 16, seed = 123) - predictions <- collect(predict(model, data)) - expect_equal(predictions$prediction, c(60.323, 61.122, 60.171, 61.187, - 63.221, 63.639, 64.989, 63.761, - 66.019, 67.857, 68.169, 66.513, - 68.655, 69.564, 69.331, 70.551), - tolerance = 1e-4) - stats <- summary(model) - expect_equal(stats$numTrees, 20) - expect_equal(stats$formula, "Employed ~ .") - expect_equal(stats$numFeatures, 6) - expect_equal(length(stats$treeWeights), 20) - - modelPath <- tempfile(pattern = "spark-gbtRegression", fileext = ".tmp") - write.ml(model, modelPath) - expect_error(write.ml(model, modelPath)) - write.ml(model, modelPath, overwrite = TRUE) - model2 <- read.ml(modelPath) - stats2 <- summary(model2) - expect_equal(stats$formula, stats2$formula) - expect_equal(stats$numFeatures, stats2$numFeatures) - expect_equal(stats$features, stats2$features) - expect_equal(stats$featureImportances, stats2$featureImportances) - expect_equal(stats$numTrees, stats2$numTrees) - expect_equal(stats$treeWeights, stats2$treeWeights) - - unlink(modelPath) - - # classification - # label must be binary - GBTClassifier currently only supports binary classification. - iris2 <- iris[iris$Species != "virginica", ] - data <- suppressWarnings(createDataFrame(iris2)) - model <- spark.gbt(data, Species ~ Petal_Length + Petal_Width, "classification") - stats <- summary(model) - expect_equal(stats$numFeatures, 2) - expect_equal(stats$numTrees, 20) - expect_error(capture.output(stats), NA) - expect_true(length(capture.output(stats)) > 6) - predictions <- collect(predict(model, data))$prediction - # test string prediction values - expect_equal(length(grep("setosa", predictions)), 50) - expect_equal(length(grep("versicolor", predictions)), 50) - - modelPath <- tempfile(pattern = "spark-gbtClassification", fileext = ".tmp") - write.ml(model, modelPath) - expect_error(write.ml(model, modelPath)) - write.ml(model, modelPath, overwrite = TRUE) - model2 <- read.ml(modelPath) - stats2 <- summary(model2) - expect_equal(stats$depth, stats2$depth) - expect_equal(stats$numNodes, stats2$numNodes) - expect_equal(stats$numClasses, stats2$numClasses) - - unlink(modelPath) - - iris2$NumericSpecies <- ifelse(iris2$Species == "setosa", 0, 1) - df <- suppressWarnings(createDataFrame(iris2)) - m <- spark.gbt(df, NumericSpecies ~ ., type = "classification") - s <- summary(m) - # test numeric prediction values - expect_equal(iris2$NumericSpecies, as.double(collect(predict(m, df))$prediction)) - expect_equal(s$numFeatures, 5) - expect_equal(s$numTrees, 20) - - # spark.gbt classification can work on libsvm data - data <- read.df(absoluteSparkPath("data/mllib/sample_binary_classification_data.txt"), - source = "libsvm") - model <- spark.gbt(data, label ~ features, "classification") - expect_equal(summary(model)$numFeatures, 692) -}) - -sparkR.session.stop() diff --git a/R/pkg/inst/tests/testthat/test_mllib_classification.R b/R/pkg/inst/tests/testthat/test_mllib_classification.R new file mode 100644 index 0000000000000..2e0dea321e7bf --- /dev/null +++ b/R/pkg/inst/tests/testthat/test_mllib_classification.R @@ -0,0 +1,341 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +library(testthat) + +context("MLlib classification algorithms, except for tree-based algorithms") + +# Tests for MLlib classification algorithms in SparkR +sparkSession <- sparkR.session(enableHiveSupport = FALSE) + +absoluteSparkPath <- function(x) { + sparkHome <- sparkR.conf("spark.home") + file.path(sparkHome, x) +} + +test_that("spark.logit", { + # R code to reproduce the result. + # nolint start + #' library(glmnet) + #' iris.x = as.matrix(iris[, 1:4]) + #' iris.y = as.factor(as.character(iris[, 5])) + #' logit = glmnet(iris.x, iris.y, family="multinomial", alpha=0, lambda=0.5) + #' coef(logit) + # + # $setosa + # 5 x 1 sparse Matrix of class "dgCMatrix" + # s0 + # 1.0981324 + # Sepal.Length -0.2909860 + # Sepal.Width 0.5510907 + # Petal.Length -0.1915217 + # Petal.Width -0.4211946 + # + # $versicolor + # 5 x 1 sparse Matrix of class "dgCMatrix" + # s0 + # 1.520061e+00 + # Sepal.Length 2.524501e-02 + # Sepal.Width -5.310313e-01 + # Petal.Length 3.656543e-02 + # Petal.Width -3.144464e-05 + # + # $virginica + # 5 x 1 sparse Matrix of class "dgCMatrix" + # s0 + # -2.61819385 + # Sepal.Length 0.26574097 + # Sepal.Width -0.02005932 + # Petal.Length 0.15495629 + # Petal.Width 0.42122607 + # nolint end + + # Test multinomial logistic regression againt three classes + df <- suppressWarnings(createDataFrame(iris)) + model <- spark.logit(df, Species ~ ., regParam = 0.5) + summary <- summary(model) + versicolorCoefsR <- c(1.52, 0.03, -0.53, 0.04, 0.00) + virginicaCoefsR <- c(-2.62, 0.27, -0.02, 0.16, 0.42) + setosaCoefsR <- c(1.10, -0.29, 0.55, -0.19, -0.42) + versicolorCoefs <- unlist(summary$coefficients[, "versicolor"]) + virginicaCoefs <- unlist(summary$coefficients[, "virginica"]) + setosaCoefs <- unlist(summary$coefficients[, "setosa"]) + expect_true(all(abs(versicolorCoefsR - versicolorCoefs) < 0.1)) + expect_true(all(abs(virginicaCoefsR - virginicaCoefs) < 0.1)) + expect_true(all(abs(setosaCoefs - setosaCoefs) < 0.1)) + + # Test model save and load + modelPath <- tempfile(pattern = "spark-logit", fileext = ".tmp") + write.ml(model, modelPath) + expect_error(write.ml(model, modelPath)) + write.ml(model, modelPath, overwrite = TRUE) + model2 <- read.ml(modelPath) + coefs <- summary(model)$coefficients + coefs2 <- summary(model2)$coefficients + expect_equal(coefs, coefs2) + unlink(modelPath) + + # R code to reproduce the result. + # nolint start + #' library(glmnet) + #' iris2 <- iris[iris$Species %in% c("versicolor", "virginica"), ] + #' iris.x = as.matrix(iris2[, 1:4]) + #' iris.y = as.factor(as.character(iris2[, 5])) + #' logit = glmnet(iris.x, iris.y, family="multinomial", alpha=0, lambda=0.5) + #' coef(logit) + # + # $versicolor + # 5 x 1 sparse Matrix of class "dgCMatrix" + # s0 + # 3.93844796 + # Sepal.Length -0.13538675 + # Sepal.Width -0.02386443 + # Petal.Length -0.35076451 + # Petal.Width -0.77971954 + # + # $virginica + # 5 x 1 sparse Matrix of class "dgCMatrix" + # s0 + # -3.93844796 + # Sepal.Length 0.13538675 + # Sepal.Width 0.02386443 + # Petal.Length 0.35076451 + # Petal.Width 0.77971954 + # + #' logit = glmnet(iris.x, iris.y, family="binomial", alpha=0, lambda=0.5) + #' coef(logit) + # + # 5 x 1 sparse Matrix of class "dgCMatrix" + # s0 + # (Intercept) -6.0824412 + # Sepal.Length 0.2458260 + # Sepal.Width 0.1642093 + # Petal.Length 0.4759487 + # Petal.Width 1.0383948 + # + # nolint end + + # Test multinomial logistic regression againt two classes + df <- suppressWarnings(createDataFrame(iris)) + training <- df[df$Species %in% c("versicolor", "virginica"), ] + model <- spark.logit(training, Species ~ ., regParam = 0.5, family = "multinomial") + summary <- summary(model) + versicolorCoefsR <- c(3.94, -0.16, -0.02, -0.35, -0.78) + virginicaCoefsR <- c(-3.94, 0.16, -0.02, 0.35, 0.78) + versicolorCoefs <- unlist(summary$coefficients[, "versicolor"]) + virginicaCoefs <- unlist(summary$coefficients[, "virginica"]) + expect_true(all(abs(versicolorCoefsR - versicolorCoefs) < 0.1)) + expect_true(all(abs(virginicaCoefsR - virginicaCoefs) < 0.1)) + + # Test binomial logistic regression againt two classes + model <- spark.logit(training, Species ~ ., regParam = 0.5) + summary <- summary(model) + coefsR <- c(-6.08, 0.25, 0.16, 0.48, 1.04) + coefs <- unlist(summary$coefficients[, "Estimate"]) + expect_true(all(abs(coefsR - coefs) < 0.1)) + + # Test prediction with string label + prediction <- predict(model, training) + expect_equal(typeof(take(select(prediction, "prediction"), 1)$prediction), "character") + expected <- c("versicolor", "versicolor", "virginica", "versicolor", "versicolor", + "versicolor", "versicolor", "versicolor", "versicolor", "versicolor") + expect_equal(as.list(take(select(prediction, "prediction"), 10))[[1]], expected) + + # Test prediction with numeric label + label <- c(0.0, 0.0, 0.0, 1.0, 1.0) + feature <- c(1.1419053, 0.9194079, -0.9498666, -1.1069903, 0.2809776) + data <- as.data.frame(cbind(label, feature)) + df <- createDataFrame(data) + model <- spark.logit(df, label ~ feature) + prediction <- collect(select(predict(model, df), "prediction")) + expect_equal(prediction$prediction, c("0.0", "0.0", "1.0", "1.0", "0.0")) +}) + +test_that("spark.mlp", { + df <- read.df(absoluteSparkPath("data/mllib/sample_multiclass_classification_data.txt"), + source = "libsvm") + model <- spark.mlp(df, label ~ features, blockSize = 128, layers = c(4, 5, 4, 3), + solver = "l-bfgs", maxIter = 100, tol = 0.5, stepSize = 1, seed = 1) + + # Test summary method + summary <- summary(model) + expect_equal(summary$numOfInputs, 4) + expect_equal(summary$numOfOutputs, 3) + expect_equal(summary$layers, c(4, 5, 4, 3)) + expect_equal(length(summary$weights), 64) + expect_equal(head(summary$weights, 5), list(-0.878743, 0.2154151, -1.16304, -0.6583214, 1.009825), + tolerance = 1e-6) + + # Test predict method + mlpTestDF <- df + mlpPredictions <- collect(select(predict(model, mlpTestDF), "prediction")) + expect_equal(head(mlpPredictions$prediction, 6), c("1.0", "0.0", "0.0", "0.0", "0.0", "0.0")) + + # Test model save/load + modelPath <- tempfile(pattern = "spark-mlp", fileext = ".tmp") + write.ml(model, modelPath) + expect_error(write.ml(model, modelPath)) + write.ml(model, modelPath, overwrite = TRUE) + model2 <- read.ml(modelPath) + summary2 <- summary(model2) + + expect_equal(summary2$numOfInputs, 4) + expect_equal(summary2$numOfOutputs, 3) + expect_equal(summary2$layers, c(4, 5, 4, 3)) + expect_equal(length(summary2$weights), 64) + + unlink(modelPath) + + # Test default parameter + model <- spark.mlp(df, label ~ features, layers = c(4, 5, 4, 3)) + mlpPredictions <- collect(select(predict(model, mlpTestDF), "prediction")) + expect_equal(head(mlpPredictions$prediction, 10), + c("1.0", "1.0", "1.0", "1.0", "0.0", "1.0", "2.0", "2.0", "1.0", "0.0")) + + # Test illegal parameter + expect_error(spark.mlp(df, label ~ features, layers = NULL), + "layers must be a integer vector with length > 1.") + expect_error(spark.mlp(df, label ~ features, layers = c()), + "layers must be a integer vector with length > 1.") + expect_error(spark.mlp(df, label ~ features, layers = c(3)), + "layers must be a integer vector with length > 1.") + + # Test random seed + # default seed + model <- spark.mlp(df, label ~ features, layers = c(4, 5, 4, 3), maxIter = 10) + mlpPredictions <- collect(select(predict(model, mlpTestDF), "prediction")) + expect_equal(head(mlpPredictions$prediction, 10), + c("1.0", "1.0", "1.0", "1.0", "0.0", "1.0", "2.0", "2.0", "1.0", "0.0")) + # seed equals 10 + model <- spark.mlp(df, label ~ features, layers = c(4, 5, 4, 3), maxIter = 10, seed = 10) + mlpPredictions <- collect(select(predict(model, mlpTestDF), "prediction")) + expect_equal(head(mlpPredictions$prediction, 10), + c("1.0", "1.0", "1.0", "1.0", "0.0", "1.0", "2.0", "2.0", "1.0", "0.0")) + + # test initialWeights + model <- spark.mlp(df, label ~ features, layers = c(4, 3), maxIter = 2, initialWeights = + c(0, 0, 0, 0, 0, 5, 5, 5, 5, 5, 9, 9, 9, 9, 9)) + mlpPredictions <- collect(select(predict(model, mlpTestDF), "prediction")) + expect_equal(head(mlpPredictions$prediction, 10), + c("1.0", "1.0", "1.0", "1.0", "2.0", "1.0", "2.0", "2.0", "1.0", "0.0")) + + model <- spark.mlp(df, label ~ features, layers = c(4, 3), maxIter = 2, initialWeights = + c(0.0, 0.0, 0.0, 0.0, 0.0, 5.0, 5.0, 5.0, 5.0, 5.0, 9.0, 9.0, 9.0, 9.0, 9.0)) + mlpPredictions <- collect(select(predict(model, mlpTestDF), "prediction")) + expect_equal(head(mlpPredictions$prediction, 10), + c("1.0", "1.0", "1.0", "1.0", "2.0", "1.0", "2.0", "2.0", "1.0", "0.0")) + + model <- spark.mlp(df, label ~ features, layers = c(4, 3), maxIter = 2) + mlpPredictions <- collect(select(predict(model, mlpTestDF), "prediction")) + expect_equal(head(mlpPredictions$prediction, 10), + c("1.0", "1.0", "1.0", "1.0", "0.0", "1.0", "0.0", "2.0", "1.0", "0.0")) + + # Test formula works well + df <- suppressWarnings(createDataFrame(iris)) + model <- spark.mlp(df, Species ~ Sepal_Length + Sepal_Width + Petal_Length + Petal_Width, + layers = c(4, 3)) + summary <- summary(model) + expect_equal(summary$numOfInputs, 4) + expect_equal(summary$numOfOutputs, 3) + expect_equal(summary$layers, c(4, 3)) + expect_equal(length(summary$weights), 15) + expect_equal(head(summary$weights, 5), list(-1.1957257, -5.2693685, 7.4489734, -6.3751413, + -10.2376130), tolerance = 1e-6) +}) + +test_that("spark.naiveBayes", { + # R code to reproduce the result. + # We do not support instance weights yet. So we ignore the frequencies. + # + #' library(e1071) + #' t <- as.data.frame(Titanic) + #' t1 <- t[t$Freq > 0, -5] + #' m <- naiveBayes(Survived ~ ., data = t1) + #' m + #' predict(m, t1) + # + # -- output of 'm' + # + # A-priori probabilities: + # Y + # No Yes + # 0.4166667 0.5833333 + # + # Conditional probabilities: + # Class + # Y 1st 2nd 3rd Crew + # No 0.2000000 0.2000000 0.4000000 0.2000000 + # Yes 0.2857143 0.2857143 0.2857143 0.1428571 + # + # Sex + # Y Male Female + # No 0.5 0.5 + # Yes 0.5 0.5 + # + # Age + # Y Child Adult + # No 0.2000000 0.8000000 + # Yes 0.4285714 0.5714286 + # + # -- output of 'predict(m, t1)' + # + # Yes Yes Yes Yes No No Yes Yes No No Yes Yes Yes Yes Yes Yes Yes Yes No No Yes Yes No No + # + + t <- as.data.frame(Titanic) + t1 <- t[t$Freq > 0, -5] + df <- suppressWarnings(createDataFrame(t1)) + m <- spark.naiveBayes(df, Survived ~ ., smoothing = 0.0) + s <- summary(m) + expect_equal(as.double(s$apriori[1, "Yes"]), 0.5833333, tolerance = 1e-6) + expect_equal(sum(s$apriori), 1) + expect_equal(as.double(s$tables["Yes", "Age_Adult"]), 0.5714286, tolerance = 1e-6) + p <- collect(select(predict(m, df), "prediction")) + expect_equal(p$prediction, c("Yes", "Yes", "Yes", "Yes", "No", "No", "Yes", "Yes", "No", "No", + "Yes", "Yes", "Yes", "Yes", "Yes", "Yes", "Yes", "Yes", "No", "No", + "Yes", "Yes", "No", "No")) + + # Test model save/load + modelPath <- tempfile(pattern = "spark-naiveBayes", fileext = ".tmp") + write.ml(m, modelPath) + expect_error(write.ml(m, modelPath)) + write.ml(m, modelPath, overwrite = TRUE) + m2 <- read.ml(modelPath) + s2 <- summary(m2) + expect_equal(s$apriori, s2$apriori) + expect_equal(s$tables, s2$tables) + + unlink(modelPath) + + # Test e1071::naiveBayes + if (requireNamespace("e1071", quietly = TRUE)) { + expect_error(m <- e1071::naiveBayes(Survived ~ ., data = t1), NA) + expect_equal(as.character(predict(m, t1[1, ])), "Yes") + } + + # Test numeric response variable + t1$NumericSurvived <- ifelse(t1$Survived == "No", 0, 1) + t2 <- t1[-4] + df <- suppressWarnings(createDataFrame(t2)) + m <- spark.naiveBayes(df, NumericSurvived ~ ., smoothing = 0.0) + s <- summary(m) + expect_equal(as.double(s$apriori[1, 1]), 0.5833333, tolerance = 1e-6) + expect_equal(sum(s$apriori), 1) + expect_equal(as.double(s$tables[1, "Age_Adult"]), 0.5714286, tolerance = 1e-6) +}) + +sparkR.session.stop() diff --git a/R/pkg/inst/tests/testthat/test_mllib_clustering.R b/R/pkg/inst/tests/testthat/test_mllib_clustering.R new file mode 100644 index 0000000000000..1980fffd80cc6 --- /dev/null +++ b/R/pkg/inst/tests/testthat/test_mllib_clustering.R @@ -0,0 +1,224 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +library(testthat) + +context("MLlib clustering algorithms") + +# Tests for MLlib clustering algorithms in SparkR +sparkSession <- sparkR.session(enableHiveSupport = FALSE) + +absoluteSparkPath <- function(x) { + sparkHome <- sparkR.conf("spark.home") + file.path(sparkHome, x) +} + +test_that("spark.gaussianMixture", { + # R code to reproduce the result. + # nolint start + #' library(mvtnorm) + #' set.seed(1) + #' a <- rmvnorm(7, c(0, 0)) + #' b <- rmvnorm(8, c(10, 10)) + #' data <- rbind(a, b) + #' model <- mvnormalmixEM(data, k = 2) + #' model$lambda + # + # [1] 0.4666667 0.5333333 + # + #' model$mu + # + # [1] 0.11731091 -0.06192351 + # [1] 10.363673 9.897081 + # + #' model$sigma + # + # [[1]] + # [,1] [,2] + # [1,] 0.62049934 0.06880802 + # [2,] 0.06880802 1.27431874 + # + # [[2]] + # [,1] [,2] + # [1,] 0.2961543 0.160783 + # [2,] 0.1607830 1.008878 + # nolint end + data <- list(list(-0.6264538, 0.1836433), list(-0.8356286, 1.5952808), + list(0.3295078, -0.8204684), list(0.4874291, 0.7383247), + list(0.5757814, -0.3053884), list(1.5117812, 0.3898432), + list(-0.6212406, -2.2146999), list(11.1249309, 9.9550664), + list(9.9838097, 10.9438362), list(10.8212212, 10.5939013), + list(10.9189774, 10.7821363), list(10.0745650, 8.0106483), + list(10.6198257, 9.9438713), list(9.8442045, 8.5292476), + list(9.5218499, 10.4179416)) + df <- createDataFrame(data, c("x1", "x2")) + model <- spark.gaussianMixture(df, ~ x1 + x2, k = 2) + stats <- summary(model) + rLambda <- c(0.4666667, 0.5333333) + rMu <- c(0.11731091, -0.06192351, 10.363673, 9.897081) + rSigma <- c(0.62049934, 0.06880802, 0.06880802, 1.27431874, + 0.2961543, 0.160783, 0.1607830, 1.008878) + expect_equal(stats$lambda, rLambda, tolerance = 1e-3) + expect_equal(unlist(stats$mu), rMu, tolerance = 1e-3) + expect_equal(unlist(stats$sigma), rSigma, tolerance = 1e-3) + p <- collect(select(predict(model, df), "prediction")) + expect_equal(p$prediction, c(0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1)) + + # Test model save/load + modelPath <- tempfile(pattern = "spark-gaussianMixture", fileext = ".tmp") + write.ml(model, modelPath) + expect_error(write.ml(model, modelPath)) + write.ml(model, modelPath, overwrite = TRUE) + model2 <- read.ml(modelPath) + stats2 <- summary(model2) + expect_equal(stats$lambda, stats2$lambda) + expect_equal(unlist(stats$mu), unlist(stats2$mu)) + expect_equal(unlist(stats$sigma), unlist(stats2$sigma)) + + unlink(modelPath) +}) + +test_that("spark.kmeans", { + newIris <- iris + newIris$Species <- NULL + training <- suppressWarnings(createDataFrame(newIris)) + + take(training, 1) + + model <- spark.kmeans(data = training, ~ ., k = 2, maxIter = 10, initMode = "random") + sample <- take(select(predict(model, training), "prediction"), 1) + expect_equal(typeof(sample$prediction), "integer") + expect_equal(sample$prediction, 1) + + # Test stats::kmeans is working + statsModel <- kmeans(x = newIris, centers = 2) + expect_equal(sort(unique(statsModel$cluster)), c(1, 2)) + + # Test fitted works on KMeans + fitted.model <- fitted(model) + expect_equal(sort(collect(distinct(select(fitted.model, "prediction")))$prediction), c(0, 1)) + + # Test summary works on KMeans + summary.model <- summary(model) + cluster <- summary.model$cluster + k <- summary.model$k + expect_equal(k, 2) + expect_equal(sort(collect(distinct(select(cluster, "prediction")))$prediction), c(0, 1)) + + # Test model save/load + modelPath <- tempfile(pattern = "spark-kmeans", fileext = ".tmp") + write.ml(model, modelPath) + expect_error(write.ml(model, modelPath)) + write.ml(model, modelPath, overwrite = TRUE) + model2 <- read.ml(modelPath) + summary2 <- summary(model2) + expect_equal(sort(unlist(summary.model$size)), sort(unlist(summary2$size))) + expect_equal(summary.model$coefficients, summary2$coefficients) + expect_true(!summary.model$is.loaded) + expect_true(summary2$is.loaded) + + unlink(modelPath) +}) + +test_that("spark.lda with libsvm", { + text <- read.df(absoluteSparkPath("data/mllib/sample_lda_libsvm_data.txt"), source = "libsvm") + model <- spark.lda(text, optimizer = "em") + + stats <- summary(model, 10) + isDistributed <- stats$isDistributed + logLikelihood <- stats$logLikelihood + logPerplexity <- stats$logPerplexity + vocabSize <- stats$vocabSize + topics <- stats$topicTopTerms + weights <- stats$topicTopTermsWeights + vocabulary <- stats$vocabulary + + expect_false(isDistributed) + expect_true(logLikelihood <= 0 & is.finite(logLikelihood)) + expect_true(logPerplexity >= 0 & is.finite(logPerplexity)) + expect_equal(vocabSize, 11) + expect_true(is.null(vocabulary)) + + # Test model save/load + modelPath <- tempfile(pattern = "spark-lda", fileext = ".tmp") + write.ml(model, modelPath) + expect_error(write.ml(model, modelPath)) + write.ml(model, modelPath, overwrite = TRUE) + model2 <- read.ml(modelPath) + stats2 <- summary(model2) + + expect_false(stats2$isDistributed) + expect_equal(logLikelihood, stats2$logLikelihood) + expect_equal(logPerplexity, stats2$logPerplexity) + expect_equal(vocabSize, stats2$vocabSize) + expect_equal(vocabulary, stats2$vocabulary) + + unlink(modelPath) +}) + +test_that("spark.lda with text input", { + text <- read.text(absoluteSparkPath("data/mllib/sample_lda_data.txt")) + model <- spark.lda(text, optimizer = "online", features = "value") + + stats <- summary(model) + isDistributed <- stats$isDistributed + logLikelihood <- stats$logLikelihood + logPerplexity <- stats$logPerplexity + vocabSize <- stats$vocabSize + topics <- stats$topicTopTerms + weights <- stats$topicTopTermsWeights + vocabulary <- stats$vocabulary + + expect_false(isDistributed) + expect_true(logLikelihood <= 0 & is.finite(logLikelihood)) + expect_true(logPerplexity >= 0 & is.finite(logPerplexity)) + expect_equal(vocabSize, 10) + expect_true(setequal(stats$vocabulary, c("0", "1", "2", "3", "4", "5", "6", "7", "8", "9"))) + + # Test model save/load + modelPath <- tempfile(pattern = "spark-lda-text", fileext = ".tmp") + write.ml(model, modelPath) + expect_error(write.ml(model, modelPath)) + write.ml(model, modelPath, overwrite = TRUE) + model2 <- read.ml(modelPath) + stats2 <- summary(model2) + + expect_false(stats2$isDistributed) + expect_equal(logLikelihood, stats2$logLikelihood) + expect_equal(logPerplexity, stats2$logPerplexity) + expect_equal(vocabSize, stats2$vocabSize) + expect_true(all.equal(vocabulary, stats2$vocabulary)) + + unlink(modelPath) +}) + +test_that("spark.posterior and spark.perplexity", { + text <- read.text(absoluteSparkPath("data/mllib/sample_lda_data.txt")) + model <- spark.lda(text, features = "value", k = 3) + + # Assert perplexities are equal + stats <- summary(model) + logPerplexity <- spark.perplexity(model, text) + expect_equal(logPerplexity, stats$logPerplexity) + + # Assert the sum of every topic distribution is equal to 1 + posterior <- spark.posterior(model, text) + local.posterior <- collect(posterior)$topicDistribution + expect_equal(length(local.posterior), sum(unlist(local.posterior))) +}) + +sparkR.session.stop() diff --git a/R/pkg/inst/tests/testthat/test_mllib_recommendation.R b/R/pkg/inst/tests/testthat/test_mllib_recommendation.R new file mode 100644 index 0000000000000..6b1040db93050 --- /dev/null +++ b/R/pkg/inst/tests/testthat/test_mllib_recommendation.R @@ -0,0 +1,65 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +library(testthat) + +context("MLlib recommendation algorithms") + +# Tests for MLlib recommendation algorithms in SparkR +sparkSession <- sparkR.session(enableHiveSupport = FALSE) + +test_that("spark.als", { + data <- list(list(0, 0, 4.0), list(0, 1, 2.0), list(1, 1, 3.0), list(1, 2, 4.0), + list(2, 1, 1.0), list(2, 2, 5.0)) + df <- createDataFrame(data, c("user", "item", "score")) + model <- spark.als(df, ratingCol = "score", userCol = "user", itemCol = "item", + rank = 10, maxIter = 5, seed = 0, regParam = 0.1) + stats <- summary(model) + expect_equal(stats$rank, 10) + test <- createDataFrame(list(list(0, 2), list(1, 0), list(2, 0)), c("user", "item")) + predictions <- collect(predict(model, test)) + + expect_equal(predictions$prediction, c(-0.1380762, 2.6258414, -1.5018409), + tolerance = 1e-4) + + # Test model save/load + modelPath <- tempfile(pattern = "spark-als", fileext = ".tmp") + write.ml(model, modelPath) + expect_error(write.ml(model, modelPath)) + write.ml(model, modelPath, overwrite = TRUE) + model2 <- read.ml(modelPath) + stats2 <- summary(model2) + expect_equal(stats2$rating, "score") + userFactors <- collect(stats$userFactors) + itemFactors <- collect(stats$itemFactors) + userFactors2 <- collect(stats2$userFactors) + itemFactors2 <- collect(stats2$itemFactors) + + orderUser <- order(userFactors$id) + orderUser2 <- order(userFactors2$id) + expect_equal(userFactors$id[orderUser], userFactors2$id[orderUser2]) + expect_equal(userFactors$features[orderUser], userFactors2$features[orderUser2]) + + orderItem <- order(itemFactors$id) + orderItem2 <- order(itemFactors2$id) + expect_equal(itemFactors$id[orderItem], itemFactors2$id[orderItem2]) + expect_equal(itemFactors$features[orderItem], itemFactors2$features[orderItem2]) + + unlink(modelPath) +}) + +sparkR.session.stop() diff --git a/R/pkg/inst/tests/testthat/test_mllib_regression.R b/R/pkg/inst/tests/testthat/test_mllib_regression.R new file mode 100644 index 0000000000000..e20dafa4147af --- /dev/null +++ b/R/pkg/inst/tests/testthat/test_mllib_regression.R @@ -0,0 +1,417 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +library(testthat) + +context("MLlib regression algorithms, except for tree-based algorithms") + +# Tests for MLlib regression algorithms in SparkR +sparkSession <- sparkR.session(enableHiveSupport = FALSE) + +test_that("formula of spark.glm", { + training <- suppressWarnings(createDataFrame(iris)) + # directly calling the spark API + # dot minus and intercept vs native glm + model <- spark.glm(training, Sepal_Width ~ . - Species + 0) + vals <- collect(select(predict(model, training), "prediction")) + rVals <- predict(glm(Sepal.Width ~ . - Species + 0, data = iris), iris) + expect_true(all(abs(rVals - vals) < 1e-6), rVals - vals) + + # feature interaction vs native glm + model <- spark.glm(training, Sepal_Width ~ Species:Sepal_Length) + vals <- collect(select(predict(model, training), "prediction")) + rVals <- predict(glm(Sepal.Width ~ Species:Sepal.Length, data = iris), iris) + expect_true(all(abs(rVals - vals) < 1e-6), rVals - vals) + + # glm should work with long formula + training <- suppressWarnings(createDataFrame(iris)) + training$LongLongLongLongLongName <- training$Sepal_Width + training$VeryLongLongLongLonLongName <- training$Sepal_Length + training$AnotherLongLongLongLongName <- training$Species + model <- spark.glm(training, LongLongLongLongLongName ~ VeryLongLongLongLonLongName + + AnotherLongLongLongLongName) + vals <- collect(select(predict(model, training), "prediction")) + rVals <- predict(glm(Sepal.Width ~ Sepal.Length + Species, data = iris), iris) + expect_true(all(abs(rVals - vals) < 1e-6), rVals - vals) +}) + +test_that("spark.glm and predict", { + training <- suppressWarnings(createDataFrame(iris)) + # gaussian family + model <- spark.glm(training, Sepal_Width ~ Sepal_Length + Species) + prediction <- predict(model, training) + expect_equal(typeof(take(select(prediction, "prediction"), 1)$prediction), "double") + vals <- collect(select(prediction, "prediction")) + rVals <- predict(glm(Sepal.Width ~ Sepal.Length + Species, data = iris), iris) + expect_true(all(abs(rVals - vals) < 1e-6), rVals - vals) + + # poisson family + model <- spark.glm(training, Sepal_Width ~ Sepal_Length + Species, + family = poisson(link = identity)) + prediction <- predict(model, training) + expect_equal(typeof(take(select(prediction, "prediction"), 1)$prediction), "double") + vals <- collect(select(prediction, "prediction")) + rVals <- suppressWarnings(predict(glm(Sepal.Width ~ Sepal.Length + Species, + data = iris, family = poisson(link = identity)), iris)) + expect_true(all(abs(rVals - vals) < 1e-6), rVals - vals) + + # Test stats::predict is working + x <- rnorm(15) + y <- x + rnorm(15) + expect_equal(length(predict(lm(y ~ x))), 15) +}) + +test_that("spark.glm summary", { + # gaussian family + training <- suppressWarnings(createDataFrame(iris)) + stats <- summary(spark.glm(training, Sepal_Width ~ Sepal_Length + Species)) + + rStats <- summary(glm(Sepal.Width ~ Sepal.Length + Species, data = iris)) + + coefs <- unlist(stats$coefficients) + rCoefs <- unlist(rStats$coefficients) + expect_true(all(abs(rCoefs - coefs) < 1e-4)) + expect_true(all( + rownames(stats$coefficients) == + c("(Intercept)", "Sepal_Length", "Species_versicolor", "Species_virginica"))) + expect_equal(stats$dispersion, rStats$dispersion) + expect_equal(stats$null.deviance, rStats$null.deviance) + expect_equal(stats$deviance, rStats$deviance) + expect_equal(stats$df.null, rStats$df.null) + expect_equal(stats$df.residual, rStats$df.residual) + expect_equal(stats$aic, rStats$aic) + + out <- capture.output(print(stats)) + expect_match(out[2], "Deviance Residuals:") + expect_true(any(grepl("AIC: 59.22", out))) + + # binomial family + df <- suppressWarnings(createDataFrame(iris)) + training <- df[df$Species %in% c("versicolor", "virginica"), ] + stats <- summary(spark.glm(training, Species ~ Sepal_Length + Sepal_Width, + family = binomial(link = "logit"))) + + rTraining <- iris[iris$Species %in% c("versicolor", "virginica"), ] + rStats <- summary(glm(Species ~ Sepal.Length + Sepal.Width, data = rTraining, + family = binomial(link = "logit"))) + + coefs <- unlist(stats$coefficients) + rCoefs <- unlist(rStats$coefficients) + expect_true(all(abs(rCoefs - coefs) < 1e-4)) + expect_true(all( + rownames(stats$coefficients) == + c("(Intercept)", "Sepal_Length", "Sepal_Width"))) + expect_equal(stats$dispersion, rStats$dispersion) + expect_equal(stats$null.deviance, rStats$null.deviance) + expect_equal(stats$deviance, rStats$deviance) + expect_equal(stats$df.null, rStats$df.null) + expect_equal(stats$df.residual, rStats$df.residual) + expect_equal(stats$aic, rStats$aic) + + # Test spark.glm works with weighted dataset + a1 <- c(0, 1, 2, 3) + a2 <- c(5, 2, 1, 3) + w <- c(1, 2, 3, 4) + b <- c(1, 0, 1, 0) + data <- as.data.frame(cbind(a1, a2, w, b)) + df <- createDataFrame(data) + + stats <- summary(spark.glm(df, b ~ a1 + a2, family = "binomial", weightCol = "w")) + rStats <- summary(glm(b ~ a1 + a2, family = "binomial", data = data, weights = w)) + + coefs <- unlist(stats$coefficients) + rCoefs <- unlist(rStats$coefficients) + expect_true(all(abs(rCoefs - coefs) < 1e-3)) + expect_true(all(rownames(stats$coefficients) == c("(Intercept)", "a1", "a2"))) + expect_equal(stats$dispersion, rStats$dispersion) + expect_equal(stats$null.deviance, rStats$null.deviance) + expect_equal(stats$deviance, rStats$deviance) + expect_equal(stats$df.null, rStats$df.null) + expect_equal(stats$df.residual, rStats$df.residual) + expect_equal(stats$aic, rStats$aic) + + # Test summary works on base GLM models + baseModel <- stats::glm(Sepal.Width ~ Sepal.Length + Species, data = iris) + baseSummary <- summary(baseModel) + expect_true(abs(baseSummary$deviance - 12.19313) < 1e-4) + + # Test spark.glm works with regularization parameter + data <- as.data.frame(cbind(a1, a2, b)) + df <- suppressWarnings(createDataFrame(data)) + regStats <- summary(spark.glm(df, b ~ a1 + a2, regParam = 1.0)) + expect_equal(regStats$aic, 13.32836, tolerance = 1e-4) # 13.32836 is from summary() result + + # Test spark.glm works on collinear data + A <- matrix(c(1, 2, 3, 4, 2, 4, 6, 8), 4, 2) + b <- c(1, 2, 3, 4) + data <- as.data.frame(cbind(A, b)) + df <- createDataFrame(data) + stats <- summary(spark.glm(df, b ~ . - 1)) + coefs <- unlist(stats$coefficients) + expect_true(all(abs(c(0.5, 0.25) - coefs) < 1e-4)) +}) + +test_that("spark.glm save/load", { + training <- suppressWarnings(createDataFrame(iris)) + m <- spark.glm(training, Sepal_Width ~ Sepal_Length + Species) + s <- summary(m) + + modelPath <- tempfile(pattern = "spark-glm", fileext = ".tmp") + write.ml(m, modelPath) + expect_error(write.ml(m, modelPath)) + write.ml(m, modelPath, overwrite = TRUE) + m2 <- read.ml(modelPath) + s2 <- summary(m2) + + expect_equal(s$coefficients, s2$coefficients) + expect_equal(rownames(s$coefficients), rownames(s2$coefficients)) + expect_equal(s$dispersion, s2$dispersion) + expect_equal(s$null.deviance, s2$null.deviance) + expect_equal(s$deviance, s2$deviance) + expect_equal(s$df.null, s2$df.null) + expect_equal(s$df.residual, s2$df.residual) + expect_equal(s$aic, s2$aic) + expect_equal(s$iter, s2$iter) + expect_true(!s$is.loaded) + expect_true(s2$is.loaded) + + unlink(modelPath) +}) + +test_that("formula of glm", { + training <- suppressWarnings(createDataFrame(iris)) + # dot minus and intercept vs native glm + model <- glm(Sepal_Width ~ . - Species + 0, data = training) + vals <- collect(select(predict(model, training), "prediction")) + rVals <- predict(glm(Sepal.Width ~ . - Species + 0, data = iris), iris) + expect_true(all(abs(rVals - vals) < 1e-6), rVals - vals) + + # feature interaction vs native glm + model <- glm(Sepal_Width ~ Species:Sepal_Length, data = training) + vals <- collect(select(predict(model, training), "prediction")) + rVals <- predict(glm(Sepal.Width ~ Species:Sepal.Length, data = iris), iris) + expect_true(all(abs(rVals - vals) < 1e-6), rVals - vals) + + # glm should work with long formula + training <- suppressWarnings(createDataFrame(iris)) + training$LongLongLongLongLongName <- training$Sepal_Width + training$VeryLongLongLongLonLongName <- training$Sepal_Length + training$AnotherLongLongLongLongName <- training$Species + model <- glm(LongLongLongLongLongName ~ VeryLongLongLongLonLongName + AnotherLongLongLongLongName, + data = training) + vals <- collect(select(predict(model, training), "prediction")) + rVals <- predict(glm(Sepal.Width ~ Sepal.Length + Species, data = iris), iris) + expect_true(all(abs(rVals - vals) < 1e-6), rVals - vals) +}) + +test_that("glm and predict", { + training <- suppressWarnings(createDataFrame(iris)) + # gaussian family + model <- glm(Sepal_Width ~ Sepal_Length + Species, data = training) + prediction <- predict(model, training) + expect_equal(typeof(take(select(prediction, "prediction"), 1)$prediction), "double") + vals <- collect(select(prediction, "prediction")) + rVals <- predict(glm(Sepal.Width ~ Sepal.Length + Species, data = iris), iris) + expect_true(all(abs(rVals - vals) < 1e-6), rVals - vals) + + # poisson family + model <- glm(Sepal_Width ~ Sepal_Length + Species, data = training, + family = poisson(link = identity)) + prediction <- predict(model, training) + expect_equal(typeof(take(select(prediction, "prediction"), 1)$prediction), "double") + vals <- collect(select(prediction, "prediction")) + rVals <- suppressWarnings(predict(glm(Sepal.Width ~ Sepal.Length + Species, + data = iris, family = poisson(link = identity)), iris)) + expect_true(all(abs(rVals - vals) < 1e-6), rVals - vals) + + # Test stats::predict is working + x <- rnorm(15) + y <- x + rnorm(15) + expect_equal(length(predict(lm(y ~ x))), 15) +}) + +test_that("glm summary", { + # gaussian family + training <- suppressWarnings(createDataFrame(iris)) + stats <- summary(glm(Sepal_Width ~ Sepal_Length + Species, data = training)) + + rStats <- summary(glm(Sepal.Width ~ Sepal.Length + Species, data = iris)) + + coefs <- unlist(stats$coefficients) + rCoefs <- unlist(rStats$coefficients) + expect_true(all(abs(rCoefs - coefs) < 1e-4)) + expect_true(all( + rownames(stats$coefficients) == + c("(Intercept)", "Sepal_Length", "Species_versicolor", "Species_virginica"))) + expect_equal(stats$dispersion, rStats$dispersion) + expect_equal(stats$null.deviance, rStats$null.deviance) + expect_equal(stats$deviance, rStats$deviance) + expect_equal(stats$df.null, rStats$df.null) + expect_equal(stats$df.residual, rStats$df.residual) + expect_equal(stats$aic, rStats$aic) + + # binomial family + df <- suppressWarnings(createDataFrame(iris)) + training <- df[df$Species %in% c("versicolor", "virginica"), ] + stats <- summary(glm(Species ~ Sepal_Length + Sepal_Width, data = training, + family = binomial(link = "logit"))) + + rTraining <- iris[iris$Species %in% c("versicolor", "virginica"), ] + rStats <- summary(glm(Species ~ Sepal.Length + Sepal.Width, data = rTraining, + family = binomial(link = "logit"))) + + coefs <- unlist(stats$coefficients) + rCoefs <- unlist(rStats$coefficients) + expect_true(all(abs(rCoefs - coefs) < 1e-4)) + expect_true(all( + rownames(stats$coefficients) == + c("(Intercept)", "Sepal_Length", "Sepal_Width"))) + expect_equal(stats$dispersion, rStats$dispersion) + expect_equal(stats$null.deviance, rStats$null.deviance) + expect_equal(stats$deviance, rStats$deviance) + expect_equal(stats$df.null, rStats$df.null) + expect_equal(stats$df.residual, rStats$df.residual) + expect_equal(stats$aic, rStats$aic) + + # Test summary works on base GLM models + baseModel <- stats::glm(Sepal.Width ~ Sepal.Length + Species, data = iris) + baseSummary <- summary(baseModel) + expect_true(abs(baseSummary$deviance - 12.19313) < 1e-4) +}) + +test_that("glm save/load", { + training <- suppressWarnings(createDataFrame(iris)) + m <- glm(Sepal_Width ~ Sepal_Length + Species, data = training) + s <- summary(m) + + modelPath <- tempfile(pattern = "glm", fileext = ".tmp") + write.ml(m, modelPath) + expect_error(write.ml(m, modelPath)) + write.ml(m, modelPath, overwrite = TRUE) + m2 <- read.ml(modelPath) + s2 <- summary(m2) + + expect_equal(s$coefficients, s2$coefficients) + expect_equal(rownames(s$coefficients), rownames(s2$coefficients)) + expect_equal(s$dispersion, s2$dispersion) + expect_equal(s$null.deviance, s2$null.deviance) + expect_equal(s$deviance, s2$deviance) + expect_equal(s$df.null, s2$df.null) + expect_equal(s$df.residual, s2$df.residual) + expect_equal(s$aic, s2$aic) + expect_equal(s$iter, s2$iter) + expect_true(!s$is.loaded) + expect_true(s2$is.loaded) + + unlink(modelPath) +}) + +test_that("spark.isoreg", { + label <- c(7.0, 5.0, 3.0, 5.0, 1.0) + feature <- c(0.0, 1.0, 2.0, 3.0, 4.0) + weight <- c(1.0, 1.0, 1.0, 1.0, 1.0) + data <- as.data.frame(cbind(label, feature, weight)) + df <- createDataFrame(data) + + model <- spark.isoreg(df, label ~ feature, isotonic = FALSE, + weightCol = "weight") + # only allow one variable on the right hand side of the formula + expect_error(model2 <- spark.isoreg(df, ~., isotonic = FALSE)) + result <- summary(model) + expect_equal(result$predictions, list(7, 5, 4, 4, 1)) + + # Test model prediction + predict_data <- list(list(-2.0), list(-1.0), list(0.5), + list(0.75), list(1.0), list(2.0), list(9.0)) + predict_df <- createDataFrame(predict_data, c("feature")) + predict_result <- collect(select(predict(model, predict_df), "prediction")) + expect_equal(predict_result$prediction, c(7.0, 7.0, 6.0, 5.5, 5.0, 4.0, 1.0)) + + # Test model save/load + modelPath <- tempfile(pattern = "spark-isoreg", fileext = ".tmp") + write.ml(model, modelPath) + expect_error(write.ml(model, modelPath)) + write.ml(model, modelPath, overwrite = TRUE) + model2 <- read.ml(modelPath) + expect_equal(result, summary(model2)) + + unlink(modelPath) +}) + +test_that("spark.survreg", { + # R code to reproduce the result. + # + #' rData <- list(time = c(4, 3, 1, 1, 2, 2, 3), status = c(1, 1, 1, 0, 1, 1, 0), + #' x = c(0, 2, 1, 1, 1, 0, 0), sex = c(0, 0, 0, 0, 1, 1, 1)) + #' library(survival) + #' model <- survreg(Surv(time, status) ~ x + sex, rData) + #' summary(model) + #' predict(model, data) + # + # -- output of 'summary(model)' + # + # Value Std. Error z p + # (Intercept) 1.315 0.270 4.88 1.07e-06 + # x -0.190 0.173 -1.10 2.72e-01 + # sex -0.253 0.329 -0.77 4.42e-01 + # Log(scale) -1.160 0.396 -2.93 3.41e-03 + # + # -- output of 'predict(model, data)' + # + # 1 2 3 4 5 6 7 + # 3.724591 2.545368 3.079035 3.079035 2.390146 2.891269 2.891269 + # + data <- list(list(4, 1, 0, 0), list(3, 1, 2, 0), list(1, 1, 1, 0), + list(1, 0, 1, 0), list(2, 1, 1, 1), list(2, 1, 0, 1), list(3, 0, 0, 1)) + df <- createDataFrame(data, c("time", "status", "x", "sex")) + model <- spark.survreg(df, Surv(time, status) ~ x + sex) + stats <- summary(model) + coefs <- as.vector(stats$coefficients[, 1]) + rCoefs <- c(1.3149571, -0.1903409, -0.2532618, -1.1599800) + expect_equal(coefs, rCoefs, tolerance = 1e-4) + expect_true(all( + rownames(stats$coefficients) == + c("(Intercept)", "x", "sex", "Log(scale)"))) + p <- collect(select(predict(model, df), "prediction")) + expect_equal(p$prediction, c(3.724591, 2.545368, 3.079035, 3.079035, + 2.390146, 2.891269, 2.891269), tolerance = 1e-4) + + # Test model save/load + modelPath <- tempfile(pattern = "spark-survreg", fileext = ".tmp") + write.ml(model, modelPath) + expect_error(write.ml(model, modelPath)) + write.ml(model, modelPath, overwrite = TRUE) + model2 <- read.ml(modelPath) + stats2 <- summary(model2) + coefs2 <- as.vector(stats2$coefficients[, 1]) + expect_equal(coefs, coefs2) + expect_equal(rownames(stats$coefficients), rownames(stats2$coefficients)) + + unlink(modelPath) + + # Test survival::survreg + if (requireNamespace("survival", quietly = TRUE)) { + rData <- list(time = c(4, 3, 1, 1, 2, 2, 3), status = c(1, 1, 1, 0, 1, 1, 0), + x = c(0, 2, 1, 1, 1, 0, 0), sex = c(0, 0, 0, 0, 1, 1, 1)) + expect_error( + model <- survival::survreg(formula = survival::Surv(time, status) ~ x + sex, data = rData), + NA) + expect_equal(predict(model, rData)[[1]], 3.724591, tolerance = 1e-4) + } +}) + +sparkR.session.stop() diff --git a/R/pkg/inst/tests/testthat/test_mllib_stat.R b/R/pkg/inst/tests/testthat/test_mllib_stat.R new file mode 100644 index 0000000000000..beb148e7702fd --- /dev/null +++ b/R/pkg/inst/tests/testthat/test_mllib_stat.R @@ -0,0 +1,53 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +library(testthat) + +context("MLlib statistics algorithms") + +# Tests for MLlib statistics algorithms in SparkR +sparkSession <- sparkR.session(enableHiveSupport = FALSE) + +test_that("spark.kstest", { + data <- data.frame(test = c(0.1, 0.15, 0.2, 0.3, 0.25, -1, -0.5)) + df <- createDataFrame(data) + testResult <- spark.kstest(df, "test", "norm") + stats <- summary(testResult) + + rStats <- ks.test(data$test, "pnorm", alternative = "two.sided") + + expect_equal(stats$p.value, rStats$p.value, tolerance = 1e-4) + expect_equal(stats$statistic, unname(rStats$statistic), tolerance = 1e-4) + expect_match(capture.output(stats)[1], "Kolmogorov-Smirnov test summary:") + + testResult <- spark.kstest(df, "test", "norm", -0.5) + stats <- summary(testResult) + + rStats <- ks.test(data$test, "pnorm", -0.5, 1, alternative = "two.sided") + + expect_equal(stats$p.value, rStats$p.value, tolerance = 1e-4) + expect_equal(stats$statistic, unname(rStats$statistic), tolerance = 1e-4) + expect_match(capture.output(stats)[1], "Kolmogorov-Smirnov test summary:") + + # Test print.summary.KSTest + printStats <- capture.output(print.summary.KSTest(stats)) + expect_match(printStats[1], "Kolmogorov-Smirnov test summary:") + expect_match(printStats[5], + "Low presumption against null hypothesis: Sample follows theoretical distribution. ") +}) + +sparkR.session.stop() diff --git a/R/pkg/inst/tests/testthat/test_mllib_tree.R b/R/pkg/inst/tests/testthat/test_mllib_tree.R new file mode 100644 index 0000000000000..5d13539be8a83 --- /dev/null +++ b/R/pkg/inst/tests/testthat/test_mllib_tree.R @@ -0,0 +1,203 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +library(testthat) + +context("MLlib tree-based algorithms") + +# Tests for MLlib tree-based algorithms in SparkR +sparkSession <- sparkR.session(enableHiveSupport = FALSE) + +absoluteSparkPath <- function(x) { + sparkHome <- sparkR.conf("spark.home") + file.path(sparkHome, x) +} + +test_that("spark.gbt", { + # regression + data <- suppressWarnings(createDataFrame(longley)) + model <- spark.gbt(data, Employed ~ ., "regression", maxDepth = 5, maxBins = 16, seed = 123) + predictions <- collect(predict(model, data)) + expect_equal(predictions$prediction, c(60.323, 61.122, 60.171, 61.187, + 63.221, 63.639, 64.989, 63.761, + 66.019, 67.857, 68.169, 66.513, + 68.655, 69.564, 69.331, 70.551), + tolerance = 1e-4) + stats <- summary(model) + expect_equal(stats$numTrees, 20) + expect_equal(stats$formula, "Employed ~ .") + expect_equal(stats$numFeatures, 6) + expect_equal(length(stats$treeWeights), 20) + + modelPath <- tempfile(pattern = "spark-gbtRegression", fileext = ".tmp") + write.ml(model, modelPath) + expect_error(write.ml(model, modelPath)) + write.ml(model, modelPath, overwrite = TRUE) + model2 <- read.ml(modelPath) + stats2 <- summary(model2) + expect_equal(stats$formula, stats2$formula) + expect_equal(stats$numFeatures, stats2$numFeatures) + expect_equal(stats$features, stats2$features) + expect_equal(stats$featureImportances, stats2$featureImportances) + expect_equal(stats$numTrees, stats2$numTrees) + expect_equal(stats$treeWeights, stats2$treeWeights) + + unlink(modelPath) + + # classification + # label must be binary - GBTClassifier currently only supports binary classification. + iris2 <- iris[iris$Species != "virginica", ] + data <- suppressWarnings(createDataFrame(iris2)) + model <- spark.gbt(data, Species ~ Petal_Length + Petal_Width, "classification") + stats <- summary(model) + expect_equal(stats$numFeatures, 2) + expect_equal(stats$numTrees, 20) + expect_error(capture.output(stats), NA) + expect_true(length(capture.output(stats)) > 6) + predictions <- collect(predict(model, data))$prediction + # test string prediction values + expect_equal(length(grep("setosa", predictions)), 50) + expect_equal(length(grep("versicolor", predictions)), 50) + + modelPath <- tempfile(pattern = "spark-gbtClassification", fileext = ".tmp") + write.ml(model, modelPath) + expect_error(write.ml(model, modelPath)) + write.ml(model, modelPath, overwrite = TRUE) + model2 <- read.ml(modelPath) + stats2 <- summary(model2) + expect_equal(stats$depth, stats2$depth) + expect_equal(stats$numNodes, stats2$numNodes) + expect_equal(stats$numClasses, stats2$numClasses) + + unlink(modelPath) + + iris2$NumericSpecies <- ifelse(iris2$Species == "setosa", 0, 1) + df <- suppressWarnings(createDataFrame(iris2)) + m <- spark.gbt(df, NumericSpecies ~ ., type = "classification") + s <- summary(m) + # test numeric prediction values + expect_equal(iris2$NumericSpecies, as.double(collect(predict(m, df))$prediction)) + expect_equal(s$numFeatures, 5) + expect_equal(s$numTrees, 20) + + # spark.gbt classification can work on libsvm data + data <- read.df(absoluteSparkPath("data/mllib/sample_binary_classification_data.txt"), + source = "libsvm") + model <- spark.gbt(data, label ~ features, "classification") + expect_equal(summary(model)$numFeatures, 692) +}) + +test_that("spark.randomForest", { + # regression + data <- suppressWarnings(createDataFrame(longley)) + model <- spark.randomForest(data, Employed ~ ., "regression", maxDepth = 5, maxBins = 16, + numTrees = 1) + + predictions <- collect(predict(model, data)) + expect_equal(predictions$prediction, c(60.323, 61.122, 60.171, 61.187, + 63.221, 63.639, 64.989, 63.761, + 66.019, 67.857, 68.169, 66.513, + 68.655, 69.564, 69.331, 70.551), + tolerance = 1e-4) + + stats <- summary(model) + expect_equal(stats$numTrees, 1) + expect_error(capture.output(stats), NA) + expect_true(length(capture.output(stats)) > 6) + + model <- spark.randomForest(data, Employed ~ ., "regression", maxDepth = 5, maxBins = 16, + numTrees = 20, seed = 123) + predictions <- collect(predict(model, data)) + expect_equal(predictions$prediction, c(60.32820, 61.22315, 60.69025, 62.11070, + 63.53160, 64.05470, 65.12710, 64.30450, + 66.70910, 67.86125, 68.08700, 67.21865, + 68.89275, 69.53180, 69.39640, 69.68250), + + tolerance = 1e-4) + stats <- summary(model) + expect_equal(stats$numTrees, 20) + + modelPath <- tempfile(pattern = "spark-randomForestRegression", fileext = ".tmp") + write.ml(model, modelPath) + expect_error(write.ml(model, modelPath)) + write.ml(model, modelPath, overwrite = TRUE) + model2 <- read.ml(modelPath) + stats2 <- summary(model2) + expect_equal(stats$formula, stats2$formula) + expect_equal(stats$numFeatures, stats2$numFeatures) + expect_equal(stats$features, stats2$features) + expect_equal(stats$featureImportances, stats2$featureImportances) + expect_equal(stats$numTrees, stats2$numTrees) + expect_equal(stats$treeWeights, stats2$treeWeights) + + unlink(modelPath) + + # classification + data <- suppressWarnings(createDataFrame(iris)) + model <- spark.randomForest(data, Species ~ Petal_Length + Petal_Width, "classification", + maxDepth = 5, maxBins = 16) + + stats <- summary(model) + expect_equal(stats$numFeatures, 2) + expect_equal(stats$numTrees, 20) + expect_error(capture.output(stats), NA) + expect_true(length(capture.output(stats)) > 6) + # Test string prediction values + predictions <- collect(predict(model, data))$prediction + expect_equal(length(grep("setosa", predictions)), 50) + expect_equal(length(grep("versicolor", predictions)), 50) + + modelPath <- tempfile(pattern = "spark-randomForestClassification", fileext = ".tmp") + write.ml(model, modelPath) + expect_error(write.ml(model, modelPath)) + write.ml(model, modelPath, overwrite = TRUE) + model2 <- read.ml(modelPath) + stats2 <- summary(model2) + expect_equal(stats$depth, stats2$depth) + expect_equal(stats$numNodes, stats2$numNodes) + expect_equal(stats$numClasses, stats2$numClasses) + + unlink(modelPath) + + # Test numeric response variable + labelToIndex <- function(species) { + switch(as.character(species), + setosa = 0.0, + versicolor = 1.0, + virginica = 2.0 + ) + } + iris$NumericSpecies <- lapply(iris$Species, labelToIndex) + data <- suppressWarnings(createDataFrame(iris[-5])) + model <- spark.randomForest(data, NumericSpecies ~ Petal_Length + Petal_Width, "classification", + maxDepth = 5, maxBins = 16) + stats <- summary(model) + expect_equal(stats$numFeatures, 2) + expect_equal(stats$numTrees, 20) + # Test numeric prediction values + predictions <- collect(predict(model, data))$prediction + expect_equal(length(grep("1.0", predictions)), 50) + expect_equal(length(grep("2.0", predictions)), 50) + + # spark.randomForest classification can work on libsvm data + data <- read.df(absoluteSparkPath("data/mllib/sample_multiclass_classification_data.txt"), + source = "libsvm") + model <- spark.randomForest(data, label ~ features, "classification") + expect_equal(summary(model)$numFeatures, 4) +}) + +sparkR.session.stop() diff --git a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorter.java b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorter.java index 5b42843717e90..f219c5605b643 100644 --- a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorter.java +++ b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorter.java @@ -86,7 +86,7 @@ public int compare(RecordPointerAndKeyPrefix r1, RecordPointerAndKeyPrefix r2) { private final PrefixComparators.RadixSortSupport radixSortSupport; /** - * Within this buffer, position {@code 2 * i} holds a pointer pointer to the record at + * Within this buffer, position {@code 2 * i} holds a pointer to the record at * index {@code i}, while position {@code 2 * i + 1} in the array holds an 8-byte key prefix. * * Only part of the array will be used to store the pointers, the rest part is preserved as diff --git a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSortDataFormat.java b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSortDataFormat.java index 430bf677edbdf..d9f84d10e9051 100644 --- a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSortDataFormat.java +++ b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSortDataFormat.java @@ -25,7 +25,7 @@ * Supports sorting an array of (record pointer, key prefix) pairs. * Used in {@link UnsafeInMemorySorter}. *

- * Within each long[] buffer, position {@code 2 * i} holds a pointer pointer to the record at + * Within each long[] buffer, position {@code 2 * i} holds a pointer to the record at * index {@code i}, while position {@code 2 * i + 1} in the array holds an 8-byte key prefix. */ public final class UnsafeSortDataFormat diff --git a/core/src/main/scala/org/apache/spark/MapOutputTracker.scala b/core/src/main/scala/org/apache/spark/MapOutputTracker.scala index 6f5c31d7ab71c..4ca442b629fd9 100644 --- a/core/src/main/scala/org/apache/spark/MapOutputTracker.scala +++ b/core/src/main/scala/org/apache/spark/MapOutputTracker.scala @@ -317,7 +317,7 @@ private[spark] class MapOutputTrackerMaster(conf: SparkConf, pool } - // Make sure that that we aren't going to exceed the max RPC message size by making sure + // Make sure that we aren't going to exceed the max RPC message size by making sure // we use broadcast to send large map output statuses. if (minSizeForBroadcast > maxRpcMessageSize) { val msg = s"spark.shuffle.mapOutput.minSizeForBroadcast ($minSizeForBroadcast bytes) must " + diff --git a/core/src/main/scala/org/apache/spark/TaskEndReason.scala b/core/src/main/scala/org/apache/spark/TaskEndReason.scala index 7745387dbceba..8c1b5f7bf0d9b 100644 --- a/core/src/main/scala/org/apache/spark/TaskEndReason.scala +++ b/core/src/main/scala/org/apache/spark/TaskEndReason.scala @@ -98,7 +98,7 @@ case class FetchFailed( * 4 task failures, instead we immediately go back to the stage which generated the map output, * and regenerate the missing data. (2) we don't count fetch failures for blacklisting, since * presumably its not the fault of the executor where the task ran, but the executor which - * stored the data. This is especially important because we we might rack up a bunch of + * stored the data. This is especially important because we might rack up a bunch of * fetch-failures in rapid succession, on all nodes of the cluster, due to one bad node. */ override def countTowardsTaskFailures: Boolean = false diff --git a/core/src/main/scala/org/apache/spark/deploy/FaultToleranceTest.scala b/core/src/main/scala/org/apache/spark/deploy/FaultToleranceTest.scala index 79f4d06c8460e..320af5cf97550 100644 --- a/core/src/main/scala/org/apache/spark/deploy/FaultToleranceTest.scala +++ b/core/src/main/scala/org/apache/spark/deploy/FaultToleranceTest.scala @@ -43,7 +43,7 @@ import org.apache.spark.util.{ThreadUtils, Utils} * Execute using * ./bin/spark-class org.apache.spark.deploy.FaultToleranceTest * - * Make sure that that the environment includes the following properties in SPARK_DAEMON_JAVA_OPTS + * Make sure that the environment includes the following properties in SPARK_DAEMON_JAVA_OPTS * *and* SPARK_JAVA_OPTS: * - spark.deploy.recoveryMode=ZOOKEEPER * - spark.deploy.zookeeper.url=172.17.42.1:2181 diff --git a/core/src/main/scala/org/apache/spark/deploy/history/FsHistoryProvider.scala b/core/src/main/scala/org/apache/spark/deploy/history/FsHistoryProvider.scala index 3011ed0f95d8e..cd241d6d22745 100644 --- a/core/src/main/scala/org/apache/spark/deploy/history/FsHistoryProvider.scala +++ b/core/src/main/scala/org/apache/spark/deploy/history/FsHistoryProvider.scala @@ -97,6 +97,13 @@ private[history] class FsHistoryProvider(conf: SparkConf, clock: Clock) .map { d => Utils.resolveURI(d).toString } .getOrElse(DEFAULT_LOG_DIR) + private val HISTORY_UI_ACLS_ENABLE = conf.getBoolean("spark.history.ui.acls.enable", false) + private val HISTORY_UI_ADMIN_ACLS = conf.get("spark.history.ui.admin.acls", "") + private val HISTORY_UI_ADMIN_ACLS_GROUPS = conf.get("spark.history.ui.admin.acls.groups", "") + logInfo(s"History server ui acls " + (if (HISTORY_UI_ACLS_ENABLE) "enabled" else "disabled") + + "; users with admin permissions: " + HISTORY_UI_ADMIN_ACLS.toString + + "; groups with admin permissions" + HISTORY_UI_ADMIN_ACLS_GROUPS.toString) + private val hadoopConf = SparkHadoopUtil.get.newConfiguration(conf) private val fs = Utils.getHadoopFileSystem(logDir, hadoopConf) @@ -250,13 +257,14 @@ private[history] class FsHistoryProvider(conf: SparkConf, clock: Clock) val appListener = replay(fileStatus, isApplicationCompleted(fileStatus), replayBus) if (appListener.appId.isDefined) { - val uiAclsEnabled = conf.getBoolean("spark.history.ui.acls.enable", false) - ui.getSecurityManager.setAcls(uiAclsEnabled) + ui.getSecurityManager.setAcls(HISTORY_UI_ACLS_ENABLE) // make sure to set admin acls before view acls so they are properly picked up - ui.getSecurityManager.setAdminAcls(appListener.adminAcls.getOrElse("")) - ui.getSecurityManager.setViewAcls(attempt.sparkUser, - appListener.viewAcls.getOrElse("")) - ui.getSecurityManager.setAdminAclsGroups(appListener.adminAclsGroups.getOrElse("")) + val adminAcls = HISTORY_UI_ADMIN_ACLS + "," + appListener.adminAcls.getOrElse("") + ui.getSecurityManager.setAdminAcls(adminAcls) + ui.getSecurityManager.setViewAcls(attempt.sparkUser, appListener.viewAcls.getOrElse("")) + val adminAclsGroups = HISTORY_UI_ADMIN_ACLS_GROUPS + "," + + appListener.adminAclsGroups.getOrElse("") + ui.getSecurityManager.setAdminAclsGroups(adminAclsGroups) ui.getSecurityManager.setViewAclsGroups(appListener.viewAclsGroups.getOrElse("")) Some(LoadedAppUI(ui, updateProbe(appId, attemptId, attempt.fileSize))) } else { diff --git a/core/src/main/scala/org/apache/spark/deploy/history/HistoryServer.scala b/core/src/main/scala/org/apache/spark/deploy/history/HistoryServer.scala index 2b00a4a6b3e52..54f39f7620e5d 100644 --- a/core/src/main/scala/org/apache/spark/deploy/history/HistoryServer.scala +++ b/core/src/main/scala/org/apache/spark/deploy/history/HistoryServer.scala @@ -291,7 +291,7 @@ object HistoryServer extends Logging { /** * Create a security manager. - * This turns off security in the SecurityManager, so that the the History Server can start + * This turns off security in the SecurityManager, so that the History Server can start * in a Spark cluster where security is enabled. * @param config configuration for the SecurityManager constructor * @return the security manager for use in constructing the History Server. diff --git a/core/src/main/scala/org/apache/spark/deploy/worker/Worker.scala b/core/src/main/scala/org/apache/spark/deploy/worker/Worker.scala index f963a4606068b..e48817ebbafdd 100755 --- a/core/src/main/scala/org/apache/spark/deploy/worker/Worker.scala +++ b/core/src/main/scala/org/apache/spark/deploy/worker/Worker.scala @@ -445,12 +445,25 @@ private[deploy] class Worker( // Create local dirs for the executor. These are passed to the executor via the // SPARK_EXECUTOR_DIRS environment variable, and deleted by the Worker when the // application finishes. - val appLocalDirs = appDirectories.getOrElse(appId, - Utils.getOrCreateLocalRootDirs(conf).map { dir => - val appDir = Utils.createDirectory(dir, namePrefix = "executor") - Utils.chmod700(appDir) - appDir.getAbsolutePath() - }.toSeq) + val appLocalDirs = appDirectories.getOrElse(appId, { + val localRootDirs = Utils.getOrCreateLocalRootDirs(conf) + val dirs = localRootDirs.flatMap { dir => + try { + val appDir = Utils.createDirectory(dir, namePrefix = "executor") + Utils.chmod700(appDir) + Some(appDir.getAbsolutePath()) + } catch { + case e: IOException => + logWarning(s"${e.getMessage}. Ignoring this directory.") + None + } + }.toSeq + if (dirs.isEmpty) { + throw new IOException("No subfolder can be created in " + + s"${localRootDirs.mkString(",")}.") + } + dirs + }) appDirectories(appId) = appLocalDirs val manager = new ExecutorRunner( appId, diff --git a/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala b/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala index 92a27902c6696..4a38560d8deca 100644 --- a/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala +++ b/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala @@ -92,10 +92,9 @@ private[spark] class CoarseGrainedExecutorBackend( if (executor == null) { exitExecutor(1, "Received LaunchTask command but executor was null") } else { - val taskDesc = ser.deserialize[TaskDescription](data.value) + val taskDesc = TaskDescription.decode(data.value) logInfo("Got assigned task " + taskDesc.taskId) - executor.launchTask(this, taskId = taskDesc.taskId, attemptNumber = taskDesc.attemptNumber, - taskDesc.name, taskDesc.serializedTask) + executor.launchTask(this, taskDesc) } case KillTask(taskId, _, interruptThread) => diff --git a/core/src/main/scala/org/apache/spark/executor/Executor.scala b/core/src/main/scala/org/apache/spark/executor/Executor.scala index 3346f6dd1f975..789198f52ca45 100644 --- a/core/src/main/scala/org/apache/spark/executor/Executor.scala +++ b/core/src/main/scala/org/apache/spark/executor/Executor.scala @@ -26,7 +26,7 @@ import java.util.concurrent.{ConcurrentHashMap, TimeUnit} import javax.annotation.concurrent.GuardedBy import scala.collection.JavaConverters._ -import scala.collection.mutable.{ArrayBuffer, HashMap} +import scala.collection.mutable.{ArrayBuffer, HashMap, Map} import scala.util.control.NonFatal import org.apache.spark._ @@ -34,7 +34,7 @@ import org.apache.spark.deploy.SparkHadoopUtil import org.apache.spark.internal.Logging import org.apache.spark.memory.TaskMemoryManager import org.apache.spark.rpc.RpcTimeout -import org.apache.spark.scheduler.{AccumulableInfo, DirectTaskResult, IndirectTaskResult, Task} +import org.apache.spark.scheduler.{DirectTaskResult, IndirectTaskResult, Task, TaskDescription} import org.apache.spark.shuffle.FetchFailedException import org.apache.spark.storage.{StorageLevel, TaskResultBlockId} import org.apache.spark.util._ @@ -148,15 +148,9 @@ private[spark] class Executor( startDriverHeartbeater() - def launchTask( - context: ExecutorBackend, - taskId: Long, - attemptNumber: Int, - taskName: String, - serializedTask: ByteBuffer): Unit = { - val tr = new TaskRunner(context, taskId = taskId, attemptNumber = attemptNumber, taskName, - serializedTask) - runningTasks.put(taskId, tr) + def launchTask(context: ExecutorBackend, taskDescription: TaskDescription): Unit = { + val tr = new TaskRunner(context, taskDescription) + runningTasks.put(taskDescription.taskId, tr) threadPool.execute(tr) } @@ -212,13 +206,12 @@ private[spark] class Executor( class TaskRunner( execBackend: ExecutorBackend, - val taskId: Long, - val attemptNumber: Int, - taskName: String, - serializedTask: ByteBuffer) + private val taskDescription: TaskDescription) extends Runnable { + val taskId = taskDescription.taskId val threadName = s"Executor task launch worker for task $taskId" + private val taskName = taskDescription.name /** Whether this task has been killed. */ @volatile private var killed = false @@ -287,16 +280,14 @@ private[spark] class Executor( startGCTime = computeTotalGcTime() try { - val (taskFiles, taskJars, taskProps, taskBytes) = - Task.deserializeWithDependencies(serializedTask) - // Must be set before updateDependencies() is called, in case fetching dependencies // requires access to properties contained within (e.g. for access control). - Executor.taskDeserializationProps.set(taskProps) + Executor.taskDeserializationProps.set(taskDescription.properties) - updateDependencies(taskFiles, taskJars) - task = ser.deserialize[Task[Any]](taskBytes, Thread.currentThread.getContextClassLoader) - task.localProperties = taskProps + updateDependencies(taskDescription.addedFiles, taskDescription.addedJars) + task = ser.deserialize[Task[Any]]( + taskDescription.serializedTask, Thread.currentThread.getContextClassLoader) + task.localProperties = taskDescription.properties task.setTaskMemoryManager(taskMemoryManager) // If this task has been killed before we deserialized it, let's quit now. Otherwise, @@ -321,7 +312,7 @@ private[spark] class Executor( val value = try { val res = task.run( taskAttemptId = taskId, - attemptNumber = attemptNumber, + attemptNumber = taskDescription.attemptNumber, metricsSystem = env.metricsSystem) threwException = false res @@ -637,7 +628,7 @@ private[spark] class Executor( * Download any missing dependencies if we receive a new set of files and JARs from the * SparkContext. Also adds any new JARs we fetched to the class loader. */ - private def updateDependencies(newFiles: HashMap[String, Long], newJars: HashMap[String, Long]) { + private def updateDependencies(newFiles: Map[String, Long], newJars: Map[String, Long]) { lazy val hadoopConf = SparkHadoopUtil.get.newConfiguration(conf) synchronized { // Fetch missing dependencies diff --git a/core/src/main/scala/org/apache/spark/executor/ShuffleReadMetrics.scala b/core/src/main/scala/org/apache/spark/executor/ShuffleReadMetrics.scala index f7a991770d402..8dd1a1ea059be 100644 --- a/core/src/main/scala/org/apache/spark/executor/ShuffleReadMetrics.scala +++ b/core/src/main/scala/org/apache/spark/executor/ShuffleReadMetrics.scala @@ -92,7 +92,7 @@ class ShuffleReadMetrics private[spark] () extends Serializable { private[spark] def setRecordsRead(v: Long): Unit = _recordsRead.setValue(v) /** - * Resets the value of the current metrics (`this`) and and merges all the independent + * Resets the value of the current metrics (`this`) and merges all the independent * [[TempShuffleReadMetrics]] into `this`. */ private[spark] def setMergeValues(metrics: Seq[TempShuffleReadMetrics]): Unit = { diff --git a/core/src/main/scala/org/apache/spark/network/netty/NettyBlockTransferService.scala b/core/src/main/scala/org/apache/spark/network/netty/NettyBlockTransferService.scala index dc70eb82d2b54..3d4ea3cccc934 100644 --- a/core/src/main/scala/org/apache/spark/network/netty/NettyBlockTransferService.scala +++ b/core/src/main/scala/org/apache/spark/network/netty/NettyBlockTransferService.scala @@ -37,7 +37,7 @@ import org.apache.spark.storage.{BlockId, StorageLevel} import org.apache.spark.util.Utils /** - * A BlockTransferService that uses Netty to fetch a set of blocks at at time. + * A BlockTransferService that uses Netty to fetch a set of blocks at time. */ private[spark] class NettyBlockTransferService( conf: SparkConf, diff --git a/core/src/main/scala/org/apache/spark/scheduler/Task.scala b/core/src/main/scala/org/apache/spark/scheduler/Task.scala index 5becca6c06a6c..51976f666dfce 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/Task.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/Task.scala @@ -215,89 +215,3 @@ private[spark] abstract class Task[T]( } } } - -/** - * Handles transmission of tasks and their dependencies, because this can be slightly tricky. We - * need to send the list of JARs and files added to the SparkContext with each task to ensure that - * worker nodes find out about it, but we can't make it part of the Task because the user's code in - * the task might depend on one of the JARs. Thus we serialize each task as multiple objects, by - * first writing out its dependencies. - */ -private[spark] object Task { - /** - * Serialize a task and the current app dependencies (files and JARs added to the SparkContext) - */ - def serializeWithDependencies( - task: Task[_], - currentFiles: mutable.Map[String, Long], - currentJars: mutable.Map[String, Long], - serializer: SerializerInstance) - : ByteBuffer = { - - val out = new ByteBufferOutputStream(4096) - val dataOut = new DataOutputStream(out) - - // Write currentFiles - dataOut.writeInt(currentFiles.size) - for ((name, timestamp) <- currentFiles) { - dataOut.writeUTF(name) - dataOut.writeLong(timestamp) - } - - // Write currentJars - dataOut.writeInt(currentJars.size) - for ((name, timestamp) <- currentJars) { - dataOut.writeUTF(name) - dataOut.writeLong(timestamp) - } - - // Write the task properties separately so it is available before full task deserialization. - val propBytes = Utils.serialize(task.localProperties) - dataOut.writeInt(propBytes.length) - dataOut.write(propBytes) - - // Write the task itself and finish - dataOut.flush() - val taskBytes = serializer.serialize(task) - Utils.writeByteBuffer(taskBytes, out) - out.close() - out.toByteBuffer - } - - /** - * Deserialize the list of dependencies in a task serialized with serializeWithDependencies, - * and return the task itself as a serialized ByteBuffer. The caller can then update its - * ClassLoaders and deserialize the task. - * - * @return (taskFiles, taskJars, taskProps, taskBytes) - */ - def deserializeWithDependencies(serializedTask: ByteBuffer) - : (HashMap[String, Long], HashMap[String, Long], Properties, ByteBuffer) = { - - val in = new ByteBufferInputStream(serializedTask) - val dataIn = new DataInputStream(in) - - // Read task's files - val taskFiles = new HashMap[String, Long]() - val numFiles = dataIn.readInt() - for (i <- 0 until numFiles) { - taskFiles(dataIn.readUTF()) = dataIn.readLong() - } - - // Read task's JARs - val taskJars = new HashMap[String, Long]() - val numJars = dataIn.readInt() - for (i <- 0 until numJars) { - taskJars(dataIn.readUTF()) = dataIn.readLong() - } - - val propLength = dataIn.readInt() - val propBytes = new Array[Byte](propLength) - dataIn.readFully(propBytes, 0, propLength) - val taskProps = Utils.deserialize[Properties](propBytes) - - // Create a sub-buffer for the rest of the data, which is the serialized Task object - val subBuffer = serializedTask.slice() // ByteBufferInputStream will have read just up to task - (taskFiles, taskJars, taskProps, subBuffer) - } -} diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskDescription.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskDescription.scala index 45c742cbff5e7..78aa5c40010cc 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/TaskDescription.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/TaskDescription.scala @@ -17,13 +17,31 @@ package org.apache.spark.scheduler +import java.io.{DataInputStream, DataOutputStream} import java.nio.ByteBuffer +import java.util.Properties -import org.apache.spark.util.SerializableBuffer +import scala.collection.JavaConverters._ +import scala.collection.mutable.{HashMap, Map} + +import org.apache.spark.util.{ByteBufferInputStream, ByteBufferOutputStream, Utils} /** * Description of a task that gets passed onto executors to be executed, usually created by * `TaskSetManager.resourceOffer`. + * + * TaskDescriptions and the associated Task need to be serialized carefully for two reasons: + * + * (1) When a TaskDescription is received by an Executor, the Executor needs to first get the + * list of JARs and files and add these to the classpath, and set the properties, before + * deserializing the Task object (serializedTask). This is why the Properties are included + * in the TaskDescription, even though they're also in the serialized task. + * (2) Because a TaskDescription is serialized and sent to an executor for each task, efficient + * serialization (both in terms of serialization time and serialized buffer size) is + * important. For this reason, we serialize TaskDescriptions ourselves with the + * TaskDescription.encode and TaskDescription.decode methods. This results in a smaller + * serialized size because it avoids serializing unnecessary fields in the Map objects + * (which can introduce significant overhead when the maps are small). */ private[spark] class TaskDescription( val taskId: Long, @@ -31,13 +49,88 @@ private[spark] class TaskDescription( val executorId: String, val name: String, val index: Int, // Index within this task's TaskSet - _serializedTask: ByteBuffer) - extends Serializable { + val addedFiles: Map[String, Long], + val addedJars: Map[String, Long], + val properties: Properties, + val serializedTask: ByteBuffer) { + + override def toString: String = "TaskDescription(TID=%d, index=%d)".format(taskId, index) +} - // Because ByteBuffers are not serializable, wrap the task in a SerializableBuffer - private val buffer = new SerializableBuffer(_serializedTask) +private[spark] object TaskDescription { + private def serializeStringLongMap(map: Map[String, Long], dataOut: DataOutputStream): Unit = { + dataOut.writeInt(map.size) + for ((key, value) <- map) { + dataOut.writeUTF(key) + dataOut.writeLong(value) + } + } - def serializedTask: ByteBuffer = buffer.value + def encode(taskDescription: TaskDescription): ByteBuffer = { + val bytesOut = new ByteBufferOutputStream(4096) + val dataOut = new DataOutputStream(bytesOut) - override def toString: String = "TaskDescription(TID=%d, index=%d)".format(taskId, index) + dataOut.writeLong(taskDescription.taskId) + dataOut.writeInt(taskDescription.attemptNumber) + dataOut.writeUTF(taskDescription.executorId) + dataOut.writeUTF(taskDescription.name) + dataOut.writeInt(taskDescription.index) + + // Write files. + serializeStringLongMap(taskDescription.addedFiles, dataOut) + + // Write jars. + serializeStringLongMap(taskDescription.addedJars, dataOut) + + // Write properties. + dataOut.writeInt(taskDescription.properties.size()) + taskDescription.properties.asScala.foreach { case (key, value) => + dataOut.writeUTF(key) + dataOut.writeUTF(value) + } + + // Write the task. The task is already serialized, so write it directly to the byte buffer. + Utils.writeByteBuffer(taskDescription.serializedTask, bytesOut) + + dataOut.close() + bytesOut.close() + bytesOut.toByteBuffer + } + + private def deserializeStringLongMap(dataIn: DataInputStream): HashMap[String, Long] = { + val map = new HashMap[String, Long]() + val mapSize = dataIn.readInt() + for (i <- 0 until mapSize) { + map(dataIn.readUTF()) = dataIn.readLong() + } + map + } + + def decode(byteBuffer: ByteBuffer): TaskDescription = { + val dataIn = new DataInputStream(new ByteBufferInputStream(byteBuffer)) + val taskId = dataIn.readLong() + val attemptNumber = dataIn.readInt() + val executorId = dataIn.readUTF() + val name = dataIn.readUTF() + val index = dataIn.readInt() + + // Read files. + val taskFiles = deserializeStringLongMap(dataIn) + + // Read jars. + val taskJars = deserializeStringLongMap(dataIn) + + // Read properties. + val properties = new Properties() + val numProperties = dataIn.readInt() + for (i <- 0 until numProperties) { + properties.setProperty(dataIn.readUTF(), dataIn.readUTF()) + } + + // Create a sub-buffer for the serialized task into its own buffer (to be deserialized later). + val serializedTask = byteBuffer.slice() + + new TaskDescription(taskId, attemptNumber, executorId, name, index, taskFiles, taskJars, + properties, serializedTask) + } } diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskResultGetter.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskResultGetter.scala index b1addc128e696..a284f7956cd31 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/TaskResultGetter.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/TaskResultGetter.scala @@ -143,8 +143,12 @@ private[spark] class TaskResultGetter(sparkEnv: SparkEnv, scheduler: TaskSchedul logError( "Could not deserialize TaskEndReason: ClassNotFound with classloader " + loader) case ex: Exception => // No-op + } finally { + // If there's an error while deserializing the TaskEndReason, this Runnable + // will die. Still tell the scheduler about the task failure, to avoid a hang + // where the scheduler thinks the task is still running. + scheduler.handleFailedTask(taskSetManager, tid, taskState, reason) } - scheduler.handleFailedTask(taskSetManager, tid, taskState, reason) } }) } catch { diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala index 3756c216f5ecb..c7ff13cebfaaa 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala @@ -446,9 +446,8 @@ private[spark] class TaskSetManager( lastLaunchTime = curTime } // Serialize and return the task - val startTime = clock.getTimeMillis() val serializedTask: ByteBuffer = try { - Task.serializeWithDependencies(task, sched.sc.addedFiles, sched.sc.addedJars, ser) + ser.serialize(task) } catch { // If the task cannot be serialized, then there's no point to re-attempt the task, // as it will always fail. So just abort the whole task-set. @@ -475,8 +474,16 @@ private[spark] class TaskSetManager( s"partition ${task.partitionId}, $taskLocality, ${serializedTask.limit} bytes)") sched.dagScheduler.taskStarted(task, info) - new TaskDescription(taskId = taskId, attemptNumber = attemptNum, execId, - taskName, index, serializedTask) + new TaskDescription( + taskId, + attemptNum, + execId, + taskName, + index, + sched.sc.addedFiles, + sched.sc.addedJars, + task.localProperties, + serializedTask) } } else { None diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala index 3452487e72e88..31575c0ca0d15 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala @@ -98,11 +98,6 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, val rpcEnv: Rp // Executors that have been lost, but for which we don't yet know the real exit reason. protected val executorsPendingLossReason = new HashSet[String] - // If this DriverEndpoint is changed to support multiple threads, - // then this may need to be changed so that we don't share the serializer - // instance across threads - private val ser = SparkEnv.get.closureSerializer.newInstance() - protected val addressToExecutorId = new HashMap[RpcAddress, String] private val reviveThread = @@ -249,7 +244,7 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, val rpcEnv: Rp // Launch tasks returned by a set of resource offers private def launchTasks(tasks: Seq[Seq[TaskDescription]]) { for (task <- tasks.flatten) { - val serializedTask = ser.serialize(task) + val serializedTask = TaskDescription.encode(task) if (serializedTask.limit >= maxRpcMessageSize) { scheduler.taskIdToTaskSetManager.get(task.taskId).foreach { taskSetMgr => try { diff --git a/core/src/main/scala/org/apache/spark/scheduler/local/LocalSchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/local/LocalSchedulerBackend.scala index 7a73e8ed8a38f..625f998cd4608 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/local/LocalSchedulerBackend.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/local/LocalSchedulerBackend.scala @@ -84,8 +84,7 @@ private[spark] class LocalEndpoint( val offers = IndexedSeq(new WorkerOffer(localExecutorId, localExecutorHostname, freeCores)) for (task <- scheduler.resourceOffers(offers).flatten) { freeCores -= scheduler.CPUS_PER_TASK - executor.launchTask(executorBackend, taskId = task.taskId, attemptNumber = task.attemptNumber, - task.name, task.serializedTask) + executor.launchTask(executorBackend, task) } } } diff --git a/core/src/main/scala/org/apache/spark/util/SizeEstimator.scala b/core/src/main/scala/org/apache/spark/util/SizeEstimator.scala index 386fdfd218a88..3bfdf95db84c6 100644 --- a/core/src/main/scala/org/apache/spark/util/SizeEstimator.scala +++ b/core/src/main/scala/org/apache/spark/util/SizeEstimator.scala @@ -350,7 +350,7 @@ object SizeEstimator extends Logging { // 3. consistent fields layouts throughout the hierarchy: This means we should layout // superclass first. And we can use superclass's shellSize as a starting point to layout the // other fields in this class. - // 4. class alignment: HotSpot rounds field blocks up to to HeapOopSize not 4 bytes, confirmed + // 4. class alignment: HotSpot rounds field blocks up to HeapOopSize not 4 bytes, confirmed // with Aleksey. see https://bugs.openjdk.java.net/browse/CODETOOLS-7901322 // // The real world field layout is much more complicated. There are three kinds of fields diff --git a/core/src/main/scala/org/apache/spark/util/Utils.scala b/core/src/main/scala/org/apache/spark/util/Utils.scala index 078cc3d5b4f0c..0dcf0307e113f 100644 --- a/core/src/main/scala/org/apache/spark/util/Utils.scala +++ b/core/src/main/scala/org/apache/spark/util/Utils.scala @@ -237,9 +237,11 @@ private[spark] object Utils extends Logging { if (bb.hasArray) { out.write(bb.array(), bb.arrayOffset() + bb.position(), bb.remaining()) } else { + val originalPosition = bb.position() val bbval = new Array[Byte](bb.remaining()) bb.get(bbval) out.write(bbval) + bb.position(originalPosition) } } @@ -250,9 +252,11 @@ private[spark] object Utils extends Logging { if (bb.hasArray) { out.write(bb.array(), bb.arrayOffset() + bb.position(), bb.remaining()) } else { + val originalPosition = bb.position() val bbval = new Array[Byte](bb.remaining()) bb.get(bbval) out.write(bbval) + bb.position(originalPosition) } } diff --git a/core/src/test/scala/org/apache/spark/deploy/history/ApplicationCacheSuite.scala b/core/src/test/scala/org/apache/spark/deploy/history/ApplicationCacheSuite.scala index e3304be792af7..7998e3702c122 100644 --- a/core/src/test/scala/org/apache/spark/deploy/history/ApplicationCacheSuite.scala +++ b/core/src/test/scala/org/apache/spark/deploy/history/ApplicationCacheSuite.scala @@ -253,7 +253,7 @@ class ApplicationCacheSuite extends SparkFunSuite with Logging with MockitoSugar assertNotFound(appId, None) } - test("Test that if an attempt ID is is set, it must be used in lookups") { + test("Test that if an attempt ID is set, it must be used in lookups") { val operations = new StubCacheOperations() val clock = new ManualClock(1) implicit val cache = new ApplicationCache(operations, retainedApplications = 10, clock = clock) diff --git a/core/src/test/scala/org/apache/spark/deploy/history/FsHistoryProviderSuite.scala b/core/src/test/scala/org/apache/spark/deploy/history/FsHistoryProviderSuite.scala index 027f412c7581b..8cb359ed45192 100644 --- a/core/src/test/scala/org/apache/spark/deploy/history/FsHistoryProviderSuite.scala +++ b/core/src/test/scala/org/apache/spark/deploy/history/FsHistoryProviderSuite.scala @@ -35,10 +35,11 @@ import org.scalatest.BeforeAndAfter import org.scalatest.Matchers import org.scalatest.concurrent.Eventually._ -import org.apache.spark.{SparkConf, SparkFunSuite} +import org.apache.spark.{SecurityManager, SparkConf, SparkFunSuite} import org.apache.spark.internal.Logging import org.apache.spark.io._ import org.apache.spark.scheduler._ +import org.apache.spark.security.GroupMappingServiceProvider import org.apache.spark.util.{Clock, JsonProtocol, ManualClock, Utils} class FsHistoryProviderSuite extends SparkFunSuite with BeforeAndAfter with Matchers with Logging { @@ -474,6 +475,102 @@ class FsHistoryProviderSuite extends SparkFunSuite with BeforeAndAfter with Matc } } + test("support history server ui admin acls") { + def createAndCheck(conf: SparkConf, properties: (String, String)*) + (checkFn: SecurityManager => Unit): Unit = { + // Empty the testDir for each test. + if (testDir.exists() && testDir.isDirectory) { + testDir.listFiles().foreach { f => if (f.isFile) f.delete() } + } + + var provider: FsHistoryProvider = null + try { + provider = new FsHistoryProvider(conf) + val log = newLogFile("app1", Some("attempt1"), inProgress = false) + writeFile(log, true, None, + SparkListenerApplicationStart("app1", Some("app1"), System.currentTimeMillis(), + "test", Some("attempt1")), + SparkListenerEnvironmentUpdate(Map( + "Spark Properties" -> properties.toSeq, + "JVM Information" -> Seq.empty, + "System Properties" -> Seq.empty, + "Classpath Entries" -> Seq.empty + )), + SparkListenerApplicationEnd(System.currentTimeMillis())) + + provider.checkForLogs() + val appUi = provider.getAppUI("app1", Some("attempt1")) + + assert(appUi.nonEmpty) + val securityManager = appUi.get.ui.securityManager + checkFn(securityManager) + } finally { + if (provider != null) { + provider.stop() + } + } + } + + // Test both history ui admin acls and application acls are configured. + val conf1 = createTestConf() + .set("spark.history.ui.acls.enable", "true") + .set("spark.history.ui.admin.acls", "user1,user2") + .set("spark.history.ui.admin.acls.groups", "group1") + .set("spark.user.groups.mapping", classOf[TestGroupsMappingProvider].getName) + + createAndCheck(conf1, ("spark.admin.acls", "user"), ("spark.admin.acls.groups", "group")) { + securityManager => + // Test whether user has permission to access UI. + securityManager.checkUIViewPermissions("user1") should be (true) + securityManager.checkUIViewPermissions("user2") should be (true) + securityManager.checkUIViewPermissions("user") should be (true) + securityManager.checkUIViewPermissions("abc") should be (false) + + // Test whether user with admin group has permission to access UI. + securityManager.checkUIViewPermissions("user3") should be (true) + securityManager.checkUIViewPermissions("user4") should be (true) + securityManager.checkUIViewPermissions("user5") should be (true) + securityManager.checkUIViewPermissions("user6") should be (false) + } + + // Test only history ui admin acls are configured. + val conf2 = createTestConf() + .set("spark.history.ui.acls.enable", "true") + .set("spark.history.ui.admin.acls", "user1,user2") + .set("spark.history.ui.admin.acls.groups", "group1") + .set("spark.user.groups.mapping", classOf[TestGroupsMappingProvider].getName) + createAndCheck(conf2) { securityManager => + // Test whether user has permission to access UI. + securityManager.checkUIViewPermissions("user1") should be (true) + securityManager.checkUIViewPermissions("user2") should be (true) + // Check the unknown "user" should return false + securityManager.checkUIViewPermissions("user") should be (false) + + // Test whether user with admin group has permission to access UI. + securityManager.checkUIViewPermissions("user3") should be (true) + securityManager.checkUIViewPermissions("user4") should be (true) + // Check the "user5" without mapping relation should return false + securityManager.checkUIViewPermissions("user5") should be (false) + } + + // Test neither history ui admin acls nor application acls are configured. + val conf3 = createTestConf() + .set("spark.history.ui.acls.enable", "true") + .set("spark.user.groups.mapping", classOf[TestGroupsMappingProvider].getName) + createAndCheck(conf3) { securityManager => + // Test whether user has permission to access UI. + securityManager.checkUIViewPermissions("user1") should be (false) + securityManager.checkUIViewPermissions("user2") should be (false) + securityManager.checkUIViewPermissions("user") should be (false) + + // Test whether user with admin group has permission to access UI. + // Check should be failed since we don't have acl group settings. + securityManager.checkUIViewPermissions("user3") should be (false) + securityManager.checkUIViewPermissions("user4") should be (false) + securityManager.checkUIViewPermissions("user5") should be (false) + } + } + /** * Asks the provider to check for logs and calls a function to perform checks on the updated * app list. Example: @@ -532,3 +629,15 @@ class FsHistoryProviderSuite extends SparkFunSuite with BeforeAndAfter with Matc } } + +class TestGroupsMappingProvider extends GroupMappingServiceProvider { + private val mappings = Map( + "user3" -> "group1", + "user4" -> "group1", + "user5" -> "group") + + override def getGroups(username: String): Set[String] = { + mappings.get(username).map(Set(_)).getOrElse(Set.empty) + } +} + diff --git a/core/src/test/scala/org/apache/spark/executor/ExecutorSuite.scala b/core/src/test/scala/org/apache/spark/executor/ExecutorSuite.scala index 742500d87d32a..f94baaa30d18d 100644 --- a/core/src/test/scala/org/apache/spark/executor/ExecutorSuite.scala +++ b/core/src/test/scala/org/apache/spark/executor/ExecutorSuite.scala @@ -18,9 +18,10 @@ package org.apache.spark.executor import java.nio.ByteBuffer +import java.util.Properties import java.util.concurrent.CountDownLatch -import scala.collection.mutable.HashMap +import scala.collection.mutable.Map import org.mockito.Matchers._ import org.mockito.Mockito.{mock, when} @@ -32,7 +33,7 @@ import org.apache.spark.TaskState.TaskState import org.apache.spark.memory.MemoryManager import org.apache.spark.metrics.MetricsSystem import org.apache.spark.rpc.RpcEnv -import org.apache.spark.scheduler.{FakeTask, Task} +import org.apache.spark.scheduler.{FakeTask, TaskDescription} import org.apache.spark.serializer.JavaSerializer class ExecutorSuite extends SparkFunSuite { @@ -52,13 +53,18 @@ class ExecutorSuite extends SparkFunSuite { when(mockEnv.memoryManager).thenReturn(mockMemoryManager) when(mockEnv.closureSerializer).thenReturn(serializer) val fakeTaskMetrics = serializer.newInstance().serialize(TaskMetrics.registered).array() - - val serializedTask = - Task.serializeWithDependencies( - new FakeTask(0, 0, Nil, fakeTaskMetrics), - HashMap[String, Long](), - HashMap[String, Long](), - serializer.newInstance()) + val serializedTask = serializer.newInstance().serialize( + new FakeTask(0, 0, Nil, fakeTaskMetrics)) + val taskDescription = new TaskDescription( + taskId = 0, + attemptNumber = 0, + executorId = "", + name = "", + index = 0, + addedFiles = Map[String, Long](), + addedJars = Map[String, Long](), + properties = new Properties, + serializedTask) // we use latches to force the program to run in this order: // +-----------------------------+---------------------------------------+ @@ -108,7 +114,7 @@ class ExecutorSuite extends SparkFunSuite { try { executor = new Executor("id", "localhost", mockEnv, userClassPath = Nil, isLocal = true) // the task will be launched in a dedicated worker thread - executor.launchTask(mockExecutorBackend, 0, 0, "", serializedTask) + executor.launchTask(mockExecutorBackend, taskDescription) executorSuiteHelper.latch1.await() // we know the task will be started, but not yet deserialized, because of the latches we diff --git a/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala index 5e8a854e46a0f..f3d3f701af462 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala @@ -1819,7 +1819,7 @@ class DAGSchedulerSuite extends SparkFunSuite with LocalSparkContext with Timeou assert(mapOutputTracker.getMapSizesByExecutorId(shuffleId, 0).map(_._1).toSet === HashSet(makeBlockManagerId("hostA"))) - // Reducer should run where RDD 2 has preferences, even though though it also has a shuffle dep + // Reducer should run where RDD 2 has preferences, even though it also has a shuffle dep val reduceTaskSet = taskSets(1) assertLocations(reduceTaskSet, Seq(Seq("hostB"))) complete(reduceTaskSet, Seq((Success, 42))) @@ -2058,7 +2058,7 @@ class DAGSchedulerSuite extends SparkFunSuite with LocalSparkContext with Timeou // Now complete tasks in the second task set val newTaskSet = taskSets(1) - assert(newTaskSet.tasks.size === 2) // Both tasks 0 and 1 were on on hostA + assert(newTaskSet.tasks.size === 2) // Both tasks 0 and 1 were on hostA runEvent(makeCompletionEvent(newTaskSet.tasks(0), Success, makeMapStatus("hostB", 2))) assert(results.size === 0) // Map stage job should not be complete yet runEvent(makeCompletionEvent(newTaskSet.tasks(1), Success, makeMapStatus("hostB", 2))) diff --git a/core/src/test/scala/org/apache/spark/scheduler/TaskDescriptionSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/TaskDescriptionSuite.scala new file mode 100644 index 0000000000000..9f1fe0515732e --- /dev/null +++ b/core/src/test/scala/org/apache/spark/scheduler/TaskDescriptionSuite.scala @@ -0,0 +1,69 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.scheduler + +import java.nio.ByteBuffer +import java.util.Properties + +import scala.collection.mutable.HashMap + +import org.apache.spark.SparkFunSuite + +class TaskDescriptionSuite extends SparkFunSuite { + test("encoding and then decoding a TaskDescription results in the same TaskDescription") { + val originalFiles = new HashMap[String, Long]() + originalFiles.put("fileUrl1", 1824) + originalFiles.put("fileUrl2", 2) + + val originalJars = new HashMap[String, Long]() + originalJars.put("jar1", 3) + + val originalProperties = new Properties() + originalProperties.put("property1", "18") + originalProperties.put("property2", "test value") + + // Create a dummy byte buffer for the task. + val taskBuffer = ByteBuffer.wrap(Array[Byte](1, 2, 3, 4)) + + val originalTaskDescription = new TaskDescription( + taskId = 1520589, + attemptNumber = 2, + executorId = "testExecutor", + name = "task for test", + index = 19, + originalFiles, + originalJars, + originalProperties, + taskBuffer + ) + + val serializedTaskDescription = TaskDescription.encode(originalTaskDescription) + val decodedTaskDescription = TaskDescription.decode(serializedTaskDescription) + + // Make sure that all of the fields in the decoded task description match the original. + assert(decodedTaskDescription.taskId === originalTaskDescription.taskId) + assert(decodedTaskDescription.attemptNumber === originalTaskDescription.attemptNumber) + assert(decodedTaskDescription.executorId === originalTaskDescription.executorId) + assert(decodedTaskDescription.name === originalTaskDescription.name) + assert(decodedTaskDescription.index === originalTaskDescription.index) + assert(decodedTaskDescription.addedFiles.equals(originalFiles)) + assert(decodedTaskDescription.addedJars.equals(originalJars)) + assert(decodedTaskDescription.properties.equals(originalTaskDescription.properties)) + assert(decodedTaskDescription.serializedTask.equals(taskBuffer)) + } +} diff --git a/core/src/test/scala/org/apache/spark/scheduler/TaskResultGetterSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/TaskResultGetterSuite.scala index c9e682f53c105..3e55d399e9df9 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/TaskResultGetterSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/TaskResultGetterSuite.scala @@ -17,7 +17,7 @@ package org.apache.spark.scheduler -import java.io.File +import java.io.{File, ObjectInputStream} import java.net.URL import java.nio.ByteBuffer @@ -248,5 +248,24 @@ class TaskResultGetterSuite extends SparkFunSuite with BeforeAndAfter with Local assert(resSizeAfter.exists(_.toString.toLong > 0L)) } + test("failed task is handled when error occurs deserializing the reason") { + sc = new SparkContext("local", "test", conf) + val rdd = sc.parallelize(Seq(1), 1).map { _ => + throw new UndeserializableException + } + val message = intercept[SparkException] { + rdd.collect() + }.getMessage + // Job failed, even though the failure reason is unknown. + val unknownFailure = """(?s).*Lost task.*: UnknownReason.*""".r + assert(unknownFailure.findFirstMatchIn(message).isDefined) + } + +} + +private class UndeserializableException extends Exception { + private def readObject(in: ObjectInputStream): Unit = { + throw new NoClassDefFoundError() + } } diff --git a/core/src/test/scala/org/apache/spark/util/UtilsSuite.scala b/core/src/test/scala/org/apache/spark/util/UtilsSuite.scala index fb7b91222b499..442a603cae791 100644 --- a/core/src/test/scala/org/apache/spark/util/UtilsSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/UtilsSuite.scala @@ -17,7 +17,8 @@ package org.apache.spark.util -import java.io.{ByteArrayInputStream, ByteArrayOutputStream, File, FileOutputStream, PrintStream} +import java.io.{ByteArrayInputStream, ByteArrayOutputStream, DataOutput, DataOutputStream, File, + FileOutputStream, PrintStream} import java.lang.{Double => JDouble, Float => JFloat} import java.net.{BindException, ServerSocket, URI} import java.nio.{ByteBuffer, ByteOrder} @@ -389,6 +390,28 @@ class UtilsSuite extends SparkFunSuite with ResetSystemProperties with Logging { assert(Utils.deserializeLongValue(bbuf.array) === testval) } + test("writeByteBuffer should not change ByteBuffer position") { + // Test a buffer with an underlying array, for both writeByteBuffer methods. + val testBuffer = ByteBuffer.wrap(Array[Byte](1, 2, 3, 4)) + assert(testBuffer.hasArray) + val bytesOut = new ByteBufferOutputStream(4096) + Utils.writeByteBuffer(testBuffer, bytesOut) + assert(testBuffer.position() === 0) + + val dataOut = new DataOutputStream(bytesOut) + Utils.writeByteBuffer(testBuffer, dataOut: DataOutput) + assert(testBuffer.position() === 0) + + // Test a buffer without an underlying array, for both writeByteBuffer methods. + val testDirectBuffer = ByteBuffer.allocateDirect(8) + assert(!testDirectBuffer.hasArray()) + Utils.writeByteBuffer(testDirectBuffer, bytesOut) + assert(testDirectBuffer.position() === 0) + + Utils.writeByteBuffer(testDirectBuffer, dataOut: DataOutput) + assert(testDirectBuffer.position() === 0) + } + test("get iterator size") { val empty = Seq[Int]() assert(Utils.getIteratorSize(empty.toIterator) === 0L) diff --git a/core/src/test/scala/org/apache/spark/util/collection/ExternalAppendOnlyMapSuite.scala b/core/src/test/scala/org/apache/spark/util/collection/ExternalAppendOnlyMapSuite.scala index 7f0838268a111..c8b6a3346a460 100644 --- a/core/src/test/scala/org/apache/spark/util/collection/ExternalAppendOnlyMapSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/collection/ExternalAppendOnlyMapSuite.scala @@ -53,7 +53,7 @@ class ExternalAppendOnlyMapSuite extends SparkFunSuite with LocalSparkContext { conf } - test("single insert insert") { + test("single insert") { val conf = createSparkConf(loadDefaults = false) sc = new SparkContext("local", "test", conf) val map = createExternalMap[Int] diff --git a/docs/configuration.md b/docs/configuration.md index bd67144007e92..d2fdef02df9e0 100644 --- a/docs/configuration.md +++ b/docs/configuration.md @@ -59,6 +59,7 @@ The following format is accepted: 1p or 1pb (pebibytes = 1024 tebibytes) ## Dynamically Loading Spark Properties + In some cases, you may want to avoid hard-coding certain configurations in a `SparkConf`. For instance, if you'd like to run the same application with different masters or different amounts of memory. Spark allows you to simply create an empty conf: @@ -106,7 +107,8 @@ line will appear. For all other configuration properties, you can assume the def Most of the properties that control internal settings have reasonable default values. Some of the most common options to set are: -#### Application Properties +### Application Properties + @@ -215,7 +217,8 @@ of the most common options to set are: Apart from these, the following properties are also available, and may be useful in some situations: -#### Runtime Environment +### Runtime Environment +
Property NameDefaultMeaning
@@ -471,7 +474,8 @@ Apart from these, the following properties are also available, and may be useful
Property NameDefaultMeaning
-#### Shuffle Behavior +### Shuffle Behavior + @@ -612,7 +616,8 @@ Apart from these, the following properties are also available, and may be useful
Property NameDefaultMeaning
-#### Spark UI +### Spark UI + @@ -745,7 +750,8 @@ Apart from these, the following properties are also available, and may be useful
Property NameDefaultMeaning
-#### Compression and Serialization +### Compression and Serialization + @@ -891,7 +897,8 @@ Apart from these, the following properties are also available, and may be useful
Property NameDefaultMeaning
-#### Memory Management +### Memory Management + @@ -981,7 +988,8 @@ Apart from these, the following properties are also available, and may be useful
Property NameDefaultMeaning
-#### Execution Behavior +### Execution Behavior + @@ -1108,7 +1116,8 @@ Apart from these, the following properties are also available, and may be useful
Property NameDefaultMeaning
-#### Networking +### Networking + @@ -1139,13 +1148,13 @@ Apart from these, the following properties are also available, and may be useful @@ -1217,7 +1226,8 @@ Apart from these, the following properties are also available, and may be useful
Property NameDefaultMeaning
spark.driver.bindAddress (value of spark.driver.host) -

Hostname or IP address where to bind listening sockets. This config overrides the SPARK_LOCAL_IP - environment variable (see below).

+ Hostname or IP address where to bind listening sockets. This config overrides the SPARK_LOCAL_IP + environment variable (see below). -

It also allows a different address from the local one to be advertised to executors or external systems. +
It also allows a different address from the local one to be advertised to executors or external systems. This is useful, for example, when running containers with bridged networking. For this to properly work, the different ports used by the driver (RPC, block manager and UI) need to be forwarded from the - container's host.

+ container's host.
-#### Scheduling +### Scheduling + @@ -1467,7 +1477,8 @@ Apart from these, the following properties are also available, and may be useful
Property NameDefaultMeaning
-#### Dynamic Allocation +### Dynamic Allocation + @@ -1548,7 +1559,8 @@ Apart from these, the following properties are also available, and may be useful
Property NameDefaultMeaning
-#### Security +### Security + @@ -1729,7 +1741,7 @@ Apart from these, the following properties are also available, and may be useful
Property NameDefaultMeaning
-#### TLS / SSL +### TLS / SSL @@ -1737,21 +1749,21 @@ Apart from these, the following properties are also available, and may be useful @@ -1835,7 +1847,8 @@ Apart from these, the following properties are also available, and may be useful
Property NameDefaultMeaning
spark.ssl.enabled false -

Whether to enable SSL connections on all supported protocols.

+ Whether to enable SSL connections on all supported protocols. -

When spark.ssl.enabled is configured, spark.ssl.protocol - is required.

+
When spark.ssl.enabled is configured, spark.ssl.protocol + is required. -

All the SSL settings like spark.ssl.xxx where xxx is a +
All the SSL settings like spark.ssl.xxx where xxx is a particular configuration property, denote the global configuration for all the supported protocols. In order to override the global configuration for the particular protocol, - the properties must be overwritten in the protocol-specific namespace.

+ the properties must be overwritten in the protocol-specific namespace. -

Use spark.ssl.YYY.XXX settings to overwrite the global configuration for +
Use spark.ssl.YYY.XXX settings to overwrite the global configuration for particular protocol denoted by YYY. Example values for YYY include fs, ui, standalone, and historyServer. See SSL - Configuration for details on hierarchical SSL configuration for services.

+ Configuration for details on hierarchical SSL configuration for services.
-#### Spark SQL +### Spark SQL + Running the SET -v command will show the entire list of the SQL configuration.

@@ -1877,7 +1890,8 @@ showDF(properties, numRows = 200, truncate = FALSE)
-#### Spark Streaming +### Spark Streaming + @@ -1998,7 +2012,8 @@ showDF(properties, numRows = 200, truncate = FALSE)
Property NameDefaultMeaning
-#### SparkR +### SparkR + @@ -2047,7 +2062,7 @@ showDF(properties, numRows = 200, truncate = FALSE)
Property NameDefaultMeaning
-#### Deploy +### Deploy @@ -2070,15 +2085,16 @@ showDF(properties, numRows = 200, truncate = FALSE)
Property NameDefaultMeaning
-#### Cluster Managers +### Cluster Managers + Each cluster manager in Spark has additional configuration options. Configurations can be found on the pages for each mode: -##### [YARN](running-on-yarn.html#configuration) +#### [YARN](running-on-yarn.html#configuration) -##### [Mesos](running-on-mesos.html#configuration) +#### [Mesos](running-on-mesos.html#configuration) -##### [Standalone Mode](spark-standalone.html#cluster-launch-scripts) +#### [Standalone Mode](spark-standalone.html#cluster-launch-scripts) # Environment Variables diff --git a/docs/img/structured-streaming-watermark-append-mode.png b/docs/img/structured-streaming-watermark-append-mode.png new file mode 100644 index 0000000000000..541d5bf399b76 Binary files /dev/null and b/docs/img/structured-streaming-watermark-append-mode.png differ diff --git a/docs/img/structured-streaming-watermark-update-mode.png b/docs/img/structured-streaming-watermark-update-mode.png new file mode 100644 index 0000000000000..6827849c3269d Binary files /dev/null and b/docs/img/structured-streaming-watermark-update-mode.png differ diff --git a/docs/img/structured-streaming-watermark.png b/docs/img/structured-streaming-watermark.png deleted file mode 100644 index f21fbda171013..0000000000000 Binary files a/docs/img/structured-streaming-watermark.png and /dev/null differ diff --git a/docs/img/structured-streaming.pptx b/docs/img/structured-streaming.pptx index f5bdfc078cad9..2ffd9f2a51399 100644 Binary files a/docs/img/structured-streaming.pptx and b/docs/img/structured-streaming.pptx differ diff --git a/docs/ml-features.md b/docs/ml-features.md index 1d3449746c9be..d67fce3c9528a 100644 --- a/docs/ml-features.md +++ b/docs/ml-features.md @@ -752,7 +752,7 @@ for more details on the API. `Interaction` is a `Transformer` which takes vector or double-valued columns, and generates a single vector column that contains the product of all combinations of one value from each input column. -For example, if you have 2 vector type columns each of which has 3 dimensions as input columns, then then you'll get a 9-dimensional vector as the output column. +For example, if you have 2 vector type columns each of which has 3 dimensions as input columns, then you'll get a 9-dimensional vector as the output column. **Examples** diff --git a/docs/mllib-statistics.md b/docs/mllib-statistics.md index 12797bd8688e1..430c0690457e8 100644 --- a/docs/mllib-statistics.md +++ b/docs/mllib-statistics.md @@ -354,7 +354,7 @@ v = u.map(lambda x: 1.0 + 2.0 * x) useful for visualizing empirical probability distributions without requiring assumptions about the particular distribution that the observed samples are drawn from. It computes an estimate of the probability density function of a random variables, evaluated at a given set of points. It achieves -this estimate by expressing the PDF of the empirical distribution at a particular point as the the +this estimate by expressing the PDF of the empirical distribution at a particular point as the mean of PDFs of normal distributions centered around each of the samples.
diff --git a/docs/monitoring.md b/docs/monitoring.md index 7a1de52668f1a..e918174e2b57c 100644 --- a/docs/monitoring.md +++ b/docs/monitoring.md @@ -169,6 +169,28 @@ The history server can be configured as follows: If disabled, no access control checks are made. + + spark.history.ui.admin.acls + empty + + Comma separated list of users/administrators that have view access to all the Spark applications in + history server. By default only the users permitted to view the application at run-time could + access the related application history, with this, configured users/administrators could also + have the permission to access it. + Putting a "*" in the list means any user can have the privilege of admin. + + + + spark.history.ui.admin.acls.groups + empty + + Comma separated list of groups that have view access to all the Spark applications in + history server. By default only the groups permitted to view the application at run-time could + access the related application history, with this, configured groups could also + have the permission to access it. + Putting a "*" in the list means any group can have the privilege of admin. + + spark.history.fs.cleaner.enabled false @@ -276,7 +298,7 @@ can be identified by their `[attempt-id]`. In the API listed below, when running /applications/[app-id]/stages/[stage-id]/[stage-attempt-id] - Details for the given stage attempt + Details for the given stage attempt. /applications/[app-id]/stages/[stage-id]/[stage-attempt-id]/taskSummary @@ -321,6 +343,34 @@ can be identified by their `[attempt-id]`. In the API listed below, when running /applications/[base-app-id]/[attempt-id]/logs Download the event logs for a specific application attempt as a zip file. + + /applications/[app-id]/streaming/statistics + Statistics for the streaming context. + + + /applications/[app-id]/streaming/receivers + A list of all streaming receivers. + + + /applications/[app-id]/streaming/receivers/[stream-id] + Details of the given receiver. + + + /applications/[app-id]/streaming/batches + A list of all retained batches. + + + /applications/[app-id]/streaming/batches/[batch-id] + Details of the given batch. + + + /applications/[app-id]/streaming/batches/[batch-id]/operations + A list of all output operations of the given batch. + + + /applications/[app-id]/streaming/batches/[batch-id]/operations/[outputOp-id] + Details of the given operation and given batch. + The number of jobs and stages which can retrieved is constrained by the same retention diff --git a/docs/sql-programming-guide.md b/docs/sql-programming-guide.md index 4cd21aef91235..f4c89e58fa431 100644 --- a/docs/sql-programming-guide.md +++ b/docs/sql-programming-guide.md @@ -522,14 +522,11 @@ Hive metastore. Persistent tables will still exist even after your Spark program long as you maintain your connection to the same metastore. A DataFrame for a persistent table can be created by calling the `table` method on a `SparkSession` with the name of the table. -By default `saveAsTable` will create a "managed table", meaning that the location of the data will -be controlled by the metastore. Managed tables will also have their data deleted automatically -when a table is dropped. - -Currently, `saveAsTable` does not expose an API supporting the creation of an "external table" from a `DataFrame`. -However, this functionality can be achieved by providing a `path` option to the `DataFrameWriter` with `path` as the key -and location of the external table as its value (a string) when saving the table with `saveAsTable`. When an External table -is dropped only its metadata is removed. +For file-based data source, e.g. text, parquet, json, etc. you can specify a custom table path via the +`path` option, e.g. `df.write.option("path", "/some/path").saveAsTable("t")`. When the table is dropped, +the custom table path will not be removed and the table data is still there. If no custom table path is +specifed, Spark will write data to a default table path under the warehouse directory. When the table is +dropped, the default table path will be removed too. Starting from Spark 2.1, persistent datasource tables have per-partition metadata stored in the Hive metastore. This brings several benefits: @@ -954,6 +951,53 @@ adds support for finding tables in the MetaStore and writing queries using HiveQ
+### Specifying storage format for Hive tables + +When you create a Hive table, you need to define how this table should read/write data from/to file system, +i.e. the "input format" and "output format". You also need to define how this table should deserialize the data +to rows, or serialize rows to data, i.e. the "serde". The following options can be used to specify the storage +format("serde", "input format", "output format"), e.g. `CREATE TABLE src(id int) USING hive OPTIONS(fileFormat 'parquet')`. +By default, we will read the table files as plain text. Note that, Hive storage handler is not supported yet when +creating table, you can create a table using storage handler at Hive side, and use Spark SQL to read it. + + + + + + + + + + + + + + + + + + + + + + +
Property NameMeaning
fileFormat + A fileFormat is kind of a package of storage format specifications, including "serde", "input format" and + "output format". Currently we support 6 fileFormats: 'sequencefile', 'rcfile', 'orc', 'parquet', 'textfile' and 'avro'. +
inputFormat, outputFormat + These 2 options specify the name of a corresponding `InputFormat` and `OutputFormat` class as a string literal, + e.g. `org.apache.hadoop.hive.ql.io.orc.OrcInputFormat`. These 2 options must be appeared in pair, and you can not + specify them if you already specified the `fileFormat` option. +
serde + This option specifies the name of a serde class. When the `fileFormat` option is specified, do not specify this option + if the given `fileFormat` already include the information of serde. Currently "sequencefile", "textfile" and "rcfile" + don't include the serde information and you can use this option with these 3 fileFormats. +
fieldDelim, escapeDelim, collectionDelim, mapkeyDelim, lineDelim + These options can only be used with "textfile" fileFormat. They define how to read delimited files into rows. +
+ +All other properties defined with `OPTIONS` will be regarded as Hive serde properties. + ### Interacting with Different Versions of Hive Metastore One of the most important pieces of Spark SQL's Hive support is interaction with Hive metastore, @@ -1369,6 +1413,14 @@ options. - Dataset and DataFrame API `explode` has been deprecated, alternatively, use `functions.explode()` with `select` or `flatMap` - Dataset and DataFrame API `registerTempTable` has been deprecated and replaced by `createOrReplaceTempView` + - Changes to `CREATE TABLE ... LOCATION` behavior for Hive tables. + - From Spark 2.0, `CREATE TABLE ... LOCATION` is equivalent to `CREATE EXTERNAL TABLE ... LOCATION` + in order to prevent accidental dropping the existing data in the user-provided locations. + That means, a Hive table created in Spark SQL with the user-specified location is always a Hive external table. + Dropping external tables will not remove the data. Users are not allowed to specify the location for Hive managed tables. + Note that this is different from the Hive behavior. + - As a result, `DROP TABLE` statements on those tables will not remove the data. + ## Upgrading From Spark SQL 1.5 to 1.6 - From Spark 1.6, by default the Thrift server runs in multi-session mode. Which means each JDBC/ODBC diff --git a/docs/structured-streaming-kafka-integration.md b/docs/structured-streaming-kafka-integration.md index 2458bb5ffa298..9b82e8e7449a1 100644 --- a/docs/structured-streaming-kafka-integration.md +++ b/docs/structured-streaming-kafka-integration.md @@ -244,7 +244,7 @@ Note that the following Kafka params cannot be set and the Kafka source will thr - **group.id**: Kafka source will create a unique group id for each query automatically. - **auto.offset.reset**: Set the source option `startingOffsets` to specify where to start instead. Structured Streaming manages which offsets are consumed internally, rather - than rely on the kafka Consumer to do it. This will ensure that no data is missed when when new + than rely on the kafka Consumer to do it. This will ensure that no data is missed when new topics/partitions are dynamically subscribed. Note that `startingOffsets` only applies when a new Streaming query is started, and that resuming will always pick up from where the query left off. - **key.deserializer**: Keys are always deserialized as byte arrays with ByteArrayDeserializer. Use diff --git a/docs/structured-streaming-programming-guide.md b/docs/structured-streaming-programming-guide.md index 799f636505b3e..52dbbc830af64 100644 --- a/docs/structured-streaming-programming-guide.md +++ b/docs/structured-streaming-programming-guide.md @@ -374,7 +374,7 @@ The "Output" is defined as what gets written out to the external storage. The ou - *Append Mode* - Only the new rows appended in the Result Table since the last trigger will be written to the external storage. This is applicable only on the queries where existing rows in the Result Table are not expected to change. - - *Update Mode* - Only the rows that were updated in the Result Table since the last trigger will be written to the external storage (not available yet in Spark 2.0). Note that this is different from the Complete Mode in that this mode does not output the rows that are not changed. + - *Update Mode* - Only the rows that were updated in the Result Table since the last trigger will be written to the external storage (available since Spark 2.1.1). Note that this is different from the Complete Mode in that this mode only outputs the rows that have changed since the last trigger. Note that each mode is applicable on certain types of queries. This is discussed in detail [later](#output-modes). @@ -424,7 +424,7 @@ Streaming DataFrames can be created through the `DataStreamReader` interface ([Scala](api/scala/index.html#org.apache.spark.sql.streaming.DataStreamReader)/[Java](api/java/org/apache/spark/sql/streaming/DataStreamReader.html)/[Python](api/python/pyspark.sql.html#pyspark.sql.streaming.DataStreamReader) docs) returned by `SparkSession.readStream()`. Similar to the read interface for creating static DataFrame, you can specify the details of the source – data format, schema, options, etc. -#### Data Sources +#### Input Sources In Spark 2.0, there are a few built-in sources. - **File source** - Reads files written in a directory as a stream of data. Supported file formats are text, csv, json, parquet. See the docs of the DataStreamReader interface for a more up-to-date list, and supported options for each file format. Note that the files must be atomically placed in the given directory, which in most file systems, can be achieved by file move operations. @@ -433,6 +433,54 @@ In Spark 2.0, there are a few built-in sources. - **Socket source (for testing)** - Reads UTF8 text data from a socket connection. The listening server socket is at the driver. Note that this should be used only for testing as this does not provide end-to-end fault-tolerance guarantees. +Some sources are not fault-tolerant because they do not guarantee that data can be replayed using +checkpointed offsets after a failure. See the earlier section on +[fault-tolerance semantics](#fault-tolerance-semantics). +Here are the details of all the sources in Spark. + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
SourceOptionsFault-tolerantNotes
File source + path: path to the input directory, and common to all file formats. +

+ For file-format-specific options, see the related methods in DataStreamReader + (Scala/Java/Python). + E.g. for "parquet" format options see DataStreamReader.parquet()
YesSupports glob paths, but does not support multiple comma-separated paths/globs.
Socket Source + host: host to connect to, must be specified
+ port: port to connect to, must be specified +
No
Kafka Source + See the Kafka Integration Guide. + Yes
+ Here are some examples.
@@ -680,7 +728,7 @@ windowedCounts = words.groupBy( ### Handling Late Data and Watermarking Now consider what happens if one of the events arrives late to the application. -For example, say, a word generated at 12:04 (i.e. event time) could be received received by +For example, say, a word generated at 12:04 (i.e. event time) could be received by the application at 12:11. The application should use the time 12:04 instead of 12:11 to update the older counts for the window `12:00 - 12:10`. This occurs naturally in our window-based grouping – Structured Streaming can maintain the intermediate state @@ -753,34 +801,47 @@ windowedCounts = words In this example, we are defining the watermark of the query on the value of the column "timestamp", and also defining "10 minutes" as the threshold of how late is the data allowed to be. If this query -is run in Append output mode (discussed later in [Output Modes](#output-modes) section), -the engine will track the current event time from the column "timestamp" and wait for additional -"10 minutes" in event time before finalizing the windowed counts and adding them to the Result Table. +is run in Update output mode (discussed later in [Output Modes](#output-modes) section), +the engine will keep updating counts of a window in the Resule Table until the window is older +than the watermark, which lags behind the current event time in column "timestamp" by 10 minutes. Here is an illustration. -![Watermarking in Append Mode](img/structured-streaming-watermark.png) +![Watermarking in Update Mode](img/structured-streaming-watermark-update-mode.png) As shown in the illustration, the maximum event time tracked by the engine is the *blue dashed line*, and the watermark set as `(max event time - '10 mins')` at the beginning of every trigger is the red line For example, when the engine observes the data `(12:14, dog)`, it sets the watermark for the next trigger as `12:04`. -For the window `12:00 - 12:10`, the partial counts are maintained as internal state while the system -is waiting for late data. After the system finds data (i.e. `(12:21, owl)`) such that the -watermark exceeds 12:10, the partial count is finalized and appended to the table. This count will -not change any further as all "too-late" data older than 12:10 will be ignored. - -Note that in Append output mode, the system has to wait for "late threshold" time -before it can output the aggregation of a window. This may not be ideal if data can be very late, -(say 1 day) and you like to have partial counts without waiting for a day. In future, we will add -Update output mode which would allows every update to aggregates to be written to sink every trigger. +This watermark lets the engine maintain intermediate state for additional 10 minutes to allow late +data to be counted. For example, the data `(12:09, cat)` is out of order and late, and it falls in +windows `12:05 - 12:15` and `12:10 - 12:20`. Since, it is still ahead of the watermark `12:04` in +the trigger, the engine still maintains the intermediate counts as state and correctly updates the +counts of the related windows. However, when the watermark is updated to 12:11, the intermediate +state for window `(12:00 - 12:10)` is cleared, and all subsequent data (e.g. `(12:04, donkey)`) +is considered "too late" and therefore ignored. Note that after every trigger, +the updated counts (i.e. purple rows) are written to sink as the trigger output, as dictated by +the Update mode. + +Some sinks (e.g. files) may not supported fine-grained updates that Update Mode requires. To work +with them, we have also support Append Mode, where only the *final counts* are written to sink. +This is illustrated below. + +![Watermarking in Append Mode](img/structured-streaming-watermark-append-mode.png) + +Similar to the Update Mode earlier, the engine maintains intermediate counts for each window. +However, the partial counts are not updated to the Result Table and not written to sink. The engine +waits for "10 mins" for late date to be counted, +then drops intermediate state of a window < watermark, and appends the final +counts to the Result Table/sink. For example, the final counts of window `12:00 - 12:10` is +appended to the Result Table only after the watermark is updated to `12:11`. **Conditions for watermarking to clean aggregation state** It is important to note that the following conditions must be satisfied for the watermarking to -clean the state in aggregation queries *(as of Spark 2.1, subject to change in the future)*. +clean the state in aggregation queries *(as of Spark 2.1.1, subject to change in the future)*. -- **Output mode must be Append.** Complete mode requires all aggregate data to be preserved, and hence -cannot use watermarking to drop intermediate state. See the [Output Modes](#output-modes) section -for detailed explanation of the semantics of each output mode. +- **Output mode must be Append or Update.** Complete mode requires all aggregate data to be preserved, +and hence cannot use watermarking to drop intermediate state. See the [Output Modes](#output-modes) +section for detailed explanation of the semantics of each output mode. - The aggregation must have either the event-time column, or a `window` on the event-time column. @@ -835,8 +896,9 @@ streamingDf.join(staticDf, "type", "right_join") # right outer join with a stat
### Unsupported Operations -However, note that all of the operations applicable on static DataFrames/Datasets are not supported in streaming DataFrames/Datasets yet. While some of these unsupported operations will be supported in future releases of Spark, there are others which are fundamentally hard to implement on streaming data efficiently. For example, sorting is not supported on the input streaming Dataset, as it requires keeping track of all the data received in the stream. This is therefore fundamentally hard to execute efficiently. As of Spark 2.0, some of the unsupported operations are as follows - +There are a few DataFrame/Dataset operations that are not supported with streaming DataFrames/Datasets. +Some of them are as follows. + - Multiple streaming aggregations (i.e. a chain of aggregations on a streaming DF) are not yet supported on streaming Datasets. - Limit and take first N rows are not supported on streaming Datasets. @@ -863,7 +925,12 @@ In addition, there are some Dataset methods that will not work on streaming Data - `show()` - Instead use the console sink (see next section). -If you try any of these operations, you will see an AnalysisException like "operation XYZ is not supported with streaming DataFrames/Datasets". +If you try any of these operations, you will see an `AnalysisException` like "operation XYZ is not supported with streaming DataFrames/Datasets". +While some of them may be supported in future releases of Spark, +there are others which are fundamentally hard to implement on streaming data efficiently. +For example, sorting on the input stream is not supported, as it requires keeping +track of all the data received in the stream. This is therefore fundamentally hard to execute +efficiently. ## Starting Streaming Queries Once you have defined the final result DataFrame/Dataset, all that is left is for you start the streaming computation. To do that, you have to use the `DataStreamWriter` @@ -894,11 +961,11 @@ fault-tolerant sink). For example, queries with only `select`, - **Complete mode** - The whole Result Table will be outputted to the sink after every trigger. This is supported for aggregation queries. -- **Update mode** - (*not available in Spark 2.1*) Only the rows in the Result Table that were +- **Update mode** - (*Available since Spark 2.1.1*) Only the rows in the Result Table that were updated since the last trigger will be outputted to the sink. More information to be added in future releases. -Different types of streaming queries support different output modes. +Different types of streaming queries support different output modes. Here is the compatibility matrix. @@ -909,36 +976,38 @@ Here is the compatibility matrix. - - - + + - - - + + + - - + + @@ -954,49 +1023,94 @@ There are a few types of built-in output sinks. - **File sink** - Stores the output to a directory. +{% highlight scala %} +writeStream + .format("parquet") // can be "orc", "json", "csv", etc. + .option("path", "path/to/destination/dir") + .start() +{% endhighlight %} + - **Foreach sink** - Runs arbitrary computation on the records in the output. See later in the section for more details. +{% highlight scala %} +writeStream + .foreach(...) + .start() +{% endhighlight %} + - **Console sink (for debugging)** - Prints the output to the console/stdout every time there is a trigger. Both, Append and Complete output modes, are supported. This should be used for debugging purposes on low data volumes as the entire output is collected and stored in the driver's memory after every trigger. -- **Memory sink (for debugging)** - The output is stored in memory as an in-memory table. Both, Append and Complete output modes, are supported. This should be used for debugging purposes on low data volumes as the entire output is collected and stored in the driver's memory after every trigger. +{% highlight scala %} +writeStream + .format("console") + .start() +{% endhighlight %} + +- **Memory sink (for debugging)** - The output is stored in memory as an in-memory table. +Both, Append and Complete output modes, are supported. This should be used for debugging purposes +on low data volumes as the entire output is collected and stored in the driver's memory. +Hence, use it with caution. + +{% highlight scala %} +writeStream + .format("memory") + .queryName("tableName") + .start() +{% endhighlight %} -Here is a table of all the sinks, and the corresponding settings. +Some sinks are not fault-tolerant because they do not guarantee persistence of the output and are +meant for debugging purposes only. See the earlier section on +[fault-tolerance semantics](#fault-tolerance-semantics). +Here are the details of all the sinks in Spark.
Notes

Queries without aggregation
Append - Complete mode note supported as it is infeasible to keep all data in the Result Table. + Queries without aggregationAppend + Complete mode not supported as it is infeasible to keep all data in the Result Table.
Queries with aggregationAggregation on event-time with watermarkAppend, CompleteQueries with aggregationAggregation on event-time with watermarkAppend, Update, Complete Append mode uses watermark to drop old aggregation state. But the output of a windowed aggregation is delayed the late threshold specified in `withWatermark()` as by the modes semantics, rows can be added to the Result Table only once after they are - finalized (i.e. after watermark is crossed). See - Late Data section for more details. + finalized (i.e. after watermark is crossed). See the + Late Data section for more details. +

+ Update mode uses watermark to drop old aggregation state.

Complete mode does drop not old aggregation state since by definition this mode preserves all data in the Result Table.
Other aggregationsCompleteOther aggregationsComplete, Update + Since no watermark is defined (only defined in other category), + old aggregation state is not dropped. +

Append mode is not supported as aggregates can update thus violating the semantics of this mode. -

- Complete mode does drop not old aggregation state since by definition this mode - preserves all data in the Result Table.
- + - + - - + + - - + + - - - + + + @@ -1007,7 +1121,7 @@ Here is a table of all the sinks, and the corresponding settings.
Sink Supported Output ModesUsageOptions Fault-tolerant Notes
File Sink Append
writeStream
.format("parquet")
.option(
"checkpointLocation",
"path/to/checkpoint/dir")
.option(
"path",
"path/to/destination/dir")
.start()
+ path: path to the output directory, must be specified. + maxFilesPerTrigger: maximum number of new files to be considered in every trigger (default: no max) +
+ latestFirst: whether to processs the latest new files first, useful when there is a large backlog of files(default: false) +

+ For file-format-specific options, see the related methods in DataFrameWriter + (Scala/Java/Python). + E.g. for "parquet" format options see DataFrameWriter.parquet() +
Yes Supports writes to partitioned tables. Partitioning by time may be useful.
Foreach SinkAll modes
writeStream
.foreach(...)
.start()
Append, Update, CompeleteNone Depends on ForeachWriter implementation More details in the next section
Console SinkAppend, Complete
writeStream
.format("console")
.start()
Append, Update, Complete + numRows: Number of rows to print every trigger (default: 20) +
+ truncate: Whether to truncate the output if too long (default: true) +
No
Memory Sink Append, Complete
writeStream
.format("memory")
.queryName("table")
.start()
NoSaves the output data as a table, for interactive querying. Table name is the query name.NoneNo. But in Complete Mode, restarted query will recreate the full table.Table name is the query name.
-Finally, you have to call `start()` to actually start the execution of the query. This returns a StreamingQuery object which is a handle to the continuously running execution. You can use this object to manage the query, which we will discuss in the next subsection. For now, let’s understand all this with a few examples. +Note that you have to call `start()` to actually start the execution of the query. This returns a StreamingQuery object which is a handle to the continuously running execution. You can use this object to manage the query, which we will discuss in the next subsection. For now, let’s understand all this with a few examples.
diff --git a/examples/src/main/java/org/apache/spark/examples/ml/JavaLDAExample.java b/examples/src/main/java/org/apache/spark/examples/ml/JavaLDAExample.java index 9041244279079..0e5d00565b71a 100644 --- a/examples/src/main/java/org/apache/spark/examples/ml/JavaLDAExample.java +++ b/examples/src/main/java/org/apache/spark/examples/ml/JavaLDAExample.java @@ -52,7 +52,7 @@ public static void main(String[] args) { double ll = model.logLikelihood(dataset); double lp = model.logPerplexity(dataset); System.out.println("The lower bound on the log likelihood of the entire corpus: " + ll); - System.out.println("The upper bound bound on perplexity: " + lp); + System.out.println("The upper bound on perplexity: " + lp); // Describe topics. Dataset topics = model.describeTopics(3); diff --git a/examples/src/main/java/org/apache/spark/examples/sql/hive/JavaSparkHiveExample.java b/examples/src/main/java/org/apache/spark/examples/sql/hive/JavaSparkHiveExample.java index 052153c9e9736..8d06d38cf2c6b 100644 --- a/examples/src/main/java/org/apache/spark/examples/sql/hive/JavaSparkHiveExample.java +++ b/examples/src/main/java/org/apache/spark/examples/sql/hive/JavaSparkHiveExample.java @@ -64,7 +64,7 @@ public static void main(String[] args) { .enableHiveSupport() .getOrCreate(); - spark.sql("CREATE TABLE IF NOT EXISTS src (key INT, value STRING)"); + spark.sql("CREATE TABLE IF NOT EXISTS src (key INT, value STRING) USING hive"); spark.sql("LOAD DATA LOCAL INPATH 'examples/src/main/resources/kv1.txt' INTO TABLE src"); // Queries are expressed in HiveQL diff --git a/examples/src/main/python/ml/lda_example.py b/examples/src/main/python/ml/lda_example.py index 2dc1742ff7a0b..a8b346f72cd6f 100644 --- a/examples/src/main/python/ml/lda_example.py +++ b/examples/src/main/python/ml/lda_example.py @@ -46,7 +46,7 @@ ll = model.logLikelihood(dataset) lp = model.logPerplexity(dataset) print("The lower bound on the log likelihood of the entire corpus: " + str(ll)) - print("The upper bound bound on perplexity: " + str(lp)) + print("The upper bound on perplexity: " + str(lp)) # Describe topics. topics = model.describeTopics(3) diff --git a/examples/src/main/python/sql/hive.py b/examples/src/main/python/sql/hive.py index ad83fe1cf14b5..ba01544a5bd2e 100644 --- a/examples/src/main/python/sql/hive.py +++ b/examples/src/main/python/sql/hive.py @@ -44,7 +44,7 @@ .getOrCreate() # spark is an existing SparkSession - spark.sql("CREATE TABLE IF NOT EXISTS src (key INT, value STRING)") + spark.sql("CREATE TABLE IF NOT EXISTS src (key INT, value STRING) USING hive") spark.sql("LOAD DATA LOCAL INPATH 'examples/src/main/resources/kv1.txt' INTO TABLE src") # Queries are expressed in HiveQL diff --git a/examples/src/main/r/RSparkSQLExample.R b/examples/src/main/r/RSparkSQLExample.R index 373a36dba14f0..e647f0e1e9f17 100644 --- a/examples/src/main/r/RSparkSQLExample.R +++ b/examples/src/main/r/RSparkSQLExample.R @@ -195,7 +195,7 @@ head(teenagers) # $example on:spark_hive$ # enableHiveSupport defaults to TRUE sparkR.session(enableHiveSupport = TRUE) -sql("CREATE TABLE IF NOT EXISTS src (key INT, value STRING)") +sql("CREATE TABLE IF NOT EXISTS src (key INT, value STRING) USING hive") sql("LOAD DATA LOCAL INPATH 'examples/src/main/resources/kv1.txt' INTO TABLE src") # Queries can be expressed in HiveQL. diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/LDAExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/LDAExample.scala index 22b3b0e3ad9c1..4215d37cb59d5 100644 --- a/examples/src/main/scala/org/apache/spark/examples/ml/LDAExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/ml/LDAExample.scala @@ -50,7 +50,7 @@ object LDAExample { val ll = model.logLikelihood(dataset) val lp = model.logPerplexity(dataset) println(s"The lower bound on the log likelihood of the entire corpus: $ll") - println(s"The upper bound bound on perplexity: $lp") + println(s"The upper bound on perplexity: $lp") // Describe topics. val topics = model.describeTopics(3) diff --git a/examples/src/main/scala/org/apache/spark/examples/sql/hive/SparkHiveExample.scala b/examples/src/main/scala/org/apache/spark/examples/sql/hive/SparkHiveExample.scala index ded18dacf1fe3..d29ed958fe8f4 100644 --- a/examples/src/main/scala/org/apache/spark/examples/sql/hive/SparkHiveExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/sql/hive/SparkHiveExample.scala @@ -50,7 +50,7 @@ object SparkHiveExample { import spark.implicits._ import spark.sql - sql("CREATE TABLE IF NOT EXISTS src (key INT, value STRING)") + sql("CREATE TABLE IF NOT EXISTS src (key INT, value STRING) USING hive") sql("LOAD DATA LOCAL INPATH 'examples/src/main/resources/kv1.txt' INTO TABLE src") // Queries are expressed in HiveQL diff --git a/external/flume-sink/src/main/scala/org/apache/spark/streaming/flume/sink/SparkSink.scala b/external/flume-sink/src/main/scala/org/apache/spark/streaming/flume/sink/SparkSink.scala index 41f27e937662f..e5b63aa1a77ef 100644 --- a/external/flume-sink/src/main/scala/org/apache/spark/streaming/flume/sink/SparkSink.scala +++ b/external/flume-sink/src/main/scala/org/apache/spark/streaming/flume/sink/SparkSink.scala @@ -45,7 +45,7 @@ import org.apache.flume.sink.AbstractSink * the thread itself is blocked and a reference to it saved off. * * When the ack for that batch is received, - * the thread which created the transaction is is retrieved and it commits the transaction with the + * the thread which created the transaction is retrieved and it commits the transaction with the * channel from the same thread it was originally created in (since Flume transactions are * thread local). If a nack is received instead, the sink rolls back the transaction. If no ack * is received within the specified timeout, the transaction is rolled back too. If an ack comes diff --git a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSourceProvider.scala b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSourceProvider.scala index aa01238f91247..ff9965b854c63 100644 --- a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSourceProvider.scala +++ b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSourceProvider.scala @@ -212,7 +212,7 @@ private[kafka010] class KafkaSourceProvider extends StreamSourceProvider |Instead set the source option '$STARTING_OFFSETS_OPTION_KEY' to 'earliest' or 'latest' |to specify where to start. Structured Streaming manages which offsets are consumed |internally, rather than relying on the kafkaConsumer to do it. This will ensure that no - |data is missed when when new topics/partitions are dynamically subscribed. Note that + |data is missed when new topics/partitions are dynamically subscribed. Note that |'$STARTING_OFFSETS_OPTION_KEY' only applies when a new Streaming query is started, and |that resuming will always pick up from where the query left off. See the docs for more |details. diff --git a/external/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KinesisBackedBlockRDDSuite.scala b/external/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KinesisBackedBlockRDDSuite.scala index a4d81a680979e..18a5a1509a33a 100644 --- a/external/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KinesisBackedBlockRDDSuite.scala +++ b/external/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KinesisBackedBlockRDDSuite.scala @@ -129,7 +129,7 @@ abstract class KinesisBackedBlockRDDTests(aggregateTestData: Boolean) /** * Test the WriteAheadLogBackedRDD, by writing some partitions of the data to block manager - * and the rest to a write ahead log, and then reading reading it all back using the RDD. + * and the rest to a write ahead log, and then reading it all back using the RDD. * It can also test if the partitions that were read from the log were again stored in * block manager. * diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/OneVsRest.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/OneVsRest.scala index e58b30d66588c..cbd508ae79acd 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/OneVsRest.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/OneVsRest.scala @@ -308,6 +308,10 @@ final class OneVsRest @Since("1.4.0") ( override def fit(dataset: Dataset[_]): OneVsRestModel = { transformSchema(dataset.schema) + val instr = Instrumentation.create(this, dataset) + instr.logParams(labelCol, featuresCol, predictionCol) + instr.logNamedValue("classifier", $(classifier).getClass.getCanonicalName) + // determine number of classes either from metadata if provided, or via computation. val labelSchema = dataset.schema($(labelCol)) val computeNumClasses: () => Int = () => { @@ -316,6 +320,7 @@ final class OneVsRest @Since("1.4.0") ( maxLabelIndex.toInt + 1 } val numClasses = MetadataUtils.getNumClasses(labelSchema).fold(computeNumClasses())(identity) + instr.logNumClasses(numClasses) val multiclassLabeled = dataset.select($(labelCol), $(featuresCol)) @@ -339,6 +344,7 @@ final class OneVsRest @Since("1.4.0") ( paramMap.put(classifier.predictionCol -> getPredictionCol) classifier.fit(trainingDataset, paramMap) }.toArray[ClassificationModel[_, _]] + instr.logNumFeatures(models.head.numFeatures) if (handlePersistence) { multiclassLabeled.unpersist() @@ -352,6 +358,7 @@ final class OneVsRest @Since("1.4.0") ( case attr: Attribute => attr } val model = new OneVsRestModel(uid, labelAttribute.toMetadata(), models).setParent(this) + instr.logSuccess(model) copyValues(model) } diff --git a/mllib/src/main/scala/org/apache/spark/ml/clustering/LDA.scala b/mllib/src/main/scala/org/apache/spark/ml/clustering/LDA.scala index e8b6c4105055d..ac45117529efa 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/clustering/LDA.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/clustering/LDA.scala @@ -512,7 +512,7 @@ abstract class LDAModel private[ml] ( } /** - * Calculate an upper bound bound on perplexity. (Lower is better.) + * Calculate an upper bound on perplexity. (Lower is better.) * See Equation (16) in the Online LDA paper (Hoffman et al., 2010). * * WARNING: If this model is an instance of [[DistributedLDAModel]] (produced when [[optimizer]] diff --git a/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala b/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala index 3e852a683672a..995780bf64f63 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala @@ -457,10 +457,12 @@ class ALS(@Since("1.4.0") override val uid: String) extends Estimator[ALSModel] .map { row => Rating(row.getInt(0), row.getInt(1), row.getFloat(2)) } - val instrLog = Instrumentation.create(this, ratings) - instrLog.logParams(rank, numUserBlocks, numItemBlocks, implicitPrefs, alpha, userCol, + + val instr = Instrumentation.create(this, ratings) + instr.logParams(rank, numUserBlocks, numItemBlocks, implicitPrefs, alpha, userCol, itemCol, ratingCol, predictionCol, maxIter, regParam, nonnegative, checkpointInterval, seed, intermediateStorageLevel, finalStorageLevel) + val (userFactors, itemFactors) = ALS.train(ratings, rank = $(rank), numUserBlocks = $(numUserBlocks), numItemBlocks = $(numItemBlocks), maxIter = $(maxIter), regParam = $(regParam), implicitPrefs = $(implicitPrefs), @@ -471,7 +473,7 @@ class ALS(@Since("1.4.0") override val uid: String) extends Estimator[ALSModel] val userDF = userFactors.toDF("id", "features") val itemDF = itemFactors.toDF("id", "features") val model = new ALSModel(uid, $(rank), userDF, itemDF).setParent(this) - instrLog.logSuccess(model) + instr.logSuccess(model) copyValues(model) } diff --git a/mllib/src/main/scala/org/apache/spark/ml/source/libsvm/LibSVMRelation.scala b/mllib/src/main/scala/org/apache/spark/ml/source/libsvm/LibSVMRelation.scala index b5aa7ce4e115b..89bbc1556c2d8 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/source/libsvm/LibSVMRelation.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/source/libsvm/LibSVMRelation.scala @@ -45,9 +45,12 @@ private[libsvm] class LibSVMOutputWriter( private val writer = CodecStreams.createOutputStreamWriter(context, new Path(path)) - override def write(row: Row): Unit = { - val label = row.get(0) - val vector = row.get(1).asInstanceOf[Vector] + // This `asInstanceOf` is safe because it's guaranteed by `LibSVMFileFormat.verifySchema` + private val udt = dataSchema(1).dataType.asInstanceOf[VectorUDT] + + override def write(row: InternalRow): Unit = { + val label = row.getDouble(0) + val vector = udt.deserialize(row.getStruct(1, udt.sqlType.length)) writer.write(label.toString) vector.foreachActive { case (i, v) => writer.write(s" ${i + 1}:$v") @@ -115,6 +118,7 @@ private[libsvm] class LibSVMFileFormat extends TextBasedFileFormat with DataSour job: Job, options: Map[String, String], dataSchema: StructType): OutputWriterFactory = { + verifySchema(dataSchema) new OutputWriterFactory { override def newInstance( path: String, diff --git a/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala b/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala index 85191d46fd360..2012d6ca8b5ea 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala @@ -101,6 +101,11 @@ class CrossValidator @Since("1.2.0") (@Since("1.4.0") override val uid: String) val epm = $(estimatorParamMaps) val numModels = epm.length val metrics = new Array[Double](epm.length) + + val instr = Instrumentation.create(this, dataset) + instr.logParams(numFolds, seed) + logTuningParams(instr) + val splits = MLUtils.kFold(dataset.toDF.rdd, $(numFolds), $(seed)) splits.zipWithIndex.foreach { case ((training, validation), splitIndex) => val trainingDataset = sparkSession.createDataFrame(training, schema).cache() @@ -127,6 +132,7 @@ class CrossValidator @Since("1.2.0") (@Since("1.4.0") override val uid: String) logInfo(s"Best set of parameters:\n${epm(bestIndex)}") logInfo(s"Best cross-validation metric: $bestMetric.") val bestModel = est.fit(dataset, epm(bestIndex)).asInstanceOf[Model[_]] + instr.logSuccess(bestModel) copyValues(new CrossValidatorModel(uid, bestModel, metrics).setParent(this)) } diff --git a/mllib/src/main/scala/org/apache/spark/ml/tuning/TrainValidationSplit.scala b/mllib/src/main/scala/org/apache/spark/ml/tuning/TrainValidationSplit.scala index 5d1a39f7c16db..db7c9d13d301a 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/tuning/TrainValidationSplit.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/tuning/TrainValidationSplit.scala @@ -97,6 +97,10 @@ class TrainValidationSplit @Since("1.5.0") (@Since("1.5.0") override val uid: St val numModels = epm.length val metrics = new Array[Double](epm.length) + val instr = Instrumentation.create(this, dataset) + instr.logParams(trainRatio, seed) + logTuningParams(instr) + val Array(trainingDataset, validationDataset) = dataset.randomSplit(Array($(trainRatio), 1 - $(trainRatio)), $(seed)) trainingDataset.cache() @@ -123,6 +127,7 @@ class TrainValidationSplit @Since("1.5.0") (@Since("1.5.0") override val uid: St logInfo(s"Best set of parameters:\n${epm(bestIndex)}") logInfo(s"Best train validation split metric: $bestMetric.") val bestModel = est.fit(dataset, epm(bestIndex)).asInstanceOf[Model[_]] + instr.logSuccess(bestModel) copyValues(new TrainValidationSplitModel(uid, bestModel, metrics).setParent(this)) } diff --git a/mllib/src/main/scala/org/apache/spark/ml/tuning/ValidatorParams.scala b/mllib/src/main/scala/org/apache/spark/ml/tuning/ValidatorParams.scala index 26fd73814d70a..d55eb14d03456 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/tuning/ValidatorParams.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/tuning/ValidatorParams.scala @@ -26,7 +26,7 @@ import org.apache.spark.ml.{Estimator, Model} import org.apache.spark.ml.evaluation.Evaluator import org.apache.spark.ml.param.{Param, ParamMap, ParamPair, Params} import org.apache.spark.ml.param.shared.HasSeed -import org.apache.spark.ml.util.{DefaultParamsReader, DefaultParamsWriter, MetaAlgorithmReadWrite, MLWritable} +import org.apache.spark.ml.util._ import org.apache.spark.ml.util.DefaultParamsReader.Metadata import org.apache.spark.sql.types.StructType @@ -76,6 +76,15 @@ private[ml] trait ValidatorParams extends HasSeed with Params { } est.copy(firstEstimatorParamMap).transformSchema(schema) } + + /** + * Instrumentation logging for tuning params including the inner estimator and evaluator info. + */ + protected def logTuningParams(instrumentation: Instrumentation[_]): Unit = { + instrumentation.logNamedValue("estimator", $(estimator).getClass.getCanonicalName) + instrumentation.logNamedValue("evaluator", $(evaluator).getClass.getCanonicalName) + instrumentation.logNamedValue("estimatorParamMapsLength", $(estimatorParamMaps).length) + } } private[ml] object ValidatorParams { diff --git a/mllib/src/main/scala/org/apache/spark/ml/util/Instrumentation.scala b/mllib/src/main/scala/org/apache/spark/ml/util/Instrumentation.scala index feba56d18b05f..7c46f45c59717 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/util/Instrumentation.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/util/Instrumentation.scala @@ -87,8 +87,12 @@ private[spark] class Instrumentation[E <: Estimator[_]] private ( /** * Logs the value with customized name field. */ - def logNamedValue(name: String, num: Long): Unit = { - log(compact(render(name -> num))) + def logNamedValue(name: String, value: String): Unit = { + log(compact(render(name -> value))) + } + + def logNamedValue(name: String, value: Long): Unit = { + log(compact(render(name -> value))) } /** diff --git a/mllib/src/main/scala/org/apache/spark/mllib/api/python/Word2VecModelWrapper.scala b/mllib/src/main/scala/org/apache/spark/mllib/api/python/Word2VecModelWrapper.scala index 5cbfbff3e4a62..4d6520d0b2ee0 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/api/python/Word2VecModelWrapper.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/api/python/Word2VecModelWrapper.scala @@ -54,7 +54,7 @@ private[python] class Word2VecModelWrapper(model: Word2VecModel) { } /** - * Finds words similar to the the vector representation of a word without + * Finds words similar to the vector representation of a word without * filtering results. * @param vector a vector * @param num number of synonyms to find diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAModel.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAModel.scala index 25ffd8561fe37..ae336982092d6 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAModel.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAModel.scala @@ -245,7 +245,7 @@ class LocalLDAModel private[spark] ( } /** - * Calculate an upper bound bound on perplexity. (Lower is better.) + * Calculate an upper bound on perplexity. (Lower is better.) * See Equation (16) in original Online LDA paper. * * @param documents test corpus to use for calculating perplexity @@ -745,12 +745,12 @@ class DistributedLDAModel private[clustering] ( val N_wk = vertex._2 val smoothed_N_wk: TopicCounts = N_wk + (eta - 1.0) val phi_wk: TopicCounts = smoothed_N_wk :/ smoothed_N_k - (eta - 1.0) * sum(phi_wk.map(math.log)) + sumPrior + (eta - 1.0) * sum(phi_wk.map(math.log)) } else { val N_kj = vertex._2 val smoothed_N_kj: TopicCounts = N_kj + (alpha - 1.0) val theta_kj: TopicCounts = normalize(smoothed_N_kj, 1.0) - (alpha - 1.0) * sum(theta_kj.map(math.log)) + sumPrior + (alpha - 1.0) * sum(theta_kj.map(math.log)) } } graph.vertices.aggregate(0.0)(seqOp, _ + _) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/regression/StreamingLinearAlgorithm.scala b/mllib/src/main/scala/org/apache/spark/mllib/regression/StreamingLinearAlgorithm.scala index 46deb545af3f0..f44c8fe351459 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/regression/StreamingLinearAlgorithm.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/regression/StreamingLinearAlgorithm.scala @@ -29,7 +29,7 @@ import org.apache.spark.streaming.dstream.DStream /** * :: DeveloperApi :: * StreamingLinearAlgorithm implements methods for continuously - * training a generalized linear model model on streaming data, + * training a generalized linear model on streaming data, * and using it for prediction on (possibly different) streaming data. * * This class takes as type parameters a GeneralizedLinearModel, diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala index 1308210417b1b..c14dcbd55287d 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala @@ -360,7 +360,7 @@ class LogisticRegressionSuite } // force it to use raw2prediction - model.setProbabilityCol("") + model.setRawPredictionCol("rawPrediction").setProbabilityCol("") val resultsUsingRaw2Predict = model.transform(smallMultinomialDataset).select("prediction").as[Double].collect() resultsUsingRaw2Predict.zip(results.select("prediction").as[Double].collect()).foreach { @@ -368,7 +368,7 @@ class LogisticRegressionSuite } // force it to use probability2prediction - model.setRawPredictionCol("") + model.setRawPredictionCol("").setProbabilityCol("probability") val resultsUsingProb2Predict = model.transform(smallMultinomialDataset).select("prediction").as[Double].collect() resultsUsingProb2Predict.zip(results.select("prediction").as[Double].collect()).foreach { @@ -414,7 +414,7 @@ class LogisticRegressionSuite } // force it to use raw2prediction - model.setProbabilityCol("") + model.setRawPredictionCol("rawPrediction").setProbabilityCol("") val resultsUsingRaw2Predict = model.transform(smallBinaryDataset).select("prediction").as[Double].collect() resultsUsingRaw2Predict.zip(results.select("prediction").as[Double].collect()).foreach { @@ -422,7 +422,7 @@ class LogisticRegressionSuite } // force it to use probability2prediction - model.setRawPredictionCol("") + model.setRawPredictionCol("").setProbabilityCol("probability") val resultsUsingProb2Predict = model.transform(smallBinaryDataset).select("prediction").as[Double].collect() resultsUsingProb2Predict.zip(results.select("prediction").as[Double].collect()).foreach { diff --git a/mllib/src/test/scala/org/apache/spark/ml/clustering/LDASuite.scala b/mllib/src/test/scala/org/apache/spark/ml/clustering/LDASuite.scala index 3f39deddf20b4..9aa11fbdbe868 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/clustering/LDASuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/clustering/LDASuite.scala @@ -260,6 +260,14 @@ class LDASuite extends SparkFunSuite with MLlibTestSparkContext with DefaultRead Vectors.dense(model2.topicsMatrix.toArray) absTol 1e-6) assert(Vectors.dense(model.getDocConcentration) ~== Vectors.dense(model2.getDocConcentration) absTol 1e-6) + val logPrior = model.asInstanceOf[DistributedLDAModel].logPrior + val logPrior2 = model2.asInstanceOf[DistributedLDAModel].logPrior + val trainingLogLikelihood = + model.asInstanceOf[DistributedLDAModel].trainingLogLikelihood + val trainingLogLikelihood2 = + model2.asInstanceOf[DistributedLDAModel].trainingLogLikelihood + assert(logPrior ~== logPrior2 absTol 1e-6) + assert(trainingLogLikelihood ~== trainingLogLikelihood2 absTol 1e-6) } val lda = new LDA() testEstimatorAndModelReadWrite(lda, dataset, diff --git a/mllib/src/test/scala/org/apache/spark/ml/source/libsvm/LibSVMRelationSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/source/libsvm/LibSVMRelationSuite.scala index 2517de59fed63..c701f3823842c 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/source/libsvm/LibSVMRelationSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/source/libsvm/LibSVMRelationSuite.scala @@ -17,12 +17,12 @@ package org.apache.spark.ml.source.libsvm -import java.io.File +import java.io.{File, IOException} import java.nio.charset.StandardCharsets import com.google.common.io.Files -import org.apache.spark.{SparkException, SparkFunSuite} +import org.apache.spark.SparkFunSuite import org.apache.spark.ml.linalg.{DenseVector, SparseVector, Vector, Vectors} import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.sql.{Row, SaveMode} @@ -100,7 +100,7 @@ class LibSVMRelationSuite extends SparkFunSuite with MLlibTestSparkContext { test("write libsvm data failed due to invalid schema") { val df = spark.read.format("text").load(path) - intercept[SparkException] { + intercept[IOException] { df.write.format("libsvm").save(path + "_2") } } diff --git a/python/pyspark/ml/clustering.py b/python/pyspark/ml/clustering.py index 35d0aefa04a8e..54510e0beae7c 100644 --- a/python/pyspark/ml/clustering.py +++ b/python/pyspark/ml/clustering.py @@ -699,7 +699,7 @@ def logLikelihood(self, dataset): @since("2.0.0") def logPerplexity(self, dataset): """ - Calculate an upper bound bound on perplexity. (Lower is better.) + Calculate an upper bound on perplexity. (Lower is better.) See Equation (16) in the Online LDA paper (Hoffman et al., 2010). WARNING: If this model is an instance of :py:class:`DistributedLDAModel` (produced when diff --git a/python/pyspark/ml/linalg/__init__.py b/python/pyspark/ml/linalg/__init__.py index 1705c156ce4c8..b765343251965 100644 --- a/python/pyspark/ml/linalg/__init__.py +++ b/python/pyspark/ml/linalg/__init__.py @@ -481,7 +481,7 @@ def __init__(self, size, *args): >>> SparseVector(4, {1:1.0, 6:2.0}) Traceback (most recent call last): ... - AssertionError: Index 6 is out of the the size of vector with size=4 + AssertionError: Index 6 is out of the size of vector with size=4 >>> SparseVector(4, {-1:1.0}) Traceback (most recent call last): ... @@ -521,7 +521,7 @@ def __init__(self, size, *args): if self.indices.size > 0: assert np.max(self.indices) < self.size, \ - "Index %d is out of the the size of vector with size=%d" \ + "Index %d is out of the size of vector with size=%d" \ % (np.max(self.indices), self.size) assert np.min(self.indices) >= 0, \ "Contains negative index %d" % (np.min(self.indices)) diff --git a/python/pyspark/sql/dataframe.py b/python/pyspark/sql/dataframe.py index b9d90384e3e2c..10e42d0f9d322 100644 --- a/python/pyspark/sql/dataframe.py +++ b/python/pyspark/sql/dataframe.py @@ -730,8 +730,9 @@ def join(self, other, on=None, how=None): a join expression (Column), or a list of Columns. If `on` is a string or a list of strings indicating the name of the join column(s), the column(s) must exist on both sides, and this performs an equi-join. - :param how: str, default 'inner'. - One of `inner`, `outer`, `left_outer`, `right_outer`, `leftsemi`. + :param how: str, default ``inner``. Must be one of: ``inner``, ``cross``, ``outer``, + ``full``, ``full_outer``, ``left``, ``left_outer``, ``right``, ``right_outer``, + ``left_semi``, and ``left_anti``. The following performs a full outer join between ``df1`` and ``df2``. diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py index d8abafcde3846..7fe901a4fbbaf 100644 --- a/python/pyspark/sql/functions.py +++ b/python/pyspark/sql/functions.py @@ -157,17 +157,21 @@ def _(): 'dense_rank': """returns the rank of rows within a window partition, without any gaps. - The difference between rank and denseRank is that denseRank leaves no gaps in ranking - sequence when there are ties. That is, if you were ranking a competition using denseRank + The difference between rank and dense_rank is that dense_rank leaves no gaps in ranking + sequence when there are ties. That is, if you were ranking a competition using dense_rank and had three people tie for second place, you would say that all three were in second - place and that the next person came in third.""", + place and that the next person came in third. Rank would give me sequential numbers, making + the person that came in third place (after the ties) would register as coming in fifth. + + This is equivalent to the DENSE_RANK function in SQL.""", 'rank': """returns the rank of rows within a window partition. - The difference between rank and denseRank is that denseRank leaves no gaps in ranking - sequence when there are ties. That is, if you were ranking a competition using denseRank + The difference between rank and dense_rank is that dense_rank leaves no gaps in ranking + sequence when there are ties. That is, if you were ranking a competition using dense_rank and had three people tie for second place, you would say that all three were in second - place and that the next person came in third. + place and that the next person came in third. Rank would give me sequential numbers, making + the person that came in third place (after the ties) would register as coming in fifth. This is equivalent to the RANK function in SQL.""", 'cume_dist': diff --git a/python/pyspark/sql/types.py b/python/pyspark/sql/types.py index 4a023123b6eca..26b54a7fb3709 100644 --- a/python/pyspark/sql/types.py +++ b/python/pyspark/sql/types.py @@ -1389,7 +1389,9 @@ class Row(tuple): ``key in row`` will search through row keys. Row can be used to create a row object by using named arguments, - the fields will be sorted by names. + the fields will be sorted by names. It is not allowed to omit + a named argument to represent the value is None or missing. This should be + explicitly set to None in this case. >>> row = Row(name="Alice", age=11) >>> row diff --git a/python/pyspark/sql/utils.py b/python/pyspark/sql/utils.py index 2a85ec01bc92a..7bc6a59ad3b26 100644 --- a/python/pyspark/sql/utils.py +++ b/python/pyspark/sql/utils.py @@ -95,7 +95,7 @@ def install_exception_handler(): original = py4j.protocol.get_return_value # The original `get_return_value` is not patched, it's idempotent. patched = capture_sql_exception(original) - # only patch the one used in in py4j.java_gateway (call Java API) + # only patch the one used in py4j.java_gateway (call Java API) py4j.java_gateway.get_return_value = patched diff --git a/resource-managers/mesos/src/main/scala/org/apache/spark/executor/MesosExecutorBackend.scala b/resource-managers/mesos/src/main/scala/org/apache/spark/executor/MesosExecutorBackend.scala index ee9149ce0208b..b252539782580 100644 --- a/resource-managers/mesos/src/main/scala/org/apache/spark/executor/MesosExecutorBackend.scala +++ b/resource-managers/mesos/src/main/scala/org/apache/spark/executor/MesosExecutorBackend.scala @@ -29,7 +29,8 @@ import org.apache.spark.{SparkConf, SparkEnv, TaskState} import org.apache.spark.TaskState import org.apache.spark.deploy.SparkHadoopUtil import org.apache.spark.internal.Logging -import org.apache.spark.scheduler.cluster.mesos.{MesosSchedulerUtils, MesosTaskLaunchData} +import org.apache.spark.scheduler.TaskDescription +import org.apache.spark.scheduler.cluster.mesos.MesosSchedulerUtils import org.apache.spark.util.Utils private[spark] class MesosExecutorBackend @@ -84,14 +85,12 @@ private[spark] class MesosExecutorBackend } override def launchTask(d: ExecutorDriver, taskInfo: TaskInfo) { - val taskId = taskInfo.getTaskId.getValue.toLong - val taskData = MesosTaskLaunchData.fromByteString(taskInfo.getData) + val taskDescription = TaskDescription.decode(taskInfo.getData.asReadOnlyByteBuffer()) if (executor == null) { logError("Received launchTask but executor was null") } else { SparkHadoopUtil.get.runAsSparkUser { () => - executor.launchTask(this, taskId = taskId, attemptNumber = taskData.attemptNumber, - taskInfo.getName, taskData.serializedTask) + executor.launchTask(this, taskDescription) } } } diff --git a/resource-managers/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosFineGrainedSchedulerBackend.scala b/resource-managers/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosFineGrainedSchedulerBackend.scala index 779ffb52299cc..7e561916a71e2 100644 --- a/resource-managers/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosFineGrainedSchedulerBackend.scala +++ b/resource-managers/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosFineGrainedSchedulerBackend.scala @@ -351,7 +351,7 @@ private[spark] class MesosFineGrainedSchedulerBackend( .setExecutor(executorInfo) .setName(task.name) .addAllResources(cpuResources.asJava) - .setData(MesosTaskLaunchData(task.serializedTask, task.attemptNumber).toByteString) + .setData(ByteString.copyFrom(TaskDescription.encode(task))) .build() (taskInfo, finalResources.asJava) } diff --git a/resource-managers/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosTaskLaunchData.scala b/resource-managers/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosTaskLaunchData.scala deleted file mode 100644 index 8370b61145e45..0000000000000 --- a/resource-managers/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosTaskLaunchData.scala +++ /dev/null @@ -1,51 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.scheduler.cluster.mesos - -import java.nio.ByteBuffer - -import org.apache.mesos.protobuf.ByteString - -import org.apache.spark.internal.Logging - -/** - * Wrapper for serializing the data sent when launching Mesos tasks. - */ -private[spark] case class MesosTaskLaunchData( - serializedTask: ByteBuffer, - attemptNumber: Int) extends Logging { - - def toByteString: ByteString = { - val dataBuffer = ByteBuffer.allocate(4 + serializedTask.limit) - dataBuffer.putInt(attemptNumber) - dataBuffer.put(serializedTask) - dataBuffer.rewind - logDebug(s"ByteBuffer size: [${dataBuffer.remaining}]") - ByteString.copyFrom(dataBuffer) - } -} - -private[spark] object MesosTaskLaunchData extends Logging { - def fromByteString(byteString: ByteString): MesosTaskLaunchData = { - val byteBuffer = byteString.asReadOnlyByteBuffer() - logDebug(s"ByteBuffer size: [${byteBuffer.remaining}]") - val attemptNumber = byteBuffer.getInt // updates the position by 4 bytes - val serializedTask = byteBuffer.slice() // subsequence starting at the current position - MesosTaskLaunchData(serializedTask, attemptNumber) - } -} diff --git a/resource-managers/mesos/src/test/scala/org/apache/spark/scheduler/cluster/mesos/MesosFineGrainedSchedulerBackendSuite.scala b/resource-managers/mesos/src/test/scala/org/apache/spark/scheduler/cluster/mesos/MesosFineGrainedSchedulerBackendSuite.scala index 1d7a86f4b0904..4ee85b91830a9 100644 --- a/resource-managers/mesos/src/test/scala/org/apache/spark/scheduler/cluster/mesos/MesosFineGrainedSchedulerBackendSuite.scala +++ b/resource-managers/mesos/src/test/scala/org/apache/spark/scheduler/cluster/mesos/MesosFineGrainedSchedulerBackendSuite.scala @@ -21,6 +21,7 @@ import java.nio.ByteBuffer import java.util.Arrays import java.util.Collection import java.util.Collections +import java.util.Properties import scala.collection.JavaConverters._ import scala.collection.mutable @@ -246,7 +247,16 @@ class MesosFineGrainedSchedulerBackendSuite mesosOffers.get(2).getHostname, (minCpu - backend.mesosExecutorCores).toInt ) - val taskDesc = new TaskDescription(1L, 0, "s1", "n1", 0, ByteBuffer.wrap(new Array[Byte](0))) + val taskDesc = new TaskDescription( + taskId = 1L, + attemptNumber = 0, + executorId = "s1", + name = "n1", + index = 0, + addedFiles = mutable.Map.empty[String, Long], + addedJars = mutable.Map.empty[String, Long], + properties = new Properties(), + ByteBuffer.wrap(new Array[Byte](0))) when(taskScheduler.resourceOffers(expectedWorkerOffers)).thenReturn(Seq(Seq(taskDesc))) when(taskScheduler.CPUS_PER_TASK).thenReturn(2) @@ -345,7 +355,16 @@ class MesosFineGrainedSchedulerBackendSuite 2 // Deducting 1 for executor ) - val taskDesc = new TaskDescription(1L, 0, "s1", "n1", 0, ByteBuffer.wrap(new Array[Byte](0))) + val taskDesc = new TaskDescription( + taskId = 1L, + attemptNumber = 0, + executorId = "s1", + name = "n1", + index = 0, + addedFiles = mutable.Map.empty[String, Long], + addedJars = mutable.Map.empty[String, Long], + properties = new Properties(), + ByteBuffer.wrap(new Array[Byte](0))) when(taskScheduler.resourceOffers(expectedWorkerOffers)).thenReturn(Seq(Seq(taskDesc))) when(taskScheduler.CPUS_PER_TASK).thenReturn(1) diff --git a/resource-managers/mesos/src/test/scala/org/apache/spark/scheduler/cluster/mesos/MesosTaskLaunchDataSuite.scala b/resource-managers/mesos/src/test/scala/org/apache/spark/scheduler/cluster/mesos/MesosTaskLaunchDataSuite.scala deleted file mode 100644 index 5a81bb335fdb7..0000000000000 --- a/resource-managers/mesos/src/test/scala/org/apache/spark/scheduler/cluster/mesos/MesosTaskLaunchDataSuite.scala +++ /dev/null @@ -1,36 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.scheduler.cluster.mesos - -import java.nio.ByteBuffer - -import org.apache.spark.SparkFunSuite - -class MesosTaskLaunchDataSuite extends SparkFunSuite { - test("serialize and deserialize data must be same") { - val serializedTask = ByteBuffer.allocate(40) - (Range(100, 110).map(serializedTask.putInt(_))) - serializedTask.rewind - val attemptNumber = 100 - val byteString = MesosTaskLaunchData(serializedTask, attemptNumber).toByteString - serializedTask.rewind - val mesosTaskLaunchData = MesosTaskLaunchData.fromByteString(byteString) - assert(mesosTaskLaunchData.attemptNumber == attemptNumber) - assert(mesosTaskLaunchData.serializedTask.equals(serializedTask)) - } -} diff --git a/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/LocalityPreferredContainerPlacementStrategy.scala b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/LocalityPreferredContainerPlacementStrategy.scala index 8772e26f4314d..fb2d61f621c8a 100644 --- a/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/LocalityPreferredContainerPlacementStrategy.scala +++ b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/LocalityPreferredContainerPlacementStrategy.scala @@ -32,7 +32,7 @@ private[yarn] case class ContainerLocalityPreferences(nodes: Array[String], rack /** * This strategy is calculating the optimal locality preferences of YARN containers by considering - * the node ratio of pending tasks, number of required cores/containers and and locality of current + * the node ratio of pending tasks, number of required cores/containers and locality of current * existing and pending allocated containers. The target of this algorithm is to maximize the number * of tasks that would run locally. * diff --git a/sbin/start-history-server.sh b/sbin/start-history-server.sh index 6851d99b7e8f4..38a43b98c3992 100755 --- a/sbin/start-history-server.sh +++ b/sbin/start-history-server.sh @@ -31,4 +31,4 @@ fi . "${SPARK_HOME}/sbin/spark-config.sh" . "${SPARK_HOME}/bin/load-spark-env.sh" -exec "${SPARK_HOME}/sbin"/spark-daemon.sh start org.apache.spark.deploy.history.HistoryServer 1 $@ +exec "${SPARK_HOME}/sbin"/spark-daemon.sh start org.apache.spark.deploy.history.HistoryServer 1 "$@" diff --git a/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 b/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 index a34087cb6cd4a..3222a9cdc2c4e 100644 --- a/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 +++ b/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 @@ -69,16 +69,18 @@ statement | ALTER DATABASE identifier SET DBPROPERTIES tablePropertyList #setDatabaseProperties | DROP DATABASE (IF EXISTS)? identifier (RESTRICT | CASCADE)? #dropDatabase | createTableHeader ('(' colTypeList ')')? tableProvider - (OPTIONS tablePropertyList)? + (OPTIONS options=tablePropertyList)? (PARTITIONED BY partitionColumnNames=identifierList)? - bucketSpec? (AS? query)? #createTableUsing + bucketSpec? locationSpec? + (COMMENT comment=STRING)? + (AS? query)? #createTable | createTableHeader ('(' columns=colTypeList ')')? - (COMMENT STRING)? + (COMMENT comment=STRING)? (PARTITIONED BY '(' partitionColumns=colTypeList ')')? bucketSpec? skewSpec? rowFormat? createFileFormat? locationSpec? (TBLPROPERTIES tablePropertyList)? - (AS? query)? #createTable + (AS? query)? #createHiveTable | CREATE TABLE (IF NOT EXISTS)? target=tableIdentifier LIKE source=tableIdentifier #createTableLike | ANALYZE TABLE tableIdentifier partitionSpec? COMPUTE STATISTICS diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java index d205547698c5b..86de90984ca00 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java @@ -196,7 +196,7 @@ public void setNullAt(int i) { assertIndexIsValid(i); BitSetMethods.set(baseObject, baseOffset, i); // To preserve row equality, zero out the value when setting the column to null. - // Since this row does does not currently support updates to variable-length values, we don't + // Since this row does not currently support updates to variable-length values, we don't // have to worry about zeroing out that data. Platform.putLong(baseObject, getFieldOffset(i), 0); } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala index ad218cf88de23..7f7dd51aa2650 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala @@ -312,12 +312,50 @@ object ScalaReflection extends ScalaReflection { "array", ObjectType(classOf[Array[Any]])) - StaticInvoke( + val wrappedArray = StaticInvoke( scala.collection.mutable.WrappedArray.getClass, ObjectType(classOf[Seq[_]]), "make", array :: Nil) + if (localTypeOf[scala.collection.mutable.WrappedArray[_]] <:< t.erasure) { + wrappedArray + } else { + // Convert to another type using `to` + val cls = mirror.runtimeClass(t.typeSymbol.asClass) + import scala.collection.generic.CanBuildFrom + import scala.reflect.ClassTag + + // Some canBuildFrom methods take an implicit ClassTag parameter + val cbfParams = try { + cls.getDeclaredMethod("canBuildFrom", classOf[ClassTag[_]]) + StaticInvoke( + ClassTag.getClass, + ObjectType(classOf[ClassTag[_]]), + "apply", + StaticInvoke( + cls, + ObjectType(classOf[Class[_]]), + "getClass" + ) :: Nil + ) :: Nil + } catch { + case _: NoSuchMethodException => Nil + } + + Invoke( + wrappedArray, + "to", + ObjectType(cls), + StaticInvoke( + cls, + ObjectType(classOf[CanBuildFrom[_, _, _]]), + "canBuildFrom", + cbfParams + ) :: Nil + ) + } + case t if t <:< localTypeOf[Map[_, _]] => // TODO: add walked type path for map val TypeRef(_, _, Seq(keyType, valueType)) = t diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/AttributeMap.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/AttributeMap.scala index 96a11e352ec50..1504a522798b0 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/AttributeMap.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/AttributeMap.scala @@ -33,6 +33,8 @@ class AttributeMap[A](baseMap: Map[ExprId, (Attribute, A)]) override def get(k: Attribute): Option[A] = baseMap.get(k.exprId).map(_._2) + override def contains(k: Attribute): Boolean = get(k).isDefined + override def + [B1 >: A](kv: (Attribute, B1)): Map[Attribute, B1] = baseMap.values.toMap + kv override def iterator: Iterator[(Attribute, A)] = baseMap.valuesIterator diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/windowExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/windowExpressions.scala index 13115f4728950..07d294b108548 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/windowExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/windowExpressions.scala @@ -516,7 +516,7 @@ case class CumeDist() extends RowNumberLike with SizeBasedWindowFunction { * into the number of buckets); both variables are based on the size of the current partition. * During the calculation process the function keeps track of the current row number, the current * bucket number, and the row number at which the bucket will change (bucketThreshold). When the - * current row number reaches bucket threshold, the bucket value is increased by one and the the + * current row number reaches bucket threshold, the bucket value is increased by one and the * threshold is increased by the bucket size (plus one extra if the current bucket is padded). * * This documentation has been based upon similar documentation for the Hive and Presto projects. diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala index d583fa31b0b28..9b52a9cc817c7 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala @@ -22,6 +22,7 @@ import org.apache.spark.sql.catalyst.analysis.MultiInstanceRelation import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression import org.apache.spark.sql.catalyst.plans._ +import org.apache.spark.sql.catalyst.plans.logical.statsEstimation.ProjectEstimation import org.apache.spark.sql.types._ import org.apache.spark.util.Utils @@ -53,6 +54,9 @@ case class Project(projectList: Seq[NamedExpression], child: LogicalPlan) extend override def validConstraints: Set[Expression] = child.constraints.union(getAliasedConstraints(projectList)) + + override lazy val statistics: Statistics = + ProjectEstimation.estimate(this).getOrElse(super.statistics) } /** @@ -795,7 +799,7 @@ case object OneRowRelation extends LeafNode { /** * Computes [[Statistics]] for this plan. The default implementation assumes the output - * cardinality is the product of of all child plan's cardinality, i.e. applies in the case + * cardinality is the product of all child plan's cardinality, i.e. applies in the case * of cartesian joins. * * [[LeafNode]]s must override this. diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/EstimationUtils.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/EstimationUtils.scala new file mode 100644 index 0000000000000..f099e32267461 --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/EstimationUtils.scala @@ -0,0 +1,54 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.plans.logical.statsEstimation + +import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeMap} +import org.apache.spark.sql.catalyst.plans.logical.{ColumnStat, LogicalPlan} +import org.apache.spark.sql.types.StringType + + +object EstimationUtils { + + /** Check if each plan has rowCount in its statistics. */ + def rowCountsExist(plans: LogicalPlan*): Boolean = + plans.forall(_.statistics.rowCount.isDefined) + + /** Get column stats for output attributes. */ + def getOutputMap(inputMap: AttributeMap[ColumnStat], output: Seq[Attribute]) + : AttributeMap[ColumnStat] = { + AttributeMap(output.flatMap(a => inputMap.get(a).map(a -> _))) + } + + def getRowSize(attributes: Seq[Attribute], attrStats: AttributeMap[ColumnStat]): Long = { + // We assign a generic overhead for a Row object, the actual overhead is different for different + // Row format. + 8 + attributes.map { attr => + if (attrStats.contains(attr)) { + attr.dataType match { + case StringType => + // UTF8String: base + offset + numBytes + attrStats(attr).avgLen + 8 + 4 + case _ => + attrStats(attr).avgLen + } + } else { + attr.dataType.defaultSize + } + }.sum + } +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/ProjectEstimation.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/ProjectEstimation.scala new file mode 100644 index 0000000000000..6d63b09fd41b8 --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/ProjectEstimation.scala @@ -0,0 +1,44 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.plans.logical.statsEstimation + +import org.apache.spark.sql.catalyst.expressions.{Alias, Attribute, AttributeMap} +import org.apache.spark.sql.catalyst.plans.logical.{Project, Statistics} + +object ProjectEstimation { + import EstimationUtils._ + + def estimate(project: Project): Option[Statistics] = { + if (rowCountsExist(project.child)) { + val childStats = project.child.statistics + val inputAttrStats = childStats.attributeStats + // Match alias with its child's column stat + val aliasStats = project.expressions.collect { + case alias @ Alias(attr: Attribute, _) if inputAttrStats.contains(attr) => + alias.toAttribute -> inputAttrStats(attr) + } + val outputAttrStats = + getOutputMap(AttributeMap(inputAttrStats.toSeq ++ aliasStats), project.output) + Some(childStats.copy( + sizeInBytes = childStats.rowCount.get * getRowSize(project.output, outputAttrStats), + attributeStats = outputAttrStats)) + } else { + None + } + } +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/CaseInsensitiveMap.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/CaseInsensitiveMap.scala index a7f7a8a66382a..29e49a58375b9 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/CaseInsensitiveMap.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/CaseInsensitiveMap.scala @@ -27,6 +27,8 @@ class CaseInsensitiveMap(map: Map[String, String]) extends Map[String, String] override def get(k: String): Option[String] = baseMap.get(k.toLowerCase) + override def contains(k: String): Boolean = baseMap.contains(k.toLowerCase) + override def + [B1 >: String](kv: (String, B1)): Map[String, B1] = baseMap + kv.copy(_1 = kv._1.toLowerCase) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/DateTimeUtils.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/DateTimeUtils.scala index 235ca8d2633a1..a96a3b7af2912 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/DateTimeUtils.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/DateTimeUtils.scala @@ -142,7 +142,7 @@ object DateTimeUtils { } /** - * Returns the number of days since epoch from from java.sql.Date. + * Returns the number of days since epoch from java.sql.Date. */ def fromJavaDate(date: Date): SQLDate = { millisToDays(date.getTime) @@ -503,7 +503,7 @@ object DateTimeUtils { } /** - * Calculates the year and and the number of the day in the year for the given + * Calculates the year and the number of the day in the year for the given * number of days. The given days is the number of days since 1.1.1970. * * The calculation uses the fact that the period 1.1.2001 until 31.12.2400 is diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ScalaReflectionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ScalaReflectionSuite.scala index 43b6afd9ad896..650a35398f3e8 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ScalaReflectionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ScalaReflectionSuite.scala @@ -291,6 +291,37 @@ class ScalaReflectionSuite extends SparkFunSuite { .cls.isAssignableFrom(classOf[org.apache.spark.sql.catalyst.util.GenericArrayData])) } + test("SPARK 16792: Get correct deserializer for List[_]") { + val listDeserializer = deserializerFor[List[Int]] + assert(listDeserializer.dataType == ObjectType(classOf[List[_]])) + } + + test("serialize and deserialize arbitrary sequence types") { + import scala.collection.immutable.Queue + val queueSerializer = serializerFor[Queue[Int]](BoundReference( + 0, ObjectType(classOf[Queue[Int]]), nullable = false)) + assert(queueSerializer.dataType.head.dataType == + ArrayType(IntegerType, containsNull = false)) + val queueDeserializer = deserializerFor[Queue[Int]] + assert(queueDeserializer.dataType == ObjectType(classOf[Queue[_]])) + + import scala.collection.mutable.ArrayBuffer + val arrayBufferSerializer = serializerFor[ArrayBuffer[Int]](BoundReference( + 0, ObjectType(classOf[ArrayBuffer[Int]]), nullable = false)) + assert(arrayBufferSerializer.dataType.head.dataType == + ArrayType(IntegerType, containsNull = false)) + val arrayBufferDeserializer = deserializerFor[ArrayBuffer[Int]] + assert(arrayBufferDeserializer.dataType == ObjectType(classOf[ArrayBuffer[_]])) + + // Check whether conversion is skipped when using WrappedArray[_] supertype + // (would otherwise needlessly add overhead) + import org.apache.spark.sql.catalyst.expressions.objects.StaticInvoke + val seqDeserializer = deserializerFor[Seq[Int]] + assert(seqDeserializer.asInstanceOf[StaticInvoke].staticObject == + scala.collection.mutable.WrappedArray.getClass) + assert(seqDeserializer.asInstanceOf[StaticInvoke].functionName == "make") + } + private val dataTypeForComplexData = dataTypeFor[ComplexData] private val typeOfComplexData = typeOf[ComplexData] diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/AttributeSetSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/AttributeSetSuite.scala index 97cfb5f06dd73..273f95f91ee50 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/AttributeSetSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/AttributeSetSuite.scala @@ -52,7 +52,7 @@ class AttributeSetSuite extends SparkFunSuite { assert((aSet ++ bSet).contains(aLower) === true) } - test("extracts all references references") { + test("extracts all references ") { val addSet = AttributeSet(Add(aUpper, Alias(bUpper, "test")()):: Nil) assert(addSet.contains(aUpper)) assert(addSet.contains(aLower)) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/ProjectEstimationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/ProjectEstimationSuite.scala new file mode 100644 index 0000000000000..4a1bed84f84e8 --- /dev/null +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/ProjectEstimationSuite.scala @@ -0,0 +1,51 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.statsEstimation + +import org.apache.spark.sql.catalyst.expressions.{Alias, AttributeMap, AttributeReference} +import org.apache.spark.sql.catalyst.plans.logical._ +import org.apache.spark.sql.catalyst.plans.logical.statsEstimation.EstimationUtils._ +import org.apache.spark.sql.types.IntegerType + + +class ProjectEstimationSuite extends StatsEstimationTestBase { + + test("estimate project with alias") { + val ar1 = AttributeReference("key1", IntegerType)() + val ar2 = AttributeReference("key2", IntegerType)() + val colStat1 = ColumnStat(2, Some(1), Some(2), 0, 4, 4) + val colStat2 = ColumnStat(1, Some(10), Some(10), 0, 4, 4) + + val child = StatsTestPlan( + outputList = Seq(ar1, ar2), + stats = Statistics( + sizeInBytes = 2 * (4 + 4), + rowCount = Some(2), + attributeStats = AttributeMap(Seq(ar1 -> colStat1, ar2 -> colStat2)))) + + val project = Project(Seq(ar1, Alias(ar2, "abc")()), child) + val expectedColStats = Seq("key1" -> colStat1, "abc" -> colStat2) + val expectedAttrStats = toAttributeMap(expectedColStats, project) + // The number of rows won't change for project. + val expectedStats = Statistics( + sizeInBytes = 2 * getRowSize(project.output, expectedAttrStats), + rowCount = Some(2), + attributeStats = expectedAttrStats) + assert(project.statistics == expectedStats) + } +} diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/StatsEstimationTestBase.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/StatsEstimationTestBase.scala new file mode 100644 index 0000000000000..fa5b290ecb17c --- /dev/null +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/StatsEstimationTestBase.scala @@ -0,0 +1,41 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.statsEstimation + +import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeMap} +import org.apache.spark.sql.catalyst.plans.logical.{ColumnStat, LeafNode, LogicalPlan, Statistics} + + +class StatsEstimationTestBase extends SparkFunSuite { + + /** Convert (column name, column stat) pairs to an AttributeMap based on plan output. */ + def toAttributeMap(colStats: Seq[(String, ColumnStat)], plan: LogicalPlan) + : AttributeMap[ColumnStat] = { + val nameToAttr: Map[String, Attribute] = plan.output.map(a => (a.name, a)).toMap + AttributeMap(colStats.map(kv => nameToAttr(kv._1) -> kv._2)) + } +} + +/** + * This class is used for unit-testing. It's a logical plan whose output and stats are passed in. + */ +protected case class StatsTestPlan(outputList: Seq[Attribute], stats: Statistics) extends LeafNode { + override def output: Seq[Attribute] = outputList + override lazy val statistics = stats +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala index 405f38ad49ef6..3127ebf67967c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala @@ -393,7 +393,9 @@ final class DataFrameWriter[T] private[sql](ds: Dataset[T]) { // Drop the existing table catalog.dropTable(tableIdentWithDB, ignoreIfNotExists = true, purge = false) - createTable(tableIdent) + createTable(tableIdentWithDB) + // Refresh the cache of the table in the catalog. + catalog.refreshTable(tableIdentWithDB) case _ => createTable(tableIdent) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala index c1cedd8541a83..1a7a5ba798077 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala @@ -44,7 +44,7 @@ import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.plans.physical.{Partitioning, PartitioningCollection} import org.apache.spark.sql.catalyst.util.usePrettyExpression -import org.apache.spark.sql.execution.{FileRelation, LogicalRDD, QueryExecution, SQLExecution} +import org.apache.spark.sql.execution._ import org.apache.spark.sql.execution.command.{CreateViewCommand, ExplainCommand, GlobalTempView, LocalTempView} import org.apache.spark.sql.execution.datasources.LogicalRelation import org.apache.spark.sql.execution.python.EvaluatePython @@ -361,7 +361,7 @@ class Dataset[T] private[sql]( * method used to map columns depend on the type of `U`: * - When `U` is a class, fields for the class will be mapped to columns of the same name * (case sensitivity is determined by `spark.sql.caseSensitive`). - * - When `U` is a tuple, the columns will be be mapped by ordinal (i.e. the first column will + * - When `U` is a tuple, the columns will be mapped by ordinal (i.e. the first column will * be assigned to `_1`). * - When `U` is a primitive type (i.e. String, Int, etc), then the first column of the * `DataFrame` will be used. @@ -750,14 +750,18 @@ class Dataset[T] private[sql]( } /** - * Equi-join with another `DataFrame` using the given columns. + * Equi-join with another `DataFrame` using the given columns. A cross join with a predicate + * is specified as an inner join. If you would explicitly like to perform a cross join use the + * `crossJoin` method. * * Different from other join functions, the join columns will only appear once in the output, * i.e. similar to SQL's `JOIN USING` syntax. * * @param right Right side of the join operation. * @param usingColumns Names of the columns to join on. This columns must exist on both sides. - * @param joinType One of: `inner`, `outer`, `left_outer`, `right_outer`, `leftsemi`. + * @param joinType Type of join to perform. Default `inner`. Must be one of: + * `inner`, `cross`, `outer`, `full`, `full_outer`, `left`, `left_outer`, + * `right`, `right_outer`, `left_semi`, `left_anti`. * * @note If you perform a self-join using this function without aliasing the input * `DataFrame`s, you will NOT be able to reference any columns after the join, since @@ -812,7 +816,9 @@ class Dataset[T] private[sql]( * * @param right Right side of the join. * @param joinExprs Join expression. - * @param joinType One of: `inner`, `outer`, `left_outer`, `right_outer`, `leftsemi`. + * @param joinType Type of join to perform. Default `inner`. Must be one of: + * `inner`, `cross`, `outer`, `full`, `full_outer`, `left`, `left_outer`, + * `right`, `right_outer`, `left_semi`, `left_anti`. * * @group untypedrel * @since 2.0.0 @@ -889,7 +895,9 @@ class Dataset[T] private[sql]( * * @param other Right side of the join. * @param condition Join expression. - * @param joinType One of: `inner`, `outer`, `left_outer`, `right_outer`, `leftsemi`. + * @param joinType Type of join to perform. Default `inner`. Must be one of: + * `inner`, `cross`, `outer`, `full`, `full_outer`, `left`, `left_outer`, + * `right`, `right_outer`, `left_semi`, `left_anti`. * * @group typedrel * @since 1.6.0 @@ -2096,9 +2104,7 @@ class Dataset[T] private[sql]( * @group action * @since 1.6.0 */ - def head(n: Int): Array[T] = withTypedCallback("head", limit(n)) { df => - df.collect(needCallback = false) - } + def head(n: Int): Array[T] = withAction("head", limit(n).queryExecution)(collectFromPlan) /** * Returns the first row. @@ -2325,7 +2331,7 @@ class Dataset[T] private[sql]( def takeAsList(n: Int): java.util.List[T] = java.util.Arrays.asList(take(n) : _*) /** - * Returns an array that contains all of [[Row]]s in this Dataset. + * Returns an array that contains all rows in this Dataset. * * Running collect requires moving all the data into the application's driver process, and * doing so on a very large dataset can crash the driver process with OutOfMemoryError. @@ -2335,10 +2341,10 @@ class Dataset[T] private[sql]( * @group action * @since 1.6.0 */ - def collect(): Array[T] = collect(needCallback = true) + def collect(): Array[T] = withAction("collect", queryExecution)(collectFromPlan) /** - * Returns a Java list that contains all of [[Row]]s in this Dataset. + * Returns a Java list that contains all rows in this Dataset. * * Running collect requires moving all the data into the application's driver process, and * doing so on a very large dataset can crash the driver process with OutOfMemoryError. @@ -2346,27 +2352,13 @@ class Dataset[T] private[sql]( * @group action * @since 1.6.0 */ - def collectAsList(): java.util.List[T] = withCallback("collectAsList", toDF()) { _ => - withNewExecutionId { - val values = queryExecution.executedPlan.executeCollect().map(boundEnc.fromRow) - java.util.Arrays.asList(values : _*) - } - } - - private def collect(needCallback: Boolean): Array[T] = { - def execute(): Array[T] = withNewExecutionId { - queryExecution.executedPlan.executeCollect().map(boundEnc.fromRow) - } - - if (needCallback) { - withCallback("collect", toDF())(_ => execute()) - } else { - execute() - } + def collectAsList(): java.util.List[T] = withAction("collectAsList", queryExecution) { plan => + val values = collectFromPlan(plan) + java.util.Arrays.asList(values : _*) } /** - * Return an iterator that contains all of [[Row]]s in this Dataset. + * Return an iterator that contains all rows in this Dataset. * * The iterator will consume as much memory as the largest partition in this Dataset. * @@ -2377,9 +2369,9 @@ class Dataset[T] private[sql]( * @group action * @since 2.0.0 */ - def toLocalIterator(): java.util.Iterator[T] = withCallback("toLocalIterator", toDF()) { _ => - withNewExecutionId { - queryExecution.executedPlan.executeToIterator().map(boundEnc.fromRow).asJava + def toLocalIterator(): java.util.Iterator[T] = { + withAction("toLocalIterator", queryExecution) { plan => + plan.executeToIterator().map(boundEnc.fromRow).asJava } } @@ -2388,8 +2380,8 @@ class Dataset[T] private[sql]( * @group action * @since 1.6.0 */ - def count(): Long = withCallback("count", groupBy().count()) { df => - df.collect(needCallback = false).head.getLong(0) + def count(): Long = withAction("count", groupBy().count().queryExecution) { plan => + plan.executeCollect().head.getLong(0) } /** @@ -2762,38 +2754,30 @@ class Dataset[T] private[sql]( * Wrap a Dataset action to track the QueryExecution and time cost, then report to the * user-registered callback functions. */ - private def withCallback[U](name: String, df: DataFrame)(action: DataFrame => U) = { + private def withAction[U](name: String, qe: QueryExecution)(action: SparkPlan => U) = { try { - df.queryExecution.executedPlan.foreach { plan => + qe.executedPlan.foreach { plan => plan.resetMetrics() } val start = System.nanoTime() - val result = action(df) + val result = SQLExecution.withNewExecutionId(sparkSession, qe) { + action(qe.executedPlan) + } val end = System.nanoTime() - sparkSession.listenerManager.onSuccess(name, df.queryExecution, end - start) + sparkSession.listenerManager.onSuccess(name, qe, end - start) result } catch { case e: Exception => - sparkSession.listenerManager.onFailure(name, df.queryExecution, e) + sparkSession.listenerManager.onFailure(name, qe, e) throw e } } - private def withTypedCallback[A, B](name: String, ds: Dataset[A])(action: Dataset[A] => B) = { - try { - ds.queryExecution.executedPlan.foreach { plan => - plan.resetMetrics() - } - val start = System.nanoTime() - val result = action(ds) - val end = System.nanoTime() - sparkSession.listenerManager.onSuccess(name, ds.queryExecution, end - start) - result - } catch { - case e: Exception => - sparkSession.listenerManager.onFailure(name, ds.queryExecution, e) - throw e - } + /** + * Collect all elements from a spark plan. + */ + private def collectFromPlan(plan: SparkPlan): Array[T] = { + plan.executeCollect().map(boundEnc.fromRow) } private def sortInternal(global: Boolean, sortExprs: Seq[Column]): Dataset[T] = { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLImplicits.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLImplicits.scala index 872a78b578d28..2caf723669f63 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SQLImplicits.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLImplicits.scala @@ -30,7 +30,7 @@ import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder * @since 1.6.0 */ @InterfaceStability.Evolving -abstract class SQLImplicits { +abstract class SQLImplicits extends LowPrioritySQLImplicits { protected def _sqlContext: SQLContext @@ -45,9 +45,6 @@ abstract class SQLImplicits { } } - /** @since 1.6.0 */ - implicit def newProductEncoder[T <: Product : TypeTag]: Encoder[T] = Encoders.product[T] - // Primitives /** @since 1.6.0 */ @@ -112,33 +109,96 @@ abstract class SQLImplicits { // Seqs - /** @since 1.6.1 */ - implicit def newIntSeqEncoder: Encoder[Seq[Int]] = ExpressionEncoder() + /** + * @since 1.6.1 + * @deprecated use [[newIntSequenceEncoder]] + */ + def newIntSeqEncoder: Encoder[Seq[Int]] = ExpressionEncoder() - /** @since 1.6.1 */ - implicit def newLongSeqEncoder: Encoder[Seq[Long]] = ExpressionEncoder() + /** + * @since 1.6.1 + * @deprecated use [[newLongSequenceEncoder]] + */ + def newLongSeqEncoder: Encoder[Seq[Long]] = ExpressionEncoder() - /** @since 1.6.1 */ - implicit def newDoubleSeqEncoder: Encoder[Seq[Double]] = ExpressionEncoder() + /** + * @since 1.6.1 + * @deprecated use [[newDoubleSequenceEncoder]] + */ + def newDoubleSeqEncoder: Encoder[Seq[Double]] = ExpressionEncoder() - /** @since 1.6.1 */ - implicit def newFloatSeqEncoder: Encoder[Seq[Float]] = ExpressionEncoder() + /** + * @since 1.6.1 + * @deprecated use [[newFloatSequenceEncoder]] + */ + def newFloatSeqEncoder: Encoder[Seq[Float]] = ExpressionEncoder() - /** @since 1.6.1 */ - implicit def newByteSeqEncoder: Encoder[Seq[Byte]] = ExpressionEncoder() + /** + * @since 1.6.1 + * @deprecated use [[newByteSequenceEncoder]] + */ + def newByteSeqEncoder: Encoder[Seq[Byte]] = ExpressionEncoder() - /** @since 1.6.1 */ - implicit def newShortSeqEncoder: Encoder[Seq[Short]] = ExpressionEncoder() + /** + * @since 1.6.1 + * @deprecated use [[newShortSequenceEncoder]] + */ + def newShortSeqEncoder: Encoder[Seq[Short]] = ExpressionEncoder() - /** @since 1.6.1 */ - implicit def newBooleanSeqEncoder: Encoder[Seq[Boolean]] = ExpressionEncoder() + /** + * @since 1.6.1 + * @deprecated use [[newBooleanSequenceEncoder]] + */ + def newBooleanSeqEncoder: Encoder[Seq[Boolean]] = ExpressionEncoder() - /** @since 1.6.1 */ - implicit def newStringSeqEncoder: Encoder[Seq[String]] = ExpressionEncoder() + /** + * @since 1.6.1 + * @deprecated use [[newStringSequenceEncoder]] + */ + def newStringSeqEncoder: Encoder[Seq[String]] = ExpressionEncoder() - /** @since 1.6.1 */ + /** + * @since 1.6.1 + * @deprecated use [[newProductSequenceEncoder]] + */ implicit def newProductSeqEncoder[A <: Product : TypeTag]: Encoder[Seq[A]] = ExpressionEncoder() + /** @since 2.2.0 */ + implicit def newIntSequenceEncoder[T <: Seq[Int] : TypeTag]: Encoder[T] = + ExpressionEncoder() + + /** @since 2.2.0 */ + implicit def newLongSequenceEncoder[T <: Seq[Long] : TypeTag]: Encoder[T] = + ExpressionEncoder() + + /** @since 2.2.0 */ + implicit def newDoubleSequenceEncoder[T <: Seq[Double] : TypeTag]: Encoder[T] = + ExpressionEncoder() + + /** @since 2.2.0 */ + implicit def newFloatSequenceEncoder[T <: Seq[Float] : TypeTag]: Encoder[T] = + ExpressionEncoder() + + /** @since 2.2.0 */ + implicit def newByteSequenceEncoder[T <: Seq[Byte] : TypeTag]: Encoder[T] = + ExpressionEncoder() + + /** @since 2.2.0 */ + implicit def newShortSequenceEncoder[T <: Seq[Short] : TypeTag]: Encoder[T] = + ExpressionEncoder() + + /** @since 2.2.0 */ + implicit def newBooleanSequenceEncoder[T <: Seq[Boolean] : TypeTag]: Encoder[T] = + ExpressionEncoder() + + /** @since 2.2.0 */ + implicit def newStringSequenceEncoder[T <: Seq[String] : TypeTag]: Encoder[T] = + ExpressionEncoder() + + /** @since 2.2.0 */ + implicit def newProductSequenceEncoder[T <: Seq[Product] : TypeTag]: Encoder[T] = + ExpressionEncoder() + // Arrays /** @since 1.6.1 */ @@ -193,3 +253,16 @@ abstract class SQLImplicits { implicit def symbolToColumn(s: Symbol): ColumnName = new ColumnName(s.name) } + +/** + * Lower priority implicit methods for converting Scala objects into [[Dataset]]s. + * Conflicting implicits are placed here to disambiguate resolution. + * + * Reasons for including specific implicits: + * newProductEncoder - to disambiguate for [[List]]s which are both [[Seq]] and [[Product]] + */ +trait LowPrioritySQLImplicits { + /** @since 1.6.0 */ + implicit def newProductEncoder[T <: Product : TypeTag]: Encoder[T] = Encoders.product[T] + +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/CacheManager.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/CacheManager.scala index 526623a36d2a1..4ca1347008575 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/CacheManager.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/CacheManager.scala @@ -22,6 +22,7 @@ import java.util.concurrent.locks.ReentrantReadWriteLock import org.apache.hadoop.fs.{FileSystem, Path} import org.apache.spark.internal.Logging +import org.apache.spark.sql.catalyst.expressions.SubqueryExpression import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.Dataset import org.apache.spark.sql.execution.columnar.InMemoryRelation @@ -131,12 +132,16 @@ class CacheManager extends Logging { /** Replaces segments of the given logical plan with cached versions where possible. */ def useCachedData(plan: LogicalPlan): LogicalPlan = { - plan transformDown { + val newPlan = plan transformDown { case currentFragment => lookupCachedData(currentFragment) .map(_.cachedRepresentation.withOutput(currentFragment.output)) .getOrElse(currentFragment) } + + newPlan transformAllExpressions { + case s: SubqueryExpression => s.withNewPlan(useCachedData(s.plan)) + } } /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlParser.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlParser.scala index 14a983e43b005..41768d451261a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlParser.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlParser.scala @@ -342,11 +342,11 @@ class SparkSqlAstBuilder(conf: SQLConf) extends AstBuilder { } /** - * Create a data source table, returning a [[CreateTable]] logical plan. + * Create a table, returning a [[CreateTable]] logical plan. * * Expected format: * {{{ - * CREATE [EXTERNAL] TABLE [IF NOT EXISTS] [db_name.]table_name + * CREATE [TEMPORARY] TABLE [IF NOT EXISTS] [db_name.]table_name * USING table_provider * [OPTIONS table_property_list] * [PARTITIONED BY (col_name, col_name, ...)] @@ -354,19 +354,18 @@ class SparkSqlAstBuilder(conf: SQLConf) extends AstBuilder { * [SORTED BY (col_name [ASC|DESC], ...)] * INTO num_buckets BUCKETS * ] + * [LOCATION path] + * [COMMENT table_comment] * [AS select_statement]; * }}} */ - override def visitCreateTableUsing(ctx: CreateTableUsingContext): LogicalPlan = withOrigin(ctx) { + override def visitCreateTable(ctx: CreateTableContext): LogicalPlan = withOrigin(ctx) { val (table, temp, ifNotExists, external) = visitCreateTableHeader(ctx.createTableHeader) if (external) { operationNotAllowed("CREATE EXTERNAL TABLE ... USING", ctx) } - val options = Option(ctx.tablePropertyList).map(visitPropertyKeyValues).getOrElse(Map.empty) + val options = Option(ctx.options).map(visitPropertyKeyValues).getOrElse(Map.empty) val provider = ctx.tableProvider.qualifiedName.getText - if (provider.toLowerCase == DDLUtils.HIVE_PROVIDER) { - throw new AnalysisException("Cannot create hive serde table with CREATE TABLE USING") - } val schema = Option(ctx.colTypeList()).map(createSchema) val partitionColumnNames = Option(ctx.partitionColumnNames) @@ -374,10 +373,17 @@ class SparkSqlAstBuilder(conf: SQLConf) extends AstBuilder { .getOrElse(Array.empty[String]) val bucketSpec = Option(ctx.bucketSpec()).map(visitBucketSpec) - // TODO: this may be wrong for non file-based data source like JDBC, which should be external - // even there is no `path` in options. We should consider allow the EXTERNAL keyword. + val location = Option(ctx.locationSpec).map(visitLocationSpec) val storage = DataSource.buildStorageFormatFromOptions(options) - val tableType = if (storage.locationUri.isDefined) { + + if (location.isDefined && storage.locationUri.isDefined) { + throw new ParseException( + "LOCATION and 'path' in OPTIONS are both used to indicate the custom table path, " + + "you can only specify one of them.", ctx) + } + val customLocation = storage.locationUri.orElse(location) + + val tableType = if (customLocation.isDefined) { CatalogTableType.EXTERNAL } else { CatalogTableType.MANAGED @@ -386,12 +392,12 @@ class SparkSqlAstBuilder(conf: SQLConf) extends AstBuilder { val tableDesc = CatalogTable( identifier = table, tableType = tableType, - storage = storage, + storage = storage.copy(locationUri = customLocation), schema = schema.getOrElse(new StructType), provider = Some(provider), partitionColumnNames = partitionColumnNames, - bucketSpec = bucketSpec - ) + bucketSpec = bucketSpec, + comment = Option(ctx.comment).map(string)) // Determine the storage mode. val mode = if (ifNotExists) SaveMode.Ignore else SaveMode.ErrorIfExists @@ -1011,10 +1017,10 @@ class SparkSqlAstBuilder(conf: SQLConf) extends AstBuilder { } /** - * Create a table, returning a [[CreateTable]] logical plan. + * Create a Hive serde table, returning a [[CreateTable]] logical plan. * - * This is not used to create datasource tables, which is handled through - * "CREATE TABLE ... USING ...". + * This is a legacy syntax for Hive compatibility, we recommend users to use the Spark SQL + * CREATE TABLE syntax to create Hive serde table, e.g. "CREATE TABLE ... USING hive ..." * * Note: several features are currently not supported - temporary tables, bucketing, * skewed columns and storage handlers (STORED BY). @@ -1032,7 +1038,7 @@ class SparkSqlAstBuilder(conf: SQLConf) extends AstBuilder { * [AS select_statement]; * }}} */ - override def visitCreateTable(ctx: CreateTableContext): LogicalPlan = withOrigin(ctx) { + override def visitCreateHiveTable(ctx: CreateHiveTableContext): LogicalPlan = withOrigin(ctx) { val (name, temp, ifNotExists, external) = visitCreateTableHeader(ctx.createTableHeader) // TODO: implement temporary tables if (temp) { @@ -1046,7 +1052,6 @@ class SparkSqlAstBuilder(conf: SQLConf) extends AstBuilder { if (ctx.bucketSpec != null) { operationNotAllowed("CREATE TABLE ... CLUSTERED BY", ctx) } - val comment = Option(ctx.STRING).map(string) val dataCols = Option(ctx.columns).map(visitColTypeList).getOrElse(Nil) val partitionCols = Option(ctx.partitionColumns).map(visitColTypeList).getOrElse(Nil) val properties = Option(ctx.tablePropertyList).map(visitPropertyKeyValues).getOrElse(Map.empty) @@ -1057,19 +1062,7 @@ class SparkSqlAstBuilder(conf: SQLConf) extends AstBuilder { val schema = StructType(dataCols ++ partitionCols) // Storage format - val defaultStorage: CatalogStorageFormat = { - val defaultStorageType = conf.getConfString("hive.default.fileformat", "textfile") - val defaultHiveSerde = HiveSerDe.sourceToSerDe(defaultStorageType) - CatalogStorageFormat( - locationUri = None, - inputFormat = defaultHiveSerde.flatMap(_.inputFormat) - .orElse(Some("org.apache.hadoop.mapred.TextInputFormat")), - outputFormat = defaultHiveSerde.flatMap(_.outputFormat) - .orElse(Some("org.apache.hadoop.hive.ql.io.HiveIgnoreKeyTextOutputFormat")), - serde = defaultHiveSerde.flatMap(_.serde), - compressed = false, - properties = Map()) - } + val defaultStorage = HiveSerDe.getDefaultStorage(conf) validateRowFormatFileFormat(ctx.rowFormat, ctx.createFileFormat, ctx) val fileStorage = Option(ctx.createFileFormat).map(visitCreateFileFormat) .getOrElse(CatalogStorageFormat.empty) @@ -1104,7 +1097,7 @@ class SparkSqlAstBuilder(conf: SQLConf) extends AstBuilder { provider = Some(DDLUtils.HIVE_PROVIDER), partitionColumnNames = partitionCols.map(_.name), properties = properties, - comment = comment) + comment = Option(ctx.comment).map(string)) val mode = if (ifNotExists) SaveMode.Ignore else SaveMode.ErrorIfExists diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala index 28808f8e3ee14..1257d1728c311 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala @@ -408,8 +408,7 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] { object DDLStrategy extends Strategy { def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match { - case CreateTable(tableDesc, mode, None) - if tableDesc.provider.get == DDLUtils.HIVE_PROVIDER => + case CreateTable(tableDesc, mode, None) if DDLUtils.isHiveTable(tableDesc) => val cmd = CreateTableCommand(tableDesc, ifNotExists = mode == SaveMode.Ignore) ExecutedCommandExec(cmd) :: Nil @@ -421,8 +420,7 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] { // CREATE TABLE ... AS SELECT ... for hive serde table is handled in hive module, by rule // `CreateTables` - case CreateTable(tableDesc, mode, Some(query)) - if tableDesc.provider.get != DDLUtils.HIVE_PROVIDER => + case CreateTable(tableDesc, mode, Some(query)) if DDLUtils.isDatasourceTable(tableDesc) => val cmd = CreateDataSourceTableAsSelectCommand( tableDesc, diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/createDataSourceTables.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/createDataSourceTables.scala index c64c7ad943d49..73b21533a26c7 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/createDataSourceTables.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/createDataSourceTables.scala @@ -138,7 +138,7 @@ case class CreateDataSourceTableAsSelectCommand( val tableIdentWithDB = table.identifier.copy(database = Some(db)) val tableName = tableIdentWithDB.unquotedString - val result = if (sessionState.catalog.tableExists(tableIdentWithDB)) { + if (sessionState.catalog.tableExists(tableIdentWithDB)) { assert(mode != SaveMode.Overwrite, s"Expect the table $tableName has been dropped when the save mode is Overwrite") @@ -150,14 +150,16 @@ case class CreateDataSourceTableAsSelectCommand( return Seq.empty } - saveDataIntoTable(sparkSession, table, table.storage.locationUri, query, mode) + saveDataIntoTable( + sparkSession, table, table.storage.locationUri, query, mode, tableExists = true) } else { val tableLocation = if (table.tableType == CatalogTableType.MANAGED) { Some(sessionState.catalog.defaultTablePath(table.identifier)) } else { table.storage.locationUri } - val result = saveDataIntoTable(sparkSession, table, tableLocation, query, mode) + val result = saveDataIntoTable( + sparkSession, table, tableLocation, query, mode, tableExists = false) val newTable = table.copy( storage = table.storage.copy(locationUri = tableLocation), // We will use the schema of resolved.relation as the schema of the table (instead of @@ -165,20 +167,17 @@ case class CreateDataSourceTableAsSelectCommand( // provider (for example, see org.apache.spark.sql.parquet.DefaultSource). schema = result.schema) sessionState.catalog.createTable(newTable, ignoreIfExists = false) - result - } - result match { - case fs: HadoopFsRelation if table.partitionColumnNames.nonEmpty && - sparkSession.sqlContext.conf.manageFilesourcePartitions => - // Need to recover partitions into the metastore so our saved data is visible. - sparkSession.sessionState.executePlan( - AlterTableRecoverPartitionsCommand(table.identifier)).toRdd - case _ => + result match { + case fs: HadoopFsRelation if table.partitionColumnNames.nonEmpty && + sparkSession.sqlContext.conf.manageFilesourcePartitions => + // Need to recover partitions into the metastore so our saved data is visible. + sparkSession.sessionState.executePlan( + AlterTableRecoverPartitionsCommand(table.identifier)).toRdd + case _ => + } } - // Refresh the cache of the table in the catalog. - sessionState.catalog.refreshTable(tableIdentWithDB) Seq.empty[Row] } @@ -187,7 +186,8 @@ case class CreateDataSourceTableAsSelectCommand( table: CatalogTable, tableLocation: Option[String], data: LogicalPlan, - mode: SaveMode): BaseRelation = { + mode: SaveMode, + tableExists: Boolean): BaseRelation = { // Create the relation based on the input logical plan: `data`. val pathOption = tableLocation.map("path" -> _) val dataSource = DataSource( @@ -196,7 +196,7 @@ case class CreateDataSourceTableAsSelectCommand( partitionColumns = table.partitionColumnNames, bucketSpec = table.bucketSpec, options = table.storage.properties ++ pathOption, - catalogTable = Some(table)) + catalogTable = if (tableExists) Some(table) else None) try { dataSource.write(mode, Dataset.ofRows(session, query)) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/ddl.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/ddl.scala index 59a29e884739a..82cbb4aa47445 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/ddl.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/ddl.scala @@ -762,8 +762,12 @@ case class AlterTableSetLocationCommand( object DDLUtils { val HIVE_PROVIDER = "hive" + def isHiveTable(table: CatalogTable): Boolean = { + table.provider.isDefined && table.provider.get.toLowerCase == HIVE_PROVIDER + } + def isDatasourceTable(table: CatalogTable): Boolean = { - table.provider.isDefined && table.provider.get != HIVE_PROVIDER + table.provider.isDefined && table.provider.get.toLowerCase != HIVE_PROVIDER } /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala index ac3f0688bbd42..b7f3559b6559b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala @@ -466,13 +466,18 @@ case class DataSource( // SPARK-17230: Resolve the partition columns so InsertIntoHadoopFsRelationCommand does // not need to have the query as child, to avoid to analyze an optimized query, // because InsertIntoHadoopFsRelationCommand will be optimized first. - val columns = partitionColumns.map { name => + val partitionAttributes = partitionColumns.map { name => val plan = data.logicalPlan plan.resolve(name :: Nil, data.sparkSession.sessionState.analyzer.resolver).getOrElse { throw new AnalysisException( s"Unable to resolve $name given [${plan.output.map(_.name).mkString(", ")}]") }.asInstanceOf[Attribute] } + val fileIndex = catalogTable.map(_.identifier).map { tableIdent => + sparkSession.table(tableIdent).queryExecution.analyzed.collect { + case LogicalRelation(t: HadoopFsRelation, _, _) => t.location + }.head + } // For partitioned relation r, r.schema's column ordering can be different from the column // ordering of data.logicalPlan (partition columns are all moved after data column). This // will be adjusted within InsertIntoHadoopFsRelation. @@ -480,15 +485,14 @@ case class DataSource( InsertIntoHadoopFsRelationCommand( outputPath = outputPath, staticPartitions = Map.empty, - customPartitionLocations = Map.empty, - partitionColumns = columns, + partitionColumns = partitionAttributes, bucketSpec = bucketSpec, fileFormat = format, - refreshFunction = _ => Unit, // No existing table needs to be refreshed. options = options, query = data.logicalPlan, mode = mode, - catalogTable = catalogTable) + catalogTable = catalogTable, + fileIndex = fileIndex) sparkSession.sessionState.executePlan(plan).toRdd // Replace the schema with that of the DataFrame we just wrote out to avoid re-inferring it. copy(userSpecifiedSchema = Some(data.schema.asNullable)).resolveRelation() diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala index 61f0d43f2470e..3d3db06eee0c6 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala @@ -197,91 +197,19 @@ case class DataSourceAnalysis(conf: CatalystConf) extends Rule[LogicalPlan] { val partitionSchema = actualQuery.resolve( t.partitionSchema, t.sparkSession.sessionState.analyzer.resolver) - val partitionsTrackedByCatalog = - t.sparkSession.sessionState.conf.manageFilesourcePartitions && - l.catalogTable.isDefined && l.catalogTable.get.partitionColumnNames.nonEmpty && - l.catalogTable.get.tracksPartitionsInCatalog - - var initialMatchingPartitions: Seq[TablePartitionSpec] = Nil - var customPartitionLocations: Map[TablePartitionSpec, String] = Map.empty - val staticPartitions = parts.filter(_._2.nonEmpty).map { case (k, v) => k -> v.get } - // When partitions are tracked by the catalog, compute all custom partition locations that - // may be relevant to the insertion job. - if (partitionsTrackedByCatalog) { - val matchingPartitions = t.sparkSession.sessionState.catalog.listPartitions( - l.catalogTable.get.identifier, Some(staticPartitions)) - initialMatchingPartitions = matchingPartitions.map(_.spec) - customPartitionLocations = getCustomPartitionLocations( - t.sparkSession, l.catalogTable.get, outputPath, matchingPartitions) - } - - // Callback for updating metastore partition metadata after the insertion job completes. - // TODO(ekl) consider moving this into InsertIntoHadoopFsRelationCommand - def refreshPartitionsCallback(updatedPartitions: Seq[TablePartitionSpec]): Unit = { - if (partitionsTrackedByCatalog) { - val newPartitions = updatedPartitions.toSet -- initialMatchingPartitions - if (newPartitions.nonEmpty) { - AlterTableAddPartitionCommand( - l.catalogTable.get.identifier, newPartitions.toSeq.map(p => (p, None)), - ifNotExists = true).run(t.sparkSession) - } - if (overwrite) { - val deletedPartitions = initialMatchingPartitions.toSet -- updatedPartitions - if (deletedPartitions.nonEmpty) { - AlterTableDropPartitionCommand( - l.catalogTable.get.identifier, deletedPartitions.toSeq, - ifExists = true, purge = false, - retainData = true /* already deleted */).run(t.sparkSession) - } - } - } - t.location.refresh() - } - - val insertCmd = InsertIntoHadoopFsRelationCommand( + InsertIntoHadoopFsRelationCommand( outputPath, staticPartitions, - customPartitionLocations, partitionSchema, t.bucketSpec, t.fileFormat, - refreshPartitionsCallback, t.options, actualQuery, mode, - table) - - insertCmd - } - - /** - * Given a set of input partitions, returns those that have locations that differ from the - * Hive default (e.g. /k1=v1/k2=v2). These partitions were manually assigned locations by - * the user. - * - * @return a mapping from partition specs to their custom locations - */ - private def getCustomPartitionLocations( - spark: SparkSession, - table: CatalogTable, - basePath: Path, - partitions: Seq[CatalogTablePartition]): Map[TablePartitionSpec, String] = { - val hadoopConf = spark.sessionState.newHadoopConf - val fs = basePath.getFileSystem(hadoopConf) - val qualifiedBasePath = basePath.makeQualified(fs.getUri, fs.getWorkingDirectory) - partitions.flatMap { p => - val defaultLocation = qualifiedBasePath.suffix( - "/" + PartitioningUtils.getPathFragment(p.spec, table.partitionSchema)).toString - val catalogLocation = new Path(p.location).makeQualified( - fs.getUri, fs.getWorkingDirectory).toString - if (catalogLocation != defaultLocation) { - Some(p.spec -> catalogLocation) - } else { - None - } - }.toMap + table, + Some(t.location)) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormatWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormatWriter.scala index 1eb4541e2c103..16c5193eda8df 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormatWriter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormatWriter.scala @@ -64,18 +64,18 @@ object FileFormatWriter extends Logging { val outputWriterFactory: OutputWriterFactory, val allColumns: Seq[Attribute], val partitionColumns: Seq[Attribute], - val nonPartitionColumns: Seq[Attribute], + val dataColumns: Seq[Attribute], val bucketSpec: Option[BucketSpec], val path: String, val customPartitionLocations: Map[TablePartitionSpec, String], val maxRecordsPerFile: Long) extends Serializable { - assert(AttributeSet(allColumns) == AttributeSet(partitionColumns ++ nonPartitionColumns), + assert(AttributeSet(allColumns) == AttributeSet(partitionColumns ++ dataColumns), s""" |All columns: ${allColumns.mkString(", ")} |Partition columns: ${partitionColumns.mkString(", ")} - |Non-partition columns: ${nonPartitionColumns.mkString(", ")} + |Data columns: ${dataColumns.mkString(", ")} """.stripMargin) } @@ -120,7 +120,7 @@ object FileFormatWriter extends Logging { outputWriterFactory = outputWriterFactory, allColumns = queryExecution.logical.output, partitionColumns = partitionColumns, - nonPartitionColumns = dataColumns, + dataColumns = dataColumns, bucketSpec = bucketSpec, path = outputSpec.outputPath, customPartitionLocations = outputSpec.customPartitionLocations, @@ -246,9 +246,8 @@ object FileFormatWriter extends Logging { currentWriter = description.outputWriterFactory.newInstance( path = tmpFilePath, - dataSchema = description.nonPartitionColumns.toStructType, + dataSchema = description.dataColumns.toStructType, context = taskAttemptContext) - currentWriter.initConverter(dataSchema = description.nonPartitionColumns.toStructType) } override def execute(iter: Iterator[InternalRow]): Set[String] = { @@ -267,7 +266,7 @@ object FileFormatWriter extends Logging { } val internalRow = iter.next() - currentWriter.writeInternal(internalRow) + currentWriter.write(internalRow) recordsInFile += 1 } releaseResources() @@ -364,9 +363,8 @@ object FileFormatWriter extends Logging { currentWriter = description.outputWriterFactory.newInstance( path = path, - dataSchema = description.nonPartitionColumns.toStructType, + dataSchema = description.dataColumns.toStructType, context = taskAttemptContext) - currentWriter.initConverter(description.nonPartitionColumns.toStructType) } override def execute(iter: Iterator[InternalRow]): Set[String] = { @@ -383,7 +381,7 @@ object FileFormatWriter extends Logging { // Returns the data columns to be written given an input row val getOutputRow = UnsafeProjection.create( - description.nonPartitionColumns, description.allColumns) + description.dataColumns, description.allColumns) // Returns the partition path given a partition key. val getPartitionStringFunc = UnsafeProjection.create( @@ -392,7 +390,7 @@ object FileFormatWriter extends Logging { // Sorts the data before write, so that we only need one writer at the same time. val sorter = new UnsafeKVExternalSorter( sortingKeySchema, - StructType.fromAttributes(description.nonPartitionColumns), + StructType.fromAttributes(description.dataColumns), SparkEnv.get.blockManager, SparkEnv.get.serializerManager, TaskContext.get().taskMemoryManager().pageSizeBytes, @@ -448,7 +446,7 @@ object FileFormatWriter extends Logging { newOutputWriter(currentKey, getPartitionStringFunc, fileCounter) } - currentWriter.writeInternal(sortedIterator.getValue) + currentWriter.write(sortedIterator.getValue) recordsInFile += 1 } releaseResources() diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileSourceStrategy.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileSourceStrategy.scala index ead323320243a..6d0671d4cb16f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileSourceStrategy.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileSourceStrategy.scala @@ -106,7 +106,7 @@ object FileSourceStrategy extends Strategy with Logging { val outputAttributes = readDataColumns ++ partitionColumns val scan = - new FileSourceScanExec( + FileSourceScanExec( fsRelation, outputAttributes, outputSchema, diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/InsertIntoHadoopFsRelationCommand.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/InsertIntoHadoopFsRelationCommand.scala index 53c884c22b9dc..423009e4eccea 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/InsertIntoHadoopFsRelationCommand.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/InsertIntoHadoopFsRelationCommand.scala @@ -23,11 +23,11 @@ import org.apache.hadoop.fs.{FileSystem, Path} import org.apache.spark.internal.io.FileCommitProtocol import org.apache.spark.sql._ -import org.apache.spark.sql.catalyst.catalog.{BucketSpec, CatalogTable} +import org.apache.spark.sql.catalyst.catalog.{BucketSpec, CatalogTable, CatalogTablePartition} import org.apache.spark.sql.catalyst.catalog.CatalogTypes.TablePartitionSpec import org.apache.spark.sql.catalyst.expressions.Attribute import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan -import org.apache.spark.sql.execution.command.RunnableCommand +import org.apache.spark.sql.execution.command._ /** * A command for writing data to a [[HadoopFsRelation]]. Supports both overwriting and appending. @@ -37,23 +37,18 @@ import org.apache.spark.sql.execution.command.RunnableCommand * overwrites: when the spec is empty, all partitions are overwritten. * When it covers a prefix of the partition keys, only partitions matching * the prefix are overwritten. - * @param customPartitionLocations mapping of partition specs to their custom locations. The - * caller should guarantee that exactly those table partitions - * falling under the specified static partition keys are contained - * in this map, and that no other partitions are. */ case class InsertIntoHadoopFsRelationCommand( outputPath: Path, staticPartitions: TablePartitionSpec, - customPartitionLocations: Map[TablePartitionSpec, String], partitionColumns: Seq[Attribute], bucketSpec: Option[BucketSpec], fileFormat: FileFormat, - refreshFunction: Seq[TablePartitionSpec] => Unit, options: Map[String, String], - @transient query: LogicalPlan, + query: LogicalPlan, mode: SaveMode, - catalogTable: Option[CatalogTable]) + catalogTable: Option[CatalogTable], + fileIndex: Option[FileIndex]) extends RunnableCommand { import org.apache.spark.sql.catalyst.catalog.ExternalCatalogUtils.escapePathName @@ -74,12 +69,30 @@ case class InsertIntoHadoopFsRelationCommand( val fs = outputPath.getFileSystem(hadoopConf) val qualifiedOutputPath = outputPath.makeQualified(fs.getUri, fs.getWorkingDirectory) + val partitionsTrackedByCatalog = sparkSession.sessionState.conf.manageFilesourcePartitions && + catalogTable.isDefined && + catalogTable.get.partitionColumnNames.nonEmpty && + catalogTable.get.tracksPartitionsInCatalog + + var initialMatchingPartitions: Seq[TablePartitionSpec] = Nil + var customPartitionLocations: Map[TablePartitionSpec, String] = Map.empty + + // When partitions are tracked by the catalog, compute all custom partition locations that + // may be relevant to the insertion job. + if (partitionsTrackedByCatalog) { + val matchingPartitions = sparkSession.sessionState.catalog.listPartitions( + catalogTable.get.identifier, Some(staticPartitions)) + initialMatchingPartitions = matchingPartitions.map(_.spec) + customPartitionLocations = getCustomPartitionLocations( + fs, catalogTable.get, qualifiedOutputPath, matchingPartitions) + } + val pathExists = fs.exists(qualifiedOutputPath) val doInsertion = (mode, pathExists) match { case (SaveMode.ErrorIfExists, true) => throw new AnalysisException(s"path $qualifiedOutputPath already exists.") case (SaveMode.Overwrite, true) => - deleteMatchingPartitions(fs, qualifiedOutputPath) + deleteMatchingPartitions(fs, qualifiedOutputPath, customPartitionLocations) true case (SaveMode.Append, _) | (SaveMode.Overwrite, _) | (SaveMode.ErrorIfExists, false) => true @@ -98,6 +111,27 @@ case class InsertIntoHadoopFsRelationCommand( outputPath = outputPath.toString, isAppend = isAppend) + // Callback for updating metastore partition metadata after the insertion job completes. + def refreshPartitionsCallback(updatedPartitions: Seq[TablePartitionSpec]): Unit = { + if (partitionsTrackedByCatalog) { + val newPartitions = updatedPartitions.toSet -- initialMatchingPartitions + if (newPartitions.nonEmpty) { + AlterTableAddPartitionCommand( + catalogTable.get.identifier, newPartitions.toSeq.map(p => (p, None)), + ifNotExists = true).run(sparkSession) + } + if (mode == SaveMode.Overwrite) { + val deletedPartitions = initialMatchingPartitions.toSet -- updatedPartitions + if (deletedPartitions.nonEmpty) { + AlterTableDropPartitionCommand( + catalogTable.get.identifier, deletedPartitions.toSeq, + ifExists = true, purge = false, + retainData = true /* already deleted */).run(sparkSession) + } + } + } + } + FileFormatWriter.write( sparkSession = sparkSession, queryExecution = Dataset.ofRows(sparkSession, query).queryExecution, @@ -108,8 +142,10 @@ case class InsertIntoHadoopFsRelationCommand( hadoopConf = hadoopConf, partitionColumns = partitionColumns, bucketSpec = bucketSpec, - refreshFunction = refreshFunction, + refreshFunction = refreshPartitionsCallback, options = options) + + fileIndex.foreach(_.refresh()) } else { logInfo("Skipping insertion into a relation that already exists.") } @@ -121,7 +157,10 @@ case class InsertIntoHadoopFsRelationCommand( * Deletes all partition files that match the specified static prefix. Partitions with custom * locations are also cleared based on the custom locations map given to this class. */ - private def deleteMatchingPartitions(fs: FileSystem, qualifiedOutputPath: Path): Unit = { + private def deleteMatchingPartitions( + fs: FileSystem, + qualifiedOutputPath: Path, + customPartitionLocations: Map[TablePartitionSpec, String]): Unit = { val staticPartitionPrefix = if (staticPartitions.nonEmpty) { "/" + partitionColumns.flatMap { p => staticPartitions.get(p.name) match { @@ -152,4 +191,29 @@ case class InsertIntoHadoopFsRelationCommand( } } } + + /** + * Given a set of input partitions, returns those that have locations that differ from the + * Hive default (e.g. /k1=v1/k2=v2). These partitions were manually assigned locations by + * the user. + * + * @return a mapping from partition specs to their custom locations + */ + private def getCustomPartitionLocations( + fs: FileSystem, + table: CatalogTable, + qualifiedOutputPath: Path, + partitions: Seq[CatalogTablePartition]): Map[TablePartitionSpec, String] = { + partitions.flatMap { p => + val defaultLocation = qualifiedOutputPath.suffix( + "/" + PartitioningUtils.getPathFragment(p.spec, table.partitionSchema)).toString + val catalogLocation = new Path(p.location).makeQualified( + fs.getUri, fs.getWorkingDirectory).toString + if (catalogLocation != defaultLocation) { + Some(p.spec -> catalogLocation) + } else { + None + } + }.toMap + } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/OutputWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/OutputWriter.scala index a73c8146c1b0d..868e5371426c0 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/OutputWriter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/OutputWriter.scala @@ -47,19 +47,6 @@ abstract class OutputWriterFactory extends Serializable { path: String, dataSchema: StructType, context: TaskAttemptContext): OutputWriter - - /** - * Returns a new instance of [[OutputWriter]] that will write data to the given path. - * This method gets called by each task on executor to write InternalRows to - * format-specific files. Compared to the other `newInstance()`, this is a newer API that - * passes only the path that the writer must write to. The writer must write to the exact path - * and not modify it (do not add subdirectories, extensions, etc.). All other - * file-format-specific information needed to create the writer must be passed - * through the [[OutputWriterFactory]] implementation. - */ - def newWriter(path: String): OutputWriter = { - throw new UnsupportedOperationException("newInstance with just path not supported") - } } @@ -74,22 +61,11 @@ abstract class OutputWriter { * Persists a single row. Invoked on the executor side. When writing to dynamically partitioned * tables, dynamic partition columns are not included in rows to be written. */ - def write(row: Row): Unit + def write(row: InternalRow): Unit /** * Closes the [[OutputWriter]]. Invoked on the executor side after all rows are persisted, before * the task output is committed. */ def close(): Unit - - private var converter: InternalRow => Row = _ - - protected[sql] def initConverter(dataSchema: StructType) = { - converter = - CatalystTypeConverters.createToScalaConverter(dataSchema).asInstanceOf[InternalRow => Row] - } - - protected[sql] def writeInternal(row: InternalRow): Unit = { - write(converter(row)) - } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PartitioningUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PartitioningUtils.scala index bc290702dc37f..bad59961ace12 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PartitioningUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PartitioningUtils.scala @@ -41,7 +41,7 @@ object PartitionPath { } /** - * Holds a directory in a partitioned collection of files as well as as the partition values + * Holds a directory in a partitioned collection of files as well as the partition values * in the form of a Row. Before scanning, the files at `path` need to be enumerated. */ case class PartitionPath(values: InternalRow, path: Path) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVRelation.scala index 23c07eb630d31..8c19be48c2118 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVRelation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVRelation.scala @@ -221,9 +221,7 @@ private[csv] class CsvOutputWriter( row.get(ordinal, dt).toString } - override def write(row: Row): Unit = throw new UnsupportedOperationException("call writeInternal") - - override protected[sql] def writeInternal(row: InternalRow): Unit = { + override def write(row: InternalRow): Unit = { csvWriter.writeRow(rowToString(row), printHeader) printHeader = false } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonFileFormat.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonFileFormat.scala index a9d8ddfe9d805..be1f94dbad912 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonFileFormat.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonFileFormat.scala @@ -159,9 +159,7 @@ private[json] class JsonOutputWriter( // create the Generator without separator inserted between 2 records private[this] val gen = new JacksonGenerator(dataSchema, writer, options) - override def write(row: Row): Unit = throw new UnsupportedOperationException("call writeInternal") - - override protected[sql] def writeInternal(row: InternalRow): Unit = { + override def write(row: InternalRow): Unit = { gen.write(row) gen.writeLineEnding() } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetOutputWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetOutputWriter.scala index 5c0f8af17a232..8361762b09703 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetOutputWriter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetOutputWriter.scala @@ -37,9 +37,7 @@ private[parquet] class ParquetOutputWriter(path: String, context: TaskAttemptCon }.getRecordWriter(context) } - override def write(row: Row): Unit = throw new UnsupportedOperationException("call writeInternal") - - override def writeInternal(row: InternalRow): Unit = recordWriter.write(null, row) + override def write(row: InternalRow): Unit = recordWriter.write(null, row) override def close(): Unit = recordWriter.close(context) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/rules.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/rules.scala index 07b16671f744b..94ba814fa51ff 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/rules.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/rules.scala @@ -109,7 +109,7 @@ case class AnalyzeCreateTable(sparkSession: SparkSession) extends Rule[LogicalPl throw new AnalysisException("Saving data into a view is not allowed.") } - if (existingTable.provider.get == DDLUtils.HIVE_PROVIDER) { + if (DDLUtils.isHiveTable(existingTable)) { throw new AnalysisException(s"Saving data in the Hive serde table $tableName is " + "not supported yet. Please use the insertInto() API as an alternative.") } @@ -233,7 +233,7 @@ case class AnalyzeCreateTable(sparkSession: SparkSession) extends Rule[LogicalPl checkDuplication(normalizedPartitionCols, "partition") if (schema.nonEmpty && normalizedPartitionCols.length == schema.length) { - if (table.provider.get == DDLUtils.HIVE_PROVIDER) { + if (DDLUtils.isHiveTable(table)) { // When we hit this branch, it means users didn't specify schema for the table to be // created, as we always include partition columns in table schema for hive serde tables. // The real schema will be inferred at hive metastore by hive serde, plus the given @@ -380,8 +380,7 @@ case class PreprocessTableInsertion(conf: SQLConf) extends Rule[LogicalPlan] { object HiveOnlyCheck extends (LogicalPlan => Unit) { def apply(plan: LogicalPlan): Unit = { plan.foreach { - case CreateTable(tableDesc, _, Some(_)) - if tableDesc.provider.get == DDLUtils.HIVE_PROVIDER => + case CreateTable(tableDesc, _, Some(_)) if DDLUtils.isHiveTable(tableDesc) => throw new AnalysisException("Hive support is required to use CREATE Hive TABLE AS SELECT") case _ => // OK diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/text/TextFileFormat.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/text/TextFileFormat.scala index 897e535953331..6f6e3016864b6 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/text/TextFileFormat.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/text/TextFileFormat.scala @@ -132,9 +132,7 @@ class TextOutputWriter( private val writer = CodecStreams.createOutputStream(context, new Path(path)) - override def write(row: Row): Unit = throw new UnsupportedOperationException("call writeInternal") - - override protected[sql] def writeInternal(row: InternalRow): Unit = { + override def write(row: InternalRow): Unit = { if (!row.isNullAt(0)) { val utf8string = row.getUTF8String(0) utf8string.writeTo(writer) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala index 650439a193015..9a080fd3c97c1 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala @@ -785,10 +785,13 @@ object functions { /** * Window function: returns the rank of rows within a window partition, without any gaps. * - * The difference between rank and denseRank is that denseRank leaves no gaps in ranking - * sequence when there are ties. That is, if you were ranking a competition using denseRank + * The difference between rank and dense_rank is that denseRank leaves no gaps in ranking + * sequence when there are ties. That is, if you were ranking a competition using dense_rank * and had three people tie for second place, you would say that all three were in second - * place and that the next person came in third. + * place and that the next person came in third. Rank would give me sequential numbers, making + * the person that came in third place (after the ties) would register as coming in fifth. + * + * This is equivalent to the DENSE_RANK function in SQL. * * @group window_funcs * @since 1.6.0 @@ -929,10 +932,11 @@ object functions { /** * Window function: returns the rank of rows within a window partition. * - * The difference between rank and denseRank is that denseRank leaves no gaps in ranking - * sequence when there are ties. That is, if you were ranking a competition using denseRank + * The difference between rank and dense_rank is that dense_rank leaves no gaps in ranking + * sequence when there are ties. That is, if you were ranking a competition using dense_rank * and had three people tie for second place, you would say that all three were in second - * place and that the next person came in third. + * place and that the next person came in third. Rank would give me sequential numbers, making + * the person that came in third place (after the ties) would register as coming in fifth. * * This is equivalent to the RANK function in SQL. * diff --git a/sql/core/src/main/scala/org/apache/spark/sql/internal/HiveSerDe.scala b/sql/core/src/main/scala/org/apache/spark/sql/internal/HiveSerDe.scala index 52e648a917d8b..ca46a1151e3e1 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/internal/HiveSerDe.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/internal/HiveSerDe.scala @@ -17,12 +17,49 @@ package org.apache.spark.sql.internal +import org.apache.spark.sql.catalyst.catalog.CatalogStorageFormat + case class HiveSerDe( inputFormat: Option[String] = None, outputFormat: Option[String] = None, serde: Option[String] = None) object HiveSerDe { + val serdeMap = Map( + "sequencefile" -> + HiveSerDe( + inputFormat = Option("org.apache.hadoop.mapred.SequenceFileInputFormat"), + outputFormat = Option("org.apache.hadoop.mapred.SequenceFileOutputFormat")), + + "rcfile" -> + HiveSerDe( + inputFormat = Option("org.apache.hadoop.hive.ql.io.RCFileInputFormat"), + outputFormat = Option("org.apache.hadoop.hive.ql.io.RCFileOutputFormat"), + serde = Option("org.apache.hadoop.hive.serde2.columnar.LazyBinaryColumnarSerDe")), + + "orc" -> + HiveSerDe( + inputFormat = Option("org.apache.hadoop.hive.ql.io.orc.OrcInputFormat"), + outputFormat = Option("org.apache.hadoop.hive.ql.io.orc.OrcOutputFormat"), + serde = Option("org.apache.hadoop.hive.ql.io.orc.OrcSerde")), + + "parquet" -> + HiveSerDe( + inputFormat = Option("org.apache.hadoop.hive.ql.io.parquet.MapredParquetInputFormat"), + outputFormat = Option("org.apache.hadoop.hive.ql.io.parquet.MapredParquetOutputFormat"), + serde = Option("org.apache.hadoop.hive.ql.io.parquet.serde.ParquetHiveSerDe")), + + "textfile" -> + HiveSerDe( + inputFormat = Option("org.apache.hadoop.mapred.TextInputFormat"), + outputFormat = Option("org.apache.hadoop.hive.ql.io.HiveIgnoreKeyTextOutputFormat")), + + "avro" -> + HiveSerDe( + inputFormat = Option("org.apache.hadoop.hive.ql.io.avro.AvroContainerInputFormat"), + outputFormat = Option("org.apache.hadoop.hive.ql.io.avro.AvroContainerOutputFormat"), + serde = Option("org.apache.hadoop.hive.serde2.avro.AvroSerDe"))) + /** * Get the Hive SerDe information from the data source abbreviation string or classname. * @@ -31,41 +68,6 @@ object HiveSerDe { * @return HiveSerDe associated with the specified source */ def sourceToSerDe(source: String): Option[HiveSerDe] = { - val serdeMap = Map( - "sequencefile" -> - HiveSerDe( - inputFormat = Option("org.apache.hadoop.mapred.SequenceFileInputFormat"), - outputFormat = Option("org.apache.hadoop.mapred.SequenceFileOutputFormat")), - - "rcfile" -> - HiveSerDe( - inputFormat = Option("org.apache.hadoop.hive.ql.io.RCFileInputFormat"), - outputFormat = Option("org.apache.hadoop.hive.ql.io.RCFileOutputFormat"), - serde = Option("org.apache.hadoop.hive.serde2.columnar.LazyBinaryColumnarSerDe")), - - "orc" -> - HiveSerDe( - inputFormat = Option("org.apache.hadoop.hive.ql.io.orc.OrcInputFormat"), - outputFormat = Option("org.apache.hadoop.hive.ql.io.orc.OrcOutputFormat"), - serde = Option("org.apache.hadoop.hive.ql.io.orc.OrcSerde")), - - "parquet" -> - HiveSerDe( - inputFormat = Option("org.apache.hadoop.hive.ql.io.parquet.MapredParquetInputFormat"), - outputFormat = Option("org.apache.hadoop.hive.ql.io.parquet.MapredParquetOutputFormat"), - serde = Option("org.apache.hadoop.hive.ql.io.parquet.serde.ParquetHiveSerDe")), - - "textfile" -> - HiveSerDe( - inputFormat = Option("org.apache.hadoop.mapred.TextInputFormat"), - outputFormat = Option("org.apache.hadoop.hive.ql.io.HiveIgnoreKeyTextOutputFormat")), - - "avro" -> - HiveSerDe( - inputFormat = Option("org.apache.hadoop.hive.ql.io.avro.AvroContainerInputFormat"), - outputFormat = Option("org.apache.hadoop.hive.ql.io.avro.AvroContainerOutputFormat"), - serde = Option("org.apache.hadoop.hive.serde2.avro.AvroSerDe"))) - val key = source.toLowerCase match { case s if s.startsWith("org.apache.spark.sql.parquet") => "parquet" case s if s.startsWith("org.apache.spark.sql.orc") => "orc" @@ -77,4 +79,16 @@ object HiveSerDe { serdeMap.get(key) } + + def getDefaultStorage(conf: SQLConf): CatalogStorageFormat = { + val defaultStorageType = conf.getConfString("hive.default.fileformat", "textfile") + val defaultHiveSerde = sourceToSerDe(defaultStorageType) + CatalogStorageFormat.empty.copy( + inputFormat = defaultHiveSerde.flatMap(_.inputFormat) + .orElse(Some("org.apache.hadoop.mapred.TextInputFormat")), + outputFormat = defaultHiveSerde.flatMap(_.outputFormat) + .orElse(Some("org.apache.hadoop.hive.ql.io.HiveIgnoreKeyTextOutputFormat")), + serde = defaultHiveSerde.flatMap(_.serde) + .orElse(Some("org.apache.hadoop.hive.serde2.lazy.LazySimpleSerDe"))) + } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamWriter.scala index 0ce47b152c59b..abb00ce02e91b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamWriter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamWriter.scala @@ -115,7 +115,7 @@ final class DataStreamWriter[T] private[sql](ds: Dataset[T]) { } /** - * Specifies the underlying output data source. Built-in options include "parquet" for now. + * Specifies the underlying output data source. * * @since 2.0.0 */ @@ -137,9 +137,7 @@ final class DataStreamWriter[T] private[sql](ds: Dataset[T]) { * predicates on the partitioned columns. In order for partitioning to work well, the number * of distinct values in each column should typically be less than tens of thousands. * - * This was initially applicable for Parquet but in 1.5+ covers JSON, text, ORC and avro as well. - * - * @since 1.4.0 + * @since 2.0.0 */ @scala.annotation.varargs def partitionBy(colNames: String*): DataStreamWriter[T] = { @@ -285,7 +283,7 @@ final class DataStreamWriter[T] private[sql](ds: Dataset[T]) { /** * Starts the execution of the streaming query, which will continually send results to the given - * `ForeachWriter` as as new data arrives. The `ForeachWriter` can be used to send the data + * `ForeachWriter` as new data arrives. The `ForeachWriter` can be used to send the data * generated by the `DataFrame`/`Dataset` to an external system. * * Scala example: diff --git a/sql/core/src/test/resources/sql-tests/inputs/subquery/in-subquery/in-group-by.sql b/sql/core/src/test/resources/sql-tests/inputs/subquery/in-subquery/in-group-by.sql new file mode 100644 index 0000000000000..b1d96b32c2478 --- /dev/null +++ b/sql/core/src/test/resources/sql-tests/inputs/subquery/in-subquery/in-group-by.sql @@ -0,0 +1,239 @@ +-- A test suite for GROUP BY in parent side, subquery, and both predicate subquery +-- It includes correlated cases. + +create temporary view t1 as select * from values + ("t1a", 6S, 8, 10L, float(15.0), 20D, 20E2, timestamp '2014-04-04 01:00:00.000', date '2014-04-04'), + ("t1b", 8S, 16, 19L, float(17.0), 25D, 26E2, timestamp '2014-05-04 01:01:00.000', date '2014-05-04'), + ("t1a", 16S, 12, 21L, float(15.0), 20D, 20E2, timestamp '2014-06-04 01:02:00.001', date '2014-06-04'), + ("t1a", 16S, 12, 10L, float(15.0), 20D, 20E2, timestamp '2014-07-04 01:01:00.000', date '2014-07-04'), + ("t1c", 8S, 16, 19L, float(17.0), 25D, 26E2, timestamp '2014-05-04 01:02:00.001', date '2014-05-05'), + ("t1d", null, 16, 22L, float(17.0), 25D, 26E2, timestamp '2014-06-04 01:01:00.000', null), + ("t1d", null, 16, 19L, float(17.0), 25D, 26E2, timestamp '2014-07-04 01:02:00.001', null), + ("t1e", 10S, null, 25L, float(17.0), 25D, 26E2, timestamp '2014-08-04 01:01:00.000', date '2014-08-04'), + ("t1e", 10S, null, 19L, float(17.0), 25D, 26E2, timestamp '2014-09-04 01:02:00.001', date '2014-09-04'), + ("t1d", 10S, null, 12L, float(17.0), 25D, 26E2, timestamp '2015-05-04 01:01:00.000', date '2015-05-04'), + ("t1a", 6S, 8, 10L, float(15.0), 20D, 20E2, timestamp '2014-04-04 01:02:00.001', date '2014-04-04'), + ("t1e", 10S, null, 19L, float(17.0), 25D, 26E2, timestamp '2014-05-04 01:01:00.000', date '2014-05-04') + as t1(t1a, t1b, t1c, t1d, t1e, t1f, t1g, t1h, t1i); + +create temporary view t2 as select * from values + ("t2a", 6S, 12, 14L, float(15), 20D, 20E2, timestamp '2014-04-04 01:01:00.000', date '2014-04-04'), + ("t1b", 10S, 12, 19L, float(17), 25D, 26E2, timestamp '2014-05-04 01:01:00.000', date '2014-05-04'), + ("t1b", 8S, 16, 119L, float(17), 25D, 26E2, timestamp '2015-05-04 01:01:00.000', date '2015-05-04'), + ("t1c", 12S, 16, 219L, float(17), 25D, 26E2, timestamp '2016-05-04 01:01:00.000', date '2016-05-04'), + ("t1b", null, 16, 319L, float(17), 25D, 26E2, timestamp '2017-05-04 01:01:00.000', null), + ("t2e", 8S, null, 419L, float(17), 25D, 26E2, timestamp '2014-06-04 01:01:00.000', date '2014-06-04'), + ("t1f", 19S, null, 519L, float(17), 25D, 26E2, timestamp '2014-05-04 01:01:00.000', date '2014-05-04'), + ("t1b", 10S, 12, 19L, float(17), 25D, 26E2, timestamp '2014-06-04 01:01:00.000', date '2014-06-04'), + ("t1b", 8S, 16, 19L, float(17), 25D, 26E2, timestamp '2014-07-04 01:01:00.000', date '2014-07-04'), + ("t1c", 12S, 16, 19L, float(17), 25D, 26E2, timestamp '2014-08-04 01:01:00.000', date '2014-08-05'), + ("t1e", 8S, null, 19L, float(17), 25D, 26E2, timestamp '2014-09-04 01:01:00.000', date '2014-09-04'), + ("t1f", 19S, null, 19L, float(17), 25D, 26E2, timestamp '2014-10-04 01:01:00.000', date '2014-10-04'), + ("t1b", null, 16, 19L, float(17), 25D, 26E2, timestamp '2014-05-04 01:01:00.000', null) + as t2(t2a, t2b, t2c, t2d, t2e, t2f, t2g, t2h, t2i); + +create temporary view t3 as select * from values + ("t3a", 6S, 12, 110L, float(15), 20D, 20E2, timestamp '2014-04-04 01:02:00.000', date '2014-04-04'), + ("t3a", 6S, 12, 10L, float(15), 20D, 20E2, timestamp '2014-05-04 01:02:00.000', date '2014-05-04'), + ("t1b", 10S, 12, 219L, float(17), 25D, 26E2, timestamp '2014-05-04 01:02:00.000', date '2014-05-04'), + ("t1b", 10S, 12, 19L, float(17), 25D, 26E2, timestamp '2014-05-04 01:02:00.000', date '2014-05-04'), + ("t1b", 8S, 16, 319L, float(17), 25D, 26E2, timestamp '2014-06-04 01:02:00.000', date '2014-06-04'), + ("t1b", 8S, 16, 19L, float(17), 25D, 26E2, timestamp '2014-07-04 01:02:00.000', date '2014-07-04'), + ("t3c", 17S, 16, 519L, float(17), 25D, 26E2, timestamp '2014-08-04 01:02:00.000', date '2014-08-04'), + ("t3c", 17S, 16, 19L, float(17), 25D, 26E2, timestamp '2014-09-04 01:02:00.000', date '2014-09-05'), + ("t1b", null, 16, 419L, float(17), 25D, 26E2, timestamp '2014-10-04 01:02:00.000', null), + ("t1b", null, 16, 19L, float(17), 25D, 26E2, timestamp '2014-11-04 01:02:00.000', null), + ("t3b", 8S, null, 719L, float(17), 25D, 26E2, timestamp '2014-05-04 01:02:00.000', date '2014-05-04'), + ("t3b", 8S, null, 19L, float(17), 25D, 26E2, timestamp '2015-05-04 01:02:00.000', date '2015-05-04') + as t3(t3a, t3b, t3c, t3d, t3e, t3f, t3g, t3h, t3i); + +-- correlated IN subquery +-- GROUP BY in parent side +-- TC 01.01 +SELECT t1a, + Avg(t1b) +FROM t1 +WHERE t1a IN (SELECT t2a + FROM t2) +GROUP BY t1a; + +-- TC 01.02 +SELECT t1a, + Max(t1b) +FROM t1 +WHERE t1b IN (SELECT t2b + FROM t2 + WHERE t1a = t2a) +GROUP BY t1a, + t1d; + +-- TC 01.03 +SELECT t1a, + t1b +FROM t1 +WHERE t1c IN (SELECT t2c + FROM t2 + WHERE t1a = t2a) +GROUP BY t1a, + t1b; + +-- TC 01.04 +SELECT t1a, + Sum(DISTINCT( t1b )) +FROM t1 +WHERE t1c IN (SELECT t2c + FROM t2 + WHERE t1a = t2a) + OR t1c IN (SELECT t3c + FROM t3 + WHERE t1a = t3a) +GROUP BY t1a, + t1c; + +-- TC 01.05 +SELECT t1a, + Sum(DISTINCT( t1b )) +FROM t1 +WHERE t1c IN (SELECT t2c + FROM t2 + WHERE t1a = t2a) + AND t1c IN (SELECT t3c + FROM t3 + WHERE t1a = t3a) +GROUP BY t1a, + t1c; + +-- TC 01.06 +SELECT t1a, + Count(DISTINCT( t1b )) +FROM t1 +WHERE t1c IN (SELECT t2c + FROM t2 + WHERE t1a = t2a) +GROUP BY t1a, + t1c +HAVING t1a = "t1b"; + +-- GROUP BY in subquery +-- TC 01.07 +SELECT * +FROM t1 +WHERE t1b IN (SELECT Max(t2b) + FROM t2 + GROUP BY t2a); + +-- TC 01.08 +SELECT * +FROM (SELECT t2a, + t2b + FROM t2 + WHERE t2a IN (SELECT t1a + FROM t1 + WHERE t1b = t2b) + GROUP BY t2a, + t2b) t2; + +-- TC 01.09 +SELECT Count(DISTINCT( * )) +FROM t1 +WHERE t1b IN (SELECT Min(t2b) + FROM t2 + WHERE t1a = t2a + AND t1c = t2c + GROUP BY t2a); + +-- TC 01.10 +SELECT t1a, + t1b +FROM t1 +WHERE t1c IN (SELECT Max(t2c) + FROM t2 + WHERE t1a = t2a + GROUP BY t2a, + t2c + HAVING t2c > 8); + +-- TC 01.11 +SELECT t1a, + t1b +FROM t1 +WHERE t1c IN (SELECT t2c + FROM t2 + WHERE t2a IN (SELECT Min(t3a) + FROM t3 + WHERE t3a = t2a + GROUP BY t3b) + GROUP BY t2c); + +-- GROUP BY in both +-- TC 01.12 +SELECT t1a, + Min(t1b) +FROM t1 +WHERE t1c IN (SELECT Min(t2c) + FROM t2 + WHERE t2b = t1b + GROUP BY t2a) +GROUP BY t1a; + +-- TC 01.13 +SELECT t1a, + Min(t1b) +FROM t1 +WHERE t1c IN (SELECT Min(t2c) + FROM t2 + WHERE t2b IN (SELECT Min(t3b) + FROM t3 + WHERE t2a = t3a + GROUP BY t3a) + GROUP BY t2c) +GROUP BY t1a, + t1d; + +-- TC 01.14 +SELECT t1a, + Min(t1b) +FROM t1 +WHERE t1c IN (SELECT Min(t2c) + FROM t2 + WHERE t2b = t1b + GROUP BY t2a) + AND t1d IN (SELECT t3d + FROM t3 + WHERE t1c = t3c + GROUP BY t3d) +GROUP BY t1a; + +-- TC 01.15 +SELECT t1a, + Min(t1b) +FROM t1 +WHERE t1c IN (SELECT Min(t2c) + FROM t2 + WHERE t2b = t1b + GROUP BY t2a) + OR t1d IN (SELECT t3d + FROM t3 + WHERE t1c = t3c + GROUP BY t3d) +GROUP BY t1a; + +-- TC 01.16 +SELECT t1a, + Min(t1b) +FROM t1 +WHERE t1c IN (SELECT Min(t2c) + FROM t2 + WHERE t2b = t1b + GROUP BY t2a + HAVING t2a > t1a) + OR t1d IN (SELECT t3d + FROM t3 + WHERE t1c = t3c + GROUP BY t3d + HAVING t3d = t1d) +GROUP BY t1a +HAVING Min(t1b) IS NOT NULL; + + + diff --git a/sql/core/src/test/resources/sql-tests/inputs/subquery/in-subquery/simple-in.sql b/sql/core/src/test/resources/sql-tests/inputs/subquery/in-subquery/simple-in.sql new file mode 100644 index 0000000000000..20370b045e803 --- /dev/null +++ b/sql/core/src/test/resources/sql-tests/inputs/subquery/in-subquery/simple-in.sql @@ -0,0 +1,112 @@ +-- A test suite for simple IN predicate subquery +-- It includes correlated cases. + +create temporary view t1 as select * from values + ("t1a", 6S, 8, 10L, float(15.0), 20D, 20E2, timestamp '2014-04-04 01:00:00.000', date '2014-04-04'), + ("t1b", 8S, 16, 19L, float(17.0), 25D, 26E2, timestamp '2014-05-04 01:01:00.000', date '2014-05-04'), + ("t1a", 16S, 12, 21L, float(15.0), 20D, 20E2, timestamp '2014-06-04 01:02:00.001', date '2014-06-04'), + ("t1a", 16S, 12, 10L, float(15.0), 20D, 20E2, timestamp '2014-07-04 01:01:00.000', date '2014-07-04'), + ("t1c", 8S, 16, 19L, float(17.0), 25D, 26E2, timestamp '2014-05-04 01:02:00.001', date '2014-05-05'), + ("t1d", null, 16, 22L, float(17.0), 25D, 26E2, timestamp '2014-06-04 01:01:00.000', null), + ("t1d", null, 16, 19L, float(17.0), 25D, 26E2, timestamp '2014-07-04 01:02:00.001', null), + ("t1e", 10S, null, 25L, float(17.0), 25D, 26E2, timestamp '2014-08-04 01:01:00.000', date '2014-08-04'), + ("t1e", 10S, null, 19L, float(17.0), 25D, 26E2, timestamp '2014-09-04 01:02:00.001', date '2014-09-04'), + ("t1d", 10S, null, 12L, float(17.0), 25D, 26E2, timestamp '2015-05-04 01:01:00.000', date '2015-05-04'), + ("t1a", 6S, 8, 10L, float(15.0), 20D, 20E2, timestamp '2014-04-04 01:02:00.001', date '2014-04-04'), + ("t1e", 10S, null, 19L, float(17.0), 25D, 26E2, timestamp '2014-05-04 01:01:00.000', date '2014-05-04') + as t1(t1a, t1b, t1c, t1d, t1e, t1f, t1g, t1h, t1i); + +create temporary view t2 as select * from values + ("t2a", 6S, 12, 14L, float(15), 20D, 20E2, timestamp '2014-04-04 01:01:00.000', date '2014-04-04'), + ("t1b", 10S, 12, 19L, float(17), 25D, 26E2, timestamp '2014-05-04 01:01:00.000', date '2014-05-04'), + ("t1b", 8S, 16, 119L, float(17), 25D, 26E2, timestamp '2015-05-04 01:01:00.000', date '2015-05-04'), + ("t1c", 12S, 16, 219L, float(17), 25D, 26E2, timestamp '2016-05-04 01:01:00.000', date '2016-05-04'), + ("t1b", null, 16, 319L, float(17), 25D, 26E2, timestamp '2017-05-04 01:01:00.000', null), + ("t2e", 8S, null, 419L, float(17), 25D, 26E2, timestamp '2014-06-04 01:01:00.000', date '2014-06-04'), + ("t1f", 19S, null, 519L, float(17), 25D, 26E2, timestamp '2014-05-04 01:01:00.000', date '2014-05-04'), + ("t1b", 10S, 12, 19L, float(17), 25D, 26E2, timestamp '2014-06-04 01:01:00.000', date '2014-06-04'), + ("t1b", 8S, 16, 19L, float(17), 25D, 26E2, timestamp '2014-07-04 01:01:00.000', date '2014-07-04'), + ("t1c", 12S, 16, 19L, float(17), 25D, 26E2, timestamp '2014-08-04 01:01:00.000', date '2014-08-05'), + ("t1e", 8S, null, 19L, float(17), 25D, 26E2, timestamp '2014-09-04 01:01:00.000', date '2014-09-04'), + ("t1f", 19S, null, 19L, float(17), 25D, 26E2, timestamp '2014-10-04 01:01:00.000', date '2014-10-04'), + ("t1b", null, 16, 19L, float(17), 25D, 26E2, timestamp '2014-05-04 01:01:00.000', null) + as t2(t2a, t2b, t2c, t2d, t2e, t2f, t2g, t2h, t2i); + +create temporary view t3 as select * from values + ("t3a", 6S, 12, 110L, float(15), 20D, 20E2, timestamp '2014-04-04 01:02:00.000', date '2014-04-04'), + ("t3a", 6S, 12, 10L, float(15), 20D, 20E2, timestamp '2014-05-04 01:02:00.000', date '2014-05-04'), + ("t1b", 10S, 12, 219L, float(17), 25D, 26E2, timestamp '2014-05-04 01:02:00.000', date '2014-05-04'), + ("t1b", 10S, 12, 19L, float(17), 25D, 26E2, timestamp '2014-05-04 01:02:00.000', date '2014-05-04'), + ("t1b", 8S, 16, 319L, float(17), 25D, 26E2, timestamp '2014-06-04 01:02:00.000', date '2014-06-04'), + ("t1b", 8S, 16, 19L, float(17), 25D, 26E2, timestamp '2014-07-04 01:02:00.000', date '2014-07-04'), + ("t3c", 17S, 16, 519L, float(17), 25D, 26E2, timestamp '2014-08-04 01:02:00.000', date '2014-08-04'), + ("t3c", 17S, 16, 19L, float(17), 25D, 26E2, timestamp '2014-09-04 01:02:00.000', date '2014-09-05'), + ("t1b", null, 16, 419L, float(17), 25D, 26E2, timestamp '2014-10-04 01:02:00.000', null), + ("t1b", null, 16, 19L, float(17), 25D, 26E2, timestamp '2014-11-04 01:02:00.000', null), + ("t3b", 8S, null, 719L, float(17), 25D, 26E2, timestamp '2014-05-04 01:02:00.000', date '2014-05-04'), + ("t3b", 8S, null, 19L, float(17), 25D, 26E2, timestamp '2015-05-04 01:02:00.000', date '2015-05-04') + as t3(t3a, t3b, t3c, t3d, t3e, t3f, t3g, t3h, t3i); + +-- correlated IN subquery +-- simple select +-- TC 01.01 +SELECT * +FROM t1 +WHERE t1a IN (SELECT t2a + FROM t2); + +-- TC 01.02 +SELECT * +FROM t1 +WHERE t1b IN (SELECT t2b + FROM t2 + WHERE t1a = t2a); + +-- TC 01.03 +SELECT t1a, + t1b +FROM t1 +WHERE t1c IN (SELECT t2b + FROM t2 + WHERE t1a != t2a); + +-- TC 01.04 +SELECT t1a, + t1b +FROM t1 +WHERE t1c IN (SELECT t2b + FROM t2 + WHERE t1a = t2a + OR t1b > t2b); + +-- TC 01.05 +SELECT t1a, + t1b +FROM t1 +WHERE t1c IN (SELECT t2b + FROM t2 + WHERE t2i IN (SELECT t3i + FROM t3 + WHERE t2c = t3c)); + +-- TC 01.06 +SELECT t1a, + t1b +FROM t1 +WHERE t1c IN (SELECT t2b + FROM t2 + WHERE t2a IN (SELECT t3a + FROM t3 + WHERE t2c = t3c + AND t2b IS NOT NULL)); + +-- simple select for NOT IN +-- TC 01.07 +SELECT DISTINCT( t1a ), + t1b, + t1h +FROM t1 +WHERE t1a NOT IN (SELECT t2a + FROM t2); + + diff --git a/sql/core/src/test/resources/sql-tests/results/subquery/in-subquery/in-group-by.sql.out b/sql/core/src/test/resources/sql-tests/results/subquery/in-subquery/in-group-by.sql.out new file mode 100644 index 0000000000000..a159aa81eff1c --- /dev/null +++ b/sql/core/src/test/resources/sql-tests/results/subquery/in-subquery/in-group-by.sql.out @@ -0,0 +1,357 @@ +-- Automatically generated by SQLQueryTestSuite +-- Number of queries: 19 + + +-- !query 0 +create temporary view t1 as select * from values + ("t1a", 6S, 8, 10L, float(15.0), 20D, 20E2, timestamp '2014-04-04 01:00:00.000', date '2014-04-04'), + ("t1b", 8S, 16, 19L, float(17.0), 25D, 26E2, timestamp '2014-05-04 01:01:00.000', date '2014-05-04'), + ("t1a", 16S, 12, 21L, float(15.0), 20D, 20E2, timestamp '2014-06-04 01:02:00.001', date '2014-06-04'), + ("t1a", 16S, 12, 10L, float(15.0), 20D, 20E2, timestamp '2014-07-04 01:01:00.000', date '2014-07-04'), + ("t1c", 8S, 16, 19L, float(17.0), 25D, 26E2, timestamp '2014-05-04 01:02:00.001', date '2014-05-05'), + ("t1d", null, 16, 22L, float(17.0), 25D, 26E2, timestamp '2014-06-04 01:01:00.000', null), + ("t1d", null, 16, 19L, float(17.0), 25D, 26E2, timestamp '2014-07-04 01:02:00.001', null), + ("t1e", 10S, null, 25L, float(17.0), 25D, 26E2, timestamp '2014-08-04 01:01:00.000', date '2014-08-04'), + ("t1e", 10S, null, 19L, float(17.0), 25D, 26E2, timestamp '2014-09-04 01:02:00.001', date '2014-09-04'), + ("t1d", 10S, null, 12L, float(17.0), 25D, 26E2, timestamp '2015-05-04 01:01:00.000', date '2015-05-04'), + ("t1a", 6S, 8, 10L, float(15.0), 20D, 20E2, timestamp '2014-04-04 01:02:00.001', date '2014-04-04'), + ("t1e", 10S, null, 19L, float(17.0), 25D, 26E2, timestamp '2014-05-04 01:01:00.000', date '2014-05-04') + as t1(t1a, t1b, t1c, t1d, t1e, t1f, t1g, t1h, t1i) +-- !query 0 schema +struct<> +-- !query 0 output + + + +-- !query 1 +create temporary view t2 as select * from values + ("t2a", 6S, 12, 14L, float(15), 20D, 20E2, timestamp '2014-04-04 01:01:00.000', date '2014-04-04'), + ("t1b", 10S, 12, 19L, float(17), 25D, 26E2, timestamp '2014-05-04 01:01:00.000', date '2014-05-04'), + ("t1b", 8S, 16, 119L, float(17), 25D, 26E2, timestamp '2015-05-04 01:01:00.000', date '2015-05-04'), + ("t1c", 12S, 16, 219L, float(17), 25D, 26E2, timestamp '2016-05-04 01:01:00.000', date '2016-05-04'), + ("t1b", null, 16, 319L, float(17), 25D, 26E2, timestamp '2017-05-04 01:01:00.000', null), + ("t2e", 8S, null, 419L, float(17), 25D, 26E2, timestamp '2014-06-04 01:01:00.000', date '2014-06-04'), + ("t1f", 19S, null, 519L, float(17), 25D, 26E2, timestamp '2014-05-04 01:01:00.000', date '2014-05-04'), + ("t1b", 10S, 12, 19L, float(17), 25D, 26E2, timestamp '2014-06-04 01:01:00.000', date '2014-06-04'), + ("t1b", 8S, 16, 19L, float(17), 25D, 26E2, timestamp '2014-07-04 01:01:00.000', date '2014-07-04'), + ("t1c", 12S, 16, 19L, float(17), 25D, 26E2, timestamp '2014-08-04 01:01:00.000', date '2014-08-05'), + ("t1e", 8S, null, 19L, float(17), 25D, 26E2, timestamp '2014-09-04 01:01:00.000', date '2014-09-04'), + ("t1f", 19S, null, 19L, float(17), 25D, 26E2, timestamp '2014-10-04 01:01:00.000', date '2014-10-04'), + ("t1b", null, 16, 19L, float(17), 25D, 26E2, timestamp '2014-05-04 01:01:00.000', null) + as t2(t2a, t2b, t2c, t2d, t2e, t2f, t2g, t2h, t2i) +-- !query 1 schema +struct<> +-- !query 1 output + + + +-- !query 2 +create temporary view t3 as select * from values + ("t3a", 6S, 12, 110L, float(15), 20D, 20E2, timestamp '2014-04-04 01:02:00.000', date '2014-04-04'), + ("t3a", 6S, 12, 10L, float(15), 20D, 20E2, timestamp '2014-05-04 01:02:00.000', date '2014-05-04'), + ("t1b", 10S, 12, 219L, float(17), 25D, 26E2, timestamp '2014-05-04 01:02:00.000', date '2014-05-04'), + ("t1b", 10S, 12, 19L, float(17), 25D, 26E2, timestamp '2014-05-04 01:02:00.000', date '2014-05-04'), + ("t1b", 8S, 16, 319L, float(17), 25D, 26E2, timestamp '2014-06-04 01:02:00.000', date '2014-06-04'), + ("t1b", 8S, 16, 19L, float(17), 25D, 26E2, timestamp '2014-07-04 01:02:00.000', date '2014-07-04'), + ("t3c", 17S, 16, 519L, float(17), 25D, 26E2, timestamp '2014-08-04 01:02:00.000', date '2014-08-04'), + ("t3c", 17S, 16, 19L, float(17), 25D, 26E2, timestamp '2014-09-04 01:02:00.000', date '2014-09-05'), + ("t1b", null, 16, 419L, float(17), 25D, 26E2, timestamp '2014-10-04 01:02:00.000', null), + ("t1b", null, 16, 19L, float(17), 25D, 26E2, timestamp '2014-11-04 01:02:00.000', null), + ("t3b", 8S, null, 719L, float(17), 25D, 26E2, timestamp '2014-05-04 01:02:00.000', date '2014-05-04'), + ("t3b", 8S, null, 19L, float(17), 25D, 26E2, timestamp '2015-05-04 01:02:00.000', date '2015-05-04') + as t3(t3a, t3b, t3c, t3d, t3e, t3f, t3g, t3h, t3i) +-- !query 2 schema +struct<> +-- !query 2 output + + + +-- !query 3 +SELECT t1a, + Avg(t1b) +FROM t1 +WHERE t1a IN (SELECT t2a + FROM t2) +GROUP BY t1a +-- !query 3 schema +struct +-- !query 3 output +t1b 8.0 +t1c 8.0 +t1e 10.0 + + +-- !query 4 +SELECT t1a, + Max(t1b) +FROM t1 +WHERE t1b IN (SELECT t2b + FROM t2 + WHERE t1a = t2a) +GROUP BY t1a, + t1d +-- !query 4 schema +struct +-- !query 4 output +t1b 8 + + +-- !query 5 +SELECT t1a, + t1b +FROM t1 +WHERE t1c IN (SELECT t2c + FROM t2 + WHERE t1a = t2a) +GROUP BY t1a, + t1b +-- !query 5 schema +struct +-- !query 5 output +t1b 8 +t1c 8 + + +-- !query 6 +SELECT t1a, + Sum(DISTINCT( t1b )) +FROM t1 +WHERE t1c IN (SELECT t2c + FROM t2 + WHERE t1a = t2a) + OR t1c IN (SELECT t3c + FROM t3 + WHERE t1a = t3a) +GROUP BY t1a, + t1c +-- !query 6 schema +struct +-- !query 6 output +t1b 8 +t1c 8 + + +-- !query 7 +SELECT t1a, + Sum(DISTINCT( t1b )) +FROM t1 +WHERE t1c IN (SELECT t2c + FROM t2 + WHERE t1a = t2a) + AND t1c IN (SELECT t3c + FROM t3 + WHERE t1a = t3a) +GROUP BY t1a, + t1c +-- !query 7 schema +struct +-- !query 7 output +t1b 8 + + +-- !query 8 +SELECT t1a, + Count(DISTINCT( t1b )) +FROM t1 +WHERE t1c IN (SELECT t2c + FROM t2 + WHERE t1a = t2a) +GROUP BY t1a, + t1c +HAVING t1a = "t1b" +-- !query 8 schema +struct +-- !query 8 output +t1b 1 + + +-- !query 9 +SELECT * +FROM t1 +WHERE t1b IN (SELECT Max(t2b) + FROM t2 + GROUP BY t2a) +-- !query 9 schema +struct +-- !query 9 output +t1a 6 8 10 15.0 20.0 2000 2014-04-04 01:00:00 2014-04-04 +t1a 6 8 10 15.0 20.0 2000 2014-04-04 01:02:00.001 2014-04-04 +t1b 8 16 19 17.0 25.0 2600 2014-05-04 01:01:00 2014-05-04 +t1c 8 16 19 17.0 25.0 2600 2014-05-04 01:02:00.001 2014-05-05 +t1d 10 NULL 12 17.0 25.0 2600 2015-05-04 01:01:00 2015-05-04 +t1e 10 NULL 19 17.0 25.0 2600 2014-05-04 01:01:00 2014-05-04 +t1e 10 NULL 19 17.0 25.0 2600 2014-09-04 01:02:00.001 2014-09-04 +t1e 10 NULL 25 17.0 25.0 2600 2014-08-04 01:01:00 2014-08-04 + + +-- !query 10 +SELECT * +FROM (SELECT t2a, + t2b + FROM t2 + WHERE t2a IN (SELECT t1a + FROM t1 + WHERE t1b = t2b) + GROUP BY t2a, + t2b) t2 +-- !query 10 schema +struct +-- !query 10 output +t1b 8 + + +-- !query 11 +SELECT Count(DISTINCT( * )) +FROM t1 +WHERE t1b IN (SELECT Min(t2b) + FROM t2 + WHERE t1a = t2a + AND t1c = t2c + GROUP BY t2a) +-- !query 11 schema +struct +-- !query 11 output +1 + + +-- !query 12 +SELECT t1a, + t1b +FROM t1 +WHERE t1c IN (SELECT Max(t2c) + FROM t2 + WHERE t1a = t2a + GROUP BY t2a, + t2c + HAVING t2c > 8) +-- !query 12 schema +struct +-- !query 12 output +t1b 8 +t1c 8 + + +-- !query 13 +SELECT t1a, + t1b +FROM t1 +WHERE t1c IN (SELECT t2c + FROM t2 + WHERE t2a IN (SELECT Min(t3a) + FROM t3 + WHERE t3a = t2a + GROUP BY t3b) + GROUP BY t2c) +-- !query 13 schema +struct +-- !query 13 output +t1a 16 +t1a 16 +t1b 8 +t1c 8 +t1d NULL +t1d NULL + + +-- !query 14 +SELECT t1a, + Min(t1b) +FROM t1 +WHERE t1c IN (SELECT Min(t2c) + FROM t2 + WHERE t2b = t1b + GROUP BY t2a) +GROUP BY t1a +-- !query 14 schema +struct +-- !query 14 output +t1b 8 +t1c 8 + + +-- !query 15 +SELECT t1a, + Min(t1b) +FROM t1 +WHERE t1c IN (SELECT Min(t2c) + FROM t2 + WHERE t2b IN (SELECT Min(t3b) + FROM t3 + WHERE t2a = t3a + GROUP BY t3a) + GROUP BY t2c) +GROUP BY t1a, + t1d +-- !query 15 schema +struct +-- !query 15 output +t1b 8 +t1c 8 +t1d NULL +t1d NULL + + +-- !query 16 +SELECT t1a, + Min(t1b) +FROM t1 +WHERE t1c IN (SELECT Min(t2c) + FROM t2 + WHERE t2b = t1b + GROUP BY t2a) + AND t1d IN (SELECT t3d + FROM t3 + WHERE t1c = t3c + GROUP BY t3d) +GROUP BY t1a +-- !query 16 schema +struct +-- !query 16 output +t1b 8 +t1c 8 + + +-- !query 17 +SELECT t1a, + Min(t1b) +FROM t1 +WHERE t1c IN (SELECT Min(t2c) + FROM t2 + WHERE t2b = t1b + GROUP BY t2a) + OR t1d IN (SELECT t3d + FROM t3 + WHERE t1c = t3c + GROUP BY t3d) +GROUP BY t1a +-- !query 17 schema +struct +-- !query 17 output +t1a 16 +t1b 8 +t1c 8 +t1d NULL + + +-- !query 18 +SELECT t1a, + Min(t1b) +FROM t1 +WHERE t1c IN (SELECT Min(t2c) + FROM t2 + WHERE t2b = t1b + GROUP BY t2a + HAVING t2a > t1a) + OR t1d IN (SELECT t3d + FROM t3 + WHERE t1c = t3c + GROUP BY t3d + HAVING t3d = t1d) +GROUP BY t1a +HAVING Min(t1b) IS NOT NULL +-- !query 18 schema +struct +-- !query 18 output +t1a 16 +t1b 8 +t1c 8 diff --git a/sql/core/src/test/resources/sql-tests/results/subquery/in-subquery/simple-in.sql.out b/sql/core/src/test/resources/sql-tests/results/subquery/in-subquery/simple-in.sql.out new file mode 100644 index 0000000000000..66493d7fcc92d --- /dev/null +++ b/sql/core/src/test/resources/sql-tests/results/subquery/in-subquery/simple-in.sql.out @@ -0,0 +1,176 @@ +-- Automatically generated by SQLQueryTestSuite +-- Number of queries: 10 + + +-- !query 0 +create temporary view t1 as select * from values + ("t1a", 6S, 8, 10L, float(15.0), 20D, 20E2, timestamp '2014-04-04 01:00:00.000', date '2014-04-04'), + ("t1b", 8S, 16, 19L, float(17.0), 25D, 26E2, timestamp '2014-05-04 01:01:00.000', date '2014-05-04'), + ("t1a", 16S, 12, 21L, float(15.0), 20D, 20E2, timestamp '2014-06-04 01:02:00.001', date '2014-06-04'), + ("t1a", 16S, 12, 10L, float(15.0), 20D, 20E2, timestamp '2014-07-04 01:01:00.000', date '2014-07-04'), + ("t1c", 8S, 16, 19L, float(17.0), 25D, 26E2, timestamp '2014-05-04 01:02:00.001', date '2014-05-05'), + ("t1d", null, 16, 22L, float(17.0), 25D, 26E2, timestamp '2014-06-04 01:01:00.000', null), + ("t1d", null, 16, 19L, float(17.0), 25D, 26E2, timestamp '2014-07-04 01:02:00.001', null), + ("t1e", 10S, null, 25L, float(17.0), 25D, 26E2, timestamp '2014-08-04 01:01:00.000', date '2014-08-04'), + ("t1e", 10S, null, 19L, float(17.0), 25D, 26E2, timestamp '2014-09-04 01:02:00.001', date '2014-09-04'), + ("t1d", 10S, null, 12L, float(17.0), 25D, 26E2, timestamp '2015-05-04 01:01:00.000', date '2015-05-04'), + ("t1a", 6S, 8, 10L, float(15.0), 20D, 20E2, timestamp '2014-04-04 01:02:00.001', date '2014-04-04'), + ("t1e", 10S, null, 19L, float(17.0), 25D, 26E2, timestamp '2014-05-04 01:01:00.000', date '2014-05-04') + as t1(t1a, t1b, t1c, t1d, t1e, t1f, t1g, t1h, t1i) +-- !query 0 schema +struct<> +-- !query 0 output + + + +-- !query 1 +create temporary view t2 as select * from values + ("t2a", 6S, 12, 14L, float(15), 20D, 20E2, timestamp '2014-04-04 01:01:00.000', date '2014-04-04'), + ("t1b", 10S, 12, 19L, float(17), 25D, 26E2, timestamp '2014-05-04 01:01:00.000', date '2014-05-04'), + ("t1b", 8S, 16, 119L, float(17), 25D, 26E2, timestamp '2015-05-04 01:01:00.000', date '2015-05-04'), + ("t1c", 12S, 16, 219L, float(17), 25D, 26E2, timestamp '2016-05-04 01:01:00.000', date '2016-05-04'), + ("t1b", null, 16, 319L, float(17), 25D, 26E2, timestamp '2017-05-04 01:01:00.000', null), + ("t2e", 8S, null, 419L, float(17), 25D, 26E2, timestamp '2014-06-04 01:01:00.000', date '2014-06-04'), + ("t1f", 19S, null, 519L, float(17), 25D, 26E2, timestamp '2014-05-04 01:01:00.000', date '2014-05-04'), + ("t1b", 10S, 12, 19L, float(17), 25D, 26E2, timestamp '2014-06-04 01:01:00.000', date '2014-06-04'), + ("t1b", 8S, 16, 19L, float(17), 25D, 26E2, timestamp '2014-07-04 01:01:00.000', date '2014-07-04'), + ("t1c", 12S, 16, 19L, float(17), 25D, 26E2, timestamp '2014-08-04 01:01:00.000', date '2014-08-05'), + ("t1e", 8S, null, 19L, float(17), 25D, 26E2, timestamp '2014-09-04 01:01:00.000', date '2014-09-04'), + ("t1f", 19S, null, 19L, float(17), 25D, 26E2, timestamp '2014-10-04 01:01:00.000', date '2014-10-04'), + ("t1b", null, 16, 19L, float(17), 25D, 26E2, timestamp '2014-05-04 01:01:00.000', null) + as t2(t2a, t2b, t2c, t2d, t2e, t2f, t2g, t2h, t2i) +-- !query 1 schema +struct<> +-- !query 1 output + + + +-- !query 2 +create temporary view t3 as select * from values + ("t3a", 6S, 12, 110L, float(15), 20D, 20E2, timestamp '2014-04-04 01:02:00.000', date '2014-04-04'), + ("t3a", 6S, 12, 10L, float(15), 20D, 20E2, timestamp '2014-05-04 01:02:00.000', date '2014-05-04'), + ("t1b", 10S, 12, 219L, float(17), 25D, 26E2, timestamp '2014-05-04 01:02:00.000', date '2014-05-04'), + ("t1b", 10S, 12, 19L, float(17), 25D, 26E2, timestamp '2014-05-04 01:02:00.000', date '2014-05-04'), + ("t1b", 8S, 16, 319L, float(17), 25D, 26E2, timestamp '2014-06-04 01:02:00.000', date '2014-06-04'), + ("t1b", 8S, 16, 19L, float(17), 25D, 26E2, timestamp '2014-07-04 01:02:00.000', date '2014-07-04'), + ("t3c", 17S, 16, 519L, float(17), 25D, 26E2, timestamp '2014-08-04 01:02:00.000', date '2014-08-04'), + ("t3c", 17S, 16, 19L, float(17), 25D, 26E2, timestamp '2014-09-04 01:02:00.000', date '2014-09-05'), + ("t1b", null, 16, 419L, float(17), 25D, 26E2, timestamp '2014-10-04 01:02:00.000', null), + ("t1b", null, 16, 19L, float(17), 25D, 26E2, timestamp '2014-11-04 01:02:00.000', null), + ("t3b", 8S, null, 719L, float(17), 25D, 26E2, timestamp '2014-05-04 01:02:00.000', date '2014-05-04'), + ("t3b", 8S, null, 19L, float(17), 25D, 26E2, timestamp '2015-05-04 01:02:00.000', date '2015-05-04') + as t3(t3a, t3b, t3c, t3d, t3e, t3f, t3g, t3h, t3i) +-- !query 2 schema +struct<> +-- !query 2 output + + + +-- !query 3 +SELECT * +FROM t1 +WHERE t1a IN (SELECT t2a + FROM t2) +-- !query 3 schema +struct +-- !query 3 output +t1b 8 16 19 17.0 25.0 2600 2014-05-04 01:01:00 2014-05-04 +t1c 8 16 19 17.0 25.0 2600 2014-05-04 01:02:00.001 2014-05-05 +t1e 10 NULL 19 17.0 25.0 2600 2014-05-04 01:01:00 2014-05-04 +t1e 10 NULL 19 17.0 25.0 2600 2014-09-04 01:02:00.001 2014-09-04 +t1e 10 NULL 25 17.0 25.0 2600 2014-08-04 01:01:00 2014-08-04 + + +-- !query 4 +SELECT * +FROM t1 +WHERE t1b IN (SELECT t2b + FROM t2 + WHERE t1a = t2a) +-- !query 4 schema +struct +-- !query 4 output +t1b 8 16 19 17.0 25.0 2600 2014-05-04 01:01:00 2014-05-04 + + +-- !query 5 +SELECT t1a, + t1b +FROM t1 +WHERE t1c IN (SELECT t2b + FROM t2 + WHERE t1a != t2a) +-- !query 5 schema +struct +-- !query 5 output +t1a 16 +t1a 16 +t1a 6 +t1a 6 + + +-- !query 6 +SELECT t1a, + t1b +FROM t1 +WHERE t1c IN (SELECT t2b + FROM t2 + WHERE t1a = t2a + OR t1b > t2b) +-- !query 6 schema +struct +-- !query 6 output +t1a 16 +t1a 16 + + +-- !query 7 +SELECT t1a, + t1b +FROM t1 +WHERE t1c IN (SELECT t2b + FROM t2 + WHERE t2i IN (SELECT t3i + FROM t3 + WHERE t2c = t3c)) +-- !query 7 schema +struct +-- !query 7 output +t1a 6 +t1a 6 + + +-- !query 8 +SELECT t1a, + t1b +FROM t1 +WHERE t1c IN (SELECT t2b + FROM t2 + WHERE t2a IN (SELECT t3a + FROM t3 + WHERE t2c = t3c + AND t2b IS NOT NULL)) +-- !query 8 schema +struct +-- !query 8 output +t1a 6 +t1a 6 + + +-- !query 9 +SELECT DISTINCT( t1a ), + t1b, + t1h +FROM t1 +WHERE t1a NOT IN (SELECT t2a + FROM t2) +-- !query 9 schema +struct +-- !query 9 output +t1a 16 2014-06-04 01:02:00.001 +t1a 16 2014-07-04 01:01:00 +t1a 6 2014-04-04 01:00:00 +t1a 6 2014-04-04 01:02:00.001 +t1d 10 2015-05-04 01:01:00 +t1d NULL 2014-06-04 01:01:00 +t1d NULL 2014-07-04 01:02:00.001 diff --git a/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala index f42402e1cc7d2..fb4812adf108a 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala @@ -24,6 +24,8 @@ import scala.language.postfixOps import org.scalatest.concurrent.Eventually._ import org.apache.spark.CleanerListener +import org.apache.spark.sql.catalyst.expressions.{Expression, SubqueryExpression} +import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.execution.RDDScanExec import org.apache.spark.sql.execution.columnar._ import org.apache.spark.sql.execution.exchange.ShuffleExchange @@ -37,6 +39,14 @@ private case class BigData(s: String) class CachedTableSuite extends QueryTest with SQLTestUtils with SharedSQLContext { import testImplicits._ + override def afterEach(): Unit = { + try { + spark.catalog.clearCache() + } finally { + super.afterEach() + } + } + def rddIdOf(tableName: String): Int = { val plan = spark.table(tableName).queryExecution.sparkPlan plan.collect { @@ -53,6 +63,16 @@ class CachedTableSuite extends QueryTest with SQLTestUtils with SharedSQLContext maybeBlock.nonEmpty } + private def getNumInMemoryRelations(plan: LogicalPlan): Int = { + var sum = plan.collect { case _: InMemoryRelation => 1 }.sum + plan.transformAllExpressions { + case e: SubqueryExpression => + sum += getNumInMemoryRelations(e.plan) + e + } + sum + } + test("withColumn doesn't invalidate cached dataframe") { var evalCount = 0 val myUDF = udf((x: String) => { evalCount += 1; "result" }) @@ -565,4 +585,56 @@ class CachedTableSuite extends QueryTest with SQLTestUtils with SharedSQLContext case i: InMemoryRelation => i }.size == 1) } + + test("SPARK-19093 Caching in side subquery") { + withTempView("t1") { + Seq(1).toDF("c1").createOrReplaceTempView("t1") + spark.catalog.cacheTable("t1") + val cachedPlan = + sql( + """ + |SELECT * FROM t1 + |WHERE + |NOT EXISTS (SELECT * FROM t1) + """.stripMargin).queryExecution.optimizedPlan + assert(getNumInMemoryRelations(cachedPlan) == 2) + } + } + + test("SPARK-19093 scalar and nested predicate query") { + withTempView("t1", "t2", "t3", "t4") { + Seq(1).toDF("c1").createOrReplaceTempView("t1") + Seq(2).toDF("c1").createOrReplaceTempView("t2") + Seq(1).toDF("c1").createOrReplaceTempView("t3") + Seq(1).toDF("c1").createOrReplaceTempView("t4") + spark.catalog.cacheTable("t1") + spark.catalog.cacheTable("t2") + spark.catalog.cacheTable("t3") + spark.catalog.cacheTable("t4") + + // Nested predicate subquery + val cachedPlan = + sql( + """ + |SELECT * FROM t1 + |WHERE + |c1 IN (SELECT c1 FROM t2 WHERE c1 IN (SELECT c1 FROM t3 WHERE c1 = 1)) + """.stripMargin).queryExecution.optimizedPlan + assert(getNumInMemoryRelations(cachedPlan) == 3) + + // Scalar subquery and predicate subquery + val cachedPlan2 = + sql( + """ + |SELECT * FROM (SELECT max(c1) FROM t1 GROUP BY c1) + |WHERE + |c1 = (SELECT max(c1) FROM t2 GROUP BY c1) + |OR + |EXISTS (SELECT c1 FROM t3) + |OR + |c1 IN (SELECT c1 FROM t4) + """.stripMargin).queryExecution.optimizedPlan + assert(getNumInMemoryRelations(cachedPlan2) == 4) + } + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameTimeWindowingSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameTimeWindowingSuite.scala index 4296ec543e275..22d5c47a6fb51 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameTimeWindowingSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameTimeWindowingSuite.scala @@ -257,7 +257,7 @@ class DataFrameTimeWindowingSuite extends QueryTest with SharedSQLContext with B } } - test("time window in SQL with with two expressions") { + test("time window in SQL with two expressions") { withTempTable { table => checkAnswer( spark.sql( @@ -272,7 +272,7 @@ class DataFrameTimeWindowingSuite extends QueryTest with SharedSQLContext with B } } - test("time window in SQL with with three expressions") { + test("time window in SQL with three expressions") { withTempTable { table => checkAnswer( spark.sql( diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetPrimitiveSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetPrimitiveSuite.scala index f8d4c61967f95..6b50cb3e48c76 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetPrimitiveSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetPrimitiveSuite.scala @@ -17,10 +17,21 @@ package org.apache.spark.sql +import scala.collection.immutable.Queue +import scala.collection.mutable.ArrayBuffer + import org.apache.spark.sql.test.SharedSQLContext case class IntClass(value: Int) +case class SeqClass(s: Seq[Int]) + +case class ListClass(l: List[Int]) + +case class QueueClass(q: Queue[Int]) + +case class ComplexClass(seq: SeqClass, list: ListClass, queue: QueueClass) + package object packageobject { case class PackageClass(value: Int) } @@ -130,6 +141,62 @@ class DatasetPrimitiveSuite extends QueryTest with SharedSQLContext { checkDataset(Seq(Array(Tuple1(1))).toDS(), Array(Tuple1(1))) } + test("arbitrary sequences") { + checkDataset(Seq(Queue(1)).toDS(), Queue(1)) + checkDataset(Seq(Queue(1.toLong)).toDS(), Queue(1.toLong)) + checkDataset(Seq(Queue(1.toDouble)).toDS(), Queue(1.toDouble)) + checkDataset(Seq(Queue(1.toFloat)).toDS(), Queue(1.toFloat)) + checkDataset(Seq(Queue(1.toByte)).toDS(), Queue(1.toByte)) + checkDataset(Seq(Queue(1.toShort)).toDS(), Queue(1.toShort)) + checkDataset(Seq(Queue(true)).toDS(), Queue(true)) + checkDataset(Seq(Queue("test")).toDS(), Queue("test")) + checkDataset(Seq(Queue(Tuple1(1))).toDS(), Queue(Tuple1(1))) + + checkDataset(Seq(ArrayBuffer(1)).toDS(), ArrayBuffer(1)) + checkDataset(Seq(ArrayBuffer(1.toLong)).toDS(), ArrayBuffer(1.toLong)) + checkDataset(Seq(ArrayBuffer(1.toDouble)).toDS(), ArrayBuffer(1.toDouble)) + checkDataset(Seq(ArrayBuffer(1.toFloat)).toDS(), ArrayBuffer(1.toFloat)) + checkDataset(Seq(ArrayBuffer(1.toByte)).toDS(), ArrayBuffer(1.toByte)) + checkDataset(Seq(ArrayBuffer(1.toShort)).toDS(), ArrayBuffer(1.toShort)) + checkDataset(Seq(ArrayBuffer(true)).toDS(), ArrayBuffer(true)) + checkDataset(Seq(ArrayBuffer("test")).toDS(), ArrayBuffer("test")) + checkDataset(Seq(ArrayBuffer(Tuple1(1))).toDS(), ArrayBuffer(Tuple1(1))) + } + + test("sequence and product combinations") { + // Case classes + checkDataset(Seq(SeqClass(Seq(1))).toDS(), SeqClass(Seq(1))) + checkDataset(Seq(Seq(SeqClass(Seq(1)))).toDS(), Seq(SeqClass(Seq(1)))) + checkDataset(Seq(List(SeqClass(Seq(1)))).toDS(), List(SeqClass(Seq(1)))) + checkDataset(Seq(Queue(SeqClass(Seq(1)))).toDS(), Queue(SeqClass(Seq(1)))) + + checkDataset(Seq(ListClass(List(1))).toDS(), ListClass(List(1))) + checkDataset(Seq(Seq(ListClass(List(1)))).toDS(), Seq(ListClass(List(1)))) + checkDataset(Seq(List(ListClass(List(1)))).toDS(), List(ListClass(List(1)))) + checkDataset(Seq(Queue(ListClass(List(1)))).toDS(), Queue(ListClass(List(1)))) + + checkDataset(Seq(QueueClass(Queue(1))).toDS(), QueueClass(Queue(1))) + checkDataset(Seq(Seq(QueueClass(Queue(1)))).toDS(), Seq(QueueClass(Queue(1)))) + checkDataset(Seq(List(QueueClass(Queue(1)))).toDS(), List(QueueClass(Queue(1)))) + checkDataset(Seq(Queue(QueueClass(Queue(1)))).toDS(), Queue(QueueClass(Queue(1)))) + + val complex = ComplexClass(SeqClass(Seq(1)), ListClass(List(2)), QueueClass(Queue(3))) + checkDataset(Seq(complex).toDS(), complex) + checkDataset(Seq(Seq(complex)).toDS(), Seq(complex)) + checkDataset(Seq(List(complex)).toDS(), List(complex)) + checkDataset(Seq(Queue(complex)).toDS(), Queue(complex)) + + // Tuples + checkDataset(Seq(Seq(1) -> Seq(2)).toDS(), Seq(1) -> Seq(2)) + checkDataset(Seq(List(1) -> Queue(2)).toDS(), List(1) -> Queue(2)) + checkDataset(Seq(List(Seq("test1") -> List(Queue("test2")))).toDS(), + List(Seq("test1") -> List(Queue("test2")))) + + // Complex + checkDataset(Seq(ListClass(List(1)) -> Queue("test" -> SeqClass(Seq(2)))).toDS(), + ListClass(List(1)) -> Queue("test" -> SeqClass(Seq(2)))) + } + test("package objects") { import packageobject._ checkDataset(Seq(PackageClass(1)).toDS(), PackageClass(1)) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala index 375da224aaa7f..0bfc92fdb6218 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala @@ -363,7 +363,7 @@ class PlannerSuite extends SharedSQLContext { // This is a regression test for SPARK-9703 test("EnsureRequirements should not repartition if only ordering requirement is unsatisfied") { // Consider an operator that imposes both output distribution and ordering requirements on its - // children, such as sort sort merge join. If the distribution requirements are satisfied but + // children, such as sort merge join. If the distribution requirements are satisfied but // the output ordering requirements are unsatisfied, then the planner should only add sorts and // should not need to add additional shuffles / exchanges. val outputOrdering = Seq(SortOrder(Literal(1), Ascending)) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkSqlParserSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkSqlParserSuite.scala index b070138be05d5..15e490fb30a27 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkSqlParserSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkSqlParserSuite.scala @@ -121,7 +121,8 @@ class SparkSqlParserSuite extends PlanTest { tableType: CatalogTableType = CatalogTableType.MANAGED, storage: CatalogStorageFormat = CatalogStorageFormat.empty.copy( inputFormat = HiveSerDe.sourceToSerDe("textfile").get.inputFormat, - outputFormat = HiveSerDe.sourceToSerDe("textfile").get.outputFormat), + outputFormat = HiveSerDe.sourceToSerDe("textfile").get.outputFormat, + serde = Some("org.apache.hadoop.hive.serde2.lazy.LazySimpleSerDe")), schema: StructType = new StructType, provider: Option[String] = Some("hive"), partitionColumnNames: Seq[String] = Seq.empty, diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLCommandSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLCommandSuite.scala index 1a5e5226c2ec9..76bb9e5929a71 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLCommandSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLCommandSuite.scala @@ -236,7 +236,7 @@ class DDLCommandSuite extends PlanTest { comparePlans(parsed4, expected4) } - test("create table - table file format") { + test("create hive table - table file format") { val allSources = Seq("parquet", "parquetfile", "orc", "orcfile", "avro", "avrofile", "sequencefile", "rcfile", "textfile") @@ -245,13 +245,14 @@ class DDLCommandSuite extends PlanTest { val ct = parseAs[CreateTable](query) val hiveSerde = HiveSerDe.sourceToSerDe(s) assert(hiveSerde.isDefined) - assert(ct.tableDesc.storage.serde == hiveSerde.get.serde) + assert(ct.tableDesc.storage.serde == + hiveSerde.get.serde.orElse(Some("org.apache.hadoop.hive.serde2.lazy.LazySimpleSerDe"))) assert(ct.tableDesc.storage.inputFormat == hiveSerde.get.inputFormat) assert(ct.tableDesc.storage.outputFormat == hiveSerde.get.outputFormat) } } - test("create table - row format and table file format") { + test("create hive table - row format and table file format") { val createTableStart = "CREATE TABLE my_tab ROW FORMAT" val fileFormat = s"STORED AS INPUTFORMAT 'inputfmt' OUTPUTFORMAT 'outputfmt'" val query1 = s"$createTableStart SERDE 'anything' $fileFormat" @@ -262,13 +263,15 @@ class DDLCommandSuite extends PlanTest { assert(parsed1.tableDesc.storage.serde == Some("anything")) assert(parsed1.tableDesc.storage.inputFormat == Some("inputfmt")) assert(parsed1.tableDesc.storage.outputFormat == Some("outputfmt")) + val parsed2 = parseAs[CreateTable](query2) - assert(parsed2.tableDesc.storage.serde.isEmpty) + assert(parsed2.tableDesc.storage.serde == + Some("org.apache.hadoop.hive.serde2.lazy.LazySimpleSerDe")) assert(parsed2.tableDesc.storage.inputFormat == Some("inputfmt")) assert(parsed2.tableDesc.storage.outputFormat == Some("outputfmt")) } - test("create table - row format serde and generic file format") { + test("create hive table - row format serde and generic file format") { val allSources = Seq("parquet", "orc", "avro", "sequencefile", "rcfile", "textfile") val supportedSources = Set("sequencefile", "rcfile", "textfile") @@ -287,7 +290,7 @@ class DDLCommandSuite extends PlanTest { } } - test("create table - row format delimited and generic file format") { + test("create hive table - row format delimited and generic file format") { val allSources = Seq("parquet", "orc", "avro", "sequencefile", "rcfile", "textfile") val supportedSources = Set("textfile") @@ -297,7 +300,8 @@ class DDLCommandSuite extends PlanTest { val ct = parseAs[CreateTable](query) val hiveSerde = HiveSerDe.sourceToSerDe(s) assert(hiveSerde.isDefined) - assert(ct.tableDesc.storage.serde == hiveSerde.get.serde) + assert(ct.tableDesc.storage.serde == + hiveSerde.get.serde.orElse(Some("org.apache.hadoop.hive.serde2.lazy.LazySimpleSerDe"))) assert(ct.tableDesc.storage.inputFormat == hiveSerde.get.inputFormat) assert(ct.tableDesc.storage.outputFormat == hiveSerde.get.outputFormat) } else { @@ -306,7 +310,7 @@ class DDLCommandSuite extends PlanTest { } } - test("create external table - location must be specified") { + test("create hive external table - location must be specified") { assertUnsupported( sql = "CREATE EXTERNAL TABLE my_tab", containsThesePhrases = Seq("create external table", "location")) @@ -316,7 +320,7 @@ class DDLCommandSuite extends PlanTest { assert(ct.tableDesc.storage.locationUri == Some("/something/anything")) } - test("create table - property values must be set") { + test("create hive table - property values must be set") { assertUnsupported( sql = "CREATE TABLE my_tab TBLPROPERTIES('key_without_value', 'key_with_value'='x')", containsThesePhrases = Seq("key_without_value")) @@ -326,14 +330,14 @@ class DDLCommandSuite extends PlanTest { containsThesePhrases = Seq("key_without_value")) } - test("create table - location implies external") { + test("create hive table - location implies external") { val query = "CREATE TABLE my_tab LOCATION '/something/anything'" val ct = parseAs[CreateTable](query) assert(ct.tableDesc.tableType == CatalogTableType.EXTERNAL) assert(ct.tableDesc.storage.locationUri == Some("/something/anything")) } - test("create table using - with partitioned by") { + test("create table - with partitioned by") { val query = "CREATE TABLE my_tab(a INT comment 'test', b STRING) " + "USING parquet PARTITIONED BY (a)" @@ -357,7 +361,7 @@ class DDLCommandSuite extends PlanTest { } } - test("create table using - with bucket") { + test("create table - with bucket") { val query = "CREATE TABLE my_tab(a INT, b STRING) USING parquet " + "CLUSTERED BY (a) SORTED BY (b) INTO 5 BUCKETS" @@ -379,6 +383,57 @@ class DDLCommandSuite extends PlanTest { } } + test("create table - with comment") { + val sql = "CREATE TABLE my_tab(a INT, b STRING) USING parquet COMMENT 'abc'" + + val expectedTableDesc = CatalogTable( + identifier = TableIdentifier("my_tab"), + tableType = CatalogTableType.MANAGED, + storage = CatalogStorageFormat.empty, + schema = new StructType().add("a", IntegerType).add("b", StringType), + provider = Some("parquet"), + comment = Some("abc")) + + parser.parsePlan(sql) match { + case CreateTable(tableDesc, _, None) => + assert(tableDesc == expectedTableDesc.copy(createTime = tableDesc.createTime)) + case other => + fail(s"Expected to parse ${classOf[CreateTableCommand].getClass.getName} from query," + + s"got ${other.getClass.getName}: $sql") + } + } + + test("create table - with location") { + val v1 = "CREATE TABLE my_tab(a INT, b STRING) USING parquet LOCATION '/tmp/file'" + + val expectedTableDesc = CatalogTable( + identifier = TableIdentifier("my_tab"), + tableType = CatalogTableType.EXTERNAL, + storage = CatalogStorageFormat.empty.copy(locationUri = Some("/tmp/file")), + schema = new StructType().add("a", IntegerType).add("b", StringType), + provider = Some("parquet")) + + parser.parsePlan(v1) match { + case CreateTable(tableDesc, _, None) => + assert(tableDesc == expectedTableDesc.copy(createTime = tableDesc.createTime)) + case other => + fail(s"Expected to parse ${classOf[CreateTableCommand].getClass.getName} from query," + + s"got ${other.getClass.getName}: $v1") + } + + val v2 = + """ + |CREATE TABLE my_tab(a INT, b STRING) + |USING parquet + |OPTIONS (path '/tmp/file') + |LOCATION '/tmp/file' + """.stripMargin + val e = intercept[ParseException] { + parser.parsePlan(v2) + } + assert(e.message.contains("you can only specify one of them.")) + } + // ALTER TABLE table_name RENAME TO new_table_name; // ALTER VIEW view_name RENAME TO new_view_name; test("alter table/view: rename table/view") { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/FileStreamSinkSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/FileStreamSinkSuite.scala index 22f59f63d6f79..f67444fbc49d6 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/FileStreamSinkSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/FileStreamSinkSuite.scala @@ -144,7 +144,7 @@ class FileStreamSinkSuite extends StreamTest { } // This tests whether FileStreamSink works with aggregations. Specifically, it tests - // whether the the correct streaming QueryExecution (i.e. IncrementalExecution) is used to + // whether the correct streaming QueryExecution (i.e. IncrementalExecution) is used to // to execute the trigger for writing data to file sink. See SPARK-18440 for more details. test("writing with aggregation") { diff --git a/sql/hive-thriftserver/src/main/java/org/apache/hive/service/AbstractService.java b/sql/hive-thriftserver/src/main/java/org/apache/hive/service/AbstractService.java index c2a2b2d478af7..9dd0efc03968d 100644 --- a/sql/hive-thriftserver/src/main/java/org/apache/hive/service/AbstractService.java +++ b/sql/hive-thriftserver/src/main/java/org/apache/hive/service/AbstractService.java @@ -151,7 +151,7 @@ public long getStartTime() { } /** - * Verify that that a service is in a given state. + * Verify that a service is in a given state. * * @param currentState * the desired state diff --git a/sql/hive-thriftserver/src/main/java/org/apache/hive/service/ServiceOperations.java b/sql/hive-thriftserver/src/main/java/org/apache/hive/service/ServiceOperations.java index 8946219d85ccd..a2c580d6acc71 100644 --- a/sql/hive-thriftserver/src/main/java/org/apache/hive/service/ServiceOperations.java +++ b/sql/hive-thriftserver/src/main/java/org/apache/hive/service/ServiceOperations.java @@ -33,7 +33,7 @@ private ServiceOperations() { } /** - * Verify that that a service is in a given state. + * Verify that a service is in a given state. * @param state the actual state a service is in * @param expectedState the desired state * @throws IllegalStateException if the service state is different from diff --git a/sql/hive-thriftserver/src/main/java/org/apache/hive/service/ServiceStateChangeListener.java b/sql/hive-thriftserver/src/main/java/org/apache/hive/service/ServiceStateChangeListener.java index d1aadad04cf67..a1ff10dc2bc9f 100644 --- a/sql/hive-thriftserver/src/main/java/org/apache/hive/service/ServiceStateChangeListener.java +++ b/sql/hive-thriftserver/src/main/java/org/apache/hive/service/ServiceStateChangeListener.java @@ -29,7 +29,7 @@ public interface ServiceStateChangeListener { * have changed state before this callback is invoked. * * This operation is invoked on the thread that initiated the state change, - * while the service itself in in a synchronized section. + * while the service itself in a synchronized section. *
    *
  1. Any long-lived operation here will prevent the service state * change from completing in a timely manner.
  2. diff --git a/sql/hive-thriftserver/src/main/java/org/apache/hive/service/cli/TypeDescriptor.java b/sql/hive-thriftserver/src/main/java/org/apache/hive/service/cli/TypeDescriptor.java index 562b3f5e67865..b80fd67884add 100644 --- a/sql/hive-thriftserver/src/main/java/org/apache/hive/service/cli/TypeDescriptor.java +++ b/sql/hive-thriftserver/src/main/java/org/apache/hive/service/cli/TypeDescriptor.java @@ -98,7 +98,7 @@ public void setTypeQualifiers(TypeQualifiers typeQualifiers) { * For datetime types this is the length in characters of the String representation * (assuming the maximum allowed precision of the fractional seconds component). * For binary data this is the length in bytes. - * Null is returned for for data types where the column size is not applicable. + * Null is returned for data types where the column size is not applicable. */ public Integer getColumnSize() { if (type.isNumericType()) { diff --git a/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala b/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala index 5cd4935e225ee..d217e9b4feb6d 100644 --- a/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala +++ b/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala @@ -178,7 +178,7 @@ class HiveCompatibilitySuite extends HiveQueryFileTest with BeforeAndAfter { "skewjoin", "database", - // These tests fail and and exit the JVM. + // These tests fail and exit the JVM. "auto_join18_multi_distinct", "join18_multi_distinct", "input44", diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveExternalCatalog.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveExternalCatalog.scala index fde6d4a94762d..474a2c868e20a 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveExternalCatalog.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveExternalCatalog.scala @@ -215,7 +215,7 @@ private[spark] class HiveExternalCatalog(conf: SparkConf, hadoopConf: Configurat tableDefinition.storage.locationUri } - if (tableDefinition.provider.get == DDLUtils.HIVE_PROVIDER) { + if (DDLUtils.isHiveTable(tableDefinition)) { val tableWithDataSourceProps = tableDefinition.copy( // We can't leave `locationUri` empty and count on Hive metastore to set a default table // location, because Hive metastore uses hive.metastore.warehouse.dir to generate default @@ -533,7 +533,7 @@ private[spark] class HiveExternalCatalog(conf: SparkConf, hadoopConf: Configurat } else { val oldTableDef = getRawTable(db, withStatsProps.identifier.table) - val newStorage = if (tableDefinition.provider.get == DDLUtils.HIVE_PROVIDER) { + val newStorage = if (DDLUtils.isHiveTable(tableDefinition)) { tableDefinition.storage } else { // We can't alter the table storage of data source table directly for 2 reasons: diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionState.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionState.scala index bea073fb4848a..52892f1c7270c 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionState.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionState.scala @@ -64,6 +64,7 @@ private[hive] class HiveSessionState(sparkSession: SparkSession) AnalyzeCreateTable(sparkSession) :: PreprocessTableInsertion(conf) :: DataSourceAnalysis(conf) :: + new DetermineHiveSerde(conf) :: (if (conf.runSQLonFile) new ResolveDataSource(sparkSession) :: Nil else Nil) override val extendedCheckRules = Seq(PreWriteCheck(conf, catalog)) diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala index 773c4a39d854c..6d5cc5778aef8 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala @@ -18,14 +18,73 @@ package org.apache.spark.sql.hive import org.apache.spark.sql._ +import org.apache.spark.sql.catalyst.catalog.CatalogStorageFormat import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.planning._ import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan +import org.apache.spark.sql.catalyst.rules.Rule import org.apache.spark.sql.execution._ -import org.apache.spark.sql.execution.command.ExecutedCommandExec +import org.apache.spark.sql.execution.command.{DDLUtils, ExecutedCommandExec} import org.apache.spark.sql.execution.datasources.CreateTable import org.apache.spark.sql.hive.execution._ +import org.apache.spark.sql.internal.{HiveSerDe, SQLConf} + + +/** + * Determine the serde/format of the Hive serde table, according to the storage properties. + */ +class DetermineHiveSerde(conf: SQLConf) extends Rule[LogicalPlan] { + override def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators { + case c @ CreateTable(t, _, query) if DDLUtils.isHiveTable(t) && t.storage.serde.isEmpty => + if (t.bucketSpec.isDefined) { + throw new AnalysisException("Creating bucketed Hive serde table is not supported yet.") + } + if (t.partitionColumnNames.nonEmpty && query.isDefined) { + val errorMessage = "A Create Table As Select (CTAS) statement is not allowed to " + + "create a partitioned table using Hive's file formats. " + + "Please use the syntax of \"CREATE TABLE tableName USING dataSource " + + "OPTIONS (...) PARTITIONED BY ...\" to create a partitioned table through a " + + "CTAS statement." + throw new AnalysisException(errorMessage) + } + + val defaultStorage = HiveSerDe.getDefaultStorage(conf) + val options = new HiveOptions(t.storage.properties) + + val fileStorage = if (options.fileFormat.isDefined) { + HiveSerDe.sourceToSerDe(options.fileFormat.get) match { + case Some(s) => + CatalogStorageFormat.empty.copy( + inputFormat = s.inputFormat, + outputFormat = s.outputFormat, + serde = s.serde) + case None => + throw new IllegalArgumentException(s"invalid fileFormat: '${options.fileFormat.get}'") + } + } else if (options.hasInputOutputFormat) { + CatalogStorageFormat.empty.copy( + inputFormat = options.inputFormat, + outputFormat = options.outputFormat) + } else { + CatalogStorageFormat.empty + } + + val rowStorage = if (options.serde.isDefined) { + CatalogStorageFormat.empty.copy(serde = options.serde) + } else { + CatalogStorageFormat.empty + } + + val storage = t.storage.copy( + inputFormat = fileStorage.inputFormat.orElse(defaultStorage.inputFormat), + outputFormat = fileStorage.outputFormat.orElse(defaultStorage.outputFormat), + serde = rowStorage.serde.orElse(fileStorage.serde).orElse(defaultStorage.serde), + properties = options.serdeProperties) + + c.copy(tableDesc = t.copy(storage = storage)) + } +} private[hive] trait HiveStrategies { // Possibly being too clever with types here... or not clever enough. @@ -49,15 +108,7 @@ private[hive] trait HiveStrategies { InsertIntoHiveTable( table, partition, planLater(child), overwrite, ifNotExists) :: Nil - case CreateTable(tableDesc, mode, Some(query)) if tableDesc.provider.get == "hive" => - val newTableDesc = if (tableDesc.storage.serde.isEmpty) { - // add default serde - tableDesc.withNewStorage( - serde = Some("org.apache.hadoop.hive.serde2.lazy.LazySimpleSerDe")) - } else { - tableDesc - } - + case CreateTable(tableDesc, mode, Some(query)) if DDLUtils.isHiveTable(tableDesc) => // Currently we will never hit this branch, as SQL string API can only use `Ignore` or // `ErrorIfExists` mode, and `DataFrameWriter.saveAsTable` doesn't support hive serde // tables yet. @@ -68,7 +119,7 @@ private[hive] trait HiveStrategies { val dbName = tableDesc.identifier.database.getOrElse(sparkSession.catalog.currentDatabase) val cmd = CreateHiveTableAsSelectCommand( - newTableDesc.copy(identifier = tableDesc.identifier.copy(database = Some(dbName))), + tableDesc.copy(identifier = tableDesc.identifier.copy(database = Some(dbName))), query, mode == SaveMode.Ignore) ExecutedCommandExec(cmd) :: Nil diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/HiveOptions.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/HiveOptions.scala new file mode 100644 index 0000000000000..35b7a681f12e0 --- /dev/null +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/HiveOptions.scala @@ -0,0 +1,102 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.hive.execution + +import org.apache.spark.sql.catalyst.util.CaseInsensitiveMap + +/** + * Options for the Hive data source. Note that rule `DetermineHiveSerde` will extract Hive + * serde/format information from these options. + */ +class HiveOptions(@transient private val parameters: CaseInsensitiveMap) extends Serializable { + import HiveOptions._ + + def this(parameters: Map[String, String]) = this(new CaseInsensitiveMap(parameters)) + + val fileFormat = parameters.get(FILE_FORMAT).map(_.toLowerCase) + val inputFormat = parameters.get(INPUT_FORMAT) + val outputFormat = parameters.get(OUTPUT_FORMAT) + + if (inputFormat.isDefined != outputFormat.isDefined) { + throw new IllegalArgumentException("Cannot specify only inputFormat or outputFormat, you " + + "have to specify both of them.") + } + + def hasInputOutputFormat: Boolean = inputFormat.isDefined + + if (fileFormat.isDefined && inputFormat.isDefined) { + throw new IllegalArgumentException("Cannot specify fileFormat and inputFormat/outputFormat " + + "together for Hive data source.") + } + + val serde = parameters.get(SERDE) + + if (fileFormat.isDefined && serde.isDefined) { + if (!Set("sequencefile", "textfile", "rcfile").contains(fileFormat.get)) { + throw new IllegalArgumentException( + s"fileFormat '${fileFormat.get}' already specifies a serde.") + } + } + + val containsDelimiters = delimiterOptions.keys.exists(parameters.contains) + + if (containsDelimiters) { + if (serde.isDefined) { + throw new IllegalArgumentException("Cannot specify delimiters with a custom serde.") + } + if (fileFormat.isEmpty) { + throw new IllegalArgumentException("Cannot specify delimiters without fileFormat.") + } + if (fileFormat.get != "textfile") { + throw new IllegalArgumentException("Cannot specify delimiters as they are only compatible " + + s"with fileFormat 'textfile', not ${fileFormat.get}.") + } + } + + for (lineDelim <- parameters.get("lineDelim") if lineDelim != "\n") { + throw new IllegalArgumentException("Hive data source only support newline '\\n' as " + + s"line delimiter, but given: $lineDelim.") + } + + def serdeProperties: Map[String, String] = parameters.filterKeys { + k => !lowerCasedOptionNames.contains(k.toLowerCase) + }.map { case (k, v) => delimiterOptions.getOrElse(k, k) -> v } +} + +object HiveOptions { + private val lowerCasedOptionNames = collection.mutable.Set[String]() + + private def newOption(name: String): String = { + lowerCasedOptionNames += name.toLowerCase + name + } + + val FILE_FORMAT = newOption("fileFormat") + val INPUT_FORMAT = newOption("inputFormat") + val OUTPUT_FORMAT = newOption("outputFormat") + val SERDE = newOption("serde") + + // A map from the public delimiter option keys to the underlying Hive serde property keys. + val delimiterOptions = Map( + "fieldDelim" -> "field.delim", + "escapeDelim" -> "escape.delim", + // The following typo is inherited from Hive... + "collectionDelim" -> "colelction.delim", + "mapkeyDelim" -> "mapkey.delim", + "lineDelim" -> "line.delim").map { case (k, v) => k.toLowerCase -> v } +} diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/HiveTableScanExec.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/HiveTableScanExec.scala index c80695bd3e0fe..7ee5fc543c55b 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/HiveTableScanExec.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/HiveTableScanExec.scala @@ -40,7 +40,7 @@ import org.apache.spark.util.Utils * The Hive table scan operator. Column and partition pruning are both handled. * * @param requestedAttributes Attributes to be fetched from the Hive table. - * @param relation The Hive table be be scanned. + * @param relation The Hive table be scanned. * @param partitionPruningPred An optional partition pruning predicate for partitioned table. */ private[hive] diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcFileFormat.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcFileFormat.scala index 0a7631f782193..f496c01ce9ff7 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcFileFormat.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcFileFormat.scala @@ -239,10 +239,7 @@ private[orc] class OrcOutputWriter( ).asInstanceOf[RecordWriter[NullWritable, Writable]] } - override def write(row: Row): Unit = - throw new UnsupportedOperationException("call writeInternal") - - override protected[sql] def writeInternal(row: InternalRow): Unit = { + override def write(row: InternalRow): Unit = { recordWriter.write(NullWritable.get(), serializer.serialize(row)) } diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcFileOperator.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcFileOperator.scala index 3f1f86c278db0..5a3fcd7a759c0 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcFileOperator.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcFileOperator.scala @@ -27,7 +27,7 @@ import org.apache.spark.internal.Logging import org.apache.spark.sql.catalyst.parser.CatalystSqlParser import org.apache.spark.sql.types.StructType -private[orc] object OrcFileOperator extends Logging { +private[hive] object OrcFileOperator extends Logging { /** * Retrieves an ORC file reader from a given path. The path can point to either a directory or a * single ORC file. If it points to a directory, it picks any non-empty ORC file within that diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveDDLCommandSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveDDLCommandSuite.scala index d13e29b3029b1..b67e5f6fe57a1 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveDDLCommandSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveDDLCommandSuite.scala @@ -31,7 +31,6 @@ import org.apache.spark.sql.catalyst.plans.logical.{Generate, ScriptTransformati import org.apache.spark.sql.execution.command._ import org.apache.spark.sql.execution.datasources.CreateTable import org.apache.spark.sql.hive.test.{TestHive, TestHiveSingleton} -import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.test.SQLTestUtils import org.apache.spark.sql.types.StructType @@ -51,6 +50,12 @@ class HiveDDLCommandSuite extends PlanTest with SQLTestUtils with TestHiveSingle assert(e.getMessage.toLowerCase.contains("operation not allowed")) } + private def analyzeCreateTable(sql: String): CatalogTable = { + TestHive.sessionState.analyzer.execute(parser.parsePlan(sql)).collect { + case CreateTable(tableDesc, mode, _) => tableDesc + }.head + } + test("Test CTAS #1") { val s1 = """CREATE EXTERNAL TABLE IF NOT EXISTS mydb.page_view @@ -76,7 +81,7 @@ class HiveDDLCommandSuite extends PlanTest with SQLTestUtils with TestHiveSingle assert(desc.storage.outputFormat == Some("org.apache.hadoop.hive.ql.io.RCFileOutputFormat")) assert(desc.storage.serde == Some("org.apache.hadoop.hive.serde2.columnar.LazyBinaryColumnarSerDe")) - assert(desc.properties == Map(("p1", "v1"), ("p2", "v2"))) + assert(desc.properties == Map("p1" -> "v1", "p2" -> "v2")) } test("Test CTAS #2") { @@ -107,7 +112,7 @@ class HiveDDLCommandSuite extends PlanTest with SQLTestUtils with TestHiveSingle assert(desc.storage.inputFormat == Some("parquet.hive.DeprecatedParquetInputFormat")) assert(desc.storage.outputFormat == Some("parquet.hive.DeprecatedParquetOutputFormat")) assert(desc.storage.serde == Some("parquet.hive.serde.ParquetHiveSerDe")) - assert(desc.properties == Map(("p1", "v1"), ("p2", "v2"))) + assert(desc.properties == Map("p1" -> "v1", "p2" -> "v2")) } test("Test CTAS #3") { @@ -125,7 +130,7 @@ class HiveDDLCommandSuite extends PlanTest with SQLTestUtils with TestHiveSingle assert(desc.storage.inputFormat == Some("org.apache.hadoop.mapred.TextInputFormat")) assert(desc.storage.outputFormat == Some("org.apache.hadoop.hive.ql.io.HiveIgnoreKeyTextOutputFormat")) - assert(desc.storage.serde.isEmpty) + assert(desc.storage.serde == Some("org.apache.hadoop.hive.serde2.lazy.LazySimpleSerDe")) assert(desc.properties == Map()) } @@ -305,7 +310,7 @@ class HiveDDLCommandSuite extends PlanTest with SQLTestUtils with TestHiveSingle Some("org.apache.hadoop.mapred.TextInputFormat")) assert(desc.storage.outputFormat == Some("org.apache.hadoop.hive.ql.io.HiveIgnoreKeyTextOutputFormat")) - assert(desc.storage.serde.isEmpty) + assert(desc.storage.serde == Some("org.apache.hadoop.hive.serde2.lazy.LazySimpleSerDe")) assert(desc.storage.properties.isEmpty) assert(desc.properties.isEmpty) assert(desc.comment.isEmpty) @@ -412,7 +417,7 @@ class HiveDDLCommandSuite extends PlanTest with SQLTestUtils with TestHiveSingle val (desc2, _) = extractTableDesc(query2) assert(desc1.storage.inputFormat == Some("winput")) assert(desc1.storage.outputFormat == Some("wowput")) - assert(desc1.storage.serde.isEmpty) + assert(desc1.storage.serde == Some("org.apache.hadoop.hive.serde2.lazy.LazySimpleSerDe")) assert(desc2.storage.inputFormat == Some("org.apache.hadoop.hive.ql.io.orc.OrcInputFormat")) assert(desc2.storage.outputFormat == Some("org.apache.hadoop.hive.ql.io.orc.OrcOutputFormat")) assert(desc2.storage.serde == Some("org.apache.hadoop.hive.ql.io.orc.OrcSerde")) @@ -592,4 +597,94 @@ class HiveDDLCommandSuite extends PlanTest with SQLTestUtils with TestHiveSingle val hiveClient = spark.sharedState.externalCatalog.asInstanceOf[HiveExternalCatalog].client assert(hiveClient.getConf("hive.in.test", "") == "true") } + + test("create hive serde table with new syntax - basic") { + val sql = + """ + |CREATE TABLE t + |(id int, name string COMMENT 'blabla') + |USING hive + |OPTIONS (fileFormat 'parquet', my_prop 1) + |LOCATION '/tmp/file' + |COMMENT 'BLABLA' + """.stripMargin + + val table = analyzeCreateTable(sql) + assert(table.schema == new StructType() + .add("id", "int") + .add("name", "string", nullable = true, comment = "blabla")) + assert(table.provider == Some(DDLUtils.HIVE_PROVIDER)) + assert(table.storage.locationUri == Some("/tmp/file")) + assert(table.storage.properties == Map("my_prop" -> "1")) + assert(table.comment == Some("BLABLA")) + + assert(table.storage.inputFormat == + Some("org.apache.hadoop.hive.ql.io.parquet.MapredParquetInputFormat")) + assert(table.storage.outputFormat == + Some("org.apache.hadoop.hive.ql.io.parquet.MapredParquetOutputFormat")) + assert(table.storage.serde == + Some("org.apache.hadoop.hive.ql.io.parquet.serde.ParquetHiveSerDe")) + } + + test("create hive serde table with new syntax - with partition and bucketing") { + val v1 = "CREATE TABLE t (c1 int, c2 int) USING hive PARTITIONED BY (c2)" + val table = analyzeCreateTable(v1) + assert(table.schema == new StructType().add("c1", "int").add("c2", "int")) + assert(table.partitionColumnNames == Seq("c2")) + // check the default formats + assert(table.storage.serde == Some("org.apache.hadoop.hive.serde2.lazy.LazySimpleSerDe")) + assert(table.storage.inputFormat == Some("org.apache.hadoop.mapred.TextInputFormat")) + assert(table.storage.outputFormat == + Some("org.apache.hadoop.hive.ql.io.HiveIgnoreKeyTextOutputFormat")) + + val v2 = "CREATE TABLE t (c1 int, c2 int) USING hive CLUSTERED BY (c2) INTO 4 BUCKETS" + val e2 = intercept[AnalysisException](analyzeCreateTable(v2)) + assert(e2.message.contains("Creating bucketed Hive serde table is not supported yet")) + + val v3 = + """ + |CREATE TABLE t (c1 int, c2 int) USING hive + |PARTITIONED BY (c2) + |CLUSTERED BY (c2) INTO 4 BUCKETS""".stripMargin + val e3 = intercept[AnalysisException](analyzeCreateTable(v3)) + assert(e3.message.contains("Creating bucketed Hive serde table is not supported yet")) + } + + test("create hive serde table with new syntax - Hive options error checking") { + val v1 = "CREATE TABLE t (c1 int) USING hive OPTIONS (inputFormat 'abc')" + val e1 = intercept[IllegalArgumentException](analyzeCreateTable(v1)) + assert(e1.getMessage.contains("Cannot specify only inputFormat or outputFormat")) + + val v2 = "CREATE TABLE t (c1 int) USING hive OPTIONS " + + "(fileFormat 'x', inputFormat 'a', outputFormat 'b')" + val e2 = intercept[IllegalArgumentException](analyzeCreateTable(v2)) + assert(e2.getMessage.contains( + "Cannot specify fileFormat and inputFormat/outputFormat together")) + + val v3 = "CREATE TABLE t (c1 int) USING hive OPTIONS (fileFormat 'parquet', serde 'a')" + val e3 = intercept[IllegalArgumentException](analyzeCreateTable(v3)) + assert(e3.getMessage.contains("fileFormat 'parquet' already specifies a serde")) + + val v4 = "CREATE TABLE t (c1 int) USING hive OPTIONS (serde 'a', fieldDelim ' ')" + val e4 = intercept[IllegalArgumentException](analyzeCreateTable(v4)) + assert(e4.getMessage.contains("Cannot specify delimiters with a custom serde")) + + val v5 = "CREATE TABLE t (c1 int) USING hive OPTIONS (fieldDelim ' ')" + val e5 = intercept[IllegalArgumentException](analyzeCreateTable(v5)) + assert(e5.getMessage.contains("Cannot specify delimiters without fileFormat")) + + val v6 = "CREATE TABLE t (c1 int) USING hive OPTIONS (fileFormat 'parquet', fieldDelim ' ')" + val e6 = intercept[IllegalArgumentException](analyzeCreateTable(v6)) + assert(e6.getMessage.contains( + "Cannot specify delimiters as they are only compatible with fileFormat 'textfile'")) + + // The value of 'fileFormat' option is case-insensitive. + val v7 = "CREATE TABLE t (c1 int) USING hive OPTIONS (fileFormat 'TEXTFILE', lineDelim ',')" + val e7 = intercept[IllegalArgumentException](analyzeCreateTable(v7)) + assert(e7.getMessage.contains("Hive data source only support newline '\\n' as line delimiter")) + + val v8 = "CREATE TABLE t (c1 int) USING hive OPTIONS (fileFormat 'wrong')" + val e8 = intercept[IllegalArgumentException](analyzeCreateTable(v8)) + assert(e8.getMessage.contains("invalid fileFormat: 'wrong'")) + } } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveExternalCatalogSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveExternalCatalogSuite.scala index 6fee45824ea3f..2f02bb5d3b649 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveExternalCatalogSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveExternalCatalogSuite.scala @@ -68,6 +68,6 @@ class HiveExternalCatalogSuite extends ExternalCatalogSuite { val rawTable = externalCatalog.client.getTable("db1", "hive_tbl") assert(!rawTable.properties.contains(HiveExternalCatalog.DATASOURCE_PROVIDER)) - assert(externalCatalog.getTable("db1", "hive_tbl").provider == Some(DDLUtils.HIVE_PROVIDER)) + assert(DDLUtils.isHiveTable(externalCatalog.getTable("db1", "hive_tbl"))) } } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/MetastoreDataSourcesSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/MetastoreDataSourcesSuite.scala index 2b8d4e2bb3bb5..aed825e2f3d26 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/MetastoreDataSourcesSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/MetastoreDataSourcesSuite.scala @@ -1189,21 +1189,6 @@ class MetastoreDataSourcesSuite extends QueryTest with SQLTestUtils with TestHiv } } - test("create a data source table using hive") { - val tableName = "tab1" - withTable (tableName) { - val e = intercept[AnalysisException] { - sql( - s""" - |CREATE TABLE $tableName - |(col1 int) - |USING hive - """.stripMargin) - }.getMessage - assert(e.contains("Cannot create hive serde table with CREATE TABLE USING")) - } - } - test("create a temp view using hive") { val tableName = "tab1" withTable (tableName) { diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/PartitionProviderCompatibilitySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/PartitionProviderCompatibilitySuite.scala index f88fc4a2ce0ec..44233cfbf0f36 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/PartitionProviderCompatibilitySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/PartitionProviderCompatibilitySuite.scala @@ -28,6 +28,7 @@ import org.apache.spark.util.Utils class PartitionProviderCompatibilitySuite extends QueryTest with TestHiveSingleton with SQLTestUtils { + import testImplicits._ private def setupPartitionedDatasourceTable(tableName: String, dir: File): Unit = { spark.range(5).selectExpr("id as fieldOne", "id as partCol").write @@ -195,12 +196,30 @@ class PartitionProviderCompatibilitySuite withTempDir { dir => setupPartitionedDatasourceTable("test", dir) if (enabled) { - spark.sql("msck repair table test") + assert(spark.table("test").count() == 0) + } else { + assert(spark.table("test").count() == 5) } - assert(spark.sql("select * from test").count() == 5) - spark.range(10).selectExpr("id as fieldOne", "id as partCol") + // Table `test` has 5 partitions, from `partCol=0` to `partCol=4`, which are invisible + // because we have not run `REPAIR TABLE` yet. Here we add 10 more partitions from + // `partCol=3` to `partCol=12`, to test the following behaviors: + // 1. invisible partitions are still invisible if they are not overwritten. + // 2. invisible partitions become visible if they are overwritten. + // 3. newly added partitions should be visible. + spark.range(3, 13).selectExpr("id as fieldOne", "id as partCol") .write.partitionBy("partCol").mode("append").saveAsTable("test") - assert(spark.sql("select * from test").count() == 15) + + if (enabled) { + // Only the newly written partitions are visible, which means the partitions + // `partCol=0`, `partCol=1` and `partCol=2` are still invisible, so we can only see + // 5 + 10 - 3 = 12 records. + assert(spark.table("test").count() == 12) + // Repair the table to make all partitions visible. + sql("msck repair table test") + assert(spark.table("test").count() == 15) + } else { + assert(spark.table("test").count() == 15) + } } } } @@ -449,4 +468,17 @@ class PartitionProviderCompatibilitySuite assert(spark.sql("show partitions test").count() == 5) } } + + test("append data with DataFrameWriter") { + testCustomLocations { + val df = Seq((1L, 0, 0), (2L, 0, 0)).toDF("id", "P1", "P2") + df.write.partitionBy("P1", "P2").mode("append").saveAsTable("test") + assert(spark.sql("select * from test").count() == 2) + assert(spark.sql("show partitions test").count() == 4) + val df2 = Seq((3L, 2, 2)).toDF("id", "P1", "P2") + df2.write.partitionBy("P1", "P2").mode("append").saveAsTable("test") + assert(spark.sql("select * from test").count() == 3) + assert(spark.sql("show partitions test").count() == 5) + } + } } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveDDLSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveDDLSuite.scala index 8b34219530259..3ac07d093381c 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveDDLSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveDDLSuite.scala @@ -28,6 +28,7 @@ import org.apache.spark.sql.catalyst.catalog.{CatalogDatabase, CatalogTable, Cat import org.apache.spark.sql.catalyst.TableIdentifier import org.apache.spark.sql.execution.command.DDLUtils import org.apache.spark.sql.hive.HiveExternalCatalog +import org.apache.spark.sql.hive.orc.OrcFileOperator import org.apache.spark.sql.hive.test.TestHiveSingleton import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.internal.StaticSQLConf.CATALOG_IMPLEMENTATION @@ -1250,4 +1251,42 @@ class HiveDDLSuite assert(e.message.contains("unknown is not a valid partition column")) } } + + test("create hive serde table with new syntax") { + withTable("t", "t2", "t3") { + withTempPath { path => + sql( + s""" + |CREATE TABLE t(id int) USING hive + |OPTIONS(fileFormat 'orc', compression 'Zlib') + |LOCATION '${path.getCanonicalPath}' + """.stripMargin) + val table = spark.sessionState.catalog.getTableMetadata(TableIdentifier("t")) + assert(DDLUtils.isHiveTable(table)) + assert(table.storage.serde == Some("org.apache.hadoop.hive.ql.io.orc.OrcSerde")) + assert(table.storage.properties.get("compression") == Some("Zlib")) + assert(spark.table("t").collect().isEmpty) + + sql("INSERT INTO t SELECT 1") + checkAnswer(spark.table("t"), Row(1)) + // Check if this is compressed as ZLIB. + val maybeOrcFile = path.listFiles().find(_.getName.endsWith("part-00000")) + assert(maybeOrcFile.isDefined) + val orcFilePath = maybeOrcFile.get.toPath.toString + val expectedCompressionKind = + OrcFileOperator.getFileReader(orcFilePath).get.getCompression + assert("ZLIB" === expectedCompressionKind.name()) + + sql("CREATE TABLE t2 USING HIVE AS SELECT 1 AS c1, 'a' AS c2") + val table2 = spark.sessionState.catalog.getTableMetadata(TableIdentifier("t2")) + assert(DDLUtils.isHiveTable(table2)) + assert(table2.storage.serde == Some("org.apache.hadoop.hive.serde2.lazy.LazySimpleSerDe")) + checkAnswer(spark.table("t2"), Row(1, "a")) + + sql("CREATE TABLE t3(a int, p int) USING hive PARTITIONED BY (p)") + sql("INSERT INTO t3 PARTITION(p=1) SELECT 0") + checkAnswer(spark.table("t3"), Row(0, 1)) + } + } + } } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcHadoopFsRelationSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcHadoopFsRelationSuite.scala index 463c368fc42b1..e678cf6f22c20 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcHadoopFsRelationSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcHadoopFsRelationSuite.scala @@ -93,8 +93,6 @@ class OrcHadoopFsRelationSuite extends HadoopFsRelationTest { .orc(path) // Check if this is compressed as ZLIB. - val conf = spark.sessionState.newHadoopConf() - val fs = FileSystem.getLocal(conf) val maybeOrcFile = new File(path).listFiles().find(_.getName.endsWith(".zlib.orc")) assert(maybeOrcFile.isDefined) val orcFilePath = maybeOrcFile.get.toPath.toString diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/sources/CommitFailureTestSource.scala b/sql/hive/src/test/scala/org/apache/spark/sql/sources/CommitFailureTestSource.scala index abc7c8cc4db89..7501334f94dd2 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/sources/CommitFailureTestSource.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/sources/CommitFailureTestSource.scala @@ -20,7 +20,8 @@ package org.apache.spark.sql.sources import org.apache.hadoop.mapreduce.{Job, TaskAttemptContext} import org.apache.spark.TaskContext -import org.apache.spark.sql.{Row, SparkSession} +import org.apache.spark.sql.SparkSession +import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.execution.datasources.{OutputWriter, OutputWriterFactory} import org.apache.spark.sql.types.StructType @@ -42,14 +43,14 @@ class CommitFailureTestSource extends SimpleTextSource { path: String, dataSchema: StructType, context: TaskAttemptContext): OutputWriter = { - new SimpleTextOutputWriter(path, context) { + new SimpleTextOutputWriter(path, dataSchema, context) { var failed = false TaskContext.get().addTaskFailureListener { (t: TaskContext, e: Throwable) => failed = true SimpleTextRelation.callbackCalled = true } - override def write(row: Row): Unit = { + override def write(row: InternalRow): Unit = { if (SimpleTextRelation.failWriter) { sys.error("Intentional task writer failure for testing purpose.") diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/sources/SimpleTextRelation.scala b/sql/hive/src/test/scala/org/apache/spark/sql/sources/SimpleTextRelation.scala index 5fdf6152590ce..1607c97cd6acb 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/sources/SimpleTextRelation.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/sources/SimpleTextRelation.scala @@ -50,7 +50,7 @@ class SimpleTextSource extends TextBasedFileFormat with DataSourceRegister { path: String, dataSchema: StructType, context: TaskAttemptContext): OutputWriter = { - new SimpleTextOutputWriter(path, context) + new SimpleTextOutputWriter(path, dataSchema, context) } override def getFileExtension(context: TaskAttemptContext): String = "" @@ -117,13 +117,13 @@ class SimpleTextSource extends TextBasedFileFormat with DataSourceRegister { } } -class SimpleTextOutputWriter(path: String, context: TaskAttemptContext) +class SimpleTextOutputWriter(path: String, dataSchema: StructType, context: TaskAttemptContext) extends OutputWriter { private val writer = CodecStreams.createOutputStreamWriter(context, new Path(path)) - override def write(row: Row): Unit = { - val serialized = row.toSeq.map { v => + override def write(row: InternalRow): Unit = { + val serialized = row.toSeq(dataSchema).map { v => if (v == null) "" else v.toString }.mkString(",") diff --git a/streaming/src/main/scala/org/apache/spark/streaming/State.scala b/streaming/src/main/scala/org/apache/spark/streaming/State.scala index 3f560f889f055..23cf48eb06738 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/State.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/State.scala @@ -178,7 +178,7 @@ private[streaming] class StateImpl[S] extends State[S] { removed } - /** Whether the state has been been updated */ + /** Whether the state has been updated */ def isUpdated(): Boolean = { updated } diff --git a/streaming/src/main/scala/org/apache/spark/streaming/dstream/InputDStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/dstream/InputDStream.scala index a3c125c306954..9a760e2947d0b 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/dstream/InputDStream.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/dstream/InputDStream.scala @@ -88,7 +88,7 @@ abstract class InputDStream[T: ClassTag](_ssc: StreamingContext) if (!super.isTimeValid(time)) { false // Time not valid } else { - // Time is valid, but check it it is more than lastValidTime + // Time is valid, but check it is more than lastValidTime if (lastValidTime != null && time < lastValidTime) { logWarning(s"isTimeValid called with $time whereas the last valid time " + s"is $lastValidTime") diff --git a/streaming/src/test/scala/org/apache/spark/streaming/rdd/MapWithStateRDDSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/rdd/MapWithStateRDDSuite.scala index e8c814ba7184b..9b6bc71c7a5b5 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/rdd/MapWithStateRDDSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/rdd/MapWithStateRDDSuite.scala @@ -326,7 +326,7 @@ class MapWithStateRDDSuite extends SparkFunSuite with RDDCheckpointTester with B // Create a MapWithStateRDD that has a long lineage using the data RDD with a long lineage val stateRDDWithLongLineage = makeStateRDDWithLongLineageDataRDD(longLineageRDD) - // Create a new MapWithStateRDD, with the lineage lineage MapWithStateRDD as the parent + // Create a new MapWithStateRDD, with the lineage MapWithStateRDD as the parent new MapWithStateRDD[Int, Int, Int, Int]( stateRDDWithLongLineage, stateRDDWithLongLineage.sparkContext.emptyRDD[(Int, Int)].partitionBy(partitioner), diff --git a/streaming/src/test/scala/org/apache/spark/streaming/rdd/WriteAheadLogBackedBlockRDDSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/rdd/WriteAheadLogBackedBlockRDDSuite.scala index a37fac87300b7..c5e695a33ae10 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/rdd/WriteAheadLogBackedBlockRDDSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/rdd/WriteAheadLogBackedBlockRDDSuite.scala @@ -108,7 +108,7 @@ class WriteAheadLogBackedBlockRDDSuite /** * Test the WriteAheadLogBackedRDD, by writing some partitions of the data to block manager - * and the rest to a write ahead log, and then reading reading it all back using the RDD. + * and the rest to a write ahead log, and then reading it all back using the RDD. * It can also test if the partitions that were read from the log were again stored in * block manager. * diff --git a/streaming/src/test/scala/org/apache/spark/streaming/receiver/BlockGeneratorSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/receiver/BlockGeneratorSuite.scala index a1d0561bf308a..b70383ecde4d8 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/receiver/BlockGeneratorSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/receiver/BlockGeneratorSuite.scala @@ -90,7 +90,7 @@ class BlockGeneratorSuite extends SparkFunSuite with BeforeAndAfter { listener.pushedData.asScala.toSeq should contain theSameElementsInOrderAs (data1) assert(listener.onAddDataCalled === false) // should be called only with addDataWithCallback() - // Verify addDataWithCallback() add data+metadata and and callbacks are called correctly + // Verify addDataWithCallback() add data+metadata and callbacks are called correctly val data2 = 11 to 20 val metadata2 = data2.map { _.toString } data2.zip(metadata2).foreach { case (d, m) => blockGenerator.addDataWithCallback(d, m) } @@ -103,7 +103,7 @@ class BlockGeneratorSuite extends SparkFunSuite with BeforeAndAfter { listener.pushedData.asScala.toSeq should contain theSameElementsInOrderAs combined } - // Verify addMultipleDataWithCallback() add data+metadata and and callbacks are called correctly + // Verify addMultipleDataWithCallback() add data+metadata and callbacks are called correctly val data3 = 21 to 30 val metadata3 = "metadata" blockGenerator.addMultipleDataWithCallback(data3.iterator, metadata3)