diff --git a/.gitignore b/.gitignore index d162fa9cca994..d54d21b802be8 100644 --- a/.gitignore +++ b/.gitignore @@ -63,6 +63,8 @@ ec2/lib/ rat-results.txt scalastyle.txt scalastyle-output.xml +R-unit-tests.log +R/unit-tests.out # For Hive metastore_db/ diff --git a/.rat-excludes b/.rat-excludes index 8c0b328b8dd85..a314e7126ae5b 100644 --- a/.rat-excludes +++ b/.rat-excludes @@ -73,3 +73,5 @@ local-1422981780767/* local-1425081759269/* local-1426533911241/* local-1427397477963/* +DESCRIPTION +NAMESPACE diff --git a/R/.gitignore b/R/.gitignore new file mode 100644 index 0000000000000..9a5889ba28b2a --- /dev/null +++ b/R/.gitignore @@ -0,0 +1,6 @@ +*.o +*.so +*.Rd +lib +pkg/man +pkg/html diff --git a/R/DOCUMENTATION.md b/R/DOCUMENTATION.md new file mode 100644 index 0000000000000..931d01549b265 --- /dev/null +++ b/R/DOCUMENTATION.md @@ -0,0 +1,12 @@ +# SparkR Documentation + +SparkR documentation is generated using in-source comments annotated using using +`roxygen2`. After making changes to the documentation, to generate man pages, +you can run the following from an R console in the SparkR home directory + + library(devtools) + devtools::document(pkg="./pkg", roclets=c("rd")) + +You can verify if your changes are good by running + + R CMD check pkg/ diff --git a/R/README.md b/R/README.md new file mode 100644 index 0000000000000..a6970e39b55f3 --- /dev/null +++ b/R/README.md @@ -0,0 +1,67 @@ +# R on Spark + +SparkR is an R package that provides a light-weight frontend to use Spark from R. + +### SparkR development + +#### Build Spark + +Build Spark with [Maven](http://spark.apache.org/docs/latest/building-spark.html#building-with-buildmvn) and include the `-PsparkR` profile to build the R package. For example to use the default Hadoop versions you can run +``` + build/mvn -DskipTests -Psparkr package +``` + +#### Running sparkR + +You can start using SparkR by launching the SparkR shell with + + ./bin/sparkR + +The `sparkR` script automatically creates a SparkContext with Spark by default in +local mode. To specify the Spark master of a cluster for the automatically created +SparkContext, you can run + + ./bin/sparkR --master "local[2]" + +To set other options like driver memory, executor memory etc. you can pass in the [spark-submit](http://spark.apache.org/docs/latest/submitting-applications.html) arguments to `./bin/sparkR` + +#### Using SparkR from RStudio + +If you wish to use SparkR from RStudio or other R frontends you will need to set some environment variables which point SparkR to your Spark installation. For example +``` +# Set this to where Spark is installed +Sys.setenv(SPARK_HOME="/Users/shivaram/spark") +# This line loads SparkR from the installed directory +.libPaths(c(file.path(Sys.getenv("SPARK_HOME"), "R", "lib"), .libPaths())) +library(SparkR) +sc <- sparkR.init(master="local") +``` + +#### Making changes to SparkR + +The [instructions](https://cwiki.apache.org/confluence/display/SPARK/Contributing+to+Spark) for making contributions to Spark also apply to SparkR. +If you only make R file changes (i.e. no Scala changes) then you can just re-install the R package using `R/install-dev.sh` and test your changes. +Once you have made your changes, please include unit tests for them and run existing unit tests using the `run-tests.sh` script as described below. + +#### Generating documentation + +The SparkR documentation (Rd files and HTML files) are not a part of the source repository. To generate them you can run the script `R/create-docs.sh`. This script uses `devtools` and `knitr` to generate the docs and these packages need to be installed on the machine before using the script. + +### Examples, Unit tests + +SparkR comes with several sample programs in the `examples/src/main/r` directory. +To run one of them, use `./bin/sparkR `. For example: + + ./bin/sparkR examples/src/main/r/pi.R local[2] + +You can also run the unit-tests for SparkR by running (you need to install the [testthat](http://cran.r-project.org/web/packages/testthat/index.html) package first): + + R -e 'install.packages("testthat", repos="http://cran.us.r-project.org")' + ./R/run-tests.sh + +### Running on YARN +The `./bin/spark-submit` and `./bin/sparkR` can also be used to submit jobs to YARN clusters. You will need to set YARN conf dir before doing so. For example on CDH you can run +``` +export YARN_CONF_DIR=/etc/hadoop/conf +./bin/spark-submit --master yarn examples/src/main/r/pi.R 4 +``` diff --git a/R/WINDOWS.md b/R/WINDOWS.md new file mode 100644 index 0000000000000..3f889c0ca3d1e --- /dev/null +++ b/R/WINDOWS.md @@ -0,0 +1,13 @@ +## Building SparkR on Windows + +To build SparkR on Windows, the following steps are required + +1. Install R (>= 3.1) and [Rtools](http://cran.r-project.org/bin/windows/Rtools/). Make sure to +include Rtools and R in `PATH`. +2. Install +[JDK7](http://www.oracle.com/technetwork/java/javase/downloads/jdk7-downloads-1880260.html) and set +`JAVA_HOME` in the system environment variables. +3. Download and install [Maven](http://maven.apache.org/download.html). Also include the `bin` +directory in Maven in `PATH`. +4. Set `MAVEN_OPTS` as described in [Building Spark](http://spark.apache.org/docs/latest/building-spark.html). +5. Open a command shell (`cmd`) in the Spark directory and run `mvn -DskipTests -Psparkr package` diff --git a/R/create-docs.sh b/R/create-docs.sh new file mode 100755 index 0000000000000..4194172a2e115 --- /dev/null +++ b/R/create-docs.sh @@ -0,0 +1,46 @@ +#!/bin/bash + +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +# Script to create API docs for SparkR +# This requires `devtools` and `knitr` to be installed on the machine. + +# After running this script the html docs can be found in +# $SPARK_HOME/R/pkg/html + +# Figure out where the script is +export FWDIR="$(cd "`dirname "$0"`"; pwd)" +pushd $FWDIR + +# Generate Rd file +Rscript -e 'library(devtools); devtools::document(pkg="./pkg", roclets=c("rd"))' + +# Install the package +./install-dev.sh + +# Now create HTML files + +# knit_rd puts html in current working directory +mkdir -p pkg/html +pushd pkg/html + +Rscript -e 'library(SparkR, lib.loc="../../lib"); library(knitr); knit_rd("SparkR")' + +popd + +popd diff --git a/R/install-dev.bat b/R/install-dev.bat new file mode 100644 index 0000000000000..008a5c668bc45 --- /dev/null +++ b/R/install-dev.bat @@ -0,0 +1,27 @@ +@echo off + +rem +rem Licensed to the Apache Software Foundation (ASF) under one or more +rem contributor license agreements. See the NOTICE file distributed with +rem this work for additional information regarding copyright ownership. +rem The ASF licenses this file to You under the Apache License, Version 2.0 +rem (the "License"); you may not use this file except in compliance with +rem the License. You may obtain a copy of the License at +rem +rem http://www.apache.org/licenses/LICENSE-2.0 +rem +rem Unless required by applicable law or agreed to in writing, software +rem distributed under the License is distributed on an "AS IS" BASIS, +rem WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +rem See the License for the specific language governing permissions and +rem limitations under the License. +rem + +rem Install development version of SparkR +rem + +set SPARK_HOME=%~dp0.. + +MKDIR %SPARK_HOME%\R\lib + +R.exe CMD INSTALL --library="%SPARK_HOME%\R\lib" %SPARK_HOME%\R\pkg\ diff --git a/R/install-dev.sh b/R/install-dev.sh new file mode 100755 index 0000000000000..55ed6f4be1a4a --- /dev/null +++ b/R/install-dev.sh @@ -0,0 +1,36 @@ +#!/bin/bash + +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +# This scripts packages the SparkR source files (R and C files) and +# creates a package that can be loaded in R. The package is by default installed to +# $FWDIR/lib and the package can be loaded by using the following command in R: +# +# library(SparkR, lib.loc="$FWDIR/lib") +# +# NOTE(shivaram): Right now we use $SPARK_HOME/R/lib to be the installation directory +# to load the SparkR package on the worker nodes. + + +FWDIR="$(cd `dirname $0`; pwd)" +LIB_DIR="$FWDIR/lib" + +mkdir -p $LIB_DIR + +# Install R +R CMD INSTALL --library=$LIB_DIR $FWDIR/pkg/ diff --git a/R/log4j.properties b/R/log4j.properties new file mode 100644 index 0000000000000..701adb2a3da1d --- /dev/null +++ b/R/log4j.properties @@ -0,0 +1,28 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +# Set everything to be logged to the file target/unit-tests.log +log4j.rootCategory=INFO, file +log4j.appender.file=org.apache.log4j.FileAppender +log4j.appender.file.append=true +log4j.appender.file.file=R-unit-tests.log +log4j.appender.file.layout=org.apache.log4j.PatternLayout +log4j.appender.file.layout.ConversionPattern=%d{yy/MM/dd HH:mm:ss.SSS} %t %p %c{1}: %m%n + +# Ignore messages below warning level from Jetty, because it's a bit verbose +log4j.logger.org.eclipse.jetty=WARN +org.eclipse.jetty.LEVEL=WARN diff --git a/R/pkg/DESCRIPTION b/R/pkg/DESCRIPTION new file mode 100644 index 0000000000000..1842b97d43651 --- /dev/null +++ b/R/pkg/DESCRIPTION @@ -0,0 +1,35 @@ +Package: SparkR +Type: Package +Title: R frontend for Spark +Version: 1.4.0 +Date: 2013-09-09 +Author: The Apache Software Foundation +Maintainer: Shivaram Venkataraman +Imports: + methods +Depends: + R (>= 3.0), + methods, +Suggests: + testthat +Description: R frontend for Spark +License: Apache License (== 2.0) +Collate: + 'generics.R' + 'jobj.R' + 'SQLTypes.R' + 'RDD.R' + 'pairRDD.R' + 'column.R' + 'group.R' + 'DataFrame.R' + 'SQLContext.R' + 'broadcast.R' + 'context.R' + 'deserialize.R' + 'serialize.R' + 'sparkR.R' + 'backend.R' + 'client.R' + 'utils.R' + 'zzz.R' diff --git a/R/pkg/NAMESPACE b/R/pkg/NAMESPACE new file mode 100644 index 0000000000000..a354cdce74afa --- /dev/null +++ b/R/pkg/NAMESPACE @@ -0,0 +1,182 @@ +#exportPattern("^[[:alpha:]]+") +exportClasses("RDD") +exportClasses("Broadcast") +exportMethods( + "aggregateByKey", + "aggregateRDD", + "cache", + "checkpoint", + "coalesce", + "cogroup", + "collect", + "collectAsMap", + "collectPartition", + "combineByKey", + "count", + "countByKey", + "countByValue", + "distinct", + "Filter", + "filterRDD", + "first", + "flatMap", + "flatMapValues", + "fold", + "foldByKey", + "foreach", + "foreachPartition", + "fullOuterJoin", + "glom", + "groupByKey", + "join", + "keyBy", + "keys", + "length", + "lapply", + "lapplyPartition", + "lapplyPartitionsWithIndex", + "leftOuterJoin", + "lookup", + "map", + "mapPartitions", + "mapPartitionsWithIndex", + "mapValues", + "maximum", + "minimum", + "numPartitions", + "partitionBy", + "persist", + "pipeRDD", + "reduce", + "reduceByKey", + "reduceByKeyLocally", + "repartition", + "rightOuterJoin", + "sampleRDD", + "saveAsTextFile", + "saveAsObjectFile", + "sortBy", + "sortByKey", + "sumRDD", + "take", + "takeOrdered", + "takeSample", + "top", + "unionRDD", + "unpersist", + "value", + "values", + "zipRDD", + "zipWithIndex", + "zipWithUniqueId" + ) + +# S3 methods exported +export( + "textFile", + "objectFile", + "parallelize", + "hashCode", + "includePackage", + "broadcast", + "setBroadcastValue", + "setCheckpointDir" + ) +export("sparkR.init") +export("sparkR.stop") +export("print.jobj") +useDynLib(SparkR, stringHashCode) +importFrom(methods, setGeneric, setMethod, setOldClass) + +# SparkRSQL + +exportClasses("DataFrame") + +exportMethods("columns", + "distinct", + "dtypes", + "explain", + "filter", + "groupBy", + "head", + "insertInto", + "intersect", + "isLocal", + "limit", + "orderBy", + "names", + "printSchema", + "registerTempTable", + "repartition", + "sampleDF", + "saveAsParquetFile", + "saveAsTable", + "saveDF", + "schema", + "select", + "selectExpr", + "show", + "showDF", + "sortDF", + "subtract", + "toJSON", + "toRDD", + "unionAll", + "where", + "withColumn", + "withColumnRenamed") + +exportClasses("Column") + +exportMethods("abs", + "alias", + "approxCountDistinct", + "asc", + "avg", + "cast", + "contains", + "countDistinct", + "desc", + "endsWith", + "getField", + "getItem", + "isNotNull", + "isNull", + "last", + "like", + "lower", + "max", + "mean", + "min", + "rlike", + "sqrt", + "startsWith", + "substr", + "sum", + "sumDistinct", + "upper") + +exportClasses("GroupedData") +exportMethods("agg") + +export("sparkRSQL.init", + "sparkRHive.init") + +export("cacheTable", + "clearCache", + "createDataFrame", + "createExternalTable", + "dropTempTable", + "jsonFile", + "jsonRDD", + "loadDF", + "parquetFile", + "sql", + "table", + "tableNames", + "tables", + "toDF", + "uncacheTable") + +export("print.structType", + "print.structField") diff --git a/R/pkg/R/DataFrame.R b/R/pkg/R/DataFrame.R new file mode 100644 index 0000000000000..feafd56909a67 --- /dev/null +++ b/R/pkg/R/DataFrame.R @@ -0,0 +1,1270 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +# DataFrame.R - DataFrame class and methods implemented in S4 OO classes + +#' @include jobj.R SQLTypes.R RDD.R pairRDD.R column.R group.R +NULL + +setOldClass("jobj") + +#' @title S4 class that represents a DataFrame +#' @description DataFrames can be created using functions like +#' \code{jsonFile}, \code{table} etc. +#' @rdname DataFrame +#' @seealso jsonFile, table +#' +#' @param env An R environment that stores bookkeeping states of the DataFrame +#' @param sdf A Java object reference to the backing Scala DataFrame +#' @export +setClass("DataFrame", + slots = list(env = "environment", + sdf = "jobj")) + +setMethod("initialize", "DataFrame", function(.Object, sdf, isCached) { + .Object@env <- new.env() + .Object@env$isCached <- isCached + + .Object@sdf <- sdf + .Object +}) + +#' @rdname DataFrame +#' @export +dataFrame <- function(sdf, isCached = FALSE) { + new("DataFrame", sdf, isCached) +} + +############################ DataFrame Methods ############################################## + +#' Print Schema of a DataFrame +#' +#' Prints out the schema in tree format +#' +#' @param x A SparkSQL DataFrame +#' +#' @rdname printSchema +#' @export +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' sqlCtx <- sparkRSQL.init(sc) +#' path <- "path/to/file.json" +#' df <- jsonFile(sqlCtx, path) +#' printSchema(df) +#'} +setMethod("printSchema", + signature(x = "DataFrame"), + function(x) { + schemaString <- callJMethod(schema(x)$jobj, "treeString") + cat(schemaString) + }) + +#' Get schema object +#' +#' Returns the schema of this DataFrame as a structType object. +#' +#' @param x A SparkSQL DataFrame +#' +#' @rdname schema +#' @export +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' sqlCtx <- sparkRSQL.init(sc) +#' path <- "path/to/file.json" +#' df <- jsonFile(sqlCtx, path) +#' dfSchema <- schema(df) +#'} +setMethod("schema", + signature(x = "DataFrame"), + function(x) { + structType(callJMethod(x@sdf, "schema")) + }) + +#' Explain +#' +#' Print the logical and physical Catalyst plans to the console for debugging. +#' +#' @param x A SparkSQL DataFrame +#' @param extended Logical. If extended is False, explain() only prints the physical plan. +#' @rdname explain +#' @export +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' sqlCtx <- sparkRSQL.init(sc) +#' path <- "path/to/file.json" +#' df <- jsonFile(sqlCtx, path) +#' explain(df, TRUE) +#'} +setMethod("explain", + signature(x = "DataFrame"), + function(x, extended = FALSE) { + queryExec <- callJMethod(x@sdf, "queryExecution") + if (extended) { + cat(callJMethod(queryExec, "toString")) + } else { + execPlan <- callJMethod(queryExec, "executedPlan") + cat(callJMethod(execPlan, "toString")) + } + }) + +#' isLocal +#' +#' Returns True if the `collect` and `take` methods can be run locally +#' (without any Spark executors). +#' +#' @param x A SparkSQL DataFrame +#' +#' @rdname isLocal +#' @export +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' sqlCtx <- sparkRSQL.init(sc) +#' path <- "path/to/file.json" +#' df <- jsonFile(sqlCtx, path) +#' isLocal(df) +#'} +setMethod("isLocal", + signature(x = "DataFrame"), + function(x) { + callJMethod(x@sdf, "isLocal") + }) + +#' ShowDF +#' +#' Print the first numRows rows of a DataFrame +#' +#' @param x A SparkSQL DataFrame +#' @param numRows The number of rows to print. Defaults to 20. +#' +#' @rdname showDF +#' @export +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' sqlCtx <- sparkRSQL.init(sc) +#' path <- "path/to/file.json" +#' df <- jsonFile(sqlCtx, path) +#' showDF(df) +#'} +setMethod("showDF", + signature(x = "DataFrame"), + function(x, numRows = 20) { + cat(callJMethod(x@sdf, "showString", numToInt(numRows)), "\n") + }) + +#' show +#' +#' Print the DataFrame column names and types +#' +#' @param x A SparkSQL DataFrame +#' +#' @rdname show +#' @export +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' sqlCtx <- sparkRSQL.init(sc) +#' path <- "path/to/file.json" +#' df <- jsonFile(sqlCtx, path) +#' show(df) +#'} +setMethod("show", "DataFrame", + function(object) { + cols <- lapply(dtypes(object), function(l) { + paste(l, collapse = ":") + }) + s <- paste(cols, collapse = ", ") + cat(paste("DataFrame[", s, "]\n", sep = "")) + }) + +#' DataTypes +#' +#' Return all column names and their data types as a list +#' +#' @param x A SparkSQL DataFrame +#' +#' @rdname dtypes +#' @export +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' sqlCtx <- sparkRSQL.init(sc) +#' path <- "path/to/file.json" +#' df <- jsonFile(sqlCtx, path) +#' dtypes(df) +#'} +setMethod("dtypes", + signature(x = "DataFrame"), + function(x) { + lapply(schema(x)$fields(), function(f) { + c(f$name(), f$dataType.simpleString()) + }) + }) + +#' Column names +#' +#' Return all column names as a list +#' +#' @param x A SparkSQL DataFrame +#' +#' @rdname columns +#' @export +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' sqlCtx <- sparkRSQL.init(sc) +#' path <- "path/to/file.json" +#' df <- jsonFile(sqlCtx, path) +#' columns(df) +#'} +setMethod("columns", + signature(x = "DataFrame"), + function(x) { + sapply(schema(x)$fields(), function(f) { + f$name() + }) + }) + +#' @rdname columns +#' @export +setMethod("names", + signature(x = "DataFrame"), + function(x) { + columns(x) + }) + +#' Register Temporary Table +#' +#' Registers a DataFrame as a Temporary Table in the SQLContext +#' +#' @param x A SparkSQL DataFrame +#' @param tableName A character vector containing the name of the table +#' +#' @rdname registerTempTable +#' @export +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' sqlCtx <- sparkRSQL.init(sc) +#' path <- "path/to/file.json" +#' df <- jsonFile(sqlCtx, path) +#' registerTempTable(df, "json_df") +#' new_df <- sql(sqlCtx, "SELECT * FROM json_df") +#'} +setMethod("registerTempTable", + signature(x = "DataFrame", tableName = "character"), + function(x, tableName) { + callJMethod(x@sdf, "registerTempTable", tableName) + }) + +#' insertInto +#' +#' Insert the contents of a DataFrame into a table registered in the current SQL Context. +#' +#' @param x A SparkSQL DataFrame +#' @param tableName A character vector containing the name of the table +#' @param overwrite A logical argument indicating whether or not to overwrite +#' the existing rows in the table. +#' +#' @rdname insertInto +#' @export +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' sqlCtx <- sparkRSQL.init(sc) +#' df <- loadDF(sqlCtx, path, "parquet") +#' df2 <- loadDF(sqlCtx, path2, "parquet") +#' registerTempTable(df, "table1") +#' insertInto(df2, "table1", overwrite = TRUE) +#'} +setMethod("insertInto", + signature(x = "DataFrame", tableName = "character"), + function(x, tableName, overwrite = FALSE) { + callJMethod(x@sdf, "insertInto", tableName, overwrite) + }) + +#' Cache +#' +#' Persist with the default storage level (MEMORY_ONLY). +#' +#' @param x A SparkSQL DataFrame +#' +#' @rdname cache-methods +#' @export +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' sqlCtx <- sparkRSQL.init(sc) +#' path <- "path/to/file.json" +#' df <- jsonFile(sqlCtx, path) +#' cache(df) +#'} +setMethod("cache", + signature(x = "DataFrame"), + function(x) { + cached <- callJMethod(x@sdf, "cache") + x@env$isCached <- TRUE + x + }) + +#' Persist +#' +#' Persist this DataFrame with the specified storage level. For details of the +#' supported storage levels, refer to +#' http://spark.apache.org/docs/latest/programming-guide.html#rdd-persistence. +#' +#' @param x The DataFrame to persist +#' @rdname persist +#' @export +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' sqlCtx <- sparkRSQL.init(sc) +#' path <- "path/to/file.json" +#' df <- jsonFile(sqlCtx, path) +#' persist(df, "MEMORY_AND_DISK") +#'} +setMethod("persist", + signature(x = "DataFrame", newLevel = "character"), + function(x, newLevel) { + callJMethod(x@sdf, "persist", getStorageLevel(newLevel)) + x@env$isCached <- TRUE + x + }) + +#' Unpersist +#' +#' Mark this DataFrame as non-persistent, and remove all blocks for it from memory and +#' disk. +#' +#' @param x The DataFrame to unpersist +#' @param blocking Whether to block until all blocks are deleted +#' @rdname unpersist-methods +#' @export +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' sqlCtx <- sparkRSQL.init(sc) +#' path <- "path/to/file.json" +#' df <- jsonFile(sqlCtx, path) +#' persist(df, "MEMORY_AND_DISK") +#' unpersist(df) +#'} +setMethod("unpersist", + signature(x = "DataFrame"), + function(x, blocking = TRUE) { + callJMethod(x@sdf, "unpersist", blocking) + x@env$isCached <- FALSE + x + }) + +#' Repartition +#' +#' Return a new DataFrame that has exactly numPartitions partitions. +#' +#' @param x A SparkSQL DataFrame +#' @param numPartitions The number of partitions to use. +#' @rdname repartition +#' @export +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' sqlCtx <- sparkRSQL.init(sc) +#' path <- "path/to/file.json" +#' df <- jsonFile(sqlCtx, path) +#' newDF <- repartition(df, 2L) +#'} +setMethod("repartition", + signature(x = "DataFrame", numPartitions = "numeric"), + function(x, numPartitions) { + sdf <- callJMethod(x@sdf, "repartition", numToInt(numPartitions)) + dataFrame(sdf) + }) + +#' toJSON +#' +#' Convert the rows of a DataFrame into JSON objects and return an RDD where +#' each element contains a JSON string. +#' +#' @param x A SparkSQL DataFrame +#' @return A StringRRDD of JSON objects +#' @rdname tojson +#' @export +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' sqlCtx <- sparkRSQL.init(sc) +#' path <- "path/to/file.json" +#' df <- jsonFile(sqlCtx, path) +#' newRDD <- toJSON(df) +#'} +setMethod("toJSON", + signature(x = "DataFrame"), + function(x) { + rdd <- callJMethod(x@sdf, "toJSON") + jrdd <- callJMethod(rdd, "toJavaRDD") + RDD(jrdd, serializedMode = "string") + }) + +#' saveAsParquetFile +#' +#' Save the contents of a DataFrame as a Parquet file, preserving the schema. Files written out +#' with this method can be read back in as a DataFrame using parquetFile(). +#' +#' @param x A SparkSQL DataFrame +#' @param path The directory where the file is saved +#' @rdname saveAsParquetFile +#' @export +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' sqlCtx <- sparkRSQL.init(sc) +#' path <- "path/to/file.json" +#' df <- jsonFile(sqlCtx, path) +#' saveAsParquetFile(df, "/tmp/sparkr-tmp/") +#'} +setMethod("saveAsParquetFile", + signature(x = "DataFrame", path = "character"), + function(x, path) { + invisible(callJMethod(x@sdf, "saveAsParquetFile", path)) + }) + +#' Distinct +#' +#' Return a new DataFrame containing the distinct rows in this DataFrame. +#' +#' @param x A SparkSQL DataFrame +#' @rdname distinct +#' @export +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' sqlCtx <- sparkRSQL.init(sc) +#' path <- "path/to/file.json" +#' df <- jsonFile(sqlCtx, path) +#' distinctDF <- distinct(df) +#'} +setMethod("distinct", + signature(x = "DataFrame"), + function(x) { + sdf <- callJMethod(x@sdf, "distinct") + dataFrame(sdf) + }) + +#' SampleDF +#' +#' Return a sampled subset of this DataFrame using a random seed. +#' +#' @param x A SparkSQL DataFrame +#' @param withReplacement Sampling with replacement or not +#' @param fraction The (rough) sample target fraction +#' @rdname sampleDF +#' @export +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' sqlCtx <- sparkRSQL.init(sc) +#' path <- "path/to/file.json" +#' df <- jsonFile(sqlCtx, path) +#' collect(sampleDF(df, FALSE, 0.5)) +#' collect(sampleDF(df, TRUE, 0.5)) +#'} +setMethod("sampleDF", + # TODO : Figure out how to send integer as java.lang.Long to JVM so + # we can send seed as an argument through callJMethod + signature(x = "DataFrame", withReplacement = "logical", + fraction = "numeric"), + function(x, withReplacement, fraction) { + if (fraction < 0.0) stop(cat("Negative fraction value:", fraction)) + sdf <- callJMethod(x@sdf, "sample", withReplacement, fraction) + dataFrame(sdf) + }) + +#' Count +#' +#' Returns the number of rows in a DataFrame +#' +#' @param x A SparkSQL DataFrame +#' +#' @rdname count +#' @export +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' sqlCtx <- sparkRSQL.init(sc) +#' path <- "path/to/file.json" +#' df <- jsonFile(sqlCtx, path) +#' count(df) +#' } +setMethod("count", + signature(x = "DataFrame"), + function(x) { + callJMethod(x@sdf, "count") + }) + +#' Collects all the elements of a Spark DataFrame and coerces them into an R data.frame. +#' +#' @param x A SparkSQL DataFrame +#' @param stringsAsFactors (Optional) A logical indicating whether or not string columns +#' should be converted to factors. FALSE by default. + +#' @rdname collect-methods +#' @export +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' sqlCtx <- sparkRSQL.init(sc) +#' path <- "path/to/file.json" +#' df <- jsonFile(sqlCtx, path) +#' collected <- collect(df) +#' firstName <- collected[[1]]$name +#' } +setMethod("collect", + signature(x = "DataFrame"), + function(x, stringsAsFactors = FALSE) { + # listCols is a list of raw vectors, one per column + listCols <- callJStatic("org.apache.spark.sql.api.r.SQLUtils", "dfToCols", x@sdf) + cols <- lapply(listCols, function(col) { + objRaw <- rawConnection(col) + numRows <- readInt(objRaw) + col <- readCol(objRaw, numRows) + close(objRaw) + col + }) + names(cols) <- columns(x) + do.call(cbind.data.frame, list(cols, stringsAsFactors = stringsAsFactors)) + }) + +#' Limit +#' +#' Limit the resulting DataFrame to the number of rows specified. +#' +#' @param x A SparkSQL DataFrame +#' @param num The number of rows to return +#' @return A new DataFrame containing the number of rows specified. +#' +#' @rdname limit +#' @export +#' @examples +#' \dontrun{ +#' sc <- sparkR.init() +#' sqlCtx <- sparkRSQL.init(sc) +#' path <- "path/to/file.json" +#' df <- jsonFile(sqlCtx, path) +#' limitedDF <- limit(df, 10) +#' } +setMethod("limit", + signature(x = "DataFrame", num = "numeric"), + function(x, num) { + res <- callJMethod(x@sdf, "limit", as.integer(num)) + dataFrame(res) + }) + +# Take the first NUM rows of a DataFrame and return a the results as a data.frame + +#' @rdname take +#' @export +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' sqlCtx <- sparkRSQL.init(sc) +#' path <- "path/to/file.json" +#' df <- jsonFile(sqlCtx, path) +#' take(df, 2) +#' } +setMethod("take", + signature(x = "DataFrame", num = "numeric"), + function(x, num) { + limited <- limit(x, num) + collect(limited) + }) + +#' Head +#' +#' Return the first NUM rows of a DataFrame as a data.frame. If NUM is NULL, +#' then head() returns the first 6 rows in keeping with the current data.frame +#' convention in R. +#' +#' @param x A SparkSQL DataFrame +#' @param num The number of rows to return. Default is 6. +#' @return A data.frame +#' +#' @rdname head +#' @export +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' sqlCtx <- sparkRSQL.init(sc) +#' path <- "path/to/file.json" +#' df <- jsonFile(sqlCtx, path) +#' head(df) +#' } +setMethod("head", + signature(x = "DataFrame"), + function(x, num = 6L) { + # Default num is 6L in keeping with R's data.frame convention + take(x, num) + }) + +#' Return the first row of a DataFrame +#' +#' @param x A SparkSQL DataFrame +#' +#' @rdname first +#' @export +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' sqlCtx <- sparkRSQL.init(sc) +#' path <- "path/to/file.json" +#' df <- jsonFile(sqlCtx, path) +#' first(df) +#' } +setMethod("first", + signature(x = "DataFrame"), + function(x) { + take(x, 1) + }) + +#' toRDD() +#' +#' Converts a Spark DataFrame to an RDD while preserving column names. +#' +#' @param x A Spark DataFrame +#' +#' @rdname DataFrame +#' @export +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' sqlCtx <- sparkRSQL.init(sc) +#' path <- "path/to/file.json" +#' df <- jsonFile(sqlCtx, path) +#' rdd <- toRDD(df) +#' } +setMethod("toRDD", + signature(x = "DataFrame"), + function(x) { + jrdd <- callJStatic("org.apache.spark.sql.api.r.SQLUtils", "dfToRowRDD", x@sdf) + colNames <- callJMethod(x@sdf, "columns") + rdd <- RDD(jrdd, serializedMode = "row") + lapply(rdd, function(row) { + names(row) <- colNames + row + }) + }) + +#' GroupBy +#' +#' Groups the DataFrame using the specified columns, so we can run aggregation on them. +#' +#' @param x a DataFrame +#' @return a GroupedData +#' @seealso GroupedData +#' @rdname DataFrame +#' @export +#' @examples +#' \dontrun{ +#' # Compute the average for all numeric columns grouped by department. +#' avg(groupBy(df, "department")) +#' +#' # Compute the max age and average salary, grouped by department and gender. +#' agg(groupBy(df, "department", "gender"), salary="avg", "age" -> "max") +#' } +setMethod("groupBy", + signature(x = "DataFrame"), + function(x, ...) { + cols <- list(...) + if (length(cols) >= 1 && class(cols[[1]]) == "character") { + sgd <- callJMethod(x@sdf, "groupBy", cols[[1]], listToSeq(cols[-1])) + } else { + jcol <- lapply(cols, function(c) { c@jc }) + sgd <- callJMethod(x@sdf, "groupBy", listToSeq(jcol)) + } + groupedData(sgd) + }) + +#' Agg +#' +#' Compute aggregates by specifying a list of columns +#' +#' @rdname DataFrame +#' @export +setMethod("agg", + signature(x = "DataFrame"), + function(x, ...) { + agg(groupBy(x), ...) + }) + + +############################## RDD Map Functions ################################## +# All of the following functions mirror the existing RDD map functions, # +# but allow for use with DataFrames by first converting to an RRDD before calling # +# the requested map function. # +################################################################################### + +#' @rdname lapply +setMethod("lapply", + signature(X = "DataFrame", FUN = "function"), + function(X, FUN) { + rdd <- toRDD(X) + lapply(rdd, FUN) + }) + +#' @rdname lapply +setMethod("map", + signature(X = "DataFrame", FUN = "function"), + function(X, FUN) { + lapply(X, FUN) + }) + +#' @rdname flatMap +setMethod("flatMap", + signature(X = "DataFrame", FUN = "function"), + function(X, FUN) { + rdd <- toRDD(X) + flatMap(rdd, FUN) + }) + +#' @rdname lapplyPartition +setMethod("lapplyPartition", + signature(X = "DataFrame", FUN = "function"), + function(X, FUN) { + rdd <- toRDD(X) + lapplyPartition(rdd, FUN) + }) + +#' @rdname lapplyPartition +setMethod("mapPartitions", + signature(X = "DataFrame", FUN = "function"), + function(X, FUN) { + lapplyPartition(X, FUN) + }) + +#' @rdname foreach +setMethod("foreach", + signature(x = "DataFrame", func = "function"), + function(x, func) { + rdd <- toRDD(x) + foreach(rdd, func) + }) + +#' @rdname foreach +setMethod("foreachPartition", + signature(x = "DataFrame", func = "function"), + function(x, func) { + rdd <- toRDD(x) + foreachPartition(rdd, func) + }) + + +############################## SELECT ################################## + +getColumn <- function(x, c) { + column(callJMethod(x@sdf, "col", c)) +} + +#' @rdname select +setMethod("$", signature(x = "DataFrame"), + function(x, name) { + getColumn(x, name) + }) + +setMethod("$<-", signature(x = "DataFrame"), + function(x, name, value) { + stopifnot(class(value) == "Column") + cols <- columns(x) + if (name %in% cols) { + cols <- lapply(cols, function(c) { + if (c == name) { + alias(value, name) + } else { + col(c) + } + }) + nx <- select(x, cols) + } else { + nx <- withColumn(x, name, value) + } + x@sdf <- nx@sdf + x + }) + +#' @rdname select +setMethod("[[", signature(x = "DataFrame"), + function(x, i) { + if (is.numeric(i)) { + cols <- columns(x) + i <- cols[[i]] + } + getColumn(x, i) + }) + +#' @rdname select +setMethod("[", signature(x = "DataFrame", i = "missing"), + function(x, i, j, ...) { + if (is.numeric(j)) { + cols <- columns(x) + j <- cols[j] + } + if (length(j) > 1) { + j <- as.list(j) + } + select(x, j) + }) + +#' Select +#' +#' Selects a set of columns with names or Column expressions. +#' @param x A DataFrame +#' @param col A list of columns or single Column or name +#' @return A new DataFrame with selected columns +#' @export +#' @rdname select +#' @examples +#' \dontrun{ +#' select(df, "*") +#' select(df, "col1", "col2") +#' select(df, df$name, df$age + 1) +#' select(df, c("col1", "col2")) +#' select(df, list(df$name, df$age + 1)) +#' # Columns can also be selected using `[[` and `[` +#' df[[2]] == df[["age"]] +#' df[,2] == df[,"age"] +#' # Similar to R data frames columns can also be selected using `$` +#' df$age +#' } +setMethod("select", signature(x = "DataFrame", col = "character"), + function(x, col, ...) { + sdf <- callJMethod(x@sdf, "select", col, toSeq(...)) + dataFrame(sdf) + }) + +#' @rdname select +#' @export +setMethod("select", signature(x = "DataFrame", col = "Column"), + function(x, col, ...) { + jcols <- lapply(list(col, ...), function(c) { + c@jc + }) + sdf <- callJMethod(x@sdf, "select", listToSeq(jcols)) + dataFrame(sdf) + }) + +#' @rdname select +#' @export +setMethod("select", + signature(x = "DataFrame", col = "list"), + function(x, col) { + cols <- lapply(col, function(c) { + if (class(c)== "Column") { + c@jc + } else { + col(c)@jc + } + }) + sdf <- callJMethod(x@sdf, "select", listToSeq(cols)) + dataFrame(sdf) + }) + +#' SelectExpr +#' +#' Select from a DataFrame using a set of SQL expressions. +#' +#' @param x A DataFrame to be selected from. +#' @param expr A string containing a SQL expression +#' @param ... Additional expressions +#' @return A DataFrame +#' @rdname selectExpr +#' @export +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' sqlCtx <- sparkRSQL.init(sc) +#' path <- "path/to/file.json" +#' df <- jsonFile(sqlCtx, path) +#' selectExpr(df, "col1", "(col2 * 5) as newCol") +#' } +setMethod("selectExpr", + signature(x = "DataFrame", expr = "character"), + function(x, expr, ...) { + exprList <- list(expr, ...) + sdf <- callJMethod(x@sdf, "selectExpr", listToSeq(exprList)) + dataFrame(sdf) + }) + +#' WithColumn +#' +#' Return a new DataFrame with the specified column added. +#' +#' @param x A DataFrame +#' @param colName A string containing the name of the new column. +#' @param col A Column expression. +#' @return A DataFrame with the new column added. +#' @rdname withColumn +#' @export +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' sqlCtx <- sparkRSQL.init(sc) +#' path <- "path/to/file.json" +#' df <- jsonFile(sqlCtx, path) +#' newDF <- withColumn(df, "newCol", df$col1 * 5) +#' } +setMethod("withColumn", + signature(x = "DataFrame", colName = "character", col = "Column"), + function(x, colName, col) { + select(x, x$"*", alias(col, colName)) + }) + +#' WithColumnRenamed +#' +#' Rename an existing column in a DataFrame. +#' +#' @param x A DataFrame +#' @param existingCol The name of the column you want to change. +#' @param newCol The new column name. +#' @return A DataFrame with the column name changed. +#' @rdname withColumnRenamed +#' @export +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' sqlCtx <- sparkRSQL.init(sc) +#' path <- "path/to/file.json" +#' df <- jsonFile(sqlCtx, path) +#' newDF <- withColumnRenamed(df, "col1", "newCol1") +#' } +setMethod("withColumnRenamed", + signature(x = "DataFrame", existingCol = "character", newCol = "character"), + function(x, existingCol, newCol) { + cols <- lapply(columns(x), function(c) { + if (c == existingCol) { + alias(col(c), newCol) + } else { + col(c) + } + }) + select(x, cols) + }) + +setClassUnion("characterOrColumn", c("character", "Column")) + +#' SortDF +#' +#' Sort a DataFrame by the specified column(s). +#' +#' @param x A DataFrame to be sorted. +#' @param col Either a Column object or character vector indicating the field to sort on +#' @param ... Additional sorting fields +#' @return A DataFrame where all elements are sorted. +#' @rdname sortDF +#' @export +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' sqlCtx <- sparkRSQL.init(sc) +#' path <- "path/to/file.json" +#' df <- jsonFile(sqlCtx, path) +#' sortDF(df, df$col1) +#' sortDF(df, "col1") +#' sortDF(df, asc(df$col1), desc(abs(df$col2))) +#' } +setMethod("sortDF", + signature(x = "DataFrame", col = "characterOrColumn"), + function(x, col, ...) { + if (class(col) == "character") { + sdf <- callJMethod(x@sdf, "sort", col, toSeq(...)) + } else if (class(col) == "Column") { + jcols <- lapply(list(col, ...), function(c) { + c@jc + }) + sdf <- callJMethod(x@sdf, "sort", listToSeq(jcols)) + } + dataFrame(sdf) + }) + +#' @rdname sortDF +#' @export +setMethod("orderBy", + signature(x = "DataFrame", col = "characterOrColumn"), + function(x, col) { + sortDF(x, col) + }) + +#' Filter +#' +#' Filter the rows of a DataFrame according to a given condition. +#' +#' @param x A DataFrame to be sorted. +#' @param condition The condition to sort on. This may either be a Column expression +#' or a string containing a SQL statement +#' @return A DataFrame containing only the rows that meet the condition. +#' @rdname filter +#' @export +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' sqlCtx <- sparkRSQL.init(sc) +#' path <- "path/to/file.json" +#' df <- jsonFile(sqlCtx, path) +#' filter(df, "col1 > 0") +#' filter(df, df$col2 != "abcdefg") +#' } +setMethod("filter", + signature(x = "DataFrame", condition = "characterOrColumn"), + function(x, condition) { + if (class(condition) == "Column") { + condition <- condition@jc + } + sdf <- callJMethod(x@sdf, "filter", condition) + dataFrame(sdf) + }) + +#' @rdname filter +#' @export +setMethod("where", + signature(x = "DataFrame", condition = "characterOrColumn"), + function(x, condition) { + filter(x, condition) + }) + +#' Join +#' +#' Join two DataFrames based on the given join expression. +#' +#' @param x A Spark DataFrame +#' @param y A Spark DataFrame +#' @param joinExpr (Optional) The expression used to perform the join. joinExpr must be a +#' Column expression. If joinExpr is omitted, join() wil perform a Cartesian join +#' @param joinType The type of join to perform. The following join types are available: +#' 'inner', 'outer', 'left_outer', 'right_outer', 'semijoin'. The default joinType is "inner". +#' @return A DataFrame containing the result of the join operation. +#' @rdname join +#' @export +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' sqlCtx <- sparkRSQL.init(sc) +#' df1 <- jsonFile(sqlCtx, path) +#' df2 <- jsonFile(sqlCtx, path2) +#' join(df1, df2) # Performs a Cartesian +#' join(df1, df2, df1$col1 == df2$col2) # Performs an inner join based on expression +#' join(df1, df2, df1$col1 == df2$col2, "right_outer") +#' } +setMethod("join", + signature(x = "DataFrame", y = "DataFrame"), + function(x, y, joinExpr = NULL, joinType = NULL) { + if (is.null(joinExpr)) { + sdf <- callJMethod(x@sdf, "join", y@sdf) + } else { + if (class(joinExpr) != "Column") stop("joinExpr must be a Column") + if (is.null(joinType)) { + sdf <- callJMethod(x@sdf, "join", y@sdf, joinExpr@jc) + } else { + if (joinType %in% c("inner", "outer", "left_outer", "right_outer", "semijoin")) { + sdf <- callJMethod(x@sdf, "join", y@sdf, joinExpr@jc, joinType) + } else { + stop("joinType must be one of the following types: ", + "'inner', 'outer', 'left_outer', 'right_outer', 'semijoin'") + } + } + } + dataFrame(sdf) + }) + +#' UnionAll +#' +#' Return a new DataFrame containing the union of rows in this DataFrame +#' and another DataFrame. This is equivalent to `UNION ALL` in SQL. +#' +#' @param x A Spark DataFrame +#' @param y A Spark DataFrame +#' @return A DataFrame containing the result of the union. +#' @rdname unionAll +#' @export +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' sqlCtx <- sparkRSQL.init(sc) +#' df1 <- jsonFile(sqlCtx, path) +#' df2 <- jsonFile(sqlCtx, path2) +#' unioned <- unionAll(df, df2) +#' } +setMethod("unionAll", + signature(x = "DataFrame", y = "DataFrame"), + function(x, y) { + unioned <- callJMethod(x@sdf, "unionAll", y@sdf) + dataFrame(unioned) + }) + +#' Intersect +#' +#' Return a new DataFrame containing rows only in both this DataFrame +#' and another DataFrame. This is equivalent to `INTERSECT` in SQL. +#' +#' @param x A Spark DataFrame +#' @param y A Spark DataFrame +#' @return A DataFrame containing the result of the intersect. +#' @rdname intersect +#' @export +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' sqlCtx <- sparkRSQL.init(sc) +#' df1 <- jsonFile(sqlCtx, path) +#' df2 <- jsonFile(sqlCtx, path2) +#' intersectDF <- intersect(df, df2) +#' } +setMethod("intersect", + signature(x = "DataFrame", y = "DataFrame"), + function(x, y) { + intersected <- callJMethod(x@sdf, "intersect", y@sdf) + dataFrame(intersected) + }) + +#' Subtract +#' +#' Return a new DataFrame containing rows in this DataFrame +#' but not in another DataFrame. This is equivalent to `EXCEPT` in SQL. +#' +#' @param x A Spark DataFrame +#' @param y A Spark DataFrame +#' @return A DataFrame containing the result of the subtract operation. +#' @rdname subtract +#' @export +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' sqlCtx <- sparkRSQL.init(sc) +#' df1 <- jsonFile(sqlCtx, path) +#' df2 <- jsonFile(sqlCtx, path2) +#' subtractDF <- subtract(df, df2) +#' } +setMethod("subtract", + signature(x = "DataFrame", y = "DataFrame"), + function(x, y) { + subtracted <- callJMethod(x@sdf, "except", y@sdf) + dataFrame(subtracted) + }) + +#' Save the contents of the DataFrame to a data source +#' +#' The data source is specified by the `source` and a set of options (...). +#' If `source` is not specified, the default data source configured by +#' spark.sql.sources.default will be used. +#' +#' Additionally, mode is used to specify the behavior of the save operation when +#' data already exists in the data source. There are four modes: +#' append: Contents of this DataFrame are expected to be appended to existing data. +#' overwrite: Existing data is expected to be overwritten by the contents of +# this DataFrame. +#' error: An exception is expected to be thrown. +#' ignore: The save operation is expected to not save the contents of the DataFrame +# and to not change the existing data. +#' +#' @param df A SparkSQL DataFrame +#' @param path A name for the table +#' @param source A name for external data source +#' @param mode One of 'append', 'overwrite', 'error', 'ignore' +#' +#' @rdname saveAsTable +#' @export +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' sqlCtx <- sparkRSQL.init(sc) +#' path <- "path/to/file.json" +#' df <- jsonFile(sqlCtx, path) +#' saveAsTable(df, "myfile") +#' } +setMethod("saveDF", + signature(df = "DataFrame", path = 'character', source = 'character', + mode = 'character'), + function(df, path = NULL, source = NULL, mode = "append", ...){ + if (is.null(source)) { + sqlCtx <- get(".sparkRSQLsc", envir = .sparkREnv) + source <- callJMethod(sqlCtx, "getConf", "spark.sql.sources.default", + "org.apache.spark.sql.parquet") + } + allModes <- c("append", "overwrite", "error", "ignore") + if (!(mode %in% allModes)) { + stop('mode should be one of "append", "overwrite", "error", "ignore"') + } + jmode <- callJStatic("org.apache.spark.sql.api.r.SQLUtils", "saveMode", mode) + options <- varargsToEnv(...) + if (!is.null(path)) { + options[['path']] = path + } + callJMethod(df@sdf, "save", source, jmode, options) + }) + + +#' saveAsTable +#' +#' Save the contents of the DataFrame to a data source as a table +#' +#' The data source is specified by the `source` and a set of options (...). +#' If `source` is not specified, the default data source configured by +#' spark.sql.sources.default will be used. +#' +#' Additionally, mode is used to specify the behavior of the save operation when +#' data already exists in the data source. There are four modes: +#' append: Contents of this DataFrame are expected to be appended to existing data. +#' overwrite: Existing data is expected to be overwritten by the contents of +# this DataFrame. +#' error: An exception is expected to be thrown. +#' ignore: The save operation is expected to not save the contents of the DataFrame +# and to not change the existing data. +#' +#' @param df A SparkSQL DataFrame +#' @param tableName A name for the table +#' @param source A name for external data source +#' @param mode One of 'append', 'overwrite', 'error', 'ignore' +#' +#' @rdname saveAsTable +#' @export +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' sqlCtx <- sparkRSQL.init(sc) +#' path <- "path/to/file.json" +#' df <- jsonFile(sqlCtx, path) +#' saveAsTable(df, "myfile") +#' } +setMethod("saveAsTable", + signature(df = "DataFrame", tableName = 'character', source = 'character', + mode = 'character'), + function(df, tableName, source = NULL, mode="append", ...){ + if (is.null(source)) { + sqlCtx <- get(".sparkRSQLsc", envir = .sparkREnv) + source <- callJMethod(sqlCtx, "getConf", "spark.sql.sources.default", + "org.apache.spark.sql.parquet") + } + allModes <- c("append", "overwrite", "error", "ignore") + if (!(mode %in% allModes)) { + stop('mode should be one of "append", "overwrite", "error", "ignore"') + } + jmode <- callJStatic("org.apache.spark.sql.api.r.SQLUtils", "saveMode", mode) + options <- varargsToEnv(...) + callJMethod(df@sdf, "saveAsTable", tableName, source, jmode, options) + }) + diff --git a/R/pkg/R/RDD.R b/R/pkg/R/RDD.R new file mode 100644 index 0000000000000..604ad03c407b9 --- /dev/null +++ b/R/pkg/R/RDD.R @@ -0,0 +1,1539 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +# RDD in R implemented in S4 OO system. + +setOldClass("jobj") + +#' @title S4 class that represents an RDD +#' @description RDD can be created using functions like +#' \code{parallelize}, \code{textFile} etc. +#' @rdname RDD +#' @seealso parallelize, textFile +#' +#' @slot env An R environment that stores bookkeeping states of the RDD +#' @slot jrdd Java object reference to the backing JavaRDD +#' to an RDD +#' @export +setClass("RDD", + slots = list(env = "environment", + jrdd = "jobj")) + +setClass("PipelinedRDD", + slots = list(prev = "RDD", + func = "function", + prev_jrdd = "jobj"), + contains = "RDD") + +setMethod("initialize", "RDD", function(.Object, jrdd, serializedMode, + isCached, isCheckpointed) { + # Check that RDD constructor is using the correct version of serializedMode + stopifnot(class(serializedMode) == "character") + stopifnot(serializedMode %in% c("byte", "string", "row")) + # RDD has three serialization types: + # byte: The RDD stores data serialized in R. + # string: The RDD stores data as strings. + # row: The RDD stores the serialized rows of a DataFrame. + + # We use an environment to store mutable states inside an RDD object. + # Note that R's call-by-value semantics makes modifying slots inside an + # object (passed as an argument into a function, such as cache()) difficult: + # i.e. one needs to make a copy of the RDD object and sets the new slot value + # there. + + # The slots are inheritable from superclass. Here, both `env' and `jrdd' are + # inherited from RDD, but only the former is used. + .Object@env <- new.env() + .Object@env$isCached <- isCached + .Object@env$isCheckpointed <- isCheckpointed + .Object@env$serializedMode <- serializedMode + + .Object@jrdd <- jrdd + .Object +}) + +setMethod("initialize", "PipelinedRDD", function(.Object, prev, func, jrdd_val) { + .Object@env <- new.env() + .Object@env$isCached <- FALSE + .Object@env$isCheckpointed <- FALSE + .Object@env$jrdd_val <- jrdd_val + if (!is.null(jrdd_val)) { + # This tracks the serialization mode for jrdd_val + .Object@env$serializedMode <- prev@env$serializedMode + } + + .Object@prev <- prev + + isPipelinable <- function(rdd) { + e <- rdd@env + !(e$isCached || e$isCheckpointed) + } + + if (!inherits(prev, "PipelinedRDD") || !isPipelinable(prev)) { + # This transformation is the first in its stage: + .Object@func <- func + .Object@prev_jrdd <- getJRDD(prev) + .Object@env$prev_serializedMode <- prev@env$serializedMode + # NOTE: We use prev_serializedMode to track the serialization mode of prev_JRDD + # prev_serializedMode is used during the delayed computation of JRDD in getJRDD + } else { + pipelinedFunc <- function(split, iterator) { + func(split, prev@func(split, iterator)) + } + .Object@func <- pipelinedFunc + .Object@prev_jrdd <- prev@prev_jrdd # maintain the pipeline + # Get the serialization mode of the parent RDD + .Object@env$prev_serializedMode <- prev@env$prev_serializedMode + } + + .Object +}) + +#' @rdname RDD +#' @export +#' +#' @param jrdd Java object reference to the backing JavaRDD +#' @param serializedMode Use "byte" if the RDD stores data serialized in R, "string" if the RDD +#' stores strings, and "row" if the RDD stores the rows of a DataFrame +#' @param isCached TRUE if the RDD is cached +#' @param isCheckpointed TRUE if the RDD has been checkpointed +RDD <- function(jrdd, serializedMode = "byte", isCached = FALSE, + isCheckpointed = FALSE) { + new("RDD", jrdd, serializedMode, isCached, isCheckpointed) +} + +PipelinedRDD <- function(prev, func) { + new("PipelinedRDD", prev, func, NULL) +} + +# Return the serialization mode for an RDD. +setGeneric("getSerializedMode", function(rdd, ...) { standardGeneric("getSerializedMode") }) +# For normal RDDs we can directly read the serializedMode +setMethod("getSerializedMode", signature(rdd = "RDD"), function(rdd) rdd@env$serializedMode ) +# For pipelined RDDs if jrdd_val is set then serializedMode should exist +# if not we return the defaultSerialization mode of "byte" as we don't know the serialization +# mode at this point in time. +setMethod("getSerializedMode", signature(rdd = "PipelinedRDD"), + function(rdd) { + if (!is.null(rdd@env$jrdd_val)) { + return(rdd@env$serializedMode) + } else { + return("byte") + } + }) + +# The jrdd accessor function. +setMethod("getJRDD", signature(rdd = "RDD"), function(rdd) rdd@jrdd ) +setMethod("getJRDD", signature(rdd = "PipelinedRDD"), + function(rdd, serializedMode = "byte") { + if (!is.null(rdd@env$jrdd_val)) { + return(rdd@env$jrdd_val) + } + + computeFunc <- function(split, part) { + rdd@func(split, part) + } + + packageNamesArr <- serialize(.sparkREnv[[".packages"]], + connection = NULL) + + broadcastArr <- lapply(ls(.broadcastNames), + function(name) { get(name, .broadcastNames) }) + + serializedFuncArr <- serialize(computeFunc, connection = NULL) + + prev_jrdd <- rdd@prev_jrdd + + if (serializedMode == "string") { + rddRef <- newJObject("org.apache.spark.api.r.StringRRDD", + callJMethod(prev_jrdd, "rdd"), + serializedFuncArr, + rdd@env$prev_serializedMode, + packageNamesArr, + as.character(.sparkREnv[["libname"]]), + broadcastArr, + callJMethod(prev_jrdd, "classTag")) + } else { + rddRef <- newJObject("org.apache.spark.api.r.RRDD", + callJMethod(prev_jrdd, "rdd"), + serializedFuncArr, + rdd@env$prev_serializedMode, + serializedMode, + packageNamesArr, + as.character(.sparkREnv[["libname"]]), + broadcastArr, + callJMethod(prev_jrdd, "classTag")) + } + # Save the serialization flag after we create a RRDD + rdd@env$serializedMode <- serializedMode + rdd@env$jrdd_val <- callJMethod(rddRef, "asJavaRDD") # rddRef$asJavaRDD() + rdd@env$jrdd_val + }) + +setValidity("RDD", + function(object) { + jrdd <- getJRDD(object) + cls <- callJMethod(jrdd, "getClass") + className <- callJMethod(cls, "getName") + if (grep("spark.api.java.*RDD*", className) == 1) { + TRUE + } else { + paste("Invalid RDD class ", className) + } + }) + + +############ Actions and Transformations ############ + +#' Persist an RDD +#' +#' Persist this RDD with the default storage level (MEMORY_ONLY). +#' +#' @param x The RDD to cache +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' rdd <- parallelize(sc, 1:10, 2L) +#' cache(rdd) +#'} +#' @rdname cache-methods +#' @aliases cache,RDD-method +setMethod("cache", + signature(x = "RDD"), + function(x) { + callJMethod(getJRDD(x), "cache") + x@env$isCached <- TRUE + x + }) + +#' Persist an RDD +#' +#' Persist this RDD with the specified storage level. For details of the +#' supported storage levels, refer to +#' http://spark.apache.org/docs/latest/programming-guide.html#rdd-persistence. +#' +#' @param x The RDD to persist +#' @param newLevel The new storage level to be assigned +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' rdd <- parallelize(sc, 1:10, 2L) +#' persist(rdd, "MEMORY_AND_DISK") +#'} +#' @rdname persist +#' @aliases persist,RDD-method +setMethod("persist", + signature(x = "RDD", newLevel = "character"), + function(x, newLevel) { + callJMethod(getJRDD(x), "persist", getStorageLevel(newLevel)) + x@env$isCached <- TRUE + x + }) + +#' Unpersist an RDD +#' +#' Mark the RDD as non-persistent, and remove all blocks for it from memory and +#' disk. +#' +#' @param x The RDD to unpersist +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' rdd <- parallelize(sc, 1:10, 2L) +#' cache(rdd) # rdd@@env$isCached == TRUE +#' unpersist(rdd) # rdd@@env$isCached == FALSE +#'} +#' @rdname unpersist-methods +#' @aliases unpersist,RDD-method +setMethod("unpersist", + signature(x = "RDD"), + function(x) { + callJMethod(getJRDD(x), "unpersist") + x@env$isCached <- FALSE + x + }) + +#' Checkpoint an RDD +#' +#' Mark this RDD for checkpointing. It will be saved to a file inside the +#' checkpoint directory set with setCheckpointDir() and all references to its +#' parent RDDs will be removed. This function must be called before any job has +#' been executed on this RDD. It is strongly recommended that this RDD is +#' persisted in memory, otherwise saving it on a file will require recomputation. +#' +#' @param x The RDD to checkpoint +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' setCheckpointDir(sc, "checkpoints") +#' rdd <- parallelize(sc, 1:10, 2L) +#' checkpoint(rdd) +#'} +#' @rdname checkpoint-methods +#' @aliases checkpoint,RDD-method +setMethod("checkpoint", + signature(x = "RDD"), + function(x) { + jrdd <- getJRDD(x) + callJMethod(jrdd, "checkpoint") + x@env$isCheckpointed <- TRUE + x + }) + +#' Gets the number of partitions of an RDD +#' +#' @param x A RDD. +#' @return the number of partitions of rdd as an integer. +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' rdd <- parallelize(sc, 1:10, 2L) +#' numPartitions(rdd) # 2L +#'} +#' @rdname numPartitions +#' @aliases numPartitions,RDD-method +setMethod("numPartitions", + signature(x = "RDD"), + function(x) { + jrdd <- getJRDD(x) + partitions <- callJMethod(jrdd, "splits") + callJMethod(partitions, "size") + }) + +#' Collect elements of an RDD +#' +#' @description +#' \code{collect} returns a list that contains all of the elements in this RDD. +#' +#' @param x The RDD to collect +#' @param ... Other optional arguments to collect +#' @param flatten FALSE if the list should not flattened +#' @return a list containing elements in the RDD +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' rdd <- parallelize(sc, 1:10, 2L) +#' collect(rdd) # list from 1 to 10 +#' collectPartition(rdd, 0L) # list from 1 to 5 +#'} +#' @rdname collect-methods +#' @aliases collect,RDD-method +setMethod("collect", + signature(x = "RDD"), + function(x, flatten = TRUE) { + # Assumes a pairwise RDD is backed by a JavaPairRDD. + collected <- callJMethod(getJRDD(x), "collect") + convertJListToRList(collected, flatten, + serializedMode = getSerializedMode(x)) + }) + + +#' @description +#' \code{collectPartition} returns a list that contains all of the elements +#' in the specified partition of the RDD. +#' @param partitionId the partition to collect (starts from 0) +#' @rdname collect-methods +#' @aliases collectPartition,integer,RDD-method +setMethod("collectPartition", + signature(x = "RDD", partitionId = "integer"), + function(x, partitionId) { + jPartitionsList <- callJMethod(getJRDD(x), + "collectPartitions", + as.list(as.integer(partitionId))) + + jList <- jPartitionsList[[1]] + convertJListToRList(jList, flatten = TRUE, + serializedMode = getSerializedMode(x)) + }) + +#' @description +#' \code{collectAsMap} returns a named list as a map that contains all of the elements +#' in a key-value pair RDD. +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' rdd <- parallelize(sc, list(list(1, 2), list(3, 4)), 2L) +#' collectAsMap(rdd) # list(`1` = 2, `3` = 4) +#'} +#' @rdname collect-methods +#' @aliases collectAsMap,RDD-method +setMethod("collectAsMap", + signature(x = "RDD"), + function(x) { + pairList <- collect(x) + map <- new.env() + lapply(pairList, function(i) { assign(as.character(i[[1]]), i[[2]], envir = map) }) + as.list(map) + }) + +#' Return the number of elements in the RDD. +#' +#' @param x The RDD to count +#' @return number of elements in the RDD. +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' rdd <- parallelize(sc, 1:10) +#' count(rdd) # 10 +#' length(rdd) # Same as count +#'} +#' @rdname count +#' @aliases count,RDD-method +setMethod("count", + signature(x = "RDD"), + function(x) { + countPartition <- function(part) { + as.integer(length(part)) + } + valsRDD <- lapplyPartition(x, countPartition) + vals <- collect(valsRDD) + sum(as.integer(vals)) + }) + +#' Return the number of elements in the RDD +#' @export +#' @rdname count +setMethod("length", + signature(x = "RDD"), + function(x) { + count(x) + }) + +#' Return the count of each unique value in this RDD as a list of +#' (value, count) pairs. +#' +#' Same as countByValue in Spark. +#' +#' @param x The RDD to count +#' @return list of (value, count) pairs, where count is number of each unique +#' value in rdd. +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' rdd <- parallelize(sc, c(1,2,3,2,1)) +#' countByValue(rdd) # (1,2L), (2,2L), (3,1L) +#'} +#' @rdname countByValue +#' @aliases countByValue,RDD-method +setMethod("countByValue", + signature(x = "RDD"), + function(x) { + ones <- lapply(x, function(item) { list(item, 1L) }) + collect(reduceByKey(ones, `+`, numPartitions(x))) + }) + +#' Apply a function to all elements +#' +#' This function creates a new RDD by applying the given transformation to all +#' elements of the given RDD +#' +#' @param X The RDD to apply the transformation. +#' @param FUN the transformation to apply on each element +#' @return a new RDD created by the transformation. +#' @rdname lapply +#' @aliases lapply +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' rdd <- parallelize(sc, 1:10) +#' multiplyByTwo <- lapply(rdd, function(x) { x * 2 }) +#' collect(multiplyByTwo) # 2,4,6... +#'} +setMethod("lapply", + signature(X = "RDD", FUN = "function"), + function(X, FUN) { + func <- function(split, iterator) { + lapply(iterator, FUN) + } + lapplyPartitionsWithIndex(X, func) + }) + +#' @rdname lapply +#' @aliases map,RDD,function-method +setMethod("map", + signature(X = "RDD", FUN = "function"), + function(X, FUN) { + lapply(X, FUN) + }) + +#' Flatten results after apply a function to all elements +#' +#' This function return a new RDD by first applying a function to all +#' elements of this RDD, and then flattening the results. +#' +#' @param X The RDD to apply the transformation. +#' @param FUN the transformation to apply on each element +#' @return a new RDD created by the transformation. +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' rdd <- parallelize(sc, 1:10) +#' multiplyByTwo <- flatMap(rdd, function(x) { list(x*2, x*10) }) +#' collect(multiplyByTwo) # 2,20,4,40,6,60... +#'} +#' @rdname flatMap +#' @aliases flatMap,RDD,function-method +setMethod("flatMap", + signature(X = "RDD", FUN = "function"), + function(X, FUN) { + partitionFunc <- function(part) { + unlist( + lapply(part, FUN), + recursive = F + ) + } + lapplyPartition(X, partitionFunc) + }) + +#' Apply a function to each partition of an RDD +#' +#' Return a new RDD by applying a function to each partition of this RDD. +#' +#' @param X The RDD to apply the transformation. +#' @param FUN the transformation to apply on each partition. +#' @return a new RDD created by the transformation. +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' rdd <- parallelize(sc, 1:10) +#' partitionSum <- lapplyPartition(rdd, function(part) { Reduce("+", part) }) +#' collect(partitionSum) # 15, 40 +#'} +#' @rdname lapplyPartition +#' @aliases lapplyPartition,RDD,function-method +setMethod("lapplyPartition", + signature(X = "RDD", FUN = "function"), + function(X, FUN) { + lapplyPartitionsWithIndex(X, function(s, part) { FUN(part) }) + }) + +#' mapPartitions is the same as lapplyPartition. +#' +#' @rdname lapplyPartition +#' @aliases mapPartitions,RDD,function-method +setMethod("mapPartitions", + signature(X = "RDD", FUN = "function"), + function(X, FUN) { + lapplyPartition(X, FUN) + }) + +#' Return a new RDD by applying a function to each partition of this RDD, while +#' tracking the index of the original partition. +#' +#' @param X The RDD to apply the transformation. +#' @param FUN the transformation to apply on each partition; takes the partition +#' index and a list of elements in the particular partition. +#' @return a new RDD created by the transformation. +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' rdd <- parallelize(sc, 1:10, 5L) +#' prod <- lapplyPartitionsWithIndex(rdd, function(split, part) { +#' split * Reduce("+", part) }) +#' collect(prod, flatten = FALSE) # 0, 7, 22, 45, 76 +#'} +#' @rdname lapplyPartitionsWithIndex +#' @aliases lapplyPartitionsWithIndex,RDD,function-method +setMethod("lapplyPartitionsWithIndex", + signature(X = "RDD", FUN = "function"), + function(X, FUN) { + FUN <- cleanClosure(FUN) + closureCapturingFunc <- function(split, part) { + FUN(split, part) + } + PipelinedRDD(X, closureCapturingFunc) + }) + +#' @rdname lapplyPartitionsWithIndex +#' @aliases mapPartitionsWithIndex,RDD,function-method +setMethod("mapPartitionsWithIndex", + signature(X = "RDD", FUN = "function"), + function(X, FUN) { + lapplyPartitionsWithIndex(X, FUN) + }) + +#' This function returns a new RDD containing only the elements that satisfy +#' a predicate (i.e. returning TRUE in a given logical function). +#' The same as `filter()' in Spark. +#' +#' @param x The RDD to be filtered. +#' @param f A unary predicate function. +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' rdd <- parallelize(sc, 1:10) +#' unlist(collect(filterRDD(rdd, function (x) { x < 3 }))) # c(1, 2) +#'} +#' @rdname filterRDD +#' @aliases filterRDD,RDD,function-method +setMethod("filterRDD", + signature(x = "RDD", f = "function"), + function(x, f) { + filter.func <- function(part) { + Filter(f, part) + } + lapplyPartition(x, filter.func) + }) + +#' @rdname filterRDD +#' @aliases Filter +setMethod("Filter", + signature(f = "function", x = "RDD"), + function(f, x) { + filterRDD(x, f) + }) + +#' Reduce across elements of an RDD. +#' +#' This function reduces the elements of this RDD using the +#' specified commutative and associative binary operator. +#' +#' @param x The RDD to reduce +#' @param func Commutative and associative function to apply on elements +#' of the RDD. +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' rdd <- parallelize(sc, 1:10) +#' reduce(rdd, "+") # 55 +#'} +#' @rdname reduce +#' @aliases reduce,RDD,ANY-method +setMethod("reduce", + signature(x = "RDD", func = "ANY"), + function(x, func) { + + reducePartition <- function(part) { + Reduce(func, part) + } + + partitionList <- collect(lapplyPartition(x, reducePartition), + flatten = FALSE) + Reduce(func, partitionList) + }) + +#' Get the maximum element of an RDD. +#' +#' @param x The RDD to get the maximum element from +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' rdd <- parallelize(sc, 1:10) +#' maximum(rdd) # 10 +#'} +#' @rdname maximum +#' @aliases maximum,RDD +setMethod("maximum", + signature(x = "RDD"), + function(x) { + reduce(x, max) + }) + +#' Get the minimum element of an RDD. +#' +#' @param x The RDD to get the minimum element from +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' rdd <- parallelize(sc, 1:10) +#' minimum(rdd) # 1 +#'} +#' @rdname minimum +#' @aliases minimum,RDD +setMethod("minimum", + signature(x = "RDD"), + function(x) { + reduce(x, min) + }) + +#' Add up the elements in an RDD. +#' +#' @param x The RDD to add up the elements in +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' rdd <- parallelize(sc, 1:10) +#' sumRDD(rdd) # 55 +#'} +#' @rdname sumRDD +#' @aliases sumRDD,RDD +setMethod("sumRDD", + signature(x = "RDD"), + function(x) { + reduce(x, "+") + }) + +#' Applies a function to all elements in an RDD, and force evaluation. +#' +#' @param x The RDD to apply the function +#' @param func The function to be applied. +#' @return invisible NULL. +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' rdd <- parallelize(sc, 1:10) +#' foreach(rdd, function(x) { save(x, file=...) }) +#'} +#' @rdname foreach +#' @aliases foreach,RDD,function-method +setMethod("foreach", + signature(x = "RDD", func = "function"), + function(x, func) { + partition.func <- function(x) { + lapply(x, func) + NULL + } + invisible(collect(mapPartitions(x, partition.func))) + }) + +#' Applies a function to each partition in an RDD, and force evaluation. +#' +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' rdd <- parallelize(sc, 1:10) +#' foreachPartition(rdd, function(part) { save(part, file=...); NULL }) +#'} +#' @rdname foreach +#' @aliases foreachPartition,RDD,function-method +setMethod("foreachPartition", + signature(x = "RDD", func = "function"), + function(x, func) { + invisible(collect(mapPartitions(x, func))) + }) + +#' Take elements from an RDD. +#' +#' This function takes the first NUM elements in the RDD and +#' returns them in a list. +#' +#' @param x The RDD to take elements from +#' @param num Number of elements to take +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' rdd <- parallelize(sc, 1:10) +#' take(rdd, 2L) # list(1, 2) +#'} +#' @rdname take +#' @aliases take,RDD,numeric-method +setMethod("take", + signature(x = "RDD", num = "numeric"), + function(x, num) { + resList <- list() + index <- -1 + jrdd <- getJRDD(x) + numPartitions <- numPartitions(x) + + # TODO(shivaram): Collect more than one partition based on size + # estimates similar to the scala version of `take`. + while (TRUE) { + index <- index + 1 + + if (length(resList) >= num || index >= numPartitions) + break + + # a JList of byte arrays + partitionArr <- callJMethod(jrdd, "collectPartitions", as.list(as.integer(index))) + partition <- partitionArr[[1]] + + size <- num - length(resList) + # elems is capped to have at most `size` elements + elems <- convertJListToRList(partition, + flatten = TRUE, + logicalUpperBound = size, + serializedMode = getSerializedMode(x)) + # TODO: Check if this append is O(n^2)? + resList <- append(resList, elems) + } + resList + }) + +#' First +#' +#' Return the first element of an RDD +#' +#' @rdname first +#' @export +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' rdd <- parallelize(sc, 1:10) +#' first(rdd) +#' } +setMethod("first", + signature(x = "RDD"), + function(x) { + take(x, 1)[[1]] + }) + +#' Removes the duplicates from RDD. +#' +#' This function returns a new RDD containing the distinct elements in the +#' given RDD. The same as `distinct()' in Spark. +#' +#' @param x The RDD to remove duplicates from. +#' @param numPartitions Number of partitions to create. +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' rdd <- parallelize(sc, c(1,2,2,3,3,3)) +#' sort(unlist(collect(distinct(rdd)))) # c(1, 2, 3) +#'} +#' @rdname distinct +#' @aliases distinct,RDD-method +setMethod("distinct", + signature(x = "RDD"), + function(x, numPartitions = SparkR::numPartitions(x)) { + identical.mapped <- lapply(x, function(x) { list(x, NULL) }) + reduced <- reduceByKey(identical.mapped, + function(x, y) { x }, + numPartitions) + resRDD <- lapply(reduced, function(x) { x[[1]] }) + resRDD + }) + +#' Return an RDD that is a sampled subset of the given RDD. +#' +#' The same as `sample()' in Spark. (We rename it due to signature +#' inconsistencies with the `sample()' function in R's base package.) +#' +#' @param x The RDD to sample elements from +#' @param withReplacement Sampling with replacement or not +#' @param fraction The (rough) sample target fraction +#' @param seed Randomness seed value +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' rdd <- parallelize(sc, 1:10) # ensure each num is in its own split +#' collect(sampleRDD(rdd, FALSE, 0.5, 1618L)) # ~5 distinct elements +#' collect(sampleRDD(rdd, TRUE, 0.5, 9L)) # ~5 elements possibly with duplicates +#'} +#' @rdname sampleRDD +#' @aliases sampleRDD,RDD +setMethod("sampleRDD", + signature(x = "RDD", withReplacement = "logical", + fraction = "numeric", seed = "integer"), + function(x, withReplacement, fraction, seed) { + + # The sampler: takes a partition and returns its sampled version. + samplingFunc <- function(split, part) { + set.seed(seed) + res <- vector("list", length(part)) + len <- 0 + + # Discards some random values to ensure each partition has a + # different random seed. + runif(split) + + for (elem in part) { + if (withReplacement) { + count <- rpois(1, fraction) + if (count > 0) { + res[(len + 1):(len + count)] <- rep(list(elem), count) + len <- len + count + } + } else { + if (runif(1) < fraction) { + len <- len + 1 + res[[len]] <- elem + } + } + } + + # TODO(zongheng): look into the performance of the current + # implementation. Look into some iterator package? Note that + # Scala avoids many calls to creating an empty list and PySpark + # similarly achieves this using `yield'. + if (len > 0) + res[1:len] + else + list() + } + + lapplyPartitionsWithIndex(x, samplingFunc) + }) + +#' Return a list of the elements that are a sampled subset of the given RDD. +#' +#' @param x The RDD to sample elements from +#' @param withReplacement Sampling with replacement or not +#' @param num Number of elements to return +#' @param seed Randomness seed value +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' rdd <- parallelize(sc, 1:100) +#' # exactly 5 elements sampled, which may not be distinct +#' takeSample(rdd, TRUE, 5L, 1618L) +#' # exactly 5 distinct elements sampled +#' takeSample(rdd, FALSE, 5L, 16181618L) +#'} +#' @rdname takeSample +#' @aliases takeSample,RDD +setMethod("takeSample", signature(x = "RDD", withReplacement = "logical", + num = "integer", seed = "integer"), + function(x, withReplacement, num, seed) { + # This function is ported from RDD.scala. + fraction <- 0.0 + total <- 0 + multiplier <- 3.0 + initialCount <- count(x) + maxSelected <- 0 + MAXINT <- .Machine$integer.max + + if (num < 0) + stop(paste("Negative number of elements requested")) + + if (initialCount > MAXINT - 1) { + maxSelected <- MAXINT - 1 + } else { + maxSelected <- initialCount + } + + if (num > initialCount && !withReplacement) { + total <- maxSelected + fraction <- multiplier * (maxSelected + 1) / initialCount + } else { + total <- num + fraction <- multiplier * (num + 1) / initialCount + } + + set.seed(seed) + samples <- collect(sampleRDD(x, withReplacement, fraction, + as.integer(ceiling(runif(1, + -MAXINT, + MAXINT))))) + # If the first sample didn't turn out large enough, keep trying to + # take samples; this shouldn't happen often because we use a big + # multiplier for thei initial size + while (length(samples) < total) + samples <- collect(sampleRDD(x, withReplacement, fraction, + as.integer(ceiling(runif(1, + -MAXINT, + MAXINT))))) + + # TODO(zongheng): investigate if this call is an in-place shuffle? + sample(samples)[1:total] + }) + +#' Creates tuples of the elements in this RDD by applying a function. +#' +#' @param x The RDD. +#' @param func The function to be applied. +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' rdd <- parallelize(sc, list(1, 2, 3)) +#' collect(keyBy(rdd, function(x) { x*x })) # list(list(1, 1), list(4, 2), list(9, 3)) +#'} +#' @rdname keyBy +#' @aliases keyBy,RDD +setMethod("keyBy", + signature(x = "RDD", func = "function"), + function(x, func) { + apply.func <- function(x) { + list(func(x), x) + } + lapply(x, apply.func) + }) + +#' Return a new RDD that has exactly numPartitions partitions. +#' Can increase or decrease the level of parallelism in this RDD. Internally, +#' this uses a shuffle to redistribute data. +#' If you are decreasing the number of partitions in this RDD, consider using +#' coalesce, which can avoid performing a shuffle. +#' +#' @param x The RDD. +#' @param numPartitions Number of partitions to create. +#' @seealso coalesce +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' rdd <- parallelize(sc, list(1, 2, 3, 4, 5, 6, 7), 4L) +#' numPartitions(rdd) # 4 +#' numPartitions(repartition(rdd, 2L)) # 2 +#'} +#' @rdname repartition +#' @aliases repartition,RDD +setMethod("repartition", + signature(x = "RDD", numPartitions = "numeric"), + function(x, numPartitions) { + coalesce(x, numToInt(numPartitions), TRUE) + }) + +#' Return a new RDD that is reduced into numPartitions partitions. +#' +#' @param x The RDD. +#' @param numPartitions Number of partitions to create. +#' @seealso repartition +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' rdd <- parallelize(sc, list(1, 2, 3, 4, 5), 3L) +#' numPartitions(rdd) # 3 +#' numPartitions(coalesce(rdd, 1L)) # 1 +#'} +#' @rdname coalesce +#' @aliases coalesce,RDD +setMethod("coalesce", + signature(x = "RDD", numPartitions = "numeric"), + function(x, numPartitions, shuffle = FALSE) { + numPartitions <- numToInt(numPartitions) + if (shuffle || numPartitions > SparkR::numPartitions(x)) { + func <- function(s, part) { + set.seed(s) # split as seed + start <- as.integer(sample(numPartitions, 1) - 1) + lapply(seq_along(part), + function(i) { + pos <- (start + i) %% numPartitions + list(pos, part[[i]]) + }) + } + shuffled <- lapplyPartitionsWithIndex(x, func) + repartitioned <- partitionBy(shuffled, numPartitions) + values(repartitioned) + } else { + jrdd <- callJMethod(getJRDD(x), "coalesce", numPartitions, shuffle) + RDD(jrdd) + } + }) + +#' Save this RDD as a SequenceFile of serialized objects. +#' +#' @param x The RDD to save +#' @param path The directory where the file is saved +#' @seealso objectFile +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' rdd <- parallelize(sc, 1:3) +#' saveAsObjectFile(rdd, "/tmp/sparkR-tmp") +#'} +#' @rdname saveAsObjectFile +#' @aliases saveAsObjectFile,RDD +setMethod("saveAsObjectFile", + signature(x = "RDD", path = "character"), + function(x, path) { + # If serializedMode == "string" we need to serialize the data before saving it since + # objectFile() assumes serializedMode == "byte". + if (getSerializedMode(x) != "byte") { + x <- serializeToBytes(x) + } + # Return nothing + invisible(callJMethod(getJRDD(x), "saveAsObjectFile", path)) + }) + +#' Save this RDD as a text file, using string representations of elements. +#' +#' @param x The RDD to save +#' @param path The directory where the splits of the text file are saved +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' rdd <- parallelize(sc, 1:3) +#' saveAsTextFile(rdd, "/tmp/sparkR-tmp") +#'} +#' @rdname saveAsTextFile +#' @aliases saveAsTextFile,RDD +setMethod("saveAsTextFile", + signature(x = "RDD", path = "character"), + function(x, path) { + func <- function(str) { + toString(str) + } + stringRdd <- lapply(x, func) + # Return nothing + invisible( + callJMethod(getJRDD(stringRdd, serializedMode = "string"), "saveAsTextFile", path)) + }) + +#' Sort an RDD by the given key function. +#' +#' @param x An RDD to be sorted. +#' @param func A function used to compute the sort key for each element. +#' @param ascending A flag to indicate whether the sorting is ascending or descending. +#' @param numPartitions Number of partitions to create. +#' @return An RDD where all elements are sorted. +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' rdd <- parallelize(sc, list(3, 2, 1)) +#' collect(sortBy(rdd, function(x) { x })) # list (1, 2, 3) +#'} +#' @rdname sortBy +#' @aliases sortBy,RDD,RDD-method +setMethod("sortBy", + signature(x = "RDD", func = "function"), + function(x, func, ascending = TRUE, numPartitions = SparkR::numPartitions(x)) { + values(sortByKey(keyBy(x, func), ascending, numPartitions)) + }) + +# Helper function to get first N elements from an RDD in the specified order. +# Param: +# x An RDD. +# num Number of elements to return. +# ascending A flag to indicate whether the sorting is ascending or descending. +# Return: +# A list of the first N elements from the RDD in the specified order. +# +takeOrderedElem <- function(x, num, ascending = TRUE) { + if (num <= 0L) { + return(list()) + } + + partitionFunc <- function(part) { + if (num < length(part)) { + # R limitation: order works only on primitive types! + ord <- order(unlist(part, recursive = FALSE), decreasing = !ascending) + list(part[ord[1:num]]) + } else { + list(part) + } + } + + reduceFunc <- function(elems, part) { + newElems <- append(elems, part) + # R limitation: order works only on primitive types! + ord <- order(unlist(newElems, recursive = FALSE), decreasing = !ascending) + newElems[ord[1:num]] + } + + newRdd <- mapPartitions(x, partitionFunc) + reduce(newRdd, reduceFunc) +} + +#' Returns the first N elements from an RDD in ascending order. +#' +#' @param x An RDD. +#' @param num Number of elements to return. +#' @return The first N elements from the RDD in ascending order. +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' rdd <- parallelize(sc, list(10, 1, 2, 9, 3, 4, 5, 6, 7)) +#' takeOrdered(rdd, 6L) # list(1, 2, 3, 4, 5, 6) +#'} +#' @rdname takeOrdered +#' @aliases takeOrdered,RDD,RDD-method +setMethod("takeOrdered", + signature(x = "RDD", num = "integer"), + function(x, num) { + takeOrderedElem(x, num) + }) + +#' Returns the top N elements from an RDD. +#' +#' @param x An RDD. +#' @param num Number of elements to return. +#' @return The top N elements from the RDD. +#' @rdname top +#' @export +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' rdd <- parallelize(sc, list(10, 1, 2, 9, 3, 4, 5, 6, 7)) +#' top(rdd, 6L) # list(10, 9, 7, 6, 5, 4) +#'} +#' @rdname top +#' @aliases top,RDD,RDD-method +setMethod("top", + signature(x = "RDD", num = "integer"), + function(x, num) { + takeOrderedElem(x, num, FALSE) + }) + +#' Fold an RDD using a given associative function and a neutral "zero value". +#' +#' Aggregate the elements of each partition, and then the results for all the +#' partitions, using a given associative function and a neutral "zero value". +#' +#' @param x An RDD. +#' @param zeroValue A neutral "zero value". +#' @param op An associative function for the folding operation. +#' @return The folding result. +#' @rdname fold +#' @seealso reduce +#' @export +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' rdd <- parallelize(sc, list(1, 2, 3, 4, 5)) +#' fold(rdd, 0, "+") # 15 +#'} +#' @rdname fold +#' @aliases fold,RDD,RDD-method +setMethod("fold", + signature(x = "RDD", zeroValue = "ANY", op = "ANY"), + function(x, zeroValue, op) { + aggregateRDD(x, zeroValue, op, op) + }) + +#' Aggregate an RDD using the given combine functions and a neutral "zero value". +#' +#' Aggregate the elements of each partition, and then the results for all the +#' partitions, using given combine functions and a neutral "zero value". +#' +#' @param x An RDD. +#' @param zeroValue A neutral "zero value". +#' @param seqOp A function to aggregate the RDD elements. It may return a different +#' result type from the type of the RDD elements. +#' @param combOp A function to aggregate results of seqOp. +#' @return The aggregation result. +#' @rdname aggregateRDD +#' @seealso reduce +#' @export +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' rdd <- parallelize(sc, list(1, 2, 3, 4)) +#' zeroValue <- list(0, 0) +#' seqOp <- function(x, y) { list(x[[1]] + y, x[[2]] + 1) } +#' combOp <- function(x, y) { list(x[[1]] + y[[1]], x[[2]] + y[[2]]) } +#' aggregateRDD(rdd, zeroValue, seqOp, combOp) # list(10, 4) +#'} +#' @rdname aggregateRDD +#' @aliases aggregateRDD,RDD,RDD-method +setMethod("aggregateRDD", + signature(x = "RDD", zeroValue = "ANY", seqOp = "ANY", combOp = "ANY"), + function(x, zeroValue, seqOp, combOp) { + partitionFunc <- function(part) { + Reduce(seqOp, part, zeroValue) + } + + partitionList <- collect(lapplyPartition(x, partitionFunc), + flatten = FALSE) + Reduce(combOp, partitionList, zeroValue) + }) + +#' Pipes elements to a forked external process. +#' +#' The same as 'pipe()' in Spark. +#' +#' @param x The RDD whose elements are piped to the forked external process. +#' @param command The command to fork an external process. +#' @param env A named list to set environment variables of the external process. +#' @return A new RDD created by piping all elements to a forked external process. +#' @rdname pipeRDD +#' @export +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' rdd <- parallelize(sc, 1:10) +#' collect(pipeRDD(rdd, "more") +#' Output: c("1", "2", ..., "10") +#'} +#' @rdname pipeRDD +#' @aliases pipeRDD,RDD,character-method +setMethod("pipeRDD", + signature(x = "RDD", command = "character"), + function(x, command, env = list()) { + func <- function(part) { + trim.trailing.func <- function(x) { + sub("[\r\n]*$", "", toString(x)) + } + input <- unlist(lapply(part, trim.trailing.func)) + res <- system2(command, stdout = TRUE, input = input, env = env) + lapply(res, trim.trailing.func) + } + lapplyPartition(x, func) + }) + +# TODO: Consider caching the name in the RDD's environment +#' Return an RDD's name. +#' +#' @param x The RDD whose name is returned. +#' @rdname name +#' @export +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' rdd <- parallelize(sc, list(1,2,3)) +#' name(rdd) # NULL (if not set before) +#'} +#' @rdname name +#' @aliases name,RDD +setMethod("name", + signature(x = "RDD"), + function(x) { + callJMethod(getJRDD(x), "name") + }) + +#' Set an RDD's name. +#' +#' @param x The RDD whose name is to be set. +#' @param name The RDD name to be set. +#' @return a new RDD renamed. +#' @rdname setName +#' @export +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' rdd <- parallelize(sc, list(1,2,3)) +#' setName(rdd, "myRDD") +#' name(rdd) # "myRDD" +#'} +#' @rdname setName +#' @aliases setName,RDD +setMethod("setName", + signature(x = "RDD", name = "character"), + function(x, name) { + callJMethod(getJRDD(x), "setName", name) + x + }) + +#' Zip an RDD with generated unique Long IDs. +#' +#' Items in the kth partition will get ids k, n+k, 2*n+k, ..., where +#' n is the number of partitions. So there may exist gaps, but this +#' method won't trigger a spark job, which is different from +#' zipWithIndex. +#' +#' @param x An RDD to be zipped. +#' @return An RDD with zipped items. +#' @seealso zipWithIndex +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' rdd <- parallelize(sc, list("a", "b", "c", "d", "e"), 3L) +#' collect(zipWithUniqueId(rdd)) +#' # list(list("a", 0), list("b", 3), list("c", 1), list("d", 4), list("e", 2)) +#'} +#' @rdname zipWithUniqueId +#' @aliases zipWithUniqueId,RDD +setMethod("zipWithUniqueId", + signature(x = "RDD"), + function(x) { + n <- numPartitions(x) + + partitionFunc <- function(split, part) { + mapply( + function(item, index) { + list(item, (index - 1) * n + split) + }, + part, + seq_along(part), + SIMPLIFY = FALSE) + } + + lapplyPartitionsWithIndex(x, partitionFunc) + }) + +#' Zip an RDD with its element indices. +#' +#' The ordering is first based on the partition index and then the +#' ordering of items within each partition. So the first item in +#' the first partition gets index 0, and the last item in the last +#' partition receives the largest index. +#' +#' This method needs to trigger a Spark job when this RDD contains +#' more than one partition. +#' +#' @param x An RDD to be zipped. +#' @return An RDD with zipped items. +#' @seealso zipWithUniqueId +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' rdd <- parallelize(sc, list("a", "b", "c", "d", "e"), 3L) +#' collect(zipWithIndex(rdd)) +#' # list(list("a", 0), list("b", 1), list("c", 2), list("d", 3), list("e", 4)) +#'} +#' @rdname zipWithIndex +#' @aliases zipWithIndex,RDD +setMethod("zipWithIndex", + signature(x = "RDD"), + function(x) { + n <- numPartitions(x) + if (n > 1) { + nums <- collect(lapplyPartition(x, + function(part) { + list(length(part)) + })) + startIndices <- Reduce("+", nums, accumulate = TRUE) + } + + partitionFunc <- function(split, part) { + if (split == 0) { + startIndex <- 0 + } else { + startIndex <- startIndices[[split]] + } + + mapply( + function(item, index) { + list(item, index - 1 + startIndex) + }, + part, + seq_along(part), + SIMPLIFY = FALSE) + } + + lapplyPartitionsWithIndex(x, partitionFunc) + }) + +#' Coalesce all elements within each partition of an RDD into a list. +#' +#' @param x An RDD. +#' @return An RDD created by coalescing all elements within +#' each partition into a list. +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' rdd <- parallelize(sc, as.list(1:4), 2L) +#' collect(glom(rdd)) +#' # list(list(1, 2), list(3, 4)) +#'} +#' @rdname glom +#' @aliases glom,RDD +setMethod("glom", + signature(x = "RDD"), + function(x) { + partitionFunc <- function(part) { + list(part) + } + + lapplyPartition(x, partitionFunc) + }) + +############ Binary Functions ############# + +#' Return the union RDD of two RDDs. +#' The same as union() in Spark. +#' +#' @param x An RDD. +#' @param y An RDD. +#' @return a new RDD created by performing the simple union (witout removing +#' duplicates) of two input RDDs. +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' rdd <- parallelize(sc, 1:3) +#' unionRDD(rdd, rdd) # 1, 2, 3, 1, 2, 3 +#'} +#' @rdname unionRDD +#' @aliases unionRDD,RDD,RDD-method +setMethod("unionRDD", + signature(x = "RDD", y = "RDD"), + function(x, y) { + if (getSerializedMode(x) == getSerializedMode(y)) { + jrdd <- callJMethod(getJRDD(x), "union", getJRDD(y)) + union.rdd <- RDD(jrdd, getSerializedMode(x)) + } else { + # One of the RDDs is not serialized, we need to serialize it first. + if (getSerializedMode(x) != "byte") x <- serializeToBytes(x) + if (getSerializedMode(y) != "byte") y <- serializeToBytes(y) + jrdd <- callJMethod(getJRDD(x), "union", getJRDD(y)) + union.rdd <- RDD(jrdd, "byte") + } + union.rdd + }) + +#' Zip an RDD with another RDD. +#' +#' Zips this RDD with another one, returning key-value pairs with the +#' first element in each RDD second element in each RDD, etc. Assumes +#' that the two RDDs have the same number of partitions and the same +#' number of elements in each partition (e.g. one was made through +#' a map on the other). +#' +#' @param x An RDD to be zipped. +#' @param other Another RDD to be zipped. +#' @return An RDD zipped from the two RDDs. +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' rdd1 <- parallelize(sc, 0:4) +#' rdd2 <- parallelize(sc, 1000:1004) +#' collect(zipRDD(rdd1, rdd2)) +#' # list(list(0, 1000), list(1, 1001), list(2, 1002), list(3, 1003), list(4, 1004)) +#'} +#' @rdname zipRDD +#' @aliases zipRDD,RDD +setMethod("zipRDD", + signature(x = "RDD", other = "RDD"), + function(x, other) { + n1 <- numPartitions(x) + n2 <- numPartitions(other) + if (n1 != n2) { + stop("Can only zip RDDs which have the same number of partitions.") + } + + if (getSerializedMode(x) != getSerializedMode(other) || + getSerializedMode(x) == "byte") { + # Append the number of elements in each partition to that partition so that we can later + # check if corresponding partitions of both RDDs have the same number of elements. + # + # Note that this appending also serves the purpose of reserialization, because even if + # any RDD is serialized, we need to reserialize it to make sure its partitions are encoded + # as a single byte array. For example, partitions of an RDD generated from partitionBy() + # may be encoded as multiple byte arrays. + appendLength <- function(part) { + part[[length(part) + 1]] <- length(part) + 1 + part + } + x <- lapplyPartition(x, appendLength) + other <- lapplyPartition(other, appendLength) + } + + zippedJRDD <- callJMethod(getJRDD(x), "zip", getJRDD(other)) + # The zippedRDD's elements are of scala Tuple2 type. The serialized + # flag Here is used for the elements inside the tuples. + serializerMode <- getSerializedMode(x) + zippedRDD <- RDD(zippedJRDD, serializerMode) + + partitionFunc <- function(split, part) { + len <- length(part) + if (len > 0) { + if (serializerMode == "byte") { + lengthOfValues <- part[[len]] + lengthOfKeys <- part[[len - lengthOfValues]] + stopifnot(len == lengthOfKeys + lengthOfValues) + + # check if corresponding partitions of both RDDs have the same number of elements. + if (lengthOfKeys != lengthOfValues) { + stop("Can only zip RDDs with same number of elements in each pair of corresponding partitions.") + } + + if (lengthOfKeys > 1) { + keys <- part[1 : (lengthOfKeys - 1)] + values <- part[(lengthOfKeys + 1) : (len - 1)] + } else { + keys <- list() + values <- list() + } + } else { + # Keys, values must have same length here, because this has + # been validated inside the JavaRDD.zip() function. + keys <- part[c(TRUE, FALSE)] + values <- part[c(FALSE, TRUE)] + } + mapply( + function(k, v) { + list(k, v) + }, + keys, + values, + SIMPLIFY = FALSE, + USE.NAMES = FALSE) + } else { + part + } + } + + PipelinedRDD(zippedRDD, partitionFunc) + }) diff --git a/R/pkg/R/SQLContext.R b/R/pkg/R/SQLContext.R new file mode 100644 index 0000000000000..930ada22f4c38 --- /dev/null +++ b/R/pkg/R/SQLContext.R @@ -0,0 +1,520 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +# SQLcontext.R: SQLContext-driven functions + +#' infer the SQL type +infer_type <- function(x) { + if (is.null(x)) { + stop("can not infer type from NULL") + } + + # class of POSIXlt is c("POSIXlt" "POSIXt") + type <- switch(class(x)[[1]], + integer = "integer", + character = "string", + logical = "boolean", + double = "double", + numeric = "double", + raw = "binary", + list = "array", + environment = "map", + Date = "date", + POSIXlt = "timestamp", + POSIXct = "timestamp", + stop(paste("Unsupported type for DataFrame:", class(x)))) + + if (type == "map") { + stopifnot(length(x) > 0) + key <- ls(x)[[1]] + list(type = "map", + keyType = "string", + valueType = infer_type(get(key, x)), + valueContainsNull = TRUE) + } else if (type == "array") { + stopifnot(length(x) > 0) + names <- names(x) + if (is.null(names)) { + list(type = "array", elementType = infer_type(x[[1]]), containsNull = TRUE) + } else { + # StructType + types <- lapply(x, infer_type) + fields <- lapply(1:length(x), function(i) { + list(name = names[[i]], type = types[[i]], nullable = TRUE) + }) + list(type = "struct", fields = fields) + } + } else if (length(x) > 1) { + list(type = "array", elementType = type, containsNull = TRUE) + } else { + type + } +} + +#' dump the schema into JSON string +tojson <- function(x) { + if (is.list(x)) { + names <- names(x) + if (!is.null(names)) { + items <- lapply(names, function(n) { + safe_n <- gsub('"', '\\"', n) + paste(tojson(safe_n), ':', tojson(x[[n]]), sep = '') + }) + d <- paste(items, collapse = ', ') + paste('{', d, '}', sep = '') + } else { + l <- paste(lapply(x, tojson), collapse = ', ') + paste('[', l, ']', sep = '') + } + } else if (is.character(x)) { + paste('"', x, '"', sep = '') + } else if (is.logical(x)) { + if (x) "true" else "false" + } else { + stop(paste("unexpected type:", class(x))) + } +} + +#' Create a DataFrame from an RDD +#' +#' Converts an RDD to a DataFrame by infer the types. +#' +#' @param sqlCtx A SQLContext +#' @param data An RDD or list or data.frame +#' @param schema a list of column names or named list (StructType), optional +#' @return an DataFrame +#' @export +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' sqlCtx <- sparkRSQL.init(sc) +#' rdd <- lapply(parallelize(sc, 1:10), function(x) list(a=x, b=as.character(x))) +#' df <- createDataFrame(sqlCtx, rdd) +#' } + +# TODO(davies): support sampling and infer type from NA +createDataFrame <- function(sqlCtx, data, schema = NULL, samplingRatio = 1.0) { + if (is.data.frame(data)) { + # get the names of columns, they will be put into RDD + schema <- names(data) + n <- nrow(data) + m <- ncol(data) + # get rid of factor type + dropFactor <- function(x) { + if (is.factor(x)) { + as.character(x) + } else { + x + } + } + data <- lapply(1:n, function(i) { + lapply(1:m, function(j) { dropFactor(data[i,j]) }) + }) + } + if (is.list(data)) { + sc <- callJStatic("org.apache.spark.sql.api.r.SQLUtils", "getJavaSparkContext", sqlCtx) + rdd <- parallelize(sc, data) + } else if (inherits(data, "RDD")) { + rdd <- data + } else { + stop(paste("unexpected type:", class(data))) + } + + if (is.null(schema) || is.null(names(schema))) { + row <- first(rdd) + names <- if (is.null(schema)) { + names(row) + } else { + as.list(schema) + } + if (is.null(names)) { + names <- lapply(1:length(row), function(x) { + paste("_", as.character(x), sep = "") + }) + } + + # SPAKR-SQL does not support '.' in column name, so replace it with '_' + # TODO(davies): remove this once SPARK-2775 is fixed + names <- lapply(names, function(n) { + nn <- gsub("[.]", "_", n) + if (nn != n) { + warning(paste("Use", nn, "instead of", n, " as column name")) + } + nn + }) + + types <- lapply(row, infer_type) + fields <- lapply(1:length(row), function(i) { + list(name = names[[i]], type = types[[i]], nullable = TRUE) + }) + schema <- list(type = "struct", fields = fields) + } + + stopifnot(class(schema) == "list") + stopifnot(schema$type == "struct") + stopifnot(class(schema$fields) == "list") + schemaString <- tojson(schema) + + jrdd <- getJRDD(lapply(rdd, function(x) x), "row") + srdd <- callJMethod(jrdd, "rdd") + sdf <- callJStatic("org.apache.spark.sql.api.r.SQLUtils", "createDF", + srdd, schemaString, sqlCtx) + dataFrame(sdf) +} + +#' toDF +#' +#' Converts an RDD to a DataFrame by infer the types. +#' +#' @param x An RDD +#' +#' @rdname DataFrame +#' @export +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' sqlCtx <- sparkRSQL.init(sc) +#' rdd <- lapply(parallelize(sc, 1:10), function(x) list(a=x, b=as.character(x))) +#' df <- toDF(rdd) +#' } + +setGeneric("toDF", function(x, ...) { standardGeneric("toDF") }) + +setMethod("toDF", signature(x = "RDD"), + function(x, ...) { + sqlCtx <- if (exists(".sparkRHivesc", envir = .sparkREnv)) { + get(".sparkRHivesc", envir = .sparkREnv) + } else if (exists(".sparkRSQLsc", envir = .sparkREnv)) { + get(".sparkRSQLsc", envir = .sparkREnv) + } else { + stop("no SQL context available") + } + createDataFrame(sqlCtx, x, ...) + }) + +#' Create a DataFrame from a JSON file. +#' +#' Loads a JSON file (one object per line), returning the result as a DataFrame +#' It goes through the entire dataset once to determine the schema. +#' +#' @param sqlCtx SQLContext to use +#' @param path Path of file to read. A vector of multiple paths is allowed. +#' @return DataFrame +#' @export +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' sqlCtx <- sparkRSQL.init(sc) +#' path <- "path/to/file.json" +#' df <- jsonFile(sqlCtx, path) +#' } + +jsonFile <- function(sqlCtx, path) { + # Allow the user to have a more flexible definiton of the text file path + path <- normalizePath(path) + # Convert a string vector of paths to a string containing comma separated paths + path <- paste(path, collapse = ",") + sdf <- callJMethod(sqlCtx, "jsonFile", path) + dataFrame(sdf) +} + + +#' JSON RDD +#' +#' Loads an RDD storing one JSON object per string as a DataFrame. +#' +#' @param sqlCtx SQLContext to use +#' @param rdd An RDD of JSON string +#' @param schema A StructType object to use as schema +#' @param samplingRatio The ratio of simpling used to infer the schema +#' @return A DataFrame +#' @export +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' sqlCtx <- sparkRSQL.init(sc) +#' rdd <- texFile(sc, "path/to/json") +#' df <- jsonRDD(sqlCtx, rdd) +#' } + +# TODO: support schema +jsonRDD <- function(sqlCtx, rdd, schema = NULL, samplingRatio = 1.0) { + rdd <- serializeToString(rdd) + if (is.null(schema)) { + sdf <- callJMethod(sqlCtx, "jsonRDD", callJMethod(getJRDD(rdd), "rdd"), samplingRatio) + dataFrame(sdf) + } else { + stop("not implemented") + } +} + + +#' Create a DataFrame from a Parquet file. +#' +#' Loads a Parquet file, returning the result as a DataFrame. +#' +#' @param sqlCtx SQLContext to use +#' @param ... Path(s) of parquet file(s) to read. +#' @return DataFrame +#' @export + +# TODO: Implement saveasParquetFile and write examples for both +parquetFile <- function(sqlCtx, ...) { + # Allow the user to have a more flexible definiton of the text file path + paths <- lapply(list(...), normalizePath) + sdf <- callJMethod(sqlCtx, "parquetFile", paths) + dataFrame(sdf) +} + +#' SQL Query +#' +#' Executes a SQL query using Spark, returning the result as a DataFrame. +#' +#' @param sqlCtx SQLContext to use +#' @param sqlQuery A character vector containing the SQL query +#' @return DataFrame +#' @export +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' sqlCtx <- sparkRSQL.init(sc) +#' path <- "path/to/file.json" +#' df <- jsonFile(sqlCtx, path) +#' registerTempTable(df, "table") +#' new_df <- sql(sqlCtx, "SELECT * FROM table") +#' } + +sql <- function(sqlCtx, sqlQuery) { + sdf <- callJMethod(sqlCtx, "sql", sqlQuery) + dataFrame(sdf) +} + + +#' Create a DataFrame from a SparkSQL Table +#' +#' Returns the specified Table as a DataFrame. The Table must have already been registered +#' in the SQLContext. +#' +#' @param sqlCtx SQLContext to use +#' @param tableName The SparkSQL Table to convert to a DataFrame. +#' @return DataFrame +#' @export +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' sqlCtx <- sparkRSQL.init(sc) +#' path <- "path/to/file.json" +#' df <- jsonFile(sqlCtx, path) +#' registerTempTable(df, "table") +#' new_df <- table(sqlCtx, "table") +#' } + +table <- function(sqlCtx, tableName) { + sdf <- callJMethod(sqlCtx, "table", tableName) + dataFrame(sdf) +} + + +#' Tables +#' +#' Returns a DataFrame containing names of tables in the given database. +#' +#' @param sqlCtx SQLContext to use +#' @param databaseName name of the database +#' @return a DataFrame +#' @export +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' sqlCtx <- sparkRSQL.init(sc) +#' tables(sqlCtx, "hive") +#' } + +tables <- function(sqlCtx, databaseName = NULL) { + jdf <- if (is.null(databaseName)) { + callJMethod(sqlCtx, "tables") + } else { + callJMethod(sqlCtx, "tables", databaseName) + } + dataFrame(jdf) +} + + +#' Table Names +#' +#' Returns the names of tables in the given database as an array. +#' +#' @param sqlCtx SQLContext to use +#' @param databaseName name of the database +#' @return a list of table names +#' @export +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' sqlCtx <- sparkRSQL.init(sc) +#' tableNames(sqlCtx, "hive") +#' } + +tableNames <- function(sqlCtx, databaseName = NULL) { + if (is.null(databaseName)) { + callJMethod(sqlCtx, "tableNames") + } else { + callJMethod(sqlCtx, "tableNames", databaseName) + } +} + + +#' Cache Table +#' +#' Caches the specified table in-memory. +#' +#' @param sqlCtx SQLContext to use +#' @param tableName The name of the table being cached +#' @return DataFrame +#' @export +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' sqlCtx <- sparkRSQL.init(sc) +#' path <- "path/to/file.json" +#' df <- jsonFile(sqlCtx, path) +#' registerTempTable(df, "table") +#' cacheTable(sqlCtx, "table") +#' } + +cacheTable <- function(sqlCtx, tableName) { + callJMethod(sqlCtx, "cacheTable", tableName) +} + +#' Uncache Table +#' +#' Removes the specified table from the in-memory cache. +#' +#' @param sqlCtx SQLContext to use +#' @param tableName The name of the table being uncached +#' @return DataFrame +#' @export +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' sqlCtx <- sparkRSQL.init(sc) +#' path <- "path/to/file.json" +#' df <- jsonFile(sqlCtx, path) +#' registerTempTable(df, "table") +#' uncacheTable(sqlCtx, "table") +#' } + +uncacheTable <- function(sqlCtx, tableName) { + callJMethod(sqlCtx, "uncacheTable", tableName) +} + +#' Clear Cache +#' +#' Removes all cached tables from the in-memory cache. +#' +#' @param sqlCtx SQLContext to use +#' @examples +#' \dontrun{ +#' clearCache(sqlCtx) +#' } + +clearCache <- function(sqlCtx) { + callJMethod(sqlCtx, "clearCache") +} + +#' Drop Temporary Table +#' +#' Drops the temporary table with the given table name in the catalog. +#' If the table has been cached/persisted before, it's also unpersisted. +#' +#' @param sqlCtx SQLContext to use +#' @param tableName The name of the SparkSQL table to be dropped. +#' @examples +#' \dontrun{ +#' sc <- sparkR.init() +#' sqlCtx <- sparkRSQL.init(sc) +#' df <- loadDF(sqlCtx, path, "parquet") +#' registerTempTable(df, "table") +#' dropTempTable(sqlCtx, "table") +#' } + +dropTempTable <- function(sqlCtx, tableName) { + if (class(tableName) != "character") { + stop("tableName must be a string.") + } + callJMethod(sqlCtx, "dropTempTable", tableName) +} + +#' Load an DataFrame +#' +#' Returns the dataset in a data source as a DataFrame +#' +#' The data source is specified by the `source` and a set of options(...). +#' If `source` is not specified, the default data source configured by +#' "spark.sql.sources.default" will be used. +#' +#' @param sqlCtx SQLContext to use +#' @param path The path of files to load +#' @param source the name of external data source +#' @return DataFrame +#' @export +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' sqlCtx <- sparkRSQL.init(sc) +#' df <- load(sqlCtx, "path/to/file.json", source = "json") +#' } + +loadDF <- function(sqlCtx, path = NULL, source = NULL, ...) { + options <- varargsToEnv(...) + if (!is.null(path)) { + options[['path']] <- path + } + sdf <- callJMethod(sqlCtx, "load", source, options) + dataFrame(sdf) +} + +#' Create an external table +#' +#' Creates an external table based on the dataset in a data source, +#' Returns the DataFrame associated with the external table. +#' +#' The data source is specified by the `source` and a set of options(...). +#' If `source` is not specified, the default data source configured by +#' "spark.sql.sources.default" will be used. +#' +#' @param sqlCtx SQLContext to use +#' @param tableName A name of the table +#' @param path The path of files to load +#' @param source the name of external data source +#' @return DataFrame +#' @export +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' sqlCtx <- sparkRSQL.init(sc) +#' df <- sparkRSQL.createExternalTable(sqlCtx, "myjson", path="path/to/json", source="json") +#' } + +createExternalTable <- function(sqlCtx, tableName, path = NULL, source = NULL, ...) { + options <- varargsToEnv(...) + if (!is.null(path)) { + options[['path']] <- path + } + sdf <- callJMethod(sqlCtx, "createExternalTable", tableName, source, options) + dataFrame(sdf) +} diff --git a/R/pkg/R/SQLTypes.R b/R/pkg/R/SQLTypes.R new file mode 100644 index 0000000000000..962fba5b3cf03 --- /dev/null +++ b/R/pkg/R/SQLTypes.R @@ -0,0 +1,64 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +# Utility functions for handling SparkSQL DataTypes. + +# Handler for StructType +structType <- function(st) { + obj <- structure(new.env(parent = emptyenv()), class = "structType") + obj$jobj <- st + obj$fields <- function() { lapply(callJMethod(st, "fields"), structField) } + obj +} + +#' Print a Spark StructType. +#' +#' This function prints the contents of a StructType returned from the +#' SparkR JVM backend. +#' +#' @param x A StructType object +#' @param ... further arguments passed to or from other methods +print.structType <- function(x, ...) { + fieldsList <- lapply(x$fields(), function(i) { i$print() }) + print(fieldsList) +} + +# Handler for StructField +structField <- function(sf) { + obj <- structure(new.env(parent = emptyenv()), class = "structField") + obj$jobj <- sf + obj$name <- function() { callJMethod(sf, "name") } + obj$dataType <- function() { callJMethod(sf, "dataType") } + obj$dataType.toString <- function() { callJMethod(obj$dataType(), "toString") } + obj$dataType.simpleString <- function() { callJMethod(obj$dataType(), "simpleString") } + obj$nullable <- function() { callJMethod(sf, "nullable") } + obj$print <- function() { paste("StructField(", + paste(obj$name(), obj$dataType.toString(), obj$nullable(), sep = ", "), + ")", sep = "") } + obj +} + +#' Print a Spark StructField. +#' +#' This function prints the contents of a StructField returned from the +#' SparkR JVM backend. +#' +#' @param x A StructField object +#' @param ... further arguments passed to or from other methods +print.structField <- function(x, ...) { + cat(x$print()) +} diff --git a/R/pkg/R/backend.R b/R/pkg/R/backend.R new file mode 100644 index 0000000000000..2fb6fae55f28c --- /dev/null +++ b/R/pkg/R/backend.R @@ -0,0 +1,115 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +# Methods to call into SparkRBackend. + + +# Returns TRUE if object is an instance of given class +isInstanceOf <- function(jobj, className) { + stopifnot(class(jobj) == "jobj") + cls <- callJStatic("java.lang.Class", "forName", className) + callJMethod(cls, "isInstance", jobj) +} + +# Call a Java method named methodName on the object +# specified by objId. objId should be a "jobj" returned +# from the SparkRBackend. +callJMethod <- function(objId, methodName, ...) { + stopifnot(class(objId) == "jobj") + if (!isValidJobj(objId)) { + stop("Invalid jobj ", objId$id, + ". If SparkR was restarted, Spark operations need to be re-executed.") + } + invokeJava(isStatic = FALSE, objId$id, methodName, ...) +} + +# Call a static method on a specified className +callJStatic <- function(className, methodName, ...) { + invokeJava(isStatic = TRUE, className, methodName, ...) +} + +# Create a new object of the specified class name +newJObject <- function(className, ...) { + invokeJava(isStatic = TRUE, className, methodName = "", ...) +} + +# Remove an object from the SparkR backend. This is done +# automatically when a jobj is garbage collected. +removeJObject <- function(objId) { + invokeJava(isStatic = TRUE, "SparkRHandler", "rm", objId) +} + +isRemoveMethod <- function(isStatic, objId, methodName) { + isStatic == TRUE && objId == "SparkRHandler" && methodName == "rm" +} + +# Invoke a Java method on the SparkR backend. Users +# should typically use one of the higher level methods like +# callJMethod, callJStatic etc. instead of using this. +# +# isStatic - TRUE if the method to be called is static +# objId - String that refers to the object on which method is invoked +# Should be a jobj id for non-static methods and the classname +# for static methods +# methodName - name of method to be invoked +invokeJava <- function(isStatic, objId, methodName, ...) { + if (!exists(".sparkRCon", .sparkREnv)) { + stop("No connection to backend found. Please re-run sparkR.init") + } + + # If this isn't a removeJObject call + if (!isRemoveMethod(isStatic, objId, methodName)) { + objsToRemove <- ls(.toRemoveJobjs) + if (length(objsToRemove) > 0) { + sapply(objsToRemove, + function(e) { + removeJObject(e) + }) + rm(list = objsToRemove, envir = .toRemoveJobjs) + } + } + + + rc <- rawConnection(raw(0), "r+") + + writeBoolean(rc, isStatic) + writeString(rc, objId) + writeString(rc, methodName) + + args <- list(...) + writeInt(rc, length(args)) + writeArgs(rc, args) + + # Construct the whole request message to send it once, + # avoiding write-write-read pattern in case of Nagle's algorithm. + # Refer to http://en.wikipedia.org/wiki/Nagle%27s_algorithm for the details. + bytesToSend <- rawConnectionValue(rc) + close(rc) + rc <- rawConnection(raw(0), "r+") + writeInt(rc, length(bytesToSend)) + writeBin(bytesToSend, rc) + requestMessage <- rawConnectionValue(rc) + close(rc) + + conn <- get(".sparkRCon", .sparkREnv) + writeBin(requestMessage, conn) + + # TODO: check the status code to output error information + returnStatus <- readInt(conn) + stopifnot(returnStatus == 0) + readObject(conn) +} diff --git a/R/pkg/R/broadcast.R b/R/pkg/R/broadcast.R new file mode 100644 index 0000000000000..583fa2e7fdcfd --- /dev/null +++ b/R/pkg/R/broadcast.R @@ -0,0 +1,86 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +# S4 class representing Broadcast variables + +# Hidden environment that holds values for broadcast variables +# This will not be serialized / shipped by default +.broadcastNames <- new.env() +.broadcastValues <- new.env() +.broadcastIdToName <- new.env() + +#' @title S4 class that represents a Broadcast variable +#' @description Broadcast variables can be created using the broadcast +#' function from a \code{SparkContext}. +#' @rdname broadcast-class +#' @seealso broadcast +#' +#' @param id Id of the backing Spark broadcast variable +#' @export +setClass("Broadcast", slots = list(id = "character")) + +#' @rdname broadcast-class +#' @param value Value of the broadcast variable +#' @param jBroadcastRef reference to the backing Java broadcast object +#' @param objName name of broadcasted object +#' @export +Broadcast <- function(id, value, jBroadcastRef, objName) { + .broadcastValues[[id]] <- value + .broadcastNames[[as.character(objName)]] <- jBroadcastRef + .broadcastIdToName[[id]] <- as.character(objName) + new("Broadcast", id = id) +} + +#' @description +#' \code{value} can be used to get the value of a broadcast variable inside +#' a distributed function. +#' +#' @param bcast The broadcast variable to get +#' @rdname broadcast +#' @aliases value,Broadcast-method +setMethod("value", + signature(bcast = "Broadcast"), + function(bcast) { + if (exists(bcast@id, envir = .broadcastValues)) { + get(bcast@id, envir = .broadcastValues) + } else { + NULL + } + }) + +#' Internal function to set values of a broadcast variable. +#' +#' This function is used internally by Spark to set the value of a broadcast +#' variable on workers. Not intended for use outside the package. +#' +#' @rdname broadcast-internal +#' @seealso broadcast, value + +#' @param bcastId The id of broadcast variable to set +#' @param value The value to be set +#' @export +setBroadcastValue <- function(bcastId, value) { + bcastIdStr <- as.character(bcastId) + .broadcastValues[[bcastIdStr]] <- value +} + +#' Helper function to clear the list of broadcast variables we know about +#' Should be called when the SparkR JVM backend is shutdown +clearBroadcastVariables <- function() { + bcasts <- ls(.broadcastNames) + rm(list = bcasts, envir = .broadcastNames) +} diff --git a/R/pkg/R/client.R b/R/pkg/R/client.R new file mode 100644 index 0000000000000..1281c41213e32 --- /dev/null +++ b/R/pkg/R/client.R @@ -0,0 +1,57 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +# Client code to connect to SparkRBackend + +# Creates a SparkR client connection object +# if one doesn't already exist +connectBackend <- function(hostname, port, timeout = 6000) { + if (exists(".sparkRcon", envir = .sparkREnv)) { + if (isOpen(.sparkREnv[[".sparkRCon"]])) { + cat("SparkRBackend client connection already exists\n") + return(get(".sparkRcon", envir = .sparkREnv)) + } + } + + con <- socketConnection(host = hostname, port = port, server = FALSE, + blocking = TRUE, open = "wb", timeout = timeout) + + assign(".sparkRCon", con, envir = .sparkREnv) + con +} + +launchBackend <- function(args, sparkHome, jars, sparkSubmitOpts) { + if (.Platform$OS.type == "unix") { + sparkSubmitBinName = "spark-submit" + } else { + sparkSubmitBinName = "spark-submit.cmd" + } + + if (sparkHome != "") { + sparkSubmitBin <- file.path(sparkHome, "bin", sparkSubmitBinName) + } else { + sparkSubmitBin <- sparkSubmitBinName + } + + if (jars != "") { + jars <- paste("--jars", jars) + } + + combinedArgs <- paste(jars, sparkSubmitOpts, args, sep = " ") + cat("Launching java with spark-submit command", sparkSubmitBin, combinedArgs, "\n") + invisible(system2(sparkSubmitBin, combinedArgs, wait = F)) +} diff --git a/R/pkg/R/column.R b/R/pkg/R/column.R new file mode 100644 index 0000000000000..e196305186b9a --- /dev/null +++ b/R/pkg/R/column.R @@ -0,0 +1,199 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +# Column Class + +#' @include generics.R jobj.R +NULL + +setOldClass("jobj") + +#' @title S4 class that represents a DataFrame column +#' @description The column class supports unary, binary operations on DataFrame columns + +#' @rdname column +#' +#' @param jc reference to JVM DataFrame column +#' @export +setClass("Column", + slots = list(jc = "jobj")) + +setMethod("initialize", "Column", function(.Object, jc) { + .Object@jc <- jc + .Object +}) + +column <- function(jc) { + new("Column", jc) +} + +col <- function(x) { + column(callJStatic("org.apache.spark.sql.functions", "col", x)) +} + +#' @rdname show +setMethod("show", "Column", + function(object) { + cat("Column", callJMethod(object@jc, "toString"), "\n") + }) + +operators <- list( + "+" = "plus", "-" = "minus", "*" = "multiply", "/" = "divide", "%%" = "mod", + "==" = "equalTo", ">" = "gt", "<" = "lt", "!=" = "notEqual", "<=" = "leq", ">=" = "geq", + # we can not override `&&` and `||`, so use `&` and `|` instead + "&" = "and", "|" = "or" #, "!" = "unary_$bang" +) +column_functions1 <- c("asc", "desc", "isNull", "isNotNull") +column_functions2 <- c("like", "rlike", "startsWith", "endsWith", "getField", "getItem", "contains") +functions <- c("min", "max", "sum", "avg", "mean", "count", "abs", "sqrt", + "first", "last", "lower", "upper", "sumDistinct") + +createOperator <- function(op) { + setMethod(op, + signature(e1 = "Column"), + function(e1, e2) { + jc <- if (missing(e2)) { + if (op == "-") { + callJMethod(e1@jc, "unary_$minus") + } else { + callJMethod(e1@jc, operators[[op]]) + } + } else { + if (class(e2) == "Column") { + e2 <- e2@jc + } + callJMethod(e1@jc, operators[[op]], e2) + } + column(jc) + }) +} + +createColumnFunction1 <- function(name) { + setMethod(name, + signature(x = "Column"), + function(x) { + column(callJMethod(x@jc, name)) + }) +} + +createColumnFunction2 <- function(name) { + setMethod(name, + signature(x = "Column"), + function(x, data) { + if (class(data) == "Column") { + data <- data@jc + } + jc <- callJMethod(x@jc, name, data) + column(jc) + }) +} + +createStaticFunction <- function(name) { + setMethod(name, + signature(x = "Column"), + function(x) { + jc <- callJStatic("org.apache.spark.sql.functions", name, x@jc) + column(jc) + }) +} + +createMethods <- function() { + for (op in names(operators)) { + createOperator(op) + } + for (name in column_functions1) { + createColumnFunction1(name) + } + for (name in column_functions2) { + createColumnFunction2(name) + } + for (x in functions) { + createStaticFunction(x) + } +} + +createMethods() + +#' alias +#' +#' Set a new name for a column +setMethod("alias", + signature(object = "Column"), + function(object, data) { + if (is.character(data)) { + column(callJMethod(object@jc, "as", data)) + } else { + stop("data should be character") + } + }) + +#' An expression that returns a substring. +#' +#' @param start starting position +#' @param stop ending position +setMethod("substr", signature(x = "Column"), + function(x, start, stop) { + jc <- callJMethod(x@jc, "substr", as.integer(start - 1), as.integer(stop - start + 1)) + column(jc) + }) + +#' Casts the column to a different data type. +#' @examples +#' \dontrun{ +#' cast(df$age, "string") +#' cast(df$name, list(type="array", elementType="byte", containsNull = TRUE)) +#' } +setMethod("cast", + signature(x = "Column"), + function(x, dataType) { + if (is.character(dataType)) { + column(callJMethod(x@jc, "cast", dataType)) + } else if (is.list(dataType)) { + json <- tojson(dataType) + jdataType <- callJStatic("org.apache.spark.sql.types.DataType", "fromJson", json) + column(callJMethod(x@jc, "cast", jdataType)) + } else { + stop("dataType should be character or list") + } + }) + +#' Approx Count Distinct +#' +#' Returns the approximate number of distinct items in a group. +#' +setMethod("approxCountDistinct", + signature(x = "Column"), + function(x, rsd = 0.95) { + jc <- callJStatic("org.apache.spark.sql.functions", "approxCountDistinct", x@jc, rsd) + column(jc) + }) + +#' Count Distinct +#' +#' returns the number of distinct items in a group. +#' +setMethod("countDistinct", + signature(x = "Column"), + function(x, ...) { + jcol <- lapply(list(...), function (x) { + x@jc + }) + jc <- callJStatic("org.apache.spark.sql.functions", "countDistinct", x@jc, + listToSeq(jcol)) + column(jc) + }) + diff --git a/R/pkg/R/context.R b/R/pkg/R/context.R new file mode 100644 index 0000000000000..2fc0bb294bcce --- /dev/null +++ b/R/pkg/R/context.R @@ -0,0 +1,225 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +# context.R: SparkContext driven functions + +getMinSplits <- function(sc, minSplits) { + if (is.null(minSplits)) { + defaultParallelism <- callJMethod(sc, "defaultParallelism") + minSplits <- min(defaultParallelism, 2) + } + as.integer(minSplits) +} + +#' Create an RDD from a text file. +#' +#' This function reads a text file from HDFS, a local file system (available on all +#' nodes), or any Hadoop-supported file system URI, and creates an +#' RDD of strings from it. +#' +#' @param sc SparkContext to use +#' @param path Path of file to read. A vector of multiple paths is allowed. +#' @param minSplits Minimum number of splits to be created. If NULL, the default +#' value is chosen based on available parallelism. +#' @return RDD where each item is of type \code{character} +#' @export +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' lines <- textFile(sc, "myfile.txt") +#'} +textFile <- function(sc, path, minSplits = NULL) { + # Allow the user to have a more flexible definiton of the text file path + path <- suppressWarnings(normalizePath(path)) + #' Convert a string vector of paths to a string containing comma separated paths + path <- paste(path, collapse = ",") + + jrdd <- callJMethod(sc, "textFile", path, getMinSplits(sc, minSplits)) + # jrdd is of type JavaRDD[String] + RDD(jrdd, "string") +} + +#' Load an RDD saved as a SequenceFile containing serialized objects. +#' +#' The file to be loaded should be one that was previously generated by calling +#' saveAsObjectFile() of the RDD class. +#' +#' @param sc SparkContext to use +#' @param path Path of file to read. A vector of multiple paths is allowed. +#' @param minSplits Minimum number of splits to be created. If NULL, the default +#' value is chosen based on available parallelism. +#' @return RDD containing serialized R objects. +#' @seealso saveAsObjectFile +#' @export +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' rdd <- objectFile(sc, "myfile") +#'} +objectFile <- function(sc, path, minSplits = NULL) { + # Allow the user to have a more flexible definiton of the text file path + path <- suppressWarnings(normalizePath(path)) + #' Convert a string vector of paths to a string containing comma separated paths + path <- paste(path, collapse = ",") + + jrdd <- callJMethod(sc, "objectFile", path, getMinSplits(sc, minSplits)) + # Assume the RDD contains serialized R objects. + RDD(jrdd, "byte") +} + +#' Create an RDD from a homogeneous list or vector. +#' +#' This function creates an RDD from a local homogeneous list in R. The elements +#' in the list are split into \code{numSlices} slices and distributed to nodes +#' in the cluster. +#' +#' @param sc SparkContext to use +#' @param coll collection to parallelize +#' @param numSlices number of partitions to create in the RDD +#' @return an RDD created from this collection +#' @export +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' rdd <- parallelize(sc, 1:10, 2) +#' # The RDD should contain 10 elements +#' length(rdd) +#'} +parallelize <- function(sc, coll, numSlices = 1) { + # TODO: bound/safeguard numSlices + # TODO: unit tests for if the split works for all primitives + # TODO: support matrix, data frame, etc + if ((!is.list(coll) && !is.vector(coll)) || is.data.frame(coll)) { + if (is.data.frame(coll)) { + message(paste("context.R: A data frame is parallelized by columns.")) + } else { + if (is.matrix(coll)) { + message(paste("context.R: A matrix is parallelized by elements.")) + } else { + message(paste("context.R: parallelize() currently only supports lists and vectors.", + "Calling as.list() to coerce coll into a list.")) + } + } + coll <- as.list(coll) + } + + if (numSlices > length(coll)) + numSlices <- length(coll) + + sliceLen <- ceiling(length(coll) / numSlices) + slices <- split(coll, rep(1:(numSlices + 1), each = sliceLen)[1:length(coll)]) + + # Serialize each slice: obtain a list of raws, or a list of lists (slices) of + # 2-tuples of raws + serializedSlices <- lapply(slices, serialize, connection = NULL) + + jrdd <- callJStatic("org.apache.spark.api.r.RRDD", + "createRDDFromArray", sc, serializedSlices) + + RDD(jrdd, "byte") +} + +#' Include this specified package on all workers +#' +#' This function can be used to include a package on all workers before the +#' user's code is executed. This is useful in scenarios where other R package +#' functions are used in a function passed to functions like \code{lapply}. +#' NOTE: The package is assumed to be installed on every node in the Spark +#' cluster. +#' +#' @param sc SparkContext to use +#' @param pkg Package name +#' +#' @export +#' @examples +#'\dontrun{ +#' library(Matrix) +#' +#' sc <- sparkR.init() +#' # Include the matrix library we will be using +#' includePackage(sc, Matrix) +#' +#' generateSparse <- function(x) { +#' sparseMatrix(i=c(1, 2, 3), j=c(1, 2, 3), x=c(1, 2, 3)) +#' } +#' +#' rdd <- lapplyPartition(parallelize(sc, 1:2, 2L), generateSparse) +#' collect(rdd) +#'} +includePackage <- function(sc, pkg) { + pkg <- as.character(substitute(pkg)) + if (exists(".packages", .sparkREnv)) { + packages <- .sparkREnv$.packages + } else { + packages <- list() + } + packages <- c(packages, pkg) + .sparkREnv$.packages <- packages +} + +#' @title Broadcast a variable to all workers +#' +#' @description +#' Broadcast a read-only variable to the cluster, returning a \code{Broadcast} +#' object for reading it in distributed functions. +#' +#' @param sc Spark Context to use +#' @param object Object to be broadcast +#' @export +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' rdd <- parallelize(sc, 1:2, 2L) +#' +#' # Large Matrix object that we want to broadcast +#' randomMat <- matrix(nrow=100, ncol=10, data=rnorm(1000)) +#' randomMatBr <- broadcast(sc, randomMat) +#' +#' # Use the broadcast variable inside the function +#' useBroadcast <- function(x) { +#' sum(value(randomMatBr) * x) +#' } +#' sumRDD <- lapply(rdd, useBroadcast) +#'} +broadcast <- function(sc, object) { + objName <- as.character(substitute(object)) + serializedObj <- serialize(object, connection = NULL) + + jBroadcast <- callJMethod(sc, "broadcast", serializedObj) + id <- as.character(callJMethod(jBroadcast, "id")) + + Broadcast(id, object, jBroadcast, objName) +} + +#' @title Set the checkpoint directory +#' +#' Set the directory under which RDDs are going to be checkpointed. The +#' directory must be a HDFS path if running on a cluster. +#' +#' @param sc Spark Context to use +#' @param dirName Directory path +#' @export +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' setCheckpointDir(sc, "~/checkpoints") +#' rdd <- parallelize(sc, 1:2, 2L) +#' checkpoint(rdd) +#'} +setCheckpointDir <- function(sc, dirName) { + invisible(callJMethod(sc, "setCheckpointDir", suppressWarnings(normalizePath(dirName)))) +} diff --git a/R/pkg/R/deserialize.R b/R/pkg/R/deserialize.R new file mode 100644 index 0000000000000..257b435607ce8 --- /dev/null +++ b/R/pkg/R/deserialize.R @@ -0,0 +1,184 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +# Utility functions to deserialize objects from Java. + +# Type mapping from Java to R +# +# void -> NULL +# Int -> integer +# String -> character +# Boolean -> logical +# Double -> double +# Long -> double +# Array[Byte] -> raw +# Date -> Date +# Time -> POSIXct +# +# Array[T] -> list() +# Object -> jobj + +readObject <- function(con) { + # Read type first + type <- readType(con) + readTypedObject(con, type) +} + +readTypedObject <- function(con, type) { + switch (type, + "i" = readInt(con), + "c" = readString(con), + "b" = readBoolean(con), + "d" = readDouble(con), + "r" = readRaw(con), + "D" = readDate(con), + "t" = readTime(con), + "l" = readList(con), + "n" = NULL, + "j" = getJobj(readString(con)), + stop(paste("Unsupported type for deserialization", type))) +} + +readString <- function(con) { + stringLen <- readInt(con) + string <- readBin(con, raw(), stringLen, endian = "big") + rawToChar(string) +} + +readInt <- function(con) { + readBin(con, integer(), n = 1, endian = "big") +} + +readDouble <- function(con) { + readBin(con, double(), n = 1, endian = "big") +} + +readBoolean <- function(con) { + as.logical(readInt(con)) +} + +readType <- function(con) { + rawToChar(readBin(con, "raw", n = 1L)) +} + +readDate <- function(con) { + as.Date(readString(con)) +} + +readTime <- function(con) { + t <- readDouble(con) + as.POSIXct(t, origin = "1970-01-01") +} + +# We only support lists where all elements are of same type +readList <- function(con) { + type <- readType(con) + len <- readInt(con) + if (len > 0) { + l <- vector("list", len) + for (i in 1:len) { + l[[i]] <- readTypedObject(con, type) + } + l + } else { + list() + } +} + +readRaw <- function(con) { + dataLen <- readInt(con) + data <- readBin(con, raw(), as.integer(dataLen), endian = "big") +} + +readRawLen <- function(con, dataLen) { + data <- readBin(con, raw(), as.integer(dataLen), endian = "big") +} + +readDeserialize <- function(con) { + # We have two cases that are possible - In one, the entire partition is + # encoded as a byte array, so we have only one value to read. If so just + # return firstData + dataLen <- readInt(con) + firstData <- unserialize( + readBin(con, raw(), as.integer(dataLen), endian = "big")) + + # Else, read things into a list + dataLen <- readInt(con) + if (length(dataLen) > 0 && dataLen > 0) { + data <- list(firstData) + while (length(dataLen) > 0 && dataLen > 0) { + data[[length(data) + 1L]] <- unserialize( + readBin(con, raw(), as.integer(dataLen), endian = "big")) + dataLen <- readInt(con) + } + unlist(data, recursive = FALSE) + } else { + firstData + } +} + +readDeserializeRows <- function(inputCon) { + # readDeserializeRows will deserialize a DataOutputStream composed of + # a list of lists. Since the DOS is one continuous stream and + # the number of rows varies, we put the readRow function in a while loop + # that termintates when the next row is empty. + data <- list() + while(TRUE) { + row <- readRow(inputCon) + if (length(row) == 0) { + break + } + data[[length(data) + 1L]] <- row + } + data # this is a list of named lists now +} + +readRowList <- function(obj) { + # readRowList is meant for use inside an lapply. As a result, it is + # necessary to open a standalone connection for the row and consume + # the numCols bytes inside the read function in order to correctly + # deserialize the row. + rawObj <- rawConnection(obj, "r+") + on.exit(close(rawObj)) + readRow(rawObj) +} + +readRow <- function(inputCon) { + numCols <- readInt(inputCon) + if (length(numCols) > 0 && numCols > 0) { + lapply(1:numCols, function(x) { + obj <- readObject(inputCon) + if (is.null(obj)) { + NA + } else { + obj + } + }) # each row is a list now + } else { + list() + } +} + +# Take a single column as Array[Byte] and deserialize it into an atomic vector +readCol <- function(inputCon, numRows) { + # sapply can not work with POSIXlt + do.call(c, lapply(1:numRows, function(x) { + value <- readObject(inputCon) + # Replace NULL with NA so we can coerce to vectors + if (is.null(value)) NA else value + })) +} diff --git a/R/pkg/R/generics.R b/R/pkg/R/generics.R new file mode 100644 index 0000000000000..5fb1ccaa84ee2 --- /dev/null +++ b/R/pkg/R/generics.R @@ -0,0 +1,543 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +############ RDD Actions and Transformations ############ + +#' @rdname aggregateRDD +#' @seealso reduce +#' @export +setGeneric("aggregateRDD", function(x, zeroValue, seqOp, combOp) { standardGeneric("aggregateRDD") }) + +#' @rdname cache-methods +#' @export +setGeneric("cache", function(x) { standardGeneric("cache") }) + +#' @rdname coalesce +#' @seealso repartition +#' @export +setGeneric("coalesce", function(x, numPartitions, ...) { standardGeneric("coalesce") }) + +#' @rdname checkpoint-methods +#' @export +setGeneric("checkpoint", function(x) { standardGeneric("checkpoint") }) + +#' @rdname collect-methods +#' @export +setGeneric("collect", function(x, ...) { standardGeneric("collect") }) + +#' @rdname collect-methods +#' @export +setGeneric("collectAsMap", function(x) { standardGeneric("collectAsMap") }) + +#' @rdname collect-methods +#' @export +setGeneric("collectPartition", + function(x, partitionId) { + standardGeneric("collectPartition") + }) + +#' @rdname count +#' @export +setGeneric("count", function(x) { standardGeneric("count") }) + +#' @rdname countByValue +#' @export +setGeneric("countByValue", function(x) { standardGeneric("countByValue") }) + +#' @rdname distinct +#' @export +setGeneric("distinct", function(x, numPartitions = 1L) { standardGeneric("distinct") }) + +#' @rdname filterRDD +#' @export +setGeneric("filterRDD", function(x, f) { standardGeneric("filterRDD") }) + +#' @rdname first +#' @export +setGeneric("first", function(x) { standardGeneric("first") }) + +#' @rdname flatMap +#' @export +setGeneric("flatMap", function(X, FUN) { standardGeneric("flatMap") }) + +#' @rdname fold +#' @seealso reduce +#' @export +setGeneric("fold", function(x, zeroValue, op) { standardGeneric("fold") }) + +#' @rdname foreach +#' @export +setGeneric("foreach", function(x, func) { standardGeneric("foreach") }) + +#' @rdname foreach +#' @export +setGeneric("foreachPartition", function(x, func) { standardGeneric("foreachPartition") }) + +# The jrdd accessor function. +setGeneric("getJRDD", function(rdd, ...) { standardGeneric("getJRDD") }) + +#' @rdname glom +#' @export +setGeneric("glom", function(x) { standardGeneric("glom") }) + +#' @rdname keyBy +#' @export +setGeneric("keyBy", function(x, func) { standardGeneric("keyBy") }) + +#' @rdname lapplyPartition +#' @export +setGeneric("lapplyPartition", function(X, FUN) { standardGeneric("lapplyPartition") }) + +#' @rdname lapplyPartitionsWithIndex +#' @export +setGeneric("lapplyPartitionsWithIndex", + function(X, FUN) { + standardGeneric("lapplyPartitionsWithIndex") + }) + +#' @rdname lapply +#' @export +setGeneric("map", function(X, FUN) { standardGeneric("map") }) + +#' @rdname lapplyPartition +#' @export +setGeneric("mapPartitions", function(X, FUN) { standardGeneric("mapPartitions") }) + +#' @rdname lapplyPartitionsWithIndex +#' @export +setGeneric("mapPartitionsWithIndex", + function(X, FUN) { standardGeneric("mapPartitionsWithIndex") }) + +#' @rdname maximum +#' @export +setGeneric("maximum", function(x) { standardGeneric("maximum") }) + +#' @rdname minimum +#' @export +setGeneric("minimum", function(x) { standardGeneric("minimum") }) + +#' @rdname sumRDD +#' @export +setGeneric("sumRDD", function(x) { standardGeneric("sumRDD") }) + +#' @rdname name +#' @export +setGeneric("name", function(x) { standardGeneric("name") }) + +#' @rdname numPartitions +#' @export +setGeneric("numPartitions", function(x) { standardGeneric("numPartitions") }) + +#' @rdname persist +#' @export +setGeneric("persist", function(x, newLevel) { standardGeneric("persist") }) + +#' @rdname pipeRDD +#' @export +setGeneric("pipeRDD", function(x, command, env = list()) { standardGeneric("pipeRDD")}) + +#' @rdname reduce +#' @export +setGeneric("reduce", function(x, func) { standardGeneric("reduce") }) + +#' @rdname repartition +#' @seealso coalesce +#' @export +setGeneric("repartition", function(x, numPartitions) { standardGeneric("repartition") }) + +#' @rdname sampleRDD +#' @export +setGeneric("sampleRDD", + function(x, withReplacement, fraction, seed) { + standardGeneric("sampleRDD") + }) + +#' @rdname saveAsObjectFile +#' @seealso objectFile +#' @export +setGeneric("saveAsObjectFile", function(x, path) { standardGeneric("saveAsObjectFile") }) + +#' @rdname saveAsTextFile +#' @export +setGeneric("saveAsTextFile", function(x, path) { standardGeneric("saveAsTextFile") }) + +#' @rdname setName +#' @export +setGeneric("setName", function(x, name) { standardGeneric("setName") }) + +#' @rdname sortBy +#' @export +setGeneric("sortBy", + function(x, func, ascending = TRUE, numPartitions = 1L) { + standardGeneric("sortBy") + }) + +#' @rdname take +#' @export +setGeneric("take", function(x, num) { standardGeneric("take") }) + +#' @rdname takeOrdered +#' @export +setGeneric("takeOrdered", function(x, num) { standardGeneric("takeOrdered") }) + +#' @rdname takeSample +#' @export +setGeneric("takeSample", + function(x, withReplacement, num, seed) { + standardGeneric("takeSample") + }) + +#' @rdname top +#' @export +setGeneric("top", function(x, num) { standardGeneric("top") }) + +#' @rdname unionRDD +#' @export +setGeneric("unionRDD", function(x, y) { standardGeneric("unionRDD") }) + +#' @rdname unpersist-methods +#' @export +setGeneric("unpersist", function(x, ...) { standardGeneric("unpersist") }) + +#' @rdname zipRDD +#' @export +setGeneric("zipRDD", function(x, other) { standardGeneric("zipRDD") }) + +#' @rdname zipWithIndex +#' @seealso zipWithUniqueId +#' @export +setGeneric("zipWithIndex", function(x) { standardGeneric("zipWithIndex") }) + +#' @rdname zipWithUniqueId +#' @seealso zipWithIndex +#' @export +setGeneric("zipWithUniqueId", function(x) { standardGeneric("zipWithUniqueId") }) + + +############ Binary Functions ############# + +#' @rdname countByKey +#' @export +setGeneric("countByKey", function(x) { standardGeneric("countByKey") }) + +#' @rdname flatMapValues +#' @export +setGeneric("flatMapValues", function(X, FUN) { standardGeneric("flatMapValues") }) + +#' @rdname keys +#' @export +setGeneric("keys", function(x) { standardGeneric("keys") }) + +#' @rdname lookup +#' @export +setGeneric("lookup", function(x, key) { standardGeneric("lookup") }) + +#' @rdname mapValues +#' @export +setGeneric("mapValues", function(X, FUN) { standardGeneric("mapValues") }) + +#' @rdname values +#' @export +setGeneric("values", function(x) { standardGeneric("values") }) + + + +############ Shuffle Functions ############ + +#' @rdname aggregateByKey +#' @seealso foldByKey, combineByKey +#' @export +setGeneric("aggregateByKey", + function(x, zeroValue, seqOp, combOp, numPartitions) { + standardGeneric("aggregateByKey") + }) + +#' @rdname cogroup +#' @export +setGeneric("cogroup", + function(..., numPartitions) { + standardGeneric("cogroup") + }, + signature = "...") + +#' @rdname combineByKey +#' @seealso groupByKey, reduceByKey +#' @export +setGeneric("combineByKey", + function(x, createCombiner, mergeValue, mergeCombiners, numPartitions) { + standardGeneric("combineByKey") + }) + +#' @rdname foldByKey +#' @seealso aggregateByKey, combineByKey +#' @export +setGeneric("foldByKey", + function(x, zeroValue, func, numPartitions) { + standardGeneric("foldByKey") + }) + +#' @rdname join-methods +#' @export +setGeneric("fullOuterJoin", function(x, y, numPartitions) { standardGeneric("fullOuterJoin") }) + +#' @rdname groupByKey +#' @seealso reduceByKey +#' @export +setGeneric("groupByKey", function(x, numPartitions) { standardGeneric("groupByKey") }) + +#' @rdname join-methods +#' @export +setGeneric("join", function(x, y, ...) { standardGeneric("join") }) + +#' @rdname join-methods +#' @export +setGeneric("leftOuterJoin", function(x, y, numPartitions) { standardGeneric("leftOuterJoin") }) + +#' @rdname partitionBy +#' @export +setGeneric("partitionBy", function(x, numPartitions, ...) { standardGeneric("partitionBy") }) + +#' @rdname reduceByKey +#' @seealso groupByKey +#' @export +setGeneric("reduceByKey", function(x, combineFunc, numPartitions) { standardGeneric("reduceByKey")}) + +#' @rdname reduceByKeyLocally +#' @seealso reduceByKey +#' @export +setGeneric("reduceByKeyLocally", + function(x, combineFunc) { + standardGeneric("reduceByKeyLocally") + }) + +#' @rdname join-methods +#' @export +setGeneric("rightOuterJoin", function(x, y, numPartitions) { standardGeneric("rightOuterJoin") }) + +#' @rdname sortByKey +#' @export +setGeneric("sortByKey", function(x, ascending = TRUE, numPartitions = 1L) { + standardGeneric("sortByKey") +}) + + +################### Broadcast Variable Methods ################# + +#' @rdname broadcast +#' @export +setGeneric("value", function(bcast) { standardGeneric("value") }) + + + +#################### DataFrame Methods ######################## + +#' @rdname schema +#' @export +setGeneric("columns", function(x) {standardGeneric("columns") }) + +#' @rdname schema +#' @export +setGeneric("dtypes", function(x) { standardGeneric("dtypes") }) + +#' @rdname explain +#' @export +setGeneric("explain", function(x, ...) { standardGeneric("explain") }) + +#' @rdname filter +#' @export +setGeneric("filter", function(x, condition) { standardGeneric("filter") }) + +#' @rdname DataFrame +#' @export +setGeneric("groupBy", function(x, ...) { standardGeneric("groupBy") }) + +#' @rdname insertInto +#' @export +setGeneric("insertInto", function(x, tableName, ...) { standardGeneric("insertInto") }) + +#' @rdname intersect +#' @export +setGeneric("intersect", function(x, y) { standardGeneric("intersect") }) + +#' @rdname isLocal +#' @export +setGeneric("isLocal", function(x) { standardGeneric("isLocal") }) + +#' @rdname limit +#' @export +setGeneric("limit", function(x, num) {standardGeneric("limit") }) + +#' @rdname sortDF +#' @export +setGeneric("orderBy", function(x, col) { standardGeneric("orderBy") }) + +#' @rdname schema +#' @export +setGeneric("printSchema", function(x) { standardGeneric("printSchema") }) + +#' @rdname registerTempTable +#' @export +setGeneric("registerTempTable", function(x, tableName) { standardGeneric("registerTempTable") }) + +#' @rdname sampleDF +#' @export +setGeneric("sampleDF", + function(x, withReplacement, fraction, seed) { + standardGeneric("sampleDF") + }) + +#' @rdname saveAsParquetFile +#' @export +setGeneric("saveAsParquetFile", function(x, path) { standardGeneric("saveAsParquetFile") }) + +#' @rdname saveAsTable +#' @export +setGeneric("saveAsTable", function(df, tableName, source, mode, ...) { + standardGeneric("saveAsTable") +}) + +#' @rdname saveAsTable +#' @export +setGeneric("saveDF", function(df, path, source, mode, ...) { standardGeneric("saveDF") }) + +#' @rdname schema +#' @export +setGeneric("schema", function(x) { standardGeneric("schema") }) + +#' @rdname select +#' @export +setGeneric("select", function(x, col, ...) { standardGeneric("select") } ) + +#' @rdname select +#' @export +setGeneric("selectExpr", function(x, expr, ...) { standardGeneric("selectExpr") }) + +#' @rdname showDF +#' @export +setGeneric("showDF", function(x,...) { standardGeneric("showDF") }) + +#' @rdname sortDF +#' @export +setGeneric("sortDF", function(x, col, ...) { standardGeneric("sortDF") }) + +#' @rdname subtract +#' @export +setGeneric("subtract", function(x, y) { standardGeneric("subtract") }) + +#' @rdname tojson +#' @export +setGeneric("toJSON", function(x) { standardGeneric("toJSON") }) + +#' @rdname DataFrame +#' @export +setGeneric("toRDD", function(x) { standardGeneric("toRDD") }) + +#' @rdname unionAll +#' @export +setGeneric("unionAll", function(x, y) { standardGeneric("unionAll") }) + +#' @rdname filter +#' @export +setGeneric("where", function(x, condition) { standardGeneric("where") }) + +#' @rdname withColumn +#' @export +setGeneric("withColumn", function(x, colName, col) { standardGeneric("withColumn") }) + +#' @rdname withColumnRenamed +#' @export +setGeneric("withColumnRenamed", function(x, existingCol, newCol) { + standardGeneric("withColumnRenamed") }) + + +###################### Column Methods ########################## + +#' @rdname column +#' @export +setGeneric("approxCountDistinct", function(x, ...) { standardGeneric("approxCountDistinct") }) + +#' @rdname column +#' @export +setGeneric("asc", function(x) { standardGeneric("asc") }) + +#' @rdname column +#' @export +setGeneric("avg", function(x, ...) { standardGeneric("avg") }) + +#' @rdname column +#' @export +setGeneric("cast", function(x, dataType) { standardGeneric("cast") }) + +#' @rdname column +#' @export +setGeneric("contains", function(x, ...) { standardGeneric("contains") }) +#' @rdname column +#' @export +setGeneric("countDistinct", function(x, ...) { standardGeneric("countDistinct") }) + +#' @rdname column +#' @export +setGeneric("desc", function(x) { standardGeneric("desc") }) + +#' @rdname column +#' @export +setGeneric("endsWith", function(x, ...) { standardGeneric("endsWith") }) + +#' @rdname column +#' @export +setGeneric("getField", function(x, ...) { standardGeneric("getField") }) + +#' @rdname column +#' @export +setGeneric("getItem", function(x, ...) { standardGeneric("getItem") }) + +#' @rdname column +#' @export +setGeneric("isNull", function(x) { standardGeneric("isNull") }) + +#' @rdname column +#' @export +setGeneric("isNotNull", function(x) { standardGeneric("isNotNull") }) + +#' @rdname column +#' @export +setGeneric("last", function(x) { standardGeneric("last") }) + +#' @rdname column +#' @export +setGeneric("like", function(x, ...) { standardGeneric("like") }) + +#' @rdname column +#' @export +setGeneric("lower", function(x) { standardGeneric("lower") }) + +#' @rdname column +#' @export +setGeneric("rlike", function(x, ...) { standardGeneric("rlike") }) + +#' @rdname column +#' @export +setGeneric("startsWith", function(x, ...) { standardGeneric("startsWith") }) + +#' @rdname column +#' @export +setGeneric("sumDistinct", function(x) { standardGeneric("sumDistinct") }) + +#' @rdname column +#' @export +setGeneric("upper", function(x) { standardGeneric("upper") }) + diff --git a/R/pkg/R/group.R b/R/pkg/R/group.R new file mode 100644 index 0000000000000..09fc0a7abe48a --- /dev/null +++ b/R/pkg/R/group.R @@ -0,0 +1,132 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +# group.R - GroupedData class and methods implemented in S4 OO classes + +setOldClass("jobj") + +#' @title S4 class that represents a GroupedData +#' @description GroupedDatas can be created using groupBy() on a DataFrame +#' @rdname GroupedData +#' @seealso groupBy +#' +#' @param sgd A Java object reference to the backing Scala GroupedData +#' @export +setClass("GroupedData", + slots = list(sgd = "jobj")) + +setMethod("initialize", "GroupedData", function(.Object, sgd) { + .Object@sgd <- sgd + .Object +}) + +#' @rdname DataFrame +groupedData <- function(sgd) { + new("GroupedData", sgd) +} + + +#' @rdname show +setMethod("show", "GroupedData", + function(object) { + cat("GroupedData\n") + }) + +#' Count +#' +#' Count the number of rows for each group. +#' The resulting DataFrame will also contain the grouping columns. +#' +#' @param x a GroupedData +#' @return a DataFrame +#' @export +#' @examples +#' \dontrun{ +#' count(groupBy(df, "name")) +#' } +setMethod("count", + signature(x = "GroupedData"), + function(x) { + dataFrame(callJMethod(x@sgd, "count")) + }) + +#' Agg +#' +#' Aggregates on the entire DataFrame without groups. +#' The resulting DataFrame will also contain the grouping columns. +#' +#' df2 <- agg(df, = ) +#' df2 <- agg(df, newColName = aggFunction(column)) +#' +#' @param x a GroupedData +#' @return a DataFrame +#' @rdname agg +#' @examples +#' \dontrun{ +#' df2 <- agg(df, age = "sum") # new column name will be created as 'SUM(age#0)' +#' df2 <- agg(df, ageSum = sum(df$age)) # Creates a new column named ageSum +#' } +setGeneric("agg", function (x, ...) { standardGeneric("agg") }) + +setMethod("agg", + signature(x = "GroupedData"), + function(x, ...) { + cols = list(...) + stopifnot(length(cols) > 0) + if (is.character(cols[[1]])) { + cols <- varargsToEnv(...) + sdf <- callJMethod(x@sgd, "agg", cols) + } else if (class(cols[[1]]) == "Column") { + ns <- names(cols) + if (!is.null(ns)) { + for (n in ns) { + if (n != "") { + cols[[n]] = alias(cols[[n]], n) + } + } + } + jcols <- lapply(cols, function(c) { c@jc }) + # the GroupedData.agg(col, cols*) API does not contain grouping Column + sdf <- callJStatic("org.apache.spark.sql.api.r.SQLUtils", "aggWithGrouping", + x@sgd, listToSeq(jcols)) + } else { + stop("agg can only support Column or character") + } + dataFrame(sdf) + }) + + +# sum/mean/avg/min/max +methods <- c("sum", "mean", "avg", "min", "max") + +createMethod <- function(name) { + setMethod(name, + signature(x = "GroupedData"), + function(x, ...) { + sdf <- callJMethod(x@sgd, name, toSeq(...)) + dataFrame(sdf) + }) +} + +createMethods <- function() { + for (name in methods) { + createMethod(name) + } +} + +createMethods() + diff --git a/R/pkg/R/jobj.R b/R/pkg/R/jobj.R new file mode 100644 index 0000000000000..4180f146b7fbc --- /dev/null +++ b/R/pkg/R/jobj.R @@ -0,0 +1,101 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +# References to objects that exist on the JVM backend +# are maintained using the jobj. + +# Maintain a reference count of Java object references +# This allows us to GC the java object when it is safe +.validJobjs <- new.env(parent = emptyenv()) + +# List of object ids to be removed +.toRemoveJobjs <- new.env(parent = emptyenv()) + +# Check if jobj was created with the current SparkContext +isValidJobj <- function(jobj) { + if (exists(".scStartTime", envir = .sparkREnv)) { + jobj$appId == get(".scStartTime", envir = .sparkREnv) + } else { + FALSE + } +} + +getJobj <- function(objId) { + newObj <- jobj(objId) + if (exists(objId, .validJobjs)) { + .validJobjs[[objId]] <- .validJobjs[[objId]] + 1 + } else { + .validJobjs[[objId]] <- 1 + } + newObj +} + +# Handler for a java object that exists on the backend. +jobj <- function(objId) { + if (!is.character(objId)) { + stop("object id must be a character") + } + # NOTE: We need a new env for a jobj as we can only register + # finalizers for environments or external references pointers. + obj <- structure(new.env(parent = emptyenv()), class = "jobj") + obj$id <- objId + obj$appId <- get(".scStartTime", envir = .sparkREnv) + + # Register a finalizer to remove the Java object when this reference + # is garbage collected in R + reg.finalizer(obj, cleanup.jobj) + obj +} + +#' Print a JVM object reference. +#' +#' This function prints the type and id for an object stored +#' in the SparkR JVM backend. +#' +#' @param x The JVM object reference +#' @param ... further arguments passed to or from other methods +print.jobj <- function(x, ...) { + cls <- callJMethod(x, "getClass") + name <- callJMethod(cls, "getName") + cat("Java ref type", name, "id", x$id, "\n", sep = " ") +} + +cleanup.jobj <- function(jobj) { + if (isValidJobj(jobj)) { + objId <- jobj$id + # If we don't know anything about this jobj, ignore it + if (exists(objId, envir = .validJobjs)) { + .validJobjs[[objId]] <- .validJobjs[[objId]] - 1 + + if (.validJobjs[[objId]] == 0) { + rm(list = objId, envir = .validJobjs) + # NOTE: We cannot call removeJObject here as the finalizer may be run + # in the middle of another RPC. Thus we queue up this object Id to be removed + # and then run all the removeJObject when the next RPC is called. + .toRemoveJobjs[[objId]] <- 1 + } + } + } +} + +clearJobjs <- function() { + valid <- ls(.validJobjs) + rm(list = valid, envir = .validJobjs) + + removeList <- ls(.toRemoveJobjs) + rm(list = removeList, envir = .toRemoveJobjs) +} diff --git a/R/pkg/R/pairRDD.R b/R/pkg/R/pairRDD.R new file mode 100644 index 0000000000000..c2396c32a7548 --- /dev/null +++ b/R/pkg/R/pairRDD.R @@ -0,0 +1,789 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +# Operations supported on RDDs contains pairs (i.e key, value) + +############ Actions and Transformations ############ + +#' Look up elements of a key in an RDD +#' +#' @description +#' \code{lookup} returns a list of values in this RDD for key key. +#' +#' @param x The RDD to collect +#' @param key The key to look up for +#' @return a list of values in this RDD for key key +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' pairs <- list(c(1, 1), c(2, 2), c(1, 3)) +#' rdd <- parallelize(sc, pairs) +#' lookup(rdd, 1) # list(1, 3) +#'} +#' @rdname lookup +#' @aliases lookup,RDD-method +setMethod("lookup", + signature(x = "RDD", key = "ANY"), + function(x, key) { + partitionFunc <- function(part) { + filtered <- part[unlist(lapply(part, function(i) { identical(key, i[[1]]) }))] + lapply(filtered, function(i) { i[[2]] }) + } + valsRDD <- lapplyPartition(x, partitionFunc) + collect(valsRDD) + }) + +#' Count the number of elements for each key, and return the result to the +#' master as lists of (key, count) pairs. +#' +#' Same as countByKey in Spark. +#' +#' @param x The RDD to count keys. +#' @return list of (key, count) pairs, where count is number of each key in rdd. +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' rdd <- parallelize(sc, list(c("a", 1), c("b", 1), c("a", 1))) +#' countByKey(rdd) # ("a", 2L), ("b", 1L) +#'} +#' @rdname countByKey +#' @aliases countByKey,RDD-method +setMethod("countByKey", + signature(x = "RDD"), + function(x) { + keys <- lapply(x, function(item) { item[[1]] }) + countByValue(keys) + }) + +#' Return an RDD with the keys of each tuple. +#' +#' @param x The RDD from which the keys of each tuple is returned. +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' rdd <- parallelize(sc, list(list(1, 2), list(3, 4))) +#' collect(keys(rdd)) # list(1, 3) +#'} +#' @rdname keys +#' @aliases keys,RDD +setMethod("keys", + signature(x = "RDD"), + function(x) { + func <- function(k) { + k[[1]] + } + lapply(x, func) + }) + +#' Return an RDD with the values of each tuple. +#' +#' @param x The RDD from which the values of each tuple is returned. +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' rdd <- parallelize(sc, list(list(1, 2), list(3, 4))) +#' collect(values(rdd)) # list(2, 4) +#'} +#' @rdname values +#' @aliases values,RDD +setMethod("values", + signature(x = "RDD"), + function(x) { + func <- function(v) { + v[[2]] + } + lapply(x, func) + }) + +#' Applies a function to all values of the elements, without modifying the keys. +#' +#' The same as `mapValues()' in Spark. +#' +#' @param X The RDD to apply the transformation. +#' @param FUN the transformation to apply on the value of each element. +#' @return a new RDD created by the transformation. +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' rdd <- parallelize(sc, 1:10) +#' makePairs <- lapply(rdd, function(x) { list(x, x) }) +#' collect(mapValues(makePairs, function(x) { x * 2) }) +#' Output: list(list(1,2), list(2,4), list(3,6), ...) +#'} +#' @rdname mapValues +#' @aliases mapValues,RDD,function-method +setMethod("mapValues", + signature(X = "RDD", FUN = "function"), + function(X, FUN) { + func <- function(x) { + list(x[[1]], FUN(x[[2]])) + } + lapply(X, func) + }) + +#' Pass each value in the key-value pair RDD through a flatMap function without +#' changing the keys; this also retains the original RDD's partitioning. +#' +#' The same as 'flatMapValues()' in Spark. +#' +#' @param X The RDD to apply the transformation. +#' @param FUN the transformation to apply on the value of each element. +#' @return a new RDD created by the transformation. +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' rdd <- parallelize(sc, list(list(1, c(1,2)), list(2, c(3,4)))) +#' collect(flatMapValues(rdd, function(x) { x })) +#' Output: list(list(1,1), list(1,2), list(2,3), list(2,4)) +#'} +#' @rdname flatMapValues +#' @aliases flatMapValues,RDD,function-method +setMethod("flatMapValues", + signature(X = "RDD", FUN = "function"), + function(X, FUN) { + flatMapFunc <- function(x) { + lapply(FUN(x[[2]]), function(v) { list(x[[1]], v) }) + } + flatMap(X, flatMapFunc) + }) + +############ Shuffle Functions ############ + +#' Partition an RDD by key +#' +#' This function operates on RDDs where every element is of the form list(K, V) or c(K, V). +#' For each element of this RDD, the partitioner is used to compute a hash +#' function and the RDD is partitioned using this hash value. +#' +#' @param x The RDD to partition. Should be an RDD where each element is +#' list(K, V) or c(K, V). +#' @param numPartitions Number of partitions to create. +#' @param ... Other optional arguments to partitionBy. +#' +#' @param partitionFunc The partition function to use. Uses a default hashCode +#' function if not provided +#' @return An RDD partitioned using the specified partitioner. +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' pairs <- list(list(1, 2), list(1.1, 3), list(1, 4)) +#' rdd <- parallelize(sc, pairs) +#' parts <- partitionBy(rdd, 2L) +#' collectPartition(parts, 0L) # First partition should contain list(1, 2) and list(1, 4) +#'} +#' @rdname partitionBy +#' @aliases partitionBy,RDD,integer-method +setMethod("partitionBy", + signature(x = "RDD", numPartitions = "integer"), + function(x, numPartitions, partitionFunc = hashCode) { + + #if (missing(partitionFunc)) { + # partitionFunc <- hashCode + #} + + partitionFunc <- cleanClosure(partitionFunc) + serializedHashFuncBytes <- serialize(partitionFunc, connection = NULL) + + packageNamesArr <- serialize(.sparkREnv$.packages, + connection = NULL) + broadcastArr <- lapply(ls(.broadcastNames), function(name) { + get(name, .broadcastNames) }) + jrdd <- getJRDD(x) + + # We create a PairwiseRRDD that extends RDD[(Array[Byte], + # Array[Byte])], where the key is the hashed split, the value is + # the content (key-val pairs). + pairwiseRRDD <- newJObject("org.apache.spark.api.r.PairwiseRRDD", + callJMethod(jrdd, "rdd"), + as.integer(numPartitions), + serializedHashFuncBytes, + getSerializedMode(x), + packageNamesArr, + as.character(.sparkREnv$libname), + broadcastArr, + callJMethod(jrdd, "classTag")) + + # Create a corresponding partitioner. + rPartitioner <- newJObject("org.apache.spark.HashPartitioner", + as.integer(numPartitions)) + + # Call partitionBy on the obtained PairwiseRDD. + javaPairRDD <- callJMethod(pairwiseRRDD, "asJavaPairRDD") + javaPairRDD <- callJMethod(javaPairRDD, "partitionBy", rPartitioner) + + # Call .values() on the result to get back the final result, the + # shuffled acutal content key-val pairs. + r <- callJMethod(javaPairRDD, "values") + + RDD(r, serializedMode = "byte") + }) + +#' Group values by key +#' +#' This function operates on RDDs where every element is of the form list(K, V) or c(K, V). +#' and group values for each key in the RDD into a single sequence. +#' +#' @param x The RDD to group. Should be an RDD where each element is +#' list(K, V) or c(K, V). +#' @param numPartitions Number of partitions to create. +#' @return An RDD where each element is list(K, list(V)) +#' @seealso reduceByKey +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' pairs <- list(list(1, 2), list(1.1, 3), list(1, 4)) +#' rdd <- parallelize(sc, pairs) +#' parts <- groupByKey(rdd, 2L) +#' grouped <- collect(parts) +#' grouped[[1]] # Should be a list(1, list(2, 4)) +#'} +#' @rdname groupByKey +#' @aliases groupByKey,RDD,integer-method +setMethod("groupByKey", + signature(x = "RDD", numPartitions = "integer"), + function(x, numPartitions) { + shuffled <- partitionBy(x, numPartitions) + groupVals <- function(part) { + vals <- new.env() + keys <- new.env() + pred <- function(item) exists(item$hash, keys) + appendList <- function(acc, i) { + addItemToAccumulator(acc, i) + acc + } + makeList <- function(i) { + acc <- initAccumulator() + addItemToAccumulator(acc, i) + acc + } + # Each item in the partition is list of (K, V) + lapply(part, + function(item) { + item$hash <- as.character(hashCode(item[[1]])) + updateOrCreatePair(item, keys, vals, pred, + appendList, makeList) + }) + # extract out data field + vals <- eapply(vals, + function(i) { + length(i$data) <- i$counter + i$data + }) + # Every key in the environment contains a list + # Convert that to list(K, Seq[V]) + convertEnvsToList(keys, vals) + } + lapplyPartition(shuffled, groupVals) + }) + +#' Merge values by key +#' +#' This function operates on RDDs where every element is of the form list(K, V) or c(K, V). +#' and merges the values for each key using an associative reduce function. +#' +#' @param x The RDD to reduce by key. Should be an RDD where each element is +#' list(K, V) or c(K, V). +#' @param combineFunc The associative reduce function to use. +#' @param numPartitions Number of partitions to create. +#' @return An RDD where each element is list(K, V') where V' is the merged +#' value +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' pairs <- list(list(1, 2), list(1.1, 3), list(1, 4)) +#' rdd <- parallelize(sc, pairs) +#' parts <- reduceByKey(rdd, "+", 2L) +#' reduced <- collect(parts) +#' reduced[[1]] # Should be a list(1, 6) +#'} +#' @rdname reduceByKey +#' @aliases reduceByKey,RDD,integer-method +setMethod("reduceByKey", + signature(x = "RDD", combineFunc = "ANY", numPartitions = "integer"), + function(x, combineFunc, numPartitions) { + reduceVals <- function(part) { + vals <- new.env() + keys <- new.env() + pred <- function(item) exists(item$hash, keys) + lapply(part, + function(item) { + item$hash <- as.character(hashCode(item[[1]])) + updateOrCreatePair(item, keys, vals, pred, combineFunc, identity) + }) + convertEnvsToList(keys, vals) + } + locallyReduced <- lapplyPartition(x, reduceVals) + shuffled <- partitionBy(locallyReduced, numPartitions) + lapplyPartition(shuffled, reduceVals) + }) + +#' Merge values by key locally +#' +#' This function operates on RDDs where every element is of the form list(K, V) or c(K, V). +#' and merges the values for each key using an associative reduce function, but return the +#' results immediately to the driver as an R list. +#' +#' @param x The RDD to reduce by key. Should be an RDD where each element is +#' list(K, V) or c(K, V). +#' @param combineFunc The associative reduce function to use. +#' @return A list of elements of type list(K, V') where V' is the merged value for each key +#' @seealso reduceByKey +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' pairs <- list(list(1, 2), list(1.1, 3), list(1, 4)) +#' rdd <- parallelize(sc, pairs) +#' reduced <- reduceByKeyLocally(rdd, "+") +#' reduced # list(list(1, 6), list(1.1, 3)) +#'} +#' @rdname reduceByKeyLocally +#' @aliases reduceByKeyLocally,RDD,integer-method +setMethod("reduceByKeyLocally", + signature(x = "RDD", combineFunc = "ANY"), + function(x, combineFunc) { + reducePart <- function(part) { + vals <- new.env() + keys <- new.env() + pred <- function(item) exists(item$hash, keys) + lapply(part, + function(item) { + item$hash <- as.character(hashCode(item[[1]])) + updateOrCreatePair(item, keys, vals, pred, combineFunc, identity) + }) + list(list(keys, vals)) # return hash to avoid re-compute in merge + } + mergeParts <- function(accum, x) { + pred <- function(item) { + exists(item$hash, accum[[1]]) + } + lapply(ls(x[[1]]), + function(name) { + item <- list(x[[1]][[name]], x[[2]][[name]]) + item$hash <- name + updateOrCreatePair(item, accum[[1]], accum[[2]], pred, combineFunc, identity) + }) + accum + } + reduced <- mapPartitions(x, reducePart) + merged <- reduce(reduced, mergeParts) + convertEnvsToList(merged[[1]], merged[[2]]) + }) + +#' Combine values by key +#' +#' Generic function to combine the elements for each key using a custom set of +#' aggregation functions. Turns an RDD[(K, V)] into a result of type RDD[(K, C)], +#' for a "combined type" C. Note that V and C can be different -- for example, one +#' might group an RDD of type (Int, Int) into an RDD of type (Int, Seq[Int]). + +#' Users provide three functions: +#' \itemize{ +#' \item createCombiner, which turns a V into a C (e.g., creates a one-element list) +#' \item mergeValue, to merge a V into a C (e.g., adds it to the end of a list) - +#' \item mergeCombiners, to combine two C's into a single one (e.g., concatentates +#' two lists). +#' } +#' +#' @param x The RDD to combine. Should be an RDD where each element is +#' list(K, V) or c(K, V). +#' @param createCombiner Create a combiner (C) given a value (V) +#' @param mergeValue Merge the given value (V) with an existing combiner (C) +#' @param mergeCombiners Merge two combiners and return a new combiner +#' @param numPartitions Number of partitions to create. +#' @return An RDD where each element is list(K, C) where C is the combined type +#' +#' @seealso groupByKey, reduceByKey +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' pairs <- list(list(1, 2), list(1.1, 3), list(1, 4)) +#' rdd <- parallelize(sc, pairs) +#' parts <- combineByKey(rdd, function(x) { x }, "+", "+", 2L) +#' combined <- collect(parts) +#' combined[[1]] # Should be a list(1, 6) +#'} +#' @rdname combineByKey +#' @aliases combineByKey,RDD,ANY,ANY,ANY,integer-method +setMethod("combineByKey", + signature(x = "RDD", createCombiner = "ANY", mergeValue = "ANY", + mergeCombiners = "ANY", numPartitions = "integer"), + function(x, createCombiner, mergeValue, mergeCombiners, numPartitions) { + combineLocally <- function(part) { + combiners <- new.env() + keys <- new.env() + pred <- function(item) exists(item$hash, keys) + lapply(part, + function(item) { + item$hash <- as.character(item[[1]]) + updateOrCreatePair(item, keys, combiners, pred, mergeValue, createCombiner) + }) + convertEnvsToList(keys, combiners) + } + locallyCombined <- lapplyPartition(x, combineLocally) + shuffled <- partitionBy(locallyCombined, numPartitions) + mergeAfterShuffle <- function(part) { + combiners <- new.env() + keys <- new.env() + pred <- function(item) exists(item$hash, keys) + lapply(part, + function(item) { + item$hash <- as.character(item[[1]]) + updateOrCreatePair(item, keys, combiners, pred, mergeCombiners, identity) + }) + convertEnvsToList(keys, combiners) + } + lapplyPartition(shuffled, mergeAfterShuffle) + }) + +#' Aggregate a pair RDD by each key. +#' +#' Aggregate the values of each key in an RDD, using given combine functions +#' and a neutral "zero value". This function can return a different result type, +#' U, than the type of the values in this RDD, V. Thus, we need one operation +#' for merging a V into a U and one operation for merging two U's, The former +#' operation is used for merging values within a partition, and the latter is +#' used for merging values between partitions. To avoid memory allocation, both +#' of these functions are allowed to modify and return their first argument +#' instead of creating a new U. +#' +#' @param x An RDD. +#' @param zeroValue A neutral "zero value". +#' @param seqOp A function to aggregate the values of each key. It may return +#' a different result type from the type of the values. +#' @param combOp A function to aggregate results of seqOp. +#' @return An RDD containing the aggregation result. +#' @seealso foldByKey, combineByKey +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' rdd <- parallelize(sc, list(list(1, 1), list(1, 2), list(2, 3), list(2, 4))) +#' zeroValue <- list(0, 0) +#' seqOp <- function(x, y) { list(x[[1]] + y, x[[2]] + 1) } +#' combOp <- function(x, y) { list(x[[1]] + y[[1]], x[[2]] + y[[2]]) } +#' aggregateByKey(rdd, zeroValue, seqOp, combOp, 2L) +#' # list(list(1, list(3, 2)), list(2, list(7, 2))) +#'} +#' @rdname aggregateByKey +#' @aliases aggregateByKey,RDD,ANY,ANY,ANY,integer-method +setMethod("aggregateByKey", + signature(x = "RDD", zeroValue = "ANY", seqOp = "ANY", + combOp = "ANY", numPartitions = "integer"), + function(x, zeroValue, seqOp, combOp, numPartitions) { + createCombiner <- function(v) { + do.call(seqOp, list(zeroValue, v)) + } + + combineByKey(x, createCombiner, seqOp, combOp, numPartitions) + }) + +#' Fold a pair RDD by each key. +#' +#' Aggregate the values of each key in an RDD, using an associative function "func" +#' and a neutral "zero value" which may be added to the result an arbitrary +#' number of times, and must not change the result (e.g., 0 for addition, or +#' 1 for multiplication.). +#' +#' @param x An RDD. +#' @param zeroValue A neutral "zero value". +#' @param func An associative function for folding values of each key. +#' @return An RDD containing the aggregation result. +#' @seealso aggregateByKey, combineByKey +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' rdd <- parallelize(sc, list(list(1, 1), list(1, 2), list(2, 3), list(2, 4))) +#' foldByKey(rdd, 0, "+", 2L) # list(list(1, 3), list(2, 7)) +#'} +#' @rdname foldByKey +#' @aliases foldByKey,RDD,ANY,ANY,integer-method +setMethod("foldByKey", + signature(x = "RDD", zeroValue = "ANY", + func = "ANY", numPartitions = "integer"), + function(x, zeroValue, func, numPartitions) { + aggregateByKey(x, zeroValue, func, func, numPartitions) + }) + +############ Binary Functions ############# + +#' Join two RDDs +#' +#' @description +#' \code{join} This function joins two RDDs where every element is of the form list(K, V). +#' The key types of the two RDDs should be the same. +#' +#' @param x An RDD to be joined. Should be an RDD where each element is +#' list(K, V). +#' @param y An RDD to be joined. Should be an RDD where each element is +#' list(K, V). +#' @param numPartitions Number of partitions to create. +#' @return a new RDD containing all pairs of elements with matching keys in +#' two input RDDs. +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' rdd1 <- parallelize(sc, list(list(1, 1), list(2, 4))) +#' rdd2 <- parallelize(sc, list(list(1, 2), list(1, 3))) +#' join(rdd1, rdd2, 2L) # list(list(1, list(1, 2)), list(1, list(1, 3)) +#'} +#' @rdname join-methods +#' @aliases join,RDD,RDD-method +setMethod("join", + signature(x = "RDD", y = "RDD"), + function(x, y, numPartitions) { + xTagged <- lapply(x, function(i) { list(i[[1]], list(1L, i[[2]])) }) + yTagged <- lapply(y, function(i) { list(i[[1]], list(2L, i[[2]])) }) + + doJoin <- function(v) { + joinTaggedList(v, list(FALSE, FALSE)) + } + + joined <- flatMapValues(groupByKey(unionRDD(xTagged, yTagged), numToInt(numPartitions)), + doJoin) + }) + +#' Left outer join two RDDs +#' +#' @description +#' \code{leftouterjoin} This function left-outer-joins two RDDs where every element is of the form list(K, V). +#' The key types of the two RDDs should be the same. +#' +#' @param x An RDD to be joined. Should be an RDD where each element is +#' list(K, V). +#' @param y An RDD to be joined. Should be an RDD where each element is +#' list(K, V). +#' @param numPartitions Number of partitions to create. +#' @return For each element (k, v) in x, the resulting RDD will either contain +#' all pairs (k, (v, w)) for (k, w) in rdd2, or the pair (k, (v, NULL)) +#' if no elements in rdd2 have key k. +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' rdd1 <- parallelize(sc, list(list(1, 1), list(2, 4))) +#' rdd2 <- parallelize(sc, list(list(1, 2), list(1, 3))) +#' leftOuterJoin(rdd1, rdd2, 2L) +#' # list(list(1, list(1, 2)), list(1, list(1, 3)), list(2, list(4, NULL))) +#'} +#' @rdname join-methods +#' @aliases leftOuterJoin,RDD,RDD-method +setMethod("leftOuterJoin", + signature(x = "RDD", y = "RDD", numPartitions = "integer"), + function(x, y, numPartitions) { + xTagged <- lapply(x, function(i) { list(i[[1]], list(1L, i[[2]])) }) + yTagged <- lapply(y, function(i) { list(i[[1]], list(2L, i[[2]])) }) + + doJoin <- function(v) { + joinTaggedList(v, list(FALSE, TRUE)) + } + + joined <- flatMapValues(groupByKey(unionRDD(xTagged, yTagged), numPartitions), doJoin) + }) + +#' Right outer join two RDDs +#' +#' @description +#' \code{rightouterjoin} This function right-outer-joins two RDDs where every element is of the form list(K, V). +#' The key types of the two RDDs should be the same. +#' +#' @param x An RDD to be joined. Should be an RDD where each element is +#' list(K, V). +#' @param y An RDD to be joined. Should be an RDD where each element is +#' list(K, V). +#' @param numPartitions Number of partitions to create. +#' @return For each element (k, w) in y, the resulting RDD will either contain +#' all pairs (k, (v, w)) for (k, v) in x, or the pair (k, (NULL, w)) +#' if no elements in x have key k. +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' rdd1 <- parallelize(sc, list(list(1, 2), list(1, 3))) +#' rdd2 <- parallelize(sc, list(list(1, 1), list(2, 4))) +#' rightOuterJoin(rdd1, rdd2, 2L) +#' # list(list(1, list(2, 1)), list(1, list(3, 1)), list(2, list(NULL, 4))) +#'} +#' @rdname join-methods +#' @aliases rightOuterJoin,RDD,RDD-method +setMethod("rightOuterJoin", + signature(x = "RDD", y = "RDD", numPartitions = "integer"), + function(x, y, numPartitions) { + xTagged <- lapply(x, function(i) { list(i[[1]], list(1L, i[[2]])) }) + yTagged <- lapply(y, function(i) { list(i[[1]], list(2L, i[[2]])) }) + + doJoin <- function(v) { + joinTaggedList(v, list(TRUE, FALSE)) + } + + joined <- flatMapValues(groupByKey(unionRDD(xTagged, yTagged), numPartitions), doJoin) + }) + +#' Full outer join two RDDs +#' +#' @description +#' \code{fullouterjoin} This function full-outer-joins two RDDs where every element is of the form list(K, V). +#' The key types of the two RDDs should be the same. +#' +#' @param x An RDD to be joined. Should be an RDD where each element is +#' list(K, V). +#' @param y An RDD to be joined. Should be an RDD where each element is +#' list(K, V). +#' @param numPartitions Number of partitions to create. +#' @return For each element (k, v) in x and (k, w) in y, the resulting RDD +#' will contain all pairs (k, (v, w)) for both (k, v) in x and +#' (k, w) in y, or the pair (k, (NULL, w))/(k, (v, NULL)) if no elements +#' in x/y have key k. +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' rdd1 <- parallelize(sc, list(list(1, 2), list(1, 3), list(3, 3))) +#' rdd2 <- parallelize(sc, list(list(1, 1), list(2, 4))) +#' fullOuterJoin(rdd1, rdd2, 2L) # list(list(1, list(2, 1)), +#' # list(1, list(3, 1)), +#' # list(2, list(NULL, 4))) +#' # list(3, list(3, NULL)), +#'} +#' @rdname join-methods +#' @aliases fullOuterJoin,RDD,RDD-method +setMethod("fullOuterJoin", + signature(x = "RDD", y = "RDD", numPartitions = "integer"), + function(x, y, numPartitions) { + xTagged <- lapply(x, function(i) { list(i[[1]], list(1L, i[[2]])) }) + yTagged <- lapply(y, function(i) { list(i[[1]], list(2L, i[[2]])) }) + + doJoin <- function(v) { + joinTaggedList(v, list(TRUE, TRUE)) + } + + joined <- flatMapValues(groupByKey(unionRDD(xTagged, yTagged), numPartitions), doJoin) + }) + +#' For each key k in several RDDs, return a resulting RDD that +#' whose values are a list of values for the key in all RDDs. +#' +#' @param ... Several RDDs. +#' @param numPartitions Number of partitions to create. +#' @return a new RDD containing all pairs of elements with values in a list +#' in all RDDs. +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' rdd1 <- parallelize(sc, list(list(1, 1), list(2, 4))) +#' rdd2 <- parallelize(sc, list(list(1, 2), list(1, 3))) +#' cogroup(rdd1, rdd2, numPartitions = 2L) +#' # list(list(1, list(1, list(2, 3))), list(2, list(list(4), list())) +#'} +#' @rdname cogroup +#' @aliases cogroup,RDD-method +setMethod("cogroup", + "RDD", + function(..., numPartitions) { + rdds <- list(...) + rddsLen <- length(rdds) + for (i in 1:rddsLen) { + rdds[[i]] <- lapply(rdds[[i]], + function(x) { list(x[[1]], list(i, x[[2]])) }) + # TODO(hao): As issue [SparkR-142] mentions, the right value of i + # will not be captured into UDF if getJRDD is not invoked. + # It should be resolved together with that issue. + getJRDD(rdds[[i]]) # Capture the closure. + } + union.rdd <- Reduce(unionRDD, rdds) + group.func <- function(vlist) { + res <- list() + length(res) <- rddsLen + for (x in vlist) { + i <- x[[1]] + acc <- res[[i]] + # Create an accumulator. + if (is.null(acc)) { + acc <- initAccumulator() + } + addItemToAccumulator(acc, x[[2]]) + res[[i]] <- acc + } + lapply(res, function(acc) { + if (is.null(acc)) { + list() + } else { + acc$data + } + }) + } + cogroup.rdd <- mapValues(groupByKey(union.rdd, numPartitions), + group.func) + }) + +#' Sort a (k, v) pair RDD by k. +#' +#' @param x A (k, v) pair RDD to be sorted. +#' @param ascending A flag to indicate whether the sorting is ascending or descending. +#' @param numPartitions Number of partitions to create. +#' @return An RDD where all (k, v) pair elements are sorted. +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' rdd <- parallelize(sc, list(list(3, 1), list(2, 2), list(1, 3))) +#' collect(sortByKey(rdd)) # list (list(1, 3), list(2, 2), list(3, 1)) +#'} +#' @rdname sortByKey +#' @aliases sortByKey,RDD,RDD-method +setMethod("sortByKey", + signature(x = "RDD"), + function(x, ascending = TRUE, numPartitions = SparkR::numPartitions(x)) { + rangeBounds <- list() + + if (numPartitions > 1) { + rddSize <- count(x) + # constant from Spark's RangePartitioner + maxSampleSize <- numPartitions * 20 + fraction <- min(maxSampleSize / max(rddSize, 1), 1.0) + + samples <- collect(keys(sampleRDD(x, FALSE, fraction, 1L))) + + # Note: the built-in R sort() function only works on atomic vectors + samples <- sort(unlist(samples, recursive = FALSE), decreasing = !ascending) + + if (length(samples) > 0) { + rangeBounds <- lapply(seq_len(numPartitions - 1), + function(i) { + j <- ceiling(length(samples) * i / numPartitions) + samples[j] + }) + } + } + + rangePartitionFunc <- function(key) { + partition <- 0 + + # TODO: Use binary search instead of linear search, similar with Spark + while (partition < length(rangeBounds) && key > rangeBounds[[partition + 1]]) { + partition <- partition + 1 + } + + if (ascending) { + partition + } else { + numPartitions - partition - 1 + } + } + + partitionFunc <- function(part) { + sortKeyValueList(part, decreasing = !ascending) + } + + newRDD <- partitionBy(x, numPartitions, rangePartitionFunc) + lapplyPartition(newRDD, partitionFunc) + }) + diff --git a/R/pkg/R/serialize.R b/R/pkg/R/serialize.R new file mode 100644 index 0000000000000..8a9c0c652ce24 --- /dev/null +++ b/R/pkg/R/serialize.R @@ -0,0 +1,195 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +# Utility functions to serialize R objects so they can be read in Java. + +# Type mapping from R to Java +# +# NULL -> Void +# integer -> Int +# character -> String +# logical -> Boolean +# double, numeric -> Double +# raw -> Array[Byte] +# Date -> Date +# POSIXct,POSIXlt -> Time +# +# list[T] -> Array[T], where T is one of above mentioned types +# environment -> Map[String, T], where T is a native type +# jobj -> Object, where jobj is an object created in the backend + +writeObject <- function(con, object, writeType = TRUE) { + # NOTE: In R vectors have same type as objects. So we don't support + # passing in vectors as arrays and instead require arrays to be passed + # as lists. + type <- class(object)[[1]] # class of POSIXlt is c("POSIXlt", "POSIXt") + if (writeType) { + writeType(con, type) + } + switch(type, + NULL = writeVoid(con), + integer = writeInt(con, object), + character = writeString(con, object), + logical = writeBoolean(con, object), + double = writeDouble(con, object), + numeric = writeDouble(con, object), + raw = writeRaw(con, object), + list = writeList(con, object), + jobj = writeJobj(con, object), + environment = writeEnv(con, object), + Date = writeDate(con, object), + POSIXlt = writeTime(con, object), + POSIXct = writeTime(con, object), + stop(paste("Unsupported type for serialization", type))) +} + +writeVoid <- function(con) { + # no value for NULL +} + +writeJobj <- function(con, value) { + if (!isValidJobj(value)) { + stop("invalid jobj ", value$id) + } + writeString(con, value$id) +} + +writeString <- function(con, value) { + writeInt(con, as.integer(nchar(value) + 1)) + writeBin(value, con, endian = "big") +} + +writeInt <- function(con, value) { + writeBin(as.integer(value), con, endian = "big") +} + +writeDouble <- function(con, value) { + writeBin(value, con, endian = "big") +} + +writeBoolean <- function(con, value) { + # TRUE becomes 1, FALSE becomes 0 + writeInt(con, as.integer(value)) +} + +writeRawSerialize <- function(outputCon, batch) { + outputSer <- serialize(batch, ascii = FALSE, connection = NULL) + writeRaw(outputCon, outputSer) +} + +writeRowSerialize <- function(outputCon, rows) { + invisible(lapply(rows, function(r) { + bytes <- serializeRow(r) + writeRaw(outputCon, bytes) + })) +} + +serializeRow <- function(row) { + rawObj <- rawConnection(raw(0), "wb") + on.exit(close(rawObj)) + writeRow(rawObj, row) + rawConnectionValue(rawObj) +} + +writeRow <- function(con, row) { + numCols <- length(row) + writeInt(con, numCols) + for (i in 1:numCols) { + writeObject(con, row[[i]]) + } +} + +writeRaw <- function(con, batch) { + writeInt(con, length(batch)) + writeBin(batch, con, endian = "big") +} + +writeType <- function(con, class) { + type <- switch(class, + NULL = "n", + integer = "i", + character = "c", + logical = "b", + double = "d", + numeric = "d", + raw = "r", + list = "l", + jobj = "j", + environment = "e", + Date = "D", + POSIXlt = 't', + POSIXct = 't', + stop(paste("Unsupported type for serialization", class))) + writeBin(charToRaw(type), con) +} + +# Used to pass arrays where all the elements are of the same type +writeList <- function(con, arr) { + # All elements should be of same type + elemType <- unique(sapply(arr, function(elem) { class(elem) })) + stopifnot(length(elemType) <= 1) + + # TODO: Empty lists are given type "character" right now. + # This may not work if the Java side expects array of any other type. + if (length(elemType) == 0) { + elemType <- class("somestring") + } + + writeType(con, elemType) + writeInt(con, length(arr)) + + if (length(arr) > 0) { + for (a in arr) { + writeObject(con, a, FALSE) + } + } +} + +# Used to pass in hash maps required on Java side. +writeEnv <- function(con, env) { + len <- length(env) + + writeInt(con, len) + if (len > 0) { + writeList(con, as.list(ls(env))) + vals <- lapply(ls(env), function(x) { env[[x]] }) + writeList(con, as.list(vals)) + } +} + +writeDate <- function(con, date) { + writeString(con, as.character(date)) +} + +writeTime <- function(con, time) { + writeDouble(con, as.double(time)) +} + +# Used to serialize in a list of objects where each +# object can be of a different type. Serialization format is +# for each object +writeArgs <- function(con, args) { + if (length(args) > 0) { + for (a in args) { + writeObject(con, a) + } + } +} + +writeStrings <- function(con, stringList) { + writeLines(unlist(stringList), con) +} diff --git a/R/pkg/R/sparkR.R b/R/pkg/R/sparkR.R new file mode 100644 index 0000000000000..bc82df01f0fff --- /dev/null +++ b/R/pkg/R/sparkR.R @@ -0,0 +1,266 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +.sparkREnv <- new.env() + +sparkR.onLoad <- function(libname, pkgname) { + .sparkREnv$libname <- libname +} + +# Utility function that returns TRUE if we have an active connection to the +# backend and FALSE otherwise +connExists <- function(env) { + tryCatch({ + exists(".sparkRCon", envir = env) && isOpen(env[[".sparkRCon"]]) + }, error = function(err) { + return(FALSE) + }) +} + +#' Stop the Spark context. +#' +#' Also terminates the backend this R session is connected to +sparkR.stop <- function() { + env <- .sparkREnv + if (exists(".sparkRCon", envir = env)) { + # cat("Stopping SparkR\n") + if (exists(".sparkRjsc", envir = env)) { + sc <- get(".sparkRjsc", envir = env) + callJMethod(sc, "stop") + rm(".sparkRjsc", envir = env) + } + + if (exists(".backendLaunched", envir = env)) { + callJStatic("SparkRHandler", "stopBackend") + } + + # Also close the connection and remove it from our env + conn <- get(".sparkRCon", envir = env) + close(conn) + + rm(".sparkRCon", envir = env) + rm(".scStartTime", envir = env) + } + + if (exists(".monitorConn", envir = env)) { + conn <- get(".monitorConn", envir = env) + close(conn) + rm(".monitorConn", envir = env) + } + + # Clear all broadcast variables we have + # as the jobj will not be valid if we restart the JVM + clearBroadcastVariables() + + # Clear jobj maps + clearJobjs() +} + +#' Initialize a new Spark Context. +#' +#' This function initializes a new SparkContext. +#' +#' @param master The Spark master URL. +#' @param appName Application name to register with cluster manager +#' @param sparkHome Spark Home directory +#' @param sparkEnvir Named list of environment variables to set on worker nodes. +#' @param sparkExecutorEnv Named list of environment variables to be used when launching executors. +#' @param sparkJars Character string vector of jar files to pass to the worker nodes. +#' @param sparkRLibDir The path where R is installed on the worker nodes. +#' @export +#' @examples +#'\dontrun{ +#' sc <- sparkR.init("local[2]", "SparkR", "/home/spark") +#' sc <- sparkR.init("local[2]", "SparkR", "/home/spark", +#' list(spark.executor.memory="1g")) +#' sc <- sparkR.init("yarn-client", "SparkR", "/home/spark", +#' list(spark.executor.memory="1g"), +#' list(LD_LIBRARY_PATH="/directory of JVM libraries (libjvm.so) on workers/"), +#' c("jarfile1.jar","jarfile2.jar")) +#'} + +sparkR.init <- function( + master = "", + appName = "SparkR", + sparkHome = Sys.getenv("SPARK_HOME"), + sparkEnvir = list(), + sparkExecutorEnv = list(), + sparkJars = "", + sparkRLibDir = "") { + + if (exists(".sparkRjsc", envir = .sparkREnv)) { + cat("Re-using existing Spark Context. Please stop SparkR with sparkR.stop() or restart R to create a new Spark Context\n") + return(get(".sparkRjsc", envir = .sparkREnv)) + } + + sparkMem <- Sys.getenv("SPARK_MEM", "512m") + jars <- suppressWarnings(normalizePath(as.character(sparkJars))) + + # Classpath separator is ";" on Windows + # URI needs four /// as from http://stackoverflow.com/a/18522792 + if (.Platform$OS.type == "unix") { + collapseChar <- ":" + uriSep <- "//" + } else { + collapseChar <- ";" + uriSep <- "////" + } + + existingPort <- Sys.getenv("EXISTING_SPARKR_BACKEND_PORT", "") + if (existingPort != "") { + backendPort <- existingPort + } else { + path <- tempfile(pattern = "backend_port") + launchBackend( + args = path, + sparkHome = sparkHome, + jars = jars, + sparkSubmitOpts = Sys.getenv("SPARKR_SUBMIT_ARGS", "sparkr-shell")) + # wait atmost 100 seconds for JVM to launch + wait <- 0.1 + for (i in 1:25) { + Sys.sleep(wait) + if (file.exists(path)) { + break + } + wait <- wait * 1.25 + } + if (!file.exists(path)) { + stop("JVM is not ready after 10 seconds") + } + f <- file(path, open='rb') + backendPort <- readInt(f) + monitorPort <- readInt(f) + close(f) + file.remove(path) + if (length(backendPort) == 0 || backendPort == 0 || + length(monitorPort) == 0 || monitorPort == 0) { + stop("JVM failed to launch") + } + assign(".monitorConn", socketConnection(port = monitorPort), envir = .sparkREnv) + assign(".backendLaunched", 1, envir = .sparkREnv) + } + + .sparkREnv$backendPort <- backendPort + tryCatch({ + connectBackend("localhost", backendPort) + }, error = function(err) { + stop("Failed to connect JVM\n") + }) + + if (nchar(sparkHome) != 0) { + sparkHome <- normalizePath(sparkHome) + } + + if (nchar(sparkRLibDir) != 0) { + .sparkREnv$libname <- sparkRLibDir + } + + sparkEnvirMap <- new.env() + for (varname in names(sparkEnvir)) { + sparkEnvirMap[[varname]] <- sparkEnvir[[varname]] + } + + sparkExecutorEnvMap <- new.env() + if (!any(names(sparkExecutorEnv) == "LD_LIBRARY_PATH")) { + sparkExecutorEnvMap[["LD_LIBRARY_PATH"]] <- paste0("$LD_LIBRARY_PATH:",Sys.getenv("LD_LIBRARY_PATH")) + } + for (varname in names(sparkExecutorEnv)) { + sparkExecutorEnvMap[[varname]] <- sparkExecutorEnv[[varname]] + } + + nonEmptyJars <- Filter(function(x) { x != "" }, jars) + localJarPaths <- sapply(nonEmptyJars, function(j) { utils::URLencode(paste("file:", uriSep, j, sep = "")) }) + + # Set the start time to identify jobjs + # Seconds resolution is good enough for this purpose, so use ints + assign(".scStartTime", as.integer(Sys.time()), envir = .sparkREnv) + + assign( + ".sparkRjsc", + callJStatic( + "org.apache.spark.api.r.RRDD", + "createSparkContext", + master, + appName, + as.character(sparkHome), + as.list(localJarPaths), + sparkEnvirMap, + sparkExecutorEnvMap), + envir = .sparkREnv + ) + + sc <- get(".sparkRjsc", envir = .sparkREnv) + + # Register a finalizer to sleep 1 seconds on R exit to make RStudio happy + reg.finalizer(.sparkREnv, function(x) { Sys.sleep(1) }, onexit = TRUE) + + sc +} + +#' Initialize a new SQLContext. +#' +#' This function creates a SparkContext from an existing JavaSparkContext and +#' then uses it to initialize a new SQLContext +#' +#' @param jsc The existing JavaSparkContext created with SparkR.init() +#' @export +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' sqlCtx <- sparkRSQL.init(sc) +#'} + +sparkRSQL.init <- function(jsc) { + if (exists(".sparkRSQLsc", envir = .sparkREnv)) { + return(get(".sparkRSQLsc", envir = .sparkREnv)) + } + + sqlCtx <- callJStatic("org.apache.spark.sql.api.r.SQLUtils", + "createSQLContext", + jsc) + assign(".sparkRSQLsc", sqlCtx, envir = .sparkREnv) + sqlCtx +} + +#' Initialize a new HiveContext. +#' +#' This function creates a HiveContext from an existing JavaSparkContext +#' +#' @param jsc The existing JavaSparkContext created with SparkR.init() +#' @export +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' sqlCtx <- sparkRHive.init(sc) +#'} + +sparkRHive.init <- function(jsc) { + if (exists(".sparkRHivesc", envir = .sparkREnv)) { + return(get(".sparkRHivesc", envir = .sparkREnv)) + } + + ssc <- callJMethod(jsc, "sc") + hiveCtx <- tryCatch({ + newJObject("org.apache.spark.sql.hive.HiveContext", ssc) + }, error = function(err) { + stop("Spark SQL is not built with Hive support") + }) + + assign(".sparkRHivesc", hiveCtx, envir = .sparkREnv) + hiveCtx +} diff --git a/R/pkg/R/utils.R b/R/pkg/R/utils.R new file mode 100644 index 0000000000000..c337fb0751e72 --- /dev/null +++ b/R/pkg/R/utils.R @@ -0,0 +1,467 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +# Utilities and Helpers + +# Given a JList, returns an R list containing the same elements, the number +# of which is optionally upper bounded by `logicalUpperBound` (by default, +# return all elements). Takes care of deserializations and type conversions. +convertJListToRList <- function(jList, flatten, logicalUpperBound = NULL, + serializedMode = "byte") { + arrSize <- callJMethod(jList, "size") + + # Datasets with serializedMode == "string" (such as an RDD directly generated by textFile()): + # each partition is not dense-packed into one Array[Byte], and `arrSize` + # here corresponds to number of logical elements. Thus we can prune here. + if (serializedMode == "string" && !is.null(logicalUpperBound)) { + arrSize <- min(arrSize, logicalUpperBound) + } + + results <- if (arrSize > 0) { + lapply(0:(arrSize - 1), + function(index) { + obj <- callJMethod(jList, "get", as.integer(index)) + + # Assume it is either an R object or a Java obj ref. + if (inherits(obj, "jobj")) { + if (isInstanceOf(obj, "scala.Tuple2")) { + # JavaPairRDD[Array[Byte], Array[Byte]]. + + keyBytes = callJMethod(obj, "_1") + valBytes = callJMethod(obj, "_2") + res <- list(unserialize(keyBytes), + unserialize(valBytes)) + } else { + stop(paste("utils.R: convertJListToRList only supports", + "RDD[Array[Byte]] and", + "JavaPairRDD[Array[Byte], Array[Byte]] for now")) + } + } else { + if (inherits(obj, "raw")) { + if (serializedMode == "byte") { + # RDD[Array[Byte]]. `obj` is a whole partition. + res <- unserialize(obj) + # For serialized datasets, `obj` (and `rRaw`) here corresponds to + # one whole partition dense-packed together. We deserialize the + # whole partition first, then cap the number of elements to be returned. + } else if (serializedMode == "row") { + res <- readRowList(obj) + # For DataFrames that have been converted to RRDDs, we call readRowList + # which will read in each row of the RRDD as a list and deserialize + # each element. + flatten <<- FALSE + # Use global assignment to change the flatten flag. This means + # we don't have to worry about the default argument in other functions + # e.g. collect + } + # TODO: is it possible to distinguish element boundary so that we can + # unserialize only what we need? + if (!is.null(logicalUpperBound)) { + res <- head(res, n = logicalUpperBound) + } + } else { + # obj is of a primitive Java type, is simplified to R's + # corresponding type. + res <- list(obj) + } + } + res + }) + } else { + list() + } + + if (flatten) { + as.list(unlist(results, recursive = FALSE)) + } else { + as.list(results) + } +} + +# Returns TRUE if `name` refers to an RDD in the given environment `env` +isRDD <- function(name, env) { + obj <- get(name, envir = env) + inherits(obj, "RDD") +} + +#' Compute the hashCode of an object +#' +#' Java-style function to compute the hashCode for the given object. Returns +#' an integer value. +#' +#' @details +#' This only works for integer, numeric and character types right now. +#' +#' @param key the object to be hashed +#' @return the hash code as an integer +#' @export +#' @examples +#' hashCode(1L) # 1 +#' hashCode(1.0) # 1072693248 +#' hashCode("1") # 49 +hashCode <- function(key) { + if (class(key) == "integer") { + as.integer(key[[1]]) + } else if (class(key) == "numeric") { + # Convert the double to long and then calculate the hash code + rawVec <- writeBin(key[[1]], con = raw()) + intBits <- packBits(rawToBits(rawVec), "integer") + as.integer(bitwXor(intBits[2], intBits[1])) + } else if (class(key) == "character") { + .Call("stringHashCode", key) + } else { + warning(paste("Could not hash object, returning 0", sep = "")) + as.integer(0) + } +} + +# Create a new RDD with serializedMode == "byte". +# Return itself if already in "byte" format. +serializeToBytes <- function(rdd) { + if (!inherits(rdd, "RDD")) { + stop("Argument 'rdd' is not an RDD type.") + } + if (getSerializedMode(rdd) != "byte") { + ser.rdd <- lapply(rdd, function(x) { x }) + return(ser.rdd) + } else { + return(rdd) + } +} + +# Create a new RDD with serializedMode == "string". +# Return itself if already in "string" format. +serializeToString <- function(rdd) { + if (!inherits(rdd, "RDD")) { + stop("Argument 'rdd' is not an RDD type.") + } + if (getSerializedMode(rdd) != "string") { + ser.rdd <- lapply(rdd, function(x) { toString(x) }) + # force it to create jrdd using "string" + getJRDD(ser.rdd, serializedMode = "string") + return(ser.rdd) + } else { + return(rdd) + } +} + +# Fast append to list by using an accumulator. +# http://stackoverflow.com/questions/17046336/here-we-go-again-append-an-element-to-a-list-in-r +# +# The accumulator should has three fields size, counter and data. +# This function amortizes the allocation cost by doubling +# the size of the list every time it fills up. +addItemToAccumulator <- function(acc, item) { + if(acc$counter == acc$size) { + acc$size <- acc$size * 2 + length(acc$data) <- acc$size + } + acc$counter <- acc$counter + 1 + acc$data[[acc$counter]] <- item +} + +initAccumulator <- function() { + acc <- new.env() + acc$counter <- 0 + acc$data <- list(NULL) + acc$size <- 1 + acc +} + +# Utility function to sort a list of key value pairs +# Used in unit tests +sortKeyValueList <- function(kv_list, decreasing = FALSE) { + keys <- sapply(kv_list, function(x) x[[1]]) + kv_list[order(keys, decreasing = decreasing)] +} + +# Utility function to generate compact R lists from grouped rdd +# Used in Join-family functions +# param: +# tagged_list R list generated via groupByKey with tags(1L, 2L, ...) +# cnull Boolean list where each element determines whether the corresponding list should +# be converted to list(NULL) +genCompactLists <- function(tagged_list, cnull) { + len <- length(tagged_list) + lists <- list(vector("list", len), vector("list", len)) + index <- list(1, 1) + + for (x in tagged_list) { + tag <- x[[1]] + idx <- index[[tag]] + lists[[tag]][[idx]] <- x[[2]] + index[[tag]] <- idx + 1 + } + + len <- lapply(index, function(x) x - 1) + for (i in (1:2)) { + if (cnull[[i]] && len[[i]] == 0) { + lists[[i]] <- list(NULL) + } else { + length(lists[[i]]) <- len[[i]] + } + } + + lists +} + +# Utility function to merge compact R lists +# Used in Join-family functions +# param: +# left/right Two compact lists ready for Cartesian product +mergeCompactLists <- function(left, right) { + result <- list() + length(result) <- length(left) * length(right) + index <- 1 + for (i in left) { + for (j in right) { + result[[index]] <- list(i, j) + index <- index + 1 + } + } + result +} + +# Utility function to wrapper above two operations +# Used in Join-family functions +# param (same as genCompactLists): +# tagged_list R list generated via groupByKey with tags(1L, 2L, ...) +# cnull Boolean list where each element determines whether the corresponding list should +# be converted to list(NULL) +joinTaggedList <- function(tagged_list, cnull) { + lists <- genCompactLists(tagged_list, cnull) + mergeCompactLists(lists[[1]], lists[[2]]) +} + +# Utility function to reduce a key-value list with predicate +# Used in *ByKey functions +# param +# pair key-value pair +# keys/vals env of key/value with hashes +# updateOrCreatePred predicate function +# updateFn update or merge function for existing pair, similar with `mergeVal` @combineByKey +# createFn create function for new pair, similar with `createCombiner` @combinebykey +updateOrCreatePair <- function(pair, keys, vals, updateOrCreatePred, updateFn, createFn) { + # assume hashVal bind to `$hash`, key/val with index 1/2 + hashVal <- pair$hash + key <- pair[[1]] + val <- pair[[2]] + if (updateOrCreatePred(pair)) { + assign(hashVal, do.call(updateFn, list(get(hashVal, envir = vals), val)), envir = vals) + } else { + assign(hashVal, do.call(createFn, list(val)), envir = vals) + assign(hashVal, key, envir = keys) + } +} + +# Utility function to convert key&values envs into key-val list +convertEnvsToList <- function(keys, vals) { + lapply(ls(keys), + function(name) { + list(keys[[name]], vals[[name]]) + }) +} + +# Utility function to capture the varargs into environment object +varargsToEnv <- function(...) { + pairs <- as.list(substitute(list(...)))[-1L] + env <- new.env() + for (name in names(pairs)) { + env[[name]] <- pairs[[name]] + } + env +} + +getStorageLevel <- function(newLevel = c("DISK_ONLY", + "DISK_ONLY_2", + "MEMORY_AND_DISK", + "MEMORY_AND_DISK_2", + "MEMORY_AND_DISK_SER", + "MEMORY_AND_DISK_SER_2", + "MEMORY_ONLY", + "MEMORY_ONLY_2", + "MEMORY_ONLY_SER", + "MEMORY_ONLY_SER_2", + "OFF_HEAP")) { + match.arg(newLevel) + storageLevel <- switch(newLevel, + "DISK_ONLY" = callJStatic("org.apache.spark.storage.StorageLevel", "DISK_ONLY"), + "DISK_ONLY_2" = callJStatic("org.apache.spark.storage.StorageLevel", "DISK_ONLY_2"), + "MEMORY_AND_DISK" = callJStatic("org.apache.spark.storage.StorageLevel", "MEMORY_AND_DISK"), + "MEMORY_AND_DISK_2" = callJStatic("org.apache.spark.storage.StorageLevel", "MEMORY_AND_DISK_2"), + "MEMORY_AND_DISK_SER" = callJStatic("org.apache.spark.storage.StorageLevel", "MEMORY_AND_DISK_SER"), + "MEMORY_AND_DISK_SER_2" = callJStatic("org.apache.spark.storage.StorageLevel", "MEMORY_AND_DISK_SER_2"), + "MEMORY_ONLY" = callJStatic("org.apache.spark.storage.StorageLevel", "MEMORY_ONLY"), + "MEMORY_ONLY_2" = callJStatic("org.apache.spark.storage.StorageLevel", "MEMORY_ONLY_2"), + "MEMORY_ONLY_SER" = callJStatic("org.apache.spark.storage.StorageLevel", "MEMORY_ONLY_SER"), + "MEMORY_ONLY_SER_2" = callJStatic("org.apache.spark.storage.StorageLevel", "MEMORY_ONLY_SER_2"), + "OFF_HEAP" = callJStatic("org.apache.spark.storage.StorageLevel", "OFF_HEAP")) +} + +# Utility function for functions where an argument needs to be integer but we want to allow +# the user to type (for example) `5` instead of `5L` to avoid a confusing error message. +numToInt <- function(num) { + if (as.integer(num) != num) { + warning(paste("Coercing", as.list(sys.call())[[2]], "to integer.")) + } + as.integer(num) +} + +# create a Seq in JVM +toSeq <- function(...) { + callJStatic("org.apache.spark.sql.api.r.SQLUtils", "toSeq", list(...)) +} + +# create a Seq in JVM from a list +listToSeq <- function(l) { + callJStatic("org.apache.spark.sql.api.r.SQLUtils", "toSeq", l) +} + +# Utility function to recursively traverse the Abstract Syntax Tree (AST) of a +# user defined function (UDF), and to examine variables in the UDF to decide +# if their values should be included in the new function environment. +# param +# node The current AST node in the traversal. +# oldEnv The original function environment. +# defVars An Accumulator of variables names defined in the function's calling environment, +# including function argument and local variable names. +# checkedFunc An environment of function objects examined during cleanClosure. It can +# be considered as a "name"-to-"list of functions" mapping. +# newEnv A new function environment to store necessary function dependencies, an output argument. +processClosure <- function(node, oldEnv, defVars, checkedFuncs, newEnv) { + nodeLen <- length(node) + + if (nodeLen > 1 && typeof(node) == "language") { + # Recursive case: current AST node is an internal node, check for its children. + if (length(node[[1]]) > 1) { + for (i in 1:nodeLen) { + processClosure(node[[i]], oldEnv, defVars, checkedFuncs, newEnv) + } + } else { # if node[[1]] is length of 1, check for some R special functions. + nodeChar <- as.character(node[[1]]) + if (nodeChar == "{" || nodeChar == "(") { # Skip start symbol. + for (i in 2:nodeLen) { + processClosure(node[[i]], oldEnv, defVars, checkedFuncs, newEnv) + } + } else if (nodeChar == "<-" || nodeChar == "=" || + nodeChar == "<<-") { # Assignment Ops. + defVar <- node[[2]] + if (length(defVar) == 1 && typeof(defVar) == "symbol") { + # Add the defined variable name into defVars. + addItemToAccumulator(defVars, as.character(defVar)) + } else { + processClosure(node[[2]], oldEnv, defVars, checkedFuncs, newEnv) + } + for (i in 3:nodeLen) { + processClosure(node[[i]], oldEnv, defVars, checkedFuncs, newEnv) + } + } else if (nodeChar == "function") { # Function definition. + # Add parameter names. + newArgs <- names(node[[2]]) + lapply(newArgs, function(arg) { addItemToAccumulator(defVars, arg) }) + for (i in 3:nodeLen) { + processClosure(node[[i]], oldEnv, defVars, checkedFuncs, newEnv) + } + } else if (nodeChar == "$") { # Skip the field. + processClosure(node[[2]], oldEnv, defVars, checkedFuncs, newEnv) + } else if (nodeChar == "::" || nodeChar == ":::") { + processClosure(node[[3]], oldEnv, defVars, checkedFuncs, newEnv) + } else { + for (i in 1:nodeLen) { + processClosure(node[[i]], oldEnv, defVars, checkedFuncs, newEnv) + } + } + } + } else if (nodeLen == 1 && + (typeof(node) == "symbol" || typeof(node) == "language")) { + # Base case: current AST node is a leaf node and a symbol or a function call. + nodeChar <- as.character(node) + if (!nodeChar %in% defVars$data) { # Not a function parameter or local variable. + func.env <- oldEnv + topEnv <- parent.env(.GlobalEnv) + # Search in function environment, and function's enclosing environments + # up to global environment. There is no need to look into package environments + # above the global or namespace environment that is not SparkR below the global, + # as they are assumed to be loaded on workers. + while (!identical(func.env, topEnv)) { + # Namespaces other than "SparkR" will not be searched. + if (!isNamespace(func.env) || + (getNamespaceName(func.env) == "SparkR" && + !(nodeChar %in% getNamespaceExports("SparkR")))) { # Only include SparkR internals. + # Set parameter 'inherits' to FALSE since we do not need to search in + # attached package environments. + if (tryCatch(exists(nodeChar, envir = func.env, inherits = FALSE), + error = function(e) { FALSE })) { + obj <- get(nodeChar, envir = func.env, inherits = FALSE) + if (is.function(obj)) { # If the node is a function call. + funcList <- mget(nodeChar, envir = checkedFuncs, inherits = F, + ifnotfound = list(list(NULL)))[[1]] + found <- sapply(funcList, function(func) { + ifelse(identical(func, obj), TRUE, FALSE) + }) + if (sum(found) > 0) { # If function has been examined, ignore. + break + } + # Function has not been examined, record it and recursively clean its closure. + assign(nodeChar, + if (is.null(funcList[[1]])) { + list(obj) + } else { + append(funcList, obj) + }, + envir = checkedFuncs) + obj <- cleanClosure(obj, checkedFuncs) + } + assign(nodeChar, obj, envir = newEnv) + break + } + } + + # Continue to search in enclosure. + func.env <- parent.env(func.env) + } + } + } +} + +# Utility function to get user defined function (UDF) dependencies (closure). +# More specifically, this function captures the values of free variables defined +# outside a UDF, and stores them in the function's environment. +# param +# func A function whose closure needs to be captured. +# checkedFunc An environment of function objects examined during cleanClosure. It can be +# considered as a "name"-to-"list of functions" mapping. +# return value +# a new version of func that has an correct environment (closure). +cleanClosure <- function(func, checkedFuncs = new.env()) { + if (is.function(func)) { + newEnv <- new.env(parent = .GlobalEnv) + func.body <- body(func) + oldEnv <- environment(func) + # defVars is an Accumulator of variables names defined in the function's calling + # environment. First, function's arguments are added to defVars. + defVars <- initAccumulator() + argNames <- names(as.list(args(func))) + for (i in 1:(length(argNames) - 1)) { # Remove the ending NULL in pairlist. + addItemToAccumulator(defVars, argNames[i]) + } + # Recursively examine variables in the function body. + processClosure(func.body, oldEnv, defVars, checkedFuncs, newEnv) + environment(func) <- newEnv + } + func +} diff --git a/R/pkg/R/zzz.R b/R/pkg/R/zzz.R new file mode 100644 index 0000000000000..80d796d467943 --- /dev/null +++ b/R/pkg/R/zzz.R @@ -0,0 +1,21 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +.onLoad <- function(libname, pkgname) { + sparkR.onLoad(libname, pkgname) +} + diff --git a/R/pkg/inst/profile/general.R b/R/pkg/inst/profile/general.R new file mode 100644 index 0000000000000..8fe711b622086 --- /dev/null +++ b/R/pkg/inst/profile/general.R @@ -0,0 +1,22 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +.First <- function() { + home <- Sys.getenv("SPARK_HOME") + .libPaths(c(file.path(home, "R", "lib"), .libPaths())) + Sys.setenv(NOAWT=1) +} diff --git a/R/pkg/inst/profile/shell.R b/R/pkg/inst/profile/shell.R new file mode 100644 index 0000000000000..7a7f2031152a0 --- /dev/null +++ b/R/pkg/inst/profile/shell.R @@ -0,0 +1,31 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +.First <- function() { + home <- Sys.getenv("SPARK_HOME") + .libPaths(c(file.path(home, "R", "lib"), .libPaths())) + Sys.setenv(NOAWT=1) + + library(utils) + library(SparkR) + sc <- sparkR.init(Sys.getenv("MASTER", unset = "")) + assign("sc", sc, envir=.GlobalEnv) + sqlCtx <- sparkRSQL.init(sc) + assign("sqlCtx", sqlCtx, envir=.GlobalEnv) + cat("\n Welcome to SparkR!") + cat("\n Spark context is available as sc, SQL context is available as sqlCtx\n") +} diff --git a/R/pkg/inst/tests/test_binaryFile.R b/R/pkg/inst/tests/test_binaryFile.R new file mode 100644 index 0000000000000..ca4218f3819f8 --- /dev/null +++ b/R/pkg/inst/tests/test_binaryFile.R @@ -0,0 +1,90 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +context("functions on binary files") + +# JavaSparkContext handle +sc <- sparkR.init() + +mockFile = c("Spark is pretty.", "Spark is awesome.") + +test_that("saveAsObjectFile()/objectFile() following textFile() works", { + fileName1 <- tempfile(pattern="spark-test", fileext=".tmp") + fileName2 <- tempfile(pattern="spark-test", fileext=".tmp") + writeLines(mockFile, fileName1) + + rdd <- textFile(sc, fileName1, 1) + saveAsObjectFile(rdd, fileName2) + rdd <- objectFile(sc, fileName2) + expect_equal(collect(rdd), as.list(mockFile)) + + unlink(fileName1) + unlink(fileName2, recursive = TRUE) +}) + +test_that("saveAsObjectFile()/objectFile() works on a parallelized list", { + fileName <- tempfile(pattern="spark-test", fileext=".tmp") + + l <- list(1, 2, 3) + rdd <- parallelize(sc, l, 1) + saveAsObjectFile(rdd, fileName) + rdd <- objectFile(sc, fileName) + expect_equal(collect(rdd), l) + + unlink(fileName, recursive = TRUE) +}) + +test_that("saveAsObjectFile()/objectFile() following RDD transformations works", { + fileName1 <- tempfile(pattern="spark-test", fileext=".tmp") + fileName2 <- tempfile(pattern="spark-test", fileext=".tmp") + writeLines(mockFile, fileName1) + + rdd <- textFile(sc, fileName1) + + words <- flatMap(rdd, function(line) { strsplit(line, " ")[[1]] }) + wordCount <- lapply(words, function(word) { list(word, 1L) }) + + counts <- reduceByKey(wordCount, "+", 2L) + + saveAsObjectFile(counts, fileName2) + counts <- objectFile(sc, fileName2) + + output <- collect(counts) + expected <- list(list("awesome.", 1), list("Spark", 2), list("pretty.", 1), + list("is", 2)) + expect_equal(sortKeyValueList(output), sortKeyValueList(expected)) + + unlink(fileName1) + unlink(fileName2, recursive = TRUE) +}) + +test_that("saveAsObjectFile()/objectFile() works with multiple paths", { + fileName1 <- tempfile(pattern="spark-test", fileext=".tmp") + fileName2 <- tempfile(pattern="spark-test", fileext=".tmp") + + rdd1 <- parallelize(sc, "Spark is pretty.") + saveAsObjectFile(rdd1, fileName1) + rdd2 <- parallelize(sc, "Spark is awesome.") + saveAsObjectFile(rdd2, fileName2) + + rdd <- objectFile(sc, c(fileName1, fileName2)) + expect_true(count(rdd) == 2) + + unlink(fileName1, recursive = TRUE) + unlink(fileName2, recursive = TRUE) +}) + diff --git a/R/pkg/inst/tests/test_binary_function.R b/R/pkg/inst/tests/test_binary_function.R new file mode 100644 index 0000000000000..c15553ba28517 --- /dev/null +++ b/R/pkg/inst/tests/test_binary_function.R @@ -0,0 +1,68 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +context("binary functions") + +# JavaSparkContext handle +sc <- sparkR.init() + +# Data +nums <- 1:10 +rdd <- parallelize(sc, nums, 2L) + +# File content +mockFile <- c("Spark is pretty.", "Spark is awesome.") + +test_that("union on two RDDs", { + actual <- collect(unionRDD(rdd, rdd)) + expect_equal(actual, as.list(rep(nums, 2))) + + fileName <- tempfile(pattern="spark-test", fileext=".tmp") + writeLines(mockFile, fileName) + + text.rdd <- textFile(sc, fileName) + union.rdd <- unionRDD(rdd, text.rdd) + actual <- collect(union.rdd) + expect_equal(actual, c(as.list(nums), mockFile)) + expect_true(getSerializedMode(union.rdd) == "byte") + + rdd<- map(text.rdd, function(x) {x}) + union.rdd <- unionRDD(rdd, text.rdd) + actual <- collect(union.rdd) + expect_equal(actual, as.list(c(mockFile, mockFile))) + expect_true(getSerializedMode(union.rdd) == "byte") + + unlink(fileName) +}) + +test_that("cogroup on two RDDs", { + rdd1 <- parallelize(sc, list(list(1, 1), list(2, 4))) + rdd2 <- parallelize(sc, list(list(1, 2), list(1, 3))) + cogroup.rdd <- cogroup(rdd1, rdd2, numPartitions = 2L) + actual <- collect(cogroup.rdd) + expect_equal(actual, + list(list(1, list(list(1), list(2, 3))), list(2, list(list(4), list())))) + + rdd1 <- parallelize(sc, list(list("a", 1), list("a", 4))) + rdd2 <- parallelize(sc, list(list("b", 2), list("a", 3))) + cogroup.rdd <- cogroup(rdd1, rdd2, numPartitions = 2L) + actual <- collect(cogroup.rdd) + + expected <- list(list("b", list(list(), list(2))), list("a", list(list(1, 4), list(3)))) + expect_equal(sortKeyValueList(actual), + sortKeyValueList(expected)) +}) diff --git a/R/pkg/inst/tests/test_broadcast.R b/R/pkg/inst/tests/test_broadcast.R new file mode 100644 index 0000000000000..fee91a427d6d5 --- /dev/null +++ b/R/pkg/inst/tests/test_broadcast.R @@ -0,0 +1,48 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +context("broadcast variables") + +# JavaSparkContext handle +sc <- sparkR.init() + +# Partitioned data +nums <- 1:2 +rrdd <- parallelize(sc, nums, 2L) + +test_that("using broadcast variable", { + randomMat <- matrix(nrow=10, ncol=10, data=rnorm(100)) + randomMatBr <- broadcast(sc, randomMat) + + useBroadcast <- function(x) { + sum(value(randomMatBr) * x) + } + actual <- collect(lapply(rrdd, useBroadcast)) + expected <- list(sum(randomMat) * 1, sum(randomMat) * 2) + expect_equal(actual, expected) +}) + +test_that("without using broadcast variable", { + randomMat <- matrix(nrow=10, ncol=10, data=rnorm(100)) + + useBroadcast <- function(x) { + sum(randomMat * x) + } + actual <- collect(lapply(rrdd, useBroadcast)) + expected <- list(sum(randomMat) * 1, sum(randomMat) * 2) + expect_equal(actual, expected) +}) diff --git a/R/pkg/inst/tests/test_context.R b/R/pkg/inst/tests/test_context.R new file mode 100644 index 0000000000000..e4aab37436a74 --- /dev/null +++ b/R/pkg/inst/tests/test_context.R @@ -0,0 +1,50 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +context("test functions in sparkR.R") + +test_that("repeatedly starting and stopping SparkR", { + for (i in 1:4) { + sc <- sparkR.init() + rdd <- parallelize(sc, 1:20, 2L) + expect_equal(count(rdd), 20) + sparkR.stop() + } +}) + +test_that("rdd GC across sparkR.stop", { + sparkR.stop() + sc <- sparkR.init() # sc should get id 0 + rdd1 <- parallelize(sc, 1:20, 2L) # rdd1 should get id 1 + rdd2 <- parallelize(sc, 1:10, 2L) # rdd2 should get id 2 + sparkR.stop() + + sc <- sparkR.init() # sc should get id 0 again + + # GC rdd1 before creating rdd3 and rdd2 after + rm(rdd1) + gc() + + rdd3 <- parallelize(sc, 1:20, 2L) # rdd3 should get id 1 now + rdd4 <- parallelize(sc, 1:10, 2L) # rdd4 should get id 2 now + + rm(rdd2) + gc() + + count(rdd3) + count(rdd4) +}) diff --git a/R/pkg/inst/tests/test_includePackage.R b/R/pkg/inst/tests/test_includePackage.R new file mode 100644 index 0000000000000..8152b448d0870 --- /dev/null +++ b/R/pkg/inst/tests/test_includePackage.R @@ -0,0 +1,57 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +context("include R packages") + +# JavaSparkContext handle +sc <- sparkR.init() + +# Partitioned data +nums <- 1:2 +rdd <- parallelize(sc, nums, 2L) + +test_that("include inside function", { + # Only run the test if plyr is installed. + if ("plyr" %in% rownames(installed.packages())) { + suppressPackageStartupMessages(library(plyr)) + generateData <- function(x) { + suppressPackageStartupMessages(library(plyr)) + attach(airquality) + result <- transform(Ozone, logOzone = log(Ozone)) + result + } + + data <- lapplyPartition(rdd, generateData) + actual <- collect(data) + } +}) + +test_that("use include package", { + # Only run the test if plyr is installed. + if ("plyr" %in% rownames(installed.packages())) { + suppressPackageStartupMessages(library(plyr)) + generateData <- function(x) { + attach(airquality) + result <- transform(Ozone, logOzone = log(Ozone)) + result + } + + includePackage(sc, plyr) + data <- lapplyPartition(rdd, generateData) + actual <- collect(data) + } +}) diff --git a/R/pkg/inst/tests/test_parallelize_collect.R b/R/pkg/inst/tests/test_parallelize_collect.R new file mode 100644 index 0000000000000..fff028657db37 --- /dev/null +++ b/R/pkg/inst/tests/test_parallelize_collect.R @@ -0,0 +1,109 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +context("parallelize() and collect()") + +# Mock data +numVector <- c(-10:97) +numList <- list(sqrt(1), sqrt(2), sqrt(3), 4 ** 10) +strVector <- c("Dexter Morgan: I suppose I should be upset, even feel", + "violated, but I'm not. No, in fact, I think this is a friendly", + "message, like \"Hey, wanna play?\" and yes, I want to play. ", + "I really, really do.") +strList <- list("Dexter Morgan: Blood. Sometimes it sets my teeth on edge, ", + "other times it helps me control the chaos.", + "Dexter Morgan: Harry and Dorris Morgan did a wonderful job ", + "raising me. But they're both dead now. I didn't kill them. Honest.") + +numPairs <- list(list(1, 1), list(1, 2), list(2, 2), list(2, 3)) +strPairs <- list(list(strList, strList), list(strList, strList)) + +# JavaSparkContext handle +jsc <- sparkR.init() + +# Tests + +test_that("parallelize() on simple vectors and lists returns an RDD", { + numVectorRDD <- parallelize(jsc, numVector, 1) + numVectorRDD2 <- parallelize(jsc, numVector, 10) + numListRDD <- parallelize(jsc, numList, 1) + numListRDD2 <- parallelize(jsc, numList, 4) + strVectorRDD <- parallelize(jsc, strVector, 2) + strVectorRDD2 <- parallelize(jsc, strVector, 3) + strListRDD <- parallelize(jsc, strList, 4) + strListRDD2 <- parallelize(jsc, strList, 1) + + rdds <- c(numVectorRDD, + numVectorRDD2, + numListRDD, + numListRDD2, + strVectorRDD, + strVectorRDD2, + strListRDD, + strListRDD2) + + for (rdd in rdds) { + expect_true(inherits(rdd, "RDD")) + expect_true(.hasSlot(rdd, "jrdd") + && inherits(rdd@jrdd, "jobj") + && isInstanceOf(rdd@jrdd, "org.apache.spark.api.java.JavaRDD")) + } +}) + +test_that("collect(), following a parallelize(), gives back the original collections", { + numVectorRDD <- parallelize(jsc, numVector, 10) + expect_equal(collect(numVectorRDD), as.list(numVector)) + + numListRDD <- parallelize(jsc, numList, 1) + numListRDD2 <- parallelize(jsc, numList, 4) + expect_equal(collect(numListRDD), as.list(numList)) + expect_equal(collect(numListRDD2), as.list(numList)) + + strVectorRDD <- parallelize(jsc, strVector, 2) + strVectorRDD2 <- parallelize(jsc, strVector, 3) + expect_equal(collect(strVectorRDD), as.list(strVector)) + expect_equal(collect(strVectorRDD2), as.list(strVector)) + + strListRDD <- parallelize(jsc, strList, 4) + strListRDD2 <- parallelize(jsc, strList, 1) + expect_equal(collect(strListRDD), as.list(strList)) + expect_equal(collect(strListRDD2), as.list(strList)) +}) + +test_that("regression: collect() following a parallelize() does not drop elements", { + # 10 %/% 6 = 1, ceiling(10 / 6) = 2 + collLen <- 10 + numPart <- 6 + expected <- runif(collLen) + actual <- collect(parallelize(jsc, expected, numPart)) + expect_equal(actual, as.list(expected)) +}) + +test_that("parallelize() and collect() work for lists of pairs (pairwise data)", { + # use the pairwise logical to indicate pairwise data + numPairsRDDD1 <- parallelize(jsc, numPairs, 1) + numPairsRDDD2 <- parallelize(jsc, numPairs, 2) + numPairsRDDD3 <- parallelize(jsc, numPairs, 3) + expect_equal(collect(numPairsRDDD1), numPairs) + expect_equal(collect(numPairsRDDD2), numPairs) + expect_equal(collect(numPairsRDDD3), numPairs) + # can also leave out the parameter name, if the params are supplied in order + strPairsRDDD1 <- parallelize(jsc, strPairs, 1) + strPairsRDDD2 <- parallelize(jsc, strPairs, 2) + expect_equal(collect(strPairsRDDD1), strPairs) + expect_equal(collect(strPairsRDDD2), strPairs) +}) diff --git a/R/pkg/inst/tests/test_rdd.R b/R/pkg/inst/tests/test_rdd.R new file mode 100644 index 0000000000000..f75e0817b9406 --- /dev/null +++ b/R/pkg/inst/tests/test_rdd.R @@ -0,0 +1,644 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +context("basic RDD functions") + +# JavaSparkContext handle +sc <- sparkR.init() + +# Data +nums <- 1:10 +rdd <- parallelize(sc, nums, 2L) + +intPairs <- list(list(1L, -1), list(2L, 100), list(2L, 1), list(1L, 200)) +intRdd <- parallelize(sc, intPairs, 2L) + +test_that("get number of partitions in RDD", { + expect_equal(numPartitions(rdd), 2) + expect_equal(numPartitions(intRdd), 2) +}) + +test_that("first on RDD", { + expect_true(first(rdd) == 1) + newrdd <- lapply(rdd, function(x) x + 1) + expect_true(first(newrdd) == 2) +}) + +test_that("count and length on RDD", { + expect_equal(count(rdd), 10) + expect_equal(length(rdd), 10) +}) + +test_that("count by values and keys", { + mods <- lapply(rdd, function(x) { x %% 3 }) + actual <- countByValue(mods) + expected <- list(list(0, 3L), list(1, 4L), list(2, 3L)) + expect_equal(sortKeyValueList(actual), sortKeyValueList(expected)) + + actual <- countByKey(intRdd) + expected <- list(list(2L, 2L), list(1L, 2L)) + expect_equal(sortKeyValueList(actual), sortKeyValueList(expected)) +}) + +test_that("lapply on RDD", { + multiples <- lapply(rdd, function(x) { 2 * x }) + actual <- collect(multiples) + expect_equal(actual, as.list(nums * 2)) +}) + +test_that("lapplyPartition on RDD", { + sums <- lapplyPartition(rdd, function(part) { sum(unlist(part)) }) + actual <- collect(sums) + expect_equal(actual, list(15, 40)) +}) + +test_that("mapPartitions on RDD", { + sums <- mapPartitions(rdd, function(part) { sum(unlist(part)) }) + actual <- collect(sums) + expect_equal(actual, list(15, 40)) +}) + +test_that("flatMap() on RDDs", { + flat <- flatMap(intRdd, function(x) { list(x, x) }) + actual <- collect(flat) + expect_equal(actual, rep(intPairs, each=2)) +}) + +test_that("filterRDD on RDD", { + filtered.rdd <- filterRDD(rdd, function(x) { x %% 2 == 0 }) + actual <- collect(filtered.rdd) + expect_equal(actual, list(2, 4, 6, 8, 10)) + + filtered.rdd <- Filter(function(x) { x[[2]] < 0 }, intRdd) + actual <- collect(filtered.rdd) + expect_equal(actual, list(list(1L, -1))) + + # Filter out all elements. + filtered.rdd <- filterRDD(rdd, function(x) { x > 10 }) + actual <- collect(filtered.rdd) + expect_equal(actual, list()) +}) + +test_that("lookup on RDD", { + vals <- lookup(intRdd, 1L) + expect_equal(vals, list(-1, 200)) + + vals <- lookup(intRdd, 3L) + expect_equal(vals, list()) +}) + +test_that("several transformations on RDD (a benchmark on PipelinedRDD)", { + rdd2 <- rdd + for (i in 1:12) + rdd2 <- lapplyPartitionsWithIndex( + rdd2, function(split, part) { + part <- as.list(unlist(part) * split + i) + }) + rdd2 <- lapply(rdd2, function(x) x + x) + actual <- collect(rdd2) + expected <- list(24, 24, 24, 24, 24, + 168, 170, 172, 174, 176) + expect_equal(actual, expected) +}) + +test_that("PipelinedRDD support actions: cache(), persist(), unpersist(), checkpoint()", { + # RDD + rdd2 <- rdd + # PipelinedRDD + rdd2 <- lapplyPartitionsWithIndex( + rdd2, + function(split, part) { + part <- as.list(unlist(part) * split) + }) + + cache(rdd2) + expect_true(rdd2@env$isCached) + rdd2 <- lapply(rdd2, function(x) x) + expect_false(rdd2@env$isCached) + + unpersist(rdd2) + expect_false(rdd2@env$isCached) + + persist(rdd2, "MEMORY_AND_DISK") + expect_true(rdd2@env$isCached) + rdd2 <- lapply(rdd2, function(x) x) + expect_false(rdd2@env$isCached) + + unpersist(rdd2) + expect_false(rdd2@env$isCached) + + setCheckpointDir(sc, "checkpoints") + checkpoint(rdd2) + expect_true(rdd2@env$isCheckpointed) + + rdd2 <- lapply(rdd2, function(x) x) + expect_false(rdd2@env$isCached) + expect_false(rdd2@env$isCheckpointed) + + # make sure the data is collectable + collect(rdd2) + + unlink("checkpoints") +}) + +test_that("reduce on RDD", { + sum <- reduce(rdd, "+") + expect_equal(sum, 55) + + # Also test with an inline function + sumInline <- reduce(rdd, function(x, y) { x + y }) + expect_equal(sumInline, 55) +}) + +test_that("lapply with dependency", { + fa <- 5 + multiples <- lapply(rdd, function(x) { fa * x }) + actual <- collect(multiples) + + expect_equal(actual, as.list(nums * 5)) +}) + +test_that("lapplyPartitionsWithIndex on RDDs", { + func <- function(splitIndex, part) { list(splitIndex, Reduce("+", part)) } + actual <- collect(lapplyPartitionsWithIndex(rdd, func), flatten = FALSE) + expect_equal(actual, list(list(0, 15), list(1, 40))) + + pairsRDD <- parallelize(sc, list(list(1, 2), list(3, 4), list(4, 8)), 1L) + partitionByParity <- function(key) { if (key %% 2 == 1) 0 else 1 } + mkTup <- function(splitIndex, part) { list(splitIndex, part) } + actual <- collect(lapplyPartitionsWithIndex( + partitionBy(pairsRDD, 2L, partitionByParity), + mkTup), + FALSE) + expect_equal(actual, list(list(0, list(list(1, 2), list(3, 4))), + list(1, list(list(4, 8))))) +}) + +test_that("sampleRDD() on RDDs", { + expect_equal(unlist(collect(sampleRDD(rdd, FALSE, 1.0, 2014L))), nums) +}) + +test_that("takeSample() on RDDs", { + # ported from RDDSuite.scala, modified seeds + data <- parallelize(sc, 1:100, 2L) + for (seed in 4:5) { + s <- takeSample(data, FALSE, 20L, seed) + expect_equal(length(s), 20L) + expect_equal(length(unique(s)), 20L) + for (elem in s) { + expect_true(elem >= 1 && elem <= 100) + } + } + for (seed in 4:5) { + s <- takeSample(data, FALSE, 200L, seed) + expect_equal(length(s), 100L) + expect_equal(length(unique(s)), 100L) + for (elem in s) { + expect_true(elem >= 1 && elem <= 100) + } + } + for (seed in 4:5) { + s <- takeSample(data, TRUE, 20L, seed) + expect_equal(length(s), 20L) + for (elem in s) { + expect_true(elem >= 1 && elem <= 100) + } + } + for (seed in 4:5) { + s <- takeSample(data, TRUE, 100L, seed) + expect_equal(length(s), 100L) + # Chance of getting all distinct elements is astronomically low, so test we + # got < 100 + expect_true(length(unique(s)) < 100L) + } + for (seed in 4:5) { + s <- takeSample(data, TRUE, 200L, seed) + expect_equal(length(s), 200L) + # Chance of getting all distinct elements is still quite low, so test we + # got < 100 + expect_true(length(unique(s)) < 100L) + } +}) + +test_that("mapValues() on pairwise RDDs", { + multiples <- mapValues(intRdd, function(x) { x * 2 }) + actual <- collect(multiples) + expected <- lapply(intPairs, function(x) { + list(x[[1]], x[[2]] * 2) + }) + expect_equal(sortKeyValueList(actual), sortKeyValueList(expected)) +}) + +test_that("flatMapValues() on pairwise RDDs", { + l <- parallelize(sc, list(list(1, c(1,2)), list(2, c(3,4)))) + actual <- collect(flatMapValues(l, function(x) { x })) + expect_equal(actual, list(list(1,1), list(1,2), list(2,3), list(2,4))) + + # Generate x to x+1 for every value + actual <- collect(flatMapValues(intRdd, function(x) { x:(x + 1) })) + expect_equal(actual, + list(list(1L, -1), list(1L, 0), list(2L, 100), list(2L, 101), + list(2L, 1), list(2L, 2), list(1L, 200), list(1L, 201))) +}) + +test_that("reduceByKeyLocally() on PairwiseRDDs", { + pairs <- parallelize(sc, list(list(1, 2), list(1.1, 3), list(1, 4)), 2L) + actual <- reduceByKeyLocally(pairs, "+") + expect_equal(sortKeyValueList(actual), + sortKeyValueList(list(list(1, 6), list(1.1, 3)))) + + pairs <- parallelize(sc, list(list("abc", 1.2), list(1.1, 0), list("abc", 1.3), + list("bb", 5)), 4L) + actual <- reduceByKeyLocally(pairs, "+") + expect_equal(sortKeyValueList(actual), + sortKeyValueList(list(list("abc", 2.5), list(1.1, 0), list("bb", 5)))) +}) + +test_that("distinct() on RDDs", { + nums.rep2 <- rep(1:10, 2) + rdd.rep2 <- parallelize(sc, nums.rep2, 2L) + uniques <- distinct(rdd.rep2) + actual <- sort(unlist(collect(uniques))) + expect_equal(actual, nums) +}) + +test_that("maximum() on RDDs", { + max <- maximum(rdd) + expect_equal(max, 10) +}) + +test_that("minimum() on RDDs", { + min <- minimum(rdd) + expect_equal(min, 1) +}) + +test_that("sumRDD() on RDDs", { + sum <- sumRDD(rdd) + expect_equal(sum, 55) +}) + +test_that("keyBy on RDDs", { + func <- function(x) { x*x } + keys <- keyBy(rdd, func) + actual <- collect(keys) + expect_equal(actual, lapply(nums, function(x) { list(func(x), x) })) +}) + +test_that("repartition/coalesce on RDDs", { + rdd <- parallelize(sc, 1:20, 4L) # each partition contains 5 elements + + # repartition + r1 <- repartition(rdd, 2) + expect_equal(numPartitions(r1), 2L) + count <- length(collectPartition(r1, 0L)) + expect_true(count >= 8 && count <= 12) + + r2 <- repartition(rdd, 6) + expect_equal(numPartitions(r2), 6L) + count <- length(collectPartition(r2, 0L)) + expect_true(count >=0 && count <= 4) + + # coalesce + r3 <- coalesce(rdd, 1) + expect_equal(numPartitions(r3), 1L) + count <- length(collectPartition(r3, 0L)) + expect_equal(count, 20) +}) + +test_that("sortBy() on RDDs", { + sortedRdd <- sortBy(rdd, function(x) { x * x }, ascending = FALSE) + actual <- collect(sortedRdd) + expect_equal(actual, as.list(sort(nums, decreasing = TRUE))) + + rdd2 <- parallelize(sc, sort(nums, decreasing = TRUE), 2L) + sortedRdd2 <- sortBy(rdd2, function(x) { x * x }) + actual <- collect(sortedRdd2) + expect_equal(actual, as.list(nums)) +}) + +test_that("takeOrdered() on RDDs", { + l <- list(10, 1, 2, 9, 3, 4, 5, 6, 7) + rdd <- parallelize(sc, l) + actual <- takeOrdered(rdd, 6L) + expect_equal(actual, as.list(sort(unlist(l)))[1:6]) + + l <- list("e", "d", "c", "d", "a") + rdd <- parallelize(sc, l) + actual <- takeOrdered(rdd, 3L) + expect_equal(actual, as.list(sort(unlist(l)))[1:3]) +}) + +test_that("top() on RDDs", { + l <- list(10, 1, 2, 9, 3, 4, 5, 6, 7) + rdd <- parallelize(sc, l) + actual <- top(rdd, 6L) + expect_equal(actual, as.list(sort(unlist(l), decreasing = TRUE))[1:6]) + + l <- list("e", "d", "c", "d", "a") + rdd <- parallelize(sc, l) + actual <- top(rdd, 3L) + expect_equal(actual, as.list(sort(unlist(l), decreasing = TRUE))[1:3]) +}) + +test_that("fold() on RDDs", { + actual <- fold(rdd, 0, "+") + expect_equal(actual, Reduce("+", nums, 0)) + + rdd <- parallelize(sc, list()) + actual <- fold(rdd, 0, "+") + expect_equal(actual, 0) +}) + +test_that("aggregateRDD() on RDDs", { + rdd <- parallelize(sc, list(1, 2, 3, 4)) + zeroValue <- list(0, 0) + seqOp <- function(x, y) { list(x[[1]] + y, x[[2]] + 1) } + combOp <- function(x, y) { list(x[[1]] + y[[1]], x[[2]] + y[[2]]) } + actual <- aggregateRDD(rdd, zeroValue, seqOp, combOp) + expect_equal(actual, list(10, 4)) + + rdd <- parallelize(sc, list()) + actual <- aggregateRDD(rdd, zeroValue, seqOp, combOp) + expect_equal(actual, list(0, 0)) +}) + +test_that("zipWithUniqueId() on RDDs", { + rdd <- parallelize(sc, list("a", "b", "c", "d", "e"), 3L) + actual <- collect(zipWithUniqueId(rdd)) + expected <- list(list("a", 0), list("b", 3), list("c", 1), + list("d", 4), list("e", 2)) + expect_equal(actual, expected) + + rdd <- parallelize(sc, list("a", "b", "c", "d", "e"), 1L) + actual <- collect(zipWithUniqueId(rdd)) + expected <- list(list("a", 0), list("b", 1), list("c", 2), + list("d", 3), list("e", 4)) + expect_equal(actual, expected) +}) + +test_that("zipWithIndex() on RDDs", { + rdd <- parallelize(sc, list("a", "b", "c", "d", "e"), 3L) + actual <- collect(zipWithIndex(rdd)) + expected <- list(list("a", 0), list("b", 1), list("c", 2), + list("d", 3), list("e", 4)) + expect_equal(actual, expected) + + rdd <- parallelize(sc, list("a", "b", "c", "d", "e"), 1L) + actual <- collect(zipWithIndex(rdd)) + expected <- list(list("a", 0), list("b", 1), list("c", 2), + list("d", 3), list("e", 4)) + expect_equal(actual, expected) +}) + +test_that("glom() on RDD", { + rdd <- parallelize(sc, as.list(1:4), 2L) + actual <- collect(glom(rdd)) + expect_equal(actual, list(list(1, 2), list(3, 4))) +}) + +test_that("keys() on RDDs", { + keys <- keys(intRdd) + actual <- collect(keys) + expect_equal(actual, lapply(intPairs, function(x) { x[[1]] })) +}) + +test_that("values() on RDDs", { + values <- values(intRdd) + actual <- collect(values) + expect_equal(actual, lapply(intPairs, function(x) { x[[2]] })) +}) + +test_that("pipeRDD() on RDDs", { + actual <- collect(pipeRDD(rdd, "more")) + expected <- as.list(as.character(1:10)) + expect_equal(actual, expected) + + trailed.rdd <- parallelize(sc, c("1", "", "2\n", "3\n\r\n")) + actual <- collect(pipeRDD(trailed.rdd, "sort")) + expected <- list("", "1", "2", "3") + expect_equal(actual, expected) + + rev.nums <- 9:0 + rev.rdd <- parallelize(sc, rev.nums, 2L) + actual <- collect(pipeRDD(rev.rdd, "sort")) + expected <- as.list(as.character(c(5:9, 0:4))) + expect_equal(actual, expected) +}) + +test_that("zipRDD() on RDDs", { + rdd1 <- parallelize(sc, 0:4, 2) + rdd2 <- parallelize(sc, 1000:1004, 2) + actual <- collect(zipRDD(rdd1, rdd2)) + expect_equal(actual, + list(list(0, 1000), list(1, 1001), list(2, 1002), list(3, 1003), list(4, 1004))) + + mockFile = c("Spark is pretty.", "Spark is awesome.") + fileName <- tempfile(pattern="spark-test", fileext=".tmp") + writeLines(mockFile, fileName) + + rdd <- textFile(sc, fileName, 1) + actual <- collect(zipRDD(rdd, rdd)) + expected <- lapply(mockFile, function(x) { list(x ,x) }) + expect_equal(actual, expected) + + rdd1 <- parallelize(sc, 0:1, 1) + actual <- collect(zipRDD(rdd1, rdd)) + expected <- lapply(0:1, function(x) { list(x, mockFile[x + 1]) }) + expect_equal(actual, expected) + + rdd1 <- map(rdd, function(x) { x }) + actual <- collect(zipRDD(rdd, rdd1)) + expected <- lapply(mockFile, function(x) { list(x, x) }) + expect_equal(actual, expected) + + unlink(fileName) +}) + +test_that("join() on pairwise RDDs", { + rdd1 <- parallelize(sc, list(list(1,1), list(2,4))) + rdd2 <- parallelize(sc, list(list(1,2), list(1,3))) + actual <- collect(join(rdd1, rdd2, 2L)) + expect_equal(sortKeyValueList(actual), + sortKeyValueList(list(list(1, list(1, 2)), list(1, list(1, 3))))) + + rdd1 <- parallelize(sc, list(list("a",1), list("b",4))) + rdd2 <- parallelize(sc, list(list("a",2), list("a",3))) + actual <- collect(join(rdd1, rdd2, 2L)) + expect_equal(sortKeyValueList(actual), + sortKeyValueList(list(list("a", list(1, 2)), list("a", list(1, 3))))) + + rdd1 <- parallelize(sc, list(list(1,1), list(2,2))) + rdd2 <- parallelize(sc, list(list(3,3), list(4,4))) + actual <- collect(join(rdd1, rdd2, 2L)) + expect_equal(actual, list()) + + rdd1 <- parallelize(sc, list(list("a",1), list("b",2))) + rdd2 <- parallelize(sc, list(list("c",3), list("d",4))) + actual <- collect(join(rdd1, rdd2, 2L)) + expect_equal(actual, list()) +}) + +test_that("leftOuterJoin() on pairwise RDDs", { + rdd1 <- parallelize(sc, list(list(1,1), list(2,4))) + rdd2 <- parallelize(sc, list(list(1,2), list(1,3))) + actual <- collect(leftOuterJoin(rdd1, rdd2, 2L)) + expected <- list(list(1, list(1, 2)), list(1, list(1, 3)), list(2, list(4, NULL))) + expect_equal(sortKeyValueList(actual), + sortKeyValueList(expected)) + + rdd1 <- parallelize(sc, list(list("a",1), list("b",4))) + rdd2 <- parallelize(sc, list(list("a",2), list("a",3))) + actual <- collect(leftOuterJoin(rdd1, rdd2, 2L)) + expected <- list(list("b", list(4, NULL)), list("a", list(1, 2)), list("a", list(1, 3))) + expect_equal(sortKeyValueList(actual), + sortKeyValueList(expected)) + + rdd1 <- parallelize(sc, list(list(1,1), list(2,2))) + rdd2 <- parallelize(sc, list(list(3,3), list(4,4))) + actual <- collect(leftOuterJoin(rdd1, rdd2, 2L)) + expected <- list(list(1, list(1, NULL)), list(2, list(2, NULL))) + expect_equal(sortKeyValueList(actual), + sortKeyValueList(expected)) + + rdd1 <- parallelize(sc, list(list("a",1), list("b",2))) + rdd2 <- parallelize(sc, list(list("c",3), list("d",4))) + actual <- collect(leftOuterJoin(rdd1, rdd2, 2L)) + expected <- list(list("b", list(2, NULL)), list("a", list(1, NULL))) + expect_equal(sortKeyValueList(actual), + sortKeyValueList(expected)) +}) + +test_that("rightOuterJoin() on pairwise RDDs", { + rdd1 <- parallelize(sc, list(list(1,2), list(1,3))) + rdd2 <- parallelize(sc, list(list(1,1), list(2,4))) + actual <- collect(rightOuterJoin(rdd1, rdd2, 2L)) + expected <- list(list(1, list(2, 1)), list(1, list(3, 1)), list(2, list(NULL, 4))) + expect_equal(sortKeyValueList(actual), sortKeyValueList(expected)) + + rdd1 <- parallelize(sc, list(list("a",2), list("a",3))) + rdd2 <- parallelize(sc, list(list("a",1), list("b",4))) + actual <- collect(rightOuterJoin(rdd1, rdd2, 2L)) + expected <- list(list("b", list(NULL, 4)), list("a", list(2, 1)), list("a", list(3, 1))) + expect_equal(sortKeyValueList(actual), + sortKeyValueList(expected)) + + rdd1 <- parallelize(sc, list(list(1,1), list(2,2))) + rdd2 <- parallelize(sc, list(list(3,3), list(4,4))) + actual <- collect(rightOuterJoin(rdd1, rdd2, 2L)) + expect_equal(sortKeyValueList(actual), + sortKeyValueList(list(list(3, list(NULL, 3)), list(4, list(NULL, 4))))) + + rdd1 <- parallelize(sc, list(list("a",1), list("b",2))) + rdd2 <- parallelize(sc, list(list("c",3), list("d",4))) + actual <- collect(rightOuterJoin(rdd1, rdd2, 2L)) + expect_equal(sortKeyValueList(actual), + sortKeyValueList(list(list("d", list(NULL, 4)), list("c", list(NULL, 3))))) +}) + +test_that("fullOuterJoin() on pairwise RDDs", { + rdd1 <- parallelize(sc, list(list(1,2), list(1,3), list(3,3))) + rdd2 <- parallelize(sc, list(list(1,1), list(2,4))) + actual <- collect(fullOuterJoin(rdd1, rdd2, 2L)) + expected <- list(list(1, list(2, 1)), list(1, list(3, 1)), list(2, list(NULL, 4)), list(3, list(3, NULL))) + expect_equal(sortKeyValueList(actual), sortKeyValueList(expected)) + + rdd1 <- parallelize(sc, list(list("a",2), list("a",3), list("c", 1))) + rdd2 <- parallelize(sc, list(list("a",1), list("b",4))) + actual <- collect(fullOuterJoin(rdd1, rdd2, 2L)) + expected <- list(list("b", list(NULL, 4)), list("a", list(2, 1)), list("a", list(3, 1)), list("c", list(1, NULL))) + expect_equal(sortKeyValueList(actual), + sortKeyValueList(expected)) + + rdd1 <- parallelize(sc, list(list(1,1), list(2,2))) + rdd2 <- parallelize(sc, list(list(3,3), list(4,4))) + actual <- collect(fullOuterJoin(rdd1, rdd2, 2L)) + expect_equal(sortKeyValueList(actual), + sortKeyValueList(list(list(1, list(1, NULL)), list(2, list(2, NULL)), list(3, list(NULL, 3)), list(4, list(NULL, 4))))) + + rdd1 <- parallelize(sc, list(list("a",1), list("b",2))) + rdd2 <- parallelize(sc, list(list("c",3), list("d",4))) + actual <- collect(fullOuterJoin(rdd1, rdd2, 2L)) + expect_equal(sortKeyValueList(actual), + sortKeyValueList(list(list("a", list(1, NULL)), list("b", list(2, NULL)), list("d", list(NULL, 4)), list("c", list(NULL, 3))))) +}) + +test_that("sortByKey() on pairwise RDDs", { + numPairsRdd <- map(rdd, function(x) { list (x, x) }) + sortedRdd <- sortByKey(numPairsRdd, ascending = FALSE) + actual <- collect(sortedRdd) + numPairs <- lapply(nums, function(x) { list (x, x) }) + expect_equal(actual, sortKeyValueList(numPairs, decreasing = TRUE)) + + rdd2 <- parallelize(sc, sort(nums, decreasing = TRUE), 2L) + numPairsRdd2 <- map(rdd2, function(x) { list (x, x) }) + sortedRdd2 <- sortByKey(numPairsRdd2) + actual <- collect(sortedRdd2) + expect_equal(actual, numPairs) + + # sort by string keys + l <- list(list("a", 1), list("b", 2), list("1", 3), list("d", 4), list("2", 5)) + rdd3 <- parallelize(sc, l, 2L) + sortedRdd3 <- sortByKey(rdd3) + actual <- collect(sortedRdd3) + expect_equal(actual, list(list("1", 3), list("2", 5), list("a", 1), list("b", 2), list("d", 4))) + + # test on the boundary cases + + # boundary case 1: the RDD to be sorted has only 1 partition + rdd4 <- parallelize(sc, l, 1L) + sortedRdd4 <- sortByKey(rdd4) + actual <- collect(sortedRdd4) + expect_equal(actual, list(list("1", 3), list("2", 5), list("a", 1), list("b", 2), list("d", 4))) + + # boundary case 2: the sorted RDD has only 1 partition + rdd5 <- parallelize(sc, l, 2L) + sortedRdd5 <- sortByKey(rdd5, numPartitions = 1L) + actual <- collect(sortedRdd5) + expect_equal(actual, list(list("1", 3), list("2", 5), list("a", 1), list("b", 2), list("d", 4))) + + # boundary case 3: the RDD to be sorted has only 1 element + l2 <- list(list("a", 1)) + rdd6 <- parallelize(sc, l2, 2L) + sortedRdd6 <- sortByKey(rdd6) + actual <- collect(sortedRdd6) + expect_equal(actual, l2) + + # boundary case 4: the RDD to be sorted has 0 element + l3 <- list() + rdd7 <- parallelize(sc, l3, 2L) + sortedRdd7 <- sortByKey(rdd7) + actual <- collect(sortedRdd7) + expect_equal(actual, l3) +}) + +test_that("collectAsMap() on a pairwise RDD", { + rdd <- parallelize(sc, list(list(1, 2), list(3, 4))) + vals <- collectAsMap(rdd) + expect_equal(vals, list(`1` = 2, `3` = 4)) + + rdd <- parallelize(sc, list(list("a", 1), list("b", 2))) + vals <- collectAsMap(rdd) + expect_equal(vals, list(a = 1, b = 2)) + + rdd <- parallelize(sc, list(list(1.1, 2.2), list(1.2, 2.4))) + vals <- collectAsMap(rdd) + expect_equal(vals, list(`1.1` = 2.2, `1.2` = 2.4)) + + rdd <- parallelize(sc, list(list(1, "a"), list(2, "b"))) + vals <- collectAsMap(rdd) + expect_equal(vals, list(`1` = "a", `2` = "b")) +}) diff --git a/R/pkg/inst/tests/test_shuffle.R b/R/pkg/inst/tests/test_shuffle.R new file mode 100644 index 0000000000000..d1da8232aea81 --- /dev/null +++ b/R/pkg/inst/tests/test_shuffle.R @@ -0,0 +1,209 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +context("partitionBy, groupByKey, reduceByKey etc.") + +# JavaSparkContext handle +sc <- sparkR.init() + +# Data +intPairs <- list(list(1L, -1), list(2L, 100), list(2L, 1), list(1L, 200)) +intRdd <- parallelize(sc, intPairs, 2L) + +doublePairs <- list(list(1.5, -1), list(2.5, 100), list(2.5, 1), list(1.5, 200)) +doubleRdd <- parallelize(sc, doublePairs, 2L) + +numPairs <- list(list(1L, 100), list(2L, 200), list(4L, -1), list(3L, 1), + list(3L, 0)) +numPairsRdd <- parallelize(sc, numPairs, length(numPairs)) + +strList <- list("Dexter Morgan: Blood. Sometimes it sets my teeth on edge and ", + "Dexter Morgan: Harry and Dorris Morgan did a wonderful job ") +strListRDD <- parallelize(sc, strList, 4) + +test_that("groupByKey for integers", { + grouped <- groupByKey(intRdd, 2L) + + actual <- collect(grouped) + + expected <- list(list(2L, list(100, 1)), list(1L, list(-1, 200))) + expect_equal(sortKeyValueList(actual), sortKeyValueList(expected)) +}) + +test_that("groupByKey for doubles", { + grouped <- groupByKey(doubleRdd, 2L) + + actual <- collect(grouped) + + expected <- list(list(1.5, list(-1, 200)), list(2.5, list(100, 1))) + expect_equal(sortKeyValueList(actual), sortKeyValueList(expected)) +}) + +test_that("reduceByKey for ints", { + reduced <- reduceByKey(intRdd, "+", 2L) + + actual <- collect(reduced) + + expected <- list(list(2L, 101), list(1L, 199)) + expect_equal(sortKeyValueList(actual), sortKeyValueList(expected)) +}) + +test_that("reduceByKey for doubles", { + reduced <- reduceByKey(doubleRdd, "+", 2L) + actual <- collect(reduced) + + expected <- list(list(1.5, 199), list(2.5, 101)) + expect_equal(sortKeyValueList(actual), sortKeyValueList(expected)) +}) + +test_that("combineByKey for ints", { + reduced <- combineByKey(intRdd, function(x) { x }, "+", "+", 2L) + + actual <- collect(reduced) + + expected <- list(list(2L, 101), list(1L, 199)) + expect_equal(sortKeyValueList(actual), sortKeyValueList(expected)) +}) + +test_that("combineByKey for doubles", { + reduced <- combineByKey(doubleRdd, function(x) { x }, "+", "+", 2L) + actual <- collect(reduced) + + expected <- list(list(1.5, 199), list(2.5, 101)) + expect_equal(sortKeyValueList(actual), sortKeyValueList(expected)) +}) + +test_that("aggregateByKey", { + # test aggregateByKey for int keys + rdd <- parallelize(sc, list(list(1, 1), list(1, 2), list(2, 3), list(2, 4))) + + zeroValue <- list(0, 0) + seqOp <- function(x, y) { list(x[[1]] + y, x[[2]] + 1) } + combOp <- function(x, y) { list(x[[1]] + y[[1]], x[[2]] + y[[2]]) } + aggregatedRDD <- aggregateByKey(rdd, zeroValue, seqOp, combOp, 2L) + + actual <- collect(aggregatedRDD) + + expected <- list(list(1, list(3, 2)), list(2, list(7, 2))) + expect_equal(sortKeyValueList(actual), sortKeyValueList(expected)) + + # test aggregateByKey for string keys + rdd <- parallelize(sc, list(list("a", 1), list("a", 2), list("b", 3), list("b", 4))) + + zeroValue <- list(0, 0) + seqOp <- function(x, y) { list(x[[1]] + y, x[[2]] + 1) } + combOp <- function(x, y) { list(x[[1]] + y[[1]], x[[2]] + y[[2]]) } + aggregatedRDD <- aggregateByKey(rdd, zeroValue, seqOp, combOp, 2L) + + actual <- collect(aggregatedRDD) + + expected <- list(list("a", list(3, 2)), list("b", list(7, 2))) + expect_equal(sortKeyValueList(actual), sortKeyValueList(expected)) +}) + +test_that("foldByKey", { + # test foldByKey for int keys + folded <- foldByKey(intRdd, 0, "+", 2L) + + actual <- collect(folded) + + expected <- list(list(2L, 101), list(1L, 199)) + expect_equal(sortKeyValueList(actual), sortKeyValueList(expected)) + + # test foldByKey for double keys + folded <- foldByKey(doubleRdd, 0, "+", 2L) + + actual <- collect(folded) + + expected <- list(list(1.5, 199), list(2.5, 101)) + expect_equal(sortKeyValueList(actual), sortKeyValueList(expected)) + + # test foldByKey for string keys + stringKeyPairs <- list(list("a", -1), list("b", 100), list("b", 1), list("a", 200)) + + stringKeyRDD <- parallelize(sc, stringKeyPairs) + folded <- foldByKey(stringKeyRDD, 0, "+", 2L) + + actual <- collect(folded) + + expected <- list(list("b", 101), list("a", 199)) + expect_equal(sortKeyValueList(actual), sortKeyValueList(expected)) + + # test foldByKey for empty pair RDD + rdd <- parallelize(sc, list()) + folded <- foldByKey(rdd, 0, "+", 2L) + actual <- collect(folded) + expected <- list() + expect_equal(actual, expected) + + # test foldByKey for RDD with only 1 pair + rdd <- parallelize(sc, list(list(1, 1))) + folded <- foldByKey(rdd, 0, "+", 2L) + actual <- collect(folded) + expected <- list(list(1, 1)) + expect_equal(actual, expected) +}) + +test_that("partitionBy() partitions data correctly", { + # Partition by magnitude + partitionByMagnitude <- function(key) { if (key >= 3) 1 else 0 } + + resultRDD <- partitionBy(numPairsRdd, 2L, partitionByMagnitude) + + expected_first <- list(list(1, 100), list(2, 200)) # key < 3 + expected_second <- list(list(4, -1), list(3, 1), list(3, 0)) # key >= 3 + actual_first <- collectPartition(resultRDD, 0L) + actual_second <- collectPartition(resultRDD, 1L) + + expect_equal(sortKeyValueList(actual_first), sortKeyValueList(expected_first)) + expect_equal(sortKeyValueList(actual_second), sortKeyValueList(expected_second)) +}) + +test_that("partitionBy works with dependencies", { + kOne <- 1 + partitionByParity <- function(key) { if (key %% 2 == kOne) 7 else 4 } + + # Partition by parity + resultRDD <- partitionBy(numPairsRdd, numPartitions = 2L, partitionByParity) + + # keys even; 100 %% 2 == 0 + expected_first <- list(list(2, 200), list(4, -1)) + # keys odd; 3 %% 2 == 1 + expected_second <- list(list(1, 100), list(3, 1), list(3, 0)) + actual_first <- collectPartition(resultRDD, 0L) + actual_second <- collectPartition(resultRDD, 1L) + + expect_equal(sortKeyValueList(actual_first), sortKeyValueList(expected_first)) + expect_equal(sortKeyValueList(actual_second), sortKeyValueList(expected_second)) +}) + +test_that("test partitionBy with string keys", { + words <- flatMap(strListRDD, function(line) { strsplit(line, " ")[[1]] }) + wordCount <- lapply(words, function(word) { list(word, 1L) }) + + resultRDD <- partitionBy(wordCount, 2L) + expected_first <- list(list("Dexter", 1), list("Dexter", 1)) + expected_second <- list(list("and", 1), list("and", 1)) + + actual_first <- Filter(function(item) { item[[1]] == "Dexter" }, + collectPartition(resultRDD, 0L)) + actual_second <- Filter(function(item) { item[[1]] == "and" }, + collectPartition(resultRDD, 1L)) + + expect_equal(sortKeyValueList(actual_first), sortKeyValueList(expected_first)) + expect_equal(sortKeyValueList(actual_second), sortKeyValueList(expected_second)) +}) diff --git a/R/pkg/inst/tests/test_sparkSQL.R b/R/pkg/inst/tests/test_sparkSQL.R new file mode 100644 index 0000000000000..cf5cf6d1692af --- /dev/null +++ b/R/pkg/inst/tests/test_sparkSQL.R @@ -0,0 +1,695 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +library(testthat) + +context("SparkSQL functions") + +# Tests for SparkSQL functions in SparkR + +sc <- sparkR.init() + +sqlCtx <- sparkRSQL.init(sc) + +mockLines <- c("{\"name\":\"Michael\"}", + "{\"name\":\"Andy\", \"age\":30}", + "{\"name\":\"Justin\", \"age\":19}") +jsonPath <- tempfile(pattern="sparkr-test", fileext=".tmp") +parquetPath <- tempfile(pattern="sparkr-test", fileext=".parquet") +writeLines(mockLines, jsonPath) + +test_that("infer types", { + expect_equal(infer_type(1L), "integer") + expect_equal(infer_type(1.0), "double") + expect_equal(infer_type("abc"), "string") + expect_equal(infer_type(TRUE), "boolean") + expect_equal(infer_type(as.Date("2015-03-11")), "date") + expect_equal(infer_type(as.POSIXlt("2015-03-11 12:13:04.043")), "timestamp") + expect_equal(infer_type(c(1L, 2L)), + list(type = 'array', elementType = "integer", containsNull = TRUE)) + expect_equal(infer_type(list(1L, 2L)), + list(type = 'array', elementType = "integer", containsNull = TRUE)) + expect_equal(infer_type(list(a = 1L, b = "2")), + list(type = "struct", + fields = list(list(name = "a", type = "integer", nullable = TRUE), + list(name = "b", type = "string", nullable = TRUE)))) + e <- new.env() + assign("a", 1L, envir = e) + expect_equal(infer_type(e), + list(type = "map", keyType = "string", valueType = "integer", + valueContainsNull = TRUE)) +}) + +test_that("create DataFrame from RDD", { + rdd <- lapply(parallelize(sc, 1:10), function(x) { list(x, as.character(x)) }) + df <- createDataFrame(sqlCtx, rdd, list("a", "b")) + expect_true(inherits(df, "DataFrame")) + expect_true(count(df) == 10) + expect_equal(columns(df), c("a", "b")) + expect_equal(dtypes(df), list(c("a", "int"), c("b", "string"))) + + df <- createDataFrame(sqlCtx, rdd) + expect_true(inherits(df, "DataFrame")) + expect_equal(columns(df), c("_1", "_2")) + + fields <- list(list(name = "a", type = "integer", nullable = TRUE), + list(name = "b", type = "string", nullable = TRUE)) + schema <- list(type = "struct", fields = fields) + df <- createDataFrame(sqlCtx, rdd, schema) + expect_true(inherits(df, "DataFrame")) + expect_equal(columns(df), c("a", "b")) + expect_equal(dtypes(df), list(c("a", "int"), c("b", "string"))) + + rdd <- lapply(parallelize(sc, 1:10), function(x) { list(a = x, b = as.character(x)) }) + df <- createDataFrame(sqlCtx, rdd) + expect_true(inherits(df, "DataFrame")) + expect_true(count(df) == 10) + expect_equal(columns(df), c("a", "b")) + expect_equal(dtypes(df), list(c("a", "int"), c("b", "string"))) +}) + +test_that("toDF", { + rdd <- lapply(parallelize(sc, 1:10), function(x) { list(x, as.character(x)) }) + df <- toDF(rdd, list("a", "b")) + expect_true(inherits(df, "DataFrame")) + expect_true(count(df) == 10) + expect_equal(columns(df), c("a", "b")) + expect_equal(dtypes(df), list(c("a", "int"), c("b", "string"))) + + df <- toDF(rdd) + expect_true(inherits(df, "DataFrame")) + expect_equal(columns(df), c("_1", "_2")) + + fields <- list(list(name = "a", type = "integer", nullable = TRUE), + list(name = "b", type = "string", nullable = TRUE)) + schema <- list(type = "struct", fields = fields) + df <- toDF(rdd, schema) + expect_true(inherits(df, "DataFrame")) + expect_equal(columns(df), c("a", "b")) + expect_equal(dtypes(df), list(c("a", "int"), c("b", "string"))) + + rdd <- lapply(parallelize(sc, 1:10), function(x) { list(a = x, b = as.character(x)) }) + df <- toDF(rdd) + expect_true(inherits(df, "DataFrame")) + expect_true(count(df) == 10) + expect_equal(columns(df), c("a", "b")) + expect_equal(dtypes(df), list(c("a", "int"), c("b", "string"))) +}) + +test_that("create DataFrame from list or data.frame", { + l <- list(list(1, 2), list(3, 4)) + df <- createDataFrame(sqlCtx, l, c("a", "b")) + expect_equal(columns(df), c("a", "b")) + + l <- list(list(a=1, b=2), list(a=3, b=4)) + df <- createDataFrame(sqlCtx, l) + expect_equal(columns(df), c("a", "b")) + + a <- 1:3 + b <- c("a", "b", "c") + ldf <- data.frame(a, b) + df <- createDataFrame(sqlCtx, ldf) + expect_equal(columns(df), c("a", "b")) + expect_equal(dtypes(df), list(c("a", "int"), c("b", "string"))) + expect_equal(count(df), 3) + ldf2 <- collect(df) + expect_equal(ldf$a, ldf2$a) +}) + +test_that("create DataFrame with different data types", { + l <- list(a = 1L, b = 2, c = TRUE, d = "ss", e = as.Date("2012-12-13"), + f = as.POSIXct("2015-03-15 12:13:14.056")) + df <- createDataFrame(sqlCtx, list(l)) + expect_equal(dtypes(df), list(c("a", "int"), c("b", "double"), c("c", "boolean"), + c("d", "string"), c("e", "date"), c("f", "timestamp"))) + expect_equal(count(df), 1) + expect_equal(collect(df), data.frame(l, stringsAsFactors = FALSE)) +}) + +# TODO: enable this test after fix serialization for nested object +#test_that("create DataFrame with nested array and struct", { +# e <- new.env() +# assign("n", 3L, envir = e) +# l <- list(1:10, list("a", "b"), e, list(a="aa", b=3L)) +# df <- createDataFrame(sqlCtx, list(l), c("a", "b", "c", "d")) +# expect_equal(dtypes(df), list(c("a", "array"), c("b", "array"), +# c("c", "map"), c("d", "struct"))) +# expect_equal(count(df), 1) +# ldf <- collect(df) +# expect_equal(ldf[1,], l[[1]]) +#}) + +test_that("jsonFile() on a local file returns a DataFrame", { + df <- jsonFile(sqlCtx, jsonPath) + expect_true(inherits(df, "DataFrame")) + expect_true(count(df) == 3) +}) + +test_that("jsonRDD() on a RDD with json string", { + rdd <- parallelize(sc, mockLines) + expect_true(count(rdd) == 3) + df <- jsonRDD(sqlCtx, rdd) + expect_true(inherits(df, "DataFrame")) + expect_true(count(df) == 3) + + rdd2 <- flatMap(rdd, function(x) c(x, x)) + df <- jsonRDD(sqlCtx, rdd2) + expect_true(inherits(df, "DataFrame")) + expect_true(count(df) == 6) +}) + +test_that("test cache, uncache and clearCache", { + df <- jsonFile(sqlCtx, jsonPath) + registerTempTable(df, "table1") + cacheTable(sqlCtx, "table1") + uncacheTable(sqlCtx, "table1") + clearCache(sqlCtx) + dropTempTable(sqlCtx, "table1") +}) + +test_that("test tableNames and tables", { + df <- jsonFile(sqlCtx, jsonPath) + registerTempTable(df, "table1") + expect_true(length(tableNames(sqlCtx)) == 1) + df <- tables(sqlCtx) + expect_true(count(df) == 1) + dropTempTable(sqlCtx, "table1") +}) + +test_that("registerTempTable() results in a queryable table and sql() results in a new DataFrame", { + df <- jsonFile(sqlCtx, jsonPath) + registerTempTable(df, "table1") + newdf <- sql(sqlCtx, "SELECT * FROM table1 where name = 'Michael'") + expect_true(inherits(newdf, "DataFrame")) + expect_true(count(newdf) == 1) + dropTempTable(sqlCtx, "table1") +}) + +test_that("insertInto() on a registered table", { + df <- loadDF(sqlCtx, jsonPath, "json") + saveDF(df, parquetPath, "parquet", "overwrite") + dfParquet <- loadDF(sqlCtx, parquetPath, "parquet") + + lines <- c("{\"name\":\"Bob\", \"age\":24}", + "{\"name\":\"James\", \"age\":35}") + jsonPath2 <- tempfile(pattern="jsonPath2", fileext=".tmp") + parquetPath2 <- tempfile(pattern = "parquetPath2", fileext = ".parquet") + writeLines(lines, jsonPath2) + df2 <- loadDF(sqlCtx, jsonPath2, "json") + saveDF(df2, parquetPath2, "parquet", "overwrite") + dfParquet2 <- loadDF(sqlCtx, parquetPath2, "parquet") + + registerTempTable(dfParquet, "table1") + insertInto(dfParquet2, "table1") + expect_true(count(sql(sqlCtx, "select * from table1")) == 5) + expect_true(first(sql(sqlCtx, "select * from table1 order by age"))$name == "Michael") + dropTempTable(sqlCtx, "table1") + + registerTempTable(dfParquet, "table1") + insertInto(dfParquet2, "table1", overwrite = TRUE) + expect_true(count(sql(sqlCtx, "select * from table1")) == 2) + expect_true(first(sql(sqlCtx, "select * from table1 order by age"))$name == "Bob") + dropTempTable(sqlCtx, "table1") +}) + +test_that("table() returns a new DataFrame", { + df <- jsonFile(sqlCtx, jsonPath) + registerTempTable(df, "table1") + tabledf <- table(sqlCtx, "table1") + expect_true(inherits(tabledf, "DataFrame")) + expect_true(count(tabledf) == 3) + dropTempTable(sqlCtx, "table1") +}) + +test_that("toRDD() returns an RRDD", { + df <- jsonFile(sqlCtx, jsonPath) + testRDD <- toRDD(df) + expect_true(inherits(testRDD, "RDD")) + expect_true(count(testRDD) == 3) +}) + +test_that("union on two RDDs created from DataFrames returns an RRDD", { + df <- jsonFile(sqlCtx, jsonPath) + RDD1 <- toRDD(df) + RDD2 <- toRDD(df) + unioned <- unionRDD(RDD1, RDD2) + expect_true(inherits(unioned, "RDD")) + expect_true(SparkR:::getSerializedMode(unioned) == "byte") + expect_true(collect(unioned)[[2]]$name == "Andy") +}) + +test_that("union on mixed serialization types correctly returns a byte RRDD", { + # Byte RDD + nums <- 1:10 + rdd <- parallelize(sc, nums, 2L) + + # String RDD + textLines <- c("Michael", + "Andy, 30", + "Justin, 19") + textPath <- tempfile(pattern="sparkr-textLines", fileext=".tmp") + writeLines(textLines, textPath) + textRDD <- textFile(sc, textPath) + + df <- jsonFile(sqlCtx, jsonPath) + dfRDD <- toRDD(df) + + unionByte <- unionRDD(rdd, dfRDD) + expect_true(inherits(unionByte, "RDD")) + expect_true(SparkR:::getSerializedMode(unionByte) == "byte") + expect_true(collect(unionByte)[[1]] == 1) + expect_true(collect(unionByte)[[12]]$name == "Andy") + + unionString <- unionRDD(textRDD, dfRDD) + expect_true(inherits(unionString, "RDD")) + expect_true(SparkR:::getSerializedMode(unionString) == "byte") + expect_true(collect(unionString)[[1]] == "Michael") + expect_true(collect(unionString)[[5]]$name == "Andy") +}) + +test_that("objectFile() works with row serialization", { + objectPath <- tempfile(pattern="spark-test", fileext=".tmp") + df <- jsonFile(sqlCtx, jsonPath) + dfRDD <- toRDD(df) + saveAsObjectFile(coalesce(dfRDD, 1L), objectPath) + objectIn <- objectFile(sc, objectPath) + + expect_true(inherits(objectIn, "RDD")) + expect_equal(SparkR:::getSerializedMode(objectIn), "byte") + expect_equal(collect(objectIn)[[2]]$age, 30) +}) + +test_that("lapply() on a DataFrame returns an RDD with the correct columns", { + df <- jsonFile(sqlCtx, jsonPath) + testRDD <- lapply(df, function(row) { + row$newCol <- row$age + 5 + row + }) + expect_true(inherits(testRDD, "RDD")) + collected <- collect(testRDD) + expect_true(collected[[1]]$name == "Michael") + expect_true(collected[[2]]$newCol == "35") +}) + +test_that("collect() returns a data.frame", { + df <- jsonFile(sqlCtx, jsonPath) + rdf <- collect(df) + expect_true(is.data.frame(rdf)) + expect_true(names(rdf)[1] == "age") + expect_true(nrow(rdf) == 3) + expect_true(ncol(rdf) == 2) +}) + +test_that("limit() returns DataFrame with the correct number of rows", { + df <- jsonFile(sqlCtx, jsonPath) + dfLimited <- limit(df, 2) + expect_true(inherits(dfLimited, "DataFrame")) + expect_true(count(dfLimited) == 2) +}) + +test_that("collect() and take() on a DataFrame return the same number of rows and columns", { + df <- jsonFile(sqlCtx, jsonPath) + expect_true(nrow(collect(df)) == nrow(take(df, 10))) + expect_true(ncol(collect(df)) == ncol(take(df, 10))) +}) + +test_that("multiple pipeline transformations starting with a DataFrame result in an RDD with the correct values", { + df <- jsonFile(sqlCtx, jsonPath) + first <- lapply(df, function(row) { + row$age <- row$age + 5 + row + }) + second <- lapply(first, function(row) { + row$testCol <- if (row$age == 35 && !is.na(row$age)) TRUE else FALSE + row + }) + expect_true(inherits(second, "RDD")) + expect_true(count(second) == 3) + expect_true(collect(second)[[2]]$age == 35) + expect_true(collect(second)[[2]]$testCol) + expect_false(collect(second)[[3]]$testCol) +}) + +test_that("cache(), persist(), and unpersist() on a DataFrame", { + df <- jsonFile(sqlCtx, jsonPath) + expect_false(df@env$isCached) + cache(df) + expect_true(df@env$isCached) + + unpersist(df) + expect_false(df@env$isCached) + + persist(df, "MEMORY_AND_DISK") + expect_true(df@env$isCached) + + unpersist(df) + expect_false(df@env$isCached) + + # make sure the data is collectable + expect_true(is.data.frame(collect(df))) +}) + +test_that("schema(), dtypes(), columns(), names() return the correct values/format", { + df <- jsonFile(sqlCtx, jsonPath) + testSchema <- schema(df) + expect_true(length(testSchema$fields()) == 2) + expect_true(testSchema$fields()[[1]]$dataType.toString() == "LongType") + expect_true(testSchema$fields()[[2]]$dataType.simpleString() == "string") + expect_true(testSchema$fields()[[1]]$name() == "age") + + testTypes <- dtypes(df) + expect_true(length(testTypes[[1]]) == 2) + expect_true(testTypes[[1]][1] == "age") + + testCols <- columns(df) + expect_true(length(testCols) == 2) + expect_true(testCols[2] == "name") + + testNames <- names(df) + expect_true(length(testNames) == 2) + expect_true(testNames[2] == "name") +}) + +test_that("head() and first() return the correct data", { + df <- jsonFile(sqlCtx, jsonPath) + testHead <- head(df) + expect_true(nrow(testHead) == 3) + expect_true(ncol(testHead) == 2) + + testHead2 <- head(df, 2) + expect_true(nrow(testHead2) == 2) + expect_true(ncol(testHead2) == 2) + + testFirst <- first(df) + expect_true(nrow(testFirst) == 1) +}) + +test_that("distinct() on DataFrames", { + lines <- c("{\"name\":\"Michael\"}", + "{\"name\":\"Andy\", \"age\":30}", + "{\"name\":\"Justin\", \"age\":19}", + "{\"name\":\"Justin\", \"age\":19}") + jsonPathWithDup <- tempfile(pattern="sparkr-test", fileext=".tmp") + writeLines(lines, jsonPathWithDup) + + df <- jsonFile(sqlCtx, jsonPathWithDup) + uniques <- distinct(df) + expect_true(inherits(uniques, "DataFrame")) + expect_true(count(uniques) == 3) +}) + +test_that("sampleDF on a DataFrame", { + df <- jsonFile(sqlCtx, jsonPath) + sampled <- sampleDF(df, FALSE, 1.0) + expect_equal(nrow(collect(sampled)), count(df)) + expect_true(inherits(sampled, "DataFrame")) + sampled2 <- sampleDF(df, FALSE, 0.1) + expect_true(count(sampled2) < 3) +}) + +test_that("select operators", { + df <- select(jsonFile(sqlCtx, jsonPath), "name", "age") + expect_true(inherits(df$name, "Column")) + expect_true(inherits(df[[2]], "Column")) + expect_true(inherits(df[["age"]], "Column")) + + expect_true(inherits(df[,1], "DataFrame")) + expect_equal(columns(df[,1]), c("name")) + expect_equal(columns(df[,"age"]), c("age")) + df2 <- df[,c("age", "name")] + expect_true(inherits(df2, "DataFrame")) + expect_equal(columns(df2), c("age", "name")) + + df$age2 <- df$age + expect_equal(columns(df), c("name", "age", "age2")) + expect_equal(count(where(df, df$age2 == df$age)), 2) + df$age2 <- df$age * 2 + expect_equal(columns(df), c("name", "age", "age2")) + expect_equal(count(where(df, df$age2 == df$age * 2)), 2) +}) + +test_that("select with column", { + df <- jsonFile(sqlCtx, jsonPath) + df1 <- select(df, "name") + expect_true(columns(df1) == c("name")) + expect_true(count(df1) == 3) + + df2 <- select(df, df$age) + expect_true(columns(df2) == c("age")) + expect_true(count(df2) == 3) +}) + +test_that("selectExpr() on a DataFrame", { + df <- jsonFile(sqlCtx, jsonPath) + selected <- selectExpr(df, "age * 2") + expect_true(names(selected) == "(age * 2)") + expect_equal(collect(selected), collect(select(df, df$age * 2L))) + + selected2 <- selectExpr(df, "name as newName", "abs(age) as age") + expect_equal(names(selected2), c("newName", "age")) + expect_true(count(selected2) == 3) +}) + +test_that("column calculation", { + df <- jsonFile(sqlCtx, jsonPath) + d <- collect(select(df, alias(df$age + 1, "age2"))) + expect_true(names(d) == c("age2")) + df2 <- select(df, lower(df$name), abs(df$age)) + expect_true(inherits(df2, "DataFrame")) + expect_true(count(df2) == 3) +}) + +test_that("load() from json file", { + df <- loadDF(sqlCtx, jsonPath, "json") + expect_true(inherits(df, "DataFrame")) + expect_true(count(df) == 3) +}) + +test_that("save() as parquet file", { + df <- loadDF(sqlCtx, jsonPath, "json") + saveDF(df, parquetPath, "parquet", mode="overwrite") + df2 <- loadDF(sqlCtx, parquetPath, "parquet") + expect_true(inherits(df2, "DataFrame")) + expect_true(count(df2) == 3) +}) + +test_that("test HiveContext", { + hiveCtx <- tryCatch({ + newJObject("org.apache.spark.sql.hive.test.TestHiveContext", ssc) + }, error = function(err) { + skip("Hive is not build with SparkSQL, skipped") + }) + df <- createExternalTable(hiveCtx, "json", jsonPath, "json") + expect_true(inherits(df, "DataFrame")) + expect_true(count(df) == 3) + df2 <- sql(hiveCtx, "select * from json") + expect_true(inherits(df2, "DataFrame")) + expect_true(count(df2) == 3) + + jsonPath2 <- tempfile(pattern="sparkr-test", fileext=".tmp") + saveAsTable(df, "json", "json", "append", path = jsonPath2) + df3 <- sql(hiveCtx, "select * from json") + expect_true(inherits(df3, "DataFrame")) + expect_true(count(df3) == 6) +}) + +test_that("column operators", { + c <- SparkR:::col("a") + c2 <- (- c + 1 - 2) * 3 / 4.0 + c3 <- (c + c2 - c2) * c2 %% c2 + c4 <- (c > c2) & (c2 <= c3) | (c == c2) & (c2 != c3) +}) + +test_that("column functions", { + c <- SparkR:::col("a") + c2 <- min(c) + max(c) + sum(c) + avg(c) + count(c) + abs(c) + sqrt(c) + c3 <- lower(c) + upper(c) + first(c) + last(c) + c4 <- approxCountDistinct(c) + countDistinct(c) + cast(c, "string") +}) + +test_that("string operators", { + df <- jsonFile(sqlCtx, jsonPath) + expect_equal(count(where(df, like(df$name, "A%"))), 1) + expect_equal(count(where(df, startsWith(df$name, "A"))), 1) + expect_equal(first(select(df, substr(df$name, 1, 2)))[[1]], "Mi") + expect_equal(collect(select(df, cast(df$age, "string")))[[2, 1]], "30") +}) + +test_that("group by", { + df <- jsonFile(sqlCtx, jsonPath) + df1 <- agg(df, name = "max", age = "sum") + expect_true(1 == count(df1)) + df1 <- agg(df, age2 = max(df$age)) + expect_true(1 == count(df1)) + expect_equal(columns(df1), c("age2")) + + gd <- groupBy(df, "name") + expect_true(inherits(gd, "GroupedData")) + df2 <- count(gd) + expect_true(inherits(df2, "DataFrame")) + expect_true(3 == count(df2)) + + df3 <- agg(gd, age = "sum") + expect_true(inherits(df3, "DataFrame")) + expect_true(3 == count(df3)) + + df3 <- agg(gd, age = sum(df$age)) + expect_true(inherits(df3, "DataFrame")) + expect_true(3 == count(df3)) + expect_equal(columns(df3), c("name", "age")) + + df4 <- sum(gd, "age") + expect_true(inherits(df4, "DataFrame")) + expect_true(3 == count(df4)) + expect_true(3 == count(mean(gd, "age"))) + expect_true(3 == count(max(gd, "age"))) +}) + +test_that("sortDF() and orderBy() on a DataFrame", { + df <- jsonFile(sqlCtx, jsonPath) + sorted <- sortDF(df, df$age) + expect_true(collect(sorted)[1,2] == "Michael") + + sorted2 <- sortDF(df, "name") + expect_true(collect(sorted2)[2,"age"] == 19) + + sorted3 <- orderBy(df, asc(df$age)) + expect_true(is.na(first(sorted3)$age)) + expect_true(collect(sorted3)[2, "age"] == 19) + + sorted4 <- orderBy(df, desc(df$name)) + expect_true(first(sorted4)$name == "Michael") + expect_true(collect(sorted4)[3,"name"] == "Andy") +}) + +test_that("filter() on a DataFrame", { + df <- jsonFile(sqlCtx, jsonPath) + filtered <- filter(df, "age > 20") + expect_true(count(filtered) == 1) + expect_true(collect(filtered)$name == "Andy") + filtered2 <- where(df, df$name != "Michael") + expect_true(count(filtered2) == 2) + expect_true(collect(filtered2)$age[2] == 19) +}) + +test_that("join() on a DataFrame", { + df <- jsonFile(sqlCtx, jsonPath) + + mockLines2 <- c("{\"name\":\"Michael\", \"test\": \"yes\"}", + "{\"name\":\"Andy\", \"test\": \"no\"}", + "{\"name\":\"Justin\", \"test\": \"yes\"}", + "{\"name\":\"Bob\", \"test\": \"yes\"}") + jsonPath2 <- tempfile(pattern="sparkr-test", fileext=".tmp") + writeLines(mockLines2, jsonPath2) + df2 <- jsonFile(sqlCtx, jsonPath2) + + joined <- join(df, df2) + expect_equal(names(joined), c("age", "name", "name", "test")) + expect_true(count(joined) == 12) + + joined2 <- join(df, df2, df$name == df2$name) + expect_equal(names(joined2), c("age", "name", "name", "test")) + expect_true(count(joined2) == 3) + + joined3 <- join(df, df2, df$name == df2$name, "right_outer") + expect_equal(names(joined3), c("age", "name", "name", "test")) + expect_true(count(joined3) == 4) + expect_true(is.na(collect(orderBy(joined3, joined3$age))$age[2])) + + joined4 <- select(join(df, df2, df$name == df2$name, "outer"), + alias(df$age + 5, "newAge"), df$name, df2$test) + expect_equal(names(joined4), c("newAge", "name", "test")) + expect_true(count(joined4) == 4) + expect_equal(collect(orderBy(joined4, joined4$name))$newAge[3], 24) +}) + +test_that("toJSON() returns an RDD of the correct values", { + df <- jsonFile(sqlCtx, jsonPath) + testRDD <- toJSON(df) + expect_true(inherits(testRDD, "RDD")) + expect_true(SparkR:::getSerializedMode(testRDD) == "string") + expect_equal(collect(testRDD)[[1]], mockLines[1]) +}) + +test_that("showDF()", { + df <- jsonFile(sqlCtx, jsonPath) + expect_output(showDF(df), "age name \nnull Michael\n30 Andy \n19 Justin ") +}) + +test_that("isLocal()", { + df <- jsonFile(sqlCtx, jsonPath) + expect_false(isLocal(df)) +}) + +test_that("unionAll(), subtract(), and intersect() on a DataFrame", { + df <- jsonFile(sqlCtx, jsonPath) + + lines <- c("{\"name\":\"Bob\", \"age\":24}", + "{\"name\":\"Andy\", \"age\":30}", + "{\"name\":\"James\", \"age\":35}") + jsonPath2 <- tempfile(pattern="sparkr-test", fileext=".tmp") + writeLines(lines, jsonPath2) + df2 <- loadDF(sqlCtx, jsonPath2, "json") + + unioned <- sortDF(unionAll(df, df2), df$age) + expect_true(inherits(unioned, "DataFrame")) + expect_true(count(unioned) == 6) + expect_true(first(unioned)$name == "Michael") + + subtracted <- sortDF(subtract(df, df2), desc(df$age)) + expect_true(inherits(unioned, "DataFrame")) + expect_true(count(subtracted) == 2) + expect_true(first(subtracted)$name == "Justin") + + intersected <- sortDF(intersect(df, df2), df$age) + expect_true(inherits(unioned, "DataFrame")) + expect_true(count(intersected) == 1) + expect_true(first(intersected)$name == "Andy") +}) + +test_that("withColumn() and withColumnRenamed()", { + df <- jsonFile(sqlCtx, jsonPath) + newDF <- withColumn(df, "newAge", df$age + 2) + expect_true(length(columns(newDF)) == 3) + expect_true(columns(newDF)[3] == "newAge") + expect_true(first(filter(newDF, df$name != "Michael"))$newAge == 32) + + newDF2 <- withColumnRenamed(df, "age", "newerAge") + expect_true(length(columns(newDF2)) == 2) + expect_true(columns(newDF2)[1] == "newerAge") +}) + +test_that("saveDF() on DataFrame and works with parquetFile", { + df <- jsonFile(sqlCtx, jsonPath) + saveDF(df, parquetPath, "parquet", mode="overwrite") + parquetDF <- parquetFile(sqlCtx, parquetPath) + expect_true(inherits(parquetDF, "DataFrame")) + expect_equal(count(df), count(parquetDF)) +}) + +test_that("parquetFile works with multiple input paths", { + df <- jsonFile(sqlCtx, jsonPath) + saveDF(df, parquetPath, "parquet", mode="overwrite") + parquetPath2 <- tempfile(pattern = "parquetPath2", fileext = ".parquet") + saveDF(df, parquetPath2, "parquet", mode="overwrite") + parquetDF <- parquetFile(sqlCtx, parquetPath, parquetPath2) + expect_true(inherits(parquetDF, "DataFrame")) + expect_true(count(parquetDF) == count(df)*2) +}) + +unlink(parquetPath) +unlink(jsonPath) diff --git a/R/pkg/inst/tests/test_take.R b/R/pkg/inst/tests/test_take.R new file mode 100644 index 0000000000000..7f4c7c315d787 --- /dev/null +++ b/R/pkg/inst/tests/test_take.R @@ -0,0 +1,67 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +context("tests RDD function take()") + +# Mock data +numVector <- c(-10:97) +numList <- list(sqrt(1), sqrt(2), sqrt(3), 4 ** 10) +strVector <- c("Dexter Morgan: I suppose I should be upset, even feel", + "violated, but I'm not. No, in fact, I think this is a friendly", + "message, like \"Hey, wanna play?\" and yes, I want to play. ", + "I really, really do.") +strList <- list("Dexter Morgan: Blood. Sometimes it sets my teeth on edge, ", + "other times it helps me control the chaos.", + "Dexter Morgan: Harry and Dorris Morgan did a wonderful job ", + "raising me. But they're both dead now. I didn't kill them. Honest.") + +# JavaSparkContext handle +jsc <- sparkR.init() + +test_that("take() gives back the original elements in correct count and order", { + numVectorRDD <- parallelize(jsc, numVector, 10) + # case: number of elements to take is less than the size of the first partition + expect_equal(take(numVectorRDD, 1), as.list(head(numVector, n = 1))) + # case: number of elements to take is the same as the size of the first partition + expect_equal(take(numVectorRDD, 11), as.list(head(numVector, n = 11))) + # case: number of elements to take is greater than all elements + expect_equal(take(numVectorRDD, length(numVector)), as.list(numVector)) + expect_equal(take(numVectorRDD, length(numVector) + 1), as.list(numVector)) + + numListRDD <- parallelize(jsc, numList, 1) + numListRDD2 <- parallelize(jsc, numList, 4) + expect_equal(take(numListRDD, 3), take(numListRDD2, 3)) + expect_equal(take(numListRDD, 5), take(numListRDD2, 5)) + expect_equal(take(numListRDD, 1), as.list(head(numList, n = 1))) + expect_equal(take(numListRDD2, 999), numList) + + strVectorRDD <- parallelize(jsc, strVector, 2) + strVectorRDD2 <- parallelize(jsc, strVector, 3) + expect_equal(take(strVectorRDD, 4), as.list(strVector)) + expect_equal(take(strVectorRDD2, 2), as.list(head(strVector, n = 2))) + + strListRDD <- parallelize(jsc, strList, 4) + strListRDD2 <- parallelize(jsc, strList, 1) + expect_equal(take(strListRDD, 3), as.list(head(strList, n = 3))) + expect_equal(take(strListRDD2, 1), as.list(head(strList, n = 1))) + + expect_true(length(take(strListRDD, 0)) == 0) + expect_true(length(take(strVectorRDD, 0)) == 0) + expect_true(length(take(numListRDD, 0)) == 0) + expect_true(length(take(numVectorRDD, 0)) == 0) +}) + diff --git a/R/pkg/inst/tests/test_textFile.R b/R/pkg/inst/tests/test_textFile.R new file mode 100644 index 0000000000000..6b87b4b3e0b08 --- /dev/null +++ b/R/pkg/inst/tests/test_textFile.R @@ -0,0 +1,162 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +context("the textFile() function") + +# JavaSparkContext handle +sc <- sparkR.init() + +mockFile = c("Spark is pretty.", "Spark is awesome.") + +test_that("textFile() on a local file returns an RDD", { + fileName <- tempfile(pattern="spark-test", fileext=".tmp") + writeLines(mockFile, fileName) + + rdd <- textFile(sc, fileName) + expect_true(inherits(rdd, "RDD")) + expect_true(count(rdd) > 0) + expect_true(count(rdd) == 2) + + unlink(fileName) +}) + +test_that("textFile() followed by a collect() returns the same content", { + fileName <- tempfile(pattern="spark-test", fileext=".tmp") + writeLines(mockFile, fileName) + + rdd <- textFile(sc, fileName) + expect_equal(collect(rdd), as.list(mockFile)) + + unlink(fileName) +}) + +test_that("textFile() word count works as expected", { + fileName <- tempfile(pattern="spark-test", fileext=".tmp") + writeLines(mockFile, fileName) + + rdd <- textFile(sc, fileName) + + words <- flatMap(rdd, function(line) { strsplit(line, " ")[[1]] }) + wordCount <- lapply(words, function(word) { list(word, 1L) }) + + counts <- reduceByKey(wordCount, "+", 2L) + output <- collect(counts) + expected <- list(list("pretty.", 1), list("is", 2), list("awesome.", 1), + list("Spark", 2)) + expect_equal(sortKeyValueList(output), sortKeyValueList(expected)) + + unlink(fileName) +}) + +test_that("several transformations on RDD created by textFile()", { + fileName <- tempfile(pattern="spark-test", fileext=".tmp") + writeLines(mockFile, fileName) + + rdd <- textFile(sc, fileName) # RDD + for (i in 1:10) { + # PipelinedRDD initially created from RDD + rdd <- lapply(rdd, function(x) paste(x, x)) + } + collect(rdd) + + unlink(fileName) +}) + +test_that("textFile() followed by a saveAsTextFile() returns the same content", { + fileName1 <- tempfile(pattern="spark-test", fileext=".tmp") + fileName2 <- tempfile(pattern="spark-test", fileext=".tmp") + writeLines(mockFile, fileName1) + + rdd <- textFile(sc, fileName1, 1L) + saveAsTextFile(rdd, fileName2) + rdd <- textFile(sc, fileName2) + expect_equal(collect(rdd), as.list(mockFile)) + + unlink(fileName1) + unlink(fileName2) +}) + +test_that("saveAsTextFile() on a parallelized list works as expected", { + fileName <- tempfile(pattern="spark-test", fileext=".tmp") + l <- list(1, 2, 3) + rdd <- parallelize(sc, l, 1L) + saveAsTextFile(rdd, fileName) + rdd <- textFile(sc, fileName) + expect_equal(collect(rdd), lapply(l, function(x) {toString(x)})) + + unlink(fileName) +}) + +test_that("textFile() and saveAsTextFile() word count works as expected", { + fileName1 <- tempfile(pattern="spark-test", fileext=".tmp") + fileName2 <- tempfile(pattern="spark-test", fileext=".tmp") + writeLines(mockFile, fileName1) + + rdd <- textFile(sc, fileName1) + + words <- flatMap(rdd, function(line) { strsplit(line, " ")[[1]] }) + wordCount <- lapply(words, function(word) { list(word, 1L) }) + + counts <- reduceByKey(wordCount, "+", 2L) + + saveAsTextFile(counts, fileName2) + rdd <- textFile(sc, fileName2) + + output <- collect(rdd) + expected <- list(list("awesome.", 1), list("Spark", 2), + list("pretty.", 1), list("is", 2)) + expectedStr <- lapply(expected, function(x) { toString(x) }) + expect_equal(sortKeyValueList(output), sortKeyValueList(expectedStr)) + + unlink(fileName1) + unlink(fileName2) +}) + +test_that("textFile() on multiple paths", { + fileName1 <- tempfile(pattern="spark-test", fileext=".tmp") + fileName2 <- tempfile(pattern="spark-test", fileext=".tmp") + writeLines("Spark is pretty.", fileName1) + writeLines("Spark is awesome.", fileName2) + + rdd <- textFile(sc, c(fileName1, fileName2)) + expect_true(count(rdd) == 2) + + unlink(fileName1) + unlink(fileName2) +}) + +test_that("Pipelined operations on RDDs created using textFile", { + fileName <- tempfile(pattern="spark-test", fileext=".tmp") + writeLines(mockFile, fileName) + + rdd <- textFile(sc, fileName) + + lengths <- lapply(rdd, function(x) { length(x) }) + expect_equal(collect(lengths), list(1, 1)) + + lengthsPipelined <- lapply(lengths, function(x) { x + 10 }) + expect_equal(collect(lengthsPipelined), list(11, 11)) + + lengths30 <- lapply(lengthsPipelined, function(x) { x + 20 }) + expect_equal(collect(lengths30), list(31, 31)) + + lengths20 <- lapply(lengths, function(x) { x + 20 }) + expect_equal(collect(lengths20), list(21, 21)) + + unlink(fileName) +}) + diff --git a/R/pkg/inst/tests/test_utils.R b/R/pkg/inst/tests/test_utils.R new file mode 100644 index 0000000000000..9c5bb427932b4 --- /dev/null +++ b/R/pkg/inst/tests/test_utils.R @@ -0,0 +1,137 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +context("functions in utils.R") + +# JavaSparkContext handle +sc <- sparkR.init() + +test_that("convertJListToRList() gives back (deserializes) the original JLists + of strings and integers", { + # It's hard to manually create a Java List using rJava, since it does not + # support generics well. Instead, we rely on collect() returning a + # JList. + nums <- as.list(1:10) + rdd <- parallelize(sc, nums, 1L) + jList <- callJMethod(rdd@jrdd, "collect") + rList <- convertJListToRList(jList, flatten = TRUE) + expect_equal(rList, nums) + + strs <- as.list("hello", "spark") + rdd <- parallelize(sc, strs, 2L) + jList <- callJMethod(rdd@jrdd, "collect") + rList <- convertJListToRList(jList, flatten = TRUE) + expect_equal(rList, strs) +}) + +test_that("serializeToBytes on RDD", { + # File content + mockFile <- c("Spark is pretty.", "Spark is awesome.") + fileName <- tempfile(pattern="spark-test", fileext=".tmp") + writeLines(mockFile, fileName) + + text.rdd <- textFile(sc, fileName) + expect_true(getSerializedMode(text.rdd) == "string") + ser.rdd <- serializeToBytes(text.rdd) + expect_equal(collect(ser.rdd), as.list(mockFile)) + expect_true(getSerializedMode(ser.rdd) == "byte") + + unlink(fileName) +}) + +test_that("cleanClosure on R functions", { + y <- c(1, 2, 3) + g <- function(x) { x + 1 } + f <- function(x) { g(x) + y } + newF <- cleanClosure(f) + env <- environment(newF) + expect_equal(length(ls(env)), 2) # y, g + actual <- get("y", envir = env, inherits = FALSE) + expect_equal(actual, y) + actual <- get("g", envir = env, inherits = FALSE) + expect_equal(actual, g) + + # Test for nested enclosures and package variables. + env2 <- new.env() + funcEnv <- new.env(parent = env2) + f <- function(x) { log(g(x) + y) } + environment(f) <- funcEnv # enclosing relationship: f -> funcEnv -> env2 -> .GlobalEnv + newF <- cleanClosure(f) + env <- environment(newF) + expect_equal(length(ls(env)), 2) # "min" should not be included + actual <- get("y", envir = env, inherits = FALSE) + expect_equal(actual, y) + actual <- get("g", envir = env, inherits = FALSE) + expect_equal(actual, g) + + base <- c(1, 2, 3) + l <- list(field = matrix(1)) + field <- matrix(2) + defUse <- 3 + g <- function(x) { x + y } + f <- function(x) { + defUse <- base::as.integer(x) + 1 # Test for access operators `::`. + lapply(x, g) + 1 # Test for capturing function call "g"'s closure as a argument of lapply. + l$field[1,1] <- 3 # Test for access operators `$`. + res <- defUse + l$field[1,] # Test for def-use chain of "defUse", and "" symbol. + f(res) # Test for recursive calls. + } + newF <- cleanClosure(f) + env <- environment(newF) + expect_equal(length(ls(env)), 3) # Only "g", "l" and "f". No "base", "field" or "defUse". + expect_true("g" %in% ls(env)) + expect_true("l" %in% ls(env)) + expect_true("f" %in% ls(env)) + expect_equal(get("l", envir = env, inherits = FALSE), l) + # "y" should be in the environemnt of g. + newG <- get("g", envir = env, inherits = FALSE) + env <- environment(newG) + expect_equal(length(ls(env)), 1) + actual <- get("y", envir = env, inherits = FALSE) + expect_equal(actual, y) + + # Test for function (and variable) definitions. + f <- function(x) { + g <- function(y) { y * 2 } + g(x) + } + newF <- cleanClosure(f) + env <- environment(newF) + expect_equal(length(ls(env)), 0) # "y" and "g" should not be included. + + # Test for overriding variables in base namespace (Issue: SparkR-196). + nums <- as.list(1:10) + rdd <- parallelize(sc, nums, 2L) + t = 4 # Override base::t in .GlobalEnv. + f <- function(x) { x > t } + newF <- cleanClosure(f) + env <- environment(newF) + expect_equal(ls(env), "t") + expect_equal(get("t", envir = env, inherits = FALSE), t) + actual <- collect(lapply(rdd, f)) + expected <- as.list(c(rep(FALSE, 4), rep(TRUE, 6))) + expect_equal(actual, expected) + + # Test for broadcast variables. + a <- matrix(nrow=10, ncol=10, data=rnorm(100)) + aBroadcast <- broadcast(sc, a) + normMultiply <- function(x) { norm(aBroadcast$value) * x } + newnormMultiply <- SparkR:::cleanClosure(normMultiply) + env <- environment(newnormMultiply) + expect_equal(ls(env), "aBroadcast") + expect_equal(get("aBroadcast", envir = env, inherits = FALSE), aBroadcast) +}) diff --git a/R/pkg/inst/worker/daemon.R b/R/pkg/inst/worker/daemon.R new file mode 100644 index 0000000000000..3584b418a71a9 --- /dev/null +++ b/R/pkg/inst/worker/daemon.R @@ -0,0 +1,52 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +# Worker daemon + +rLibDir <- Sys.getenv("SPARKR_RLIBDIR") +script <- paste(rLibDir, "SparkR/worker/worker.R", sep = "/") + +# preload SparkR package, speedup worker +.libPaths(c(rLibDir, .libPaths())) +suppressPackageStartupMessages(library(SparkR)) + +port <- as.integer(Sys.getenv("SPARKR_WORKER_PORT")) +inputCon <- socketConnection(port = port, open = "rb", blocking = TRUE, timeout = 3600) + +while (TRUE) { + ready <- socketSelect(list(inputCon)) + if (ready) { + port <- SparkR:::readInt(inputCon) + # There is a small chance that it could be interrupted by signal, retry one time + if (length(port) == 0) { + port <- SparkR:::readInt(inputCon) + if (length(port) == 0) { + cat("quitting daemon\n") + quit(save = "no") + } + } + p <- parallel:::mcfork() + if (inherits(p, "masterProcess")) { + close(inputCon) + Sys.setenv(SPARKR_WORKER_PORT = port) + source(script) + # Set SIGUSR1 so that child can exit + tools::pskill(Sys.getpid(), tools::SIGUSR1) + parallel:::mcexit(0L) + } + } +} diff --git a/R/pkg/inst/worker/worker.R b/R/pkg/inst/worker/worker.R new file mode 100644 index 0000000000000..c6542928e8ddd --- /dev/null +++ b/R/pkg/inst/worker/worker.R @@ -0,0 +1,128 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +# Worker class + +rLibDir <- Sys.getenv("SPARKR_RLIBDIR") +# Set libPaths to include SparkR package as loadNamespace needs this +# TODO: Figure out if we can avoid this by not loading any objects that require +# SparkR namespace +.libPaths(c(rLibDir, .libPaths())) +suppressPackageStartupMessages(library(SparkR)) + +port <- as.integer(Sys.getenv("SPARKR_WORKER_PORT")) +inputCon <- socketConnection(port = port, blocking = TRUE, open = "rb") +outputCon <- socketConnection(port = port, blocking = TRUE, open = "wb") + +# read the index of the current partition inside the RDD +partition <- SparkR:::readInt(inputCon) + +deserializer <- SparkR:::readString(inputCon) +serializer <- SparkR:::readString(inputCon) + +# Include packages as required +packageNames <- unserialize(SparkR:::readRaw(inputCon)) +for (pkg in packageNames) { + suppressPackageStartupMessages(require(as.character(pkg), character.only=TRUE)) +} + +# read function dependencies +funcLen <- SparkR:::readInt(inputCon) +computeFunc <- unserialize(SparkR:::readRawLen(inputCon, funcLen)) +env <- environment(computeFunc) +parent.env(env) <- .GlobalEnv # Attach under global environment. + +# Read and set broadcast variables +numBroadcastVars <- SparkR:::readInt(inputCon) +if (numBroadcastVars > 0) { + for (bcast in seq(1:numBroadcastVars)) { + bcastId <- SparkR:::readInt(inputCon) + value <- unserialize(SparkR:::readRaw(inputCon)) + setBroadcastValue(bcastId, value) + } +} + +# If -1: read as normal RDD; if >= 0, treat as pairwise RDD and treat the int +# as number of partitions to create. +numPartitions <- SparkR:::readInt(inputCon) + +isEmpty <- SparkR:::readInt(inputCon) + +if (isEmpty != 0) { + + if (numPartitions == -1) { + if (deserializer == "byte") { + # Now read as many characters as described in funcLen + data <- SparkR:::readDeserialize(inputCon) + } else if (deserializer == "string") { + data <- as.list(readLines(inputCon)) + } else if (deserializer == "row") { + data <- SparkR:::readDeserializeRows(inputCon) + } + output <- computeFunc(partition, data) + if (serializer == "byte") { + SparkR:::writeRawSerialize(outputCon, output) + } else if (serializer == "row") { + SparkR:::writeRowSerialize(outputCon, output) + } else { + SparkR:::writeStrings(outputCon, output) + } + } else { + if (deserializer == "byte") { + # Now read as many characters as described in funcLen + data <- SparkR:::readDeserialize(inputCon) + } else if (deserializer == "string") { + data <- readLines(inputCon) + } else if (deserializer == "row") { + data <- SparkR:::readDeserializeRows(inputCon) + } + + res <- new.env() + + # Step 1: hash the data to an environment + hashTupleToEnvir <- function(tuple) { + # NOTE: execFunction is the hash function here + hashVal <- computeFunc(tuple[[1]]) + bucket <- as.character(hashVal %% numPartitions) + acc <- res[[bucket]] + # Create a new accumulator + if (is.null(acc)) { + acc <- SparkR:::initAccumulator() + } + SparkR:::addItemToAccumulator(acc, tuple) + res[[bucket]] <- acc + } + invisible(lapply(data, hashTupleToEnvir)) + + # Step 2: write out all of the environment as key-value pairs. + for (name in ls(res)) { + SparkR:::writeInt(outputCon, 2L) + SparkR:::writeInt(outputCon, as.integer(name)) + # Truncate the accumulator list to the number of elements we have + length(res[[name]]$data) <- res[[name]]$counter + SparkR:::writeRawSerialize(outputCon, res[[name]]$data) + } + } +} + +# End of output +if (serializer %in% c("byte", "row")) { + SparkR:::writeInt(outputCon, 0L) +} + +close(outputCon) +close(inputCon) diff --git a/R/pkg/src/Makefile b/R/pkg/src/Makefile new file mode 100644 index 0000000000000..a55a56fe80e10 --- /dev/null +++ b/R/pkg/src/Makefile @@ -0,0 +1,27 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +all: sharelib + +sharelib: string_hash_code.c + R CMD SHLIB -o SparkR.so string_hash_code.c + +clean: + rm -f *.o + rm -f *.so + +.PHONY: all clean diff --git a/R/pkg/src/Makefile.win b/R/pkg/src/Makefile.win new file mode 100644 index 0000000000000..aa486d8228371 --- /dev/null +++ b/R/pkg/src/Makefile.win @@ -0,0 +1,27 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +all: sharelib + +sharelib: string_hash_code.c + R CMD SHLIB -o SparkR.dll string_hash_code.c + +clean: + rm -f *.o + rm -f *.dll + +.PHONY: all clean diff --git a/R/pkg/src/string_hash_code.c b/R/pkg/src/string_hash_code.c new file mode 100644 index 0000000000000..e3274b9a0c547 --- /dev/null +++ b/R/pkg/src/string_hash_code.c @@ -0,0 +1,49 @@ +/* + Licensed to the Apache Software Foundation (ASF) under one or more + contributor license agreements. See the NOTICE file distributed with + this work for additional information regarding copyright ownership. + The ASF licenses this file to You under the Apache License, Version 2.0 + (the "License"); you may not use this file except in compliance with + the License. You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +*/ + +/* + * A C function for R extension which implements the Java String hash algorithm. + * Refer to http://en.wikipedia.org/wiki/Java_hashCode%28%29#The_java.lang.String_hash_function + * + */ + +#include +#include + +/* for compatibility with R before 3.1 */ +#ifndef IS_SCALAR +#define IS_SCALAR(x, type) (TYPEOF(x) == (type) && XLENGTH(x) == 1) +#endif + +SEXP stringHashCode(SEXP string) { + const char* str; + R_xlen_t len, i; + int hashCode = 0; + + if (!IS_SCALAR(string, STRSXP)) { + error("invalid input"); + } + + str = CHAR(asChar(string)); + len = XLENGTH(asChar(string)); + + for (i = 0; i < len; i++) { + hashCode = (hashCode << 5) - hashCode + *str++; + } + + return ScalarInteger(hashCode); +} diff --git a/R/pkg/tests/run-all.R b/R/pkg/tests/run-all.R new file mode 100644 index 0000000000000..4f8a1ed2d83ef --- /dev/null +++ b/R/pkg/tests/run-all.R @@ -0,0 +1,21 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +library(testthat) +library(SparkR) + +test_package("SparkR") diff --git a/R/run-tests.sh b/R/run-tests.sh new file mode 100755 index 0000000000000..e82ad0ba2cd06 --- /dev/null +++ b/R/run-tests.sh @@ -0,0 +1,39 @@ +#!/bin/bash + +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +FWDIR="$(cd `dirname $0`; pwd)" + +FAILED=0 +LOGFILE=$FWDIR/unit-tests.out +rm -f $LOGFILE + +SPARK_TESTING=1 $FWDIR/../bin/sparkR --driver-java-options "-Dlog4j.configuration=file:$FWDIR/log4j.properties" $FWDIR/pkg/tests/run-all.R 2>&1 | tee -a $LOGFILE +FAILED=$((PIPESTATUS[0]||$FAILED)) + +if [[ $FAILED != 0 ]]; then + cat $LOGFILE + echo -en "\033[31m" # Red + echo "Had test failures; see logs." + echo -en "\033[0m" # No color + exit -1 +else + echo -en "\033[32m" # Green + echo "Tests passed." + echo -en "\033[0m" # No color +fi diff --git a/bagel/src/test/resources/log4j.properties b/bagel/src/test/resources/log4j.properties index 853ef0ed2986f..edbecdae92096 100644 --- a/bagel/src/test/resources/log4j.properties +++ b/bagel/src/test/resources/log4j.properties @@ -24,4 +24,4 @@ log4j.appender.file.layout=org.apache.log4j.PatternLayout log4j.appender.file.layout.ConversionPattern=%d{yy/MM/dd HH:mm:ss.SSS} %t %p %c{1}: %m%n # Ignore messages below warning level from Jetty, because it's a bit verbose -log4j.logger.org.eclipse.jetty=WARN +log4j.logger.org.spark-project.jetty=WARN diff --git a/bin/load-spark-env.cmd b/bin/load-spark-env.cmd new file mode 100644 index 0000000000000..36d932c453b6f --- /dev/null +++ b/bin/load-spark-env.cmd @@ -0,0 +1,59 @@ +@echo off + +rem +rem Licensed to the Apache Software Foundation (ASF) under one or more +rem contributor license agreements. See the NOTICE file distributed with +rem this work for additional information regarding copyright ownership. +rem The ASF licenses this file to You under the Apache License, Version 2.0 +rem (the "License"); you may not use this file except in compliance with +rem the License. You may obtain a copy of the License at +rem +rem http://www.apache.org/licenses/LICENSE-2.0 +rem +rem Unless required by applicable law or agreed to in writing, software +rem distributed under the License is distributed on an "AS IS" BASIS, +rem WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +rem See the License for the specific language governing permissions and +rem limitations under the License. +rem + +rem This script loads spark-env.cmd if it exists, and ensures it is only loaded once. +rem spark-env.cmd is loaded from SPARK_CONF_DIR if set, or within the current directory's +rem conf/ subdirectory. + +if [%SPARK_ENV_LOADED%] == [] ( + set SPARK_ENV_LOADED=1 + + if not [%SPARK_CONF_DIR%] == [] ( + set user_conf_dir=%SPARK_CONF_DIR% + ) else ( + set user_conf_dir=%~dp0..\..\conf + ) + + call :LoadSparkEnv +) + +rem Setting SPARK_SCALA_VERSION if not already set. + +set ASSEMBLY_DIR2=%SPARK_HOME%/assembly/target/scala-2.11 +set ASSEMBLY_DIR1=%SPARK_HOME%/assembly/target/scala-2.10 + +if [%SPARK_SCALA_VERSION%] == [] ( + + if exist %ASSEMBLY_DIR2% if exist %ASSEMBLY_DIR1% ( + echo "Presence of build for both scala versions(SCALA 2.10 and SCALA 2.11) detected." + echo "Either clean one of them or, set SPARK_SCALA_VERSION=2.11 in spark-env.cmd." + exit 1 + ) + if exist %ASSEMBLY_DIR2% ( + set SPARK_SCALA_VERSION=2.11 + ) else ( + set SPARK_SCALA_VERSION=2.10 + ) +) +exit /b 0 + +:LoadSparkEnv +if exist "%user_conf_dir%\spark-env.cmd" ( + call "%user_conf_dir%\spark-env.cmd" +) diff --git a/bin/load-spark-env.sh b/bin/load-spark-env.sh index 2d7070c25d328..95779e9ddbb18 100644 --- a/bin/load-spark-env.sh +++ b/bin/load-spark-env.sh @@ -20,6 +20,7 @@ # This script loads spark-env.sh if it exists, and ensures it is only loaded once. # spark-env.sh is loaded from SPARK_CONF_DIR if set, or within the current directory's # conf/ subdirectory. +FWDIR="$(cd "`dirname "$0"`"/..; pwd)" if [ -z "$SPARK_ENV_LOADED" ]; then export SPARK_ENV_LOADED=1 @@ -41,8 +42,8 @@ fi if [ -z "$SPARK_SCALA_VERSION" ]; then - ASSEMBLY_DIR2="$SPARK_HOME/assembly/target/scala-2.11" - ASSEMBLY_DIR1="$SPARK_HOME/assembly/target/scala-2.10" + ASSEMBLY_DIR2="$FWDIR/assembly/target/scala-2.11" + ASSEMBLY_DIR1="$FWDIR/assembly/target/scala-2.10" if [[ -d "$ASSEMBLY_DIR2" && -d "$ASSEMBLY_DIR1" ]]; then echo -e "Presence of build for both scala versions(SCALA 2.10 and SCALA 2.11) detected." 1>&2 diff --git a/bin/pyspark2.cmd b/bin/pyspark2.cmd index 4f5eb5e20614d..09b4149c2a439 100644 --- a/bin/pyspark2.cmd +++ b/bin/pyspark2.cmd @@ -20,8 +20,7 @@ rem rem Figure out where the Spark framework is installed set SPARK_HOME=%~dp0.. -rem Load environment variables from conf\spark-env.cmd, if it exists -if exist "%SPARK_HOME%\conf\spark-env.cmd" call "%SPARK_HOME%\conf\spark-env.cmd" +call %SPARK_HOME%\bin\load-spark-env.cmd rem Figure out which Python to use. if "x%PYSPARK_DRIVER_PYTHON%"=="x" ( diff --git a/bin/run-example2.cmd b/bin/run-example2.cmd index b49d0dcb4ff2d..c3e0221fb62e3 100644 --- a/bin/run-example2.cmd +++ b/bin/run-example2.cmd @@ -25,8 +25,7 @@ set FWDIR=%~dp0..\ rem Export this as SPARK_HOME set SPARK_HOME=%FWDIR% -rem Load environment variables from conf\spark-env.cmd, if it exists -if exist "%FWDIR%conf\spark-env.cmd" call "%FWDIR%conf\spark-env.cmd" +call %SPARK_HOME%\bin\load-spark-env.cmd rem Test that an argument was given if not "x%1"=="x" goto arg_given diff --git a/bin/spark-class2.cmd b/bin/spark-class2.cmd index 4ce727bc99128..4b3401d745f2a 100644 --- a/bin/spark-class2.cmd +++ b/bin/spark-class2.cmd @@ -20,8 +20,7 @@ rem rem Figure out where the Spark framework is installed set SPARK_HOME=%~dp0.. -rem Load environment variables from conf\spark-env.cmd, if it exists -if exist "%SPARK_HOME%\conf\spark-env.cmd" call "%SPARK_HOME%\conf\spark-env.cmd" +call %SPARK_HOME%\bin\load-spark-env.cmd rem Test that an argument was given if "x%1"=="x" ( diff --git a/bin/sparkR b/bin/sparkR new file mode 100755 index 0000000000000..8c918e2b09aef --- /dev/null +++ b/bin/sparkR @@ -0,0 +1,39 @@ +#!/bin/bash + +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +# Figure out where Spark is installed +export SPARK_HOME="$(cd "`dirname "$0"`"/..; pwd)" + +source "$SPARK_HOME"/bin/load-spark-env.sh + +function usage() { + if [ -n "$1" ]; then + echo $1 + fi + echo "Usage: ./bin/sparkR [options]" 1>&2 + "$SPARK_HOME"/bin/spark-submit --help 2>&1 | grep -v Usage 1>&2 + exit $2 +} +export -f usage + +if [[ "$@" = *--help ]] || [[ "$@" = *-h ]]; then + usage +fi + +exec "$SPARK_HOME"/bin/spark-submit sparkr-shell-main "$@" diff --git a/bin/sparkR.cmd b/bin/sparkR.cmd new file mode 100644 index 0000000000000..d7b60183ca8e0 --- /dev/null +++ b/bin/sparkR.cmd @@ -0,0 +1,23 @@ +@echo off + +rem +rem Licensed to the Apache Software Foundation (ASF) under one or more +rem contributor license agreements. See the NOTICE file distributed with +rem this work for additional information regarding copyright ownership. +rem The ASF licenses this file to You under the Apache License, Version 2.0 +rem (the "License"); you may not use this file except in compliance with +rem the License. You may obtain a copy of the License at +rem +rem http://www.apache.org/licenses/LICENSE-2.0 +rem +rem Unless required by applicable law or agreed to in writing, software +rem distributed under the License is distributed on an "AS IS" BASIS, +rem WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +rem See the License for the specific language governing permissions and +rem limitations under the License. +rem + +rem This is the entry point for running SparkR. To avoid polluting the +rem environment, it just launches a new cmd to do the real work. + +cmd /V /E /C %~dp0sparkR2.cmd %* diff --git a/bin/sparkR2.cmd b/bin/sparkR2.cmd new file mode 100644 index 0000000000000..e47f22c7300bb --- /dev/null +++ b/bin/sparkR2.cmd @@ -0,0 +1,26 @@ +@echo off + +rem +rem Licensed to the Apache Software Foundation (ASF) under one or more +rem contributor license agreements. See the NOTICE file distributed with +rem this work for additional information regarding copyright ownership. +rem The ASF licenses this file to You under the Apache License, Version 2.0 +rem (the "License"); you may not use this file except in compliance with +rem the License. You may obtain a copy of the License at +rem +rem http://www.apache.org/licenses/LICENSE-2.0 +rem +rem Unless required by applicable law or agreed to in writing, software +rem distributed under the License is distributed on an "AS IS" BASIS, +rem WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +rem See the License for the specific language governing permissions and +rem limitations under the License. +rem + +rem Figure out where the Spark framework is installed +set SPARK_HOME=%~dp0.. + +call %SPARK_HOME%\bin\load-spark-env.cmd + + +call %SPARK_HOME%\bin\spark-submit2.cmd sparkr-shell-main %* diff --git a/conf/log4j.properties.template b/conf/log4j.properties.template index 89eec7d4b7f61..3a2a88219818f 100644 --- a/conf/log4j.properties.template +++ b/conf/log4j.properties.template @@ -6,7 +6,7 @@ log4j.appender.console.layout=org.apache.log4j.PatternLayout log4j.appender.console.layout.ConversionPattern=%d{yy/MM/dd HH:mm:ss} %p %c{1}: %m%n # Settings to quiet third party logs that are too verbose -log4j.logger.org.eclipse.jetty=WARN -log4j.logger.org.eclipse.jetty.util.component.AbstractLifeCycle=ERROR +log4j.logger.org.spark-project.jetty=WARN +log4j.logger.org.spark-project.jetty.util.component.AbstractLifeCycle=ERROR log4j.logger.org.apache.spark.repl.SparkIMain$exprTyper=INFO log4j.logger.org.apache.spark.repl.SparkILoop$SparkILoopInterpreter=INFO diff --git a/core/pom.xml b/core/pom.xml index 5ec75db94b75a..f26dd23def259 100644 --- a/core/pom.xml +++ b/core/pom.xml @@ -457,4 +457,55 @@ + + + Windows + + + Windows + + + + \ + .bat + + + + unix + + + unix + + + + / + .sh + + + + sparkr + + + + org.codehaus.mojo + exec-maven-plugin + 1.3.2 + + + sparkr-pkg + compile + + exec + + + + + ..${path.separator}R${path.separator}install-dev${script.extension} + + + + + + + diff --git a/core/src/main/resources/org/apache/spark/log4j-defaults.properties b/core/src/main/resources/org/apache/spark/log4j-defaults.properties index 89eec7d4b7f61..3a2a88219818f 100644 --- a/core/src/main/resources/org/apache/spark/log4j-defaults.properties +++ b/core/src/main/resources/org/apache/spark/log4j-defaults.properties @@ -6,7 +6,7 @@ log4j.appender.console.layout=org.apache.log4j.PatternLayout log4j.appender.console.layout.ConversionPattern=%d{yy/MM/dd HH:mm:ss} %p %c{1}: %m%n # Settings to quiet third party logs that are too verbose -log4j.logger.org.eclipse.jetty=WARN -log4j.logger.org.eclipse.jetty.util.component.AbstractLifeCycle=ERROR +log4j.logger.org.spark-project.jetty=WARN +log4j.logger.org.spark-project.jetty.util.component.AbstractLifeCycle=ERROR log4j.logger.org.apache.spark.repl.SparkIMain$exprTyper=INFO log4j.logger.org.apache.spark.repl.SparkILoop$SparkILoopInterpreter=INFO diff --git a/core/src/main/scala/org/apache/spark/ExecutorAllocationManager.scala b/core/src/main/scala/org/apache/spark/ExecutorAllocationManager.scala index 21c6e6ffa6666..9385f557c4614 100644 --- a/core/src/main/scala/org/apache/spark/ExecutorAllocationManager.scala +++ b/core/src/main/scala/org/apache/spark/ExecutorAllocationManager.scala @@ -17,10 +17,12 @@ package org.apache.spark +import java.util.concurrent.{Executors, TimeUnit} + import scala.collection.mutable import org.apache.spark.scheduler._ -import org.apache.spark.util.{SystemClock, Clock} +import org.apache.spark.util.{Clock, SystemClock, Utils} /** * An agent that dynamically allocates and removes executors based on the workload. @@ -129,6 +131,10 @@ private[spark] class ExecutorAllocationManager( // Listener for Spark events that impact the allocation policy private val listener = new ExecutorAllocationListener + // Executor that handles the scheduling task. + private val executor = Executors.newSingleThreadScheduledExecutor( + Utils.namedThreadFactory("spark-dynamic-executor-allocation")) + /** * Verify that the settings specified through the config are valid. * If not, throw an appropriate exception. @@ -173,32 +179,24 @@ private[spark] class ExecutorAllocationManager( } /** - * Register for scheduler callbacks to decide when to add and remove executors. + * Register for scheduler callbacks to decide when to add and remove executors, and start + * the scheduling task. */ def start(): Unit = { listenerBus.addListener(listener) - startPolling() + + val scheduleTask = new Runnable() { + override def run(): Unit = Utils.logUncaughtExceptions(schedule()) + } + executor.scheduleAtFixedRate(scheduleTask, 0, intervalMillis, TimeUnit.MILLISECONDS) } /** - * Start the main polling thread that keeps track of when to add and remove executors. + * Stop the allocation manager. */ - private def startPolling(): Unit = { - val t = new Thread { - override def run(): Unit = { - while (true) { - try { - schedule() - } catch { - case e: Exception => logError("Exception in dynamic executor allocation thread!", e) - } - Thread.sleep(intervalMillis) - } - } - } - t.setName("spark-dynamic-executor-allocation") - t.setDaemon(true) - t.start() + def stop(): Unit = { + executor.shutdown() + executor.awaitTermination(10, TimeUnit.SECONDS) } /** diff --git a/core/src/main/scala/org/apache/spark/HeartbeatReceiver.scala b/core/src/main/scala/org/apache/spark/HeartbeatReceiver.scala index 8435e1ea2611c..5871b8c869f03 100644 --- a/core/src/main/scala/org/apache/spark/HeartbeatReceiver.scala +++ b/core/src/main/scala/org/apache/spark/HeartbeatReceiver.scala @@ -17,15 +17,15 @@ package org.apache.spark -import scala.concurrent.duration._ -import scala.collection.mutable +import java.util.concurrent.{ScheduledFuture, TimeUnit, Executors} -import akka.actor.{Actor, Cancellable} +import scala.collection.mutable import org.apache.spark.executor.TaskMetrics +import org.apache.spark.rpc.{ThreadSafeRpcEndpoint, RpcEnv, RpcCallContext} import org.apache.spark.storage.BlockManagerId import org.apache.spark.scheduler.{SlaveLost, TaskScheduler} -import org.apache.spark.util.ActorLogReceive +import org.apache.spark.util.Utils /** * A heartbeat from executors to the driver. This is a shared message used by several internal @@ -37,6 +37,12 @@ private[spark] case class Heartbeat( taskMetrics: Array[(Long, TaskMetrics)], // taskId -> TaskMetrics blockManagerId: BlockManagerId) +/** + * An event that SparkContext uses to notify HeartbeatReceiver that SparkContext.taskScheduler is + * created. + */ +private[spark] case object TaskSchedulerIsSet + private[spark] case object ExpireDeadHosts private[spark] case class HeartbeatResponse(reregisterBlockManager: Boolean) @@ -44,8 +50,12 @@ private[spark] case class HeartbeatResponse(reregisterBlockManager: Boolean) /** * Lives in the driver to receive heartbeats from executors.. */ -private[spark] class HeartbeatReceiver(sc: SparkContext, scheduler: TaskScheduler) - extends Actor with ActorLogReceive with Logging { +private[spark] class HeartbeatReceiver(sc: SparkContext) + extends ThreadSafeRpcEndpoint with Logging { + + override val rpcEnv: RpcEnv = sc.env.rpcEnv + + private[spark] var scheduler: TaskScheduler = null // executor ID -> timestamp of when the last heartbeat from this executor was received private val executorLastSeen = new mutable.HashMap[String, Long] @@ -61,24 +71,44 @@ private[spark] class HeartbeatReceiver(sc: SparkContext, scheduler: TaskSchedule sc.conf.getOption("spark.network.timeoutInterval").map(_.toLong * 1000). getOrElse(sc.conf.getLong("spark.storage.blockManagerTimeoutIntervalMs", 60000)) - private var timeoutCheckingTask: Cancellable = null - - override def preStart(): Unit = { - import context.dispatcher - timeoutCheckingTask = context.system.scheduler.schedule(0.seconds, - checkTimeoutIntervalMs.milliseconds, self, ExpireDeadHosts) - super.preStart() + private var timeoutCheckingTask: ScheduledFuture[_] = null + + private val timeoutCheckingThread = Executors.newSingleThreadScheduledExecutor( + Utils.namedThreadFactory("heartbeat-timeout-checking-thread")) + + private val killExecutorThread = Executors.newSingleThreadExecutor( + Utils.namedThreadFactory("kill-executor-thread")) + + override def onStart(): Unit = { + timeoutCheckingTask = timeoutCheckingThread.scheduleAtFixedRate(new Runnable { + override def run(): Unit = Utils.tryLogNonFatalError { + Option(self).foreach(_.send(ExpireDeadHosts)) + } + }, 0, checkTimeoutIntervalMs, TimeUnit.MILLISECONDS) } - - override def receiveWithLogging: PartialFunction[Any, Unit] = { - case Heartbeat(executorId, taskMetrics, blockManagerId) => - val unknownExecutor = !scheduler.executorHeartbeatReceived( - executorId, taskMetrics, blockManagerId) - val response = HeartbeatResponse(reregisterBlockManager = unknownExecutor) - executorLastSeen(executorId) = System.currentTimeMillis() - sender ! response + + override def receive: PartialFunction[Any, Unit] = { case ExpireDeadHosts => expireDeadHosts() + case TaskSchedulerIsSet => + scheduler = sc.taskScheduler + } + + override def receiveAndReply(context: RpcCallContext): PartialFunction[Any, Unit] = { + case heartbeat @ Heartbeat(executorId, taskMetrics, blockManagerId) => + if (scheduler != null) { + val unknownExecutor = !scheduler.executorHeartbeatReceived( + executorId, taskMetrics, blockManagerId) + val response = HeartbeatResponse(reregisterBlockManager = unknownExecutor) + executorLastSeen(executorId) = System.currentTimeMillis() + context.reply(response) + } else { + // Because Executor will sleep several seconds before sending the first "Heartbeat", this + // case rarely happens. However, if it really happens, log it and ask the executor to + // register itself again. + logWarning(s"Dropping $heartbeat because TaskScheduler is not ready yet") + context.reply(HeartbeatResponse(reregisterBlockManager = true)) + } } private def expireDeadHosts(): Unit = { @@ -91,17 +121,25 @@ private[spark] class HeartbeatReceiver(sc: SparkContext, scheduler: TaskSchedule scheduler.executorLost(executorId, SlaveLost("Executor heartbeat " + s"timed out after ${now - lastSeenMs} ms")) if (sc.supportDynamicAllocation) { - sc.killExecutor(executorId) + // Asynchronously kill the executor to avoid blocking the current thread + killExecutorThread.submit(new Runnable { + override def run(): Unit = sc.killExecutor(executorId) + }) } executorLastSeen.remove(executorId) } } } - override def postStop(): Unit = { + override def onStop(): Unit = { if (timeoutCheckingTask != null) { - timeoutCheckingTask.cancel() + timeoutCheckingTask.cancel(true) } - super.postStop() + timeoutCheckingThread.shutdownNow() + killExecutorThread.shutdownNow() } } + +object HeartbeatReceiver { + val ENDPOINT_NAME = "HeartbeatReceiver" +} diff --git a/core/src/main/scala/org/apache/spark/HttpServer.scala b/core/src/main/scala/org/apache/spark/HttpServer.scala index 09a9ccc226721..8de3a6c04df34 100644 --- a/core/src/main/scala/org/apache/spark/HttpServer.scala +++ b/core/src/main/scala/org/apache/spark/HttpServer.scala @@ -160,7 +160,7 @@ private[spark] class HttpServer( throw new ServerStateException("Server is not started") } else { val scheme = if (securityManager.fileServerSSLOptions.enabled) "https" else "http" - s"$scheme://${Utils.localIpAddress}:$port" + s"$scheme://${Utils.localHostNameForURI()}:$port" } } } diff --git a/core/src/main/scala/org/apache/spark/MapOutputTracker.scala b/core/src/main/scala/org/apache/spark/MapOutputTracker.scala index c9426c5de23a2..d65c94e410662 100644 --- a/core/src/main/scala/org/apache/spark/MapOutputTracker.scala +++ b/core/src/main/scala/org/apache/spark/MapOutputTracker.scala @@ -21,13 +21,11 @@ import java.io._ import java.util.concurrent.ConcurrentHashMap import java.util.zip.{GZIPInputStream, GZIPOutputStream} -import scala.collection.mutable.{HashSet, HashMap, Map} -import scala.concurrent.Await +import scala.collection.mutable.{HashSet, Map} import scala.collection.JavaConversions._ +import scala.reflect.ClassTag -import akka.actor._ -import akka.pattern.ask - +import org.apache.spark.rpc.{RpcEndpointRef, RpcEnv, RpcCallContext, RpcEndpoint} import org.apache.spark.scheduler.MapStatus import org.apache.spark.shuffle.MetadataFetchFailedException import org.apache.spark.storage.BlockManagerId @@ -38,14 +36,15 @@ private[spark] case class GetMapOutputStatuses(shuffleId: Int) extends MapOutputTrackerMessage private[spark] case object StopMapOutputTracker extends MapOutputTrackerMessage -/** Actor class for MapOutputTrackerMaster */ -private[spark] class MapOutputTrackerMasterActor(tracker: MapOutputTrackerMaster, conf: SparkConf) - extends Actor with ActorLogReceive with Logging { +/** RpcEndpoint class for MapOutputTrackerMaster */ +private[spark] class MapOutputTrackerMasterEndpoint( + override val rpcEnv: RpcEnv, tracker: MapOutputTrackerMaster, conf: SparkConf) + extends RpcEndpoint with Logging { val maxAkkaFrameSize = AkkaUtils.maxFrameSizeBytes(conf) - override def receiveWithLogging: PartialFunction[Any, Unit] = { + override def receiveAndReply(context: RpcCallContext): PartialFunction[Any, Unit] = { case GetMapOutputStatuses(shuffleId: Int) => - val hostPort = sender.path.address.hostPort + val hostPort = context.sender.address.hostPort logInfo("Asked to send map output locations for shuffle " + shuffleId + " to " + hostPort) val mapOutputStatuses = tracker.getSerializedMapOutputStatuses(shuffleId) val serializedSize = mapOutputStatuses.size @@ -53,19 +52,19 @@ private[spark] class MapOutputTrackerMasterActor(tracker: MapOutputTrackerMaster val msg = s"Map output statuses were $serializedSize bytes which " + s"exceeds spark.akka.frameSize ($maxAkkaFrameSize bytes)." - /* For SPARK-1244 we'll opt for just logging an error and then throwing an exception. - * Note that on exception the actor will just restart. A bigger refactoring (SPARK-1239) - * will ultimately remove this entire code path. */ + /* For SPARK-1244 we'll opt for just logging an error and then sending it to the sender. + * A bigger refactoring (SPARK-1239) will ultimately remove this entire code path. */ val exception = new SparkException(msg) logError(msg, exception) - throw exception + context.sendFailure(exception) + } else { + context.reply(mapOutputStatuses) } - sender ! mapOutputStatuses case StopMapOutputTracker => - logInfo("MapOutputTrackerActor stopped!") - sender ! true - context.stop(self) + logInfo("MapOutputTrackerMasterEndpoint stopped!") + context.reply(true) + stop() } } @@ -75,12 +74,9 @@ private[spark] class MapOutputTrackerMasterActor(tracker: MapOutputTrackerMaster * (driver and executor) use different HashMap to store its metadata. */ private[spark] abstract class MapOutputTracker(conf: SparkConf) extends Logging { - private val timeout = AkkaUtils.askTimeout(conf) - private val retryAttempts = AkkaUtils.numRetries(conf) - private val retryIntervalMs = AkkaUtils.retryWaitMs(conf) - /** Set to the MapOutputTrackerActor living on the driver. */ - var trackerActor: ActorRef = _ + /** Set to the MapOutputTrackerMasterEndpoint living on the driver. */ + var trackerEndpoint: RpcEndpointRef = _ /** * This HashMap has different behavior for the driver and the executors. @@ -105,12 +101,12 @@ private[spark] abstract class MapOutputTracker(conf: SparkConf) extends Logging private val fetching = new HashSet[Int] /** - * Send a message to the trackerActor and get its result within a default timeout, or + * Send a message to the trackerEndpoint and get its result within a default timeout, or * throw a SparkException if this fails. */ - protected def askTracker(message: Any): Any = { + protected def askTracker[T: ClassTag](message: Any): T = { try { - AkkaUtils.askWithReply(message, trackerActor, retryAttempts, retryIntervalMs, timeout) + trackerEndpoint.askWithReply[T](message) } catch { case e: Exception => logError("Error communicating with MapOutputTracker", e) @@ -118,9 +114,9 @@ private[spark] abstract class MapOutputTracker(conf: SparkConf) extends Logging } } - /** Send a one-way message to the trackerActor, to which we expect it to reply with true. */ + /** Send a one-way message to the trackerEndpoint, to which we expect it to reply with true. */ protected def sendTracker(message: Any) { - val response = askTracker(message) + val response = askTracker[Boolean](message) if (response != true) { throw new SparkException( "Error reply received from MapOutputTracker. Expecting true, got " + response.toString) @@ -157,11 +153,10 @@ private[spark] abstract class MapOutputTracker(conf: SparkConf) extends Logging if (fetchedStatuses == null) { // We won the race to fetch the output locs; do so - logInfo("Doing the fetch; tracker actor = " + trackerActor) + logInfo("Doing the fetch; tracker endpoint = " + trackerEndpoint) // This try-finally prevents hangs due to timeouts: try { - val fetchedBytes = - askTracker(GetMapOutputStatuses(shuffleId)).asInstanceOf[Array[Byte]] + val fetchedBytes = askTracker[Array[Byte]](GetMapOutputStatuses(shuffleId)) fetchedStatuses = MapOutputTracker.deserializeMapStatuses(fetchedBytes) logInfo("Got the output locations") mapStatuses.put(shuffleId, fetchedStatuses) @@ -328,7 +323,7 @@ private[spark] class MapOutputTrackerMaster(conf: SparkConf) override def stop() { sendTracker(StopMapOutputTracker) mapStatuses.clear() - trackerActor = null + trackerEndpoint = null metadataCleaner.cancel() cachedSerializedStatuses.clear() } @@ -350,17 +345,22 @@ private[spark] class MapOutputTrackerWorker(conf: SparkConf) extends MapOutputTr private[spark] object MapOutputTracker extends Logging { + val ENDPOINT_NAME = "MapOutputTracker" + // Serialize an array of map output locations into an efficient byte format so that we can send // it to reduce tasks. We do this by compressing the serialized bytes using GZIP. They will // generally be pretty compressible because many map outputs will be on the same hostname. def serializeMapStatuses(statuses: Array[MapStatus]): Array[Byte] = { val out = new ByteArrayOutputStream val objOut = new ObjectOutputStream(new GZIPOutputStream(out)) - // Since statuses can be modified in parallel, sync on it - statuses.synchronized { - objOut.writeObject(statuses) + Utils.tryWithSafeFinally { + // Since statuses can be modified in parallel, sync on it + statuses.synchronized { + objOut.writeObject(statuses) + } + } { + objOut.close() } - objOut.close() out.toByteArray } diff --git a/core/src/main/scala/org/apache/spark/SparkContext.scala b/core/src/main/scala/org/apache/spark/SparkContext.scala index d7ad293dd2686..5919888c9a725 100644 --- a/core/src/main/scala/org/apache/spark/SparkContext.scala +++ b/core/src/main/scala/org/apache/spark/SparkContext.scala @@ -23,7 +23,7 @@ import java.io._ import java.lang.reflect.Constructor import java.net.URI import java.util.{Arrays, Properties, UUID} -import java.util.concurrent.atomic.AtomicInteger +import java.util.concurrent.atomic.{AtomicBoolean, AtomicInteger} import java.util.UUID.randomUUID import scala.collection.{Map, Set} @@ -32,8 +32,6 @@ import scala.collection.generic.Growable import scala.collection.mutable.HashMap import scala.reflect.{ClassTag, classTag} -import akka.actor.Props - import org.apache.hadoop.conf.Configuration import org.apache.hadoop.fs.Path import org.apache.hadoop.io.{ArrayWritable, BooleanWritable, BytesWritable, DoubleWritable, @@ -48,12 +46,13 @@ import org.apache.mesos.MesosNativeLibrary import org.apache.spark.annotation.{DeveloperApi, Experimental} import org.apache.spark.broadcast.Broadcast import org.apache.spark.deploy.{LocalSparkCluster, SparkHadoopUtil} -import org.apache.spark.executor.TriggerThreadDump +import org.apache.spark.executor.{ExecutorEndpoint, TriggerThreadDump} import org.apache.spark.input.{StreamInputFormat, PortableDataStream, WholeTextFileInputFormat, FixedLengthBinaryInputFormat} import org.apache.spark.io.CompressionCodec import org.apache.spark.partial.{ApproximateEvaluator, PartialResult} import org.apache.spark.rdd._ +import org.apache.spark.rpc.RpcAddress import org.apache.spark.scheduler._ import org.apache.spark.scheduler.cluster.{CoarseGrainedSchedulerBackend, SparkDeploySchedulerBackend, SimrSchedulerBackend} @@ -95,10 +94,10 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli val startTime = System.currentTimeMillis() - @volatile private var stopped: Boolean = false + private val stopped: AtomicBoolean = new AtomicBoolean(false) private def assertNotStopped(): Unit = { - if (stopped) { + if (stopped.get()) { throw new IllegalStateException("Cannot call methods on a stopped SparkContext") } } @@ -227,9 +226,11 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli val appName = conf.get("spark.app.name") private[spark] val isEventLogEnabled = conf.getBoolean("spark.eventLog.enabled", false) - private[spark] val eventLogDir: Option[String] = { + private[spark] val eventLogDir: Option[URI] = { if (isEventLogEnabled) { - Some(conf.get("spark.eventLog.dir", EventLoggingListener.DEFAULT_LOG_DIR).stripSuffix("/")) + val unresolvedDir = conf.get("spark.eventLog.dir", EventLoggingListener.DEFAULT_LOG_DIR) + .stripSuffix("/") + Some(Utils.resolveURI(unresolvedDir)) } else { None } @@ -356,11 +357,17 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli val sparkUser = Utils.getCurrentUserName() executorEnvs("SPARK_USER") = sparkUser + // We need to register "HeartbeatReceiver" before "createTaskScheduler" because Executor will + // retrieve "HeartbeatReceiver" in the constructor. (SPARK-6640) + private val heartbeatReceiver = env.rpcEnv.setupEndpoint( + HeartbeatReceiver.ENDPOINT_NAME, new HeartbeatReceiver(this)) + // Create and start the scheduler private[spark] var (schedulerBackend, taskScheduler) = SparkContext.createTaskScheduler(this, master) - private val heartbeatReceiver = env.actorSystem.actorOf( - Props(new HeartbeatReceiver(this, taskScheduler)), "HeartbeatReceiver") + + heartbeatReceiver.send(TaskSchedulerIsSet) + @volatile private[spark] var dagScheduler: DAGScheduler = _ try { dagScheduler = new DAGScheduler(this) @@ -433,6 +440,7 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli // Thread Local variable that can be used by users to pass information down the stack private val localProperties = new InheritableThreadLocal[Properties] { override protected def childValue(parent: Properties): Properties = new Properties(parent) + override protected def initialValue(): Properties = new Properties() } /** @@ -446,10 +454,12 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli if (executorId == SparkContext.DRIVER_IDENTIFIER) { Some(Utils.getThreadDump()) } else { - val (host, port) = env.blockManager.master.getActorSystemHostPortForExecutor(executorId).get - val actorRef = AkkaUtils.makeExecutorRef("ExecutorActor", conf, host, port, env.actorSystem) - Some(AkkaUtils.askWithReply[Array[ThreadStackTrace]](TriggerThreadDump, actorRef, - AkkaUtils.numRetries(conf), AkkaUtils.retryWaitMs(conf), AkkaUtils.askTimeout(conf))) + val (host, port) = env.blockManager.master.getRpcHostPortForExecutor(executorId).get + val endpointRef = env.rpcEnv.setupEndpointRef( + SparkEnv.executorActorSystemName, + RpcAddress(host, port), + ExecutorEndpoint.EXECUTOR_ENDPOINT_NAME) + Some(endpointRef.askWithReply[Array[ThreadStackTrace]](TriggerThreadDump)) } } catch { case e: Exception => @@ -474,9 +484,6 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli * Spark fair scheduler pool. */ def setLocalProperty(key: String, value: String) { - if (localProperties.get() == null) { - localProperties.set(new Properties()) - } if (value == null) { localProperties.get.remove(key) } else { @@ -1138,7 +1145,7 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli * Return whether dynamically adjusting the amount of resources allocated to * this application is supported. This is currently only available for YARN. */ - private[spark] def supportDynamicAllocation = + private[spark] def supportDynamicAllocation = master.contains("yarn") || dynamicAllocationTesting /** @@ -1392,32 +1399,34 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli addedJars.clear() } - /** Shut down the SparkContext. */ + // Shut down the SparkContext. def stop() { - SparkContext.SPARK_CONTEXT_CONSTRUCTOR_LOCK.synchronized { - if (!stopped) { - stopped = true - postApplicationEnd() - ui.foreach(_.stop()) - env.metricsSystem.report() - metadataCleaner.cancel() - cleaner.foreach(_.stop()) - dagScheduler.stop() - dagScheduler = null - listenerBus.stop() - eventLogger.foreach(_.stop()) - env.actorSystem.stop(heartbeatReceiver) - progressBar.foreach(_.stop()) - taskScheduler = null - // TODO: Cache.stop()? - env.stop() - SparkEnv.set(null) - logInfo("Successfully stopped SparkContext") - SparkContext.clearActiveContext() - } else { - logInfo("SparkContext already stopped") - } + // Use the stopping variable to ensure no contention for the stop scenario. + // Still track the stopped variable for use elsewhere in the code. + + if (!stopped.compareAndSet(false, true)) { + logInfo("SparkContext already stopped.") + return } + + postApplicationEnd() + ui.foreach(_.stop()) + env.metricsSystem.report() + metadataCleaner.cancel() + cleaner.foreach(_.stop()) + executorAllocationManager.foreach(_.stop()) + dagScheduler.stop() + dagScheduler = null + listenerBus.stop() + eventLogger.foreach(_.stop()) + env.rpcEnv.stop(heartbeatReceiver) + progressBar.foreach(_.stop()) + taskScheduler = null + // TODO: Cache.stop()? + env.stop() + SparkEnv.set(null) + SparkContext.clearActiveContext() + logInfo("Successfully stopped SparkContext") } @@ -1479,7 +1488,7 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli partitions: Seq[Int], allowLocal: Boolean, resultHandler: (Int, U) => Unit) { - if (stopped) { + if (stopped.get()) { throw new IllegalStateException("SparkContext has been shutdown") } val callSite = getCallSite @@ -1892,7 +1901,17 @@ object SparkContext extends Logging { private[spark] val SPARK_JOB_INTERRUPT_ON_CANCEL = "spark.job.interruptOnCancel" - private[spark] val DRIVER_IDENTIFIER = "" + /** + * Executor id for the driver. In earlier versions of Spark, this was ``, but this was + * changed to `driver` because the angle brackets caused escaping issues in URLs and XML (see + * SPARK-6716 for more details). + */ + private[spark] val DRIVER_IDENTIFIER = "driver" + + /** + * Legacy version of DRIVER_IDENTIFIER, retained for backwards-compatibility. + */ + private[spark] val LEGACY_DRIVER_IDENTIFIER = "" // The following deprecated objects have already been copied to `object AccumulatorParam` to // make the compiler find them automatically. They are duplicate codes only for backward diff --git a/core/src/main/scala/org/apache/spark/SparkEnv.scala b/core/src/main/scala/org/apache/spark/SparkEnv.scala index 4a2ed82a40dec..0171488e09562 100644 --- a/core/src/main/scala/org/apache/spark/SparkEnv.scala +++ b/core/src/main/scala/org/apache/spark/SparkEnv.scala @@ -24,7 +24,6 @@ import scala.collection.JavaConversions._ import scala.collection.mutable import scala.util.Properties -import akka.actor._ import com.google.common.collect.MapMaker import org.apache.spark.annotation.DeveloperApi @@ -41,7 +40,7 @@ import org.apache.spark.scheduler.OutputCommitCoordinator.OutputCommitCoordinato import org.apache.spark.serializer.Serializer import org.apache.spark.shuffle.{ShuffleMemoryManager, ShuffleManager} import org.apache.spark.storage._ -import org.apache.spark.util.{AkkaUtils, RpcUtils, Utils} +import org.apache.spark.util.{RpcUtils, Utils} /** * :: DeveloperApi :: @@ -286,16 +285,9 @@ object SparkEnv extends Logging { val closureSerializer = instantiateClassFromConf[Serializer]( "spark.closure.serializer", "org.apache.spark.serializer.JavaSerializer") - def registerOrLookup(name: String, newActor: => Actor): ActorRef = { - if (isDriver) { - logInfo("Registering " + name) - actorSystem.actorOf(Props(newActor), name = name) - } else { - AkkaUtils.makeDriverRef(name, conf, actorSystem) - } - } - - def registerOrLookupEndpoint(name: String, endpointCreator: => RpcEndpoint): RpcEndpointRef = { + def registerOrLookupEndpoint( + name: String, endpointCreator: => RpcEndpoint): + RpcEndpointRef = { if (isDriver) { logInfo("Registering " + name) rpcEnv.setupEndpoint(name, endpointCreator) @@ -312,9 +304,9 @@ object SparkEnv extends Logging { // Have to assign trackerActor after initialization as MapOutputTrackerActor // requires the MapOutputTracker itself - mapOutputTracker.trackerActor = registerOrLookup( - "MapOutputTracker", - new MapOutputTrackerMasterActor(mapOutputTracker.asInstanceOf[MapOutputTrackerMaster], conf)) + mapOutputTracker.trackerEndpoint = registerOrLookupEndpoint(MapOutputTracker.ENDPOINT_NAME, + new MapOutputTrackerMasterEndpoint( + rpcEnv, mapOutputTracker.asInstanceOf[MapOutputTrackerMaster], conf)) // Let the user specify short names for shuffle managers val shortShuffleMgrNames = Map( @@ -334,12 +326,13 @@ object SparkEnv extends Logging { new NioBlockTransferService(conf, securityManager) } - val blockManagerMaster = new BlockManagerMaster(registerOrLookup( - "BlockManagerMaster", - new BlockManagerMasterActor(isLocal, conf, listenerBus)), conf, isDriver) + val blockManagerMaster = new BlockManagerMaster(registerOrLookupEndpoint( + BlockManagerMaster.DRIVER_ENDPOINT_NAME, + new BlockManagerMasterEndpoint(rpcEnv, isLocal, conf, listenerBus)), + conf, isDriver) // NB: blockManager is not valid until initialize() is called later. - val blockManager = new BlockManager(executorId, actorSystem, blockManagerMaster, + val blockManager = new BlockManager(executorId, rpcEnv, blockManagerMaster, serializer, conf, mapOutputTracker, shuffleManager, blockTransferService, securityManager, numUsableCores) diff --git a/core/src/main/scala/org/apache/spark/api/java/JavaPairRDD.scala b/core/src/main/scala/org/apache/spark/api/java/JavaPairRDD.scala index a023712be1166..8441bb3a3047e 100644 --- a/core/src/main/scala/org/apache/spark/api/java/JavaPairRDD.scala +++ b/core/src/main/scala/org/apache/spark/api/java/JavaPairRDD.scala @@ -661,7 +661,7 @@ class JavaPairRDD[K, V](val rdd: RDD[(K, V)]) */ def flatMapValues[U](f: JFunction[V, java.lang.Iterable[U]]): JavaPairRDD[K, U] = { import scala.collection.JavaConverters._ - def fn = (x: V) => f.call(x).asScala + def fn: (V) => Iterable[U] = (x: V) => f.call(x).asScala implicit val ctag: ClassTag[U] = fakeClassTag fromRDD(rdd.flatMapValues(fn)) } diff --git a/core/src/main/scala/org/apache/spark/api/java/JavaRDD.scala b/core/src/main/scala/org/apache/spark/api/java/JavaRDD.scala index 18ccd625fc8d1..db4e996feb31c 100644 --- a/core/src/main/scala/org/apache/spark/api/java/JavaRDD.scala +++ b/core/src/main/scala/org/apache/spark/api/java/JavaRDD.scala @@ -192,7 +192,7 @@ class JavaRDD[T](val rdd: RDD[T])(implicit val classTag: ClassTag[T]) */ def sortBy[S](f: JFunction[T, S], ascending: Boolean, numPartitions: Int): JavaRDD[T] = { import scala.collection.JavaConverters._ - def fn = (x: T) => f.call(x) + def fn: (T) => S = (x: T) => f.call(x) import com.google.common.collect.Ordering // shadows scala.math.Ordering implicit val ordering = Ordering.natural().asInstanceOf[Ordering[S]] implicit val ctag: ClassTag[S] = fakeClassTag diff --git a/core/src/main/scala/org/apache/spark/api/java/JavaRDDLike.scala b/core/src/main/scala/org/apache/spark/api/java/JavaRDDLike.scala index 8da42934a7d96..8bf0627fc420d 100644 --- a/core/src/main/scala/org/apache/spark/api/java/JavaRDDLike.scala +++ b/core/src/main/scala/org/apache/spark/api/java/JavaRDDLike.scala @@ -17,8 +17,9 @@ package org.apache.spark.api.java -import java.util.{Comparator, List => JList, Iterator => JIterator} +import java.{lang => jl} import java.lang.{Iterable => JIterable, Long => JLong} +import java.util.{Comparator, List => JList, Iterator => JIterator} import scala.collection.JavaConversions._ import scala.collection.JavaConverters._ @@ -93,7 +94,7 @@ trait JavaRDDLike[T, This <: JavaRDDLike[T, This]] extends Serializable { * of the original partition. */ def mapPartitionsWithIndex[R]( - f: JFunction2[java.lang.Integer, java.util.Iterator[T], java.util.Iterator[R]], + f: JFunction2[jl.Integer, java.util.Iterator[T], java.util.Iterator[R]], preservesPartitioning: Boolean = false): JavaRDD[R] = new JavaRDD(rdd.mapPartitionsWithIndex(((a,b) => f(a,asJavaIterator(b))), preservesPartitioning)(fakeClassTag))(fakeClassTag) @@ -109,7 +110,7 @@ trait JavaRDDLike[T, This <: JavaRDDLike[T, This]] extends Serializable { * Return a new RDD by applying a function to all elements of this RDD. */ def mapToPair[K2, V2](f: PairFunction[T, K2, V2]): JavaPairRDD[K2, V2] = { - def cm = implicitly[ClassTag[(K2, V2)]] + def cm: ClassTag[(K2, V2)] = implicitly[ClassTag[(K2, V2)]] new JavaPairRDD(rdd.map[(K2, V2)](f)(cm))(fakeClassTag[K2], fakeClassTag[V2]) } @@ -119,7 +120,7 @@ trait JavaRDDLike[T, This <: JavaRDDLike[T, This]] extends Serializable { */ def flatMap[U](f: FlatMapFunction[T, U]): JavaRDD[U] = { import scala.collection.JavaConverters._ - def fn = (x: T) => f.call(x).asScala + def fn: (T) => Iterable[U] = (x: T) => f.call(x).asScala JavaRDD.fromRDD(rdd.flatMap(fn)(fakeClassTag[U]))(fakeClassTag[U]) } @@ -129,8 +130,8 @@ trait JavaRDDLike[T, This <: JavaRDDLike[T, This]] extends Serializable { */ def flatMapToDouble(f: DoubleFlatMapFunction[T]): JavaDoubleRDD = { import scala.collection.JavaConverters._ - def fn = (x: T) => f.call(x).asScala - new JavaDoubleRDD(rdd.flatMap(fn).map((x: java.lang.Double) => x.doubleValue())) + def fn: (T) => Iterable[jl.Double] = (x: T) => f.call(x).asScala + new JavaDoubleRDD(rdd.flatMap(fn).map((x: jl.Double) => x.doubleValue())) } /** @@ -139,8 +140,8 @@ trait JavaRDDLike[T, This <: JavaRDDLike[T, This]] extends Serializable { */ def flatMapToPair[K2, V2](f: PairFlatMapFunction[T, K2, V2]): JavaPairRDD[K2, V2] = { import scala.collection.JavaConverters._ - def fn = (x: T) => f.call(x).asScala - def cm = implicitly[ClassTag[(K2, V2)]] + def fn: (T) => Iterable[(K2, V2)] = (x: T) => f.call(x).asScala + def cm: ClassTag[(K2, V2)] = implicitly[ClassTag[(K2, V2)]] JavaPairRDD.fromRDD(rdd.flatMap(fn)(cm))(fakeClassTag[K2], fakeClassTag[V2]) } @@ -148,7 +149,9 @@ trait JavaRDDLike[T, This <: JavaRDDLike[T, This]] extends Serializable { * Return a new RDD by applying a function to each partition of this RDD. */ def mapPartitions[U](f: FlatMapFunction[java.util.Iterator[T], U]): JavaRDD[U] = { - def fn = (x: Iterator[T]) => asScalaIterator(f.call(asJavaIterator(x)).iterator()) + def fn: (Iterator[T]) => Iterator[U] = { + (x: Iterator[T]) => asScalaIterator(f.call(asJavaIterator(x)).iterator()) + } JavaRDD.fromRDD(rdd.mapPartitions(fn)(fakeClassTag[U]))(fakeClassTag[U]) } @@ -157,7 +160,9 @@ trait JavaRDDLike[T, This <: JavaRDDLike[T, This]] extends Serializable { */ def mapPartitions[U](f: FlatMapFunction[java.util.Iterator[T], U], preservesPartitioning: Boolean): JavaRDD[U] = { - def fn = (x: Iterator[T]) => asScalaIterator(f.call(asJavaIterator(x)).iterator()) + def fn: (Iterator[T]) => Iterator[U] = { + (x: Iterator[T]) => asScalaIterator(f.call(asJavaIterator(x)).iterator()) + } JavaRDD.fromRDD( rdd.mapPartitions(fn, preservesPartitioning)(fakeClassTag[U]))(fakeClassTag[U]) } @@ -166,8 +171,10 @@ trait JavaRDDLike[T, This <: JavaRDDLike[T, This]] extends Serializable { * Return a new RDD by applying a function to each partition of this RDD. */ def mapPartitionsToDouble(f: DoubleFlatMapFunction[java.util.Iterator[T]]): JavaDoubleRDD = { - def fn = (x: Iterator[T]) => asScalaIterator(f.call(asJavaIterator(x)).iterator()) - new JavaDoubleRDD(rdd.mapPartitions(fn).map((x: java.lang.Double) => x.doubleValue())) + def fn: (Iterator[T]) => Iterator[jl.Double] = { + (x: Iterator[T]) => asScalaIterator(f.call(asJavaIterator(x)).iterator()) + } + new JavaDoubleRDD(rdd.mapPartitions(fn).map((x: jl.Double) => x.doubleValue())) } /** @@ -175,7 +182,9 @@ trait JavaRDDLike[T, This <: JavaRDDLike[T, This]] extends Serializable { */ def mapPartitionsToPair[K2, V2](f: PairFlatMapFunction[java.util.Iterator[T], K2, V2]): JavaPairRDD[K2, V2] = { - def fn = (x: Iterator[T]) => asScalaIterator(f.call(asJavaIterator(x)).iterator()) + def fn: (Iterator[T]) => Iterator[(K2, V2)] = { + (x: Iterator[T]) => asScalaIterator(f.call(asJavaIterator(x)).iterator()) + } JavaPairRDD.fromRDD(rdd.mapPartitions(fn))(fakeClassTag[K2], fakeClassTag[V2]) } @@ -184,7 +193,9 @@ trait JavaRDDLike[T, This <: JavaRDDLike[T, This]] extends Serializable { */ def mapPartitionsToDouble(f: DoubleFlatMapFunction[java.util.Iterator[T]], preservesPartitioning: Boolean): JavaDoubleRDD = { - def fn = (x: Iterator[T]) => asScalaIterator(f.call(asJavaIterator(x)).iterator()) + def fn: (Iterator[T]) => Iterator[jl.Double] = { + (x: Iterator[T]) => asScalaIterator(f.call(asJavaIterator(x)).iterator()) + } new JavaDoubleRDD(rdd.mapPartitions(fn, preservesPartitioning) .map(x => x.doubleValue())) } @@ -194,7 +205,9 @@ trait JavaRDDLike[T, This <: JavaRDDLike[T, This]] extends Serializable { */ def mapPartitionsToPair[K2, V2](f: PairFlatMapFunction[java.util.Iterator[T], K2, V2], preservesPartitioning: Boolean): JavaPairRDD[K2, V2] = { - def fn = (x: Iterator[T]) => asScalaIterator(f.call(asJavaIterator(x)).iterator()) + def fn: (Iterator[T]) => Iterator[(K2, V2)] = { + (x: Iterator[T]) => asScalaIterator(f.call(asJavaIterator(x)).iterator()) + } JavaPairRDD.fromRDD( rdd.mapPartitions(fn, preservesPartitioning))(fakeClassTag[K2], fakeClassTag[V2]) } @@ -277,8 +290,10 @@ trait JavaRDDLike[T, This <: JavaRDDLike[T, This]] extends Serializable { def zipPartitions[U, V]( other: JavaRDDLike[U, _], f: FlatMapFunction2[java.util.Iterator[T], java.util.Iterator[U], V]): JavaRDD[V] = { - def fn = (x: Iterator[T], y: Iterator[U]) => asScalaIterator( - f.call(asJavaIterator(x), asJavaIterator(y)).iterator()) + def fn: (Iterator[T], Iterator[U]) => Iterator[V] = { + (x: Iterator[T], y: Iterator[U]) => asScalaIterator( + f.call(asJavaIterator(x), asJavaIterator(y)).iterator()) + } JavaRDD.fromRDD( rdd.zipPartitions(other.rdd)(fn)(other.classTag, fakeClassTag[V]))(fakeClassTag[V]) } @@ -441,8 +456,8 @@ trait JavaRDDLike[T, This <: JavaRDDLike[T, This]] extends Serializable { * Return the count of each unique value in this RDD as a map of (value, count) pairs. The final * combine step happens locally on the master, equivalent to running a single reduce task. */ - def countByValue(): java.util.Map[T, java.lang.Long] = - mapAsSerializableJavaMap(rdd.countByValue().map((x => (x._1, new java.lang.Long(x._2))))) + def countByValue(): java.util.Map[T, jl.Long] = + mapAsSerializableJavaMap(rdd.countByValue().map((x => (x._1, new jl.Long(x._2))))) /** * (Experimental) Approximate version of countByValue(). diff --git a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala index 19f4c95fcad74..b1ffba4c546bf 100644 --- a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala +++ b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala @@ -605,7 +605,6 @@ private[spark] object PythonRDD extends Logging { */ private def serveIterator[T](items: Iterator[T], threadName: String): Int = { val serverSocket = new ServerSocket(0, 1) - serverSocket.setReuseAddress(true) // Close the socket if no connection in 3 seconds serverSocket.setSoTimeout(3000) @@ -615,9 +614,9 @@ private[spark] object PythonRDD extends Logging { try { val sock = serverSocket.accept() val out = new DataOutputStream(new BufferedOutputStream(sock.getOutputStream)) - try { + Utils.tryWithSafeFinally { writeIteratorToStream(items, out) - } finally { + } { out.close() } } catch { @@ -863,9 +862,9 @@ private[spark] class PythonBroadcast(@transient var path: String) extends Serial val file = File.createTempFile("broadcast", "", dir) path = file.getAbsolutePath val out = new FileOutputStream(file) - try { + Utils.tryWithSafeFinally { Utils.copyStream(in, out) - } finally { + } { out.close() } } diff --git a/core/src/main/scala/org/apache/spark/api/r/RBackend.scala b/core/src/main/scala/org/apache/spark/api/r/RBackend.scala new file mode 100644 index 0000000000000..3a2c94bd9d875 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/api/r/RBackend.scala @@ -0,0 +1,145 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.api.r + +import java.io.{DataOutputStream, File, FileOutputStream, IOException} +import java.net.{InetSocketAddress, ServerSocket} +import java.util.concurrent.TimeUnit + +import io.netty.bootstrap.ServerBootstrap +import io.netty.channel.{ChannelFuture, ChannelInitializer, EventLoopGroup} +import io.netty.channel.nio.NioEventLoopGroup +import io.netty.channel.socket.SocketChannel +import io.netty.channel.socket.nio.NioServerSocketChannel +import io.netty.handler.codec.LengthFieldBasedFrameDecoder +import io.netty.handler.codec.bytes.{ByteArrayDecoder, ByteArrayEncoder} + +import org.apache.spark.Logging + +/** + * Netty-based backend server that is used to communicate between R and Java. + */ +private[spark] class RBackend { + + private[this] var channelFuture: ChannelFuture = null + private[this] var bootstrap: ServerBootstrap = null + private[this] var bossGroup: EventLoopGroup = null + + def init(): Int = { + bossGroup = new NioEventLoopGroup(2) + val workerGroup = bossGroup + val handler = new RBackendHandler(this) + + bootstrap = new ServerBootstrap() + .group(bossGroup, workerGroup) + .channel(classOf[NioServerSocketChannel]) + + bootstrap.childHandler(new ChannelInitializer[SocketChannel]() { + def initChannel(ch: SocketChannel): Unit = { + ch.pipeline() + .addLast("encoder", new ByteArrayEncoder()) + .addLast("frameDecoder", + // maxFrameLength = 2G + // lengthFieldOffset = 0 + // lengthFieldLength = 4 + // lengthAdjustment = 0 + // initialBytesToStrip = 4, i.e. strip out the length field itself + new LengthFieldBasedFrameDecoder(Integer.MAX_VALUE, 0, 4, 0, 4)) + .addLast("decoder", new ByteArrayDecoder()) + .addLast("handler", handler) + } + }) + + channelFuture = bootstrap.bind(new InetSocketAddress(0)) + channelFuture.syncUninterruptibly() + channelFuture.channel().localAddress().asInstanceOf[InetSocketAddress].getPort() + } + + def run(): Unit = { + channelFuture.channel.closeFuture().syncUninterruptibly() + } + + def close(): Unit = { + if (channelFuture != null) { + // close is a local operation and should finish within milliseconds; timeout just to be safe + channelFuture.channel().close().awaitUninterruptibly(10, TimeUnit.SECONDS) + channelFuture = null + } + if (bootstrap != null && bootstrap.group() != null) { + bootstrap.group().shutdownGracefully() + } + if (bootstrap != null && bootstrap.childGroup() != null) { + bootstrap.childGroup().shutdownGracefully() + } + bootstrap = null + } + +} + +private[spark] object RBackend extends Logging { + def main(args: Array[String]): Unit = { + if (args.length < 1) { + System.err.println("Usage: RBackend ") + System.exit(-1) + } + val sparkRBackend = new RBackend() + try { + // bind to random port + val boundPort = sparkRBackend.init() + val serverSocket = new ServerSocket(0, 1) + val listenPort = serverSocket.getLocalPort() + + // tell the R process via temporary file + val path = args(0) + val f = new File(path + ".tmp") + val dos = new DataOutputStream(new FileOutputStream(f)) + dos.writeInt(boundPort) + dos.writeInt(listenPort) + dos.close() + f.renameTo(new File(path)) + + // wait for the end of stdin, then exit + new Thread("wait for socket to close") { + setDaemon(true) + override def run(): Unit = { + // any un-catched exception will also shutdown JVM + val buf = new Array[Byte](1024) + // shutdown JVM if R does not connect back in 10 seconds + serverSocket.setSoTimeout(10000) + try { + val inSocket = serverSocket.accept() + serverSocket.close() + // wait for the end of socket, closed if R process die + inSocket.getInputStream().read(buf) + } finally { + sparkRBackend.close() + System.exit(0) + } + } + }.start() + + sparkRBackend.run() + } catch { + case e: IOException => + logError("Server shutting down: failed with exception ", e) + sparkRBackend.close() + System.exit(1) + } + System.exit(0) + } +} diff --git a/core/src/main/scala/org/apache/spark/api/r/RBackendHandler.scala b/core/src/main/scala/org/apache/spark/api/r/RBackendHandler.scala new file mode 100644 index 0000000000000..0075d963711f1 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/api/r/RBackendHandler.scala @@ -0,0 +1,223 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.api.r + +import java.io.{ByteArrayInputStream, ByteArrayOutputStream, DataInputStream, DataOutputStream} + +import scala.collection.mutable.HashMap + +import io.netty.channel.ChannelHandler.Sharable +import io.netty.channel.{ChannelHandlerContext, SimpleChannelInboundHandler} + +import org.apache.spark.Logging +import org.apache.spark.api.r.SerDe._ + +/** + * Handler for RBackend + * TODO: This is marked as sharable to get a handle to RBackend. Is it safe to re-use + * this across connections ? + */ +@Sharable +private[r] class RBackendHandler(server: RBackend) + extends SimpleChannelInboundHandler[Array[Byte]] with Logging { + + override def channelRead0(ctx: ChannelHandlerContext, msg: Array[Byte]): Unit = { + val bis = new ByteArrayInputStream(msg) + val dis = new DataInputStream(bis) + + val bos = new ByteArrayOutputStream() + val dos = new DataOutputStream(bos) + + // First bit is isStatic + val isStatic = readBoolean(dis) + val objId = readString(dis) + val methodName = readString(dis) + val numArgs = readInt(dis) + + if (objId == "SparkRHandler") { + methodName match { + case "stopBackend" => + writeInt(dos, 0) + writeType(dos, "void") + server.close() + case "rm" => + try { + val t = readObjectType(dis) + assert(t == 'c') + val objToRemove = readString(dis) + JVMObjectTracker.remove(objToRemove) + writeInt(dos, 0) + writeObject(dos, null) + } catch { + case e: Exception => + logError(s"Removing $objId failed", e) + writeInt(dos, -1) + } + case _ => dos.writeInt(-1) + } + } else { + handleMethodCall(isStatic, objId, methodName, numArgs, dis, dos) + } + + val reply = bos.toByteArray + ctx.write(reply) + } + + override def channelReadComplete(ctx: ChannelHandlerContext): Unit = { + ctx.flush() + } + + override def exceptionCaught(ctx: ChannelHandlerContext, cause: Throwable): Unit = { + // Close the connection when an exception is raised. + cause.printStackTrace() + ctx.close() + } + + def handleMethodCall( + isStatic: Boolean, + objId: String, + methodName: String, + numArgs: Int, + dis: DataInputStream, + dos: DataOutputStream): Unit = { + var obj: Object = null + try { + val cls = if (isStatic) { + Class.forName(objId) + } else { + JVMObjectTracker.get(objId) match { + case None => throw new IllegalArgumentException("Object not found " + objId) + case Some(o) => + obj = o + o.getClass + } + } + + val args = readArgs(numArgs, dis) + + val methods = cls.getMethods + val selectedMethods = methods.filter(m => m.getName == methodName) + if (selectedMethods.length > 0) { + val methods = selectedMethods.filter { x => + matchMethod(numArgs, args, x.getParameterTypes) + } + if (methods.isEmpty) { + logWarning(s"cannot find matching method ${cls}.$methodName. " + + s"Candidates are:") + selectedMethods.foreach { method => + logWarning(s"$methodName(${method.getParameterTypes.mkString(",")})") + } + throw new Exception(s"No matched method found for $cls.$methodName") + } + val ret = methods.head.invoke(obj, args:_*) + + // Write status bit + writeInt(dos, 0) + writeObject(dos, ret.asInstanceOf[AnyRef]) + } else if (methodName == "") { + // methodName should be "" for constructor + val ctor = cls.getConstructors.filter { x => + matchMethod(numArgs, args, x.getParameterTypes) + }.head + + val obj = ctor.newInstance(args:_*) + + writeInt(dos, 0) + writeObject(dos, obj.asInstanceOf[AnyRef]) + } else { + throw new IllegalArgumentException("invalid method " + methodName + " for object " + objId) + } + } catch { + case e: Exception => + logError(s"$methodName on $objId failed", e) + writeInt(dos, -1) + } + } + + // Read a number of arguments from the data input stream + def readArgs(numArgs: Int, dis: DataInputStream): Array[java.lang.Object] = { + (0 until numArgs).map { arg => + readObject(dis) + }.toArray + } + + // Checks if the arguments passed in args matches the parameter types. + // NOTE: Currently we do exact match. We may add type conversions later. + def matchMethod( + numArgs: Int, + args: Array[java.lang.Object], + parameterTypes: Array[Class[_]]): Boolean = { + if (parameterTypes.length != numArgs) { + return false + } + + for (i <- 0 to numArgs - 1) { + val parameterType = parameterTypes(i) + var parameterWrapperType = parameterType + + // Convert native parameters to Object types as args is Array[Object] here + if (parameterType.isPrimitive) { + parameterWrapperType = parameterType match { + case java.lang.Integer.TYPE => classOf[java.lang.Integer] + case java.lang.Double.TYPE => classOf[java.lang.Double] + case java.lang.Boolean.TYPE => classOf[java.lang.Boolean] + case _ => parameterType + } + } + if (!parameterWrapperType.isInstance(args(i))) { + return false + } + } + true + } +} + +/** + * Helper singleton that tracks JVM objects returned to R. + * This is useful for referencing these objects in RPC calls. + */ +private[r] object JVMObjectTracker { + + // TODO: This map should be thread-safe if we want to support multiple + // connections at the same time + private[this] val objMap = new HashMap[String, Object] + + // TODO: We support only one connection now, so an integer is fine. + // Investigate using use atomic integer in the future. + private[this] var objCounter: Int = 0 + + def getObject(id: String): Object = { + objMap(id) + } + + def get(id: String): Option[Object] = { + objMap.get(id) + } + + def put(obj: Object): String = { + val objId = objCounter.toString + objCounter = objCounter + 1 + objMap.put(objId, obj) + objId + } + + def remove(id: String): Option[Object] = { + objMap.remove(id) + } + +} diff --git a/core/src/main/scala/org/apache/spark/api/r/RRDD.scala b/core/src/main/scala/org/apache/spark/api/r/RRDD.scala new file mode 100644 index 0000000000000..5fa4d483b8342 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/api/r/RRDD.scala @@ -0,0 +1,450 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.api.r + +import java.io._ +import java.net.ServerSocket +import java.util.{Map => JMap} + +import scala.collection.JavaConversions._ +import scala.io.Source +import scala.reflect.ClassTag +import scala.util.Try + +import org.apache.spark._ +import org.apache.spark.api.java.{JavaPairRDD, JavaRDD, JavaSparkContext} +import org.apache.spark.broadcast.Broadcast +import org.apache.spark.rdd.RDD +import org.apache.spark.util.Utils + +private abstract class BaseRRDD[T: ClassTag, U: ClassTag]( + parent: RDD[T], + numPartitions: Int, + func: Array[Byte], + deserializer: String, + serializer: String, + packageNames: Array[Byte], + rLibDir: String, + broadcastVars: Array[Broadcast[Object]]) + extends RDD[U](parent) with Logging { + override def getPartitions: Array[Partition] = parent.partitions + + override def compute(partition: Partition, context: TaskContext): Iterator[U] = { + + // The parent may be also an RRDD, so we should launch it first. + val parentIterator = firstParent[T].iterator(partition, context) + + // we expect two connections + val serverSocket = new ServerSocket(0, 2) + val listenPort = serverSocket.getLocalPort() + + // The stdout/stderr is shared by multiple tasks, because we use one daemon + // to launch child process as worker. + val errThread = RRDD.createRWorker(rLibDir, listenPort) + + // We use two sockets to separate input and output, then it's easy to manage + // the lifecycle of them to avoid deadlock. + // TODO: optimize it to use one socket + + // the socket used to send out the input of task + serverSocket.setSoTimeout(10000) + val inSocket = serverSocket.accept() + startStdinThread(inSocket.getOutputStream(), parentIterator, partition.index) + + // the socket used to receive the output of task + val outSocket = serverSocket.accept() + val inputStream = new BufferedInputStream(outSocket.getInputStream) + val dataStream = openDataStream(inputStream) + serverSocket.close() + + try { + + return new Iterator[U] { + def next(): U = { + val obj = _nextObj + if (hasNext) { + _nextObj = read() + } + obj + } + + var _nextObj = read() + + def hasNext(): Boolean = { + val hasMore = (_nextObj != null) + if (!hasMore) { + dataStream.close() + } + hasMore + } + } + } catch { + case e: Exception => + throw new SparkException("R computation failed with\n " + errThread.getLines()) + } + } + + /** + * Start a thread to write RDD data to the R process. + */ + private def startStdinThread[T]( + output: OutputStream, + iter: Iterator[T], + partition: Int): Unit = { + + val env = SparkEnv.get + val bufferSize = System.getProperty("spark.buffer.size", "65536").toInt + val stream = new BufferedOutputStream(output, bufferSize) + + new Thread("writer for R") { + override def run(): Unit = { + try { + SparkEnv.set(env) + val dataOut = new DataOutputStream(stream) + dataOut.writeInt(partition) + + SerDe.writeString(dataOut, deserializer) + SerDe.writeString(dataOut, serializer) + + dataOut.writeInt(packageNames.length) + dataOut.write(packageNames) + + dataOut.writeInt(func.length) + dataOut.write(func) + + dataOut.writeInt(broadcastVars.length) + broadcastVars.foreach { broadcast => + // TODO(shivaram): Read a Long in R to avoid this cast + dataOut.writeInt(broadcast.id.toInt) + // TODO: Pass a byte array from R to avoid this cast ? + val broadcastByteArr = broadcast.value.asInstanceOf[Array[Byte]] + dataOut.writeInt(broadcastByteArr.length) + dataOut.write(broadcastByteArr) + } + + dataOut.writeInt(numPartitions) + + if (!iter.hasNext) { + dataOut.writeInt(0) + } else { + dataOut.writeInt(1) + } + + val printOut = new PrintStream(stream) + + def writeElem(elem: Any): Unit = { + if (deserializer == SerializationFormats.BYTE) { + val elemArr = elem.asInstanceOf[Array[Byte]] + dataOut.writeInt(elemArr.length) + dataOut.write(elemArr) + } else if (deserializer == SerializationFormats.ROW) { + dataOut.write(elem.asInstanceOf[Array[Byte]]) + } else if (deserializer == SerializationFormats.STRING) { + printOut.println(elem) + } + } + + for (elem <- iter) { + elem match { + case (key, value) => + writeElem(key) + writeElem(value) + case _ => + writeElem(elem) + } + } + stream.flush() + } catch { + // TODO: We should propogate this error to the task thread + case e: Exception => + logError("R Writer thread got an exception", e) + } finally { + Try(output.close()) + } + } + }.start() + } + + protected def openDataStream(input: InputStream): Closeable + + protected def read(): U +} + +/** + * Form an RDD[(Int, Array[Byte])] from key-value pairs returned from R. + * This is used by SparkR's shuffle operations. + */ +private class PairwiseRRDD[T: ClassTag]( + parent: RDD[T], + numPartitions: Int, + hashFunc: Array[Byte], + deserializer: String, + packageNames: Array[Byte], + rLibDir: String, + broadcastVars: Array[Object]) + extends BaseRRDD[T, (Int, Array[Byte])]( + parent, numPartitions, hashFunc, deserializer, + SerializationFormats.BYTE, packageNames, rLibDir, + broadcastVars.map(x => x.asInstanceOf[Broadcast[Object]])) { + + private var dataStream: DataInputStream = _ + + override protected def openDataStream(input: InputStream): Closeable = { + dataStream = new DataInputStream(input) + dataStream + } + + override protected def read(): (Int, Array[Byte]) = { + try { + val length = dataStream.readInt() + + length match { + case length if length == 2 => + val hashedKey = dataStream.readInt() + val contentPairsLength = dataStream.readInt() + val contentPairs = new Array[Byte](contentPairsLength) + dataStream.readFully(contentPairs) + (hashedKey, contentPairs) + case _ => null // End of input + } + } catch { + case eof: EOFException => { + throw new SparkException("R worker exited unexpectedly (crashed)", eof) + } + } + } + + lazy val asJavaPairRDD : JavaPairRDD[Int, Array[Byte]] = JavaPairRDD.fromRDD(this) +} + +/** + * An RDD that stores serialized R objects as Array[Byte]. + */ +private class RRDD[T: ClassTag]( + parent: RDD[T], + func: Array[Byte], + deserializer: String, + serializer: String, + packageNames: Array[Byte], + rLibDir: String, + broadcastVars: Array[Object]) + extends BaseRRDD[T, Array[Byte]]( + parent, -1, func, deserializer, serializer, packageNames, rLibDir, + broadcastVars.map(x => x.asInstanceOf[Broadcast[Object]])) { + + private var dataStream: DataInputStream = _ + + override protected def openDataStream(input: InputStream): Closeable = { + dataStream = new DataInputStream(input) + dataStream + } + + override protected def read(): Array[Byte] = { + try { + val length = dataStream.readInt() + + length match { + case length if length > 0 => + val obj = new Array[Byte](length) + dataStream.readFully(obj, 0, length) + obj + case _ => null + } + } catch { + case eof: EOFException => { + throw new SparkException("R worker exited unexpectedly (crashed)", eof) + } + } + } + + lazy val asJavaRDD : JavaRDD[Array[Byte]] = JavaRDD.fromRDD(this) +} + +/** + * An RDD that stores R objects as Array[String]. + */ +private class StringRRDD[T: ClassTag]( + parent: RDD[T], + func: Array[Byte], + deserializer: String, + packageNames: Array[Byte], + rLibDir: String, + broadcastVars: Array[Object]) + extends BaseRRDD[T, String]( + parent, -1, func, deserializer, SerializationFormats.STRING, packageNames, rLibDir, + broadcastVars.map(x => x.asInstanceOf[Broadcast[Object]])) { + + private var dataStream: BufferedReader = _ + + override protected def openDataStream(input: InputStream): Closeable = { + dataStream = new BufferedReader(new InputStreamReader(input)) + dataStream + } + + override protected def read(): String = { + try { + dataStream.readLine() + } catch { + case e: IOException => { + throw new SparkException("R worker exited unexpectedly (crashed)", e) + } + } + } + + lazy val asJavaRDD : JavaRDD[String] = JavaRDD.fromRDD(this) +} + +private[r] class BufferedStreamThread( + in: InputStream, + name: String, + errBufferSize: Int) extends Thread(name) with Logging { + val lines = new Array[String](errBufferSize) + var lineIdx = 0 + override def run() { + for (line <- Source.fromInputStream(in).getLines) { + synchronized { + lines(lineIdx) = line + lineIdx = (lineIdx + 1) % errBufferSize + } + logInfo(line) + } + } + + def getLines(): String = synchronized { + (0 until errBufferSize).filter { x => + lines((x + lineIdx) % errBufferSize) != null + }.map { x => + lines((x + lineIdx) % errBufferSize) + }.mkString("\n") + } +} + +private[r] object RRDD { + // Because forking processes from Java is expensive, we prefer to launch + // a single R daemon (daemon.R) and tell it to fork new workers for our tasks. + // This daemon currently only works on UNIX-based systems now, so we should + // also fall back to launching workers (worker.R) directly. + private[this] var errThread: BufferedStreamThread = _ + private[this] var daemonChannel: DataOutputStream = _ + + def createSparkContext( + master: String, + appName: String, + sparkHome: String, + jars: Array[String], + sparkEnvirMap: JMap[Object, Object], + sparkExecutorEnvMap: JMap[Object, Object]): JavaSparkContext = { + + val sparkConf = new SparkConf().setAppName(appName) + .setSparkHome(sparkHome) + .setJars(jars) + + // Override `master` if we have a user-specified value + if (master != "") { + sparkConf.setMaster(master) + } else { + // If conf has no master set it to "local" to maintain + // backwards compatibility + sparkConf.setIfMissing("spark.master", "local") + } + + for ((name, value) <- sparkEnvirMap) { + sparkConf.set(name.asInstanceOf[String], value.asInstanceOf[String]) + } + for ((name, value) <- sparkExecutorEnvMap) { + sparkConf.setExecutorEnv(name.asInstanceOf[String], value.asInstanceOf[String]) + } + + new JavaSparkContext(sparkConf) + } + + /** + * Start a thread to print the process's stderr to ours + */ + private def startStdoutThread(proc: Process): BufferedStreamThread = { + val BUFFER_SIZE = 100 + val thread = new BufferedStreamThread(proc.getInputStream, "stdout reader for R", BUFFER_SIZE) + thread.setDaemon(true) + thread.start() + thread + } + + private def createRProcess(rLibDir: String, port: Int, script: String): BufferedStreamThread = { + val rCommand = "Rscript" + val rOptions = "--vanilla" + val rExecScript = rLibDir + "/SparkR/worker/" + script + val pb = new ProcessBuilder(List(rCommand, rOptions, rExecScript)) + // Unset the R_TESTS environment variable for workers. + // This is set by R CMD check as startup.Rs + // (http://svn.r-project.org/R/trunk/src/library/tools/R/testing.R) + // and confuses worker script which tries to load a non-existent file + pb.environment().put("R_TESTS", "") + pb.environment().put("SPARKR_RLIBDIR", rLibDir) + pb.environment().put("SPARKR_WORKER_PORT", port.toString) + pb.redirectErrorStream(true) // redirect stderr into stdout + val proc = pb.start() + val errThread = startStdoutThread(proc) + errThread + } + + /** + * ProcessBuilder used to launch worker R processes. + */ + def createRWorker(rLibDir: String, port: Int): BufferedStreamThread = { + val useDaemon = SparkEnv.get.conf.getBoolean("spark.sparkr.use.daemon", true) + if (!Utils.isWindows && useDaemon) { + synchronized { + if (daemonChannel == null) { + // we expect one connections + val serverSocket = new ServerSocket(0, 1) + val daemonPort = serverSocket.getLocalPort + errThread = createRProcess(rLibDir, daemonPort, "daemon.R") + // the socket used to send out the input of task + serverSocket.setSoTimeout(10000) + val sock = serverSocket.accept() + daemonChannel = new DataOutputStream(new BufferedOutputStream(sock.getOutputStream)) + serverSocket.close() + } + try { + daemonChannel.writeInt(port) + daemonChannel.flush() + } catch { + case e: IOException => + // daemon process died + daemonChannel.close() + daemonChannel = null + errThread = null + // fail the current task, retry by scheduler + throw e + } + errThread + } + } else { + createRProcess(rLibDir, port, "worker.R") + } + } + + /** + * Create an RRDD given a sequence of byte arrays. Used to create RRDD when `parallelize` is + * called from R. + */ + def createRDDFromArray(jsc: JavaSparkContext, arr: Array[Array[Byte]]): JavaRDD[Array[Byte]] = { + JavaRDD.fromRDD(jsc.sc.parallelize(arr, arr.length)) + } + +} diff --git a/core/src/main/scala/org/apache/spark/api/r/SerDe.scala b/core/src/main/scala/org/apache/spark/api/r/SerDe.scala new file mode 100644 index 0000000000000..ccb2a371f4e48 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/api/r/SerDe.scala @@ -0,0 +1,340 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.api.r + +import java.io.{DataInputStream, DataOutputStream} +import java.sql.{Date, Time} + +import scala.collection.JavaConversions._ + +/** + * Utility functions to serialize, deserialize objects to / from R + */ +private[spark] object SerDe { + + // Type mapping from R to Java + // + // NULL -> void + // integer -> Int + // character -> String + // logical -> Boolean + // double, numeric -> Double + // raw -> Array[Byte] + // Date -> Date + // POSIXlt/POSIXct -> Time + // + // list[T] -> Array[T], where T is one of above mentioned types + // environment -> Map[String, T], where T is a native type + // jobj -> Object, where jobj is an object created in the backend + + def readObjectType(dis: DataInputStream): Char = { + dis.readByte().toChar + } + + def readObject(dis: DataInputStream): Object = { + val dataType = readObjectType(dis) + readTypedObject(dis, dataType) + } + + def readTypedObject( + dis: DataInputStream, + dataType: Char): Object = { + dataType match { + case 'n' => null + case 'i' => new java.lang.Integer(readInt(dis)) + case 'd' => new java.lang.Double(readDouble(dis)) + case 'b' => new java.lang.Boolean(readBoolean(dis)) + case 'c' => readString(dis) + case 'e' => readMap(dis) + case 'r' => readBytes(dis) + case 'l' => readList(dis) + case 'D' => readDate(dis) + case 't' => readTime(dis) + case 'j' => JVMObjectTracker.getObject(readString(dis)) + case _ => throw new IllegalArgumentException(s"Invalid type $dataType") + } + } + + def readBytes(in: DataInputStream): Array[Byte] = { + val len = readInt(in) + val out = new Array[Byte](len) + val bytesRead = in.readFully(out) + out + } + + def readInt(in: DataInputStream): Int = { + in.readInt() + } + + def readDouble(in: DataInputStream): Double = { + in.readDouble() + } + + def readString(in: DataInputStream): String = { + val len = in.readInt() + val asciiBytes = new Array[Byte](len) + in.readFully(asciiBytes) + assert(asciiBytes(len - 1) == 0) + val str = new String(asciiBytes.dropRight(1).map(_.toChar)) + str + } + + def readBoolean(in: DataInputStream): Boolean = { + val intVal = in.readInt() + if (intVal == 0) false else true + } + + def readDate(in: DataInputStream): Date = { + Date.valueOf(readString(in)) + } + + def readTime(in: DataInputStream): Time = { + val t = in.readDouble() + new Time((t * 1000L).toLong) + } + + def readBytesArr(in: DataInputStream): Array[Array[Byte]] = { + val len = readInt(in) + (0 until len).map(_ => readBytes(in)).toArray + } + + def readIntArr(in: DataInputStream): Array[Int] = { + val len = readInt(in) + (0 until len).map(_ => readInt(in)).toArray + } + + def readDoubleArr(in: DataInputStream): Array[Double] = { + val len = readInt(in) + (0 until len).map(_ => readDouble(in)).toArray + } + + def readBooleanArr(in: DataInputStream): Array[Boolean] = { + val len = readInt(in) + (0 until len).map(_ => readBoolean(in)).toArray + } + + def readStringArr(in: DataInputStream): Array[String] = { + val len = readInt(in) + (0 until len).map(_ => readString(in)).toArray + } + + def readList(dis: DataInputStream): Array[_] = { + val arrType = readObjectType(dis) + arrType match { + case 'i' => readIntArr(dis) + case 'c' => readStringArr(dis) + case 'd' => readDoubleArr(dis) + case 'b' => readBooleanArr(dis) + case 'j' => readStringArr(dis).map(x => JVMObjectTracker.getObject(x)) + case 'r' => readBytesArr(dis) + case _ => throw new IllegalArgumentException(s"Invalid array type $arrType") + } + } + + def readMap(in: DataInputStream): java.util.Map[Object, Object] = { + val len = readInt(in) + if (len > 0) { + val keysType = readObjectType(in) + val keysLen = readInt(in) + val keys = (0 until keysLen).map(_ => readTypedObject(in, keysType)) + + val valuesType = readObjectType(in) + val valuesLen = readInt(in) + val values = (0 until valuesLen).map(_ => readTypedObject(in, valuesType)) + mapAsJavaMap(keys.zip(values).toMap) + } else { + new java.util.HashMap[Object, Object]() + } + } + + // Methods to write out data from Java to R + // + // Type mapping from Java to R + // + // void -> NULL + // Int -> integer + // String -> character + // Boolean -> logical + // Double -> double + // Long -> double + // Array[Byte] -> raw + // Date -> Date + // Time -> POSIXct + // + // Array[T] -> list() + // Object -> jobj + + def writeType(dos: DataOutputStream, typeStr: String): Unit = { + typeStr match { + case "void" => dos.writeByte('n') + case "character" => dos.writeByte('c') + case "double" => dos.writeByte('d') + case "integer" => dos.writeByte('i') + case "logical" => dos.writeByte('b') + case "date" => dos.writeByte('D') + case "time" => dos.writeByte('t') + case "raw" => dos.writeByte('r') + case "list" => dos.writeByte('l') + case "jobj" => dos.writeByte('j') + case _ => throw new IllegalArgumentException(s"Invalid type $typeStr") + } + } + + def writeObject(dos: DataOutputStream, value: Object): Unit = { + if (value == null) { + writeType(dos, "void") + } else { + value.getClass.getName match { + case "java.lang.String" => + writeType(dos, "character") + writeString(dos, value.asInstanceOf[String]) + case "long" | "java.lang.Long" => + writeType(dos, "double") + writeDouble(dos, value.asInstanceOf[Long].toDouble) + case "double" | "java.lang.Double" => + writeType(dos, "double") + writeDouble(dos, value.asInstanceOf[Double]) + case "int" | "java.lang.Integer" => + writeType(dos, "integer") + writeInt(dos, value.asInstanceOf[Int]) + case "boolean" | "java.lang.Boolean" => + writeType(dos, "logical") + writeBoolean(dos, value.asInstanceOf[Boolean]) + case "java.sql.Date" => + writeType(dos, "date") + writeDate(dos, value.asInstanceOf[Date]) + case "java.sql.Time" => + writeType(dos, "time") + writeTime(dos, value.asInstanceOf[Time]) + case "[B" => + writeType(dos, "raw") + writeBytes(dos, value.asInstanceOf[Array[Byte]]) + // TODO: Types not handled right now include + // byte, char, short, float + + // Handle arrays + case "[Ljava.lang.String;" => + writeType(dos, "list") + writeStringArr(dos, value.asInstanceOf[Array[String]]) + case "[I" => + writeType(dos, "list") + writeIntArr(dos, value.asInstanceOf[Array[Int]]) + case "[J" => + writeType(dos, "list") + writeDoubleArr(dos, value.asInstanceOf[Array[Long]].map(_.toDouble)) + case "[D" => + writeType(dos, "list") + writeDoubleArr(dos, value.asInstanceOf[Array[Double]]) + case "[Z" => + writeType(dos, "list") + writeBooleanArr(dos, value.asInstanceOf[Array[Boolean]]) + case "[[B" => + writeType(dos, "list") + writeBytesArr(dos, value.asInstanceOf[Array[Array[Byte]]]) + case otherName => + // Handle array of objects + if (otherName.startsWith("[L")) { + val objArr = value.asInstanceOf[Array[Object]] + writeType(dos, "list") + writeType(dos, "jobj") + dos.writeInt(objArr.length) + objArr.foreach(o => writeJObj(dos, o)) + } else { + writeType(dos, "jobj") + writeJObj(dos, value) + } + } + } + } + + def writeInt(out: DataOutputStream, value: Int): Unit = { + out.writeInt(value) + } + + def writeDouble(out: DataOutputStream, value: Double): Unit = { + out.writeDouble(value) + } + + def writeBoolean(out: DataOutputStream, value: Boolean): Unit = { + val intValue = if (value) 1 else 0 + out.writeInt(intValue) + } + + def writeDate(out: DataOutputStream, value: Date): Unit = { + writeString(out, value.toString) + } + + def writeTime(out: DataOutputStream, value: Time): Unit = { + out.writeDouble(value.getTime.toDouble / 1000.0) + } + + + // NOTE: Only works for ASCII right now + def writeString(out: DataOutputStream, value: String): Unit = { + val len = value.length + out.writeInt(len + 1) // For the \0 + out.writeBytes(value) + out.writeByte(0) + } + + def writeBytes(out: DataOutputStream, value: Array[Byte]): Unit = { + out.writeInt(value.length) + out.write(value) + } + + def writeJObj(out: DataOutputStream, value: Object): Unit = { + val objId = JVMObjectTracker.put(value) + writeString(out, objId) + } + + def writeIntArr(out: DataOutputStream, value: Array[Int]): Unit = { + writeType(out, "integer") + out.writeInt(value.length) + value.foreach(v => out.writeInt(v)) + } + + def writeDoubleArr(out: DataOutputStream, value: Array[Double]): Unit = { + writeType(out, "double") + out.writeInt(value.length) + value.foreach(v => out.writeDouble(v)) + } + + def writeBooleanArr(out: DataOutputStream, value: Array[Boolean]): Unit = { + writeType(out, "logical") + out.writeInt(value.length) + value.foreach(v => writeBoolean(out, v)) + } + + def writeStringArr(out: DataOutputStream, value: Array[String]): Unit = { + writeType(out, "character") + out.writeInt(value.length) + value.foreach(v => writeString(out, v)) + } + + def writeBytesArr(out: DataOutputStream, value: Array[Array[Byte]]): Unit = { + writeType(out, "raw") + out.writeInt(value.length) + value.foreach(v => writeBytes(out, v)) + } +} + +private[r] object SerializationFormats { + val BYTE = "byte" + val STRING = "string" + val ROW = "row" +} diff --git a/core/src/main/scala/org/apache/spark/broadcast/HttpBroadcast.scala b/core/src/main/scala/org/apache/spark/broadcast/HttpBroadcast.scala index 74ccfa6d3c9a3..4457c75e8b0fc 100644 --- a/core/src/main/scala/org/apache/spark/broadcast/HttpBroadcast.scala +++ b/core/src/main/scala/org/apache/spark/broadcast/HttpBroadcast.scala @@ -165,7 +165,7 @@ private[broadcast] object HttpBroadcast extends Logging { private def write(id: Long, value: Any) { val file = getFile(id) val fileOutputStream = new FileOutputStream(file) - try { + Utils.tryWithSafeFinally { val out: OutputStream = { if (compress) { compressionCodec.compressedOutputStream(fileOutputStream) @@ -175,10 +175,13 @@ private[broadcast] object HttpBroadcast extends Logging { } val ser = SparkEnv.get.serializer.newInstance() val serOut = ser.serializeStream(out) - serOut.writeObject(value) - serOut.close() + Utils.tryWithSafeFinally { + serOut.writeObject(value) + } { + serOut.close() + } files += file - } finally { + } { fileOutputStream.close() } } @@ -212,9 +215,11 @@ private[broadcast] object HttpBroadcast extends Logging { } val ser = SparkEnv.get.serializer.newInstance() val serIn = ser.deserializeStream(in) - val obj = serIn.readObject[T]() - serIn.close() - obj + Utils.tryWithSafeFinally { + serIn.readObject[T]() + } { + serIn.close() + } } /** diff --git a/core/src/main/scala/org/apache/spark/deploy/ApplicationDescription.scala b/core/src/main/scala/org/apache/spark/deploy/ApplicationDescription.scala index 3d0d68de8f495..b7ae9c1fc0a23 100644 --- a/core/src/main/scala/org/apache/spark/deploy/ApplicationDescription.scala +++ b/core/src/main/scala/org/apache/spark/deploy/ApplicationDescription.scala @@ -17,13 +17,15 @@ package org.apache.spark.deploy +import java.net.URI + private[spark] class ApplicationDescription( val name: String, val maxCores: Option[Int], val memoryPerSlave: Int, val command: Command, var appUiUrl: String, - val eventLogDir: Option[String] = None, + val eventLogDir: Option[URI] = None, // short name of compression codec used when writing event logs, if any (e.g. lzf) val eventLogCodec: Option[String] = None) extends Serializable { @@ -36,7 +38,7 @@ private[spark] class ApplicationDescription( memoryPerSlave: Int = memoryPerSlave, command: Command = command, appUiUrl: String = appUiUrl, - eventLogDir: Option[String] = eventLogDir, + eventLogDir: Option[URI] = eventLogDir, eventLogCodec: Option[String] = eventLogCodec): ApplicationDescription = new ApplicationDescription( name, maxCores, memoryPerSlave, command, appUiUrl, eventLogDir, eventLogCodec) diff --git a/core/src/main/scala/org/apache/spark/deploy/Client.scala b/core/src/main/scala/org/apache/spark/deploy/Client.scala index 65238af2caa24..8d13b2a2cd4f3 100644 --- a/core/src/main/scala/org/apache/spark/deploy/Client.scala +++ b/core/src/main/scala/org/apache/spark/deploy/Client.scala @@ -89,7 +89,7 @@ private class ClientActor(driverArgs: ClientArguments, conf: SparkConf) /* Find out driver status then exit the JVM */ def pollAndReportStatus(driverId: String) { - println(s"... waiting before polling master for driver state") + println("... waiting before polling master for driver state") Thread.sleep(5000) println("... polling master for driver state") val statusFuture = (masterActor ? RequestDriverStatus(driverId))(timeout) diff --git a/core/src/main/scala/org/apache/spark/deploy/LocalSparkCluster.scala b/core/src/main/scala/org/apache/spark/deploy/LocalSparkCluster.scala index 3ab425aab84c8..f0e77c2ba982b 100644 --- a/core/src/main/scala/org/apache/spark/deploy/LocalSparkCluster.scala +++ b/core/src/main/scala/org/apache/spark/deploy/LocalSparkCluster.scala @@ -53,7 +53,7 @@ class LocalSparkCluster( /* Start the Master */ val (masterSystem, masterPort, _, _) = Master.startSystemAndActor(localHostname, 0, 0, _conf) masterActorSystems += masterSystem - val masterUrl = "spark://" + localHostname + ":" + masterPort + val masterUrl = "spark://" + Utils.localHostNameForURI() + ":" + masterPort val masters = Array(masterUrl) /* Start the Workers */ diff --git a/core/src/main/scala/org/apache/spark/deploy/RRunner.scala b/core/src/main/scala/org/apache/spark/deploy/RRunner.scala new file mode 100644 index 0000000000000..e99779f299785 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/deploy/RRunner.scala @@ -0,0 +1,92 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.deploy + +import java.io._ +import java.util.concurrent.{Semaphore, TimeUnit} + +import scala.collection.JavaConversions._ + +import org.apache.hadoop.fs.Path + +import org.apache.spark.api.r.RBackend +import org.apache.spark.util.RedirectThread + +/** + * Main class used to launch SparkR applications using spark-submit. It executes R as a + * subprocess and then has it connect back to the JVM to access system properties etc. + */ +object RRunner { + def main(args: Array[String]): Unit = { + val rFile = PythonRunner.formatPath(args(0)) + + val otherArgs = args.slice(1, args.length) + + // Time to wait for SparkR backend to initialize in seconds + val backendTimeout = sys.env.getOrElse("SPARKR_BACKEND_TIMEOUT", "120").toInt + val rCommand = "Rscript" + + // Check if the file path exists. + // If not, change directory to current working directory for YARN cluster mode + val rF = new File(rFile) + val rFileNormalized = if (!rF.exists()) { + new Path(rFile).getName + } else { + rFile + } + + // Launch a SparkR backend server for the R process to connect to; this will let it see our + // Java system properties etc. + val sparkRBackend = new RBackend() + @volatile var sparkRBackendPort = 0 + val initialized = new Semaphore(0) + val sparkRBackendThread = new Thread("SparkR backend") { + override def run() { + sparkRBackendPort = sparkRBackend.init() + initialized.release() + sparkRBackend.run() + } + } + + sparkRBackendThread.start() + // Wait for RBackend initialization to finish + if (initialized.tryAcquire(backendTimeout, TimeUnit.SECONDS)) { + // Launch R + val returnCode = try { + val builder = new ProcessBuilder(Seq(rCommand, rFileNormalized) ++ otherArgs) + val env = builder.environment() + env.put("EXISTING_SPARKR_BACKEND_PORT", sparkRBackendPort.toString) + val sparkHome = System.getenv("SPARK_HOME") + env.put("R_PROFILE_USER", + Seq(sparkHome, "R", "lib", "SparkR", "profile", "general.R").mkString(File.separator)) + builder.redirectErrorStream(true) // Ugly but needed for stdout and stderr to synchronize + val process = builder.start() + + new RedirectThread(process.getInputStream, System.out, "redirect R output").start() + + process.waitFor() + } finally { + sparkRBackend.close() + } + System.exit(returnCode) + } else { + System.err.println("SparkR backend did not initialize in " + backendTimeout + " seconds") + System.exit(-1) + } + } +} diff --git a/core/src/main/scala/org/apache/spark/deploy/SparkHadoopUtil.scala b/core/src/main/scala/org/apache/spark/deploy/SparkHadoopUtil.scala index c2568eb4b60ac..cfaebf9ea5050 100644 --- a/core/src/main/scala/org/apache/spark/deploy/SparkHadoopUtil.scala +++ b/core/src/main/scala/org/apache/spark/deploy/SparkHadoopUtil.scala @@ -24,11 +24,10 @@ import org.apache.hadoop.conf.Configuration import org.apache.hadoop.fs.{FileStatus, FileSystem, Path} import org.apache.hadoop.fs.FileSystem.Statistics import org.apache.hadoop.mapred.JobConf -import org.apache.hadoop.mapreduce.{JobContext, TaskAttemptContext} -import org.apache.hadoop.security.Credentials -import org.apache.hadoop.security.UserGroupInformation +import org.apache.hadoop.mapreduce.JobContext +import org.apache.hadoop.security.{Credentials, UserGroupInformation} -import org.apache.spark.{Logging, SparkContext, SparkConf, SparkException} +import org.apache.spark.{Logging, SparkConf, SparkException} import org.apache.spark.annotation.DeveloperApi import org.apache.spark.util.Utils @@ -201,6 +200,37 @@ class SparkHadoopUtil extends Logging { val baseStatus = fs.getFileStatus(basePath) if (baseStatus.isDir) recurse(basePath) else Array(baseStatus) } + + private val HADOOP_CONF_PATTERN = "(\\$\\{hadoopconf-[^\\}\\$\\s]+\\})".r.unanchored + + /** + * Substitute variables by looking them up in Hadoop configs. Only variables that match the + * ${hadoopconf- .. } pattern are substituted. + */ + def substituteHadoopVariables(text: String, hadoopConf: Configuration): String = { + text match { + case HADOOP_CONF_PATTERN(matched) => { + logDebug(text + " matched " + HADOOP_CONF_PATTERN) + val key = matched.substring(13, matched.length() - 1) // remove ${hadoopconf- .. } + val eval = Option[String](hadoopConf.get(key)) + .map { value => + logDebug("Substituted " + matched + " with " + value) + text.replace(matched, value) + } + if (eval.isEmpty) { + // The variable was not found in Hadoop configs, so return text as is. + text + } else { + // Continue to substitute more variables. + substituteHadoopVariables(eval.get, hadoopConf) + } + } + case _ => { + logDebug(text + " didn't match " + HADOOP_CONF_PATTERN) + text + } + } + } } object SparkHadoopUtil { diff --git a/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala b/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala index 660307d19eab4..60bc243ebf40a 100644 --- a/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala +++ b/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala @@ -77,6 +77,7 @@ object SparkSubmit { // Special primary resource names that represent shells rather than application jars. private val SPARK_SHELL = "spark-shell" private val PYSPARK_SHELL = "pyspark-shell" + private val SPARKR_SHELL = "sparkr-shell" private val CLASS_NOT_FOUND_EXIT_STATUS = 101 @@ -284,6 +285,13 @@ object SparkSubmit { } } + // Require all R files to be local + if (args.isR && !isYarnCluster) { + if (Utils.nonLocalPaths(args.primaryResource).nonEmpty) { + printErrorAndExit(s"Only local R files are supported: $args.primaryResource") + } + } + // The following modes are not supported or applicable (clusterManager, deployMode) match { case (MESOS, CLUSTER) => @@ -291,6 +299,9 @@ object SparkSubmit { case (STANDALONE, CLUSTER) if args.isPython => printErrorAndExit("Cluster deploy mode is currently not supported for python " + "applications on standalone clusters.") + case (STANDALONE, CLUSTER) if args.isR => + printErrorAndExit("Cluster deploy mode is currently not supported for R " + + "applications on standalone clusters.") case (_, CLUSTER) if isShell(args.primaryResource) => printErrorAndExit("Cluster deploy mode is not applicable to Spark shells.") case (_, CLUSTER) if isSqlShell(args.mainClass) => @@ -317,11 +328,32 @@ object SparkSubmit { } } - // In yarn-cluster mode for a python app, add primary resource and pyFiles to files - // that can be distributed with the job - if (args.isPython && isYarnCluster) { - args.files = mergeFileLists(args.files, args.primaryResource) - args.files = mergeFileLists(args.files, args.pyFiles) + // If we're running a R app, set the main class to our specific R runner + if (args.isR && deployMode == CLIENT) { + if (args.primaryResource == SPARKR_SHELL) { + args.mainClass = "org.apache.spark.api.r.RBackend" + } else { + // If a R file is provided, add it to the child arguments and list of files to deploy. + // Usage: RRunner
[app arguments] + args.mainClass = "org.apache.spark.deploy.RRunner" + args.childArgs = ArrayBuffer(args.primaryResource) ++ args.childArgs + args.files = mergeFileLists(args.files, args.primaryResource) + } + } + + if (isYarnCluster) { + // In yarn-cluster mode for a python app, add primary resource and pyFiles to files + // that can be distributed with the job + if (args.isPython) { + args.files = mergeFileLists(args.files, args.primaryResource) + args.files = mergeFileLists(args.files, args.pyFiles) + } + + // In yarn-cluster mode for a R app, add primary resource to files + // that can be distributed with the job + if (args.isR) { + args.files = mergeFileLists(args.files, args.primaryResource) + } } // Special flag to avoid deprecation warnings at the client @@ -405,8 +437,8 @@ object SparkSubmit { // Add the application jar automatically so the user doesn't have to call sc.addJar // For YARN cluster mode, the jar is already distributed on each node as "app.jar" - // For python files, the primary resource is already distributed as a regular file - if (!isYarnCluster && !args.isPython) { + // For python and R files, the primary resource is already distributed as a regular file + if (!isYarnCluster && !args.isPython && !args.isR) { var jars = sysProps.get("spark.jars").map(x => x.split(",").toSeq).getOrElse(Seq.empty) if (isUserJar(args.primaryResource)) { jars = jars ++ Seq(args.primaryResource) @@ -447,6 +479,10 @@ object SparkSubmit { childArgs += ("--py-files", pyFilesNames) } childArgs += ("--class", "org.apache.spark.deploy.PythonRunner") + } else if (args.isR) { + val mainFile = new Path(args.primaryResource).getName + childArgs += ("--primary-r-file", mainFile) + childArgs += ("--class", "org.apache.spark.deploy.RRunner") } else { if (args.primaryResource != SPARK_INTERNAL) { childArgs += ("--jar", args.primaryResource) @@ -591,15 +627,15 @@ object SparkSubmit { /** * Return whether the given primary resource represents a user jar. */ - private def isUserJar(primaryResource: String): Boolean = { - !isShell(primaryResource) && !isPython(primaryResource) && !isInternal(primaryResource) + private[deploy] def isUserJar(res: String): Boolean = { + !isShell(res) && !isPython(res) && !isInternal(res) && !isR(res) } /** * Return whether the given primary resource represents a shell. */ - private[deploy] def isShell(primaryResource: String): Boolean = { - primaryResource == SPARK_SHELL || primaryResource == PYSPARK_SHELL + private[deploy] def isShell(res: String): Boolean = { + (res == SPARK_SHELL || res == PYSPARK_SHELL || res == SPARKR_SHELL) } /** @@ -619,12 +655,19 @@ object SparkSubmit { /** * Return whether the given primary resource requires running python. */ - private[deploy] def isPython(primaryResource: String): Boolean = { - primaryResource.endsWith(".py") || primaryResource == PYSPARK_SHELL + private[deploy] def isPython(res: String): Boolean = { + res != null && res.endsWith(".py") || res == PYSPARK_SHELL + } + + /** + * Return whether the given primary resource requires running R. + */ + private[deploy] def isR(res: String): Boolean = { + res != null && res.endsWith(".R") || res == SPARKR_SHELL } - private[deploy] def isInternal(primaryResource: String): Boolean = { - primaryResource == SPARK_INTERNAL + private[deploy] def isInternal(res: String): Boolean = { + res == SPARK_INTERNAL } /** diff --git a/core/src/main/scala/org/apache/spark/deploy/SparkSubmitArguments.scala b/core/src/main/scala/org/apache/spark/deploy/SparkSubmitArguments.scala index 6eb73c43470a5..03ecf3fd99ec5 100644 --- a/core/src/main/scala/org/apache/spark/deploy/SparkSubmitArguments.scala +++ b/core/src/main/scala/org/apache/spark/deploy/SparkSubmitArguments.scala @@ -59,6 +59,7 @@ private[deploy] class SparkSubmitArguments(args: Seq[String], env: Map[String, S var verbose: Boolean = false var isPython: Boolean = false var pyFiles: String = null + var isR: Boolean = false var action: SparkSubmitAction = null val sparkProperties: HashMap[String, String] = new HashMap[String, String]() var proxyUser: String = null @@ -158,7 +159,7 @@ private[deploy] class SparkSubmitArguments(args: Seq[String], env: Map[String, S .getOrElse(sparkProperties.get("spark.executor.instances").orNull) // Try to set main class from JAR if no --class argument is given - if (mainClass == null && !isPython && primaryResource != null) { + if (mainClass == null && !isPython && !isR && primaryResource != null) { val uri = new URI(primaryResource) val uriScheme = uri.getScheme() @@ -211,9 +212,9 @@ private[deploy] class SparkSubmitArguments(args: Seq[String], env: Map[String, S printUsageAndExit(-1) } if (primaryResource == null) { - SparkSubmit.printErrorAndExit("Must specify a primary resource (JAR or Python file)") + SparkSubmit.printErrorAndExit("Must specify a primary resource (JAR or Python or R file)") } - if (mainClass == null && !isPython) { + if (mainClass == null && SparkSubmit.isUserJar(primaryResource)) { SparkSubmit.printErrorAndExit("No main class set in JAR; please specify one with --class") } if (pyFiles != null && !isPython) { @@ -414,6 +415,7 @@ private[deploy] class SparkSubmitArguments(args: Seq[String], env: Map[String, S opt } isPython = SparkSubmit.isPython(opt) + isR = SparkSubmit.isR(opt) false } diff --git a/core/src/main/scala/org/apache/spark/deploy/client/TestClient.scala b/core/src/main/scala/org/apache/spark/deploy/client/TestClient.scala index c1c4812f17fbe..40835b9550586 100644 --- a/core/src/main/scala/org/apache/spark/deploy/client/TestClient.scala +++ b/core/src/main/scala/org/apache/spark/deploy/client/TestClient.scala @@ -46,7 +46,7 @@ private[spark] object TestClient { def main(args: Array[String]) { val url = args(0) val conf = new SparkConf - val (actorSystem, _) = AkkaUtils.createActorSystem("spark", Utils.localIpAddress, 0, + val (actorSystem, _) = AkkaUtils.createActorSystem("spark", Utils.localHostName(), 0, conf = conf, securityManager = new SecurityManager(conf)) val desc = new ApplicationDescription("TestClient", Some(1), 512, Command("spark.deploy.client.TestExecutor", Seq(), Map(), Seq(), Seq(), Seq()), "ignored") diff --git a/core/src/main/scala/org/apache/spark/deploy/history/FsHistoryProvider.scala b/core/src/main/scala/org/apache/spark/deploy/history/FsHistoryProvider.scala index e46672265d07a..4e8b899669f2f 100644 --- a/core/src/main/scala/org/apache/spark/deploy/history/FsHistoryProvider.scala +++ b/core/src/main/scala/org/apache/spark/deploy/history/FsHistoryProvider.scala @@ -116,7 +116,7 @@ private[history] class FsHistoryProvider(conf: SparkConf) extends ApplicationHis if (!fs.exists(path)) { var msg = s"Log directory specified does not exist: $logDir." if (logDir == DEFAULT_LOG_DIR) { - msg += " Did you configure the correct one through spark.fs.history.logDirectory?" + msg += " Did you configure the correct one through spark.history.fs.logDirectory?" } throw new IllegalArgumentException(msg) } diff --git a/core/src/main/scala/org/apache/spark/deploy/master/FileSystemPersistenceEngine.scala b/core/src/main/scala/org/apache/spark/deploy/master/FileSystemPersistenceEngine.scala index 32499b3a784a1..f459ed5b3a1a1 100644 --- a/core/src/main/scala/org/apache/spark/deploy/master/FileSystemPersistenceEngine.scala +++ b/core/src/main/scala/org/apache/spark/deploy/master/FileSystemPersistenceEngine.scala @@ -24,6 +24,7 @@ import scala.reflect.ClassTag import akka.serialization.Serialization import org.apache.spark.Logging +import org.apache.spark.util.Utils /** @@ -59,9 +60,9 @@ private[master] class FileSystemPersistenceEngine( val serializer = serialization.findSerializerFor(value) val serialized = serializer.toBinary(value) val out = new FileOutputStream(file) - try { + Utils.tryWithSafeFinally { out.write(serialized) - } finally { + } { out.close() } } diff --git a/core/src/main/scala/org/apache/spark/deploy/rest/StandaloneRestClient.scala b/core/src/main/scala/org/apache/spark/deploy/rest/StandaloneRestClient.scala index 420442f7564cc..b8fd406fb6f9a 100644 --- a/core/src/main/scala/org/apache/spark/deploy/rest/StandaloneRestClient.scala +++ b/core/src/main/scala/org/apache/spark/deploy/rest/StandaloneRestClient.scala @@ -27,6 +27,7 @@ import com.fasterxml.jackson.core.JsonProcessingException import com.google.common.base.Charsets import org.apache.spark.{Logging, SparkConf, SPARK_VERSION => sparkVersion} +import org.apache.spark.util.Utils /** * A client that submits applications to the standalone Master using a REST protocol. @@ -148,8 +149,11 @@ private[deploy] class StandaloneRestClient extends Logging { conn.setRequestProperty("charset", "utf-8") conn.setDoOutput(true) val out = new DataOutputStream(conn.getOutputStream) - out.write(json.getBytes(Charsets.UTF_8)) - out.close() + Utils.tryWithSafeFinally { + out.write(json.getBytes(Charsets.UTF_8)) + } { + out.close() + } readResponse(conn) } @@ -241,7 +245,7 @@ private[deploy] class StandaloneRestClient extends Logging { } } else { val failMessage = Option(submitResponse.message).map { ": " + _ }.getOrElse("") - logError("Application submission failed" + failMessage) + logError(s"Application submission failed$failMessage") } } diff --git a/core/src/main/scala/org/apache/spark/deploy/worker/ExecutorRunner.scala b/core/src/main/scala/org/apache/spark/deploy/worker/ExecutorRunner.scala index 83e24a7a1f80c..7d5acabb95a48 100644 --- a/core/src/main/scala/org/apache/spark/deploy/worker/ExecutorRunner.scala +++ b/core/src/main/scala/org/apache/spark/deploy/worker/ExecutorRunner.scala @@ -50,7 +50,7 @@ private[deploy] class ExecutorRunner( val workerUrl: String, conf: SparkConf, val appLocalDirs: Seq[String], - var state: ExecutorState.Value) + @volatile var state: ExecutorState.Value) extends Logging { private val fullId = appId + "/" + execId diff --git a/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala b/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala index 900e678ee02ef..8300f9f2190b9 100644 --- a/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala +++ b/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala @@ -21,39 +21,45 @@ import java.net.URL import java.nio.ByteBuffer import scala.collection.mutable -import scala.concurrent.Await +import scala.util.{Failure, Success} -import akka.actor.{Actor, ActorSelection, Props} -import akka.pattern.Patterns -import akka.remote.{RemotingLifecycleEvent, DisassociatedEvent} - -import org.apache.spark.{Logging, SecurityManager, SparkConf, SparkEnv} +import org.apache.spark.rpc._ +import org.apache.spark._ import org.apache.spark.TaskState.TaskState import org.apache.spark.deploy.SparkHadoopUtil import org.apache.spark.deploy.worker.WorkerWatcher import org.apache.spark.scheduler.TaskDescription import org.apache.spark.scheduler.cluster.CoarseGrainedClusterMessages._ -import org.apache.spark.util.{ActorLogReceive, AkkaUtils, SignalLogger, Utils} +import org.apache.spark.util.{SignalLogger, Utils} private[spark] class CoarseGrainedExecutorBackend( + override val rpcEnv: RpcEnv, driverUrl: String, executorId: String, hostPort: String, cores: Int, userClassPath: Seq[URL], env: SparkEnv) - extends Actor with ActorLogReceive with ExecutorBackend with Logging { + extends ThreadSafeRpcEndpoint with ExecutorBackend with Logging { Utils.checkHostPort(hostPort, "Expected hostport") var executor: Executor = null - var driver: ActorSelection = null + @volatile var driver: Option[RpcEndpointRef] = None - override def preStart() { + override def onStart() { + import scala.concurrent.ExecutionContext.Implicits.global logInfo("Connecting to driver: " + driverUrl) - driver = context.actorSelection(driverUrl) - driver ! RegisterExecutor(executorId, hostPort, cores, extractLogUrls) - context.system.eventStream.subscribe(self, classOf[RemotingLifecycleEvent]) + rpcEnv.asyncSetupEndpointRefByURI(driverUrl).flatMap { ref => + driver = Some(ref) + ref.sendWithReply[RegisteredExecutor.type]( + RegisterExecutor(executorId, self, hostPort, cores, extractLogUrls)) + } onComplete { + case Success(msg) => Utils.tryLogNonFatalError { + Option(self).foreach(_.send(msg)) // msg must be RegisteredExecutor + } + case Failure(e) => logError(s"Cannot register with driver: $driverUrl", e) + } } def extractLogUrls: Map[String, String] = { @@ -62,7 +68,7 @@ private[spark] class CoarseGrainedExecutorBackend( .map(e => (e._1.substring(prefix.length).toLowerCase, e._2)) } - override def receiveWithLogging: PartialFunction[Any, Unit] = { + override def receive: PartialFunction[Any, Unit] = { case RegisteredExecutor => logInfo("Successfully registered with driver") val (hostname, _) = Utils.parseHostPort(hostPort) @@ -92,23 +98,28 @@ private[spark] class CoarseGrainedExecutorBackend( executor.killTask(taskId, interruptThread) } - case x: DisassociatedEvent => - if (x.remoteAddress == driver.anchorPath.address) { - logError(s"Driver $x disassociated! Shutting down.") - System.exit(1) - } else { - logWarning(s"Received irrelevant DisassociatedEvent $x") - } - case StopExecutor => logInfo("Driver commanded a shutdown") executor.stop() - context.stop(self) - context.system.shutdown() + stop() + rpcEnv.shutdown() + } + + override def onDisconnected(remoteAddress: RpcAddress): Unit = { + if (driver.exists(_.address == remoteAddress)) { + logError(s"Driver $remoteAddress disassociated! Shutting down.") + System.exit(1) + } else { + logWarning(s"An unknown ($remoteAddress) driver disconnected.") + } } override def statusUpdate(taskId: Long, state: TaskState, data: ByteBuffer) { - driver ! StatusUpdate(executorId, taskId, state, data) + val msg = StatusUpdate(executorId, taskId, state, data) + driver match { + case Some(driverRef) => driverRef.send(msg) + case None => logWarning(s"Drop $msg because has not yet connected to driver") + } } } @@ -132,16 +143,14 @@ private[spark] object CoarseGrainedExecutorBackend extends Logging { // Bootstrap to fetch the driver's Spark properties. val executorConf = new SparkConf val port = executorConf.getInt("spark.executor.port", 0) - val (fetcher, _) = AkkaUtils.createActorSystem( + val fetcher = RpcEnv.create( "driverPropsFetcher", hostname, port, executorConf, new SecurityManager(executorConf)) - val driver = fetcher.actorSelection(driverUrl) - val timeout = AkkaUtils.askTimeout(executorConf) - val fut = Patterns.ask(driver, RetrieveSparkProps, timeout) - val props = Await.result(fut, timeout).asInstanceOf[Seq[(String, String)]] ++ + val driver = fetcher.setupEndpointRefByURI(driverUrl) + val props = driver.askWithReply[Seq[(String, String)]](RetrieveSparkProps) ++ Seq[(String, String)](("spark.app.id", appId)) fetcher.shutdown() @@ -162,16 +171,14 @@ private[spark] object CoarseGrainedExecutorBackend extends Logging { val boundPort = env.conf.getInt("spark.executor.port", 0) assert(boundPort != 0) - // Start the CoarseGrainedExecutorBackend actor. + // Start the CoarseGrainedExecutorBackend endpoint. val sparkHostPort = hostname + ":" + boundPort - env.actorSystem.actorOf( - Props(classOf[CoarseGrainedExecutorBackend], - driverUrl, executorId, sparkHostPort, cores, userClassPath, env), - name = "Executor") + env.rpcEnv.setupEndpoint("Executor", new CoarseGrainedExecutorBackend( + env.rpcEnv, driverUrl, executorId, sparkHostPort, cores, userClassPath, env)) workerUrl.foreach { url => env.rpcEnv.setupEndpoint("WorkerWatcher", new WorkerWatcher(env.rpcEnv, url)) } - env.actorSystem.awaitTermination() + env.rpcEnv.awaitTermination() } } diff --git a/core/src/main/scala/org/apache/spark/executor/Executor.scala b/core/src/main/scala/org/apache/spark/executor/Executor.scala index bf3135ef081c1..14f99a464b6e9 100644 --- a/core/src/main/scala/org/apache/spark/executor/Executor.scala +++ b/core/src/main/scala/org/apache/spark/executor/Executor.scala @@ -27,8 +27,6 @@ import scala.collection.JavaConversions._ import scala.collection.mutable.{ArrayBuffer, HashMap} import scala.util.control.NonFatal -import akka.actor.Props - import org.apache.spark._ import org.apache.spark.deploy.SparkHadoopUtil import org.apache.spark.scheduler.{DirectTaskResult, IndirectTaskResult, Task} @@ -88,9 +86,9 @@ private[spark] class Executor( env.blockManager.initialize(conf.getAppId) } - // Create an actor for receiving RPCs from the driver - private val executorActor = env.actorSystem.actorOf( - Props(new ExecutorActor(executorId)), "ExecutorActor") + // Create an RpcEndpoint for receiving RPCs from the driver + private val executorEndpoint = env.rpcEnv.setupEndpoint( + ExecutorEndpoint.EXECUTOR_ENDPOINT_NAME, new ExecutorEndpoint(env.rpcEnv, executorId)) // Whether to load classes in user jars before those in Spark jars private val userClassPathFirst: Boolean = { @@ -139,7 +137,7 @@ private[spark] class Executor( def stop(): Unit = { env.metricsSystem.report() - env.actorSystem.stop(executorActor) + env.rpcEnv.stop(executorEndpoint) isStopped = true threadPool.shutdown() if (!isLocal) { @@ -391,11 +389,8 @@ private[spark] class Executor( } } - private val timeout = AkkaUtils.lookupTimeout(conf) - private val retryAttempts = AkkaUtils.numRetries(conf) - private val retryIntervalMs = AkkaUtils.retryWaitMs(conf) private val heartbeatReceiverRef = - AkkaUtils.makeDriverRef("HeartbeatReceiver", conf, env.actorSystem) + RpcUtils.makeDriverRef(HeartbeatReceiver.ENDPOINT_NAME, conf, env.rpcEnv) /** Reports heartbeat and metrics for active tasks to the driver. */ private def reportHeartBeat(): Unit = { @@ -426,8 +421,7 @@ private[spark] class Executor( val message = Heartbeat(executorId, tasksMetrics.toArray, env.blockManager.blockManagerId) try { - val response = AkkaUtils.askWithReply[HeartbeatResponse](message, heartbeatReceiverRef, - retryAttempts, retryIntervalMs, timeout) + val response = heartbeatReceiverRef.askWithReply[HeartbeatResponse](message) if (response.reregisterBlockManager) { logWarning("Told to re-register on heartbeat") env.blockManager.reregister() diff --git a/core/src/main/scala/org/apache/spark/executor/ExecutorActor.scala b/core/src/main/scala/org/apache/spark/executor/ExecutorEndpoint.scala similarity index 67% rename from core/src/main/scala/org/apache/spark/executor/ExecutorActor.scala rename to core/src/main/scala/org/apache/spark/executor/ExecutorEndpoint.scala index 3e47d13f7545d..cf362f8464735 100644 --- a/core/src/main/scala/org/apache/spark/executor/ExecutorActor.scala +++ b/core/src/main/scala/org/apache/spark/executor/ExecutorEndpoint.scala @@ -17,10 +17,8 @@ package org.apache.spark.executor -import akka.actor.Actor -import org.apache.spark.Logging - -import org.apache.spark.util.{Utils, ActorLogReceive} +import org.apache.spark.rpc.{RpcEnv, RpcCallContext, RpcEndpoint} +import org.apache.spark.util.Utils /** * Driver -> Executor message to trigger a thread dump. @@ -28,14 +26,18 @@ import org.apache.spark.util.{Utils, ActorLogReceive} private[spark] case object TriggerThreadDump /** - * Actor that runs inside of executors to enable driver -> executor RPC. + * [[RpcEndpoint]] that runs inside of executors to enable driver -> executor RPC. */ private[spark] -class ExecutorActor(executorId: String) extends Actor with ActorLogReceive with Logging { +class ExecutorEndpoint(override val rpcEnv: RpcEnv, executorId: String) extends RpcEndpoint { - override def receiveWithLogging: PartialFunction[Any, Unit] = { + override def receiveAndReply(context: RpcCallContext): PartialFunction[Any, Unit] = { case TriggerThreadDump => - sender ! Utils.getThreadDump() + context.reply(Utils.getThreadDump()) } } + +object ExecutorEndpoint { + val EXECUTOR_ENDPOINT_NAME = "ExecutorEndpoint" +} diff --git a/core/src/main/scala/org/apache/spark/network/nio/Connection.scala b/core/src/main/scala/org/apache/spark/network/nio/Connection.scala index 04eb2bf9ba4ab..6b898bd4bfc1b 100644 --- a/core/src/main/scala/org/apache/spark/network/nio/Connection.scala +++ b/core/src/main/scala/org/apache/spark/network/nio/Connection.scala @@ -181,7 +181,7 @@ abstract class Connection(val channel: SocketChannel, val selector: Selector, buffer.get(bytes) bytes.foreach(x => print(x + " ")) buffer.position(curPosition) - print(" (" + bytes.size + ")") + print(" (" + bytes.length + ")") } def printBuffer(buffer: ByteBuffer, position: Int, length: Int) { diff --git a/core/src/main/scala/org/apache/spark/rdd/AsyncRDDActions.scala b/core/src/main/scala/org/apache/spark/rdd/AsyncRDDActions.scala index 646df283ac069..3406a7e97e368 100644 --- a/core/src/main/scala/org/apache/spark/rdd/AsyncRDDActions.scala +++ b/core/src/main/scala/org/apache/spark/rdd/AsyncRDDActions.scala @@ -45,7 +45,7 @@ class AsyncRDDActions[T: ClassTag](self: RDD[T]) extends Serializable with Loggi } result }, - Range(0, self.partitions.size), + Range(0, self.partitions.length), (index: Int, data: Long) => totalCount.addAndGet(data), totalCount.get()) } @@ -54,8 +54,8 @@ class AsyncRDDActions[T: ClassTag](self: RDD[T]) extends Serializable with Loggi * Returns a future for retrieving all elements of this RDD. */ def collectAsync(): FutureAction[Seq[T]] = { - val results = new Array[Array[T]](self.partitions.size) - self.context.submitJob[T, Array[T], Seq[T]](self, _.toArray, Range(0, self.partitions.size), + val results = new Array[Array[T]](self.partitions.length) + self.context.submitJob[T, Array[T], Seq[T]](self, _.toArray, Range(0, self.partitions.length), (index, data) => results(index) = data, results.flatten.toSeq) } @@ -111,7 +111,7 @@ class AsyncRDDActions[T: ClassTag](self: RDD[T]) extends Serializable with Loggi */ def foreachAsync(f: T => Unit): FutureAction[Unit] = { val cleanF = self.context.clean(f) - self.context.submitJob[T, Unit, Unit](self, _.foreach(cleanF), Range(0, self.partitions.size), + self.context.submitJob[T, Unit, Unit](self, _.foreach(cleanF), Range(0, self.partitions.length), (index, data) => Unit, Unit) } @@ -119,7 +119,7 @@ class AsyncRDDActions[T: ClassTag](self: RDD[T]) extends Serializable with Loggi * Applies a function f to each partition of this RDD. */ def foreachPartitionAsync(f: Iterator[T] => Unit): FutureAction[Unit] = { - self.context.submitJob[T, Unit, Unit](self, f, Range(0, self.partitions.size), + self.context.submitJob[T, Unit, Unit](self, f, Range(0, self.partitions.length), (index, data) => Unit, Unit) } } diff --git a/core/src/main/scala/org/apache/spark/rdd/BlockRDD.scala b/core/src/main/scala/org/apache/spark/rdd/BlockRDD.scala index fffa1911f5bc2..71578d1210fde 100644 --- a/core/src/main/scala/org/apache/spark/rdd/BlockRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/BlockRDD.scala @@ -36,7 +36,7 @@ class BlockRDD[T: ClassTag](@transient sc: SparkContext, @transient val blockIds override def getPartitions: Array[Partition] = { assertValid() - (0 until blockIds.size).map(i => { + (0 until blockIds.length).map(i => { new BlockRDDPartition(blockIds(i), i).asInstanceOf[Partition] }).toArray } diff --git a/core/src/main/scala/org/apache/spark/rdd/CartesianRDD.scala b/core/src/main/scala/org/apache/spark/rdd/CartesianRDD.scala index 9059eb13bb5d8..c1d6971787572 100644 --- a/core/src/main/scala/org/apache/spark/rdd/CartesianRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/CartesianRDD.scala @@ -53,11 +53,11 @@ class CartesianRDD[T: ClassTag, U: ClassTag]( extends RDD[Pair[T, U]](sc, Nil) with Serializable { - val numPartitionsInRdd2 = rdd2.partitions.size + val numPartitionsInRdd2 = rdd2.partitions.length override def getPartitions: Array[Partition] = { // create the cross product split - val array = new Array[Partition](rdd1.partitions.size * rdd2.partitions.size) + val array = new Array[Partition](rdd1.partitions.length * rdd2.partitions.length) for (s1 <- rdd1.partitions; s2 <- rdd2.partitions) { val idx = s1.index * numPartitionsInRdd2 + s2.index array(idx) = new CartesianPartition(idx, rdd1, rdd2, s1.index, s2.index) diff --git a/core/src/main/scala/org/apache/spark/rdd/CheckpointRDD.scala b/core/src/main/scala/org/apache/spark/rdd/CheckpointRDD.scala index 1c13e2c372845..0d130dd4c7a60 100644 --- a/core/src/main/scala/org/apache/spark/rdd/CheckpointRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/CheckpointRDD.scala @@ -27,6 +27,7 @@ import org.apache.hadoop.fs.Path import org.apache.spark._ import org.apache.spark.broadcast.Broadcast import org.apache.spark.deploy.SparkHadoopUtil +import org.apache.spark.util.Utils private[spark] class CheckpointRDDPartition(val index: Int) extends Partition {} @@ -48,7 +49,7 @@ class CheckpointRDD[T: ClassTag](sc: SparkContext, val checkpointPath: String) if (fs.exists(cpath)) { val dirContents = fs.listStatus(cpath).map(_.getPath) val partitionFiles = dirContents.filter(_.getName.startsWith("part-")).map(_.toString).sorted - val numPart = partitionFiles.size + val numPart = partitionFiles.length if (numPart > 0 && (! partitionFiles(0).endsWith(CheckpointRDD.splitIdToFile(0)) || ! partitionFiles(numPart-1).endsWith(CheckpointRDD.splitIdToFile(numPart-1)))) { throw new SparkException("Invalid checkpoint directory: " + checkpointPath) @@ -112,8 +113,11 @@ private[spark] object CheckpointRDD extends Logging { } val serializer = env.serializer.newInstance() val serializeStream = serializer.serializeStream(fileOutputStream) - serializeStream.writeAll(iterator) - serializeStream.close() + Utils.tryWithSafeFinally { + serializeStream.writeAll(iterator) + } { + serializeStream.close() + } if (!fs.rename(tempOutputPath, finalOutputPath)) { if (!fs.exists(finalOutputPath)) { diff --git a/core/src/main/scala/org/apache/spark/rdd/CoGroupedRDD.scala b/core/src/main/scala/org/apache/spark/rdd/CoGroupedRDD.scala index 07398a6fa62f6..7021a339e879b 100644 --- a/core/src/main/scala/org/apache/spark/rdd/CoGroupedRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/CoGroupedRDD.scala @@ -99,7 +99,7 @@ class CoGroupedRDD[K](@transient var rdds: Seq[RDD[_ <: Product2[K, _]]], part: override def getPartitions: Array[Partition] = { val array = new Array[Partition](part.numPartitions) - for (i <- 0 until array.size) { + for (i <- 0 until array.length) { // Each CoGroupPartition will have a dependency per contributing RDD array(i) = new CoGroupPartition(i, rdds.zipWithIndex.map { case (rdd, j) => // Assume each RDD contributed a single dependency, and get it @@ -120,7 +120,7 @@ class CoGroupedRDD[K](@transient var rdds: Seq[RDD[_ <: Product2[K, _]]], part: val sparkConf = SparkEnv.get.conf val externalSorting = sparkConf.getBoolean("spark.shuffle.spill", true) val split = s.asInstanceOf[CoGroupPartition] - val numRdds = split.deps.size + val numRdds = split.deps.length // A list of (rdd iterator, dependency number) pairs val rddIterators = new ArrayBuffer[(Iterator[Product2[K, Any]], Int)] diff --git a/core/src/main/scala/org/apache/spark/rdd/CoalescedRDD.scala b/core/src/main/scala/org/apache/spark/rdd/CoalescedRDD.scala index 5117ccfabfcc2..0c1b02c07d09f 100644 --- a/core/src/main/scala/org/apache/spark/rdd/CoalescedRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/CoalescedRDD.scala @@ -166,7 +166,7 @@ private class PartitionCoalescer(maxPartitions: Int, prev: RDD[_], balanceSlack: // determines the tradeoff between load-balancing the partitions sizes and their locality // e.g. balanceSlack=0.10 means that it allows up to 10% imbalance in favor of locality - val slack = (balanceSlack * prev.partitions.size).toInt + val slack = (balanceSlack * prev.partitions.length).toInt var noLocality = true // if true if no preferredLocations exists for parent RDD diff --git a/core/src/main/scala/org/apache/spark/rdd/DoubleRDDFunctions.scala b/core/src/main/scala/org/apache/spark/rdd/DoubleRDDFunctions.scala index 71e6e300fec5f..29ca3e9c4bd04 100644 --- a/core/src/main/scala/org/apache/spark/rdd/DoubleRDDFunctions.scala +++ b/core/src/main/scala/org/apache/spark/rdd/DoubleRDDFunctions.scala @@ -70,7 +70,7 @@ class DoubleRDDFunctions(self: RDD[Double]) extends Logging with Serializable { @Experimental def meanApprox(timeout: Long, confidence: Double = 0.95): PartialResult[BoundedDouble] = { val processPartition = (ctx: TaskContext, ns: Iterator[Double]) => StatCounter(ns) - val evaluator = new MeanEvaluator(self.partitions.size, confidence) + val evaluator = new MeanEvaluator(self.partitions.length, confidence) self.context.runApproximateJob(self, processPartition, evaluator, timeout) } @@ -81,7 +81,7 @@ class DoubleRDDFunctions(self: RDD[Double]) extends Logging with Serializable { @Experimental def sumApprox(timeout: Long, confidence: Double = 0.95): PartialResult[BoundedDouble] = { val processPartition = (ctx: TaskContext, ns: Iterator[Double]) => StatCounter(ns) - val evaluator = new SumEvaluator(self.partitions.size, confidence) + val evaluator = new SumEvaluator(self.partitions.length, confidence) self.context.runApproximateJob(self, processPartition, evaluator, timeout) } diff --git a/core/src/main/scala/org/apache/spark/rdd/OrderedRDDFunctions.scala b/core/src/main/scala/org/apache/spark/rdd/OrderedRDDFunctions.scala index 6fdfdb734d1b8..6afe50161dacd 100644 --- a/core/src/main/scala/org/apache/spark/rdd/OrderedRDDFunctions.scala +++ b/core/src/main/scala/org/apache/spark/rdd/OrderedRDDFunctions.scala @@ -56,7 +56,7 @@ class OrderedRDDFunctions[K : Ordering : ClassTag, * order of the keys). */ // TODO: this currently doesn't work on P other than Tuple2! - def sortByKey(ascending: Boolean = true, numPartitions: Int = self.partitions.size) + def sortByKey(ascending: Boolean = true, numPartitions: Int = self.partitions.length) : RDD[(K, V)] = { val part = new RangePartitioner(numPartitions, self, ascending) diff --git a/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala b/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala index 6b4f097ea9ae5..05351ba4ff76b 100644 --- a/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala +++ b/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala @@ -823,7 +823,7 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)]) * RDD will be <= us. */ def subtractByKey[W: ClassTag](other: RDD[(K, W)]): RDD[(K, V)] = - subtractByKey(other, self.partitioner.getOrElse(new HashPartitioner(self.partitions.size))) + subtractByKey(other, self.partitioner.getOrElse(new HashPartitioner(self.partitions.length))) /** Return an RDD with the pairs from `this` whose keys are not in `other`. */ def subtractByKey[W: ClassTag](other: RDD[(K, W)], numPartitions: Int): RDD[(K, V)] = @@ -995,7 +995,7 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)]) val writer = format.getRecordWriter(hadoopContext).asInstanceOf[NewRecordWriter[K,V]] require(writer != null, "Unable to obtain RecordWriter") var recordsWritten = 0L - try { + Utils.tryWithSafeFinally { while (iter.hasNext) { val pair = iter.next() writer.write(pair._1, pair._2) @@ -1004,7 +1004,7 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)]) maybeUpdateOutputMetrics(bytesWrittenCallback, outputMetrics, recordsWritten) recordsWritten += 1 } - } finally { + } { writer.close(hadoopContext) } committer.commitTask(hadoopContext) @@ -1068,7 +1068,8 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)]) writer.setup(context.stageId, context.partitionId, taskAttemptId) writer.open() var recordsWritten = 0L - try { + + Utils.tryWithSafeFinally { while (iter.hasNext) { val record = iter.next() writer.write(record._1.asInstanceOf[AnyRef], record._2.asInstanceOf[AnyRef]) @@ -1077,7 +1078,7 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)]) maybeUpdateOutputMetrics(bytesWrittenCallback, outputMetrics, recordsWritten) recordsWritten += 1 } - } finally { + } { writer.close() } writer.commit() diff --git a/core/src/main/scala/org/apache/spark/rdd/RDD.scala b/core/src/main/scala/org/apache/spark/rdd/RDD.scala index ddbfd5624e741..d80d94a588346 100644 --- a/core/src/main/scala/org/apache/spark/rdd/RDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/RDD.scala @@ -316,7 +316,7 @@ abstract class RDD[T: ClassTag]( /** * Return a new RDD containing the distinct elements in this RDD. */ - def distinct(): RDD[T] = distinct(partitions.size) + def distinct(): RDD[T] = distinct(partitions.length) /** * Return a new RDD that has exactly numPartitions partitions. @@ -488,7 +488,7 @@ abstract class RDD[T: ClassTag]( def sortBy[K]( f: (T) => K, ascending: Boolean = true, - numPartitions: Int = this.partitions.size) + numPartitions: Int = this.partitions.length) (implicit ord: Ordering[K], ctag: ClassTag[K]): RDD[T] = this.keyBy[K](f) .sortByKey(ascending, numPartitions) @@ -852,7 +852,7 @@ abstract class RDD[T: ClassTag]( * RDD will be <= us. */ def subtract(other: RDD[T]): RDD[T] = - subtract(other, partitioner.getOrElse(new HashPartitioner(partitions.size))) + subtract(other, partitioner.getOrElse(new HashPartitioner(partitions.length))) /** * Return an RDD with the elements from `this` that are not in `other`. @@ -986,14 +986,14 @@ abstract class RDD[T: ClassTag]( combOp: (U, U) => U, depth: Int = 2): U = { require(depth >= 1, s"Depth must be greater than or equal to 1 but got $depth.") - if (partitions.size == 0) { + if (partitions.length == 0) { return Utils.clone(zeroValue, context.env.closureSerializer.newInstance()) } val cleanSeqOp = context.clean(seqOp) val cleanCombOp = context.clean(combOp) val aggregatePartition = (it: Iterator[T]) => it.aggregate(zeroValue)(cleanSeqOp, cleanCombOp) var partiallyAggregated = mapPartitions(it => Iterator(aggregatePartition(it))) - var numPartitions = partiallyAggregated.partitions.size + var numPartitions = partiallyAggregated.partitions.length val scale = math.max(math.ceil(math.pow(numPartitions, 1.0 / depth)).toInt, 2) // If creating an extra level doesn't help reduce the wall-clock time, we stop tree aggregation. while (numPartitions > scale + numPartitions / scale) { @@ -1026,7 +1026,7 @@ abstract class RDD[T: ClassTag]( } result } - val evaluator = new CountEvaluator(partitions.size, confidence) + val evaluator = new CountEvaluator(partitions.length, confidence) sc.runApproximateJob(this, countElements, evaluator, timeout) } @@ -1061,7 +1061,7 @@ abstract class RDD[T: ClassTag]( } map } - val evaluator = new GroupedCountEvaluator[T](partitions.size, confidence) + val evaluator = new GroupedCountEvaluator[T](partitions.length, confidence) sc.runApproximateJob(this, countPartition, evaluator, timeout) } @@ -1140,7 +1140,7 @@ abstract class RDD[T: ClassTag]( * the same index assignments, you should sort the RDD with sortByKey() or save it to a file. */ def zipWithUniqueId(): RDD[(T, Long)] = { - val n = this.partitions.size.toLong + val n = this.partitions.length.toLong this.mapPartitionsWithIndex { case (k, iter) => iter.zipWithIndex.map { case (item, i) => (item, i * n + k) @@ -1243,7 +1243,7 @@ abstract class RDD[T: ClassTag]( queue ++= util.collection.Utils.takeOrdered(items, num)(ord) Iterator.single(queue) } - if (mapRDDs.partitions.size == 0) { + if (mapRDDs.partitions.length == 0) { Array.empty } else { mapRDDs.reduce { (queue1, queue2) => @@ -1489,7 +1489,7 @@ abstract class RDD[T: ClassTag]( } // The first RDD in the dependency stack has no parents, so no need for a +- def firstDebugString(rdd: RDD[_]): Seq[String] = { - val partitionStr = "(" + rdd.partitions.size + ")" + val partitionStr = "(" + rdd.partitions.length + ")" val leftOffset = (partitionStr.length - 1) / 2 val nextPrefix = (" " * leftOffset) + "|" + (" " * (partitionStr.length - leftOffset)) @@ -1499,7 +1499,7 @@ abstract class RDD[T: ClassTag]( } ++ debugChildren(rdd, nextPrefix) } def shuffleDebugString(rdd: RDD[_], prefix: String = "", isLastChild: Boolean): Seq[String] = { - val partitionStr = "(" + rdd.partitions.size + ")" + val partitionStr = "(" + rdd.partitions.length + ")" val leftOffset = (partitionStr.length - 1) / 2 val thisPrefix = prefix.replaceAll("\\|\\s+$", "") val nextPrefix = ( diff --git a/core/src/main/scala/org/apache/spark/rdd/RDDCheckpointData.scala b/core/src/main/scala/org/apache/spark/rdd/RDDCheckpointData.scala index f67e5f1857979..6afd63d537d75 100644 --- a/core/src/main/scala/org/apache/spark/rdd/RDDCheckpointData.scala +++ b/core/src/main/scala/org/apache/spark/rdd/RDDCheckpointData.scala @@ -94,10 +94,10 @@ private[spark] class RDDCheckpointData[T: ClassTag](@transient rdd: RDD[T]) new SerializableWritable(rdd.context.hadoopConfiguration)) rdd.context.runJob(rdd, CheckpointRDD.writeToFile[T](path.toString, broadcastedConf) _) val newRDD = new CheckpointRDD[T](rdd.context, path.toString) - if (newRDD.partitions.size != rdd.partitions.size) { + if (newRDD.partitions.length != rdd.partitions.length) { throw new SparkException( - "Checkpoint RDD " + newRDD + "(" + newRDD.partitions.size + ") has different " + - "number of partitions than original RDD " + rdd + "(" + rdd.partitions.size + ")") + "Checkpoint RDD " + newRDD + "(" + newRDD.partitions.length + ") has different " + + "number of partitions than original RDD " + rdd + "(" + rdd.partitions.length + ")") } // Change the dependencies and partitions of the RDD diff --git a/core/src/main/scala/org/apache/spark/rdd/SubtractedRDD.scala b/core/src/main/scala/org/apache/spark/rdd/SubtractedRDD.scala index c27f435eb9c5a..e9d745588ee9a 100644 --- a/core/src/main/scala/org/apache/spark/rdd/SubtractedRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/SubtractedRDD.scala @@ -76,7 +76,7 @@ private[spark] class SubtractedRDD[K: ClassTag, V: ClassTag, W: ClassTag]( override def getPartitions: Array[Partition] = { val array = new Array[Partition](part.numPartitions) - for (i <- 0 until array.size) { + for (i <- 0 until array.length) { // Each CoGroupPartition will depend on rdd1 and rdd2 array(i) = new CoGroupPartition(i, Seq(rdd1, rdd2).zipWithIndex.map { case (rdd, j) => dependencies(j) match { diff --git a/core/src/main/scala/org/apache/spark/rdd/UnionRDD.scala b/core/src/main/scala/org/apache/spark/rdd/UnionRDD.scala index 4239e7e22af89..3986645350a82 100644 --- a/core/src/main/scala/org/apache/spark/rdd/UnionRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/UnionRDD.scala @@ -63,7 +63,7 @@ class UnionRDD[T: ClassTag]( extends RDD[T](sc, Nil) { // Nil since we implement getDependencies override def getPartitions: Array[Partition] = { - val array = new Array[Partition](rdds.map(_.partitions.size).sum) + val array = new Array[Partition](rdds.map(_.partitions.length).sum) var pos = 0 for ((rdd, rddIndex) <- rdds.zipWithIndex; split <- rdd.partitions) { array(pos) = new UnionPartition(pos, rdd, rddIndex, split.index) @@ -76,8 +76,8 @@ class UnionRDD[T: ClassTag]( val deps = new ArrayBuffer[Dependency[_]] var pos = 0 for (rdd <- rdds) { - deps += new RangeDependency(rdd, 0, pos, rdd.partitions.size) - pos += rdd.partitions.size + deps += new RangeDependency(rdd, 0, pos, rdd.partitions.length) + pos += rdd.partitions.length } deps } diff --git a/core/src/main/scala/org/apache/spark/rdd/ZippedPartitionsRDD.scala b/core/src/main/scala/org/apache/spark/rdd/ZippedPartitionsRDD.scala index d0be304762e1f..a96b6c3d23454 100644 --- a/core/src/main/scala/org/apache/spark/rdd/ZippedPartitionsRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/ZippedPartitionsRDD.scala @@ -52,8 +52,8 @@ private[spark] abstract class ZippedPartitionsBaseRDD[V: ClassTag]( if (preservesPartitioning) firstParent[Any].partitioner else None override def getPartitions: Array[Partition] = { - val numParts = rdds.head.partitions.size - if (!rdds.forall(rdd => rdd.partitions.size == numParts)) { + val numParts = rdds.head.partitions.length + if (!rdds.forall(rdd => rdd.partitions.length == numParts)) { throw new IllegalArgumentException("Can't zip RDDs with unequal numbers of partitions") } Array.tabulate[Partition](numParts) { i => diff --git a/core/src/main/scala/org/apache/spark/rdd/ZippedWithIndexRDD.scala b/core/src/main/scala/org/apache/spark/rdd/ZippedWithIndexRDD.scala index 8c43a559409f2..523aaf2b860b5 100644 --- a/core/src/main/scala/org/apache/spark/rdd/ZippedWithIndexRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/ZippedWithIndexRDD.scala @@ -41,7 +41,7 @@ class ZippedWithIndexRDD[T: ClassTag](@transient prev: RDD[T]) extends RDD[(T, L /** The start index of each partition. */ @transient private val startIndices: Array[Long] = { - val n = prev.partitions.size + val n = prev.partitions.length if (n == 0) { Array[Long]() } else if (n == 1) { diff --git a/core/src/main/scala/org/apache/spark/rpc/RpcEnv.scala b/core/src/main/scala/org/apache/spark/rpc/RpcEnv.scala index 7985941d949c0..e259867c14040 100644 --- a/core/src/main/scala/org/apache/spark/rpc/RpcEnv.scala +++ b/core/src/main/scala/org/apache/spark/rpc/RpcEnv.scala @@ -30,7 +30,9 @@ import org.apache.spark.util.{AkkaUtils, Utils} /** * An RPC environment. [[RpcEndpoint]]s need to register itself with a name to [[RpcEnv]] to * receives messages. Then [[RpcEnv]] will process messages sent from [[RpcEndpointRef]] or remote - * nodes, and deliver them to corresponding [[RpcEndpoint]]s. + * nodes, and deliver them to corresponding [[RpcEndpoint]]s. For uncaught exceptions caught by + * [[RpcEnv]], [[RpcEnv]] will use [[RpcCallContext.sendFailure]] to send exceptions back to the + * sender, or logging them if no such sender or `NotSerializableException`. * * [[RpcEnv]] also provides some methods to retrieve [[RpcEndpointRef]]s given name or uri. */ @@ -40,10 +42,7 @@ private[spark] abstract class RpcEnv(conf: SparkConf) { /** * Return RpcEndpointRef of the registered [[RpcEndpoint]]. Will be used to implement - * [[RpcEndpoint.self]]. - * - * Note: This method won't return null. `IllegalArgumentException` will be thrown if calling this - * on a non-existent endpoint. + * [[RpcEndpoint.self]]. Return `null` if the corresponding [[RpcEndpointRef]] does not exist. */ private[rpc] def endpointRef(endpoint: RpcEndpoint): RpcEndpointRef @@ -58,20 +57,6 @@ private[spark] abstract class RpcEnv(conf: SparkConf) { */ def setupEndpoint(name: String, endpoint: RpcEndpoint): RpcEndpointRef - /** - * Register a [[RpcEndpoint]] with a name and return its [[RpcEndpointRef]]. [[RpcEnv]] should - * make sure thread-safely sending messages to [[RpcEndpoint]]. - * - * Thread-safety means processing of one message happens before processing of the next message by - * the same [[RpcEndpoint]]. In the other words, changes to internal fields of a [[RpcEndpoint]] - * are visible when processing the next message, and fields in the [[RpcEndpoint]] need not be - * volatile or equivalent. - * - * However, there is no guarantee that the same thread will be executing the same [[RpcEndpoint]] - * for different messages. - */ - def setupThreadSafeEndpoint(name: String, endpoint: RpcEndpoint): RpcEndpointRef - /** * Retrieve the [[RpcEndpointRef]] represented by `uri` asynchronously. */ @@ -181,7 +166,7 @@ private[spark] trait RpcEnvFactory { * constructor onStart receive* onStop * * Note: `receive` can be called concurrently. If you want `receive` is thread-safe, please use - * [[RpcEnv.setupThreadSafeEndpoint]] + * [[ThreadSafeRpcEndpoint]] * * If any error is thrown from one of [[RpcEndpoint]] methods except `onError`, `onError` will be * invoked with the cause. If `onError` throws an error, [[RpcEnv]] will ignore it. @@ -195,7 +180,7 @@ private[spark] trait RpcEndpoint { /** * The [[RpcEndpointRef]] of this [[RpcEndpoint]]. `self` will become valid when `onStart` is - * called. + * called. And `self` will become `null` when `onStop` is called. * * Note: Because before `onStart`, [[RpcEndpoint]] has not yet been registered and there is not * valid [[RpcEndpointRef]] for it. So don't call `self` before `onStart` is called. @@ -278,6 +263,19 @@ private[spark] trait RpcEndpoint { } } +/** + * A trait that requires RpcEnv thread-safely sending messages to it. + * + * Thread-safety means processing of one message happens before processing of the next message by + * the same [[ThreadSafeRpcEndpoint]]. In the other words, changes to internal fields of a + * [[ThreadSafeRpcEndpoint]] are visible when processing the next message, and fields in the + * [[ThreadSafeRpcEndpoint]] need not be volatile or equivalent. + * + * However, there is no guarantee that the same thread will be executing the same + * [[ThreadSafeRpcEndpoint]] for different messages. + */ +trait ThreadSafeRpcEndpoint extends RpcEndpoint + /** * A reference for a remote [[RpcEndpoint]]. [[RpcEndpointRef]] is thread-safe. */ @@ -407,7 +405,8 @@ private[spark] object RpcAddress { } /** - * A callback that [[RpcEndpoint]] can use it to send back a message or failure. + * A callback that [[RpcEndpoint]] can use it to send back a message or failure. It's thread-safe + * and can be called in any thread. */ private[spark] trait RpcCallContext { diff --git a/core/src/main/scala/org/apache/spark/rpc/akka/AkkaRpcEnv.scala b/core/src/main/scala/org/apache/spark/rpc/akka/AkkaRpcEnv.scala index 769d59b7b3343..652e52f2b2e73 100644 --- a/core/src/main/scala/org/apache/spark/rpc/akka/AkkaRpcEnv.scala +++ b/core/src/main/scala/org/apache/spark/rpc/akka/AkkaRpcEnv.scala @@ -17,16 +17,16 @@ package org.apache.spark.rpc.akka -import java.net.URI import java.util.concurrent.ConcurrentHashMap -import scala.concurrent.{Await, Future} +import scala.concurrent.Future import scala.concurrent.duration._ import scala.language.postfixOps import scala.reflect.ClassTag import scala.util.control.NonFatal import akka.actor.{ActorSystem, ExtendedActorSystem, Actor, ActorRef, Props, Address} +import akka.event.Logging.Error import akka.pattern.{ask => akkaAsk} import akka.remote.{AssociationEvent, AssociatedEvent, DisassociatedEvent, AssociationErrorEvent} import org.apache.spark.{SparkException, Logging, SparkConf} @@ -82,17 +82,9 @@ private[spark] class AkkaRpcEnv private[akka] ( /** * Retrieve the [[RpcEndpointRef]] of `endpoint`. */ - override def endpointRef(endpoint: RpcEndpoint): RpcEndpointRef = { - val endpointRef = endpointToRef.get(endpoint) - require(endpointRef != null, s"Cannot find RpcEndpointRef of ${endpoint} in ${this}") - endpointRef - } + override def endpointRef(endpoint: RpcEndpoint): RpcEndpointRef = endpointToRef.get(endpoint) override def setupEndpoint(name: String, endpoint: RpcEndpoint): RpcEndpointRef = { - setupThreadSafeEndpoint(name, endpoint) - } - - override def setupThreadSafeEndpoint(name: String, endpoint: RpcEndpoint): RpcEndpointRef = { @volatile var endpointRef: AkkaRpcEndpointRef = null // Use lazy because the Actor needs to use `endpointRef`. // So `actorRef` should be created after assigning `endpointRef`. @@ -250,10 +242,25 @@ private[spark] class AkkaRpcEnvFactory extends RpcEnvFactory { def create(config: RpcEnvConfig): RpcEnv = { val (actorSystem, boundPort) = AkkaUtils.createActorSystem( config.name, config.host, config.port, config.conf, config.securityManager) + actorSystem.actorOf(Props(classOf[ErrorMonitor]), "ErrorMonitor") new AkkaRpcEnv(actorSystem, config.conf, boundPort) } } +/** + * Monitor errors reported by Akka and log them. + */ +private[akka] class ErrorMonitor extends Actor with ActorLogReceive with Logging { + + override def preStart(): Unit = { + context.system.eventStream.subscribe(self, classOf[Error]) + } + + override def receiveWithLogging: Actor.Receive = { + case Error(cause: Throwable, _, _, message: String) => logError(message, cause) + } +} + private[akka] class AkkaRpcEndpointRef( @transient defaultAddress: RpcAddress, @transient _actorRef: => ActorRef, diff --git a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala index d35b4f9dbaf88..508fe7b3303ca 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala @@ -23,14 +23,11 @@ import java.util.concurrent.{TimeUnit, Executors} import java.util.concurrent.atomic.AtomicInteger import scala.collection.mutable.{ArrayBuffer, HashMap, HashSet, Map, Stack} -import scala.concurrent.Await import scala.concurrent.duration._ +import scala.language.existentials import scala.language.postfixOps import scala.util.control.NonFatal -import akka.pattern.ask -import akka.util.Timeout - import org.apache.spark._ import org.apache.spark.broadcast.Broadcast import org.apache.spark.executor.TaskMetrics @@ -53,6 +50,10 @@ import org.apache.spark.storage.BlockManagerMessages.BlockManagerHeartbeat * not caused by shuffle file loss are handled by the TaskScheduler, which will retry each task * a small number of times before cancelling the whole stage. * + * Here's a checklist to use when making or reviewing changes to this class: + * + * - When adding a new data structure, update `DAGSchedulerSuite.assertDataStructuresEmpty` to + * include the new structure. This will help to catch memory leaks. */ private[spark] class DAGScheduler( @@ -114,6 +115,8 @@ class DAGScheduler( // stray messages to detect. private val failedEpoch = new HashMap[String, Long] + private [scheduler] val outputCommitCoordinator = env.outputCommitCoordinator + // A closure serializer that we reuse. // This is only safe because DAGScheduler runs in a single thread. private val closureSerializer = SparkEnv.get.closureSerializer.newInstance() @@ -131,8 +134,6 @@ class DAGScheduler( private[scheduler] val eventProcessLoop = new DAGSchedulerEventProcessLoop(this) taskScheduler.setDAGScheduler(this) - private val outputCommitCoordinator = env.outputCommitCoordinator - // Called by TaskScheduler to report task's starting. def taskStarted(task: Task[_], taskInfo: TaskInfo) { eventProcessLoop.post(BeginEvent(task, taskInfo)) @@ -165,11 +166,8 @@ class DAGScheduler( taskMetrics: Array[(Long, Int, Int, TaskMetrics)], // (taskId, stageId, stateAttempt, metrics) blockManagerId: BlockManagerId): Boolean = { listenerBus.post(SparkListenerExecutorMetricsUpdate(execId, taskMetrics)) - implicit val timeout = Timeout(600 seconds) - - Await.result( - blockManagerMaster.driverActor ? BlockManagerHeartbeat(blockManagerId), - timeout.duration).asInstanceOf[Boolean] + blockManagerMaster.driverEndpoint.askWithReply[Boolean]( + BlockManagerHeartbeat(blockManagerId), 600 seconds) } // Called by TaskScheduler when an executor fails. @@ -493,7 +491,7 @@ class DAGScheduler( callSite: CallSite, allowLocal: Boolean, resultHandler: (Int, U) => Unit, - properties: Properties = null): JobWaiter[U] = { + properties: Properties): JobWaiter[U] = { // Check to make sure we are not launching a task on a partition that does not exist. val maxPartitions = rdd.partitions.length partitions.find(p => p >= maxPartitions || p < 0).foreach { p => @@ -522,7 +520,7 @@ class DAGScheduler( callSite: CallSite, allowLocal: Boolean, resultHandler: (Int, U) => Unit, - properties: Properties = null): Unit = { + properties: Properties): Unit = { val start = System.nanoTime val waiter = submitJob(rdd, func, partitions, callSite, allowLocal, resultHandler, properties) waiter.awaitResult() match { @@ -542,7 +540,7 @@ class DAGScheduler( evaluator: ApproximateEvaluator[U, R], callSite: CallSite, timeout: Long, - properties: Properties = null): PartialResult[R] = { + properties: Properties): PartialResult[R] = { val listener = new ApproximateActionListener(rdd, func, evaluator, timeout) val func2 = func.asInstanceOf[(TaskContext, Iterator[_]) => _] val partitions = (0 until rdd.partitions.size).toArray @@ -647,13 +645,13 @@ class DAGScheduler( val split = rdd.partitions(job.partitions(0)) val taskContext = new TaskContextImpl(job.finalStage.id, job.partitions(0), taskAttemptId = 0, attemptNumber = 0, runningLocally = true) - TaskContextHelper.setTaskContext(taskContext) + TaskContext.setTaskContext(taskContext) try { val result = job.func(taskContext, rdd.iterator(split, taskContext)) job.listener.taskSucceeded(0, result) } finally { taskContext.markTaskCompleted() - TaskContextHelper.unset() + TaskContext.unset() } } catch { case e: Exception => @@ -689,7 +687,7 @@ class DAGScheduler( // Cancel all jobs belonging to this job group. // First finds all active jobs with this group id, and then kill stages for them. val activeInGroup = activeJobs.filter(activeJob => - groupId == activeJob.properties.get(SparkContext.SPARK_JOB_GROUP_ID)) + Option(activeJob.properties).exists(_.get(SparkContext.SPARK_JOB_GROUP_ID) == groupId)) val jobIds = activeInGroup.map(_.jobId) jobIds.foreach(handleJobCancellation(_, "part of cancelled job group %s".format(groupId))) submitWaitingStages() @@ -716,9 +714,10 @@ class DAGScheduler( // cancelling the stages because if the DAG scheduler is stopped, the entire application // is in the process of getting stopped. val stageFailedMessage = "Stage cancelled because SparkContext was shut down" - runningStages.foreach { stage => - stage.latestInfo.stageFailed(stageFailedMessage) - listenerBus.post(SparkListenerStageCompleted(stage.latestInfo)) + // The `toArray` here is necessary so that we don't iterate over `runningStages` while + // mutating it. + runningStages.toArray.foreach { stage => + markStageAsFinished(stage, Some(stageFailedMessage)) } listenerBus.post(SparkListenerJobEnd(job.jobId, clock.getTimeMillis(), JobFailed(error))) } @@ -736,7 +735,7 @@ class DAGScheduler( allowLocal: Boolean, callSite: CallSite, listener: JobListener, - properties: Properties = null) { + properties: Properties) { var finalStage: ResultStage = null try { // New stage creation may throw an exception if, for example, jobs are run on a @@ -893,10 +892,9 @@ class DAGScheduler( new TaskSet(tasks.toArray, stage.id, stage.newAttemptId(), stage.jobId, properties)) stage.latestInfo.submissionTime = Some(clock.getTimeMillis()) } else { - // Because we posted SparkListenerStageSubmitted earlier, we should post - // SparkListenerStageCompleted here in case there are no tasks to run. - outputCommitCoordinator.stageEnd(stage.id) - listenerBus.post(SparkListenerStageCompleted(stage.latestInfo)) + // Because we posted SparkListenerStageSubmitted earlier, we should mark + // the stage as completed here in case there are no tasks to run + markStageAsFinished(stage, None) val debugString = stage match { case stage: ShuffleMapStage => @@ -908,7 +906,6 @@ class DAGScheduler( s"Stage ${stage} is actually done; (partitions: ${stage.numPartitions})" } logDebug(debugString) - runningStages -= stage } } @@ -974,22 +971,6 @@ class DAGScheduler( } val stage = stageIdToStage(task.stageId) - - def markStageAsFinished(stage: Stage, errorMessage: Option[String] = None): Unit = { - val serviceTime = stage.latestInfo.submissionTime match { - case Some(t) => "%.03f".format((clock.getTimeMillis() - t) / 1000.0) - case _ => "Unknown" - } - if (errorMessage.isEmpty) { - logInfo("%s (%s) finished in %s s".format(stage, stage.name, serviceTime)) - stage.latestInfo.completionTime = Some(clock.getTimeMillis()) - } else { - stage.latestInfo.stageFailed(errorMessage.get) - logInfo("%s (%s) failed in %s s".format(stage, stage.name, serviceTime)) - } - listenerBus.post(SparkListenerStageCompleted(stage.latestInfo)) - runningStages -= stage - } event.reason match { case Success => listenerBus.post(SparkListenerTaskEnd(stageId, stage.latestInfo.attemptId, taskType, @@ -1105,7 +1086,6 @@ class DAGScheduler( logInfo(s"Marking $failedStage (${failedStage.name}) as failed " + s"due to a fetch failure from $mapStage (${mapStage.name})") markStageAsFinished(failedStage, Some(failureMessage)) - runningStages -= failedStage } if (disallowStageRetryForTest) { @@ -1221,6 +1201,26 @@ class DAGScheduler( submitWaitingStages() } + /** + * Marks a stage as finished and removes it from the list of running stages. + */ + private def markStageAsFinished(stage: Stage, errorMessage: Option[String] = None): Unit = { + val serviceTime = stage.latestInfo.submissionTime match { + case Some(t) => "%.03f".format((clock.getTimeMillis() - t) / 1000.0) + case _ => "Unknown" + } + if (errorMessage.isEmpty) { + logInfo("%s (%s) finished in %s s".format(stage, stage.name, serviceTime)) + stage.latestInfo.completionTime = Some(clock.getTimeMillis()) + } else { + stage.latestInfo.stageFailed(errorMessage.get) + logInfo("%s (%s) failed in %s s".format(stage, stage.name, serviceTime)) + } + outputCommitCoordinator.stageEnd(stage.id) + listenerBus.post(SparkListenerStageCompleted(stage.latestInfo)) + runningStages -= stage + } + /** * Aborts all jobs depending on a particular Stage. This is called in response to a task set * being canceled by the TaskScheduler. Use taskSetFailed() to inject this event from outside. @@ -1270,8 +1270,7 @@ class DAGScheduler( if (runningStages.contains(stage)) { try { // cancelTasks will fail if a SchedulerBackend does not implement killTask taskScheduler.cancelTasks(stageId, shouldInterruptThread) - stage.latestInfo.stageFailed(failureReason) - listenerBus.post(SparkListenerStageCompleted(stage.latestInfo)) + markStageAsFinished(stage, Some(failureReason)) } catch { case e: UnsupportedOperationException => logInfo(s"Could not cancel tasks for stage $stageId", e) diff --git a/core/src/main/scala/org/apache/spark/scheduler/EventLoggingListener.scala b/core/src/main/scala/org/apache/spark/scheduler/EventLoggingListener.scala index c0d889360ae99..08e7727db2fde 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/EventLoggingListener.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/EventLoggingListener.scala @@ -47,21 +47,21 @@ import org.apache.spark.util.{JsonProtocol, Utils} */ private[spark] class EventLoggingListener( appId: String, - logBaseDir: String, + logBaseDir: URI, sparkConf: SparkConf, hadoopConf: Configuration) extends SparkListener with Logging { import EventLoggingListener._ - def this(appId: String, logBaseDir: String, sparkConf: SparkConf) = + def this(appId: String, logBaseDir: URI, sparkConf: SparkConf) = this(appId, logBaseDir, sparkConf, SparkHadoopUtil.get.newConfiguration(sparkConf)) private val shouldCompress = sparkConf.getBoolean("spark.eventLog.compress", false) private val shouldOverwrite = sparkConf.getBoolean("spark.eventLog.overwrite", false) private val testing = sparkConf.getBoolean("spark.eventLog.testing", false) private val outputBufferSize = sparkConf.getInt("spark.eventLog.buffer.kb", 100) * 1024 - private val fileSystem = Utils.getHadoopFileSystem(new URI(logBaseDir), hadoopConf) + private val fileSystem = Utils.getHadoopFileSystem(logBaseDir, hadoopConf) private val compressionCodec = if (shouldCompress) { Some(CompressionCodec.createCodec(sparkConf)) @@ -259,13 +259,13 @@ private[spark] object EventLoggingListener extends Logging { * @return A path which consists of file-system-safe characters. */ def getLogPath( - logBaseDir: String, + logBaseDir: URI, appId: String, compressionCodecName: Option[String] = None): String = { val sanitizedAppId = appId.replaceAll("[ :/]", "-").replaceAll("[.${}'\"]", "_").toLowerCase // e.g. app_123, app_123.lzf val logName = sanitizedAppId + compressionCodecName.map { "." + _ }.getOrElse("") - Utils.resolveURI(logBaseDir).toString.stripSuffix("/") + "/" + logName + logBaseDir.toString.stripSuffix("/") + "/" + logName } /** diff --git a/core/src/main/scala/org/apache/spark/scheduler/OutputCommitCoordinator.scala b/core/src/main/scala/org/apache/spark/scheduler/OutputCommitCoordinator.scala index 9e29fd13821dc..7c184b1dcb308 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/OutputCommitCoordinator.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/OutputCommitCoordinator.scala @@ -59,6 +59,13 @@ private[spark] class OutputCommitCoordinator(conf: SparkConf) extends Logging { private val authorizedCommittersByStage: CommittersByStageMap = mutable.Map() private type CommittersByStageMap = mutable.Map[StageId, mutable.Map[PartitionId, TaskAttemptId]] + /** + * Returns whether the OutputCommitCoordinator's internal data structures are all empty. + */ + def isEmpty: Boolean = { + authorizedCommittersByStage.isEmpty + } + /** * Called by tasks to ask whether they can commit their output to HDFS. * diff --git a/core/src/main/scala/org/apache/spark/scheduler/ShuffleMapTask.scala b/core/src/main/scala/org/apache/spark/scheduler/ShuffleMapTask.scala index fd0d484b45460..6c7d00069acb2 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/ShuffleMapTask.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/ShuffleMapTask.scala @@ -33,7 +33,7 @@ import org.apache.spark.shuffle.ShuffleWriter * See [[org.apache.spark.scheduler.Task]] for more information. * * @param stageId id of the stage this task belongs to - * @param taskBinary broadcast version of of the RDD and the ShuffleDependency. Once deserialized, + * @param taskBinary broadcast version of the RDD and the ShuffleDependency. Once deserialized, * the type should be (RDD[_], ShuffleDependency[_, _, _]). * @param partition partition of the RDD this task is associated with * @param locs preferred task execution locations for locality scheduling diff --git a/core/src/main/scala/org/apache/spark/scheduler/Task.scala b/core/src/main/scala/org/apache/spark/scheduler/Task.scala index 4d9f940813b8e..8b592867ee31d 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/Task.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/Task.scala @@ -22,7 +22,7 @@ import java.nio.ByteBuffer import scala.collection.mutable.HashMap -import org.apache.spark.{TaskContextHelper, TaskContextImpl, TaskContext} +import org.apache.spark.{TaskContextImpl, TaskContext} import org.apache.spark.executor.TaskMetrics import org.apache.spark.serializer.SerializerInstance import org.apache.spark.util.ByteBufferInputStream @@ -54,7 +54,7 @@ private[spark] abstract class Task[T](val stageId: Int, var partitionId: Int) ex final def run(taskAttemptId: Long, attemptNumber: Int): T = { context = new TaskContextImpl(stageId = stageId, partitionId = partitionId, taskAttemptId = taskAttemptId, attemptNumber = attemptNumber, runningLocally = false) - TaskContextHelper.setTaskContext(context) + TaskContext.setTaskContext(context) context.taskMetrics.setHostname(Utils.localHostName()) taskThread = Thread.currentThread() if (_killed) { @@ -64,7 +64,7 @@ private[spark] abstract class Task[T](val stageId: Int, var partitionId: Int) ex runTask(context) } finally { context.markTaskCompleted() - TaskContextHelper.unset() + TaskContext.unset() } } diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedClusterMessage.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedClusterMessage.scala index 9bf74f4be198d..70364cea62a80 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedClusterMessage.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedClusterMessage.scala @@ -20,6 +20,7 @@ package org.apache.spark.scheduler.cluster import java.nio.ByteBuffer import org.apache.spark.TaskState.TaskState +import org.apache.spark.rpc.RpcEndpointRef import org.apache.spark.util.{SerializableBuffer, Utils} private[spark] sealed trait CoarseGrainedClusterMessage extends Serializable @@ -41,6 +42,7 @@ private[spark] object CoarseGrainedClusterMessages { // Executors to driver case class RegisterExecutor( executorId: String, + executorRef: RpcEndpointRef, hostPort: String, cores: Int, logUrls: Map[String, String]) @@ -70,6 +72,8 @@ private[spark] object CoarseGrainedClusterMessages { case class RemoveExecutor(executorId: String, reason: String) extends CoarseGrainedClusterMessage + case class SetupDriver(driver: RpcEndpointRef) extends CoarseGrainedClusterMessage + // Exchanged between the driver and the AM in Yarn client mode case class AddWebUIFilter(filterName:String, filterParams: Map[String, String], proxyBase: String) extends CoarseGrainedClusterMessage @@ -77,7 +81,7 @@ private[spark] object CoarseGrainedClusterMessages { // Messages exchanged between the driver and the cluster manager for executor allocation // In Yarn mode, these are exchanged between the driver and the AM - case object RegisterClusterManager extends CoarseGrainedClusterMessage + case class RegisterClusterManager(am: RpcEndpointRef) extends CoarseGrainedClusterMessage // Request executors by specifying the new total number of executors desired // This includes executors already pending or running diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala index 5d258d9da4d1a..4c49da87af9dc 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala @@ -17,20 +17,16 @@ package org.apache.spark.scheduler.cluster +import java.util.concurrent.{TimeUnit, Executors} import java.util.concurrent.atomic.AtomicInteger import scala.collection.mutable.{ArrayBuffer, HashMap, HashSet} -import scala.concurrent.Await -import scala.concurrent.duration._ - -import akka.actor._ -import akka.pattern.ask -import akka.remote.{DisassociatedEvent, RemotingLifecycleEvent} +import org.apache.spark.rpc._ import org.apache.spark.{ExecutorAllocationClient, Logging, SparkEnv, SparkException, TaskState} import org.apache.spark.scheduler._ import org.apache.spark.scheduler.cluster.CoarseGrainedClusterMessages._ -import org.apache.spark.util.{ActorLogReceive, SerializableBuffer, AkkaUtils, Utils} +import org.apache.spark.util.{SerializableBuffer, AkkaUtils, Utils} /** * A scheduler backend that waits for coarse grained executors to connect to it through Akka. @@ -41,7 +37,7 @@ import org.apache.spark.util.{ActorLogReceive, SerializableBuffer, AkkaUtils, Ut * (spark.deploy.*). */ private[spark] -class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, val actorSystem: ActorSystem) +class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, val rpcEnv: RpcEnv) extends ExecutorAllocationClient with SchedulerBackend with Logging { // Use an atomic variable to track total number of cores in the cluster for simplicity and speed @@ -49,7 +45,6 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, val actorSyste // Total number of executors that are currently registered var totalRegisteredExecutors = new AtomicInteger(0) val conf = scheduler.sc.conf - private val timeout = AkkaUtils.askTimeout(conf) private val akkaFrameSize = AkkaUtils.maxFrameSizeBytes(conf) // Submit tasks only after (registered resources / total expected resources) // is equal to at least this value, that is double between 0 and 1. @@ -71,48 +66,26 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, val actorSyste // Executors we have requested the cluster manager to kill that have not died yet private val executorsPendingToRemove = new HashSet[String] - class DriverActor(sparkProperties: Seq[(String, String)]) extends Actor with ActorLogReceive { + class DriverEndpoint(override val rpcEnv: RpcEnv, sparkProperties: Seq[(String, String)]) + extends ThreadSafeRpcEndpoint with Logging { override protected def log = CoarseGrainedSchedulerBackend.this.log - private val addressToExecutorId = new HashMap[Address, String] - override def preStart() { - // Listen for remote client disconnection events, since they don't go through Akka's watch() - context.system.eventStream.subscribe(self, classOf[RemotingLifecycleEvent]) + private val addressToExecutorId = new HashMap[RpcAddress, String] + + private val reviveThread = + Executors.newSingleThreadScheduledExecutor(Utils.namedThreadFactory("driver-revive-thread")) + override def onStart() { // Periodically revive offers to allow delay scheduling to work val reviveInterval = conf.getLong("spark.scheduler.revive.interval", 1000) - import context.dispatcher - context.system.scheduler.schedule(0.millis, reviveInterval.millis, self, ReviveOffers) - } - - def receiveWithLogging: PartialFunction[Any, Unit] = { - case RegisterExecutor(executorId, hostPort, cores, logUrls) => - Utils.checkHostPort(hostPort, "Host port expected " + hostPort) - if (executorDataMap.contains(executorId)) { - sender ! RegisterExecutorFailed("Duplicate executor ID: " + executorId) - } else { - logInfo("Registered executor: " + sender + " with ID " + executorId) - sender ! RegisteredExecutor - - addressToExecutorId(sender.path.address) = executorId - totalCoreCount.addAndGet(cores) - totalRegisteredExecutors.addAndGet(1) - val (host, _) = Utils.parseHostPort(hostPort) - val data = new ExecutorData(sender, sender.path.address, host, cores, cores, logUrls) - // This must be synchronized because variables mutated - // in this block are read when requesting executors - CoarseGrainedSchedulerBackend.this.synchronized { - executorDataMap.put(executorId, data) - if (numPendingExecutors > 0) { - numPendingExecutors -= 1 - logDebug(s"Decremented number of pending executors ($numPendingExecutors left)") - } - } - listenerBus.post( - SparkListenerExecutorAdded(System.currentTimeMillis(), executorId, data)) - makeOffers() + reviveThread.scheduleAtFixedRate(new Runnable { + override def run(): Unit = Utils.tryLogNonFatalError { + Option(self).foreach(_.send(ReviveOffers)) } + }, 0, reviveInterval, TimeUnit.MILLISECONDS) + } + override def receive: PartialFunction[Any, Unit] = { case StatusUpdate(executorId, taskId, state, data) => scheduler.statusUpdate(taskId, state, data.value) if (TaskState.isFinished(state)) { @@ -133,33 +106,58 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, val actorSyste case KillTask(taskId, executorId, interruptThread) => executorDataMap.get(executorId) match { case Some(executorInfo) => - executorInfo.executorActor ! KillTask(taskId, executorId, interruptThread) + executorInfo.executorEndpoint.send(KillTask(taskId, executorId, interruptThread)) case None => // Ignoring the task kill since the executor is not registered. logWarning(s"Attempted to kill task $taskId for unknown executor $executorId.") } + } + + override def receiveAndReply(context: RpcCallContext): PartialFunction[Any, Unit] = { + case RegisterExecutor(executorId, executorRef, hostPort, cores, logUrls) => + Utils.checkHostPort(hostPort, "Host port expected " + hostPort) + if (executorDataMap.contains(executorId)) { + context.reply(RegisterExecutorFailed("Duplicate executor ID: " + executorId)) + } else { + logInfo("Registered executor: " + executorRef + " with ID " + executorId) + context.reply(RegisteredExecutor) + + addressToExecutorId(executorRef.address) = executorId + totalCoreCount.addAndGet(cores) + totalRegisteredExecutors.addAndGet(1) + val (host, _) = Utils.parseHostPort(hostPort) + val data = new ExecutorData(executorRef, executorRef.address, host, cores, cores, logUrls) + // This must be synchronized because variables mutated + // in this block are read when requesting executors + CoarseGrainedSchedulerBackend.this.synchronized { + executorDataMap.put(executorId, data) + if (numPendingExecutors > 0) { + numPendingExecutors -= 1 + logDebug(s"Decremented number of pending executors ($numPendingExecutors left)") + } + } + listenerBus.post( + SparkListenerExecutorAdded(System.currentTimeMillis(), executorId, data)) + makeOffers() + } case StopDriver => - sender ! true - context.stop(self) + context.reply(true) + stop() case StopExecutors => logInfo("Asking each executor to shut down") for ((_, executorData) <- executorDataMap) { - executorData.executorActor ! StopExecutor + executorData.executorEndpoint.send(StopExecutor) } - sender ! true + context.reply(true) case RemoveExecutor(executorId, reason) => removeExecutor(executorId, reason) - sender ! true - - case DisassociatedEvent(_, address, _) => - addressToExecutorId.get(address).foreach(removeExecutor(_, - "remote Akka client disassociated")) + context.reply(true) case RetrieveSparkProps => - sender ! sparkProperties + context.reply(sparkProperties) } // Make fake resource offers on all executors @@ -169,6 +167,11 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, val actorSyste }.toSeq)) } + override def onDisconnected(remoteAddress: RpcAddress): Unit = { + addressToExecutorId.get(remoteAddress).foreach(removeExecutor(_, + "remote Rpc client disassociated")) + } + // Make fake resource offers on just one executor def makeOffers(executorId: String) { val executorData = executorDataMap(executorId) @@ -199,7 +202,7 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, val actorSyste else { val executorData = executorDataMap(task.executorId) executorData.freeCores -= scheduler.CPUS_PER_TASK - executorData.executorActor ! LaunchTask(new SerializableBuffer(serializedTask)) + executorData.executorEndpoint.send(LaunchTask(new SerializableBuffer(serializedTask))) } } } @@ -223,9 +226,13 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, val actorSyste case None => logError(s"Asked to remove non-existent executor $executorId") } } + + override def onStop() { + reviveThread.shutdownNow() + } } - var driverActor: ActorRef = null + var driverEndpoint: RpcEndpointRef = null val taskIdsOnSlave = new HashMap[String, HashSet[String]] override def start() { @@ -236,16 +243,15 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, val actorSyste } } // TODO (prashant) send conf instead of properties - driverActor = actorSystem.actorOf( - Props(new DriverActor(properties)), name = CoarseGrainedSchedulerBackend.ACTOR_NAME) + driverEndpoint = rpcEnv.setupEndpoint( + CoarseGrainedSchedulerBackend.ENDPOINT_NAME, new DriverEndpoint(rpcEnv, properties)) } def stopExecutors() { try { - if (driverActor != null) { + if (driverEndpoint != null) { logInfo("Shutting down all executors") - val future = driverActor.ask(StopExecutors)(timeout) - Await.ready(future, timeout) + driverEndpoint.askWithReply[Boolean](StopExecutors) } } catch { case e: Exception => @@ -256,22 +262,21 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, val actorSyste override def stop() { stopExecutors() try { - if (driverActor != null) { - val future = driverActor.ask(StopDriver)(timeout) - Await.ready(future, timeout) + if (driverEndpoint != null) { + driverEndpoint.askWithReply[Boolean](StopDriver) } } catch { case e: Exception => - throw new SparkException("Error stopping standalone scheduler's driver actor", e) + throw new SparkException("Error stopping standalone scheduler's driver endpoint", e) } } override def reviveOffers() { - driverActor ! ReviveOffers + driverEndpoint.send(ReviveOffers) } override def killTask(taskId: Long, executorId: String, interruptThread: Boolean) { - driverActor ! KillTask(taskId, executorId, interruptThread) + driverEndpoint.send(KillTask(taskId, executorId, interruptThread)) } override def defaultParallelism(): Int = { @@ -281,11 +286,10 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, val actorSyste // Called by subclasses when notified of a lost worker def removeExecutor(executorId: String, reason: String) { try { - val future = driverActor.ask(RemoveExecutor(executorId, reason))(timeout) - Await.ready(future, timeout) + driverEndpoint.askWithReply[Boolean](RemoveExecutor(executorId, reason)) } catch { case e: Exception => - throw new SparkException("Error notifying standalone scheduler's driver actor", e) + throw new SparkException("Error notifying standalone scheduler's driver endpoint", e) } } @@ -391,5 +395,5 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, val actorSyste } private[spark] object CoarseGrainedSchedulerBackend { - val ACTOR_NAME = "CoarseGrainedScheduler" + val ENDPOINT_NAME = "CoarseGrainedScheduler" } diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/ExecutorData.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/ExecutorData.scala index 5e571efe76720..26e72c0bff38d 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/ExecutorData.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/ExecutorData.scala @@ -17,20 +17,20 @@ package org.apache.spark.scheduler.cluster -import akka.actor.{Address, ActorRef} +import org.apache.spark.rpc.{RpcEndpointRef, RpcAddress} /** * Grouping of data for an executor used by CoarseGrainedSchedulerBackend. * - * @param executorActor The ActorRef representing this executor + * @param executorEndpoint The ActorRef representing this executor * @param executorAddress The network address of this executor * @param executorHost The hostname that this executor is running on * @param freeCores The current number of cores available for work on the executor * @param totalCores The total number of cores available to the executor */ private[cluster] class ExecutorData( - val executorActor: ActorRef, - val executorAddress: Address, + val executorEndpoint: RpcEndpointRef, + val executorAddress: RpcAddress, override val executorHost: String, var freeCores: Int, override val totalCores: Int, diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/SimrSchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/SimrSchedulerBackend.scala index 06786a59524e7..0324c9dab910b 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/SimrSchedulerBackend.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/SimrSchedulerBackend.scala @@ -19,16 +19,16 @@ package org.apache.spark.scheduler.cluster import org.apache.hadoop.fs.{Path, FileSystem} +import org.apache.spark.rpc.RpcAddress import org.apache.spark.{Logging, SparkContext, SparkEnv} import org.apache.spark.deploy.SparkHadoopUtil import org.apache.spark.scheduler.TaskSchedulerImpl -import org.apache.spark.util.AkkaUtils private[spark] class SimrSchedulerBackend( scheduler: TaskSchedulerImpl, sc: SparkContext, driverFilePath: String) - extends CoarseGrainedSchedulerBackend(scheduler, sc.env.actorSystem) + extends CoarseGrainedSchedulerBackend(scheduler, sc.env.rpcEnv) with Logging { val tmpPath = new Path(driverFilePath + "_tmp") @@ -39,12 +39,9 @@ private[spark] class SimrSchedulerBackend( override def start() { super.start() - val driverUrl = AkkaUtils.address( - AkkaUtils.protocol(actorSystem), - SparkEnv.driverActorSystemName, - sc.conf.get("spark.driver.host"), - sc.conf.get("spark.driver.port"), - CoarseGrainedSchedulerBackend.ACTOR_NAME) + val driverUrl = rpcEnv.uriOf(SparkEnv.driverActorSystemName, + RpcAddress(sc.conf.get("spark.driver.host"), sc.conf.get("spark.driver.port").toInt), + CoarseGrainedSchedulerBackend.ENDPOINT_NAME) val conf = SparkHadoopUtil.get.newConfiguration(sc.conf) val fs = FileSystem.get(conf) diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/SparkDeploySchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/SparkDeploySchedulerBackend.scala index ffd4825705755..7eb3fdc19b5b8 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/SparkDeploySchedulerBackend.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/SparkDeploySchedulerBackend.scala @@ -19,17 +19,18 @@ package org.apache.spark.scheduler.cluster import java.util.concurrent.Semaphore +import org.apache.spark.rpc.RpcAddress import org.apache.spark.{Logging, SparkConf, SparkContext, SparkEnv} import org.apache.spark.deploy.{ApplicationDescription, Command} import org.apache.spark.deploy.client.{AppClient, AppClientListener} import org.apache.spark.scheduler.{ExecutorExited, ExecutorLossReason, SlaveLost, TaskSchedulerImpl} -import org.apache.spark.util.{AkkaUtils, Utils} +import org.apache.spark.util.Utils private[spark] class SparkDeploySchedulerBackend( scheduler: TaskSchedulerImpl, sc: SparkContext, masters: Array[String]) - extends CoarseGrainedSchedulerBackend(scheduler, sc.env.actorSystem) + extends CoarseGrainedSchedulerBackend(scheduler, sc.env.rpcEnv) with AppClientListener with Logging { @@ -48,12 +49,9 @@ private[spark] class SparkDeploySchedulerBackend( super.start() // The endpoint for executors to talk to us - val driverUrl = AkkaUtils.address( - AkkaUtils.protocol(actorSystem), - SparkEnv.driverActorSystemName, - conf.get("spark.driver.host"), - conf.get("spark.driver.port"), - CoarseGrainedSchedulerBackend.ACTOR_NAME) + val driverUrl = rpcEnv.uriOf(SparkEnv.driverActorSystemName, + RpcAddress(sc.conf.get("spark.driver.host"), sc.conf.get("spark.driver.port").toInt), + CoarseGrainedSchedulerBackend.ENDPOINT_NAME) val args = Seq( "--driver-url", driverUrl, "--executor-id", "{{EXECUTOR_ID}}", diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/YarnSchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/YarnSchedulerBackend.scala index 5a38ad9f2b12c..f72566c370a6f 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/YarnSchedulerBackend.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/YarnSchedulerBackend.scala @@ -19,10 +19,8 @@ package org.apache.spark.scheduler.cluster import scala.concurrent.{Future, ExecutionContext} -import akka.actor.{Actor, ActorRef, Props} -import akka.remote.{DisassociatedEvent, RemotingLifecycleEvent} - -import org.apache.spark.SparkContext +import org.apache.spark.{Logging, SparkContext} +import org.apache.spark.rpc._ import org.apache.spark.scheduler.cluster.CoarseGrainedClusterMessages._ import org.apache.spark.scheduler.TaskSchedulerImpl import org.apache.spark.ui.JettyUtils @@ -37,7 +35,7 @@ import scala.util.control.NonFatal private[spark] abstract class YarnSchedulerBackend( scheduler: TaskSchedulerImpl, sc: SparkContext) - extends CoarseGrainedSchedulerBackend(scheduler, sc.env.actorSystem) { + extends CoarseGrainedSchedulerBackend(scheduler, sc.env.rpcEnv) { if (conf.getOption("spark.scheduler.minRegisteredResourcesRatio").isEmpty) { minRegisteredRatio = 0.8 @@ -45,10 +43,8 @@ private[spark] abstract class YarnSchedulerBackend( protected var totalExpectedExecutors = 0 - private val yarnSchedulerActor: ActorRef = - actorSystem.actorOf( - Props(new YarnSchedulerActor), - name = YarnSchedulerBackend.ACTOR_NAME) + private val yarnSchedulerEndpoint = rpcEnv.setupEndpoint( + YarnSchedulerBackend.ENDPOINT_NAME, new YarnSchedulerEndpoint(rpcEnv)) private implicit val askTimeout = AkkaUtils.askTimeout(sc.conf) @@ -57,16 +53,14 @@ private[spark] abstract class YarnSchedulerBackend( * This includes executors already pending or running. */ override def doRequestTotalExecutors(requestedTotal: Int): Boolean = { - AkkaUtils.askWithReply[Boolean]( - RequestExecutors(requestedTotal), yarnSchedulerActor, askTimeout) + yarnSchedulerEndpoint.askWithReply[Boolean](RequestExecutors(requestedTotal)) } /** * Request that the ApplicationMaster kill the specified executors. */ override def doKillExecutors(executorIds: Seq[String]): Boolean = { - AkkaUtils.askWithReply[Boolean]( - KillExecutors(executorIds), yarnSchedulerActor, askTimeout) + yarnSchedulerEndpoint.askWithReply[Boolean](KillExecutors(executorIds)) } override def sufficientResourcesRegistered(): Boolean = { @@ -96,64 +90,71 @@ private[spark] abstract class YarnSchedulerBackend( } /** - * An actor that communicates with the ApplicationMaster. + * An [[RpcEndpoint]] that communicates with the ApplicationMaster. */ - private class YarnSchedulerActor extends Actor { - private var amActor: Option[ActorRef] = None - - implicit val askAmActorExecutor = ExecutionContext.fromExecutor( - Utils.newDaemonCachedThreadPool("yarn-scheduler-ask-am-executor")) + private class YarnSchedulerEndpoint(override val rpcEnv: RpcEnv) + extends ThreadSafeRpcEndpoint with Logging { + private var amEndpoint: Option[RpcEndpointRef] = None - override def preStart(): Unit = { - // Listen for disassociation events - context.system.eventStream.subscribe(self, classOf[RemotingLifecycleEvent]) - } + private val askAmThreadPool = + Utils.newDaemonCachedThreadPool("yarn-scheduler-ask-am-thread-pool") + implicit val askAmExecutor = ExecutionContext.fromExecutor(askAmThreadPool) override def receive: PartialFunction[Any, Unit] = { - case RegisterClusterManager => - logInfo(s"ApplicationMaster registered as $sender") - amActor = Some(sender) + case RegisterClusterManager(am) => + logInfo(s"ApplicationMaster registered as $am") + amEndpoint = Some(am) + + case AddWebUIFilter(filterName, filterParams, proxyBase) => + addWebUIFilter(filterName, filterParams, proxyBase) + + } + override def receiveAndReply(context: RpcCallContext): PartialFunction[Any, Unit] = { case r: RequestExecutors => - amActor match { - case Some(actor) => - val driverActor = sender + amEndpoint match { + case Some(am) => Future { - driverActor ! AkkaUtils.askWithReply[Boolean](r, actor, askTimeout) + context.reply(am.askWithReply[Boolean](r)) } onFailure { - case NonFatal(e) => logError(s"Sending $r to AM was unsuccessful", e) + case NonFatal(e) => + logError(s"Sending $r to AM was unsuccessful", e) + context.sendFailure(e) } case None => logWarning("Attempted to request executors before the AM has registered!") - sender ! false + context.reply(false) } case k: KillExecutors => - amActor match { - case Some(actor) => - val driverActor = sender + amEndpoint match { + case Some(am) => Future { - driverActor ! AkkaUtils.askWithReply[Boolean](k, actor, askTimeout) + context.reply(am.askWithReply[Boolean](k)) } onFailure { - case NonFatal(e) => logError(s"Sending $k to AM was unsuccessful", e) + case NonFatal(e) => + logError(s"Sending $k to AM was unsuccessful", e) + context.sendFailure(e) } case None => logWarning("Attempted to kill executors before the AM has registered!") - sender ! false + context.reply(false) } - case AddWebUIFilter(filterName, filterParams, proxyBase) => - addWebUIFilter(filterName, filterParams, proxyBase) - sender ! true + } - case d: DisassociatedEvent => - if (amActor.isDefined && sender == amActor.get) { - logWarning(s"ApplicationMaster has disassociated: $d") - } + override def onDisconnected(remoteAddress: RpcAddress): Unit = { + if (amEndpoint.exists(_.address == remoteAddress)) { + logWarning(s"ApplicationMaster has disassociated: $remoteAddress") + } + } + + override def onStop(): Unit ={ + askAmThreadPool.shutdownNow() } } } private[spark] object YarnSchedulerBackend { - val ACTOR_NAME = "YarnScheduler" + val ENDPOINT_NAME = "YarnScheduler" } diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/CoarseMesosSchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/CoarseMesosSchedulerBackend.scala index e13de0f46ef89..b037a4966ced0 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/CoarseMesosSchedulerBackend.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/CoarseMesosSchedulerBackend.scala @@ -47,7 +47,7 @@ private[spark] class CoarseMesosSchedulerBackend( scheduler: TaskSchedulerImpl, sc: SparkContext, master: String) - extends CoarseGrainedSchedulerBackend(scheduler, sc.env.actorSystem) + extends CoarseGrainedSchedulerBackend(scheduler, sc.env.rpcEnv) with MScheduler with Logging { @@ -148,7 +148,7 @@ private[spark] class CoarseMesosSchedulerBackend( SparkEnv.driverActorSystemName, conf.get("spark.driver.host"), conf.get("spark.driver.port"), - CoarseGrainedSchedulerBackend.ACTOR_NAME) + CoarseGrainedSchedulerBackend.ENDPOINT_NAME) val uri = conf.get("spark.executor.uri", null) if (uri == null) { diff --git a/core/src/main/scala/org/apache/spark/scheduler/local/LocalBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/local/LocalBackend.scala index eb3f999b5b375..70a477a6895cc 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/local/LocalBackend.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/local/LocalBackend.scala @@ -18,17 +18,14 @@ package org.apache.spark.scheduler.local import java.nio.ByteBuffer +import java.util.concurrent.{Executors, TimeUnit} -import scala.concurrent.duration._ -import scala.language.postfixOps - -import akka.actor.{Actor, ActorRef, Props} - +import org.apache.spark.rpc.{ThreadSafeRpcEndpoint, RpcEndpointRef, RpcEnv} +import org.apache.spark.util.Utils import org.apache.spark.{Logging, SparkContext, SparkEnv, TaskState} import org.apache.spark.TaskState.TaskState import org.apache.spark.executor.{Executor, ExecutorBackend} import org.apache.spark.scheduler.{SchedulerBackend, TaskSchedulerImpl, WorkerOffer} -import org.apache.spark.util.ActorLogReceive private case class ReviveOffers() @@ -39,17 +36,19 @@ private case class KillTask(taskId: Long, interruptThread: Boolean) private case class StopExecutor() /** - * Calls to LocalBackend are all serialized through LocalActor. Using an actor makes the calls on - * LocalBackend asynchronous, which is necessary to prevent deadlock between LocalBackend + * Calls to LocalBackend are all serialized through LocalEndpoint. Using an RpcEndpoint makes the + * calls on LocalBackend asynchronous, which is necessary to prevent deadlock between LocalBackend * and the TaskSchedulerImpl. */ -private[spark] class LocalActor( +private[spark] class LocalEndpoint( + override val rpcEnv: RpcEnv, scheduler: TaskSchedulerImpl, executorBackend: LocalBackend, private val totalCores: Int) - extends Actor with ActorLogReceive with Logging { + extends ThreadSafeRpcEndpoint with Logging { - import context.dispatcher // to use Akka's scheduler.scheduleOnce() + private val reviveThread = Executors.newSingleThreadScheduledExecutor( + Utils.namedThreadFactory("local-revive-thread")) private var freeCores = totalCores @@ -59,7 +58,7 @@ private[spark] class LocalActor( private val executor = new Executor( localExecutorId, localExecutorHostname, SparkEnv.get, isLocal = true) - override def receiveWithLogging: PartialFunction[Any, Unit] = { + override def receive: PartialFunction[Any, Unit] = { case ReviveOffers => reviveOffers() @@ -87,9 +86,17 @@ private[spark] class LocalActor( } if (tasks.isEmpty && scheduler.activeTaskSets.nonEmpty) { // Try to reviveOffer after 1 second, because scheduler may wait for locality timeout - context.system.scheduler.scheduleOnce(1000 millis, self, ReviveOffers) + reviveThread.schedule(new Runnable { + override def run(): Unit = Utils.tryLogNonFatalError { + Option(self).foreach(_.send(ReviveOffers)) + } + }, 1000, TimeUnit.MILLISECONDS) } } + + override def onStop(): Unit = { + reviveThread.shutdownNow() + } } /** @@ -101,31 +108,30 @@ private[spark] class LocalBackend(scheduler: TaskSchedulerImpl, val totalCores: extends SchedulerBackend with ExecutorBackend { private val appId = "local-" + System.currentTimeMillis - var localActor: ActorRef = null + var localEndpoint: RpcEndpointRef = null override def start() { - localActor = SparkEnv.get.actorSystem.actorOf( - Props(new LocalActor(scheduler, this, totalCores)), - "LocalBackendActor") + localEndpoint = SparkEnv.get.rpcEnv.setupEndpoint( + "LocalBackendEndpoint", new LocalEndpoint(SparkEnv.get.rpcEnv, scheduler, this, totalCores)) } override def stop() { - localActor ! StopExecutor + localEndpoint.send(StopExecutor) } override def reviveOffers() { - localActor ! ReviveOffers + localEndpoint.send(ReviveOffers) } override def defaultParallelism(): Int = scheduler.conf.getInt("spark.default.parallelism", totalCores) override def killTask(taskId: Long, executorId: String, interruptThread: Boolean) { - localActor ! KillTask(taskId, interruptThread) + localEndpoint.send(KillTask(taskId, interruptThread)) } override def statusUpdate(taskId: Long, state: TaskState, serializedData: ByteBuffer) { - localActor ! StatusUpdate(taskId, state, serializedData) + localEndpoint.send(StatusUpdate(taskId, state, serializedData)) } override def applicationId(): String = appId diff --git a/core/src/main/scala/org/apache/spark/shuffle/FileShuffleBlockManager.scala b/core/src/main/scala/org/apache/spark/shuffle/FileShuffleBlockManager.scala index d0178dfde6935..5be3ed771e534 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/FileShuffleBlockManager.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/FileShuffleBlockManager.scala @@ -67,7 +67,7 @@ private[spark] trait ShuffleWriterGroup { // org.apache.spark.network.shuffle.StandaloneShuffleBlockManager#getHashBasedShuffleBlockData(). private[spark] class FileShuffleBlockManager(conf: SparkConf) - extends ShuffleBlockManager with Logging { + extends ShuffleBlockResolver with Logging { private val transportConf = SparkTransportConf.fromSparkConf(conf) @@ -175,11 +175,6 @@ class FileShuffleBlockManager(conf: SparkConf) } } - override def getBytes(blockId: ShuffleBlockId): Option[ByteBuffer] = { - val segment = getBlockData(blockId) - Some(segment.nioByteBuffer()) - } - override def getBlockData(blockId: ShuffleBlockId): ManagedBuffer = { if (consolidateShuffleFiles) { // Search all file groups associated with this shuffle. diff --git a/core/src/main/scala/org/apache/spark/shuffle/IndexShuffleBlockManager.scala b/core/src/main/scala/org/apache/spark/shuffle/IndexShuffleBlockManager.scala index 87fd161e06c85..a1741e2875c16 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/IndexShuffleBlockManager.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/IndexShuffleBlockManager.scala @@ -26,6 +26,9 @@ import org.apache.spark.{SparkConf, SparkEnv} import org.apache.spark.network.buffer.{FileSegmentManagedBuffer, ManagedBuffer} import org.apache.spark.network.netty.SparkTransportConf import org.apache.spark.storage._ +import org.apache.spark.util.Utils + +import IndexShuffleBlockManager.NOOP_REDUCE_ID /** * Create and maintain the shuffle blocks' mapping between logic block and physical file location. @@ -39,25 +42,18 @@ import org.apache.spark.storage._ // Note: Changes to the format in this file should be kept in sync with // org.apache.spark.network.shuffle.StandaloneShuffleBlockManager#getSortBasedShuffleBlockData(). private[spark] -class IndexShuffleBlockManager(conf: SparkConf) extends ShuffleBlockManager { +class IndexShuffleBlockManager(conf: SparkConf) extends ShuffleBlockResolver { private lazy val blockManager = SparkEnv.get.blockManager private val transportConf = SparkTransportConf.fromSparkConf(conf) - /** - * Mapping to a single shuffleBlockId with reduce ID 0. - * */ - def consolidateId(shuffleId: Int, mapId: Int): ShuffleBlockId = { - ShuffleBlockId(shuffleId, mapId, 0) - } - def getDataFile(shuffleId: Int, mapId: Int): File = { - blockManager.diskBlockManager.getFile(ShuffleDataBlockId(shuffleId, mapId, 0)) + blockManager.diskBlockManager.getFile(ShuffleDataBlockId(shuffleId, mapId, NOOP_REDUCE_ID)) } private def getIndexFile(shuffleId: Int, mapId: Int): File = { - blockManager.diskBlockManager.getFile(ShuffleIndexBlockId(shuffleId, mapId, 0)) + blockManager.diskBlockManager.getFile(ShuffleIndexBlockId(shuffleId, mapId, NOOP_REDUCE_ID)) } /** @@ -83,24 +79,19 @@ class IndexShuffleBlockManager(conf: SparkConf) extends ShuffleBlockManager { def writeIndexFile(shuffleId: Int, mapId: Int, lengths: Array[Long]): Unit = { val indexFile = getIndexFile(shuffleId, mapId) val out = new DataOutputStream(new BufferedOutputStream(new FileOutputStream(indexFile))) - try { + Utils.tryWithSafeFinally { // We take in lengths of each block, need to convert it to offsets. var offset = 0L out.writeLong(offset) - for (length <- lengths) { offset += length out.writeLong(offset) } - } finally { + } { out.close() } } - override def getBytes(blockId: ShuffleBlockId): Option[ByteBuffer] = { - Some(getBlockData(blockId).nioByteBuffer()) - } - override def getBlockData(blockId: ShuffleBlockId): ManagedBuffer = { // The block is actually going to be a range of a single map output file for this map, so // find out the consolidated file, then the offset within that from our index @@ -123,3 +114,11 @@ class IndexShuffleBlockManager(conf: SparkConf) extends ShuffleBlockManager { override def stop(): Unit = {} } + +private[spark] object IndexShuffleBlockManager { + // No-op reduce ID used in interactions with disk store and BlockObjectWriter. + // The disk store currently expects puts to relate to a (map, reduce) pair, but in the sort + // shuffle outputs for several reduces are glommed into a single file. + // TODO: Avoid this entirely by having the DiskBlockObjectWriter not require a BlockId. + val NOOP_REDUCE_ID = 0 +} diff --git a/core/src/main/scala/org/apache/spark/shuffle/ShuffleBlockManager.scala b/core/src/main/scala/org/apache/spark/shuffle/ShuffleBlockResolver.scala similarity index 68% rename from core/src/main/scala/org/apache/spark/shuffle/ShuffleBlockManager.scala rename to core/src/main/scala/org/apache/spark/shuffle/ShuffleBlockResolver.scala index b521f0c7fc77e..4342b0d598b16 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/ShuffleBlockManager.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/ShuffleBlockResolver.scala @@ -22,15 +22,19 @@ import org.apache.spark.network.buffer.ManagedBuffer import org.apache.spark.storage.ShuffleBlockId private[spark] -trait ShuffleBlockManager { +/** + * Implementers of this trait understand how to retrieve block data for a logical shuffle block + * identifier (i.e. map, reduce, and shuffle). Implementations may use files or file segments to + * encapsulate shuffle data. This is used by the BlockStore to abstract over different shuffle + * implementations when shuffle data is retrieved. + */ +trait ShuffleBlockResolver { type ShuffleId = Int /** - * Get shuffle block data managed by the local ShuffleBlockManager. - * @return Some(ByteBuffer) if block found, otherwise None. + * Retrieve the data for the specified block. If the data for that block is not available, + * throws an unspecified exception. */ - def getBytes(blockId: ShuffleBlockId): Option[ByteBuffer] - def getBlockData(blockId: ShuffleBlockId): ManagedBuffer def stop(): Unit diff --git a/core/src/main/scala/org/apache/spark/shuffle/ShuffleManager.scala b/core/src/main/scala/org/apache/spark/shuffle/ShuffleManager.scala index a44a8e1249256..978366d1a1d1b 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/ShuffleManager.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/ShuffleManager.scala @@ -55,7 +55,10 @@ private[spark] trait ShuffleManager { */ def unregisterShuffle(shuffleId: Int): Boolean - def shuffleBlockManager: ShuffleBlockManager + /** + * Return a resolver capable of retrieving shuffle block data based on block coordinates. + */ + def shuffleBlockResolver: ShuffleBlockResolver /** Shut down this ShuffleManager. */ def stop(): Unit diff --git a/core/src/main/scala/org/apache/spark/shuffle/ShuffleWriter.scala b/core/src/main/scala/org/apache/spark/shuffle/ShuffleWriter.scala index b934480cfb9be..f6e6fe5defe09 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/ShuffleWriter.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/ShuffleWriter.scala @@ -23,7 +23,7 @@ import org.apache.spark.scheduler.MapStatus * Obtained inside a map task to write out records to the shuffle system. */ private[spark] trait ShuffleWriter[K, V] { - /** Write a bunch of records to this task's output */ + /** Write a sequence of records to this task's output */ def write(records: Iterator[_ <: Product2[K, V]]): Unit /** Close this writer, passing along whether the map completed */ diff --git a/core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleManager.scala b/core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleManager.scala index 62e0629b34400..2a7df8dd5bd83 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleManager.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleManager.scala @@ -53,20 +53,20 @@ private[spark] class HashShuffleManager(conf: SparkConf) extends ShuffleManager override def getWriter[K, V](handle: ShuffleHandle, mapId: Int, context: TaskContext) : ShuffleWriter[K, V] = { new HashShuffleWriter( - shuffleBlockManager, handle.asInstanceOf[BaseShuffleHandle[K, V, _]], mapId, context) + shuffleBlockResolver, handle.asInstanceOf[BaseShuffleHandle[K, V, _]], mapId, context) } /** Remove a shuffle's metadata from the ShuffleManager. */ override def unregisterShuffle(shuffleId: Int): Boolean = { - shuffleBlockManager.removeShuffle(shuffleId) + shuffleBlockResolver.removeShuffle(shuffleId) } - override def shuffleBlockManager: FileShuffleBlockManager = { + override def shuffleBlockResolver: FileShuffleBlockManager = { fileShuffleBlockManager } /** Shut down this ShuffleManager. */ override def stop(): Unit = { - shuffleBlockManager.stop() + shuffleBlockResolver.stop() } } diff --git a/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleManager.scala b/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleManager.scala index bda30a56d808e..0497036192154 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleManager.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleManager.scala @@ -58,7 +58,7 @@ private[spark] class SortShuffleManager(conf: SparkConf) extends ShuffleManager val baseShuffleHandle = handle.asInstanceOf[BaseShuffleHandle[K, V, _]] shuffleMapNumber.putIfAbsent(baseShuffleHandle.shuffleId, baseShuffleHandle.numMaps) new SortShuffleWriter( - shuffleBlockManager, baseShuffleHandle, mapId, context) + shuffleBlockResolver, baseShuffleHandle, mapId, context) } /** Remove a shuffle's metadata from the ShuffleManager. */ @@ -66,18 +66,19 @@ private[spark] class SortShuffleManager(conf: SparkConf) extends ShuffleManager if (shuffleMapNumber.containsKey(shuffleId)) { val numMaps = shuffleMapNumber.remove(shuffleId) (0 until numMaps).map{ mapId => - shuffleBlockManager.removeDataByMap(shuffleId, mapId) + shuffleBlockResolver.removeDataByMap(shuffleId, mapId) } } true } - override def shuffleBlockManager: IndexShuffleBlockManager = { + override def shuffleBlockResolver: IndexShuffleBlockManager = { indexShuffleBlockManager } /** Shut down this ShuffleManager. */ override def stop(): Unit = { - shuffleBlockManager.stop() + shuffleBlockResolver.stop() } } + diff --git a/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleWriter.scala b/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleWriter.scala index 55ea0f17b156a..a066435df6fb0 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleWriter.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleWriter.scala @@ -58,8 +58,7 @@ private[spark] class SortShuffleWriter[K, V, C]( // In this case we pass neither an aggregator nor an ordering to the sorter, because we don't // care whether the keys get sorted in each partition; that will be done on the reduce side // if the operation being run is sortByKey. - sorter = new ExternalSorter[K, V, V]( - None, Some(dep.partitioner), None, dep.serializer) + sorter = new ExternalSorter[K, V, V](None, Some(dep.partitioner), None, dep.serializer) sorter.insertAll(records) } @@ -67,7 +66,7 @@ private[spark] class SortShuffleWriter[K, V, C]( // because it just opens a single file, so is typically too fast to measure accurately // (see SPARK-3570). val outputFile = shuffleBlockManager.getDataFile(dep.shuffleId, mapId) - val blockId = shuffleBlockManager.consolidateId(dep.shuffleId, mapId) + val blockId = ShuffleBlockId(dep.shuffleId, mapId, IndexShuffleBlockManager.NOOP_REDUCE_ID) val partitionLengths = sorter.writePartitionedFile(blockId, context, outputFile) shuffleBlockManager.writeIndexFile(dep.shuffleId, mapId, partitionLengths) @@ -100,3 +99,4 @@ private[spark] class SortShuffleWriter[K, V, C]( } } } + diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala index 1dff09a75d038..1aa0ef18de118 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala @@ -26,7 +26,6 @@ import scala.concurrent.ExecutionContext.Implicits.global import scala.concurrent.duration._ import scala.util.Random -import akka.actor.{ActorSystem, Props} import sun.nio.ch.DirectBuffer import org.apache.spark._ @@ -37,6 +36,7 @@ import org.apache.spark.network.buffer.{ManagedBuffer, NioManagedBuffer} import org.apache.spark.network.netty.SparkTransportConf import org.apache.spark.network.shuffle.ExternalShuffleClient import org.apache.spark.network.shuffle.protocol.ExecutorShuffleInfo +import org.apache.spark.rpc.RpcEnv import org.apache.spark.serializer.Serializer import org.apache.spark.shuffle.ShuffleManager import org.apache.spark.shuffle.hash.HashShuffleManager @@ -64,7 +64,7 @@ private[spark] class BlockResult( */ private[spark] class BlockManager( executorId: String, - actorSystem: ActorSystem, + rpcEnv: RpcEnv, val master: BlockManagerMaster, defaultSerializer: Serializer, maxMemory: Long, @@ -136,9 +136,9 @@ private[spark] class BlockManager( // Whether to compress shuffle output temporarily spilled to disk private val compressShuffleSpill = conf.getBoolean("spark.shuffle.spill.compress", true) - private val slaveActor = actorSystem.actorOf( - Props(new BlockManagerSlaveActor(this, mapOutputTracker)), - name = "BlockManagerActor" + BlockManager.ID_GENERATOR.next) + private val slaveEndpoint = rpcEnv.setupEndpoint( + "BlockManagerEndpoint" + BlockManager.ID_GENERATOR.next, + new BlockManagerSlaveEndpoint(rpcEnv, this, mapOutputTracker)) // Pending re-registration action being executed asynchronously or null if none is pending. // Accesses should synchronize on asyncReregisterLock. @@ -167,7 +167,7 @@ private[spark] class BlockManager( */ def this( execId: String, - actorSystem: ActorSystem, + rpcEnv: RpcEnv, master: BlockManagerMaster, serializer: Serializer, conf: SparkConf, @@ -176,7 +176,7 @@ private[spark] class BlockManager( blockTransferService: BlockTransferService, securityManager: SecurityManager, numUsableCores: Int) = { - this(execId, actorSystem, master, serializer, BlockManager.getMaxMemory(conf), + this(execId, rpcEnv, master, serializer, BlockManager.getMaxMemory(conf), conf, mapOutputTracker, shuffleManager, blockTransferService, securityManager, numUsableCores) } @@ -186,7 +186,7 @@ private[spark] class BlockManager( * where it is only learned after registration with the TaskScheduler). * * This method initializes the BlockTransferService and ShuffleClient, registers with the - * BlockManagerMaster, starts the BlockManagerWorker actor, and registers with a local shuffle + * BlockManagerMaster, starts the BlockManagerWorker endpoint, and registers with a local shuffle * service if configured. */ def initialize(appId: String): Unit = { @@ -202,7 +202,7 @@ private[spark] class BlockManager( blockManagerId } - master.registerBlockManager(blockManagerId, maxMemory, slaveActor) + master.registerBlockManager(blockManagerId, maxMemory, slaveEndpoint) // Register Executors' configuration with the local shuffle service, if one should exist. if (externalShuffleServiceEnabled && !blockManagerId.isDriver) { @@ -265,7 +265,7 @@ private[spark] class BlockManager( def reregister(): Unit = { // TODO: We might need to rate limit re-registering. logInfo("BlockManager re-registering with master") - master.registerBlockManager(blockManagerId, maxMemory, slaveActor) + master.registerBlockManager(blockManagerId, maxMemory, slaveEndpoint) reportAllBlocks() } @@ -301,7 +301,7 @@ private[spark] class BlockManager( */ override def getBlockData(blockId: BlockId): ManagedBuffer = { if (blockId.isShuffle) { - shuffleManager.shuffleBlockManager.getBlockData(blockId.asInstanceOf[ShuffleBlockId]) + shuffleManager.shuffleBlockResolver.getBlockData(blockId.asInstanceOf[ShuffleBlockId]) } else { val blockBytesOpt = doGetLocal(blockId, asBlockResult = false) .asInstanceOf[Option[ByteBuffer]] @@ -439,14 +439,10 @@ private[spark] class BlockManager( // As an optimization for map output fetches, if the block is for a shuffle, return it // without acquiring a lock; the disk store never deletes (recent) items so this should work if (blockId.isShuffle) { - val shuffleBlockManager = shuffleManager.shuffleBlockManager - shuffleBlockManager.getBytes(blockId.asInstanceOf[ShuffleBlockId]) match { - case Some(bytes) => - Some(bytes) - case None => - throw new BlockException( - blockId, s"Block $blockId not found on disk, though it should be") - } + val shuffleBlockManager = shuffleManager.shuffleBlockResolver + // TODO: This should gracefully handle case where local block is not available. Currently + // downstream code will throw an exception. + Option(shuffleBlockManager.getBlockData(blockId.asInstanceOf[ShuffleBlockId]).nioByteBuffer()) } else { doGetLocal(blockId, asBlockResult = false).asInstanceOf[Option[ByteBuffer]] } @@ -1219,7 +1215,7 @@ private[spark] class BlockManager( shuffleClient.close() } diskBlockManager.stop() - actorSystem.stop(slaveActor) + rpcEnv.stop(slaveEndpoint) blockInfo.clear() memoryStore.clear() diskStore.clear() diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManagerId.scala b/core/src/main/scala/org/apache/spark/storage/BlockManagerId.scala index a6f1ebf325a7c..69ac37511e730 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManagerId.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManagerId.scala @@ -60,7 +60,10 @@ class BlockManagerId private ( def port: Int = port_ - def isDriver: Boolean = { executorId == SparkContext.DRIVER_IDENTIFIER } + def isDriver: Boolean = { + executorId == SparkContext.DRIVER_IDENTIFIER || + executorId == SparkContext.LEGACY_DRIVER_IDENTIFIER + } override def writeExternal(out: ObjectOutput): Unit = Utils.tryOrIOException { out.writeUTF(executorId_) diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManagerMaster.scala b/core/src/main/scala/org/apache/spark/storage/BlockManagerMaster.scala index 061964826f08b..ceacf043029f3 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManagerMaster.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManagerMaster.scala @@ -20,35 +20,31 @@ package org.apache.spark.storage import scala.concurrent.{Await, Future} import scala.concurrent.ExecutionContext.Implicits.global -import akka.actor._ - +import org.apache.spark.rpc.RpcEndpointRef import org.apache.spark.{Logging, SparkConf, SparkException} import org.apache.spark.storage.BlockManagerMessages._ import org.apache.spark.util.AkkaUtils private[spark] class BlockManagerMaster( - var driverActor: ActorRef, + var driverEndpoint: RpcEndpointRef, conf: SparkConf, isDriver: Boolean) extends Logging { - private val AKKA_RETRY_ATTEMPTS: Int = AkkaUtils.numRetries(conf) - private val AKKA_RETRY_INTERVAL_MS: Int = AkkaUtils.retryWaitMs(conf) - - val DRIVER_AKKA_ACTOR_NAME = "BlockManagerMaster" val timeout = AkkaUtils.askTimeout(conf) - /** Remove a dead executor from the driver actor. This is only called on the driver side. */ + /** Remove a dead executor from the driver endpoint. This is only called on the driver side. */ def removeExecutor(execId: String) { tell(RemoveExecutor(execId)) logInfo("Removed " + execId + " successfully in removeExecutor") } /** Register the BlockManager's id with the driver. */ - def registerBlockManager(blockManagerId: BlockManagerId, maxMemSize: Long, slaveActor: ActorRef) { + def registerBlockManager( + blockManagerId: BlockManagerId, maxMemSize: Long, slaveEndpoint: RpcEndpointRef): Unit = { logInfo("Trying to register BlockManager") - tell(RegisterBlockManager(blockManagerId, maxMemSize, slaveActor)) + tell(RegisterBlockManager(blockManagerId, maxMemSize, slaveEndpoint)) logInfo("Registered BlockManager") } @@ -59,7 +55,7 @@ class BlockManagerMaster( memSize: Long, diskSize: Long, tachyonSize: Long): Boolean = { - val res = askDriverWithReply[Boolean]( + val res = driverEndpoint.askWithReply[Boolean]( UpdateBlockInfo(blockManagerId, blockId, storageLevel, memSize, diskSize, tachyonSize)) logDebug(s"Updated info of block $blockId") res @@ -67,12 +63,12 @@ class BlockManagerMaster( /** Get locations of the blockId from the driver */ def getLocations(blockId: BlockId): Seq[BlockManagerId] = { - askDriverWithReply[Seq[BlockManagerId]](GetLocations(blockId)) + driverEndpoint.askWithReply[Seq[BlockManagerId]](GetLocations(blockId)) } /** Get locations of multiple blockIds from the driver */ def getLocations(blockIds: Array[BlockId]): Seq[Seq[BlockManagerId]] = { - askDriverWithReply[Seq[Seq[BlockManagerId]]](GetLocationsMultipleBlockIds(blockIds)) + driverEndpoint.askWithReply[Seq[Seq[BlockManagerId]]](GetLocationsMultipleBlockIds(blockIds)) } /** @@ -85,11 +81,11 @@ class BlockManagerMaster( /** Get ids of other nodes in the cluster from the driver */ def getPeers(blockManagerId: BlockManagerId): Seq[BlockManagerId] = { - askDriverWithReply[Seq[BlockManagerId]](GetPeers(blockManagerId)) + driverEndpoint.askWithReply[Seq[BlockManagerId]](GetPeers(blockManagerId)) } - def getActorSystemHostPortForExecutor(executorId: String): Option[(String, Int)] = { - askDriverWithReply[Option[(String, Int)]](GetActorSystemHostPortForExecutor(executorId)) + def getRpcHostPortForExecutor(executorId: String): Option[(String, Int)] = { + driverEndpoint.askWithReply[Option[(String, Int)]](GetRpcHostPortForExecutor(executorId)) } /** @@ -97,12 +93,12 @@ class BlockManagerMaster( * blocks that the driver knows about. */ def removeBlock(blockId: BlockId) { - askDriverWithReply(RemoveBlock(blockId)) + driverEndpoint.askWithReply[Boolean](RemoveBlock(blockId)) } /** Remove all blocks belonging to the given RDD. */ def removeRdd(rddId: Int, blocking: Boolean) { - val future = askDriverWithReply[Future[Seq[Int]]](RemoveRdd(rddId)) + val future = driverEndpoint.askWithReply[Future[Seq[Int]]](RemoveRdd(rddId)) future.onFailure { case e: Exception => logWarning(s"Failed to remove RDD $rddId - ${e.getMessage}}") @@ -114,7 +110,7 @@ class BlockManagerMaster( /** Remove all blocks belonging to the given shuffle. */ def removeShuffle(shuffleId: Int, blocking: Boolean) { - val future = askDriverWithReply[Future[Seq[Boolean]]](RemoveShuffle(shuffleId)) + val future = driverEndpoint.askWithReply[Future[Seq[Boolean]]](RemoveShuffle(shuffleId)) future.onFailure { case e: Exception => logWarning(s"Failed to remove shuffle $shuffleId - ${e.getMessage}}") @@ -126,7 +122,7 @@ class BlockManagerMaster( /** Remove all blocks belonging to the given broadcast. */ def removeBroadcast(broadcastId: Long, removeFromMaster: Boolean, blocking: Boolean) { - val future = askDriverWithReply[Future[Seq[Int]]]( + val future = driverEndpoint.askWithReply[Future[Seq[Int]]]( RemoveBroadcast(broadcastId, removeFromMaster)) future.onFailure { case e: Exception => @@ -145,11 +141,11 @@ class BlockManagerMaster( * amount of remaining memory. */ def getMemoryStatus: Map[BlockManagerId, (Long, Long)] = { - askDriverWithReply[Map[BlockManagerId, (Long, Long)]](GetMemoryStatus) + driverEndpoint.askWithReply[Map[BlockManagerId, (Long, Long)]](GetMemoryStatus) } def getStorageStatus: Array[StorageStatus] = { - askDriverWithReply[Array[StorageStatus]](GetStorageStatus) + driverEndpoint.askWithReply[Array[StorageStatus]](GetStorageStatus) } /** @@ -165,11 +161,12 @@ class BlockManagerMaster( askSlaves: Boolean = true): Map[BlockManagerId, BlockStatus] = { val msg = GetBlockStatus(blockId, askSlaves) /* - * To avoid potential deadlocks, the use of Futures is necessary, because the master actor + * To avoid potential deadlocks, the use of Futures is necessary, because the master endpoint * should not block on waiting for a block manager, which can in turn be waiting for the - * master actor for a response to a prior message. + * master endpoint for a response to a prior message. */ - val response = askDriverWithReply[Map[BlockManagerId, Future[Option[BlockStatus]]]](msg) + val response = driverEndpoint. + askWithReply[Map[BlockManagerId, Future[Option[BlockStatus]]]](msg) val (blockManagerIds, futures) = response.unzip val result = Await.result(Future.sequence(futures), timeout) if (result == null) { @@ -193,33 +190,28 @@ class BlockManagerMaster( filter: BlockId => Boolean, askSlaves: Boolean): Seq[BlockId] = { val msg = GetMatchingBlockIds(filter, askSlaves) - val future = askDriverWithReply[Future[Seq[BlockId]]](msg) + val future = driverEndpoint.askWithReply[Future[Seq[BlockId]]](msg) Await.result(future, timeout) } - /** Stop the driver actor, called only on the Spark driver node */ + /** Stop the driver endpoint, called only on the Spark driver node */ def stop() { - if (driverActor != null && isDriver) { + if (driverEndpoint != null && isDriver) { tell(StopBlockManagerMaster) - driverActor = null + driverEndpoint = null logInfo("BlockManagerMaster stopped") } } - /** Send a one-way message to the master actor, to which we expect it to reply with true. */ + /** Send a one-way message to the master endpoint, to which we expect it to reply with true. */ private def tell(message: Any) { - if (!askDriverWithReply[Boolean](message)) { - throw new SparkException("BlockManagerMasterActor returned false, expected true.") + if (!driverEndpoint.askWithReply[Boolean](message)) { + throw new SparkException("BlockManagerMasterEndpoint returned false, expected true.") } } - /** - * Send a message to the driver actor and get its result within a default timeout, or - * throw a SparkException if this fails. - */ - private def askDriverWithReply[T](message: Any): T = { - AkkaUtils.askWithReply(message, driverActor, AKKA_RETRY_ATTEMPTS, AKKA_RETRY_INTERVAL_MS, - timeout) - } +} +private[spark] object BlockManagerMaster { + val DRIVER_ENDPOINT_NAME = "BlockManagerMaster" } diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterActor.scala b/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterEndpoint.scala similarity index 83% rename from core/src/main/scala/org/apache/spark/storage/BlockManagerMasterActor.scala rename to core/src/main/scala/org/apache/spark/storage/BlockManagerMasterEndpoint.scala index 5b5328016124e..28c73a7d543ff 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterActor.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterEndpoint.scala @@ -21,25 +21,26 @@ import java.util.{HashMap => JHashMap} import scala.collection.mutable import scala.collection.JavaConversions._ -import scala.concurrent.Future -import scala.concurrent.duration._ +import scala.concurrent.{ExecutionContext, Future} -import akka.actor.{Actor, ActorRef} -import akka.pattern.ask - -import org.apache.spark.{Logging, SparkConf, SparkException} +import org.apache.spark.rpc.{RpcEndpointRef, RpcEnv, RpcCallContext, ThreadSafeRpcEndpoint} +import org.apache.spark.{Logging, SparkConf} import org.apache.spark.annotation.DeveloperApi import org.apache.spark.scheduler._ import org.apache.spark.storage.BlockManagerMessages._ -import org.apache.spark.util.{ActorLogReceive, AkkaUtils, Utils} +import org.apache.spark.util.Utils /** - * BlockManagerMasterActor is an actor on the master node to track statuses of - * all slaves' block managers. + * BlockManagerMasterEndpoint is an [[ThreadSafeRpcEndpoint]] on the master node to track statuses + * of all slaves' block managers. */ private[spark] -class BlockManagerMasterActor(val isLocal: Boolean, conf: SparkConf, listenerBus: LiveListenerBus) - extends Actor with ActorLogReceive with Logging { +class BlockManagerMasterEndpoint( + override val rpcEnv: RpcEnv, + val isLocal: Boolean, + conf: SparkConf, + listenerBus: LiveListenerBus) + extends ThreadSafeRpcEndpoint with Logging { // Mapping from block manager id to the block manager's information. private val blockManagerInfo = new mutable.HashMap[BlockManagerId, BlockManagerInfo] @@ -50,68 +51,67 @@ class BlockManagerMasterActor(val isLocal: Boolean, conf: SparkConf, listenerBus // Mapping from block id to the set of block managers that have the block. private val blockLocations = new JHashMap[BlockId, mutable.HashSet[BlockManagerId]] - private val akkaTimeout = AkkaUtils.askTimeout(conf) + private val askThreadPool = Utils.newDaemonCachedThreadPool("block-manager-ask-thread-pool") + private implicit val askExecutionContext = ExecutionContext.fromExecutorService(askThreadPool) - override def receiveWithLogging: PartialFunction[Any, Unit] = { - case RegisterBlockManager(blockManagerId, maxMemSize, slaveActor) => - register(blockManagerId, maxMemSize, slaveActor) - sender ! true + override def receiveAndReply(context: RpcCallContext): PartialFunction[Any, Unit] = { + case RegisterBlockManager(blockManagerId, maxMemSize, slaveEndpoint) => + register(blockManagerId, maxMemSize, slaveEndpoint) + context.reply(true) case UpdateBlockInfo( blockManagerId, blockId, storageLevel, deserializedSize, size, tachyonSize) => - sender ! updateBlockInfo( - blockManagerId, blockId, storageLevel, deserializedSize, size, tachyonSize) + context.reply(updateBlockInfo( + blockManagerId, blockId, storageLevel, deserializedSize, size, tachyonSize)) case GetLocations(blockId) => - sender ! getLocations(blockId) + context.reply(getLocations(blockId)) case GetLocationsMultipleBlockIds(blockIds) => - sender ! getLocationsMultipleBlockIds(blockIds) + context.reply(getLocationsMultipleBlockIds(blockIds)) case GetPeers(blockManagerId) => - sender ! getPeers(blockManagerId) + context.reply(getPeers(blockManagerId)) - case GetActorSystemHostPortForExecutor(executorId) => - sender ! getActorSystemHostPortForExecutor(executorId) + case GetRpcHostPortForExecutor(executorId) => + context.reply(getRpcHostPortForExecutor(executorId)) case GetMemoryStatus => - sender ! memoryStatus + context.reply(memoryStatus) case GetStorageStatus => - sender ! storageStatus + context.reply(storageStatus) case GetBlockStatus(blockId, askSlaves) => - sender ! blockStatus(blockId, askSlaves) + context.reply(blockStatus(blockId, askSlaves)) case GetMatchingBlockIds(filter, askSlaves) => - sender ! getMatchingBlockIds(filter, askSlaves) + context.reply(getMatchingBlockIds(filter, askSlaves)) case RemoveRdd(rddId) => - sender ! removeRdd(rddId) + context.reply(removeRdd(rddId)) case RemoveShuffle(shuffleId) => - sender ! removeShuffle(shuffleId) + context.reply(removeShuffle(shuffleId)) case RemoveBroadcast(broadcastId, removeFromDriver) => - sender ! removeBroadcast(broadcastId, removeFromDriver) + context.reply(removeBroadcast(broadcastId, removeFromDriver)) case RemoveBlock(blockId) => removeBlockFromWorkers(blockId) - sender ! true + context.reply(true) case RemoveExecutor(execId) => removeExecutor(execId) - sender ! true + context.reply(true) case StopBlockManagerMaster => - sender ! true - context.stop(self) + context.reply(true) + stop() case BlockManagerHeartbeat(blockManagerId) => - sender ! heartbeatReceived(blockManagerId) + context.reply(heartbeatReceived(blockManagerId)) - case other => - logWarning("Got unknown message: " + other) } private def removeRdd(rddId: Int): Future[Seq[Int]] = { @@ -129,22 +129,20 @@ class BlockManagerMasterActor(val isLocal: Boolean, conf: SparkConf, listenerBus // Ask the slaves to remove the RDD, and put the result in a sequence of Futures. // The dispatcher is used as an implicit argument into the Future sequence construction. - import context.dispatcher val removeMsg = RemoveRdd(rddId) Future.sequence( blockManagerInfo.values.map { bm => - bm.slaveActor.ask(removeMsg)(akkaTimeout).mapTo[Int] + bm.slaveEndpoint.sendWithReply[Int](removeMsg) }.toSeq ) } private def removeShuffle(shuffleId: Int): Future[Seq[Boolean]] = { - // Nothing to do in the BlockManagerMasterActor data structures - import context.dispatcher + // Nothing to do in the BlockManagerMasterEndpoint data structures val removeMsg = RemoveShuffle(shuffleId) Future.sequence( blockManagerInfo.values.map { bm => - bm.slaveActor.ask(removeMsg)(akkaTimeout).mapTo[Boolean] + bm.slaveEndpoint.sendWithReply[Boolean](removeMsg) }.toSeq ) } @@ -155,14 +153,13 @@ class BlockManagerMasterActor(val isLocal: Boolean, conf: SparkConf, listenerBus * from the executors, but not from the driver. */ private def removeBroadcast(broadcastId: Long, removeFromDriver: Boolean): Future[Seq[Int]] = { - import context.dispatcher val removeMsg = RemoveBroadcast(broadcastId, removeFromDriver) val requiredBlockManagers = blockManagerInfo.values.filter { info => removeFromDriver || !info.blockManagerId.isDriver } Future.sequence( requiredBlockManagers.map { bm => - bm.slaveActor.ask(removeMsg)(akkaTimeout).mapTo[Int] + bm.slaveEndpoint.sendWithReply[Int](removeMsg) }.toSeq ) } @@ -217,7 +214,7 @@ class BlockManagerMasterActor(val isLocal: Boolean, conf: SparkConf, listenerBus // Remove the block from the slave's BlockManager. // Doesn't actually wait for a confirmation and the message might get lost. // If message loss becomes frequent, we should add retry logic here. - blockManager.get.slaveActor.ask(RemoveBlock(blockId))(akkaTimeout) + blockManager.get.slaveEndpoint.sendWithReply[Boolean](RemoveBlock(blockId)) } } } @@ -247,17 +244,16 @@ class BlockManagerMasterActor(val isLocal: Boolean, conf: SparkConf, listenerBus private def blockStatus( blockId: BlockId, askSlaves: Boolean): Map[BlockManagerId, Future[Option[BlockStatus]]] = { - import context.dispatcher val getBlockStatus = GetBlockStatus(blockId) /* - * Rather than blocking on the block status query, master actor should simply return + * Rather than blocking on the block status query, master endpoint should simply return * Futures to avoid potential deadlocks. This can arise if there exists a block manager - * that is also waiting for this master actor's response to a previous message. + * that is also waiting for this master endpoint's response to a previous message. */ blockManagerInfo.values.map { info => val blockStatusFuture = if (askSlaves) { - info.slaveActor.ask(getBlockStatus)(akkaTimeout).mapTo[Option[BlockStatus]] + info.slaveEndpoint.sendWithReply[Option[BlockStatus]](getBlockStatus) } else { Future { info.getStatus(blockId) } } @@ -276,13 +272,12 @@ class BlockManagerMasterActor(val isLocal: Boolean, conf: SparkConf, listenerBus private def getMatchingBlockIds( filter: BlockId => Boolean, askSlaves: Boolean): Future[Seq[BlockId]] = { - import context.dispatcher val getMatchingBlockIds = GetMatchingBlockIds(filter) Future.sequence( blockManagerInfo.values.map { info => val future = if (askSlaves) { - info.slaveActor.ask(getMatchingBlockIds)(akkaTimeout).mapTo[Seq[BlockId]] + info.slaveEndpoint.sendWithReply[Seq[BlockId]](getMatchingBlockIds) } else { Future { info.blocks.keys.filter(filter).toSeq } } @@ -291,7 +286,7 @@ class BlockManagerMasterActor(val isLocal: Boolean, conf: SparkConf, listenerBus ).map(_.flatten.toSeq) } - private def register(id: BlockManagerId, maxMemSize: Long, slaveActor: ActorRef) { + private def register(id: BlockManagerId, maxMemSize: Long, slaveEndpoint: RpcEndpointRef) { val time = System.currentTimeMillis() if (!blockManagerInfo.contains(id)) { blockManagerIdByExecutor.get(id.executorId) match { @@ -308,7 +303,7 @@ class BlockManagerMasterActor(val isLocal: Boolean, conf: SparkConf, listenerBus blockManagerIdByExecutor(id.executorId) = id blockManagerInfo(id) = new BlockManagerInfo( - id, System.currentTimeMillis(), maxMemSize, slaveActor) + id, System.currentTimeMillis(), maxMemSize, slaveEndpoint) } listenerBus.post(SparkListenerBlockManagerAdded(time, id, maxMemSize)) } @@ -379,19 +374,21 @@ class BlockManagerMasterActor(val isLocal: Boolean, conf: SparkConf, listenerBus } /** - * Returns the hostname and port of an executor's actor system, based on the Akka address of its - * BlockManagerSlaveActor. + * Returns the hostname and port of an executor, based on the [[RpcEnv]] address of its + * [[BlockManagerSlaveEndpoint]]. */ - private def getActorSystemHostPortForExecutor(executorId: String): Option[(String, Int)] = { + private def getRpcHostPortForExecutor(executorId: String): Option[(String, Int)] = { for ( blockManagerId <- blockManagerIdByExecutor.get(executorId); - info <- blockManagerInfo.get(blockManagerId); - host <- info.slaveActor.path.address.host; - port <- info.slaveActor.path.address.port + info <- blockManagerInfo.get(blockManagerId) ) yield { - (host, port) + (info.slaveEndpoint.address.host, info.slaveEndpoint.address.port) } } + + override def onStop(): Unit = { + askThreadPool.shutdownNow() + } } @DeveloperApi @@ -412,7 +409,7 @@ private[spark] class BlockManagerInfo( val blockManagerId: BlockManagerId, timeMs: Long, val maxMem: Long, - val slaveActor: ActorRef) + val slaveEndpoint: RpcEndpointRef) extends Logging { private var _lastSeenMs: Long = timeMs diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManagerMessages.scala b/core/src/main/scala/org/apache/spark/storage/BlockManagerMessages.scala index 48247453edef0..f89d8d7493f7c 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManagerMessages.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManagerMessages.scala @@ -19,8 +19,7 @@ package org.apache.spark.storage import java.io.{Externalizable, ObjectInput, ObjectOutput} -import akka.actor.ActorRef - +import org.apache.spark.rpc.RpcEndpointRef import org.apache.spark.util.Utils private[spark] object BlockManagerMessages { @@ -52,7 +51,7 @@ private[spark] object BlockManagerMessages { case class RegisterBlockManager( blockManagerId: BlockManagerId, maxMemSize: Long, - sender: ActorRef) + sender: RpcEndpointRef) extends ToBlockManagerMaster case class UpdateBlockInfo( @@ -92,7 +91,7 @@ private[spark] object BlockManagerMessages { case class GetPeers(blockManagerId: BlockManagerId) extends ToBlockManagerMaster - case class GetActorSystemHostPortForExecutor(executorId: String) extends ToBlockManagerMaster + case class GetRpcHostPortForExecutor(executorId: String) extends ToBlockManagerMaster case class RemoveExecutor(execId: String) extends ToBlockManagerMaster diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManagerSlaveActor.scala b/core/src/main/scala/org/apache/spark/storage/BlockManagerSlaveEndpoint.scala similarity index 61% rename from core/src/main/scala/org/apache/spark/storage/BlockManagerSlaveActor.scala rename to core/src/main/scala/org/apache/spark/storage/BlockManagerSlaveEndpoint.scala index 52fb896c4e21f..8980fa8eb70e2 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManagerSlaveActor.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManagerSlaveEndpoint.scala @@ -17,41 +17,43 @@ package org.apache.spark.storage -import scala.concurrent.Future - -import akka.actor.{ActorRef, Actor} +import scala.concurrent.{ExecutionContext, Future} +import org.apache.spark.rpc.{RpcEnv, RpcCallContext, RpcEndpoint} +import org.apache.spark.util.Utils import org.apache.spark.{Logging, MapOutputTracker, SparkEnv} import org.apache.spark.storage.BlockManagerMessages._ -import org.apache.spark.util.ActorLogReceive /** - * An actor to take commands from the master to execute options. For example, + * An RpcEndpoint to take commands from the master to execute options. For example, * this is used to remove blocks from the slave's BlockManager. */ private[storage] -class BlockManagerSlaveActor( +class BlockManagerSlaveEndpoint( + override val rpcEnv: RpcEnv, blockManager: BlockManager, mapOutputTracker: MapOutputTracker) - extends Actor with ActorLogReceive with Logging { + extends RpcEndpoint with Logging { - import context.dispatcher + private val asyncThreadPool = + Utils.newDaemonCachedThreadPool("block-manager-slave-async-thread-pool") + private implicit val asyncExecutionContext = ExecutionContext.fromExecutorService(asyncThreadPool) // Operations that involve removing blocks may be slow and should be done asynchronously - override def receiveWithLogging: PartialFunction[Any, Unit] = { + override def receiveAndReply(context: RpcCallContext): PartialFunction[Any, Unit] = { case RemoveBlock(blockId) => - doAsync[Boolean]("removing block " + blockId, sender) { + doAsync[Boolean]("removing block " + blockId, context) { blockManager.removeBlock(blockId) true } case RemoveRdd(rddId) => - doAsync[Int]("removing RDD " + rddId, sender) { + doAsync[Int]("removing RDD " + rddId, context) { blockManager.removeRdd(rddId) } case RemoveShuffle(shuffleId) => - doAsync[Boolean]("removing shuffle " + shuffleId, sender) { + doAsync[Boolean]("removing shuffle " + shuffleId, context) { if (mapOutputTracker != null) { mapOutputTracker.unregisterShuffle(shuffleId) } @@ -59,30 +61,34 @@ class BlockManagerSlaveActor( } case RemoveBroadcast(broadcastId, _) => - doAsync[Int]("removing broadcast " + broadcastId, sender) { + doAsync[Int]("removing broadcast " + broadcastId, context) { blockManager.removeBroadcast(broadcastId, tellMaster = true) } case GetBlockStatus(blockId, _) => - sender ! blockManager.getStatus(blockId) + context.reply(blockManager.getStatus(blockId)) case GetMatchingBlockIds(filter, _) => - sender ! blockManager.getMatchingBlockIds(filter) + context.reply(blockManager.getMatchingBlockIds(filter)) } - private def doAsync[T](actionMessage: String, responseActor: ActorRef)(body: => T) { + private def doAsync[T](actionMessage: String, context: RpcCallContext)(body: => T) { val future = Future { logDebug(actionMessage) body } future.onSuccess { case response => logDebug("Done " + actionMessage + ", response is " + response) - responseActor ! response - logDebug("Sent response: " + response + " to " + responseActor) + context.reply(response) + logDebug("Sent response: " + response + " to " + context.sender) } future.onFailure { case t: Throwable => logError("Error in " + actionMessage, t) - responseActor ! null.asInstanceOf[T] + context.sendFailure(t) } } + + override def onStop(): Unit = { + asyncThreadPool.shutdownNow() + } } diff --git a/core/src/main/scala/org/apache/spark/storage/BlockObjectWriter.scala b/core/src/main/scala/org/apache/spark/storage/BlockObjectWriter.scala index f703e50b6b0ac..0dfc91dfaff85 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockObjectWriter.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockObjectWriter.scala @@ -23,6 +23,7 @@ import java.nio.channels.FileChannel import org.apache.spark.Logging import org.apache.spark.serializer.{SerializationStream, Serializer} import org.apache.spark.executor.ShuffleWriteMetrics +import org.apache.spark.util.Utils /** * An interface for writing JVM objects to some underlying storage. This interface allows @@ -140,14 +141,17 @@ private[spark] class DiskBlockObjectWriter( override def close() { if (initialized) { - if (syncWrites) { - // Force outstanding writes to disk and track how long it takes - objOut.flush() - callWithTiming { - fos.getFD.sync() + Utils.tryWithSafeFinally { + if (syncWrites) { + // Force outstanding writes to disk and track how long it takes + objOut.flush() + callWithTiming { + fos.getFD.sync() + } } + } { + objOut.close() } - objOut.close() channel = null bs = null diff --git a/core/src/main/scala/org/apache/spark/storage/DiskStore.scala b/core/src/main/scala/org/apache/spark/storage/DiskStore.scala index 61ef5ff168791..4b232ae7d3180 100644 --- a/core/src/main/scala/org/apache/spark/storage/DiskStore.scala +++ b/core/src/main/scala/org/apache/spark/storage/DiskStore.scala @@ -46,10 +46,13 @@ private[spark] class DiskStore(blockManager: BlockManager, diskManager: DiskBloc val startTime = System.currentTimeMillis val file = diskManager.getFile(blockId) val channel = new FileOutputStream(file).getChannel - while (bytes.remaining > 0) { - channel.write(bytes) + Utils.tryWithSafeFinally { + while (bytes.remaining > 0) { + channel.write(bytes) + } + } { + channel.close() } - channel.close() val finishTime = System.currentTimeMillis logDebug("Block %s stored as %s file on disk in %d ms".format( file.getName, Utils.bytesToString(bytes.limit), finishTime - startTime)) @@ -75,9 +78,9 @@ private[spark] class DiskStore(blockManager: BlockManager, diskManager: DiskBloc val file = diskManager.getFile(blockId) val outputStream = new FileOutputStream(file) try { - try { + Utils.tryWithSafeFinally { blockManager.dataSerializeStream(blockId, outputStream, values) - } finally { + } { // Close outputStream here because it should be closed before file is deleted. outputStream.close() } @@ -106,8 +109,7 @@ private[spark] class DiskStore(blockManager: BlockManager, diskManager: DiskBloc private def getBytes(file: File, offset: Long, length: Long): Option[ByteBuffer] = { val channel = new RandomAccessFile(file, "r").getChannel - - try { + Utils.tryWithSafeFinally { // For small files, directly read rather than memory map if (length < minMemoryMapBytes) { val buf = ByteBuffer.allocate(length.toInt) @@ -123,7 +125,7 @@ private[spark] class DiskStore(blockManager: BlockManager, diskManager: DiskBloc } else { Some(channel.map(MapMode.READ_ONLY, offset, length)) } - } finally { + } { channel.close() } } diff --git a/core/src/main/scala/org/apache/spark/storage/RDDInfo.scala b/core/src/main/scala/org/apache/spark/storage/RDDInfo.scala index 0186eb30a1905..034525b56f59c 100644 --- a/core/src/main/scala/org/apache/spark/storage/RDDInfo.scala +++ b/core/src/main/scala/org/apache/spark/storage/RDDInfo.scala @@ -52,6 +52,6 @@ class RDDInfo( private[spark] object RDDInfo { def fromRdd(rdd: RDD[_]): RDDInfo = { val rddName = Option(rdd.name).getOrElse(rdd.id.toString) - new RDDInfo(rdd.id, rddName, rdd.partitions.size, rdd.getStorageLevel) + new RDDInfo(rdd.id, rddName, rdd.partitions.length, rdd.getStorageLevel) } } diff --git a/core/src/main/scala/org/apache/spark/ui/ConsoleProgressBar.scala b/core/src/main/scala/org/apache/spark/ui/ConsoleProgressBar.scala index 67f572e79314d..77c0bc8b5360a 100644 --- a/core/src/main/scala/org/apache/spark/ui/ConsoleProgressBar.scala +++ b/core/src/main/scala/org/apache/spark/ui/ConsoleProgressBar.scala @@ -65,7 +65,7 @@ private[spark] class ConsoleProgressBar(sc: SparkContext) extends Logging { val stageIds = sc.statusTracker.getActiveStageIds() val stages = stageIds.map(sc.statusTracker.getStageInfo).flatten.filter(_.numTasks() > 1) .filter(now - _.submissionTime() > FIRST_DELAY).sortBy(_.stageId()) - if (stages.size > 0) { + if (stages.length > 0) { show(now, stages.take(3)) // display at most 3 stages in same time } } @@ -81,7 +81,7 @@ private[spark] class ConsoleProgressBar(sc: SparkContext) extends Logging { val total = s.numTasks() val header = s"[Stage ${s.stageId()}:" val tailer = s"(${s.numCompletedTasks()} + ${s.numActiveTasks()}) / $total]" - val w = width - header.size - tailer.size + val w = width - header.length - tailer.length val bar = if (w > 0) { val percent = w * s.numCompletedTasks() / total (0 until w).map { i => diff --git a/core/src/main/scala/org/apache/spark/ui/WebUI.scala b/core/src/main/scala/org/apache/spark/ui/WebUI.scala index 411fe8da780b2..384f2ad26e281 100644 --- a/core/src/main/scala/org/apache/spark/ui/WebUI.scala +++ b/core/src/main/scala/org/apache/spark/ui/WebUI.scala @@ -48,7 +48,7 @@ private[spark] abstract class WebUI( protected val handlers = ArrayBuffer[ServletContextHandler]() protected val pageToHandlers = new HashMap[WebUIPage, ArrayBuffer[ServletContextHandler]] protected var serverInfo: Option[ServerInfo] = None - protected val localHostName = Utils.localHostName() + protected val localHostName = Utils.localHostNameForURI() protected val publicHostName = Option(conf.getenv("SPARK_PUBLIC_DNS")).getOrElse(localHostName) private val className = Utils.getFormattedClassName(this) diff --git a/core/src/main/scala/org/apache/spark/util/EventLoop.scala b/core/src/main/scala/org/apache/spark/util/EventLoop.scala index b0ed908b84424..e9b2b8d24b476 100644 --- a/core/src/main/scala/org/apache/spark/util/EventLoop.scala +++ b/core/src/main/scala/org/apache/spark/util/EventLoop.scala @@ -76,9 +76,21 @@ private[spark] abstract class EventLoop[E](name: String) extends Logging { def stop(): Unit = { if (stopped.compareAndSet(false, true)) { eventThread.interrupt() - eventThread.join() - // Call onStop after the event thread exits to make sure onReceive happens before onStop - onStop() + var onStopCalled = false + try { + eventThread.join() + // Call onStop after the event thread exits to make sure onReceive happens before onStop + onStopCalled = true + onStop() + } catch { + case ie: InterruptedException => + Thread.currentThread().interrupt() + if (!onStopCalled) { + // ie is thrown from `eventThread.join()`. Otherwise, we should not call `onStop` since + // it's already called. + onStop() + } + } } else { // Keep quiet to allow calling `stop` multiple times. } diff --git a/core/src/main/scala/org/apache/spark/util/Utils.scala b/core/src/main/scala/org/apache/spark/util/Utils.scala index bb8bd1015668a..a541d660cd5c6 100644 --- a/core/src/main/scala/org/apache/spark/util/Utils.scala +++ b/core/src/main/scala/org/apache/spark/util/Utils.scala @@ -34,6 +34,7 @@ import scala.util.Try import scala.util.control.{ControlThrowable, NonFatal} import com.google.common.io.{ByteStreams, Files} +import com.google.common.net.InetAddresses import com.google.common.util.concurrent.ThreadFactoryBuilder import org.apache.commons.lang3.SystemUtils import org.apache.hadoop.conf.Configuration @@ -313,7 +314,7 @@ private[spark] object Utils extends Logging { transferToEnabled: Boolean = false): Long = { var count = 0L - try { + tryWithSafeFinally { if (in.isInstanceOf[FileInputStream] && out.isInstanceOf[FileOutputStream] && transferToEnabled) { // When both streams are File stream, use transferTo to improve copy performance. @@ -353,7 +354,7 @@ private[spark] object Utils extends Logging { } } count - } finally { + } { if (closeStreams) { try { in.close() @@ -789,13 +790,12 @@ private[spark] object Utils extends Logging { * Get the local host's IP address in dotted-quad format (e.g. 1.2.3.4). * Note, this is typically not used from within core spark. */ - lazy val localIpAddress: String = findLocalIpAddress() - lazy val localIpAddressHostname: String = getAddressHostName(localIpAddress) + private lazy val localIpAddress: InetAddress = findLocalInetAddress() - private def findLocalIpAddress(): String = { + private def findLocalInetAddress(): InetAddress = { val defaultIpOverride = System.getenv("SPARK_LOCAL_IP") if (defaultIpOverride != null) { - defaultIpOverride + InetAddress.getByName(defaultIpOverride) } else { val address = InetAddress.getLocalHost if (address.isLoopbackAddress) { @@ -806,15 +806,20 @@ private[spark] object Utils extends Logging { // It's more proper to pick ip address following system output order. val activeNetworkIFs = NetworkInterface.getNetworkInterfaces.toList val reOrderedNetworkIFs = if (isWindows) activeNetworkIFs else activeNetworkIFs.reverse + for (ni <- reOrderedNetworkIFs) { - for (addr <- ni.getInetAddresses if !addr.isLinkLocalAddress && - !addr.isLoopbackAddress && addr.isInstanceOf[Inet4Address]) { + val addresses = ni.getInetAddresses.toList + .filterNot(addr => addr.isLinkLocalAddress || addr.isLoopbackAddress) + if (addresses.nonEmpty) { + val addr = addresses.find(_.isInstanceOf[Inet4Address]).getOrElse(addresses.head) + // because of Inet6Address.toHostName may add interface at the end if it knows about it + val strippedAddress = InetAddress.getByAddress(addr.getAddress) // We've found an address that looks reasonable! logWarning("Your hostname, " + InetAddress.getLocalHost.getHostName + " resolves to" + - " a loopback address: " + address.getHostAddress + "; using " + addr.getHostAddress + - " instead (on interface " + ni.getName + ")") + " a loopback address: " + address.getHostAddress + "; using " + + strippedAddress.getHostAddress + " instead (on interface " + ni.getName + ")") logWarning("Set SPARK_LOCAL_IP if you need to bind to another address") - return addr.getHostAddress + return strippedAddress } } logWarning("Your hostname, " + InetAddress.getLocalHost.getHostName + " resolves to" + @@ -822,7 +827,7 @@ private[spark] object Utils extends Logging { " external IP address!") logWarning("Set SPARK_LOCAL_IP if you need to bind to another address") } - address.getHostAddress + address } } @@ -842,11 +847,14 @@ private[spark] object Utils extends Logging { * Get the local machine's hostname. */ def localHostName(): String = { - customHostname.getOrElse(localIpAddressHostname) + customHostname.getOrElse(localIpAddress.getHostAddress) } - def getAddressHostName(address: String): String = { - InetAddress.getByName(address).getHostName + /** + * Get the local machine's URI. + */ + def localHostNameForURI(): String = { + customHostname.getOrElse(InetAddresses.toUriString(localIpAddress)) } def checkHost(host: String, message: String = "") { @@ -1214,6 +1222,54 @@ private[spark] object Utils extends Logging { } } + /** Executes the given block. Log non-fatal errors if any, and only throw fatal errors */ + def tryLogNonFatalError(block: => Unit) { + try { + block + } catch { + case NonFatal(t) => + logError(s"Uncaught exception in thread ${Thread.currentThread().getName}", t) + } + } + + /** + * Execute a block of code, then a finally block, but if exceptions happen in + * the finally block, do not suppress the original exception. + * + * This is primarily an issue with `finally { out.close() }` blocks, where + * close needs to be called to clean up `out`, but if an exception happened + * in `out.write`, it's likely `out` may be corrupted and `out.close` will + * fail as well. This would then suppress the original/likely more meaningful + * exception from the original `out.write` call. + */ + def tryWithSafeFinally[T](block: => T)(finallyBlock: => Unit): T = { + // It would be nice to find a method on Try that did this + var originalThrowable: Throwable = null + try { + block + } catch { + case t: Throwable => + // Purposefully not using NonFatal, because even fatal exceptions + // we don't want to have our finallyBlock suppress + originalThrowable = t + throw originalThrowable + } finally { + try { + finallyBlock + } catch { + case t: Throwable => + if (originalThrowable != null) { + // We could do originalThrowable.addSuppressed(t), but it's + // not available in JDK 1.6. + logWarning(s"Suppressing exception in finally: " + t.getMessage, t) + throw originalThrowable + } else { + throw t + } + } + } + } + /** Default filtering function for finding call sites using `getCallSite`. */ private def coreExclusionFunction(className: String): Boolean = { // A regular expression to match classes of the "core" Spark API that we want to skip when @@ -2074,7 +2130,7 @@ private[spark] class RedirectThread( override def run() { scala.util.control.Exception.ignoring(classOf[IOException]) { // FIXME: We copy the stream on the level of bytes to avoid encoding problems. - try { + Utils.tryWithSafeFinally { val buf = new Array[Byte](1024) var len = in.read(buf) while (len != -1) { @@ -2082,7 +2138,7 @@ private[spark] class RedirectThread( out.flush() len = in.read(buf) } - } finally { + } { if (propagateEof) { out.close() } diff --git a/core/src/main/scala/org/apache/spark/util/collection/BitSet.scala b/core/src/main/scala/org/apache/spark/util/collection/BitSet.scala index f79e8e0491ea1..41cb8cfe2afa3 100644 --- a/core/src/main/scala/org/apache/spark/util/collection/BitSet.scala +++ b/core/src/main/scala/org/apache/spark/util/collection/BitSet.scala @@ -39,7 +39,7 @@ class BitSet(numBits: Int) extends Serializable { val wordIndex = bitIndex >> 6 // divide by 64 var i = 0 while(i < wordIndex) { words(i) = -1; i += 1 } - if(wordIndex < words.size) { + if(wordIndex < words.length) { // Set the remaining bits (note that the mask could still be zero) val mask = ~(-1L << (bitIndex & 0x3f)) words(wordIndex) |= mask diff --git a/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala b/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala index b962c101c91da..035f3767ff554 100644 --- a/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala +++ b/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala @@ -664,6 +664,8 @@ private[spark] class ExternalSorter[K, V, C]( } /** + * Exposed for testing purposes. + * * Return an iterator over all the data written to this object, grouped by partition and * aggregated by the requested aggregator. For each partition we then have an iterator over its * contents, and these are expected to be accessed in order (you can't "skip ahead" to one @@ -673,7 +675,7 @@ private[spark] class ExternalSorter[K, V, C]( * For now, we just merge all the spilled files in once pass, but this can be modified to * support hierarchical merging. */ - def partitionedIterator: Iterator[(Int, Iterator[Product2[K, C]])] = { + def partitionedIterator: Iterator[(Int, Iterator[Product2[K, C]])] = { val usingMap = aggregator.isDefined val collection: SizeTrackingPairCollection[(Int, K), C] = if (usingMap) map else buffer if (spills.isEmpty && partitionWriters == null) { @@ -726,25 +728,19 @@ private[spark] class ExternalSorter[K, V, C]( // this simple we spill out the current in-memory collection so that everything is in files. spillToPartitionFiles(if (aggregator.isDefined) map else buffer) partitionWriters.foreach(_.commitAndClose()) - var out: FileOutputStream = null - var in: FileInputStream = null + val out = new FileOutputStream(outputFile, true) val writeStartTime = System.nanoTime - try { - out = new FileOutputStream(outputFile, true) + util.Utils.tryWithSafeFinally { for (i <- 0 until numPartitions) { - in = new FileInputStream(partitionWriters(i).fileSegment().file) - val size = org.apache.spark.util.Utils.copyStream(in, out, false, transferToEnabled) - in.close() - in = null - lengths(i) = size - } - } finally { - if (out != null) { - out.close() - } - if (in != null) { - in.close() + val in = new FileInputStream(partitionWriters(i).fileSegment().file) + util.Utils.tryWithSafeFinally { + lengths(i) = org.apache.spark.util.Utils.copyStream(in, out, false, transferToEnabled) + } { + in.close() + } } + } { + out.close() context.taskMetrics.shuffleWriteMetrics.foreach( _.incShuffleWriteTime(System.nanoTime - writeStartTime)) } @@ -781,7 +777,7 @@ private[spark] class ExternalSorter[K, V, C]( /** * Read a partition file back as an iterator (used in our iterator method) */ - def readPartitionFile(writer: BlockObjectWriter): Iterator[Product2[K, C]] = { + private def readPartitionFile(writer: BlockObjectWriter): Iterator[Product2[K, C]] = { if (writer.isOpen) { writer.commitAndClose() } diff --git a/core/src/test/resources/log4j.properties b/core/src/test/resources/log4j.properties index 287c8e3563503..eb3b1999eb996 100644 --- a/core/src/test/resources/log4j.properties +++ b/core/src/test/resources/log4j.properties @@ -24,5 +24,5 @@ log4j.appender.file.layout=org.apache.log4j.PatternLayout log4j.appender.file.layout.ConversionPattern=%d{yy/MM/dd HH:mm:ss.SSS} %t %p %c{1}: %m%n # Ignore messages below warning level from Jetty, because it's a bit verbose -log4j.logger.org.eclipse.jetty=WARN -org.eclipse.jetty.LEVEL=WARN +log4j.logger.org.spark-project.jetty=WARN +org.spark-project.jetty.LEVEL=WARN diff --git a/core/src/test/scala/org/apache/spark/AccumulatorSuite.scala b/core/src/test/scala/org/apache/spark/AccumulatorSuite.scala index bd0f8bdefa171..75399461f2a5f 100644 --- a/core/src/test/scala/org/apache/spark/AccumulatorSuite.scala +++ b/core/src/test/scala/org/apache/spark/AccumulatorSuite.scala @@ -27,19 +27,20 @@ import org.scalatest.Matchers class AccumulatorSuite extends FunSuite with Matchers with LocalSparkContext { - implicit def setAccum[A] = new AccumulableParam[mutable.Set[A], A] { - def addInPlace(t1: mutable.Set[A], t2: mutable.Set[A]) : mutable.Set[A] = { - t1 ++= t2 - t1 - } - def addAccumulator(t1: mutable.Set[A], t2: A) : mutable.Set[A] = { - t1 += t2 - t1 - } - def zero(t: mutable.Set[A]) : mutable.Set[A] = { - new mutable.HashSet[A]() + implicit def setAccum[A]: AccumulableParam[mutable.Set[A], A] = + new AccumulableParam[mutable.Set[A], A] { + def addInPlace(t1: mutable.Set[A], t2: mutable.Set[A]) : mutable.Set[A] = { + t1 ++= t2 + t1 + } + def addAccumulator(t1: mutable.Set[A], t2: A) : mutable.Set[A] = { + t1 += t2 + t1 + } + def zero(t: mutable.Set[A]) : mutable.Set[A] = { + new mutable.HashSet[A]() + } } - } test ("basic accumulation"){ sc = new SparkContext("local", "test") @@ -49,11 +50,10 @@ class AccumulatorSuite extends FunSuite with Matchers with LocalSparkContext { d.foreach{x => acc += x} acc.value should be (210) - - val longAcc = sc.accumulator(0l) + val longAcc = sc.accumulator(0L) val maxInt = Integer.MAX_VALUE.toLong d.foreach{x => longAcc += maxInt + x} - longAcc.value should be (210l + maxInt * 20) + longAcc.value should be (210L + maxInt * 20) } test ("value not assignable from tasks") { diff --git a/core/src/test/scala/org/apache/spark/CacheManagerSuite.scala b/core/src/test/scala/org/apache/spark/CacheManagerSuite.scala index 4b25c200a695a..70529d9216591 100644 --- a/core/src/test/scala/org/apache/spark/CacheManagerSuite.scala +++ b/core/src/test/scala/org/apache/spark/CacheManagerSuite.scala @@ -45,16 +45,17 @@ class CacheManagerSuite extends FunSuite with LocalSparkContext with BeforeAndAf rdd = new RDD[Int](sc, Nil) { override def getPartitions: Array[Partition] = Array(split) override val getDependencies = List[Dependency[_]]() - override def compute(split: Partition, context: TaskContext) = Array(1, 2, 3, 4).iterator + override def compute(split: Partition, context: TaskContext): Iterator[Int] = + Array(1, 2, 3, 4).iterator } rdd2 = new RDD[Int](sc, List(new OneToOneDependency(rdd))) { override def getPartitions: Array[Partition] = firstParent[Int].partitions - override def compute(split: Partition, context: TaskContext) = + override def compute(split: Partition, context: TaskContext): Iterator[Int] = firstParent[Int].iterator(split, context) }.cache() rdd3 = new RDD[Int](sc, List(new OneToOneDependency(rdd2))) { override def getPartitions: Array[Partition] = firstParent[Int].partitions - override def compute(split: Partition, context: TaskContext) = + override def compute(split: Partition, context: TaskContext): Iterator[Int] = firstParent[Int].iterator(split, context) }.cache() } diff --git a/core/src/test/scala/org/apache/spark/CheckpointSuite.scala b/core/src/test/scala/org/apache/spark/CheckpointSuite.scala index 32abc65385267..e1faddeabec79 100644 --- a/core/src/test/scala/org/apache/spark/CheckpointSuite.scala +++ b/core/src/test/scala/org/apache/spark/CheckpointSuite.scala @@ -75,7 +75,8 @@ class CheckpointSuite extends FunSuite with LocalSparkContext with Logging { assert(sc.checkpointFile[Int](parCollection.getCheckpointFile.get).collect() === result) assert(parCollection.dependencies != Nil) assert(parCollection.partitions.length === numPartitions) - assert(parCollection.partitions.toList === parCollection.checkpointData.get.getPartitions.toList) + assert(parCollection.partitions.toList === + parCollection.checkpointData.get.getPartitions.toList) assert(parCollection.collect() === result) } @@ -102,13 +103,13 @@ class CheckpointSuite extends FunSuite with LocalSparkContext with Logging { } test("UnionRDD") { - def otherRDD = sc.makeRDD(1 to 10, 1) + def otherRDD: RDD[Int] = sc.makeRDD(1 to 10, 1) testRDD(_.union(otherRDD)) testRDDPartitions(_.union(otherRDD)) } test("CartesianRDD") { - def otherRDD = sc.makeRDD(1 to 10, 1) + def otherRDD: RDD[Int] = sc.makeRDD(1 to 10, 1) testRDD(new CartesianRDD(sc, _, otherRDD)) testRDDPartitions(new CartesianRDD(sc, _, otherRDD)) @@ -223,7 +224,8 @@ class CheckpointSuite extends FunSuite with LocalSparkContext with Logging { val partitionAfterCheckpoint = serializeDeserialize( unionRDD.partitions.head.asInstanceOf[PartitionerAwareUnionRDDPartition]) assert( - partitionBeforeCheckpoint.parents.head.getClass != partitionAfterCheckpoint.parents.head.getClass, + partitionBeforeCheckpoint.parents.head.getClass != + partitionAfterCheckpoint.parents.head.getClass, "PartitionerAwareUnionRDDPartition.parents not updated after parent RDD is checkpointed" ) } @@ -358,7 +360,7 @@ class CheckpointSuite extends FunSuite with LocalSparkContext with Logging { * Generate an pair RDD (with partitioner) such that both the RDD and its partitions * have large size. */ - def generateFatPairRDD() = { + def generateFatPairRDD(): RDD[(Int, Int)] = { new FatPairRDD(sc.makeRDD(1 to 100, 4), partitioner).mapValues(x => x) } @@ -445,7 +447,8 @@ class FatPairRDD(parent: RDD[Int], _partitioner: Partitioner) extends RDD[(Int, object CheckpointSuite { // This is a custom cogroup function that does not use mapValues like // the PairRDDFunctions.cogroup() - def cogroup[K, V](first: RDD[(K, V)], second: RDD[(K, V)], part: Partitioner) = { + def cogroup[K, V](first: RDD[(K, V)], second: RDD[(K, V)], part: Partitioner) + : RDD[(K, Array[Iterable[V]])] = { new CoGroupedRDD[K]( Seq(first.asInstanceOf[RDD[(K, _)]], second.asInstanceOf[RDD[(K, _)]]), part diff --git a/core/src/test/scala/org/apache/spark/ContextCleanerSuite.scala b/core/src/test/scala/org/apache/spark/ContextCleanerSuite.scala index cdfaacee7da40..1de169d964d23 100644 --- a/core/src/test/scala/org/apache/spark/ContextCleanerSuite.scala +++ b/core/src/test/scala/org/apache/spark/ContextCleanerSuite.scala @@ -64,7 +64,7 @@ abstract class ContextCleanerSuiteBase(val shuffleManager: Class[_] = classOf[Ha } } - //------ Helper functions ------ + // ------ Helper functions ------ protected def newRDD() = sc.makeRDD(1 to 10) protected def newPairRDD() = newRDD().map(_ -> 1) @@ -370,7 +370,7 @@ class CleanerTester( val cleanerListener = new CleanerListener { def rddCleaned(rddId: Int): Unit = { toBeCleanedRDDIds -= rddId - logInfo("RDD "+ rddId + " cleaned") + logInfo("RDD " + rddId + " cleaned") } def shuffleCleaned(shuffleId: Int): Unit = { diff --git a/core/src/test/scala/org/apache/spark/ExecutorAllocationManagerSuite.scala b/core/src/test/scala/org/apache/spark/ExecutorAllocationManagerSuite.scala index abfcee75728dc..3ded1e4af8742 100644 --- a/core/src/test/scala/org/apache/spark/ExecutorAllocationManagerSuite.scala +++ b/core/src/test/scala/org/apache/spark/ExecutorAllocationManagerSuite.scala @@ -19,7 +19,7 @@ package org.apache.spark import scala.collection.mutable -import org.scalatest.{FunSuite, PrivateMethodTester} +import org.scalatest.{BeforeAndAfter, FunSuite, PrivateMethodTester} import org.apache.spark.executor.TaskMetrics import org.apache.spark.scheduler._ import org.apache.spark.scheduler.cluster.ExecutorInfo @@ -28,10 +28,20 @@ import org.apache.spark.util.ManualClock /** * Test add and remove behavior of ExecutorAllocationManager. */ -class ExecutorAllocationManagerSuite extends FunSuite with LocalSparkContext { +class ExecutorAllocationManagerSuite extends FunSuite with LocalSparkContext with BeforeAndAfter { import ExecutorAllocationManager._ import ExecutorAllocationManagerSuite._ + private val contexts = new mutable.ListBuffer[SparkContext]() + + before { + contexts.clear() + } + + after { + contexts.foreach(_.stop()) + } + test("verify min/max executors") { val conf = new SparkConf() .setMaster("local") @@ -39,18 +49,19 @@ class ExecutorAllocationManagerSuite extends FunSuite with LocalSparkContext { .set("spark.dynamicAllocation.enabled", "true") .set("spark.dynamicAllocation.testing", "true") val sc0 = new SparkContext(conf) + contexts += sc0 assert(sc0.executorAllocationManager.isDefined) sc0.stop() // Min < 0 val conf1 = conf.clone().set("spark.dynamicAllocation.minExecutors", "-1") - intercept[SparkException] { new SparkContext(conf1) } + intercept[SparkException] { contexts += new SparkContext(conf1) } SparkEnv.get.stop() SparkContext.clearActiveContext() // Max < 0 val conf2 = conf.clone().set("spark.dynamicAllocation.maxExecutors", "-1") - intercept[SparkException] { new SparkContext(conf2) } + intercept[SparkException] { contexts += new SparkContext(conf2) } SparkEnv.get.stop() SparkContext.clearActiveContext() @@ -665,16 +676,6 @@ class ExecutorAllocationManagerSuite extends FunSuite with LocalSparkContext { assert(removeTimes(manager).contains("executor-2")) assert(!removeTimes(manager).contains("executor-1")) } -} - -/** - * Helper methods for testing ExecutorAllocationManager. - * This includes methods to access private methods and fields in ExecutorAllocationManager. - */ -private object ExecutorAllocationManagerSuite extends PrivateMethodTester { - private val schedulerBacklogTimeout = 1L - private val sustainedSchedulerBacklogTimeout = 2L - private val executorIdleTimeout = 3L private def createSparkContext(minExecutors: Int = 1, maxExecutors: Int = 5): SparkContext = { val conf = new SparkConf() @@ -688,9 +689,22 @@ private object ExecutorAllocationManagerSuite extends PrivateMethodTester { sustainedSchedulerBacklogTimeout.toString) .set("spark.dynamicAllocation.executorIdleTimeout", executorIdleTimeout.toString) .set("spark.dynamicAllocation.testing", "true") - new SparkContext(conf) + val sc = new SparkContext(conf) + contexts += sc + sc } +} + +/** + * Helper methods for testing ExecutorAllocationManager. + * This includes methods to access private methods and fields in ExecutorAllocationManager. + */ +private object ExecutorAllocationManagerSuite extends PrivateMethodTester { + private val schedulerBacklogTimeout = 1L + private val sustainedSchedulerBacklogTimeout = 2L + private val executorIdleTimeout = 3L + private def createStageInfo(stageId: Int, numTasks: Int): StageInfo = { new StageInfo(stageId, 0, "name", numTasks, Seq.empty, "no details") } diff --git a/core/src/test/scala/org/apache/spark/FileSuite.scala b/core/src/test/scala/org/apache/spark/FileSuite.scala index 7acd27c735727..c8f08eed47c76 100644 --- a/core/src/test/scala/org/apache/spark/FileSuite.scala +++ b/core/src/test/scala/org/apache/spark/FileSuite.scala @@ -222,7 +222,7 @@ class FileSuite extends FunSuite with LocalSparkContext { val nums = sc.makeRDD(1 to 3).map(x => (new IntWritable(x), new Text("a" * x))) nums.saveAsSequenceFile(outputDir) val output = - sc.newAPIHadoopFile[IntWritable, Text, SequenceFileInputFormat[IntWritable, Text]](outputDir) + sc.newAPIHadoopFile[IntWritable, Text, SequenceFileInputFormat[IntWritable, Text]](outputDir) assert(output.map(_.toString).collect().toList === List("(1,a)", "(2,aa)", "(3,aaa)")) } @@ -451,7 +451,8 @@ class FileSuite extends FunSuite with LocalSparkContext { test ("prevent user from overwriting the empty directory (new Hadoop API)") { sc = new SparkContext("local", "test") - val randomRDD = sc.parallelize(Array(("key1", "a"), ("key2", "a"), ("key3", "b"), ("key4", "c")), 1) + val randomRDD = sc.parallelize( + Array(("key1", "a"), ("key2", "a"), ("key3", "b"), ("key4", "c")), 1) intercept[FileAlreadyExistsException] { randomRDD.saveAsNewAPIHadoopFile[NewTextOutputFormat[String, String]](tempDir.getPath) } @@ -459,8 +460,10 @@ class FileSuite extends FunSuite with LocalSparkContext { test ("prevent user from overwriting the non-empty directory (new Hadoop API)") { sc = new SparkContext("local", "test") - val randomRDD = sc.parallelize(Array(("key1", "a"), ("key2", "a"), ("key3", "b"), ("key4", "c")), 1) - randomRDD.saveAsNewAPIHadoopFile[NewTextOutputFormat[String, String]](tempDir.getPath + "/output") + val randomRDD = sc.parallelize( + Array(("key1", "a"), ("key2", "a"), ("key3", "b"), ("key4", "c")), 1) + randomRDD.saveAsNewAPIHadoopFile[NewTextOutputFormat[String, String]]( + tempDir.getPath + "/output") assert(new File(tempDir.getPath + "/output/part-r-00000").exists() === true) intercept[FileAlreadyExistsException] { randomRDD.saveAsNewAPIHadoopFile[NewTextOutputFormat[String, String]](tempDir.getPath) @@ -471,16 +474,20 @@ class FileSuite extends FunSuite with LocalSparkContext { val sf = new SparkConf() sf.setAppName("test").setMaster("local").set("spark.hadoop.validateOutputSpecs", "false") sc = new SparkContext(sf) - val randomRDD = sc.parallelize(Array(("key1", "a"), ("key2", "a"), ("key3", "b"), ("key4", "c")), 1) - randomRDD.saveAsNewAPIHadoopFile[NewTextOutputFormat[String, String]](tempDir.getPath + "/output") + val randomRDD = sc.parallelize( + Array(("key1", "a"), ("key2", "a"), ("key3", "b"), ("key4", "c")), 1) + randomRDD.saveAsNewAPIHadoopFile[NewTextOutputFormat[String, String]]( + tempDir.getPath + "/output") assert(new File(tempDir.getPath + "/output/part-r-00000").exists() === true) - randomRDD.saveAsNewAPIHadoopFile[NewTextOutputFormat[String, String]](tempDir.getPath + "/output") + randomRDD.saveAsNewAPIHadoopFile[NewTextOutputFormat[String, String]]( + tempDir.getPath + "/output") assert(new File(tempDir.getPath + "/output/part-r-00000").exists() === true) } test ("save Hadoop Dataset through old Hadoop API") { sc = new SparkContext("local", "test") - val randomRDD = sc.parallelize(Array(("key1", "a"), ("key2", "a"), ("key3", "b"), ("key4", "c")), 1) + val randomRDD = sc.parallelize( + Array(("key1", "a"), ("key2", "a"), ("key3", "b"), ("key4", "c")), 1) val job = new JobConf() job.setOutputKeyClass(classOf[String]) job.setOutputValueClass(classOf[String]) @@ -492,7 +499,8 @@ class FileSuite extends FunSuite with LocalSparkContext { test ("save Hadoop Dataset through new Hadoop API") { sc = new SparkContext("local", "test") - val randomRDD = sc.parallelize(Array(("key1", "a"), ("key2", "a"), ("key3", "b"), ("key4", "c")), 1) + val randomRDD = sc.parallelize( + Array(("key1", "a"), ("key2", "a"), ("key3", "b"), ("key4", "c")), 1) val job = new Job(sc.hadoopConfiguration) job.setOutputKeyClass(classOf[String]) job.setOutputValueClass(classOf[String]) diff --git a/core/src/test/scala/org/apache/spark/HeartbeatReceiverSuite.scala b/core/src/test/scala/org/apache/spark/HeartbeatReceiverSuite.scala new file mode 100644 index 0000000000000..0fd570e5297d9 --- /dev/null +++ b/core/src/test/scala/org/apache/spark/HeartbeatReceiverSuite.scala @@ -0,0 +1,81 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark + +import scala.concurrent.duration._ +import scala.language.postfixOps + +import org.apache.spark.executor.TaskMetrics +import org.apache.spark.storage.BlockManagerId +import org.scalatest.FunSuite +import org.mockito.Mockito.{mock, spy, verify, when} +import org.mockito.Matchers +import org.mockito.Matchers._ + +import org.apache.spark.scheduler.TaskScheduler +import org.apache.spark.util.RpcUtils +import org.scalatest.concurrent.Eventually._ + +class HeartbeatReceiverSuite extends FunSuite with LocalSparkContext { + + test("HeartbeatReceiver") { + sc = spy(new SparkContext("local[2]", "test")) + val scheduler = mock(classOf[TaskScheduler]) + when(scheduler.executorHeartbeatReceived(any(), any(), any())).thenReturn(true) + when(sc.taskScheduler).thenReturn(scheduler) + + val heartbeatReceiver = new HeartbeatReceiver(sc) + sc.env.rpcEnv.setupEndpoint("heartbeat", heartbeatReceiver).send(TaskSchedulerIsSet) + eventually(timeout(5 seconds), interval(5 millis)) { + assert(heartbeatReceiver.scheduler != null) + } + val receiverRef = RpcUtils.makeDriverRef("heartbeat", sc.conf, sc.env.rpcEnv) + + val metrics = new TaskMetrics + val blockManagerId = BlockManagerId("executor-1", "localhost", 12345) + val response = receiverRef.askWithReply[HeartbeatResponse]( + Heartbeat("executor-1", Array(1L -> metrics), blockManagerId)) + + verify(scheduler).executorHeartbeatReceived( + Matchers.eq("executor-1"), Matchers.eq(Array(1L -> metrics)), Matchers.eq(blockManagerId)) + assert(false === response.reregisterBlockManager) + } + + test("HeartbeatReceiver re-register") { + sc = spy(new SparkContext("local[2]", "test")) + val scheduler = mock(classOf[TaskScheduler]) + when(scheduler.executorHeartbeatReceived(any(), any(), any())).thenReturn(false) + when(sc.taskScheduler).thenReturn(scheduler) + + val heartbeatReceiver = new HeartbeatReceiver(sc) + sc.env.rpcEnv.setupEndpoint("heartbeat", heartbeatReceiver).send(TaskSchedulerIsSet) + eventually(timeout(5 seconds), interval(5 millis)) { + assert(heartbeatReceiver.scheduler != null) + } + val receiverRef = RpcUtils.makeDriverRef("heartbeat", sc.conf, sc.env.rpcEnv) + + val metrics = new TaskMetrics + val blockManagerId = BlockManagerId("executor-1", "localhost", 12345) + val response = receiverRef.askWithReply[HeartbeatResponse]( + Heartbeat("executor-1", Array(1L -> metrics), blockManagerId)) + + verify(scheduler).executorHeartbeatReceived( + Matchers.eq("executor-1"), Matchers.eq(Array(1L -> metrics)), Matchers.eq(blockManagerId)) + assert(true === response.reregisterBlockManager) + } +} diff --git a/core/src/test/scala/org/apache/spark/ImplicitOrderingSuite.scala b/core/src/test/scala/org/apache/spark/ImplicitOrderingSuite.scala index d895230ecf330..51348c039b5c9 100644 --- a/core/src/test/scala/org/apache/spark/ImplicitOrderingSuite.scala +++ b/core/src/test/scala/org/apache/spark/ImplicitOrderingSuite.scala @@ -51,7 +51,7 @@ private object ImplicitOrderingSuite { override def compare(o: OrderedClass): Int = ??? } - def basicMapExpectations(rdd: RDD[Int]) = { + def basicMapExpectations(rdd: RDD[Int]): List[(Boolean, String)] = { List((rdd.map(x => (x, x)).keyOrdering.isDefined, "rdd.map(x => (x, x)).keyOrdering.isDefined"), (rdd.map(x => (1, x)).keyOrdering.isDefined, @@ -68,7 +68,7 @@ private object ImplicitOrderingSuite { "rdd.map(x => (new OrderedClass, x)).keyOrdering.isDefined")) } - def otherRDDMethodExpectations(rdd: RDD[Int]) = { + def otherRDDMethodExpectations(rdd: RDD[Int]): List[(Boolean, String)] = { List((rdd.groupBy(x => x).keyOrdering.isDefined, "rdd.groupBy(x => x).keyOrdering.isDefined"), (rdd.groupBy(x => new NonOrderedClass).keyOrdering.isEmpty, @@ -82,4 +82,4 @@ private object ImplicitOrderingSuite { (rdd.groupBy((x: Int) => x, new HashPartitioner(5)).keyOrdering.isDefined, "rdd.groupBy((x: Int) => x, new HashPartitioner(5)).keyOrdering.isDefined")) } -} \ No newline at end of file +} diff --git a/core/src/test/scala/org/apache/spark/JobCancellationSuite.scala b/core/src/test/scala/org/apache/spark/JobCancellationSuite.scala index 21487bc24d58a..4d3e09793faff 100644 --- a/core/src/test/scala/org/apache/spark/JobCancellationSuite.scala +++ b/core/src/test/scala/org/apache/spark/JobCancellationSuite.scala @@ -188,7 +188,7 @@ class JobCancellationSuite extends FunSuite with Matchers with BeforeAndAfter val rdd = sc.parallelize(1 to 10, 2).map { i => JobCancellationSuite.twoJobsSharingStageSemaphore.acquire() (i, i) - }.reduceByKey(_+_) + }.reduceByKey(_ + _) val f1 = rdd.collectAsync() val f2 = rdd.countAsync() diff --git a/core/src/test/scala/org/apache/spark/LocalSparkContext.scala b/core/src/test/scala/org/apache/spark/LocalSparkContext.scala index 53e367a61715b..8bf2e55defd02 100644 --- a/core/src/test/scala/org/apache/spark/LocalSparkContext.scala +++ b/core/src/test/scala/org/apache/spark/LocalSparkContext.scala @@ -37,7 +37,7 @@ trait LocalSparkContext extends BeforeAndAfterEach with BeforeAndAfterAll { self super.afterEach() } - def resetSparkContext() = { + def resetSparkContext(): Unit = { LocalSparkContext.stop(sc) sc = null } @@ -54,7 +54,7 @@ object LocalSparkContext { } /** Runs `f` by passing in `sc` and ensures that `sc` is stopped. */ - def withSpark[T](sc: SparkContext)(f: SparkContext => T) = { + def withSpark[T](sc: SparkContext)(f: SparkContext => T): T = { try { f(sc) } finally { diff --git a/core/src/test/scala/org/apache/spark/MapOutputTrackerSuite.scala b/core/src/test/scala/org/apache/spark/MapOutputTrackerSuite.scala index ccfe0678cb1c3..6295d34be5ca9 100644 --- a/core/src/test/scala/org/apache/spark/MapOutputTrackerSuite.scala +++ b/core/src/test/scala/org/apache/spark/MapOutputTrackerSuite.scala @@ -17,34 +17,37 @@ package org.apache.spark -import scala.concurrent.Await - -import akka.actor._ -import akka.testkit.TestActorRef +import org.mockito.Mockito._ +import org.mockito.Matchers.{any, isA} import org.scalatest.FunSuite +import org.apache.spark.rpc.{RpcAddress, RpcEndpointRef, RpcCallContext, RpcEnv} import org.apache.spark.scheduler.{CompressedMapStatus, MapStatus} import org.apache.spark.shuffle.FetchFailedException import org.apache.spark.storage.BlockManagerId -import org.apache.spark.util.AkkaUtils class MapOutputTrackerSuite extends FunSuite { private val conf = new SparkConf + def createRpcEnv(name: String, host: String = "localhost", port: Int = 0, + securityManager: SecurityManager = new SecurityManager(conf)): RpcEnv = { + RpcEnv.create(name, host, port, conf, securityManager) + } + test("master start and stop") { - val actorSystem = ActorSystem("test") + val rpcEnv = createRpcEnv("test") val tracker = new MapOutputTrackerMaster(conf) - tracker.trackerActor = - actorSystem.actorOf(Props(new MapOutputTrackerMasterActor(tracker, conf))) + tracker.trackerEndpoint = rpcEnv.setupEndpoint(MapOutputTracker.ENDPOINT_NAME, + new MapOutputTrackerMasterEndpoint(rpcEnv, tracker, conf)) tracker.stop() - actorSystem.shutdown() + rpcEnv.shutdown() } test("master register shuffle and fetch") { - val actorSystem = ActorSystem("test") + val rpcEnv = createRpcEnv("test") val tracker = new MapOutputTrackerMaster(conf) - tracker.trackerActor = - actorSystem.actorOf(Props(new MapOutputTrackerMasterActor(tracker, conf))) + tracker.trackerEndpoint = rpcEnv.setupEndpoint(MapOutputTracker.ENDPOINT_NAME, + new MapOutputTrackerMasterEndpoint(rpcEnv, tracker, conf)) tracker.registerShuffle(10, 2) assert(tracker.containsShuffle(10)) val size1000 = MapStatus.decompressSize(MapStatus.compressSize(1000L)) @@ -57,13 +60,14 @@ class MapOutputTrackerSuite extends FunSuite { assert(statuses.toSeq === Seq((BlockManagerId("a", "hostA", 1000), size1000), (BlockManagerId("b", "hostB", 1000), size10000))) tracker.stop() - actorSystem.shutdown() + rpcEnv.shutdown() } test("master register and unregister shuffle") { - val actorSystem = ActorSystem("test") + val rpcEnv = createRpcEnv("test") val tracker = new MapOutputTrackerMaster(conf) - tracker.trackerActor = actorSystem.actorOf(Props(new MapOutputTrackerMasterActor(tracker, conf))) + tracker.trackerEndpoint = rpcEnv.setupEndpoint(MapOutputTracker.ENDPOINT_NAME, + new MapOutputTrackerMasterEndpoint(rpcEnv, tracker, conf)) tracker.registerShuffle(10, 2) val compressedSize1000 = MapStatus.compressSize(1000L) val compressedSize10000 = MapStatus.compressSize(10000L) @@ -78,14 +82,14 @@ class MapOutputTrackerSuite extends FunSuite { assert(tracker.getServerStatuses(10, 0).isEmpty) tracker.stop() - actorSystem.shutdown() + rpcEnv.shutdown() } test("master register shuffle and unregister map output and fetch") { - val actorSystem = ActorSystem("test") + val rpcEnv = createRpcEnv("test") val tracker = new MapOutputTrackerMaster(conf) - tracker.trackerActor = - actorSystem.actorOf(Props(new MapOutputTrackerMasterActor(tracker, conf))) + tracker.trackerEndpoint = rpcEnv.setupEndpoint(MapOutputTracker.ENDPOINT_NAME, + new MapOutputTrackerMasterEndpoint(rpcEnv, tracker, conf)) tracker.registerShuffle(10, 2) val compressedSize1000 = MapStatus.compressSize(1000L) val compressedSize10000 = MapStatus.compressSize(10000L) @@ -104,25 +108,21 @@ class MapOutputTrackerSuite extends FunSuite { intercept[FetchFailedException] { tracker.getServerStatuses(10, 1) } tracker.stop() - actorSystem.shutdown() + rpcEnv.shutdown() } test("remote fetch") { val hostname = "localhost" - val (actorSystem, boundPort) = AkkaUtils.createActorSystem("spark", hostname, 0, conf = conf, - securityManager = new SecurityManager(conf)) + val rpcEnv = createRpcEnv("spark", hostname, 0, new SecurityManager(conf)) val masterTracker = new MapOutputTrackerMaster(conf) - masterTracker.trackerActor = actorSystem.actorOf( - Props(new MapOutputTrackerMasterActor(masterTracker, conf)), "MapOutputTracker") + masterTracker.trackerEndpoint = rpcEnv.setupEndpoint(MapOutputTracker.ENDPOINT_NAME, + new MapOutputTrackerMasterEndpoint(rpcEnv, masterTracker, conf)) - val (slaveSystem, _) = AkkaUtils.createActorSystem("spark-slave", hostname, 0, conf = conf, - securityManager = new SecurityManager(conf)) + val slaveRpcEnv = createRpcEnv("spark-slave", hostname, 0, new SecurityManager(conf)) val slaveTracker = new MapOutputTrackerWorker(conf) - val selection = slaveSystem.actorSelection( - AkkaUtils.address(AkkaUtils.protocol(slaveSystem), "spark", "localhost", boundPort, "MapOutputTracker")) - val timeout = AkkaUtils.lookupTimeout(conf) - slaveTracker.trackerActor = Await.result(selection.resolveOne(timeout), timeout) + slaveTracker.trackerEndpoint = + slaveRpcEnv.setupEndpointRef("spark", rpcEnv.address, MapOutputTracker.ENDPOINT_NAME) masterTracker.registerShuffle(10, 1) masterTracker.incrementEpoch() @@ -147,8 +147,8 @@ class MapOutputTrackerSuite extends FunSuite { masterTracker.stop() slaveTracker.stop() - actorSystem.shutdown() - slaveSystem.shutdown() + rpcEnv.shutdown() + slaveRpcEnv.shutdown() } test("remote fetch below akka frame size") { @@ -157,19 +157,24 @@ class MapOutputTrackerSuite extends FunSuite { newConf.set("spark.akka.askTimeout", "1") // Fail fast val masterTracker = new MapOutputTrackerMaster(conf) - val actorSystem = ActorSystem("test") - val actorRef = TestActorRef[MapOutputTrackerMasterActor]( - Props(new MapOutputTrackerMasterActor(masterTracker, newConf)))(actorSystem) - val masterActor = actorRef.underlyingActor + val rpcEnv = createRpcEnv("spark") + val masterEndpoint = new MapOutputTrackerMasterEndpoint(rpcEnv, masterTracker, newConf) + rpcEnv.setupEndpoint(MapOutputTracker.ENDPOINT_NAME, masterEndpoint) // Frame size should be ~123B, and no exception should be thrown masterTracker.registerShuffle(10, 1) masterTracker.registerMapOutput(10, 0, MapStatus( BlockManagerId("88", "mph", 1000), Array.fill[Long](10)(0))) - masterActor.receive(GetMapOutputStatuses(10)) + val sender = mock(classOf[RpcEndpointRef]) + when(sender.address).thenReturn(RpcAddress("localhost", 12345)) + val rpcCallContext = mock(classOf[RpcCallContext]) + when(rpcCallContext.sender).thenReturn(sender) + masterEndpoint.receiveAndReply(rpcCallContext)(GetMapOutputStatuses(10)) + verify(rpcCallContext).reply(any()) + verify(rpcCallContext, never()).sendFailure(any()) // masterTracker.stop() // this throws an exception - actorSystem.shutdown() + rpcEnv.shutdown() } test("remote fetch exceeds akka frame size") { @@ -178,12 +183,11 @@ class MapOutputTrackerSuite extends FunSuite { newConf.set("spark.akka.askTimeout", "1") // Fail fast val masterTracker = new MapOutputTrackerMaster(conf) - val actorSystem = ActorSystem("test") - val actorRef = TestActorRef[MapOutputTrackerMasterActor]( - Props(new MapOutputTrackerMasterActor(masterTracker, newConf)))(actorSystem) - val masterActor = actorRef.underlyingActor + val rpcEnv = createRpcEnv("test") + val masterEndpoint = new MapOutputTrackerMasterEndpoint(rpcEnv, masterTracker, newConf) + rpcEnv.setupEndpoint(MapOutputTracker.ENDPOINT_NAME, masterEndpoint) - // Frame size should be ~1.1MB, and MapOutputTrackerMasterActor should throw exception. + // Frame size should be ~1.1MB, and MapOutputTrackerMasterEndpoint should throw exception. // Note that the size is hand-selected here because map output statuses are compressed before // being sent. masterTracker.registerShuffle(20, 100) @@ -191,9 +195,15 @@ class MapOutputTrackerSuite extends FunSuite { masterTracker.registerMapOutput(20, i, new CompressedMapStatus( BlockManagerId("999", "mps", 1000), Array.fill[Long](4000000)(0))) } - intercept[SparkException] { masterActor.receive(GetMapOutputStatuses(20)) } + val sender = mock(classOf[RpcEndpointRef]) + when(sender.address).thenReturn(RpcAddress("localhost", 12345)) + val rpcCallContext = mock(classOf[RpcCallContext]) + when(rpcCallContext.sender).thenReturn(sender) + masterEndpoint.receiveAndReply(rpcCallContext)(GetMapOutputStatuses(20)) + verify(rpcCallContext, never()).reply(any()) + verify(rpcCallContext).sendFailure(isA(classOf[SparkException])) // masterTracker.stop() // this throws an exception - actorSystem.shutdown() + rpcEnv.shutdown() } } diff --git a/core/src/test/scala/org/apache/spark/PartitioningSuite.scala b/core/src/test/scala/org/apache/spark/PartitioningSuite.scala index b7532314ada01..47e3bf6e1ac41 100644 --- a/core/src/test/scala/org/apache/spark/PartitioningSuite.scala +++ b/core/src/test/scala/org/apache/spark/PartitioningSuite.scala @@ -92,7 +92,7 @@ class PartitioningSuite extends FunSuite with SharedSparkContext with PrivateMet test("RangePartitioner for keys that are not Comparable (but with Ordering)") { // Row does not extend Comparable, but has an implicit Ordering defined. implicit object RowOrdering extends Ordering[Row] { - override def compare(x: Row, y: Row) = x.value - y.value + override def compare(x: Row, y: Row): Int = x.value - y.value } val rdd = sc.parallelize(1 to 4500).map(x => (Row(x), Row(x))) @@ -212,20 +212,24 @@ class PartitioningSuite extends FunSuite with SharedSparkContext with PrivateMet val arrPairs: RDD[(Array[Int], Int)] = sc.parallelize(Array(1, 2, 3, 4), 2).map(x => (Array(x), x)) - assert(intercept[SparkException]{ arrs.distinct() }.getMessage.contains("array")) + def verify(testFun: => Unit): Unit = { + intercept[SparkException](testFun).getMessage.contains("array") + } + + verify(arrs.distinct()) // We can't catch all usages of arrays, since they might occur inside other collections: // assert(fails { arrPairs.distinct() }) - assert(intercept[SparkException]{ arrPairs.partitionBy(new HashPartitioner(2)) }.getMessage.contains("array")) - assert(intercept[SparkException]{ arrPairs.join(arrPairs) }.getMessage.contains("array")) - assert(intercept[SparkException]{ arrPairs.leftOuterJoin(arrPairs) }.getMessage.contains("array")) - assert(intercept[SparkException]{ arrPairs.rightOuterJoin(arrPairs) }.getMessage.contains("array")) - assert(intercept[SparkException]{ arrPairs.fullOuterJoin(arrPairs) }.getMessage.contains("array")) - assert(intercept[SparkException]{ arrPairs.groupByKey() }.getMessage.contains("array")) - assert(intercept[SparkException]{ arrPairs.countByKey() }.getMessage.contains("array")) - assert(intercept[SparkException]{ arrPairs.countByKeyApprox(1) }.getMessage.contains("array")) - assert(intercept[SparkException]{ arrPairs.cogroup(arrPairs) }.getMessage.contains("array")) - assert(intercept[SparkException]{ arrPairs.reduceByKeyLocally(_ + _) }.getMessage.contains("array")) - assert(intercept[SparkException]{ arrPairs.reduceByKey(_ + _) }.getMessage.contains("array")) + verify(arrPairs.partitionBy(new HashPartitioner(2))) + verify(arrPairs.join(arrPairs)) + verify(arrPairs.leftOuterJoin(arrPairs)) + verify(arrPairs.rightOuterJoin(arrPairs)) + verify(arrPairs.fullOuterJoin(arrPairs)) + verify(arrPairs.groupByKey()) + verify(arrPairs.countByKey()) + verify(arrPairs.countByKeyApprox(1)) + verify(arrPairs.cogroup(arrPairs)) + verify(arrPairs.reduceByKeyLocally(_ + _)) + verify(arrPairs.reduceByKey(_ + _)) } test("zero-length partitions should be correctly handled") { diff --git a/core/src/test/scala/org/apache/spark/SSLOptionsSuite.scala b/core/src/test/scala/org/apache/spark/SSLOptionsSuite.scala index 444a33371bd71..93f46ef11c0e2 100644 --- a/core/src/test/scala/org/apache/spark/SSLOptionsSuite.scala +++ b/core/src/test/scala/org/apache/spark/SSLOptionsSuite.scala @@ -36,7 +36,8 @@ class SSLOptionsSuite extends FunSuite with BeforeAndAfterAll { conf.set("spark.ssl.keyPassword", "password") conf.set("spark.ssl.trustStore", trustStorePath) conf.set("spark.ssl.trustStorePassword", "password") - conf.set("spark.ssl.enabledAlgorithms", "TLS_RSA_WITH_AES_128_CBC_SHA, TLS_RSA_WITH_AES_256_CBC_SHA") + conf.set("spark.ssl.enabledAlgorithms", + "TLS_RSA_WITH_AES_128_CBC_SHA, TLS_RSA_WITH_AES_256_CBC_SHA") conf.set("spark.ssl.protocol", "SSLv3") val opts = SSLOptions.parse(conf, "spark.ssl") @@ -52,7 +53,8 @@ class SSLOptionsSuite extends FunSuite with BeforeAndAfterAll { assert(opts.keyStorePassword === Some("password")) assert(opts.keyPassword === Some("password")) assert(opts.protocol === Some("SSLv3")) - assert(opts.enabledAlgorithms === Set("TLS_RSA_WITH_AES_128_CBC_SHA", "TLS_RSA_WITH_AES_256_CBC_SHA")) + assert(opts.enabledAlgorithms === + Set("TLS_RSA_WITH_AES_128_CBC_SHA", "TLS_RSA_WITH_AES_256_CBC_SHA")) } test("test resolving property with defaults specified ") { @@ -66,7 +68,8 @@ class SSLOptionsSuite extends FunSuite with BeforeAndAfterAll { conf.set("spark.ssl.keyPassword", "password") conf.set("spark.ssl.trustStore", trustStorePath) conf.set("spark.ssl.trustStorePassword", "password") - conf.set("spark.ssl.enabledAlgorithms", "TLS_RSA_WITH_AES_128_CBC_SHA, TLS_RSA_WITH_AES_256_CBC_SHA") + conf.set("spark.ssl.enabledAlgorithms", + "TLS_RSA_WITH_AES_128_CBC_SHA, TLS_RSA_WITH_AES_256_CBC_SHA") conf.set("spark.ssl.protocol", "SSLv3") val defaultOpts = SSLOptions.parse(conf, "spark.ssl", defaults = None) @@ -83,7 +86,8 @@ class SSLOptionsSuite extends FunSuite with BeforeAndAfterAll { assert(opts.keyStorePassword === Some("password")) assert(opts.keyPassword === Some("password")) assert(opts.protocol === Some("SSLv3")) - assert(opts.enabledAlgorithms === Set("TLS_RSA_WITH_AES_128_CBC_SHA", "TLS_RSA_WITH_AES_256_CBC_SHA")) + assert(opts.enabledAlgorithms === + Set("TLS_RSA_WITH_AES_128_CBC_SHA", "TLS_RSA_WITH_AES_256_CBC_SHA")) } test("test whether defaults can be overridden ") { @@ -99,7 +103,8 @@ class SSLOptionsSuite extends FunSuite with BeforeAndAfterAll { conf.set("spark.ssl.keyPassword", "password") conf.set("spark.ssl.trustStore", trustStorePath) conf.set("spark.ssl.trustStorePassword", "password") - conf.set("spark.ssl.enabledAlgorithms", "TLS_RSA_WITH_AES_128_CBC_SHA, TLS_RSA_WITH_AES_256_CBC_SHA") + conf.set("spark.ssl.enabledAlgorithms", + "TLS_RSA_WITH_AES_128_CBC_SHA, TLS_RSA_WITH_AES_256_CBC_SHA") conf.set("spark.ui.ssl.enabledAlgorithms", "ABC, DEF") conf.set("spark.ssl.protocol", "SSLv3") diff --git a/core/src/test/scala/org/apache/spark/SSLSampleConfigs.scala b/core/src/test/scala/org/apache/spark/SSLSampleConfigs.scala index ace8123a8961f..308b9ea17708d 100644 --- a/core/src/test/scala/org/apache/spark/SSLSampleConfigs.scala +++ b/core/src/test/scala/org/apache/spark/SSLSampleConfigs.scala @@ -21,10 +21,11 @@ import java.io.File object SSLSampleConfigs { val keyStorePath = new File(this.getClass.getResource("/keystore").toURI).getAbsolutePath - val untrustedKeyStorePath = new File(this.getClass.getResource("/untrusted-keystore").toURI).getAbsolutePath + val untrustedKeyStorePath = new File( + this.getClass.getResource("/untrusted-keystore").toURI).getAbsolutePath val trustStorePath = new File(this.getClass.getResource("/truststore").toURI).getAbsolutePath - def sparkSSLConfig() = { + def sparkSSLConfig(): SparkConf = { val conf = new SparkConf(loadDefaults = false) conf.set("spark.ssl.enabled", "true") conf.set("spark.ssl.keyStore", keyStorePath) @@ -38,7 +39,7 @@ object SSLSampleConfigs { conf } - def sparkSSLConfigUntrusted() = { + def sparkSSLConfigUntrusted(): SparkConf = { val conf = new SparkConf(loadDefaults = false) conf.set("spark.ssl.enabled", "true") conf.set("spark.ssl.keyStore", untrustedKeyStorePath) diff --git a/core/src/test/scala/org/apache/spark/ShuffleSuite.scala b/core/src/test/scala/org/apache/spark/ShuffleSuite.scala index f57921b768310..d7180516029d5 100644 --- a/core/src/test/scala/org/apache/spark/ShuffleSuite.scala +++ b/core/src/test/scala/org/apache/spark/ShuffleSuite.scala @@ -142,7 +142,7 @@ abstract class ShuffleSuite extends FunSuite with Matchers with LocalSparkContex test("shuffle on mutable pairs") { // Use a local cluster with 2 processes to make sure there are both local and remote blocks sc = new SparkContext("local-cluster[2,1,512]", "test", conf) - def p[T1, T2](_1: T1, _2: T2) = MutablePair(_1, _2) + def p[T1, T2](_1: T1, _2: T2): MutablePair[T1, T2] = MutablePair(_1, _2) val data = Array(p(1, 1), p(1, 2), p(1, 3), p(2, 1)) val pairs: RDD[MutablePair[Int, Int]] = sc.parallelize(data, 2) val results = new ShuffledRDD[Int, Int, Int](pairs, @@ -155,7 +155,7 @@ abstract class ShuffleSuite extends FunSuite with Matchers with LocalSparkContex // This is not in SortingSuite because of the local cluster setup. // Use a local cluster with 2 processes to make sure there are both local and remote blocks sc = new SparkContext("local-cluster[2,1,512]", "test", conf) - def p[T1, T2](_1: T1, _2: T2) = MutablePair(_1, _2) + def p[T1, T2](_1: T1, _2: T2): MutablePair[T1, T2] = MutablePair(_1, _2) val data = Array(p(1, 11), p(3, 33), p(100, 100), p(2, 22)) val pairs: RDD[MutablePair[Int, Int]] = sc.parallelize(data, 2) val results = new OrderedRDDFunctions[Int, Int, MutablePair[Int, Int]](pairs) @@ -169,7 +169,7 @@ abstract class ShuffleSuite extends FunSuite with Matchers with LocalSparkContex test("cogroup using mutable pairs") { // Use a local cluster with 2 processes to make sure there are both local and remote blocks sc = new SparkContext("local-cluster[2,1,512]", "test", conf) - def p[T1, T2](_1: T1, _2: T2) = MutablePair(_1, _2) + def p[T1, T2](_1: T1, _2: T2): MutablePair[T1, T2] = MutablePair(_1, _2) val data1 = Seq(p(1, 1), p(1, 2), p(1, 3), p(2, 1)) val data2 = Seq(p(1, "11"), p(1, "12"), p(2, "22"), p(3, "3")) val pairs1: RDD[MutablePair[Int, Int]] = sc.parallelize(data1, 2) @@ -196,7 +196,7 @@ abstract class ShuffleSuite extends FunSuite with Matchers with LocalSparkContex test("subtract mutable pairs") { // Use a local cluster with 2 processes to make sure there are both local and remote blocks sc = new SparkContext("local-cluster[2,1,512]", "test", conf) - def p[T1, T2](_1: T1, _2: T2) = MutablePair(_1, _2) + def p[T1, T2](_1: T1, _2: T2): MutablePair[T1, T2] = MutablePair(_1, _2) val data1 = Seq(p(1, 1), p(1, 2), p(1, 3), p(2, 1), p(3, 33)) val data2 = Seq(p(1, "11"), p(1, "12"), p(2, "22")) val pairs1: RDD[MutablePair[Int, Int]] = sc.parallelize(data1, 2) @@ -242,14 +242,14 @@ abstract class ShuffleSuite extends FunSuite with Matchers with LocalSparkContex shuffleSpillCompress <- Set(true, false); shuffleCompress <- Set(true, false) ) { - val conf = new SparkConf() + val myConf = conf.clone() .setAppName("test") .setMaster("local") .set("spark.shuffle.spill.compress", shuffleSpillCompress.toString) .set("spark.shuffle.compress", shuffleCompress.toString) .set("spark.shuffle.memoryFraction", "0.001") resetSparkContext() - sc = new SparkContext(conf) + sc = new SparkContext(myConf) try { sc.parallelize(0 until 100000).map(i => (i / 4, i)).groupByKey().collect() } catch { diff --git a/core/src/test/scala/org/apache/spark/SparkContextSuite.scala b/core/src/test/scala/org/apache/spark/SparkContextSuite.scala index b07c4d93db4e6..94be1c6d6397c 100644 --- a/core/src/test/scala/org/apache/spark/SparkContextSuite.scala +++ b/core/src/test/scala/org/apache/spark/SparkContextSuite.scala @@ -18,6 +18,7 @@ package org.apache.spark import java.io.File +import java.util.concurrent.TimeUnit import com.google.common.base.Charsets._ import com.google.common.io.Files @@ -25,9 +26,11 @@ import com.google.common.io.Files import org.scalatest.FunSuite import org.apache.hadoop.io.BytesWritable - import org.apache.spark.util.Utils +import scala.concurrent.Await +import scala.concurrent.duration.Duration + class SparkContextSuite extends FunSuite with LocalSparkContext { test("Only one SparkContext may be active at a time") { @@ -111,11 +114,13 @@ class SparkContextSuite extends FunSuite with LocalSparkContext { if (length1 != gotten1.length()) { throw new SparkException( - s"file has different length $length1 than added file ${gotten1.length()} : " + absolutePath1) + s"file has different length $length1 than added file ${gotten1.length()} : " + + absolutePath1) } if (length2 != gotten2.length()) { throw new SparkException( - s"file has different length $length2 than added file ${gotten2.length()} : " + absolutePath2) + s"file has different length $length2 than added file ${gotten2.length()} : " + + absolutePath2) } if (absolutePath1 == gotten1.getAbsolutePath) { @@ -173,4 +178,19 @@ class SparkContextSuite extends FunSuite with LocalSparkContext { sc.stop() } } + + test("Cancelling job group should not cause SparkContext to shutdown (SPARK-6414)") { + try { + sc = new SparkContext(new SparkConf().setAppName("test").setMaster("local")) + val future = sc.parallelize(Seq(0)).foreachAsync(_ => {Thread.sleep(1000L)}) + sc.cancelJobGroup("nonExistGroupId") + Await.ready(future, Duration(2, TimeUnit.SECONDS)) + + // In SPARK-6414, sc.cancelJobGroup will cause NullPointerException and cause + // SparkContext to shutdown, so the following assertion will fail. + assert(sc.parallelize(1 to 10).count() == 10L) + } finally { + sc.stop() + } + } } diff --git a/core/src/test/scala/org/apache/spark/StatusTrackerSuite.scala b/core/src/test/scala/org/apache/spark/StatusTrackerSuite.scala index 41d6ea29d5b06..084eb237d70d1 100644 --- a/core/src/test/scala/org/apache/spark/StatusTrackerSuite.scala +++ b/core/src/test/scala/org/apache/spark/StatusTrackerSuite.scala @@ -82,7 +82,8 @@ class StatusTrackerSuite extends FunSuite with Matchers with LocalSparkContext { secondJobFuture.jobIds.head } eventually(timeout(10 seconds)) { - sc.statusTracker.getJobIdsForGroup("my-job-group").toSet should be (Set(firstJobId, secondJobId)) + sc.statusTracker.getJobIdsForGroup("my-job-group").toSet should be ( + Set(firstJobId, secondJobId)) } } -} \ No newline at end of file +} diff --git a/core/src/test/scala/org/apache/spark/broadcast/BroadcastSuite.scala b/core/src/test/scala/org/apache/spark/broadcast/BroadcastSuite.scala index af3272692d7a1..c8fdfa693912e 100644 --- a/core/src/test/scala/org/apache/spark/broadcast/BroadcastSuite.scala +++ b/core/src/test/scala/org/apache/spark/broadcast/BroadcastSuite.scala @@ -33,7 +33,7 @@ class DummyBroadcastClass(rdd: RDD[Int]) extends Serializable { val broadcast = rdd.context.broadcast(list) val bid = broadcast.id - def doSomething() = { + def doSomething(): Set[(Int, Boolean)] = { rdd.map { x => val bm = SparkEnv.get.blockManager // Check if broadcast block was fetched diff --git a/core/src/test/scala/org/apache/spark/deploy/ClientSuite.scala b/core/src/test/scala/org/apache/spark/deploy/ClientSuite.scala index 518073dcbb64e..745f9eeee7536 100644 --- a/core/src/test/scala/org/apache/spark/deploy/ClientSuite.scala +++ b/core/src/test/scala/org/apache/spark/deploy/ClientSuite.scala @@ -46,5 +46,4 @@ class ClientSuite extends FunSuite with Matchers { // Invalid syntax. ClientArguments.isValidJarUrl("hdfs:") should be (false) } - } diff --git a/core/src/test/scala/org/apache/spark/deploy/JsonProtocolSuite.scala b/core/src/test/scala/org/apache/spark/deploy/JsonProtocolSuite.scala index b63cfd50fab02..bb21f18596f7a 100644 --- a/core/src/test/scala/org/apache/spark/deploy/JsonProtocolSuite.scala +++ b/core/src/test/scala/org/apache/spark/deploy/JsonProtocolSuite.scala @@ -100,13 +100,13 @@ class JsonProtocolSuite extends FunSuite with JsonTestUtils { appInfo } - def createDriverCommand() = new Command( + def createDriverCommand(): Command = new Command( "org.apache.spark.FakeClass", Seq("some arg --and-some options -g foo"), Map(("K1", "V1"), ("K2", "V2")), Seq("cp1", "cp2"), Seq("lp1", "lp2"), Seq("-Dfoo") ) - def createDriverDesc() = new DriverDescription("hdfs://some-dir/some.jar", 100, 3, - false, createDriverCommand()) + def createDriverDesc(): DriverDescription = + new DriverDescription("hdfs://some-dir/some.jar", 100, 3, false, createDriverCommand()) def createDriverInfo(): DriverInfo = new DriverInfo(3, "driver-3", createDriverDesc(), new Date()) diff --git a/core/src/test/scala/org/apache/spark/deploy/LogUrlsStandaloneSuite.scala b/core/src/test/scala/org/apache/spark/deploy/LogUrlsStandaloneSuite.scala index 54dd7c9c45c61..9cdb42814ca32 100644 --- a/core/src/test/scala/org/apache/spark/deploy/LogUrlsStandaloneSuite.scala +++ b/core/src/test/scala/org/apache/spark/deploy/LogUrlsStandaloneSuite.scala @@ -56,7 +56,7 @@ class LogUrlsStandaloneSuite extends FunSuite with LocalSparkContext { test("verify that log urls reflect SPARK_PUBLIC_DNS (SPARK-6175)") { val SPARK_PUBLIC_DNS = "public_dns" class MySparkConf extends SparkConf(false) { - override def getenv(name: String) = { + override def getenv(name: String): String = { if (name == "SPARK_PUBLIC_DNS") SPARK_PUBLIC_DNS else super.getenv(name) } diff --git a/core/src/test/scala/org/apache/spark/deploy/history/FsHistoryProviderSuite.scala b/core/src/test/scala/org/apache/spark/deploy/history/FsHistoryProviderSuite.scala index e908ba604ebed..fcae603c7d18e 100644 --- a/core/src/test/scala/org/apache/spark/deploy/history/FsHistoryProviderSuite.scala +++ b/core/src/test/scala/org/apache/spark/deploy/history/FsHistoryProviderSuite.scala @@ -50,7 +50,7 @@ class FsHistoryProviderSuite extends FunSuite with BeforeAndAfter with Matchers inProgress: Boolean, codec: Option[String] = None): File = { val ip = if (inProgress) EventLoggingListener.IN_PROGRESS else "" - val logUri = EventLoggingListener.getLogPath(testDir.getAbsolutePath, appId) + val logUri = EventLoggingListener.getLogPath(testDir.toURI, appId) val logPath = new URI(logUri).getPath + ip new File(logPath) } diff --git a/core/src/test/scala/org/apache/spark/deploy/history/HistoryServerSuite.scala b/core/src/test/scala/org/apache/spark/deploy/history/HistoryServerSuite.scala index 06180add40a62..3988a3fd36b34 100644 --- a/core/src/test/scala/org/apache/spark/deploy/history/HistoryServerSuite.scala +++ b/core/src/test/scala/org/apache/spark/deploy/history/HistoryServerSuite.scala @@ -189,10 +189,10 @@ class HistoryServerSuite extends FunSuite with BeforeAndAfter with Matchers with when(historyServer.getProviderConfig()).thenReturn(Map[String, String]()) val page = new HistoryPage(historyServer) - //when + // when val response = page.render(request) - //then + // then val links = response \\ "a" val justHrefs = for { l <- links diff --git a/core/src/test/scala/org/apache/spark/deploy/rest/StandaloneRestSubmitSuite.scala b/core/src/test/scala/org/apache/spark/deploy/rest/StandaloneRestSubmitSuite.scala index 2fa90e3bd1c63..8e09976636386 100644 --- a/core/src/test/scala/org/apache/spark/deploy/rest/StandaloneRestSubmitSuite.scala +++ b/core/src/test/scala/org/apache/spark/deploy/rest/StandaloneRestSubmitSuite.scala @@ -508,7 +508,7 @@ private class DummyMaster( exception: Option[Exception] = None) extends Actor { - override def receive = { + override def receive: PartialFunction[Any, Unit] = { case RequestSubmitDriver(driverDesc) => sender ! SubmitDriverResponse(success = true, Some(submitId), submitMessage) case RequestKillDriver(driverId) => @@ -531,7 +531,7 @@ private class SmarterMaster extends Actor { private var counter: Int = 0 private val submittedDrivers = new mutable.HashMap[String, DriverState] - override def receive = { + override def receive: PartialFunction[Any, Unit] = { case RequestSubmitDriver(driverDesc) => val driverId = s"driver-$counter" submittedDrivers(driverId) = RUNNING diff --git a/core/src/test/scala/org/apache/spark/deploy/rest/SubmitRestProtocolSuite.scala b/core/src/test/scala/org/apache/spark/deploy/rest/SubmitRestProtocolSuite.scala index 1d64ec201e647..61071ee17256c 100644 --- a/core/src/test/scala/org/apache/spark/deploy/rest/SubmitRestProtocolSuite.scala +++ b/core/src/test/scala/org/apache/spark/deploy/rest/SubmitRestProtocolSuite.scala @@ -129,7 +129,8 @@ class SubmitRestProtocolSuite extends FunSuite { assert(newMessage.sparkProperties("spark.files") === "fireball.png") assert(newMessage.sparkProperties("spark.driver.memory") === "512m") assert(newMessage.sparkProperties("spark.driver.cores") === "180") - assert(newMessage.sparkProperties("spark.driver.extraJavaOptions") === " -Dslices=5 -Dcolor=mostly_red") + assert(newMessage.sparkProperties("spark.driver.extraJavaOptions") === + " -Dslices=5 -Dcolor=mostly_red") assert(newMessage.sparkProperties("spark.driver.extraClassPath") === "food-coloring.jar") assert(newMessage.sparkProperties("spark.driver.extraLibraryPath") === "pickle.jar") assert(newMessage.sparkProperties("spark.driver.supervise") === "false") diff --git a/core/src/test/scala/org/apache/spark/deploy/worker/ExecutorRunnerTest.scala b/core/src/test/scala/org/apache/spark/deploy/worker/ExecutorRunnerTest.scala index 6fca6321e5a1b..a8b9df227c996 100644 --- a/core/src/test/scala/org/apache/spark/deploy/worker/ExecutorRunnerTest.scala +++ b/core/src/test/scala/org/apache/spark/deploy/worker/ExecutorRunnerTest.scala @@ -35,7 +35,8 @@ class ExecutorRunnerTest extends FunSuite { val er = new ExecutorRunner(appId, 1, appDesc, 8, 500, null, "blah", "worker321", 123, "publicAddr", new File(sparkHome), new File("ooga"), "blah", new SparkConf, Seq("localDir"), ExecutorState.RUNNING) - val builder = CommandUtils.buildProcessBuilder(appDesc.command, 512, sparkHome, er.substituteVariables) + val builder = CommandUtils.buildProcessBuilder( + appDesc.command, 512, sparkHome, er.substituteVariables) assert(builder.command().last === appId) } } diff --git a/core/src/test/scala/org/apache/spark/deploy/worker/WorkerArgumentsTest.scala b/core/src/test/scala/org/apache/spark/deploy/worker/WorkerArgumentsTest.scala index 372d7aa453008..7cc2104281464 100644 --- a/core/src/test/scala/org/apache/spark/deploy/worker/WorkerArgumentsTest.scala +++ b/core/src/test/scala/org/apache/spark/deploy/worker/WorkerArgumentsTest.scala @@ -37,7 +37,7 @@ class WorkerArgumentsTest extends FunSuite { val args = Array("spark://localhost:0000 ") class MySparkConf extends SparkConf(false) { - override def getenv(name: String) = { + override def getenv(name: String): String = { if (name == "SPARK_WORKER_MEMORY") "50000" else super.getenv(name) } @@ -56,7 +56,7 @@ class WorkerArgumentsTest extends FunSuite { val args = Array("spark://localhost:0000 ") class MySparkConf extends SparkConf(false) { - override def getenv(name: String) = { + override def getenv(name: String): String = { if (name == "SPARK_WORKER_MEMORY") "5G" else super.getenv(name) } diff --git a/core/src/test/scala/org/apache/spark/deploy/worker/WorkerSuite.scala b/core/src/test/scala/org/apache/spark/deploy/worker/WorkerSuite.scala index 84e2fd7ad936d..450fba21f4b5c 100644 --- a/core/src/test/scala/org/apache/spark/deploy/worker/WorkerSuite.scala +++ b/core/src/test/scala/org/apache/spark/deploy/worker/WorkerSuite.scala @@ -24,8 +24,10 @@ import org.scalatest.{Matchers, FunSuite} class WorkerSuite extends FunSuite with Matchers { - def cmd(javaOpts: String*) = Command("", Seq.empty, Map.empty, Seq.empty, Seq.empty, Seq(javaOpts:_*)) - def conf(opts: (String, String)*) = new SparkConf(loadDefaults = false).setAll(opts) + def cmd(javaOpts: String*): Command = { + Command("", Seq.empty, Map.empty, Seq.empty, Seq.empty, Seq(javaOpts:_*)) + } + def conf(opts: (String, String)*): SparkConf = new SparkConf(loadDefaults = false).setAll(opts) test("test isUseLocalNodeSSLConfig") { Worker.isUseLocalNodeSSLConfig(cmd("-Dasdf=dfgh")) shouldBe false diff --git a/core/src/test/scala/org/apache/spark/metrics/InputOutputMetricsSuite.scala b/core/src/test/scala/org/apache/spark/metrics/InputOutputMetricsSuite.scala index 78fa98a3b9065..190b08d950a02 100644 --- a/core/src/test/scala/org/apache/spark/metrics/InputOutputMetricsSuite.scala +++ b/core/src/test/scala/org/apache/spark/metrics/InputOutputMetricsSuite.scala @@ -238,7 +238,7 @@ class InputOutputMetricsSuite extends FunSuite with SharedSparkContext sc.textFile(tmpFilePath, 4) .map(key => (key, 1)) - .reduceByKey(_+_) + .reduceByKey(_ + _) .saveAsTextFile("file://" + tmpFile.getAbsolutePath) sc.listenerBus.waitUntilEmpty(500) diff --git a/core/src/test/scala/org/apache/spark/metrics/MetricsConfigSuite.scala b/core/src/test/scala/org/apache/spark/metrics/MetricsConfigSuite.scala index 37e528435aa5d..100ac77dec1f7 100644 --- a/core/src/test/scala/org/apache/spark/metrics/MetricsConfigSuite.scala +++ b/core/src/test/scala/org/apache/spark/metrics/MetricsConfigSuite.scala @@ -35,7 +35,8 @@ class MetricsConfigSuite extends FunSuite with BeforeAndAfter { val property = conf.getInstance("random") assert(property.size() === 2) - assert(property.getProperty("sink.servlet.class") === "org.apache.spark.metrics.sink.MetricsServlet") + assert(property.getProperty("sink.servlet.class") === + "org.apache.spark.metrics.sink.MetricsServlet") assert(property.getProperty("sink.servlet.path") === "/metrics/json") } @@ -47,16 +48,20 @@ class MetricsConfigSuite extends FunSuite with BeforeAndAfter { assert(masterProp.size() === 5) assert(masterProp.getProperty("sink.console.period") === "20") assert(masterProp.getProperty("sink.console.unit") === "minutes") - assert(masterProp.getProperty("source.jvm.class") === "org.apache.spark.metrics.source.JvmSource") - assert(masterProp.getProperty("sink.servlet.class") === "org.apache.spark.metrics.sink.MetricsServlet") + assert(masterProp.getProperty("source.jvm.class") === + "org.apache.spark.metrics.source.JvmSource") + assert(masterProp.getProperty("sink.servlet.class") === + "org.apache.spark.metrics.sink.MetricsServlet") assert(masterProp.getProperty("sink.servlet.path") === "/metrics/master/json") val workerProp = conf.getInstance("worker") assert(workerProp.size() === 5) assert(workerProp.getProperty("sink.console.period") === "10") assert(workerProp.getProperty("sink.console.unit") === "seconds") - assert(workerProp.getProperty("source.jvm.class") === "org.apache.spark.metrics.source.JvmSource") - assert(workerProp.getProperty("sink.servlet.class") === "org.apache.spark.metrics.sink.MetricsServlet") + assert(workerProp.getProperty("source.jvm.class") === + "org.apache.spark.metrics.source.JvmSource") + assert(workerProp.getProperty("sink.servlet.class") === + "org.apache.spark.metrics.sink.MetricsServlet") assert(workerProp.getProperty("sink.servlet.path") === "/metrics/json") } diff --git a/core/src/test/scala/org/apache/spark/rdd/JdbcRDDSuite.scala b/core/src/test/scala/org/apache/spark/rdd/JdbcRDDSuite.scala index 0dc59888f7304..be8467354b222 100644 --- a/core/src/test/scala/org/apache/spark/rdd/JdbcRDDSuite.scala +++ b/core/src/test/scala/org/apache/spark/rdd/JdbcRDDSuite.scala @@ -80,7 +80,7 @@ class JdbcRDDSuite extends FunSuite with BeforeAndAfter with LocalSparkContext { (r: ResultSet) => { r.getInt(1) } ).cache() assert(rdd.count === 100) - assert(rdd.reduce(_+_) === 10100) + assert(rdd.reduce(_ + _) === 10100) } test("large id overflow") { @@ -92,7 +92,7 @@ class JdbcRDDSuite extends FunSuite with BeforeAndAfter with LocalSparkContext { 1131544775L, 567279358897692673L, 20, (r: ResultSet) => { r.getInt(1) } ).cache() assert(rdd.count === 100) - assert(rdd.reduce(_+_) === 5050) + assert(rdd.reduce(_ + _) === 5050) } after { diff --git a/core/src/test/scala/org/apache/spark/rdd/PairRDDFunctionsSuite.scala b/core/src/test/scala/org/apache/spark/rdd/PairRDDFunctionsSuite.scala index 108f70af43f37..ca0d953d306d8 100644 --- a/core/src/test/scala/org/apache/spark/rdd/PairRDDFunctionsSuite.scala +++ b/core/src/test/scala/org/apache/spark/rdd/PairRDDFunctionsSuite.scala @@ -168,13 +168,13 @@ class PairRDDFunctionsSuite extends FunSuite with SharedSparkContext { test("reduceByKey") { val pairs = sc.parallelize(Array((1, 1), (1, 2), (1, 3), (1, 1), (2, 1))) - val sums = pairs.reduceByKey(_+_).collect() + val sums = pairs.reduceByKey(_ + _).collect() assert(sums.toSet === Set((1, 7), (2, 1))) } test("reduceByKey with collectAsMap") { val pairs = sc.parallelize(Array((1, 1), (1, 2), (1, 3), (1, 1), (2, 1))) - val sums = pairs.reduceByKey(_+_).collectAsMap() + val sums = pairs.reduceByKey(_ + _).collectAsMap() assert(sums.size === 2) assert(sums(1) === 7) assert(sums(2) === 1) @@ -182,7 +182,7 @@ class PairRDDFunctionsSuite extends FunSuite with SharedSparkContext { test("reduceByKey with many output partitons") { val pairs = sc.parallelize(Array((1, 1), (1, 2), (1, 3), (1, 1), (2, 1))) - val sums = pairs.reduceByKey(_+_, 10).collect() + val sums = pairs.reduceByKey(_ + _, 10).collect() assert(sums.toSet === Set((1, 7), (2, 1))) } @@ -192,7 +192,7 @@ class PairRDDFunctionsSuite extends FunSuite with SharedSparkContext { def getPartition(key: Any) = key.asInstanceOf[Int] } val pairs = sc.parallelize(Array((1, 1), (1, 2), (1, 1), (0, 1))).partitionBy(p) - val sums = pairs.reduceByKey(_+_) + val sums = pairs.reduceByKey(_ + _) assert(sums.collect().toSet === Set((1, 4), (0, 1))) assert(sums.partitioner === Some(p)) // count the dependencies to make sure there is only 1 ShuffledRDD @@ -208,7 +208,7 @@ class PairRDDFunctionsSuite extends FunSuite with SharedSparkContext { } test("countApproxDistinctByKey") { - def error(est: Long, size: Long) = math.abs(est - size) / size.toDouble + def error(est: Long, size: Long): Double = math.abs(est - size) / size.toDouble /* Since HyperLogLog unique counting is approximate, and the relative standard deviation is * only a statistical bound, the tests can fail for large values of relativeSD. We will be using @@ -465,7 +465,7 @@ class PairRDDFunctionsSuite extends FunSuite with SharedSparkContext { test("foldByKey") { val pairs = sc.parallelize(Array((1, 1), (1, 2), (1, 3), (1, 1), (2, 1))) - val sums = pairs.foldByKey(0)(_+_).collect() + val sums = pairs.foldByKey(0)(_ + _).collect() assert(sums.toSet === Set((1, 7), (2, 1))) } @@ -505,7 +505,8 @@ class PairRDDFunctionsSuite extends FunSuite with SharedSparkContext { conf.setOutputCommitter(classOf[FakeOutputCommitter]) FakeOutputCommitter.ran = false - pairs.saveAsHadoopFile("ignored", pairs.keyClass, pairs.valueClass, classOf[FakeOutputFormat], conf) + pairs.saveAsHadoopFile( + "ignored", pairs.keyClass, pairs.valueClass, classOf[FakeOutputFormat], conf) assert(FakeOutputCommitter.ran, "OutputCommitter was never called") } @@ -552,7 +553,7 @@ class PairRDDFunctionsSuite extends FunSuite with SharedSparkContext { } private object StratifiedAuxiliary { - def stratifier (fractionPositive: Double) = { + def stratifier (fractionPositive: Double): (Int) => String = { (x: Int) => if (x % 10 < (10 * fractionPositive).toInt) "1" else "0" } @@ -572,7 +573,7 @@ class PairRDDFunctionsSuite extends FunSuite with SharedSparkContext { def testSampleExact(stratifiedData: RDD[(String, Int)], samplingRate: Double, seed: Long, - n: Long) = { + n: Long): Unit = { testBernoulli(stratifiedData, true, samplingRate, seed, n) testPoisson(stratifiedData, true, samplingRate, seed, n) } @@ -580,7 +581,7 @@ class PairRDDFunctionsSuite extends FunSuite with SharedSparkContext { def testSample(stratifiedData: RDD[(String, Int)], samplingRate: Double, seed: Long, - n: Long) = { + n: Long): Unit = { testBernoulli(stratifiedData, false, samplingRate, seed, n) testPoisson(stratifiedData, false, samplingRate, seed, n) } @@ -590,7 +591,7 @@ class PairRDDFunctionsSuite extends FunSuite with SharedSparkContext { exact: Boolean, samplingRate: Double, seed: Long, - n: Long) = { + n: Long): Unit = { val expectedSampleSize = stratifiedData.countByKey() .mapValues(count => math.ceil(count * samplingRate).toInt) val fractions = Map("1" -> samplingRate, "0" -> samplingRate) @@ -612,7 +613,7 @@ class PairRDDFunctionsSuite extends FunSuite with SharedSparkContext { exact: Boolean, samplingRate: Double, seed: Long, - n: Long) = { + n: Long): Unit = { val expectedSampleSize = stratifiedData.countByKey().mapValues(count => math.ceil(count * samplingRate).toInt) val fractions = Map("1" -> samplingRate, "0" -> samplingRate) @@ -701,27 +702,27 @@ class FakeOutputFormat() extends OutputFormat[Integer, Integer]() { */ class NewFakeWriter extends NewRecordWriter[Integer, Integer] { - def close(p1: NewTaskAttempContext) = () + def close(p1: NewTaskAttempContext): Unit = () - def write(p1: Integer, p2: Integer) = () + def write(p1: Integer, p2: Integer): Unit = () } class NewFakeCommitter extends NewOutputCommitter { - def setupJob(p1: NewJobContext) = () + def setupJob(p1: NewJobContext): Unit = () def needsTaskCommit(p1: NewTaskAttempContext): Boolean = false - def setupTask(p1: NewTaskAttempContext) = () + def setupTask(p1: NewTaskAttempContext): Unit = () - def commitTask(p1: NewTaskAttempContext) = () + def commitTask(p1: NewTaskAttempContext): Unit = () - def abortTask(p1: NewTaskAttempContext) = () + def abortTask(p1: NewTaskAttempContext): Unit = () } class NewFakeFormat() extends NewOutputFormat[Integer, Integer]() { - def checkOutputSpecs(p1: NewJobContext) = () + def checkOutputSpecs(p1: NewJobContext): Unit = () def getRecordWriter(p1: NewTaskAttempContext): NewRecordWriter[Integer, Integer] = { new NewFakeWriter() @@ -735,7 +736,7 @@ class NewFakeFormat() extends NewOutputFormat[Integer, Integer]() { class ConfigTestFormat() extends NewFakeFormat() with Configurable { var setConfCalled = false - def setConf(p1: Configuration) = { + def setConf(p1: Configuration): Unit = { setConfCalled = true () } diff --git a/core/src/test/scala/org/apache/spark/rdd/ParallelCollectionSplitSuite.scala b/core/src/test/scala/org/apache/spark/rdd/ParallelCollectionSplitSuite.scala index cd193ae4f5238..1880364581c1a 100644 --- a/core/src/test/scala/org/apache/spark/rdd/ParallelCollectionSplitSuite.scala +++ b/core/src/test/scala/org/apache/spark/rdd/ParallelCollectionSplitSuite.scala @@ -100,7 +100,7 @@ class ParallelCollectionSplitSuite extends FunSuite with Checkers { val data = 1 until 100 val slices = ParallelCollectionRDD.slice(data, 3) assert(slices.size === 3) - assert(slices.map(_.size).reduceLeft(_+_) === 99) + assert(slices.map(_.size).reduceLeft(_ + _) === 99) assert(slices.forall(_.isInstanceOf[Range])) } @@ -108,7 +108,7 @@ class ParallelCollectionSplitSuite extends FunSuite with Checkers { val data = 1 to 100 val slices = ParallelCollectionRDD.slice(data, 3) assert(slices.size === 3) - assert(slices.map(_.size).reduceLeft(_+_) === 100) + assert(slices.map(_.size).reduceLeft(_ + _) === 100) assert(slices.forall(_.isInstanceOf[Range])) } @@ -139,7 +139,7 @@ class ParallelCollectionSplitSuite extends FunSuite with Checkers { assert(slices(i).isInstanceOf[Range]) val range = slices(i).asInstanceOf[Range] assert(range.start === i * (N / 40), "slice " + i + " start") - assert(range.end === (i+1) * (N / 40), "slice " + i + " end") + assert(range.end === (i + 1) * (N / 40), "slice " + i + " end") assert(range.step === 1, "slice " + i + " step") } } @@ -156,7 +156,7 @@ class ParallelCollectionSplitSuite extends FunSuite with Checkers { val slices = ParallelCollectionRDD.slice(d, n) ("n slices" |: slices.size == n) && ("concat to d" |: Seq.concat(slices: _*).mkString(",") == d.mkString(",")) && - ("equal sizes" |: slices.map(_.size).forall(x => x==d.size/n || x==d.size/n+1)) + ("equal sizes" |: slices.map(_.size).forall(x => x == d.size / n || x == d.size /n + 1)) } check(prop) } @@ -174,7 +174,7 @@ class ParallelCollectionSplitSuite extends FunSuite with Checkers { ("n slices" |: slices.size == n) && ("all ranges" |: slices.forall(_.isInstanceOf[Range])) && ("concat to d" |: Seq.concat(slices: _*).mkString(",") == d.mkString(",")) && - ("equal sizes" |: slices.map(_.size).forall(x => x==d.size/n || x==d.size/n+1)) + ("equal sizes" |: slices.map(_.size).forall(x => x == d.size / n || x == d.size / n + 1)) } check(prop) } @@ -192,7 +192,7 @@ class ParallelCollectionSplitSuite extends FunSuite with Checkers { ("n slices" |: slices.size == n) && ("all ranges" |: slices.forall(_.isInstanceOf[Range])) && ("concat to d" |: Seq.concat(slices: _*).mkString(",") == d.mkString(",")) && - ("equal sizes" |: slices.map(_.size).forall(x => x==d.size/n || x==d.size/n+1)) + ("equal sizes" |: slices.map(_.size).forall(x => x == d.size / n || x == d.size / n + 1)) } check(prop) } @@ -201,7 +201,7 @@ class ParallelCollectionSplitSuite extends FunSuite with Checkers { val data = 1L until 100L val slices = ParallelCollectionRDD.slice(data, 3) assert(slices.size === 3) - assert(slices.map(_.size).reduceLeft(_+_) === 99) + assert(slices.map(_.size).reduceLeft(_ + _) === 99) assert(slices.forall(_.isInstanceOf[NumericRange[_]])) } @@ -209,7 +209,7 @@ class ParallelCollectionSplitSuite extends FunSuite with Checkers { val data = 1L to 100L val slices = ParallelCollectionRDD.slice(data, 3) assert(slices.size === 3) - assert(slices.map(_.size).reduceLeft(_+_) === 100) + assert(slices.map(_.size).reduceLeft(_ + _) === 100) assert(slices.forall(_.isInstanceOf[NumericRange[_]])) } @@ -217,7 +217,7 @@ class ParallelCollectionSplitSuite extends FunSuite with Checkers { val data = 1.0 until 100.0 by 1.0 val slices = ParallelCollectionRDD.slice(data, 3) assert(slices.size === 3) - assert(slices.map(_.size).reduceLeft(_+_) === 99) + assert(slices.map(_.size).reduceLeft(_ + _) === 99) assert(slices.forall(_.isInstanceOf[NumericRange[_]])) } @@ -225,7 +225,7 @@ class ParallelCollectionSplitSuite extends FunSuite with Checkers { val data = 1.0 to 100.0 by 1.0 val slices = ParallelCollectionRDD.slice(data, 3) assert(slices.size === 3) - assert(slices.map(_.size).reduceLeft(_+_) === 100) + assert(slices.map(_.size).reduceLeft(_ + _) === 100) assert(slices.forall(_.isInstanceOf[NumericRange[_]])) } diff --git a/core/src/test/scala/org/apache/spark/rdd/PartitionPruningRDDSuite.scala b/core/src/test/scala/org/apache/spark/rdd/PartitionPruningRDDSuite.scala index 8408d7e785c65..465068c6cbb16 100644 --- a/core/src/test/scala/org/apache/spark/rdd/PartitionPruningRDDSuite.scala +++ b/core/src/test/scala/org/apache/spark/rdd/PartitionPruningRDDSuite.scala @@ -23,7 +23,6 @@ import org.apache.spark.{Partition, SharedSparkContext, TaskContext} class PartitionPruningRDDSuite extends FunSuite with SharedSparkContext { - test("Pruned Partitions inherit locality prefs correctly") { val rdd = new RDD[Int](sc, Nil) { @@ -74,8 +73,6 @@ class PartitionPruningRDDSuite extends FunSuite with SharedSparkContext { } class TestPartition(i: Int, value: Int) extends Partition with Serializable { - def index = i - - def testValue = this.value - + def index: Int = i + def testValue: Int = this.value } diff --git a/core/src/test/scala/org/apache/spark/rdd/PartitionwiseSampledRDDSuite.scala b/core/src/test/scala/org/apache/spark/rdd/PartitionwiseSampledRDDSuite.scala index a0483886f8db3..0d1369c19c69e 100644 --- a/core/src/test/scala/org/apache/spark/rdd/PartitionwiseSampledRDDSuite.scala +++ b/core/src/test/scala/org/apache/spark/rdd/PartitionwiseSampledRDDSuite.scala @@ -35,7 +35,7 @@ class MockSampler extends RandomSampler[Long, Long] { Iterator(s) } - override def clone = new MockSampler + override def clone: MockSampler = new MockSampler } class PartitionwiseSampledRDDSuite extends FunSuite with SharedSparkContext { diff --git a/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala b/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala index bede1ffb3e2d0..df42faab64505 100644 --- a/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala +++ b/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala @@ -82,7 +82,7 @@ class RDDSuite extends FunSuite with SharedSparkContext { test("countApproxDistinct") { - def error(est: Long, size: Long) = math.abs(est - size) / size.toDouble + def error(est: Long, size: Long): Double = math.abs(est - size) / size.toDouble val size = 1000 val uniformDistro = for (i <- 1 to 5000) yield i % size @@ -100,7 +100,7 @@ class RDDSuite extends FunSuite with SharedSparkContext { } test("partitioner aware union") { - def makeRDDWithPartitioner(seq: Seq[Int]) = { + def makeRDDWithPartitioner(seq: Seq[Int]): RDD[Int] = { sc.makeRDD(seq, 1) .map(x => (x, null)) .partitionBy(new HashPartitioner(2)) @@ -159,8 +159,8 @@ class RDDSuite extends FunSuite with SharedSparkContext { test("treeAggregate") { val rdd = sc.makeRDD(-1000 until 1000, 10) - def seqOp = (c: Long, x: Int) => c + x - def combOp = (c1: Long, c2: Long) => c1 + c2 + def seqOp: (Long, Int) => Long = (c: Long, x: Int) => c + x + def combOp: (Long, Long) => Long = (c1: Long, c2: Long) => c1 + c2 for (depth <- 1 until 10) { val sum = rdd.treeAggregate(0L)(seqOp, combOp, depth) assert(sum === -1000L) @@ -204,7 +204,7 @@ class RDDSuite extends FunSuite with SharedSparkContext { assert(empty.collect().size === 0) val thrown = intercept[UnsupportedOperationException]{ - empty.reduce(_+_) + empty.reduce(_ + _) } assert(thrown.getMessage.contains("empty")) @@ -321,7 +321,7 @@ class RDDSuite extends FunSuite with SharedSparkContext { assert(list3.sorted === Array("a","b","c"), "Locality preferences are dropped") // RDD with locality preferences spread (non-randomly) over 6 machines, m0 through m5 - val data = sc.makeRDD((1 to 9).map(i => (i, (i to (i+2)).map{ j => "m" + (j%6)}))) + val data = sc.makeRDD((1 to 9).map(i => (i, (i to (i + 2)).map{ j => "m" + (j%6)}))) val coalesced1 = data.coalesce(3) assert(coalesced1.collect().toList.sorted === (1 to 9).toList, "Data got *lost* in coalescing") @@ -921,15 +921,17 @@ class RDDSuite extends FunSuite with SharedSparkContext { test("task serialization exception should not hang scheduler") { class BadSerializable extends Serializable { @throws(classOf[IOException]) - private def writeObject(out: ObjectOutputStream): Unit = throw new KryoException("Bad serialization") + private def writeObject(out: ObjectOutputStream): Unit = + throw new KryoException("Bad serialization") @throws(classOf[IOException]) private def readObject(in: ObjectInputStream): Unit = {} } - // Note that in the original bug, SPARK-4349, that this verifies, the job would only hang if there were - // more threads in the Spark Context than there were number of objects in this sequence. + // Note that in the original bug, SPARK-4349, that this verifies, the job would only hang if + // there were more threads in the Spark Context than there were number of objects in this + // sequence. intercept[Throwable] { - sc.parallelize(Seq(new BadSerializable, new BadSerializable)).collect + sc.parallelize(Seq(new BadSerializable, new BadSerializable)).collect() } // Check that the context has not crashed sc.parallelize(1 to 100).map(x => x*2).collect diff --git a/core/src/test/scala/org/apache/spark/rdd/RDDSuiteUtils.scala b/core/src/test/scala/org/apache/spark/rdd/RDDSuiteUtils.scala index 4762fc17855ce..fe695d85e29dd 100644 --- a/core/src/test/scala/org/apache/spark/rdd/RDDSuiteUtils.scala +++ b/core/src/test/scala/org/apache/spark/rdd/RDDSuiteUtils.scala @@ -21,11 +21,11 @@ object RDDSuiteUtils { case class Person(first: String, last: String, age: Int) object AgeOrdering extends Ordering[Person] { - def compare(a:Person, b:Person) = a.age compare b.age + def compare(a:Person, b:Person): Int = a.age.compare(b.age) } object NameOrdering extends Ordering[Person] { - def compare(a:Person, b:Person) = + def compare(a:Person, b:Person): Int = implicitly[Ordering[Tuple2[String,String]]].compare((a.last, a.first), (b.last, b.first)) } } diff --git a/core/src/test/scala/org/apache/spark/rpc/RpcEnvSuite.scala b/core/src/test/scala/org/apache/spark/rpc/RpcEnvSuite.scala index e07bdb9637575..ada07ef11cd7a 100644 --- a/core/src/test/scala/org/apache/spark/rpc/RpcEnvSuite.scala +++ b/core/src/test/scala/org/apache/spark/rpc/RpcEnvSuite.scala @@ -70,7 +70,7 @@ abstract class RpcEnvSuite extends FunSuite with BeforeAndAfterAll { env.setupEndpoint("send-remotely", new RpcEndpoint { override val rpcEnv = env - override def receive = { + override def receive: PartialFunction[Any, Unit] = { case msg: String => message = msg } }) @@ -109,7 +109,7 @@ abstract class RpcEnvSuite extends FunSuite with BeforeAndAfterAll { val rpcEndpointRef = env.setupEndpoint("ask-locally", new RpcEndpoint { override val rpcEnv = env - override def receiveAndReply(context: RpcCallContext) = { + override def receiveAndReply(context: RpcCallContext): PartialFunction[Any, Unit] = { case msg: String => { context.reply(msg) } @@ -123,7 +123,7 @@ abstract class RpcEnvSuite extends FunSuite with BeforeAndAfterAll { env.setupEndpoint("ask-remotely", new RpcEndpoint { override val rpcEnv = env - override def receiveAndReply(context: RpcCallContext) = { + override def receiveAndReply(context: RpcCallContext): PartialFunction[Any, Unit] = { case msg: String => { context.reply(msg) } @@ -146,7 +146,7 @@ abstract class RpcEnvSuite extends FunSuite with BeforeAndAfterAll { env.setupEndpoint("ask-timeout", new RpcEndpoint { override val rpcEnv = env - override def receiveAndReply(context: RpcCallContext) = { + override def receiveAndReply(context: RpcCallContext): PartialFunction[Any, Unit] = { case msg: String => { Thread.sleep(100) context.reply(msg) @@ -182,7 +182,7 @@ abstract class RpcEnvSuite extends FunSuite with BeforeAndAfterAll { calledMethods += "start" } - override def receive = { + override def receive: PartialFunction[Any, Unit] = { case msg: String => } @@ -206,7 +206,7 @@ abstract class RpcEnvSuite extends FunSuite with BeforeAndAfterAll { throw new RuntimeException("Oops!") } - override def receive = { + override def receive: PartialFunction[Any, Unit] = { case m => } @@ -225,7 +225,7 @@ abstract class RpcEnvSuite extends FunSuite with BeforeAndAfterAll { val endpointRef = env.setupEndpoint("onError-onStop", new RpcEndpoint { override val rpcEnv = env - override def receive = { + override def receive: PartialFunction[Any, Unit] = { case m => } @@ -250,8 +250,8 @@ abstract class RpcEnvSuite extends FunSuite with BeforeAndAfterAll { val endpointRef = env.setupEndpoint("onError-receive", new RpcEndpoint { override val rpcEnv = env - override def receive = { - case m => throw new RuntimeException("Oops!") + override def receive: PartialFunction[Any, Unit] = { + case m => throw new RuntimeException("Oops!") } override def onError(cause: Throwable): Unit = { @@ -277,7 +277,7 @@ abstract class RpcEnvSuite extends FunSuite with BeforeAndAfterAll { callSelfSuccessfully = true } - override def receive = { + override def receive: PartialFunction[Any, Unit] = { case m => } }) @@ -294,7 +294,7 @@ abstract class RpcEnvSuite extends FunSuite with BeforeAndAfterAll { val endpointRef = env.setupEndpoint("self-receive", new RpcEndpoint { override val rpcEnv = env - override def receive = { + override def receive: PartialFunction[Any, Unit] = { case m => { self callSelfSuccessfully = true @@ -311,30 +311,28 @@ abstract class RpcEnvSuite extends FunSuite with BeforeAndAfterAll { } test("self: call in onStop") { - @volatile var e: Throwable = null + @volatile var selfOption: Option[RpcEndpointRef] = null val endpointRef = env.setupEndpoint("self-onStop", new RpcEndpoint { override val rpcEnv = env - override def receive = { + override def receive: PartialFunction[Any, Unit] = { case m => } override def onStop(): Unit = { - self + selfOption = Option(self) } override def onError(cause: Throwable): Unit = { - e = cause } }) env.stop(endpointRef) eventually(timeout(5 seconds), interval(10 millis)) { - // Calling `self` in `onStop` is invalid - assert(e != null) - assert(e.getMessage.contains("Cannot find RpcEndpointRef")) + // Calling `self` in `onStop` will return null, so selfOption will be None + assert(selfOption == None) } } @@ -342,10 +340,10 @@ abstract class RpcEnvSuite extends FunSuite with BeforeAndAfterAll { // If a RpcEnv implementation breaks the `receive` contract, hope this test can expose it for(i <- 0 until 100) { @volatile var result = 0 - val endpointRef = env.setupThreadSafeEndpoint(s"receive-in-sequence-$i", new RpcEndpoint { + val endpointRef = env.setupEndpoint(s"receive-in-sequence-$i", new ThreadSafeRpcEndpoint { override val rpcEnv = env - override def receive = { + override def receive: PartialFunction[Any, Unit] = { case m => result += 1 } @@ -374,7 +372,7 @@ abstract class RpcEnvSuite extends FunSuite with BeforeAndAfterAll { val endpointRef = env.setupEndpoint("stop-reentrant", new RpcEndpoint { override val rpcEnv = env - override def receive = { + override def receive: PartialFunction[Any, Unit] = { case m => } @@ -396,7 +394,7 @@ abstract class RpcEnvSuite extends FunSuite with BeforeAndAfterAll { val endpointRef = env.setupEndpoint("sendWithReply", new RpcEndpoint { override val rpcEnv = env - override def receiveAndReply(context: RpcCallContext) = { + override def receiveAndReply(context: RpcCallContext): PartialFunction[Any, Unit] = { case m => context.reply("ack") } }) @@ -412,7 +410,7 @@ abstract class RpcEnvSuite extends FunSuite with BeforeAndAfterAll { env.setupEndpoint("sendWithReply-remotely", new RpcEndpoint { override val rpcEnv = env - override def receiveAndReply(context: RpcCallContext) = { + override def receiveAndReply(context: RpcCallContext): PartialFunction[Any, Unit] = { case m => context.reply("ack") } }) @@ -434,7 +432,7 @@ abstract class RpcEnvSuite extends FunSuite with BeforeAndAfterAll { val endpointRef = env.setupEndpoint("sendWithReply-error", new RpcEndpoint { override val rpcEnv = env - override def receiveAndReply(context: RpcCallContext) = { + override def receiveAndReply(context: RpcCallContext): PartialFunction[Any, Unit] = { case m => context.sendFailure(new SparkException("Oops")) } }) @@ -452,7 +450,7 @@ abstract class RpcEnvSuite extends FunSuite with BeforeAndAfterAll { env.setupEndpoint("sendWithReply-remotely-error", new RpcEndpoint { override val rpcEnv = env - override def receiveAndReply(context: RpcCallContext) = { + override def receiveAndReply(context: RpcCallContext): PartialFunction[Any, Unit] = { case msg: String => context.sendFailure(new SparkException("Oops")) } }) @@ -475,10 +473,10 @@ abstract class RpcEnvSuite extends FunSuite with BeforeAndAfterAll { test("network events") { val events = new mutable.ArrayBuffer[(Any, Any)] with mutable.SynchronizedBuffer[(Any, Any)] - env.setupThreadSafeEndpoint("network-events", new RpcEndpoint { + env.setupEndpoint("network-events", new ThreadSafeRpcEndpoint { override val rpcEnv = env - override def receive = { + override def receive: PartialFunction[Any, Unit] = { case "hello" => case m => events += "receive" -> m } @@ -516,10 +514,35 @@ abstract class RpcEnvSuite extends FunSuite with BeforeAndAfterAll { ("onDisconnected", remoteAddress))) } } -} -case object Start + test("sendWithReply: unserializable error") { + env.setupEndpoint("sendWithReply-unserializable-error", new RpcEndpoint { + override val rpcEnv = env + + override def receiveAndReply(context: RpcCallContext): PartialFunction[Any, Unit] = { + case msg: String => context.sendFailure(new UnserializableException) + } + }) + + val anotherEnv = createRpcEnv(new SparkConf(), "remote", 13345) + // Use anotherEnv to find out the RpcEndpointRef + val rpcEndpointRef = anotherEnv.setupEndpointRef( + "local", env.address, "sendWithReply-unserializable-error") + try { + val f = rpcEndpointRef.sendWithReply[String]("hello") + intercept[TimeoutException] { + Await.result(f, 1 seconds) + } + } finally { + anotherEnv.shutdown() + anotherEnv.awaitTermination() + } + } + +} -case class Ping(id: Int) +class UnserializableClass -case class Pong(id: Int) +class UnserializableException extends Exception { + private val unserializableField = new UnserializableClass +} diff --git a/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala index 63360a0f189a3..3c52a8c4460c6 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala @@ -57,20 +57,18 @@ class MyRDD( locations: Seq[Seq[String]] = Nil) extends RDD[(Int, Int)](sc, dependencies) with Serializable { override def compute(split: Partition, context: TaskContext): Iterator[(Int, Int)] = throw new RuntimeException("should not be reached") - override def getPartitions = (0 until numPartitions).map(i => new Partition { - override def index = i + override def getPartitions: Array[Partition] = (0 until numPartitions).map(i => new Partition { + override def index: Int = i }).toArray override def getPreferredLocations(split: Partition): Seq[String] = - if (locations.isDefinedAt(split.index)) - locations(split.index) - else - Nil + if (locations.isDefinedAt(split.index)) locations(split.index) else Nil override def toString: String = "DAGSchedulerSuiteRDD " + id } class DAGSchedulerSuiteDummyException extends Exception -class DAGSchedulerSuite extends FunSuiteLike with BeforeAndAfter with LocalSparkContext with Timeouts { +class DAGSchedulerSuite + extends FunSuiteLike with BeforeAndAfter with LocalSparkContext with Timeouts { val conf = new SparkConf /** Set of TaskSets the DAGScheduler has requested executed. */ @@ -209,7 +207,8 @@ class DAGSchedulerSuite extends FunSuiteLike with BeforeAndAfter with LocalSpar assert(taskSet.tasks.size >= results.size) for ((result, i) <- results.zipWithIndex) { if (i < taskSet.tasks.size) { - runEvent(CompletionEvent(taskSet.tasks(i), result._1, result._2, null, createFakeTaskInfo(), null)) + runEvent(CompletionEvent( + taskSet.tasks(i), result._1, result._2, null, createFakeTaskInfo(), null)) } } } @@ -269,21 +268,23 @@ class DAGSchedulerSuite extends FunSuiteLike with BeforeAndAfter with LocalSpar submit(new MyRDD(sc, 1, Nil), Array(0)) complete(taskSets(0), List((Success, 42))) assert(results === Map(0 -> 42)) - assertDataStructuresEmpty + assertDataStructuresEmpty() } test("local job") { val rdd = new PairOfIntsRDD(sc, Nil) { override def compute(split: Partition, context: TaskContext): Iterator[(Int, Int)] = Array(42 -> 0).iterator - override def getPartitions = Array( new Partition { override def index = 0 } ) - override def getPreferredLocations(split: Partition) = Nil - override def toString = "DAGSchedulerSuite Local RDD" + override def getPartitions: Array[Partition] = + Array( new Partition { override def index: Int = 0 } ) + override def getPreferredLocations(split: Partition): List[String] = Nil + override def toString: String = "DAGSchedulerSuite Local RDD" } val jobId = scheduler.nextJobId.getAndIncrement() - runEvent(JobSubmitted(jobId, rdd, jobComputeFunc, Array(0), true, CallSite("", ""), jobListener)) + runEvent( + JobSubmitted(jobId, rdd, jobComputeFunc, Array(0), true, CallSite("", ""), jobListener)) assert(results === Map(0 -> 42)) - assertDataStructuresEmpty + assertDataStructuresEmpty() } test("local job oom") { @@ -295,9 +296,10 @@ class DAGSchedulerSuite extends FunSuiteLike with BeforeAndAfter with LocalSpar override def toString = "DAGSchedulerSuite Local RDD" } val jobId = scheduler.nextJobId.getAndIncrement() - runEvent(JobSubmitted(jobId, rdd, jobComputeFunc, Array(0), true, CallSite("", ""), jobListener)) + runEvent( + JobSubmitted(jobId, rdd, jobComputeFunc, Array(0), true, CallSite("", ""), jobListener)) assert(results.size == 0) - assertDataStructuresEmpty + assertDataStructuresEmpty() } test("run trivial job w/ dependency") { @@ -306,7 +308,7 @@ class DAGSchedulerSuite extends FunSuiteLike with BeforeAndAfter with LocalSpar submit(finalRdd, Array(0)) complete(taskSets(0), Seq((Success, 42))) assert(results === Map(0 -> 42)) - assertDataStructuresEmpty + assertDataStructuresEmpty() } test("cache location preferences w/ dependency") { @@ -319,7 +321,7 @@ class DAGSchedulerSuite extends FunSuiteLike with BeforeAndAfter with LocalSpar assertLocations(taskSet, Seq(Seq("hostA", "hostB"))) complete(taskSet, Seq((Success, 42))) assert(results === Map(0 -> 42)) - assertDataStructuresEmpty + assertDataStructuresEmpty() } test("regression test for getCacheLocs") { @@ -335,7 +337,7 @@ class DAGSchedulerSuite extends FunSuiteLike with BeforeAndAfter with LocalSpar } test("avoid exponential blowup when getting preferred locs list") { - // Build up a complex dependency graph with repeated zip operations, without preferred locations. + // Build up a complex dependency graph with repeated zip operations, without preferred locations var rdd: RDD[_] = new MyRDD(sc, 1, Nil) (1 to 30).foreach(_ => rdd = rdd.zip(rdd)) // getPreferredLocs runs quickly, indicating that exponential graph traversal is avoided. @@ -357,7 +359,7 @@ class DAGSchedulerSuite extends FunSuiteLike with BeforeAndAfter with LocalSpar assert(sc.listenerBus.waitUntilEmpty(WAIT_TIMEOUT_MILLIS)) assert(sparkListener.failedStages.contains(0)) assert(sparkListener.failedStages.size === 1) - assertDataStructuresEmpty + assertDataStructuresEmpty() } test("trivial job failure") { @@ -367,7 +369,7 @@ class DAGSchedulerSuite extends FunSuiteLike with BeforeAndAfter with LocalSpar assert(sc.listenerBus.waitUntilEmpty(WAIT_TIMEOUT_MILLIS)) assert(sparkListener.failedStages.contains(0)) assert(sparkListener.failedStages.size === 1) - assertDataStructuresEmpty + assertDataStructuresEmpty() } test("trivial job cancellation") { @@ -378,7 +380,7 @@ class DAGSchedulerSuite extends FunSuiteLike with BeforeAndAfter with LocalSpar assert(sc.listenerBus.waitUntilEmpty(WAIT_TIMEOUT_MILLIS)) assert(sparkListener.failedStages.contains(0)) assert(sparkListener.failedStages.size === 1) - assertDataStructuresEmpty + assertDataStructuresEmpty() } test("job cancellation no-kill backend") { @@ -387,18 +389,20 @@ class DAGSchedulerSuite extends FunSuiteLike with BeforeAndAfter with LocalSpar val noKillTaskScheduler = new TaskScheduler() { override def rootPool: Pool = null override def schedulingMode: SchedulingMode = SchedulingMode.NONE - override def start() = {} - override def stop() = {} - override def submitTasks(taskSet: TaskSet) = { + override def start(): Unit = {} + override def stop(): Unit = {} + override def submitTasks(taskSet: TaskSet): Unit = { taskSets += taskSet } override def cancelTasks(stageId: Int, interruptThread: Boolean) { throw new UnsupportedOperationException } - override def setDAGScheduler(dagScheduler: DAGScheduler) = {} - override def defaultParallelism() = 2 - override def executorHeartbeatReceived(execId: String, taskMetrics: Array[(Long, TaskMetrics)], - blockManagerId: BlockManagerId): Boolean = true + override def setDAGScheduler(dagScheduler: DAGScheduler): Unit = {} + override def defaultParallelism(): Int = 2 + override def executorHeartbeatReceived( + execId: String, + taskMetrics: Array[(Long, TaskMetrics)], + blockManagerId: BlockManagerId): Boolean = true override def executorLost(executorId: String, reason: ExecutorLossReason): Unit = {} } val noKillScheduler = new DAGScheduler( @@ -422,7 +426,7 @@ class DAGSchedulerSuite extends FunSuiteLike with BeforeAndAfter with LocalSpar // When the task set completes normally, state should be correctly updated. complete(taskSets(0), Seq((Success, 42))) assert(results === Map(0 -> 42)) - assertDataStructuresEmpty + assertDataStructuresEmpty() assert(sc.listenerBus.waitUntilEmpty(WAIT_TIMEOUT_MILLIS)) assert(sparkListener.failedStages.isEmpty) @@ -442,7 +446,7 @@ class DAGSchedulerSuite extends FunSuiteLike with BeforeAndAfter with LocalSpar Array(makeBlockManagerId("hostA"), makeBlockManagerId("hostB"))) complete(taskSets(1), Seq((Success, 42))) assert(results === Map(0 -> 42)) - assertDataStructuresEmpty + assertDataStructuresEmpty() } test("run trivial shuffle with fetch failure") { @@ -465,10 +469,11 @@ class DAGSchedulerSuite extends FunSuiteLike with BeforeAndAfter with LocalSpar // have the 2nd attempt pass complete(taskSets(2), Seq((Success, makeMapStatus("hostA", 1)))) // we can see both result blocks now - assert(mapOutputTracker.getServerStatuses(shuffleId, 0).map(_._1.host) === Array("hostA", "hostB")) + assert(mapOutputTracker.getServerStatuses(shuffleId, 0).map(_._1.host) === + Array("hostA", "hostB")) complete(taskSets(3), Seq((Success, 43))) assert(results === Map(0 -> 42, 1 -> 43)) - assertDataStructuresEmpty + assertDataStructuresEmpty() } test("trivial shuffle with multiple fetch failures") { @@ -521,19 +526,23 @@ class DAGSchedulerSuite extends FunSuiteLike with BeforeAndAfter with LocalSpar assert(newEpoch > oldEpoch) val taskSet = taskSets(0) // should be ignored for being too old - runEvent(CompletionEvent(taskSet.tasks(0), Success, makeMapStatus("hostA", 1), null, createFakeTaskInfo(), null)) + runEvent(CompletionEvent( + taskSet.tasks(0), Success, makeMapStatus("hostA", 1), null, createFakeTaskInfo(), null)) // should work because it's a non-failed host - runEvent(CompletionEvent(taskSet.tasks(0), Success, makeMapStatus("hostB", 1), null, createFakeTaskInfo(), null)) + runEvent(CompletionEvent( + taskSet.tasks(0), Success, makeMapStatus("hostB", 1), null, createFakeTaskInfo(), null)) // should be ignored for being too old - runEvent(CompletionEvent(taskSet.tasks(0), Success, makeMapStatus("hostA", 1), null, createFakeTaskInfo(), null)) + runEvent(CompletionEvent( + taskSet.tasks(0), Success, makeMapStatus("hostA", 1), null, createFakeTaskInfo(), null)) // should work because it's a new epoch taskSet.tasks(1).epoch = newEpoch - runEvent(CompletionEvent(taskSet.tasks(1), Success, makeMapStatus("hostA", 1), null, createFakeTaskInfo(), null)) + runEvent(CompletionEvent( + taskSet.tasks(1), Success, makeMapStatus("hostA", 1), null, createFakeTaskInfo(), null)) assert(mapOutputTracker.getServerStatuses(shuffleId, 0).map(_._1) === Array(makeBlockManagerId("hostB"), makeBlockManagerId("hostA"))) complete(taskSets(1), Seq((Success, 42), (Success, 43))) assert(results === Map(0 -> 42, 1 -> 43)) - assertDataStructuresEmpty + assertDataStructuresEmpty() } test("run shuffle with map stage failure") { @@ -552,7 +561,7 @@ class DAGSchedulerSuite extends FunSuiteLike with BeforeAndAfter with LocalSpar assert(sc.listenerBus.waitUntilEmpty(WAIT_TIMEOUT_MILLIS)) assert(sparkListener.failedStages.toSet === Set(0)) - assertDataStructuresEmpty + assertDataStructuresEmpty() } /** @@ -586,7 +595,7 @@ class DAGSchedulerSuite extends FunSuiteLike with BeforeAndAfter with LocalSpar class FailureRecordingJobListener() extends JobListener { var failureMessage: String = _ override def taskSucceeded(index: Int, result: Any) {} - override def jobFailed(exception: Exception) = { failureMessage = exception.getMessage } + override def jobFailed(exception: Exception): Unit = { failureMessage = exception.getMessage } } val listener1 = new FailureRecordingJobListener() val listener2 = new FailureRecordingJobListener() @@ -606,7 +615,7 @@ class DAGSchedulerSuite extends FunSuiteLike with BeforeAndAfter with LocalSpar assert(listener1.failureMessage === s"Job aborted due to stage failure: $stageFailureMessage") assert(listener2.failureMessage === s"Job aborted due to stage failure: $stageFailureMessage") - assertDataStructuresEmpty + assertDataStructuresEmpty() } test("run trivial shuffle with out-of-band failure and retry") { @@ -629,7 +638,7 @@ class DAGSchedulerSuite extends FunSuiteLike with BeforeAndAfter with LocalSpar Array(makeBlockManagerId("hostC"), makeBlockManagerId("hostB"))) complete(taskSets(2), Seq((Success, 42))) assert(results === Map(0 -> 42)) - assertDataStructuresEmpty + assertDataStructuresEmpty() } test("recursive shuffle failures") { @@ -658,7 +667,7 @@ class DAGSchedulerSuite extends FunSuiteLike with BeforeAndAfter with LocalSpar complete(taskSets(4), Seq((Success, makeMapStatus("hostA", 1)))) complete(taskSets(5), Seq((Success, 42))) assert(results === Map(0 -> 42)) - assertDataStructuresEmpty + assertDataStructuresEmpty() } test("cached post-shuffle") { @@ -690,7 +699,7 @@ class DAGSchedulerSuite extends FunSuiteLike with BeforeAndAfter with LocalSpar complete(taskSets(3), Seq((Success, makeMapStatus("hostD", 1)))) complete(taskSets(4), Seq((Success, 42))) assert(results === Map(0 -> 42)) - assertDataStructuresEmpty + assertDataStructuresEmpty() } test("misbehaved accumulator should not crash DAGScheduler and SparkContext") { @@ -742,7 +751,7 @@ class DAGSchedulerSuite extends FunSuiteLike with BeforeAndAfter with LocalSpar } test("accumulator not calculated for resubmitted result stage") { - //just for register + // just for register val accum = new Accumulator[Int](0, AccumulatorParam.IntAccumulatorParam) val finalRdd = new MyRDD(sc, 1, Nil) submit(finalRdd, Array(0)) @@ -754,7 +763,7 @@ class DAGSchedulerSuite extends FunSuiteLike with BeforeAndAfter with LocalSpar assert(accVal === 1) - assertDataStructuresEmpty + assertDataStructuresEmpty() } /** @@ -774,7 +783,7 @@ class DAGSchedulerSuite extends FunSuiteLike with BeforeAndAfter with LocalSpar private def makeBlockManagerId(host: String): BlockManagerId = BlockManagerId("exec-" + host, host, 12345) - private def assertDataStructuresEmpty = { + private def assertDataStructuresEmpty(): Unit = { assert(scheduler.activeJobs.isEmpty) assert(scheduler.failedStages.isEmpty) assert(scheduler.jobIdToActiveJob.isEmpty) @@ -783,6 +792,7 @@ class DAGSchedulerSuite extends FunSuiteLike with BeforeAndAfter with LocalSpar assert(scheduler.runningStages.isEmpty) assert(scheduler.shuffleToMapStage.isEmpty) assert(scheduler.waitingStages.isEmpty) + assert(scheduler.outputCommitCoordinator.isEmpty) } // Nothing in this test should break if the task info's fields are null, but diff --git a/core/src/test/scala/org/apache/spark/scheduler/EventLoggingListenerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/EventLoggingListenerSuite.scala index 448258a754153..6d25edb7d20dc 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/EventLoggingListenerSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/EventLoggingListenerSuite.scala @@ -61,7 +61,7 @@ class EventLoggingListenerSuite extends FunSuite with LocalSparkContext with Bef test("Verify log file exist") { // Verify logging directory exists val conf = getLoggingConf(testDirPath) - val eventLogger = new EventLoggingListener("test", testDirPath.toUri().toString(), conf) + val eventLogger = new EventLoggingListener("test", testDirPath.toUri(), conf) eventLogger.start() val logPath = new Path(eventLogger.logPath + EventLoggingListener.IN_PROGRESS) @@ -95,7 +95,7 @@ class EventLoggingListenerSuite extends FunSuite with LocalSparkContext with Bef } test("Log overwriting") { - val logUri = EventLoggingListener.getLogPath(testDir.getAbsolutePath, "test") + val logUri = EventLoggingListener.getLogPath(testDir.toURI, "test") val logPath = new URI(logUri).getPath // Create file before writing the event log new FileOutputStream(new File(logPath)).close() @@ -107,16 +107,19 @@ class EventLoggingListenerSuite extends FunSuite with LocalSparkContext with Bef test("Event log name") { // without compression - assert(s"file:/base-dir/app1" === EventLoggingListener.getLogPath("/base-dir", "app1")) + assert(s"file:/base-dir/app1" === EventLoggingListener.getLogPath( + Utils.resolveURI("/base-dir"), "app1")) // with compression assert(s"file:/base-dir/app1.lzf" === - EventLoggingListener.getLogPath("/base-dir", "app1", Some("lzf"))) + EventLoggingListener.getLogPath(Utils.resolveURI("/base-dir"), "app1", Some("lzf"))) // illegal characters in app ID assert(s"file:/base-dir/a-fine-mind_dollar_bills__1" === - EventLoggingListener.getLogPath("/base-dir", "a fine:mind$dollar{bills}.1")) + EventLoggingListener.getLogPath(Utils.resolveURI("/base-dir"), + "a fine:mind$dollar{bills}.1")) // illegal characters in app ID with compression assert(s"file:/base-dir/a-fine-mind_dollar_bills__1.lz4" === - EventLoggingListener.getLogPath("/base-dir", "a fine:mind$dollar{bills}.1", Some("lz4"))) + EventLoggingListener.getLogPath(Utils.resolveURI("/base-dir"), + "a fine:mind$dollar{bills}.1", Some("lz4"))) } /* ----------------- * @@ -137,7 +140,7 @@ class EventLoggingListenerSuite extends FunSuite with LocalSparkContext with Bef val conf = getLoggingConf(testDirPath, compressionCodec) extraConf.foreach { case (k, v) => conf.set(k, v) } val logName = compressionCodec.map("test-" + _).getOrElse("test") - val eventLogger = new EventLoggingListener(logName, testDirPath.toUri().toString(), conf) + val eventLogger = new EventLoggingListener(logName, testDirPath.toUri(), conf) val listenerBus = new LiveListenerBus val applicationStart = SparkListenerApplicationStart("Greatest App (N)ever", None, 125L, "Mickey") @@ -173,12 +176,15 @@ class EventLoggingListenerSuite extends FunSuite with LocalSparkContext with Bef * This runs a simple Spark job and asserts that the expected events are logged when expected. */ private def testApplicationEventLogging(compressionCodec: Option[String] = None) { + // Set defaultFS to something that would cause an exception, to make sure we don't run + // into SPARK-6688. val conf = getLoggingConf(testDirPath, compressionCodec) + .set("spark.hadoop.fs.defaultFS", "unsupported://example.com") val sc = new SparkContext("local-cluster[2,2,512]", "test", conf) assert(sc.eventLogger.isDefined) val eventLogger = sc.eventLogger.get val eventLogPath = eventLogger.logPath - val expectedLogDir = testDir.toURI().toString() + val expectedLogDir = testDir.toURI() assert(eventLogPath === EventLoggingListener.getLogPath( expectedLogDir, sc.applicationId, compressionCodec.map(CompressionCodec.getShortName))) @@ -262,7 +268,7 @@ class EventLoggingListenerSuite extends FunSuite with LocalSparkContext with Bef object EventLoggingListenerSuite { /** Get a SparkConf with event logging enabled. */ - def getLoggingConf(logDir: Path, compressionCodec: Option[String] = None) = { + def getLoggingConf(logDir: Path, compressionCodec: Option[String] = None): SparkConf = { val conf = new SparkConf conf.set("spark.eventLog.enabled", "true") conf.set("spark.eventLog.testing", "true") @@ -274,5 +280,5 @@ object EventLoggingListenerSuite { conf } - def getUniqueApplicationId = "test-" + System.currentTimeMillis + def getUniqueApplicationId: String = "test-" + System.currentTimeMillis } diff --git a/core/src/test/scala/org/apache/spark/scheduler/NotSerializableFakeTask.scala b/core/src/test/scala/org/apache/spark/scheduler/NotSerializableFakeTask.scala index 6b75c98839e03..9b92f8de56759 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/NotSerializableFakeTask.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/NotSerializableFakeTask.scala @@ -24,7 +24,9 @@ import org.apache.spark.TaskContext /** * A Task implementation that fails to serialize. */ -private[spark] class NotSerializableFakeTask(myId: Int, stageId: Int) extends Task[Array[Byte]](stageId, 0) { +private[spark] class NotSerializableFakeTask(myId: Int, stageId: Int) + extends Task[Array[Byte]](stageId, 0) { + override def runTask(context: TaskContext): Array[Byte] = Array.empty[Byte] override def preferredLocations: Seq[TaskLocation] = Seq[TaskLocation]() diff --git a/core/src/test/scala/org/apache/spark/scheduler/ReplayListenerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/ReplayListenerSuite.scala index 601694f57aad0..6de6d2fec622a 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/ReplayListenerSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/ReplayListenerSuite.scala @@ -18,6 +18,7 @@ package org.apache.spark.scheduler import java.io.{File, PrintWriter} +import java.net.URI import org.json4s.jackson.JsonMethods._ import org.scalatest.{BeforeAndAfter, FunSuite} @@ -145,7 +146,7 @@ class ReplayListenerSuite extends FunSuite with BeforeAndAfter { * log the events. */ private class EventMonster(conf: SparkConf) - extends EventLoggingListener("test", "testdir", conf) { + extends EventLoggingListener("test", new URI("testdir"), conf) { override def start() { } diff --git a/core/src/test/scala/org/apache/spark/scheduler/SparkListenerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/SparkListenerSuite.scala index 627c9a4ddfffc..825c616c0c3e0 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/SparkListenerSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/SparkListenerSuite.scala @@ -85,7 +85,7 @@ class SparkListenerSuite extends FunSuite with LocalSparkContext with Matchers val stopperReturned = new Semaphore(0) class BlockingListener extends SparkListener { - override def onJobEnd(jobEnd: SparkListenerJobEnd) = { + override def onJobEnd(jobEnd: SparkListenerJobEnd): Unit = { listenerStarted.release() listenerWait.acquire() drained = true @@ -206,8 +206,9 @@ class SparkListenerSuite extends FunSuite with LocalSparkContext with Matchers sc.addSparkListener(new StatsReportListener) // just to make sure some of the tasks take a noticeable amount of time val w = { i: Int => - if (i == 0) + if (i == 0) { Thread.sleep(100) + } i } @@ -247,12 +248,12 @@ class SparkListenerSuite extends FunSuite with LocalSparkContext with Matchers */ taskInfoMetrics.foreach { case (taskInfo, taskMetrics) => - taskMetrics.resultSize should be > (0l) + taskMetrics.resultSize should be > (0L) if (stageInfo.rddInfos.exists(info => info.name == d2.name || info.name == d3.name)) { taskMetrics.inputMetrics should not be ('defined) taskMetrics.outputMetrics should not be ('defined) taskMetrics.shuffleWriteMetrics should be ('defined) - taskMetrics.shuffleWriteMetrics.get.shuffleBytesWritten should be > (0l) + taskMetrics.shuffleWriteMetrics.get.shuffleBytesWritten should be > (0L) } if (stageInfo.rddInfos.exists(_.name == d4.name)) { taskMetrics.shuffleReadMetrics should be ('defined) @@ -260,7 +261,7 @@ class SparkListenerSuite extends FunSuite with LocalSparkContext with Matchers sm.totalBlocksFetched should be (128) sm.localBlocksFetched should be (128) sm.remoteBlocksFetched should be (0) - sm.remoteBytesRead should be (0l) + sm.remoteBytesRead should be (0L) } } } @@ -406,12 +407,12 @@ class SparkListenerSuite extends FunSuite with LocalSparkContext with Matchers val startedGettingResultTasks = new mutable.HashSet[Int]() val endedTasks = new mutable.HashSet[Int]() - override def onTaskStart(taskStart: SparkListenerTaskStart) = synchronized { + override def onTaskStart(taskStart: SparkListenerTaskStart): Unit = synchronized { startedTasks += taskStart.taskInfo.index notify() } - override def onTaskEnd(taskEnd: SparkListenerTaskEnd) = synchronized { + override def onTaskEnd(taskEnd: SparkListenerTaskEnd): Unit = synchronized { endedTasks += taskEnd.taskInfo.index notify() } @@ -425,7 +426,7 @@ class SparkListenerSuite extends FunSuite with LocalSparkContext with Matchers * A simple listener that throws an exception on job end. */ private class BadListener extends SparkListener { - override def onJobEnd(jobEnd: SparkListenerJobEnd) = { throw new Exception } + override def onJobEnd(jobEnd: SparkListenerJobEnd): Unit = { throw new Exception } } } @@ -438,10 +439,10 @@ class SparkListenerSuite extends FunSuite with LocalSparkContext with Matchers */ private class BasicJobCounter extends SparkListener { var count = 0 - override def onJobEnd(job: SparkListenerJobEnd) = count += 1 + override def onJobEnd(job: SparkListenerJobEnd): Unit = count += 1 } private class ListenerThatAcceptsSparkConf(conf: SparkConf) extends SparkListener { var count = 0 - override def onJobEnd(job: SparkListenerJobEnd) = count += 1 + override def onJobEnd(job: SparkListenerJobEnd): Unit = count += 1 } diff --git a/core/src/test/scala/org/apache/spark/scheduler/TaskSchedulerImplSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/TaskSchedulerImplSuite.scala index add13f5b21765..ffa4381969b68 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/TaskSchedulerImplSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/TaskSchedulerImplSuite.scala @@ -17,8 +17,6 @@ package org.apache.spark.scheduler -import java.util.Properties - import org.scalatest.FunSuite import org.apache.spark._ @@ -27,7 +25,7 @@ class FakeSchedulerBackend extends SchedulerBackend { def start() {} def stop() {} def reviveOffers() {} - def defaultParallelism() = 1 + def defaultParallelism(): Int = 1 } class TaskSchedulerImplSuite extends FunSuite with LocalSparkContext with Logging { @@ -115,7 +113,8 @@ class TaskSchedulerImplSuite extends FunSuite with LocalSparkContext with Loggin } val numFreeCores = 1 taskScheduler.setDAGScheduler(dagScheduler) - var taskSet = new TaskSet(Array(new NotSerializableFakeTask(1, 0), new NotSerializableFakeTask(0, 1)), 0, 0, 0, null) + val taskSet = new TaskSet( + Array(new NotSerializableFakeTask(1, 0), new NotSerializableFakeTask(0, 1)), 0, 0, 0, null) val multiCoreWorkerOffers = Seq(new WorkerOffer("executor0", "host0", taskCpus), new WorkerOffer("executor1", "host1", numFreeCores)) taskScheduler.submitTasks(taskSet) @@ -123,7 +122,8 @@ class TaskSchedulerImplSuite extends FunSuite with LocalSparkContext with Loggin assert(0 === taskDescriptions.length) // Now check that we can still submit tasks - // Even if one of the tasks has not-serializable tasks, the other task set should still be processed without error + // Even if one of the tasks has not-serializable tasks, the other task set should + // still be processed without error taskScheduler.submitTasks(taskSet) taskScheduler.submitTasks(FakeTask.createTaskSet(1)) taskDescriptions = taskScheduler.resourceOffers(multiCoreWorkerOffers).flatten diff --git a/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala index 12330d8f63c40..716d12c0762cf 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala @@ -67,7 +67,7 @@ object FakeRackUtil { hostToRack(host) = rack } - def getRackForHost(host: String) = { + def getRackForHost(host: String): Option[String] = { hostToRack.get(host) } } @@ -327,8 +327,8 @@ class TaskSetManagerSuite extends FunSuite with LocalSparkContext with Logging { // First offer host1: first task should be chosen assert(manager.resourceOffer("exec1", "host1", ANY).get.index === 0) - // After this, nothing should get chosen, because we have separated tasks with unavailable preference - // from the noPrefPendingTasks + // After this, nothing should get chosen, because we have separated tasks with unavailable + // preference from the noPrefPendingTasks assert(manager.resourceOffer("exec1", "host1", ANY) === None) // Now mark host2 as dead @@ -499,7 +499,8 @@ class TaskSetManagerSuite extends FunSuite with LocalSparkContext with Logging { sched.addExecutor("execC", "host2") manager.executorAdded() // Valid locality should contain PROCESS_LOCAL, NODE_LOCAL, RACK_LOCAL and ANY - assert(manager.myLocalityLevels.sameElements(Array(PROCESS_LOCAL, NODE_LOCAL, NO_PREF, RACK_LOCAL, ANY))) + assert(manager.myLocalityLevels.sameElements( + Array(PROCESS_LOCAL, NODE_LOCAL, NO_PREF, RACK_LOCAL, ANY))) // test if the valid locality is recomputed when the executor is lost sched.removeExecutor("execC") manager.executorLost("execC", "host2") @@ -569,7 +570,8 @@ class TaskSetManagerSuite extends FunSuite with LocalSparkContext with Logging { sc = new SparkContext("local", "test") val sched = new FakeTaskScheduler(sc, ("exec1", "host1")) - val taskSet = new TaskSet(Array(new NotSerializableFakeTask(1, 0), new NotSerializableFakeTask(0, 1)), 0, 0, 0, null) + val taskSet = new TaskSet( + Array(new NotSerializableFakeTask(1, 0), new NotSerializableFakeTask(0, 1)), 0, 0, 0, null) val manager = new TaskSetManager(sched, taskSet, MAX_TASK_FAILURES) intercept[TaskNotSerializableException] { @@ -582,7 +584,7 @@ class TaskSetManagerSuite extends FunSuite with LocalSparkContext with Logging { val conf = new SparkConf().set("spark.driver.maxResultSize", "2m") sc = new SparkContext("local", "test", conf) - def genBytes(size: Int) = { (x: Int) => + def genBytes(size: Int): (Int) => Array[Byte] = { (x: Int) => val bytes = Array.ofDim[Byte](size) scala.util.Random.nextBytes(bytes) bytes @@ -605,7 +607,8 @@ class TaskSetManagerSuite extends FunSuite with LocalSparkContext with Logging { test("speculative and noPref task should be scheduled after node-local") { sc = new SparkContext("local", "test") - val sched = new FakeTaskScheduler(sc, ("execA", "host1"), ("execB", "host2"), ("execC", "host3")) + val sched = new FakeTaskScheduler( + sc, ("execA", "host1"), ("execB", "host2"), ("execC", "host3")) val taskSet = FakeTask.createTaskSet(4, Seq(TaskLocation("host1", "execA")), Seq(TaskLocation("host2"), TaskLocation("host1")), @@ -629,9 +632,11 @@ class TaskSetManagerSuite extends FunSuite with LocalSparkContext with Logging { assert(manager.resourceOffer("execB", "host2", ANY).get.index === 3) } - test("node-local tasks should be scheduled right away when there are only node-local and no-preference tasks") { + test("node-local tasks should be scheduled right away " + + "when there are only node-local and no-preference tasks") { sc = new SparkContext("local", "test") - val sched = new FakeTaskScheduler(sc, ("execA", "host1"), ("execB", "host2"), ("execC", "host3")) + val sched = new FakeTaskScheduler( + sc, ("execA", "host1"), ("execB", "host2"), ("execC", "host3")) val taskSet = FakeTask.createTaskSet(4, Seq(TaskLocation("host1")), Seq(TaskLocation("host2")), @@ -650,7 +655,8 @@ class TaskSetManagerSuite extends FunSuite with LocalSparkContext with Logging { assert(manager.resourceOffer("execA", "host3", NO_PREF).get.index === 2) } - test("SPARK-4939: node-local tasks should be scheduled right after process-local tasks finished") { + test("SPARK-4939: node-local tasks should be scheduled right after process-local tasks finished") + { sc = new SparkContext("local", "test") val sched = new FakeTaskScheduler(sc, ("execA", "host1"), ("execB", "host2")) val taskSet = FakeTask.createTaskSet(4, diff --git a/core/src/test/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerBackendSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerBackendSuite.scala index f1a4380d349b3..a311512e82c5e 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerBackendSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerBackendSuite.scala @@ -63,16 +63,18 @@ class MesosSchedulerBackendSuite extends FunSuite with LocalSparkContext with Mo // uri is null. val executorInfo = mesosSchedulerBackend.createExecutorInfo("test-id") - assert(executorInfo.getCommand.getValue === s" /mesos-home/bin/spark-class ${classOf[MesosExecutorBackend].getName}") + assert(executorInfo.getCommand.getValue === + s" /mesos-home/bin/spark-class ${classOf[MesosExecutorBackend].getName}") // uri exists. conf.set("spark.executor.uri", "hdfs:///test-app-1.0.0.tgz") val executorInfo1 = mesosSchedulerBackend.createExecutorInfo("test-id") - assert(executorInfo1.getCommand.getValue === s"cd test-app-1*; ./bin/spark-class ${classOf[MesosExecutorBackend].getName}") + assert(executorInfo1.getCommand.getValue === + s"cd test-app-1*; ./bin/spark-class ${classOf[MesosExecutorBackend].getName}") } test("mesos resource offers result in launching tasks") { - def createOffer(id: Int, mem: Int, cpu: Int) = { + def createOffer(id: Int, mem: Int, cpu: Int): Offer = { val builder = Offer.newBuilder() builder.addResourcesBuilder() .setName("mem") @@ -82,8 +84,10 @@ class MesosSchedulerBackendSuite extends FunSuite with LocalSparkContext with Mo .setName("cpus") .setType(Value.Type.SCALAR) .setScalar(Scalar.newBuilder().setValue(cpu)) - builder.setId(OfferID.newBuilder().setValue(s"o${id.toString}").build()).setFrameworkId(FrameworkID.newBuilder().setValue("f1")) - .setSlaveId(SlaveID.newBuilder().setValue(s"s${id.toString}")).setHostname(s"host${id.toString}").build() + builder.setId(OfferID.newBuilder().setValue(s"o${id.toString}").build()) + .setFrameworkId(FrameworkID.newBuilder().setValue("f1")) + .setSlaveId(SlaveID.newBuilder().setValue(s"s${id.toString}")) + .setHostname(s"host${id.toString}").build() } val driver = mock[SchedulerDriver] diff --git a/core/src/test/scala/org/apache/spark/serializer/KryoSerializerSuite.scala b/core/src/test/scala/org/apache/spark/serializer/KryoSerializerSuite.scala index 6198df84fab3d..b070a54aa989b 100644 --- a/core/src/test/scala/org/apache/spark/serializer/KryoSerializerSuite.scala +++ b/core/src/test/scala/org/apache/spark/serializer/KryoSerializerSuite.scala @@ -106,7 +106,9 @@ class KryoSerializerSuite extends FunSuite with SharedSparkContext { check(mutable.HashMap(1 -> "one", 2 -> "two")) check(mutable.HashMap("one" -> 1, "two" -> 2)) check(List(Some(mutable.HashMap(1->1, 2->2)), None, Some(mutable.HashMap(3->4)))) - check(List(mutable.HashMap("one" -> 1, "two" -> 2),mutable.HashMap(1->"one",2->"two",3->"three"))) + check(List( + mutable.HashMap("one" -> 1, "two" -> 2), + mutable.HashMap(1->"one",2->"two",3->"three"))) } test("ranges") { @@ -169,7 +171,10 @@ class KryoSerializerSuite extends FunSuite with SharedSparkContext { test("kryo with collect") { val control = 1 :: 2 :: Nil - val result = sc.parallelize(control, 2).map(new ClassWithoutNoArgConstructor(_)).collect().map(_.x) + val result = sc.parallelize(control, 2) + .map(new ClassWithoutNoArgConstructor(_)) + .collect() + .map(_.x) assert(control === result.toSeq) } @@ -237,7 +242,7 @@ class KryoSerializerSuite extends FunSuite with SharedSparkContext { // Set a special, broken ClassLoader and make sure we get an exception on deserialization ser.setDefaultClassLoader(new ClassLoader() { - override def loadClass(name: String) = throw new UnsupportedOperationException + override def loadClass(name: String): Class[_] = throw new UnsupportedOperationException }) intercept[UnsupportedOperationException] { ser.newInstance().deserialize[ClassLoaderTestingObject](bytes) @@ -287,14 +292,14 @@ object KryoTest { class ClassWithNoArgConstructor { var x: Int = 0 - override def equals(other: Any) = other match { + override def equals(other: Any): Boolean = other match { case c: ClassWithNoArgConstructor => x == c.x case _ => false } } class ClassWithoutNoArgConstructor(val x: Int) { - override def equals(other: Any) = other match { + override def equals(other: Any): Boolean = other match { case c: ClassWithoutNoArgConstructor => x == c.x case _ => false } diff --git a/core/src/test/scala/org/apache/spark/serializer/ProactiveClosureSerializationSuite.scala b/core/src/test/scala/org/apache/spark/serializer/ProactiveClosureSerializationSuite.scala index d037e2c19a64d..433fd6bb4a11d 100644 --- a/core/src/test/scala/org/apache/spark/serializer/ProactiveClosureSerializationSuite.scala +++ b/core/src/test/scala/org/apache/spark/serializer/ProactiveClosureSerializationSuite.scala @@ -24,14 +24,16 @@ import org.apache.spark.rdd.RDD /* A trivial (but unserializable) container for trivial functions */ class UnserializableClass { - def op[T](x: T) = x.toString + def op[T](x: T): String = x.toString - def pred[T](x: T) = x.toString.length % 2 == 0 + def pred[T](x: T): Boolean = x.toString.length % 2 == 0 } class ProactiveClosureSerializationSuite extends FunSuite with SharedSparkContext { - def fixture = (sc.parallelize(0 until 1000).map(_.toString), new UnserializableClass) + def fixture: (RDD[String], UnserializableClass) = { + (sc.parallelize(0 until 1000).map(_.toString), new UnserializableClass) + } test("throws expected serialization exceptions on actions") { val (data, uc) = fixture diff --git a/core/src/test/scala/org/apache/spark/serializer/TestSerializer.scala b/core/src/test/scala/org/apache/spark/serializer/TestSerializer.scala index 0ade1bab18d7e..963264cef3a71 100644 --- a/core/src/test/scala/org/apache/spark/serializer/TestSerializer.scala +++ b/core/src/test/scala/org/apache/spark/serializer/TestSerializer.scala @@ -27,7 +27,7 @@ import scala.reflect.ClassTag * A serializer implementation that always return a single element in a deserialization stream. */ class TestSerializer extends Serializer { - override def newInstance() = new TestSerializerInstance + override def newInstance(): TestSerializerInstance = new TestSerializerInstance } @@ -36,7 +36,8 @@ class TestSerializerInstance extends SerializerInstance { override def serializeStream(s: OutputStream): SerializationStream = ??? - override def deserializeStream(s: InputStream) = new TestDeserializationStream + override def deserializeStream(s: InputStream): TestDeserializationStream = + new TestDeserializationStream override def deserialize[T: ClassTag](bytes: ByteBuffer): T = ??? diff --git a/core/src/test/scala/org/apache/spark/shuffle/hash/HashShuffleManagerSuite.scala b/core/src/test/scala/org/apache/spark/shuffle/hash/HashShuffleManagerSuite.scala index 6790388f96603..7d76435cd75e7 100644 --- a/core/src/test/scala/org/apache/spark/shuffle/hash/HashShuffleManagerSuite.scala +++ b/core/src/test/scala/org/apache/spark/shuffle/hash/HashShuffleManagerSuite.scala @@ -54,7 +54,7 @@ class HashShuffleManagerSuite extends FunSuite with LocalSparkContext { sc = new SparkContext("local", "test", conf) val shuffleBlockManager = - SparkEnv.get.shuffleManager.shuffleBlockManager.asInstanceOf[FileShuffleBlockManager] + SparkEnv.get.shuffleManager.shuffleBlockResolver.asInstanceOf[FileShuffleBlockManager] val shuffle1 = shuffleBlockManager.forMapTask(1, 1, 1, new JavaSerializer(conf), new ShuffleWriteMetrics) @@ -85,8 +85,8 @@ class HashShuffleManagerSuite extends FunSuite with LocalSparkContext { // Now comes the test : // Write to shuffle 3; and close it, but before registering it, check if the file lengths for // previous task (forof shuffle1) is the same as 'segments'. Earlier, we were inferring length - // of block based on remaining data in file : which could mess things up when there is concurrent read - // and writes happening to the same shuffle group. + // of block based on remaining data in file : which could mess things up when there is + // concurrent read and writes happening to the same shuffle group. val shuffle3 = shuffleBlockManager.forMapTask(1, 3, 1, new JavaSerializer(testConf), new ShuffleWriteMetrics) diff --git a/core/src/test/scala/org/apache/spark/storage/BlockManagerReplicationSuite.scala b/core/src/test/scala/org/apache/spark/storage/BlockManagerReplicationSuite.scala index c2903c8597997..b4de90b65d545 100644 --- a/core/src/test/scala/org/apache/spark/storage/BlockManagerReplicationSuite.scala +++ b/core/src/test/scala/org/apache/spark/storage/BlockManagerReplicationSuite.scala @@ -22,11 +22,11 @@ import scala.concurrent.duration._ import scala.language.implicitConversions import scala.language.postfixOps -import akka.actor.{ActorSystem, Props} import org.mockito.Mockito.{mock, when} -import org.scalatest.{BeforeAndAfter, FunSuite, Matchers, PrivateMethodTester} +import org.scalatest.{BeforeAndAfter, FunSuite, Matchers} import org.scalatest.concurrent.Eventually._ +import org.apache.spark.rpc.RpcEnv import org.apache.spark.{MapOutputTrackerMaster, SparkConf, SparkContext, SecurityManager} import org.apache.spark.network.BlockTransferService import org.apache.spark.network.nio.NioBlockTransferService @@ -34,13 +34,12 @@ import org.apache.spark.scheduler.LiveListenerBus import org.apache.spark.serializer.KryoSerializer import org.apache.spark.shuffle.hash.HashShuffleManager import org.apache.spark.storage.StorageLevel._ -import org.apache.spark.util.{AkkaUtils, SizeEstimator} /** Testsuite that tests block replication in BlockManager */ class BlockManagerReplicationSuite extends FunSuite with Matchers with BeforeAndAfter { private val conf = new SparkConf(false) - var actorSystem: ActorSystem = null + var rpcEnv: RpcEnv = null var master: BlockManagerMaster = null val securityMgr = new SecurityManager(conf) val mapOutputTracker = new MapOutputTrackerMaster(conf) @@ -61,7 +60,7 @@ class BlockManagerReplicationSuite extends FunSuite with Matchers with BeforeAnd maxMem: Long, name: String = SparkContext.DRIVER_IDENTIFIER): BlockManager = { val transfer = new NioBlockTransferService(conf, securityMgr) - val store = new BlockManager(name, actorSystem, master, serializer, maxMem, conf, + val store = new BlockManager(name, rpcEnv, master, serializer, maxMem, conf, mapOutputTracker, shuffleManager, transfer, securityMgr, 0) store.initialize("app-id") allStores += store @@ -69,12 +68,10 @@ class BlockManagerReplicationSuite extends FunSuite with Matchers with BeforeAnd } before { - val (actorSystem, boundPort) = AkkaUtils.createActorSystem( - "test", "localhost", 0, conf = conf, securityManager = securityMgr) - this.actorSystem = actorSystem + rpcEnv = RpcEnv.create("test", "localhost", 0, conf, securityMgr) conf.set("spark.authenticate", "false") - conf.set("spark.driver.port", boundPort.toString) + conf.set("spark.driver.port", rpcEnv.address.port.toString) conf.set("spark.storage.unrollFraction", "0.4") conf.set("spark.storage.unrollMemoryThreshold", "512") @@ -83,18 +80,17 @@ class BlockManagerReplicationSuite extends FunSuite with Matchers with BeforeAnd // to make cached peers refresh frequently conf.set("spark.storage.cachedPeersTtl", "10") - master = new BlockManagerMaster( - actorSystem.actorOf(Props(new BlockManagerMasterActor(true, conf, new LiveListenerBus))), - conf, true) + master = new BlockManagerMaster(rpcEnv.setupEndpoint("blockmanager", + new BlockManagerMasterEndpoint(rpcEnv, true, conf, new LiveListenerBus)), conf, true) allStores.clear() } after { allStores.foreach { _.stop() } allStores.clear() - actorSystem.shutdown() - actorSystem.awaitTermination() - actorSystem = null + rpcEnv.shutdown() + rpcEnv.awaitTermination() + rpcEnv = null master = null } @@ -262,7 +258,7 @@ class BlockManagerReplicationSuite extends FunSuite with Matchers with BeforeAnd val failableTransfer = mock(classOf[BlockTransferService]) // this wont actually work when(failableTransfer.hostName).thenReturn("some-hostname") when(failableTransfer.port).thenReturn(1000) - val failableStore = new BlockManager("failable-store", actorSystem, master, serializer, + val failableStore = new BlockManager("failable-store", rpcEnv, master, serializer, 10000, conf, mapOutputTracker, shuffleManager, failableTransfer, securityMgr, 0) failableStore.initialize("app-id") allStores += failableStore // so that this gets stopped after test diff --git a/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala b/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala index ecd1cba5b5abe..545722b050ee8 100644 --- a/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala +++ b/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala @@ -19,24 +19,18 @@ package org.apache.spark.storage import java.nio.{ByteBuffer, MappedByteBuffer} import java.util.Arrays -import java.util.concurrent.TimeUnit import scala.collection.mutable.ArrayBuffer -import scala.concurrent.Await import scala.concurrent.duration._ import scala.language.implicitConversions import scala.language.postfixOps -import akka.actor._ -import akka.pattern.ask -import akka.util.Timeout - import org.mockito.Mockito.{mock, when} - import org.scalatest._ import org.scalatest.concurrent.Eventually._ import org.scalatest.concurrent.Timeouts._ +import org.apache.spark.rpc.RpcEnv import org.apache.spark.{MapOutputTrackerMaster, SparkConf, SparkContext, SecurityManager} import org.apache.spark.executor.DataReadMethod import org.apache.spark.network.nio.NioBlockTransferService @@ -53,7 +47,7 @@ class BlockManagerSuite extends FunSuite with Matchers with BeforeAndAfterEach private val conf = new SparkConf(false) var store: BlockManager = null var store2: BlockManager = null - var actorSystem: ActorSystem = null + var rpcEnv: RpcEnv = null var master: BlockManagerMaster = null conf.set("spark.authenticate", "false") val securityMgr = new SecurityManager(conf) @@ -66,34 +60,31 @@ class BlockManagerSuite extends FunSuite with Matchers with BeforeAndAfterEach // Implicitly convert strings to BlockIds for test clarity. implicit def StringToBlockId(value: String): BlockId = new TestBlockId(value) - def rdd(rddId: Int, splitId: Int) = RDDBlockId(rddId, splitId) + def rdd(rddId: Int, splitId: Int): RDDBlockId = RDDBlockId(rddId, splitId) private def makeBlockManager( maxMem: Long, name: String = SparkContext.DRIVER_IDENTIFIER): BlockManager = { val transfer = new NioBlockTransferService(conf, securityMgr) - val manager = new BlockManager(name, actorSystem, master, serializer, maxMem, conf, + val manager = new BlockManager(name, rpcEnv, master, serializer, maxMem, conf, mapOutputTracker, shuffleManager, transfer, securityMgr, 0) manager.initialize("app-id") manager } override def beforeEach(): Unit = { - val (actorSystem, boundPort) = AkkaUtils.createActorSystem( - "test", "localhost", 0, conf = conf, securityManager = securityMgr) - this.actorSystem = actorSystem + rpcEnv = RpcEnv.create("test", "localhost", 0, conf, securityMgr) // Set the arch to 64-bit and compressedOops to true to get a deterministic test-case System.setProperty("os.arch", "amd64") conf.set("os.arch", "amd64") conf.set("spark.test.useCompressedOops", "true") - conf.set("spark.driver.port", boundPort.toString) + conf.set("spark.driver.port", rpcEnv.address.port.toString) conf.set("spark.storage.unrollFraction", "0.4") conf.set("spark.storage.unrollMemoryThreshold", "512") - master = new BlockManagerMaster( - actorSystem.actorOf(Props(new BlockManagerMasterActor(true, conf, new LiveListenerBus))), - conf, true) + master = new BlockManagerMaster(rpcEnv.setupEndpoint("blockmanager", + new BlockManagerMasterEndpoint(rpcEnv, true, conf, new LiveListenerBus)), conf, true) val initialize = PrivateMethod[Unit]('initialize) SizeEstimator invokePrivate initialize() @@ -108,16 +99,18 @@ class BlockManagerSuite extends FunSuite with Matchers with BeforeAndAfterEach store2.stop() store2 = null } - actorSystem.shutdown() - actorSystem.awaitTermination() - actorSystem = null + rpcEnv.shutdown() + rpcEnv.awaitTermination() + rpcEnv = null master = null } test("StorageLevel object caching") { val level1 = StorageLevel(false, false, false, false, 3) - val level2 = StorageLevel(false, false, false, false, 3) // this should return the same object as level1 - val level3 = StorageLevel(false, false, false, false, 2) // this should return a different object + // this should return the same object as level1 + val level2 = StorageLevel(false, false, false, false, 3) + // this should return a different object + val level3 = StorageLevel(false, false, false, false, 2) assert(level2 === level1, "level2 is not same as level1") assert(level2.eq(level1), "level2 is not the same object as level1") assert(level3 != level1, "level3 is same as level1") @@ -148,6 +141,12 @@ class BlockManagerSuite extends FunSuite with Matchers with BeforeAndAfterEach assert(id2_.eq(id1), "Deserialized id2 is not the same object as original id1") } + test("BlockManagerId.isDriver() backwards-compatibility with legacy driver ids (SPARK-6716)") { + assert(BlockManagerId(SparkContext.DRIVER_IDENTIFIER, "XXX", 1).isDriver) + assert(BlockManagerId(SparkContext.LEGACY_DRIVER_IDENTIFIER, "XXX", 1).isDriver) + assert(!BlockManagerId("notADriverIdentifier", "XXX", 1).isDriver) + } + test("master + 1 manager interaction") { store = makeBlockManager(20000) val a1 = new Array[Byte](4000) @@ -357,10 +356,8 @@ class BlockManagerSuite extends FunSuite with Matchers with BeforeAndAfterEach master.removeExecutor(store.blockManagerId.executorId) assert(master.getLocations("a1").size == 0, "a1 was not removed from master") - implicit val timeout = Timeout(30, TimeUnit.SECONDS) - val reregister = !Await.result( - master.driverActor ? BlockManagerHeartbeat(store.blockManagerId), - timeout.duration).asInstanceOf[Boolean] + val reregister = !master.driverEndpoint.askWithReply[Boolean]( + BlockManagerHeartbeat(store.blockManagerId)) assert(reregister == true) } @@ -785,7 +782,7 @@ class BlockManagerSuite extends FunSuite with Matchers with BeforeAndAfterEach test("block store put failure") { // Use Java serializer so we can create an unserializable error. val transfer = new NioBlockTransferService(conf, securityMgr) - store = new BlockManager(SparkContext.DRIVER_IDENTIFIER, actorSystem, master, + store = new BlockManager(SparkContext.DRIVER_IDENTIFIER, rpcEnv, master, new JavaSerializer(conf), 1200, conf, mapOutputTracker, shuffleManager, transfer, securityMgr, 0) @@ -807,7 +804,7 @@ class BlockManagerSuite extends FunSuite with Matchers with BeforeAndAfterEach // Create a non-trivial (not all zeros) byte array var counter = 0.toByte - def incr = {counter = (counter + 1).toByte; counter;} + def incr: Byte = {counter = (counter + 1).toByte; counter;} val bytes = Array.fill[Byte](1000)(incr) val byteBuffer = ByteBuffer.wrap(bytes) @@ -961,8 +958,10 @@ class BlockManagerSuite extends FunSuite with Matchers with BeforeAndAfterEach store.putIterator("list3", list.iterator, StorageLevel.MEMORY_AND_DISK, tellMaster = true) // getLocations and getBlockStatus should yield the same locations - assert(store.master.getMatchingBlockIds(_.toString.contains("list"), askSlaves = false).size === 3) - assert(store.master.getMatchingBlockIds(_.toString.contains("list1"), askSlaves = false).size === 1) + assert(store.master.getMatchingBlockIds(_.toString.contains("list"), askSlaves = false).size + === 3) + assert(store.master.getMatchingBlockIds(_.toString.contains("list1"), askSlaves = false).size + === 1) // insert some more blocks store.putIterator("newlist1", list.iterator, StorageLevel.MEMORY_AND_DISK, tellMaster = true) @@ -970,8 +969,10 @@ class BlockManagerSuite extends FunSuite with Matchers with BeforeAndAfterEach store.putIterator("newlist3", list.iterator, StorageLevel.MEMORY_AND_DISK, tellMaster = false) // getLocations and getBlockStatus should yield the same locations - assert(store.master.getMatchingBlockIds(_.toString.contains("newlist"), askSlaves = false).size === 1) - assert(store.master.getMatchingBlockIds(_.toString.contains("newlist"), askSlaves = true).size === 3) + assert(store.master.getMatchingBlockIds(_.toString.contains("newlist"), askSlaves = false).size + === 1) + assert(store.master.getMatchingBlockIds(_.toString.contains("newlist"), askSlaves = true).size + === 3) val blockIds = Seq(RDDBlockId(1, 0), RDDBlockId(1, 1), RDDBlockId(2, 0)) blockIds.foreach { blockId => @@ -1095,8 +1096,8 @@ class BlockManagerSuite extends FunSuite with Matchers with BeforeAndAfterEach val memoryStore = store.memoryStore val smallList = List.fill(40)(new Array[Byte](100)) val bigList = List.fill(40)(new Array[Byte](1000)) - def smallIterator = smallList.iterator.asInstanceOf[Iterator[Any]] - def bigIterator = bigList.iterator.asInstanceOf[Iterator[Any]] + def smallIterator: Iterator[Any] = smallList.iterator.asInstanceOf[Iterator[Any]] + def bigIterator: Iterator[Any] = bigList.iterator.asInstanceOf[Iterator[Any]] assert(memoryStore.currentUnrollMemoryForThisThread === 0) // Unroll with plenty of space. This should succeed and cache both blocks. @@ -1149,8 +1150,8 @@ class BlockManagerSuite extends FunSuite with Matchers with BeforeAndAfterEach val diskStore = store.diskStore val smallList = List.fill(40)(new Array[Byte](100)) val bigList = List.fill(40)(new Array[Byte](1000)) - def smallIterator = smallList.iterator.asInstanceOf[Iterator[Any]] - def bigIterator = bigList.iterator.asInstanceOf[Iterator[Any]] + def smallIterator: Iterator[Any] = smallList.iterator.asInstanceOf[Iterator[Any]] + def bigIterator: Iterator[Any] = bigList.iterator.asInstanceOf[Iterator[Any]] assert(memoryStore.currentUnrollMemoryForThisThread === 0) store.putIterator("b1", smallIterator, memAndDisk) @@ -1192,7 +1193,7 @@ class BlockManagerSuite extends FunSuite with Matchers with BeforeAndAfterEach val memOnly = StorageLevel.MEMORY_ONLY val memoryStore = store.memoryStore val smallList = List.fill(40)(new Array[Byte](100)) - def smallIterator = smallList.iterator.asInstanceOf[Iterator[Any]] + def smallIterator: Iterator[Any] = smallList.iterator.asInstanceOf[Iterator[Any]] assert(memoryStore.currentUnrollMemoryForThisThread === 0) // All unroll memory used is released because unrollSafely returned an array diff --git a/core/src/test/scala/org/apache/spark/storage/LocalDirsSuite.scala b/core/src/test/scala/org/apache/spark/storage/LocalDirsSuite.scala index 82a82e23eecf2..b47157f8331cc 100644 --- a/core/src/test/scala/org/apache/spark/storage/LocalDirsSuite.scala +++ b/core/src/test/scala/org/apache/spark/storage/LocalDirsSuite.scala @@ -47,7 +47,7 @@ class LocalDirsSuite extends FunSuite with BeforeAndAfter { assert(!new File("/NONEXISTENT_DIR").exists()) // SPARK_LOCAL_DIRS is a valid directory: class MySparkConf extends SparkConf(false) { - override def getenv(name: String) = { + override def getenv(name: String): String = { if (name == "SPARK_LOCAL_DIRS") System.getProperty("java.io.tmpdir") else super.getenv(name) } diff --git a/core/src/test/scala/org/apache/spark/ui/UISeleniumSuite.scala b/core/src/test/scala/org/apache/spark/ui/UISeleniumSuite.scala index 6c67ade41fe53..cdd0b2898fe3b 100644 --- a/core/src/test/scala/org/apache/spark/ui/UISeleniumSuite.scala +++ b/core/src/test/scala/org/apache/spark/ui/UISeleniumSuite.scala @@ -161,7 +161,7 @@ class UISeleniumSuite extends FunSuite with WebBrowser with Matchers with Before new SparkContext(conf) } - def hasKillLink = find(className("kill-link")).isDefined + def hasKillLink: Boolean = find(className("kill-link")).isDefined def runSlowJob(sc: SparkContext) { sc.parallelize(1 to 10).map{x => Thread.sleep(10000); x}.countAsync() } diff --git a/core/src/test/scala/org/apache/spark/ui/jobs/JobProgressListenerSuite.scala b/core/src/test/scala/org/apache/spark/ui/jobs/JobProgressListenerSuite.scala index c0c28cb60e21d..21d8267114133 100644 --- a/core/src/test/scala/org/apache/spark/ui/jobs/JobProgressListenerSuite.scala +++ b/core/src/test/scala/org/apache/spark/ui/jobs/JobProgressListenerSuite.scala @@ -269,7 +269,7 @@ class JobProgressListenerSuite extends FunSuite with LocalSparkContext with Matc val taskType = Utils.getFormattedClassName(new ShuffleMapTask(0)) val execId = "exe-1" - def makeTaskMetrics(base: Int) = { + def makeTaskMetrics(base: Int): TaskMetrics = { val taskMetrics = new TaskMetrics() val shuffleReadMetrics = new ShuffleReadMetrics() val shuffleWriteMetrics = new ShuffleWriteMetrics() @@ -291,7 +291,7 @@ class JobProgressListenerSuite extends FunSuite with LocalSparkContext with Matc taskMetrics } - def makeTaskInfo(taskId: Long, finishTime: Int = 0) = { + def makeTaskInfo(taskId: Long, finishTime: Int = 0): TaskInfo = { val taskInfo = new TaskInfo(taskId, 0, 1, 0L, execId, "host1", TaskLocality.NODE_LOCAL, false) taskInfo.finishTime = finishTime diff --git a/core/src/test/scala/org/apache/spark/ui/storage/StorageTabSuite.scala b/core/src/test/scala/org/apache/spark/ui/storage/StorageTabSuite.scala index e1bc1379b5d80..3744e479d2f05 100644 --- a/core/src/test/scala/org/apache/spark/ui/storage/StorageTabSuite.scala +++ b/core/src/test/scala/org/apache/spark/ui/storage/StorageTabSuite.scala @@ -107,7 +107,8 @@ class StorageTabSuite extends FunSuite with BeforeAndAfter { val myRddInfo0 = rddInfo0 val myRddInfo1 = rddInfo1 val myRddInfo2 = rddInfo2 - val stageInfo0 = new StageInfo(0, 0, "0", 100, Seq(myRddInfo0, myRddInfo1, myRddInfo2), "details") + val stageInfo0 = new StageInfo( + 0, 0, "0", 100, Seq(myRddInfo0, myRddInfo1, myRddInfo2), "details") bus.postToAll(SparkListenerBlockManagerAdded(1L, bm1, 1000L)) bus.postToAll(SparkListenerStageSubmitted(stageInfo0)) assert(storageListener._rddInfoMap.size === 3) diff --git a/core/src/test/scala/org/apache/spark/util/AkkaUtilsSuite.scala b/core/src/test/scala/org/apache/spark/util/AkkaUtilsSuite.scala index 6250d50fb7036..bec79fc4dc8f7 100644 --- a/core/src/test/scala/org/apache/spark/util/AkkaUtilsSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/AkkaUtilsSuite.scala @@ -19,14 +19,11 @@ package org.apache.spark.util import java.util.concurrent.TimeoutException -import scala.concurrent.Await -import scala.util.{Failure, Try} - -import akka.actor._ - +import akka.actor.ActorNotFound import org.scalatest.FunSuite import org.apache.spark._ +import org.apache.spark.rpc.RpcEnv import org.apache.spark.scheduler.MapStatus import org.apache.spark.storage.BlockManagerId import org.apache.spark.SSLSampleConfigs._ @@ -39,39 +36,37 @@ class AkkaUtilsSuite extends FunSuite with LocalSparkContext with ResetSystemPro test("remote fetch security bad password") { val conf = new SparkConf + conf.set("spark.rpc", "akka") conf.set("spark.authenticate", "true") conf.set("spark.authenticate.secret", "good") val securityManager = new SecurityManager(conf) val hostname = "localhost" - val (actorSystem, boundPort) = AkkaUtils.createActorSystem("spark", hostname, 0, - conf = conf, securityManager = securityManager) - System.setProperty("spark.hostPort", hostname + ":" + boundPort) + val rpcEnv = RpcEnv.create("spark", hostname, 0, conf, securityManager) + System.setProperty("spark.hostPort", rpcEnv.address.hostPort) assert(securityManager.isAuthenticationEnabled() === true) val masterTracker = new MapOutputTrackerMaster(conf) - masterTracker.trackerActor = actorSystem.actorOf( - Props(new MapOutputTrackerMasterActor(masterTracker, conf)), "MapOutputTracker") + masterTracker.trackerEndpoint = rpcEnv.setupEndpoint(MapOutputTracker.ENDPOINT_NAME, + new MapOutputTrackerMasterEndpoint(rpcEnv, masterTracker, conf)) val badconf = new SparkConf + badconf.set("spark.rpc", "akka") badconf.set("spark.authenticate", "true") badconf.set("spark.authenticate.secret", "bad") val securityManagerBad = new SecurityManager(badconf) assert(securityManagerBad.isAuthenticationEnabled() === true) - val (slaveSystem, _) = AkkaUtils.createActorSystem("spark-slave", hostname, 0, - conf = conf, securityManager = securityManagerBad) + val slaveRpcEnv = RpcEnv.create("spark-slave", hostname, 0, conf, securityManagerBad) val slaveTracker = new MapOutputTrackerWorker(conf) - val selection = slaveSystem.actorSelection( - AkkaUtils.address(AkkaUtils.protocol(slaveSystem), "spark", "localhost", boundPort, "MapOutputTracker")) - val timeout = AkkaUtils.lookupTimeout(conf) intercept[akka.actor.ActorNotFound] { - slaveTracker.trackerActor = Await.result(selection.resolveOne(timeout), timeout) + slaveTracker.trackerEndpoint = + slaveRpcEnv.setupEndpointRef("spark", rpcEnv.address, MapOutputTracker.ENDPOINT_NAME) } - actorSystem.shutdown() - slaveSystem.shutdown() + rpcEnv.shutdown() + slaveRpcEnv.shutdown() } test("remote fetch security off") { @@ -81,28 +76,24 @@ class AkkaUtilsSuite extends FunSuite with LocalSparkContext with ResetSystemPro val securityManager = new SecurityManager(conf) val hostname = "localhost" - val (actorSystem, boundPort) = AkkaUtils.createActorSystem("spark", hostname, 0, - conf = conf, securityManager = securityManager) - System.setProperty("spark.hostPort", hostname + ":" + boundPort) + val rpcEnv = RpcEnv.create("spark", hostname, 0, conf, securityManager) + System.setProperty("spark.hostPort", rpcEnv.address.hostPort) assert(securityManager.isAuthenticationEnabled() === false) val masterTracker = new MapOutputTrackerMaster(conf) - masterTracker.trackerActor = actorSystem.actorOf( - Props(new MapOutputTrackerMasterActor(masterTracker, conf)), "MapOutputTracker") + masterTracker.trackerEndpoint = rpcEnv.setupEndpoint(MapOutputTracker.ENDPOINT_NAME, + new MapOutputTrackerMasterEndpoint(rpcEnv, masterTracker, conf)) val badconf = new SparkConf badconf.set("spark.authenticate", "false") badconf.set("spark.authenticate.secret", "good") val securityManagerBad = new SecurityManager(badconf) - val (slaveSystem, _) = AkkaUtils.createActorSystem("spark-slave", hostname, 0, - conf = badconf, securityManager = securityManagerBad) + val slaveRpcEnv = RpcEnv.create("spark-slave", hostname, 0, badconf, securityManagerBad) val slaveTracker = new MapOutputTrackerWorker(conf) - val selection = slaveSystem.actorSelection( - AkkaUtils.address(AkkaUtils.protocol(slaveSystem), "spark", "localhost", boundPort, "MapOutputTracker")) - val timeout = AkkaUtils.lookupTimeout(conf) - slaveTracker.trackerActor = Await.result(selection.resolveOne(timeout), timeout) + slaveTracker.trackerEndpoint = + slaveRpcEnv.setupEndpointRef("spark", rpcEnv.address, MapOutputTracker.ENDPOINT_NAME) assert(securityManagerBad.isAuthenticationEnabled() === false) @@ -120,8 +111,8 @@ class AkkaUtilsSuite extends FunSuite with LocalSparkContext with ResetSystemPro assert(slaveTracker.getServerStatuses(10, 0).toSeq === Seq((BlockManagerId("a", "hostA", 1000), size1000))) - actorSystem.shutdown() - slaveSystem.shutdown() + rpcEnv.shutdown() + slaveRpcEnv.shutdown() } test("remote fetch security pass") { @@ -131,15 +122,14 @@ class AkkaUtilsSuite extends FunSuite with LocalSparkContext with ResetSystemPro val securityManager = new SecurityManager(conf) val hostname = "localhost" - val (actorSystem, boundPort) = AkkaUtils.createActorSystem("spark", hostname, 0, - conf = conf, securityManager = securityManager) - System.setProperty("spark.hostPort", hostname + ":" + boundPort) + val rpcEnv = RpcEnv.create("spark", hostname, 0, conf, securityManager) + System.setProperty("spark.hostPort", rpcEnv.address.hostPort) assert(securityManager.isAuthenticationEnabled() === true) val masterTracker = new MapOutputTrackerMaster(conf) - masterTracker.trackerActor = actorSystem.actorOf( - Props(new MapOutputTrackerMasterActor(masterTracker, conf)), "MapOutputTracker") + masterTracker.trackerEndpoint = rpcEnv.setupEndpoint(MapOutputTracker.ENDPOINT_NAME, + new MapOutputTrackerMasterEndpoint(rpcEnv, masterTracker, conf)) val goodconf = new SparkConf goodconf.set("spark.authenticate", "true") @@ -148,13 +138,10 @@ class AkkaUtilsSuite extends FunSuite with LocalSparkContext with ResetSystemPro assert(securityManagerGood.isAuthenticationEnabled() === true) - val (slaveSystem, _) = AkkaUtils.createActorSystem("spark-slave", hostname, 0, - conf = goodconf, securityManager = securityManagerGood) + val slaveRpcEnv =RpcEnv.create("spark-slave", hostname, 0, goodconf, securityManagerGood) val slaveTracker = new MapOutputTrackerWorker(conf) - val selection = slaveSystem.actorSelection( - AkkaUtils.address(AkkaUtils.protocol(slaveSystem), "spark", "localhost", boundPort, "MapOutputTracker")) - val timeout = AkkaUtils.lookupTimeout(conf) - slaveTracker.trackerActor = Await.result(selection.resolveOne(timeout), timeout) + slaveTracker.trackerEndpoint = + slaveRpcEnv.setupEndpointRef("spark", rpcEnv.address, MapOutputTracker.ENDPOINT_NAME) masterTracker.registerShuffle(10, 1) masterTracker.incrementEpoch() @@ -170,47 +157,45 @@ class AkkaUtilsSuite extends FunSuite with LocalSparkContext with ResetSystemPro assert(slaveTracker.getServerStatuses(10, 0).toSeq === Seq((BlockManagerId("a", "hostA", 1000), size1000))) - actorSystem.shutdown() - slaveSystem.shutdown() + rpcEnv.shutdown() + slaveRpcEnv.shutdown() } test("remote fetch security off client") { val conf = new SparkConf + conf.set("spark.rpc", "akka") conf.set("spark.authenticate", "true") conf.set("spark.authenticate.secret", "good") val securityManager = new SecurityManager(conf) val hostname = "localhost" - val (actorSystem, boundPort) = AkkaUtils.createActorSystem("spark", hostname, 0, - conf = conf, securityManager = securityManager) - System.setProperty("spark.hostPort", hostname + ":" + boundPort) + val rpcEnv = RpcEnv.create("spark", hostname, 0, conf, securityManager) + System.setProperty("spark.hostPort", rpcEnv.address.hostPort) assert(securityManager.isAuthenticationEnabled() === true) val masterTracker = new MapOutputTrackerMaster(conf) - masterTracker.trackerActor = actorSystem.actorOf( - Props(new MapOutputTrackerMasterActor(masterTracker, conf)), "MapOutputTracker") + masterTracker.trackerEndpoint = rpcEnv.setupEndpoint(MapOutputTracker.ENDPOINT_NAME, + new MapOutputTrackerMasterEndpoint(rpcEnv, masterTracker, conf)) val badconf = new SparkConf + badconf.set("spark.rpc", "akka") badconf.set("spark.authenticate", "false") badconf.set("spark.authenticate.secret", "bad") val securityManagerBad = new SecurityManager(badconf) assert(securityManagerBad.isAuthenticationEnabled() === false) - val (slaveSystem, _) = AkkaUtils.createActorSystem("spark-slave", hostname, 0, - conf = badconf, securityManager = securityManagerBad) + val slaveRpcEnv = RpcEnv.create("spark-slave", hostname, 0, badconf, securityManagerBad) val slaveTracker = new MapOutputTrackerWorker(conf) - val selection = slaveSystem.actorSelection( - AkkaUtils.address(AkkaUtils.protocol(slaveSystem), "spark", "localhost", boundPort, "MapOutputTracker")) - val timeout = AkkaUtils.lookupTimeout(conf) intercept[akka.actor.ActorNotFound] { - slaveTracker.trackerActor = Await.result(selection.resolveOne(timeout), timeout) + slaveTracker.trackerEndpoint = + slaveRpcEnv.setupEndpointRef("spark", rpcEnv.address, MapOutputTracker.ENDPOINT_NAME) } - actorSystem.shutdown() - slaveSystem.shutdown() + rpcEnv.shutdown() + slaveRpcEnv.shutdown() } test("remote fetch ssl on") { @@ -218,26 +203,22 @@ class AkkaUtilsSuite extends FunSuite with LocalSparkContext with ResetSystemPro val securityManager = new SecurityManager(conf) val hostname = "localhost" - val (actorSystem, boundPort) = AkkaUtils.createActorSystem("spark", hostname, 0, - conf = conf, securityManager = securityManager) - System.setProperty("spark.hostPort", hostname + ":" + boundPort) + val rpcEnv = RpcEnv.create("spark", hostname, 0, conf, securityManager) + System.setProperty("spark.hostPort", rpcEnv.address.hostPort) assert(securityManager.isAuthenticationEnabled() === false) val masterTracker = new MapOutputTrackerMaster(conf) - masterTracker.trackerActor = actorSystem.actorOf( - Props(new MapOutputTrackerMasterActor(masterTracker, conf)), "MapOutputTracker") + masterTracker.trackerEndpoint = rpcEnv.setupEndpoint(MapOutputTracker.ENDPOINT_NAME, + new MapOutputTrackerMasterEndpoint(rpcEnv, masterTracker, conf)) val slaveConf = sparkSSLConfig() val securityManagerBad = new SecurityManager(slaveConf) - val (slaveSystem, _) = AkkaUtils.createActorSystem("spark-slave", hostname, 0, - conf = slaveConf, securityManager = securityManagerBad) + val slaveRpcEnv = RpcEnv.create("spark-slaves", hostname, 0, slaveConf, securityManagerBad) val slaveTracker = new MapOutputTrackerWorker(conf) - val selection = slaveSystem.actorSelection( - AkkaUtils.address(AkkaUtils.protocol(slaveSystem), "spark", "localhost", boundPort, "MapOutputTracker")) - val timeout = AkkaUtils.lookupTimeout(conf) - slaveTracker.trackerActor = Await.result(selection.resolveOne(timeout), timeout) + slaveTracker.trackerEndpoint = + slaveRpcEnv.setupEndpointRef("spark", rpcEnv.address, MapOutputTracker.ENDPOINT_NAME) assert(securityManagerBad.isAuthenticationEnabled() === false) @@ -255,8 +236,8 @@ class AkkaUtilsSuite extends FunSuite with LocalSparkContext with ResetSystemPro assert(slaveTracker.getServerStatuses(10, 0).toSeq === Seq((BlockManagerId("a", "hostA", 1000), size1000))) - actorSystem.shutdown() - slaveSystem.shutdown() + rpcEnv.shutdown() + slaveRpcEnv.shutdown() } @@ -267,28 +248,24 @@ class AkkaUtilsSuite extends FunSuite with LocalSparkContext with ResetSystemPro val securityManager = new SecurityManager(conf) val hostname = "localhost" - val (actorSystem, boundPort) = AkkaUtils.createActorSystem("spark", hostname, 0, - conf = conf, securityManager = securityManager) - System.setProperty("spark.hostPort", hostname + ":" + boundPort) + val rpcEnv = RpcEnv.create("spark", hostname, 0, conf, securityManager) + System.setProperty("spark.hostPort", rpcEnv.address.hostPort) assert(securityManager.isAuthenticationEnabled() === true) val masterTracker = new MapOutputTrackerMaster(conf) - masterTracker.trackerActor = actorSystem.actorOf( - Props(new MapOutputTrackerMasterActor(masterTracker, conf)), "MapOutputTracker") + masterTracker.trackerEndpoint = rpcEnv.setupEndpoint(MapOutputTracker.ENDPOINT_NAME, + new MapOutputTrackerMasterEndpoint(rpcEnv, masterTracker, conf)) val slaveConf = sparkSSLConfig() slaveConf.set("spark.authenticate", "true") slaveConf.set("spark.authenticate.secret", "good") val securityManagerBad = new SecurityManager(slaveConf) - val (slaveSystem, _) = AkkaUtils.createActorSystem("spark-slave", hostname, 0, - conf = slaveConf, securityManager = securityManagerBad) + val slaveRpcEnv = RpcEnv.create("spark-slave", hostname, 0, slaveConf, securityManagerBad) val slaveTracker = new MapOutputTrackerWorker(conf) - val selection = slaveSystem.actorSelection( - AkkaUtils.address(AkkaUtils.protocol(slaveSystem), "spark", "localhost", boundPort, "MapOutputTracker")) - val timeout = AkkaUtils.lookupTimeout(conf) - slaveTracker.trackerActor = Await.result(selection.resolveOne(timeout), timeout) + slaveTracker.trackerEndpoint = + slaveRpcEnv.setupEndpointRef("spark", rpcEnv.address, MapOutputTracker.ENDPOINT_NAME) assert(securityManagerBad.isAuthenticationEnabled() === true) @@ -305,45 +282,43 @@ class AkkaUtilsSuite extends FunSuite with LocalSparkContext with ResetSystemPro assert(slaveTracker.getServerStatuses(10, 0).toSeq === Seq((BlockManagerId("a", "hostA", 1000), size1000))) - actorSystem.shutdown() - slaveSystem.shutdown() + rpcEnv.shutdown() + slaveRpcEnv.shutdown() } test("remote fetch ssl on and security enabled - bad credentials") { val conf = sparkSSLConfig() + conf.set("spark.rpc", "akka") conf.set("spark.authenticate", "true") conf.set("spark.authenticate.secret", "good") val securityManager = new SecurityManager(conf) val hostname = "localhost" - val (actorSystem, boundPort) = AkkaUtils.createActorSystem("spark", hostname, 0, - conf = conf, securityManager = securityManager) - System.setProperty("spark.hostPort", hostname + ":" + boundPort) + val rpcEnv = RpcEnv.create("spark", hostname, 0, conf, securityManager) + System.setProperty("spark.hostPort", rpcEnv.address.hostPort) assert(securityManager.isAuthenticationEnabled() === true) val masterTracker = new MapOutputTrackerMaster(conf) - masterTracker.trackerActor = actorSystem.actorOf( - Props(new MapOutputTrackerMasterActor(masterTracker, conf)), "MapOutputTracker") + masterTracker.trackerEndpoint = rpcEnv.setupEndpoint(MapOutputTracker.ENDPOINT_NAME, + new MapOutputTrackerMasterEndpoint(rpcEnv, masterTracker, conf)) val slaveConf = sparkSSLConfig() + slaveConf.set("spark.rpc", "akka") slaveConf.set("spark.authenticate", "true") slaveConf.set("spark.authenticate.secret", "bad") val securityManagerBad = new SecurityManager(slaveConf) - val (slaveSystem, _) = AkkaUtils.createActorSystem("spark-slave", hostname, 0, - conf = slaveConf, securityManager = securityManagerBad) + val slaveRpcEnv = RpcEnv.create("spark-slave", hostname, 0, slaveConf, securityManagerBad) val slaveTracker = new MapOutputTrackerWorker(conf) - val selection = slaveSystem.actorSelection( - AkkaUtils.address(AkkaUtils.protocol(slaveSystem), "spark", "localhost", boundPort, "MapOutputTracker")) - val timeout = AkkaUtils.lookupTimeout(conf) intercept[akka.actor.ActorNotFound] { - slaveTracker.trackerActor = Await.result(selection.resolveOne(timeout), timeout) + slaveTracker.trackerEndpoint = + slaveRpcEnv.setupEndpointRef("spark", rpcEnv.address, MapOutputTracker.ENDPOINT_NAME) } - actorSystem.shutdown() - slaveSystem.shutdown() + rpcEnv.shutdown() + slaveRpcEnv.shutdown() } @@ -352,35 +327,30 @@ class AkkaUtilsSuite extends FunSuite with LocalSparkContext with ResetSystemPro val securityManager = new SecurityManager(conf) val hostname = "localhost" - val (actorSystem, boundPort) = AkkaUtils.createActorSystem("spark", hostname, 0, - conf = conf, securityManager = securityManager) - System.setProperty("spark.hostPort", hostname + ":" + boundPort) + val rpcEnv = RpcEnv.create("spark", hostname, 0, conf, securityManager) + System.setProperty("spark.hostPort", rpcEnv.address.hostPort) assert(securityManager.isAuthenticationEnabled() === false) val masterTracker = new MapOutputTrackerMaster(conf) - masterTracker.trackerActor = actorSystem.actorOf( - Props(new MapOutputTrackerMasterActor(masterTracker, conf)), "MapOutputTracker") + masterTracker.trackerEndpoint = rpcEnv.setupEndpoint(MapOutputTracker.ENDPOINT_NAME, + new MapOutputTrackerMasterEndpoint(rpcEnv, masterTracker, conf)) val slaveConf = sparkSSLConfig() val securityManagerBad = new SecurityManager(slaveConf) - val (slaveSystem, _) = AkkaUtils.createActorSystem("spark-slave", hostname, 0, - conf = slaveConf, securityManager = securityManagerBad) + val slaveRpcEnv = RpcEnv.create("spark-slave", hostname, 0, slaveConf, securityManagerBad) val slaveTracker = new MapOutputTrackerWorker(conf) - val selection = slaveSystem.actorSelection( - AkkaUtils.address(AkkaUtils.protocol(slaveSystem), "spark", "localhost", boundPort, "MapOutputTracker")) - val timeout = AkkaUtils.lookupTimeout(conf) - val result = Try(Await.result(selection.resolveOne(timeout * 2), timeout)) - - result match { - case Failure(ex: ActorNotFound) => - case Failure(ex: TimeoutException) => - case r => fail(s"$r is neither Failure(ActorNotFound) nor Failure(TimeoutException)") + try { + slaveRpcEnv.setupEndpointRef("spark", rpcEnv.address, MapOutputTracker.ENDPOINT_NAME) + fail("should receive either ActorNotFound or TimeoutException") + } catch { + case e: ActorNotFound => + case e: TimeoutException => } - actorSystem.shutdown() - slaveSystem.shutdown() + rpcEnv.shutdown() + slaveRpcEnv.shutdown() } } diff --git a/core/src/test/scala/org/apache/spark/util/ClosureCleanerSuite.scala b/core/src/test/scala/org/apache/spark/util/ClosureCleanerSuite.scala index 054ef54e746a5..c47162779bbba 100644 --- a/core/src/test/scala/org/apache/spark/util/ClosureCleanerSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/ClosureCleanerSuite.scala @@ -83,7 +83,7 @@ object TestObject { class TestClass extends Serializable { var x = 5 - def getX = x + def getX: Int = x def run(): Int = { var nonSer = new NonSerializable @@ -95,7 +95,7 @@ class TestClass extends Serializable { } class TestClassWithoutDefaultConstructor(x: Int) extends Serializable { - def getX = x + def getX: Int = x def run(): Int = { var nonSer = new NonSerializable @@ -164,7 +164,7 @@ object TestObjectWithNesting { } class TestClassWithNesting(val y: Int) extends Serializable { - def getY = y + def getY: Int = y def run(): Int = { var nonSer = new NonSerializable diff --git a/core/src/test/scala/org/apache/spark/util/EventLoopSuite.scala b/core/src/test/scala/org/apache/spark/util/EventLoopSuite.scala index 1026cb2aa7cae..47b535206c949 100644 --- a/core/src/test/scala/org/apache/spark/util/EventLoopSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/EventLoopSuite.scala @@ -203,4 +203,76 @@ class EventLoopSuite extends FunSuite with Timeouts { assert(!eventLoop.isActive) } } + + test("EventLoop: stop() in onStart should call onStop") { + @volatile var onStopCalled: Boolean = false + val eventLoop = new EventLoop[Int]("test") { + + override def onStart(): Unit = { + stop() + } + + override def onReceive(event: Int): Unit = { + } + + override def onError(e: Throwable): Unit = { + } + + override def onStop(): Unit = { + onStopCalled = true + } + } + eventLoop.start() + eventually(timeout(5 seconds), interval(5 millis)) { + assert(!eventLoop.isActive) + } + assert(onStopCalled) + } + + test("EventLoop: stop() in onReceive should call onStop") { + @volatile var onStopCalled: Boolean = false + val eventLoop = new EventLoop[Int]("test") { + + override def onReceive(event: Int): Unit = { + stop() + } + + override def onError(e: Throwable): Unit = { + } + + override def onStop(): Unit = { + onStopCalled = true + } + } + eventLoop.start() + eventLoop.post(1) + eventually(timeout(5 seconds), interval(5 millis)) { + assert(!eventLoop.isActive) + } + assert(onStopCalled) + } + + test("EventLoop: stop() in onError should call onStop") { + @volatile var onStopCalled: Boolean = false + val eventLoop = new EventLoop[Int]("test") { + + override def onReceive(event: Int): Unit = { + throw new RuntimeException("Oops") + } + + override def onError(e: Throwable): Unit = { + stop() + } + + override def onStop(): Unit = { + onStopCalled = true + } + } + eventLoop.start() + eventLoop.post(1) + eventually(timeout(5 seconds), interval(5 millis)) { + assert(!eventLoop.isActive) + } + assert(onStopCalled) + } } diff --git a/core/src/test/scala/org/apache/spark/util/FileAppenderSuite.scala b/core/src/test/scala/org/apache/spark/util/FileAppenderSuite.scala index 43b6a405cb68c..c05317534cddf 100644 --- a/core/src/test/scala/org/apache/spark/util/FileAppenderSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/FileAppenderSuite.scala @@ -109,7 +109,8 @@ class FileAppenderSuite extends FunSuite with BeforeAndAfter with Logging { // verify whether the earliest file has been deleted val rolledOverFiles = allGeneratedFiles.filter { _ != testFile.toString }.toArray.sorted - logInfo(s"All rolled over files generated:${rolledOverFiles.size}\n" + rolledOverFiles.mkString("\n")) + logInfo(s"All rolled over files generated:${rolledOverFiles.size}\n" + + rolledOverFiles.mkString("\n")) assert(rolledOverFiles.size > 2) val earliestRolledOverFile = rolledOverFiles.head val existingRolledOverFiles = RollingFileAppender.getSortedRolledOverFiles( @@ -135,7 +136,7 @@ class FileAppenderSuite extends FunSuite with BeforeAndAfter with Logging { val testOutputStream = new PipedOutputStream() val testInputStream = new PipedInputStream(testOutputStream) val appender = FileAppender(testInputStream, testFile, conf) - //assert(appender.getClass === classTag[ExpectedAppender].getClass) + // assert(appender.getClass === classTag[ExpectedAppender].getClass) assert(appender.getClass.getSimpleName === classTag[ExpectedAppender].runtimeClass.getSimpleName) if (appender.isInstanceOf[RollingFileAppender]) { @@ -153,9 +154,11 @@ class FileAppenderSuite extends FunSuite with BeforeAndAfter with Logging { import RollingFileAppender._ - def rollingStrategy(strategy: String) = Seq(STRATEGY_PROPERTY -> strategy) - def rollingSize(size: String) = Seq(SIZE_PROPERTY -> size) - def rollingInterval(interval: String) = Seq(INTERVAL_PROPERTY -> interval) + def rollingStrategy(strategy: String): Seq[(String, String)] = + Seq(STRATEGY_PROPERTY -> strategy) + def rollingSize(size: String): Seq[(String, String)] = Seq(SIZE_PROPERTY -> size) + def rollingInterval(interval: String): Seq[(String, String)] = + Seq(INTERVAL_PROPERTY -> interval) val msInDay = 24 * 60 * 60 * 1000L val msInHour = 60 * 60 * 1000L diff --git a/core/src/test/scala/org/apache/spark/util/NextIteratorSuite.scala b/core/src/test/scala/org/apache/spark/util/NextIteratorSuite.scala index 72e81f3f1a884..403dcb03bd6e5 100644 --- a/core/src/test/scala/org/apache/spark/util/NextIteratorSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/NextIteratorSuite.scala @@ -71,7 +71,7 @@ class NextIteratorSuite extends FunSuite with Matchers { class StubIterator(ints: Buffer[Int]) extends NextIterator[Int] { var closeCalled = 0 - override def getNext() = { + override def getNext(): Int = { if (ints.size == 0) { finished = true 0 diff --git a/core/src/test/scala/org/apache/spark/util/SizeEstimatorSuite.scala b/core/src/test/scala/org/apache/spark/util/SizeEstimatorSuite.scala index 7424c2e91d4f2..67a9f75ff2187 100644 --- a/core/src/test/scala/org/apache/spark/util/SizeEstimatorSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/SizeEstimatorSuite.scala @@ -98,8 +98,10 @@ class SizeEstimatorSuite // If an array contains the *same* element many times, we should only count it once. val d1 = new DummyClass1 - assertResult(72)(SizeEstimator.estimate(Array.fill(10)(d1))) // 10 pointers plus 8-byte object - assertResult(432)(SizeEstimator.estimate(Array.fill(100)(d1))) // 100 pointers plus 8-byte object + // 10 pointers plus 8-byte object + assertResult(72)(SizeEstimator.estimate(Array.fill(10)(d1))) + // 100 pointers plus 8-byte object + assertResult(432)(SizeEstimator.estimate(Array.fill(100)(d1))) // Same thing with huge array containing the same element many times. Note that this won't // return exactly 4032 because it can't tell that *all* the elements will equal the first diff --git a/core/src/test/scala/org/apache/spark/util/TimeStampedHashMapSuite.scala b/core/src/test/scala/org/apache/spark/util/TimeStampedHashMapSuite.scala index c1c605cdb487b..8b72fe665c214 100644 --- a/core/src/test/scala/org/apache/spark/util/TimeStampedHashMapSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/TimeStampedHashMapSuite.scala @@ -63,7 +63,7 @@ class TimeStampedHashMapSuite extends FunSuite { assert(map1.getTimestamp("k1").get < threshTime1) assert(map1.getTimestamp("k2").isDefined) assert(map1.getTimestamp("k2").get >= threshTime1) - map1.clearOldValues(threshTime1) //should only clear k1 + map1.clearOldValues(threshTime1) // should only clear k1 assert(map1.get("k1") === None) assert(map1.get("k2").isDefined) } @@ -93,7 +93,7 @@ class TimeStampedHashMapSuite extends FunSuite { assert(map1.getTimestamp("k1").get < threshTime1) assert(map1.getTimestamp("k2").isDefined) assert(map1.getTimestamp("k2").get >= threshTime1) - map1.clearOldValues(threshTime1) //should only clear k1 + map1.clearOldValues(threshTime1) // should only clear k1 assert(map1.get("k1") === None) assert(map1.get("k2").isDefined) } diff --git a/core/src/test/scala/org/apache/spark/util/UtilsSuite.scala b/core/src/test/scala/org/apache/spark/util/UtilsSuite.scala index 5d93086082189..449fb87f111c4 100644 --- a/core/src/test/scala/org/apache/spark/util/UtilsSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/UtilsSuite.scala @@ -106,7 +106,7 @@ class UtilsSuite extends FunSuite with ResetSystemProperties { val second = 1000 val minute = second * 60 val hour = minute * 60 - def str = Utils.msDurationToString(_) + def str: (Long) => String = Utils.msDurationToString(_) val sep = new DecimalFormatSymbols(Locale.getDefault()).getDecimalSeparator() @@ -199,7 +199,8 @@ class UtilsSuite extends FunSuite with ResetSystemProperties { test("doesDirectoryContainFilesNewerThan") { // create some temporary directories and files val parent: File = Utils.createTempDir() - val child1: File = Utils.createTempDir(parent.getCanonicalPath) // The parent directory has two child directories + // The parent directory has two child directories + val child1: File = Utils.createTempDir(parent.getCanonicalPath) val child2: File = Utils.createTempDir(parent.getCanonicalPath) val child3: File = Utils.createTempDir(child1.getCanonicalPath) // set the last modified time of child1 to 30 secs old diff --git a/core/src/test/scala/org/apache/spark/util/VectorSuite.scala b/core/src/test/scala/org/apache/spark/util/VectorSuite.scala index 794a55d61750b..ce2968728a996 100644 --- a/core/src/test/scala/org/apache/spark/util/VectorSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/VectorSuite.scala @@ -27,7 +27,7 @@ import org.scalatest.FunSuite @deprecated("suppress compile time deprecation warning", "1.0.0") class VectorSuite extends FunSuite { - def verifyVector(vector: Vector, expectedLength: Int) = { + def verifyVector(vector: Vector, expectedLength: Int): Unit = { assert(vector.length == expectedLength) assert(vector.elements.min > 0.0) assert(vector.elements.max < 1.0) diff --git a/core/src/test/scala/org/apache/spark/util/collection/ExternalAppendOnlyMapSuite.scala b/core/src/test/scala/org/apache/spark/util/collection/ExternalAppendOnlyMapSuite.scala index 48f79ea651018..dff8f3ddc816f 100644 --- a/core/src/test/scala/org/apache/spark/util/collection/ExternalAppendOnlyMapSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/collection/ExternalAppendOnlyMapSuite.scala @@ -185,7 +185,7 @@ class ExternalAppendOnlyMapSuite extends FunSuite with LocalSparkContext { // reduceByKey val rdd = sc.parallelize(1 to 10).map(i => (i%2, 1)) - val result1 = rdd.reduceByKey(_+_).collect() + val result1 = rdd.reduceByKey(_ + _).collect() assert(result1.toSet === Set[(Int, Int)]((0, 5), (1, 5))) // groupByKey diff --git a/core/src/test/scala/org/apache/spark/util/collection/ExternalSorterSuite.scala b/core/src/test/scala/org/apache/spark/util/collection/ExternalSorterSuite.scala index 72d96798b1141..9ff067f86af44 100644 --- a/core/src/test/scala/org/apache/spark/util/collection/ExternalSorterSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/collection/ExternalSorterSuite.scala @@ -553,10 +553,10 @@ class ExternalSorterSuite extends FunSuite with LocalSparkContext with PrivateMe conf.set("spark.shuffle.memoryFraction", "0.001") sc = new SparkContext("local-cluster[1,1,512]", "test", conf) - def createCombiner(i: String) = ArrayBuffer[String](i) - def mergeValue(buffer: ArrayBuffer[String], i: String) = buffer += i - def mergeCombiners(buffer1: ArrayBuffer[String], buffer2: ArrayBuffer[String]) = - buffer1 ++= buffer2 + def createCombiner(i: String): ArrayBuffer[String] = ArrayBuffer[String](i) + def mergeValue(buffer: ArrayBuffer[String], i: String): ArrayBuffer[String] = buffer += i + def mergeCombiners(buffer1: ArrayBuffer[String], buffer2: ArrayBuffer[String]) + : ArrayBuffer[String] = buffer1 ++= buffer2 val agg = new Aggregator[String, String, ArrayBuffer[String]]( createCombiner _, mergeValue _, mergeCombiners _) @@ -633,14 +633,17 @@ class ExternalSorterSuite extends FunSuite with LocalSparkContext with PrivateMe conf.set("spark.shuffle.memoryFraction", "0.001") sc = new SparkContext("local-cluster[1,1,512]", "test", conf) - def createCombiner(i: Int) = ArrayBuffer[Int](i) - def mergeValue(buffer: ArrayBuffer[Int], i: Int) = buffer += i - def mergeCombiners(buf1: ArrayBuffer[Int], buf2: ArrayBuffer[Int]) = buf1 ++= buf2 + def createCombiner(i: Int): ArrayBuffer[Int] = ArrayBuffer[Int](i) + def mergeValue(buffer: ArrayBuffer[Int], i: Int): ArrayBuffer[Int] = buffer += i + def mergeCombiners(buf1: ArrayBuffer[Int], buf2: ArrayBuffer[Int]): ArrayBuffer[Int] = { + buf1 ++= buf2 + } val agg = new Aggregator[Int, Int, ArrayBuffer[Int]](createCombiner, mergeValue, mergeCombiners) val sorter = new ExternalSorter[Int, Int, ArrayBuffer[Int]](Some(agg), None, None, None) - sorter.insertAll((1 to 100000).iterator.map(i => (i, i)) ++ Iterator((Int.MaxValue, Int.MaxValue))) + sorter.insertAll( + (1 to 100000).iterator.map(i => (i, i)) ++ Iterator((Int.MaxValue, Int.MaxValue))) val it = sorter.iterator while (it.hasNext) { @@ -654,9 +657,10 @@ class ExternalSorterSuite extends FunSuite with LocalSparkContext with PrivateMe conf.set("spark.shuffle.memoryFraction", "0.001") sc = new SparkContext("local-cluster[1,1,512]", "test", conf) - def createCombiner(i: String) = ArrayBuffer[String](i) - def mergeValue(buffer: ArrayBuffer[String], i: String) = buffer += i - def mergeCombiners(buf1: ArrayBuffer[String], buf2: ArrayBuffer[String]) = buf1 ++= buf2 + def createCombiner(i: String): ArrayBuffer[String] = ArrayBuffer[String](i) + def mergeValue(buffer: ArrayBuffer[String], i: String): ArrayBuffer[String] = buffer += i + def mergeCombiners(buf1: ArrayBuffer[String], buf2: ArrayBuffer[String]): ArrayBuffer[String] = + buf1 ++= buf2 val agg = new Aggregator[String, String, ArrayBuffer[String]]( createCombiner, mergeValue, mergeCombiners) @@ -720,7 +724,7 @@ class ExternalSorterSuite extends FunSuite with LocalSparkContext with PrivateMe // Using wrongOrdering to show integer overflow introduced exception. val rand = new Random(100L) val wrongOrdering = new Ordering[String] { - override def compare(a: String, b: String) = { + override def compare(a: String, b: String): Int = { val h1 = if (a == null) 0 else a.hashCode() val h2 = if (b == null) 0 else b.hashCode() h1 - h2 @@ -742,9 +746,10 @@ class ExternalSorterSuite extends FunSuite with LocalSparkContext with PrivateMe // Using aggregation and external spill to make sure ExternalSorter using // partitionKeyComparator. - def createCombiner(i: String) = ArrayBuffer(i) - def mergeValue(c: ArrayBuffer[String], i: String) = c += i - def mergeCombiners(c1: ArrayBuffer[String], c2: ArrayBuffer[String]) = c1 ++= c2 + def createCombiner(i: String): ArrayBuffer[String] = ArrayBuffer(i) + def mergeValue(c: ArrayBuffer[String], i: String): ArrayBuffer[String] = c += i + def mergeCombiners(c1: ArrayBuffer[String], c2: ArrayBuffer[String]): ArrayBuffer[String] = + c1 ++= c2 val agg = new Aggregator[String, String, ArrayBuffer[String]]( createCombiner, mergeValue, mergeCombiners) diff --git a/core/src/test/scala/org/apache/spark/util/random/XORShiftRandomSuite.scala b/core/src/test/scala/org/apache/spark/util/random/XORShiftRandomSuite.scala index ef7178bcdf5c2..03f5f2d1b8528 100644 --- a/core/src/test/scala/org/apache/spark/util/random/XORShiftRandomSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/random/XORShiftRandomSuite.scala @@ -28,7 +28,7 @@ import scala.language.reflectiveCalls class XORShiftRandomSuite extends FunSuite with Matchers { - def fixture = new { + def fixture: Object {val seed: Long; val hundMil: Int; val xorRand: XORShiftRandom} = new { val seed = 1L val xorRand = new XORShiftRandom(seed) val hundMil = 1e8.toInt diff --git a/dev/check-license b/dev/check-license index 39943f882b6ca..10740cfdc5242 100755 --- a/dev/check-license +++ b/dev/check-license @@ -24,29 +24,27 @@ acquire_rat_jar () { JAR="$rat_jar" - if [[ ! -f "$rat_jar" ]]; then - # Download rat launch jar if it hasn't been downloaded yet - if [ ! -f "$JAR" ]; then - # Download - printf "Attempting to fetch rat\n" - JAR_DL="${JAR}.part" - if [ $(command -v curl) ]; then - curl -L --silent "${URL}" > "$JAR_DL" && mv "$JAR_DL" "$JAR" - elif [ $(command -v wget) ]; then - wget --quiet ${URL} -O "$JAR_DL" && mv "$JAR_DL" "$JAR" - else - printf "You do not have curl or wget installed, please install rat manually.\n" - exit -1 - fi - fi - - unzip -tq $JAR &> /dev/null - if [ $? -ne 0 ]; then - # We failed to download - printf "Our attempt to download rat locally to ${JAR} failed. Please install rat manually.\n" + # Download rat launch jar if it hasn't been downloaded yet + if [ ! -f "$JAR" ]; then + # Download + printf "Attempting to fetch rat\n" + JAR_DL="${JAR}.part" + if [ $(command -v curl) ]; then + curl -L --silent "${URL}" > "$JAR_DL" && mv "$JAR_DL" "$JAR" + elif [ $(command -v wget) ]; then + wget --quiet ${URL} -O "$JAR_DL" && mv "$JAR_DL" "$JAR" + else + printf "You do not have curl or wget installed, please install rat manually.\n" exit -1 fi - printf "Launching rat from ${JAR}\n" + fi + + unzip -tq "$JAR" &> /dev/null + if [ $? -ne 0 ]; then + # We failed to download + rm "$JAR" + printf "Our attempt to download rat locally to ${JAR} failed. Please install rat manually.\n" + exit -1 fi } @@ -71,6 +69,11 @@ mkdir -p "$FWDIR"/lib $java_cmd -jar "$rat_jar" -E "$FWDIR"/.rat-excludes -d "$FWDIR" > rat-results.txt +if [ $? -ne 0 ]; then + echo "RAT exited abnormally" + exit 1 +fi + ERRORS="$(cat rat-results.txt | grep -e "??")" if test ! -z "$ERRORS"; then diff --git a/dev/run-tests b/dev/run-tests index 561d7fc9e7b1f..bb21ab6c9aa04 100755 --- a/dev/run-tests +++ b/dev/run-tests @@ -173,7 +173,7 @@ CURRENT_BLOCK=$BLOCK_BUILD build/mvn $HIVE_BUILD_ARGS clean package -DskipTests else echo -e "q\n" \ - | build/sbt $HIVE_BUILD_ARGS package assembly/assembly \ + | build/sbt $HIVE_BUILD_ARGS package assembly/assembly streaming-kafka-assembly/assembly \ | grep -v -e "info.*Resolving" -e "warn.*Merging" -e "info.*Including" fi } @@ -236,3 +236,18 @@ echo "=========================================================================" CURRENT_BLOCK=$BLOCK_PYSPARK_UNIT_TESTS ./python/run-tests + +echo "" +echo "=========================================================================" +echo "Running SparkR tests" +echo "=========================================================================" + +CURRENT_BLOCK=$BLOCK_SPARKR_UNIT_TESTS + +if [ $(command -v R) ]; then + ./R/install-dev.sh + ./R/run-tests.sh +else + echo "Ignoring SparkR tests as R was not found in PATH" +fi + diff --git a/dev/run-tests-codes.sh b/dev/run-tests-codes.sh index 8ab6db6925d6e..154e01255b2ef 100644 --- a/dev/run-tests-codes.sh +++ b/dev/run-tests-codes.sh @@ -25,3 +25,4 @@ readonly BLOCK_BUILD=14 readonly BLOCK_MIMA=15 readonly BLOCK_SPARK_UNIT_TESTS=16 readonly BLOCK_PYSPARK_UNIT_TESTS=17 +readonly BLOCK_SPARKR_UNIT_TESTS=18 diff --git a/dev/run-tests-jenkins b/dev/run-tests-jenkins index f10aa6b59e1af..f6372835a6dbf 100755 --- a/dev/run-tests-jenkins +++ b/dev/run-tests-jenkins @@ -210,6 +210,8 @@ done failing_test="Spark unit tests" elif [ "$test_result" -eq "$BLOCK_PYSPARK_UNIT_TESTS" ]; then failing_test="PySpark unit tests" + elif [ "$test_result" -eq "$BLOCK_SPARKR_UNIT_TESTS" ]; then + failing_test="SparkR unit tests" else failing_test="some tests" fi diff --git a/dev/scalastyle b/dev/scalastyle index 86919227ed1ab..4e03f89ed5d5d 100755 --- a/dev/scalastyle +++ b/dev/scalastyle @@ -18,9 +18,10 @@ # echo -e "q\n" | build/sbt -Phive -Phive-thriftserver scalastyle > scalastyle.txt +echo -e "q\n" | build/sbt -Phive -Phive-thriftserver test:scalastyle >> scalastyle.txt # Check style with YARN built too -echo -e "q\n" | build/sbt -Pyarn -Phadoop-2.2 -Dhadoop.version=2.2.0 scalastyle \ - >> scalastyle.txt +echo -e "q\n" | build/sbt -Pyarn -Phadoop-2.2 -Dhadoop.version=2.2.0 scalastyle >> scalastyle.txt +echo -e "q\n" | build/sbt -Pyarn -Phadoop-2.2 -Dhadoop.version=2.2.0 test:scalastyle >> scalastyle.txt ERRORS=$(cat scalastyle.txt | awk '{if($1~/error/)print}') rm scalastyle.txt diff --git a/docs/README.md b/docs/README.md index 3773ea25c8b67..5852f972a051d 100644 --- a/docs/README.md +++ b/docs/README.md @@ -58,13 +58,19 @@ phase, use the following sytax: We use Sphinx to generate Python API docs, so you will need to install it by running `sudo pip install sphinx`. -## API Docs (Scaladoc and Sphinx) +## knitr, devtools + +SparkR documentation is written using `roxygen2` and we use `knitr`, `devtools` to generate +documentation. To install these packages you can run `install.packages(c("knitr", "devtools"))` from a +R console. + +## API Docs (Scaladoc, Sphinx, roxygen2) You can build just the Spark scaladoc by running `build/sbt unidoc` from the SPARK_PROJECT_ROOT directory. Similarly, you can build just the PySpark docs by running `make html` from the SPARK_PROJECT_ROOT/python/docs directory. Documentation is only generated for classes that are listed as -public in `__init__.py`. +public in `__init__.py`. The SparkR docs can be built by running SPARK_PROJECT_ROOT/R/create-docs.sh. When you run `jekyll` in the `docs` directory, it will also copy over the scaladoc for the various Spark subprojects into the `docs` directory (and then also into the `_site` directory). We use a @@ -72,5 +78,5 @@ jekyll plugin to run `build/sbt unidoc` before building the site so if you haven may take some time as it generates all of the scaladoc. The jekyll plugin also generates the PySpark docs [Sphinx](http://sphinx-doc.org/). -NOTE: To skip the step of building and copying over the Scala and Python API docs, run `SKIP_API=1 +NOTE: To skip the step of building and copying over the Scala, Python, R API docs, run `SKIP_API=1 jekyll`. diff --git a/docs/_layouts/global.html b/docs/_layouts/global.html index 2e88b3093652d..b92c75f90b11c 100755 --- a/docs/_layouts/global.html +++ b/docs/_layouts/global.html @@ -84,6 +84,7 @@
  • Scala
  • Java
  • Python
  • +
  • R
  • diff --git a/docs/_plugins/copy_api_dirs.rb b/docs/_plugins/copy_api_dirs.rb index 3c626a0b7f54b..0ea3f8eab461b 100644 --- a/docs/_plugins/copy_api_dirs.rb +++ b/docs/_plugins/copy_api_dirs.rb @@ -78,5 +78,18 @@ puts "cp -r python/docs/_build/html/. docs/api/python" cp_r("python/docs/_build/html/.", "docs/api/python") - cd("..") + # Build SparkR API docs + puts "Moving to R directory and building roxygen docs." + cd("R") + puts `./create-docs.sh` + + puts "Moving back into home dir." + cd("../") + + puts "Making directory api/R" + mkdir_p "docs/api/R" + + puts "cp -r R/pkg/html/. docs/api/R" + cp_r("R/pkg/html/.", "docs/api/R") + end diff --git a/docs/cluster-overview.md b/docs/cluster-overview.md index 6a75d5c457f02..7079de546e2f5 100644 --- a/docs/cluster-overview.md +++ b/docs/cluster-overview.md @@ -33,7 +33,11 @@ There are several useful things to note about this architecture: 2. Spark is agnostic to the underlying cluster manager. As long as it can acquire executor processes, and these communicate with each other, it is relatively easy to run it even on a cluster manager that also supports other applications (e.g. Mesos/YARN). -3. Because the driver schedules tasks on the cluster, it should be run close to the worker +3. The driver program must listen for and accept incoming connections from its executors throughout + its lifetime (e.g., see [spark.driver.port and spark.fileserver.port in the network config + section](configuration.html#networking)). As such, the driver program must be network + addressable from the worker nodes. +4. Because the driver schedules tasks on the cluster, it should be run close to the worker nodes, preferably on the same local area network. If you'd like to send requests to the cluster remotely, it's better to open an RPC to the driver and have it submit operations from nearby than to run a driver far away from the worker nodes. diff --git a/docs/img/cluster-overview.png b/docs/img/cluster-overview.png index 368274068e754..317554c5f2a5b 100644 Binary files a/docs/img/cluster-overview.png and b/docs/img/cluster-overview.png differ diff --git a/docs/img/cluster-overview.pptx b/docs/img/cluster-overview.pptx index af3c462cd904d..1b90d7ec5a7ae 100644 Binary files a/docs/img/cluster-overview.pptx and b/docs/img/cluster-overview.pptx differ diff --git a/docs/ml-guide.md b/docs/ml-guide.md index c08c76d226713..771a07183e26f 100644 --- a/docs/ml-guide.md +++ b/docs/ml-guide.md @@ -493,7 +493,7 @@ from pyspark.ml.feature import HashingTF, Tokenizer from pyspark.sql import Row, SQLContext sc = SparkContext(appName="SimpleTextClassificationPipeline") -sqlCtx = SQLContext(sc) +sqlContext = SQLContext(sc) # Prepare training documents, which are labeled. LabeledDocument = Row("id", "text", "label") diff --git a/docs/running-on-yarn.md b/docs/running-on-yarn.md index d9f3eb2b74b18..ed5bb263a5809 100644 --- a/docs/running-on-yarn.md +++ b/docs/running-on-yarn.md @@ -87,7 +87,8 @@ Most of the configs are the same for Spark on YARN as for other deployment modes spark.yarn.historyServer.address (none) - The address of the Spark history server (i.e. host.com:18080). The address should not contain a scheme (http://). Defaults to not being set since the history server is an optional service. This address is given to the YARN ResourceManager when the Spark application finishes to link the application from the ResourceManager UI to the Spark history server UI. + The address of the Spark history server (i.e. host.com:18080). The address should not contain a scheme (http://). Defaults to not being set since the history server is an optional service. This address is given to the YARN ResourceManager when the Spark application finishes to link the application from the ResourceManager UI to the Spark history server UI. + For this property, YARN properties can be used as variables, and these are substituted by Spark at runtime. For eg, if the Spark history server runs on the same node as the YARN ResourceManager, it can be set to `${hadoopconf-yarn.resourcemanager.hostname}:18080`. @@ -196,6 +197,15 @@ Most of the configs are the same for Spark on YARN as for other deployment modes It should be no larger than the global number of max attempts in the YARN configuration. + + spark.yarn.submit.waitAppCompletion + true + + In YARN cluster mode, controls whether the client waits to exit until the application completes. + If set to true, the client process will stay alive reporting the application's status. + Otherwise, the client process will exit after submission. + + # Launching Spark on YARN diff --git a/docs/sql-programming-guide.md b/docs/sql-programming-guide.md index 4441d6a000a02..332618edf0c55 100644 --- a/docs/sql-programming-guide.md +++ b/docs/sql-programming-guide.md @@ -1642,7 +1642,7 @@ moved into the udf object in `SQLContext`.
    {% highlight java %} -sqlCtx.udf.register("strLen", (s: String) => s.length()) +sqlContext.udf.register("strLen", (s: String) => s.length()) {% endhighlight %}
    @@ -1650,7 +1650,7 @@ sqlCtx.udf.register("strLen", (s: String) => s.length())
    {% highlight java %} -sqlCtx.udf().register("strLen", (String s) -> { s.length(); }); +sqlContext.udf().register("strLen", (String s) -> { s.length(); }); {% endhighlight %}
    @@ -1784,6 +1784,7 @@ in Hive deployments. **Esoteric Hive Features** + * `UNION` type * Unique join * Column statistics collecting: Spark SQL does not piggyback scans to collect column statistics at diff --git a/ec2/spark_ec2.py b/ec2/spark_ec2.py index 5507a9c5a4733..0c1f24761d0de 100755 --- a/ec2/spark_ec2.py +++ b/ec2/spark_ec2.py @@ -282,6 +282,10 @@ def parse_args(): parser.add_option( "--vpc-id", default=None, help="VPC to launch instances in") + parser.add_option( + "--private-ips", action="store_true", default=False, + help="Use private IPs for instances rather than public if VPC/subnet " + + "requires that.") (opts, args) = parser.parse_args() if len(args) != 2: @@ -707,7 +711,7 @@ def get_instances(group_names): # Deploy configuration files and run setup scripts on a newly launched # or started EC2 cluster. def setup_cluster(conn, master_nodes, slave_nodes, opts, deploy_ssh_key): - master = master_nodes[0].public_dns_name + master = get_dns_name(master_nodes[0], opts.private_ips) if deploy_ssh_key: print "Generating cluster's SSH key on master..." key_setup = """ @@ -719,8 +723,9 @@ def setup_cluster(conn, master_nodes, slave_nodes, opts, deploy_ssh_key): dot_ssh_tar = ssh_read(master, opts, ['tar', 'c', '.ssh']) print "Transferring cluster's SSH key to slaves..." for slave in slave_nodes: - print slave.public_dns_name - ssh_write(slave.public_dns_name, opts, ['tar', 'x'], dot_ssh_tar) + slave_address = get_dns_name(slave, opts.private_ips) + print slave_address + ssh_write(slave_address, opts, ['tar', 'x'], dot_ssh_tar) modules = ['spark', 'ephemeral-hdfs', 'persistent-hdfs', 'mapreduce', 'spark-standalone', 'tachyon'] @@ -809,7 +814,8 @@ def is_cluster_ssh_available(cluster_instances, opts): Check if SSH is available on all the instances in a cluster. """ for i in cluster_instances: - if not is_ssh_available(host=i.ip_address, opts=opts): + dns_name = get_dns_name(i, opts.private_ips) + if not is_ssh_available(host=dns_name, opts=opts): return False else: return True @@ -923,7 +929,7 @@ def get_num_disks(instance_type): # # root_dir should be an absolute path to the directory with the files we want to deploy. def deploy_files(conn, root_dir, opts, master_nodes, slave_nodes, modules): - active_master = master_nodes[0].public_dns_name + active_master = get_dns_name(master_nodes[0], opts.private_ips) num_disks = get_num_disks(opts.instance_type) hdfs_data_dirs = "/mnt/ephemeral-hdfs/data" @@ -948,10 +954,12 @@ def deploy_files(conn, root_dir, opts, master_nodes, slave_nodes, modules): print "Deploying Spark via git hash; Tachyon won't be set up" modules = filter(lambda x: x != "tachyon", modules) + master_addresses = [get_dns_name(i, opts.private_ips) for i in master_nodes] + slave_addresses = [get_dns_name(i, opts.private_ips) for i in slave_nodes] template_vars = { - "master_list": '\n'.join([i.public_dns_name for i in master_nodes]), + "master_list": '\n'.join(master_addresses), "active_master": active_master, - "slave_list": '\n'.join([i.public_dns_name for i in slave_nodes]), + "slave_list": '\n'.join(slave_addresses), "cluster_url": cluster_url, "hdfs_data_dirs": hdfs_data_dirs, "mapred_local_dirs": mapred_local_dirs, @@ -1011,7 +1019,7 @@ def deploy_files(conn, root_dir, opts, master_nodes, slave_nodes, modules): # # root_dir should be an absolute path. def deploy_user_files(root_dir, opts, master_nodes): - active_master = master_nodes[0].public_dns_name + active_master = get_dns_name(master_nodes[0], opts.private_ips) command = [ 'rsync', '-rv', '-e', stringify_command(ssh_command(opts)), @@ -1122,6 +1130,20 @@ def get_partition(total, num_partitions, current_partitions): return num_slaves_this_zone +# Gets the IP address, taking into account the --private-ips flag +def get_ip_address(instance, private_ips=False): + ip = instance.ip_address if not private_ips else \ + instance.private_ip_address + return ip + + +# Gets the DNS name, taking into account the --private-ips flag +def get_dns_name(instance, private_ips=False): + dns = instance.public_dns_name if not private_ips else \ + instance.private_ip_address + return dns + + def real_main(): (opts, action, cluster_name) = parse_args() @@ -1230,7 +1252,7 @@ def real_main(): if any(master_nodes + slave_nodes): print "The following instances will be terminated:" for inst in master_nodes + slave_nodes: - print "> %s" % inst.public_dns_name + print "> %s" % get_dns_name(inst, opts.private_ips) print "ALL DATA ON ALL NODES WILL BE LOST!!" msg = "Are you sure you want to destroy the cluster {c}? (y/N) ".format(c=cluster_name) @@ -1294,13 +1316,17 @@ def real_main(): elif action == "login": (master_nodes, slave_nodes) = get_existing_cluster(conn, opts, cluster_name) - master = master_nodes[0].public_dns_name - print "Logging into master " + master + "..." - proxy_opt = [] - if opts.proxy_port is not None: - proxy_opt = ['-D', opts.proxy_port] - subprocess.check_call( - ssh_command(opts) + proxy_opt + ['-t', '-t', "%s@%s" % (opts.user, master)]) + if not master_nodes[0].public_dns_name and not opts.private_ips: + print "Master has no public DNS name. Maybe you meant to specify " \ + "--private-ips?" + else: + master = get_dns_name(master_nodes[0], opts.private_ips) + print "Logging into master " + master + "..." + proxy_opt = [] + if opts.proxy_port is not None: + proxy_opt = ['-D', opts.proxy_port] + subprocess.check_call( + ssh_command(opts) + proxy_opt + ['-t', '-t', "%s@%s" % (opts.user, master)]) elif action == "reboot-slaves": response = raw_input( @@ -1318,7 +1344,11 @@ def real_main(): elif action == "get-master": (master_nodes, slave_nodes) = get_existing_cluster(conn, opts, cluster_name) - print master_nodes[0].public_dns_name + if not master_nodes[0].public_dns_name and not opts.private_ips: + print "Master has no public DNS name. Maybe you meant to specify " \ + "--private-ips?" + else: + print get_dns_name(master_nodes[0], opts.private_ips) elif action == "stop": response = raw_input( diff --git a/examples/pom.xml b/examples/pom.xml index 7e93f0eec0b91..afd7c6d52f0dd 100644 --- a/examples/pom.xml +++ b/examples/pom.xml @@ -90,6 +90,12 @@ org.apache.spark spark-streaming-zeromq_${scala.binary.version} ${project.version} + + + org.spark-project.protobuf + protobuf-java + + org.apache.hbase @@ -234,6 +240,7 @@ org.apache.commons commons-math3 + provided com.twitter @@ -262,6 +269,22 @@ com.ning compress-lzf + + commons-cli + commons-cli + + + commons-codec + commons-codec + + + commons-lang + commons-lang + + + commons-logging + commons-logging + io.netty netty @@ -270,10 +293,22 @@ jline jline + + net.jpountz.lz4 + lz4 + org.apache.cassandra.deps avro + + org.apache.commons + commons-math3 + + + org.apache.thrift + libthrift + @@ -281,6 +316,17 @@ scopt_${scala.binary.version} 3.2.0 + + + + org.scala-lang + scala-library + provided + + @@ -322,12 +368,6 @@ - - - org.apache.commons.math3 - org.spark-project.commons.math3 - - diff --git a/examples/src/main/java/org/apache/spark/examples/sql/JavaSparkSQL.java b/examples/src/main/java/org/apache/spark/examples/sql/JavaSparkSQL.java index dee794840a3e1..8159ffbe2d269 100644 --- a/examples/src/main/java/org/apache/spark/examples/sql/JavaSparkSQL.java +++ b/examples/src/main/java/org/apache/spark/examples/sql/JavaSparkSQL.java @@ -55,7 +55,7 @@ public void setAge(int age) { public static void main(String[] args) throws Exception { SparkConf sparkConf = new SparkConf().setAppName("JavaSparkSQL"); JavaSparkContext ctx = new JavaSparkContext(sparkConf); - SQLContext sqlCtx = new SQLContext(ctx); + SQLContext sqlContext = new SQLContext(ctx); System.out.println("=== Data source: RDD ==="); // Load a text file and convert each line to a Java Bean. @@ -74,11 +74,11 @@ public Person call(String line) { }); // Apply a schema to an RDD of Java Beans and register it as a table. - DataFrame schemaPeople = sqlCtx.createDataFrame(people, Person.class); + DataFrame schemaPeople = sqlContext.createDataFrame(people, Person.class); schemaPeople.registerTempTable("people"); // SQL can be run over RDDs that have been registered as tables. - DataFrame teenagers = sqlCtx.sql("SELECT name FROM people WHERE age >= 13 AND age <= 19"); + DataFrame teenagers = sqlContext.sql("SELECT name FROM people WHERE age >= 13 AND age <= 19"); // The results of SQL queries are DataFrames and support all the normal RDD operations. // The columns of a row in the result can be accessed by ordinal. @@ -99,12 +99,12 @@ public String call(Row row) { // Read in the parquet file created above. // Parquet files are self-describing so the schema is preserved. // The result of loading a parquet file is also a DataFrame. - DataFrame parquetFile = sqlCtx.parquetFile("people.parquet"); + DataFrame parquetFile = sqlContext.parquetFile("people.parquet"); //Parquet files can also be registered as tables and then used in SQL statements. parquetFile.registerTempTable("parquetFile"); DataFrame teenagers2 = - sqlCtx.sql("SELECT name FROM parquetFile WHERE age >= 13 AND age <= 19"); + sqlContext.sql("SELECT name FROM parquetFile WHERE age >= 13 AND age <= 19"); teenagerNames = teenagers2.toJavaRDD().map(new Function() { @Override public String call(Row row) { @@ -120,7 +120,7 @@ public String call(Row row) { // The path can be either a single text file or a directory storing text files. String path = "examples/src/main/resources/people.json"; // Create a DataFrame from the file(s) pointed by path - DataFrame peopleFromJsonFile = sqlCtx.jsonFile(path); + DataFrame peopleFromJsonFile = sqlContext.jsonFile(path); // Because the schema of a JSON dataset is automatically inferred, to write queries, // it is better to take a look at what is the schema. @@ -133,8 +133,8 @@ public String call(Row row) { // Register this DataFrame as a table. peopleFromJsonFile.registerTempTable("people"); - // SQL statements can be run by using the sql methods provided by sqlCtx. - DataFrame teenagers3 = sqlCtx.sql("SELECT name FROM people WHERE age >= 13 AND age <= 19"); + // SQL statements can be run by using the sql methods provided by sqlContext. + DataFrame teenagers3 = sqlContext.sql("SELECT name FROM people WHERE age >= 13 AND age <= 19"); // The results of SQL queries are DataFrame and support all the normal RDD operations. // The columns of a row in the result can be accessed by ordinal. @@ -151,7 +151,7 @@ public String call(Row row) { List jsonData = Arrays.asList( "{\"name\":\"Yin\",\"address\":{\"city\":\"Columbus\",\"state\":\"Ohio\"}}"); JavaRDD anotherPeopleRDD = ctx.parallelize(jsonData); - DataFrame peopleFromJsonRDD = sqlCtx.jsonRDD(anotherPeopleRDD.rdd()); + DataFrame peopleFromJsonRDD = sqlContext.jsonRDD(anotherPeopleRDD.rdd()); // Take a look at the schema of this new DataFrame. peopleFromJsonRDD.printSchema(); @@ -164,7 +164,7 @@ public String call(Row row) { peopleFromJsonRDD.registerTempTable("people2"); - DataFrame peopleWithCity = sqlCtx.sql("SELECT name, address.city FROM people2"); + DataFrame peopleWithCity = sqlContext.sql("SELECT name, address.city FROM people2"); List nameAndCity = peopleWithCity.toJavaRDD().map(new Function() { @Override public String call(Row row) { diff --git a/examples/src/main/python/ml/simple_text_classification_pipeline.py b/examples/src/main/python/ml/simple_text_classification_pipeline.py index d281f4fa44282..c73edb7fd6b20 100644 --- a/examples/src/main/python/ml/simple_text_classification_pipeline.py +++ b/examples/src/main/python/ml/simple_text_classification_pipeline.py @@ -33,7 +33,7 @@ if __name__ == "__main__": sc = SparkContext(appName="SimpleTextClassificationPipeline") - sqlCtx = SQLContext(sc) + sqlContext = SQLContext(sc) # Prepare training documents, which are labeled. LabeledDocument = Row("id", "text", "label") diff --git a/examples/src/main/python/mllib/dataset_example.py b/examples/src/main/python/mllib/dataset_example.py index b5a70db2b9a3c..fcbf56cbf0c52 100644 --- a/examples/src/main/python/mllib/dataset_example.py +++ b/examples/src/main/python/mllib/dataset_example.py @@ -44,19 +44,19 @@ def summarize(dataset): print >> sys.stderr, "Usage: dataset_example.py " exit(-1) sc = SparkContext(appName="DatasetExample") - sqlCtx = SQLContext(sc) + sqlContext = SQLContext(sc) if len(sys.argv) == 2: input = sys.argv[1] else: input = "data/mllib/sample_libsvm_data.txt" points = MLUtils.loadLibSVMFile(sc, input) - dataset0 = sqlCtx.inferSchema(points).setName("dataset0").cache() + dataset0 = sqlContext.inferSchema(points).setName("dataset0").cache() summarize(dataset0) tempdir = tempfile.NamedTemporaryFile(delete=False).name os.unlink(tempdir) print "Save dataset as a Parquet file to %s." % tempdir dataset0.saveAsParquetFile(tempdir) print "Load it back and summarize it again." - dataset1 = sqlCtx.parquetFile(tempdir).setName("dataset1").cache() + dataset1 = sqlContext.parquetFile(tempdir).setName("dataset1").cache() summarize(dataset1) shutil.rmtree(tempdir) diff --git a/examples/src/main/r/kmeans.R b/examples/src/main/r/kmeans.R new file mode 100644 index 0000000000000..6e6b5cb93789c --- /dev/null +++ b/examples/src/main/r/kmeans.R @@ -0,0 +1,93 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +library(SparkR) + +# Logistic regression in Spark. +# Note: unlike the example in Scala, a point here is represented as a vector of +# doubles. + +parseVectors <- function(lines) { + lines <- strsplit(as.character(lines) , " ", fixed = TRUE) + list(matrix(as.numeric(unlist(lines)), ncol = length(lines[[1]]))) +} + +dist.fun <- function(P, C) { + apply( + C, + 1, + function(x) { + colSums((t(P) - x)^2) + } + ) +} + +closestPoint <- function(P, C) { + max.col(-dist.fun(P, C)) +} +# Main program + +args <- commandArgs(trailing = TRUE) + +if (length(args) != 3) { + print("Usage: kmeans ") + q("no") +} + +sc <- sparkR.init(appName = "RKMeans") +K <- as.integer(args[[2]]) +convergeDist <- as.double(args[[3]]) + +lines <- textFile(sc, args[[1]]) +points <- cache(lapplyPartition(lines, parseVectors)) +# kPoints <- take(points, K) +kPoints <- do.call(rbind, takeSample(points, FALSE, K, 16189L)) +tempDist <- 1.0 + +while (tempDist > convergeDist) { + closest <- lapplyPartition( + lapply(points, + function(p) { + cp <- closestPoint(p, kPoints); + mapply(list, unique(cp), split.data.frame(cbind(1, p), cp), SIMPLIFY=FALSE) + }), + function(x) {do.call(c, x) + }) + + pointStats <- reduceByKey(closest, + function(p1, p2) { + t(colSums(rbind(p1, p2))) + }, + 2L) + + newPoints <- do.call( + rbind, + collect(lapply(pointStats, + function(tup) { + point.sum <- tup[[2]][, -1] + point.count <- tup[[2]][, 1] + point.sum/point.count + }))) + + D <- dist.fun(kPoints, newPoints) + tempDist <- sum(D[cbind(1:3, max.col(-D))]) + kPoints <- newPoints + cat("Finished iteration (delta = ", tempDist, ")\n") +} + +cat("Final centers:\n") +writeLines(unlist(lapply(kPoints, paste, collapse = " "))) diff --git a/examples/src/main/r/linear_solver_mnist.R b/examples/src/main/r/linear_solver_mnist.R new file mode 100644 index 0000000000000..c864a4232d010 --- /dev/null +++ b/examples/src/main/r/linear_solver_mnist.R @@ -0,0 +1,107 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +# Instructions: https://github.com/amplab-extras/SparkR-pkg/wiki/SparkR-Example:-Digit-Recognition-on-EC2 + +library(SparkR) +library(Matrix) + +args <- commandArgs(trailing = TRUE) + +# number of random features; default to 1100 +D <- ifelse(length(args) > 0, as.integer(args[[1]]), 1100) +# number of partitions for training dataset +trainParts <- 12 +# dimension of digits +d <- 784 +# number of test examples +NTrain <- 60000 +# number of training examples +NTest <- 10000 +# scale of features +gamma <- 4e-4 + +sc <- sparkR.init(appName = "SparkR-LinearSolver") + +# You can also use HDFS path to speed things up: +# hdfs:///train-mnist-dense-with-labels.data +file <- textFile(sc, "/data/train-mnist-dense-with-labels.data", trainParts) + +W <- gamma * matrix(nrow=D, ncol=d, data=rnorm(D*d)) +b <- 2 * pi * matrix(nrow=D, ncol=1, data=runif(D)) +broadcastW <- broadcast(sc, W) +broadcastB <- broadcast(sc, b) + +includePackage(sc, Matrix) +numericLines <- lapplyPartitionsWithIndex(file, + function(split, part) { + matList <- sapply(part, function(line) { + as.numeric(strsplit(line, ",", fixed=TRUE)[[1]]) + }, simplify=FALSE) + mat <- Matrix(ncol=d+1, data=unlist(matList, F, F), + sparse=T, byrow=T) + mat + }) + +featureLabels <- cache(lapplyPartition( + numericLines, + function(part) { + label <- part[,1] + mat <- part[,-1] + ones <- rep(1, nrow(mat)) + features <- cos( + mat %*% t(value(broadcastW)) + (matrix(ncol=1, data=ones) %*% t(value(broadcastB)))) + onesMat <- Matrix(ones) + featuresPlus <- cBind(features, onesMat) + labels <- matrix(nrow=nrow(mat), ncol=10, data=-1) + for (i in 1:nrow(mat)) { + labels[i, label[i]] <- 1 + } + list(label=labels, features=featuresPlus) + })) + +FTF <- Reduce("+", collect(lapplyPartition(featureLabels, + function(part) { + t(part$features) %*% part$features + }), flatten=F)) + +FTY <- Reduce("+", collect(lapplyPartition(featureLabels, + function(part) { + t(part$features) %*% part$label + }), flatten=F)) + +# solve for the coefficient matrix +C <- solve(FTF, FTY) + +test <- Matrix(as.matrix(read.csv("/data/test-mnist-dense-with-labels.data", + header=F), sparse=T)) +testData <- test[,-1] +testLabels <- matrix(ncol=1, test[,1]) + +err <- 0 + +# contstruct the feature maps for all examples from this digit +featuresTest <- cos(testData %*% t(value(broadcastW)) + + (matrix(ncol=1, data=rep(1, NTest)) %*% t(value(broadcastB)))) +featuresTest <- cBind(featuresTest, Matrix(rep(1, NTest))) + +# extract the one vs. all assignment +results <- featuresTest %*% C +labelsGot <- apply(results, 1, which.max) +err <- sum(testLabels != labelsGot) / nrow(testLabels) + +cat("\nFinished running. The error rate is: ", err, ".\n") diff --git a/examples/src/main/r/logistic_regression.R b/examples/src/main/r/logistic_regression.R new file mode 100644 index 0000000000000..2a86aa98160d3 --- /dev/null +++ b/examples/src/main/r/logistic_regression.R @@ -0,0 +1,62 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +library(SparkR) + +args <- commandArgs(trailing = TRUE) + +if (length(args) != 3) { + print("Usage: logistic_regression ") + q("no") +} + +# Initialize Spark context +sc <- sparkR.init(appName = "LogisticRegressionR") +iterations <- as.integer(args[[2]]) +D <- as.integer(args[[3]]) + +readPartition <- function(part){ + part = strsplit(part, " ", fixed = T) + list(matrix(as.numeric(unlist(part)), ncol = length(part[[1]]))) +} + +# Read data points and convert each partition to a matrix +points <- cache(lapplyPartition(textFile(sc, args[[1]]), readPartition)) + +# Initialize w to a random value +w <- runif(n=D, min = -1, max = 1) +cat("Initial w: ", w, "\n") + +# Compute logistic regression gradient for a matrix of data points +gradient <- function(partition) { + partition = partition[[1]] + Y <- partition[, 1] # point labels (first column of input file) + X <- partition[, -1] # point coordinates + + # For each point (x, y), compute gradient function + dot <- X %*% w + logit <- 1 / (1 + exp(-Y * dot)) + grad <- t(X) %*% ((logit - 1) * Y) + list(grad) +} + +for (i in 1:iterations) { + cat("On iteration ", i, "\n") + w <- w - reduce(lapplyPartition(points, gradient), "+") +} + +cat("Final w: ", w, "\n") diff --git a/examples/src/main/r/pi.R b/examples/src/main/r/pi.R new file mode 100644 index 0000000000000..aa7a833e147a0 --- /dev/null +++ b/examples/src/main/r/pi.R @@ -0,0 +1,46 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +library(SparkR) + +args <- commandArgs(trailing = TRUE) + +sc <- sparkR.init(appName = "PiR") + +slices <- ifelse(length(args) > 1, as.integer(args[[2]]), 2) + +n <- 100000 * slices + +piFunc <- function(elem) { + rands <- runif(n = 2, min = -1, max = 1) + val <- ifelse((rands[1]^2 + rands[2]^2) < 1, 1.0, 0.0) + val +} + + +piFuncVec <- function(elems) { + message(length(elems)) + rands1 <- runif(n = length(elems), min = -1, max = 1) + rands2 <- runif(n = length(elems), min = -1, max = 1) + val <- ifelse((rands1^2 + rands2^2) < 1, 1.0, 0.0) + sum(val) +} + +rdd <- parallelize(sc, 1:n, slices) +count <- reduce(lapplyPartition(rdd, piFuncVec), sum) +cat("Pi is roughly", 4.0 * count / n, "\n") +cat("Num elements in RDD ", count(rdd), "\n") diff --git a/examples/src/main/r/wordcount.R b/examples/src/main/r/wordcount.R new file mode 100644 index 0000000000000..b734cb0ecf55b --- /dev/null +++ b/examples/src/main/r/wordcount.R @@ -0,0 +1,42 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +library(SparkR) + +args <- commandArgs(trailing = TRUE) + +if (length(args) != 1) { + print("Usage: wordcount ") + q("no") +} + +# Initialize Spark context +sc <- sparkR.init(appName = "RwordCount") +lines <- textFile(sc, args[[1]]) + +words <- flatMap(lines, + function(line) { + strsplit(line, " ")[[1]] + }) +wordCount <- lapply(words, function(word) { list(word, 1L) }) + +counts <- reduceByKey(wordCount, "+", 2L) +output <- collect(counts) + +for (wordcount in output) { + cat(wordcount[[1]], ": ", wordcount[[2]], "\n") +} diff --git a/examples/src/main/scala/org/apache/spark/examples/LocalKMeans.scala b/examples/src/main/scala/org/apache/spark/examples/LocalKMeans.scala index 17624c20cff3d..f73eac1e2b906 100644 --- a/examples/src/main/scala/org/apache/spark/examples/LocalKMeans.scala +++ b/examples/src/main/scala/org/apache/spark/examples/LocalKMeans.scala @@ -40,8 +40,8 @@ object LocalKMeans { val convergeDist = 0.001 val rand = new Random(42) - def generateData = { - def generatePoint(i: Int) = { + def generateData: Array[DenseVector[Double]] = { + def generatePoint(i: Int): DenseVector[Double] = { DenseVector.fill(D){rand.nextDouble * R} } Array.tabulate(N)(generatePoint) diff --git a/examples/src/main/scala/org/apache/spark/examples/LocalLR.scala b/examples/src/main/scala/org/apache/spark/examples/LocalLR.scala index 92a683ad57ea1..a55e0dc8d36c2 100644 --- a/examples/src/main/scala/org/apache/spark/examples/LocalLR.scala +++ b/examples/src/main/scala/org/apache/spark/examples/LocalLR.scala @@ -37,8 +37,8 @@ object LocalLR { case class DataPoint(x: Vector[Double], y: Double) - def generateData = { - def generatePoint(i: Int) = { + def generateData: Array[DataPoint] = { + def generatePoint(i: Int): DataPoint = { val y = if(i % 2 == 0) -1 else 1 val x = DenseVector.fill(D){rand.nextGaussian + y * R} DataPoint(x, y) diff --git a/examples/src/main/scala/org/apache/spark/examples/LogQuery.scala b/examples/src/main/scala/org/apache/spark/examples/LogQuery.scala index 74620ad007d83..32e02eab8b031 100644 --- a/examples/src/main/scala/org/apache/spark/examples/LogQuery.scala +++ b/examples/src/main/scala/org/apache/spark/examples/LogQuery.scala @@ -54,8 +54,8 @@ object LogQuery { // scalastyle:on /** Tracks the total query count and number of aggregate bytes for a particular group. */ class Stats(val count: Int, val numBytes: Int) extends Serializable { - def merge(other: Stats) = new Stats(count + other.count, numBytes + other.numBytes) - override def toString = "bytes=%s\tn=%s".format(numBytes, count) + def merge(other: Stats): Stats = new Stats(count + other.count, numBytes + other.numBytes) + override def toString: String = "bytes=%s\tn=%s".format(numBytes, count) } def extractKey(line: String): (String, String, String) = { diff --git a/examples/src/main/scala/org/apache/spark/examples/SparkLR.scala b/examples/src/main/scala/org/apache/spark/examples/SparkLR.scala index 257a7d29f922a..8c01a60844620 100644 --- a/examples/src/main/scala/org/apache/spark/examples/SparkLR.scala +++ b/examples/src/main/scala/org/apache/spark/examples/SparkLR.scala @@ -42,8 +42,8 @@ object SparkLR { case class DataPoint(x: Vector[Double], y: Double) - def generateData = { - def generatePoint(i: Int) = { + def generateData: Array[DataPoint] = { + def generatePoint(i: Int): DataPoint = { val y = if(i % 2 == 0) -1 else 1 val x = DenseVector.fill(D){rand.nextGaussian + y * R} DataPoint(x, y) diff --git a/examples/src/main/scala/org/apache/spark/examples/SparkTC.scala b/examples/src/main/scala/org/apache/spark/examples/SparkTC.scala index f7f83086df3db..772cd897f5140 100644 --- a/examples/src/main/scala/org/apache/spark/examples/SparkTC.scala +++ b/examples/src/main/scala/org/apache/spark/examples/SparkTC.scala @@ -31,7 +31,7 @@ object SparkTC { val numVertices = 100 val rand = new Random(42) - def generateGraph = { + def generateGraph: Seq[(Int, Int)] = { val edges: mutable.Set[(Int, Int)] = mutable.Set.empty while (edges.size < numEdges) { val from = rand.nextInt(numVertices) diff --git a/examples/src/main/scala/org/apache/spark/examples/bagel/PageRankUtils.scala b/examples/src/main/scala/org/apache/spark/examples/bagel/PageRankUtils.scala index e322d4ce5a745..ab6e63deb3c95 100644 --- a/examples/src/main/scala/org/apache/spark/examples/bagel/PageRankUtils.scala +++ b/examples/src/main/scala/org/apache/spark/examples/bagel/PageRankUtils.scala @@ -90,7 +90,7 @@ class PRMessage() extends Message[String] with Serializable { } class CustomPartitioner(partitions: Int) extends Partitioner { - def numPartitions = partitions + def numPartitions: Int = partitions def getPartition(key: Any): Int = { val hash = key match { diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/MovieLensALS.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/MovieLensALS.scala index 1f4ca4fbe7778..0bc36ea65e1ab 100644 --- a/examples/src/main/scala/org/apache/spark/examples/mllib/MovieLensALS.scala +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/MovieLensALS.scala @@ -178,7 +178,9 @@ object MovieLensALS { def computeRmse(model: MatrixFactorizationModel, data: RDD[Rating], implicitPrefs: Boolean) : Double = { - def mapPredictedRating(r: Double) = if (implicitPrefs) math.max(math.min(r, 1.0), 0.0) else r + def mapPredictedRating(r: Double): Double = { + if (implicitPrefs) math.max(math.min(r, 1.0), 0.0) else r + } val predictions: RDD[Rating] = model.predict(data.map(x => (x.user, x.product))) val predictionsAndRatings = predictions.map{ x => diff --git a/examples/src/main/scala/org/apache/spark/examples/streaming/ActorWordCount.scala b/examples/src/main/scala/org/apache/spark/examples/streaming/ActorWordCount.scala index b433082dce1a2..92867b44be138 100644 --- a/examples/src/main/scala/org/apache/spark/examples/streaming/ActorWordCount.scala +++ b/examples/src/main/scala/org/apache/spark/examples/streaming/ActorWordCount.scala @@ -85,13 +85,13 @@ extends Actor with ActorHelper { lazy private val remotePublisher = context.actorSelection(urlOfPublisher) - override def preStart = remotePublisher ! SubscribeReceiver(context.self) + override def preStart(): Unit = remotePublisher ! SubscribeReceiver(context.self) - def receive = { + def receive: PartialFunction[Any, Unit] = { case msg => store(msg.asInstanceOf[T]) } - override def postStop() = remotePublisher ! UnsubscribeReceiver(context.self) + override def postStop(): Unit = remotePublisher ! UnsubscribeReceiver(context.self) } diff --git a/examples/src/main/scala/org/apache/spark/examples/streaming/RecoverableNetworkWordCount.scala b/examples/src/main/scala/org/apache/spark/examples/streaming/RecoverableNetworkWordCount.scala index c3a05c89d817e..751b30ea15782 100644 --- a/examples/src/main/scala/org/apache/spark/examples/streaming/RecoverableNetworkWordCount.scala +++ b/examples/src/main/scala/org/apache/spark/examples/streaming/RecoverableNetworkWordCount.scala @@ -55,7 +55,8 @@ import org.apache.spark.util.IntParam */ object RecoverableNetworkWordCount { - def createContext(ip: String, port: Int, outputPath: String, checkpointDirectory: String) = { + def createContext(ip: String, port: Int, outputPath: String, checkpointDirectory: String) + : StreamingContext = { // If you do not see this printed, that means the StreamingContext has been loaded // from the new checkpoint diff --git a/examples/src/main/scala/org/apache/spark/examples/streaming/ZeroMQWordCount.scala b/examples/src/main/scala/org/apache/spark/examples/streaming/ZeroMQWordCount.scala index 6510c70bd1866..e99d1baa72b9f 100644 --- a/examples/src/main/scala/org/apache/spark/examples/streaming/ZeroMQWordCount.scala +++ b/examples/src/main/scala/org/apache/spark/examples/streaming/ZeroMQWordCount.scala @@ -35,7 +35,7 @@ import org.apache.spark.SparkConf */ object SimpleZeroMQPublisher { - def main(args: Array[String]) = { + def main(args: Array[String]): Unit = { if (args.length < 2) { System.err.println("Usage: SimpleZeroMQPublisher ") System.exit(1) @@ -45,7 +45,7 @@ object SimpleZeroMQPublisher { val acs: ActorSystem = ActorSystem() val pubSocket = ZeroMQExtension(acs).newSocket(SocketType.Pub, Bind(url)) - implicit def stringToByteString(x: String) = ByteString(x) + implicit def stringToByteString(x: String): ByteString = ByteString(x) val messages: List[ByteString] = List("words ", "may ", "count ") while (true) { Thread.sleep(1000) @@ -86,7 +86,7 @@ object ZeroMQWordCount { // Create the context and set the batch size val ssc = new StreamingContext(sparkConf, Seconds(2)) - def bytesToStringIterator(x: Seq[ByteString]) = (x.map(_.utf8String)).iterator + def bytesToStringIterator(x: Seq[ByteString]): Iterator[String] = x.map(_.utf8String).iterator // For this stream, a zeroMQ publisher should be running. val lines = ZeroMQUtils.createStream(ssc, url, Subscribe(topic), bytesToStringIterator _) diff --git a/examples/src/main/scala/org/apache/spark/examples/streaming/clickstream/PageViewGenerator.scala b/examples/src/main/scala/org/apache/spark/examples/streaming/clickstream/PageViewGenerator.scala index 8402491b62671..54d996b8ac990 100644 --- a/examples/src/main/scala/org/apache/spark/examples/streaming/clickstream/PageViewGenerator.scala +++ b/examples/src/main/scala/org/apache/spark/examples/streaming/clickstream/PageViewGenerator.scala @@ -94,7 +94,7 @@ object PageViewGenerator { while (true) { val socket = listener.accept() new Thread() { - override def run = { + override def run(): Unit = { println("Got client connected from: " + socket.getInetAddress) val out = new PrintWriter(socket.getOutputStream(), true) diff --git a/external/flume-sink/src/test/resources/log4j.properties b/external/flume-sink/src/test/resources/log4j.properties index 2a58e99817224..42df8792f147f 100644 --- a/external/flume-sink/src/test/resources/log4j.properties +++ b/external/flume-sink/src/test/resources/log4j.properties @@ -24,5 +24,5 @@ log4j.appender.file.layout=org.apache.log4j.PatternLayout log4j.appender.file.layout.ConversionPattern=%d{yy/MM/dd HH:mm:ss.SSS} %t %p %c{1}: %m%n # Ignore messages below warning level from Jetty, because it's a bit verbose -log4j.logger.org.eclipse.jetty=WARN +log4j.logger.org.spark-project.jetty=WARN diff --git a/external/flume/src/main/scala/org/apache/spark/streaming/flume/FlumeInputDStream.scala b/external/flume/src/main/scala/org/apache/spark/streaming/flume/FlumeInputDStream.scala index 2de2a7926bfd1..60e2994431b38 100644 --- a/external/flume/src/main/scala/org/apache/spark/streaming/flume/FlumeInputDStream.scala +++ b/external/flume/src/main/scala/org/apache/spark/streaming/flume/FlumeInputDStream.scala @@ -37,8 +37,7 @@ import org.apache.spark.streaming.dstream._ import org.apache.spark.streaming.StreamingContext import org.apache.spark.streaming.receiver.Receiver -import org.jboss.netty.channel.ChannelPipelineFactory -import org.jboss.netty.channel.Channels +import org.jboss.netty.channel.{ChannelPipeline, ChannelPipelineFactory, Channels} import org.jboss.netty.channel.socket.nio.NioServerSocketChannelFactory import org.jboss.netty.handler.codec.compression._ @@ -187,8 +186,8 @@ class FlumeReceiver( logInfo("Flume receiver stopped") } - override def preferredLocation = Some(host) - + override def preferredLocation: Option[String] = Option(host) + /** A Netty Pipeline factory that will decompress incoming data from * and the Netty client and compress data going back to the client. * @@ -198,13 +197,12 @@ class FlumeReceiver( */ private[streaming] class CompressionChannelPipelineFactory extends ChannelPipelineFactory { - - def getPipeline() = { + def getPipeline(): ChannelPipeline = { val pipeline = Channels.pipeline() val encoder = new ZlibEncoder(6) pipeline.addFirst("deflater", encoder) pipeline.addFirst("inflater", new ZlibDecoder()) pipeline + } } } -} diff --git a/external/flume/src/test/resources/log4j.properties b/external/flume/src/test/resources/log4j.properties index 9697237bfa1a3..75e3b53a093f6 100644 --- a/external/flume/src/test/resources/log4j.properties +++ b/external/flume/src/test/resources/log4j.properties @@ -24,5 +24,5 @@ log4j.appender.file.layout=org.apache.log4j.PatternLayout log4j.appender.file.layout.ConversionPattern=%d{yy/MM/dd HH:mm:ss.SSS} %t %p %c{1}: %m%n # Ignore messages below warning level from Jetty, because it's a bit verbose -log4j.logger.org.eclipse.jetty=WARN +log4j.logger.org.spark-project.jetty=WARN diff --git a/external/flume/src/test/scala/org/apache/spark/streaming/flume/FlumePollingStreamSuite.scala b/external/flume/src/test/scala/org/apache/spark/streaming/flume/FlumePollingStreamSuite.scala index e04d4088df7dc..2edea9b5b69ba 100644 --- a/external/flume/src/test/scala/org/apache/spark/streaming/flume/FlumePollingStreamSuite.scala +++ b/external/flume/src/test/scala/org/apache/spark/streaming/flume/FlumePollingStreamSuite.scala @@ -1,21 +1,20 @@ /* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at * - * http://www.apache.org/licenses/LICENSE-2.0 + * http://www.apache.org/licenses/LICENSE-2.0 * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. */ + package org.apache.spark.streaming.flume import java.net.InetSocketAddress @@ -213,7 +212,7 @@ class FlumePollingStreamSuite extends FunSuite with BeforeAndAfter with Logging assert(counter === totalEventsPerChannel * channels.size) } - def assertChannelIsEmpty(channel: MemoryChannel) = { + def assertChannelIsEmpty(channel: MemoryChannel): Unit = { val queueRemaining = channel.getClass.getDeclaredField("queueRemaining") queueRemaining.setAccessible(true) val m = queueRemaining.get(channel).getClass.getDeclaredMethod("availablePermits") diff --git a/external/flume/src/test/scala/org/apache/spark/streaming/flume/FlumeStreamSuite.scala b/external/flume/src/test/scala/org/apache/spark/streaming/flume/FlumeStreamSuite.scala index 51d273af8da84..39e6754c81dbf 100644 --- a/external/flume/src/test/scala/org/apache/spark/streaming/flume/FlumeStreamSuite.scala +++ b/external/flume/src/test/scala/org/apache/spark/streaming/flume/FlumeStreamSuite.scala @@ -151,7 +151,9 @@ class FlumeStreamSuite extends FunSuite with BeforeAndAfter with Matchers with L } /** Class to create socket channel with compression */ - private class CompressionChannelFactory(compressionLevel: Int) extends NioClientSocketChannelFactory { + private class CompressionChannelFactory(compressionLevel: Int) + extends NioClientSocketChannelFactory { + override def newChannel(pipeline: ChannelPipeline): SocketChannel = { val encoder = new ZlibEncoder(compressionLevel) pipeline.addFirst("deflater", encoder) diff --git a/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/DirectKafkaInputDStream.scala b/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/DirectKafkaInputDStream.scala index 04e65cb3d708c..1b1fc8051d052 100644 --- a/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/DirectKafkaInputDStream.scala +++ b/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/DirectKafkaInputDStream.scala @@ -129,8 +129,9 @@ class DirectKafkaInputDStream[ private[streaming] class DirectKafkaInputDStreamCheckpointData extends DStreamCheckpointData(this) { - def batchForTime = data.asInstanceOf[mutable.HashMap[ - Time, Array[OffsetRange.OffsetRangeTuple]]] + def batchForTime: mutable.HashMap[Time, Array[(String, Int, Long, Long)]] = { + data.asInstanceOf[mutable.HashMap[Time, Array[OffsetRange.OffsetRangeTuple]]] + } override def update(time: Time) { batchForTime.clear() diff --git a/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/KafkaCluster.scala b/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/KafkaCluster.scala index 2f7e0ab39fefd..bd767031c1849 100644 --- a/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/KafkaCluster.scala +++ b/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/KafkaCluster.scala @@ -123,9 +123,17 @@ class KafkaCluster(val kafkaParams: Map[String, String]) extends Serializable { val errs = new Err withBrokers(Random.shuffle(config.seedBrokers), errs) { consumer => val resp: TopicMetadataResponse = consumer.send(req) - // error codes here indicate missing / just created topic, - // repeating on a different broker wont be useful - return Right(resp.topicsMetadata.toSet) + val respErrs = resp.topicsMetadata.filter(m => m.errorCode != ErrorMapping.NoError) + + if (respErrs.isEmpty) { + return Right(resp.topicsMetadata.toSet) + } else { + respErrs.foreach { m => + val cause = ErrorMapping.exceptionFor(m.errorCode) + val msg = s"Error getting partition metadata for '${m.topic}'. Does the topic exist?" + errs.append(new SparkException(msg, cause)) + } + } } Left(errs) } diff --git a/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/KafkaRDD.scala b/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/KafkaRDD.scala index 6d465bcb6bfc0..a0b8a0c565210 100644 --- a/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/KafkaRDD.scala +++ b/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/KafkaRDD.scala @@ -86,7 +86,7 @@ class KafkaRDD[ val part = thePart.asInstanceOf[KafkaRDDPartition] assert(part.fromOffset <= part.untilOffset, errBeginAfterEnd(part)) if (part.fromOffset == part.untilOffset) { - log.warn(s"Beginning offset ${part.fromOffset} is the same as ending offset " + + log.info(s"Beginning offset ${part.fromOffset} is the same as ending offset " + s"skipping ${part.topic} ${part.partition}") Iterator.empty } else { @@ -155,7 +155,7 @@ class KafkaRDD[ .dropWhile(_.offset < requestOffset) } - override def close() = consumer.close() + override def close(): Unit = consumer.close() override def getNext(): R = { if (iter == null || !iter.hasNext) { @@ -207,7 +207,7 @@ object KafkaRDD { fromOffsets: Map[TopicAndPartition, Long], untilOffsets: Map[TopicAndPartition, LeaderOffset], messageHandler: MessageAndMetadata[K, V] => R - ): KafkaRDD[K, V, U, T, R] = { + ): KafkaRDD[K, V, U, T, R] = { val leaders = untilOffsets.map { case (tp, lo) => tp -> (lo.host, lo.port) }.toMap diff --git a/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/KafkaTestUtils.scala b/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/KafkaTestUtils.scala new file mode 100644 index 0000000000000..13e9475065979 --- /dev/null +++ b/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/KafkaTestUtils.scala @@ -0,0 +1,261 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.streaming.kafka + +import java.io.File +import java.lang.{Integer => JInt} +import java.net.InetSocketAddress +import java.util.{Map => JMap} +import java.util.Properties +import java.util.concurrent.TimeoutException + +import scala.annotation.tailrec +import scala.language.postfixOps +import scala.util.control.NonFatal + +import kafka.admin.AdminUtils +import kafka.producer.{KeyedMessage, Producer, ProducerConfig} +import kafka.serializer.StringEncoder +import kafka.server.{KafkaConfig, KafkaServer} +import kafka.utils.ZKStringSerializer +import org.apache.zookeeper.server.{NIOServerCnxnFactory, ZooKeeperServer} +import org.I0Itec.zkclient.ZkClient + +import org.apache.spark.{Logging, SparkConf} +import org.apache.spark.streaming.Time +import org.apache.spark.util.Utils + +/** + * This is a helper class for Kafka test suites. This has the functionality to set up + * and tear down local Kafka servers, and to push data using Kafka producers. + * + * The reason to put Kafka test utility class in src is to test Python related Kafka APIs. + */ +private class KafkaTestUtils extends Logging { + + // Zookeeper related configurations + private val zkHost = "localhost" + private var zkPort: Int = 0 + private val zkConnectionTimeout = 6000 + private val zkSessionTimeout = 6000 + + private var zookeeper: EmbeddedZookeeper = _ + + private var zkClient: ZkClient = _ + + // Kafka broker related configurations + private val brokerHost = "localhost" + private var brokerPort = 9092 + private var brokerConf: KafkaConfig = _ + + // Kafka broker server + private var server: KafkaServer = _ + + // Kafka producer + private var producer: Producer[String, String] = _ + + // Flag to test whether the system is correctly started + private var zkReady = false + private var brokerReady = false + + def zkAddress: String = { + assert(zkReady, "Zookeeper not setup yet or already torn down, cannot get zookeeper address") + s"$zkHost:$zkPort" + } + + def brokerAddress: String = { + assert(brokerReady, "Kafka not setup yet or already torn down, cannot get broker address") + s"$brokerHost:$brokerPort" + } + + def zookeeperClient: ZkClient = { + assert(zkReady, "Zookeeper not setup yet or already torn down, cannot get zookeeper client") + Option(zkClient).getOrElse( + throw new IllegalStateException("Zookeeper client is not yet initialized")) + } + + // Set up the Embedded Zookeeper server and get the proper Zookeeper port + private def setupEmbeddedZookeeper(): Unit = { + // Zookeeper server startup + zookeeper = new EmbeddedZookeeper(s"$zkHost:$zkPort") + // Get the actual zookeeper binding port + zkPort = zookeeper.actualPort + zkClient = new ZkClient(s"$zkHost:$zkPort", zkSessionTimeout, zkConnectionTimeout, + ZKStringSerializer) + zkReady = true + } + + // Set up the Embedded Kafka server + private def setupEmbeddedKafkaServer(): Unit = { + assert(zkReady, "Zookeeper should be set up beforehand") + + // Kafka broker startup + Utils.startServiceOnPort(brokerPort, port => { + brokerPort = port + brokerConf = new KafkaConfig(brokerConfiguration) + server = new KafkaServer(brokerConf) + server.startup() + (server, port) + }, new SparkConf(), "KafkaBroker") + + brokerReady = true + } + + /** setup the whole embedded servers, including Zookeeper and Kafka brokers */ + def setup(): Unit = { + setupEmbeddedZookeeper() + setupEmbeddedKafkaServer() + } + + /** Teardown the whole servers, including Kafka broker and Zookeeper */ + def teardown(): Unit = { + brokerReady = false + zkReady = false + + if (producer != null) { + producer.close() + producer = null + } + + if (server != null) { + server.shutdown() + server = null + } + + brokerConf.logDirs.foreach { f => Utils.deleteRecursively(new File(f)) } + + if (zkClient != null) { + zkClient.close() + zkClient = null + } + + if (zookeeper != null) { + zookeeper.shutdown() + zookeeper = null + } + } + + /** Create a Kafka topic and wait until it propagated to the whole cluster */ + def createTopic(topic: String): Unit = { + AdminUtils.createTopic(zkClient, topic, 1, 1) + // wait until metadata is propagated + waitUntilMetadataIsPropagated(topic, 0) + } + + /** Java-friendly function for sending messages to the Kafka broker */ + def sendMessages(topic: String, messageToFreq: JMap[String, JInt]): Unit = { + import scala.collection.JavaConversions._ + sendMessages(topic, Map(messageToFreq.mapValues(_.intValue()).toSeq: _*)) + } + + /** Send the messages to the Kafka broker */ + def sendMessages(topic: String, messageToFreq: Map[String, Int]): Unit = { + val messages = messageToFreq.flatMap { case (s, freq) => Seq.fill(freq)(s) }.toArray + sendMessages(topic, messages) + } + + /** Send the array of messages to the Kafka broker */ + def sendMessages(topic: String, messages: Array[String]): Unit = { + producer = new Producer[String, String](new ProducerConfig(producerConfiguration)) + producer.send(messages.map { new KeyedMessage[String, String](topic, _ ) }: _*) + producer.close() + producer = null + } + + private def brokerConfiguration: Properties = { + val props = new Properties() + props.put("broker.id", "0") + props.put("host.name", "localhost") + props.put("port", brokerPort.toString) + props.put("log.dir", Utils.createTempDir().getAbsolutePath) + props.put("zookeeper.connect", zkAddress) + props.put("log.flush.interval.messages", "1") + props.put("replica.socket.timeout.ms", "1500") + props + } + + private def producerConfiguration: Properties = { + val props = new Properties() + props.put("metadata.broker.list", brokerAddress) + props.put("serializer.class", classOf[StringEncoder].getName) + props + } + + // A simplified version of scalatest eventually, rewritten here to avoid adding extra test + // dependency + def eventually[T](timeout: Time, interval: Time)(func: => T): T = { + def makeAttempt(): Either[Throwable, T] = { + try { + Right(func) + } catch { + case e if NonFatal(e) => Left(e) + } + } + + val startTime = System.currentTimeMillis() + @tailrec + def tryAgain(attempt: Int): T = { + makeAttempt() match { + case Right(result) => result + case Left(e) => + val duration = System.currentTimeMillis() - startTime + if (duration < timeout.milliseconds) { + Thread.sleep(interval.milliseconds) + } else { + throw new TimeoutException(e.getMessage) + } + + tryAgain(attempt + 1) + } + } + + tryAgain(1) + } + + private def waitUntilMetadataIsPropagated(topic: String, partition: Int): Unit = { + eventually(Time(10000), Time(100)) { + assert( + server.apis.metadataCache.containsTopicAndPartition(topic, partition), + s"Partition [$topic, $partition] metadata not propagated after timeout" + ) + } + } + + private class EmbeddedZookeeper(val zkConnect: String) { + val snapshotDir = Utils.createTempDir() + val logDir = Utils.createTempDir() + + val zookeeper = new ZooKeeperServer(snapshotDir, logDir, 500) + val (ip, port) = { + val splits = zkConnect.split(":") + (splits(0), splits(1).toInt) + } + val factory = new NIOServerCnxnFactory() + factory.configure(new InetSocketAddress(ip, port), 16) + factory.startup(zookeeper) + + val actualPort = factory.getLocalPort + + def shutdown() { + factory.shutdown() + Utils.deleteRecursively(snapshotDir) + Utils.deleteRecursively(logDir) + } + } +} + diff --git a/external/kafka/src/test/java/org/apache/spark/streaming/kafka/JavaDirectKafkaStreamSuite.java b/external/kafka/src/test/java/org/apache/spark/streaming/kafka/JavaDirectKafkaStreamSuite.java index d6ca6d58b5665..4c1d6a03eb2b8 100644 --- a/external/kafka/src/test/java/org/apache/spark/streaming/kafka/JavaDirectKafkaStreamSuite.java +++ b/external/kafka/src/test/java/org/apache/spark/streaming/kafka/JavaDirectKafkaStreamSuite.java @@ -41,24 +41,28 @@ public class JavaDirectKafkaStreamSuite implements Serializable { private transient JavaStreamingContext ssc = null; - private transient KafkaStreamSuiteBase suiteBase = null; + private transient KafkaTestUtils kafkaTestUtils = null; @Before public void setUp() { - suiteBase = new KafkaStreamSuiteBase() { }; - suiteBase.setupKafka(); - System.clearProperty("spark.driver.port"); - SparkConf sparkConf = new SparkConf() - .setMaster("local[4]").setAppName(this.getClass().getSimpleName()); - ssc = new JavaStreamingContext(sparkConf, Durations.milliseconds(200)); + kafkaTestUtils = new KafkaTestUtils(); + kafkaTestUtils.setup(); + SparkConf sparkConf = new SparkConf() + .setMaster("local[4]").setAppName(this.getClass().getSimpleName()); + ssc = new JavaStreamingContext(sparkConf, Durations.milliseconds(200)); } @After public void tearDown() { + if (ssc != null) { ssc.stop(); ssc = null; - System.clearProperty("spark.driver.port"); - suiteBase.tearDownKafka(); + } + + if (kafkaTestUtils != null) { + kafkaTestUtils.teardown(); + kafkaTestUtils = null; + } } @Test @@ -74,7 +78,7 @@ public void testKafkaStream() throws InterruptedException { sent.addAll(Arrays.asList(topic2data)); HashMap kafkaParams = new HashMap(); - kafkaParams.put("metadata.broker.list", suiteBase.brokerAddress()); + kafkaParams.put("metadata.broker.list", kafkaTestUtils.brokerAddress()); kafkaParams.put("auto.offset.reset", "smallest"); JavaDStream stream1 = KafkaUtils.createDirectStream( @@ -147,8 +151,8 @@ private HashMap topicOffsetToMap(String topic, Long off private String[] createTopicAndSendData(String topic) { String[] data = { topic + "-1", topic + "-2", topic + "-3"}; - suiteBase.createTopic(topic); - suiteBase.sendMessages(topic, data); + kafkaTestUtils.createTopic(topic); + kafkaTestUtils.sendMessages(topic, data); return data; } } diff --git a/external/kafka/src/test/java/org/apache/spark/streaming/kafka/JavaKafkaRDDSuite.java b/external/kafka/src/test/java/org/apache/spark/streaming/kafka/JavaKafkaRDDSuite.java index 4477b81827c70..a9dc6e50613ca 100644 --- a/external/kafka/src/test/java/org/apache/spark/streaming/kafka/JavaKafkaRDDSuite.java +++ b/external/kafka/src/test/java/org/apache/spark/streaming/kafka/JavaKafkaRDDSuite.java @@ -37,13 +37,12 @@ public class JavaKafkaRDDSuite implements Serializable { private transient JavaSparkContext sc = null; - private transient KafkaStreamSuiteBase suiteBase = null; + private transient KafkaTestUtils kafkaTestUtils = null; @Before public void setUp() { - suiteBase = new KafkaStreamSuiteBase() { }; - suiteBase.setupKafka(); - System.clearProperty("spark.driver.port"); + kafkaTestUtils = new KafkaTestUtils(); + kafkaTestUtils.setup(); SparkConf sparkConf = new SparkConf() .setMaster("local[4]").setAppName(this.getClass().getSimpleName()); sc = new JavaSparkContext(sparkConf); @@ -51,10 +50,15 @@ public void setUp() { @After public void tearDown() { - sc.stop(); - sc = null; - System.clearProperty("spark.driver.port"); - suiteBase.tearDownKafka(); + if (sc != null) { + sc.stop(); + sc = null; + } + + if (kafkaTestUtils != null) { + kafkaTestUtils.teardown(); + kafkaTestUtils = null; + } } @Test @@ -66,7 +70,7 @@ public void testKafkaRDD() throws InterruptedException { String[] topic2data = createTopicAndSendData(topic2); HashMap kafkaParams = new HashMap(); - kafkaParams.put("metadata.broker.list", suiteBase.brokerAddress()); + kafkaParams.put("metadata.broker.list", kafkaTestUtils.brokerAddress()); OffsetRange[] offsetRanges = { OffsetRange.create(topic1, 0, 0, 1), @@ -75,7 +79,7 @@ public void testKafkaRDD() throws InterruptedException { HashMap emptyLeaders = new HashMap(); HashMap leaders = new HashMap(); - String[] hostAndPort = suiteBase.brokerAddress().split(":"); + String[] hostAndPort = kafkaTestUtils.brokerAddress().split(":"); Broker broker = Broker.create(hostAndPort[0], Integer.parseInt(hostAndPort[1])); leaders.put(new TopicAndPartition(topic1, 0), broker); leaders.put(new TopicAndPartition(topic2, 0), broker); @@ -144,8 +148,8 @@ public String call(MessageAndMetadata msgAndMd) throws Exception private String[] createTopicAndSendData(String topic) { String[] data = { topic + "-1", topic + "-2", topic + "-3"}; - suiteBase.createTopic(topic); - suiteBase.sendMessages(topic, data); + kafkaTestUtils.createTopic(topic); + kafkaTestUtils.sendMessages(topic, data); return data; } } diff --git a/external/kafka/src/test/java/org/apache/spark/streaming/kafka/JavaKafkaStreamSuite.java b/external/kafka/src/test/java/org/apache/spark/streaming/kafka/JavaKafkaStreamSuite.java index bad0a93eb2e84..540f4ceabab47 100644 --- a/external/kafka/src/test/java/org/apache/spark/streaming/kafka/JavaKafkaStreamSuite.java +++ b/external/kafka/src/test/java/org/apache/spark/streaming/kafka/JavaKafkaStreamSuite.java @@ -22,9 +22,7 @@ import java.util.List; import java.util.Random; -import scala.Predef; import scala.Tuple2; -import scala.collection.JavaConverters; import kafka.serializer.StringDecoder; import org.junit.After; @@ -44,13 +42,12 @@ public class JavaKafkaStreamSuite implements Serializable { private transient JavaStreamingContext ssc = null; private transient Random random = new Random(); - private transient KafkaStreamSuiteBase suiteBase = null; + private transient KafkaTestUtils kafkaTestUtils = null; @Before public void setUp() { - suiteBase = new KafkaStreamSuiteBase() { }; - suiteBase.setupKafka(); - System.clearProperty("spark.driver.port"); + kafkaTestUtils = new KafkaTestUtils(); + kafkaTestUtils.setup(); SparkConf sparkConf = new SparkConf() .setMaster("local[4]").setAppName(this.getClass().getSimpleName()); ssc = new JavaStreamingContext(sparkConf, new Duration(500)); @@ -58,10 +55,15 @@ public void setUp() { @After public void tearDown() { - ssc.stop(); - ssc = null; - System.clearProperty("spark.driver.port"); - suiteBase.tearDownKafka(); + if (ssc != null) { + ssc.stop(); + ssc = null; + } + + if (kafkaTestUtils != null) { + kafkaTestUtils.teardown(); + kafkaTestUtils = null; + } } @Test @@ -75,15 +77,11 @@ public void testKafkaStream() throws InterruptedException { sent.put("b", 3); sent.put("c", 10); - suiteBase.createTopic(topic); - HashMap tmp = new HashMap(sent); - suiteBase.sendMessages(topic, - JavaConverters.mapAsScalaMapConverter(tmp).asScala().toMap( - Predef.>conforms()) - ); + kafkaTestUtils.createTopic(topic); + kafkaTestUtils.sendMessages(topic, sent); HashMap kafkaParams = new HashMap(); - kafkaParams.put("zookeeper.connect", suiteBase.zkAddress()); + kafkaParams.put("zookeeper.connect", kafkaTestUtils.zkAddress()); kafkaParams.put("group.id", "test-consumer-" + random.nextInt(10000)); kafkaParams.put("auto.offset.reset", "smallest"); @@ -126,6 +124,7 @@ public Void call(JavaPairRDD rdd) throws Exception { ); ssc.start(); + long startTime = System.currentTimeMillis(); boolean sizeMatches = false; while (!sizeMatches && System.currentTimeMillis() - startTime < 20000) { @@ -136,6 +135,5 @@ public Void call(JavaPairRDD rdd) throws Exception { for (String k : sent.keySet()) { Assert.assertEquals(sent.get(k).intValue(), result.get(k).intValue()); } - ssc.stop(); } } diff --git a/external/kafka/src/test/resources/log4j.properties b/external/kafka/src/test/resources/log4j.properties index 9697237bfa1a3..75e3b53a093f6 100644 --- a/external/kafka/src/test/resources/log4j.properties +++ b/external/kafka/src/test/resources/log4j.properties @@ -24,5 +24,5 @@ log4j.appender.file.layout=org.apache.log4j.PatternLayout log4j.appender.file.layout.ConversionPattern=%d{yy/MM/dd HH:mm:ss.SSS} %t %p %c{1}: %m%n # Ignore messages below warning level from Jetty, because it's a bit verbose -log4j.logger.org.eclipse.jetty=WARN +log4j.logger.org.spark-project.jetty=WARN diff --git a/external/kafka/src/test/scala/org/apache/spark/streaming/kafka/DirectKafkaStreamSuite.scala b/external/kafka/src/test/scala/org/apache/spark/streaming/kafka/DirectKafkaStreamSuite.scala index 17ca9d145d665..415730f5559c5 100644 --- a/external/kafka/src/test/scala/org/apache/spark/streaming/kafka/DirectKafkaStreamSuite.scala +++ b/external/kafka/src/test/scala/org/apache/spark/streaming/kafka/DirectKafkaStreamSuite.scala @@ -27,31 +27,41 @@ import scala.language.postfixOps import kafka.common.TopicAndPartition import kafka.message.MessageAndMetadata import kafka.serializer.StringDecoder -import org.scalatest.{BeforeAndAfter, BeforeAndAfterAll} +import org.scalatest.{BeforeAndAfter, BeforeAndAfterAll, FunSuite} import org.scalatest.concurrent.Eventually -import org.apache.spark.{SparkConf, SparkContext} +import org.apache.spark.{Logging, SparkConf, SparkContext} import org.apache.spark.rdd.RDD import org.apache.spark.streaming.{Milliseconds, StreamingContext, Time} import org.apache.spark.streaming.dstream.DStream import org.apache.spark.util.Utils -class DirectKafkaStreamSuite extends KafkaStreamSuiteBase - with BeforeAndAfter with BeforeAndAfterAll with Eventually { +class DirectKafkaStreamSuite + extends FunSuite + with BeforeAndAfter + with BeforeAndAfterAll + with Eventually + with Logging { val sparkConf = new SparkConf() .setMaster("local[4]") .setAppName(this.getClass.getSimpleName) - var sc: SparkContext = _ - var ssc: StreamingContext = _ - var testDir: File = _ + private var sc: SparkContext = _ + private var ssc: StreamingContext = _ + private var testDir: File = _ + + private var kafkaTestUtils: KafkaTestUtils = _ override def beforeAll { - setupKafka() + kafkaTestUtils = new KafkaTestUtils + kafkaTestUtils.setup() } override def afterAll { - tearDownKafka() + if (kafkaTestUtils != null) { + kafkaTestUtils.teardown() + kafkaTestUtils = null + } } after { @@ -72,12 +82,12 @@ class DirectKafkaStreamSuite extends KafkaStreamSuiteBase val topics = Set("basic1", "basic2", "basic3") val data = Map("a" -> 7, "b" -> 9) topics.foreach { t => - createTopic(t) - sendMessages(t, data) + kafkaTestUtils.createTopic(t) + kafkaTestUtils.sendMessages(t, data) } val totalSent = data.values.sum * topics.size val kafkaParams = Map( - "metadata.broker.list" -> s"$brokerAddress", + "metadata.broker.list" -> kafkaTestUtils.brokerAddress, "auto.offset.reset" -> "smallest" ) @@ -121,9 +131,9 @@ class DirectKafkaStreamSuite extends KafkaStreamSuiteBase val topic = "largest" val topicPartition = TopicAndPartition(topic, 0) val data = Map("a" -> 10) - createTopic(topic) + kafkaTestUtils.createTopic(topic) val kafkaParams = Map( - "metadata.broker.list" -> s"$brokerAddress", + "metadata.broker.list" -> kafkaTestUtils.brokerAddress, "auto.offset.reset" -> "largest" ) val kc = new KafkaCluster(kafkaParams) @@ -132,7 +142,7 @@ class DirectKafkaStreamSuite extends KafkaStreamSuiteBase } // Send some initial messages before starting context - sendMessages(topic, data) + kafkaTestUtils.sendMessages(topic, data) eventually(timeout(10 seconds), interval(20 milliseconds)) { assert(getLatestOffset() > 3) } @@ -154,7 +164,7 @@ class DirectKafkaStreamSuite extends KafkaStreamSuiteBase stream.map { _._2 }.foreachRDD { rdd => collectedData ++= rdd.collect() } ssc.start() val newData = Map("b" -> 10) - sendMessages(topic, newData) + kafkaTestUtils.sendMessages(topic, newData) eventually(timeout(10 seconds), interval(50 milliseconds)) { collectedData.contains("b") } @@ -166,9 +176,9 @@ class DirectKafkaStreamSuite extends KafkaStreamSuiteBase val topic = "offset" val topicPartition = TopicAndPartition(topic, 0) val data = Map("a" -> 10) - createTopic(topic) + kafkaTestUtils.createTopic(topic) val kafkaParams = Map( - "metadata.broker.list" -> s"$brokerAddress", + "metadata.broker.list" -> kafkaTestUtils.brokerAddress, "auto.offset.reset" -> "largest" ) val kc = new KafkaCluster(kafkaParams) @@ -177,7 +187,7 @@ class DirectKafkaStreamSuite extends KafkaStreamSuiteBase } // Send some initial messages before starting context - sendMessages(topic, data) + kafkaTestUtils.sendMessages(topic, data) eventually(timeout(10 seconds), interval(20 milliseconds)) { assert(getLatestOffset() >= 10) } @@ -200,7 +210,7 @@ class DirectKafkaStreamSuite extends KafkaStreamSuiteBase stream.foreachRDD { rdd => collectedData ++= rdd.collect() } ssc.start() val newData = Map("b" -> 10) - sendMessages(topic, newData) + kafkaTestUtils.sendMessages(topic, newData) eventually(timeout(10 seconds), interval(50 milliseconds)) { collectedData.contains("b") } @@ -210,18 +220,18 @@ class DirectKafkaStreamSuite extends KafkaStreamSuiteBase // Test to verify the offset ranges can be recovered from the checkpoints test("offset recovery") { val topic = "recovery" - createTopic(topic) + kafkaTestUtils.createTopic(topic) testDir = Utils.createTempDir() val kafkaParams = Map( - "metadata.broker.list" -> s"$brokerAddress", + "metadata.broker.list" -> kafkaTestUtils.brokerAddress, "auto.offset.reset" -> "smallest" ) // Send data to Kafka and wait for it to be received def sendDataAndWaitForReceive(data: Seq[Int]) { val strings = data.map { _.toString} - sendMessages(topic, strings.map { _ -> 1}.toMap) + kafkaTestUtils.sendMessages(topic, strings.map { _ -> 1}.toMap) eventually(timeout(10 seconds), interval(50 milliseconds)) { assert(strings.forall { DirectKafkaStreamSuite.collectedData.contains }) } diff --git a/external/kafka/src/test/scala/org/apache/spark/streaming/kafka/KafkaClusterSuite.scala b/external/kafka/src/test/scala/org/apache/spark/streaming/kafka/KafkaClusterSuite.scala index fc9275b7207be..7fb841b79cb65 100644 --- a/external/kafka/src/test/scala/org/apache/spark/streaming/kafka/KafkaClusterSuite.scala +++ b/external/kafka/src/test/scala/org/apache/spark/streaming/kafka/KafkaClusterSuite.scala @@ -20,31 +20,41 @@ package org.apache.spark.streaming.kafka import scala.util.Random import kafka.common.TopicAndPartition -import org.scalatest.BeforeAndAfterAll +import org.scalatest.{BeforeAndAfterAll, FunSuite} -class KafkaClusterSuite extends KafkaStreamSuiteBase with BeforeAndAfterAll { - val topic = "kcsuitetopic" + Random.nextInt(10000) - val topicAndPartition = TopicAndPartition(topic, 0) - var kc: KafkaCluster = null +class KafkaClusterSuite extends FunSuite with BeforeAndAfterAll { + private val topic = "kcsuitetopic" + Random.nextInt(10000) + private val topicAndPartition = TopicAndPartition(topic, 0) + private var kc: KafkaCluster = null + + private var kafkaTestUtils: KafkaTestUtils = _ override def beforeAll() { - setupKafka() - createTopic(topic) - sendMessages(topic, Map("a" -> 1)) - kc = new KafkaCluster(Map("metadata.broker.list" -> s"$brokerAddress")) + kafkaTestUtils = new KafkaTestUtils + kafkaTestUtils.setup() + + kafkaTestUtils.createTopic(topic) + kafkaTestUtils.sendMessages(topic, Map("a" -> 1)) + kc = new KafkaCluster(Map("metadata.broker.list" -> kafkaTestUtils.brokerAddress)) } override def afterAll() { - tearDownKafka() + if (kafkaTestUtils != null) { + kafkaTestUtils.teardown() + kafkaTestUtils = null + } } test("metadata apis") { val leader = kc.findLeaders(Set(topicAndPartition)).right.get(topicAndPartition) val leaderAddress = s"${leader._1}:${leader._2}" - assert(leaderAddress === brokerAddress, "didn't get leader") + assert(leaderAddress === kafkaTestUtils.brokerAddress, "didn't get leader") val parts = kc.getPartitions(Set(topic)).right.get assert(parts(topicAndPartition), "didn't get partitions") + + val err = kc.getPartitions(Set(topic + "BAD")) + assert(err.isLeft, "getPartitions for a nonexistant topic should be an error") } test("leader offset apis") { diff --git a/external/kafka/src/test/scala/org/apache/spark/streaming/kafka/KafkaRDDSuite.scala b/external/kafka/src/test/scala/org/apache/spark/streaming/kafka/KafkaRDDSuite.scala index a223da70b043f..7d26ce50875b3 100644 --- a/external/kafka/src/test/scala/org/apache/spark/streaming/kafka/KafkaRDDSuite.scala +++ b/external/kafka/src/test/scala/org/apache/spark/streaming/kafka/KafkaRDDSuite.scala @@ -22,18 +22,22 @@ import scala.util.Random import kafka.serializer.StringDecoder import kafka.common.TopicAndPartition import kafka.message.MessageAndMetadata -import org.scalatest.BeforeAndAfterAll +import org.scalatest.{BeforeAndAfterAll, FunSuite} import org.apache.spark._ -import org.apache.spark.SparkContext._ -class KafkaRDDSuite extends KafkaStreamSuiteBase with BeforeAndAfterAll { - val sparkConf = new SparkConf().setMaster("local[4]").setAppName(this.getClass.getSimpleName) - var sc: SparkContext = _ +class KafkaRDDSuite extends FunSuite with BeforeAndAfterAll { + + private var kafkaTestUtils: KafkaTestUtils = _ + + private val sparkConf = new SparkConf().setMaster("local[4]") + .setAppName(this.getClass.getSimpleName) + private var sc: SparkContext = _ + override def beforeAll { sc = new SparkContext(sparkConf) - - setupKafka() + kafkaTestUtils = new KafkaTestUtils + kafkaTestUtils.setup() } override def afterAll { @@ -41,17 +45,21 @@ class KafkaRDDSuite extends KafkaStreamSuiteBase with BeforeAndAfterAll { sc.stop sc = null } - tearDownKafka() + + if (kafkaTestUtils != null) { + kafkaTestUtils.teardown() + kafkaTestUtils = null + } } test("basic usage") { val topic = "topicbasic" - createTopic(topic) + kafkaTestUtils.createTopic(topic) val messages = Set("the", "quick", "brown", "fox") - sendMessages(topic, messages.toArray) + kafkaTestUtils.sendMessages(topic, messages.toArray) - val kafkaParams = Map("metadata.broker.list" -> brokerAddress, + val kafkaParams = Map("metadata.broker.list" -> kafkaTestUtils.brokerAddress, "group.id" -> s"test-consumer-${Random.nextInt(10000)}") val offsetRanges = Array(OffsetRange(topic, 0, 0, messages.size)) @@ -67,15 +75,15 @@ class KafkaRDDSuite extends KafkaStreamSuiteBase with BeforeAndAfterAll { // the idea is to find e.g. off-by-one errors between what kafka has available and the rdd val topic = "topic1" val sent = Map("a" -> 5, "b" -> 3, "c" -> 10) - createTopic(topic) + kafkaTestUtils.createTopic(topic) - val kafkaParams = Map("metadata.broker.list" -> brokerAddress, + val kafkaParams = Map("metadata.broker.list" -> kafkaTestUtils.brokerAddress, "group.id" -> s"test-consumer-${Random.nextInt(10000)}") val kc = new KafkaCluster(kafkaParams) // this is the "lots of messages" case - sendMessages(topic, sent) + kafkaTestUtils.sendMessages(topic, sent) // rdd defined from leaders after sending messages, should get the number sent val rdd = getRdd(kc, Set(topic)) @@ -92,14 +100,14 @@ class KafkaRDDSuite extends KafkaStreamSuiteBase with BeforeAndAfterAll { // shouldn't get anything, since message is sent after rdd was defined val sentOnlyOne = Map("d" -> 1) - sendMessages(topic, sentOnlyOne) + kafkaTestUtils.sendMessages(topic, sentOnlyOne) assert(rdd2.isDefined) assert(rdd2.get.count === 0, "got messages when there shouldn't be any") // this is the "exactly 1 message" case, namely the single message from sentOnlyOne above val rdd3 = getRdd(kc, Set(topic)) // send lots of messages after rdd was defined, they shouldn't show up - sendMessages(topic, Map("extra" -> 22)) + kafkaTestUtils.sendMessages(topic, Map("extra" -> 22)) assert(rdd3.isDefined) assert(rdd3.get.count === sentOnlyOne.values.sum, "didn't get exactly one message") diff --git a/external/kafka/src/test/scala/org/apache/spark/streaming/kafka/KafkaStreamSuite.scala b/external/kafka/src/test/scala/org/apache/spark/streaming/kafka/KafkaStreamSuite.scala index e4966eebb9b34..24699dfc33adb 100644 --- a/external/kafka/src/test/scala/org/apache/spark/streaming/kafka/KafkaStreamSuite.scala +++ b/external/kafka/src/test/scala/org/apache/spark/streaming/kafka/KafkaStreamSuite.scala @@ -17,209 +17,38 @@ package org.apache.spark.streaming.kafka -import java.io.File -import java.net.InetSocketAddress -import java.util.Properties - import scala.collection.mutable import scala.concurrent.duration._ import scala.language.postfixOps import scala.util.Random -import kafka.admin.AdminUtils -import kafka.common.{KafkaException, TopicAndPartition} -import kafka.producer.{KeyedMessage, Producer, ProducerConfig} -import kafka.serializer.{StringDecoder, StringEncoder} -import kafka.server.{KafkaConfig, KafkaServer} -import kafka.utils.ZKStringSerializer -import org.I0Itec.zkclient.ZkClient -import org.apache.zookeeper.server.{NIOServerCnxnFactory, ZooKeeperServer} -import org.scalatest.{BeforeAndAfter, FunSuite} +import kafka.serializer.StringDecoder +import org.scalatest.{BeforeAndAfterAll, FunSuite} import org.scalatest.concurrent.Eventually -import org.apache.spark.{Logging, SparkConf} +import org.apache.spark.SparkConf import org.apache.spark.storage.StorageLevel import org.apache.spark.streaming.{Milliseconds, StreamingContext} -import org.apache.spark.util.Utils - -/** - * This is an abstract base class for Kafka testsuites. This has the functionality to set up - * and tear down local Kafka servers, and to push data using Kafka producers. - */ -abstract class KafkaStreamSuiteBase extends FunSuite with Eventually with Logging { - - private val zkHost = "localhost" - private var zkPort: Int = 0 - private val zkConnectionTimeout = 6000 - private val zkSessionTimeout = 6000 - private var zookeeper: EmbeddedZookeeper = _ - private val brokerHost = "localhost" - private var brokerPort = 9092 - private var brokerConf: KafkaConfig = _ - private var server: KafkaServer = _ - private var producer: Producer[String, String] = _ - private var zkReady = false - private var brokerReady = false - - protected var zkClient: ZkClient = _ - - def zkAddress: String = { - assert(zkReady, "Zookeeper not setup yet or already torn down, cannot get zookeeper address") - s"$zkHost:$zkPort" - } - def brokerAddress: String = { - assert(brokerReady, "Kafka not setup yet or already torn down, cannot get broker address") - s"$brokerHost:$brokerPort" - } - - def setupKafka() { - // Zookeeper server startup - zookeeper = new EmbeddedZookeeper(s"$zkHost:$zkPort") - // Get the actual zookeeper binding port - zkPort = zookeeper.actualPort - zkReady = true - logInfo("==================== Zookeeper Started ====================") +class KafkaStreamSuite extends FunSuite with Eventually with BeforeAndAfterAll { + private var ssc: StreamingContext = _ + private var kafkaTestUtils: KafkaTestUtils = _ - zkClient = new ZkClient(zkAddress, zkSessionTimeout, zkConnectionTimeout, ZKStringSerializer) - logInfo("==================== Zookeeper Client Created ====================") - - // Kafka broker startup - var bindSuccess: Boolean = false - while(!bindSuccess) { - try { - val brokerProps = getBrokerConfig() - brokerConf = new KafkaConfig(brokerProps) - server = new KafkaServer(brokerConf) - server.startup() - logInfo("==================== Kafka Broker Started ====================") - bindSuccess = true - } catch { - case e: KafkaException => - if (e.getMessage != null && e.getMessage.contains("Socket server failed to bind to")) { - brokerPort += 1 - } - case e: Exception => throw new Exception("Kafka server create failed", e) - } - } - - Thread.sleep(2000) - logInfo("==================== Kafka + Zookeeper Ready ====================") - brokerReady = true + override def beforeAll(): Unit = { + kafkaTestUtils = new KafkaTestUtils + kafkaTestUtils.setup() } - def tearDownKafka() { - brokerReady = false - zkReady = false - if (producer != null) { - producer.close() - producer = null - } - - if (server != null) { - server.shutdown() - server = null - } - - brokerConf.logDirs.foreach { f => Utils.deleteRecursively(new File(f)) } - - if (zkClient != null) { - zkClient.close() - zkClient = null - } - - if (zookeeper != null) { - zookeeper.shutdown() - zookeeper = null - } - } - - def createTopic(topic: String) { - AdminUtils.createTopic(zkClient, topic, 1, 1) - // wait until metadata is propagated - waitUntilMetadataIsPropagated(topic, 0) - logInfo(s"==================== Topic $topic Created ====================") - } - - def sendMessages(topic: String, messageToFreq: Map[String, Int]) { - val messages = messageToFreq.flatMap { case (s, freq) => Seq.fill(freq)(s) }.toArray - sendMessages(topic, messages) - } - - def sendMessages(topic: String, messages: Array[String]) { - producer = new Producer[String, String](new ProducerConfig(getProducerConfig())) - producer.send(messages.map { new KeyedMessage[String, String](topic, _ ) }: _*) - producer.close() - logInfo(s"==================== Sent Messages: ${messages.mkString(", ")} ====================") - } - - private def getBrokerConfig(): Properties = { - val props = new Properties() - props.put("broker.id", "0") - props.put("host.name", "localhost") - props.put("port", brokerPort.toString) - props.put("log.dir", Utils.createTempDir().getAbsolutePath) - props.put("zookeeper.connect", zkAddress) - props.put("log.flush.interval.messages", "1") - props.put("replica.socket.timeout.ms", "1500") - props - } - - private def getProducerConfig(): Properties = { - val brokerAddr = brokerConf.hostName + ":" + brokerConf.port - val props = new Properties() - props.put("metadata.broker.list", brokerAddr) - props.put("serializer.class", classOf[StringEncoder].getName) - props - } - - private def waitUntilMetadataIsPropagated(topic: String, partition: Int) { - eventually(timeout(10000 milliseconds), interval(100 milliseconds)) { - assert( - server.apis.metadataCache.containsTopicAndPartition(topic, partition), - s"Partition [$topic, $partition] metadata not propagated after timeout" - ) - } - } - - class EmbeddedZookeeper(val zkConnect: String) { - val random = new Random() - val snapshotDir = Utils.createTempDir() - val logDir = Utils.createTempDir() - - val zookeeper = new ZooKeeperServer(snapshotDir, logDir, 500) - val (ip, port) = { - val splits = zkConnect.split(":") - (splits(0), splits(1).toInt) - } - val factory = new NIOServerCnxnFactory() - factory.configure(new InetSocketAddress(ip, port), 16) - factory.startup(zookeeper) - - val actualPort = factory.getLocalPort - - def shutdown() { - factory.shutdown() - Utils.deleteRecursively(snapshotDir) - Utils.deleteRecursively(logDir) - } - } -} - - -class KafkaStreamSuite extends KafkaStreamSuiteBase with BeforeAndAfter { - var ssc: StreamingContext = _ - - before { - setupKafka() - } - - after { + override def afterAll(): Unit = { if (ssc != null) { ssc.stop() ssc = null } - tearDownKafka() + + if (kafkaTestUtils != null) { + kafkaTestUtils.teardown() + kafkaTestUtils = null + } } test("Kafka input stream") { @@ -227,10 +56,10 @@ class KafkaStreamSuite extends KafkaStreamSuiteBase with BeforeAndAfter { ssc = new StreamingContext(sparkConf, Milliseconds(500)) val topic = "topic1" val sent = Map("a" -> 5, "b" -> 3, "c" -> 10) - createTopic(topic) - sendMessages(topic, sent) + kafkaTestUtils.createTopic(topic) + kafkaTestUtils.sendMessages(topic, sent) - val kafkaParams = Map("zookeeper.connect" -> zkAddress, + val kafkaParams = Map("zookeeper.connect" -> kafkaTestUtils.zkAddress, "group.id" -> s"test-consumer-${Random.nextInt(10000)}", "auto.offset.reset" -> "smallest") @@ -244,14 +73,14 @@ class KafkaStreamSuite extends KafkaStreamSuiteBase with BeforeAndAfter { result.put(kv._1, count) } } + ssc.start() + eventually(timeout(10000 milliseconds), interval(100 milliseconds)) { assert(sent.size === result.size) sent.keys.foreach { k => assert(sent(k) === result(k).toInt) } } - ssc.stop() } } - diff --git a/external/kafka/src/test/scala/org/apache/spark/streaming/kafka/ReliableKafkaStreamSuite.scala b/external/kafka/src/test/scala/org/apache/spark/streaming/kafka/ReliableKafkaStreamSuite.scala index 3cd960d1fd1d4..38548dd73b82c 100644 --- a/external/kafka/src/test/scala/org/apache/spark/streaming/kafka/ReliableKafkaStreamSuite.scala +++ b/external/kafka/src/test/scala/org/apache/spark/streaming/kafka/ReliableKafkaStreamSuite.scala @@ -17,7 +17,6 @@ package org.apache.spark.streaming.kafka - import java.io.File import scala.collection.mutable @@ -27,7 +26,7 @@ import scala.util.Random import kafka.serializer.StringDecoder import kafka.utils.{ZKGroupTopicDirs, ZkUtils} -import org.scalatest.BeforeAndAfter +import org.scalatest.{BeforeAndAfter, BeforeAndAfterAll, FunSuite} import org.scalatest.concurrent.Eventually import org.apache.spark.SparkConf @@ -35,47 +34,61 @@ import org.apache.spark.storage.StorageLevel import org.apache.spark.streaming.{Milliseconds, StreamingContext} import org.apache.spark.util.Utils -class ReliableKafkaStreamSuite extends KafkaStreamSuiteBase with BeforeAndAfter with Eventually { +class ReliableKafkaStreamSuite extends FunSuite + with BeforeAndAfterAll with BeforeAndAfter with Eventually { - val sparkConf = new SparkConf() + private val sparkConf = new SparkConf() .setMaster("local[4]") .setAppName(this.getClass.getSimpleName) .set("spark.streaming.receiver.writeAheadLog.enable", "true") - val data = Map("a" -> 10, "b" -> 10, "c" -> 10) + private val data = Map("a" -> 10, "b" -> 10, "c" -> 10) + private var kafkaTestUtils: KafkaTestUtils = _ - var groupId: String = _ - var kafkaParams: Map[String, String] = _ - var ssc: StreamingContext = _ - var tempDirectory: File = null + private var groupId: String = _ + private var kafkaParams: Map[String, String] = _ + private var ssc: StreamingContext = _ + private var tempDirectory: File = null + + override def beforeAll() : Unit = { + kafkaTestUtils = new KafkaTestUtils + kafkaTestUtils.setup() - before { - setupKafka() groupId = s"test-consumer-${Random.nextInt(10000)}" kafkaParams = Map( - "zookeeper.connect" -> zkAddress, + "zookeeper.connect" -> kafkaTestUtils.zkAddress, "group.id" -> groupId, "auto.offset.reset" -> "smallest" ) - ssc = new StreamingContext(sparkConf, Milliseconds(500)) tempDirectory = Utils.createTempDir() + } + + override def afterAll(): Unit = { + Utils.deleteRecursively(tempDirectory) + + if (kafkaTestUtils != null) { + kafkaTestUtils.teardown() + kafkaTestUtils = null + } + } + + before { + ssc = new StreamingContext(sparkConf, Milliseconds(500)) ssc.checkpoint(tempDirectory.getAbsolutePath) } after { if (ssc != null) { ssc.stop() + ssc = null } - Utils.deleteRecursively(tempDirectory) - tearDownKafka() } - test("Reliable Kafka input stream with single topic") { - var topic = "test-topic" - createTopic(topic) - sendMessages(topic, data) + val topic = "test-topic" + kafkaTestUtils.createTopic(topic) + kafkaTestUtils.sendMessages(topic, data) // Verify whether the offset of this group/topic/partition is 0 before starting. assert(getCommitOffset(groupId, topic, 0) === None) @@ -91,6 +104,7 @@ class ReliableKafkaStreamSuite extends KafkaStreamSuiteBase with BeforeAndAfter } } ssc.start() + eventually(timeout(20000 milliseconds), interval(200 milliseconds)) { // A basic process verification for ReliableKafkaReceiver. // Verify whether received message number is equal to the sent message number. @@ -100,14 +114,13 @@ class ReliableKafkaStreamSuite extends KafkaStreamSuiteBase with BeforeAndAfter // Verify the offset number whether it is equal to the total message number. assert(getCommitOffset(groupId, topic, 0) === Some(29L)) } - ssc.stop() } test("Reliable Kafka input stream with multiple topics") { val topics = Map("topic1" -> 1, "topic2" -> 1, "topic3" -> 1) topics.foreach { case (t, _) => - createTopic(t) - sendMessages(t, data) + kafkaTestUtils.createTopic(t) + kafkaTestUtils.sendMessages(t, data) } // Before started, verify all the group/topic/partition offsets are 0. @@ -118,19 +131,18 @@ class ReliableKafkaStreamSuite extends KafkaStreamSuiteBase with BeforeAndAfter ssc, kafkaParams, topics, StorageLevel.MEMORY_ONLY) stream.foreachRDD(_ => Unit) ssc.start() + eventually(timeout(20000 milliseconds), interval(100 milliseconds)) { // Verify the offset for each group/topic to see whether they are equal to the expected one. topics.foreach { case (t, _) => assert(getCommitOffset(groupId, t, 0) === Some(29L)) } } - ssc.stop() } /** Getting partition offset from Zookeeper. */ private def getCommitOffset(groupId: String, topic: String, partition: Int): Option[Long] = { - assert(zkClient != null, "Zookeeper client is not initialized") val topicDirs = new ZKGroupTopicDirs(groupId, topic) val zkPath = s"${topicDirs.consumerOffsetDir}/$partition" - ZkUtils.readDataMaybeNull(zkClient, zkPath)._1.map(_.toLong) + ZkUtils.readDataMaybeNull(kafkaTestUtils.zookeeperClient, zkPath)._1.map(_.toLong) } } diff --git a/external/mqtt/src/test/resources/log4j.properties b/external/mqtt/src/test/resources/log4j.properties index 9697237bfa1a3..75e3b53a093f6 100644 --- a/external/mqtt/src/test/resources/log4j.properties +++ b/external/mqtt/src/test/resources/log4j.properties @@ -24,5 +24,5 @@ log4j.appender.file.layout=org.apache.log4j.PatternLayout log4j.appender.file.layout.ConversionPattern=%d{yy/MM/dd HH:mm:ss.SSS} %t %p %c{1}: %m%n # Ignore messages below warning level from Jetty, because it's a bit verbose -log4j.logger.org.eclipse.jetty=WARN +log4j.logger.org.spark-project.jetty=WARN diff --git a/external/mqtt/src/test/scala/org/apache/spark/streaming/mqtt/MQTTStreamSuite.scala b/external/mqtt/src/test/scala/org/apache/spark/streaming/mqtt/MQTTStreamSuite.scala index 24d78ecb3a97d..a19a72c58a705 100644 --- a/external/mqtt/src/test/scala/org/apache/spark/streaming/mqtt/MQTTStreamSuite.scala +++ b/external/mqtt/src/test/scala/org/apache/spark/streaming/mqtt/MQTTStreamSuite.scala @@ -139,7 +139,8 @@ class MQTTStreamSuite extends FunSuite with Eventually with BeforeAndAfter { msgTopic.publish(message) } catch { case e: MqttException if e.getReasonCode == MqttException.REASON_CODE_MAX_INFLIGHT => - Thread.sleep(50) // wait for Spark streaming to consume something from the message queue + // wait for Spark streaming to consume something from the message queue + Thread.sleep(50) } } } diff --git a/external/twitter/src/main/scala/org/apache/spark/streaming/twitter/TwitterInputDStream.scala b/external/twitter/src/main/scala/org/apache/spark/streaming/twitter/TwitterInputDStream.scala index 4eacc47da5699..7cf02d85d73d3 100644 --- a/external/twitter/src/main/scala/org/apache/spark/streaming/twitter/TwitterInputDStream.scala +++ b/external/twitter/src/main/scala/org/apache/spark/streaming/twitter/TwitterInputDStream.scala @@ -70,7 +70,7 @@ class TwitterReceiver( try { val newTwitterStream = new TwitterStreamFactory().getInstance(twitterAuth) newTwitterStream.addListener(new StatusListener { - def onStatus(status: Status) = { + def onStatus(status: Status): Unit = { store(status) } // Unimplemented diff --git a/external/twitter/src/test/resources/log4j.properties b/external/twitter/src/test/resources/log4j.properties index 64bfc5745088f..9a3569789d2e0 100644 --- a/external/twitter/src/test/resources/log4j.properties +++ b/external/twitter/src/test/resources/log4j.properties @@ -24,5 +24,5 @@ log4j.appender.file.layout=org.apache.log4j.PatternLayout log4j.appender.file.layout.ConversionPattern=%d{yy/MM/dd HH:mm:ss.SSS} %t %p %c{1}: %m%n # Ignore messages below warning level from Jetty, because it's a bit verbose -log4j.logger.org.eclipse.jetty=WARN +log4j.logger.org.spark-project.jetty=WARN diff --git a/external/zeromq/src/main/scala/org/apache/spark/streaming/zeromq/ZeroMQReceiver.scala b/external/zeromq/src/main/scala/org/apache/spark/streaming/zeromq/ZeroMQReceiver.scala index 554705878ee78..588e6bac7b14a 100644 --- a/external/zeromq/src/main/scala/org/apache/spark/streaming/zeromq/ZeroMQReceiver.scala +++ b/external/zeromq/src/main/scala/org/apache/spark/streaming/zeromq/ZeroMQReceiver.scala @@ -29,13 +29,16 @@ import org.apache.spark.streaming.receiver.ActorHelper /** * A receiver to subscribe to ZeroMQ stream. */ -private[streaming] class ZeroMQReceiver[T: ClassTag](publisherUrl: String, - subscribe: Subscribe, - bytesToObjects: Seq[ByteString] => Iterator[T]) +private[streaming] class ZeroMQReceiver[T: ClassTag]( + publisherUrl: String, + subscribe: Subscribe, + bytesToObjects: Seq[ByteString] => Iterator[T]) extends Actor with ActorHelper with Logging { - override def preStart() = ZeroMQExtension(context.system) - .newSocket(SocketType.Sub, Listener(self), Connect(publisherUrl), subscribe) + override def preStart(): Unit = { + ZeroMQExtension(context.system) + .newSocket(SocketType.Sub, Listener(self), Connect(publisherUrl), subscribe) + } def receive: Receive = { diff --git a/external/zeromq/src/test/resources/log4j.properties b/external/zeromq/src/test/resources/log4j.properties index 9697237bfa1a3..75e3b53a093f6 100644 --- a/external/zeromq/src/test/resources/log4j.properties +++ b/external/zeromq/src/test/resources/log4j.properties @@ -24,5 +24,5 @@ log4j.appender.file.layout=org.apache.log4j.PatternLayout log4j.appender.file.layout.ConversionPattern=%d{yy/MM/dd HH:mm:ss.SSS} %t %p %c{1}: %m%n # Ignore messages below warning level from Jetty, because it's a bit verbose -log4j.logger.org.eclipse.jetty=WARN +log4j.logger.org.spark-project.jetty=WARN diff --git a/extras/java8-tests/src/test/resources/log4j.properties b/extras/java8-tests/src/test/resources/log4j.properties index 287c8e3563503..eb3b1999eb996 100644 --- a/extras/java8-tests/src/test/resources/log4j.properties +++ b/extras/java8-tests/src/test/resources/log4j.properties @@ -24,5 +24,5 @@ log4j.appender.file.layout=org.apache.log4j.PatternLayout log4j.appender.file.layout.ConversionPattern=%d{yy/MM/dd HH:mm:ss.SSS} %t %p %c{1}: %m%n # Ignore messages below warning level from Jetty, because it's a bit verbose -log4j.logger.org.eclipse.jetty=WARN -org.eclipse.jetty.LEVEL=WARN +log4j.logger.org.spark-project.jetty=WARN +org.spark-project.jetty.LEVEL=WARN diff --git a/extras/kinesis-asl/src/main/resources/log4j.properties b/extras/kinesis-asl/src/main/resources/log4j.properties index 97348fb5b6123..6cdc9286c5d76 100644 --- a/extras/kinesis-asl/src/main/resources/log4j.properties +++ b/extras/kinesis-asl/src/main/resources/log4j.properties @@ -31,7 +31,7 @@ log4j.appender.console.layout=org.apache.log4j.PatternLayout log4j.appender.console.layout.ConversionPattern=%d{yy/MM/dd HH:mm:ss} %p %c{1}: %m%n # Settings to quiet third party logs that are too verbose -log4j.logger.org.eclipse.jetty=WARN -log4j.logger.org.eclipse.jetty.util.component.AbstractLifeCycle=ERROR +log4j.logger.org.spark-project.jetty=WARN +log4j.logger.org.spark-project.jetty.util.component.AbstractLifeCycle=ERROR log4j.logger.org.apache.spark.repl.SparkIMain$exprTyper=INFO log4j.logger.org.apache.spark.repl.SparkILoop$SparkILoopInterpreter=INFO \ No newline at end of file diff --git a/extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisReceiver.scala b/extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisReceiver.scala index 1bd1f324298e7..a7fe4476cacb8 100644 --- a/extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisReceiver.scala +++ b/extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisReceiver.scala @@ -23,6 +23,7 @@ import org.apache.spark.Logging import org.apache.spark.storage.StorageLevel import org.apache.spark.streaming.Duration import org.apache.spark.streaming.receiver.Receiver +import org.apache.spark.util.Utils import com.amazonaws.auth.AWSCredentialsProvider import com.amazonaws.auth.DefaultAWSCredentialsProviderChain @@ -118,7 +119,7 @@ private[kinesis] class KinesisReceiver( * method. */ override def onStart() { - workerId = InetAddress.getLocalHost.getHostAddress() + ":" + UUID.randomUUID() + workerId = Utils.localHostName() + ":" + UUID.randomUUID() credentialsProvider = new DefaultAWSCredentialsProviderChain() kinesisClientLibConfiguration = new KinesisClientLibConfiguration(appName, streamName, credentialsProvider, workerId).withKinesisEndpoint(endpointUrl) diff --git a/extras/kinesis-asl/src/test/resources/log4j.properties b/extras/kinesis-asl/src/test/resources/log4j.properties index 853ef0ed2986f..edbecdae92096 100644 --- a/extras/kinesis-asl/src/test/resources/log4j.properties +++ b/extras/kinesis-asl/src/test/resources/log4j.properties @@ -24,4 +24,4 @@ log4j.appender.file.layout=org.apache.log4j.PatternLayout log4j.appender.file.layout.ConversionPattern=%d{yy/MM/dd HH:mm:ss.SSS} %t %p %c{1}: %m%n # Ignore messages below warning level from Jetty, because it's a bit verbose -log4j.logger.org.eclipse.jetty=WARN +log4j.logger.org.spark-project.jetty=WARN diff --git a/graphx/src/main/scala/org/apache/spark/graphx/EdgeContext.scala b/graphx/src/main/scala/org/apache/spark/graphx/EdgeContext.scala index d8be02e2023d5..23430179f12ec 100644 --- a/graphx/src/main/scala/org/apache/spark/graphx/EdgeContext.scala +++ b/graphx/src/main/scala/org/apache/spark/graphx/EdgeContext.scala @@ -62,7 +62,6 @@ object EdgeContext { * , _ + _) * }}} */ - def unapply[VD, ED, A](edge: EdgeContext[VD, ED, A]) = + def unapply[VD, ED, A](edge: EdgeContext[VD, ED, A]): Some[(VertexId, VertexId, VD, VD, ED)] = Some(edge.srcId, edge.dstId, edge.srcAttr, edge.dstAttr, edge.attr) } - diff --git a/graphx/src/main/scala/org/apache/spark/graphx/EdgeDirection.scala b/graphx/src/main/scala/org/apache/spark/graphx/EdgeDirection.scala index 6f03eb1439773..058c8c8aa1b24 100644 --- a/graphx/src/main/scala/org/apache/spark/graphx/EdgeDirection.scala +++ b/graphx/src/main/scala/org/apache/spark/graphx/EdgeDirection.scala @@ -34,12 +34,12 @@ class EdgeDirection private (private val name: String) extends Serializable { override def toString: String = "EdgeDirection." + name - override def equals(o: Any) = o match { + override def equals(o: Any): Boolean = o match { case other: EdgeDirection => other.name == name case _ => false } - override def hashCode = name.hashCode + override def hashCode: Int = name.hashCode } @@ -48,14 +48,14 @@ class EdgeDirection private (private val name: String) extends Serializable { */ object EdgeDirection { /** Edges arriving at a vertex. */ - final val In = new EdgeDirection("In") + final val In: EdgeDirection = new EdgeDirection("In") /** Edges originating from a vertex. */ - final val Out = new EdgeDirection("Out") + final val Out: EdgeDirection = new EdgeDirection("Out") /** Edges originating from *or* arriving at a vertex of interest. */ - final val Either = new EdgeDirection("Either") + final val Either: EdgeDirection = new EdgeDirection("Either") /** Edges originating from *and* arriving at a vertex of interest. */ - final val Both = new EdgeDirection("Both") + final val Both: EdgeDirection = new EdgeDirection("Both") } diff --git a/graphx/src/main/scala/org/apache/spark/graphx/EdgeTriplet.scala b/graphx/src/main/scala/org/apache/spark/graphx/EdgeTriplet.scala index 9d473d5ebda44..c8790cac3d8a0 100644 --- a/graphx/src/main/scala/org/apache/spark/graphx/EdgeTriplet.scala +++ b/graphx/src/main/scala/org/apache/spark/graphx/EdgeTriplet.scala @@ -62,7 +62,7 @@ class EdgeTriplet[VD, ED] extends Edge[ED] { def vertexAttr(vid: VertexId): VD = if (srcId == vid) srcAttr else { assert(dstId == vid); dstAttr } - override def toString = ((srcId, srcAttr), (dstId, dstAttr), attr).toString() + override def toString: String = ((srcId, srcAttr), (dstId, dstAttr), attr).toString() def toTuple: ((VertexId, VD), (VertexId, VD), ED) = ((srcId, srcAttr), (dstId, dstAttr), attr) } diff --git a/graphx/src/main/scala/org/apache/spark/graphx/Graph.scala b/graphx/src/main/scala/org/apache/spark/graphx/Graph.scala index 8494d06b1cdb7..36dc7b0f86c89 100644 --- a/graphx/src/main/scala/org/apache/spark/graphx/Graph.scala +++ b/graphx/src/main/scala/org/apache/spark/graphx/Graph.scala @@ -409,7 +409,7 @@ abstract class Graph[VD: ClassTag, ED: ClassTag] protected () extends Serializab * {{{ * val rawGraph: Graph[_, _] = Graph.textFile("twittergraph") * val inDeg: RDD[(VertexId, Int)] = - * aggregateMessages[Int](ctx => ctx.sendToDst(1), _ + _) + * rawGraph.aggregateMessages[Int](ctx => ctx.sendToDst(1), _ + _) * }}} * * @note By expressing computation at the edge level we achieve diff --git a/graphx/src/main/scala/org/apache/spark/graphx/impl/EdgePartition.scala b/graphx/src/main/scala/org/apache/spark/graphx/impl/EdgePartition.scala index 373af75448374..c561570809253 100644 --- a/graphx/src/main/scala/org/apache/spark/graphx/impl/EdgePartition.scala +++ b/graphx/src/main/scala/org/apache/spark/graphx/impl/EdgePartition.scala @@ -324,7 +324,7 @@ class EdgePartition[ * * @return an iterator over edges in the partition */ - def iterator = new Iterator[Edge[ED]] { + def iterator: Iterator[Edge[ED]] = new Iterator[Edge[ED]] { private[this] val edge = new Edge[ED] private[this] var pos = 0 @@ -351,7 +351,7 @@ class EdgePartition[ override def hasNext: Boolean = pos < EdgePartition.this.size - override def next() = { + override def next(): EdgeTriplet[VD, ED] = { val triplet = new EdgeTriplet[VD, ED] val localSrcId = localSrcIds(pos) val localDstId = localDstIds(pos) @@ -518,11 +518,11 @@ private class AggregatingEdgeContext[VD, ED, A]( _attr = attr } - override def srcId = _srcId - override def dstId = _dstId - override def srcAttr = _srcAttr - override def dstAttr = _dstAttr - override def attr = _attr + override def srcId: VertexId = _srcId + override def dstId: VertexId = _dstId + override def srcAttr: VD = _srcAttr + override def dstAttr: VD = _dstAttr + override def attr: ED = _attr override def sendToSrc(msg: A) { send(_localSrcId, msg) diff --git a/graphx/src/main/scala/org/apache/spark/graphx/impl/EdgeRDDImpl.scala b/graphx/src/main/scala/org/apache/spark/graphx/impl/EdgeRDDImpl.scala index 43a3aea0f6196..c88b2f65a86cd 100644 --- a/graphx/src/main/scala/org/apache/spark/graphx/impl/EdgeRDDImpl.scala +++ b/graphx/src/main/scala/org/apache/spark/graphx/impl/EdgeRDDImpl.scala @@ -70,9 +70,9 @@ class EdgeRDDImpl[ED: ClassTag, VD: ClassTag] private[graphx] ( this } - override def getStorageLevel = partitionsRDD.getStorageLevel + override def getStorageLevel: StorageLevel = partitionsRDD.getStorageLevel - override def checkpoint() = { + override def checkpoint(): Unit = { partitionsRDD.checkpoint() } diff --git a/graphx/src/main/scala/org/apache/spark/graphx/impl/ReplicatedVertexView.scala b/graphx/src/main/scala/org/apache/spark/graphx/impl/ReplicatedVertexView.scala index 8ab255bd4038c..1df86449fa0c2 100644 --- a/graphx/src/main/scala/org/apache/spark/graphx/impl/ReplicatedVertexView.scala +++ b/graphx/src/main/scala/org/apache/spark/graphx/impl/ReplicatedVertexView.scala @@ -50,7 +50,7 @@ class ReplicatedVertexView[VD: ClassTag, ED: ClassTag]( * Return a new `ReplicatedVertexView` where edges are reversed and shipping levels are swapped to * match. */ - def reverse() = { + def reverse(): ReplicatedVertexView[VD, ED] = { val newEdges = edges.mapEdgePartitions((pid, part) => part.reverse) new ReplicatedVertexView(newEdges, hasDstId, hasSrcId) } diff --git a/graphx/src/main/scala/org/apache/spark/graphx/impl/VertexRDDImpl.scala b/graphx/src/main/scala/org/apache/spark/graphx/impl/VertexRDDImpl.scala index 349c8545bf201..33ac7b0ed6095 100644 --- a/graphx/src/main/scala/org/apache/spark/graphx/impl/VertexRDDImpl.scala +++ b/graphx/src/main/scala/org/apache/spark/graphx/impl/VertexRDDImpl.scala @@ -71,9 +71,9 @@ class VertexRDDImpl[VD] private[graphx] ( this } - override def getStorageLevel = partitionsRDD.getStorageLevel + override def getStorageLevel: StorageLevel = partitionsRDD.getStorageLevel - override def checkpoint() = { + override def checkpoint(): Unit = { partitionsRDD.checkpoint() } diff --git a/graphx/src/main/scala/org/apache/spark/graphx/lib/ConnectedComponents.scala b/graphx/src/main/scala/org/apache/spark/graphx/lib/ConnectedComponents.scala index e2f6cc138958e..859f896039047 100644 --- a/graphx/src/main/scala/org/apache/spark/graphx/lib/ConnectedComponents.scala +++ b/graphx/src/main/scala/org/apache/spark/graphx/lib/ConnectedComponents.scala @@ -37,7 +37,7 @@ object ConnectedComponents { */ def run[VD: ClassTag, ED: ClassTag](graph: Graph[VD, ED]): Graph[VertexId, ED] = { val ccGraph = graph.mapVertices { case (vid, _) => vid } - def sendMessage(edge: EdgeTriplet[VertexId, ED]) = { + def sendMessage(edge: EdgeTriplet[VertexId, ED]): Iterator[(VertexId, VertexId)] = { if (edge.srcAttr < edge.dstAttr) { Iterator((edge.dstId, edge.srcAttr)) } else if (edge.srcAttr > edge.dstAttr) { diff --git a/graphx/src/main/scala/org/apache/spark/graphx/lib/LabelPropagation.scala b/graphx/src/main/scala/org/apache/spark/graphx/lib/LabelPropagation.scala index 82e9e06515179..2bcf8684b8b8e 100644 --- a/graphx/src/main/scala/org/apache/spark/graphx/lib/LabelPropagation.scala +++ b/graphx/src/main/scala/org/apache/spark/graphx/lib/LabelPropagation.scala @@ -43,7 +43,7 @@ object LabelPropagation { */ def run[VD, ED: ClassTag](graph: Graph[VD, ED], maxSteps: Int): Graph[VertexId, ED] = { val lpaGraph = graph.mapVertices { case (vid, _) => vid } - def sendMessage(e: EdgeTriplet[VertexId, ED]) = { + def sendMessage(e: EdgeTriplet[VertexId, ED]): Iterator[(VertexId, Map[VertexId, VertexId])] = { Iterator((e.srcId, Map(e.dstAttr -> 1L)), (e.dstId, Map(e.srcAttr -> 1L))) } def mergeMessage(count1: Map[VertexId, Long], count2: Map[VertexId, Long]) @@ -54,7 +54,7 @@ object LabelPropagation { i -> (count1Val + count2Val) }.toMap } - def vertexProgram(vid: VertexId, attr: Long, message: Map[VertexId, Long]) = { + def vertexProgram(vid: VertexId, attr: Long, message: Map[VertexId, Long]): VertexId = { if (message.isEmpty) attr else message.maxBy(_._2)._1 } val initialMessage = Map[VertexId, Long]() diff --git a/graphx/src/main/scala/org/apache/spark/graphx/lib/PageRank.scala b/graphx/src/main/scala/org/apache/spark/graphx/lib/PageRank.scala index 570440ba4441f..042e366a29f58 100644 --- a/graphx/src/main/scala/org/apache/spark/graphx/lib/PageRank.scala +++ b/graphx/src/main/scala/org/apache/spark/graphx/lib/PageRank.scala @@ -156,7 +156,7 @@ object PageRank extends Logging { (newPR, newPR - oldPR) } - def sendMessage(edge: EdgeTriplet[(Double, Double), Double]) = { + def sendMessage(edge: EdgeTriplet[(Double, Double), Double]): Iterator[(VertexId, Double)] = { if (edge.srcAttr._2 > tol) { Iterator((edge.dstId, edge.srcAttr._2 * edge.attr)) } else { diff --git a/graphx/src/main/scala/org/apache/spark/graphx/lib/SVDPlusPlus.scala b/graphx/src/main/scala/org/apache/spark/graphx/lib/SVDPlusPlus.scala index 1a7178b82e3af..3b0e1628d86b5 100644 --- a/graphx/src/main/scala/org/apache/spark/graphx/lib/SVDPlusPlus.scala +++ b/graphx/src/main/scala/org/apache/spark/graphx/lib/SVDPlusPlus.scala @@ -93,7 +93,7 @@ object SVDPlusPlus { val gJoinT0 = g.outerJoinVertices(t0) { (vid: VertexId, vd: (Array[Double], Array[Double], Double, Double), msg: Option[(Long, Double)]) => - (vd._1, vd._2, msg.get._2 / msg.get._1, 1.0 / scala.math.sqrt(msg.get._1)) + (vd._1, vd._2, msg.get._2 / msg.get._1 - u, 1.0 / scala.math.sqrt(msg.get._1)) }.cache() materialize(gJoinT0) g.unpersist() diff --git a/graphx/src/main/scala/org/apache/spark/graphx/util/collection/GraphXPrimitiveKeyOpenHashMap.scala b/graphx/src/main/scala/org/apache/spark/graphx/util/collection/GraphXPrimitiveKeyOpenHashMap.scala index 57b01b6f2e1fb..e2754ea699da9 100644 --- a/graphx/src/main/scala/org/apache/spark/graphx/util/collection/GraphXPrimitiveKeyOpenHashMap.scala +++ b/graphx/src/main/scala/org/apache/spark/graphx/util/collection/GraphXPrimitiveKeyOpenHashMap.scala @@ -56,7 +56,7 @@ class GraphXPrimitiveKeyOpenHashMap[@specialized(Long, Int) K: ClassTag, private var _oldValues: Array[V] = null - override def size = keySet.size + override def size: Int = keySet.size /** Get the value for a given key */ def apply(k: K): V = { @@ -112,7 +112,7 @@ class GraphXPrimitiveKeyOpenHashMap[@specialized(Long, Int) K: ClassTag, } } - override def iterator = new Iterator[(K, V)] { + override def iterator: Iterator[(K, V)] = new Iterator[(K, V)] { var pos = 0 var nextPair: (K, V) = computeNextPair() @@ -128,9 +128,9 @@ class GraphXPrimitiveKeyOpenHashMap[@specialized(Long, Int) K: ClassTag, } } - def hasNext = nextPair != null + def hasNext: Boolean = nextPair != null - def next() = { + def next(): (K, V) = { val pair = nextPair nextPair = computeNextPair() pair diff --git a/graphx/src/test/resources/log4j.properties b/graphx/src/test/resources/log4j.properties index 287c8e3563503..eb3b1999eb996 100644 --- a/graphx/src/test/resources/log4j.properties +++ b/graphx/src/test/resources/log4j.properties @@ -24,5 +24,5 @@ log4j.appender.file.layout=org.apache.log4j.PatternLayout log4j.appender.file.layout.ConversionPattern=%d{yy/MM/dd HH:mm:ss.SSS} %t %p %c{1}: %m%n # Ignore messages below warning level from Jetty, because it's a bit verbose -log4j.logger.org.eclipse.jetty=WARN -org.eclipse.jetty.LEVEL=WARN +log4j.logger.org.spark-project.jetty=WARN +org.spark-project.jetty.LEVEL=WARN diff --git a/graphx/src/test/scala/org/apache/spark/graphx/GraphSuite.scala b/graphx/src/test/scala/org/apache/spark/graphx/GraphSuite.scala index 8d15150458d26..a570e4ed75fc3 100644 --- a/graphx/src/test/scala/org/apache/spark/graphx/GraphSuite.scala +++ b/graphx/src/test/scala/org/apache/spark/graphx/GraphSuite.scala @@ -38,12 +38,12 @@ class GraphSuite extends FunSuite with LocalSparkContext { val doubleRing = ring ++ ring val graph = Graph.fromEdgeTuples(sc.parallelize(doubleRing), 1) assert(graph.edges.count() === doubleRing.size) - assert(graph.edges.collect.forall(e => e.attr == 1)) + assert(graph.edges.collect().forall(e => e.attr == 1)) // uniqueEdges option should uniquify edges and store duplicate count in edge attributes val uniqueGraph = Graph.fromEdgeTuples(sc.parallelize(doubleRing), 1, Some(RandomVertexCut)) assert(uniqueGraph.edges.count() === ring.size) - assert(uniqueGraph.edges.collect.forall(e => e.attr == 2)) + assert(uniqueGraph.edges.collect().forall(e => e.attr == 2)) } } @@ -64,7 +64,7 @@ class GraphSuite extends FunSuite with LocalSparkContext { assert( graph.edges.count() === rawEdges.size ) // Vertices not explicitly provided but referenced by edges should be created automatically assert( graph.vertices.count() === 100) - graph.triplets.collect.map { et => + graph.triplets.collect().map { et => assert((et.srcId < 10 && et.srcAttr) || (et.srcId >= 10 && !et.srcAttr)) assert((et.dstId < 10 && et.dstAttr) || (et.dstId >= 10 && !et.dstAttr)) } @@ -75,15 +75,17 @@ class GraphSuite extends FunSuite with LocalSparkContext { withSpark { sc => val n = 5 val star = starGraph(sc, n) - assert(star.triplets.map(et => (et.srcId, et.dstId, et.srcAttr, et.dstAttr)).collect.toSet === - (1 to n).map(x => (0: VertexId, x: VertexId, "v", "v")).toSet) + assert(star.triplets.map(et => (et.srcId, et.dstId, et.srcAttr, et.dstAttr)).collect().toSet + === (1 to n).map(x => (0: VertexId, x: VertexId, "v", "v")).toSet) } } test("partitionBy") { withSpark { sc => - def mkGraph(edges: List[(Long, Long)]) = Graph.fromEdgeTuples(sc.parallelize(edges, 2), 0) - def nonemptyParts(graph: Graph[Int, Int]) = { + def mkGraph(edges: List[(Long, Long)]): Graph[Int, Int] = { + Graph.fromEdgeTuples(sc.parallelize(edges, 2), 0) + } + def nonemptyParts(graph: Graph[Int, Int]): RDD[List[Edge[Int]]] = { graph.edges.partitionsRDD.mapPartitions { iter => Iterator(iter.next()._2.iterator.toList) }.filter(_.nonEmpty) @@ -102,7 +104,8 @@ class GraphSuite extends FunSuite with LocalSparkContext { assert(nonemptyParts(mkGraph(sameSrcEdges).partitionBy(EdgePartition1D)).count === 1) // partitionBy(CanonicalRandomVertexCut) puts edges that are identical modulo direction into // the same partition - assert(nonemptyParts(mkGraph(canonicalEdges).partitionBy(CanonicalRandomVertexCut)).count === 1) + assert( + nonemptyParts(mkGraph(canonicalEdges).partitionBy(CanonicalRandomVertexCut)).count === 1) // partitionBy(EdgePartition2D) puts identical edges in the same partition assert(nonemptyParts(mkGraph(identicalEdges).partitionBy(EdgePartition2D)).count === 1) @@ -140,10 +143,10 @@ class GraphSuite extends FunSuite with LocalSparkContext { val g = Graph( sc.parallelize(List((0L, "a"), (1L, "b"), (2L, "c"))), sc.parallelize(List(Edge(0L, 1L, 1), Edge(0L, 2L, 1)), 2)) - assert(g.triplets.collect.map(_.toTuple).toSet === + assert(g.triplets.collect().map(_.toTuple).toSet === Set(((0L, "a"), (1L, "b"), 1), ((0L, "a"), (2L, "c"), 1))) val gPart = g.partitionBy(EdgePartition2D) - assert(gPart.triplets.collect.map(_.toTuple).toSet === + assert(gPart.triplets.collect().map(_.toTuple).toSet === Set(((0L, "a"), (1L, "b"), 1), ((0L, "a"), (2L, "c"), 1))) } } @@ -154,10 +157,10 @@ class GraphSuite extends FunSuite with LocalSparkContext { val star = starGraph(sc, n) // mapVertices preserving type val mappedVAttrs = star.mapVertices((vid, attr) => attr + "2") - assert(mappedVAttrs.vertices.collect.toSet === (0 to n).map(x => (x: VertexId, "v2")).toSet) + assert(mappedVAttrs.vertices.collect().toSet === (0 to n).map(x => (x: VertexId, "v2")).toSet) // mapVertices changing type val mappedVAttrs2 = star.mapVertices((vid, attr) => attr.length) - assert(mappedVAttrs2.vertices.collect.toSet === (0 to n).map(x => (x: VertexId, 1)).toSet) + assert(mappedVAttrs2.vertices.collect().toSet === (0 to n).map(x => (x: VertexId, 1)).toSet) } } @@ -177,12 +180,12 @@ class GraphSuite extends FunSuite with LocalSparkContext { // Trigger initial vertex replication graph0.triplets.foreach(x => {}) // Change type of replicated vertices, but preserve erased type - val graph1 = graph0.mapVertices { - case (vid, integerOpt) => integerOpt.map((x: java.lang.Integer) => (x.toDouble): java.lang.Double) + val graph1 = graph0.mapVertices { case (vid, integerOpt) => + integerOpt.map((x: java.lang.Integer) => x.toDouble: java.lang.Double) } // Access replicated vertices, exposing the erased type val graph2 = graph1.mapTriplets(t => t.srcAttr.get) - assert(graph2.edges.map(_.attr).collect.toSet === Set[java.lang.Double](1.0, 2.0, 3.0)) + assert(graph2.edges.map(_.attr).collect().toSet === Set[java.lang.Double](1.0, 2.0, 3.0)) } } @@ -202,7 +205,7 @@ class GraphSuite extends FunSuite with LocalSparkContext { withSpark { sc => val n = 5 val star = starGraph(sc, n) - assert(star.mapTriplets(et => et.srcAttr + et.dstAttr).edges.collect.toSet === + assert(star.mapTriplets(et => et.srcAttr + et.dstAttr).edges.collect().toSet === (1L to n).map(x => Edge(0, x, "vv")).toSet) } } @@ -211,7 +214,7 @@ class GraphSuite extends FunSuite with LocalSparkContext { withSpark { sc => val n = 5 val star = starGraph(sc, n) - assert(star.reverse.outDegrees.collect.toSet === (1 to n).map(x => (x: VertexId, 1)).toSet) + assert(star.reverse.outDegrees.collect().toSet === (1 to n).map(x => (x: VertexId, 1)).toSet) } } @@ -221,7 +224,7 @@ class GraphSuite extends FunSuite with LocalSparkContext { val edges: RDD[Edge[Int]] = sc.parallelize(Array(Edge(1L, 2L, 0))) val graph = Graph(vertices, edges).reverse val result = graph.mapReduceTriplets[Int](et => Iterator((et.dstId, et.srcAttr)), _ + _) - assert(result.collect.toSet === Set((1L, 2))) + assert(result.collect().toSet === Set((1L, 2))) } } @@ -237,7 +240,8 @@ class GraphSuite extends FunSuite with LocalSparkContext { assert(subgraph.vertices.collect().toSet === (0 to n by 2).map(x => (x, "v")).toSet) // And 4 edges. - assert(subgraph.edges.map(_.copy()).collect().toSet === (2 to n by 2).map(x => Edge(0, x, 1)).toSet) + assert(subgraph.edges.map(_.copy()).collect().toSet === + (2 to n by 2).map(x => Edge(0, x, 1)).toSet) } } @@ -273,9 +277,9 @@ class GraphSuite extends FunSuite with LocalSparkContext { sc.parallelize((1 to n).flatMap(x => List((0: VertexId, x: VertexId), (0: VertexId, x: VertexId))), 1), "v") val star2 = doubleStar.groupEdges { (a, b) => a} - assert(star2.edges.collect.toArray.sorted(Edge.lexicographicOrdering[Int]) === - star.edges.collect.toArray.sorted(Edge.lexicographicOrdering[Int])) - assert(star2.vertices.collect.toSet === star.vertices.collect.toSet) + assert(star2.edges.collect().toArray.sorted(Edge.lexicographicOrdering[Int]) === + star.edges.collect().toArray.sorted(Edge.lexicographicOrdering[Int])) + assert(star2.vertices.collect().toSet === star.vertices.collect().toSet) } } @@ -300,21 +304,23 @@ class GraphSuite extends FunSuite with LocalSparkContext { throw new Exception("map ran on edge with dst vid %d, which is odd".format(et.dstId)) } Iterator((et.srcId, 1)) - }, (a: Int, b: Int) => a + b, Some((active, EdgeDirection.In))).collect.toSet + }, (a: Int, b: Int) => a + b, Some((active, EdgeDirection.In))).collect().toSet assert(numEvenNeighbors === (1 to n).map(x => (x: VertexId, n / 2)).toSet) // outerJoinVertices followed by mapReduceTriplets(activeSetOpt) - val ringEdges = sc.parallelize((0 until n).map(x => (x: VertexId, (x+1) % n: VertexId)), 3) + val ringEdges = sc.parallelize((0 until n).map(x => (x: VertexId, (x + 1) % n: VertexId)), 3) val ring = Graph.fromEdgeTuples(ringEdges, 0) .mapVertices((vid, attr) => vid).cache() val changed = ring.vertices.filter { case (vid, attr) => attr % 2 == 1 }.mapValues(-_).cache() - val changedGraph = ring.outerJoinVertices(changed) { (vid, old, newOpt) => newOpt.getOrElse(old) } + val changedGraph = ring.outerJoinVertices(changed) { (vid, old, newOpt) => + newOpt.getOrElse(old) + } val numOddNeighbors = changedGraph.mapReduceTriplets(et => { // Map function should only run on edges with source in the active set if (et.srcId % 2 != 1) { throw new Exception("map ran on edge with src vid %d, which is even".format(et.dstId)) } Iterator((et.dstId, 1)) - }, (a: Int, b: Int) => a + b, Some(changed, EdgeDirection.Out)).collect.toSet + }, (a: Int, b: Int) => a + b, Some(changed, EdgeDirection.Out)).collect().toSet assert(numOddNeighbors === (2 to n by 2).map(x => (x: VertexId, 1)).toSet) } @@ -340,17 +346,18 @@ class GraphSuite extends FunSuite with LocalSparkContext { val n = 5 val reverseStar = starGraph(sc, n).reverse.cache() // outerJoinVertices changing type - val reverseStarDegrees = - reverseStar.outerJoinVertices(reverseStar.outDegrees) { (vid, a, bOpt) => bOpt.getOrElse(0) } + val reverseStarDegrees = reverseStar.outerJoinVertices(reverseStar.outDegrees) { + (vid, a, bOpt) => bOpt.getOrElse(0) + } val neighborDegreeSums = reverseStarDegrees.mapReduceTriplets( et => Iterator((et.srcId, et.dstAttr), (et.dstId, et.srcAttr)), - (a: Int, b: Int) => a + b).collect.toSet + (a: Int, b: Int) => a + b).collect().toSet assert(neighborDegreeSums === Set((0: VertexId, n)) ++ (1 to n).map(x => (x: VertexId, 0))) // outerJoinVertices preserving type val messages = reverseStar.vertices.mapValues { (vid, attr) => vid.toString } val newReverseStar = reverseStar.outerJoinVertices(messages) { (vid, a, bOpt) => a + bOpt.getOrElse("") } - assert(newReverseStar.vertices.map(_._2).collect.toSet === + assert(newReverseStar.vertices.map(_._2).collect().toSet === (0 to n).map(x => "v%d".format(x)).toSet) } } @@ -361,7 +368,7 @@ class GraphSuite extends FunSuite with LocalSparkContext { val edges = sc.parallelize(List(Edge(1, 2, 0), Edge(2, 1, 0)), 2) val graph = Graph(verts, edges) val triplets = graph.triplets.map(et => (et.srcId, et.dstId, et.srcAttr, et.dstAttr)) - .collect.toSet + .collect().toSet assert(triplets === Set((1: VertexId, 2: VertexId, "a", "b"), (2: VertexId, 1: VertexId, "b", "a"))) } @@ -417,7 +424,7 @@ class GraphSuite extends FunSuite with LocalSparkContext { val graph = Graph.fromEdgeTuples(edges, 1) val neighborAttrSums = graph.mapReduceTriplets[Int]( et => Iterator((et.dstId, et.srcAttr)), _ + _) - assert(neighborAttrSums.collect.toSet === Set((0: VertexId, n))) + assert(neighborAttrSums.collect().toSet === Set((0: VertexId, n))) } finally { sc.stop() } diff --git a/graphx/src/test/scala/org/apache/spark/graphx/LocalSparkContext.scala b/graphx/src/test/scala/org/apache/spark/graphx/LocalSparkContext.scala index a3e28efc75a98..d2ad9be555770 100644 --- a/graphx/src/test/scala/org/apache/spark/graphx/LocalSparkContext.scala +++ b/graphx/src/test/scala/org/apache/spark/graphx/LocalSparkContext.scala @@ -26,7 +26,7 @@ import org.apache.spark.SparkContext */ trait LocalSparkContext { /** Runs `f` on a new SparkContext and ensures that it is stopped afterwards. */ - def withSpark[T](f: SparkContext => T) = { + def withSpark[T](f: SparkContext => T): T = { val conf = new SparkConf() GraphXUtils.registerKryoClasses(conf) val sc = new SparkContext("local", "test", conf) diff --git a/graphx/src/test/scala/org/apache/spark/graphx/VertexRDDSuite.scala b/graphx/src/test/scala/org/apache/spark/graphx/VertexRDDSuite.scala index c9443d11c76cf..d0a7198d691d7 100644 --- a/graphx/src/test/scala/org/apache/spark/graphx/VertexRDDSuite.scala +++ b/graphx/src/test/scala/org/apache/spark/graphx/VertexRDDSuite.scala @@ -25,7 +25,7 @@ import org.apache.spark.storage.StorageLevel class VertexRDDSuite extends FunSuite with LocalSparkContext { - def vertices(sc: SparkContext, n: Int) = { + private def vertices(sc: SparkContext, n: Int) = { VertexRDD(sc.parallelize((0 to n).map(x => (x.toLong, x)), 5)) } @@ -52,7 +52,7 @@ class VertexRDDSuite extends FunSuite with LocalSparkContext { val vertexA = VertexRDD(sc.parallelize(0 until 75, 2).map(i => (i.toLong, 0))).cache() val vertexB = VertexRDD(sc.parallelize(25 until 100, 2).map(i => (i.toLong, 1))).cache() val vertexC = vertexA.minus(vertexB) - assert(vertexC.map(_._1).collect.toSet === (0 until 25).toSet) + assert(vertexC.map(_._1).collect().toSet === (0 until 25).toSet) } } @@ -62,7 +62,7 @@ class VertexRDDSuite extends FunSuite with LocalSparkContext { val vertexB: RDD[(VertexId, Int)] = sc.parallelize(25 until 100, 2).map(i => (i.toLong, 1)).cache() val vertexC = vertexA.minus(vertexB) - assert(vertexC.map(_._1).collect.toSet === (0 until 25).toSet) + assert(vertexC.map(_._1).collect().toSet === (0 until 25).toSet) } } @@ -72,7 +72,7 @@ class VertexRDDSuite extends FunSuite with LocalSparkContext { val vertexB = VertexRDD(sc.parallelize(50 until 100, 2).map(i => (i.toLong, 1))) assert(vertexA.partitions.size != vertexB.partitions.size) val vertexC = vertexA.minus(vertexB) - assert(vertexC.map(_._1).collect.toSet === (0 until 50).toSet) + assert(vertexC.map(_._1).collect().toSet === (0 until 50).toSet) } } @@ -106,7 +106,7 @@ class VertexRDDSuite extends FunSuite with LocalSparkContext { val vertexB = VertexRDD(sc.parallelize(8 until 16, 2).map(i => (i.toLong, 1))) assert(vertexA.partitions.size != vertexB.partitions.size) val vertexC = vertexA.diff(vertexB) - assert(vertexC.map(_._1).collect.toSet === (8 until 16).toSet) + assert(vertexC.map(_._1).collect().toSet === (8 until 16).toSet) } } @@ -116,11 +116,11 @@ class VertexRDDSuite extends FunSuite with LocalSparkContext { val verts = vertices(sc, n).cache() val evens = verts.filter(q => ((q._2 % 2) == 0)).cache() // leftJoin with another VertexRDD - assert(verts.leftJoin(evens) { (id, a, bOpt) => a - bOpt.getOrElse(0) }.collect.toSet === + assert(verts.leftJoin(evens) { (id, a, bOpt) => a - bOpt.getOrElse(0) }.collect().toSet === (0 to n by 2).map(x => (x.toLong, 0)).toSet ++ (1 to n by 2).map(x => (x.toLong, x)).toSet) // leftJoin with an RDD val evensRDD = evens.map(identity) - assert(verts.leftJoin(evensRDD) { (id, a, bOpt) => a - bOpt.getOrElse(0) }.collect.toSet === + assert(verts.leftJoin(evensRDD) { (id, a, bOpt) => a - bOpt.getOrElse(0) }.collect().toSet === (0 to n by 2).map(x => (x.toLong, 0)).toSet ++ (1 to n by 2).map(x => (x.toLong, x)).toSet) } } @@ -134,7 +134,7 @@ class VertexRDDSuite extends FunSuite with LocalSparkContext { val vertexC = vertexA.leftJoin(vertexB) { (vid, old, newOpt) => old - newOpt.getOrElse(0) } - assert(vertexC.filter(v => v._2 != 0).map(_._1).collect.toSet == (1 to 99 by 2).toSet) + assert(vertexC.filter(v => v._2 != 0).map(_._1).collect().toSet == (1 to 99 by 2).toSet) } } @@ -144,11 +144,11 @@ class VertexRDDSuite extends FunSuite with LocalSparkContext { val verts = vertices(sc, n).cache() val evens = verts.filter(q => ((q._2 % 2) == 0)).cache() // innerJoin with another VertexRDD - assert(verts.innerJoin(evens) { (id, a, b) => a - b }.collect.toSet === + assert(verts.innerJoin(evens) { (id, a, b) => a - b }.collect().toSet === (0 to n by 2).map(x => (x.toLong, 0)).toSet) // innerJoin with an RDD val evensRDD = evens.map(identity) - assert(verts.innerJoin(evensRDD) { (id, a, b) => a - b }.collect.toSet === + assert(verts.innerJoin(evensRDD) { (id, a, b) => a - b }.collect().toSet === (0 to n by 2).map(x => (x.toLong, 0)).toSet) } } @@ -161,7 +161,7 @@ class VertexRDDSuite extends FunSuite with LocalSparkContext { val vertexC = vertexA.innerJoin(vertexB) { (vid, old, newVal) => old - newVal } - assert(vertexC.filter(v => v._2 == 0).map(_._1).collect.toSet == (0 to 98 by 2).toSet) + assert(vertexC.filter(v => v._2 == 0).map(_._1).collect().toSet == (0 to 98 by 2).toSet) } } @@ -171,7 +171,7 @@ class VertexRDDSuite extends FunSuite with LocalSparkContext { val verts = vertices(sc, n) val messageTargets = (0 to n) ++ (0 to n by 2) val messages = sc.parallelize(messageTargets.map(x => (x.toLong, 1))) - assert(verts.aggregateUsingIndex[Int](messages, _ + _).collect.toSet === + assert(verts.aggregateUsingIndex[Int](messages, _ + _).collect().toSet === (0 to n).map(x => (x.toLong, if (x % 2 == 0) 2 else 1)).toSet) } } @@ -183,7 +183,7 @@ class VertexRDDSuite extends FunSuite with LocalSparkContext { val edges = EdgeRDD.fromEdges(sc.parallelize(List.empty[Edge[Int]])) val rdd = VertexRDD(verts, edges, 0, (a: Int, b: Int) => a + b) // test merge function - assert(rdd.collect.toSet == Set((0L, 0), (1L, 3), (2L, 9))) + assert(rdd.collect().toSet == Set((0L, 0), (1L, 3), (2L, 9))) } } diff --git a/graphx/src/test/scala/org/apache/spark/graphx/lib/ConnectedComponentsSuite.scala b/graphx/src/test/scala/org/apache/spark/graphx/lib/ConnectedComponentsSuite.scala index 3915be15b3434..4cc30a96408f8 100644 --- a/graphx/src/test/scala/org/apache/spark/graphx/lib/ConnectedComponentsSuite.scala +++ b/graphx/src/test/scala/org/apache/spark/graphx/lib/ConnectedComponentsSuite.scala @@ -32,7 +32,7 @@ class ConnectedComponentsSuite extends FunSuite with LocalSparkContext { withSpark { sc => val gridGraph = GraphGenerators.gridGraph(sc, 10, 10) val ccGraph = gridGraph.connectedComponents() - val maxCCid = ccGraph.vertices.map { case (vid, ccId) => ccId }.sum + val maxCCid = ccGraph.vertices.map { case (vid, ccId) => ccId }.sum() assert(maxCCid === 0) } } // end of Grid connected components @@ -42,7 +42,7 @@ class ConnectedComponentsSuite extends FunSuite with LocalSparkContext { withSpark { sc => val gridGraph = GraphGenerators.gridGraph(sc, 10, 10).reverse val ccGraph = gridGraph.connectedComponents() - val maxCCid = ccGraph.vertices.map { case (vid, ccId) => ccId }.sum + val maxCCid = ccGraph.vertices.map { case (vid, ccId) => ccId }.sum() assert(maxCCid === 0) } } // end of Grid connected components @@ -50,8 +50,8 @@ class ConnectedComponentsSuite extends FunSuite with LocalSparkContext { test("Chain Connected Components") { withSpark { sc => - val chain1 = (0 until 9).map(x => (x, x+1) ) - val chain2 = (10 until 20).map(x => (x, x+1) ) + val chain1 = (0 until 9).map(x => (x, x + 1)) + val chain2 = (10 until 20).map(x => (x, x + 1)) val rawEdges = sc.parallelize(chain1 ++ chain2, 3).map { case (s,d) => (s.toLong, d.toLong) } val twoChains = Graph.fromEdgeTuples(rawEdges, 1.0) val ccGraph = twoChains.connectedComponents() @@ -73,12 +73,12 @@ class ConnectedComponentsSuite extends FunSuite with LocalSparkContext { test("Reverse Chain Connected Components") { withSpark { sc => - val chain1 = (0 until 9).map(x => (x, x+1) ) - val chain2 = (10 until 20).map(x => (x, x+1) ) + val chain1 = (0 until 9).map(x => (x, x + 1)) + val chain2 = (10 until 20).map(x => (x, x + 1)) val rawEdges = sc.parallelize(chain1 ++ chain2, 3).map { case (s,d) => (s.toLong, d.toLong) } val twoChains = Graph.fromEdgeTuples(rawEdges, true).reverse val ccGraph = twoChains.connectedComponents() - val vertices = ccGraph.vertices.collect + val vertices = ccGraph.vertices.collect() for ( (id, cc) <- vertices ) { if (id < 10) { assert(cc === 0) @@ -120,9 +120,9 @@ class ConnectedComponentsSuite extends FunSuite with LocalSparkContext { // Build the initial Graph val graph = Graph(users, relationships, defaultUser) val ccGraph = graph.connectedComponents() - val vertices = ccGraph.vertices.collect + val vertices = ccGraph.vertices.collect() for ( (id, cc) <- vertices ) { - assert(cc == 0) + assert(cc === 0) } } } // end of toy connected components diff --git a/graphx/src/test/scala/org/apache/spark/graphx/lib/PageRankSuite.scala b/graphx/src/test/scala/org/apache/spark/graphx/lib/PageRankSuite.scala index fc491ae327c2a..95804b07b1db0 100644 --- a/graphx/src/test/scala/org/apache/spark/graphx/lib/PageRankSuite.scala +++ b/graphx/src/test/scala/org/apache/spark/graphx/lib/PageRankSuite.scala @@ -19,15 +19,12 @@ package org.apache.spark.graphx.lib import org.scalatest.FunSuite -import org.apache.spark.SparkContext -import org.apache.spark.SparkContext._ import org.apache.spark.graphx._ -import org.apache.spark.graphx.lib._ import org.apache.spark.graphx.util.GraphGenerators -import org.apache.spark.rdd._ + object GridPageRank { - def apply(nRows: Int, nCols: Int, nIter: Int, resetProb: Double) = { + def apply(nRows: Int, nCols: Int, nIter: Int, resetProb: Double): Seq[(VertexId, Double)] = { val inNbrs = Array.fill(nRows * nCols)(collection.mutable.MutableList.empty[Int]) val outDegree = Array.fill(nRows * nCols)(0) // Convert row column address into vertex ids (row major order) @@ -35,13 +32,13 @@ object GridPageRank { // Make the grid graph for (r <- 0 until nRows; c <- 0 until nCols) { val ind = sub2ind(r,c) - if (r+1 < nRows) { + if (r + 1 < nRows) { outDegree(ind) += 1 - inNbrs(sub2ind(r+1,c)) += ind + inNbrs(sub2ind(r + 1,c)) += ind } - if (c+1 < nCols) { + if (c + 1 < nCols) { outDegree(ind) += 1 - inNbrs(sub2ind(r,c+1)) += ind + inNbrs(sub2ind(r,c + 1)) += ind } } // compute the pagerank @@ -64,7 +61,7 @@ class PageRankSuite extends FunSuite with LocalSparkContext { def compareRanks(a: VertexRDD[Double], b: VertexRDD[Double]): Double = { a.leftJoin(b) { case (id, a, bOpt) => (a - bOpt.getOrElse(0.0)) * (a - bOpt.getOrElse(0.0)) } - .map { case (id, error) => error }.sum + .map { case (id, error) => error }.sum() } test("Star PageRank") { @@ -80,12 +77,12 @@ class PageRankSuite extends FunSuite with LocalSparkContext { // Static PageRank should only take 2 iterations to converge val notMatching = staticRanks1.innerZipJoin(staticRanks2) { (vid, pr1, pr2) => if (pr1 != pr2) 1 else 0 - }.map { case (vid, test) => test }.sum + }.map { case (vid, test) => test }.sum() assert(notMatching === 0) val staticErrors = staticRanks2.map { case (vid, pr) => - val correct = (vid > 0 && pr == resetProb) || - (vid == 0 && math.abs(pr - (resetProb + (1.0 - resetProb) * (resetProb * (nVertices - 1)) )) < 1.0E-5) + val p = math.abs(pr - (resetProb + (1.0 - resetProb) * (resetProb * (nVertices - 1)) )) + val correct = (vid > 0 && pr == resetProb) || (vid == 0L && p < 1.0E-5) if (!correct) 1 else 0 } assert(staticErrors.sum === 0) @@ -95,8 +92,6 @@ class PageRankSuite extends FunSuite with LocalSparkContext { } } // end of test Star PageRank - - test("Grid PageRank") { withSpark { sc => val rows = 10 @@ -109,18 +104,18 @@ class PageRankSuite extends FunSuite with LocalSparkContext { val staticRanks = gridGraph.staticPageRank(numIter, resetProb).vertices.cache() val dynamicRanks = gridGraph.pageRank(tol, resetProb).vertices.cache() - val referenceRanks = VertexRDD(sc.parallelize(GridPageRank(rows, cols, numIter, resetProb))).cache() + val referenceRanks = VertexRDD( + sc.parallelize(GridPageRank(rows, cols, numIter, resetProb))).cache() assert(compareRanks(staticRanks, referenceRanks) < errorTol) assert(compareRanks(dynamicRanks, referenceRanks) < errorTol) } } // end of Grid PageRank - test("Chain PageRank") { withSpark { sc => - val chain1 = (0 until 9).map(x => (x, x+1) ) - val rawEdges = sc.parallelize(chain1, 1).map { case (s,d) => (s.toLong, d.toLong) } + val chain1 = (0 until 9).map(x => (x, x + 1)) + val rawEdges = sc.parallelize(chain1, 1).map { case (s, d) => (s.toLong, d.toLong) } val chain = Graph.fromEdgeTuples(rawEdges, 1.0).cache() val resetProb = 0.15 val tol = 0.0001 diff --git a/graphx/src/test/scala/org/apache/spark/graphx/lib/StronglyConnectedComponentsSuite.scala b/graphx/src/test/scala/org/apache/spark/graphx/lib/StronglyConnectedComponentsSuite.scala index df54aa37cad68..1f658c371ffcf 100644 --- a/graphx/src/test/scala/org/apache/spark/graphx/lib/StronglyConnectedComponentsSuite.scala +++ b/graphx/src/test/scala/org/apache/spark/graphx/lib/StronglyConnectedComponentsSuite.scala @@ -34,8 +34,8 @@ class StronglyConnectedComponentsSuite extends FunSuite with LocalSparkContext { val edges = sc.parallelize(Seq.empty[Edge[Int]]) val graph = Graph(vertices, edges) val sccGraph = graph.stronglyConnectedComponents(5) - for ((id, scc) <- sccGraph.vertices.collect) { - assert(id == scc) + for ((id, scc) <- sccGraph.vertices.collect()) { + assert(id === scc) } } } @@ -45,8 +45,8 @@ class StronglyConnectedComponentsSuite extends FunSuite with LocalSparkContext { val rawEdges = sc.parallelize((0L to 6L).map(x => (x, (x + 1) % 7))) val graph = Graph.fromEdgeTuples(rawEdges, -1) val sccGraph = graph.stronglyConnectedComponents(20) - for ((id, scc) <- sccGraph.vertices.collect) { - assert(0L == scc) + for ((id, scc) <- sccGraph.vertices.collect()) { + assert(0L === scc) } } } @@ -60,13 +60,14 @@ class StronglyConnectedComponentsSuite extends FunSuite with LocalSparkContext { val rawEdges = sc.parallelize(edges) val graph = Graph.fromEdgeTuples(rawEdges, -1) val sccGraph = graph.stronglyConnectedComponents(20) - for ((id, scc) <- sccGraph.vertices.collect) { - if (id < 3) - assert(0L == scc) - else if (id < 6) - assert(3L == scc) - else - assert(id == scc) + for ((id, scc) <- sccGraph.vertices.collect()) { + if (id < 3) { + assert(0L === scc) + } else if (id < 6) { + assert(3L === scc) + } else { + assert(id === scc) + } } } } diff --git a/launcher/pom.xml b/launcher/pom.xml index 0fe2814135d88..182e5f60218db 100644 --- a/launcher/pom.xml +++ b/launcher/pom.xml @@ -52,11 +52,6 @@ mockito-all test - - org.scalatest - scalatest_${scala.binary.version} - test - org.slf4j slf4j-api diff --git a/launcher/src/main/java/org/apache/spark/launcher/CommandBuilderUtils.java b/launcher/src/main/java/org/apache/spark/launcher/CommandBuilderUtils.java index 9b04732afee14..f4ebc25bdd32b 100644 --- a/launcher/src/main/java/org/apache/spark/launcher/CommandBuilderUtils.java +++ b/launcher/src/main/java/org/apache/spark/launcher/CommandBuilderUtils.java @@ -274,14 +274,14 @@ static String quoteForBatchScript(String arg) { } /** - * Quotes a string so that it can be used in a command string and be parsed back into a single - * argument by python's "shlex.split()" function. - * + * Quotes a string so that it can be used in a command string. * Basically, just add simple escapes. E.g.: * original single argument : ab "cd" ef * after: "ab \"cd\" ef" + * + * This can be parsed back into a single argument by python's "shlex.split()" function. */ - static String quoteForPython(String s) { + static String quoteForCommandString(String s) { StringBuilder quoted = new StringBuilder().append('"'); for (int i = 0; i < s.length(); i++) { int cp = s.codePointAt(i); diff --git a/launcher/src/main/java/org/apache/spark/launcher/SparkSubmitCommandBuilder.java b/launcher/src/main/java/org/apache/spark/launcher/SparkSubmitCommandBuilder.java index 91dcf70f105db..a73c9c87e3126 100644 --- a/launcher/src/main/java/org/apache/spark/launcher/SparkSubmitCommandBuilder.java +++ b/launcher/src/main/java/org/apache/spark/launcher/SparkSubmitCommandBuilder.java @@ -17,14 +17,9 @@ package org.apache.spark.launcher; +import java.io.File; import java.io.IOException; -import java.util.ArrayList; -import java.util.Arrays; -import java.util.Collections; -import java.util.HashMap; -import java.util.List; -import java.util.Map; -import java.util.Properties; +import java.util.*; import static org.apache.spark.launcher.CommandBuilderUtils.*; @@ -53,6 +48,20 @@ class SparkSubmitCommandBuilder extends AbstractCommandBuilder { */ static final String PYSPARK_SHELL_RESOURCE = "pyspark-shell"; + /** + * Name of the app resource used to identify the SparkR shell. The command line parser expects + * the resource name to be the very first argument to spark-submit in this case. + * + * NOTE: this cannot be "sparkr-shell" since that identifies the SparkR shell to SparkSubmit + * (see sparkR.R), and can cause this code to enter into an infinite loop. + */ + static final String SPARKR_SHELL = "sparkr-shell-main"; + + /** + * This is the actual resource name that identifies the SparkR shell to SparkSubmit. + */ + static final String SPARKR_SHELL_RESOURCE = "sparkr-shell"; + /** * This map must match the class names for available special classes, since this modifies the way * command line parsing works. This maps the class name to the resource to use when calling @@ -87,6 +96,10 @@ class SparkSubmitCommandBuilder extends AbstractCommandBuilder { this.allowsMixedArguments = true; appResource = PYSPARK_SHELL_RESOURCE; submitArgs = args.subList(1, args.size()); + } else if (args.size() > 0 && args.get(0).equals(SPARKR_SHELL)) { + this.allowsMixedArguments = true; + appResource = SPARKR_SHELL_RESOURCE; + submitArgs = args.subList(1, args.size()); } else { this.allowsMixedArguments = false; } @@ -98,6 +111,8 @@ class SparkSubmitCommandBuilder extends AbstractCommandBuilder { public List buildCommand(Map env) throws IOException { if (PYSPARK_SHELL_RESOURCE.equals(appResource)) { return buildPySparkShellCommand(env); + } else if (SPARKR_SHELL_RESOURCE.equals(appResource)) { + return buildSparkRCommand(env); } else { return buildSparkSubmitCommand(env); } @@ -213,36 +228,62 @@ private List buildPySparkShellCommand(Map env) throws IO return buildCommand(env); } - // When launching the pyspark shell, the spark-submit arguments should be stored in the - // PYSPARK_SUBMIT_ARGS env variable. The executable is the PYSPARK_DRIVER_PYTHON env variable - // set by the pyspark script, followed by PYSPARK_DRIVER_PYTHON_OPTS. checkArgument(appArgs.isEmpty(), "pyspark does not support any application options."); + // When launching the pyspark shell, the spark-submit arguments should be stored in the + // PYSPARK_SUBMIT_ARGS env variable. + constructEnvVarArgs(env, "PYSPARK_SUBMIT_ARGS"); + + // The executable is the PYSPARK_DRIVER_PYTHON env variable set by the pyspark script, + // followed by PYSPARK_DRIVER_PYTHON_OPTS. + List pyargs = new ArrayList(); + pyargs.add(firstNonEmpty(System.getenv("PYSPARK_DRIVER_PYTHON"), "python")); + String pyOpts = System.getenv("PYSPARK_DRIVER_PYTHON_OPTS"); + if (!isEmpty(pyOpts)) { + pyargs.addAll(parseOptionString(pyOpts)); + } + + return pyargs; + } + + private List buildSparkRCommand(Map env) throws IOException { + if (!appArgs.isEmpty() && appArgs.get(0).endsWith(".R")) { + appResource = appArgs.get(0); + appArgs.remove(0); + return buildCommand(env); + } + // When launching the SparkR shell, store the spark-submit arguments in the SPARKR_SUBMIT_ARGS + // env variable. + constructEnvVarArgs(env, "SPARKR_SUBMIT_ARGS"); + + // Set shell.R as R_PROFILE_USER to load the SparkR package when the shell comes up. + String sparkHome = System.getenv("SPARK_HOME"); + env.put("R_PROFILE_USER", + join(File.separator, sparkHome, "R", "lib", "SparkR", "profile", "shell.R")); + + List args = new ArrayList(); + args.add(firstNonEmpty(System.getenv("SPARKR_DRIVER_R"), "R")); + return args; + } + + private void constructEnvVarArgs( + Map env, + String submitArgsEnvVariable) throws IOException { Properties props = loadPropertiesFile(); mergeEnvPathList(env, getLibPathEnvName(), firstNonEmptyValue(SparkLauncher.DRIVER_EXTRA_LIBRARY_PATH, conf, props)); - // Store spark-submit arguments in an environment variable, since there's no way to pass - // them to shell.py on the comand line. StringBuilder submitArgs = new StringBuilder(); for (String arg : buildSparkSubmitArgs()) { if (submitArgs.length() > 0) { submitArgs.append(" "); } - submitArgs.append(quoteForPython(arg)); + submitArgs.append(quoteForCommandString(arg)); } - env.put("PYSPARK_SUBMIT_ARGS", submitArgs.toString()); - - List pyargs = new ArrayList(); - pyargs.add(firstNonEmpty(System.getenv("PYSPARK_DRIVER_PYTHON"), "python")); - String pyOpts = System.getenv("PYSPARK_DRIVER_PYTHON_OPTS"); - if (!isEmpty(pyOpts)) { - pyargs.addAll(parseOptionString(pyOpts)); - } - - return pyargs; + env.put(submitArgsEnvVariable, submitArgs.toString()); } + private boolean isClientMode(Properties userProps) { String userMaster = firstNonEmpty(master, (String) userProps.get(SparkLauncher.SPARK_MASTER)); // Default master is "local[*]", so assume client mode in that case. diff --git a/launcher/src/test/java/org/apache/spark/launcher/CommandBuilderUtilsSuite.java b/launcher/src/test/java/org/apache/spark/launcher/CommandBuilderUtilsSuite.java index dba0203867372..1ae42eed8a3af 100644 --- a/launcher/src/test/java/org/apache/spark/launcher/CommandBuilderUtilsSuite.java +++ b/launcher/src/test/java/org/apache/spark/launcher/CommandBuilderUtilsSuite.java @@ -79,9 +79,9 @@ public void testWindowsBatchQuoting() { @Test public void testPythonArgQuoting() { - assertEquals("\"abc\"", quoteForPython("abc")); - assertEquals("\"a b c\"", quoteForPython("a b c")); - assertEquals("\"a \\\"b\\\" c\"", quoteForPython("a \"b\" c")); + assertEquals("\"abc\"", quoteForCommandString("abc")); + assertEquals("\"a b c\"", quoteForCommandString("a b c")); + assertEquals("\"a \\\"b\\\" c\"", quoteForCommandString("a \"b\" c")); } private void testOpt(String opts, List expected) { diff --git a/launcher/src/test/resources/log4j.properties b/launcher/src/test/resources/log4j.properties index 00c20ad69cd4d..67a6a98217118 100644 --- a/launcher/src/test/resources/log4j.properties +++ b/launcher/src/test/resources/log4j.properties @@ -27,5 +27,5 @@ log4j.appender.file.layout=org.apache.log4j.PatternLayout log4j.appender.file.layout.ConversionPattern=%d{yy/MM/dd HH:mm:ss.SSS} %t %p %c{1}: %m%n # Ignore messages below warning level from Jetty, because it's a bit verbose -log4j.logger.org.eclipse.jetty=WARN -org.eclipse.jetty.LEVEL=WARN +log4j.logger.org.spark-project.jetty=WARN +org.spark-project.jetty.LEVEL=WARN diff --git a/mllib/src/main/scala/org/apache/spark/ml/Identifiable.scala b/mllib/src/main/scala/org/apache/spark/ml/Identifiable.scala index cd84b05bfb496..a50090671ae48 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/Identifiable.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/Identifiable.scala @@ -29,5 +29,5 @@ private[ml] trait Identifiable extends Serializable { * random hex chars. */ private[ml] val uid: String = - this.getClass.getSimpleName + "-" + UUID.randomUUID().toString.take(8) + this.getClass.getSimpleName + "_" + UUID.randomUUID().toString.take(8) } diff --git a/mllib/src/main/scala/org/apache/spark/ml/Pipeline.scala b/mllib/src/main/scala/org/apache/spark/ml/Pipeline.scala index c4a36103303a2..a455341a1f723 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/Pipeline.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/Pipeline.scala @@ -47,6 +47,9 @@ abstract class PipelineStage extends Serializable with Logging { /** * Derives the output schema from the input schema and parameters, optionally with logging. + * + * This should be optimistic. If it is unclear whether the schema will be valid, then it should + * be assumed valid until proven otherwise. */ protected def transformSchema( schema: StructType, diff --git a/mllib/src/main/scala/org/apache/spark/ml/attribute/AttributeGroup.scala b/mllib/src/main/scala/org/apache/spark/ml/attribute/AttributeGroup.scala index 970e6ad5514d1..aa27a668f1695 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/attribute/AttributeGroup.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/attribute/AttributeGroup.scala @@ -106,7 +106,7 @@ class AttributeGroup private ( def getAttr(attrIndex: Int): Attribute = this(attrIndex) /** Converts to metadata without name. */ - private[attribute] def toMetadata: Metadata = { + private[attribute] def toMetadataImpl: Metadata = { import AttributeKeys._ val bldr = new MetadataBuilder() if (attributes.isDefined) { @@ -142,17 +142,24 @@ class AttributeGroup private ( bldr.build() } - /** Converts to a StructField with some existing metadata. */ - def toStructField(existingMetadata: Metadata): StructField = { - val newMetadata = new MetadataBuilder() + /** Converts to ML metadata with some existing metadata. */ + def toMetadata(existingMetadata: Metadata): Metadata = { + new MetadataBuilder() .withMetadata(existingMetadata) - .putMetadata(AttributeKeys.ML_ATTR, toMetadata) + .putMetadata(AttributeKeys.ML_ATTR, toMetadataImpl) .build() - StructField(name, new VectorUDT, nullable = false, newMetadata) + } + + /** Converts to ML metadata */ + def toMetadata: Metadata = toMetadata(Metadata.empty) + + /** Converts to a StructField with some existing metadata. */ + def toStructField(existingMetadata: Metadata): StructField = { + StructField(name, new VectorUDT, nullable = false, toMetadata(existingMetadata)) } /** Converts to a StructField. */ - def toStructField(): StructField = toStructField(Metadata.empty) + def toStructField: StructField = toStructField(Metadata.empty) override def equals(other: Any): Boolean = { other match { diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala index 21f61d80dd95a..34625745dd0a8 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala @@ -31,7 +31,7 @@ import org.apache.spark.storage.StorageLevel * Params for logistic regression. */ private[classification] trait LogisticRegressionParams extends ProbabilisticClassifierParams - with HasRegParam with HasMaxIter with HasThreshold + with HasRegParam with HasMaxIter with HasFitIntercept with HasThreshold /** @@ -55,6 +55,9 @@ class LogisticRegression /** @group setParam */ def setMaxIter(value: Int): this.type = set(maxIter, value) + /** @group setParam */ + def setFitIntercept(value: Boolean): this.type = set(fitIntercept, value) + /** @group setParam */ def setThreshold(value: Double): this.type = set(threshold, value) @@ -67,7 +70,8 @@ class LogisticRegression } // Train model - val lr = new LogisticRegressionWithLBFGS + val lr = new LogisticRegressionWithLBFGS() + .setIntercept(paramMap(fitIntercept)) lr.optimizer .setRegParam(paramMap(regParam)) .setNumIterations(paramMap(maxIter)) @@ -180,7 +184,6 @@ class LogisticRegressionModel private[ml] ( * The behavior of this can be adjusted using [[threshold]]. */ override protected def predict(features: Vector): Double = { - println(s"LR.predict with threshold: ${paramMap(threshold)}") if (score(features) > paramMap(threshold)) 1 else 0 } diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala new file mode 100644 index 0000000000000..61e6742e880d8 --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala @@ -0,0 +1,126 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.feature + +import org.apache.spark.SparkException +import org.apache.spark.annotation.AlphaComponent +import org.apache.spark.ml.{Estimator, Model} +import org.apache.spark.ml.attribute.NominalAttribute +import org.apache.spark.ml.param._ +import org.apache.spark.sql.DataFrame +import org.apache.spark.sql.functions._ +import org.apache.spark.sql.types.{StringType, StructType} +import org.apache.spark.util.collection.OpenHashMap + +/** + * Base trait for [[StringIndexer]] and [[StringIndexerModel]]. + */ +private[feature] trait StringIndexerBase extends Params with HasInputCol with HasOutputCol { + + /** Validates and transforms the input schema. */ + protected def validateAndTransformSchema(schema: StructType, paramMap: ParamMap): StructType = { + val map = this.paramMap ++ paramMap + checkInputColumn(schema, map(inputCol), StringType) + val inputFields = schema.fields + val outputColName = map(outputCol) + require(inputFields.forall(_.name != outputColName), + s"Output column $outputColName already exists.") + val attr = NominalAttribute.defaultAttr.withName(map(outputCol)) + val outputFields = inputFields :+ attr.toStructField() + StructType(outputFields) + } +} + +/** + * :: AlphaComponent :: + * A label indexer that maps a string column of labels to an ML column of label indices. + * The indices are in [0, numLabels), ordered by label frequencies. + * So the most frequent label gets index 0. + */ +@AlphaComponent +class StringIndexer extends Estimator[StringIndexerModel] with StringIndexerBase { + + /** @group setParam */ + def setInputCol(value: String): this.type = set(inputCol, value) + + /** @group setParam */ + def setOutputCol(value: String): this.type = set(outputCol, value) + + // TODO: handle unseen labels + + override def fit(dataset: DataFrame, paramMap: ParamMap): StringIndexerModel = { + val map = this.paramMap ++ paramMap + val counts = dataset.select(map(inputCol)).map(_.getString(0)).countByValue() + val labels = counts.toSeq.sortBy(-_._2).map(_._1).toArray + val model = new StringIndexerModel(this, map, labels) + Params.inheritValues(map, this, model) + model + } + + override def transformSchema(schema: StructType, paramMap: ParamMap): StructType = { + validateAndTransformSchema(schema, paramMap) + } +} + +/** + * :: AlphaComponent :: + * Model fitted by [[StringIndexer]]. + */ +@AlphaComponent +class StringIndexerModel private[ml] ( + override val parent: StringIndexer, + override val fittingParamMap: ParamMap, + labels: Array[String]) extends Model[StringIndexerModel] with StringIndexerBase { + + private val labelToIndex: OpenHashMap[String, Double] = { + val n = labels.length + val map = new OpenHashMap[String, Double](n) + var i = 0 + while (i < n) { + map.update(labels(i), i) + i += 1 + } + map + } + + /** @group setParam */ + def setInputCol(value: String): this.type = set(inputCol, value) + + /** @group setParam */ + def setOutputCol(value: String): this.type = set(outputCol, value) + + override def transform(dataset: DataFrame, paramMap: ParamMap): DataFrame = { + val map = this.paramMap ++ paramMap + val indexer = udf { label: String => + if (labelToIndex.contains(label)) { + labelToIndex(label) + } else { + // TODO: handle unseen labels + throw new SparkException(s"Unseen label: $label.") + } + } + val outputColName = map(outputCol) + val metadata = NominalAttribute.defaultAttr + .withName(outputColName).withValues(labels).toStructField().metadata + dataset.select(col("*"), indexer(dataset(map(inputCol))).as(outputColName, metadata)) + } + + override def transformSchema(schema: StructType, paramMap: ParamMap): StructType = { + validateAndTransformSchema(schema, paramMap) + } +} diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/VectorAssembler.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/VectorAssembler.scala new file mode 100644 index 0000000000000..d1b8f7e6e9295 --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/VectorAssembler.scala @@ -0,0 +1,111 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.feature + +import scala.collection.mutable.ArrayBuilder + +import org.apache.spark.SparkException +import org.apache.spark.annotation.AlphaComponent +import org.apache.spark.ml.Transformer +import org.apache.spark.ml.param.{HasInputCols, HasOutputCol, ParamMap} +import org.apache.spark.mllib.linalg.{Vector, VectorUDT, Vectors} +import org.apache.spark.sql.{Column, DataFrame, Row} +import org.apache.spark.sql.catalyst.analysis.UnresolvedAttribute +import org.apache.spark.sql.catalyst.expressions.{Alias, Cast, CreateStruct} +import org.apache.spark.sql.functions._ +import org.apache.spark.sql.types._ + +/** + * :: AlphaComponent :: + * A feature transformer than merge multiple columns into a vector column. + */ +@AlphaComponent +class VectorAssembler extends Transformer with HasInputCols with HasOutputCol { + + /** @group setParam */ + def setInputCols(value: Array[String]): this.type = set(inputCols, value) + + /** @group setParam */ + def setOutputCol(value: String): this.type = set(outputCol, value) + + override def transform(dataset: DataFrame, paramMap: ParamMap): DataFrame = { + val map = this.paramMap ++ paramMap + val assembleFunc = udf { r: Row => + VectorAssembler.assemble(r.toSeq: _*) + } + val schema = dataset.schema + val inputColNames = map(inputCols) + val args = inputColNames.map { c => + schema(c).dataType match { + case DoubleType => UnresolvedAttribute(c) + case t if t.isInstanceOf[VectorUDT] => UnresolvedAttribute(c) + case _: NativeType => Alias(Cast(UnresolvedAttribute(c), DoubleType), s"${c}_double_$uid")() + } + } + dataset.select(col("*"), assembleFunc(new Column(CreateStruct(args))).as(map(outputCol))) + } + + override def transformSchema(schema: StructType, paramMap: ParamMap): StructType = { + val map = this.paramMap ++ paramMap + val inputColNames = map(inputCols) + val outputColName = map(outputCol) + val inputDataTypes = inputColNames.map(name => schema(name).dataType) + inputDataTypes.foreach { + case _: NativeType => + case t if t.isInstanceOf[VectorUDT] => + case other => + throw new IllegalArgumentException(s"Data type $other is not supported.") + } + if (schema.fieldNames.contains(outputColName)) { + throw new IllegalArgumentException(s"Output column $outputColName already exists.") + } + StructType(schema.fields :+ new StructField(outputColName, new VectorUDT, false)) + } +} + +@AlphaComponent +object VectorAssembler { + + private[feature] def assemble(vv: Any*): Vector = { + val indices = ArrayBuilder.make[Int] + val values = ArrayBuilder.make[Double] + var cur = 0 + vv.foreach { + case v: Double => + if (v != 0.0) { + indices += cur + values += v + } + cur += 1 + case vec: Vector => + vec.foreachActive { case (i, v) => + if (v != 0.0) { + indices += cur + i + values += v + } + } + cur += vec.size + case null => + // TODO: output Double.NaN? + throw new SparkException("Values to assemble cannot be null.") + case o => + throw new SparkException(s"$o of type ${o.getClass.getName} is not supported.") + } + Vectors.sparse(cur, indices.result(), values.result()) + } +} diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/VectorIndexer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/VectorIndexer.scala new file mode 100644 index 0000000000000..8760960e19272 --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/VectorIndexer.scala @@ -0,0 +1,393 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.feature + +import org.apache.spark.annotation.AlphaComponent +import org.apache.spark.ml.{Estimator, Model} +import org.apache.spark.ml.attribute.{BinaryAttribute, NumericAttribute, NominalAttribute, + Attribute, AttributeGroup} +import org.apache.spark.ml.param.{HasInputCol, HasOutputCol, IntParam, ParamMap, Params} +import org.apache.spark.mllib.linalg.{SparseVector, DenseVector, Vector, VectorUDT} +import org.apache.spark.sql.{Row, DataFrame} +import org.apache.spark.sql.functions.callUDF +import org.apache.spark.sql.types.{StructField, StructType} +import org.apache.spark.util.collection.OpenHashSet + + +/** Private trait for params for VectorIndexer and VectorIndexerModel */ +private[ml] trait VectorIndexerParams extends Params with HasInputCol with HasOutputCol { + + /** + * Threshold for the number of values a categorical feature can take. + * If a feature is found to have > maxCategories values, then it is declared continuous. + * + * (default = 20) + */ + val maxCategories = new IntParam(this, "maxCategories", + "Threshold for the number of values a categorical feature can take." + + " If a feature is found to have > maxCategories values, then it is declared continuous.", + Some(20)) + + /** @group getParam */ + def getMaxCategories: Int = get(maxCategories) +} + +/** + * :: AlphaComponent :: + * + * Class for indexing categorical feature columns in a dataset of [[Vector]]. + * + * This has 2 usage modes: + * - Automatically identify categorical features (default behavior) + * - This helps process a dataset of unknown vectors into a dataset with some continuous + * features and some categorical features. The choice between continuous and categorical + * is based upon a maxCategories parameter. + * - Set maxCategories to the maximum number of categorical any categorical feature should have. + * - E.g.: Feature 0 has unique values {-1.0, 0.0}, and feature 1 values {1.0, 3.0, 5.0}. + * If maxCategories = 2, then feature 0 will be declared categorical and use indices {0, 1}, + * and feature 1 will be declared continuous. + * - Index all features, if all features are categorical + * - If maxCategories is set to be very large, then this will build an index of unique + * values for all features. + * - Warning: This can cause problems if features are continuous since this will collect ALL + * unique values to the driver. + * - E.g.: Feature 0 has unique values {-1.0, 0.0}, and feature 1 values {1.0, 3.0, 5.0}. + * If maxCategories >= 3, then both features will be declared categorical. + * + * This returns a model which can transform categorical features to use 0-based indices. + * + * Index stability: + * - This is not guaranteed to choose the same category index across multiple runs. + * - If a categorical feature includes value 0, then this is guaranteed to map value 0 to index 0. + * This maintains vector sparsity. + * - More stability may be added in the future. + * + * TODO: Future extensions: The following functionality is planned for the future: + * - Preserve metadata in transform; if a feature's metadata is already present, do not recompute. + * - Specify certain features to not index, either via a parameter or via existing metadata. + * - Add warning if a categorical feature has only 1 category. + * - Add option for allowing unknown categories. + */ +@AlphaComponent +class VectorIndexer extends Estimator[VectorIndexerModel] with VectorIndexerParams { + + /** @group setParam */ + def setMaxCategories(value: Int): this.type = { + require(value > 1, + s"DatasetIndexer given maxCategories = value, but requires maxCategories > 1.") + set(maxCategories, value) + } + + /** @group setParam */ + def setInputCol(value: String): this.type = set(inputCol, value) + + /** @group setParam */ + def setOutputCol(value: String): this.type = set(outputCol, value) + + override def fit(dataset: DataFrame, paramMap: ParamMap): VectorIndexerModel = { + transformSchema(dataset.schema, paramMap, logging = true) + val map = this.paramMap ++ paramMap + val firstRow = dataset.select(map(inputCol)).take(1) + require(firstRow.length == 1, s"VectorIndexer cannot be fit on an empty dataset.") + val numFeatures = firstRow(0).getAs[Vector](0).size + val vectorDataset = dataset.select(map(inputCol)).map { case Row(v: Vector) => v } + val maxCats = map(maxCategories) + val categoryStats: VectorIndexer.CategoryStats = vectorDataset.mapPartitions { iter => + val localCatStats = new VectorIndexer.CategoryStats(numFeatures, maxCats) + iter.foreach(localCatStats.addVector) + Iterator(localCatStats) + }.reduce((stats1, stats2) => stats1.merge(stats2)) + val model = new VectorIndexerModel(this, map, numFeatures, categoryStats.getCategoryMaps) + Params.inheritValues(map, this, model) + model + } + + override def transformSchema(schema: StructType, paramMap: ParamMap): StructType = { + // We do not transfer feature metadata since we do not know what types of features we will + // produce in transform(). + val map = this.paramMap ++ paramMap + val dataType = new VectorUDT + require(map.contains(inputCol), s"VectorIndexer requires input column parameter: $inputCol") + require(map.contains(outputCol), s"VectorIndexer requires output column parameter: $outputCol") + checkInputColumn(schema, map(inputCol), dataType) + addOutputColumn(schema, map(outputCol), dataType) + } +} + +private object VectorIndexer { + + /** + * Helper class for tracking unique values for each feature. + * + * TODO: Track which features are known to be continuous already; do not update counts for them. + * + * @param numFeatures This class fails if it encounters a Vector whose length is not numFeatures. + * @param maxCategories This class caps the number of unique values collected at maxCategories. + */ + class CategoryStats(private val numFeatures: Int, private val maxCategories: Int) + extends Serializable { + + /** featureValueSets[feature index] = set of unique values */ + private val featureValueSets = + Array.fill[OpenHashSet[Double]](numFeatures)(new OpenHashSet[Double]()) + + /** Merge with another instance, modifying this instance. */ + def merge(other: CategoryStats): CategoryStats = { + featureValueSets.zip(other.featureValueSets).foreach { case (thisValSet, otherValSet) => + otherValSet.iterator.foreach { x => + // Once we have found > maxCategories values, we know the feature is continuous + // and do not need to collect more values for it. + if (thisValSet.size <= maxCategories) thisValSet.add(x) + } + } + this + } + + /** Add a new vector to this index, updating sets of unique feature values */ + def addVector(v: Vector): Unit = { + require(v.size == numFeatures, s"VectorIndexer expected $numFeatures features but" + + s" found vector of size ${v.size}.") + v match { + case dv: DenseVector => addDenseVector(dv) + case sv: SparseVector => addSparseVector(sv) + } + } + + /** + * Based on stats collected, decide which features are categorical, + * and choose indices for categories. + * + * Sparsity: This tries to maintain sparsity by treating value 0.0 specially. + * If a categorical feature takes value 0.0, then value 0.0 is given index 0. + * + * @return Feature value index. Keys are categorical feature indices (column indices). + * Values are mappings from original features values to 0-based category indices. + */ + def getCategoryMaps: Map[Int, Map[Double, Int]] = { + // Filter out features which are declared continuous. + featureValueSets.zipWithIndex.filter(_._1.size <= maxCategories).map { + case (featureValues: OpenHashSet[Double], featureIndex: Int) => + var sortedFeatureValues = featureValues.iterator.filter(_ != 0.0).toArray.sorted + val zeroExists = sortedFeatureValues.length + 1 == featureValues.size + if (zeroExists) { + sortedFeatureValues = 0.0 +: sortedFeatureValues + } + val categoryMap: Map[Double, Int] = sortedFeatureValues.zipWithIndex.toMap + (featureIndex, categoryMap) + }.toMap + } + + private def addDenseVector(dv: DenseVector): Unit = { + var i = 0 + while (i < dv.size) { + if (featureValueSets(i).size <= maxCategories) { + featureValueSets(i).add(dv(i)) + } + i += 1 + } + } + + private def addSparseVector(sv: SparseVector): Unit = { + // TODO: This might be able to handle 0's more efficiently. + var vecIndex = 0 // index into vector + var k = 0 // index into non-zero elements + while (vecIndex < sv.size) { + val featureValue = if (k < sv.indices.length && vecIndex == sv.indices(k)) { + k += 1 + sv.values(k - 1) + } else { + 0.0 + } + if (featureValueSets(vecIndex).size <= maxCategories) { + featureValueSets(vecIndex).add(featureValue) + } + vecIndex += 1 + } + } + } +} + +/** + * :: AlphaComponent :: + * + * Transform categorical features to use 0-based indices instead of their original values. + * - Categorical features are mapped to indices. + * - Continuous features (columns) are left unchanged. + * This also appends metadata to the output column, marking features as Numeric (continuous), + * Nominal (categorical), or Binary (either continuous or categorical). + * + * This maintains vector sparsity. + * + * @param numFeatures Number of features, i.e., length of Vectors which this transforms + * @param categoryMaps Feature value index. Keys are categorical feature indices (column indices). + * Values are maps from original features values to 0-based category indices. + * If a feature is not in this map, it is treated as continuous. + */ +@AlphaComponent +class VectorIndexerModel private[ml] ( + override val parent: VectorIndexer, + override val fittingParamMap: ParamMap, + val numFeatures: Int, + val categoryMaps: Map[Int, Map[Double, Int]]) + extends Model[VectorIndexerModel] with VectorIndexerParams { + + /** + * Pre-computed feature attributes, with some missing info. + * In transform(), set attribute name and other info, if available. + */ + private val partialFeatureAttributes: Array[Attribute] = { + val attrs = new Array[Attribute](numFeatures) + var categoricalFeatureCount = 0 // validity check for numFeatures, categoryMaps + var featureIndex = 0 + while (featureIndex < numFeatures) { + if (categoryMaps.contains(featureIndex)) { + // categorical feature + val featureValues: Array[String] = + categoryMaps(featureIndex).toArray.sortBy(_._1).map(_._1).map(_.toString) + if (featureValues.length == 2) { + attrs(featureIndex) = new BinaryAttribute(index = Some(featureIndex), + values = Some(featureValues)) + } else { + attrs(featureIndex) = new NominalAttribute(index = Some(featureIndex), + isOrdinal = Some(false), values = Some(featureValues)) + } + categoricalFeatureCount += 1 + } else { + // continuous feature + attrs(featureIndex) = new NumericAttribute(index = Some(featureIndex)) + } + featureIndex += 1 + } + require(categoricalFeatureCount == categoryMaps.size, "VectorIndexerModel given categoryMaps" + + s" with keys outside expected range [0,...,numFeatures), where numFeatures=$numFeatures") + attrs + } + + // TODO: Check more carefully about whether this whole class will be included in a closure. + + private val transformFunc: Vector => Vector = { + val sortedCategoricalFeatureIndices = categoryMaps.keys.toArray.sorted + val localVectorMap = categoryMaps + val f: Vector => Vector = { + case dv: DenseVector => + val tmpv = dv.copy + localVectorMap.foreach { case (featureIndex: Int, categoryMap: Map[Double, Int]) => + tmpv.values(featureIndex) = categoryMap(tmpv(featureIndex)) + } + tmpv + case sv: SparseVector => + // We use the fact that categorical value 0 is always mapped to index 0. + val tmpv = sv.copy + var catFeatureIdx = 0 // index into sortedCategoricalFeatureIndices + var k = 0 // index into non-zero elements of sparse vector + while (catFeatureIdx < sortedCategoricalFeatureIndices.length && k < tmpv.indices.length) { + val featureIndex = sortedCategoricalFeatureIndices(catFeatureIdx) + if (featureIndex < tmpv.indices(k)) { + catFeatureIdx += 1 + } else if (featureIndex > tmpv.indices(k)) { + k += 1 + } else { + tmpv.values(k) = localVectorMap(featureIndex)(tmpv.values(k)) + catFeatureIdx += 1 + k += 1 + } + } + tmpv + } + f + } + + /** @group setParam */ + def setInputCol(value: String): this.type = set(inputCol, value) + + /** @group setParam */ + def setOutputCol(value: String): this.type = set(outputCol, value) + + override def transform(dataset: DataFrame, paramMap: ParamMap): DataFrame = { + transformSchema(dataset.schema, paramMap, logging = true) + val map = this.paramMap ++ paramMap + val newField = prepOutputField(dataset.schema, map) + val newCol = callUDF(transformFunc, new VectorUDT, dataset(map(inputCol))) + // For now, just check the first row of inputCol for vector length. + val firstRow = dataset.select(map(inputCol)).take(1) + if (firstRow.length != 0) { + val actualNumFeatures = firstRow(0).getAs[Vector](0).size + require(numFeatures == actualNumFeatures, "VectorIndexerModel expected vector of length" + + s" $numFeatures but found length $actualNumFeatures") + } + dataset.withColumn(map(outputCol), newCol.as(map(outputCol), newField.metadata)) + } + + override def transformSchema(schema: StructType, paramMap: ParamMap): StructType = { + val map = this.paramMap ++ paramMap + val dataType = new VectorUDT + require(map.contains(inputCol), + s"VectorIndexerModel requires input column parameter: $inputCol") + require(map.contains(outputCol), + s"VectorIndexerModel requires output column parameter: $outputCol") + checkInputColumn(schema, map(inputCol), dataType) + + val origAttrGroup = AttributeGroup.fromStructField(schema(map(inputCol))) + val origNumFeatures: Option[Int] = if (origAttrGroup.attributes.nonEmpty) { + Some(origAttrGroup.attributes.get.length) + } else { + origAttrGroup.numAttributes + } + require(origNumFeatures.forall(_ == numFeatures), "VectorIndexerModel expected" + + s" $numFeatures features, but input column ${map(inputCol)} had metadata specifying" + + s" ${origAttrGroup.numAttributes.get} features.") + + val newField = prepOutputField(schema, map) + val outputFields = schema.fields :+ newField + StructType(outputFields) + } + + /** + * Prepare the output column field, including per-feature metadata. + * @param schema Input schema + * @param map Parameter map (with this class' embedded parameter map folded in) + * @return Output column field + */ + private def prepOutputField(schema: StructType, map: ParamMap): StructField = { + val origAttrGroup = AttributeGroup.fromStructField(schema(map(inputCol))) + val featureAttributes: Array[Attribute] = if (origAttrGroup.attributes.nonEmpty) { + // Convert original attributes to modified attributes + val origAttrs: Array[Attribute] = origAttrGroup.attributes.get + origAttrs.zip(partialFeatureAttributes).map { + case (origAttr: Attribute, featAttr: BinaryAttribute) => + if (origAttr.name.nonEmpty) { + featAttr.withName(origAttr.name.get) + } else { + featAttr + } + case (origAttr: Attribute, featAttr: NominalAttribute) => + if (origAttr.name.nonEmpty) { + featAttr.withName(origAttr.name.get) + } else { + featAttr + } + case (origAttr: Attribute, featAttr: NumericAttribute) => + origAttr.withIndex(featAttr.index.get) + } + } else { + partialFeatureAttributes + } + val newAttributeGroup = new AttributeGroup(map(outputCol), featureAttributes) + newAttributeGroup.toStructField(schema(map(inputCol)).metadata) + } +} diff --git a/mllib/src/main/scala/org/apache/spark/ml/param/params.scala b/mllib/src/main/scala/org/apache/spark/ml/param/params.scala index 17ece897a6c55..7d5178d0abb2d 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/param/params.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/param/params.scala @@ -198,23 +198,31 @@ trait Params extends Identifiable with Serializable { /** * Check whether the given schema contains an input column. - * @param colName Parameter name for the input column. - * @param dataType SQL DataType of the input column. + * @param colName Input column name + * @param dataType Input column DataType */ protected def checkInputColumn(schema: StructType, colName: String, dataType: DataType): Unit = { val actualDataType = schema(colName).dataType - require(actualDataType.equals(dataType), - s"Input column $colName must be of type $dataType" + - s" but was actually $actualDataType. Column param description: ${getParam(colName)}") + require(actualDataType.equals(dataType), s"Input column $colName must be of type $dataType" + + s" but was actually $actualDataType. Column param description: ${getParam(colName)}") } + /** + * Add an output column to the given schema. + * This fails if the given output column already exists. + * @param schema Initial schema (not modified) + * @param colName Output column name. If this column name is an empy String "", this method + * returns the initial schema, unchanged. This allows users to disable output + * columns. + * @param dataType Output column DataType + */ protected def addOutputColumn( schema: StructType, colName: String, dataType: DataType): StructType = { if (colName.length == 0) return schema val fieldNames = schema.fieldNames - require(!fieldNames.contains(colName), s"Prediction column $colName already exists.") + require(!fieldNames.contains(colName), s"Output column $colName already exists.") val outputFields = schema.fields ++ Seq(StructField(colName, dataType, nullable = false)) StructType(outputFields) } diff --git a/mllib/src/main/scala/org/apache/spark/ml/param/sharedParams.scala b/mllib/src/main/scala/org/apache/spark/ml/param/sharedParams.scala index 5d660d1e151a7..07e6eb417763d 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/param/sharedParams.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/param/sharedParams.scala @@ -106,6 +106,18 @@ private[ml] trait HasProbabilityCol extends Params { def getProbabilityCol: String = get(probabilityCol) } +private[ml] trait HasFitIntercept extends Params { + /** + * param for fitting the intercept term, defaults to true + * @group param + */ + val fitIntercept: BooleanParam = + new BooleanParam(this, "fitIntercept", "indicates whether to fit an intercept term", Some(true)) + + /** @group getParam */ + def getFitIntercept: Boolean = get(fitIntercept) +} + private[ml] trait HasThreshold extends Params { /** * param for threshold in (binary) prediction @@ -128,6 +140,16 @@ private[ml] trait HasInputCol extends Params { def getInputCol: String = get(inputCol) } +private[ml] trait HasInputCols extends Params { + /** + * Param for input column names. + */ + val inputCols: Param[Array[String]] = new Param(this, "inputCols", "input column names") + + /** @group getParam */ + def getInputCols: Array[String] = get(inputCols) +} + private[ml] trait HasOutputCol extends Params { /** * param for output column name diff --git a/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala b/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala index 514b4ef98dc5b..52c9e95d6012f 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala @@ -320,7 +320,7 @@ object ALS extends Logging { /** Trait for least squares solvers applied to the normal equation. */ private[recommendation] trait LeastSquaresNESolver extends Serializable { - /** Solves a least squares problem (possibly with other constraints). */ + /** Solves a least squares problem with regularization (possibly with other constraints). */ def solve(ne: NormalEquation, lambda: Double): Array[Float] } @@ -332,20 +332,19 @@ object ALS extends Logging { /** * Solves a least squares problem with L2 regularization: * - * min norm(A x - b)^2^ + lambda * n * norm(x)^2^ + * min norm(A x - b)^2^ + lambda * norm(x)^2^ * * @param ne a [[NormalEquation]] instance that contains AtA, Atb, and n (number of instances) - * @param lambda regularization constant, which will be scaled by n + * @param lambda regularization constant * @return the solution x */ override def solve(ne: NormalEquation, lambda: Double): Array[Float] = { val k = ne.k // Add scaled lambda to the diagonals of AtA. - val scaledlambda = lambda * ne.n var i = 0 var j = 2 while (i < ne.triK) { - ne.ata(i) += scaledlambda + ne.ata(i) += lambda i += j j += 1 } @@ -391,7 +390,7 @@ object ALS extends Logging { override def solve(ne: NormalEquation, lambda: Double): Array[Float] = { val rank = ne.k initialize(rank) - fillAtA(ne.ata, lambda * ne.n) + fillAtA(ne.ata, lambda) val x = NNLS.solve(ata, ne.atb, workspace) ne.reset() x.map(x => x.toFloat) @@ -420,7 +419,15 @@ object ALS extends Logging { } } - /** Representing a normal equation (ALS' subproblem). */ + /** + * Representing a normal equation to solve the following weighted least squares problem: + * + * minimize \sum,,i,, c,,i,, (a,,i,,^T^ x - b,,i,,)^2^ + lambda * x^T^ x. + * + * Its normal equation is given by + * + * \sum,,i,, c,,i,, (a,,i,, a,,i,,^T^ x - b,,i,, a,,i,,) + lambda * x = 0. + */ private[recommendation] class NormalEquation(val k: Int) extends Serializable { /** Number of entries in the upper triangular part of a k-by-k matrix. */ @@ -429,8 +436,6 @@ object ALS extends Logging { val ata = new Array[Double](triK) /** A^T^ * b */ val atb = new Array[Double](k) - /** Number of observations. */ - var n = 0 private val da = new Array[Double](k) private val upper = "U" @@ -444,28 +449,13 @@ object ALS extends Logging { } /** Adds an observation. */ - def add(a: Array[Float], b: Float): this.type = { - require(a.length == k) - copyToDouble(a) - blas.dspr(upper, k, 1.0, da, 1, ata) - blas.daxpy(k, b.toDouble, da, 1, atb, 1) - n += 1 - this - } - - /** - * Adds an observation with implicit feedback. Note that this does not increment the counter. - */ - def addImplicit(a: Array[Float], b: Float, alpha: Double): this.type = { + def add(a: Array[Float], b: Double, c: Double = 1.0): this.type = { + require(c >= 0.0) require(a.length == k) - // Extension to the original paper to handle b < 0. confidence is a function of |b| instead - // so that it is never negative. - val confidence = 1.0 + alpha * math.abs(b) copyToDouble(a) - blas.dspr(upper, k, confidence - 1.0, da, 1, ata) - // For b <= 0, the corresponding preference is 0. So the term below is only added for b > 0. - if (b > 0) { - blas.daxpy(k, confidence, da, 1, atb, 1) + blas.dspr(upper, k, c, da, 1, ata) + if (b != 0.0) { + blas.daxpy(k, c * b, da, 1, atb, 1) } this } @@ -475,7 +465,6 @@ object ALS extends Logging { require(other.k == k) blas.daxpy(ata.length, 1.0, other.ata, 1, ata, 1) blas.daxpy(atb.length, 1.0, other.atb, 1, atb, 1) - n += other.n this } @@ -483,7 +472,6 @@ object ALS extends Logging { def reset(): Unit = { ju.Arrays.fill(ata, 0.0) ju.Arrays.fill(atb, 0.0) - n = 0 } } @@ -1114,6 +1102,7 @@ object ALS extends Logging { ls.merge(YtY.get) } var i = srcPtrs(j) + var numExplicits = 0 while (i < srcPtrs(j + 1)) { val encoded = srcEncodedIndices(i) val blockId = srcEncoder.blockId(encoded) @@ -1121,13 +1110,23 @@ object ALS extends Logging { val srcFactor = sortedSrcFactors(blockId)(localIndex) val rating = ratings(i) if (implicitPrefs) { - ls.addImplicit(srcFactor, rating, alpha) + // Extension to the original paper to handle b < 0. confidence is a function of |b| + // instead so that it is never negative. c1 is confidence - 1.0. + val c1 = alpha * math.abs(rating) + // For rating <= 0, the corresponding preference is 0. So the term below is only added + // for rating > 0. Because YtY is already added, we need to adjust the scaling here. + if (rating > 0) { + numExplicits += 1 + ls.add(srcFactor, (c1 + 1.0) / c1, c1) + } } else { ls.add(srcFactor, rating) + numExplicits += 1 } i += 1 } - dstFactors(j) = solver.solve(ls, regParam) + // Weight lambda by the number of explicit ratings based on the ALS-WR paper. + dstFactors(j) = solver.solve(ls, numExplicits * regParam) j += 1 } dstFactors @@ -1141,7 +1140,7 @@ object ALS extends Logging { private def computeYtY(factorBlocks: RDD[(Int, FactorBlock)], rank: Int): NormalEquation = { factorBlocks.values.aggregate(new NormalEquation(rank))( seqOp = (ne, factors) => { - factors.foreach(ne.add(_, 0.0f)) + factors.foreach(ne.add(_, 0.0)) ne }, combOp = (ne1, ne2) => ne1.merge(ne2)) diff --git a/core/src/main/scala/org/apache/spark/TaskContextHelper.scala b/mllib/src/main/scala/org/apache/spark/mllib/api/python/FPGrowthModelWrapper.scala similarity index 62% rename from core/src/main/scala/org/apache/spark/TaskContextHelper.scala rename to mllib/src/main/scala/org/apache/spark/mllib/api/python/FPGrowthModelWrapper.scala index 4636c4600a01a..ee933f4cfcafd 100644 --- a/core/src/main/scala/org/apache/spark/TaskContextHelper.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/api/python/FPGrowthModelWrapper.scala @@ -15,15 +15,19 @@ * limitations under the License. */ -package org.apache.spark +package org.apache.spark.mllib.api.python + +import org.apache.spark.api.java.JavaRDD +import org.apache.spark.mllib.fpm.{FPGrowth, FPGrowthModel} +import org.apache.spark.rdd.RDD /** - * This class exists to restrict the visibility of TaskContext setters. + * A Wrapper of FPGrowthModel to provide helper method for Python */ -private [spark] object TaskContextHelper { - - def setTaskContext(tc: TaskContext): Unit = TaskContext.setTaskContext(tc) +private[python] class FPGrowthModelWrapper(model: FPGrowthModel[Any]) + extends FPGrowthModel(model.freqItemsets) { - def unset(): Unit = TaskContext.unset() - + def getFreqItemsets: RDD[Array[Any]] = { + SerDe.fromTuple2RDD(model.freqItemsets.map(x => (x.javaItems, x.freq))) + } } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala b/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala index 662ec5fbed453..ab15f0f36a14b 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala @@ -34,6 +34,7 @@ import org.apache.spark.api.python.SerDeUtil import org.apache.spark.mllib.classification._ import org.apache.spark.mllib.clustering._ import org.apache.spark.mllib.feature._ +import org.apache.spark.mllib.fpm.{FPGrowth, FPGrowthModel} import org.apache.spark.mllib.linalg._ import org.apache.spark.mllib.optimization._ import org.apache.spark.mllib.random.{RandomRDDs => RG} @@ -358,9 +359,7 @@ private[python] class PythonMLLibAPI extends Serializable { val model = new GaussianMixtureModel(weight, gaussians) model.predictSoft(data) } - - - + /** * Java stub for Python mllib ALS.train(). This stub returns a handle * to the Java object instead of the content of the Java object. Extra care @@ -420,6 +419,24 @@ private[python] class PythonMLLibAPI extends Serializable { new MatrixFactorizationModelWrapper(model) } + /** + * Java stub for Python mllib FPGrowth.train(). This stub returns a handle + * to the Java object instead of the content of the Java object. Extra care + * needs to be taken in the Python code to ensure it gets freed on exit; see + * the Py4J documentation. + */ + def trainFPGrowthModel( + data: JavaRDD[java.lang.Iterable[Any]], + minSupport: Double, + numPartitions: Int): FPGrowthModel[Any] = { + val fpg = new FPGrowth() + .setMinSupport(minSupport) + .setNumPartitions(numPartitions) + + val model = fpg.run(data.rdd.map(_.asScala.toArray)) + new FPGrowthModelWrapper(model) + } + /** * Java stub for Normalizer.transform() */ @@ -433,9 +450,9 @@ private[python] class PythonMLLibAPI extends Serializable { def normalizeVector(p: Double, rdd: JavaRDD[Vector]): JavaRDD[Vector] = { new Normalizer(p).transform(rdd) } - + /** - * Java stub for IDF.fit(). This stub returns a + * Java stub for StandardScaler.fit(). This stub returns a * handle to the Java object instead of the content of the Java object. * Extra care needs to be taken in the Python code to ensure it gets freed on * exit; see the Py4J documentation. @@ -476,13 +493,15 @@ private[python] class PythonMLLibAPI extends Serializable { learningRate: Double, numPartitions: Int, numIterations: Int, - seed: Long): Word2VecModelWrapper = { + seed: Long, + minCount: Int): Word2VecModelWrapper = { val word2vec = new Word2Vec() .setVectorSize(vectorSize) .setLearningRate(learningRate) .setNumPartitions(numPartitions) .setNumIterations(numIterations) .setSeed(seed) + .setMinCount(minCount) try { val model = word2vec.fit(dataJRDD.rdd.persist(StorageLevel.MEMORY_AND_DISK_SER)) new Word2VecModelWrapper(model) @@ -516,6 +535,10 @@ private[python] class PythonMLLibAPI extends Serializable { val words = result.map(_._1) List(words, similarity).map(_.asInstanceOf[Object]).asJava } + + def getVectors: JMap[String, JList[Float]] = { + model.getVectors.map({case (k, v) => (k, v.toList.asJava)}).asJava + } } /** @@ -1113,7 +1136,10 @@ private[spark] object SerDe extends Serializable { iter.flatMap { row => val obj = unpickle.loads(row) if (batched) { - obj.asInstanceOf[JArrayList[_]].asScala + obj match { + case list: JArrayList[_] => list.asScala + case arr: Array[_] => arr + } } else { Seq(obj) } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/classification/LogisticRegression.scala b/mllib/src/main/scala/org/apache/spark/mllib/classification/LogisticRegression.scala index e7c3599ff619c..057b628c6a586 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/classification/LogisticRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/classification/LogisticRegression.scala @@ -62,6 +62,15 @@ class LogisticRegressionModel ( s" but was given weights of length ${weights.size}") } + private val dataWithBiasSize: Int = weights.size / (numClasses - 1) + + private val weightsArray: Array[Double] = weights match { + case dv: DenseVector => dv.values + case _ => + throw new IllegalArgumentException( + s"weights only supports dense vector but got type ${weights.getClass}.") + } + /** * Constructs a [[LogisticRegressionModel]] with weights and intercept for binary classification. */ @@ -74,6 +83,7 @@ class LogisticRegressionModel ( * Sets the threshold that separates positive predictions from negative predictions * in Binary Logistic Regression. An example with prediction score greater than or equal to * this threshold is identified as an positive, and negative otherwise. The default value is 0.5. + * It is only used for binary classification. */ @Experimental def setThreshold(threshold: Double): this.type = { @@ -84,6 +94,7 @@ class LogisticRegressionModel ( /** * :: Experimental :: * Returns the threshold (if any) used for converting raw prediction scores into 0/1 predictions. + * It is only used for binary classification. */ @Experimental def getThreshold: Option[Double] = threshold @@ -91,6 +102,7 @@ class LogisticRegressionModel ( /** * :: Experimental :: * Clears the threshold so that `predict` will output raw prediction scores. + * It is only used for binary classification. */ @Experimental def clearThreshold(): this.type = { @@ -106,7 +118,6 @@ class LogisticRegressionModel ( // If dataMatrix and weightMatrix have the same dimension, it's binary logistic regression. if (numClasses == 2) { - require(numFeatures == weightMatrix.size) val margin = dot(weightMatrix, dataMatrix) + intercept val score = 1.0 / (1.0 + math.exp(-margin)) threshold match { @@ -114,30 +125,9 @@ class LogisticRegressionModel ( case None => score } } else { - val dataWithBiasSize = weightMatrix.size / (numClasses - 1) - - val weightsArray = weightMatrix match { - case dv: DenseVector => dv.values - case _ => - throw new IllegalArgumentException( - s"weights only supports dense vector but got type ${weightMatrix.getClass}.") - } - - val margins = (0 until numClasses - 1).map { i => - var margin = 0.0 - dataMatrix.foreachActive { (index, value) => - if (value != 0.0) margin += value * weightsArray((i * dataWithBiasSize) + index) - } - // Intercept is required to be added into margin. - if (dataMatrix.size + 1 == dataWithBiasSize) { - margin += weightsArray((i * dataWithBiasSize) + dataMatrix.size) - } - margin - } - /** - * Find the one with maximum margins. If the maxMargin is negative, then the prediction - * result will be the first class. + * Compute and find the one with maximum margins. If the maxMargin is negative, then the + * prediction result will be the first class. * * PS, if you want to compute the probabilities for each outcome instead of the outcome * with maximum probability, remember to subtract the maxMargin from margins if maxMargin @@ -145,13 +135,20 @@ class LogisticRegressionModel ( */ var bestClass = 0 var maxMargin = 0.0 - var i = 0 - while(i < margins.size) { - if (margins(i) > maxMargin) { - maxMargin = margins(i) + val withBias = dataMatrix.size + 1 == dataWithBiasSize + (0 until numClasses - 1).foreach { i => + var margin = 0.0 + dataMatrix.foreachActive { (index, value) => + if (value != 0.0) margin += value * weightsArray((i * dataWithBiasSize) + index) + } + // Intercept is required to be added into margin. + if (withBias) { + margin += weightsArray((i * dataWithBiasSize) + dataMatrix.size) + } + if (margin > maxMargin) { + maxMargin = margin bestClass = i + 1 } - i += 1 } bestClass.toDouble } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/classification/StreamingLogisticRegressionWithSGD.scala b/mllib/src/main/scala/org/apache/spark/mllib/classification/StreamingLogisticRegressionWithSGD.scala index b89f38cf5aba4..7d33df3221fbf 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/classification/StreamingLogisticRegressionWithSGD.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/classification/StreamingLogisticRegressionWithSGD.scala @@ -63,6 +63,8 @@ class StreamingLogisticRegressionWithSGD private[mllib] ( protected val algorithm = new LogisticRegressionWithSGD( stepSize, numIterations, regParam, miniBatchFraction) + protected var model: Option[LogisticRegressionModel] = None + /** Set the step size for gradient descent. Default: 0.1. */ def setStepSize(stepSize: Double): this.type = { this.algorithm.optimizer.setStepSize(stepSize) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/PowerIterationClustering.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/PowerIterationClustering.scala index 180023922a9b0..aa53e88d59856 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/PowerIterationClustering.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/PowerIterationClustering.scala @@ -17,15 +17,20 @@ package org.apache.spark.mllib.clustering -import org.apache.spark.{Logging, SparkException} +import org.json4s.JsonDSL._ +import org.json4s._ +import org.json4s.jackson.JsonMethods._ + import org.apache.spark.annotation.Experimental import org.apache.spark.api.java.JavaRDD import org.apache.spark.graphx._ import org.apache.spark.graphx.impl.GraphImpl import org.apache.spark.mllib.linalg.Vectors -import org.apache.spark.mllib.util.MLUtils +import org.apache.spark.mllib.util.{Loader, MLUtils, Saveable} import org.apache.spark.rdd.RDD +import org.apache.spark.sql.{Row, SQLContext} import org.apache.spark.util.random.XORShiftRandom +import org.apache.spark.{Logging, SparkContext, SparkException} /** * :: Experimental :: @@ -38,7 +43,60 @@ import org.apache.spark.util.random.XORShiftRandom @Experimental class PowerIterationClusteringModel( val k: Int, - val assignments: RDD[PowerIterationClustering.Assignment]) extends Serializable + val assignments: RDD[PowerIterationClustering.Assignment]) extends Saveable with Serializable { + + override def save(sc: SparkContext, path: String): Unit = { + PowerIterationClusteringModel.SaveLoadV1_0.save(sc, this, path) + } + + override protected def formatVersion: String = "1.0" +} + +object PowerIterationClusteringModel extends Loader[PowerIterationClusteringModel] { + override def load(sc: SparkContext, path: String): PowerIterationClusteringModel = { + PowerIterationClusteringModel.SaveLoadV1_0.load(sc, path) + } + + private[clustering] + object SaveLoadV1_0 { + + private val thisFormatVersion = "1.0" + + private[clustering] + val thisClassName = "org.apache.spark.mllib.clustering.PowerIterationClusteringModel" + + def save(sc: SparkContext, model: PowerIterationClusteringModel, path: String): Unit = { + val sqlContext = new SQLContext(sc) + import sqlContext.implicits._ + + val metadata = compact(render( + ("class" -> thisClassName) ~ ("version" -> thisFormatVersion) ~ ("k" -> model.k))) + sc.parallelize(Seq(metadata), 1).saveAsTextFile(Loader.metadataPath(path)) + + val dataRDD = model.assignments.toDF() + dataRDD.saveAsParquetFile(Loader.dataPath(path)) + } + + def load(sc: SparkContext, path: String): PowerIterationClusteringModel = { + implicit val formats = DefaultFormats + val sqlContext = new SQLContext(sc) + + val (className, formatVersion, metadata) = Loader.loadMetadata(sc, path) + assert(className == thisClassName) + assert(formatVersion == thisFormatVersion) + + val k = (metadata \ "k").extract[Int] + val assignments = sqlContext.parquetFile(Loader.dataPath(path)) + Loader.checkSchema[PowerIterationClustering.Assignment](assignments.schema) + + val assignmentsRDD = assignments.map { + case Row(id: Long, cluster: Int) => PowerIterationClustering.Assignment(id, cluster) + } + + new PowerIterationClusteringModel(k, assignmentsRDD) + } + } +} /** * :: Experimental :: @@ -135,7 +193,7 @@ class PowerIterationClustering private[clustering] ( val v = powerIter(w, maxIterations) val assignments = kMeans(v, k).mapPartitions({ iter => iter.map { case (id, cluster) => - new Assignment(id, cluster) + Assignment(id, cluster) } }, preservesPartitioning = true) new PowerIterationClusteringModel(k, assignments) @@ -152,7 +210,7 @@ object PowerIterationClustering extends Logging { * @param cluster assigned cluster id */ @Experimental - class Assignment(val id: Long, val cluster: Int) extends Serializable + case class Assignment(id: Long, cluster: Int) /** * Normalizes the affinity matrix (A) by row sums and returns the normalized affinity matrix (W). diff --git a/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala b/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala index 9ee7e4a66b535..b2d9053f70145 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala @@ -522,7 +522,7 @@ object Word2VecModel extends Loader[Word2VecModel] { new Word2VecModel(word2VecMap) } - def save(sc: SparkContext, path: String, model: Map[String, Array[Float]]) = { + def save(sc: SparkContext, path: String, model: Map[String, Array[Float]]): Unit = { val sqlContext = new SQLContext(sc) import sqlContext.implicits._ diff --git a/mllib/src/main/scala/org/apache/spark/mllib/linalg/Matrices.scala b/mllib/src/main/scala/org/apache/spark/mllib/linalg/Matrices.scala index d1a174063caba..3fa5e068d16d4 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/linalg/Matrices.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/linalg/Matrices.scala @@ -87,6 +87,9 @@ sealed trait Matrix extends Serializable { /** A human readable representation of the matrix */ override def toString: String = toBreeze.toString() + /** A human readable representation of the matrix with maximum lines and width */ + def toString(maxLines: Int, maxLineWidth: Int): String = toBreeze.toString(maxLines, maxLineWidth) + /** Map the values of this matrix using a function. Generates a new matrix. Performs the * function on only the backing array. For example, an operation such as addition or * subtraction will only be performed on the non-zero values in a `SparseMatrix`. */ diff --git a/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/RowMatrix.scala b/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/RowMatrix.scala index 961111507f2c2..9a89a6f3a515f 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/RowMatrix.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/RowMatrix.scala @@ -531,7 +531,6 @@ class RowMatrix( val rand = new XORShiftRandom(indx) val scaled = new Array[Double](p.size) iter.flatMap { row => - val buf = new ListBuffer[((Int, Int), Double)]() row match { case SparseVector(size, indices, values) => val nnz = indices.size @@ -540,8 +539,9 @@ class RowMatrix( scaled(k) = values(k) / q(indices(k)) k += 1 } - k = 0 - while (k < nnz) { + + Iterator.tabulate (nnz) { k => + val buf = new ListBuffer[((Int, Int), Double)]() val i = indices(k) val iVal = scaled(k) if (iVal != 0 && rand.nextDouble() < p(i)) { @@ -555,8 +555,8 @@ class RowMatrix( l += 1 } } - k += 1 - } + buf + }.flatten case DenseVector(values) => val n = values.size var i = 0 @@ -564,8 +564,8 @@ class RowMatrix( scaled(i) = values(i) / q(i) i += 1 } - i = 0 - while (i < n) { + Iterator.tabulate (n) { i => + val buf = new ListBuffer[((Int, Int), Double)]() val iVal = scaled(i) if (iVal != 0 && rand.nextDouble() < p(i)) { var j = i + 1 @@ -577,10 +577,9 @@ class RowMatrix( j += 1 } } - i += 1 - } + buf + }.flatten } - buf } }.reduceByKey(_ + _).map { case ((i, j), sim) => MatrixEntry(i.toLong, j.toLong, sim) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/regression/StreamingLinearAlgorithm.scala b/mllib/src/main/scala/org/apache/spark/mllib/regression/StreamingLinearAlgorithm.scala index ce95c063db970..cea8f3f47307b 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/regression/StreamingLinearAlgorithm.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/regression/StreamingLinearAlgorithm.scala @@ -60,7 +60,7 @@ abstract class StreamingLinearAlgorithm[ A <: GeneralizedLinearAlgorithm[M]] extends Logging { /** The model to be updated and used for prediction. */ - protected var model: Option[M] = None + protected var model: Option[M] /** The algorithm to use for updating. */ protected val algorithm: A @@ -114,7 +114,7 @@ abstract class StreamingLinearAlgorithm[ if (model.isEmpty) { throw new IllegalArgumentException("Model must be initialized before starting prediction.") } - data.map(model.get.predict) + data.map{x => model.get.predict(x)} } /** Java-friendly version of `predictOn`. */ @@ -132,7 +132,7 @@ abstract class StreamingLinearAlgorithm[ if (model.isEmpty) { throw new IllegalArgumentException("Model must be initialized before starting prediction") } - data.mapValues(model.get.predict) + data.mapValues{x => model.get.predict(x)} } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/regression/StreamingLinearRegressionWithSGD.scala b/mllib/src/main/scala/org/apache/spark/mllib/regression/StreamingLinearRegressionWithSGD.scala index e5e6301127a28..a49153bf73c0d 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/regression/StreamingLinearRegressionWithSGD.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/regression/StreamingLinearRegressionWithSGD.scala @@ -59,6 +59,8 @@ class StreamingLinearRegressionWithSGD private[mllib] ( val algorithm = new LinearRegressionWithSGD(stepSize, numIterations, miniBatchFraction) + protected var model: Option[LinearRegressionModel] = None + /** Set the step size for gradient descent. Default: 0.1. */ def setStepSize(stepSize: Double): this.type = { this.algorithm.optimizer.setStepSize(stepSize) diff --git a/mllib/src/test/java/org/apache/spark/ml/feature/JavaVectorIndexerSuite.java b/mllib/src/test/java/org/apache/spark/ml/feature/JavaVectorIndexerSuite.java new file mode 100644 index 0000000000000..161100134c92d --- /dev/null +++ b/mllib/src/test/java/org/apache/spark/ml/feature/JavaVectorIndexerSuite.java @@ -0,0 +1,70 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.feature; + +import java.io.Serializable; +import java.util.List; + +import org.junit.After; +import org.junit.Assert; +import org.junit.Before; +import org.junit.Test; + +import com.google.common.collect.Lists; + +import org.apache.spark.api.java.JavaSparkContext; +import org.apache.spark.ml.feature.VectorIndexerSuite.FeatureData; +import org.apache.spark.mllib.linalg.Vectors; +import org.apache.spark.sql.DataFrame; +import org.apache.spark.sql.SQLContext; + + +public class JavaVectorIndexerSuite implements Serializable { + private transient JavaSparkContext sc; + + @Before + public void setUp() { + sc = new JavaSparkContext("local", "JavaVectorIndexerSuite"); + } + + @After + public void tearDown() { + sc.stop(); + sc = null; + } + + @Test + public void vectorIndexerAPI() { + // The tests are to check Java compatibility. + List points = Lists.newArrayList( + new FeatureData(Vectors.dense(0.0, -2.0)), + new FeatureData(Vectors.dense(1.0, 3.0)), + new FeatureData(Vectors.dense(1.0, 4.0)) + ); + SQLContext sqlContext = new SQLContext(sc); + DataFrame data = sqlContext.createDataFrame(sc.parallelize(points, 2), FeatureData.class); + VectorIndexer indexer = new VectorIndexer() + .setInputCol("features") + .setOutputCol("indexed") + .setMaxCategories(2); + VectorIndexerModel model = indexer.fit(data); + Assert.assertEquals(model.numFeatures(), 2); + Assert.assertEquals(model.categoryMaps().size(), 1); + DataFrame indexedData = model.transform(data); + } +} diff --git a/mllib/src/test/resources/log4j.properties b/mllib/src/test/resources/log4j.properties index 9697237bfa1a3..75e3b53a093f6 100644 --- a/mllib/src/test/resources/log4j.properties +++ b/mllib/src/test/resources/log4j.properties @@ -24,5 +24,5 @@ log4j.appender.file.layout=org.apache.log4j.PatternLayout log4j.appender.file.layout.ConversionPattern=%d{yy/MM/dd HH:mm:ss.SSS} %t %p %c{1}: %m%n # Ignore messages below warning level from Jetty, because it's a bit verbose -log4j.logger.org.eclipse.jetty=WARN +log4j.logger.org.spark-project.jetty=WARN diff --git a/mllib/src/test/scala/org/apache/spark/ml/attribute/AttributeGroupSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/attribute/AttributeGroupSuite.scala index 3fb6e2ec46468..0dcfe5a2002dc 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/attribute/AttributeGroupSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/attribute/AttributeGroupSuite.scala @@ -43,8 +43,8 @@ class AttributeGroupSuite extends FunSuite { intercept[NoSuchElementException] { group("abc") } - assert(group === AttributeGroup.fromMetadata(group.toMetadata, group.name)) - assert(group === AttributeGroup.fromStructField(group.toStructField())) + assert(group === AttributeGroup.fromMetadata(group.toMetadataImpl, group.name)) + assert(group === AttributeGroup.fromStructField(group.toStructField)) } test("attribute group without attributes") { @@ -53,8 +53,8 @@ class AttributeGroupSuite extends FunSuite { assert(group0.numAttributes === Some(10)) assert(group0.size === 10) assert(group0.attributes.isEmpty) - assert(group0 === AttributeGroup.fromMetadata(group0.toMetadata, group0.name)) - assert(group0 === AttributeGroup.fromStructField(group0.toStructField())) + assert(group0 === AttributeGroup.fromMetadata(group0.toMetadataImpl, group0.name)) + assert(group0 === AttributeGroup.fromStructField(group0.toStructField)) val group1 = new AttributeGroup("item") assert(group1.name === "item") diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala index b3d1bfcfbee0f..35d8c2e16c6cd 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala @@ -46,6 +46,7 @@ class LogisticRegressionSuite extends FunSuite with MLlibTestSparkContext { assert(lr.getPredictionCol == "prediction") assert(lr.getRawPredictionCol == "rawPrediction") assert(lr.getProbabilityCol == "probability") + assert(lr.getFitIntercept == true) val model = lr.fit(dataset) model.transform(dataset) .select("label", "probability", "prediction", "rawPrediction") @@ -55,6 +56,14 @@ class LogisticRegressionSuite extends FunSuite with MLlibTestSparkContext { assert(model.getPredictionCol == "prediction") assert(model.getRawPredictionCol == "rawPrediction") assert(model.getProbabilityCol == "probability") + assert(model.intercept !== 0.0) + } + + test("logistic regression doesn't fit intercept when fitIntercept is off") { + val lr = new LogisticRegression + lr.setFitIntercept(false) + val model = lr.fit(dataset) + assert(model.intercept === 0.0) } test("logistic regression with setters") { diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/NormalizerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/NormalizerSuite.scala index a18c335952b96..9d09f24709e23 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/NormalizerSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/NormalizerSuite.scala @@ -24,7 +24,6 @@ import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.mllib.util.TestingUtils._ import org.apache.spark.sql.{DataFrame, Row, SQLContext} -private case class DataSet(features: Vector) class NormalizerSuite extends FunSuite with MLlibTestSparkContext { @@ -63,7 +62,7 @@ class NormalizerSuite extends FunSuite with MLlibTestSparkContext { ) val sqlContext = new SQLContext(sc) - dataFrame = sqlContext.createDataFrame(sc.parallelize(data, 2).map(DataSet)) + dataFrame = sqlContext.createDataFrame(sc.parallelize(data, 2).map(NormalizerSuite.FeatureData)) normalizer = new Normalizer() .setInputCol("features") .setOutputCol("normalized_features") @@ -107,3 +106,7 @@ class NormalizerSuite extends FunSuite with MLlibTestSparkContext { assertValues(result, l1Normalized) } } + +private object NormalizerSuite { + case class FeatureData(features: Vector) +} diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/StringIndexerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/StringIndexerSuite.scala new file mode 100644 index 0000000000000..00b5d094d82f1 --- /dev/null +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/StringIndexerSuite.scala @@ -0,0 +1,52 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.feature + +import org.scalatest.FunSuite + +import org.apache.spark.ml.attribute.{Attribute, NominalAttribute} +import org.apache.spark.mllib.util.MLlibTestSparkContext +import org.apache.spark.sql.SQLContext + +class StringIndexerSuite extends FunSuite with MLlibTestSparkContext { + private var sqlContext: SQLContext = _ + + override def beforeAll(): Unit = { + super.beforeAll() + sqlContext = new SQLContext(sc) + } + + test("StringIndexer") { + val data = sc.parallelize(Seq((0, "a"), (1, "b"), (2, "c"), (3, "a"), (4, "a"), (5, "c")), 2) + val df = sqlContext.createDataFrame(data).toDF("id", "label") + val indexer = new StringIndexer() + .setInputCol("label") + .setOutputCol("labelIndex") + .fit(df) + val transformed = indexer.transform(df) + val attr = Attribute.fromStructField(transformed.schema("labelIndex")) + .asInstanceOf[NominalAttribute] + assert(attr.values.get === Array("a", "c", "b")) + val output = transformed.select("id", "labelIndex").map { r => + (r.getInt(0), r.getDouble(1)) + }.collect().toSet + // a -> 0, b -> 2, c -> 1 + val expected = Set((0, 0.0), (1, 2.0), (2, 1.0), (3, 0.0), (4, 0.0), (5, 1.0)) + assert(output === expected) + } +} diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/TokenizerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/TokenizerSuite.scala index bf862b912d326..d186ead8f542f 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/TokenizerSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/TokenizerSuite.scala @@ -25,10 +25,7 @@ import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.sql.{DataFrame, Row, SQLContext} @BeanInfo -case class TokenizerTestData(rawText: String, wantedTokens: Seq[String]) { - /** Constructor used in [[org.apache.spark.ml.feature.JavaTokenizerSuite]] */ - def this(rawText: String, wantedTokens: Array[String]) = this(rawText, wantedTokens.toSeq) -} +case class TokenizerTestData(rawText: String, wantedTokens: Array[String]) class RegexTokenizerSuite extends FunSuite with MLlibTestSparkContext { import org.apache.spark.ml.feature.RegexTokenizerSuite._ @@ -46,14 +43,14 @@ class RegexTokenizerSuite extends FunSuite with MLlibTestSparkContext { .setOutputCol("tokens") val dataset0 = sqlContext.createDataFrame(Seq( - TokenizerTestData("Test for tokenization.", Seq("Test", "for", "tokenization", ".")), - TokenizerTestData("Te,st. punct", Seq("Te", ",", "st", ".", "punct")) + TokenizerTestData("Test for tokenization.", Array("Test", "for", "tokenization", ".")), + TokenizerTestData("Te,st. punct", Array("Te", ",", "st", ".", "punct")) )) testRegexTokenizer(tokenizer, dataset0) val dataset1 = sqlContext.createDataFrame(Seq( - TokenizerTestData("Test for tokenization.", Seq("Test", "for", "tokenization")), - TokenizerTestData("Te,st. punct", Seq("punct")) + TokenizerTestData("Test for tokenization.", Array("Test", "for", "tokenization")), + TokenizerTestData("Te,st. punct", Array("punct")) )) tokenizer.setMinTokenLength(3) @@ -64,8 +61,8 @@ class RegexTokenizerSuite extends FunSuite with MLlibTestSparkContext { .setGaps(true) .setMinTokenLength(0) val dataset2 = sqlContext.createDataFrame(Seq( - TokenizerTestData("Test for tokenization.", Seq("Test", "for", "tokenization.")), - TokenizerTestData("Te,st. punct", Seq("Te,st.", "", "punct")) + TokenizerTestData("Test for tokenization.", Array("Test", "for", "tokenization.")), + TokenizerTestData("Te,st. punct", Array("Te,st.", "", "punct")) )) testRegexTokenizer(tokenizer, dataset2) } diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/VectorAssemblerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/VectorAssemblerSuite.scala new file mode 100644 index 0000000000000..57d0278e03639 --- /dev/null +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/VectorAssemblerSuite.scala @@ -0,0 +1,63 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.feature + +import org.scalatest.FunSuite + +import org.apache.spark.SparkException +import org.apache.spark.mllib.linalg.{Vector, Vectors} +import org.apache.spark.mllib.util.MLlibTestSparkContext +import org.apache.spark.sql.{Row, SQLContext} + +class VectorAssemblerSuite extends FunSuite with MLlibTestSparkContext { + + @transient var sqlContext: SQLContext = _ + + override def beforeAll(): Unit = { + super.beforeAll() + sqlContext = new SQLContext(sc) + } + + test("assemble") { + import org.apache.spark.ml.feature.VectorAssembler.assemble + assert(assemble(0.0) === Vectors.sparse(1, Array.empty, Array.empty)) + assert(assemble(0.0, 1.0) === Vectors.sparse(2, Array(1), Array(1.0))) + val dv = Vectors.dense(2.0, 0.0) + assert(assemble(0.0, dv, 1.0) === Vectors.sparse(4, Array(1, 3), Array(2.0, 1.0))) + val sv = Vectors.sparse(2, Array(0, 1), Array(3.0, 4.0)) + assert(assemble(0.0, dv, 1.0, sv) === + Vectors.sparse(6, Array(1, 3, 4, 5), Array(2.0, 1.0, 3.0, 4.0))) + for (v <- Seq(1, "a", null)) { + intercept[SparkException](assemble(v)) + intercept[SparkException](assemble(1.0, v)) + } + } + + test("VectorAssembler") { + val df = sqlContext.createDataFrame(Seq( + (0, 0.0, Vectors.dense(1.0, 2.0), "a", Vectors.sparse(2, Array(1), Array(3.0)), 10L) + )).toDF("id", "x", "y", "name", "z", "n") + val assembler = new VectorAssembler() + .setInputCols(Array("x", "y", "z", "n")) + .setOutputCol("features") + assembler.transform(df).select("features").collect().foreach { + case Row(v: Vector) => + assert(v === Vectors.sparse(6, Array(1, 2, 4, 5), Array(1.0, 2.0, 3.0, 10.0))) + } + } +} diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/VectorIndexerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/VectorIndexerSuite.scala new file mode 100644 index 0000000000000..81ef831c42e55 --- /dev/null +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/VectorIndexerSuite.scala @@ -0,0 +1,256 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.feature + +import scala.beans.{BeanInfo, BeanProperty} + +import org.scalatest.FunSuite + +import org.apache.spark.SparkException +import org.apache.spark.ml.attribute._ +import org.apache.spark.ml.util.TestingUtils +import org.apache.spark.mllib.linalg.{SparseVector, Vector, Vectors} +import org.apache.spark.mllib.util.MLlibTestSparkContext +import org.apache.spark.rdd.RDD +import org.apache.spark.sql.{DataFrame, SQLContext} + + +class VectorIndexerSuite extends FunSuite with MLlibTestSparkContext { + + import VectorIndexerSuite.FeatureData + + @transient var sqlContext: SQLContext = _ + + // identical, of length 3 + @transient var densePoints1: DataFrame = _ + @transient var sparsePoints1: DataFrame = _ + @transient var point1maxes: Array[Double] = _ + + // identical, of length 2 + @transient var densePoints2: DataFrame = _ + @transient var sparsePoints2: DataFrame = _ + + // different lengths + @transient var badPoints: DataFrame = _ + + override def beforeAll(): Unit = { + super.beforeAll() + + val densePoints1Seq = Seq( + Vectors.dense(1.0, 2.0, 0.0), + Vectors.dense(0.0, 1.0, 2.0), + Vectors.dense(0.0, 0.0, -1.0), + Vectors.dense(1.0, 3.0, 2.0)) + val sparsePoints1Seq = Seq( + Vectors.sparse(3, Array(0, 1), Array(1.0, 2.0)), + Vectors.sparse(3, Array(1, 2), Array(1.0, 2.0)), + Vectors.sparse(3, Array(2), Array(-1.0)), + Vectors.sparse(3, Array(0, 1, 2), Array(1.0, 3.0, 2.0))) + point1maxes = Array(1.0, 3.0, 2.0) + + val densePoints2Seq = Seq( + Vectors.dense(1.0, 1.0, 0.0, 1.0), + Vectors.dense(0.0, 1.0, 1.0, 1.0), + Vectors.dense(-1.0, 1.0, 2.0, 0.0)) + val sparsePoints2Seq = Seq( + Vectors.sparse(4, Array(0, 1, 3), Array(1.0, 1.0, 1.0)), + Vectors.sparse(4, Array(1, 2, 3), Array(1.0, 1.0, 1.0)), + Vectors.sparse(4, Array(0, 1, 2), Array(-1.0, 1.0, 2.0))) + + val badPointsSeq = Seq( + Vectors.sparse(2, Array(0, 1), Array(1.0, 1.0)), + Vectors.sparse(3, Array(2), Array(-1.0))) + + // Sanity checks for assumptions made in tests + assert(densePoints1Seq.head.size == sparsePoints1Seq.head.size) + assert(densePoints2Seq.head.size == sparsePoints2Seq.head.size) + assert(densePoints1Seq.head.size != densePoints2Seq.head.size) + def checkPair(dvSeq: Seq[Vector], svSeq: Seq[Vector]): Unit = { + assert(dvSeq.zip(svSeq).forall { case (dv, sv) => dv.toArray === sv.toArray }, + "typo in unit test") + } + checkPair(densePoints1Seq, sparsePoints1Seq) + checkPair(densePoints2Seq, sparsePoints2Seq) + + sqlContext = new SQLContext(sc) + densePoints1 = sqlContext.createDataFrame(sc.parallelize(densePoints1Seq, 2).map(FeatureData)) + sparsePoints1 = sqlContext.createDataFrame(sc.parallelize(sparsePoints1Seq, 2).map(FeatureData)) + densePoints2 = sqlContext.createDataFrame(sc.parallelize(densePoints2Seq, 2).map(FeatureData)) + sparsePoints2 = sqlContext.createDataFrame(sc.parallelize(sparsePoints2Seq, 2).map(FeatureData)) + badPoints = sqlContext.createDataFrame(sc.parallelize(badPointsSeq, 2).map(FeatureData)) + } + + private def getIndexer: VectorIndexer = + new VectorIndexer().setInputCol("features").setOutputCol("indexed") + + test("Cannot fit an empty DataFrame") { + val rdd = sqlContext.createDataFrame(sc.parallelize(Array.empty[Vector], 2).map(FeatureData)) + val vectorIndexer = getIndexer + intercept[IllegalArgumentException] { + vectorIndexer.fit(rdd) + } + } + + test("Throws error when given RDDs with different size vectors") { + val vectorIndexer = getIndexer + val model = vectorIndexer.fit(densePoints1) // vectors of length 3 + model.transform(densePoints1) // should work + model.transform(sparsePoints1) // should work + intercept[IllegalArgumentException] { + model.transform(densePoints2) + println("Did not throw error when fit, transform were called on vectors of different lengths") + } + intercept[SparkException] { + vectorIndexer.fit(badPoints) + println("Did not throw error when fitting vectors of different lengths in same RDD.") + } + } + + test("Same result with dense and sparse vectors") { + def testDenseSparse(densePoints: DataFrame, sparsePoints: DataFrame): Unit = { + val denseVectorIndexer = getIndexer.setMaxCategories(2) + val sparseVectorIndexer = getIndexer.setMaxCategories(2) + val denseModel = denseVectorIndexer.fit(densePoints) + val sparseModel = sparseVectorIndexer.fit(sparsePoints) + val denseMap = denseModel.categoryMaps + val sparseMap = sparseModel.categoryMaps + assert(denseMap.keys.toSet == sparseMap.keys.toSet, + "Categorical features chosen from dense vs. sparse vectors did not match.") + assert(denseMap == sparseMap, + "Categorical feature value indexes chosen from dense vs. sparse vectors did not match.") + } + testDenseSparse(densePoints1, sparsePoints1) + testDenseSparse(densePoints2, sparsePoints2) + } + + test("Builds valid categorical feature value index, transform correctly, check metadata") { + def checkCategoryMaps( + data: DataFrame, + maxCategories: Int, + categoricalFeatures: Set[Int]): Unit = { + val collectedData = data.collect().map(_.getAs[Vector](0)) + val errMsg = s"checkCategoryMaps failed for input with maxCategories=$maxCategories," + + s" categoricalFeatures=${categoricalFeatures.mkString(", ")}" + try { + val vectorIndexer = getIndexer.setMaxCategories(maxCategories) + val model = vectorIndexer.fit(data) + val categoryMaps = model.categoryMaps + // Chose correct categorical features + assert(categoryMaps.keys.toSet === categoricalFeatures) + val transformed = model.transform(data).select("indexed") + val indexedRDD: RDD[Vector] = transformed.map(_.getAs[Vector](0)) + val featureAttrs = AttributeGroup.fromStructField(transformed.schema("indexed")) + assert(featureAttrs.name === "indexed") + assert(featureAttrs.attributes.get.length === model.numFeatures) + categoricalFeatures.foreach { feature: Int => + val origValueSet = collectedData.map(_(feature)).toSet + val targetValueIndexSet = Range(0, origValueSet.size).toSet + val catMap = categoryMaps(feature) + assert(catMap.keys.toSet === origValueSet) // Correct categories + assert(catMap.values.toSet === targetValueIndexSet) // Correct category indices + if (origValueSet.contains(0.0)) { + assert(catMap(0.0) === 0) // value 0 gets index 0 + } + // Check transformed data + assert(indexedRDD.map(_(feature)).collect().toSet === targetValueIndexSet) + // Check metadata + val featureAttr = featureAttrs(feature) + assert(featureAttr.index.get === feature) + featureAttr match { + case attr: BinaryAttribute => + assert(attr.values.get === origValueSet.toArray.sorted.map(_.toString)) + case attr: NominalAttribute => + assert(attr.values.get === origValueSet.toArray.sorted.map(_.toString)) + assert(attr.isOrdinal.get === false) + case _ => + throw new RuntimeException(errMsg + s". Categorical feature $feature failed" + + s" metadata check. Found feature attribute: $featureAttr.") + } + } + // Check numerical feature metadata. + Range(0, model.numFeatures).filter(feature => !categoricalFeatures.contains(feature)) + .foreach { feature: Int => + val featureAttr = featureAttrs(feature) + featureAttr match { + case attr: NumericAttribute => + assert(featureAttr.index.get === feature) + case _ => + throw new RuntimeException(errMsg + s". Numerical feature $feature failed" + + s" metadata check. Found feature attribute: $featureAttr.") + } + } + } catch { + case e: org.scalatest.exceptions.TestFailedException => + println(errMsg) + throw e + } + } + checkCategoryMaps(densePoints1, maxCategories = 2, categoricalFeatures = Set(0)) + checkCategoryMaps(densePoints1, maxCategories = 3, categoricalFeatures = Set(0, 2)) + checkCategoryMaps(densePoints2, maxCategories = 2, categoricalFeatures = Set(1, 3)) + } + + test("Maintain sparsity for sparse vectors") { + def checkSparsity(data: DataFrame, maxCategories: Int): Unit = { + val points = data.collect().map(_.getAs[Vector](0)) + val vectorIndexer = getIndexer.setMaxCategories(maxCategories) + val model = vectorIndexer.fit(data) + val indexedPoints = model.transform(data).select("indexed").map(_.getAs[Vector](0)).collect() + points.zip(indexedPoints).foreach { + case (orig: SparseVector, indexed: SparseVector) => + assert(orig.indices.length == indexed.indices.length) + case _ => throw new UnknownError("Unit test has a bug in it.") // should never happen + } + } + checkSparsity(sparsePoints1, maxCategories = 2) + checkSparsity(sparsePoints2, maxCategories = 2) + } + + test("Preserve metadata") { + // For continuous features, preserve name and stats. + val featureAttributes: Array[Attribute] = point1maxes.zipWithIndex.map { case (maxVal, i) => + NumericAttribute.defaultAttr.withName(i.toString).withMax(maxVal) + } + val attrGroup = new AttributeGroup("features", featureAttributes) + val densePoints1WithMeta = + densePoints1.select(densePoints1("features").as("features", attrGroup.toMetadata)) + val vectorIndexer = getIndexer.setMaxCategories(2) + val model = vectorIndexer.fit(densePoints1WithMeta) + // Check that ML metadata are preserved. + val indexedPoints = model.transform(densePoints1WithMeta) + val transAttributes: Array[Attribute] = + AttributeGroup.fromStructField(indexedPoints.schema("indexed")).attributes.get + featureAttributes.zip(transAttributes).foreach { case (orig, trans) => + assert(orig.name === trans.name) + (orig, trans) match { + case (orig: NumericAttribute, trans: NumericAttribute) => + assert(orig.max.nonEmpty && orig.max === trans.max) + case _ => + // do nothing + // TODO: Once input features marked as categorical are handled correctly, check that here. + } + } + // Check that non-ML metadata are preserved. + TestingUtils.testPreserveMetadata(densePoints1WithMeta, model, "features", "indexed") + } +} + +private[feature] object VectorIndexerSuite { + @BeanInfo + case class FeatureData(@BeanProperty features: Vector) +} diff --git a/mllib/src/test/scala/org/apache/spark/ml/param/TestParams.scala b/mllib/src/test/scala/org/apache/spark/ml/param/TestParams.scala index 1a65883d78a71..ce52f2f230085 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/param/TestParams.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/param/TestParams.scala @@ -28,7 +28,7 @@ class TestParams extends Params { def setInputCol(value: String): this.type = { set(inputCol, value); this } def getInputCol: String = get(inputCol) - override def validate(paramMap: ParamMap) = { + override def validate(paramMap: ParamMap): Unit = { val m = this.paramMap ++ paramMap require(m(maxIter) >= 0) require(m.contains(inputCol)) diff --git a/mllib/src/test/scala/org/apache/spark/ml/recommendation/ALSSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/recommendation/ALSSuite.scala index 0bb06e9e8ac9c..fc7349330cf86 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/recommendation/ALSSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/recommendation/ALSSuite.scala @@ -22,6 +22,7 @@ import java.util.Random import scala.collection.mutable import scala.collection.mutable.ArrayBuffer +import scala.language.existentials import com.github.fommil.netlib.BLAS.{getInstance => blas} import org.scalatest.FunSuite @@ -68,39 +69,42 @@ class ALSSuite extends FunSuite with MLlibTestSparkContext with Logging { } } - test("normal equation construction with explict feedback") { + test("normal equation construction") { val k = 2 val ne0 = new NormalEquation(k) - .add(Array(1.0f, 2.0f), 3.0f) - .add(Array(4.0f, 5.0f), 6.0f) + .add(Array(1.0f, 2.0f), 3.0) + .add(Array(4.0f, 5.0f), 6.0, 2.0) // weighted assert(ne0.k === k) assert(ne0.triK === k * (k + 1) / 2) - assert(ne0.n === 2) // NumPy code that computes the expected values: // A = np.matrix("1 2; 4 5") // b = np.matrix("3; 6") - // ata = A.transpose() * A - // atb = A.transpose() * b - assert(Vectors.dense(ne0.ata) ~== Vectors.dense(17.0, 22.0, 29.0) relTol 1e-8) - assert(Vectors.dense(ne0.atb) ~== Vectors.dense(27.0, 36.0) relTol 1e-8) + // C = np.matrix(np.diag([1, 2])) + // ata = A.transpose() * C * A + // atb = A.transpose() * C * b + assert(Vectors.dense(ne0.ata) ~== Vectors.dense(33.0, 42.0, 54.0) relTol 1e-8) + assert(Vectors.dense(ne0.atb) ~== Vectors.dense(51.0, 66.0) relTol 1e-8) val ne1 = new NormalEquation(2) - .add(Array(7.0f, 8.0f), 9.0f) + .add(Array(7.0f, 8.0f), 9.0) ne0.merge(ne1) - assert(ne0.n === 3) // NumPy code that computes the expected values: // A = np.matrix("1 2; 4 5; 7 8") // b = np.matrix("3; 6; 9") - // ata = A.transpose() * A - // atb = A.transpose() * b - assert(Vectors.dense(ne0.ata) ~== Vectors.dense(66.0, 78.0, 93.0) relTol 1e-8) - assert(Vectors.dense(ne0.atb) ~== Vectors.dense(90.0, 108.0) relTol 1e-8) + // C = np.matrix(np.diag([1, 2, 1])) + // ata = A.transpose() * C * A + // atb = A.transpose() * C * b + assert(Vectors.dense(ne0.ata) ~== Vectors.dense(82.0, 98.0, 118.0) relTol 1e-8) + assert(Vectors.dense(ne0.atb) ~== Vectors.dense(114.0, 138.0) relTol 1e-8) intercept[IllegalArgumentException] { - ne0.add(Array(1.0f), 2.0f) + ne0.add(Array(1.0f), 2.0) } intercept[IllegalArgumentException] { - ne0.add(Array(1.0f, 2.0f, 3.0f), 4.0f) + ne0.add(Array(1.0f, 2.0f, 3.0f), 4.0) + } + intercept[IllegalArgumentException] { + ne0.add(Array(1.0f, 2.0f), 0.0, -1.0) } intercept[IllegalArgumentException] { val ne2 = new NormalEquation(3) @@ -108,41 +112,16 @@ class ALSSuite extends FunSuite with MLlibTestSparkContext with Logging { } ne0.reset() - assert(ne0.n === 0) assert(ne0.ata.forall(_ == 0.0)) assert(ne0.atb.forall(_ == 0.0)) } - test("normal equation construction with implicit feedback") { - val k = 2 - val alpha = 0.5 - val ne0 = new NormalEquation(k) - .addImplicit(Array(-5.0f, -4.0f), -3.0f, alpha) - .addImplicit(Array(-2.0f, -1.0f), 0.0f, alpha) - .addImplicit(Array(1.0f, 2.0f), 3.0f, alpha) - assert(ne0.k === k) - assert(ne0.triK === k * (k + 1) / 2) - assert(ne0.n === 0) // addImplicit doesn't increase the count. - // NumPy code that computes the expected values: - // alpha = 0.5 - // A = np.matrix("-5 -4; -2 -1; 1 2") - // b = np.matrix("-3; 0; 3") - // b1 = b > 0 - // c = 1.0 + alpha * np.abs(b) - // C = np.diag(c.A1) - // I = np.eye(3) - // ata = A.transpose() * (C - I) * A - // atb = A.transpose() * C * b1 - assert(Vectors.dense(ne0.ata) ~== Vectors.dense(39.0, 33.0, 30.0) relTol 1e-8) - assert(Vectors.dense(ne0.atb) ~== Vectors.dense(2.5, 5.0) relTol 1e-8) - } - test("CholeskySolver") { val k = 2 val ne0 = new NormalEquation(k) - .add(Array(1.0f, 2.0f), 4.0f) - .add(Array(1.0f, 3.0f), 9.0f) - .add(Array(1.0f, 4.0f), 16.0f) + .add(Array(1.0f, 2.0f), 4.0) + .add(Array(1.0f, 3.0f), 9.0) + .add(Array(1.0f, 4.0f), 16.0) val ne1 = new NormalEquation(k) .merge(ne0) @@ -154,13 +133,12 @@ class ALSSuite extends FunSuite with MLlibTestSparkContext with Logging { // x0 = np.linalg.lstsq(A, b)[0] assert(Vectors.dense(x0) ~== Vectors.dense(-8.333333, 6.0) relTol 1e-6) - assert(ne0.n === 0) assert(ne0.ata.forall(_ == 0.0)) assert(ne0.atb.forall(_ == 0.0)) - val x1 = chol.solve(ne1, 0.5).map(_.toDouble) + val x1 = chol.solve(ne1, 1.5).map(_.toDouble) // NumPy code that computes the expected solution, where lambda is scaled by n: - // x0 = np.linalg.solve(A.transpose() * A + 0.5 * 3 * np.eye(2), A.transpose() * b) + // x0 = np.linalg.solve(A.transpose() * A + 1.5 * np.eye(2), A.transpose() * b) assert(Vectors.dense(x1) ~== Vectors.dense(-0.1155556, 3.28) relTol 1e-6) } diff --git a/mllib/src/test/scala/org/apache/spark/ml/util/TestingUtils.scala b/mllib/src/test/scala/org/apache/spark/ml/util/TestingUtils.scala new file mode 100644 index 0000000000000..c44cb61b34171 --- /dev/null +++ b/mllib/src/test/scala/org/apache/spark/ml/util/TestingUtils.scala @@ -0,0 +1,60 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.util + +import org.apache.spark.ml.Transformer +import org.apache.spark.sql.DataFrame +import org.apache.spark.sql.types.MetadataBuilder +import org.scalatest.FunSuite + +private[ml] object TestingUtils extends FunSuite { + + /** + * Test whether unrelated metadata are preserved for this transformer. + * This attaches extra metadata to a column, transforms the column, and check to ensure the + * extra metadata have not changed. + * @param data Input dataset + * @param transformer Transformer to test + * @param inputCol Unique input column for Transformer. This must be the ONLY input column. + * @param outputCol Output column to test for metadata presence. + */ + def testPreserveMetadata( + data: DataFrame, + transformer: Transformer, + inputCol: String, + outputCol: String): Unit = { + // Create some fake metadata + val origMetadata = data.schema(inputCol).metadata + val metaKey = "__testPreserveMetadata__fake_key" + val metaValue = 12345 + assert(!origMetadata.contains(metaKey), + s"Unit test with testPreserveMetadata will fail since metadata key was present: $metaKey") + val newMetadata = + new MetadataBuilder().withMetadata(origMetadata).putLong(metaKey, metaValue).build() + // Add metadata to the inputCol + val withMetadata = data.select(data(inputCol).as(inputCol, newMetadata)) + // Transform, and ensure extra metadata was not affected + val transformed = transformer.transform(withMetadata) + val transMetadata = transformed.schema(outputCol).metadata + assert(transMetadata.contains(metaKey), + "Unit test with testPreserveMetadata failed; extra metadata key was not present.") + assert(transMetadata.getLong(metaKey) === metaValue, + "Unit test with testPreserveMetadata failed; extra metadata value was wrong." + + s" Expected $metaValue but found ${transMetadata.getLong(metaKey)}") + } +} diff --git a/mllib/src/test/scala/org/apache/spark/mllib/classification/NaiveBayesSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/classification/NaiveBayesSuite.scala index f9fe3e006ccb8..ea89b17b7c08f 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/classification/NaiveBayesSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/classification/NaiveBayesSuite.scala @@ -102,7 +102,7 @@ class NaiveBayesSuite extends FunSuite with MLlibTestSparkContext { def validateModelFit( piData: Array[Double], thetaData: Array[Array[Double]], - model: NaiveBayesModel) = { + model: NaiveBayesModel): Unit = { def closeFit(d1: Double, d2: Double, precision: Double): Boolean = { (d1 - d2).abs <= precision } diff --git a/mllib/src/test/scala/org/apache/spark/mllib/classification/StreamingLogisticRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/classification/StreamingLogisticRegressionSuite.scala index 8b3e6e5ce9249..5683b55e8500a 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/classification/StreamingLogisticRegressionSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/classification/StreamingLogisticRegressionSuite.scala @@ -30,7 +30,7 @@ import org.apache.spark.streaming.TestSuiteBase class StreamingLogisticRegressionSuite extends FunSuite with TestSuiteBase { // use longer wait time to ensure job completion - override def maxWaitTimeMillis = 30000 + override def maxWaitTimeMillis: Int = 30000 // Test if we can accurately learn B for Y = logistic(BX) on streaming data test("parameter accuracy") { @@ -132,4 +132,31 @@ class StreamingLogisticRegressionSuite extends FunSuite with TestSuiteBase { assert(errors.forall(x => x <= 0.4)) } + // Test training combined with prediction + test("training and prediction") { + // create model initialized with zero weights + val model = new StreamingLogisticRegressionWithSGD() + .setInitialWeights(Vectors.dense(-0.1)) + .setStepSize(0.01) + .setNumIterations(10) + + // generate sequence of simulated data for testing + val numBatches = 10 + val nPoints = 100 + val testInput = (0 until numBatches).map { i => + LogisticRegressionSuite.generateLogisticInput(0.0, 5.0, nPoints, 42 * (i + 1)) + } + + // train and predict + val ssc = setupStreams(testInput, (inputDStream: DStream[LabeledPoint]) => { + model.trainOn(inputDStream) + model.predictOnValues(inputDStream.map(x => (x.label, x.features))) + }) + + val output: Seq[Seq[(Double, Double)]] = runStreams(ssc, numBatches, numBatches) + + // assert that prediction error improves, ensuring that the updated model is being used + val error = output.map(batch => batch.map(p => math.abs(p._1 - p._2)).sum / nPoints).toList + assert(error.head > 0.8 & error.last < 0.2) + } } diff --git a/mllib/src/test/scala/org/apache/spark/mllib/clustering/KMeansSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/clustering/KMeansSuite.scala index 7bf250eb5a383..0f2b26d462ad2 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/clustering/KMeansSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/clustering/KMeansSuite.scala @@ -199,9 +199,13 @@ class KMeansSuite extends FunSuite with MLlibTestSparkContext { test("k-means|| initialization") { case class VectorWithCompare(x: Vector) extends Ordered[VectorWithCompare] { - @Override def compare(that: VectorWithCompare): Int = { - if(this.x.toArray.foldLeft[Double](0.0)((acc, x) => acc + x * x) > - that.x.toArray.foldLeft[Double](0.0)((acc, x) => acc + x * x)) -1 else 1 + override def compare(that: VectorWithCompare): Int = { + if (this.x.toArray.foldLeft[Double](0.0)((acc, x) => acc + x * x) > + that.x.toArray.foldLeft[Double](0.0)((acc, x) => acc + x * x)) { + -1 + } else { + 1 + } } } diff --git a/mllib/src/test/scala/org/apache/spark/mllib/clustering/LDASuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/clustering/LDASuite.scala index 302d751eb8a94..15de10fd13a19 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/clustering/LDASuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/clustering/LDASuite.scala @@ -19,7 +19,7 @@ package org.apache.spark.mllib.clustering import org.scalatest.FunSuite -import org.apache.spark.mllib.linalg.{DenseMatrix, Matrix, Vectors} +import org.apache.spark.mllib.linalg.{Vector, DenseMatrix, Matrix, Vectors} import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.mllib.util.TestingUtils._ @@ -141,7 +141,7 @@ private[clustering] object LDASuite { (terms.toArray, termWeights.toArray) } - def tinyCorpus = Array( + def tinyCorpus: Array[(Long, Vector)] = Array( Vectors.dense(1, 3, 0, 2, 8), Vectors.dense(0, 2, 1, 0, 4), Vectors.dense(2, 3, 12, 3, 1), diff --git a/mllib/src/test/scala/org/apache/spark/mllib/clustering/PowerIterationClusteringSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/clustering/PowerIterationClusteringSuite.scala index 6315c03a700f1..6d6fe6fe46bab 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/clustering/PowerIterationClusteringSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/clustering/PowerIterationClusteringSuite.scala @@ -18,12 +18,15 @@ package org.apache.spark.mllib.clustering import scala.collection.mutable +import scala.util.Random import org.scalatest.FunSuite +import org.apache.spark.SparkContext import org.apache.spark.graphx.{Edge, Graph} import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.mllib.util.TestingUtils._ +import org.apache.spark.util.Utils class PowerIterationClusteringSuite extends FunSuite with MLlibTestSparkContext { @@ -110,4 +113,35 @@ class PowerIterationClusteringSuite extends FunSuite with MLlibTestSparkContext assert(x ~== u1(i.toInt) absTol 1e-14) } } + + test("model save/load") { + val tempDir = Utils.createTempDir() + val path = tempDir.toURI.toString + val model = PowerIterationClusteringSuite.createModel(sc, 3, 10) + try { + model.save(sc, path) + val sameModel = PowerIterationClusteringModel.load(sc, path) + PowerIterationClusteringSuite.checkEqual(model, sameModel) + } finally { + Utils.deleteRecursively(tempDir) + } + } +} + +object PowerIterationClusteringSuite extends FunSuite { + def createModel(sc: SparkContext, k: Int, nPoints: Int): PowerIterationClusteringModel = { + val assignments = sc.parallelize( + (0 until nPoints).map(p => PowerIterationClustering.Assignment(p, Random.nextInt(k)))) + new PowerIterationClusteringModel(k, assignments) + } + + def checkEqual(a: PowerIterationClusteringModel, b: PowerIterationClusteringModel): Unit = { + assert(a.k === b.k) + + val aAssignments = a.assignments.map(x => (x.id, x.cluster)) + val bAssignments = b.assignments.map(x => (x.id, x.cluster)) + val unequalElements = aAssignments.join(bAssignments).filter { + case (id, (c1, c2)) => c1 != c2 }.count() + assert(unequalElements === 0L) + } } diff --git a/mllib/src/test/scala/org/apache/spark/mllib/clustering/StreamingKMeansSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/clustering/StreamingKMeansSuite.scala index 850c9fce507cd..f90025d535e45 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/clustering/StreamingKMeansSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/clustering/StreamingKMeansSuite.scala @@ -27,7 +27,7 @@ import org.apache.spark.util.random.XORShiftRandom class StreamingKMeansSuite extends FunSuite with TestSuiteBase { - override def maxWaitTimeMillis = 30000 + override def maxWaitTimeMillis: Int = 30000 test("accuracy for single center and equivalence to grand average") { // set parameters @@ -59,7 +59,7 @@ class StreamingKMeansSuite extends FunSuite with TestSuiteBase { // estimated center from streaming should exactly match the arithmetic mean of all data points // because the decay factor is set to 1.0 val grandMean = - input.flatten.map(x => x.toBreeze).reduce(_+_) / (numBatches * numPoints).toDouble + input.flatten.map(x => x.toBreeze).reduce(_ + _) / (numBatches * numPoints).toDouble assert(model.latestModel().clusterCenters(0) ~== Vectors.dense(grandMean.toArray) absTol 1E-5) } diff --git a/mllib/src/test/scala/org/apache/spark/mllib/linalg/MatricesSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/linalg/MatricesSuite.scala index 0d2cec58e2c03..86119ec38101e 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/linalg/MatricesSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/linalg/MatricesSuite.scala @@ -439,4 +439,20 @@ class MatricesSuite extends FunSuite { assert(mUDT.typeName == "matrix") assert(mUDT.simpleString == "matrix") } + + test("toString") { + val empty = Matrices.ones(0, 0) + empty.toString(0, 0) + + val mat = Matrices.rand(5, 10, new Random()) + mat.toString(-1, -5) + mat.toString(0, 0) + mat.toString(Int.MinValue, Int.MinValue) + mat.toString(Int.MaxValue, Int.MaxValue) + var lines = mat.toString(6, 50).lines.toArray + assert(lines.size == 5 && lines.forall(_.size <= 50)) + + lines = mat.toString(5, 100).lines.toArray + assert(lines.size == 5 && lines.forall(_.size <= 100)) + } } diff --git a/mllib/src/test/scala/org/apache/spark/mllib/random/RandomRDDsSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/random/RandomRDDsSuite.scala index 6395188a0842a..63f2ea916d457 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/random/RandomRDDsSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/random/RandomRDDsSuite.scala @@ -181,7 +181,8 @@ class RandomRDDsSuite extends FunSuite with MLlibTestSparkContext with Serializa val poisson = RandomRDDs.poissonVectorRDD(sc, poissonMean, rows, cols, parts, seed) testGeneratedVectorRDD(poisson, rows, cols, parts, poissonMean, math.sqrt(poissonMean), 0.1) - val exponential = RandomRDDs.exponentialVectorRDD(sc, exponentialMean, rows, cols, parts, seed) + val exponential = + RandomRDDs.exponentialVectorRDD(sc, exponentialMean, rows, cols, parts, seed) testGeneratedVectorRDD(exponential, rows, cols, parts, exponentialMean, exponentialMean, 0.1) val gamma = RandomRDDs.gammaVectorRDD(sc, gammaShape, gammaScale, rows, cols, parts, seed) @@ -197,7 +198,7 @@ private[random] class MockDistro extends RandomDataGenerator[Double] { // This allows us to check that each partition has a different seed override def nextValue(): Double = seed.toDouble - override def setSeed(seed: Long) = this.seed = seed + override def setSeed(seed: Long): Unit = this.seed = seed override def copy(): MockDistro = new MockDistro } diff --git a/mllib/src/test/scala/org/apache/spark/mllib/recommendation/ALSSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/recommendation/ALSSuite.scala index 8775c0ca9df84..b3798940ddc38 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/recommendation/ALSSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/recommendation/ALSSuite.scala @@ -203,6 +203,7 @@ class ALSSuite extends FunSuite with MLlibTestSparkContext { * @param numProductBlocks number of product blocks to partition products into * @param negativeFactors whether the generated user/product factors can have negative entries */ + // scalastyle:off def testALS( users: Int, products: Int, @@ -216,6 +217,8 @@ class ALSSuite extends FunSuite with MLlibTestSparkContext { numUserBlocks: Int = -1, numProductBlocks: Int = -1, negativeFactors: Boolean = true) { + // scalastyle:on + val (sampledRatings, trueRatings, truePrefs) = ALSSuite.generateRatings(users, products, features, samplingRate, implicitPrefs, negativeWeights, negativeFactors) diff --git a/mllib/src/test/scala/org/apache/spark/mllib/regression/RidgeRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/regression/RidgeRegressionSuite.scala index 43d61151e2471..d6c93cc0e49cd 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/regression/RidgeRegressionSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/regression/RidgeRegressionSuite.scala @@ -35,7 +35,7 @@ private object RidgeRegressionSuite { class RidgeRegressionSuite extends FunSuite with MLlibTestSparkContext { - def predictionError(predictions: Seq[Double], input: Seq[LabeledPoint]) = { + def predictionError(predictions: Seq[Double], input: Seq[LabeledPoint]): Double = { predictions.zip(input).map { case (prediction, expected) => (prediction - expected.label) * (prediction - expected.label) }.reduceLeft(_ + _) / predictions.size diff --git a/mllib/src/test/scala/org/apache/spark/mllib/regression/StreamingLinearRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/regression/StreamingLinearRegressionSuite.scala index 70b43ddb7daf5..26604dbe6c1ef 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/regression/StreamingLinearRegressionSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/regression/StreamingLinearRegressionSuite.scala @@ -29,7 +29,7 @@ import org.apache.spark.streaming.TestSuiteBase class StreamingLinearRegressionSuite extends FunSuite with TestSuiteBase { // use longer wait time to ensure job completion - override def maxWaitTimeMillis = 20000 + override def maxWaitTimeMillis: Int = 20000 // Assert that two values are equal within tolerance epsilon def assertEqual(v1: Double, v2: Double, epsilon: Double) { @@ -139,4 +139,32 @@ class StreamingLinearRegressionSuite extends FunSuite with TestSuiteBase { val errors = output.map(batch => batch.map(p => math.abs(p._1 - p._2)).sum / nPoints) assert(errors.forall(x => x <= 0.1)) } + + // Test training combined with prediction + test("training and prediction") { + // create model initialized with zero weights + val model = new StreamingLinearRegressionWithSGD() + .setInitialWeights(Vectors.dense(0.0, 0.0)) + .setStepSize(0.2) + .setNumIterations(25) + + // generate sequence of simulated data for testing + val numBatches = 10 + val nPoints = 100 + val testInput = (0 until numBatches).map { i => + LinearDataGenerator.generateLinearInput(0.0, Array(10.0, 10.0), nPoints, 42 * (i + 1)) + } + + // train and predict + val ssc = setupStreams(testInput, (inputDStream: DStream[LabeledPoint]) => { + model.trainOn(inputDStream) + model.predictOnValues(inputDStream.map(x => (x.label, x.features))) + }) + + val output: Seq[Seq[(Double, Double)]] = runStreams(ssc, numBatches, numBatches) + + // assert that prediction error improves, ensuring that the updated model is being used + val error = output.map(batch => batch.map(p => math.abs(p._1 - p._2)).sum / nPoints).toList + assert((error.head - error.last) > 2) + } } diff --git a/mllib/src/test/scala/org/apache/spark/mllib/util/TestingUtils.scala b/mllib/src/test/scala/org/apache/spark/mllib/util/TestingUtils.scala index e957fa5d25f4c..352193a67860c 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/util/TestingUtils.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/util/TestingUtils.scala @@ -95,16 +95,16 @@ object TestingUtils { /** * Comparison using absolute tolerance. */ - def absTol(eps: Double): CompareDoubleRightSide = CompareDoubleRightSide(AbsoluteErrorComparison, - x, eps, ABS_TOL_MSG) + def absTol(eps: Double): CompareDoubleRightSide = + CompareDoubleRightSide(AbsoluteErrorComparison, x, eps, ABS_TOL_MSG) /** * Comparison using relative tolerance. */ - def relTol(eps: Double): CompareDoubleRightSide = CompareDoubleRightSide(RelativeErrorComparison, - x, eps, REL_TOL_MSG) + def relTol(eps: Double): CompareDoubleRightSide = + CompareDoubleRightSide(RelativeErrorComparison, x, eps, REL_TOL_MSG) - override def toString = x.toString + override def toString: String = x.toString } case class CompareVectorRightSide( @@ -166,7 +166,7 @@ object TestingUtils { x.toArray.zip(y.toArray).forall(x => x._1 ~= x._2 relTol eps) }, x, eps, REL_TOL_MSG) - override def toString = x.toString + override def toString: String = x.toString } case class CompareMatrixRightSide( @@ -229,7 +229,7 @@ object TestingUtils { x.toArray.zip(y.toArray).forall(x => x._1 ~= x._2 relTol eps) }, x, eps, REL_TOL_MSG) - override def toString = x.toString + override def toString: String = x.toString } } diff --git a/mllib/src/test/scala/org/apache/spark/mllib/util/TestingUtilsSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/util/TestingUtilsSuite.scala index b0ecb33c28483..59e6c778806f4 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/util/TestingUtilsSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/util/TestingUtilsSuite.scala @@ -88,16 +88,20 @@ class TestingUtilsSuite extends FunSuite { assert(!(17.8 ~= 17.59 absTol 0.2)) // Comparisons of numbers very close to zero, and both side of zeros - assert(Double.MinPositiveValue ~== 4 * Double.MinPositiveValue absTol 5 * Double.MinPositiveValue) - assert(Double.MinPositiveValue !~== 6 * Double.MinPositiveValue absTol 5 * Double.MinPositiveValue) - - assert(-Double.MinPositiveValue ~== 3 * Double.MinPositiveValue absTol 5 * Double.MinPositiveValue) - assert(Double.MinPositiveValue !~== -4 * Double.MinPositiveValue absTol 5 * Double.MinPositiveValue) + assert( + Double.MinPositiveValue ~== 4 * Double.MinPositiveValue absTol 5 * Double.MinPositiveValue) + assert( + Double.MinPositiveValue !~== 6 * Double.MinPositiveValue absTol 5 * Double.MinPositiveValue) + + assert( + -Double.MinPositiveValue ~== 3 * Double.MinPositiveValue absTol 5 * Double.MinPositiveValue) + assert( + Double.MinPositiveValue !~== -4 * Double.MinPositiveValue absTol 5 * Double.MinPositiveValue) } test("Comparing vectors using relative error.") { - //Comparisons of two dense vectors + // Comparisons of two dense vectors assert(Vectors.dense(Array(3.1, 3.5)) ~== Vectors.dense(Array(3.130, 3.534)) relTol 0.01) assert(Vectors.dense(Array(3.1, 3.5)) !~== Vectors.dense(Array(3.135, 3.534)) relTol 0.01) assert(Vectors.dense(Array(3.1, 3.5)) ~= Vectors.dense(Array(3.130, 3.534)) relTol 0.01) @@ -130,7 +134,7 @@ class TestingUtilsSuite extends FunSuite { test("Comparing vectors using absolute error.") { - //Comparisons of two dense vectors + // Comparisons of two dense vectors assert(Vectors.dense(Array(3.1, 3.5, 0.0)) ~== Vectors.dense(Array(3.1 + 1E-8, 3.5 + 2E-7, 1E-8)) absTol 1E-6) diff --git a/network/common/pom.xml b/network/common/pom.xml index 7b51845206f4a..22c738bde6d42 100644 --- a/network/common/pom.xml +++ b/network/common/pom.xml @@ -80,6 +80,11 @@ mockito-all test + + org.slf4j + slf4j-log4j12 + test + diff --git a/network/common/src/main/java/org/apache/spark/network/protocol/MessageEncoder.java b/network/common/src/main/java/org/apache/spark/network/protocol/MessageEncoder.java index 91d1e8a538a77..0f999f5dfe8d8 100644 --- a/network/common/src/main/java/org/apache/spark/network/protocol/MessageEncoder.java +++ b/network/common/src/main/java/org/apache/spark/network/protocol/MessageEncoder.java @@ -72,9 +72,11 @@ public void encode(ChannelHandlerContext ctx, Message in, List out) { in.encode(header); assert header.writableBytes() == 0; - out.add(header); if (body != null && bodyLength > 0) { - out.add(body); + out.add(new MessageWithHeader(header, body, bodyLength)); + } else { + out.add(header); } } + } diff --git a/network/common/src/main/java/org/apache/spark/network/protocol/MessageWithHeader.java b/network/common/src/main/java/org/apache/spark/network/protocol/MessageWithHeader.java new file mode 100644 index 0000000000000..d686a951467cf --- /dev/null +++ b/network/common/src/main/java/org/apache/spark/network/protocol/MessageWithHeader.java @@ -0,0 +1,109 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network.protocol; + +import java.io.IOException; +import java.nio.channels.WritableByteChannel; + +import com.google.common.base.Preconditions; +import io.netty.buffer.ByteBuf; +import io.netty.channel.FileRegion; +import io.netty.util.AbstractReferenceCounted; +import io.netty.util.ReferenceCountUtil; + +/** + * A wrapper message that holds two separate pieces (a header and a body). + * + * The header must be a ByteBuf, while the body can be a ByteBuf or a FileRegion. + */ +class MessageWithHeader extends AbstractReferenceCounted implements FileRegion { + + private final ByteBuf header; + private final int headerLength; + private final Object body; + private final long bodyLength; + private long totalBytesTransferred; + + MessageWithHeader(ByteBuf header, Object body, long bodyLength) { + Preconditions.checkArgument(body instanceof ByteBuf || body instanceof FileRegion, + "Body must be a ByteBuf or a FileRegion."); + this.header = header; + this.headerLength = header.readableBytes(); + this.body = body; + this.bodyLength = bodyLength; + } + + @Override + public long count() { + return headerLength + bodyLength; + } + + @Override + public long position() { + return 0; + } + + @Override + public long transfered() { + return totalBytesTransferred; + } + + /** + * This code is more complicated than you would think because we might require multiple + * transferTo invocations in order to transfer a single MessageWithHeader to avoid busy waiting. + * + * The contract is that the caller will ensure position is properly set to the total number + * of bytes transferred so far (i.e. value returned by transfered()). + */ + @Override + public long transferTo(final WritableByteChannel target, final long position) throws IOException { + Preconditions.checkArgument(position == totalBytesTransferred, "Invalid position."); + // Bytes written for header in this call. + long writtenHeader = 0; + if (header.readableBytes() > 0) { + writtenHeader = copyByteBuf(header, target); + totalBytesTransferred += writtenHeader; + if (header.readableBytes() > 0) { + return writtenHeader; + } + } + + // Bytes written for body in this call. + long writtenBody = 0; + if (body instanceof FileRegion) { + writtenBody = ((FileRegion) body).transferTo(target, totalBytesTransferred - headerLength); + } else if (body instanceof ByteBuf) { + writtenBody = copyByteBuf((ByteBuf) body, target); + } + totalBytesTransferred += writtenBody; + + return writtenHeader + writtenBody; + } + + @Override + protected void deallocate() { + header.release(); + ReferenceCountUtil.release(body); + } + + private int copyByteBuf(ByteBuf buf, WritableByteChannel target) throws IOException { + int written = target.write(buf.nioBuffer()); + buf.skipBytes(written); + return written; + } +} diff --git a/project/spark-style/src/main/scala/org/apache/spark/scalastyle/NonASCIICharacterChecker.scala b/network/common/src/test/java/org/apache/spark/network/ByteArrayWritableChannel.java similarity index 55% rename from project/spark-style/src/main/scala/org/apache/spark/scalastyle/NonASCIICharacterChecker.scala rename to network/common/src/test/java/org/apache/spark/network/ByteArrayWritableChannel.java index 3d43c35299555..b525ed69fc9fb 100644 --- a/project/spark-style/src/main/scala/org/apache/spark/scalastyle/NonASCIICharacterChecker.scala +++ b/network/common/src/test/java/org/apache/spark/network/ByteArrayWritableChannel.java @@ -15,25 +15,41 @@ * limitations under the License. */ +package org.apache.spark.network; -package org.apache.spark.scalastyle +import java.nio.ByteBuffer; +import java.nio.channels.WritableByteChannel; -import java.util.regex.Pattern +public class ByteArrayWritableChannel implements WritableByteChannel { -import org.scalastyle.{PositionError, ScalariformChecker, ScalastyleError} + private final byte[] data; + private int offset; -import scalariform.lexer.Token -import scalariform.parser.CompilationUnit + public ByteArrayWritableChannel(int size) { + this.data = new byte[size]; + this.offset = 0; + } + + public byte[] getData() { + return data; + } -class NonASCIICharacterChecker extends ScalariformChecker { - val errorKey: String = "non.ascii.character.disallowed" + @Override + public int write(ByteBuffer src) { + int available = src.remaining(); + src.get(data, offset, available); + offset += available; + return available; + } + + @Override + public void close() { - override def verify(ast: CompilationUnit): List[ScalastyleError] = { - ast.tokens.filter(hasNonAsciiChars).map(x => PositionError(x.offset)).toList } - private def hasNonAsciiChars(x: Token) = - x.rawText.trim.nonEmpty && !Pattern.compile( """\p{ASCII}+""", Pattern.DOTALL) - .matcher(x.text.trim).matches() + @Override + public boolean isOpen() { + return true; + } } diff --git a/network/common/src/test/java/org/apache/spark/network/ProtocolSuite.java b/network/common/src/test/java/org/apache/spark/network/ProtocolSuite.java index 43dc0cf8c7194..860dd6d9b3915 100644 --- a/network/common/src/test/java/org/apache/spark/network/ProtocolSuite.java +++ b/network/common/src/test/java/org/apache/spark/network/ProtocolSuite.java @@ -17,26 +17,34 @@ package org.apache.spark.network; +import java.util.List; + +import com.google.common.primitives.Ints; +import io.netty.buffer.Unpooled; +import io.netty.channel.ChannelHandlerContext; +import io.netty.channel.FileRegion; import io.netty.channel.embedded.EmbeddedChannel; +import io.netty.handler.codec.MessageToMessageEncoder; import org.junit.Test; import static org.junit.Assert.assertEquals; -import org.apache.spark.network.protocol.Message; -import org.apache.spark.network.protocol.StreamChunkId; -import org.apache.spark.network.protocol.ChunkFetchRequest; import org.apache.spark.network.protocol.ChunkFetchFailure; +import org.apache.spark.network.protocol.ChunkFetchRequest; import org.apache.spark.network.protocol.ChunkFetchSuccess; -import org.apache.spark.network.protocol.RpcRequest; -import org.apache.spark.network.protocol.RpcFailure; -import org.apache.spark.network.protocol.RpcResponse; +import org.apache.spark.network.protocol.Message; import org.apache.spark.network.protocol.MessageDecoder; import org.apache.spark.network.protocol.MessageEncoder; +import org.apache.spark.network.protocol.RpcFailure; +import org.apache.spark.network.protocol.RpcRequest; +import org.apache.spark.network.protocol.RpcResponse; +import org.apache.spark.network.protocol.StreamChunkId; import org.apache.spark.network.util.NettyUtils; public class ProtocolSuite { private void testServerToClient(Message msg) { - EmbeddedChannel serverChannel = new EmbeddedChannel(new MessageEncoder()); + EmbeddedChannel serverChannel = new EmbeddedChannel(new FileRegionEncoder(), + new MessageEncoder()); serverChannel.writeOutbound(msg); EmbeddedChannel clientChannel = new EmbeddedChannel( @@ -51,7 +59,8 @@ private void testServerToClient(Message msg) { } private void testClientToServer(Message msg) { - EmbeddedChannel clientChannel = new EmbeddedChannel(new MessageEncoder()); + EmbeddedChannel clientChannel = new EmbeddedChannel(new FileRegionEncoder(), + new MessageEncoder()); clientChannel.writeOutbound(msg); EmbeddedChannel serverChannel = new EmbeddedChannel( @@ -83,4 +92,25 @@ public void responses() { testServerToClient(new RpcFailure(0, "this is an error")); testServerToClient(new RpcFailure(0, "")); } + + /** + * Handler to transform a FileRegion into a byte buffer. EmbeddedChannel doesn't actually transfer + * bytes, but messages, so this is needed so that the frame decoder on the receiving side can + * understand what MessageWithHeader actually contains. + */ + private static class FileRegionEncoder extends MessageToMessageEncoder { + + @Override + public void encode(ChannelHandlerContext ctx, FileRegion in, List out) + throws Exception { + + ByteArrayWritableChannel channel = new ByteArrayWritableChannel(Ints.checkedCast(in.count())); + while (in.transfered() < in.count()) { + in.transferTo(channel, in.transfered()); + } + out.add(Unpooled.wrappedBuffer(channel.getData())); + } + + } + } diff --git a/network/common/src/test/java/org/apache/spark/network/protocol/MessageWithHeaderSuite.java b/network/common/src/test/java/org/apache/spark/network/protocol/MessageWithHeaderSuite.java new file mode 100644 index 0000000000000..ff985096d72d5 --- /dev/null +++ b/network/common/src/test/java/org/apache/spark/network/protocol/MessageWithHeaderSuite.java @@ -0,0 +1,129 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network.protocol; + +import java.io.IOException; +import java.nio.ByteBuffer; +import java.nio.channels.WritableByteChannel; + +import io.netty.buffer.ByteBuf; +import io.netty.buffer.Unpooled; +import io.netty.channel.FileRegion; +import io.netty.util.AbstractReferenceCounted; +import org.junit.Test; + +import static org.junit.Assert.*; + +import org.apache.spark.network.ByteArrayWritableChannel; + +public class MessageWithHeaderSuite { + + @Test + public void testSingleWrite() throws Exception { + testFileRegionBody(8, 8); + } + + @Test + public void testShortWrite() throws Exception { + testFileRegionBody(8, 1); + } + + @Test + public void testByteBufBody() throws Exception { + ByteBuf header = Unpooled.copyLong(42); + ByteBuf body = Unpooled.copyLong(84); + MessageWithHeader msg = new MessageWithHeader(header, body, body.readableBytes()); + + ByteBuf result = doWrite(msg, 1); + assertEquals(msg.count(), result.readableBytes()); + assertEquals(42, result.readLong()); + assertEquals(84, result.readLong()); + } + + private void testFileRegionBody(int totalWrites, int writesPerCall) throws Exception { + ByteBuf header = Unpooled.copyLong(42); + int headerLength = header.readableBytes(); + TestFileRegion region = new TestFileRegion(totalWrites, writesPerCall); + MessageWithHeader msg = new MessageWithHeader(header, region, region.count()); + + ByteBuf result = doWrite(msg, totalWrites / writesPerCall); + assertEquals(headerLength + region.count(), result.readableBytes()); + assertEquals(42, result.readLong()); + for (long i = 0; i < 8; i++) { + assertEquals(i, result.readLong()); + } + } + + private ByteBuf doWrite(MessageWithHeader msg, int minExpectedWrites) throws Exception { + int writes = 0; + ByteArrayWritableChannel channel = new ByteArrayWritableChannel((int) msg.count()); + while (msg.transfered() < msg.count()) { + msg.transferTo(channel, msg.transfered()); + writes++; + } + assertTrue("Not enough writes!", minExpectedWrites <= writes); + return Unpooled.wrappedBuffer(channel.getData()); + } + + private static class TestFileRegion extends AbstractReferenceCounted implements FileRegion { + + private final int writeCount; + private final int writesPerCall; + private int written; + + TestFileRegion(int totalWrites, int writesPerCall) { + this.writeCount = totalWrites; + this.writesPerCall = writesPerCall; + } + + @Override + public long count() { + return 8 * writeCount; + } + + @Override + public long position() { + return 0; + } + + @Override + public long transfered() { + return 8 * written; + } + + @Override + public long transferTo(WritableByteChannel target, long position) throws IOException { + for (int i = 0; i < writesPerCall; i++) { + ByteBuf buf = Unpooled.copyLong((position / 8) + i); + ByteBuffer nio = buf.nioBuffer(); + while (nio.remaining() > 0) { + target.write(nio); + } + buf.release(); + written++; + } + return 8 * writesPerCall; + } + + @Override + protected void deallocate() { + } + + } + +} diff --git a/network/common/src/test/resources/log4j.properties b/network/common/src/test/resources/log4j.properties new file mode 100644 index 0000000000000..e8da774f7ca9e --- /dev/null +++ b/network/common/src/test/resources/log4j.properties @@ -0,0 +1,27 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +# Set everything to be logged to the file target/unit-tests.log +log4j.rootCategory=DEBUG, file +log4j.appender.file=org.apache.log4j.FileAppender +log4j.appender.file.append=true +log4j.appender.file.file=target/unit-tests.log +log4j.appender.file.layout=org.apache.log4j.PatternLayout +log4j.appender.file.layout.ConversionPattern=%d{yy/MM/dd HH:mm:ss.SSS} %t %p %c{1}: %m%n + +# Silence verbose logs from 3rd-party libraries. +log4j.logger.io.netty=INFO diff --git a/pom.xml b/pom.xml index 42bd926a2fcb8..d8881c213bf07 100644 --- a/pom.xml +++ b/pom.xml @@ -159,6 +159,8 @@ 1.1.1.6 1.1.2 + ${java.home} + ${test_classpath} + ${test.java.home} true @@ -1224,6 +1227,7 @@ launched by the tests have access to the correct test-time classpath. --> ${test_classpath} + ${test.java.home} true @@ -1716,6 +1720,16 @@ + + test-java-home + + env.JAVA_HOME + + + ${env.JAVA_HOME} + + + scala-2.11 @@ -1749,5 +1763,8 @@ parquet-provided + + sparkr + diff --git a/project/MimaExcludes.scala b/project/MimaExcludes.scala index 54500f7c2701f..1564babefa62f 100644 --- a/project/MimaExcludes.scala +++ b/project/MimaExcludes.scala @@ -60,6 +60,14 @@ object MimaExcludes { ) ++ Seq( // SPARK-6510 Add a Graph#minus method acting as Set#difference ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.graphx.VertexRDD.minus") + ) ++ Seq( + // SPARK-6492 Fix deadlock in SparkContext.stop() + ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.SparkContext.org$" + + "apache$spark$SparkContext$$SPARK_CONTEXT_CONSTRUCTOR_LOCK") + )++ Seq( + // SPARK-6693 add tostring with max lines and width for matrix + ProblemFilters.exclude[MissingMethodProblem]( + "org.apache.spark.mllib.linalg.Matrix.toString") ) case v if v.startsWith("1.3") => diff --git a/project/SparkBuild.scala b/project/SparkBuild.scala index d3faa551a4b14..5f51f4b58f97a 100644 --- a/project/SparkBuild.scala +++ b/project/SparkBuild.scala @@ -119,7 +119,9 @@ object SparkBuild extends PomBuild { lazy val publishLocalBoth = TaskKey[Unit]("publish-local", "publish local for m2 and ivy") lazy val sharedSettings = graphSettings ++ genjavadocSettings ++ Seq ( - javaHome := Properties.envOrNone("JAVA_HOME").map(file), + javaHome := sys.env.get("JAVA_HOME") + .orElse(sys.props.get("java.home").map { p => new File(p).getParentFile().getAbsolutePath() }) + .map(file), incOptions := incOptions.value.withNameHashing(true), retrieveManaged := true, retrievePattern := "[type]s/[artifact](-[revision])(-[classifier]).[ext]", @@ -426,8 +428,10 @@ object TestSettings { fork := true, // Setting SPARK_DIST_CLASSPATH is a simple way to make sure any child processes // launched by the tests have access to the correct test-time classpath. - envVars in Test += ("SPARK_DIST_CLASSPATH" -> - (fullClasspath in Test).value.files.map(_.getAbsolutePath).mkString(":").stripSuffix(":")), + envVars in Test ++= Map( + "SPARK_DIST_CLASSPATH" -> + (fullClasspath in Test).value.files.map(_.getAbsolutePath).mkString(":").stripSuffix(":"), + "JAVA_HOME" -> sys.env.get("JAVA_HOME").getOrElse(sys.props("java.home"))), javaOptions in Test += "-Dspark.test.home=" + sparkHome, javaOptions in Test += "-Dspark.testing=1", javaOptions in Test += "-Dspark.port.maxRetries=100", diff --git a/project/plugins.sbt b/project/plugins.sbt index ee45b6a51905e..7096b0d3ee7de 100644 --- a/project/plugins.sbt +++ b/project/plugins.sbt @@ -19,7 +19,7 @@ addSbtPlugin("com.github.mpeltonen" % "sbt-idea" % "1.6.0") addSbtPlugin("net.virtual-void" % "sbt-dependency-graph" % "0.7.4") -addSbtPlugin("org.scalastyle" %% "scalastyle-sbt-plugin" % "0.6.0") +addSbtPlugin("org.scalastyle" %% "scalastyle-sbt-plugin" % "0.7.0") addSbtPlugin("com.typesafe" % "sbt-mima-plugin" % "0.1.6") diff --git a/project/project/SparkPluginBuild.scala b/project/project/SparkPluginBuild.scala index 8863f272da415..471d00bd8223f 100644 --- a/project/project/SparkPluginBuild.scala +++ b/project/project/SparkPluginBuild.scala @@ -24,20 +24,6 @@ import sbt.Keys._ * becomes available for scalastyle sbt plugin. */ object SparkPluginDef extends Build { - lazy val root = Project("plugins", file(".")) dependsOn(sparkStyle, sbtPomReader) - lazy val sparkStyle = Project("spark-style", file("spark-style"), settings = styleSettings) + lazy val root = Project("plugins", file(".")) dependsOn(sbtPomReader) lazy val sbtPomReader = uri("https://github.com/ScrapCodes/sbt-pom-reader.git#ignore_artifact_id") - - // There is actually no need to publish this artifact. - def styleSettings = Defaults.defaultSettings ++ Seq ( - name := "spark-style", - organization := "org.apache.spark", - scalaVersion := "2.10.4", - scalacOptions := Seq("-unchecked", "-deprecation"), - libraryDependencies ++= Dependencies.scalaStyle - ) - - object Dependencies { - val scalaStyle = Seq("org.scalastyle" %% "scalastyle" % "0.4.0") - } } diff --git a/python/docs/pyspark.mllib.rst b/python/docs/pyspark.mllib.rst index 15101470afc07..26ece4c2c389a 100644 --- a/python/docs/pyspark.mllib.rst +++ b/python/docs/pyspark.mllib.rst @@ -31,6 +31,13 @@ pyspark.mllib.feature module :undoc-members: :show-inheritance: +pyspark.mllib.fpm module +------------------------ + +.. automodule:: pyspark.mllib.fpm + :members: + :undoc-members: + pyspark.mllib.linalg module --------------------------- diff --git a/python/docs/pyspark.streaming.rst b/python/docs/pyspark.streaming.rst index 7890d9dcaac21..50822c93faba1 100644 --- a/python/docs/pyspark.streaming.rst +++ b/python/docs/pyspark.streaming.rst @@ -10,7 +10,7 @@ Module contents :show-inheritance: pyspark.streaming.kafka module ----------------------------- +------------------------------ .. automodule:: pyspark.streaming.kafka :members: :undoc-members: diff --git a/python/pyspark/java_gateway.py b/python/pyspark/java_gateway.py index 0a16cbd8bff62..2a5e84a7dfdb4 100644 --- a/python/pyspark/java_gateway.py +++ b/python/pyspark/java_gateway.py @@ -29,11 +29,10 @@ def launch_gateway(): - SPARK_HOME = os.environ["SPARK_HOME"] - if "PYSPARK_GATEWAY_PORT" in os.environ: gateway_port = int(os.environ["PYSPARK_GATEWAY_PORT"]) else: + SPARK_HOME = os.environ["SPARK_HOME"] # Launch the Py4j gateway using Spark's run command so that we pick up the # proper classpath and settings from spark-env.sh on_windows = platform.system() == "Windows" diff --git a/python/pyspark/join.py b/python/pyspark/join.py index efc1ef9396412..c3491defb2b29 100644 --- a/python/pyspark/join.py +++ b/python/pyspark/join.py @@ -48,7 +48,7 @@ def dispatch(seq): vbuf.append(v) elif n == 2: wbuf.append(v) - return [(v, w) for v in vbuf for w in wbuf] + return ((v, w) for v in vbuf for w in wbuf) return _do_python_join(rdd, other, numPartitions, dispatch) @@ -62,7 +62,7 @@ def dispatch(seq): wbuf.append(v) if not vbuf: vbuf.append(None) - return [(v, w) for v in vbuf for w in wbuf] + return ((v, w) for v in vbuf for w in wbuf) return _do_python_join(rdd, other, numPartitions, dispatch) @@ -76,7 +76,7 @@ def dispatch(seq): wbuf.append(v) if not wbuf: wbuf.append(None) - return [(v, w) for v in vbuf for w in wbuf] + return ((v, w) for v in vbuf for w in wbuf) return _do_python_join(rdd, other, numPartitions, dispatch) @@ -104,8 +104,9 @@ def make_mapper(i): rdd_len = len(vrdds) def dispatch(seq): - bufs = [[] for i in range(rdd_len)] - for (n, v) in seq: + bufs = [[] for _ in range(rdd_len)] + for n, v in seq: bufs[n].append(v) - return tuple(map(ResultIterable, bufs)) + return tuple(ResultIterable(vs) for vs in bufs) + return union_vrdds.groupByKey(numPartitions).mapValues(dispatch) diff --git a/python/pyspark/ml/classification.py b/python/pyspark/ml/classification.py index 4ff7463498cce..7f42de531f3b4 100644 --- a/python/pyspark/ml/classification.py +++ b/python/pyspark/ml/classification.py @@ -91,9 +91,9 @@ class LogisticRegressionModel(JavaModel): # The small batch size here ensures that we see multiple batches, # even in these small test examples: sc = SparkContext("local[2]", "ml.feature tests") - sqlCtx = SQLContext(sc) + sqlContext = SQLContext(sc) globs['sc'] = sc - globs['sqlCtx'] = sqlCtx + globs['sqlContext'] = sqlContext (failure_count, test_count) = doctest.testmod( globs=globs, optionflags=doctest.ELLIPSIS) sc.stop() diff --git a/python/pyspark/ml/feature.py b/python/pyspark/ml/feature.py index 433b4fb5d22bf..1cfcd019dfb18 100644 --- a/python/pyspark/ml/feature.py +++ b/python/pyspark/ml/feature.py @@ -117,9 +117,9 @@ def setParams(self, numFeatures=1 << 18, inputCol="input", outputCol="output"): # The small batch size here ensures that we see multiple batches, # even in these small test examples: sc = SparkContext("local[2]", "ml.feature tests") - sqlCtx = SQLContext(sc) + sqlContext = SQLContext(sc) globs['sc'] = sc - globs['sqlCtx'] = sqlCtx + globs['sqlContext'] = sqlContext (failure_count, test_count) = doctest.testmod( globs=globs, optionflags=doctest.ELLIPSIS) sc.stop() diff --git a/python/pyspark/mllib/__init__.py b/python/pyspark/mllib/__init__.py index 6449800d9c120..f2ef573fe9f6f 100644 --- a/python/pyspark/mllib/__init__.py +++ b/python/pyspark/mllib/__init__.py @@ -25,7 +25,7 @@ if numpy.version.version < '1.4': raise Exception("MLlib requires NumPy 1.4+") -__all__ = ['classification', 'clustering', 'feature', 'linalg', 'random', +__all__ = ['classification', 'clustering', 'feature', 'fpm', 'linalg', 'random', 'recommendation', 'regression', 'stat', 'tree', 'util'] import sys diff --git a/python/pyspark/mllib/feature.py b/python/pyspark/mllib/feature.py index 4bfe3014ef748..8be819aceec24 100644 --- a/python/pyspark/mllib/feature.py +++ b/python/pyspark/mllib/feature.py @@ -132,6 +132,22 @@ def transform(self, vector): """ return JavaVectorTransformer.transform(self, vector) + def setWithMean(self, withMean): + """ + Setter of the boolean which decides + whether it uses mean or not + """ + self.call("setWithMean", withMean) + return self + + def setWithStd(self, withStd): + """ + Setter of the boolean which decides + whether it uses std or not + """ + self.call("setWithStd", withStd) + return self + class StandardScaler(object): """ @@ -337,6 +353,12 @@ def findSynonyms(self, word, num): words, similarity = self.call("findSynonyms", word, num) return zip(words, similarity) + def getVectors(self): + """ + Returns a map of words to their vector representations. + """ + return self.call("getVectors") + class Word2Vec(object): """ @@ -379,6 +401,7 @@ def __init__(self): self.numPartitions = 1 self.numIterations = 1 self.seed = random.randint(0, sys.maxint) + self.minCount = 5 def setVectorSize(self, vectorSize): """ @@ -417,6 +440,14 @@ def setSeed(self, seed): self.seed = seed return self + def setMinCount(self, minCount): + """ + Sets minCount, the minimum number of times a token must appear + to be included in the word2vec model's vocabulary (default: 5). + """ + self.minCount = minCount + return self + def fit(self, data): """ Computes the vector representation of each word in vocabulary. @@ -428,7 +459,8 @@ def fit(self, data): raise TypeError("data should be an RDD of list of string") jmodel = callMLlibFunc("trainWord2Vec", data, int(self.vectorSize), float(self.learningRate), int(self.numPartitions), - int(self.numIterations), long(self.seed)) + int(self.numIterations), long(self.seed), + int(self.minCount)) return Word2VecModel(jmodel) diff --git a/python/pyspark/mllib/fpm.py b/python/pyspark/mllib/fpm.py new file mode 100644 index 0000000000000..3aa6d79d7093c --- /dev/null +++ b/python/pyspark/mllib/fpm.py @@ -0,0 +1,81 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from pyspark import SparkContext +from pyspark.mllib.common import JavaModelWrapper, callMLlibFunc, inherit_doc + +__all__ = ['FPGrowth', 'FPGrowthModel'] + + +@inherit_doc +class FPGrowthModel(JavaModelWrapper): + + """ + .. note:: Experimental + + A FP-Growth model for mining frequent itemsets + using the Parallel FP-Growth algorithm. + + >>> data = [["a", "b", "c"], ["a", "b", "d", "e"], ["a", "c", "e"], ["a", "c", "f"]] + >>> rdd = sc.parallelize(data, 2) + >>> model = FPGrowth.train(rdd, 0.6, 2) + >>> sorted(model.freqItemsets().collect()) + [([u'a'], 4), ([u'c'], 3), ([u'c', u'a'], 3)] + """ + + def freqItemsets(self): + """ + Get the frequent itemsets of this model + """ + return self.call("getFreqItemsets") + + +class FPGrowth(object): + """ + .. note:: Experimental + + A Parallel FP-growth algorithm to mine frequent itemsets. + """ + + @classmethod + def train(cls, data, minSupport=0.3, numPartitions=-1): + """ + Computes an FP-Growth model that contains frequent itemsets. + :param data: The input data set, each element + contains a transaction. + :param minSupport: The minimal support level + (default: `0.3`). + :param numPartitions: The number of partitions used by parallel + FP-growth (default: same as input data). + """ + model = callMLlibFunc("trainFPGrowthModel", data, float(minSupport), int(numPartitions)) + return FPGrowthModel(model) + + +def _test(): + import doctest + import pyspark.mllib.fpm + globs = pyspark.mllib.fpm.__dict__.copy() + globs['sc'] = SparkContext('local[4]', 'PythonTest') + (failure_count, test_count) = doctest.testmod(globs=globs, optionflags=doctest.ELLIPSIS) + globs['sc'].stop() + if failure_count: + exit(-1) + + +if __name__ == "__main__": + _test() diff --git a/python/pyspark/mllib/linalg.py b/python/pyspark/mllib/linalg.py index f5aad28afda0f..a80320c52d1d0 100644 --- a/python/pyspark/mllib/linalg.py +++ b/python/pyspark/mllib/linalg.py @@ -173,7 +173,24 @@ def toArray(self): class DenseVector(Vector): """ - A dense vector represented by a value array. + A dense vector represented by a value array. We use numpy array for + storage and arithmetics will be delegated to the underlying numpy + array. + + >>> v = Vectors.dense([1.0, 2.0]) + >>> u = Vectors.dense([3.0, 4.0]) + >>> v + u + DenseVector([4.0, 6.0]) + >>> 2 - v + DenseVector([1.0, 0.0]) + >>> v / 2 + DenseVector([0.5, 1.0]) + >>> v * u + DenseVector([3.0, 8.0]) + >>> u / v + DenseVector([3.0, 2.0]) + >>> u % 2 + DenseVector([1.0, 0.0]) """ def __init__(self, ar): if isinstance(ar, basestring): @@ -292,6 +309,25 @@ def __ne__(self, other): def __getattr__(self, item): return getattr(self.array, item) + def _delegate(op): + def func(self, other): + if isinstance(other, DenseVector): + other = other.array + return DenseVector(getattr(self.array, op)(other)) + return func + + __neg__ = _delegate("__neg__") + __add__ = _delegate("__add__") + __sub__ = _delegate("__sub__") + __mul__ = _delegate("__mul__") + __div__ = _delegate("__div__") + __mod__ = _delegate("__mod__") + __radd__ = _delegate("__radd__") + __rsub__ = _delegate("__rsub__") + __rmul__ = _delegate("__rmul__") + __rdiv__ = _delegate("__rdiv__") + __rmod__ = _delegate("__rmod__") + class SparseVector(Vector): """ @@ -604,6 +640,15 @@ def toArray(self): """ raise NotImplementedError + @staticmethod + def _convert_to_array(array_like, dtype): + """ + Convert Matrix attributes which are array-like or buffer to array. + """ + if isinstance(array_like, basestring): + return np.frombuffer(array_like, dtype=dtype) + return np.asarray(array_like, dtype=dtype) + class DenseMatrix(Matrix): """ @@ -611,13 +656,8 @@ class DenseMatrix(Matrix): """ def __init__(self, numRows, numCols, values): Matrix.__init__(self, numRows, numCols) - if isinstance(values, basestring): - values = np.frombuffer(values, dtype=np.float64) - elif not isinstance(values, np.ndarray): - values = np.array(values, dtype=np.float64) + values = self._convert_to_array(values, np.float64) assert len(values) == numRows * numCols - if values.dtype != np.float64: - values.astype(np.float64) self.values = values def __reduce__(self): @@ -634,6 +674,27 @@ def toArray(self): """ return self.values.reshape((self.numRows, self.numCols), order='F') + def toSparse(self): + """Convert to SparseMatrix""" + indices = np.nonzero(self.values)[0] + colCounts = np.bincount(indices / self.numRows) + colPtrs = np.cumsum(np.hstack( + (0, colCounts, np.zeros(self.numCols - colCounts.size)))) + values = self.values[indices] + rowIndices = indices % self.numRows + + return SparseMatrix(self.numRows, self.numCols, colPtrs, rowIndices, values) + + def __getitem__(self, indices): + i, j = indices + if i < 0 or i >= self.numRows: + raise ValueError("Row index %d is out of range [0, %d)" + % (i, self.numRows)) + if j >= self.numCols or j < 0: + raise ValueError("Column index %d is out of range [0, %d)" + % (j, self.numCols)) + return self.values[i + j * self.numRows] + def __eq__(self, other): return (isinstance(other, DenseMatrix) and self.numRows == other.numRows and @@ -641,6 +702,82 @@ def __eq__(self, other): all(self.values == other.values)) +class SparseMatrix(Matrix): + """Sparse Matrix stored in CSC format.""" + def __init__(self, numRows, numCols, colPtrs, rowIndices, values, + isTransposed=False): + Matrix.__init__(self, numRows, numCols) + self.isTransposed = isTransposed + self.colPtrs = self._convert_to_array(colPtrs, np.int32) + self.rowIndices = self._convert_to_array(rowIndices, np.int32) + self.values = self._convert_to_array(values, np.float64) + + if self.isTransposed: + if self.colPtrs.size != numRows + 1: + raise ValueError("Expected colPtrs of size %d, got %d." + % (numRows + 1, self.colPtrs.size)) + else: + if self.colPtrs.size != numCols + 1: + raise ValueError("Expected colPtrs of size %d, got %d." + % (numCols + 1, self.colPtrs.size)) + if self.rowIndices.size != self.values.size: + raise ValueError("Expected rowIndices of length %d, got %d." + % (self.rowIndices.size, self.values.size)) + + def __reduce__(self): + return SparseMatrix, ( + self.numRows, self.numCols, self.colPtrs.tostring(), + self.rowIndices.tostring(), self.values.tostring(), + self.isTransposed) + + def __getitem__(self, indices): + i, j = indices + if i < 0 or i >= self.numRows: + raise ValueError("Row index %d is out of range [0, %d)" + % (i, self.numRows)) + if j < 0 or j >= self.numCols: + raise ValueError("Column index %d is out of range [0, %d)" + % (j, self.numCols)) + + # If a CSR matrix is given, then the row index should be searched + # for in ColPtrs, and the column index should be searched for in the + # corresponding slice obtained from rowIndices. + if self.isTransposed: + j, i = i, j + + colStart = self.colPtrs[j] + colEnd = self.colPtrs[j + 1] + nz = self.rowIndices[colStart: colEnd] + ind = np.searchsorted(nz, i) + colStart + if ind < colEnd and self.rowIndices[ind] == i: + return self.values[ind] + else: + return 0.0 + + def toArray(self): + """ + Return an numpy.ndarray + """ + A = np.zeros((self.numRows, self.numCols), dtype=np.float64, order='F') + for k in xrange(self.colPtrs.size - 1): + startptr = self.colPtrs[k] + endptr = self.colPtrs[k + 1] + if self.isTransposed: + A[k, self.rowIndices[startptr:endptr]] = self.values[startptr:endptr] + else: + A[self.rowIndices[startptr:endptr], k] = self.values[startptr:endptr] + return A + + def toDense(self): + densevals = np.reshape( + self.toArray(), (self.numRows * self.numCols), order='F') + return DenseMatrix(self.numRows, self.numCols, densevals) + + # TODO: More efficient implementation: + def __eq__(self, other): + return np.all(self.toArray == other.toArray) + + class Matrices(object): @staticmethod def dense(numRows, numCols, values): @@ -649,6 +786,13 @@ def dense(numRows, numCols, values): """ return DenseMatrix(numRows, numCols, values) + @staticmethod + def sparse(numRows, numCols, colPtrs, rowIndices, values): + """ + Create a SparseMatrix + """ + return SparseMatrix(numRows, numCols, colPtrs, rowIndices, values) + def _test(): import doctest diff --git a/python/pyspark/mllib/recommendation.py b/python/pyspark/mllib/recommendation.py index b094e50856f70..c5c4c13dae105 100644 --- a/python/pyspark/mllib/recommendation.py +++ b/python/pyspark/mllib/recommendation.py @@ -52,7 +52,7 @@ class MatrixFactorizationModel(JavaModelWrapper, JavaSaveable, JavaLoader): >>> ratings = sc.parallelize([r1, r2, r3]) >>> model = ALS.trainImplicit(ratings, 1, seed=10) >>> model.predict(2, 2) - 0.43... + 0.4... >>> testset = sc.parallelize([(1, 2), (1, 1)]) >>> model = ALS.train(ratings, 2, seed=0) @@ -82,14 +82,14 @@ class MatrixFactorizationModel(JavaModelWrapper, JavaSaveable, JavaLoader): >>> model = ALS.trainImplicit(ratings, 1, nonnegative=True, seed=10) >>> model.predict(2,2) - 0.43... + 0.4... >>> import os, tempfile >>> path = tempfile.mkdtemp() >>> model.save(sc, path) >>> sameModel = MatrixFactorizationModel.load(sc, path) >>> sameModel.predict(2,2) - 0.43... + 0.4... >>> sameModel.predictAll(testset).collect() [Rating(... >>> try: diff --git a/python/pyspark/mllib/stat/_statistics.py b/python/pyspark/mllib/stat/_statistics.py index 218ac148ca992..1d83e9d483f8e 100644 --- a/python/pyspark/mllib/stat/_statistics.py +++ b/python/pyspark/mllib/stat/_statistics.py @@ -49,6 +49,12 @@ def max(self): def min(self): return self.call("min").toArray() + def normL1(self): + return self.call("normL1").toArray() + + def normL2(self): + return self.call("normL2").toArray() + class Statistics(object): diff --git a/python/pyspark/mllib/tests.py b/python/pyspark/mllib/tests.py index 3bb0f0ca68128..8eaddcf8b9b5e 100644 --- a/python/pyspark/mllib/tests.py +++ b/python/pyspark/mllib/tests.py @@ -24,7 +24,7 @@ import tempfile import array as pyarray -from numpy import array, array_equal +from numpy import array, array_equal, zeros from py4j.protocol import Py4JJavaError if sys.version_info[:2] <= (2, 6): @@ -36,12 +36,15 @@ else: import unittest +from pyspark.mllib.common import _to_java_object_rdd from pyspark.mllib.linalg import Vector, SparseVector, DenseVector, VectorUDT, _convert_to_vector,\ - DenseMatrix, Vectors, Matrices + DenseMatrix, SparseMatrix, Vectors, Matrices from pyspark.mllib.regression import LabeledPoint from pyspark.mllib.random import RandomRDDs from pyspark.mllib.stat import Statistics +from pyspark.mllib.feature import Word2Vec from pyspark.mllib.feature import IDF +from pyspark.mllib.feature import StandardScaler from pyspark.serializers import PickleSerializer from pyspark.sql import SQLContext from pyspark.tests import ReusedPySparkTestCase as PySparkTestCase @@ -135,6 +138,61 @@ def test_sparse_vector_indexing(self): for ind in [4, -5, 7.8]: self.assertRaises(ValueError, sv.__getitem__, ind) + def test_matrix_indexing(self): + mat = DenseMatrix(3, 2, [0, 1, 4, 6, 8, 10]) + expected = [[0, 6], [1, 8], [4, 10]] + for i in range(3): + for j in range(2): + self.assertEquals(mat[i, j], expected[i][j]) + + def test_sparse_matrix(self): + # Test sparse matrix creation. + sm1 = SparseMatrix( + 3, 4, [0, 2, 2, 4, 4], [1, 2, 1, 2], [1.0, 2.0, 4.0, 5.0]) + self.assertEquals(sm1.numRows, 3) + self.assertEquals(sm1.numCols, 4) + self.assertEquals(sm1.colPtrs.tolist(), [0, 2, 2, 4, 4]) + self.assertEquals(sm1.rowIndices.tolist(), [1, 2, 1, 2]) + self.assertEquals(sm1.values.tolist(), [1.0, 2.0, 4.0, 5.0]) + + # Test indexing + expected = [ + [0, 0, 0, 0], + [1, 0, 4, 0], + [2, 0, 5, 0]] + + for i in range(3): + for j in range(4): + self.assertEquals(expected[i][j], sm1[i, j]) + self.assertTrue(array_equal(sm1.toArray(), expected)) + + # Test conversion to dense and sparse. + smnew = sm1.toDense().toSparse() + self.assertEquals(sm1.numRows, smnew.numRows) + self.assertEquals(sm1.numCols, smnew.numCols) + self.assertTrue(array_equal(sm1.colPtrs, smnew.colPtrs)) + self.assertTrue(array_equal(sm1.rowIndices, smnew.rowIndices)) + self.assertTrue(array_equal(sm1.values, smnew.values)) + + sm1t = SparseMatrix( + 3, 4, [0, 2, 3, 5], [0, 1, 2, 0, 2], [3.0, 2.0, 4.0, 9.0, 8.0], + isTransposed=True) + self.assertEquals(sm1t.numRows, 3) + self.assertEquals(sm1t.numCols, 4) + self.assertEquals(sm1t.colPtrs.tolist(), [0, 2, 3, 5]) + self.assertEquals(sm1t.rowIndices.tolist(), [0, 1, 2, 0, 2]) + self.assertEquals(sm1t.values.tolist(), [3.0, 2.0, 4.0, 9.0, 8.0]) + + expected = [ + [3, 2, 0, 0], + [0, 0, 4, 0], + [9, 0, 8, 0]] + + for i in range(3): + for j in range(4): + self.assertEquals(expected[i][j], sm1t[i, j]) + self.assertTrue(array_equal(sm1t.toArray(), expected)) + class ListTests(PySparkTestCase): @@ -348,6 +406,19 @@ def test_col_with_different_rdds(self): summary = Statistics.colStats(data) self.assertEqual(10, summary.count()) + def test_col_norms(self): + data = RandomRDDs.normalVectorRDD(self.sc, 1000, 10, 10) + summary = Statistics.colStats(data) + self.assertEqual(10, len(summary.normL1())) + self.assertEqual(10, len(summary.normL2())) + + data2 = self.sc.parallelize(xrange(10)).map(lambda x: Vectors.dense(x)) + summary2 = Statistics.colStats(data2) + self.assertEqual(array([45.0]), summary2.normL1()) + import math + expectedNormL2 = math.sqrt(sum(map(lambda x: x*x, xrange(10)))) + self.assertTrue(math.fabs(summary2.normL2()[0] - expectedNormL2) < 1e-14) + class VectorUDTTests(PySparkTestCase): @@ -622,6 +693,12 @@ def test_right_number_of_results(self): self.assertIsNotNone(chi[1000]) +class SerDeTest(PySparkTestCase): + def test_to_java_object_rdd(self): # SPARK-6660 + data = RandomRDDs.uniformRDD(self.sc, 10, 5, seed=0L) + self.assertEqual(_to_java_object_rdd(data).count(), 10) + + class FeatureTest(PySparkTestCase): def test_idf_model(self): data = [ @@ -634,6 +711,64 @@ def test_idf_model(self): idf = model.idf() self.assertEqual(len(idf), 11) + +class Word2VecTests(PySparkTestCase): + def test_word2vec_setters(self): + data = [ + ["I", "have", "a", "pen"], + ["I", "like", "soccer", "very", "much"], + ["I", "live", "in", "Tokyo"] + ] + model = Word2Vec() \ + .setVectorSize(2) \ + .setLearningRate(0.01) \ + .setNumPartitions(2) \ + .setNumIterations(10) \ + .setSeed(1024) \ + .setMinCount(3) + self.assertEquals(model.vectorSize, 2) + self.assertTrue(model.learningRate < 0.02) + self.assertEquals(model.numPartitions, 2) + self.assertEquals(model.numIterations, 10) + self.assertEquals(model.seed, 1024) + self.assertEquals(model.minCount, 3) + + def test_word2vec_get_vectors(self): + data = [ + ["a", "b", "c", "d", "e", "f", "g"], + ["a", "b", "c", "d", "e", "f"], + ["a", "b", "c", "d", "e"], + ["a", "b", "c", "d"], + ["a", "b", "c"], + ["a", "b"], + ["a"] + ] + model = Word2Vec().fit(self.sc.parallelize(data)) + self.assertEquals(len(model.getVectors()), 3) + + +class StandardScalerTests(PySparkTestCase): + def test_model_setters(self): + data = [ + [1.0, 2.0, 3.0], + [2.0, 3.0, 4.0], + [3.0, 4.0, 5.0] + ] + model = StandardScaler().fit(self.sc.parallelize(data)) + self.assertIsNotNone(model.setWithMean(True)) + self.assertIsNotNone(model.setWithStd(True)) + self.assertEqual(model.transform([1.0, 2.0, 3.0]), DenseVector([-1.0, -1.0, -1.0])) + + def test_model_transform(self): + data = [ + [1.0, 2.0, 3.0], + [2.0, 3.0, 4.0], + [3.0, 4.0, 5.0] + ] + model = StandardScaler().fit(self.sc.parallelize(data)) + self.assertEqual(model.transform([1.0, 2.0, 3.0]), DenseVector([1.0, 2.0, 3.0])) + + if __name__ == "__main__": if not _have_scipy: print "NOTE: Skipping SciPy tests as it does not seem to be installed" diff --git a/python/pyspark/mllib/tree.py b/python/pyspark/mllib/tree.py index bf288d76447bd..a7a4d2aaf855b 100644 --- a/python/pyspark/mllib/tree.py +++ b/python/pyspark/mllib/tree.py @@ -286,21 +286,18 @@ def trainClassifier(cls, data, numClasses, categoricalFeaturesInfo, numTrees, :param numTrees: Number of trees in the random forest. :param featureSubsetStrategy: Number of features to consider for splits at each node. - Supported: "auto" (default), "all", "sqrt", "log2", - "onethird". - If "auto" is set, this parameter is set based on - numTrees: - if numTrees == 1, set to "all"; - if numTrees > 1 (forest) set to "sqrt". - :param impurity: Criterion used for information gain - calculation. + Supported: "auto" (default), "all", "sqrt", "log2", "onethird". + If "auto" is set, this parameter is set based on numTrees: + if numTrees == 1, set to "all"; + if numTrees > 1 (forest) set to "sqrt". + :param impurity: Criterion used for information gain calculation. Supported values: "gini" (recommended) or "entropy". :param maxDepth: Maximum depth of the tree. E.g., depth 0 means 1 leaf node; depth 1 means 1 internal node + 2 leaf nodes. (default: 4) :param maxBins: maximum number of bins used for splitting features - (default: 100) + (default: 100) :param seed: Random seed for bootstrapping and choosing feature subsets. :return: RandomForestModel that can be used for prediction @@ -365,13 +362,10 @@ def trainRegressor(cls, data, categoricalFeaturesInfo, numTrees, featureSubsetSt :param numTrees: Number of trees in the random forest. :param featureSubsetStrategy: Number of features to consider for splits at each node. - Supported: "auto" (default), "all", "sqrt", "log2", - "onethird". - If "auto" is set, this parameter is set based on - numTrees: - if numTrees == 1, set to "all"; - if numTrees > 1 (forest) set to "onethird" for - regression. + Supported: "auto" (default), "all", "sqrt", "log2", "onethird". + If "auto" is set, this parameter is set based on numTrees: + if numTrees == 1, set to "all"; + if numTrees > 1 (forest) set to "onethird" for regression. :param impurity: Criterion used for information gain calculation. Supported values: "variance". diff --git a/python/pyspark/rdd.py b/python/pyspark/rdd.py index c337a43c8a7fc..c9ac95d117574 100644 --- a/python/pyspark/rdd.py +++ b/python/pyspark/rdd.py @@ -41,7 +41,7 @@ from pyspark.storagelevel import StorageLevel from pyspark.resultiterable import ResultIterable from pyspark.shuffle import Aggregator, InMemoryMerger, ExternalMerger, \ - get_used_memory, ExternalSorter + get_used_memory, ExternalSorter, ExternalGroupBy from pyspark.traceback_utils import SCCallSiteSync from py4j.java_collections import ListConverter, MapConverter @@ -113,6 +113,7 @@ def _parse_memory(s): def _load_from_socket(port, serializer): sock = socket.socket() + sock.settimeout(3) try: sock.connect(("localhost", port)) rf = sock.makefile("rb", 65536) @@ -572,8 +573,8 @@ def sortByKey(self, ascending=True, numPartitions=None, keyfunc=lambda x: x): if numPartitions is None: numPartitions = self._defaultReducePartitions() - spill = (self.ctx._conf.get("spark.shuffle.spill", 'True').lower() == 'true') - memory = _parse_memory(self.ctx._conf.get("spark.python.worker.memory", "512m")) + spill = self._can_spill() + memory = self._memory_limit() serializer = self._jrdd_deserializer def sortPartition(iterator): @@ -594,7 +595,7 @@ def sortPartition(iterator): maxSampleSize = numPartitions * 20.0 # constant from Spark's RangePartitioner fraction = min(maxSampleSize / max(rddSize, 1), 1.0) samples = self.sample(False, fraction, 1).map(lambda (k, v): k).collect() - samples = sorted(samples, reverse=(not ascending), key=keyfunc) + samples = sorted(samples, key=keyfunc) # we have numPartitions many parts but one of the them has # an implicit boundary @@ -1698,10 +1699,8 @@ def combineByKey(self, createCombiner, mergeValue, mergeCombiners, numPartitions = self._defaultReducePartitions() serializer = self.ctx.serializer - spill = (self.ctx._conf.get("spark.shuffle.spill", 'True').lower() - == 'true') - memory = _parse_memory(self.ctx._conf.get( - "spark.python.worker.memory", "512m")) + spill = self._can_spill() + memory = self._memory_limit() agg = Aggregator(createCombiner, mergeValue, mergeCombiners) def combineLocally(iterator): @@ -1754,21 +1753,28 @@ def createZero(): return self.combineByKey(lambda v: func(createZero(), v), func, func, numPartitions) + def _can_spill(self): + return self.ctx._conf.get("spark.shuffle.spill", "True").lower() == "true" + + def _memory_limit(self): + return _parse_memory(self.ctx._conf.get("spark.python.worker.memory", "512m")) + # TODO: support variant with custom partitioner def groupByKey(self, numPartitions=None): """ Group the values for each key in the RDD into a single sequence. - Hash-partitions the resulting RDD with into numPartitions partitions. + Hash-partitions the resulting RDD with numPartitions partitions. Note: If you are grouping in order to perform an aggregation (such as a sum or average) over each key, using reduceByKey or aggregateByKey will provide much better performance. >>> x = sc.parallelize([("a", 1), ("b", 1), ("a", 1)]) - >>> map((lambda (x,y): (x, list(y))), sorted(x.groupByKey().collect())) + >>> sorted(x.groupByKey().mapValues(len).collect()) + [('a', 2), ('b', 1)] + >>> sorted(x.groupByKey().mapValues(list).collect()) [('a', [1, 1]), ('b', [1])] """ - def createCombiner(x): return [x] @@ -1780,8 +1786,27 @@ def mergeCombiners(a, b): a.extend(b) return a - return self.combineByKey(createCombiner, mergeValue, mergeCombiners, - numPartitions).mapValues(lambda x: ResultIterable(x)) + spill = self._can_spill() + memory = self._memory_limit() + serializer = self._jrdd_deserializer + agg = Aggregator(createCombiner, mergeValue, mergeCombiners) + + def combine(iterator): + merger = ExternalMerger(agg, memory * 0.9, serializer) \ + if spill else InMemoryMerger(agg) + merger.mergeValues(iterator) + return merger.iteritems() + + locally_combined = self.mapPartitions(combine, preservesPartitioning=True) + shuffled = locally_combined.partitionBy(numPartitions) + + def groupByKey(it): + merger = ExternalGroupBy(agg, memory, serializer)\ + if spill else InMemoryMerger(agg) + merger.mergeCombiners(it) + return merger.iteritems() + + return shuffled.mapPartitions(groupByKey, True).mapValues(ResultIterable) def flatMapValues(self, f): """ @@ -2208,7 +2233,7 @@ def toLocalIterator(self): def _prepare_for_python_RDD(sc, command, obj=None): # the serialized command will be compressed by broadcast ser = CloudPickleSerializer() - pickled_command = ser.dumps(command) + pickled_command = ser.dumps((command, sys.version_info[:2])) if len(pickled_command) > (1 << 20): # 1M broadcast = sc.broadcast(pickled_command) pickled_command = ser.dumps(broadcast) diff --git a/python/pyspark/resultiterable.py b/python/pyspark/resultiterable.py index ef04c82866e6c..1ab5ce14c3531 100644 --- a/python/pyspark/resultiterable.py +++ b/python/pyspark/resultiterable.py @@ -15,15 +15,16 @@ # limitations under the License. # -__all__ = ["ResultIterable"] - import collections +__all__ = ["ResultIterable"] + class ResultIterable(collections.Iterable): """ - A special result iterable. This is used because the standard iterator can not be pickled + A special result iterable. This is used because the standard + iterator can not be pickled """ def __init__(self, data): diff --git a/python/pyspark/serializers.py b/python/pyspark/serializers.py index 0ffb41d02f6f6..4afa82f4b2973 100644 --- a/python/pyspark/serializers.py +++ b/python/pyspark/serializers.py @@ -220,6 +220,29 @@ def __repr__(self): return "BatchedSerializer(%s, %d)" % (str(self.serializer), self.batchSize) +class FlattenedValuesSerializer(BatchedSerializer): + + """ + Serializes a stream of list of pairs, split the list of values + which contain more than a certain number of objects to make them + have similar sizes. + """ + def __init__(self, serializer, batchSize=10): + BatchedSerializer.__init__(self, serializer, batchSize) + + def _batched(self, iterator): + n = self.batchSize + for key, values in iterator: + for i in xrange(0, len(values), n): + yield key, values[i:i + n] + + def load_stream(self, stream): + return self.serializer.load_stream(stream) + + def __repr__(self): + return "FlattenedValuesSerializer(%d)" % self.batchSize + + class AutoBatchedSerializer(BatchedSerializer): """ Choose the size of batch automatically based on the size of object @@ -251,7 +274,7 @@ def __eq__(self, other): return (isinstance(other, AutoBatchedSerializer) and other.serializer == self.serializer and other.bestSize == self.bestSize) - def __str__(self): + def __repr__(self): return "AutoBatchedSerializer(%s)" % str(self.serializer) diff --git a/python/pyspark/shell.py b/python/pyspark/shell.py index 1a02fece9c5a5..81aa970a32f76 100644 --- a/python/pyspark/shell.py +++ b/python/pyspark/shell.py @@ -53,9 +53,9 @@ try: # Try to access HiveConf, it will raise exception if Hive is not added sc._jvm.org.apache.hadoop.hive.conf.HiveConf() - sqlCtx = HiveContext(sc) + sqlCtx = sqlContext = HiveContext(sc) except py4j.protocol.Py4JError: - sqlCtx = SQLContext(sc) + sqlCtx = sqlContext = SQLContext(sc) print("""Welcome to ____ __ @@ -68,7 +68,7 @@ platform.python_version(), platform.python_build()[0], platform.python_build()[1])) -print("SparkContext available as sc, %s available as sqlCtx." % sqlCtx.__class__.__name__) +print("SparkContext available as sc, %s available as sqlContext." % sqlContext.__class__.__name__) if add_files is not None: print("Warning: ADD_FILES environment variable is deprecated, use --py-files argument instead") diff --git a/python/pyspark/shuffle.py b/python/pyspark/shuffle.py index 10a7ccd502000..8a6fc627eb383 100644 --- a/python/pyspark/shuffle.py +++ b/python/pyspark/shuffle.py @@ -16,28 +16,35 @@ # import os -import sys import platform import shutil import warnings import gc import itertools +import operator import random import pyspark.heapq3 as heapq -from pyspark.serializers import AutoBatchedSerializer, PickleSerializer +from pyspark.serializers import BatchedSerializer, PickleSerializer, FlattenedValuesSerializer, \ + CompressedSerializer, AutoBatchedSerializer + try: import psutil + process = None + def get_used_memory(): """ Return the used memory in MB """ - process = psutil.Process(os.getpid()) + global process + if process is None or process._pid != os.getpid(): + process = psutil.Process(os.getpid()) if hasattr(process, "memory_info"): info = process.memory_info() else: info = process.get_memory_info() return info.rss >> 20 + except ImportError: def get_used_memory(): @@ -46,6 +53,7 @@ def get_used_memory(): for line in open('/proc/self/status'): if line.startswith('VmRSS:'): return int(line.split()[1]) >> 10 + else: warnings.warn("Please install psutil to have better " "support with spilling") @@ -54,6 +62,7 @@ def get_used_memory(): rss = resource.getrusage(resource.RUSAGE_SELF).ru_maxrss return rss >> 20 # TODO: support windows + return 0 @@ -148,10 +157,16 @@ def mergeCombiners(self, iterator): d[k] = comb(d[k], v) if k in d else v def iteritems(self): - """ Return the merged items ad iterator """ + """ Return the merged items as iterator """ return self.data.iteritems() +def _compressed_serializer(self, serializer=None): + # always use PickleSerializer to simplify implementation + ser = PickleSerializer() + return AutoBatchedSerializer(CompressedSerializer(ser)) + + class ExternalMerger(Merger): """ @@ -173,7 +188,7 @@ class ExternalMerger(Merger): dict. Repeat this again until combine all the items. - Before return any items, it will load each partition and - combine them seperately. Yield them before loading next + combine them separately. Yield them before loading next partition. - During loading a partition, if the memory goes over limit, @@ -182,7 +197,7 @@ class ExternalMerger(Merger): `data` and `pdata` are used to hold the merged items in memory. At first, all the data are merged into `data`. Once the used - memory goes over limit, the items in `data` are dumped indo + memory goes over limit, the items in `data` are dumped into disks, `data` will be cleared, all rest of items will be merged into `pdata` and then dumped into disks. Before returning, all the items in `pdata` will be dumped into disks. @@ -193,16 +208,16 @@ class ExternalMerger(Merger): >>> agg = SimpleAggregator(lambda x, y: x + y) >>> merger = ExternalMerger(agg, 10) >>> N = 10000 - >>> merger.mergeValues(zip(xrange(N), xrange(N)) * 10) + >>> merger.mergeValues(zip(xrange(N), xrange(N))) >>> assert merger.spills > 0 >>> sum(v for k,v in merger.iteritems()) - 499950000 + 49995000 >>> merger = ExternalMerger(agg, 10) - >>> merger.mergeCombiners(zip(xrange(N), xrange(N)) * 10) + >>> merger.mergeCombiners(zip(xrange(N), xrange(N))) >>> assert merger.spills > 0 >>> sum(v for k,v in merger.iteritems()) - 499950000 + 49995000 """ # the max total partitions created recursively @@ -212,8 +227,7 @@ def __init__(self, aggregator, memory_limit=512, serializer=None, localdirs=None, scale=1, partitions=59, batch=1000): Merger.__init__(self, aggregator) self.memory_limit = memory_limit - # default serializer is only used for tests - self.serializer = serializer or AutoBatchedSerializer(PickleSerializer()) + self.serializer = _compressed_serializer(serializer) self.localdirs = localdirs or _get_local_dirs(str(id(self))) # number of partitions when spill data into disks self.partitions = partitions @@ -221,7 +235,7 @@ def __init__(self, aggregator, memory_limit=512, serializer=None, self.batch = batch # scale is used to scale down the hash of key for recursive hash map self.scale = scale - # unpartitioned merged data + # un-partitioned merged data self.data = {} # partitioned merged data, list of dicts self.pdata = [] @@ -244,72 +258,63 @@ def _next_limit(self): def mergeValues(self, iterator): """ Combine the items by creator and combiner """ - iterator = iter(iterator) # speedup attribute lookup creator, comb = self.agg.createCombiner, self.agg.mergeValue - d, c, batch = self.data, 0, self.batch + c, data, pdata, hfun, batch = 0, self.data, self.pdata, self._partition, self.batch + limit = self.memory_limit for k, v in iterator: + d = pdata[hfun(k)] if pdata else data d[k] = comb(d[k], v) if k in d else creator(v) c += 1 - if c % batch == 0 and get_used_memory() > self.memory_limit: - self._spill() - self._partitioned_mergeValues(iterator, self._next_limit()) - break + if c >= batch: + if get_used_memory() >= limit: + self._spill() + limit = self._next_limit() + batch /= 2 + c = 0 + else: + batch *= 1.5 + + if get_used_memory() >= limit: + self._spill() def _partition(self, key): """ Return the partition for key """ return hash((key, self._seed)) % self.partitions - def _partitioned_mergeValues(self, iterator, limit=0): - """ Partition the items by key, then combine them """ - # speedup attribute lookup - creator, comb = self.agg.createCombiner, self.agg.mergeValue - c, pdata, hfun, batch = 0, self.pdata, self._partition, self.batch - - for k, v in iterator: - d = pdata[hfun(k)] - d[k] = comb(d[k], v) if k in d else creator(v) - if not limit: - continue - - c += 1 - if c % batch == 0 and get_used_memory() > limit: - self._spill() - limit = self._next_limit() + def _object_size(self, obj): + """ How much of memory for this obj, assume that all the objects + consume similar bytes of memory + """ + return 1 - def mergeCombiners(self, iterator, check=True): + def mergeCombiners(self, iterator, limit=None): """ Merge (K,V) pair by mergeCombiner """ - iterator = iter(iterator) + if limit is None: + limit = self.memory_limit # speedup attribute lookup - d, comb, batch = self.data, self.agg.mergeCombiners, self.batch - c = 0 - for k, v in iterator: - d[k] = comb(d[k], v) if k in d else v - if not check: - continue - - c += 1 - if c % batch == 0 and get_used_memory() > self.memory_limit: - self._spill() - self._partitioned_mergeCombiners(iterator, self._next_limit()) - break - - def _partitioned_mergeCombiners(self, iterator, limit=0): - """ Partition the items by key, then merge them """ - comb, pdata = self.agg.mergeCombiners, self.pdata - c, hfun = 0, self._partition + comb, hfun, objsize = self.agg.mergeCombiners, self._partition, self._object_size + c, data, pdata, batch = 0, self.data, self.pdata, self.batch for k, v in iterator: - d = pdata[hfun(k)] + d = pdata[hfun(k)] if pdata else data d[k] = comb(d[k], v) if k in d else v if not limit: continue - c += 1 - if c % self.batch == 0 and get_used_memory() > limit: - self._spill() - limit = self._next_limit() + c += objsize(v) + if c > batch: + if get_used_memory() > limit: + self._spill() + limit = self._next_limit() + batch /= 2 + c = 0 + else: + batch *= 1.5 + + if limit and get_used_memory() >= limit: + self._spill() def _spill(self): """ @@ -335,7 +340,7 @@ def _spill(self): for k, v in self.data.iteritems(): h = self._partition(k) - # put one item in batch, make it compatitable with load_stream + # put one item in batch, make it compatible with load_stream # it will increase the memory if dump them in batch self.serializer.dump_stream([(k, v)], streams[h]) @@ -344,7 +349,7 @@ def _spill(self): s.close() self.data.clear() - self.pdata = [{} for i in range(self.partitions)] + self.pdata.extend([{} for i in range(self.partitions)]) else: for i in range(self.partitions): @@ -370,29 +375,12 @@ def _external_items(self): assert not self.data if any(self.pdata): self._spill() - hard_limit = self._next_limit() + # disable partitioning and spilling when merge combiners from disk + self.pdata = [] try: for i in range(self.partitions): - self.data = {} - for j in range(self.spills): - path = self._get_spill_dir(j) - p = os.path.join(path, str(i)) - # do not check memory during merging - self.mergeCombiners(self.serializer.load_stream(open(p)), - False) - - # limit the total partitions - if (self.scale * self.partitions < self.MAX_TOTAL_PARTITIONS - and j < self.spills - 1 - and get_used_memory() > hard_limit): - self.data.clear() # will read from disk again - gc.collect() # release the memory as much as possible - for v in self._recursive_merged_items(i): - yield v - return - - for v in self.data.iteritems(): + for v in self._merged_items(i): yield v self.data.clear() @@ -400,53 +388,56 @@ def _external_items(self): for j in range(self.spills): path = self._get_spill_dir(j) os.remove(os.path.join(path, str(i))) - finally: self._cleanup() - def _cleanup(self): - """ Clean up all the files in disks """ - for d in self.localdirs: - shutil.rmtree(d, True) + def _merged_items(self, index): + self.data = {} + limit = self._next_limit() + for j in range(self.spills): + path = self._get_spill_dir(j) + p = os.path.join(path, str(index)) + # do not check memory during merging + self.mergeCombiners(self.serializer.load_stream(open(p)), 0) + + # limit the total partitions + if (self.scale * self.partitions < self.MAX_TOTAL_PARTITIONS + and j < self.spills - 1 + and get_used_memory() > limit): + self.data.clear() # will read from disk again + gc.collect() # release the memory as much as possible + return self._recursive_merged_items(index) - def _recursive_merged_items(self, start): + return self.data.iteritems() + + def _recursive_merged_items(self, index): """ merge the partitioned items and return the as iterator If one partition can not be fit in memory, then them will be partitioned and merged recursively. """ - # make sure all the data are dumps into disks. - assert not self.data - if any(self.pdata): - self._spill() - assert self.spills > 0 - - for i in range(start, self.partitions): - subdirs = [os.path.join(d, "parts", str(i)) - for d in self.localdirs] - m = ExternalMerger(self.agg, self.memory_limit, self.serializer, - subdirs, self.scale * self.partitions, self.partitions) - m.pdata = [{} for _ in range(self.partitions)] - limit = self._next_limit() - - for j in range(self.spills): - path = self._get_spill_dir(j) - p = os.path.join(path, str(i)) - m._partitioned_mergeCombiners( - self.serializer.load_stream(open(p))) - - if get_used_memory() > limit: - m._spill() - limit = self._next_limit() + subdirs = [os.path.join(d, "parts", str(index)) for d in self.localdirs] + m = ExternalMerger(self.agg, self.memory_limit, self.serializer, subdirs, + self.scale * self.partitions, self.partitions, self.batch) + m.pdata = [{} for _ in range(self.partitions)] + limit = self._next_limit() + + for j in range(self.spills): + path = self._get_spill_dir(j) + p = os.path.join(path, str(index)) + m.mergeCombiners(self.serializer.load_stream(open(p)), 0) + + if get_used_memory() > limit: + m._spill() + limit = self._next_limit() - for v in m._external_items(): - yield v + return m._external_items() - # remove the merged partition - for j in range(self.spills): - path = self._get_spill_dir(j) - os.remove(os.path.join(path, str(i))) + def _cleanup(self): + """ Clean up all the files in disks """ + for d in self.localdirs: + shutil.rmtree(d, True) class ExternalSorter(object): @@ -457,6 +448,7 @@ class ExternalSorter(object): The spilling will only happen when the used memory goes above the limit. + >>> sorter = ExternalSorter(1) # 1M >>> import random >>> l = range(1024) @@ -469,7 +461,7 @@ class ExternalSorter(object): def __init__(self, memory_limit, serializer=None): self.memory_limit = memory_limit self.local_dirs = _get_local_dirs("sort") - self.serializer = serializer or AutoBatchedSerializer(PickleSerializer()) + self.serializer = _compressed_serializer(serializer) def _get_path(self, n): """ Choose one directory for spill by number n """ @@ -515,6 +507,7 @@ def sorted(self, iterator, key=None, reverse=False): limit = self._next_limit() MemoryBytesSpilled += (used_memory - get_used_memory()) << 20 DiskBytesSpilled += os.path.getsize(path) + os.unlink(path) # data will be deleted after close elif not chunks: batch = min(batch * 2, 10000) @@ -529,6 +522,310 @@ def sorted(self, iterator, key=None, reverse=False): return heapq.merge(chunks, key=key, reverse=reverse) +class ExternalList(object): + """ + ExternalList can have many items which cannot be hold in memory in + the same time. + + >>> l = ExternalList(range(100)) + >>> len(l) + 100 + >>> l.append(10) + >>> len(l) + 101 + >>> for i in range(20240): + ... l.append(i) + >>> len(l) + 20341 + >>> import pickle + >>> l2 = pickle.loads(pickle.dumps(l)) + >>> len(l2) + 20341 + >>> list(l2)[100] + 10 + """ + LIMIT = 10240 + + def __init__(self, values): + self.values = values + self.count = len(values) + self._file = None + self._ser = None + + def __getstate__(self): + if self._file is not None: + self._file.flush() + f = os.fdopen(os.dup(self._file.fileno())) + f.seek(0) + serialized = f.read() + else: + serialized = '' + return self.values, self.count, serialized + + def __setstate__(self, item): + self.values, self.count, serialized = item + if serialized: + self._open_file() + self._file.write(serialized) + else: + self._file = None + self._ser = None + + def __iter__(self): + if self._file is not None: + self._file.flush() + # read all items from disks first + with os.fdopen(os.dup(self._file.fileno()), 'r') as f: + f.seek(0) + for v in self._ser.load_stream(f): + yield v + + for v in self.values: + yield v + + def __len__(self): + return self.count + + def append(self, value): + self.values.append(value) + self.count += 1 + # dump them into disk if the key is huge + if len(self.values) >= self.LIMIT: + self._spill() + + def _open_file(self): + dirs = _get_local_dirs("objects") + d = dirs[id(self) % len(dirs)] + if not os.path.exists(d): + os.makedirs(d) + p = os.path.join(d, str(id)) + self._file = open(p, "w+", 65536) + self._ser = BatchedSerializer(CompressedSerializer(PickleSerializer()), 1024) + os.unlink(p) + + def _spill(self): + """ dump the values into disk """ + global MemoryBytesSpilled, DiskBytesSpilled + if self._file is None: + self._open_file() + + used_memory = get_used_memory() + pos = self._file.tell() + self._ser.dump_stream(self.values, self._file) + self.values = [] + gc.collect() + DiskBytesSpilled += self._file.tell() - pos + MemoryBytesSpilled += (used_memory - get_used_memory()) << 20 + + +class ExternalListOfList(ExternalList): + """ + An external list for list. + + >>> l = ExternalListOfList([[i, i] for i in range(100)]) + >>> len(l) + 200 + >>> l.append(range(10)) + >>> len(l) + 210 + >>> len(list(l)) + 210 + """ + + def __init__(self, values): + ExternalList.__init__(self, values) + self.count = sum(len(i) for i in values) + + def append(self, value): + ExternalList.append(self, value) + # already counted 1 in ExternalList.append + self.count += len(value) - 1 + + def __iter__(self): + for values in ExternalList.__iter__(self): + for v in values: + yield v + + +class GroupByKey(object): + """ + Group a sorted iterator as [(k1, it1), (k2, it2), ...] + + >>> k = [i/3 for i in range(6)] + >>> v = [[i] for i in range(6)] + >>> g = GroupByKey(iter(zip(k, v))) + >>> [(k, list(it)) for k, it in g] + [(0, [0, 1, 2]), (1, [3, 4, 5])] + """ + + def __init__(self, iterator): + self.iterator = iter(iterator) + self.next_item = None + + def __iter__(self): + return self + + def next(self): + key, value = self.next_item if self.next_item else next(self.iterator) + values = ExternalListOfList([value]) + try: + while True: + k, v = next(self.iterator) + if k != key: + self.next_item = (k, v) + break + values.append(v) + except StopIteration: + self.next_item = None + return key, values + + +class ExternalGroupBy(ExternalMerger): + + """ + Group by the items by key. If any partition of them can not been + hold in memory, it will do sort based group by. + + This class works as follows: + + - It repeatedly group the items by key and save them in one dict in + memory. + + - When the used memory goes above memory limit, it will split + the combined data into partitions by hash code, dump them + into disk, one file per partition. If the number of keys + in one partitions is smaller than 1000, it will sort them + by key before dumping into disk. + + - Then it goes through the rest of the iterator, group items + by key into different dict by hash. Until the used memory goes over + memory limit, it dump all the dicts into disks, one file per + dict. Repeat this again until combine all the items. It + also will try to sort the items by key in each partition + before dumping into disks. + + - It will yield the grouped items partitions by partitions. + If the data in one partitions can be hold in memory, then it + will load and combine them in memory and yield. + + - If the dataset in one partition cannot be hold in memory, + it will sort them first. If all the files are already sorted, + it merge them by heap.merge(), so it will do external sort + for all the files. + + - After sorting, `GroupByKey` class will put all the continuous + items with the same key as a group, yield the values as + an iterator. + """ + SORT_KEY_LIMIT = 1000 + + def flattened_serializer(self): + assert isinstance(self.serializer, BatchedSerializer) + ser = self.serializer + return FlattenedValuesSerializer(ser, 20) + + def _object_size(self, obj): + return len(obj) + + def _spill(self): + """ + dump already partitioned data into disks. + """ + global MemoryBytesSpilled, DiskBytesSpilled + path = self._get_spill_dir(self.spills) + if not os.path.exists(path): + os.makedirs(path) + + used_memory = get_used_memory() + if not self.pdata: + # The data has not been partitioned, it will iterator the + # data once, write them into different files, has no + # additional memory. It only called when the memory goes + # above limit at the first time. + + # open all the files for writing + streams = [open(os.path.join(path, str(i)), 'w') + for i in range(self.partitions)] + + # If the number of keys is small, then the overhead of sort is small + # sort them before dumping into disks + self._sorted = len(self.data) < self.SORT_KEY_LIMIT + if self._sorted: + self.serializer = self.flattened_serializer() + for k in sorted(self.data.keys()): + h = self._partition(k) + self.serializer.dump_stream([(k, self.data[k])], streams[h]) + else: + for k, v in self.data.iteritems(): + h = self._partition(k) + self.serializer.dump_stream([(k, v)], streams[h]) + + for s in streams: + DiskBytesSpilled += s.tell() + s.close() + + self.data.clear() + # self.pdata is cached in `mergeValues` and `mergeCombiners` + self.pdata.extend([{} for i in range(self.partitions)]) + + else: + for i in range(self.partitions): + p = os.path.join(path, str(i)) + with open(p, "w") as f: + # dump items in batch + if self._sorted: + # sort by key only (stable) + sorted_items = sorted(self.pdata[i].iteritems(), key=operator.itemgetter(0)) + self.serializer.dump_stream(sorted_items, f) + else: + self.serializer.dump_stream(self.pdata[i].iteritems(), f) + self.pdata[i].clear() + DiskBytesSpilled += os.path.getsize(p) + + self.spills += 1 + gc.collect() # release the memory as much as possible + MemoryBytesSpilled += (used_memory - get_used_memory()) << 20 + + def _merged_items(self, index): + size = sum(os.path.getsize(os.path.join(self._get_spill_dir(j), str(index))) + for j in range(self.spills)) + # if the memory can not hold all the partition, + # then use sort based merge. Because of compression, + # the data on disks will be much smaller than needed memory + if (size >> 20) >= self.memory_limit / 10: + return self._merge_sorted_items(index) + + self.data = {} + for j in range(self.spills): + path = self._get_spill_dir(j) + p = os.path.join(path, str(index)) + # do not check memory during merging + self.mergeCombiners(self.serializer.load_stream(open(p)), 0) + return self.data.iteritems() + + def _merge_sorted_items(self, index): + """ load a partition from disk, then sort and group by key """ + def load_partition(j): + path = self._get_spill_dir(j) + p = os.path.join(path, str(index)) + return self.serializer.load_stream(open(p, 'r', 65536)) + + disk_items = [load_partition(j) for j in range(self.spills)] + + if self._sorted: + # all the partitions are already sorted + sorted_items = heapq.merge(disk_items, key=operator.itemgetter(0)) + + else: + # Flatten the combined values, so it will not consume huge + # memory during merging sort. + ser = self.flattened_serializer() + sorter = ExternalSorter(self.memory_limit, ser) + sorted_items = sorter.sorted(itertools.chain(*disk_items), + key=operator.itemgetter(0)) + return ((k, vs) for k, vs in GroupByKey(sorted_items)) + + if __name__ == "__main__": import doctest doctest.testmod() diff --git a/python/pyspark/sql/context.py b/python/pyspark/sql/context.py index c2d81ba804110..e8529a8f8e3a4 100644 --- a/python/pyspark/sql/context.py +++ b/python/pyspark/sql/context.py @@ -37,12 +37,12 @@ __all__ = ["SQLContext", "HiveContext", "UDFRegistration"] -def _monkey_patch_RDD(sqlCtx): +def _monkey_patch_RDD(sqlContext): def toDF(self, schema=None, sampleRatio=None): """ Converts current :class:`RDD` into a :class:`DataFrame` - This is a shorthand for ``sqlCtx.createDataFrame(rdd, schema, sampleRatio)`` + This is a shorthand for ``sqlContext.createDataFrame(rdd, schema, sampleRatio)`` :param schema: a StructType or list of names of columns :param samplingRatio: the sample ratio of rows used for inferring @@ -51,7 +51,7 @@ def toDF(self, schema=None, sampleRatio=None): >>> rdd.toDF().collect() [Row(name=u'Alice', age=1)] """ - return sqlCtx.createDataFrame(self, schema, sampleRatio) + return sqlContext.createDataFrame(self, schema, sampleRatio) RDD.toDF = toDF @@ -75,13 +75,13 @@ def __init__(self, sparkContext, sqlContext=None): """Creates a new SQLContext. >>> from datetime import datetime - >>> sqlCtx = SQLContext(sc) + >>> sqlContext = SQLContext(sc) >>> allTypes = sc.parallelize([Row(i=1, s="string", d=1.0, l=1L, ... b=True, list=[1, 2, 3], dict={"s": 0}, row=Row(a=1), ... time=datetime(2014, 8, 1, 14, 1, 5))]) >>> df = allTypes.toDF() >>> df.registerTempTable("allTypes") - >>> sqlCtx.sql('select i+1, d+1, not b, list[1], dict["s"], time, row.a ' + >>> sqlContext.sql('select i+1, d+1, not b, list[1], dict["s"], time, row.a ' ... 'from allTypes where b and i > 0').collect() [Row(c0=2, c1=2.0, c2=False, c3=2, c4=0...8, 1, 14, 1, 5), a=1)] >>> df.map(lambda x: (x.i, x.s, x.d, x.l, x.b, x.time, @@ -133,18 +133,18 @@ def registerFunction(self, name, f, returnType=StringType()): :param samplingRatio: lambda function :param returnType: a :class:`DataType` object - >>> sqlCtx.registerFunction("stringLengthString", lambda x: len(x)) - >>> sqlCtx.sql("SELECT stringLengthString('test')").collect() + >>> sqlContext.registerFunction("stringLengthString", lambda x: len(x)) + >>> sqlContext.sql("SELECT stringLengthString('test')").collect() [Row(c0=u'4')] >>> from pyspark.sql.types import IntegerType - >>> sqlCtx.registerFunction("stringLengthInt", lambda x: len(x), IntegerType()) - >>> sqlCtx.sql("SELECT stringLengthInt('test')").collect() + >>> sqlContext.registerFunction("stringLengthInt", lambda x: len(x), IntegerType()) + >>> sqlContext.sql("SELECT stringLengthInt('test')").collect() [Row(c0=4)] >>> from pyspark.sql.types import IntegerType - >>> sqlCtx.udf.register("stringLengthInt", lambda x: len(x), IntegerType()) - >>> sqlCtx.sql("SELECT stringLengthInt('test')").collect() + >>> sqlContext.udf.register("stringLengthInt", lambda x: len(x), IntegerType()) + >>> sqlContext.sql("SELECT stringLengthInt('test')").collect() [Row(c0=4)] """ func = lambda _, it: imap(lambda x: f(*x), it) @@ -229,26 +229,26 @@ def createDataFrame(self, data, schema=None, samplingRatio=None): :param samplingRatio: the sample ratio of rows used for inferring >>> l = [('Alice', 1)] - >>> sqlCtx.createDataFrame(l).collect() + >>> sqlContext.createDataFrame(l).collect() [Row(_1=u'Alice', _2=1)] - >>> sqlCtx.createDataFrame(l, ['name', 'age']).collect() + >>> sqlContext.createDataFrame(l, ['name', 'age']).collect() [Row(name=u'Alice', age=1)] >>> d = [{'name': 'Alice', 'age': 1}] - >>> sqlCtx.createDataFrame(d).collect() + >>> sqlContext.createDataFrame(d).collect() [Row(age=1, name=u'Alice')] >>> rdd = sc.parallelize(l) - >>> sqlCtx.createDataFrame(rdd).collect() + >>> sqlContext.createDataFrame(rdd).collect() [Row(_1=u'Alice', _2=1)] - >>> df = sqlCtx.createDataFrame(rdd, ['name', 'age']) + >>> df = sqlContext.createDataFrame(rdd, ['name', 'age']) >>> df.collect() [Row(name=u'Alice', age=1)] >>> from pyspark.sql import Row >>> Person = Row('name', 'age') >>> person = rdd.map(lambda r: Person(*r)) - >>> df2 = sqlCtx.createDataFrame(person) + >>> df2 = sqlContext.createDataFrame(person) >>> df2.collect() [Row(name=u'Alice', age=1)] @@ -256,11 +256,11 @@ def createDataFrame(self, data, schema=None, samplingRatio=None): >>> schema = StructType([ ... StructField("name", StringType(), True), ... StructField("age", IntegerType(), True)]) - >>> df3 = sqlCtx.createDataFrame(rdd, schema) + >>> df3 = sqlContext.createDataFrame(rdd, schema) >>> df3.collect() [Row(name=u'Alice', age=1)] - >>> sqlCtx.createDataFrame(df.toPandas()).collect() # doctest: +SKIP + >>> sqlContext.createDataFrame(df.toPandas()).collect() # doctest: +SKIP [Row(name=u'Alice', age=1)] """ if isinstance(data, DataFrame): @@ -316,7 +316,7 @@ def registerDataFrameAsTable(self, df, tableName): Temporary tables exist only during the lifetime of this instance of :class:`SQLContext`. - >>> sqlCtx.registerDataFrameAsTable(df, "table1") + >>> sqlContext.registerDataFrameAsTable(df, "table1") """ if (df.__class__ is DataFrame): self._ssql_ctx.registerDataFrameAsTable(df._jdf, tableName) @@ -330,7 +330,7 @@ def parquetFile(self, *paths): >>> parquetFile = tempfile.mkdtemp() >>> shutil.rmtree(parquetFile) >>> df.saveAsParquetFile(parquetFile) - >>> df2 = sqlCtx.parquetFile(parquetFile) + >>> df2 = sqlContext.parquetFile(parquetFile) >>> sorted(df.collect()) == sorted(df2.collect()) True """ @@ -352,7 +352,7 @@ def jsonFile(self, path, schema=None, samplingRatio=1.0): >>> shutil.rmtree(jsonFile) >>> with open(jsonFile, 'w') as f: ... f.writelines(jsonStrings) - >>> df1 = sqlCtx.jsonFile(jsonFile) + >>> df1 = sqlContext.jsonFile(jsonFile) >>> df1.printSchema() root |-- field1: long (nullable = true) @@ -365,7 +365,7 @@ def jsonFile(self, path, schema=None, samplingRatio=1.0): ... StructField("field2", StringType()), ... StructField("field3", ... StructType([StructField("field5", ArrayType(IntegerType()))]))]) - >>> df2 = sqlCtx.jsonFile(jsonFile, schema) + >>> df2 = sqlContext.jsonFile(jsonFile, schema) >>> df2.printSchema() root |-- field2: string (nullable = true) @@ -386,11 +386,11 @@ def jsonRDD(self, rdd, schema=None, samplingRatio=1.0): If the schema is provided, applies the given schema to this JSON dataset. Otherwise, it samples the dataset with ratio ``samplingRatio`` to determine the schema. - >>> df1 = sqlCtx.jsonRDD(json) + >>> df1 = sqlContext.jsonRDD(json) >>> df1.first() Row(field1=1, field2=u'row1', field3=Row(field4=11, field5=None), field6=None) - >>> df2 = sqlCtx.jsonRDD(json, df1.schema) + >>> df2 = sqlContext.jsonRDD(json, df1.schema) >>> df2.first() Row(field1=1, field2=u'row1', field3=Row(field4=11, field5=None), field6=None) @@ -400,7 +400,7 @@ def jsonRDD(self, rdd, schema=None, samplingRatio=1.0): ... StructField("field3", ... StructType([StructField("field5", ArrayType(IntegerType()))])) ... ]) - >>> df3 = sqlCtx.jsonRDD(json, schema) + >>> df3 = sqlContext.jsonRDD(json, schema) >>> df3.first() Row(field2=u'row1', field3=Row(field5=None)) """ @@ -480,8 +480,8 @@ def createExternalTable(self, tableName, path=None, source=None, def sql(self, sqlQuery): """Returns a :class:`DataFrame` representing the result of the given query. - >>> sqlCtx.registerDataFrameAsTable(df, "table1") - >>> df2 = sqlCtx.sql("SELECT field1 AS f1, field2 as f2 from table1") + >>> sqlContext.registerDataFrameAsTable(df, "table1") + >>> df2 = sqlContext.sql("SELECT field1 AS f1, field2 as f2 from table1") >>> df2.collect() [Row(f1=1, f2=u'row1'), Row(f1=2, f2=u'row2'), Row(f1=3, f2=u'row3')] """ @@ -490,8 +490,8 @@ def sql(self, sqlQuery): def table(self, tableName): """Returns the specified table as a :class:`DataFrame`. - >>> sqlCtx.registerDataFrameAsTable(df, "table1") - >>> df2 = sqlCtx.table("table1") + >>> sqlContext.registerDataFrameAsTable(df, "table1") + >>> df2 = sqlContext.table("table1") >>> sorted(df.collect()) == sorted(df2.collect()) True """ @@ -505,8 +505,8 @@ def tables(self, dbName=None): The returned DataFrame has two columns: ``tableName`` and ``isTemporary`` (a column with :class:`BooleanType` indicating if a table is a temporary one or not). - >>> sqlCtx.registerDataFrameAsTable(df, "table1") - >>> df2 = sqlCtx.tables() + >>> sqlContext.registerDataFrameAsTable(df, "table1") + >>> df2 = sqlContext.tables() >>> df2.filter("tableName = 'table1'").first() Row(tableName=u'table1', isTemporary=True) """ @@ -520,10 +520,10 @@ def tableNames(self, dbName=None): If ``dbName`` is not specified, the current database will be used. - >>> sqlCtx.registerDataFrameAsTable(df, "table1") - >>> "table1" in sqlCtx.tableNames() + >>> sqlContext.registerDataFrameAsTable(df, "table1") + >>> "table1" in sqlContext.tableNames() True - >>> "table1" in sqlCtx.tableNames("db") + >>> "table1" in sqlContext.tableNames("db") True """ if dbName is None: @@ -574,15 +574,24 @@ def _ssql_ctx(self): def _get_hive_ctx(self): return self._jvm.HiveContext(self._jsc.sc()) + def refreshTable(self, tableName): + """Invalidate and refresh all the cached the metadata of the given + table. For performance reasons, Spark SQL or the external data source + library it uses might cache certain metadata about a table, such as the + location of blocks. When those change outside of Spark SQL, users should + call this function to invalidate the cache. + """ + self._ssql_ctx.refreshTable(tableName) + class UDFRegistration(object): """Wrapper for user-defined function registration.""" - def __init__(self, sqlCtx): - self.sqlCtx = sqlCtx + def __init__(self, sqlContext): + self.sqlContext = sqlContext def register(self, name, f, returnType=StringType()): - return self.sqlCtx.registerFunction(name, f, returnType) + return self.sqlContext.registerFunction(name, f, returnType) register.__doc__ = SQLContext.registerFunction.__doc__ @@ -595,13 +604,12 @@ def _test(): globs = pyspark.sql.context.__dict__.copy() sc = SparkContext('local[4]', 'PythonTest') globs['sc'] = sc - globs['sqlCtx'] = sqlCtx = SQLContext(sc) + globs['sqlContext'] = SQLContext(sc) globs['rdd'] = rdd = sc.parallelize( [Row(field1=1, field2="row1"), Row(field1=2, field2="row2"), Row(field1=3, field2="row3")] ) - _monkey_patch_RDD(sqlCtx) globs['df'] = rdd.toDF() jsonStrings = [ '{"field1": 1, "field2": "row1", "field3":{"field4":11}}', diff --git a/python/pyspark/sql/dataframe.py b/python/pyspark/sql/dataframe.py index c30326ebd133e..ef91a9c4f522d 100644 --- a/python/pyspark/sql/dataframe.py +++ b/python/pyspark/sql/dataframe.py @@ -110,7 +110,7 @@ def saveAsParquetFile(self, path): >>> parquetFile = tempfile.mkdtemp() >>> shutil.rmtree(parquetFile) >>> df.saveAsParquetFile(parquetFile) - >>> df2 = sqlCtx.parquetFile(parquetFile) + >>> df2 = sqlContext.parquetFile(parquetFile) >>> sorted(df2.collect()) == sorted(df.collect()) True """ @@ -123,7 +123,7 @@ def registerTempTable(self, name): that was used to create this :class:`DataFrame`. >>> df.registerTempTable("people") - >>> df2 = sqlCtx.sql("select * from people") + >>> df2 = sqlContext.sql("select * from people") >>> sorted(df.collect()) == sorted(df2.collect()) True """ @@ -1180,7 +1180,7 @@ def _test(): globs = pyspark.sql.dataframe.__dict__.copy() sc = SparkContext('local[4]', 'PythonTest') globs['sc'] = sc - globs['sqlCtx'] = SQLContext(sc) + globs['sqlContext'] = SQLContext(sc) globs['df'] = sc.parallelize([(2, 'Alice'), (5, 'Bob')])\ .toDF(StructType([StructField('age', IntegerType()), StructField('name', StringType())])) diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py index 8a478fddf0e95..daeb6916b58bc 100644 --- a/python/pyspark/sql/functions.py +++ b/python/pyspark/sql/functions.py @@ -123,7 +123,8 @@ def _create_judf(self): pickled_command, broadcast_vars, env, includes = _prepare_for_python_RDD(sc, command, self) ssql_ctx = sc._jvm.SQLContext(sc._jsc.sc()) jdt = ssql_ctx.parseDataType(self.returnType.json()) - judf = sc._jvm.UserDefinedPythonFunction(f.__name__, bytearray(pickled_command), env, + fname = f.__name__ if hasattr(f, '__name__') else f.__class__.__name__ + judf = sc._jvm.UserDefinedPythonFunction(fname, bytearray(pickled_command), env, includes, sc.pythonExec, broadcast_vars, sc._javaAccumulator, jdt) return judf @@ -160,7 +161,7 @@ def _test(): globs = pyspark.sql.functions.__dict__.copy() sc = SparkContext('local[4]', 'PythonTest') globs['sc'] = sc - globs['sqlCtx'] = SQLContext(sc) + globs['sqlContext'] = SQLContext(sc) globs['df'] = sc.parallelize([Row(name='Alice', age=2), Row(name='Bob', age=5)]).toDF() (failure_count, test_count) = doctest.testmod( pyspark.sql.functions, globs=globs, diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index 258464b7f230d..b3a6a2c6a9229 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -25,6 +25,7 @@ import shutil import tempfile import pickle +import functools import py4j @@ -41,6 +42,7 @@ from pyspark.sql.types import * from pyspark.sql.types import UserDefinedType, _infer_type from pyspark.tests import ReusedPySparkTestCase +from pyspark.sql.functions import UserDefinedFunction class ExamplePointUDT(UserDefinedType): @@ -114,6 +116,35 @@ def tearDownClass(cls): ReusedPySparkTestCase.tearDownClass() shutil.rmtree(cls.tempdir.name, ignore_errors=True) + def test_udf_with_callable(self): + d = [Row(number=i, squared=i**2) for i in range(10)] + rdd = self.sc.parallelize(d) + data = self.sqlCtx.createDataFrame(rdd) + + class PlusFour: + def __call__(self, col): + if col is not None: + return col + 4 + + call = PlusFour() + pudf = UserDefinedFunction(call, LongType()) + res = data.select(pudf(data['number']).alias('plus_four')) + self.assertEqual(res.agg({'plus_four': 'sum'}).collect()[0][0], 85) + + def test_udf_with_partial_function(self): + d = [Row(number=i, squared=i**2) for i in range(10)] + rdd = self.sc.parallelize(d) + data = self.sqlCtx.createDataFrame(rdd) + + def some_func(col, param): + if col is not None: + return col + param + + pfunc = functools.partial(some_func, param=4) + pudf = UserDefinedFunction(pfunc, LongType()) + res = data.select(pudf(data['number']).alias('plus_four')) + self.assertEqual(res.agg({'plus_four': 'sum'}).collect()[0][0], 85) + def test_udf(self): self.sqlCtx.registerFunction("twoArgs", lambda x, y: len(x) + y, IntegerType()) [row] = self.sqlCtx.sql("SELECT twoArgs('test', 1)").collect() diff --git a/python/pyspark/sql/types.py b/python/pyspark/sql/types.py index 45eb8b945dcb0..ef76d84c00481 100644 --- a/python/pyspark/sql/types.py +++ b/python/pyspark/sql/types.py @@ -434,7 +434,7 @@ def _parse_datatype_json_string(json_string): >>> def check_datatype(datatype): ... pickled = pickle.loads(pickle.dumps(datatype)) ... assert datatype == pickled - ... scala_datatype = sqlCtx._ssql_ctx.parseDataType(datatype.json()) + ... scala_datatype = sqlContext._ssql_ctx.parseDataType(datatype.json()) ... python_datatype = _parse_datatype_json_string(scala_datatype.json()) ... assert datatype == python_datatype >>> for cls in _all_primitive_types.values(): @@ -567,8 +567,8 @@ def _infer_schema(row): elif isinstance(row, (tuple, list)): if hasattr(row, "_fields"): # namedtuple items = zip(row._fields, tuple(row)) - elif hasattr(row, "__FIELDS__"): # Row - items = zip(row.__FIELDS__, tuple(row)) + elif hasattr(row, "__fields__"): # Row + items = zip(row.__fields__, tuple(row)) else: names = ['_%d' % i for i in range(1, len(row) + 1)] items = zip(names, row) @@ -647,7 +647,7 @@ def converter(obj): if isinstance(obj, dict): return tuple(c(obj.get(n)) for n, c in zip(names, converters)) elif isinstance(obj, tuple): - if hasattr(obj, "_fields") or hasattr(obj, "__FIELDS__"): + if hasattr(obj, "_fields") or hasattr(obj, "__fields__"): return tuple(c(v) for c, v in zip(converters, obj)) elif all(isinstance(x, tuple) and len(x) == 2 for x in obj): # k-v pairs d = dict(obj) @@ -997,12 +997,13 @@ def _restore_object(dataType, obj): # same object in most cases. k = id(dataType) cls = _cached_cls.get(k) - if cls is None: + if cls is None or cls.__datatype is not dataType: # use dataType as key to avoid create multiple class cls = _cached_cls.get(dataType) if cls is None: cls = _create_cls(dataType) _cached_cls[dataType] = cls + cls.__datatype = dataType _cached_cls[k] = cls return cls(obj) @@ -1119,8 +1120,8 @@ def Dict(d): class Row(tuple): """ Row in DataFrame """ - __DATATYPE__ = dataType - __FIELDS__ = tuple(f.name for f in dataType.fields) + __datatype = dataType + __fields__ = tuple(f.name for f in dataType.fields) __slots__ = () # create property for fast access @@ -1128,22 +1129,22 @@ class Row(tuple): def asDict(self): """ Return as a dict """ - return dict((n, getattr(self, n)) for n in self.__FIELDS__) + return dict((n, getattr(self, n)) for n in self.__fields__) def __repr__(self): # call collect __repr__ for nested objects return ("Row(%s)" % ", ".join("%s=%r" % (n, getattr(self, n)) - for n in self.__FIELDS__)) + for n in self.__fields__)) def __reduce__(self): - return (_restore_object, (self.__DATATYPE__, tuple(self))) + return (_restore_object, (self.__datatype, tuple(self))) return Row def _create_row(fields, values): row = Row(*values) - row.__FIELDS__ = fields + row.__fields__ = fields return row @@ -1183,7 +1184,7 @@ def __new__(self, *args, **kwargs): # create row objects names = sorted(kwargs.keys()) row = tuple.__new__(self, [kwargs[n] for n in names]) - row.__FIELDS__ = names + row.__fields__ = names return row else: @@ -1193,11 +1194,11 @@ def asDict(self): """ Return as an dict """ - if not hasattr(self, "__FIELDS__"): + if not hasattr(self, "__fields__"): raise TypeError("Cannot convert a Row class into dict") - return dict(zip(self.__FIELDS__, self)) + return dict(zip(self.__fields__, self)) - # let obect acs like class + # let object acts like class def __call__(self, *args): """create new Row object""" return _create_row(self, args) @@ -1208,21 +1209,21 @@ def __getattr__(self, item): try: # it will be slow when it has many fields, # but this will not be used in normal cases - idx = self.__FIELDS__.index(item) + idx = self.__fields__.index(item) return self[idx] except IndexError: raise AttributeError(item) def __reduce__(self): - if hasattr(self, "__FIELDS__"): - return (_create_row, (self.__FIELDS__, tuple(self))) + if hasattr(self, "__fields__"): + return (_create_row, (self.__fields__, tuple(self))) else: return tuple.__reduce__(self) def __repr__(self): - if hasattr(self, "__FIELDS__"): + if hasattr(self, "__fields__"): return "Row(%s)" % ", ".join("%s=%r" % (k, v) - for k, v in zip(self.__FIELDS__, self)) + for k, v in zip(self.__fields__, tuple(self))) else: return "" % ", ".join(self) @@ -1237,7 +1238,7 @@ def _test(): globs = pyspark.sql.types.__dict__.copy() sc = SparkContext('local[4]', 'PythonTest') globs['sc'] = sc - globs['sqlCtx'] = sqlCtx = SQLContext(sc) + globs['sqlContext'] = SQLContext(sc) globs['ExamplePoint'] = ExamplePoint globs['ExamplePointUDT'] = ExamplePointUDT (failure_count, test_count) = doctest.testmod( diff --git a/python/pyspark/streaming/tests.py b/python/pyspark/streaming/tests.py index 608f8e26473a6..9b4635e49020b 100644 --- a/python/pyspark/streaming/tests.py +++ b/python/pyspark/streaming/tests.py @@ -23,13 +23,16 @@ import tempfile import struct +from py4j.java_collections import MapConverter + from pyspark.context import SparkConf, SparkContext, RDD from pyspark.streaming.context import StreamingContext +from pyspark.streaming.kafka import KafkaUtils class PySparkStreamingTestCase(unittest.TestCase): - timeout = 10 # seconds + timeout = 20 # seconds duration = 1 def setUp(self): @@ -556,5 +559,43 @@ def check_output(n): check_output(3) +class KafkaStreamTests(PySparkStreamingTestCase): + + def setUp(self): + super(KafkaStreamTests, self).setUp() + + kafkaTestUtilsClz = self.ssc._jvm.java.lang.Thread.currentThread().getContextClassLoader()\ + .loadClass("org.apache.spark.streaming.kafka.KafkaTestUtils") + self._kafkaTestUtils = kafkaTestUtilsClz.newInstance() + self._kafkaTestUtils.setup() + + def tearDown(self): + if self._kafkaTestUtils is not None: + self._kafkaTestUtils.teardown() + self._kafkaTestUtils = None + + super(KafkaStreamTests, self).tearDown() + + def test_kafka_stream(self): + """Test the Python Kafka stream API.""" + topic = "topic1" + sendData = {"a": 3, "b": 5, "c": 10} + jSendData = MapConverter().convert(sendData, + self.ssc.sparkContext._gateway._gateway_client) + + self._kafkaTestUtils.createTopic(topic) + self._kafkaTestUtils.sendMessages(topic, jSendData) + + stream = KafkaUtils.createStream(self.ssc, self._kafkaTestUtils.zkAddress(), + "test-streaming-consumer", {topic: 1}, + {"auto.offset.reset": "smallest"}) + + result = {} + for i in chain.from_iterable(self._collect(stream.map(lambda x: x[1]), + sum(sendData.values()))): + result[i] = result.get(i, 0) + 1 + + self.assertEqual(sendData, result) + if __name__ == "__main__": unittest.main() diff --git a/python/pyspark/tests.py b/python/pyspark/tests.py index dd8d3b1c53733..b938b9ce12395 100644 --- a/python/pyspark/tests.py +++ b/python/pyspark/tests.py @@ -31,9 +31,12 @@ import time import zipfile import random +import itertools import threading import hashlib +from py4j.protocol import Py4JJavaError + if sys.version_info[:2] <= (2, 6): try: import unittest2 as unittest @@ -76,7 +79,7 @@ class MergerTests(unittest.TestCase): def setUp(self): - self.N = 1 << 14 + self.N = 1 << 12 self.l = [i for i in xrange(self.N)] self.data = zip(self.l, self.l) self.agg = Aggregator(lambda x: [x], @@ -108,7 +111,7 @@ def test_small_dataset(self): sum(xrange(self.N))) def test_medium_dataset(self): - m = ExternalMerger(self.agg, 10) + m = ExternalMerger(self.agg, 30) m.mergeValues(self.data) self.assertTrue(m.spills >= 1) self.assertEqual(sum(sum(v) for k, v in m.iteritems()), @@ -124,10 +127,36 @@ def test_huge_dataset(self): m = ExternalMerger(self.agg, 10, partitions=3) m.mergeCombiners(map(lambda (k, v): (k, [str(v)]), self.data * 10)) self.assertTrue(m.spills >= 1) - self.assertEqual(sum(len(v) for k, v in m._recursive_merged_items(0)), + self.assertEqual(sum(len(v) for k, v in m.iteritems()), self.N * 10) m._cleanup() + def test_group_by_key(self): + + def gen_data(N, step): + for i in range(1, N + 1, step): + for j in range(i): + yield (i, [j]) + + def gen_gs(N, step=1): + return shuffle.GroupByKey(gen_data(N, step)) + + self.assertEqual(1, len(list(gen_gs(1)))) + self.assertEqual(2, len(list(gen_gs(2)))) + self.assertEqual(100, len(list(gen_gs(100)))) + self.assertEqual(range(1, 101), [k for k, _ in gen_gs(100)]) + self.assertTrue(all(range(k) == list(vs) for k, vs in gen_gs(100))) + + for k, vs in gen_gs(50002, 10000): + self.assertEqual(k, len(vs)) + self.assertEqual(range(k), list(vs)) + + ser = PickleSerializer() + l = ser.loads(ser.dumps(list(gen_gs(50002, 30000)))) + for k, vs in l: + self.assertEqual(k, len(vs)) + self.assertEqual(range(k), list(vs)) + class SorterTests(unittest.TestCase): def test_in_memory_sort(self): @@ -702,6 +731,21 @@ def test_distinct(self): self.assertEquals(result.getNumPartitions(), 5) self.assertEquals(result.count(), 3) + def test_external_group_by_key(self): + self.sc._conf.set("spark.python.worker.memory", "5m") + N = 200001 + kv = self.sc.parallelize(range(N)).map(lambda x: (x % 3, x)) + gkv = kv.groupByKey().cache() + self.assertEqual(3, gkv.count()) + filtered = gkv.filter(lambda (k, vs): k == 1) + self.assertEqual(1, filtered.count()) + self.assertEqual([(1, N/3)], filtered.mapValues(len).collect()) + self.assertEqual([(N/3, N/3)], + filtered.values().map(lambda x: (len(x), len(list(x)))).collect()) + result = filtered.collect()[0][1] + self.assertEqual(N/3, len(result)) + self.assertTrue(isinstance(result.data, shuffle.ExternalList)) + def test_sort_on_empty_rdd(self): self.assertEqual([], self.sc.parallelize(zip([], [])).sortByKey().collect()) @@ -752,9 +796,9 @@ def test_narrow_dependency_in_join(self): self.assertEqual(rdd.getNumPartitions() + 2, parted.union(rdd).getNumPartitions()) self.assertEqual(rdd.getNumPartitions() + 2, rdd.union(parted).getNumPartitions()) - self.sc.setJobGroup("test1", "test", True) tracker = self.sc.statusTracker() + self.sc.setJobGroup("test1", "test", True) d = sorted(parted.join(parted).collect()) self.assertEqual(10, len(d)) self.assertEqual((0, (0, 0)), d[0]) @@ -787,6 +831,17 @@ def test_take_on_jrdd(self): rdd = self.sc.parallelize(range(1 << 20)).map(lambda x: str(x)) rdd._jrdd.first() + def test_sortByKey_uses_all_partitions_not_only_first_and_last(self): + # Regression test for SPARK-5969 + seq = [(i * 59 % 101, i) for i in range(101)] # unsorted sequence + rdd = self.sc.parallelize(seq) + for ascending in [True, False]: + sort = rdd.sortByKey(ascending=ascending, numPartitions=5) + self.assertEqual(sort.collect(), sorted(seq, reverse=not ascending)) + sizes = sort.glom().map(len).collect() + for size in sizes: + self.assertGreater(size, 0) + class ProfilerTests(PySparkTestCase): @@ -1441,6 +1496,20 @@ def count(): self.assertTrue(not t.isAlive()) self.assertEqual(100000, rdd.count()) + def test_with_different_versions_of_python(self): + rdd = self.sc.parallelize(range(10)) + rdd.count() + version = sys.version_info + sys.version_info = (2, 0, 0) + log4j = self.sc._jvm.org.apache.log4j + old_level = log4j.LogManager.getRootLogger().getLevel() + log4j.LogManager.getRootLogger().setLevel(log4j.Level.FATAL) + try: + self.assertRaises(Py4JJavaError, lambda: rdd.count()) + finally: + sys.version_info = version + log4j.LogManager.getRootLogger().setLevel(old_level) + class SparkSubmitTests(unittest.TestCase): diff --git a/python/pyspark/worker.py b/python/pyspark/worker.py index 8a93c320ec5d3..452d6fabdcc17 100644 --- a/python/pyspark/worker.py +++ b/python/pyspark/worker.py @@ -88,7 +88,11 @@ def main(infile, outfile): command = pickleSer._read_with_length(infile) if isinstance(command, Broadcast): command = pickleSer.loads(command.value) - (func, profiler, deserializer, serializer) = command + (func, profiler, deserializer, serializer), version = command + if version != sys.version_info[:2]: + raise Exception(("Python in worker has different version %s than that in " + + "driver %s, PySpark cannot run with different minor versions") % + (sys.version_info[:2], version)) init_time = time.time() def process(): diff --git a/python/run-tests b/python/run-tests index b7630c356cfae..f3a07d8aba562 100755 --- a/python/run-tests +++ b/python/run-tests @@ -21,6 +21,8 @@ # Figure out where the Spark framework is installed FWDIR="$(cd "`dirname "$0"`"; cd ../; pwd)" +. "$FWDIR"/bin/load-spark-env.sh + # CD into the python directory to find things on the right path cd "$FWDIR/python" @@ -57,7 +59,7 @@ function run_core_tests() { PYSPARK_DOC_TEST=1 run_test "pyspark/broadcast.py" PYSPARK_DOC_TEST=1 run_test "pyspark/accumulators.py" run_test "pyspark/serializers.py" - run_test "pyspark/profiler.py" + run_test "pyspark/profiler.py" run_test "pyspark/shuffle.py" run_test "pyspark/tests.py" } @@ -77,6 +79,7 @@ function run_mllib_tests() { run_test "pyspark/mllib/clustering.py" run_test "pyspark/mllib/evaluation.py" run_test "pyspark/mllib/feature.py" + run_test "pyspark/mllib/fpm.py" run_test "pyspark/mllib/linalg.py" run_test "pyspark/mllib/rand.py" run_test "pyspark/mllib/recommendation.py" @@ -96,6 +99,21 @@ function run_ml_tests() { function run_streaming_tests() { echo "Run streaming tests ..." + + KAFKA_ASSEMBLY_DIR="$FWDIR"/external/kafka-assembly + JAR_PATH="${KAFKA_ASSEMBLY_DIR}/target/scala-${SPARK_SCALA_VERSION}" + for f in "${JAR_PATH}"/spark-streaming-kafka-assembly-*.jar; do + if [[ ! -e "$f" ]]; then + echo "Failed to find Spark Streaming Kafka assembly jar in $KAFKA_ASSEMBLY_DIR" 1>&2 + echo "You need to build Spark with " \ + "'build/sbt assembly/assembly streaming-kafka-assembly/assembly' or" \ + "'build/mvn package' before running this program" 1>&2 + exit 1 + fi + KAFKA_ASSEMBLY_JAR="$f" + done + + export PYSPARK_SUBMIT_ARGS="--jars ${KAFKA_ASSEMBLY_JAR} pyspark-shell" run_test "pyspark/streaming/util.py" run_test "pyspark/streaming/tests.py" } diff --git a/repl/src/test/resources/log4j.properties b/repl/src/test/resources/log4j.properties index e7e4a4113174a..e2ee9c963a4da 100644 --- a/repl/src/test/resources/log4j.properties +++ b/repl/src/test/resources/log4j.properties @@ -24,4 +24,4 @@ log4j.appender.file.layout=org.apache.log4j.PatternLayout log4j.appender.file.layout.ConversionPattern=%d{yy/MM/dd HH:mm:ss.SSS} %t %p %c{1}: %m%n # Ignore messages below warning level from Jetty, because it's a bit verbose -log4j.logger.org.eclipse.jetty=WARN +log4j.logger.org.spark-project.jetty=WARN diff --git a/sbin/spark-daemon.sh b/sbin/spark-daemon.sh index 92e76a3fe6ca2..d8e0facb81169 100755 --- a/sbin/spark-daemon.sh +++ b/sbin/spark-daemon.sh @@ -29,7 +29,7 @@ # SPARK_NICENESS The scheduling priority for daemons. Defaults to 0. ## -usage="Usage: spark-daemon.sh [--config ] (start|stop) " +usage="Usage: spark-daemon.sh [--config ] (start|stop|status) " # if no args specified, show usage if [ $# -le 1 ]; then @@ -195,6 +195,23 @@ case $option in fi ;; + (status) + + if [ -f $pid ]; then + TARGET_ID="$(cat "$pid")" + if [[ $(ps -p "$TARGET_ID" -o comm=) =~ "java" ]]; then + echo $command is running. + exit 0 + else + echo $pid file is present but $command not running + exit 1 + fi + else + echo $command not running. + exit 2 + fi + ;; + (*) echo $usage exit 1 diff --git a/scalastyle-config.xml b/scalastyle-config.xml index 459a5035d4984..7168d5b2a8e26 100644 --- a/scalastyle-config.xml +++ b/scalastyle-config.xml @@ -137,7 +137,7 @@ - + diff --git a/sql/README.md b/sql/README.md index fbb3200a3a4b4..237620e3fa808 100644 --- a/sql/README.md +++ b/sql/README.md @@ -56,6 +56,6 @@ res2: Array[org.apache.spark.sql.Row] = Array([238,val_238], [86,val_86], [311,v You can also build further queries on top of these `DataFrames` using the query DSL. ``` -scala> query.where('key > 30).select(avg('key)).collect() +scala> query.where(query("key") > 30).select(avg(query("key"))).collect() res3: Array[org.apache.spark.sql.Row] = Array([274.79025423728814]) ``` diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/AnalysisException.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/AnalysisException.scala index 34fedead44db3..f9992185a4563 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/AnalysisException.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/AnalysisException.scala @@ -30,7 +30,7 @@ class AnalysisException protected[sql] ( val startPosition: Option[Int] = None) extends Exception with Serializable { - def withPosition(line: Option[Int], startPosition: Option[Int]) = { + def withPosition(line: Option[Int], startPosition: Option[Int]): AnalysisException = { val newException = new AnalysisException(message, line, startPosition) newException.setStackTrace(getStackTrace) newException diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystTypeConverters.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystTypeConverters.scala new file mode 100644 index 0000000000000..91976fef6dc0d --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystTypeConverters.scala @@ -0,0 +1,295 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst + +import java.util.{Map => JavaMap} + +import scala.collection.mutable.HashMap + +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.types._ + +/** + * Functions to convert Scala types to Catalyst types and vice versa. + */ +object CatalystTypeConverters { + // The Predef.Map is scala.collection.immutable.Map. + // Since the map values can be mutable, we explicitly import scala.collection.Map at here. + import scala.collection.Map + + /** + * Converts Scala objects to catalyst rows / types. This method is slow, and for batch + * conversion you should be using converter produced by createToCatalystConverter. + * Note: This is always called after schemaFor has been called. + * This ordering is important for UDT registration. + */ + def convertToCatalyst(a: Any, dataType: DataType): Any = (a, dataType) match { + // Check UDT first since UDTs can override other types + case (obj, udt: UserDefinedType[_]) => + udt.serialize(obj) + + case (o: Option[_], _) => + o.map(convertToCatalyst(_, dataType)).orNull + + case (s: Seq[_], arrayType: ArrayType) => + s.map(convertToCatalyst(_, arrayType.elementType)) + + case (s: Array[_], arrayType: ArrayType) => + s.toSeq.map(convertToCatalyst(_, arrayType.elementType)) + + case (m: Map[_, _], mapType: MapType) => + m.map { case (k, v) => + convertToCatalyst(k, mapType.keyType) -> convertToCatalyst(v, mapType.valueType) + } + + case (jmap: JavaMap[_, _], mapType: MapType) => + val iter = jmap.entrySet.iterator + var listOfEntries: List[(Any, Any)] = List() + while (iter.hasNext) { + val entry = iter.next() + listOfEntries :+= (convertToCatalyst(entry.getKey, mapType.keyType), + convertToCatalyst(entry.getValue, mapType.valueType)) + } + listOfEntries.toMap + + case (p: Product, structType: StructType) => + val ar = new Array[Any](structType.size) + val iter = p.productIterator + var idx = 0 + while (idx < structType.size) { + ar(idx) = convertToCatalyst(iter.next(), structType.fields(idx).dataType) + idx += 1 + } + new GenericRowWithSchema(ar, structType) + + case (d: BigDecimal, _) => + Decimal(d) + + case (d: java.math.BigDecimal, _) => + Decimal(d) + + case (d: java.sql.Date, _) => + DateUtils.fromJavaDate(d) + + case (r: Row, structType: StructType) => + val converters = structType.fields.map { + f => (item: Any) => convertToCatalyst(item, f.dataType) + } + convertRowWithConverters(r, structType, converters) + + case (other, _) => + other + } + + /** + * Creates a converter function that will convert Scala objects to the specified catalyst type. + * Typical use case would be converting a collection of rows that have the same schema. You will + * call this function once to get a converter, and apply it to every row. + */ + private[sql] def createToCatalystConverter(dataType: DataType): Any => Any = { + def extractOption(item: Any): Any = item match { + case opt: Option[_] => opt.orNull + case other => other + } + + dataType match { + // Check UDT first since UDTs can override other types + case udt: UserDefinedType[_] => + (item) => extractOption(item) match { + case null => null + case other => udt.serialize(other) + } + + case arrayType: ArrayType => + val elementConverter = createToCatalystConverter(arrayType.elementType) + (item: Any) => { + extractOption(item) match { + case a: Array[_] => a.toSeq.map(elementConverter) + case s: Seq[_] => s.map(elementConverter) + case null => null + } + } + + case mapType: MapType => + val keyConverter = createToCatalystConverter(mapType.keyType) + val valueConverter = createToCatalystConverter(mapType.valueType) + (item: Any) => { + extractOption(item) match { + case m: Map[_, _] => + m.map { case (k, v) => + keyConverter(k) -> valueConverter(v) + } + + case jmap: JavaMap[_, _] => + val iter = jmap.entrySet.iterator + val convertedMap: HashMap[Any, Any] = HashMap() + while (iter.hasNext) { + val entry = iter.next() + convertedMap(keyConverter(entry.getKey)) = valueConverter(entry.getValue) + } + convertedMap + + case null => null + } + } + + case structType: StructType => + val converters = structType.fields.map(f => createToCatalystConverter(f.dataType)) + (item: Any) => { + extractOption(item) match { + case r: Row => + convertRowWithConverters(r, structType, converters) + + case p: Product => + val ar = new Array[Any](structType.size) + val iter = p.productIterator + var idx = 0 + while (idx < structType.size) { + ar(idx) = converters(idx)(iter.next()) + idx += 1 + } + new GenericRowWithSchema(ar, structType) + + case null => + null + } + } + + case dateType: DateType => (item: Any) => extractOption(item) match { + case d: java.sql.Date => DateUtils.fromJavaDate(d) + case other => other + } + + case _ => + (item: Any) => extractOption(item) match { + case d: BigDecimal => Decimal(d) + case d: java.math.BigDecimal => Decimal(d) + case other => other + } + } + } + + /** + * Converts Catalyst types used internally in rows to standard Scala types + * This method is slow, and for batch conversion you should be using converter + * produced by createToScalaConverter. + */ + def convertToScala(a: Any, dataType: DataType): Any = (a, dataType) match { + // Check UDT first since UDTs can override other types + case (d, udt: UserDefinedType[_]) => + udt.deserialize(d) + + case (s: Seq[_], arrayType: ArrayType) => + s.map(convertToScala(_, arrayType.elementType)) + + case (m: Map[_, _], mapType: MapType) => + m.map { case (k, v) => + convertToScala(k, mapType.keyType) -> convertToScala(v, mapType.valueType) + } + + case (r: Row, s: StructType) => + convertRowToScala(r, s) + + case (d: Decimal, _: DecimalType) => + d.toJavaBigDecimal + + case (i: Int, DateType) => + DateUtils.toJavaDate(i) + + case (other, _) => + other + } + + /** + * Creates a converter function that will convert Catalyst types to Scala type. + * Typical use case would be converting a collection of rows that have the same schema. You will + * call this function once to get a converter, and apply it to every row. + */ + private[sql] def createToScalaConverter(dataType: DataType): Any => Any = dataType match { + // Check UDT first since UDTs can override other types + case udt: UserDefinedType[_] => + (item: Any) => if (item == null) null else udt.deserialize(item) + + case arrayType: ArrayType => + val elementConverter = createToScalaConverter(arrayType.elementType) + (item: Any) => if (item == null) null else item.asInstanceOf[Seq[_]].map(elementConverter) + + case mapType: MapType => + val keyConverter = createToScalaConverter(mapType.keyType) + val valueConverter = createToScalaConverter(mapType.valueType) + (item: Any) => if (item == null) { + null + } else { + item.asInstanceOf[Map[_, _]].map { case (k, v) => + keyConverter(k) -> valueConverter(v) + } + } + + case s: StructType => + val converters = s.fields.map(f => createToScalaConverter(f.dataType)) + (item: Any) => { + if (item == null) { + null + } else { + convertRowWithConverters(item.asInstanceOf[Row], s, converters) + } + } + + case _: DecimalType => + (item: Any) => item match { + case d: Decimal => d.toJavaBigDecimal + case other => other + } + + case DateType => + (item: Any) => item match { + case i: Int => DateUtils.toJavaDate(i) + case other => other + } + + case other => + (item: Any) => item + } + + def convertRowToScala(r: Row, schema: StructType): Row = { + val ar = new Array[Any](r.size) + var idx = 0 + while (idx < r.size) { + ar(idx) = convertToScala(r(idx), schema.fields(idx).dataType) + idx += 1 + } + new GenericRowWithSchema(ar, schema) + } + + /** + * Converts a row by applying the provided set of converter functions. It is used for both + * toScala and toCatalyst conversions. + */ + private[sql] def convertRowWithConverters( + row: Row, + schema: StructType, + converters: Array[Any => Any]): Row = { + val ar = new Array[Any](row.size) + var idx = 0 + while (idx < row.size) { + ar(idx) = converters(idx)(row(idx)) + idx += 1 + } + new GenericRowWithSchema(ar, schema) + } +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala index 2220970085462..01d5c1512201a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala @@ -46,56 +46,6 @@ trait ScalaReflection { case class Schema(dataType: DataType, nullable: Boolean) - /** - * Converts Scala objects to catalyst rows / types. - * Note: This is always called after schemaFor has been called. - * This ordering is important for UDT registration. - */ - def convertToCatalyst(a: Any, dataType: DataType): Any = (a, dataType) match { - // Check UDT first since UDTs can override other types - case (obj, udt: UserDefinedType[_]) => udt.serialize(obj) - case (o: Option[_], _) => o.map(convertToCatalyst(_, dataType)).orNull - case (s: Seq[_], arrayType: ArrayType) => s.map(convertToCatalyst(_, arrayType.elementType)) - case (s: Array[_], arrayType: ArrayType) => if (arrayType.elementType.isPrimitive) { - s.toSeq - } else { - s.toSeq.map(convertToCatalyst(_, arrayType.elementType)) - } - case (m: Map[_, _], mapType: MapType) => m.map { case (k, v) => - convertToCatalyst(k, mapType.keyType) -> convertToCatalyst(v, mapType.valueType) - } - case (p: Product, structType: StructType) => - new GenericRow( - p.productIterator.toSeq.zip(structType.fields).map { case (elem, field) => - convertToCatalyst(elem, field.dataType) - }.toArray) - case (d: BigDecimal, _) => Decimal(d) - case (d: java.math.BigDecimal, _) => Decimal(d) - case (d: java.sql.Date, _) => DateUtils.fromJavaDate(d) - case (other, _) => other - } - - /** Converts Catalyst types used internally in rows to standard Scala types */ - def convertToScala(a: Any, dataType: DataType): Any = (a, dataType) match { - // Check UDT first since UDTs can override other types - case (d, udt: UserDefinedType[_]) => udt.deserialize(d) - case (s: Seq[_], arrayType: ArrayType) => s.map(convertToScala(_, arrayType.elementType)) - case (m: Map[_, _], mapType: MapType) => m.map { case (k, v) => - convertToScala(k, mapType.keyType) -> convertToScala(v, mapType.valueType) - } - case (r: Row, s: StructType) => convertRowToScala(r, s) - case (d: Decimal, _: DecimalType) => d.toJavaBigDecimal - case (i: Int, DateType) => DateUtils.toJavaDate(i) - case (other, _) => other - } - - def convertRowToScala(r: Row, schema: StructType): Row = { - // TODO: This is very slow!!! - new GenericRowWithSchema( - r.toSeq.zip(schema.fields.map(_.dataType)) - .map(r_dt => convertToScala(r_dt._1, r_dt._2)).toArray, schema) - } - /** Returns a Sequence of attributes for the given case class type. */ def attributesFor[T: TypeTag]: Seq[Attribute] = schemaFor[T] match { case Schema(s: StructType, _) => diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala index b176f7e729a42..ee04cb579deb6 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala @@ -111,6 +111,7 @@ class SqlParser extends AbstractSparkSQLParser with DataTypeParser { protected val UPPER = Keyword("UPPER") protected val WHEN = Keyword("WHEN") protected val WHERE = Keyword("WHERE") + protected val WITH = Keyword("WITH") protected def assignAliases(exprs: Seq[Expression]): Seq[NamedExpression] = { exprs.zipWithIndex.map { @@ -127,6 +128,7 @@ class SqlParser extends AbstractSparkSQLParser with DataTypeParser { | UNION ~ DISTINCT.? ^^^ { (q1: LogicalPlan, q2: LogicalPlan) => Distinct(Union(q1, q2)) } ) | insert + | cte ) protected lazy val select: Parser[LogicalPlan] = @@ -156,6 +158,11 @@ class SqlParser extends AbstractSparkSQLParser with DataTypeParser { case o ~ r ~ s => InsertIntoTable(r, Map.empty[String, Option[String]], s, o) } + protected lazy val cte: Parser[LogicalPlan] = + WITH ~> rep1sep(ident ~ ( AS ~ "(" ~> start <~ ")"), ",") ~ start ^^ { + case r ~ s => With(s, r.map({case n ~ s => (n, Subquery(n, s))}).toMap) + } + protected lazy val projection: Parser[Expression] = expression ~ (AS.? ~> ident.?) ^^ { case e ~ a => a.fold(e)(Alias(e, _)()) @@ -316,13 +323,13 @@ class SqlParser extends AbstractSparkSQLParser with DataTypeParser { protected lazy val literal: Parser[Literal] = ( numericLiteral | booleanLiteral - | stringLit ^^ {case s => Literal(s, StringType) } - | NULL ^^^ Literal(null, NullType) + | stringLit ^^ {case s => Literal.create(s, StringType) } + | NULL ^^^ Literal.create(null, NullType) ) protected lazy val booleanLiteral: Parser[Literal] = - ( TRUE ^^^ Literal(true, BooleanType) - | FALSE ^^^ Literal(false, BooleanType) + ( TRUE ^^^ Literal.create(true, BooleanType) + | FALSE ^^^ Literal.create(false, BooleanType) ) protected lazy val numericLiteral: Parser[Literal] = diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index c578d084a45b6..50702ac6832ec 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -27,7 +27,7 @@ import org.apache.spark.sql.types._ /** * A trivial [[Analyzer]] with an [[EmptyCatalog]] and [[EmptyFunctionRegistry]]. Used for testing - * when all relations are already filled in and the analyser needs only to resolve attribute + * when all relations are already filled in and the analyzer needs only to resolve attribute * references. */ object SimpleAnalyzer extends Analyzer(EmptyCatalog, EmptyFunctionRegistry, true) @@ -140,10 +140,10 @@ class Analyzer( case x: Expression if nonSelectedGroupExprSet.contains(x) => // if the input attribute in the Invalid Grouping Expression set of for this group // replace it with constant null - Literal(null, expr.dataType) + Literal.create(null, expr.dataType) case x if x == g.gid => // replace the groupingId with concrete value (the bit mask) - Literal(bitmask, IntegerType) + Literal.create(bitmask, IntegerType) }) result += GroupExpression(substitution) @@ -169,21 +169,36 @@ class Analyzer( * Replaces [[UnresolvedRelation]]s with concrete relations from the catalog. */ object ResolveRelations extends Rule[LogicalPlan] { - def getTable(u: UnresolvedRelation): LogicalPlan = { + def getTable(u: UnresolvedRelation, cteRelations: Map[String, LogicalPlan]): LogicalPlan = { try { - catalog.lookupRelation(u.tableIdentifier, u.alias) + // In hive, if there is same table name in database and CTE definition, + // hive will use the table in database, not the CTE one. + // Taking into account the reasonableness and the implementation complexity, + // here use the CTE definition first, check table name only and ignore database name + cteRelations.get(u.tableIdentifier.last) + .map(relation => u.alias.map(Subquery(_, relation)).getOrElse(relation)) + .getOrElse(catalog.lookupRelation(u.tableIdentifier, u.alias)) } catch { case _: NoSuchTableException => u.failAnalysis(s"no such table ${u.tableName}") } } - def apply(plan: LogicalPlan): LogicalPlan = plan transform { - case i @ InsertIntoTable(u: UnresolvedRelation, _, _, _) => - i.copy( - table = EliminateSubQueries(getTable(u))) - case u: UnresolvedRelation => - getTable(u) + def apply(plan: LogicalPlan): LogicalPlan = { + val (realPlan, cteRelations) = plan match { + // TODO allow subquery to define CTE + // Add cte table to a temp relation map,drop `with` plan and keep its child + case With(child, relations) => (child, relations) + case other => (other, Map.empty[String, LogicalPlan]) + } + + realPlan transform { + case i@InsertIntoTable(u: UnresolvedRelation, _, _, _) => + i.copy( + table = EliminateSubQueries(getTable(u, cteRelations))) + case u: UnresolvedRelation => + getTable(u, cteRelations) + } } } @@ -293,7 +308,7 @@ class Analyzer( logDebug(s"Resolving $u to $result") result case UnresolvedGetField(child, fieldName) if child.resolved => - resolveGetField(child, fieldName) + GetField(child, fieldName, resolver) } } @@ -313,36 +328,6 @@ class Analyzer( */ protected def containsStar(exprs: Seq[Expression]): Boolean = exprs.exists(_.collect { case _: Star => true }.nonEmpty) - - /** - * Returns the resolved `GetField`, and report error if no desired field or over one - * desired fields are found. - */ - protected def resolveGetField(expr: Expression, fieldName: String): Expression = { - def findField(fields: Array[StructField]): Int = { - val checkField = (f: StructField) => resolver(f.name, fieldName) - val ordinal = fields.indexWhere(checkField) - if (ordinal == -1) { - throw new AnalysisException( - s"No such struct field $fieldName in ${fields.map(_.name).mkString(", ")}") - } else if (fields.indexWhere(checkField, ordinal + 1) != -1) { - throw new AnalysisException( - s"Ambiguous reference to fields ${fields.filter(checkField).mkString(", ")}") - } else { - ordinal - } - } - expr.dataType match { - case StructType(fields) => - val ordinal = findField(fields) - StructGetField(expr, fields(ordinal), ordinal) - case ArrayType(StructType(fields), containsNull) => - val ordinal = findField(fields) - ArrayGetField(expr, fields(ordinal), ordinal, containsNull) - case otherType => - throw new AnalysisException(s"GetField is not valid on fields of type $otherType") - } - } } /** diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Catalog.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Catalog.scala index 5eb7dff0cede8..b2f8157a1a61f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Catalog.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Catalog.scala @@ -22,7 +22,7 @@ import scala.collection.mutable import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, Subquery} /** - * Thrown by a catalog when a table cannot be found. The analzyer will rethrow the exception + * Thrown by a catalog when a table cannot be found. The analyzer will rethrow the exception * as an AnalysisException with the correct position information. */ class NoSuchTableException extends Exception @@ -201,7 +201,7 @@ trait OverrideCatalog extends Catalog { /** * A trivial catalog that returns an error when a relation is requested. Used for testing when all - * relations are already filled in and the analyser needs only to resolve attribute references. + * relations are already filled in and the analyzer needs only to resolve attribute references. */ object EmptyCatalog extends Catalog { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala index c43ea55899695..16ca5bcd57a72 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala @@ -57,8 +57,8 @@ class SimpleFunctionRegistry(val caseSensitive: Boolean) extends FunctionRegistr } /** - * A trivial catalog that returns an error when a function is requested. Used for testing when all - * functions are already filled in and the analyser needs only to resolve attribute references. + * A trivial catalog that returns an error when a function is requested. Used for testing when all + * functions are already filled in and the analyzer needs only to resolve attribute references. */ object EmptyFunctionRegistry extends FunctionRegistry { override def registerFunction(name: String, builder: FunctionBuilder): Unit = { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala index 3c7b46e0702a2..3aeb964994d37 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala @@ -115,7 +115,7 @@ trait HiveTypeCoercion { * the appropriate numeric equivalent. */ object ConvertNaNs extends Rule[LogicalPlan] { - val stringNaN = Literal("NaN", StringType) + val stringNaN = Literal.create("NaN", StringType) def apply(plan: LogicalPlan): LogicalPlan = plan transform { case q: LogicalPlan => q transformExpressions { @@ -285,6 +285,7 @@ trait HiveTypeCoercion { * Calculates and propagates precision for fixed-precision decimals. Hive has a number of * rules for this based on the SQL standard and MS SQL: * https://cwiki.apache.org/confluence/download/attachments/27362075/Hive_Decimal_Precision_Scale_Support.pdf + * https://msdn.microsoft.com/en-us/library/ms190476.aspx * * In particular, if we have expressions e1 and e2 with precision/scale p1/s2 and p2/s2 * respectively, then the following operations have the following precision / scale: @@ -296,6 +297,7 @@ trait HiveTypeCoercion { * e1 * e2 p1 + p2 + 1 s1 + s2 * e1 / e2 p1 - s1 + s2 + max(6, s1 + p2 + 1) max(6, s1 + p2 + 1) * e1 % e2 min(p1-s1, p2-s2) + max(s1, s2) max(s1, s2) + * e1 union e2 max(s1, s2) + max(p1-s1, p2-s2) max(s1, s2) * sum(e1) p1 + 10 s1 * avg(e1) p1 + 4 s1 + 4 * @@ -311,7 +313,12 @@ trait HiveTypeCoercion { * - SHORT gets turned into DECIMAL(5, 0) * - INT gets turned into DECIMAL(10, 0) * - LONG gets turned into DECIMAL(20, 0) - * - FLOAT and DOUBLE cause fixed-length decimals to turn into DOUBLE (this is the same as Hive, + * - FLOAT and DOUBLE + * 1. Union operation: + * FLOAT gets turned into DECIMAL(7, 7), DOUBLE gets turned into DECIMAL(15, 15) (this is the + * same as Hive) + * 2. Other operation: + * FLOAT and DOUBLE cause fixed-length decimals to turn into DOUBLE (this is the same as Hive, * but note that unlimited decimals are considered bigger than doubles in WidenTypes) */ // scalastyle:on @@ -328,76 +335,127 @@ trait HiveTypeCoercion { def isFloat(t: DataType): Boolean = t == FloatType || t == DoubleType - def apply(plan: LogicalPlan): LogicalPlan = plan transformAllExpressions { - // Skip nodes whose children have not been resolved yet - case e if !e.childrenResolved => e + // Conversion rules for float and double into fixed-precision decimals + val floatTypeToFixed: Map[DataType, DecimalType] = Map( + FloatType -> DecimalType(7, 7), + DoubleType -> DecimalType(15, 15) + ) - case Add(e1 @ DecimalType.Expression(p1, s1), e2 @ DecimalType.Expression(p2, s2)) => - Cast( - Add(Cast(e1, DecimalType.Unlimited), Cast(e2, DecimalType.Unlimited)), - DecimalType(max(s1, s2) + max(p1 - s1, p2 - s2) + 1, max(s1, s2)) - ) - - case Subtract(e1 @ DecimalType.Expression(p1, s1), e2 @ DecimalType.Expression(p2, s2)) => - Cast( - Subtract(Cast(e1, DecimalType.Unlimited), Cast(e2, DecimalType.Unlimited)), - DecimalType(max(s1, s2) + max(p1 - s1, p2 - s2) + 1, max(s1, s2)) - ) - - case Multiply(e1 @ DecimalType.Expression(p1, s1), e2 @ DecimalType.Expression(p2, s2)) => - Cast( - Multiply(Cast(e1, DecimalType.Unlimited), Cast(e2, DecimalType.Unlimited)), - DecimalType(p1 + p2 + 1, s1 + s2) - ) - - case Divide(e1 @ DecimalType.Expression(p1, s1), e2 @ DecimalType.Expression(p2, s2)) => - Cast( - Divide(Cast(e1, DecimalType.Unlimited), Cast(e2, DecimalType.Unlimited)), - DecimalType(p1 - s1 + s2 + max(6, s1 + p2 + 1), max(6, s1 + p2 + 1)) - ) - - case Remainder(e1 @ DecimalType.Expression(p1, s1), e2 @ DecimalType.Expression(p2, s2)) => - Cast( - Remainder(Cast(e1, DecimalType.Unlimited), Cast(e2, DecimalType.Unlimited)), - DecimalType(min(p1 - s1, p2 - s2) + max(s1, s2), max(s1, s2)) - ) - - case LessThan(e1 @ DecimalType.Expression(p1, s1), - e2 @ DecimalType.Expression(p2, s2)) if p1 != p2 || s1 != s2 => - LessThan(Cast(e1, DecimalType.Unlimited), Cast(e2, DecimalType.Unlimited)) - - case LessThanOrEqual(e1 @ DecimalType.Expression(p1, s1), - e2 @ DecimalType.Expression(p2, s2)) if p1 != p2 || s1 != s2 => - LessThanOrEqual(Cast(e1, DecimalType.Unlimited), Cast(e2, DecimalType.Unlimited)) - - case GreaterThan(e1 @ DecimalType.Expression(p1, s1), - e2 @ DecimalType.Expression(p2, s2)) if p1 != p2 || s1 != s2 => - GreaterThan(Cast(e1, DecimalType.Unlimited), Cast(e2, DecimalType.Unlimited)) - - case GreaterThanOrEqual(e1 @ DecimalType.Expression(p1, s1), - e2 @ DecimalType.Expression(p2, s2)) if p1 != p2 || s1 != s2 => - GreaterThanOrEqual(Cast(e1, DecimalType.Unlimited), Cast(e2, DecimalType.Unlimited)) - - // Promote integers inside a binary expression with fixed-precision decimals to decimals, - // and fixed-precision decimals in an expression with floats / doubles to doubles - case b: BinaryExpression if b.left.dataType != b.right.dataType => - (b.left.dataType, b.right.dataType) match { - case (t, DecimalType.Fixed(p, s)) if intTypeToFixed.contains(t) => - b.makeCopy(Array(Cast(b.left, intTypeToFixed(t)), b.right)) - case (DecimalType.Fixed(p, s), t) if intTypeToFixed.contains(t) => - b.makeCopy(Array(b.left, Cast(b.right, intTypeToFixed(t)))) - case (t, DecimalType.Fixed(p, s)) if isFloat(t) => - b.makeCopy(Array(b.left, Cast(b.right, DoubleType))) - case (DecimalType.Fixed(p, s), t) if isFloat(t) => - b.makeCopy(Array(Cast(b.left, DoubleType), b.right)) - case _ => - b + def apply(plan: LogicalPlan): LogicalPlan = plan transform { + // fix decimal precision for union + case u @ Union(left, right) if u.childrenResolved && !u.resolved => + val castedInput = left.output.zip(right.output).map { + case (l, r) if l.dataType != r.dataType => + (l.dataType, r.dataType) match { + case (DecimalType.Fixed(p1, s1), DecimalType.Fixed(p2, s2)) => + // Union decimals with precision/scale p1/s2 and p2/s2 will be promoted to + // DecimalType(max(s1, s2) + max(p1-s1, p2-s2), max(s1, s2)) + val fixedType = DecimalType(max(s1, s2) + max(p1 - s1, p2 - s2), max(s1, s2)) + (Alias(Cast(l, fixedType), l.name)(), Alias(Cast(r, fixedType), r.name)()) + case (t, DecimalType.Fixed(p, s)) if intTypeToFixed.contains(t) => + (Alias(Cast(l, intTypeToFixed(t)), l.name)(), r) + case (DecimalType.Fixed(p, s), t) if intTypeToFixed.contains(t) => + (l, Alias(Cast(r, intTypeToFixed(t)), r.name)()) + case (t, DecimalType.Fixed(p, s)) if floatTypeToFixed.contains(t) => + (Alias(Cast(l, floatTypeToFixed(t)), l.name)(), r) + case (DecimalType.Fixed(p, s), t) if floatTypeToFixed.contains(t) => + (l, Alias(Cast(r, floatTypeToFixed(t)), r.name)()) + case _ => (l, r) + } + case other => other } - // TODO: MaxOf, MinOf, etc might want other rules + val (castedLeft, castedRight) = castedInput.unzip - // SUM and AVERAGE are handled by the implementations of those expressions + val newLeft = + if (castedLeft.map(_.dataType) != left.output.map(_.dataType)) { + Project(castedLeft, left) + } else { + left + } + + val newRight = + if (castedRight.map(_.dataType) != right.output.map(_.dataType)) { + Project(castedRight, right) + } else { + right + } + + Union(newLeft, newRight) + + // fix decimal precision for expressions + case q => q.transformExpressions { + // Skip nodes whose children have not been resolved yet + case e if !e.childrenResolved => e + + case Add(e1 @ DecimalType.Expression(p1, s1), e2 @ DecimalType.Expression(p2, s2)) => + Cast( + Add(Cast(e1, DecimalType.Unlimited), Cast(e2, DecimalType.Unlimited)), + DecimalType(max(s1, s2) + max(p1 - s1, p2 - s2) + 1, max(s1, s2)) + ) + + case Subtract(e1 @ DecimalType.Expression(p1, s1), e2 @ DecimalType.Expression(p2, s2)) => + Cast( + Subtract(Cast(e1, DecimalType.Unlimited), Cast(e2, DecimalType.Unlimited)), + DecimalType(max(s1, s2) + max(p1 - s1, p2 - s2) + 1, max(s1, s2)) + ) + + case Multiply(e1 @ DecimalType.Expression(p1, s1), e2 @ DecimalType.Expression(p2, s2)) => + Cast( + Multiply(Cast(e1, DecimalType.Unlimited), Cast(e2, DecimalType.Unlimited)), + DecimalType(p1 + p2 + 1, s1 + s2) + ) + + case Divide(e1 @ DecimalType.Expression(p1, s1), e2 @ DecimalType.Expression(p2, s2)) => + Cast( + Divide(Cast(e1, DecimalType.Unlimited), Cast(e2, DecimalType.Unlimited)), + DecimalType(p1 - s1 + s2 + max(6, s1 + p2 + 1), max(6, s1 + p2 + 1)) + ) + + case Remainder(e1 @ DecimalType.Expression(p1, s1), e2 @ DecimalType.Expression(p2, s2)) => + Cast( + Remainder(Cast(e1, DecimalType.Unlimited), Cast(e2, DecimalType.Unlimited)), + DecimalType(min(p1 - s1, p2 - s2) + max(s1, s2), max(s1, s2)) + ) + + case LessThan(e1 @ DecimalType.Expression(p1, s1), + e2 @ DecimalType.Expression(p2, s2)) if p1 != p2 || s1 != s2 => + LessThan(Cast(e1, DecimalType.Unlimited), Cast(e2, DecimalType.Unlimited)) + + case LessThanOrEqual(e1 @ DecimalType.Expression(p1, s1), + e2 @ DecimalType.Expression(p2, s2)) if p1 != p2 || s1 != s2 => + LessThanOrEqual(Cast(e1, DecimalType.Unlimited), Cast(e2, DecimalType.Unlimited)) + + case GreaterThan(e1 @ DecimalType.Expression(p1, s1), + e2 @ DecimalType.Expression(p2, s2)) if p1 != p2 || s1 != s2 => + GreaterThan(Cast(e1, DecimalType.Unlimited), Cast(e2, DecimalType.Unlimited)) + + case GreaterThanOrEqual(e1 @ DecimalType.Expression(p1, s1), + e2 @ DecimalType.Expression(p2, s2)) if p1 != p2 || s1 != s2 => + GreaterThanOrEqual(Cast(e1, DecimalType.Unlimited), Cast(e2, DecimalType.Unlimited)) + + // Promote integers inside a binary expression with fixed-precision decimals to decimals, + // and fixed-precision decimals in an expression with floats / doubles to doubles + case b: BinaryExpression if b.left.dataType != b.right.dataType => + (b.left.dataType, b.right.dataType) match { + case (t, DecimalType.Fixed(p, s)) if intTypeToFixed.contains(t) => + b.makeCopy(Array(Cast(b.left, intTypeToFixed(t)), b.right)) + case (DecimalType.Fixed(p, s), t) if intTypeToFixed.contains(t) => + b.makeCopy(Array(b.left, Cast(b.right, intTypeToFixed(t)))) + case (t, DecimalType.Fixed(p, s)) if isFloat(t) => + b.makeCopy(Array(b.left, Cast(b.right, DoubleType))) + case (DecimalType.Fixed(p, s), t) if isFloat(t) => + b.makeCopy(Array(Cast(b.left, DoubleType), b.right)) + case _ => + b + } + + // TODO: MaxOf, MinOf, etc might want other rules + + // SUM and AVERAGE are handled by the implementations of those expressions + } } + } /** diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/package.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/package.scala index c61c395cb4bb1..7731336d247db 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/package.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/package.scala @@ -44,7 +44,7 @@ package object analysis { } /** Catches any AnalysisExceptions thrown by `f` and attaches `t`'s position if any. */ - def withPosition[A](t: TreeNode[_])(f: => A) = { + def withPosition[A](t: TreeNode[_])(f: => A): A = { try f catch { case a: AnalysisException => throw a.withPosition(t.origin.line, t.origin.startPosition) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUdf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUdf.scala index 389dc4f745723..9a77ca624ebe2 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUdf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUdf.scala @@ -17,7 +17,7 @@ package org.apache.spark.sql.catalyst.expressions -import org.apache.spark.sql.catalyst.ScalaReflection +import org.apache.spark.sql.catalyst.CatalystTypeConverters import org.apache.spark.sql.types.DataType /** @@ -39,12 +39,14 @@ case class ScalaUdf(function: AnyRef, dataType: DataType, children: Seq[Expressi (1 to 22).map { x => val anys = (1 to x).map(x => "Any").reduce(_ + ", " + _) - val childs = (0 to x - 1).map(x => s"val child$x = children($x)").reduce(_ + "\n " + _) - val evals = (0 to x - 1).map(x => s"ScalaReflection.convertToScala(child$x.eval(input), child$x.dataType)").reduce(_ + ",\n " + _) + val childs = (0 to x - 1).map(x => s"val child$x = children($x)").reduce(_ + "\n " + _) + lazy val converters = (0 to x - 1).map(x => s"lazy val converter$x = CatalystTypeConverters.createToScalaConverter(child$x.dataType)").reduce(_ + "\n " + _) + val evals = (0 to x - 1).map(x => s"converter$x(child$x.eval(input))").reduce(_ + ",\n " + _) - s""" case $x => + s"""case $x => val func = function.asInstanceOf[($anys) => Any] $childs + $converters (input: Row) => { func( $evals) @@ -60,51 +62,61 @@ case class ScalaUdf(function: AnyRef, dataType: DataType, children: Seq[Expressi (input: Row) => { func() } - + case 1 => val func = function.asInstanceOf[(Any) => Any] val child0 = children(0) + lazy val converter0 = CatalystTypeConverters.createToScalaConverter(child0.dataType) (input: Row) => { func( - ScalaReflection.convertToScala(child0.eval(input), child0.dataType)) + converter0(child0.eval(input))) } - + case 2 => val func = function.asInstanceOf[(Any, Any) => Any] val child0 = children(0) val child1 = children(1) + lazy val converter0 = CatalystTypeConverters.createToScalaConverter(child0.dataType) + lazy val converter1 = CatalystTypeConverters.createToScalaConverter(child1.dataType) (input: Row) => { func( - ScalaReflection.convertToScala(child0.eval(input), child0.dataType), - ScalaReflection.convertToScala(child1.eval(input), child1.dataType)) + converter0(child0.eval(input)), + converter1(child1.eval(input))) } - + case 3 => val func = function.asInstanceOf[(Any, Any, Any) => Any] val child0 = children(0) val child1 = children(1) val child2 = children(2) + lazy val converter0 = CatalystTypeConverters.createToScalaConverter(child0.dataType) + lazy val converter1 = CatalystTypeConverters.createToScalaConverter(child1.dataType) + lazy val converter2 = CatalystTypeConverters.createToScalaConverter(child2.dataType) (input: Row) => { func( - ScalaReflection.convertToScala(child0.eval(input), child0.dataType), - ScalaReflection.convertToScala(child1.eval(input), child1.dataType), - ScalaReflection.convertToScala(child2.eval(input), child2.dataType)) + converter0(child0.eval(input)), + converter1(child1.eval(input)), + converter2(child2.eval(input))) } - + case 4 => val func = function.asInstanceOf[(Any, Any, Any, Any) => Any] val child0 = children(0) val child1 = children(1) val child2 = children(2) val child3 = children(3) + lazy val converter0 = CatalystTypeConverters.createToScalaConverter(child0.dataType) + lazy val converter1 = CatalystTypeConverters.createToScalaConverter(child1.dataType) + lazy val converter2 = CatalystTypeConverters.createToScalaConverter(child2.dataType) + lazy val converter3 = CatalystTypeConverters.createToScalaConverter(child3.dataType) (input: Row) => { func( - ScalaReflection.convertToScala(child0.eval(input), child0.dataType), - ScalaReflection.convertToScala(child1.eval(input), child1.dataType), - ScalaReflection.convertToScala(child2.eval(input), child2.dataType), - ScalaReflection.convertToScala(child3.eval(input), child3.dataType)) + converter0(child0.eval(input)), + converter1(child1.eval(input)), + converter2(child2.eval(input)), + converter3(child3.eval(input))) } - + case 5 => val func = function.asInstanceOf[(Any, Any, Any, Any, Any) => Any] val child0 = children(0) @@ -112,15 +124,20 @@ case class ScalaUdf(function: AnyRef, dataType: DataType, children: Seq[Expressi val child2 = children(2) val child3 = children(3) val child4 = children(4) + lazy val converter0 = CatalystTypeConverters.createToScalaConverter(child0.dataType) + lazy val converter1 = CatalystTypeConverters.createToScalaConverter(child1.dataType) + lazy val converter2 = CatalystTypeConverters.createToScalaConverter(child2.dataType) + lazy val converter3 = CatalystTypeConverters.createToScalaConverter(child3.dataType) + lazy val converter4 = CatalystTypeConverters.createToScalaConverter(child4.dataType) (input: Row) => { func( - ScalaReflection.convertToScala(child0.eval(input), child0.dataType), - ScalaReflection.convertToScala(child1.eval(input), child1.dataType), - ScalaReflection.convertToScala(child2.eval(input), child2.dataType), - ScalaReflection.convertToScala(child3.eval(input), child3.dataType), - ScalaReflection.convertToScala(child4.eval(input), child4.dataType)) + converter0(child0.eval(input)), + converter1(child1.eval(input)), + converter2(child2.eval(input)), + converter3(child3.eval(input)), + converter4(child4.eval(input))) } - + case 6 => val func = function.asInstanceOf[(Any, Any, Any, Any, Any, Any) => Any] val child0 = children(0) @@ -129,16 +146,22 @@ case class ScalaUdf(function: AnyRef, dataType: DataType, children: Seq[Expressi val child3 = children(3) val child4 = children(4) val child5 = children(5) + lazy val converter0 = CatalystTypeConverters.createToScalaConverter(child0.dataType) + lazy val converter1 = CatalystTypeConverters.createToScalaConverter(child1.dataType) + lazy val converter2 = CatalystTypeConverters.createToScalaConverter(child2.dataType) + lazy val converter3 = CatalystTypeConverters.createToScalaConverter(child3.dataType) + lazy val converter4 = CatalystTypeConverters.createToScalaConverter(child4.dataType) + lazy val converter5 = CatalystTypeConverters.createToScalaConverter(child5.dataType) (input: Row) => { func( - ScalaReflection.convertToScala(child0.eval(input), child0.dataType), - ScalaReflection.convertToScala(child1.eval(input), child1.dataType), - ScalaReflection.convertToScala(child2.eval(input), child2.dataType), - ScalaReflection.convertToScala(child3.eval(input), child3.dataType), - ScalaReflection.convertToScala(child4.eval(input), child4.dataType), - ScalaReflection.convertToScala(child5.eval(input), child5.dataType)) + converter0(child0.eval(input)), + converter1(child1.eval(input)), + converter2(child2.eval(input)), + converter3(child3.eval(input)), + converter4(child4.eval(input)), + converter5(child5.eval(input))) } - + case 7 => val func = function.asInstanceOf[(Any, Any, Any, Any, Any, Any, Any) => Any] val child0 = children(0) @@ -148,17 +171,24 @@ case class ScalaUdf(function: AnyRef, dataType: DataType, children: Seq[Expressi val child4 = children(4) val child5 = children(5) val child6 = children(6) + lazy val converter0 = CatalystTypeConverters.createToScalaConverter(child0.dataType) + lazy val converter1 = CatalystTypeConverters.createToScalaConverter(child1.dataType) + lazy val converter2 = CatalystTypeConverters.createToScalaConverter(child2.dataType) + lazy val converter3 = CatalystTypeConverters.createToScalaConverter(child3.dataType) + lazy val converter4 = CatalystTypeConverters.createToScalaConverter(child4.dataType) + lazy val converter5 = CatalystTypeConverters.createToScalaConverter(child5.dataType) + lazy val converter6 = CatalystTypeConverters.createToScalaConverter(child6.dataType) (input: Row) => { func( - ScalaReflection.convertToScala(child0.eval(input), child0.dataType), - ScalaReflection.convertToScala(child1.eval(input), child1.dataType), - ScalaReflection.convertToScala(child2.eval(input), child2.dataType), - ScalaReflection.convertToScala(child3.eval(input), child3.dataType), - ScalaReflection.convertToScala(child4.eval(input), child4.dataType), - ScalaReflection.convertToScala(child5.eval(input), child5.dataType), - ScalaReflection.convertToScala(child6.eval(input), child6.dataType)) + converter0(child0.eval(input)), + converter1(child1.eval(input)), + converter2(child2.eval(input)), + converter3(child3.eval(input)), + converter4(child4.eval(input)), + converter5(child5.eval(input)), + converter6(child6.eval(input))) } - + case 8 => val func = function.asInstanceOf[(Any, Any, Any, Any, Any, Any, Any, Any) => Any] val child0 = children(0) @@ -169,18 +199,26 @@ case class ScalaUdf(function: AnyRef, dataType: DataType, children: Seq[Expressi val child5 = children(5) val child6 = children(6) val child7 = children(7) + lazy val converter0 = CatalystTypeConverters.createToScalaConverter(child0.dataType) + lazy val converter1 = CatalystTypeConverters.createToScalaConverter(child1.dataType) + lazy val converter2 = CatalystTypeConverters.createToScalaConverter(child2.dataType) + lazy val converter3 = CatalystTypeConverters.createToScalaConverter(child3.dataType) + lazy val converter4 = CatalystTypeConverters.createToScalaConverter(child4.dataType) + lazy val converter5 = CatalystTypeConverters.createToScalaConverter(child5.dataType) + lazy val converter6 = CatalystTypeConverters.createToScalaConverter(child6.dataType) + lazy val converter7 = CatalystTypeConverters.createToScalaConverter(child7.dataType) (input: Row) => { func( - ScalaReflection.convertToScala(child0.eval(input), child0.dataType), - ScalaReflection.convertToScala(child1.eval(input), child1.dataType), - ScalaReflection.convertToScala(child2.eval(input), child2.dataType), - ScalaReflection.convertToScala(child3.eval(input), child3.dataType), - ScalaReflection.convertToScala(child4.eval(input), child4.dataType), - ScalaReflection.convertToScala(child5.eval(input), child5.dataType), - ScalaReflection.convertToScala(child6.eval(input), child6.dataType), - ScalaReflection.convertToScala(child7.eval(input), child7.dataType)) + converter0(child0.eval(input)), + converter1(child1.eval(input)), + converter2(child2.eval(input)), + converter3(child3.eval(input)), + converter4(child4.eval(input)), + converter5(child5.eval(input)), + converter6(child6.eval(input)), + converter7(child7.eval(input))) } - + case 9 => val func = function.asInstanceOf[(Any, Any, Any, Any, Any, Any, Any, Any, Any) => Any] val child0 = children(0) @@ -192,19 +230,28 @@ case class ScalaUdf(function: AnyRef, dataType: DataType, children: Seq[Expressi val child6 = children(6) val child7 = children(7) val child8 = children(8) + lazy val converter0 = CatalystTypeConverters.createToScalaConverter(child0.dataType) + lazy val converter1 = CatalystTypeConverters.createToScalaConverter(child1.dataType) + lazy val converter2 = CatalystTypeConverters.createToScalaConverter(child2.dataType) + lazy val converter3 = CatalystTypeConverters.createToScalaConverter(child3.dataType) + lazy val converter4 = CatalystTypeConverters.createToScalaConverter(child4.dataType) + lazy val converter5 = CatalystTypeConverters.createToScalaConverter(child5.dataType) + lazy val converter6 = CatalystTypeConverters.createToScalaConverter(child6.dataType) + lazy val converter7 = CatalystTypeConverters.createToScalaConverter(child7.dataType) + lazy val converter8 = CatalystTypeConverters.createToScalaConverter(child8.dataType) (input: Row) => { func( - ScalaReflection.convertToScala(child0.eval(input), child0.dataType), - ScalaReflection.convertToScala(child1.eval(input), child1.dataType), - ScalaReflection.convertToScala(child2.eval(input), child2.dataType), - ScalaReflection.convertToScala(child3.eval(input), child3.dataType), - ScalaReflection.convertToScala(child4.eval(input), child4.dataType), - ScalaReflection.convertToScala(child5.eval(input), child5.dataType), - ScalaReflection.convertToScala(child6.eval(input), child6.dataType), - ScalaReflection.convertToScala(child7.eval(input), child7.dataType), - ScalaReflection.convertToScala(child8.eval(input), child8.dataType)) + converter0(child0.eval(input)), + converter1(child1.eval(input)), + converter2(child2.eval(input)), + converter3(child3.eval(input)), + converter4(child4.eval(input)), + converter5(child5.eval(input)), + converter6(child6.eval(input)), + converter7(child7.eval(input)), + converter8(child8.eval(input))) } - + case 10 => val func = function.asInstanceOf[(Any, Any, Any, Any, Any, Any, Any, Any, Any, Any) => Any] val child0 = children(0) @@ -217,20 +264,30 @@ case class ScalaUdf(function: AnyRef, dataType: DataType, children: Seq[Expressi val child7 = children(7) val child8 = children(8) val child9 = children(9) + lazy val converter0 = CatalystTypeConverters.createToScalaConverter(child0.dataType) + lazy val converter1 = CatalystTypeConverters.createToScalaConverter(child1.dataType) + lazy val converter2 = CatalystTypeConverters.createToScalaConverter(child2.dataType) + lazy val converter3 = CatalystTypeConverters.createToScalaConverter(child3.dataType) + lazy val converter4 = CatalystTypeConverters.createToScalaConverter(child4.dataType) + lazy val converter5 = CatalystTypeConverters.createToScalaConverter(child5.dataType) + lazy val converter6 = CatalystTypeConverters.createToScalaConverter(child6.dataType) + lazy val converter7 = CatalystTypeConverters.createToScalaConverter(child7.dataType) + lazy val converter8 = CatalystTypeConverters.createToScalaConverter(child8.dataType) + lazy val converter9 = CatalystTypeConverters.createToScalaConverter(child9.dataType) (input: Row) => { func( - ScalaReflection.convertToScala(child0.eval(input), child0.dataType), - ScalaReflection.convertToScala(child1.eval(input), child1.dataType), - ScalaReflection.convertToScala(child2.eval(input), child2.dataType), - ScalaReflection.convertToScala(child3.eval(input), child3.dataType), - ScalaReflection.convertToScala(child4.eval(input), child4.dataType), - ScalaReflection.convertToScala(child5.eval(input), child5.dataType), - ScalaReflection.convertToScala(child6.eval(input), child6.dataType), - ScalaReflection.convertToScala(child7.eval(input), child7.dataType), - ScalaReflection.convertToScala(child8.eval(input), child8.dataType), - ScalaReflection.convertToScala(child9.eval(input), child9.dataType)) + converter0(child0.eval(input)), + converter1(child1.eval(input)), + converter2(child2.eval(input)), + converter3(child3.eval(input)), + converter4(child4.eval(input)), + converter5(child5.eval(input)), + converter6(child6.eval(input)), + converter7(child7.eval(input)), + converter8(child8.eval(input)), + converter9(child9.eval(input))) } - + case 11 => val func = function.asInstanceOf[(Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any) => Any] val child0 = children(0) @@ -244,21 +301,32 @@ case class ScalaUdf(function: AnyRef, dataType: DataType, children: Seq[Expressi val child8 = children(8) val child9 = children(9) val child10 = children(10) + lazy val converter0 = CatalystTypeConverters.createToScalaConverter(child0.dataType) + lazy val converter1 = CatalystTypeConverters.createToScalaConverter(child1.dataType) + lazy val converter2 = CatalystTypeConverters.createToScalaConverter(child2.dataType) + lazy val converter3 = CatalystTypeConverters.createToScalaConverter(child3.dataType) + lazy val converter4 = CatalystTypeConverters.createToScalaConverter(child4.dataType) + lazy val converter5 = CatalystTypeConverters.createToScalaConverter(child5.dataType) + lazy val converter6 = CatalystTypeConverters.createToScalaConverter(child6.dataType) + lazy val converter7 = CatalystTypeConverters.createToScalaConverter(child7.dataType) + lazy val converter8 = CatalystTypeConverters.createToScalaConverter(child8.dataType) + lazy val converter9 = CatalystTypeConverters.createToScalaConverter(child9.dataType) + lazy val converter10 = CatalystTypeConverters.createToScalaConverter(child10.dataType) (input: Row) => { func( - ScalaReflection.convertToScala(child0.eval(input), child0.dataType), - ScalaReflection.convertToScala(child1.eval(input), child1.dataType), - ScalaReflection.convertToScala(child2.eval(input), child2.dataType), - ScalaReflection.convertToScala(child3.eval(input), child3.dataType), - ScalaReflection.convertToScala(child4.eval(input), child4.dataType), - ScalaReflection.convertToScala(child5.eval(input), child5.dataType), - ScalaReflection.convertToScala(child6.eval(input), child6.dataType), - ScalaReflection.convertToScala(child7.eval(input), child7.dataType), - ScalaReflection.convertToScala(child8.eval(input), child8.dataType), - ScalaReflection.convertToScala(child9.eval(input), child9.dataType), - ScalaReflection.convertToScala(child10.eval(input), child10.dataType)) + converter0(child0.eval(input)), + converter1(child1.eval(input)), + converter2(child2.eval(input)), + converter3(child3.eval(input)), + converter4(child4.eval(input)), + converter5(child5.eval(input)), + converter6(child6.eval(input)), + converter7(child7.eval(input)), + converter8(child8.eval(input)), + converter9(child9.eval(input)), + converter10(child10.eval(input))) } - + case 12 => val func = function.asInstanceOf[(Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any) => Any] val child0 = children(0) @@ -273,22 +341,34 @@ case class ScalaUdf(function: AnyRef, dataType: DataType, children: Seq[Expressi val child9 = children(9) val child10 = children(10) val child11 = children(11) + lazy val converter0 = CatalystTypeConverters.createToScalaConverter(child0.dataType) + lazy val converter1 = CatalystTypeConverters.createToScalaConverter(child1.dataType) + lazy val converter2 = CatalystTypeConverters.createToScalaConverter(child2.dataType) + lazy val converter3 = CatalystTypeConverters.createToScalaConverter(child3.dataType) + lazy val converter4 = CatalystTypeConverters.createToScalaConverter(child4.dataType) + lazy val converter5 = CatalystTypeConverters.createToScalaConverter(child5.dataType) + lazy val converter6 = CatalystTypeConverters.createToScalaConverter(child6.dataType) + lazy val converter7 = CatalystTypeConverters.createToScalaConverter(child7.dataType) + lazy val converter8 = CatalystTypeConverters.createToScalaConverter(child8.dataType) + lazy val converter9 = CatalystTypeConverters.createToScalaConverter(child9.dataType) + lazy val converter10 = CatalystTypeConverters.createToScalaConverter(child10.dataType) + lazy val converter11 = CatalystTypeConverters.createToScalaConverter(child11.dataType) (input: Row) => { func( - ScalaReflection.convertToScala(child0.eval(input), child0.dataType), - ScalaReflection.convertToScala(child1.eval(input), child1.dataType), - ScalaReflection.convertToScala(child2.eval(input), child2.dataType), - ScalaReflection.convertToScala(child3.eval(input), child3.dataType), - ScalaReflection.convertToScala(child4.eval(input), child4.dataType), - ScalaReflection.convertToScala(child5.eval(input), child5.dataType), - ScalaReflection.convertToScala(child6.eval(input), child6.dataType), - ScalaReflection.convertToScala(child7.eval(input), child7.dataType), - ScalaReflection.convertToScala(child8.eval(input), child8.dataType), - ScalaReflection.convertToScala(child9.eval(input), child9.dataType), - ScalaReflection.convertToScala(child10.eval(input), child10.dataType), - ScalaReflection.convertToScala(child11.eval(input), child11.dataType)) + converter0(child0.eval(input)), + converter1(child1.eval(input)), + converter2(child2.eval(input)), + converter3(child3.eval(input)), + converter4(child4.eval(input)), + converter5(child5.eval(input)), + converter6(child6.eval(input)), + converter7(child7.eval(input)), + converter8(child8.eval(input)), + converter9(child9.eval(input)), + converter10(child10.eval(input)), + converter11(child11.eval(input))) } - + case 13 => val func = function.asInstanceOf[(Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any) => Any] val child0 = children(0) @@ -304,23 +384,36 @@ case class ScalaUdf(function: AnyRef, dataType: DataType, children: Seq[Expressi val child10 = children(10) val child11 = children(11) val child12 = children(12) + lazy val converter0 = CatalystTypeConverters.createToScalaConverter(child0.dataType) + lazy val converter1 = CatalystTypeConverters.createToScalaConverter(child1.dataType) + lazy val converter2 = CatalystTypeConverters.createToScalaConverter(child2.dataType) + lazy val converter3 = CatalystTypeConverters.createToScalaConverter(child3.dataType) + lazy val converter4 = CatalystTypeConverters.createToScalaConverter(child4.dataType) + lazy val converter5 = CatalystTypeConverters.createToScalaConverter(child5.dataType) + lazy val converter6 = CatalystTypeConverters.createToScalaConverter(child6.dataType) + lazy val converter7 = CatalystTypeConverters.createToScalaConverter(child7.dataType) + lazy val converter8 = CatalystTypeConverters.createToScalaConverter(child8.dataType) + lazy val converter9 = CatalystTypeConverters.createToScalaConverter(child9.dataType) + lazy val converter10 = CatalystTypeConverters.createToScalaConverter(child10.dataType) + lazy val converter11 = CatalystTypeConverters.createToScalaConverter(child11.dataType) + lazy val converter12 = CatalystTypeConverters.createToScalaConverter(child12.dataType) (input: Row) => { func( - ScalaReflection.convertToScala(child0.eval(input), child0.dataType), - ScalaReflection.convertToScala(child1.eval(input), child1.dataType), - ScalaReflection.convertToScala(child2.eval(input), child2.dataType), - ScalaReflection.convertToScala(child3.eval(input), child3.dataType), - ScalaReflection.convertToScala(child4.eval(input), child4.dataType), - ScalaReflection.convertToScala(child5.eval(input), child5.dataType), - ScalaReflection.convertToScala(child6.eval(input), child6.dataType), - ScalaReflection.convertToScala(child7.eval(input), child7.dataType), - ScalaReflection.convertToScala(child8.eval(input), child8.dataType), - ScalaReflection.convertToScala(child9.eval(input), child9.dataType), - ScalaReflection.convertToScala(child10.eval(input), child10.dataType), - ScalaReflection.convertToScala(child11.eval(input), child11.dataType), - ScalaReflection.convertToScala(child12.eval(input), child12.dataType)) + converter0(child0.eval(input)), + converter1(child1.eval(input)), + converter2(child2.eval(input)), + converter3(child3.eval(input)), + converter4(child4.eval(input)), + converter5(child5.eval(input)), + converter6(child6.eval(input)), + converter7(child7.eval(input)), + converter8(child8.eval(input)), + converter9(child9.eval(input)), + converter10(child10.eval(input)), + converter11(child11.eval(input)), + converter12(child12.eval(input))) } - + case 14 => val func = function.asInstanceOf[(Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any) => Any] val child0 = children(0) @@ -337,24 +430,38 @@ case class ScalaUdf(function: AnyRef, dataType: DataType, children: Seq[Expressi val child11 = children(11) val child12 = children(12) val child13 = children(13) + lazy val converter0 = CatalystTypeConverters.createToScalaConverter(child0.dataType) + lazy val converter1 = CatalystTypeConverters.createToScalaConverter(child1.dataType) + lazy val converter2 = CatalystTypeConverters.createToScalaConverter(child2.dataType) + lazy val converter3 = CatalystTypeConverters.createToScalaConverter(child3.dataType) + lazy val converter4 = CatalystTypeConverters.createToScalaConverter(child4.dataType) + lazy val converter5 = CatalystTypeConverters.createToScalaConverter(child5.dataType) + lazy val converter6 = CatalystTypeConverters.createToScalaConverter(child6.dataType) + lazy val converter7 = CatalystTypeConverters.createToScalaConverter(child7.dataType) + lazy val converter8 = CatalystTypeConverters.createToScalaConverter(child8.dataType) + lazy val converter9 = CatalystTypeConverters.createToScalaConverter(child9.dataType) + lazy val converter10 = CatalystTypeConverters.createToScalaConverter(child10.dataType) + lazy val converter11 = CatalystTypeConverters.createToScalaConverter(child11.dataType) + lazy val converter12 = CatalystTypeConverters.createToScalaConverter(child12.dataType) + lazy val converter13 = CatalystTypeConverters.createToScalaConverter(child13.dataType) (input: Row) => { func( - ScalaReflection.convertToScala(child0.eval(input), child0.dataType), - ScalaReflection.convertToScala(child1.eval(input), child1.dataType), - ScalaReflection.convertToScala(child2.eval(input), child2.dataType), - ScalaReflection.convertToScala(child3.eval(input), child3.dataType), - ScalaReflection.convertToScala(child4.eval(input), child4.dataType), - ScalaReflection.convertToScala(child5.eval(input), child5.dataType), - ScalaReflection.convertToScala(child6.eval(input), child6.dataType), - ScalaReflection.convertToScala(child7.eval(input), child7.dataType), - ScalaReflection.convertToScala(child8.eval(input), child8.dataType), - ScalaReflection.convertToScala(child9.eval(input), child9.dataType), - ScalaReflection.convertToScala(child10.eval(input), child10.dataType), - ScalaReflection.convertToScala(child11.eval(input), child11.dataType), - ScalaReflection.convertToScala(child12.eval(input), child12.dataType), - ScalaReflection.convertToScala(child13.eval(input), child13.dataType)) + converter0(child0.eval(input)), + converter1(child1.eval(input)), + converter2(child2.eval(input)), + converter3(child3.eval(input)), + converter4(child4.eval(input)), + converter5(child5.eval(input)), + converter6(child6.eval(input)), + converter7(child7.eval(input)), + converter8(child8.eval(input)), + converter9(child9.eval(input)), + converter10(child10.eval(input)), + converter11(child11.eval(input)), + converter12(child12.eval(input)), + converter13(child13.eval(input))) } - + case 15 => val func = function.asInstanceOf[(Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any) => Any] val child0 = children(0) @@ -372,25 +479,40 @@ case class ScalaUdf(function: AnyRef, dataType: DataType, children: Seq[Expressi val child12 = children(12) val child13 = children(13) val child14 = children(14) + lazy val converter0 = CatalystTypeConverters.createToScalaConverter(child0.dataType) + lazy val converter1 = CatalystTypeConverters.createToScalaConverter(child1.dataType) + lazy val converter2 = CatalystTypeConverters.createToScalaConverter(child2.dataType) + lazy val converter3 = CatalystTypeConverters.createToScalaConverter(child3.dataType) + lazy val converter4 = CatalystTypeConverters.createToScalaConverter(child4.dataType) + lazy val converter5 = CatalystTypeConverters.createToScalaConverter(child5.dataType) + lazy val converter6 = CatalystTypeConverters.createToScalaConverter(child6.dataType) + lazy val converter7 = CatalystTypeConverters.createToScalaConverter(child7.dataType) + lazy val converter8 = CatalystTypeConverters.createToScalaConverter(child8.dataType) + lazy val converter9 = CatalystTypeConverters.createToScalaConverter(child9.dataType) + lazy val converter10 = CatalystTypeConverters.createToScalaConverter(child10.dataType) + lazy val converter11 = CatalystTypeConverters.createToScalaConverter(child11.dataType) + lazy val converter12 = CatalystTypeConverters.createToScalaConverter(child12.dataType) + lazy val converter13 = CatalystTypeConverters.createToScalaConverter(child13.dataType) + lazy val converter14 = CatalystTypeConverters.createToScalaConverter(child14.dataType) (input: Row) => { func( - ScalaReflection.convertToScala(child0.eval(input), child0.dataType), - ScalaReflection.convertToScala(child1.eval(input), child1.dataType), - ScalaReflection.convertToScala(child2.eval(input), child2.dataType), - ScalaReflection.convertToScala(child3.eval(input), child3.dataType), - ScalaReflection.convertToScala(child4.eval(input), child4.dataType), - ScalaReflection.convertToScala(child5.eval(input), child5.dataType), - ScalaReflection.convertToScala(child6.eval(input), child6.dataType), - ScalaReflection.convertToScala(child7.eval(input), child7.dataType), - ScalaReflection.convertToScala(child8.eval(input), child8.dataType), - ScalaReflection.convertToScala(child9.eval(input), child9.dataType), - ScalaReflection.convertToScala(child10.eval(input), child10.dataType), - ScalaReflection.convertToScala(child11.eval(input), child11.dataType), - ScalaReflection.convertToScala(child12.eval(input), child12.dataType), - ScalaReflection.convertToScala(child13.eval(input), child13.dataType), - ScalaReflection.convertToScala(child14.eval(input), child14.dataType)) + converter0(child0.eval(input)), + converter1(child1.eval(input)), + converter2(child2.eval(input)), + converter3(child3.eval(input)), + converter4(child4.eval(input)), + converter5(child5.eval(input)), + converter6(child6.eval(input)), + converter7(child7.eval(input)), + converter8(child8.eval(input)), + converter9(child9.eval(input)), + converter10(child10.eval(input)), + converter11(child11.eval(input)), + converter12(child12.eval(input)), + converter13(child13.eval(input)), + converter14(child14.eval(input))) } - + case 16 => val func = function.asInstanceOf[(Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any) => Any] val child0 = children(0) @@ -409,26 +531,42 @@ case class ScalaUdf(function: AnyRef, dataType: DataType, children: Seq[Expressi val child13 = children(13) val child14 = children(14) val child15 = children(15) + lazy val converter0 = CatalystTypeConverters.createToScalaConverter(child0.dataType) + lazy val converter1 = CatalystTypeConverters.createToScalaConverter(child1.dataType) + lazy val converter2 = CatalystTypeConverters.createToScalaConverter(child2.dataType) + lazy val converter3 = CatalystTypeConverters.createToScalaConverter(child3.dataType) + lazy val converter4 = CatalystTypeConverters.createToScalaConverter(child4.dataType) + lazy val converter5 = CatalystTypeConverters.createToScalaConverter(child5.dataType) + lazy val converter6 = CatalystTypeConverters.createToScalaConverter(child6.dataType) + lazy val converter7 = CatalystTypeConverters.createToScalaConverter(child7.dataType) + lazy val converter8 = CatalystTypeConverters.createToScalaConverter(child8.dataType) + lazy val converter9 = CatalystTypeConverters.createToScalaConverter(child9.dataType) + lazy val converter10 = CatalystTypeConverters.createToScalaConverter(child10.dataType) + lazy val converter11 = CatalystTypeConverters.createToScalaConverter(child11.dataType) + lazy val converter12 = CatalystTypeConverters.createToScalaConverter(child12.dataType) + lazy val converter13 = CatalystTypeConverters.createToScalaConverter(child13.dataType) + lazy val converter14 = CatalystTypeConverters.createToScalaConverter(child14.dataType) + lazy val converter15 = CatalystTypeConverters.createToScalaConverter(child15.dataType) (input: Row) => { func( - ScalaReflection.convertToScala(child0.eval(input), child0.dataType), - ScalaReflection.convertToScala(child1.eval(input), child1.dataType), - ScalaReflection.convertToScala(child2.eval(input), child2.dataType), - ScalaReflection.convertToScala(child3.eval(input), child3.dataType), - ScalaReflection.convertToScala(child4.eval(input), child4.dataType), - ScalaReflection.convertToScala(child5.eval(input), child5.dataType), - ScalaReflection.convertToScala(child6.eval(input), child6.dataType), - ScalaReflection.convertToScala(child7.eval(input), child7.dataType), - ScalaReflection.convertToScala(child8.eval(input), child8.dataType), - ScalaReflection.convertToScala(child9.eval(input), child9.dataType), - ScalaReflection.convertToScala(child10.eval(input), child10.dataType), - ScalaReflection.convertToScala(child11.eval(input), child11.dataType), - ScalaReflection.convertToScala(child12.eval(input), child12.dataType), - ScalaReflection.convertToScala(child13.eval(input), child13.dataType), - ScalaReflection.convertToScala(child14.eval(input), child14.dataType), - ScalaReflection.convertToScala(child15.eval(input), child15.dataType)) + converter0(child0.eval(input)), + converter1(child1.eval(input)), + converter2(child2.eval(input)), + converter3(child3.eval(input)), + converter4(child4.eval(input)), + converter5(child5.eval(input)), + converter6(child6.eval(input)), + converter7(child7.eval(input)), + converter8(child8.eval(input)), + converter9(child9.eval(input)), + converter10(child10.eval(input)), + converter11(child11.eval(input)), + converter12(child12.eval(input)), + converter13(child13.eval(input)), + converter14(child14.eval(input)), + converter15(child15.eval(input))) } - + case 17 => val func = function.asInstanceOf[(Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any) => Any] val child0 = children(0) @@ -448,27 +586,44 @@ case class ScalaUdf(function: AnyRef, dataType: DataType, children: Seq[Expressi val child14 = children(14) val child15 = children(15) val child16 = children(16) + lazy val converter0 = CatalystTypeConverters.createToScalaConverter(child0.dataType) + lazy val converter1 = CatalystTypeConverters.createToScalaConverter(child1.dataType) + lazy val converter2 = CatalystTypeConverters.createToScalaConverter(child2.dataType) + lazy val converter3 = CatalystTypeConverters.createToScalaConverter(child3.dataType) + lazy val converter4 = CatalystTypeConverters.createToScalaConverter(child4.dataType) + lazy val converter5 = CatalystTypeConverters.createToScalaConverter(child5.dataType) + lazy val converter6 = CatalystTypeConverters.createToScalaConverter(child6.dataType) + lazy val converter7 = CatalystTypeConverters.createToScalaConverter(child7.dataType) + lazy val converter8 = CatalystTypeConverters.createToScalaConverter(child8.dataType) + lazy val converter9 = CatalystTypeConverters.createToScalaConverter(child9.dataType) + lazy val converter10 = CatalystTypeConverters.createToScalaConverter(child10.dataType) + lazy val converter11 = CatalystTypeConverters.createToScalaConverter(child11.dataType) + lazy val converter12 = CatalystTypeConverters.createToScalaConverter(child12.dataType) + lazy val converter13 = CatalystTypeConverters.createToScalaConverter(child13.dataType) + lazy val converter14 = CatalystTypeConverters.createToScalaConverter(child14.dataType) + lazy val converter15 = CatalystTypeConverters.createToScalaConverter(child15.dataType) + lazy val converter16 = CatalystTypeConverters.createToScalaConverter(child16.dataType) (input: Row) => { func( - ScalaReflection.convertToScala(child0.eval(input), child0.dataType), - ScalaReflection.convertToScala(child1.eval(input), child1.dataType), - ScalaReflection.convertToScala(child2.eval(input), child2.dataType), - ScalaReflection.convertToScala(child3.eval(input), child3.dataType), - ScalaReflection.convertToScala(child4.eval(input), child4.dataType), - ScalaReflection.convertToScala(child5.eval(input), child5.dataType), - ScalaReflection.convertToScala(child6.eval(input), child6.dataType), - ScalaReflection.convertToScala(child7.eval(input), child7.dataType), - ScalaReflection.convertToScala(child8.eval(input), child8.dataType), - ScalaReflection.convertToScala(child9.eval(input), child9.dataType), - ScalaReflection.convertToScala(child10.eval(input), child10.dataType), - ScalaReflection.convertToScala(child11.eval(input), child11.dataType), - ScalaReflection.convertToScala(child12.eval(input), child12.dataType), - ScalaReflection.convertToScala(child13.eval(input), child13.dataType), - ScalaReflection.convertToScala(child14.eval(input), child14.dataType), - ScalaReflection.convertToScala(child15.eval(input), child15.dataType), - ScalaReflection.convertToScala(child16.eval(input), child16.dataType)) + converter0(child0.eval(input)), + converter1(child1.eval(input)), + converter2(child2.eval(input)), + converter3(child3.eval(input)), + converter4(child4.eval(input)), + converter5(child5.eval(input)), + converter6(child6.eval(input)), + converter7(child7.eval(input)), + converter8(child8.eval(input)), + converter9(child9.eval(input)), + converter10(child10.eval(input)), + converter11(child11.eval(input)), + converter12(child12.eval(input)), + converter13(child13.eval(input)), + converter14(child14.eval(input)), + converter15(child15.eval(input)), + converter16(child16.eval(input))) } - + case 18 => val func = function.asInstanceOf[(Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any) => Any] val child0 = children(0) @@ -489,28 +644,46 @@ case class ScalaUdf(function: AnyRef, dataType: DataType, children: Seq[Expressi val child15 = children(15) val child16 = children(16) val child17 = children(17) + lazy val converter0 = CatalystTypeConverters.createToScalaConverter(child0.dataType) + lazy val converter1 = CatalystTypeConverters.createToScalaConverter(child1.dataType) + lazy val converter2 = CatalystTypeConverters.createToScalaConverter(child2.dataType) + lazy val converter3 = CatalystTypeConverters.createToScalaConverter(child3.dataType) + lazy val converter4 = CatalystTypeConverters.createToScalaConverter(child4.dataType) + lazy val converter5 = CatalystTypeConverters.createToScalaConverter(child5.dataType) + lazy val converter6 = CatalystTypeConverters.createToScalaConverter(child6.dataType) + lazy val converter7 = CatalystTypeConverters.createToScalaConverter(child7.dataType) + lazy val converter8 = CatalystTypeConverters.createToScalaConverter(child8.dataType) + lazy val converter9 = CatalystTypeConverters.createToScalaConverter(child9.dataType) + lazy val converter10 = CatalystTypeConverters.createToScalaConverter(child10.dataType) + lazy val converter11 = CatalystTypeConverters.createToScalaConverter(child11.dataType) + lazy val converter12 = CatalystTypeConverters.createToScalaConverter(child12.dataType) + lazy val converter13 = CatalystTypeConverters.createToScalaConverter(child13.dataType) + lazy val converter14 = CatalystTypeConverters.createToScalaConverter(child14.dataType) + lazy val converter15 = CatalystTypeConverters.createToScalaConverter(child15.dataType) + lazy val converter16 = CatalystTypeConverters.createToScalaConverter(child16.dataType) + lazy val converter17 = CatalystTypeConverters.createToScalaConverter(child17.dataType) (input: Row) => { func( - ScalaReflection.convertToScala(child0.eval(input), child0.dataType), - ScalaReflection.convertToScala(child1.eval(input), child1.dataType), - ScalaReflection.convertToScala(child2.eval(input), child2.dataType), - ScalaReflection.convertToScala(child3.eval(input), child3.dataType), - ScalaReflection.convertToScala(child4.eval(input), child4.dataType), - ScalaReflection.convertToScala(child5.eval(input), child5.dataType), - ScalaReflection.convertToScala(child6.eval(input), child6.dataType), - ScalaReflection.convertToScala(child7.eval(input), child7.dataType), - ScalaReflection.convertToScala(child8.eval(input), child8.dataType), - ScalaReflection.convertToScala(child9.eval(input), child9.dataType), - ScalaReflection.convertToScala(child10.eval(input), child10.dataType), - ScalaReflection.convertToScala(child11.eval(input), child11.dataType), - ScalaReflection.convertToScala(child12.eval(input), child12.dataType), - ScalaReflection.convertToScala(child13.eval(input), child13.dataType), - ScalaReflection.convertToScala(child14.eval(input), child14.dataType), - ScalaReflection.convertToScala(child15.eval(input), child15.dataType), - ScalaReflection.convertToScala(child16.eval(input), child16.dataType), - ScalaReflection.convertToScala(child17.eval(input), child17.dataType)) + converter0(child0.eval(input)), + converter1(child1.eval(input)), + converter2(child2.eval(input)), + converter3(child3.eval(input)), + converter4(child4.eval(input)), + converter5(child5.eval(input)), + converter6(child6.eval(input)), + converter7(child7.eval(input)), + converter8(child8.eval(input)), + converter9(child9.eval(input)), + converter10(child10.eval(input)), + converter11(child11.eval(input)), + converter12(child12.eval(input)), + converter13(child13.eval(input)), + converter14(child14.eval(input)), + converter15(child15.eval(input)), + converter16(child16.eval(input)), + converter17(child17.eval(input))) } - + case 19 => val func = function.asInstanceOf[(Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any) => Any] val child0 = children(0) @@ -532,29 +705,48 @@ case class ScalaUdf(function: AnyRef, dataType: DataType, children: Seq[Expressi val child16 = children(16) val child17 = children(17) val child18 = children(18) + lazy val converter0 = CatalystTypeConverters.createToScalaConverter(child0.dataType) + lazy val converter1 = CatalystTypeConverters.createToScalaConverter(child1.dataType) + lazy val converter2 = CatalystTypeConverters.createToScalaConverter(child2.dataType) + lazy val converter3 = CatalystTypeConverters.createToScalaConverter(child3.dataType) + lazy val converter4 = CatalystTypeConverters.createToScalaConverter(child4.dataType) + lazy val converter5 = CatalystTypeConverters.createToScalaConverter(child5.dataType) + lazy val converter6 = CatalystTypeConverters.createToScalaConverter(child6.dataType) + lazy val converter7 = CatalystTypeConverters.createToScalaConverter(child7.dataType) + lazy val converter8 = CatalystTypeConverters.createToScalaConverter(child8.dataType) + lazy val converter9 = CatalystTypeConverters.createToScalaConverter(child9.dataType) + lazy val converter10 = CatalystTypeConverters.createToScalaConverter(child10.dataType) + lazy val converter11 = CatalystTypeConverters.createToScalaConverter(child11.dataType) + lazy val converter12 = CatalystTypeConverters.createToScalaConverter(child12.dataType) + lazy val converter13 = CatalystTypeConverters.createToScalaConverter(child13.dataType) + lazy val converter14 = CatalystTypeConverters.createToScalaConverter(child14.dataType) + lazy val converter15 = CatalystTypeConverters.createToScalaConverter(child15.dataType) + lazy val converter16 = CatalystTypeConverters.createToScalaConverter(child16.dataType) + lazy val converter17 = CatalystTypeConverters.createToScalaConverter(child17.dataType) + lazy val converter18 = CatalystTypeConverters.createToScalaConverter(child18.dataType) (input: Row) => { func( - ScalaReflection.convertToScala(child0.eval(input), child0.dataType), - ScalaReflection.convertToScala(child1.eval(input), child1.dataType), - ScalaReflection.convertToScala(child2.eval(input), child2.dataType), - ScalaReflection.convertToScala(child3.eval(input), child3.dataType), - ScalaReflection.convertToScala(child4.eval(input), child4.dataType), - ScalaReflection.convertToScala(child5.eval(input), child5.dataType), - ScalaReflection.convertToScala(child6.eval(input), child6.dataType), - ScalaReflection.convertToScala(child7.eval(input), child7.dataType), - ScalaReflection.convertToScala(child8.eval(input), child8.dataType), - ScalaReflection.convertToScala(child9.eval(input), child9.dataType), - ScalaReflection.convertToScala(child10.eval(input), child10.dataType), - ScalaReflection.convertToScala(child11.eval(input), child11.dataType), - ScalaReflection.convertToScala(child12.eval(input), child12.dataType), - ScalaReflection.convertToScala(child13.eval(input), child13.dataType), - ScalaReflection.convertToScala(child14.eval(input), child14.dataType), - ScalaReflection.convertToScala(child15.eval(input), child15.dataType), - ScalaReflection.convertToScala(child16.eval(input), child16.dataType), - ScalaReflection.convertToScala(child17.eval(input), child17.dataType), - ScalaReflection.convertToScala(child18.eval(input), child18.dataType)) + converter0(child0.eval(input)), + converter1(child1.eval(input)), + converter2(child2.eval(input)), + converter3(child3.eval(input)), + converter4(child4.eval(input)), + converter5(child5.eval(input)), + converter6(child6.eval(input)), + converter7(child7.eval(input)), + converter8(child8.eval(input)), + converter9(child9.eval(input)), + converter10(child10.eval(input)), + converter11(child11.eval(input)), + converter12(child12.eval(input)), + converter13(child13.eval(input)), + converter14(child14.eval(input)), + converter15(child15.eval(input)), + converter16(child16.eval(input)), + converter17(child17.eval(input)), + converter18(child18.eval(input))) } - + case 20 => val func = function.asInstanceOf[(Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any) => Any] val child0 = children(0) @@ -577,30 +769,50 @@ case class ScalaUdf(function: AnyRef, dataType: DataType, children: Seq[Expressi val child17 = children(17) val child18 = children(18) val child19 = children(19) + lazy val converter0 = CatalystTypeConverters.createToScalaConverter(child0.dataType) + lazy val converter1 = CatalystTypeConverters.createToScalaConverter(child1.dataType) + lazy val converter2 = CatalystTypeConverters.createToScalaConverter(child2.dataType) + lazy val converter3 = CatalystTypeConverters.createToScalaConverter(child3.dataType) + lazy val converter4 = CatalystTypeConverters.createToScalaConverter(child4.dataType) + lazy val converter5 = CatalystTypeConverters.createToScalaConverter(child5.dataType) + lazy val converter6 = CatalystTypeConverters.createToScalaConverter(child6.dataType) + lazy val converter7 = CatalystTypeConverters.createToScalaConverter(child7.dataType) + lazy val converter8 = CatalystTypeConverters.createToScalaConverter(child8.dataType) + lazy val converter9 = CatalystTypeConverters.createToScalaConverter(child9.dataType) + lazy val converter10 = CatalystTypeConverters.createToScalaConverter(child10.dataType) + lazy val converter11 = CatalystTypeConverters.createToScalaConverter(child11.dataType) + lazy val converter12 = CatalystTypeConverters.createToScalaConverter(child12.dataType) + lazy val converter13 = CatalystTypeConverters.createToScalaConverter(child13.dataType) + lazy val converter14 = CatalystTypeConverters.createToScalaConverter(child14.dataType) + lazy val converter15 = CatalystTypeConverters.createToScalaConverter(child15.dataType) + lazy val converter16 = CatalystTypeConverters.createToScalaConverter(child16.dataType) + lazy val converter17 = CatalystTypeConverters.createToScalaConverter(child17.dataType) + lazy val converter18 = CatalystTypeConverters.createToScalaConverter(child18.dataType) + lazy val converter19 = CatalystTypeConverters.createToScalaConverter(child19.dataType) (input: Row) => { func( - ScalaReflection.convertToScala(child0.eval(input), child0.dataType), - ScalaReflection.convertToScala(child1.eval(input), child1.dataType), - ScalaReflection.convertToScala(child2.eval(input), child2.dataType), - ScalaReflection.convertToScala(child3.eval(input), child3.dataType), - ScalaReflection.convertToScala(child4.eval(input), child4.dataType), - ScalaReflection.convertToScala(child5.eval(input), child5.dataType), - ScalaReflection.convertToScala(child6.eval(input), child6.dataType), - ScalaReflection.convertToScala(child7.eval(input), child7.dataType), - ScalaReflection.convertToScala(child8.eval(input), child8.dataType), - ScalaReflection.convertToScala(child9.eval(input), child9.dataType), - ScalaReflection.convertToScala(child10.eval(input), child10.dataType), - ScalaReflection.convertToScala(child11.eval(input), child11.dataType), - ScalaReflection.convertToScala(child12.eval(input), child12.dataType), - ScalaReflection.convertToScala(child13.eval(input), child13.dataType), - ScalaReflection.convertToScala(child14.eval(input), child14.dataType), - ScalaReflection.convertToScala(child15.eval(input), child15.dataType), - ScalaReflection.convertToScala(child16.eval(input), child16.dataType), - ScalaReflection.convertToScala(child17.eval(input), child17.dataType), - ScalaReflection.convertToScala(child18.eval(input), child18.dataType), - ScalaReflection.convertToScala(child19.eval(input), child19.dataType)) + converter0(child0.eval(input)), + converter1(child1.eval(input)), + converter2(child2.eval(input)), + converter3(child3.eval(input)), + converter4(child4.eval(input)), + converter5(child5.eval(input)), + converter6(child6.eval(input)), + converter7(child7.eval(input)), + converter8(child8.eval(input)), + converter9(child9.eval(input)), + converter10(child10.eval(input)), + converter11(child11.eval(input)), + converter12(child12.eval(input)), + converter13(child13.eval(input)), + converter14(child14.eval(input)), + converter15(child15.eval(input)), + converter16(child16.eval(input)), + converter17(child17.eval(input)), + converter18(child18.eval(input)), + converter19(child19.eval(input))) } - + case 21 => val func = function.asInstanceOf[(Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any) => Any] val child0 = children(0) @@ -624,31 +836,52 @@ case class ScalaUdf(function: AnyRef, dataType: DataType, children: Seq[Expressi val child18 = children(18) val child19 = children(19) val child20 = children(20) + lazy val converter0 = CatalystTypeConverters.createToScalaConverter(child0.dataType) + lazy val converter1 = CatalystTypeConverters.createToScalaConverter(child1.dataType) + lazy val converter2 = CatalystTypeConverters.createToScalaConverter(child2.dataType) + lazy val converter3 = CatalystTypeConverters.createToScalaConverter(child3.dataType) + lazy val converter4 = CatalystTypeConverters.createToScalaConverter(child4.dataType) + lazy val converter5 = CatalystTypeConverters.createToScalaConverter(child5.dataType) + lazy val converter6 = CatalystTypeConverters.createToScalaConverter(child6.dataType) + lazy val converter7 = CatalystTypeConverters.createToScalaConverter(child7.dataType) + lazy val converter8 = CatalystTypeConverters.createToScalaConverter(child8.dataType) + lazy val converter9 = CatalystTypeConverters.createToScalaConverter(child9.dataType) + lazy val converter10 = CatalystTypeConverters.createToScalaConverter(child10.dataType) + lazy val converter11 = CatalystTypeConverters.createToScalaConverter(child11.dataType) + lazy val converter12 = CatalystTypeConverters.createToScalaConverter(child12.dataType) + lazy val converter13 = CatalystTypeConverters.createToScalaConverter(child13.dataType) + lazy val converter14 = CatalystTypeConverters.createToScalaConverter(child14.dataType) + lazy val converter15 = CatalystTypeConverters.createToScalaConverter(child15.dataType) + lazy val converter16 = CatalystTypeConverters.createToScalaConverter(child16.dataType) + lazy val converter17 = CatalystTypeConverters.createToScalaConverter(child17.dataType) + lazy val converter18 = CatalystTypeConverters.createToScalaConverter(child18.dataType) + lazy val converter19 = CatalystTypeConverters.createToScalaConverter(child19.dataType) + lazy val converter20 = CatalystTypeConverters.createToScalaConverter(child20.dataType) (input: Row) => { func( - ScalaReflection.convertToScala(child0.eval(input), child0.dataType), - ScalaReflection.convertToScala(child1.eval(input), child1.dataType), - ScalaReflection.convertToScala(child2.eval(input), child2.dataType), - ScalaReflection.convertToScala(child3.eval(input), child3.dataType), - ScalaReflection.convertToScala(child4.eval(input), child4.dataType), - ScalaReflection.convertToScala(child5.eval(input), child5.dataType), - ScalaReflection.convertToScala(child6.eval(input), child6.dataType), - ScalaReflection.convertToScala(child7.eval(input), child7.dataType), - ScalaReflection.convertToScala(child8.eval(input), child8.dataType), - ScalaReflection.convertToScala(child9.eval(input), child9.dataType), - ScalaReflection.convertToScala(child10.eval(input), child10.dataType), - ScalaReflection.convertToScala(child11.eval(input), child11.dataType), - ScalaReflection.convertToScala(child12.eval(input), child12.dataType), - ScalaReflection.convertToScala(child13.eval(input), child13.dataType), - ScalaReflection.convertToScala(child14.eval(input), child14.dataType), - ScalaReflection.convertToScala(child15.eval(input), child15.dataType), - ScalaReflection.convertToScala(child16.eval(input), child16.dataType), - ScalaReflection.convertToScala(child17.eval(input), child17.dataType), - ScalaReflection.convertToScala(child18.eval(input), child18.dataType), - ScalaReflection.convertToScala(child19.eval(input), child19.dataType), - ScalaReflection.convertToScala(child20.eval(input), child20.dataType)) + converter0(child0.eval(input)), + converter1(child1.eval(input)), + converter2(child2.eval(input)), + converter3(child3.eval(input)), + converter4(child4.eval(input)), + converter5(child5.eval(input)), + converter6(child6.eval(input)), + converter7(child7.eval(input)), + converter8(child8.eval(input)), + converter9(child9.eval(input)), + converter10(child10.eval(input)), + converter11(child11.eval(input)), + converter12(child12.eval(input)), + converter13(child13.eval(input)), + converter14(child14.eval(input)), + converter15(child15.eval(input)), + converter16(child16.eval(input)), + converter17(child17.eval(input)), + converter18(child18.eval(input)), + converter19(child19.eval(input)), + converter20(child20.eval(input))) } - + case 22 => val func = function.asInstanceOf[(Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any) => Any] val child0 = children(0) @@ -673,35 +906,57 @@ case class ScalaUdf(function: AnyRef, dataType: DataType, children: Seq[Expressi val child19 = children(19) val child20 = children(20) val child21 = children(21) + lazy val converter0 = CatalystTypeConverters.createToScalaConverter(child0.dataType) + lazy val converter1 = CatalystTypeConverters.createToScalaConverter(child1.dataType) + lazy val converter2 = CatalystTypeConverters.createToScalaConverter(child2.dataType) + lazy val converter3 = CatalystTypeConverters.createToScalaConverter(child3.dataType) + lazy val converter4 = CatalystTypeConverters.createToScalaConverter(child4.dataType) + lazy val converter5 = CatalystTypeConverters.createToScalaConverter(child5.dataType) + lazy val converter6 = CatalystTypeConverters.createToScalaConverter(child6.dataType) + lazy val converter7 = CatalystTypeConverters.createToScalaConverter(child7.dataType) + lazy val converter8 = CatalystTypeConverters.createToScalaConverter(child8.dataType) + lazy val converter9 = CatalystTypeConverters.createToScalaConverter(child9.dataType) + lazy val converter10 = CatalystTypeConverters.createToScalaConverter(child10.dataType) + lazy val converter11 = CatalystTypeConverters.createToScalaConverter(child11.dataType) + lazy val converter12 = CatalystTypeConverters.createToScalaConverter(child12.dataType) + lazy val converter13 = CatalystTypeConverters.createToScalaConverter(child13.dataType) + lazy val converter14 = CatalystTypeConverters.createToScalaConverter(child14.dataType) + lazy val converter15 = CatalystTypeConverters.createToScalaConverter(child15.dataType) + lazy val converter16 = CatalystTypeConverters.createToScalaConverter(child16.dataType) + lazy val converter17 = CatalystTypeConverters.createToScalaConverter(child17.dataType) + lazy val converter18 = CatalystTypeConverters.createToScalaConverter(child18.dataType) + lazy val converter19 = CatalystTypeConverters.createToScalaConverter(child19.dataType) + lazy val converter20 = CatalystTypeConverters.createToScalaConverter(child20.dataType) + lazy val converter21 = CatalystTypeConverters.createToScalaConverter(child21.dataType) (input: Row) => { func( - ScalaReflection.convertToScala(child0.eval(input), child0.dataType), - ScalaReflection.convertToScala(child1.eval(input), child1.dataType), - ScalaReflection.convertToScala(child2.eval(input), child2.dataType), - ScalaReflection.convertToScala(child3.eval(input), child3.dataType), - ScalaReflection.convertToScala(child4.eval(input), child4.dataType), - ScalaReflection.convertToScala(child5.eval(input), child5.dataType), - ScalaReflection.convertToScala(child6.eval(input), child6.dataType), - ScalaReflection.convertToScala(child7.eval(input), child7.dataType), - ScalaReflection.convertToScala(child8.eval(input), child8.dataType), - ScalaReflection.convertToScala(child9.eval(input), child9.dataType), - ScalaReflection.convertToScala(child10.eval(input), child10.dataType), - ScalaReflection.convertToScala(child11.eval(input), child11.dataType), - ScalaReflection.convertToScala(child12.eval(input), child12.dataType), - ScalaReflection.convertToScala(child13.eval(input), child13.dataType), - ScalaReflection.convertToScala(child14.eval(input), child14.dataType), - ScalaReflection.convertToScala(child15.eval(input), child15.dataType), - ScalaReflection.convertToScala(child16.eval(input), child16.dataType), - ScalaReflection.convertToScala(child17.eval(input), child17.dataType), - ScalaReflection.convertToScala(child18.eval(input), child18.dataType), - ScalaReflection.convertToScala(child19.eval(input), child19.dataType), - ScalaReflection.convertToScala(child20.eval(input), child20.dataType), - ScalaReflection.convertToScala(child21.eval(input), child21.dataType)) + converter0(child0.eval(input)), + converter1(child1.eval(input)), + converter2(child2.eval(input)), + converter3(child3.eval(input)), + converter4(child4.eval(input)), + converter5(child5.eval(input)), + converter6(child6.eval(input)), + converter7(child7.eval(input)), + converter8(child8.eval(input)), + converter9(child9.eval(input)), + converter10(child10.eval(input)), + converter11(child11.eval(input)), + converter12(child12.eval(input)), + converter13(child13.eval(input)), + converter14(child14.eval(input)), + converter15(child15.eval(input)), + converter16(child16.eval(input)), + converter17(child17.eval(input)), + converter18(child18.eval(input)), + converter19(child19.eval(input)), + converter20(child20.eval(input)), + converter21(child21.eval(input))) } } - + // scalastyle:on - - override def eval(input: Row): Any = ScalaReflection.convertToCatalyst(f(input), dataType) + + override def eval(input: Row): Any = CatalystTypeConverters.convertToCatalyst(f(input), dataType) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregates.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregates.scala index 30da4faa3f1c6..14a855054b94d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregates.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregates.scala @@ -189,9 +189,10 @@ case class CollectHashSet(expressions: Seq[Expression]) extends AggregateExpress override def children: Seq[Expression] = expressions override def nullable: Boolean = false - override def dataType: ArrayType = ArrayType(expressions.head.dataType) + override def dataType: OpenHashSetUDT = new OpenHashSetUDT(expressions.head.dataType) override def toString: String = s"AddToHashSet(${expressions.mkString(",")})" - override def newInstance(): CollectHashSetFunction = new CollectHashSetFunction(expressions, this) + override def newInstance(): CollectHashSetFunction = + new CollectHashSetFunction(expressions, this) } case class CollectHashSetFunction( @@ -250,11 +251,28 @@ case class CombineSetsAndCountFunction( override def eval(input: Row): Any = seen.size.toLong } +/** The data type of ApproxCountDistinctPartition since its output is a HyperLogLog object. */ +private[sql] case object HyperLogLogUDT extends UserDefinedType[HyperLogLog] { + + override def sqlType: DataType = BinaryType + + /** Since we are using HyperLogLog internally, usually it will not be called. */ + override def serialize(obj: Any): Array[Byte] = + obj.asInstanceOf[HyperLogLog].getBytes + + + /** Since we are using HyperLogLog internally, usually it will not be called. */ + override def deserialize(datum: Any): HyperLogLog = + HyperLogLog.Builder.build(datum.asInstanceOf[Array[Byte]]) + + override def userClass: Class[HyperLogLog] = classOf[HyperLogLog] +} + case class ApproxCountDistinctPartition(child: Expression, relativeSD: Double) extends AggregateExpression with trees.UnaryNode[Expression] { override def nullable: Boolean = false - override def dataType: DataType = child.dataType + override def dataType: DataType = HyperLogLogUDT override def toString: String = s"APPROXIMATE COUNT(DISTINCT $child)" override def newInstance(): ApproxCountDistinctPartitionFunction = { new ApproxCountDistinctPartitionFunction(child, this, relativeSD) @@ -505,7 +523,8 @@ case class AverageFunction(expr: Expression, base: AggregateExpression) private var count: Long = _ private val sum = MutableLiteral(zero.eval(null), calcType) - private def addFunction(value: Any) = Add(sum, Cast(Literal(value, expr.dataType), calcType)) + private def addFunction(value: Any) = Add(sum, + Cast(Literal.create(value, expr.dataType), calcType)) override def eval(input: Row): Any = { if (count == 0L) { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala index d1abf3c0b64a5..aac56e1568332 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala @@ -464,7 +464,7 @@ abstract class CodeGenerator[InType <: AnyRef, OutType <: AnyRef] extends Loggin val itemEval = expressionEvaluator(item) val setEval = expressionEvaluator(set) - val ArrayType(elementType, _) = set.dataType + val elementType = set.dataType.asInstanceOf[OpenHashSetUDT].elementType itemEval.code ++ setEval.code ++ q""" @@ -482,7 +482,7 @@ abstract class CodeGenerator[InType <: AnyRef, OutType <: AnyRef] extends Loggin val leftEval = expressionEvaluator(left) val rightEval = expressionEvaluator(right) - val ArrayType(elementType, _) = left.dataType + val elementType = left.dataType.asInstanceOf[OpenHashSetUDT].elementType leftEval.code ++ rightEval.code ++ q""" diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypes.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypes.scala index 3b2b9211268a9..fc1f69655963d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypes.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypes.scala @@ -19,6 +19,8 @@ package org.apache.spark.sql.catalyst.expressions import scala.collection.Map +import org.apache.spark.sql.AnalysisException +import org.apache.spark.sql.catalyst.analysis.Resolver import org.apache.spark.sql.types._ /** @@ -81,6 +83,41 @@ trait GetField extends UnaryExpression { def field: StructField } +object GetField { + /** + * Returns the resolved `GetField`, and report error if no desired field or over one + * desired fields are found. + */ + def apply( + expr: Expression, + fieldName: String, + resolver: Resolver): GetField = { + def findField(fields: Array[StructField]): Int = { + val checkField = (f: StructField) => resolver(f.name, fieldName) + val ordinal = fields.indexWhere(checkField) + if (ordinal == -1) { + throw new AnalysisException( + s"No such struct field $fieldName in ${fields.map(_.name).mkString(", ")}") + } else if (fields.indexWhere(checkField, ordinal + 1) != -1) { + throw new AnalysisException( + s"Ambiguous reference to fields ${fields.filter(checkField).mkString(", ")}") + } else { + ordinal + } + } + expr.dataType match { + case StructType(fields) => + val ordinal = findField(fields) + StructGetField(expr, fields(ordinal), ordinal) + case ArrayType(StructType(fields), containsNull) => + val ordinal = findField(fields) + ArrayGetField(expr, fields(ordinal), ordinal, containsNull) + case otherType => + throw new AnalysisException(s"GetField is not valid on fields of type $otherType") + } + } +} + /** * Returns the value of fields in the Struct `child`. */ diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala index 19f3fc9c2291a..0e2d593e94124 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala @@ -41,6 +41,8 @@ object Literal { case _ => throw new RuntimeException("Unsupported literal type " + v.getClass + " " + v) } + + def create(v: Any, dataType: DataType): Literal = Literal(v, dataType) } /** @@ -62,7 +64,10 @@ object IntegerLiteral { } } -case class Literal(value: Any, dataType: DataType) extends LeafExpression { +/** + * In order to do type checking, use Literal.create() instead of constructor + */ +case class Literal protected (value: Any, dataType: DataType) extends LeafExpression { override def foldable: Boolean = true override def nullable: Boolean = value == null diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/rows.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/rows.scala index a8983df208318..0a275b84086cf 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/rows.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/rows.scala @@ -224,6 +224,7 @@ class RowOrdering(ordering: Seq[SortOrder]) extends Ordering[Row] { n.ordering.asInstanceOf[Ordering[Any]].compare(left, right) case n: NativeType if order.direction == Descending => n.ordering.asInstanceOf[Ordering[Any]].reverse.compare(left, right) + case other => sys.error(s"Type $other does not support ordered operations") } if (comparison != 0) return comparison } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/sets.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/sets.scala index 35faa00782e80..4c44182278207 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/sets.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/sets.scala @@ -20,6 +20,33 @@ package org.apache.spark.sql.catalyst.expressions import org.apache.spark.sql.types._ import org.apache.spark.util.collection.OpenHashSet +/** The data type for expressions returning an OpenHashSet as the result. */ +private[sql] class OpenHashSetUDT( + val elementType: DataType) extends UserDefinedType[OpenHashSet[Any]] { + + override def sqlType: DataType = ArrayType(elementType) + + /** Since we are using OpenHashSet internally, usually it will not be called. */ + override def serialize(obj: Any): Seq[Any] = { + obj.asInstanceOf[OpenHashSet[Any]].iterator.toSeq + } + + /** Since we are using OpenHashSet internally, usually it will not be called. */ + override def deserialize(datum: Any): OpenHashSet[Any] = { + val iterator = datum.asInstanceOf[Seq[Any]].iterator + val set = new OpenHashSet[Any] + while(iterator.hasNext) { + set.add(iterator.next()) + } + + set + } + + override def userClass: Class[OpenHashSet[Any]] = classOf[OpenHashSet[Any]] + + private[spark] override def asNullable: OpenHashSetUDT = this +} + /** * Creates a new set of the specified type */ @@ -28,9 +55,7 @@ case class NewSet(elementType: DataType) extends LeafExpression { override def nullable: Boolean = false - // We are currently only using these Expressions internally for aggregation. However, if we ever - // expose these to users we'll want to create a proper type instead of hijacking ArrayType. - override def dataType: DataType = ArrayType(elementType) + override def dataType: OpenHashSetUDT = new OpenHashSetUDT(elementType) override def eval(input: Row): Any = { new OpenHashSet[Any]() @@ -50,7 +75,7 @@ case class AddItemToSet(item: Expression, set: Expression) extends Expression { override def nullable: Boolean = set.nullable - override def dataType: DataType = set.dataType + override def dataType: OpenHashSetUDT = set.dataType.asInstanceOf[OpenHashSetUDT] override def eval(input: Row): Any = { val itemEval = item.eval(input) @@ -80,7 +105,7 @@ case class CombineSets(left: Expression, right: Expression) extends BinaryExpres override def nullable: Boolean = left.nullable || right.nullable - override def dataType: DataType = left.dataType + override def dataType: OpenHashSetUDT = left.dataType.asInstanceOf[OpenHashSetUDT] override def symbol: String = "++=" diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala index 3cdca4e9dd2d1..acfbbace608ef 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala @@ -156,12 +156,11 @@ case class Lower(child: Expression) extends UnaryExpression with CaseConversionE /** A base trait for functions that compare two strings, returning a boolean. */ trait StringComparison { - self: BinaryExpression => + self: BinaryPredicate => - type EvaluatedType = Any + override type EvaluatedType = Any override def nullable: Boolean = left.nullable || right.nullable - override def dataType: DataType = BooleanType def compare(l: String, r: String): Boolean @@ -184,7 +183,7 @@ trait StringComparison { * A function that returns true if the string `left` contains the string `right`. */ case class Contains(left: Expression, right: Expression) - extends BinaryExpression with StringComparison { + extends BinaryPredicate with StringComparison { override def compare(l: String, r: String): Boolean = l.contains(r) } @@ -192,7 +191,7 @@ case class Contains(left: Expression, right: Expression) * A function that returns true if the string `left` starts with the string `right`. */ case class StartsWith(left: Expression, right: Expression) - extends BinaryExpression with StringComparison { + extends BinaryPredicate with StringComparison { override def compare(l: String, r: String): Boolean = l.startsWith(r) } @@ -200,7 +199,7 @@ case class StartsWith(left: Expression, right: Expression) * A function that returns true if the string `left` ends with the string `right`. */ case class EndsWith(left: Expression, right: Expression) - extends BinaryExpression with StringComparison { + extends BinaryPredicate with StringComparison { override def compare(l: String, r: String): Boolean = l.endsWith(r) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index c23d3b61887c6..93e69d409cb91 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -218,12 +218,12 @@ object NullPropagation extends Rule[LogicalPlan] { def apply(plan: LogicalPlan): LogicalPlan = plan transform { case q: LogicalPlan => q transformExpressionsUp { case e @ Count(Literal(null, _)) => Cast(Literal(0L), e.dataType) - case e @ IsNull(c) if !c.nullable => Literal(false, BooleanType) - case e @ IsNotNull(c) if !c.nullable => Literal(true, BooleanType) - case e @ GetItem(Literal(null, _), _) => Literal(null, e.dataType) - case e @ GetItem(_, Literal(null, _)) => Literal(null, e.dataType) - case e @ StructGetField(Literal(null, _), _, _) => Literal(null, e.dataType) - case e @ ArrayGetField(Literal(null, _), _, _, _) => Literal(null, e.dataType) + case e @ IsNull(c) if !c.nullable => Literal.create(false, BooleanType) + case e @ IsNotNull(c) if !c.nullable => Literal.create(true, BooleanType) + case e @ GetItem(Literal(null, _), _) => Literal.create(null, e.dataType) + case e @ GetItem(_, Literal(null, _)) => Literal.create(null, e.dataType) + case e @ StructGetField(Literal(null, _), _, _) => Literal.create(null, e.dataType) + case e @ ArrayGetField(Literal(null, _), _, _, _) => Literal.create(null, e.dataType) case e @ EqualNullSafe(Literal(null, _), r) => IsNull(r) case e @ EqualNullSafe(l, Literal(null, _)) => IsNull(l) case e @ Count(expr) if !expr.nullable => Count(Literal(1)) @@ -235,36 +235,36 @@ object NullPropagation extends Rule[LogicalPlan] { case _ => true } if (newChildren.length == 0) { - Literal(null, e.dataType) + Literal.create(null, e.dataType) } else if (newChildren.length == 1) { newChildren(0) } else { Coalesce(newChildren) } - case e @ Substring(Literal(null, _), _, _) => Literal(null, e.dataType) - case e @ Substring(_, Literal(null, _), _) => Literal(null, e.dataType) - case e @ Substring(_, _, Literal(null, _)) => Literal(null, e.dataType) + case e @ Substring(Literal(null, _), _, _) => Literal.create(null, e.dataType) + case e @ Substring(_, Literal(null, _), _) => Literal.create(null, e.dataType) + case e @ Substring(_, _, Literal(null, _)) => Literal.create(null, e.dataType) // Put exceptional cases above if any case e: BinaryArithmetic => e.children match { - case Literal(null, _) :: right :: Nil => Literal(null, e.dataType) - case left :: Literal(null, _) :: Nil => Literal(null, e.dataType) + case Literal(null, _) :: right :: Nil => Literal.create(null, e.dataType) + case left :: Literal(null, _) :: Nil => Literal.create(null, e.dataType) case _ => e } case e: BinaryComparison => e.children match { - case Literal(null, _) :: right :: Nil => Literal(null, e.dataType) - case left :: Literal(null, _) :: Nil => Literal(null, e.dataType) + case Literal(null, _) :: right :: Nil => Literal.create(null, e.dataType) + case left :: Literal(null, _) :: Nil => Literal.create(null, e.dataType) case _ => e } case e: StringRegexExpression => e.children match { - case Literal(null, _) :: right :: Nil => Literal(null, e.dataType) - case left :: Literal(null, _) :: Nil => Literal(null, e.dataType) + case Literal(null, _) :: right :: Nil => Literal.create(null, e.dataType) + case left :: Literal(null, _) :: Nil => Literal.create(null, e.dataType) case _ => e } case e: StringComparison => e.children match { - case Literal(null, _) :: right :: Nil => Literal(null, e.dataType) - case left :: Literal(null, _) :: Nil => Literal(null, e.dataType) + case Literal(null, _) :: right :: Nil => Literal.create(null, e.dataType) + case left :: Literal(null, _) :: Nil => Literal.create(null, e.dataType) case _ => e } } @@ -284,13 +284,13 @@ object ConstantFolding extends Rule[LogicalPlan] { case l: Literal => l // Fold expressions that are foldable. - case e if e.foldable => Literal(e.eval(null), e.dataType) + case e if e.foldable => Literal.create(e.eval(null), e.dataType) // Fold "literal in (item1, item2, ..., literal, ...)" into true directly. case In(Literal(v, _), list) if list.exists { case Literal(candidate, _) if candidate == v => true case _ => false - } => Literal(true, BooleanType) + } => Literal.create(true, BooleanType) } } } @@ -647,7 +647,7 @@ object DecimalAggregates extends Rule[LogicalPlan] { case Average(e @ DecimalType.Expression(prec, scale)) if prec + 4 <= MAX_DOUBLE_DIGITS => Cast( - Divide(Average(UnscaledValue(e)), Literal(math.pow(10.0, scale), DoubleType)), + Divide(Average(UnscaledValue(e)), Literal.create(math.pow(10.0, scale), DoubleType)), DecimalType(prec + 4, scale + 4)) } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala index 02f7c26a8ab6e..7967189cacb24 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala @@ -150,7 +150,7 @@ abstract class QueryPlan[PlanType <: TreeNode[PlanType]] extends TreeNode[PlanTy }.toSeq } - def schema: StructType = StructType.fromAttributes(output) + lazy val schema: StructType = StructType.fromAttributes(output) /** Returns the output schema in the tree format. */ def schemaString: String = schema.treeString diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LocalRelation.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LocalRelation.scala index bb79dc340553b..e3e070f0ff307 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LocalRelation.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LocalRelation.scala @@ -18,9 +18,9 @@ package org.apache.spark.sql.catalyst.plans.logical import org.apache.spark.sql.Row -import org.apache.spark.sql.catalyst.analysis +import org.apache.spark.sql.catalyst.{CatalystTypeConverters, analysis} import org.apache.spark.sql.catalyst.expressions.Attribute -import org.apache.spark.sql.types.{DataTypeConversions, StructType, StructField} +import org.apache.spark.sql.types.{StructType, StructField} object LocalRelation { def apply(output: Attribute*): LocalRelation = new LocalRelation(output) @@ -31,7 +31,8 @@ object LocalRelation { def fromProduct(output: Seq[Attribute], data: Seq[Product]): LocalRelation = { val schema = StructType.fromAttributes(output) - LocalRelation(output, data.map(row => DataTypeConversions.productToRow(row, schema))) + val converter = CatalystTypeConverters.createToCatalystConverter(schema) + LocalRelation(output, data.map(converter(_).asInstanceOf[Row])) } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala index 2e9f3aa4ec4ad..579a0fb8d3f93 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala @@ -205,13 +205,12 @@ abstract class LogicalPlan extends QueryPlan[LogicalPlan] with Logging { // One match, but we also need to extract the requested nested field. case Seq((a, nestedFields)) => try { - - // The foldLeft adds UnresolvedGetField for every remaining parts of the name, - // and aliased it with the last part of the name. - // For example, consider name "a.b.c", where "a" is resolved to an existing attribute. - // Then this will add UnresolvedGetField("b") and UnresolvedGetField("c"), and alias + // The foldLeft adds GetFields for every remaining parts of the identifier, + // and aliases it with the last part of the identifier. + // For example, consider "a.b.c", where "a" is resolved to an existing attribute. + // Then this will add GetField("c", GetField("b", a)), and alias // the final expression as "c". - val fieldExprs = nestedFields.foldLeft(a: Expression)(resolveGetField(_, _, resolver)) + val fieldExprs = nestedFields.foldLeft(a: Expression)(GetField(_, _, resolver)) val aliasName = nestedFields.last Some(Alias(fieldExprs, aliasName)()) } catch { @@ -230,41 +229,6 @@ abstract class LogicalPlan extends QueryPlan[LogicalPlan] with Logging { s"Reference '$name' is ambiguous, could be: $referenceNames.") } } - - /** - * Returns the resolved `GetField`, and report error if no desired field or over one - * desired fields are found. - * - * TODO: this code is duplicated from Analyzer and should be refactored to avoid this. - */ - protected def resolveGetField( - expr: Expression, - fieldName: String, - resolver: Resolver): Expression = { - def findField(fields: Array[StructField]): Int = { - val checkField = (f: StructField) => resolver(f.name, fieldName) - val ordinal = fields.indexWhere(checkField) - if (ordinal == -1) { - throw new AnalysisException( - s"No such struct field $fieldName in ${fields.map(_.name).mkString(", ")}") - } else if (fields.indexWhere(checkField, ordinal + 1) != -1) { - throw new AnalysisException( - s"Ambiguous reference to fields ${fields.filter(checkField).mkString(", ")}") - } else { - ordinal - } - } - expr.dataType match { - case StructType(fields) => - val ordinal = findField(fields) - StructGetField(expr, fields(ordinal), ordinal) - case ArrayType(StructType(fields), containsNull) => - val ordinal = findField(fields) - ArrayGetField(expr, fields(ordinal), ordinal, containsNull) - case otherType => - throw new AnalysisException(s"GetField is not valid on fields of type $otherType") - } - } } /** diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala index 8633e06093cf3..5d31a6eecfce2 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala @@ -147,6 +147,18 @@ case class CreateTableAsSelect[T]( override lazy val resolved: Boolean = databaseName != None && childrenResolved } +/** + * A container for holding named common table expressions (CTEs) and a query plan. + * This operator will be removed during analysis and the relations will be substituted into child. + * @param child The final query of this CTE. + * @param cteRelations Queries that this CTE defined, + * key is the alias of the CTE definition, + * value is the CTE definition. + */ +case class With(child: LogicalPlan, cteRelations: Map[String, Subquery]) extends UnaryNode { + override def output: Seq[Attribute] = child.output +} + case class WriteToFile( path: String, child: LogicalPlan) extends UnaryNode { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DataTypeConversions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DataTypeConversions.scala deleted file mode 100644 index a9d63e784963d..0000000000000 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DataTypeConversions.scala +++ /dev/null @@ -1,77 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.types - -import java.text.SimpleDateFormat - -import org.apache.spark.sql.Row -import org.apache.spark.sql.catalyst.ScalaReflection -import org.apache.spark.sql.catalyst.expressions.GenericMutableRow - - -private[sql] object DataTypeConversions { - - def productToRow(product: Product, schema: StructType): Row = { - val mutableRow = new GenericMutableRow(product.productArity) - val schemaFields = schema.fields.toArray - - var i = 0 - while (i < mutableRow.length) { - mutableRow(i) = - ScalaReflection.convertToCatalyst(product.productElement(i), schemaFields(i).dataType) - i += 1 - } - - mutableRow - } - - def stringToTime(s: String): java.util.Date = { - if (!s.contains('T')) { - // JDBC escape string - if (s.contains(' ')) { - java.sql.Timestamp.valueOf(s) - } else { - java.sql.Date.valueOf(s) - } - } else if (s.endsWith("Z")) { - // this is zero timezone of ISO8601 - stringToTime(s.substring(0, s.length - 1) + "GMT-00:00") - } else if (s.indexOf("GMT") == -1) { - // timezone with ISO8601 - val inset = "+00.00".length - val s0 = s.substring(0, s.length - inset) - val s1 = s.substring(s.length - inset, s.length) - if (s0.substring(s0.lastIndexOf(':')).contains('.')) { - stringToTime(s0 + "GMT" + s1) - } else { - stringToTime(s0 + ".0GMT" + s1) - } - } else { - // ISO8601 with GMT insert - val ISO8601GMT: SimpleDateFormat = new SimpleDateFormat( "yyyy-MM-dd'T'HH:mm:ss.SSSz" ) - ISO8601GMT.parse(s) - } - } - - /** Converts Java objects to catalyst rows / types */ - def convertJavaToCatalyst(a: Any, dataType: DataType): Any = (a, dataType) match { - case (obj, udt: UserDefinedType[_]) => ScalaReflection.convertToCatalyst(obj, udt) // Scala type - case (d: java.math.BigDecimal, _) => Decimal(d) - case (other, _) => other - } -} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DataTypeParser.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DataTypeParser.scala index 34270d0ca7cd7..5163f05879e42 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DataTypeParser.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DataTypeParser.scala @@ -40,7 +40,7 @@ private[sql] trait DataTypeParser extends StandardTokenParsers { protected lazy val primitiveType: Parser[DataType] = "(?i)string".r ^^^ StringType | "(?i)float".r ^^^ FloatType | - "(?i)int".r ^^^ IntegerType | + "(?i)(?:int|integer)".r ^^^ IntegerType | "(?i)tinyint".r ^^^ ByteType | "(?i)smallint".r ^^^ ShortType | "(?i)double".r ^^^ DoubleType | diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DateUtils.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DateUtils.scala index 8a1a3b81b3d2c..504fb05842505 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DateUtils.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DateUtils.scala @@ -18,6 +18,7 @@ package org.apache.spark.sql.types import java.sql.Date +import java.text.SimpleDateFormat import java.util.{Calendar, TimeZone} import org.apache.spark.sql.catalyst.expressions.Cast @@ -57,4 +58,32 @@ object DateUtils { } def toString(days: Int): String = Cast.threadLocalDateFormat.get.format(toJavaDate(days)) + + def stringToTime(s: String): java.util.Date = { + if (!s.contains('T')) { + // JDBC escape string + if (s.contains(' ')) { + java.sql.Timestamp.valueOf(s) + } else { + java.sql.Date.valueOf(s) + } + } else if (s.endsWith("Z")) { + // this is zero timezone of ISO8601 + stringToTime(s.substring(0, s.length - 1) + "GMT-00:00") + } else if (s.indexOf("GMT") == -1) { + // timezone with ISO8601 + val inset = "+00.00".length + val s0 = s.substring(0, s.length - inset) + val s1 = s.substring(s.length - inset, s.length) + if (s0.substring(s0.lastIndexOf(':')).contains('.')) { + stringToTime(s0 + "GMT" + s1) + } else { + stringToTime(s0 + ".0GMT" + s1) + } + } else { + // ISO8601 with GMT insert + val ISO8601GMT: SimpleDateFormat = new SimpleDateFormat( "yyyy-MM-dd'T'HH:mm:ss.SSSz" ) + ISO8601GMT.parse(s) + } + } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/dataTypes.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/dataTypes.scala index 952cf5c75688d..cdf2bc68d9c5e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/dataTypes.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/dataTypes.scala @@ -20,6 +20,7 @@ package org.apache.spark.sql.types import java.sql.Timestamp import scala.collection.mutable.ArrayBuffer +import scala.math._ import scala.math.Numeric.{FloatAsIfIntegral, DoubleAsIfIntegral} import scala.reflect.ClassTag import scala.reflect.runtime.universe.{TypeTag, runtimeMirror, typeTag} @@ -934,7 +935,9 @@ object StructType { case (DecimalType.Fixed(leftPrecision, leftScale), DecimalType.Fixed(rightPrecision, rightScale)) => - DecimalType(leftPrecision.max(rightPrecision), leftScale.max(rightScale)) + DecimalType( + max(leftScale, rightScale) + max(leftPrecision - leftScale, rightPrecision - rightScale), + max(leftScale, rightScale)) case (leftUdt: UserDefinedType[_], rightUdt: UserDefinedType[_]) if leftUdt.userClass == rightUdt.userClass => leftUdt diff --git a/sql/catalyst/src/test/resources/log4j.properties b/sql/catalyst/src/test/resources/log4j.properties index 287c8e3563503..eb3b1999eb996 100644 --- a/sql/catalyst/src/test/resources/log4j.properties +++ b/sql/catalyst/src/test/resources/log4j.properties @@ -24,5 +24,5 @@ log4j.appender.file.layout=org.apache.log4j.PatternLayout log4j.appender.file.layout.ConversionPattern=%d{yy/MM/dd HH:mm:ss.SSS} %t %p %c{1}: %m%n # Ignore messages below warning level from Jetty, because it's a bit verbose -log4j.logger.org.eclipse.jetty=WARN -org.eclipse.jetty.LEVEL=WARN +log4j.logger.org.spark-project.jetty=WARN +org.spark-project.jetty.LEVEL=WARN diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/DistributionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/DistributionSuite.scala index 46b2250aab231..ea82cd2622de9 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/DistributionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/DistributionSuite.scala @@ -30,7 +30,7 @@ class DistributionSuite extends FunSuite { inputPartitioning: Partitioning, requiredDistribution: Distribution, satisfied: Boolean) { - if (inputPartitioning.satisfies(requiredDistribution) != satisfied) + if (inputPartitioning.satisfies(requiredDistribution) != satisfied) { fail( s""" |== Input Partitioning == @@ -40,6 +40,7 @@ class DistributionSuite extends FunSuite { |== Does input partitioning satisfy required distribution? == |Expected $satisfied got ${inputPartitioning.satisfies(requiredDistribution)} """.stripMargin) + } } test("HashPartitioning is the output partitioning") { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ScalaReflectionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ScalaReflectionSuite.scala index eee00e3f7ea76..bbc0b661a0c0c 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ScalaReflectionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ScalaReflectionSuite.scala @@ -260,7 +260,7 @@ class ScalaReflectionSuite extends FunSuite { val data = PrimitiveData(1, 1, 1, 1, 1, 1, true) val convertedData = Row(1, 1.toLong, 1.toDouble, 1.toFloat, 1.toShort, 1.toByte, true) val dataType = schemaFor[PrimitiveData].dataType - assert(convertToCatalyst(data, dataType) === convertedData) + assert(CatalystTypeConverters.convertToCatalyst(data, dataType) === convertedData) } test("convert Option[Product] to catalyst") { @@ -270,7 +270,7 @@ class ScalaReflectionSuite extends FunSuite { val dataType = schemaFor[OptionalData].dataType val convertedData = Row(2, 2.toLong, 2.toDouble, 2.toFloat, 2.toShort, 2.toByte, true, Row(1, 1, 1, 1, 1, 1, true)) - assert(convertToCatalyst(data, dataType) === convertedData) + assert(CatalystTypeConverters.convertToCatalyst(data, dataType) === convertedData) } test("infer schema from case class with multiple constructors") { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala index ee7b14c7a157c..6e3d6b9263e86 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala @@ -27,6 +27,8 @@ import org.apache.spark.sql.types._ import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.dsl.plans._ +import scala.collection.immutable + class AnalysisSuite extends FunSuite with BeforeAndAfter { val caseSensitiveCatalog = new SimpleCatalog(true) val caseInsensitiveCatalog = new SimpleCatalog(false) @@ -41,10 +43,10 @@ class AnalysisSuite extends FunSuite with BeforeAndAfter { } - def caseSensitiveAnalyze(plan: LogicalPlan) = + def caseSensitiveAnalyze(plan: LogicalPlan): Unit = caseSensitiveAnalyzer.checkAnalysis(caseSensitiveAnalyzer(plan)) - def caseInsensitiveAnalyze(plan: LogicalPlan) = + def caseInsensitiveAnalyze(plan: LogicalPlan): Unit = caseInsensitiveAnalyzer.checkAnalysis(caseInsensitiveAnalyzer(plan)) val testRelation = LocalRelation(AttributeReference("a", IntegerType, nullable = true)()) @@ -147,7 +149,7 @@ class AnalysisSuite extends FunSuite with BeforeAndAfter { name: String, plan: LogicalPlan, errorMessages: Seq[String], - caseSensitive: Boolean = true) = { + caseSensitive: Boolean = true): Unit = { test(name) { val error = intercept[AnalysisException] { if(caseSensitive) { @@ -202,7 +204,7 @@ class AnalysisSuite extends FunSuite with BeforeAndAfter { case class UnresolvedTestPlan() extends LeafNode { override lazy val resolved = false - override def output = Nil + override def output: Seq[Attribute] = Nil } errorTest( diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/DecimalPrecisionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/DecimalPrecisionSuite.scala index bc2ec754d5865..67bec999dfbd1 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/DecimalPrecisionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/DecimalPrecisionSuite.scala @@ -18,7 +18,7 @@ package org.apache.spark.sql.catalyst.analysis import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.plans.logical.{Project, LocalRelation} +import org.apache.spark.sql.catalyst.plans.logical.{Union, Project, LocalRelation} import org.apache.spark.sql.types._ import org.scalatest.{BeforeAndAfter, FunSuite} @@ -31,7 +31,8 @@ class DecimalPrecisionSuite extends FunSuite with BeforeAndAfter { AttributeReference("d1", DecimalType(2, 1))(), AttributeReference("d2", DecimalType(5, 2))(), AttributeReference("u", DecimalType.Unlimited)(), - AttributeReference("f", FloatType)() + AttributeReference("f", FloatType)(), + AttributeReference("b", DoubleType)() ) val i: Expression = UnresolvedAttribute("i") @@ -39,6 +40,7 @@ class DecimalPrecisionSuite extends FunSuite with BeforeAndAfter { val d2: Expression = UnresolvedAttribute("d2") val u: Expression = UnresolvedAttribute("u") val f: Expression = UnresolvedAttribute("f") + val b: Expression = UnresolvedAttribute("b") before { catalog.registerTable(Seq("table"), relation) @@ -58,6 +60,17 @@ class DecimalPrecisionSuite extends FunSuite with BeforeAndAfter { assert(comparison.right.dataType === expectedType) } + private def checkUnion(left: Expression, right: Expression, expectedType: DataType): Unit = { + val plan = + Union(Project(Seq(Alias(left, "l")()), relation), + Project(Seq(Alias(right, "r")()), relation)) + val (l, r) = analyzer(plan).collect { + case Union(left, right) => (left.output.head, right.output.head) + }.head + assert(l.dataType === expectedType) + assert(r.dataType === expectedType) + } + test("basic operations") { checkType(Add(d1, d2), DecimalType(6, 2)) checkType(Subtract(d1, d2), DecimalType(6, 2)) @@ -82,6 +95,19 @@ class DecimalPrecisionSuite extends FunSuite with BeforeAndAfter { checkComparison(GreaterThan(d2, d2), DecimalType(5, 2)) } + test("decimal precision for union") { + checkUnion(d1, i, DecimalType(11, 1)) + checkUnion(i, d2, DecimalType(12, 2)) + checkUnion(d1, d2, DecimalType(5, 2)) + checkUnion(d2, d1, DecimalType(5, 2)) + checkUnion(d1, f, DecimalType(8, 7)) + checkUnion(f, d2, DecimalType(10, 7)) + checkUnion(d1, b, DecimalType(16, 15)) + checkUnion(b, d2, DecimalType(18, 15)) + checkUnion(d1, u, DecimalType.Unlimited) + checkUnion(u, d2, DecimalType.Unlimited) + } + test("bringing in primitive types") { checkType(Add(d1, i), DecimalType(12, 1)) checkType(Add(d1, f), DoubleType) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercionSuite.scala index ecbb54218d457..fcd745f43cfbf 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercionSuite.scala @@ -96,7 +96,9 @@ class HiveTypeCoercionSuite extends PlanTest { widenTest(StringType, TimestampType, None) // ComplexType - widenTest(NullType, MapType(IntegerType, StringType, false), Some(MapType(IntegerType, StringType, false))) + widenTest(NullType, + MapType(IntegerType, StringType, false), + Some(MapType(IntegerType, StringType, false))) widenTest(NullType, StructType(Seq()), Some(StructType(Seq()))) widenTest(StringType, MapType(IntegerType, StringType, true), None) widenTest(ArrayType(IntegerType), StructType(Seq()), None) @@ -113,7 +115,9 @@ class HiveTypeCoercionSuite extends PlanTest { // Remove superflous boolean -> boolean casts. ruleTest(Cast(Literal(true), BooleanType), Literal(true)) // Stringify boolean when casting to string. - ruleTest(Cast(Literal(false), StringType), If(Literal(false), Literal("true"), Literal("false"))) + ruleTest( + Cast(Literal(false), StringType), + If(Literal(false), Literal("true"), Literal("false"))) } test("coalesce casts") { @@ -127,11 +131,11 @@ class HiveTypeCoercionSuite extends PlanTest { ruleTest( Coalesce(Literal(1.0) :: Literal(1) - :: Literal(1.0, FloatType) + :: Literal.create(1.0, FloatType) :: Nil), Coalesce(Cast(Literal(1.0), DoubleType) :: Cast(Literal(1), DoubleType) - :: Cast(Literal(1.0, FloatType), DoubleType) + :: Cast(Literal.create(1.0, FloatType), DoubleType) :: Nil)) ruleTest( Coalesce(Literal(1L) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvaluationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvaluationSuite.scala index 1183a0d899dda..d2b1090a0cdd5 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvaluationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvaluationSuite.scala @@ -82,10 +82,13 @@ class ExpressionEvaluationSuite extends ExpressionEvaluationBaseSuite { assert(BitwiseNot(1.toByte).eval(EmptyRow).isInstanceOf[Byte]) } + // scalastyle:off /** * Checks for three-valued-logic. Based on: * http://en.wikipedia.org/wiki/Null_(SQL)#Comparisons_with_NULL_and_the_three-valued_logic_.283VL.29 - * I.e. in flat cpo "False -> Unknown -> True", OR is lowest upper bound, AND is greatest lower bound. + * I.e. in flat cpo "False -> Unknown -> True", + * OR is lowest upper bound, + * AND is greatest lower bound. * p q p OR q p AND q p = q * True True True True True * True False True False False @@ -102,7 +105,7 @@ class ExpressionEvaluationSuite extends ExpressionEvaluationBaseSuite { * False True * Unknown Unknown */ - + // scalastyle:on val notTrueTable = (true, false) :: (false, true) :: @@ -111,7 +114,7 @@ class ExpressionEvaluationSuite extends ExpressionEvaluationBaseSuite { test("3VL Not") { notTrueTable.foreach { case (v, answer) => - checkEvaluation(!Literal(v, BooleanType), answer) + checkEvaluation(!Literal.create(v, BooleanType), answer) } } @@ -155,7 +158,7 @@ class ExpressionEvaluationSuite extends ExpressionEvaluationBaseSuite { test(s"3VL $name") { truthTable.foreach { case (l,r,answer) => - val expr = op(Literal(l, BooleanType), Literal(r, BooleanType)) + val expr = op(Literal.create(l, BooleanType), Literal.create(r, BooleanType)) checkEvaluation(expr, answer) } } @@ -165,7 +168,9 @@ class ExpressionEvaluationSuite extends ExpressionEvaluationBaseSuite { checkEvaluation(In(Literal(1), Seq(Literal(1), Literal(2))), true) checkEvaluation(In(Literal(2), Seq(Literal(1), Literal(2))), true) checkEvaluation(In(Literal(3), Seq(Literal(1), Literal(2))), false) - checkEvaluation(In(Literal(1), Seq(Literal(1), Literal(2))) && In(Literal(2), Seq(Literal(1), Literal(2))), true) + checkEvaluation( + In(Literal(1), Seq(Literal(1), Literal(2))) && In(Literal(2), Seq(Literal(1), Literal(2))), + true) } test("Divide") { @@ -175,12 +180,13 @@ class ExpressionEvaluationSuite extends ExpressionEvaluationBaseSuite { checkEvaluation(Divide(Literal(1), Literal(0)), null) checkEvaluation(Divide(Literal(1.0), Literal(0.0)), null) checkEvaluation(Divide(Literal(0.0), Literal(0.0)), null) - checkEvaluation(Divide(Literal(0), Literal(null, IntegerType)), null) - checkEvaluation(Divide(Literal(1), Literal(null, IntegerType)), null) - checkEvaluation(Divide(Literal(null, IntegerType), Literal(0)), null) - checkEvaluation(Divide(Literal(null, DoubleType), Literal(0.0)), null) - checkEvaluation(Divide(Literal(null, IntegerType), Literal(1)), null) - checkEvaluation(Divide(Literal(null, IntegerType), Literal(null, IntegerType)), null) + checkEvaluation(Divide(Literal(0), Literal.create(null, IntegerType)), null) + checkEvaluation(Divide(Literal(1), Literal.create(null, IntegerType)), null) + checkEvaluation(Divide(Literal.create(null, IntegerType), Literal(0)), null) + checkEvaluation(Divide(Literal.create(null, DoubleType), Literal(0.0)), null) + checkEvaluation(Divide(Literal.create(null, IntegerType), Literal(1)), null) + checkEvaluation(Divide(Literal.create(null, IntegerType), Literal.create(null, IntegerType)), + null) } test("Remainder") { @@ -190,12 +196,13 @@ class ExpressionEvaluationSuite extends ExpressionEvaluationBaseSuite { checkEvaluation(Remainder(Literal(1), Literal(0)), null) checkEvaluation(Remainder(Literal(1.0), Literal(0.0)), null) checkEvaluation(Remainder(Literal(0.0), Literal(0.0)), null) - checkEvaluation(Remainder(Literal(0), Literal(null, IntegerType)), null) - checkEvaluation(Remainder(Literal(1), Literal(null, IntegerType)), null) - checkEvaluation(Remainder(Literal(null, IntegerType), Literal(0)), null) - checkEvaluation(Remainder(Literal(null, DoubleType), Literal(0.0)), null) - checkEvaluation(Remainder(Literal(null, IntegerType), Literal(1)), null) - checkEvaluation(Remainder(Literal(null, IntegerType), Literal(null, IntegerType)), null) + checkEvaluation(Remainder(Literal(0), Literal.create(null, IntegerType)), null) + checkEvaluation(Remainder(Literal(1), Literal.create(null, IntegerType)), null) + checkEvaluation(Remainder(Literal.create(null, IntegerType), Literal(0)), null) + checkEvaluation(Remainder(Literal.create(null, DoubleType), Literal(0.0)), null) + checkEvaluation(Remainder(Literal.create(null, IntegerType), Literal(1)), null) + checkEvaluation(Remainder(Literal.create(null, IntegerType), Literal.create(null, IntegerType)), + null) } test("INSET") { @@ -222,14 +229,14 @@ class ExpressionEvaluationSuite extends ExpressionEvaluationBaseSuite { checkEvaluation(MaxOf(1L, 2L), 2L) checkEvaluation(MaxOf(2L, 1L), 2L) - checkEvaluation(MaxOf(Literal(null, IntegerType), 2), 2) - checkEvaluation(MaxOf(2, Literal(null, IntegerType)), 2) + checkEvaluation(MaxOf(Literal.create(null, IntegerType), 2), 2) + checkEvaluation(MaxOf(2, Literal.create(null, IntegerType)), 2) } test("LIKE literal Regular Expression") { - checkEvaluation(Literal(null, StringType).like("a"), null) - checkEvaluation(Literal("a", StringType).like(Literal(null, StringType)), null) - checkEvaluation(Literal(null, StringType).like(Literal(null, StringType)), null) + checkEvaluation(Literal.create(null, StringType).like("a"), null) + checkEvaluation(Literal.create("a", StringType).like(Literal.create(null, StringType)), null) + checkEvaluation(Literal.create(null, StringType).like(Literal.create(null, StringType)), null) checkEvaluation("abdef" like "abdef", true) checkEvaluation("a_%b" like "a\\__b", true) checkEvaluation("addb" like "a_%b", true) @@ -264,13 +271,14 @@ class ExpressionEvaluationSuite extends ExpressionEvaluationBaseSuite { checkEvaluation("ab" like regEx, true, new GenericRow(Array[Any]("a%b"))) checkEvaluation("a\nb" like regEx, true, new GenericRow(Array[Any]("a%b"))) - checkEvaluation(Literal(null, StringType) like regEx, null, new GenericRow(Array[Any]("bc%"))) + checkEvaluation(Literal.create(null, StringType) like regEx, null, + new GenericRow(Array[Any]("bc%"))) } test("RLIKE literal Regular Expression") { - checkEvaluation(Literal(null, StringType) rlike "abdef", null) - checkEvaluation("abdef" rlike Literal(null, StringType), null) - checkEvaluation(Literal(null, StringType) rlike Literal(null, StringType), null) + checkEvaluation(Literal.create(null, StringType) rlike "abdef", null) + checkEvaluation("abdef" rlike Literal.create(null, StringType), null) + checkEvaluation(Literal.create(null, StringType) rlike Literal.create(null, StringType), null) checkEvaluation("abdef" rlike "abdef", true) checkEvaluation("abbbbc" rlike "a.*c", true) @@ -381,7 +389,7 @@ class ExpressionEvaluationSuite extends ExpressionEvaluationBaseSuite { assert(("abcdef" cast DoubleType).nullable === true) assert(("abcdef" cast FloatType).nullable === true) - checkEvaluation(Cast(Literal(null, IntegerType), ShortType), null) + checkEvaluation(Cast(Literal.create(null, IntegerType), ShortType), null) } test("date") { @@ -507,8 +515,10 @@ class ExpressionEvaluationSuite extends ExpressionEvaluationBaseSuite { } test("array casting") { - val array = Literal(Seq("123", "abc", "", null), ArrayType(StringType, containsNull = true)) - val array_notNull = Literal(Seq("123", "abc", ""), ArrayType(StringType, containsNull = false)) + val array = Literal.create(Seq("123", "abc", "", null), + ArrayType(StringType, containsNull = true)) + val array_notNull = Literal.create(Seq("123", "abc", ""), + ArrayType(StringType, containsNull = false)) { val cast = Cast(array, ArrayType(IntegerType, containsNull = true)) @@ -556,10 +566,10 @@ class ExpressionEvaluationSuite extends ExpressionEvaluationBaseSuite { } test("map casting") { - val map = Literal( + val map = Literal.create( Map("a" -> "123", "b" -> "abc", "c" -> "", "d" -> null), MapType(StringType, StringType, valueContainsNull = true)) - val map_notNull = Literal( + val map_notNull = Literal.create( Map("a" -> "123", "b" -> "abc", "c" -> ""), MapType(StringType, StringType, valueContainsNull = false)) @@ -617,14 +627,14 @@ class ExpressionEvaluationSuite extends ExpressionEvaluationBaseSuite { } test("struct casting") { - val struct = Literal( + val struct = Literal.create( Row("123", "abc", "", null), StructType(Seq( StructField("a", StringType, nullable = true), StructField("b", StringType, nullable = true), StructField("c", StringType, nullable = true), StructField("d", StringType, nullable = true)))) - val struct_notNull = Literal( + val struct_notNull = Literal.create( Row("123", "abc", ""), StructType(Seq( StructField("a", StringType, nullable = false), @@ -712,7 +722,7 @@ class ExpressionEvaluationSuite extends ExpressionEvaluationBaseSuite { } test("complex casting") { - val complex = Literal( + val complex = Literal.create( Row( Seq("123", "abc", ""), Map("a" -> "123", "b" -> "abc", "c" -> ""), @@ -755,30 +765,31 @@ class ExpressionEvaluationSuite extends ExpressionEvaluationBaseSuite { checkEvaluation(c2.isNull, true, row) checkEvaluation(c2.isNotNull, false, row) - checkEvaluation(Literal(1, ShortType).isNull, false) - checkEvaluation(Literal(1, ShortType).isNotNull, true) + checkEvaluation(Literal.create(1, ShortType).isNull, false) + checkEvaluation(Literal.create(1, ShortType).isNotNull, true) - checkEvaluation(Literal(null, ShortType).isNull, true) - checkEvaluation(Literal(null, ShortType).isNotNull, false) + checkEvaluation(Literal.create(null, ShortType).isNull, true) + checkEvaluation(Literal.create(null, ShortType).isNotNull, false) checkEvaluation(Coalesce(c1 :: c2 :: Nil), "^Ba*n", row) - checkEvaluation(Coalesce(Literal(null, StringType) :: Nil), null, row) - checkEvaluation(Coalesce(Literal(null, StringType) :: c1 :: c2 :: Nil), "^Ba*n", row) + checkEvaluation(Coalesce(Literal.create(null, StringType) :: Nil), null, row) + checkEvaluation(Coalesce(Literal.create(null, StringType) :: c1 :: c2 :: Nil), "^Ba*n", row) - checkEvaluation(If(c3, Literal("a", StringType), Literal("b", StringType)), "a", row) + checkEvaluation( + If(c3, Literal.create("a", StringType), Literal.create("b", StringType)), "a", row) checkEvaluation(If(c3, c1, c2), "^Ba*n", row) checkEvaluation(If(c4, c2, c1), "^Ba*n", row) - checkEvaluation(If(Literal(null, BooleanType), c2, c1), "^Ba*n", row) - checkEvaluation(If(Literal(true, BooleanType), c1, c2), "^Ba*n", row) - checkEvaluation(If(Literal(false, BooleanType), c2, c1), "^Ba*n", row) - checkEvaluation(If(Literal(false, BooleanType), - Literal("a", StringType), Literal("b", StringType)), "b", row) + checkEvaluation(If(Literal.create(null, BooleanType), c2, c1), "^Ba*n", row) + checkEvaluation(If(Literal.create(true, BooleanType), c1, c2), "^Ba*n", row) + checkEvaluation(If(Literal.create(false, BooleanType), c2, c1), "^Ba*n", row) + checkEvaluation(If(Literal.create(false, BooleanType), + Literal.create("a", StringType), Literal.create("b", StringType)), "b", row) checkEvaluation(c1 in (c1, c2), true, row) checkEvaluation( - Literal("^Ba*n", StringType) in (Literal("^Ba*n", StringType)), true, row) + Literal.create("^Ba*n", StringType) in (Literal.create("^Ba*n", StringType)), true, row) checkEvaluation( - Literal("^Ba*n", StringType) in (Literal("^Ba*n", StringType), c2), true, row) + Literal.create("^Ba*n", StringType) in (Literal.create("^Ba*n", StringType), c2), true, row) } test("case when") { @@ -793,9 +804,9 @@ class ExpressionEvaluationSuite extends ExpressionEvaluationBaseSuite { checkEvaluation(CaseWhen(Seq(c1, c4, c6)), "c", row) checkEvaluation(CaseWhen(Seq(c2, c4, c6)), "c", row) checkEvaluation(CaseWhen(Seq(c3, c4, c6)), "a", row) - checkEvaluation(CaseWhen(Seq(Literal(null, BooleanType), c4, c6)), "c", row) - checkEvaluation(CaseWhen(Seq(Literal(false, BooleanType), c4, c6)), "c", row) - checkEvaluation(CaseWhen(Seq(Literal(true, BooleanType), c4, c6)), "a", row) + checkEvaluation(CaseWhen(Seq(Literal.create(null, BooleanType), c4, c6)), "c", row) + checkEvaluation(CaseWhen(Seq(Literal.create(false, BooleanType), c4, c6)), "c", row) + checkEvaluation(CaseWhen(Seq(Literal.create(true, BooleanType), c4, c6)), "a", row) checkEvaluation(CaseWhen(Seq(c3, c4, c2, c5, c6)), "a", row) checkEvaluation(CaseWhen(Seq(c2, c4, c3, c5, c6)), "b", row) @@ -841,19 +852,21 @@ class ExpressionEvaluationSuite extends ExpressionEvaluationBaseSuite { checkEvaluation(GetItem(BoundReference(3, typeMap, true), Literal("aa")), "bb", row) - checkEvaluation(GetItem(Literal(null, typeMap), Literal("aa")), null, row) - checkEvaluation(GetItem(Literal(null, typeMap), Literal(null, StringType)), null, row) + checkEvaluation(GetItem(Literal.create(null, typeMap), Literal("aa")), null, row) + checkEvaluation( + GetItem(Literal.create(null, typeMap), Literal.create(null, StringType)), null, row) checkEvaluation(GetItem(BoundReference(3, typeMap, true), - Literal(null, StringType)), null, row) + Literal.create(null, StringType)), null, row) checkEvaluation(GetItem(BoundReference(4, typeArray, true), Literal(1)), "bb", row) - checkEvaluation(GetItem(Literal(null, typeArray), Literal(1)), null, row) - checkEvaluation(GetItem(Literal(null, typeArray), Literal(null, IntegerType)), null, row) + checkEvaluation(GetItem(Literal.create(null, typeArray), Literal(1)), null, row) + checkEvaluation( + GetItem(Literal.create(null, typeArray), Literal.create(null, IntegerType)), null, row) checkEvaluation(GetItem(BoundReference(4, typeArray, true), - Literal(null, IntegerType)), null, row) + Literal.create(null, IntegerType)), null, row) - def quickBuildGetField(expr: Expression, fieldName: String) = { + def quickBuildGetField(expr: Expression, fieldName: String): StructGetField = { expr.dataType match { case StructType(fields) => val field = fields.find(_.name == fieldName).get @@ -861,10 +874,12 @@ class ExpressionEvaluationSuite extends ExpressionEvaluationBaseSuite { } } - def quickResolve(u: UnresolvedGetField) = quickBuildGetField(u.child, u.fieldName) + def quickResolve(u: UnresolvedGetField): StructGetField = { + quickBuildGetField(u.child, u.fieldName) + } checkEvaluation(quickBuildGetField(BoundReference(2, typeS, nullable = true), "a"), "aa", row) - checkEvaluation(quickBuildGetField(Literal(null, typeS), "a"), null, row) + checkEvaluation(quickBuildGetField(Literal.create(null, typeS), "a"), null, row) val typeS_notNullable = StructType( StructField("a", StringType, nullable = false) @@ -872,10 +887,11 @@ class ExpressionEvaluationSuite extends ExpressionEvaluationBaseSuite { ) assert(quickBuildGetField(BoundReference(2,typeS, nullable = true), "a").nullable === true) - assert(quickBuildGetField(BoundReference(2, typeS_notNullable, nullable = false), "a").nullable === false) + assert(quickBuildGetField(BoundReference(2, typeS_notNullable, nullable = false), "a").nullable + === false) - assert(quickBuildGetField(Literal(null, typeS), "a").nullable === true) - assert(quickBuildGetField(Literal(null, typeS_notNullable), "a").nullable === true) + assert(quickBuildGetField(Literal.create(null, typeS), "a").nullable === true) + assert(quickBuildGetField(Literal.create(null, typeS_notNullable), "a").nullable === true) checkEvaluation('c.map(typeMap).at(3).getItem("aa"), "bb", row) checkEvaluation('c.array(typeArray.elementType).at(4).getItem(1), "bb", row) @@ -890,13 +906,14 @@ class ExpressionEvaluationSuite extends ExpressionEvaluationBaseSuite { val c4 = 'a.int.at(3) checkEvaluation(UnaryMinus(c1), -1, row) - checkEvaluation(UnaryMinus(Literal(100, IntegerType)), -100) + checkEvaluation(UnaryMinus(Literal.create(100, IntegerType)), -100) checkEvaluation(Add(c1, c4), null, row) checkEvaluation(Add(c1, c2), 3, row) - checkEvaluation(Add(c1, Literal(null, IntegerType)), null, row) - checkEvaluation(Add(Literal(null, IntegerType), c2), null, row) - checkEvaluation(Add(Literal(null, IntegerType), Literal(null, IntegerType)), null, row) + checkEvaluation(Add(c1, Literal.create(null, IntegerType)), null, row) + checkEvaluation(Add(Literal.create(null, IntegerType), c2), null, row) + checkEvaluation( + Add(Literal.create(null, IntegerType), Literal.create(null, IntegerType)), null, row) checkEvaluation(-c1, -1, row) checkEvaluation(c1 + c2, 3, row) @@ -914,12 +931,13 @@ class ExpressionEvaluationSuite extends ExpressionEvaluationBaseSuite { val c4 = 'a.double.at(3) checkEvaluation(UnaryMinus(c1), -1.1, row) - checkEvaluation(UnaryMinus(Literal(100.0, DoubleType)), -100.0) + checkEvaluation(UnaryMinus(Literal.create(100.0, DoubleType)), -100.0) checkEvaluation(Add(c1, c4), null, row) checkEvaluation(Add(c1, c2), 3.1, row) - checkEvaluation(Add(c1, Literal(null, DoubleType)), null, row) - checkEvaluation(Add(Literal(null, DoubleType), c2), null, row) - checkEvaluation(Add(Literal(null, DoubleType), Literal(null, DoubleType)), null, row) + checkEvaluation(Add(c1, Literal.create(null, DoubleType)), null, row) + checkEvaluation(Add(Literal.create(null, DoubleType), c2), null, row) + checkEvaluation( + Add(Literal.create(null, DoubleType), Literal.create(null, DoubleType)), null, row) checkEvaluation(-c1, -1.1, row) checkEvaluation(c1 + c2, 3.1, row) @@ -940,9 +958,10 @@ class ExpressionEvaluationSuite extends ExpressionEvaluationBaseSuite { checkEvaluation(LessThan(c1, c4), null, row) checkEvaluation(LessThan(c1, c2), true, row) - checkEvaluation(LessThan(c1, Literal(null, IntegerType)), null, row) - checkEvaluation(LessThan(Literal(null, IntegerType), c2), null, row) - checkEvaluation(LessThan(Literal(null, IntegerType), Literal(null, IntegerType)), null, row) + checkEvaluation(LessThan(c1, Literal.create(null, IntegerType)), null, row) + checkEvaluation(LessThan(Literal.create(null, IntegerType), c2), null, row) + checkEvaluation( + LessThan(Literal.create(null, IntegerType), Literal.create(null, IntegerType)), null, row) checkEvaluation(c1 < c2, true, row) checkEvaluation(c1 <= c2, true, row) @@ -954,8 +973,8 @@ class ExpressionEvaluationSuite extends ExpressionEvaluationBaseSuite { checkEvaluation(c1 <=> c4, false, row) checkEvaluation(c4 <=> c6, true, row) checkEvaluation(c3 <=> c5, true, row) - checkEvaluation(Literal(true) <=> Literal(null, BooleanType), false, row) - checkEvaluation(Literal(null, BooleanType) <=> Literal(true), false, row) + checkEvaluation(Literal(true) <=> Literal.create(null, BooleanType), false, row) + checkEvaluation(Literal.create(null, BooleanType) <=> Literal(true), false, row) } test("StringComparison") { @@ -966,17 +985,17 @@ class ExpressionEvaluationSuite extends ExpressionEvaluationBaseSuite { checkEvaluation(c1 contains "b", true, row) checkEvaluation(c1 contains "x", false, row) checkEvaluation(c2 contains "b", null, row) - checkEvaluation(c1 contains Literal(null, StringType), null, row) + checkEvaluation(c1 contains Literal.create(null, StringType), null, row) checkEvaluation(c1 startsWith "a", true, row) checkEvaluation(c1 startsWith "b", false, row) checkEvaluation(c2 startsWith "a", null, row) - checkEvaluation(c1 startsWith Literal(null, StringType), null, row) + checkEvaluation(c1 startsWith Literal.create(null, StringType), null, row) checkEvaluation(c1 endsWith "c", true, row) checkEvaluation(c1 endsWith "b", false, row) checkEvaluation(c2 endsWith "b", null, row) - checkEvaluation(c1 endsWith Literal(null, StringType), null, row) + checkEvaluation(c1 endsWith Literal.create(null, StringType), null, row) } test("Substring") { @@ -985,54 +1004,84 @@ class ExpressionEvaluationSuite extends ExpressionEvaluationBaseSuite { val s = 'a.string.at(0) // substring from zero position with less-than-full length - checkEvaluation(Substring(s, Literal(0, IntegerType), Literal(2, IntegerType)), "ex", row) - checkEvaluation(Substring(s, Literal(1, IntegerType), Literal(2, IntegerType)), "ex", row) + checkEvaluation( + Substring(s, Literal.create(0, IntegerType), Literal.create(2, IntegerType)), "ex", row) + checkEvaluation( + Substring(s, Literal.create(1, IntegerType), Literal.create(2, IntegerType)), "ex", row) // substring from zero position with full length - checkEvaluation(Substring(s, Literal(0, IntegerType), Literal(7, IntegerType)), "example", row) - checkEvaluation(Substring(s, Literal(1, IntegerType), Literal(7, IntegerType)), "example", row) + checkEvaluation( + Substring(s, Literal.create(0, IntegerType), Literal.create(7, IntegerType)), "example", row) + checkEvaluation( + Substring(s, Literal.create(1, IntegerType), Literal.create(7, IntegerType)), "example", row) // substring from zero position with greater-than-full length - checkEvaluation(Substring(s, Literal(0, IntegerType), Literal(100, IntegerType)), "example", row) - checkEvaluation(Substring(s, Literal(1, IntegerType), Literal(100, IntegerType)), "example", row) + checkEvaluation(Substring(s, Literal.create(0, IntegerType), Literal.create(100, IntegerType)), + "example", row) + checkEvaluation(Substring(s, Literal.create(1, IntegerType), Literal.create(100, IntegerType)), + "example", row) // substring from nonzero position with less-than-full length - checkEvaluation(Substring(s, Literal(2, IntegerType), Literal(2, IntegerType)), "xa", row) + checkEvaluation(Substring(s, Literal.create(2, IntegerType), Literal.create(2, IntegerType)), + "xa", row) // substring from nonzero position with full length - checkEvaluation(Substring(s, Literal(2, IntegerType), Literal(6, IntegerType)), "xample", row) + checkEvaluation(Substring(s, Literal.create(2, IntegerType), Literal.create(6, IntegerType)), + "xample", row) // substring from nonzero position with greater-than-full length - checkEvaluation(Substring(s, Literal(2, IntegerType), Literal(100, IntegerType)), "xample", row) + checkEvaluation(Substring(s, Literal.create(2, IntegerType), Literal.create(100, IntegerType)), + "xample", row) // zero-length substring (within string bounds) - checkEvaluation(Substring(s, Literal(0, IntegerType), Literal(0, IntegerType)), "", row) + checkEvaluation(Substring(s, Literal.create(0, IntegerType), Literal.create(0, IntegerType)), + "", row) // zero-length substring (beyond string bounds) - checkEvaluation(Substring(s, Literal(100, IntegerType), Literal(4, IntegerType)), "", row) + checkEvaluation(Substring(s, Literal.create(100, IntegerType), Literal.create(4, IntegerType)), + "", row) // substring(null, _, _) -> null - checkEvaluation(Substring(s, Literal(100, IntegerType), Literal(4, IntegerType)), null, new GenericRow(Array[Any](null))) + checkEvaluation(Substring(s, Literal.create(100, IntegerType), Literal.create(4, IntegerType)), + null, new GenericRow(Array[Any](null))) // substring(_, null, _) -> null - checkEvaluation(Substring(s, Literal(null, IntegerType), Literal(4, IntegerType)), null, row) + checkEvaluation(Substring(s, Literal.create(null, IntegerType), Literal.create(4, IntegerType)), + null, row) // substring(_, _, null) -> null - checkEvaluation(Substring(s, Literal(100, IntegerType), Literal(null, IntegerType)), null, row) + checkEvaluation( + Substring(s, Literal.create(100, IntegerType), Literal.create(null, IntegerType)), + null, + row) // 2-arg substring from zero position - checkEvaluation(Substring(s, Literal(0, IntegerType), Literal(Integer.MAX_VALUE, IntegerType)), "example", row) - checkEvaluation(Substring(s, Literal(1, IntegerType), Literal(Integer.MAX_VALUE, IntegerType)), "example", row) + checkEvaluation( + Substring(s, Literal.create(0, IntegerType), Literal.create(Integer.MAX_VALUE, IntegerType)), + "example", + row) + checkEvaluation( + Substring(s, Literal.create(1, IntegerType), Literal.create(Integer.MAX_VALUE, IntegerType)), + "example", + row) // 2-arg substring from nonzero position - checkEvaluation(Substring(s, Literal(2, IntegerType), Literal(Integer.MAX_VALUE, IntegerType)), "xample", row) + checkEvaluation( + Substring(s, Literal.create(2, IntegerType), Literal.create(Integer.MAX_VALUE, IntegerType)), + "xample", + row) val s_notNull = 'a.string.notNull.at(0) - assert(Substring(s, Literal(0, IntegerType), Literal(2, IntegerType)).nullable === true) - assert(Substring(s_notNull, Literal(0, IntegerType), Literal(2, IntegerType)).nullable === false) - assert(Substring(s_notNull, Literal(null, IntegerType), Literal(2, IntegerType)).nullable === true) - assert(Substring(s_notNull, Literal(0, IntegerType), Literal(null, IntegerType)).nullable === true) + assert(Substring(s, Literal.create(0, IntegerType), Literal.create(2, IntegerType)).nullable + === true) + assert( + Substring(s_notNull, Literal.create(0, IntegerType), Literal.create(2, IntegerType)).nullable + === false) + assert(Substring(s_notNull, + Literal.create(null, IntegerType), Literal.create(2, IntegerType)).nullable === true) + assert(Substring(s_notNull, + Literal.create(0, IntegerType), Literal.create(null, IntegerType)).nullable === true) checkEvaluation(s.substr(0, 2), "ex", row) checkEvaluation(s.substr(0), "example", row) @@ -1050,7 +1099,7 @@ class ExpressionEvaluationSuite extends ExpressionEvaluationBaseSuite { checkEvaluation(Sqrt(d), expected, row) } - checkEvaluation(Sqrt(Literal(null, DoubleType)), null, new GenericRow(Array[Any](null))) + checkEvaluation(Sqrt(Literal.create(null, DoubleType)), null, new GenericRow(Array[Any](null))) checkEvaluation(Sqrt(-1), null, EmptyRow) checkEvaluation(Sqrt(-1.5), null, EmptyRow) } @@ -1064,22 +1113,25 @@ class ExpressionEvaluationSuite extends ExpressionEvaluationBaseSuite { checkEvaluation(BitwiseAnd(c1, c4), null, row) checkEvaluation(BitwiseAnd(c1, c2), 0, row) - checkEvaluation(BitwiseAnd(c1, Literal(null, IntegerType)), null, row) - checkEvaluation(BitwiseAnd(Literal(null, IntegerType), Literal(null, IntegerType)), null, row) + checkEvaluation(BitwiseAnd(c1, Literal.create(null, IntegerType)), null, row) + checkEvaluation( + BitwiseAnd(Literal.create(null, IntegerType), Literal.create(null, IntegerType)), null, row) checkEvaluation(BitwiseOr(c1, c4), null, row) checkEvaluation(BitwiseOr(c1, c2), 3, row) - checkEvaluation(BitwiseOr(c1, Literal(null, IntegerType)), null, row) - checkEvaluation(BitwiseOr(Literal(null, IntegerType), Literal(null, IntegerType)), null, row) + checkEvaluation(BitwiseOr(c1, Literal.create(null, IntegerType)), null, row) + checkEvaluation( + BitwiseOr(Literal.create(null, IntegerType), Literal.create(null, IntegerType)), null, row) checkEvaluation(BitwiseXor(c1, c4), null, row) checkEvaluation(BitwiseXor(c1, c2), 3, row) - checkEvaluation(BitwiseXor(c1, Literal(null, IntegerType)), null, row) - checkEvaluation(BitwiseXor(Literal(null, IntegerType), Literal(null, IntegerType)), null, row) + checkEvaluation(BitwiseXor(c1, Literal.create(null, IntegerType)), null, row) + checkEvaluation( + BitwiseXor(Literal.create(null, IntegerType), Literal.create(null, IntegerType)), null, row) checkEvaluation(BitwiseNot(c4), null, row) checkEvaluation(BitwiseNot(c1), -2, row) - checkEvaluation(BitwiseNot(Literal(null, IntegerType)), null, row) + checkEvaluation(BitwiseNot(Literal.create(null, IntegerType)), null, row) checkEvaluation(c1 & c2, 0, row) checkEvaluation(c1 | c2, 3, row) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ConstantFoldingSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ConstantFoldingSuite.scala index ef10c0aece716..4396bd0dda9a9 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ConstantFoldingSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ConstantFoldingSuite.scala @@ -176,40 +176,39 @@ class ConstantFoldingSuite extends PlanTest { } test("Constant folding test: expressions have null literals") { - val originalQuery = - testRelation - .select( - IsNull(Literal(null)) as 'c1, - IsNotNull(Literal(null)) as 'c2, + val originalQuery = testRelation.select( + IsNull(Literal(null)) as 'c1, + IsNotNull(Literal(null)) as 'c2, - GetItem(Literal(null, ArrayType(IntegerType)), 1) as 'c3, - GetItem(Literal(Seq(1), ArrayType(IntegerType)), Literal(null, IntegerType)) as 'c4, - UnresolvedGetField( - Literal(null, StructType(Seq(StructField("a", IntegerType, true)))), - "a") as 'c5, + GetItem(Literal.create(null, ArrayType(IntegerType)), 1) as 'c3, + GetItem( + Literal.create(Seq(1), ArrayType(IntegerType)), Literal.create(null, IntegerType)) as 'c4, + UnresolvedGetField( + Literal.create(null, StructType(Seq(StructField("a", IntegerType, true)))), + "a") as 'c5, - UnaryMinus(Literal(null, IntegerType)) as 'c6, - Cast(Literal(null), IntegerType) as 'c7, - Not(Literal(null, BooleanType)) as 'c8, + UnaryMinus(Literal.create(null, IntegerType)) as 'c6, + Cast(Literal(null), IntegerType) as 'c7, + Not(Literal.create(null, BooleanType)) as 'c8, - Add(Literal(null, IntegerType), 1) as 'c9, - Add(1, Literal(null, IntegerType)) as 'c10, + Add(Literal.create(null, IntegerType), 1) as 'c9, + Add(1, Literal.create(null, IntegerType)) as 'c10, - EqualTo(Literal(null, IntegerType), 1) as 'c11, - EqualTo(1, Literal(null, IntegerType)) as 'c12, + EqualTo(Literal.create(null, IntegerType), 1) as 'c11, + EqualTo(1, Literal.create(null, IntegerType)) as 'c12, - Like(Literal(null, StringType), "abc") as 'c13, - Like("abc", Literal(null, StringType)) as 'c14, + Like(Literal.create(null, StringType), "abc") as 'c13, + Like("abc", Literal.create(null, StringType)) as 'c14, - Upper(Literal(null, StringType)) as 'c15, + Upper(Literal.create(null, StringType)) as 'c15, - Substring(Literal(null, StringType), 0, 1) as 'c16, - Substring("abc", Literal(null, IntegerType), 1) as 'c17, - Substring("abc", 0, Literal(null, IntegerType)) as 'c18, + Substring(Literal.create(null, StringType), 0, 1) as 'c16, + Substring("abc", Literal.create(null, IntegerType), 1) as 'c17, + Substring("abc", 0, Literal.create(null, IntegerType)) as 'c18, - Contains(Literal(null, StringType), "abc") as 'c19, - Contains("abc", Literal(null, StringType)) as 'c20 - ) + Contains(Literal.create(null, StringType), "abc") as 'c19, + Contains("abc", Literal.create(null, StringType)) as 'c20 + ) val optimized = Optimize(originalQuery.analyze) @@ -219,31 +218,31 @@ class ConstantFoldingSuite extends PlanTest { Literal(true) as 'c1, Literal(false) as 'c2, - Literal(null, IntegerType) as 'c3, - Literal(null, IntegerType) as 'c4, - Literal(null, IntegerType) as 'c5, + Literal.create(null, IntegerType) as 'c3, + Literal.create(null, IntegerType) as 'c4, + Literal.create(null, IntegerType) as 'c5, - Literal(null, IntegerType) as 'c6, - Literal(null, IntegerType) as 'c7, - Literal(null, BooleanType) as 'c8, + Literal.create(null, IntegerType) as 'c6, + Literal.create(null, IntegerType) as 'c7, + Literal.create(null, BooleanType) as 'c8, - Literal(null, IntegerType) as 'c9, - Literal(null, IntegerType) as 'c10, + Literal.create(null, IntegerType) as 'c9, + Literal.create(null, IntegerType) as 'c10, - Literal(null, BooleanType) as 'c11, - Literal(null, BooleanType) as 'c12, + Literal.create(null, BooleanType) as 'c11, + Literal.create(null, BooleanType) as 'c12, - Literal(null, BooleanType) as 'c13, - Literal(null, BooleanType) as 'c14, + Literal.create(null, BooleanType) as 'c13, + Literal.create(null, BooleanType) as 'c14, - Literal(null, StringType) as 'c15, + Literal.create(null, StringType) as 'c15, - Literal(null, StringType) as 'c16, - Literal(null, StringType) as 'c17, - Literal(null, StringType) as 'c18, + Literal.create(null, StringType) as 'c16, + Literal.create(null, StringType) as 'c17, + Literal.create(null, StringType) as 'c18, - Literal(null, BooleanType) as 'c19, - Literal(null, BooleanType) as 'c20 + Literal.create(null, BooleanType) as 'c19, + Literal.create(null, BooleanType) as 'c20 ).analyze comparePlans(optimized, correctAnswer) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FilterPushdownSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FilterPushdownSuite.scala index 55c6766520a1e..1448098c770aa 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FilterPushdownSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FilterPushdownSuite.scala @@ -432,7 +432,8 @@ class FilterPushdownSuite extends PlanTest { val originalQuery = { z.join(x.join(y)) - .where(("x.b".attr === "y.b".attr) && ("x.a".attr === 1) && ("z.a".attr >= 3) && ("z.a".attr === "x.b".attr)) + .where(("x.b".attr === "y.b".attr) && ("x.a".attr === 1) && + ("z.a".attr >= 3) && ("z.a".attr === "x.b".attr)) } val optimized = Optimize(originalQuery.analyze) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/OptimizeInSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/OptimizeInSuite.scala index 233e329cb2038..966bc9ada1e6e 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/OptimizeInSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/OptimizeInSuite.scala @@ -52,7 +52,7 @@ class OptimizeInSuite extends PlanTest { val optimized = Optimize(originalQuery.analyze) val correctAnswer = testRelation - .where(InSet(UnresolvedAttribute("a"), HashSet[Any]()+1+2)) + .where(InSet(UnresolvedAttribute("a"), HashSet[Any]() + 1 + 2)) .analyze comparePlans(optimized, correctAnswer) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/PlanTest.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/PlanTest.scala index 129d091ca03e3..e7cafcc96de87 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/PlanTest.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/PlanTest.scala @@ -45,12 +45,13 @@ class PlanTest extends FunSuite { protected def comparePlans(plan1: LogicalPlan, plan2: LogicalPlan) { val normalized1 = normalizeExprIds(plan1) val normalized2 = normalizeExprIds(plan2) - if (normalized1 != normalized2) + if (normalized1 != normalized2) { fail( s""" |== FAIL: Plans do not match === |${sideBySide(normalized1.treeString, normalized2.treeString).mkString("\n")} - """.stripMargin) + """.stripMargin) + } } /** Fails the test if the two expressions do not match */ diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/SameResultSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/SameResultSuite.scala index 11e6831b24768..1273921f6394c 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/SameResultSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/SameResultSuite.scala @@ -32,7 +32,7 @@ class SameResultSuite extends FunSuite { val testRelation = LocalRelation('a.int, 'b.int, 'c.int) val testRelation2 = LocalRelation('a.int, 'b.int, 'c.int) - def assertSameResult(a: LogicalPlan, b: LogicalPlan, result: Boolean = true) = { + def assertSameResult(a: LogicalPlan, b: LogicalPlan, result: Boolean = true): Unit = { val aAnalyzed = a.analyze val bAnalyzed = b.analyze diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/trees/TreeNodeSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/trees/TreeNodeSuite.scala index e7ce92a2160b6..4eb8708335dcf 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/trees/TreeNodeSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/trees/TreeNodeSuite.scala @@ -25,12 +25,12 @@ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.types.{StringType, NullType} case class Dummy(optKey: Option[Expression]) extends Expression { - def children = optKey.toSeq - def nullable = true - def dataType = NullType + def children: Seq[Expression] = optKey.toSeq + def nullable: Boolean = true + def dataType: NullType = NullType override lazy val resolved = true type EvaluatedType = Any - def eval(input: Row) = null.asInstanceOf[Any] + def eval(input: Row): Any = null.asInstanceOf[Any] } class TreeNodeSuite extends FunSuite { @@ -90,7 +90,7 @@ class TreeNodeSuite extends FunSuite { } test("transform works on nodes with Option children") { - val dummy1 = Dummy(Some(Literal("1", StringType))) + val dummy1 = Dummy(Some(Literal.create("1", StringType))) val dummy2 = Dummy(None) val toZero: PartialFunction[Expression, Expression] = { case Literal(_, _) => Literal(0) } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/types/DataTypeParserSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/types/DataTypeParserSuite.scala index 1ba21b64603ac..169125264a803 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/types/DataTypeParserSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/types/DataTypeParserSuite.scala @@ -34,10 +34,12 @@ class DataTypeParserSuite extends FunSuite { } checkDataType("int", IntegerType) + checkDataType("integer", IntegerType) checkDataType("BooLean", BooleanType) checkDataType("tinYint", ByteType) checkDataType("smallINT", ShortType) checkDataType("INT", IntegerType) + checkDataType("INTEGER", IntegerType) checkDataType("bigint", LongType) checkDataType("float", FloatType) checkDataType("dOUBle", DoubleType) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala index 19cfa15f27b09..9b9adf855077a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala @@ -33,7 +33,7 @@ import org.apache.spark.api.java.JavaRDD import org.apache.spark.api.python.SerDeUtil import org.apache.spark.rdd.RDD import org.apache.spark.storage.StorageLevel -import org.apache.spark.sql.catalyst.{ScalaReflection, SqlParser} +import org.apache.spark.sql.catalyst.{CatalystTypeConverters, ScalaReflection, SqlParser} import org.apache.spark.sql.catalyst.analysis.{UnresolvedRelation, ResolvedStar} import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.{JoinType, Inner} @@ -240,8 +240,8 @@ class DataFrame private[sql]( s"Old column names (${schema.size}): " + schema.fields.map(_.name).mkString(", ") + "\n" + s"New column names (${colNames.size}): " + colNames.mkString(", ")) - val newCols = schema.fieldNames.zip(colNames).map { case (oldName, newName) => - apply(oldName).as(newName) + val newCols = logicalPlan.output.zip(colNames).map { case (oldAttribute, newName) => + Column(oldAttribute).as(newName) } select(newCols :_*) } @@ -273,7 +273,7 @@ class DataFrame private[sql]( def printSchema(): Unit = println(schema.treeString) /** - * Prints the plans (logical and physical) to the console for debugging purpose. + * Prints the plans (logical and physical) to the console for debugging purposes. * @group basic */ def explain(extended: Boolean): Unit = { @@ -285,7 +285,7 @@ class DataFrame private[sql]( } /** - * Only prints the physical plan to the console for debugging purpose. + * Only prints the physical plan to the console for debugging purposes. * @group basic */ def explain(): Unit = explain(extended = false) @@ -713,7 +713,7 @@ class DataFrame private[sql]( val schema = ScalaReflection.schemaFor[A].dataType.asInstanceOf[StructType] val attributes = schema.toAttributes val rowFunction = - f.andThen(_.map(ScalaReflection.convertToCatalyst(_, schema).asInstanceOf[Row])) + f.andThen(_.map(CatalystTypeConverters.convertToCatalyst(_, schema).asInstanceOf[Row])) val generator = UserDefinedGenerator(attributes, rowFunction, input.map(_.expr)) Generate(generator, join = true, outer = false, None, logicalPlan) @@ -734,7 +734,7 @@ class DataFrame private[sql]( val dataType = ScalaReflection.schemaFor[B].dataType val attributes = AttributeReference(outputColumn, dataType)() :: Nil def rowFunction(row: Row): TraversableOnce[Row] = { - f(row(0).asInstanceOf[A]).map(o => Row(ScalaReflection.convertToCatalyst(o, dataType))) + f(row(0).asInstanceOf[A]).map(o => Row(CatalystTypeConverters.convertToCatalyst(o, dataType))) } val generator = UserDefinedGenerator(attributes, rowFunction, apply(inputColumn).expr :: Nil) @@ -904,7 +904,8 @@ class DataFrame private[sql]( */ override def repartition(numPartitions: Int): DataFrame = { sqlContext.createDataFrame( - queryExecution.toRdd.map(_.copy()).repartition(numPartitions), schema) + queryExecution.toRdd.map(_.copy()).repartition(numPartitions), + schema, needsConversion = false) } /** @@ -960,7 +961,10 @@ class DataFrame private[sql]( lazy val rdd: RDD[Row] = { // use a local variable to make sure the map closure doesn't capture the whole DataFrame val schema = this.schema - queryExecution.executedPlan.execute().map(ScalaReflection.convertRowToScala(_, schema)) + queryExecution.executedPlan.execute().mapPartitions { rows => + val converter = CatalystTypeConverters.createToScalaConverter(schema) + rows.map(converter(_).asInstanceOf[Row]) + } } /** @@ -976,8 +980,8 @@ class DataFrame private[sql]( def javaRDD: JavaRDD[Row] = toJavaRDD /** - * Registers this RDD as a temporary table using the given name. The lifetime of this temporary - * table is tied to the [[SQLContext]] that was used to create this DataFrame. + * Registers this [[DataFrame]] as a temporary table using the given name. The lifetime of this + * temporary table is tied to the [[SQLContext]] that was used to create this DataFrame. * * @group basic */ @@ -1252,7 +1256,7 @@ class DataFrame private[sql]( //////////////////////////////////////////////////////////////////////////// /** - * Save this RDD to a JDBC database at `url` under the table name `table`. + * Save this [[DataFrame]] to a JDBC database at `url` under the table name `table`. * This will run a `CREATE TABLE` and a bunch of `INSERT INTO` statements. * If you pass `true` for `allowExisting`, it will drop any table with the * given name; if you pass `false`, it will throw if the table already @@ -1276,7 +1280,7 @@ class DataFrame private[sql]( } /** - * Save this RDD to a JDBC database at `url` under the table name `table`. + * Save this [[DataFrame]] to a JDBC database at `url` under the table name `table`. * Assumes the table already exists and has a compatible schema. If you * pass `true` for `overwrite`, it will `TRUNCATE` the table before * performing the `INSERT`s. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameNaFunctions.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameNaFunctions.scala index bf3c3fe876873..481ed4924857e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameNaFunctions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameNaFunctions.scala @@ -192,6 +192,127 @@ final class DataFrameNaFunctions private[sql](df: DataFrame) { */ def fill(valueMap: Map[String, Any]): DataFrame = fill0(valueMap.toSeq) + /** + * Replaces values matching keys in `replacement` map with the corresponding values. + * Key and value of `replacement` map must have the same type, and can only be doubles or strings. + * If `col` is "*", then the replacement is applied on all string columns or numeric columns. + * + * {{{ + * import com.google.common.collect.ImmutableMap; + * + * // Replaces all occurrences of 1.0 with 2.0 in column "height". + * df.replace("height", ImmutableMap.of(1.0, 2.0)); + * + * // Replaces all occurrences of "UNKNOWN" with "unnamed" in column "name". + * df.replace("name", ImmutableMap.of("UNKNOWN", "unnamed")); + * + * // Replaces all occurrences of "UNKNOWN" with "unnamed" in all string columns. + * df.replace("*", ImmutableMap.of("UNKNOWN", "unnamed")); + * }}} + * + * @param col name of the column to apply the value replacement + * @param replacement value replacement map, as explained above + */ + def replace[T](col: String, replacement: java.util.Map[T, T]): DataFrame = { + replace[T](col, replacement.toMap : Map[T, T]) + } + + /** + * Replaces values matching keys in `replacement` map with the corresponding values. + * Key and value of `replacement` map must have the same type, and can only be doubles or strings. + * + * {{{ + * import com.google.common.collect.ImmutableMap; + * + * // Replaces all occurrences of 1.0 with 2.0 in column "height" and "weight". + * df.replace(new String[] {"height", "weight"}, ImmutableMap.of(1.0, 2.0)); + * + * // Replaces all occurrences of "UNKNOWN" with "unnamed" in column "firstname" and "lastname". + * df.replace(new String[] {"firstname", "lastname"}, ImmutableMap.of("UNKNOWN", "unnamed")); + * }}} + * + * @param cols list of columns to apply the value replacement + * @param replacement value replacement map, as explained above + */ + def replace[T](cols: Array[String], replacement: java.util.Map[T, T]): DataFrame = { + replace(cols.toSeq, replacement.toMap) + } + + /** + * (Scala-specific) Replaces values matching keys in `replacement` map. + * Key and value of `replacement` map must have the same type, and can only be doubles or strings. + * If `col` is "*", then the replacement is applied on all string columns or numeric columns. + * + * {{{ + * // Replaces all occurrences of 1.0 with 2.0 in column "height". + * df.replace("height", Map(1.0 -> 2.0)) + * + * // Replaces all occurrences of "UNKNOWN" with "unnamed" in column "name". + * df.replace("name", Map("UNKNOWN" -> "unnamed") + * + * // Replaces all occurrences of "UNKNOWN" with "unnamed" in all string columns. + * df.replace("*", Map("UNKNOWN" -> "unnamed") + * }}} + * + * @param col name of the column to apply the value replacement + * @param replacement value replacement map, as explained above + */ + def replace[T](col: String, replacement: Map[T, T]): DataFrame = { + if (col == "*") { + replace0(df.columns, replacement) + } else { + replace0(Seq(col), replacement) + } + } + + /** + * (Scala-specific) Replaces values matching keys in `replacement` map. + * Key and value of `replacement` map must have the same type, and can only be doubles or strings. + * + * {{{ + * // Replaces all occurrences of 1.0 with 2.0 in column "height" and "weight". + * df.replace("height" :: "weight" :: Nil, Map(1.0 -> 2.0)); + * + * // Replaces all occurrences of "UNKNOWN" with "unnamed" in column "firstname" and "lastname". + * df.replace("firstname" :: "lastname" :: Nil, Map("UNKNOWN" -> "unnamed"); + * }}} + * + * @param cols list of columns to apply the value replacement + * @param replacement value replacement map, as explained above + */ + def replace[T](cols: Seq[String], replacement: Map[T, T]): DataFrame = replace0(cols, replacement) + + private def replace0[T](cols: Seq[String], replacement: Map[T, T]): DataFrame = { + if (replacement.isEmpty || cols.isEmpty) { + return df + } + + // replacementMap is either Map[String, String] or Map[Double, Double] + val replacementMap: Map[_, _] = replacement.head._2 match { + case v: String => replacement + case _ => replacement.map { case (k, v) => (convertToDouble(k), convertToDouble(v)) } + } + + // targetColumnType is either DoubleType or StringType + val targetColumnType = replacement.head._1 match { + case _: jl.Double | _: jl.Float | _: jl.Integer | _: jl.Long => DoubleType + case _: String => StringType + } + + val columnEquals = df.sqlContext.analyzer.resolver + val projections = df.schema.fields.map { f => + val shouldReplace = cols.exists(colName => columnEquals(colName, f.name)) + if (f.dataType.isInstanceOf[NumericType] && targetColumnType == DoubleType && shouldReplace) { + replaceCol(f, replacementMap) + } else if (f.dataType == targetColumnType && shouldReplace) { + replaceCol(f, replacementMap) + } else { + df.col(f.name) + } + } + df.select(projections : _*) + } + private def fill0(values: Seq[(String, Any)]): DataFrame = { // Error handling values.foreach { case (colName, replaceValue) => @@ -228,4 +349,27 @@ final class DataFrameNaFunctions private[sql](df: DataFrame) { private def fillCol[T](col: StructField, replacement: T): Column = { coalesce(df.col(col.name), lit(replacement).cast(col.dataType)).as(col.name) } + + /** + * Returns a [[Column]] expression that replaces value matching key in `replacementMap` with + * value in `replacementMap`, using [[CaseWhen]]. + * + * TODO: This can be optimized to use broadcast join when replacementMap is large. + */ + private def replaceCol(col: StructField, replacementMap: Map[_, _]): Column = { + val branches: Seq[Expression] = replacementMap.flatMap { case (source, target) => + df.col(col.name).equalTo(lit(source).cast(col.dataType)).expr :: + lit(target).cast(col.dataType).expr :: Nil + }.toSeq + new Column(CaseWhen(branches ++ Seq(df.col(col.name).expr))).as(col.name) + } + + private def convertToDouble(v: Any): Double = v match { + case v: Float => v.toDouble + case v: Double => v + case v: Long => v.toDouble + case v: Int => v.toDouble + case v => throw new IllegalArgumentException( + s"Unsupported value type ${v.getClass.getName} ($v).") + } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/GroupedData.scala b/sql/core/src/main/scala/org/apache/spark/sql/GroupedData.scala index a5e6b638d2150..53ad67372e024 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/GroupedData.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/GroupedData.scala @@ -34,7 +34,7 @@ import org.apache.spark.sql.types.NumericType @Experimental class GroupedData protected[sql](df: DataFrame, groupingExprs: Seq[Expression]) { - private[this] implicit def toDF(aggExprs: Seq[NamedExpression]): DataFrame = { + private[sql] implicit def toDF(aggExprs: Seq[NamedExpression]): DataFrame = { val namedGroupingExprs = groupingExprs.map { case expr: NamedExpression => expr case expr: Expression => Alias(expr, expr.prettyString)() diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala index 4815620c6fe57..ee641bdfeb2d7 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala @@ -39,6 +39,8 @@ private[spark] object SQLConf { val PARQUET_FILTER_PUSHDOWN_ENABLED = "spark.sql.parquet.filterPushdown" val PARQUET_USE_DATA_SOURCE_API = "spark.sql.parquet.useDataSourceApi" + val HIVE_VERIFY_PARTITIONPATH = "spark.sql.hive.verifyPartitionPath" + val COLUMN_NAME_OF_CORRUPT_RECORD = "spark.sql.columnNameOfCorruptRecord" val BROADCAST_TIMEOUT = "spark.sql.broadcastTimeout" @@ -119,6 +121,10 @@ private[sql] class SQLConf extends Serializable { private[spark] def parquetUseDataSourceApi = getConf(PARQUET_USE_DATA_SOURCE_API, "true").toBoolean + /** When true uses verifyPartitionPath to prune the path which is not exists. */ + private[spark] def verifyPartitionPath = + getConf(HIVE_VERIFY_PARTITIONPATH, "true").toBoolean + /** When true the planner will use the external sort, which may spill to disk. */ private[spark] def externalSortEnabled: Boolean = getConf(EXTERNAL_SORT, "false").toBoolean diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala index 1794936a52c6d..c25ef58e6f62a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala @@ -31,9 +31,9 @@ import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.analysis._ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.optimizer.{DefaultOptimizer, Optimizer} -import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan, OneRowRelation} +import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan} import org.apache.spark.sql.catalyst.rules.RuleExecutor -import org.apache.spark.sql.catalyst.{ScalaReflection, expressions} +import org.apache.spark.sql.catalyst.{CatalystTypeConverters, ScalaReflection, expressions} import org.apache.spark.sql.execution.{Filter, _} import org.apache.spark.sql.jdbc.{JDBCPartition, JDBCPartitioningInfo, JDBCRelation} import org.apache.spark.sql.json._ @@ -392,9 +392,24 @@ class SQLContext(@transient val sparkContext: SparkContext) */ @DeveloperApi def createDataFrame(rowRDD: RDD[Row], schema: StructType): DataFrame = { + createDataFrame(rowRDD, schema, needsConversion = true) + } + + /** + * Creates a DataFrame from an RDD[Row]. User can specify whether the input rows should be + * converted to Catalyst rows. + */ + private[sql] + def createDataFrame(rowRDD: RDD[Row], schema: StructType, needsConversion: Boolean) = { // TODO: use MutableProjection when rowRDD is another DataFrame and the applied // schema differs from the existing schema on any field data type. - val logicalPlan = LogicalRDD(schema.toAttributes, rowRDD)(self) + val catalystRows = if (needsConversion) { + val converter = CatalystTypeConverters.createToCatalystConverter(schema) + rowRDD.map(converter(_).asInstanceOf[Row]) + } else { + rowRDD + } + val logicalPlan = LogicalRDD(schema.toAttributes, catalystRows)(self) DataFrame(this, logicalPlan) } @@ -445,7 +460,7 @@ class SQLContext(@transient val sparkContext: SparkContext) iter.map { row => new GenericRow( extractors.zip(attributeSeq).map { case (e, attr) => - DataTypeConversions.convertJavaToCatalyst(e.invoke(row), attr.dataType) + CatalystTypeConverters.convertToCatalyst(e.invoke(row), attr.dataType) }.toArray[Any] ) : Row } @@ -604,7 +619,7 @@ class SQLContext(@transient val sparkContext: SparkContext) JsonRDD.nullTypeToStringType( JsonRDD.inferSchema(json, 1.0, columnNameOfCorruptJsonRecord))) val rowRDD = JsonRDD.jsonStringToRow(json, appliedSchema, columnNameOfCorruptJsonRecord) - createDataFrame(rowRDD, appliedSchema) + createDataFrame(rowRDD, appliedSchema, needsConversion = false) } /** @@ -633,7 +648,7 @@ class SQLContext(@transient val sparkContext: SparkContext) JsonRDD.nullTypeToStringType( JsonRDD.inferSchema(json, samplingRatio, columnNameOfCorruptJsonRecord)) val rowRDD = JsonRDD.jsonStringToRow(json, appliedSchema, columnNameOfCorruptJsonRecord) - createDataFrame(rowRDD, appliedSchema) + createDataFrame(rowRDD, appliedSchema, needsConversion = false) } /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/api/r/SQLUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/api/r/SQLUtils.scala new file mode 100644 index 0000000000000..d1ea7cc3e9162 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/api/r/SQLUtils.scala @@ -0,0 +1,127 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.api.r + +import java.io.{ByteArrayInputStream, ByteArrayOutputStream, DataInputStream, DataOutputStream} + +import org.apache.spark.api.java.{JavaRDD, JavaSparkContext} +import org.apache.spark.api.r.SerDe +import org.apache.spark.rdd.RDD +import org.apache.spark.sql.catalyst.expressions.{Alias, Expression, NamedExpression} +import org.apache.spark.sql.types.{DataType, StructType} +import org.apache.spark.sql.{Column, DataFrame, GroupedData, Row, SQLContext, SaveMode} + +private[r] object SQLUtils { + def createSQLContext(jsc: JavaSparkContext): SQLContext = { + new SQLContext(jsc) + } + + def getJavaSparkContext(sqlCtx: SQLContext): JavaSparkContext = { + new JavaSparkContext(sqlCtx.sparkContext) + } + + def toSeq[T](arr: Array[T]): Seq[T] = { + arr.toSeq + } + + def createDF(rdd: RDD[Array[Byte]], schemaString: String, sqlContext: SQLContext): DataFrame = { + val schema = DataType.fromJson(schemaString).asInstanceOf[StructType] + val num = schema.fields.size + val rowRDD = rdd.map(bytesToRow) + sqlContext.createDataFrame(rowRDD, schema) + } + + // A helper to include grouping columns in Agg() + def aggWithGrouping(gd: GroupedData, exprs: Column*): DataFrame = { + val aggExprs = exprs.map { col => + col.expr match { + case expr: NamedExpression => expr + case expr: Expression => Alias(expr, expr.simpleString)() + } + } + gd.toDF(aggExprs) + } + + def dfToRowRDD(df: DataFrame): JavaRDD[Array[Byte]] = { + df.map(r => rowToRBytes(r)) + } + + private[this] def bytesToRow(bytes: Array[Byte]): Row = { + val bis = new ByteArrayInputStream(bytes) + val dis = new DataInputStream(bis) + val num = SerDe.readInt(dis) + Row.fromSeq((0 until num).map { i => + SerDe.readObject(dis) + }.toSeq) + } + + private[this] def rowToRBytes(row: Row): Array[Byte] = { + val bos = new ByteArrayOutputStream() + val dos = new DataOutputStream(bos) + + SerDe.writeInt(dos, row.length) + (0 until row.length).map { idx => + val obj: Object = row(idx).asInstanceOf[Object] + SerDe.writeObject(dos, obj) + } + bos.toByteArray() + } + + def dfToCols(df: DataFrame): Array[Array[Byte]] = { + // localDF is Array[Row] + val localDF = df.collect() + val numCols = df.columns.length + // dfCols is Array[Array[Any]] + val dfCols = convertRowsToColumns(localDF, numCols) + + dfCols.map { col => + colToRBytes(col) + } + } + + def convertRowsToColumns(localDF: Array[Row], numCols: Int): Array[Array[Any]] = { + (0 until numCols).map { colIdx => + localDF.map { row => + row(colIdx) + } + }.toArray + } + + def colToRBytes(col: Array[Any]): Array[Byte] = { + val numRows = col.length + val bos = new ByteArrayOutputStream() + val dos = new DataOutputStream(bos) + + SerDe.writeInt(dos, numRows) + + col.map { item => + val obj: Object = item.asInstanceOf[Object] + SerDe.writeObject(dos, obj) + } + bos.toByteArray() + } + + def saveMode(mode: String): SaveMode = { + mode match { + case "append" => SaveMode.Append + case "overwrite" => SaveMode.Overwrite + case "error" => SaveMode.ErrorIfExists + case "ignore" => SaveMode.Ignore + } + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/ExistingRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/ExistingRDD.scala index d8955725e59b1..656bdd7212f56 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/ExistingRDD.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/ExistingRDD.scala @@ -20,14 +20,12 @@ package org.apache.spark.sql.execution import org.apache.spark.annotation.DeveloperApi import org.apache.spark.rdd.RDD import org.apache.spark.sql.{Row, SQLContext} -import org.apache.spark.sql.catalyst.ScalaReflection +import org.apache.spark.sql.catalyst.CatalystTypeConverters import org.apache.spark.sql.catalyst.analysis.MultiInstanceRelation -import org.apache.spark.sql.catalyst.expressions.{Attribute, GenericMutableRow} +import org.apache.spark.sql.catalyst.expressions.{SpecificMutableRow, Attribute} import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, Statistics} import org.apache.spark.sql.types.StructType -import scala.collection.immutable - /** * :: DeveloperApi :: */ @@ -39,13 +37,15 @@ object RDDConversions { Iterator.empty } else { val bufferedIterator = iterator.buffered - val mutableRow = new GenericMutableRow(bufferedIterator.head.productArity) + val mutableRow = new SpecificMutableRow(schema.fields.map(_.dataType)) val schemaFields = schema.fields.toArray + val converters = schemaFields.map { + f => CatalystTypeConverters.createToCatalystConverter(f.dataType) + } bufferedIterator.map { r => var i = 0 while (i < mutableRow.length) { - mutableRow(i) = - ScalaReflection.convertToCatalyst(r.productElement(i), schemaFields(i).dataType) + mutableRow(i) = converters(i)(r.productElement(i)) i += 1 } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/GeneratedAggregate.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/GeneratedAggregate.scala index 89682d25ca7dc..95176e425132d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/GeneratedAggregate.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/GeneratedAggregate.scala @@ -68,6 +68,8 @@ case class GeneratedAggregate( a.collect { case agg: AggregateExpression => agg} } + // If you add any new function support, please add tests in org.apache.spark.sql.SQLQuerySuite + // (in test "aggregation with codegen"). val computeFunctions = aggregatesToCompute.map { case c @ Count(expr) => // If we're evaluating UnscaledValue(x), we can do Count on x directly, since its @@ -93,13 +95,16 @@ case class GeneratedAggregate( } val currentSum = AttributeReference("currentSum", calcType, nullable = true)() - val initialValue = Literal(null, calcType) + val initialValue = Literal.create(null, calcType) - // Coalasce avoids double calculation... + // Coalesce avoids double calculation... // but really, common sub expression elimination would be better.... val zero = Cast(Literal(0), calcType) val updateFunction = Coalesce( - Add(Coalesce(currentSum :: zero :: Nil), Cast(expr, calcType)) :: currentSum :: Nil) + Add( + Coalesce(currentSum :: zero :: Nil), + Cast(expr, calcType) + ) :: currentSum :: zero :: Nil) val result = expr.dataType match { case DecimalType.Fixed(_, _) => @@ -109,6 +114,45 @@ case class GeneratedAggregate( AggregateEvaluation(currentSum :: Nil, initialValue :: Nil, updateFunction :: Nil, result) + case cs @ CombineSum(expr) => + val calcType = expr.dataType + expr.dataType match { + case DecimalType.Fixed(_, _) => + DecimalType.Unlimited + case _ => + expr.dataType + } + + val currentSum = AttributeReference("currentSum", calcType, nullable = true)() + val initialValue = Literal.create(null, calcType) + + // Coalasce avoids double calculation... + // but really, common sub expression elimination would be better.... + val zero = Cast(Literal(0), calcType) + // If we're evaluating UnscaledValue(x), we can do Count on x directly, since its + // UnscaledValue will be null if and only if x is null; helps with Average on decimals + val actualExpr = expr match { + case UnscaledValue(e) => e + case _ => expr + } + // partial sum result can be null only when no input rows present + val updateFunction = If( + IsNotNull(actualExpr), + Coalesce( + Add( + Coalesce(currentSum :: zero :: Nil), + Cast(expr, calcType)) :: currentSum :: zero :: Nil), + currentSum) + + val result = + expr.dataType match { + case DecimalType.Fixed(_, _) => + Cast(currentSum, cs.dataType) + case _ => currentSum + } + + AggregateEvaluation(currentSum :: Nil, initialValue :: Nil, updateFunction :: Nil, result) + case a @ Average(expr) => val calcType = expr.dataType match { @@ -137,13 +181,13 @@ case class GeneratedAggregate( expr.dataType match { case DecimalType.Fixed(_, _) => If(EqualTo(currentCount, Literal(0L)), - Literal(null, a.dataType), + Literal.create(null, a.dataType), Cast(Divide( Cast(currentSum, DecimalType.Unlimited), Cast(currentCount, DecimalType.Unlimited)), a.dataType)) case _ => If(EqualTo(currentCount, Literal(0L)), - Literal(null, a.dataType), + Literal.create(null, a.dataType), Divide(Cast(currentSum, a.dataType), Cast(currentCount, a.dataType))) } @@ -156,7 +200,7 @@ case class GeneratedAggregate( case m @ Max(expr) => val currentMax = AttributeReference("currentMax", expr.dataType, nullable = true)() - val initialValue = Literal(null, expr.dataType) + val initialValue = Literal.create(null, expr.dataType) val updateMax = MaxOf(currentMax, expr) AggregateEvaluation( @@ -166,7 +210,8 @@ case class GeneratedAggregate( currentMax) case CollectHashSet(Seq(expr)) => - val set = AttributeReference("hashSet", ArrayType(expr.dataType), nullable = false)() + val set = + AttributeReference("hashSet", new OpenHashSetUDT(expr.dataType), nullable = false)() val initialValue = NewSet(expr.dataType) val addToSet = AddItemToSet(expr, set) @@ -177,9 +222,10 @@ case class GeneratedAggregate( set) case CombineSetsAndCount(inputSet) => - val ArrayType(inputType, _) = inputSet.dataType - val set = AttributeReference("hashSet", inputSet.dataType, nullable = false)() - val initialValue = NewSet(inputType) + val elementType = inputSet.dataType.asInstanceOf[OpenHashSetUDT].elementType + val set = + AttributeReference("hashSet", new OpenHashSetUDT(elementType), nullable = false)() + val initialValue = NewSet(elementType) val collectSets = CombineSets(set, inputSet) AggregateEvaluation( diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/LocalTableScan.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/LocalTableScan.scala index 5bd699a2fa949..8a8c3a404323a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/LocalTableScan.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/LocalTableScan.scala @@ -19,7 +19,7 @@ package org.apache.spark.sql.execution import org.apache.spark.rdd.RDD import org.apache.spark.sql.Row -import org.apache.spark.sql.catalyst.ScalaReflection +import org.apache.spark.sql.catalyst.CatalystTypeConverters import org.apache.spark.sql.catalyst.expressions.Attribute @@ -32,9 +32,15 @@ case class LocalTableScan(output: Seq[Attribute], rows: Seq[Row]) extends LeafNo override def execute(): RDD[Row] = rdd - override def executeCollect(): Array[Row] = - rows.map(ScalaReflection.convertRowToScala(_, schema)).toArray - override def executeTake(limit: Int): Array[Row] = - rows.map(ScalaReflection.convertRowToScala(_, schema)).take(limit).toArray + override def executeCollect(): Array[Row] = { + val converter = CatalystTypeConverters.createToScalaConverter(schema) + rows.map(converter(_).asInstanceOf[Row]).toArray + } + + + override def executeTake(limit: Int): Array[Row] = { + val converter = CatalystTypeConverters.createToScalaConverter(schema) + rows.map(converter(_).asInstanceOf[Row]).take(limit).toArray + } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala index d239637cd4b4e..fabcf6b4a0570 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala @@ -21,7 +21,7 @@ import org.apache.spark.annotation.DeveloperApi import org.apache.spark.Logging import org.apache.spark.rdd.RDD import org.apache.spark.sql.SQLContext -import org.apache.spark.sql.catalyst.{ScalaReflection, trees} +import org.apache.spark.sql.catalyst.{CatalystTypeConverters, trees} import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.codegen._ import org.apache.spark.sql.catalyst.plans.QueryPlan @@ -80,8 +80,12 @@ abstract class SparkPlan extends QueryPlan[SparkPlan] with Logging with Serializ /** * Runs this query returning the result as an array. */ + def executeCollect(): Array[Row] = { - execute().map(ScalaReflection.convertRowToScala(_, schema)).collect() + execute().mapPartitions { iter => + val converter = CatalystTypeConverters.createToScalaConverter(schema) + iter.map(converter(_).asInstanceOf[Row]) + }.collect() } /** @@ -125,7 +129,8 @@ abstract class SparkPlan extends QueryPlan[SparkPlan] with Logging with Serializ partsScanned += numPartsToTry } - buf.toArray.map(ScalaReflection.convertRowToScala(_, this.schema)) + val converter = CatalystTypeConverters.createToScalaConverter(schema) + buf.toArray.map(converter(_).asInstanceOf[Row]) } protected def newProjection( diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlSerializer.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlSerializer.scala index 967bd76b302d8..914f387dec78f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlSerializer.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlSerializer.scala @@ -18,6 +18,7 @@ package org.apache.spark.sql.execution import java.nio.ByteBuffer +import java.util.{HashMap => JavaHashMap} import org.apache.spark.sql.types.Decimal @@ -26,14 +27,13 @@ import scala.reflect.ClassTag import com.clearspring.analytics.stream.cardinality.HyperLogLog import com.esotericsoftware.kryo.io.{Input, Output} import com.esotericsoftware.kryo.{Serializer, Kryo} -import com.twitter.chill.{AllScalaRegistrar, ResourcePool} +import com.twitter.chill.ResourcePool import org.apache.spark.{SparkEnv, SparkConf} import org.apache.spark.serializer.{SerializerInstance, KryoSerializer} import org.apache.spark.sql.catalyst.expressions.GenericRow import org.apache.spark.util.collection.OpenHashSet import org.apache.spark.util.MutablePair -import org.apache.spark.util.Utils import org.apache.spark.sql.catalyst.expressions.codegen.{IntegerHashSet, LongHashSet} @@ -55,6 +55,7 @@ private[sql] class SparkSqlSerializer(conf: SparkConf) extends KryoSerializer(co kryo.register(classOf[org.apache.spark.util.collection.OpenHashSet[_]], new OpenHashSetSerializer) kryo.register(classOf[Decimal]) + kryo.register(classOf[JavaHashMap[_, _]]) kryo.setReferences(false) kryo diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala index f754fa770d1b5..23f7e5609414b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala @@ -155,7 +155,7 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] { } def canBeCodeGened(aggs: Seq[AggregateExpression]): Boolean = !aggs.exists { - case _: Sum | _: Count | _: Max | _: CombineSetsAndCount => false + case _: CombineSum | _: Sum | _: Count | _: Max | _: CombineSetsAndCount => false // The generated set implementation is pretty limited ATM. case CollectHashSet(exprs) if exprs.size == 1 && Seq(IntegerType, LongType).contains(exprs.head.dataType) => false diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala index 1f5251a20376f..6eec520abff53 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala @@ -21,7 +21,7 @@ import org.apache.spark.{SparkEnv, HashPartitioner, SparkConf} import org.apache.spark.annotation.DeveloperApi import org.apache.spark.rdd.{RDD, ShuffledRDD} import org.apache.spark.shuffle.sort.SortShuffleManager -import org.apache.spark.sql.catalyst.ScalaReflection +import org.apache.spark.sql.catalyst.CatalystTypeConverters import org.apache.spark.sql.catalyst.errors._ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.physical._ @@ -139,9 +139,10 @@ case class TakeOrdered(limit: Int, sortOrder: Seq[SortOrder], child: SparkPlan) private def collectData(): Array[Row] = child.execute().map(_.copy()).takeOrdered(limit)(ord) - // TODO: Is this copying for no reason? - override def executeCollect(): Array[Row] = - collectData().map(ScalaReflection.convertRowToScala(_, this.schema)) + override def executeCollect(): Array[Row] = { + val converter = CatalystTypeConverters.createToScalaConverter(schema) + collectData().map(converter(_).asInstanceOf[Row]) + } // TODO: Terminal split should be implemented differently from non-terminal split. // TODO: Pick num splits based on |limit|. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala index 2fa1cf5add3b5..ab84c123e0c0b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala @@ -17,9 +17,11 @@ package org.apache.spark.sql.execution.joins +import java.io.{ObjectInput, ObjectOutput, Externalizable} import java.util.{HashMap => JavaHashMap} import org.apache.spark.sql.catalyst.expressions.{Projection, Row} +import org.apache.spark.sql.execution.SparkSqlSerializer import org.apache.spark.util.collection.CompactBuffer @@ -29,16 +31,43 @@ import org.apache.spark.util.collection.CompactBuffer */ private[joins] sealed trait HashedRelation { def get(key: Row): CompactBuffer[Row] + + // This is a helper method to implement Externalizable, and is used by + // GeneralHashedRelation and UniqueKeyHashedRelation + protected def writeBytes(out: ObjectOutput, serialized: Array[Byte]): Unit = { + out.writeInt(serialized.length) // Write the length of serialized bytes first + out.write(serialized) + } + + // This is a helper method to implement Externalizable, and is used by + // GeneralHashedRelation and UniqueKeyHashedRelation + protected def readBytes(in: ObjectInput): Array[Byte] = { + val serializedSize = in.readInt() // Read the length of serialized bytes first + val bytes = new Array[Byte](serializedSize) + in.readFully(bytes) + bytes + } } /** * A general [[HashedRelation]] backed by a hash map that maps the key into a sequence of values. */ -private[joins] final class GeneralHashedRelation(hashTable: JavaHashMap[Row, CompactBuffer[Row]]) - extends HashedRelation with Serializable { +private[joins] final class GeneralHashedRelation( + private var hashTable: JavaHashMap[Row, CompactBuffer[Row]]) + extends HashedRelation with Externalizable { + + def this() = this(null) // Needed for serialization override def get(key: Row): CompactBuffer[Row] = hashTable.get(key) + + override def writeExternal(out: ObjectOutput): Unit = { + writeBytes(out, SparkSqlSerializer.serialize(hashTable)) + } + + override def readExternal(in: ObjectInput): Unit = { + hashTable = SparkSqlSerializer.deserialize(readBytes(in)) + } } @@ -46,8 +75,10 @@ private[joins] final class GeneralHashedRelation(hashTable: JavaHashMap[Row, Com * A specialized [[HashedRelation]] that maps key into a single value. This implementation * assumes the key is unique. */ -private[joins] final class UniqueKeyHashedRelation(hashTable: JavaHashMap[Row, Row]) - extends HashedRelation with Serializable { +private[joins] final class UniqueKeyHashedRelation(private var hashTable: JavaHashMap[Row, Row]) + extends HashedRelation with Externalizable { + + def this() = this(null) // Needed for serialization override def get(key: Row): CompactBuffer[Row] = { val v = hashTable.get(key) @@ -55,6 +86,14 @@ private[joins] final class UniqueKeyHashedRelation(hashTable: JavaHashMap[Row, R } def getValue(key: Row): Row = hashTable.get(key) + + override def writeExternal(out: ObjectOutput): Unit = { + writeBytes(out, SparkSqlSerializer.serialize(hashTable)) + } + + override def readExternal(in: ObjectInput): Unit = { + hashTable = SparkSqlSerializer.deserialize(readBytes(in)) + } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala index 111e751588a8b..ff91e1d74bc2c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala @@ -22,7 +22,7 @@ import scala.reflect.runtime.universe.{TypeTag, typeTag} import org.apache.spark.annotation.Experimental import org.apache.spark.sql.catalyst.ScalaReflection -import org.apache.spark.sql.catalyst.analysis.Star +import org.apache.spark.sql.catalyst.analysis.{UnresolvedFunction, Star} import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.types._ @@ -605,4 +605,23 @@ object functions { } // scalastyle:on + + /** + * Call an user-defined function. + * Example: + * {{{ + * import org.apache.spark.sql._ + * + * val df = Seq(("id1", 1), ("id2", 4), ("id3", 5)).toDF("id", "value") + * val sqlContext = df.sqlContext + * sqlContext.udf.register("simpleUdf", (v: Int) => v * v) + * df.select($"id", callUdf("simpleUdf", $"value")) + * }}} + * + * @group udf_funcs + */ + def callUdf(udfName: String, cols: Column*): Column = { + UnresolvedFunction(udfName, cols.map(_.expr)) + } + } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/DriverQuirks.scala b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/DriverQuirks.scala index 1704be7fcbd30..0feabc4282f4a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/DriverQuirks.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/DriverQuirks.scala @@ -49,9 +49,9 @@ private[sql] object DriverQuirks { * Fetch the DriverQuirks class corresponding to a given database url. */ def get(url: String): DriverQuirks = { - if (url.substring(0, 10).equals("jdbc:mysql")) { + if (url.startsWith("jdbc:mysql")) { new MySQLQuirks() - } else if (url.substring(0, 15).equals("jdbc:postgresql")) { + } else if (url.startsWith("jdbc:postgresql")) { new PostgresQuirks() } else { new NoQuirks() diff --git a/sql/core/src/main/scala/org/apache/spark/sql/json/JsonRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/json/JsonRDD.scala index 0b770f2251943..b1e8521383756 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/json/JsonRDD.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/json/JsonRDD.scala @@ -391,7 +391,7 @@ private[sql] object JsonRDD extends Logging { value match { // only support string as date case value: java.lang.String => - DateUtils.millisToDays(DataTypeConversions.stringToTime(value).getTime) + DateUtils.millisToDays(DateUtils.stringToTime(value).getTime) case value: java.sql.Date => DateUtils.fromJavaDate(value) } } @@ -400,7 +400,7 @@ private[sql] object JsonRDD extends Logging { value match { case value: java.lang.Integer => new Timestamp(value.asInstanceOf[Int].toLong) case value: java.lang.Long => new Timestamp(value) - case value: java.lang.String => toTimestamp(DataTypeConversions.stringToTime(value).getTime) + case value: java.lang.String => toTimestamp(DateUtils.stringToTime(value).getTime) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/parquet/DirectParquetOutputCommitter.scala b/sql/core/src/main/scala/org/apache/spark/sql/parquet/DirectParquetOutputCommitter.scala new file mode 100644 index 0000000000000..25a66cb488103 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/parquet/DirectParquetOutputCommitter.scala @@ -0,0 +1,66 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.parquet + +import org.apache.hadoop.fs.Path +import org.apache.hadoop.mapreduce.{JobContext, TaskAttemptContext} +import org.apache.hadoop.mapreduce.lib.output.FileOutputCommitter + +import parquet.Log +import parquet.hadoop.util.ContextUtil +import parquet.hadoop.{ParquetFileReader, ParquetFileWriter, ParquetOutputCommitter} + +private[parquet] class DirectParquetOutputCommitter(outputPath: Path, context: TaskAttemptContext) + extends ParquetOutputCommitter(outputPath, context) { + val LOG = Log.getLog(classOf[ParquetOutputCommitter]) + + override def getWorkPath(): Path = outputPath + override def abortTask(taskContext: TaskAttemptContext): Unit = {} + override def commitTask(taskContext: TaskAttemptContext): Unit = {} + override def needsTaskCommit(taskContext: TaskAttemptContext): Boolean = true + override def setupJob(jobContext: JobContext): Unit = {} + override def setupTask(taskContext: TaskAttemptContext): Unit = {} + + override def commitJob(jobContext: JobContext) { + try { + val configuration = ContextUtil.getConfiguration(jobContext) + val fileSystem = outputPath.getFileSystem(configuration) + val outputStatus = fileSystem.getFileStatus(outputPath) + val footers = ParquetFileReader.readAllFootersInParallel(configuration, outputStatus) + try { + ParquetFileWriter.writeMetadataFile(configuration, outputPath, footers) + if (configuration.getBoolean("mapreduce.fileoutputcommitter.marksuccessfuljobs", true)) { + val successPath = new Path(outputPath, FileOutputCommitter.SUCCEEDED_FILE_NAME) + fileSystem.create(successPath).close() + } + } catch { + case e: Exception => { + LOG.warn("could not write summary file for " + outputPath, e) + val metadataPath = new Path(outputPath, ParquetFileWriter.PARQUET_METADATA_FILE) + if (fileSystem.exists(metadataPath)) { + fileSystem.delete(metadataPath, true) + } + } + } + } catch { + case e: Exception => LOG.warn("could not write summary file for " + outputPath, e) + } + } + +} + diff --git a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableOperations.scala b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableOperations.scala index 1c868da23e060..3724bda829d30 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableOperations.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableOperations.scala @@ -379,6 +379,8 @@ private[sql] case class InsertIntoParquetTable( */ private[parquet] class AppendingParquetOutputFormat(offset: Int) extends parquet.hadoop.ParquetOutputFormat[Row] { + var committer: OutputCommitter = null + // override to accept existing directories as valid output directory override def checkOutputSpecs(job: JobContext): Unit = {} @@ -403,6 +405,26 @@ private[parquet] class AppendingParquetOutputFormat(offset: Int) private def getTaskAttemptID(context: TaskAttemptContext): TaskAttemptID = { context.getClass.getMethod("getTaskAttemptID").invoke(context).asInstanceOf[TaskAttemptID] } + + // override to create output committer from configuration + override def getOutputCommitter(context: TaskAttemptContext): OutputCommitter = { + if (committer == null) { + val output = getOutputPath(context) + val cls = context.getConfiguration.getClass("spark.sql.parquet.output.committer.class", + classOf[ParquetOutputCommitter], classOf[ParquetOutputCommitter]) + val ctor = cls.getDeclaredConstructor(classOf[Path], classOf[TaskAttemptContext]) + committer = ctor.newInstance(output, context).asInstanceOf[ParquetOutputCommitter] + } + committer + } + + // FileOutputFormat.getOutputPath takes JobConf in hadoop-1 but JobContext in hadoop-2 + private def getOutputPath(context: TaskAttemptContext): Path = { + context.getConfiguration().get("mapred.output.dir") match { + case null => null + case name => new Path(name) + } + } } /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTypes.scala b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTypes.scala index da668f068613b..60e1bec4db8e5 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTypes.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTypes.scala @@ -390,6 +390,7 @@ private[parquet] object ParquetTypesConverter extends Logging { def convertFromAttributes(attributes: Seq[Attribute], toThriftSchemaNames: Boolean = false): MessageType = { + checkSpecialCharacters(attributes) val fields = attributes.map( attribute => fromDataType(attribute.dataType, attribute.name, attribute.nullable, @@ -404,7 +405,20 @@ private[parquet] object ParquetTypesConverter extends Logging { } } + private def checkSpecialCharacters(schema: Seq[Attribute]) = { + // ,;{}()\n\t= and space character are special characters in Parquet schema + schema.map(_.name).foreach { name => + if (name.matches(".*[ ,;{}()\n\t=].*")) { + sys.error( + s"""Attribute name "$name" contains invalid character(s) among " ,;{}()\n\t=". + |Please use alias to rename it. + """.stripMargin.split("\n").mkString(" ")) + } + } + } + def convertToString(schema: Seq[Attribute]): String = { + checkSpecialCharacters(schema) StructType.fromAttributes(schema).json } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/parquet/newParquet.scala b/sql/core/src/main/scala/org/apache/spark/sql/parquet/newParquet.scala index 19800ad88c031..20fdf5e58ef82 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/parquet/newParquet.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/parquet/newParquet.scala @@ -122,7 +122,8 @@ private[sql] class DefaultSource val df = sqlContext.createDataFrame( data.queryExecution.toRdd, - data.schema.asNullable) + data.schema.asNullable, + needsConversion = false) val createdRelation = createRelation(sqlContext, parameters, df.schema).asInstanceOf[ParquetRelation2] createdRelation.insert(df, overwrite = mode == SaveMode.Overwrite) @@ -267,7 +268,8 @@ private[sql] case class ParquetRelation2( // containing Parquet files (e.g. partitioned Parquet table). val baseStatuses = paths.distinct.map { p => val fs = FileSystem.get(URI.create(p), sparkContext.hadoopConfiguration) - val qualified = fs.makeQualified(new Path(p)) + val path = new Path(p) + val qualified = path.makeQualified(fs.getUri, fs.getWorkingDirectory) if (!fs.exists(qualified) && maybeSchema.isDefined) { fs.mkdirs(qualified) @@ -430,20 +432,24 @@ private[sql] case class ParquetRelation2( // FileInputFormat cannot handle empty lists. if (selectedFiles.nonEmpty) { - FileInputFormat.setInputPaths(job, selectedFiles.map(_.getPath): _*) + // In order to encode the authority of a Path containning special characters such as /, + // we need to use the string retruned by the URI of the path to create a new Path. + val selectedPaths = selectedFiles.map(status => new Path(status.getPath.toUri.toString)) + FileInputFormat.setInputPaths(job, selectedPaths: _*) } - // Push down filters when possible. Notice that not all filters can be converted to Parquet - // filter predicate. Here we try to convert each individual predicate and only collect those - // convertible ones. + // Try to push down filters when filter push-down is enabled. if (sqlContext.conf.parquetFilterPushDown) { + val partitionColNames = partitionColumns.map(_.name).toSet predicates // Don't push down predicates which reference partition columns .filter { pred => - val partitionColNames = partitionColumns.map(_.name).toSet val referencedColNames = pred.references.map(_.name).toSet referencedColNames.intersect(partitionColNames).isEmpty } + // Collects all converted Parquet filter predicates. Notice that not all predicates can be + // converted (`ParquetFilters.createFilter` returns an `Option`). That's why a `flatMap` + // is used here. .flatMap(ParquetFilters.createFilter) .reduceOption(FilterApi.and) .foreach(ParquetInputFormat.setFilterPredicate(jobConf, _)) @@ -481,10 +487,31 @@ private[sql] case class ParquetRelation2( val cacheMetadata = useCache @transient - val cachedStatus = selectedFiles + val cachedStatus = selectedFiles.map { st => + // In order to encode the authority of a Path containning special characters such as /, + // we need to use the string retruned by the URI of the path to create a new Path. + val newPath = new Path(st.getPath.toUri.toString) + + new FileStatus( + st.getLen, + st.isDir, + st.getReplication, + st.getBlockSize, + st.getModificationTime, + st.getAccessTime, + st.getPermission, + st.getOwner, + st.getGroup, + newPath) + } @transient - val cachedFooters = selectedFooters + val cachedFooters = selectedFooters.map { f => + // In order to encode the authority of a Path containning special characters such as /, + // we need to use the string retruned by the URI of the path to create a new Path. + new Footer(new Path(f.getFile.toUri.toString), f.getParquetMetadata) + } + // Overridden so we can inject our own cached files statuses. override def getPartitions: Array[SparkPartition] = { @@ -872,9 +899,9 @@ private[sql] object ParquetRelation2 extends Logging { * PartitionValues( * Seq("a", "b", "c"), * Seq( - * Literal(42, IntegerType), - * Literal("hello", StringType), - * Literal(3.14, FloatType))) + * Literal.create(42, IntegerType), + * Literal.create("hello", StringType), + * Literal.create(3.14, FloatType))) * }}} */ private[parquet] def parsePartition( @@ -953,15 +980,16 @@ private[sql] object ParquetRelation2 extends Logging { raw: String, defaultPartitionName: String): Literal = { // First tries integral types - Try(Literal(Integer.parseInt(raw), IntegerType)) - .orElse(Try(Literal(JLong.parseLong(raw), LongType))) + Try(Literal.create(Integer.parseInt(raw), IntegerType)) + .orElse(Try(Literal.create(JLong.parseLong(raw), LongType))) // Then falls back to fractional types - .orElse(Try(Literal(JFloat.parseFloat(raw), FloatType))) - .orElse(Try(Literal(JDouble.parseDouble(raw), DoubleType))) - .orElse(Try(Literal(new JBigDecimal(raw), DecimalType.Unlimited))) + .orElse(Try(Literal.create(JFloat.parseFloat(raw), FloatType))) + .orElse(Try(Literal.create(JDouble.parseDouble(raw), DoubleType))) + .orElse(Try(Literal.create(new JBigDecimal(raw), DecimalType.Unlimited))) // Then falls back to string .getOrElse { - if (raw == defaultPartitionName) Literal(null, NullType) else Literal(raw, StringType) + if (raw == defaultPartitionName) Literal.create(null, NullType) + else Literal.create(raw, StringType) } } @@ -980,7 +1008,7 @@ private[sql] object ParquetRelation2 extends Logging { } literals.map { case l @ Literal(_, dataType) => - Literal(Cast(l, desiredType).eval(), desiredType) + Literal.create(Cast(l, desiredType).eval(), desiredType) } } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/sources/commands.scala b/sql/core/src/main/scala/org/apache/spark/sql/sources/commands.scala index 9bbe06e59ba30..dbdb0d39c26a1 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/sources/commands.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/sources/commands.scala @@ -31,7 +31,8 @@ private[sql] case class InsertIntoDataSource( val relation = logicalRelation.relation.asInstanceOf[InsertableRelation] val data = DataFrame(sqlContext, query) // Apply the schema of the existing table to the new data. - val df = sqlContext.createDataFrame(data.queryExecution.toRdd, logicalRelation.schema) + val df = sqlContext.createDataFrame( + data.queryExecution.toRdd, logicalRelation.schema, needsConversion = false) relation.insert(df, overwrite) // Invalidate the cache. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/sources/ddl.scala b/sql/core/src/main/scala/org/apache/spark/sql/sources/ddl.scala index eb46b46ca5bf4..319de710fbc3e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/sources/ddl.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/sources/ddl.scala @@ -204,7 +204,7 @@ private[sql] object ResolvedDataSource { provider: String, options: Map[String, String]): ResolvedDataSource = { val clazz: Class[_] = lookupDataSource(provider) - def className = clazz.getCanonicalName + def className: String = clazz.getCanonicalName val relation = userSpecifiedSchema match { case Some(schema: StructType) => clazz.newInstance() match { case dataSource: SchemaRelationProvider => diff --git a/sql/core/src/main/scala/org/apache/spark/sql/test/ExamplePointUDT.scala b/sql/core/src/main/scala/org/apache/spark/sql/test/ExamplePointUDT.scala index c11d0ae5bf1cc..2fdd798b44bb6 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/test/ExamplePointUDT.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/test/ExamplePointUDT.scala @@ -28,7 +28,7 @@ import org.apache.spark.sql.types._ * @param y y coordinate */ @SQLUserDefinedType(udt = classOf[ExamplePointUDT]) -private[sql] class ExamplePoint(val x: Double, val y: Double) +private[sql] class ExamplePoint(val x: Double, val y: Double) extends Serializable /** * User-defined type for [[ExamplePoint]]. diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java index 1ff2d5a190521..6d0fbe83c2f36 100644 --- a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java +++ b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java @@ -20,6 +20,8 @@ import java.io.Serializable; import java.util.Arrays; +import scala.collection.Seq; + import org.junit.After; import org.junit.Assert; import org.junit.Before; @@ -127,6 +129,12 @@ public void testCreateDataFrameFromJavaBeans() { schema.apply("b")); Row first = df.select("a", "b").first(); Assert.assertEquals(bean.getA(), first.getDouble(0), 0.0); - Assert.assertArrayEquals(bean.getB(), first.getAs(1)); + // Now Java lists and maps are converetd to Scala Seq's and Map's. Once we get a Seq below, + // verify that it has the expected length, and contains expected elements. + Seq result = first.getAs(1); + Assert.assertEquals(bean.getB().length, result.length()); + for (int i = 0; i < result.length(); i++) { + Assert.assertEquals(bean.getB()[i], result.apply(i)); + } } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala index c240f2be955ca..f7b5f08beb92f 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala @@ -92,7 +92,8 @@ class CachedTableSuite extends QueryTest { test("too big for memory") { val data = "*" * 10000 - sparkContext.parallelize(1 to 200000, 1).map(_ => BigData(data)).toDF().registerTempTable("bigData") + sparkContext.parallelize(1 to 200000, 1).map(_ => BigData(data)).toDF() + .registerTempTable("bigData") table("bigData").persist(StorageLevel.MEMORY_AND_DISK) assert(table("bigData").count() === 200000L) table("bigData").unpersist(blocking = true) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameNaFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameNaFunctionsSuite.scala index 0896f175c056f..41b4f02e6a294 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameNaFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameNaFunctionsSuite.scala @@ -154,4 +154,38 @@ class DataFrameNaFunctionsSuite extends QueryTest { ))), Row("test", null, 1, 2.2)) } + + test("replace") { + val input = createDF() + + // Replace two numeric columns: age and height + val out = input.na.replace(Seq("age", "height"), Map( + 16 -> 61, + 60 -> 6, + 164.3 -> 461.3 // Alice is really tall + )) + + checkAnswer( + out, + Row("Bob", 61, 176.5) :: + Row("Alice", null, 461.3) :: + Row("David", 6, null) :: + Row("Amy", null, null) :: + Row(null, null, null) :: Nil) + + // Replace only the age column + val out1 = input.na.replace("age", Map( + 16 -> 61, + 60 -> 6, + 164.3 -> 461.3 // Alice is really tall + )) + + checkAnswer( + out1, + Row("Bob", 61, 176.5) :: + Row("Alice", null, 164.3) :: + Row("David", 6, null) :: + Row("Amy", null, null) :: + Row(null, null, null) :: Nil) + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala index 6761d996fd975..b26e22f6229fe 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala @@ -21,7 +21,7 @@ import scala.language.postfixOps import org.apache.spark.sql.functions._ import org.apache.spark.sql.types._ -import org.apache.spark.sql.test.TestSQLContext +import org.apache.spark.sql.test.{ExamplePointUDT, ExamplePoint, TestSQLContext} import org.apache.spark.sql.test.TestSQLContext.logicalPlanToSparkQuery import org.apache.spark.sql.test.TestSQLContext.implicits._ import org.apache.spark.sql.test.TestSQLContext.sql @@ -60,6 +60,14 @@ class DataFrameSuite extends QueryTest { assert($"test".toString === "test") } + test("rename nested groupby") { + val df = Seq((1,(1,1))).toDF() + + checkAnswer( + df.groupBy("_1").agg(col("_1"), sum("_2._1")).toDF("key", "total"), + Row(1, 1) :: Nil) + } + test("invalid plan toString, debug mode") { val oldSetting = TestSQLContext.conf.dataFrameEagerAnalysis TestSQLContext.setConf(SQLConf.DATAFRAME_EAGER_ANALYSIS, "true") @@ -321,8 +329,9 @@ class DataFrameSuite extends QueryTest { checkAnswer( decimalData.agg(avg('a cast DecimalType(10, 2))), Row(new java.math.BigDecimal(2.0))) + // non-partial checkAnswer( - decimalData.agg(avg('a cast DecimalType(10, 2)), sumDistinct('a cast DecimalType(10, 2))), // non-partial + decimalData.agg(avg('a cast DecimalType(10, 2)), sumDistinct('a cast DecimalType(10, 2))), Row(new java.math.BigDecimal(2.0), new java.math.BigDecimal(6)) :: Nil) } @@ -431,6 +440,15 @@ class DataFrameSuite extends QueryTest { ) } + test("call udf in SQLContext") { + val df = Seq(("id1", 1), ("id2", 4), ("id3", 5)).toDF("id", "value") + val sqlctx = df.sqlContext + sqlctx.udf.register("simpleUdf", (v: Int) => v * v) + checkAnswer( + df.select($"id", callUdf("simpleUdf", $"value")), + Row("id1", 1) :: Row("id2", 16) :: Row("id3", 25) :: Nil) + } + test("withColumn") { val df = testData.toDF().withColumn("newCol", col("key") + 1) checkAnswer( @@ -506,4 +524,11 @@ class DataFrameSuite extends QueryTest { testData.select($"*").show() testData.select($"*").show(1000) } + + test("createDataFrame(RDD[Row], StructType) should convert UDTs (SPARK-6672)") { + val rowRDD = TestSQLContext.sparkContext.parallelize(Seq(Row(new ExamplePoint(1.0, 2.0)))) + val schema = StructType(Array(StructField("point", new ExamplePointUDT(), false))) + val df = TestSQLContext.createDataFrame(rowRDD, schema) + df.rdd.collect() + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala index 9b4dd6c620fec..9a81fc5d72819 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala @@ -67,7 +67,7 @@ class QueryTest extends PlanTest { checkAnswer(df, Seq(expectedAnswer)) } - def sqlTest(sqlString: String, expectedAnswer: Seq[Row])(implicit sqlContext: SQLContext): Unit = { + def sqlTest(sqlString: String, expectedAnswer: Seq[Row])(implicit sqlContext: SQLContext) { test(sqlString) { checkAnswer(sqlContext.sql(sqlString), expectedAnswer) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala index 87e7cf8c8af9f..5e453e05e2ac7 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala @@ -17,9 +17,11 @@ package org.apache.spark.sql +import org.apache.spark.sql.execution.GeneratedAggregate import org.apache.spark.sql.test.TestSQLContext import org.scalatest.BeforeAndAfterAll +import org.apache.spark.sql.execution.GeneratedAggregate import org.apache.spark.sql.functions._ import org.apache.spark.sql.catalyst.errors.TreeNodeException import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan @@ -102,11 +104,97 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll { sql("SELECT ABS(2.5)"), Row(2.5)) } - + test("aggregation with codegen") { val originalValue = conf.codegenEnabled setConf(SQLConf.CODEGEN_ENABLED, "true") - sql("SELECT key FROM testData GROUP BY key").collect() + // Prepare a table that we can group some rows. + table("testData") + .unionAll(table("testData")) + .unionAll(table("testData")) + .registerTempTable("testData3x") + + def testCodeGen(sqlText: String, expectedResults: Seq[Row]): Unit = { + val df = sql(sqlText) + // First, check if we have GeneratedAggregate. + var hasGeneratedAgg = false + df.queryExecution.executedPlan.foreach { + case generatedAgg: GeneratedAggregate => hasGeneratedAgg = true + case _ => + } + if (!hasGeneratedAgg) { + fail( + s""" + |Codegen is enabled, but query $sqlText does not have GeneratedAggregate in the plan. + |${df.queryExecution.simpleString} + """.stripMargin) + } + // Then, check results. + checkAnswer(df, expectedResults) + } + + // Just to group rows. + testCodeGen( + "SELECT key FROM testData3x GROUP BY key", + (1 to 100).map(Row(_))) + // COUNT + testCodeGen( + "SELECT key, count(value) FROM testData3x GROUP BY key", + (1 to 100).map(i => Row(i, 3))) + testCodeGen( + "SELECT count(key) FROM testData3x", + Row(300) :: Nil) + // COUNT DISTINCT ON int + testCodeGen( + "SELECT value, count(distinct key) FROM testData3x GROUP BY value", + (1 to 100).map(i => Row(i.toString, 1))) + testCodeGen( + "SELECT count(distinct key) FROM testData3x", + Row(100) :: Nil) + // SUM + testCodeGen( + "SELECT value, sum(key) FROM testData3x GROUP BY value", + (1 to 100).map(i => Row(i.toString, 3 * i))) + testCodeGen( + "SELECT sum(key), SUM(CAST(key as Double)) FROM testData3x", + Row(5050 * 3, 5050 * 3.0) :: Nil) + // AVERAGE + testCodeGen( + "SELECT value, avg(key) FROM testData3x GROUP BY value", + (1 to 100).map(i => Row(i.toString, i))) + testCodeGen( + "SELECT avg(key) FROM testData3x", + Row(50.5) :: Nil) + // MAX + testCodeGen( + "SELECT value, max(key) FROM testData3x GROUP BY value", + (1 to 100).map(i => Row(i.toString, i))) + testCodeGen( + "SELECT max(key) FROM testData3x", + Row(100) :: Nil) + // Some combinations. + testCodeGen( + """ + |SELECT + | value, + | sum(key), + | max(key), + | avg(key), + | count(key), + | count(distinct key) + |FROM testData3x + |GROUP BY value + """.stripMargin, + (1 to 100).map(i => Row(i.toString, i*3, i, i, 3, 1))) + testCodeGen( + "SELECT max(key), avg(key), count(key), count(distinct key) FROM testData3x", + Row(100, 50.5, 300, 100) :: Nil) + // Aggregate with Code generation handling all null values + testCodeGen( + "SELECT sum('a'), avg('a'), count(null) FROM testData", + Row(0, null, 0) :: Nil) + + dropTempTable("testData3x") setConf(SQLConf.CODEGEN_ENABLED, originalValue.toString) } @@ -182,7 +270,10 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll { Row(java.sql.Timestamp.valueOf("1969-12-31 16:00:00.002"))) checkAnswer(sql( - "SELECT time FROM timestamps WHERE time IN ('1969-12-31 16:00:00.001','1969-12-31 16:00:00.002')"), + """ + |SELECT time FROM timestamps + |WHERE time IN ('1969-12-31 16:00:00.001','1969-12-31 16:00:00.002') + """.stripMargin), Seq(Row(java.sql.Timestamp.valueOf("1969-12-31 16:00:00.001")), Row(java.sql.Timestamp.valueOf("1969-12-31 16:00:00.002")))) @@ -248,7 +339,7 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll { Row("1")) } - def sortTest() = { + def sortTest(): Unit = { checkAnswer( sql("SELECT * FROM testData2 ORDER BY a ASC, b ASC"), Seq(Row(1,1), Row(1,2), Row(2,1), Row(2,2), Row(3,1), Row(3,2))) @@ -318,6 +409,20 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll { mapData.collect().take(1).map(Row.fromTuple).toSeq) } + test("CTE feature") { + checkAnswer( + sql("with q1 as (select * from testData limit 10) select * from q1"), + testData.take(10).toSeq) + + checkAnswer( + sql(""" + |with q1 as (select * from testData where key= '5'), + |q2 as (select * from testData where key = '4') + |select * from q1 union all select * from q2""".stripMargin), + Row(5, "5") :: Row(4, "4") :: Nil) + + } + test("date row") { checkAnswer(sql( """select cast("2015-01-28" as date) from testData limit 1"""), @@ -327,7 +432,10 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll { test("from follow multiple brackets") { checkAnswer(sql( - "select key from ((select * from testData limit 1) union all (select * from testData limit 1)) x limit 1"), + """ + |select key from ((select * from testData limit 1) + | union all (select * from testData limit 1)) x limit 1 + """.stripMargin), Row(1) ) @@ -337,7 +445,11 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll { ) checkAnswer(sql( - "select key from (select * from testData limit 1 union all select * from testData limit 1) x limit 1"), + """ + |select key from + | (select * from testData limit 1 union all select * from testData limit 1) x + | limit 1 + """.stripMargin), Row(1) ) } @@ -384,7 +496,10 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll { Seq(Row(1, 0), Row(2, 1))) checkAnswer( - sql("SELECT COUNT(a), COUNT(b), COUNT(1), COUNT(DISTINCT a), COUNT(DISTINCT b) FROM testData3"), + sql( + """ + |SELECT COUNT(a), COUNT(b), COUNT(1), COUNT(DISTINCT a), COUNT(DISTINCT b) FROM testData3 + """.stripMargin), Row(2, 1, 2, 2, 1)) } @@ -997,7 +1112,8 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll { } test("SPARK-3483 Special chars in column names") { - val data = sparkContext.parallelize(Seq("""{"key?number1": "value1", "key.number2": "value2"}""")) + val data = sparkContext.parallelize( + Seq("""{"key?number1": "value1", "key.number2": "value2"}""")) jsonRDD(data).registerTempTable("records") sql("SELECT `key?number1` FROM records") } @@ -1082,8 +1198,8 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll { } test("SPARK-6145: ORDER BY test for nested fields") { - jsonRDD(sparkContext.makeRDD( - """{"a": {"b": 1, "a": {"a": 1}}, "c": [{"d": 1}]}""" :: Nil)).registerTempTable("nestedOrder") + jsonRDD(sparkContext.makeRDD("""{"a": {"b": 1, "a": {"a": 1}}, "c": [{"d": 1}]}""" :: Nil)) + .registerTempTable("nestedOrder") checkAnswer(sql("SELECT 1 FROM nestedOrder ORDER BY a.b"), Row(1)) checkAnswer(sql("SELECT a.b FROM nestedOrder ORDER BY a.b"), Row(1)) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/ScalaReflectionRelationSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/ScalaReflectionRelationSuite.scala index 17e923ca48502..3fa00fd9d0ccb 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/ScalaReflectionRelationSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/ScalaReflectionRelationSuite.scala @@ -80,7 +80,7 @@ class ScalaReflectionRelationSuite extends FunSuite { test("query case class RDD") { val data = ReflectData("a", 1, 1L, 1.toFloat, 1.toDouble, 1.toShort, 1.toByte, true, - new java.math.BigDecimal(1), new Date(12345), new Timestamp(12345), Seq(1,2,3)) + new java.math.BigDecimal(1), new Date(12345), new Timestamp(12345), Seq(1,2,3)) val rdd = sparkContext.parallelize(data :: Nil) rdd.toDF().registerTempTable("reflectData") @@ -103,7 +103,8 @@ class ScalaReflectionRelationSuite extends FunSuite { val rdd = sparkContext.parallelize(data :: Nil) rdd.toDF().registerTempTable("reflectOptionalData") - assert(sql("SELECT * FROM reflectOptionalData").collect().head === Row.fromSeq(Seq.fill(7)(null))) + assert(sql("SELECT * FROM reflectOptionalData").collect().head === + Row.fromSeq(Seq.fill(7)(null))) } // Equality is broken for Arrays, so we test that separately. diff --git a/sql/core/src/test/scala/org/apache/spark/sql/UserDefinedTypeSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/UserDefinedTypeSuite.scala index fe618e0e8e767..2672e20deadc5 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/UserDefinedTypeSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/UserDefinedTypeSuite.scala @@ -23,13 +23,16 @@ import org.apache.spark.util.Utils import scala.beans.{BeanInfo, BeanProperty} +import com.clearspring.analytics.stream.cardinality.HyperLogLog + import org.apache.spark.rdd.RDD +import org.apache.spark.sql.catalyst.expressions.{OpenHashSetUDT, HyperLogLogUDT} import org.apache.spark.sql.functions._ import org.apache.spark.sql.test.TestSQLContext import org.apache.spark.sql.test.TestSQLContext.{sparkContext, sql} import org.apache.spark.sql.test.TestSQLContext.implicits._ import org.apache.spark.sql.types._ - +import org.apache.spark.util.collection.OpenHashSet @SQLUserDefinedType(udt = classOf[MyDenseVectorUDT]) private[sql] class MyDenseVector(val data: Array[Double]) extends Serializable { @@ -63,7 +66,7 @@ private[sql] class MyDenseVectorUDT extends UserDefinedType[MyDenseVector] { } } - override def userClass = classOf[MyDenseVector] + override def userClass: Class[MyDenseVector] = classOf[MyDenseVector] private[spark] override def asNullable: MyDenseVectorUDT = this } @@ -119,4 +122,23 @@ class UserDefinedTypeSuite extends QueryTest { df.limit(1).groupBy('int).agg(first('vec)).collect()(0).getAs[MyDenseVector](0) df.orderBy('int).limit(1).groupBy('int).agg(first('vec)).collect()(0).getAs[MyDenseVector](0) } + + test("HyperLogLogUDT") { + val hyperLogLogUDT = HyperLogLogUDT + val hyperLogLog = new HyperLogLog(0.4) + (1 to 10).foreach(i => hyperLogLog.offer(Row(i))) + + val actual = hyperLogLogUDT.deserialize(hyperLogLogUDT.serialize(hyperLogLog)) + assert(actual.cardinality() === hyperLogLog.cardinality()) + assert(java.util.Arrays.equals(actual.getBytes, hyperLogLog.getBytes)) + } + + test("OpenHashSetUDT") { + val openHashSetUDT = new OpenHashSetUDT(IntegerType) + val set = new OpenHashSet[Int] + (1 to 10).foreach(i => set.add(i)) + + val actual = openHashSetUDT.deserialize(openHashSetUDT.serialize(set)) + assert(actual.iterator.toSet === set.iterator.toSet) + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnarTestUtils.scala b/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnarTestUtils.scala index c7a40845db16c..b301818a008e7 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnarTestUtils.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnarTestUtils.scala @@ -27,7 +27,7 @@ import org.apache.spark.sql.catalyst.expressions.GenericMutableRow import org.apache.spark.sql.types.{Decimal, DataType, NativeType} object ColumnarTestUtils { - def makeNullRow(length: Int) = { + def makeNullRow(length: Int): GenericMutableRow = { val row = new GenericMutableRow(length) (0 until length).foreach(row.setNullAt) row @@ -93,7 +93,7 @@ object ColumnarTestUtils { def makeUniqueValuesAndSingleValueRows[T <: NativeType]( columnType: NativeColumnType[T], - count: Int) = { + count: Int): (Seq[T#JvmType], Seq[GenericMutableRow]) = { val values = makeUniqueRandomValues(columnType, count) val rows = values.map { value => diff --git a/sql/core/src/test/scala/org/apache/spark/sql/columnar/NullableColumnAccessorSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/columnar/NullableColumnAccessorSuite.scala index bb305355276bf..a0702144f942c 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/columnar/NullableColumnAccessorSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/columnar/NullableColumnAccessorSuite.scala @@ -31,7 +31,8 @@ class TestNullableColumnAccessor[T <: DataType, JvmType]( with NullableColumnAccessor object TestNullableColumnAccessor { - def apply[T <: DataType, JvmType](buffer: ByteBuffer, columnType: ColumnType[T, JvmType]) = { + def apply[T <: DataType, JvmType](buffer: ByteBuffer, columnType: ColumnType[T, JvmType]) + : TestNullableColumnAccessor[T, JvmType] = { // Skips the column type ID buffer.getInt() new TestNullableColumnAccessor(buffer, columnType) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/columnar/NullableColumnBuilderSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/columnar/NullableColumnBuilderSuite.scala index 75a47498683f4..3a5605d2335d7 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/columnar/NullableColumnBuilderSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/columnar/NullableColumnBuilderSuite.scala @@ -27,7 +27,8 @@ class TestNullableColumnBuilder[T <: DataType, JvmType](columnType: ColumnType[T with NullableColumnBuilder object TestNullableColumnBuilder { - def apply[T <: DataType, JvmType](columnType: ColumnType[T, JvmType], initialSize: Int = 0) = { + def apply[T <: DataType, JvmType](columnType: ColumnType[T, JvmType], initialSize: Int = 0) + : TestNullableColumnBuilder[T, JvmType] = { val builder = new TestNullableColumnBuilder(columnType) builder.initialize(initialSize) builder diff --git a/sql/core/src/test/scala/org/apache/spark/sql/columnar/compression/TestCompressibleColumnBuilder.scala b/sql/core/src/test/scala/org/apache/spark/sql/columnar/compression/TestCompressibleColumnBuilder.scala index 0b18b4119268f..fc8ff3b41d0e6 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/columnar/compression/TestCompressibleColumnBuilder.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/columnar/compression/TestCompressibleColumnBuilder.scala @@ -35,7 +35,7 @@ object TestCompressibleColumnBuilder { def apply[T <: NativeType]( columnStats: ColumnStats, columnType: NativeColumnType[T], - scheme: CompressionScheme) = { + scheme: CompressionScheme): TestCompressibleColumnBuilder[T] = { val builder = new TestCompressibleColumnBuilder(columnStats, columnType, Seq(scheme)) builder.initialize(0, "", useCompression = true) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/debug/DebuggingSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/debug/DebuggingSuite.scala index 4e9472c60249e..358d8cf06e463 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/debug/DebuggingSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/debug/DebuggingSuite.scala @@ -30,4 +30,4 @@ class DebuggingSuite extends FunSuite { test("DataFrame.typeCheck()") { testData.typeCheck() } -} \ No newline at end of file +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala index 592ed4b23b7d3..3596b183d4328 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala @@ -45,10 +45,12 @@ class JDBCSuite extends FunSuite with BeforeAndAfter { conn = DriverManager.getConnection(url, properties) conn.prepareStatement("create schema test").executeUpdate() - conn.prepareStatement("create table test.people (name TEXT(32) NOT NULL, theid INTEGER NOT NULL)").executeUpdate() + conn.prepareStatement( + "create table test.people (name TEXT(32) NOT NULL, theid INTEGER NOT NULL)").executeUpdate() conn.prepareStatement("insert into test.people values ('fred', 1)").executeUpdate() conn.prepareStatement("insert into test.people values ('mary', 2)").executeUpdate() - conn.prepareStatement("insert into test.people values ('joe ''foo'' \"bar\"', 3)").executeUpdate() + conn.prepareStatement( + "insert into test.people values ('joe ''foo'' \"bar\"', 3)").executeUpdate() conn.commit() sql( @@ -132,25 +134,25 @@ class JDBCSuite extends FunSuite with BeforeAndAfter { } test("SELECT *") { - assert(sql("SELECT * FROM foobar").collect().size == 3) + assert(sql("SELECT * FROM foobar").collect().size === 3) } test("SELECT * WHERE (simple predicates)") { - assert(sql("SELECT * FROM foobar WHERE THEID < 1").collect().size == 0) - assert(sql("SELECT * FROM foobar WHERE THEID != 2").collect().size == 2) - assert(sql("SELECT * FROM foobar WHERE THEID = 1").collect().size == 1) - assert(sql("SELECT * FROM foobar WHERE NAME = 'fred'").collect().size == 1) - assert(sql("SELECT * FROM foobar WHERE NAME > 'fred'").collect().size == 2) - assert(sql("SELECT * FROM foobar WHERE NAME != 'fred'").collect().size == 2) + assert(sql("SELECT * FROM foobar WHERE THEID < 1").collect().size === 0) + assert(sql("SELECT * FROM foobar WHERE THEID != 2").collect().size === 2) + assert(sql("SELECT * FROM foobar WHERE THEID = 1").collect().size === 1) + assert(sql("SELECT * FROM foobar WHERE NAME = 'fred'").collect().size === 1) + assert(sql("SELECT * FROM foobar WHERE NAME > 'fred'").collect().size === 2) + assert(sql("SELECT * FROM foobar WHERE NAME != 'fred'").collect().size === 2) } test("SELECT * WHERE (quoted strings)") { - assert(sql("select * from foobar").where('NAME === "joe 'foo' \"bar\"").collect().size == 1) + assert(sql("select * from foobar").where('NAME === "joe 'foo' \"bar\"").collect().size === 1) } test("SELECT first field") { val names = sql("SELECT NAME FROM foobar").collect().map(x => x.getString(0)).sortWith(_ < _) - assert(names.size == 3) + assert(names.size === 3) assert(names(0).equals("fred")) assert(names(1).equals("joe 'foo' \"bar\"")) assert(names(2).equals("mary")) @@ -158,10 +160,10 @@ class JDBCSuite extends FunSuite with BeforeAndAfter { test("SELECT second field") { val ids = sql("SELECT THEID FROM foobar").collect().map(x => x.getInt(0)).sortWith(_ < _) - assert(ids.size == 3) - assert(ids(0) == 1) - assert(ids(1) == 2) - assert(ids(2) == 3) + assert(ids.size === 3) + assert(ids(0) === 1) + assert(ids(1) === 2) + assert(ids(2) === 3) } test("SELECT * partitioned") { @@ -169,46 +171,46 @@ class JDBCSuite extends FunSuite with BeforeAndAfter { } test("SELECT WHERE (simple predicates) partitioned") { - assert(sql("SELECT * FROM parts WHERE THEID < 1").collect().size == 0) - assert(sql("SELECT * FROM parts WHERE THEID != 2").collect().size == 2) - assert(sql("SELECT THEID FROM parts WHERE THEID = 1").collect().size == 1) + assert(sql("SELECT * FROM parts WHERE THEID < 1").collect().size === 0) + assert(sql("SELECT * FROM parts WHERE THEID != 2").collect().size === 2) + assert(sql("SELECT THEID FROM parts WHERE THEID = 1").collect().size === 1) } test("SELECT second field partitioned") { val ids = sql("SELECT THEID FROM parts").collect().map(x => x.getInt(0)).sortWith(_ < _) - assert(ids.size == 3) - assert(ids(0) == 1) - assert(ids(1) == 2) - assert(ids(2) == 3) + assert(ids.size === 3) + assert(ids(0) === 1) + assert(ids(1) === 2) + assert(ids(2) === 3) } test("Basic API") { - assert(TestSQLContext.jdbc(urlWithUserAndPass, "TEST.PEOPLE").collect.size == 3) + assert(TestSQLContext.jdbc(urlWithUserAndPass, "TEST.PEOPLE").collect().size === 3) } test("Partitioning via JDBCPartitioningInfo API") { assert(TestSQLContext.jdbc(urlWithUserAndPass, "TEST.PEOPLE", "THEID", 0, 4, 3) - .collect.size == 3) + .collect.size === 3) } test("Partitioning via list-of-where-clauses API") { val parts = Array[String]("THEID < 2", "THEID >= 2") - assert(TestSQLContext.jdbc(urlWithUserAndPass, "TEST.PEOPLE", parts).collect.size == 3) + assert(TestSQLContext.jdbc(urlWithUserAndPass, "TEST.PEOPLE", parts).collect().size === 3) } test("H2 integral types") { val rows = sql("SELECT * FROM inttypes WHERE A IS NOT NULL").collect() - assert(rows.size == 1) - assert(rows(0).getInt(0) == 1) - assert(rows(0).getBoolean(1) == false) - assert(rows(0).getInt(2) == 3) - assert(rows(0).getInt(3) == 4) - assert(rows(0).getLong(4) == 1234567890123L) + assert(rows.size === 1) + assert(rows(0).getInt(0) === 1) + assert(rows(0).getBoolean(1) === false) + assert(rows(0).getInt(2) === 3) + assert(rows(0).getInt(3) === 4) + assert(rows(0).getLong(4) === 1234567890123L) } test("H2 null entries") { val rows = sql("SELECT * FROM inttypes WHERE A IS NULL").collect() - assert(rows.size == 1) + assert(rows.size === 1) assert(rows(0).isNullAt(0)) assert(rows(0).isNullAt(1)) assert(rows(0).isNullAt(2)) @@ -230,27 +232,27 @@ class JDBCSuite extends FunSuite with BeforeAndAfter { val rows = sql("SELECT * FROM timetypes").collect() val cal = new GregorianCalendar(java.util.Locale.ROOT) cal.setTime(rows(0).getAs[java.sql.Timestamp](0)) - assert(cal.get(Calendar.HOUR_OF_DAY) == 12) - assert(cal.get(Calendar.MINUTE) == 34) - assert(cal.get(Calendar.SECOND) == 56) + assert(cal.get(Calendar.HOUR_OF_DAY) === 12) + assert(cal.get(Calendar.MINUTE) === 34) + assert(cal.get(Calendar.SECOND) === 56) cal.setTime(rows(0).getAs[java.sql.Timestamp](1)) - assert(cal.get(Calendar.YEAR) == 1996) - assert(cal.get(Calendar.MONTH) == 0) - assert(cal.get(Calendar.DAY_OF_MONTH) == 1) + assert(cal.get(Calendar.YEAR) === 1996) + assert(cal.get(Calendar.MONTH) === 0) + assert(cal.get(Calendar.DAY_OF_MONTH) === 1) cal.setTime(rows(0).getAs[java.sql.Timestamp](2)) - assert(cal.get(Calendar.YEAR) == 2002) - assert(cal.get(Calendar.MONTH) == 1) - assert(cal.get(Calendar.DAY_OF_MONTH) == 20) - assert(cal.get(Calendar.HOUR) == 11) - assert(cal.get(Calendar.MINUTE) == 22) - assert(cal.get(Calendar.SECOND) == 33) - assert(rows(0).getAs[java.sql.Timestamp](2).getNanos == 543543543) + assert(cal.get(Calendar.YEAR) === 2002) + assert(cal.get(Calendar.MONTH) === 1) + assert(cal.get(Calendar.DAY_OF_MONTH) === 20) + assert(cal.get(Calendar.HOUR) === 11) + assert(cal.get(Calendar.MINUTE) === 22) + assert(cal.get(Calendar.SECOND) === 33) + assert(rows(0).getAs[java.sql.Timestamp](2).getNanos === 543543543) } test("H2 floating-point types") { val rows = sql("SELECT * FROM flttypes").collect() - assert(rows(0).getDouble(0) == 1.00000000000000022) // Yes, I meant ==. - assert(rows(0).getDouble(1) == 1.00000011920928955) // Yes, I meant ==. + assert(rows(0).getDouble(0) === 1.00000000000000022) // Yes, I meant ==. + assert(rows(0).getDouble(1) === 1.00000011920928955) // Yes, I meant ==. assert(rows(0).getAs[BigDecimal](2) .equals(new BigDecimal("123456789012345.54321543215432100000"))) } @@ -264,7 +266,7 @@ class JDBCSuite extends FunSuite with BeforeAndAfter { | user 'testUser', password 'testPass') """.stripMargin.replaceAll("\n", " ")) val rows = sql("SELECT * FROM hack").collect() - assert(rows(0).getDouble(0) == 1.00000011920928955) // Yes, I meant ==. + assert(rows(0).getDouble(0) === 1.00000011920928955) // Yes, I meant ==. // For some reason, H2 computes this square incorrectly... assert(math.abs(rows(0).getDouble(1) - 1.00000023841859331) < 1e-12) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/json/JsonSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/json/JsonSuite.scala index 706c966ee05f5..fd0e2746dc045 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/json/JsonSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/json/JsonSuite.scala @@ -380,8 +380,10 @@ class JsonSuite extends QueryTest { sql("select * from jsonTable"), Row("true", 11L, null, 1.1, "13.1", "str1") :: Row("12", null, new java.math.BigDecimal("21474836470.9"), null, null, "true") :: - Row("false", 21474836470L, new java.math.BigDecimal("92233720368547758070"), 100, "str1", "false") :: - Row(null, 21474836570L, new java.math.BigDecimal("1.1"), 21474836470L, "92233720368547758070", null) :: Nil + Row("false", 21474836470L, + new java.math.BigDecimal("92233720368547758070"), 100, "str1", "false") :: + Row(null, 21474836570L, + new java.math.BigDecimal("1.1"), 21474836470L, "92233720368547758070", null) :: Nil ) // Number and Boolean conflict: resolve the type as number in this query. @@ -404,7 +406,8 @@ class JsonSuite extends QueryTest { // Widening to DecimalType checkAnswer( sql("select num_num_2 + 1.2 from jsonTable where num_num_2 > 1.1"), - Row(new java.math.BigDecimal("21474836472.1")) :: Row(new java.math.BigDecimal("92233720368547758071.2")) :: Nil + Row(new java.math.BigDecimal("21474836472.1")) :: + Row(new java.math.BigDecimal("92233720368547758071.2")) :: Nil ) // Widening to DoubleType @@ -892,8 +895,7 @@ class JsonSuite extends QueryTest { ) } - test("SPARK-4228 DataFrame to JSON") - { + test("SPARK-4228 DataFrame to JSON") { val schema1 = StructType( StructField("f1", IntegerType, false) :: StructField("f2", StringType, false) :: @@ -913,8 +915,10 @@ class JsonSuite extends QueryTest { df1.registerTempTable("applySchema1") val df2 = df1.toDF val result = df2.toJSON.collect() + // scalastyle:off assert(result(0) === "{\"f1\":1,\"f2\":\"A1\",\"f3\":true,\"f4\":[\"1\",\" A1\",\" true\",\" null\"]}") assert(result(3) === "{\"f1\":4,\"f2\":\"D4\",\"f3\":true,\"f4\":[\"4\",\" D4\",\" true\",\" 2147483644\"],\"f5\":2147483644}") + // scalastyle:on val schema2 = StructType( StructField("f1", StructType( @@ -968,7 +972,8 @@ class JsonSuite extends QueryTest { // Access elements of a BigInteger array (we use DecimalType internally). checkAnswer( - sql("select arrayOfBigInteger[0], arrayOfBigInteger[1], arrayOfBigInteger[2] from complexTable"), + sql("select arrayOfBigInteger[0], arrayOfBigInteger[1], arrayOfBigInteger[2] " + + " from complexTable"), Row(new java.math.BigDecimal("922337203685477580700"), new java.math.BigDecimal("-922337203685477580800"), null) ) @@ -1008,7 +1013,8 @@ class JsonSuite extends QueryTest { // Access elements of an array field of a struct. checkAnswer( - sql("select structWithArrayFields.field1[1], structWithArrayFields.field2[3] from complexTable"), + sql("select structWithArrayFields.field1[1], structWithArrayFields.field2[3] " + + "from complexTable"), Row(5, null) ) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetIOSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetIOSuite.scala index 203bc79f153dd..4d0bf7cf99cdf 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetIOSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetIOSuite.scala @@ -218,7 +218,7 @@ class ParquetIOSuiteBase extends QueryTest with ParquetTest { } test("compression codec") { - def compressionCodecFor(path: String) = { + def compressionCodecFor(path: String): String = { val codecs = ParquetTypesConverter .readMetaData(new Path(path), Some(configuration)) .getBlocks @@ -381,6 +381,27 @@ class ParquetIOSuiteBase extends QueryTest with ParquetTest { } } } + + test("SPARK-6352 DirectParquetOutputCommitter") { + // Write to a parquet file and let it fail. + // _temporary should be missing if direct output committer works. + try { + configuration.set("spark.sql.parquet.output.committer.class", + "org.apache.spark.sql.parquet.DirectParquetOutputCommitter") + sqlContext.udf.register("div0", (x: Int) => x / 0) + withTempPath { dir => + intercept[org.apache.spark.SparkException] { + sqlContext.sql("select div0(1)").saveAsParquetFile(dir.getCanonicalPath) + } + val path = new Path(dir.getCanonicalPath, "_temporary") + val fs = path.getFileSystem(configuration) + assert(!fs.exists(path)) + } + } + finally { + configuration.unset("spark.sql.parquet.output.committer.class") + } + } } class ParquetDataSourceOnIOSuite extends ParquetIOSuiteBase with BeforeAndAfterAll { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetPartitionDiscoverySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetPartitionDiscoverySuite.scala index adb3c9391f6c2..b7561ce7298cb 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetPartitionDiscoverySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetPartitionDiscoverySuite.scala @@ -45,11 +45,11 @@ class ParquetPartitionDiscoverySuite extends QueryTest with ParquetTest { assert(inferPartitionColumnValue(raw, defaultPartitionName) === literal) } - check("10", Literal(10, IntegerType)) - check("1000000000000000", Literal(1000000000000000L, LongType)) - check("1.5", Literal(1.5, FloatType)) - check("hello", Literal("hello", StringType)) - check(defaultPartitionName, Literal(null, NullType)) + check("10", Literal.create(10, IntegerType)) + check("1000000000000000", Literal.create(1000000000000000L, LongType)) + check("1.5", Literal.create(1.5, FloatType)) + check("hello", Literal.create("hello", StringType)) + check(defaultPartitionName, Literal.create(null, NullType)) } test("parse partition") { @@ -75,22 +75,22 @@ class ParquetPartitionDiscoverySuite extends QueryTest with ParquetTest { "file://path/a=10", PartitionValues( ArrayBuffer("a"), - ArrayBuffer(Literal(10, IntegerType)))) + ArrayBuffer(Literal.create(10, IntegerType)))) check( "file://path/a=10/b=hello/c=1.5", PartitionValues( ArrayBuffer("a", "b", "c"), ArrayBuffer( - Literal(10, IntegerType), - Literal("hello", StringType), - Literal(1.5, FloatType)))) + Literal.create(10, IntegerType), + Literal.create("hello", StringType), + Literal.create(1.5, FloatType)))) check( "file://path/a=10/b_hello/c=1.5", PartitionValues( ArrayBuffer("c"), - ArrayBuffer(Literal(1.5, FloatType)))) + ArrayBuffer(Literal.create(1.5, FloatType)))) checkThrows[AssertionError]("file://path/=10", "Empty partition column name") checkThrows[AssertionError]("file://path/a=", "Empty partition column value") diff --git a/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetSchemaSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetSchemaSuite.scala index 61f1cf347ab0f..c964b6d984557 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetSchemaSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetSchemaSuite.scala @@ -180,10 +180,12 @@ class ParquetSchemaSuite extends FunSuite with ParquetTest { val caseClassString = "StructType(List(StructField(c1,IntegerType,false), StructField(c2,BinaryType,true)))" + // scalastyle:off val jsonString = """ |{"type":"struct","fields":[{"name":"c1","type":"integer","nullable":false,"metadata":{}},{"name":"c2","type":"binary","nullable":true,"metadata":{}}]} """.stripMargin + // scalastyle:on val fromCaseClassString = ParquetTypesConverter.convertFromString(caseClassString) val fromJson = ParquetTypesConverter.convertFromString(jsonString) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/DDLTestSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/DDLTestSuite.scala index 54af50c6e10ad..3f24a497390c1 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/sources/DDLTestSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/DDLTestSuite.scala @@ -17,6 +17,7 @@ package org.apache.spark.sql.sources +import org.apache.spark.rdd.RDD import org.apache.spark.sql._ import org.apache.spark.sql.types._ @@ -31,7 +32,7 @@ class DDLScanSource extends RelationProvider { case class SimpleDDLScan(from: Int, to: Int)(@transient val sqlContext: SQLContext) extends BaseRelation with TableScan { - override def schema = + override def schema: StructType = StructType(Seq( StructField("intType", IntegerType, nullable = false, new MetadataBuilder().putString("comment", "test comment").build()), @@ -57,8 +58,9 @@ case class SimpleDDLScan(from: Int, to: Int)(@transient val sqlContext: SQLConte )) - override def buildScan() = sqlContext.sparkContext.parallelize(from to to). - map(e => Row(s"people$e", e * 2)) + override def buildScan(): RDD[Row] = { + sqlContext.sparkContext.parallelize(from to to).map(e => Row(s"people$e", e * 2)) + } } class DDLTestSuite extends DataSourceTest { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/FilteredScanSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/FilteredScanSuite.scala index 72ddc0ea2c8cb..cb5e5147ff189 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/sources/FilteredScanSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/FilteredScanSuite.scala @@ -19,6 +19,7 @@ package org.apache.spark.sql.sources import scala.language.existentials +import org.apache.spark.rdd.RDD import org.apache.spark.sql._ import org.apache.spark.sql.types._ @@ -41,11 +42,13 @@ case class SimpleFilteredScan(from: Int, to: Int)(@transient val sqlContext: SQL StructField("b", IntegerType, nullable = false) :: StructField("c", StringType, nullable = false) :: Nil) - override def buildScan(requiredColumns: Array[String], filters: Array[Filter]) = { + override def buildScan(requiredColumns: Array[String], filters: Array[Filter]): RDD[Row] = { val rowBuilders = requiredColumns.map { case "a" => (i: Int) => Seq(i) case "b" => (i: Int) => Seq(i * 2) - case "c" => (i: Int) => Seq((i - 1 + 'a').toChar.toString * 10) + case "c" => (i: Int) => + val c = (i - 1 + 'a').toChar.toString + Seq(c * 5 + c.toUpperCase() * 5) } FiltersPushed.list = filters @@ -77,7 +80,7 @@ case class SimpleFilteredScan(from: Int, to: Int)(@transient val sqlContext: SQL } def eval(a: Int) = { - val c = (a - 1 + 'a').toChar.toString * 10 + val c = (a - 1 + 'a').toChar.toString * 5 + (a - 1 + 'a').toChar.toString.toUpperCase() * 5 !filters.map(translateFilterOnA(_)(a)).contains(false) && !filters.map(translateFilterOnC(_)(c)).contains(false) } @@ -110,7 +113,8 @@ class FilteredScanSuite extends DataSourceTest { sqlTest( "SELECT * FROM oneToTenFiltered", - (1 to 10).map(i => Row(i, i * 2, (i - 1 + 'a').toChar.toString * 10)).toSeq) + (1 to 10).map(i => Row(i, i * 2, (i - 1 + 'a').toChar.toString * 5 + + (i - 1 + 'a').toChar.toString.toUpperCase() * 5)).toSeq) sqlTest( "SELECT a, b FROM oneToTenFiltered", @@ -182,15 +186,15 @@ class FilteredScanSuite extends DataSourceTest { sqlTest( "SELECT a, b, c FROM oneToTenFiltered WHERE c like 'c%'", - Seq(Row(3, 3 * 2, "c" * 10))) + Seq(Row(3, 3 * 2, "c" * 5 + "C" * 5))) sqlTest( - "SELECT a, b, c FROM oneToTenFiltered WHERE c like 'd%'", - Seq(Row(4, 4 * 2, "d" * 10))) + "SELECT a, b, c FROM oneToTenFiltered WHERE c like '%D'", + Seq(Row(4, 4 * 2, "d" * 5 + "D" * 5))) sqlTest( - "SELECT a, b, c FROM oneToTenFiltered WHERE c like '%e%'", - Seq(Row(5, 5 * 2, "e" * 10))) + "SELECT a, b, c FROM oneToTenFiltered WHERE c like '%eE%'", + Seq(Row(5, 5 * 2, "e" * 5 + "E" * 5))) testPushDown("SELECT * FROM oneToTenFiltered WHERE A = 1", 1) testPushDown("SELECT a FROM oneToTenFiltered WHERE A = 1", 1) @@ -222,6 +226,15 @@ class FilteredScanSuite extends DataSourceTest { testPushDown("SELECT * FROM oneToTenFiltered WHERE a < 3 OR a > 8", 4) testPushDown("SELECT * FROM oneToTenFiltered WHERE NOT (a < 6)", 5) + testPushDown("SELECT a, b, c FROM oneToTenFiltered WHERE c like 'c%'", 1) + testPushDown("SELECT a, b, c FROM oneToTenFiltered WHERE c like 'C%'", 0) + + testPushDown("SELECT a, b, c FROM oneToTenFiltered WHERE c like '%D'", 1) + testPushDown("SELECT a, b, c FROM oneToTenFiltered WHERE c like '%d'", 0) + + testPushDown("SELECT a, b, c FROM oneToTenFiltered WHERE c like '%eE%'", 1) + testPushDown("SELECT a, b, c FROM oneToTenFiltered WHERE c like '%Ee%'", 0) + def testPushDown(sqlString: String, expectedCount: Int): Unit = { test(s"PushDown Returns $expectedCount: $sqlString") { val queryExecution = sql(sqlString).queryExecution diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/PrunedScanSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/PrunedScanSuite.scala index 08fb5380dc026..6a1ddf2f8e98b 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/sources/PrunedScanSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/PrunedScanSuite.scala @@ -19,6 +19,7 @@ package org.apache.spark.sql.sources import scala.language.existentials +import org.apache.spark.rdd.RDD import org.apache.spark.sql._ import org.apache.spark.sql.types._ @@ -34,12 +35,12 @@ case class SimplePrunedScan(from: Int, to: Int)(@transient val sqlContext: SQLCo extends BaseRelation with PrunedScan { - override def schema = + override def schema: StructType = StructType( StructField("a", IntegerType, nullable = false) :: StructField("b", IntegerType, nullable = false) :: Nil) - override def buildScan(requiredColumns: Array[String]) = { + override def buildScan(requiredColumns: Array[String]): RDD[Row] = { val rowBuilders = requiredColumns.map { case "a" => (i: Int) => Seq(i) case "b" => (i: Int) => Seq(i * 2) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/SaveLoadSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/SaveLoadSuite.scala index 43bc8eb2d11a7..cb287ba85c1f8 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/sources/SaveLoadSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/SaveLoadSuite.scala @@ -114,4 +114,4 @@ class SaveLoadSuite extends DataSourceTest with BeforeAndAfterAll { message.contains("Append mode is not supported"), "We should complain that 'Append mode is not supported' for JSON source.") } -} \ No newline at end of file +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/TableScanSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/TableScanSuite.scala index 7928600ac2fb5..60c8c00bda4d5 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/sources/TableScanSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/TableScanSuite.scala @@ -19,6 +19,7 @@ package org.apache.spark.sql.sources import java.sql.{Timestamp, Date} +import org.apache.spark.rdd.RDD import org.apache.spark.sql._ import org.apache.spark.sql.types._ @@ -35,10 +36,10 @@ class SimpleScanSource extends RelationProvider { case class SimpleScan(from: Int, to: Int)(@transient val sqlContext: SQLContext) extends BaseRelation with TableScan { - override def schema = + override def schema: StructType = StructType(StructField("i", IntegerType, nullable = false) :: Nil) - override def buildScan() = sqlContext.sparkContext.parallelize(from to to).map(Row(_)) + override def buildScan(): RDD[Row] = sqlContext.sparkContext.parallelize(from to to).map(Row(_)) } class AllDataTypesScanSource extends SchemaRelationProvider { @@ -57,9 +58,9 @@ case class AllDataTypesScan( extends BaseRelation with TableScan { - override def schema = userSpecifiedSchema + override def schema: StructType = userSpecifiedSchema - override def buildScan() = { + override def buildScan(): RDD[Row] = { sqlContext.sparkContext.parallelize(from to to).map { i => Row( s"str_$i", diff --git a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLEnv.scala b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLEnv.scala index 158c225159720..97b46a01ba5b4 100644 --- a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLEnv.scala +++ b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLEnv.scala @@ -22,6 +22,7 @@ import scala.collection.JavaConversions._ import org.apache.spark.scheduler.StatsReportListener import org.apache.spark.sql.hive.{HiveShim, HiveContext} import org.apache.spark.{Logging, SparkConf, SparkContext} +import org.apache.spark.util.Utils /** A singleton object for the master program. The slaves should not access this. */ private[hive] object SparkSQLEnv extends Logging { @@ -37,7 +38,7 @@ private[hive] object SparkSQLEnv extends Logging { val maybeKryoReferenceTracking = sparkConf.getOption("spark.kryo.referenceTracking") sparkConf - .setAppName(s"SparkSQL::${java.net.InetAddress.getLocalHost.getHostName}") + .setAppName(s"SparkSQL::${Utils.localHostName()}") .set("spark.sql.hive.version", HiveShim.version) .set( "spark.serializer", diff --git a/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/CliSuite.scala b/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/CliSuite.scala index 75738fa22b572..6d1d7c3a4e698 100644 --- a/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/CliSuite.scala +++ b/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/CliSuite.scala @@ -1,13 +1,12 @@ /* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at * - * http://www.apache.org/licenses/LICENSE-2.0 + * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, diff --git a/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2Suites.scala b/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2Suites.scala index bf20acecb1f32..4cf95e7bdfb2b 100644 --- a/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2Suites.scala +++ b/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2Suites.scala @@ -18,6 +18,7 @@ package org.apache.spark.sql.hive.thriftserver import java.io.File +import java.net.URL import java.sql.{Date, DriverManager, Statement} import scala.collection.mutable.ArrayBuffer @@ -41,7 +42,7 @@ import org.apache.spark.sql.hive.HiveShim import org.apache.spark.util.Utils object TestData { - def getTestDataFilePath(name: String) = { + def getTestDataFilePath(name: String): URL = { Thread.currentThread().getContextClassLoader.getResource(s"data/files/$name") } @@ -50,7 +51,7 @@ object TestData { } class HiveThriftBinaryServerSuite extends HiveThriftJdbcTest { - override def mode = ServerMode.binary + override def mode: ServerMode.Value = ServerMode.binary private def withCLIServiceClient(f: ThriftCLIServiceClient => Unit): Unit = { // Transport creation logics below mimics HiveConnection.createBinaryTransport @@ -337,7 +338,7 @@ class HiveThriftBinaryServerSuite extends HiveThriftJdbcTest { } class HiveThriftHttpServerSuite extends HiveThriftJdbcTest { - override def mode = ServerMode.http + override def mode: ServerMode.Value = ServerMode.http test("JDBC query execution") { withJdbcStatement { statement => diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala index 6bb1c47dba920..7c6a7df2bd01e 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala @@ -181,12 +181,13 @@ class HiveContext(sc: SparkContext) extends SQLContext(sc) { val tableFullName = relation.hiveQlTable.getDbName + "." + relation.hiveQlTable.getTableName - catalog.client.alterTable(tableFullName, new Table(hiveTTable)) + catalog.synchronized { + catalog.client.alterTable(tableFullName, new Table(hiveTTable)) + } } case otherRelation => - throw new NotImplementedError( - s"Analyze has only implemented for Hive tables, " + - s"but $tableName is a ${otherRelation.nodeName}") + throw new UnsupportedOperationException( + s"Analyze only works for Hive tables, but $tableName is a ${otherRelation.nodeName}") } } diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveInspectors.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveInspectors.scala index 4afa2e71d77cc..921c6194c7b76 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveInspectors.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveInspectors.scala @@ -593,7 +593,7 @@ private[hive] trait HiveInspectors { case Literal(_, dt) => sys.error(s"Hive doesn't support the constant type [$dt].") // ideally, we don't test the foldable here(but in optimizer), however, some of the // Hive UDF / UDAF requires its argument to be constant objectinspector, we do it eagerly. - case _ if expr.foldable => toInspector(Literal(expr.eval(), expr.dataType)) + case _ if expr.foldable => toInspector(Literal.create(expr.eval(), expr.dataType)) // For those non constant expression, map to object inspector according to its data type case _ => toInspector(expr.dataType) } diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala index 2b5d031741a63..3ed5c5b031736 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala @@ -67,10 +67,11 @@ private[hive] class HiveMetastoreCatalog(hive: HiveContext) extends Catalog with val cacheLoader = new CacheLoader[QualifiedTableName, LogicalPlan]() { override def load(in: QualifiedTableName): LogicalPlan = { logDebug(s"Creating new cached data source for $in") - val table = synchronized { + val table = HiveMetastoreCatalog.this.synchronized { client.getTable(in.database, in.name) } - val userSpecifiedSchema = + + def schemaStringFromParts: Option[String] = { Option(table.getProperty("spark.sql.sources.schema.numParts")).map { numParts => val parts = (0 until numParts.toInt).map { index => val part = table.getProperty(s"spark.sql.sources.schema.part.${index}") @@ -82,10 +83,19 @@ private[hive] class HiveMetastoreCatalog(hive: HiveContext) extends Catalog with part } - // Stick all parts back to a single schema string in the JSON representation - // and convert it back to a StructType. - DataType.fromJson(parts.mkString).asInstanceOf[StructType] + // Stick all parts back to a single schema string. + parts.mkString } + } + + // Originally, we used spark.sql.sources.schema to store the schema of a data source table. + // After SPARK-6024, we removed this flag. + // Although we are not using spark.sql.sources.schema any more, we need to still support. + val schemaString = + Option(table.getProperty("spark.sql.sources.schema")).orElse(schemaStringFromParts) + + val userSpecifiedSchema = + schemaString.map(s => DataType.fromJson(s).asInstanceOf[StructType]) // It does not appear that the ql client for the metastore has a way to enumerate all the // SerDe properties directly... @@ -106,7 +116,15 @@ private[hive] class HiveMetastoreCatalog(hive: HiveContext) extends Catalog with } override def refreshTable(databaseName: String, tableName: String): Unit = { - cachedDataSourceTables.refresh(QualifiedTableName(databaseName, tableName).toLowerCase) + // refreshTable does not eagerly reload the cache. It just invalidate the cache. + // Next time when we use the table, it will be populated in the cache. + // Since we also cache ParquetRealtions converted from Hive Parquet tables and + // adding converted ParquetRealtions into the cache is not defined in the load function + // of the cache (instead, we add the cache entry in convertToParquetRelation), + // it is better at here to invalidate the cache to avoid confusing waring logs from the + // cache loader (e.g. cannot find data source provider, which is only defined for + // data source table.). + invalidateTable(databaseName, tableName) } def invalidateTable(databaseName: String, tableName: String): Unit = { @@ -219,14 +237,49 @@ private[hive] class HiveMetastoreCatalog(hive: HiveContext) extends Catalog with private def convertToParquetRelation(metastoreRelation: MetastoreRelation): LogicalRelation = { val metastoreSchema = StructType.fromAttributes(metastoreRelation.output) val mergeSchema = hive.convertMetastoreParquetWithSchemaMerging - val parquetOptions = Map( - ParquetRelation2.METASTORE_SCHEMA -> metastoreSchema.json, - ParquetRelation2.MERGE_SCHEMA -> mergeSchema.toString) // NOTE: Instead of passing Metastore schema directly to `ParquetRelation2`, we have to // serialize the Metastore schema to JSON and pass it as a data source option because of the // evil case insensitivity issue, which is reconciled within `ParquetRelation2`. - if (metastoreRelation.hiveQlTable.isPartitioned) { + val parquetOptions = Map( + ParquetRelation2.METASTORE_SCHEMA -> metastoreSchema.json, + ParquetRelation2.MERGE_SCHEMA -> mergeSchema.toString) + val tableIdentifier = + QualifiedTableName(metastoreRelation.databaseName, metastoreRelation.tableName) + + def getCached( + tableIdentifier: QualifiedTableName, + pathsInMetastore: Seq[String], + schemaInMetastore: StructType, + partitionSpecInMetastore: Option[PartitionSpec]): Option[LogicalRelation] = { + cachedDataSourceTables.getIfPresent(tableIdentifier) match { + case null => None // Cache miss + case logical@LogicalRelation(parquetRelation: ParquetRelation2) => + // If we have the same paths, same schema, and same partition spec, + // we will use the cached Parquet Relation. + val useCached = + parquetRelation.paths.toSet == pathsInMetastore.toSet && + logical.schema.sameType(metastoreSchema) && + parquetRelation.maybePartitionSpec == partitionSpecInMetastore + + if (useCached) { + Some(logical) + } else { + // If the cached relation is not updated, we invalidate it right away. + cachedDataSourceTables.invalidate(tableIdentifier) + None + } + case other => + logWarning( + s"${metastoreRelation.databaseName}.${metastoreRelation.tableName} should be stored " + + s"as Parquet. However, we are getting a ${other} from the metastore cache. " + + s"This cached entry will be invalidated.") + cachedDataSourceTables.invalidate(tableIdentifier) + None + } + } + + val result = if (metastoreRelation.hiveQlTable.isPartitioned) { val partitionSchema = StructType.fromAttributes(metastoreRelation.partitionKeys) val partitionColumnDataTypes = partitionSchema.map(_.dataType) val partitions = metastoreRelation.hiveQlPartitions.map { p => @@ -238,11 +291,31 @@ private[hive] class HiveMetastoreCatalog(hive: HiveContext) extends Catalog with } val partitionSpec = PartitionSpec(partitionSchema, partitions) val paths = partitions.map(_.path) - LogicalRelation(ParquetRelation2(paths, parquetOptions, None, Some(partitionSpec))(hive)) + + val cached = getCached(tableIdentifier, paths, metastoreSchema, Some(partitionSpec)) + val parquetRelation = cached.getOrElse { + val created = + LogicalRelation(ParquetRelation2(paths, parquetOptions, None, Some(partitionSpec))(hive)) + cachedDataSourceTables.put(tableIdentifier, created) + created + } + + parquetRelation } else { val paths = Seq(metastoreRelation.hiveQlTable.getDataLocation.toString) - LogicalRelation(ParquetRelation2(paths, parquetOptions)(hive)) + + val cached = getCached(tableIdentifier, paths, metastoreSchema, None) + val parquetRelation = cached.getOrElse { + val created = + LogicalRelation(ParquetRelation2(paths, parquetOptions)(hive)) + cachedDataSourceTables.put(tableIdentifier, created) + created + } + + parquetRelation } + + result.newInstance() } override def getTables(databaseName: Option[String]): Seq[(String, Boolean)] = synchronized { @@ -790,7 +863,7 @@ private[hive] case class MetastoreRelation /** An attribute map for determining the ordinal for non-partition columns. */ val columnOrdinals = AttributeMap(attributes.zipWithIndex) - override def newInstance() = { + override def newInstance(): MetastoreRelation = { MetastoreRelation(databaseName, tableName, alias)(table, partitions)(sqlContext) } } diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala index cd8e7c09eea5b..b2ae74efeb097 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala @@ -20,6 +20,8 @@ package org.apache.spark.sql.hive import java.sql.Date +import org.apache.hadoop.hive.ql.exec.{FunctionRegistry, FunctionInfo} + import scala.collection.mutable.ArrayBuffer import org.apache.hadoop.hive.conf.HiveConf @@ -111,13 +113,16 @@ private[hive] object HiveQl { "TOK_REVOKE", + "TOK_SHOW_COMPACTIONS", "TOK_SHOW_CREATETABLE", "TOK_SHOW_GRANT", "TOK_SHOW_ROLE_GRANT", + "TOK_SHOW_ROLE_PRINCIPALS", "TOK_SHOW_ROLES", "TOK_SHOW_SET_ROLE", "TOK_SHOW_TABLESTATUS", "TOK_SHOW_TBLPROPERTIES", + "TOK_SHOW_TRANSACTIONS", "TOK_SHOWCOLUMNS", "TOK_SHOWDATABASES", "TOK_SHOWFUNCTIONS", @@ -574,11 +579,23 @@ https://cwiki.apache.org/confluence/display/Hive/Enhanced+Aggregation%2C+Cube%2C case Token("TOK_QUERY", queryArgs) if Seq("TOK_FROM", "TOK_INSERT").contains(queryArgs.head.getText) => - val (fromClause: Option[ASTNode], insertClauses) = queryArgs match { - case Token("TOK_FROM", args: Seq[ASTNode]) :: insertClauses => - (Some(args.head), insertClauses) - case Token("TOK_INSERT", _) :: Nil => (None, queryArgs) - } + val (fromClause: Option[ASTNode], insertClauses, cteRelations) = + queryArgs match { + case Token("TOK_FROM", args: Seq[ASTNode]) :: insertClauses => + // check if has CTE + insertClauses.last match { + case Token("TOK_CTE", cteClauses) => + val cteRelations = cteClauses.map(node => { + val relation = nodeToRelation(node).asInstanceOf[Subquery] + (relation.alias, relation) + }).toMap + (Some(args.head), insertClauses.init, Some(cteRelations)) + + case _ => (Some(args.head), insertClauses, None) + } + + case Token("TOK_INSERT", _) :: Nil => (None, queryArgs, None) + } // Return one query for each insert clause. val queries = insertClauses.map { case Token("TOK_INSERT", singleInsert) => @@ -659,7 +676,8 @@ https://cwiki.apache.org/confluence/display/Hive/Enhanced+Aggregation%2C+Cube%2C AttributeReference("value", StringType)()), true) } - def matchSerDe(clause: Seq[ASTNode]) = clause match { + def matchSerDe(clause: Seq[ASTNode]) + : (Seq[(String, String)], String, Seq[(String, String)]) = clause match { case Token("TOK_SERDEPROPS", propsClause) :: Nil => val rowFormat = propsClause.map { case Token(name, Token(value, Nil) :: Nil) => (name, value) @@ -791,7 +809,10 @@ https://cwiki.apache.org/confluence/display/Hive/Enhanced+Aggregation%2C+Cube%2C } // If there are multiple INSERTS just UNION them together into on query. - queries.reduceLeft(Union) + val query = queries.reduceLeft(Union) + + // return With plan if there is CTE + cteRelations.map(With(query, _)).getOrElse(query) case Token("TOK_UNION", left :: right :: Nil) => Union(nodeToPlan(left), nodeToPlan(right)) @@ -1201,7 +1222,7 @@ https://cwiki.apache.org/confluence/display/Hive/Enhanced+Aggregation%2C+Cube%2C CreateArray(children.map(nodeToExpr)) case Token("TOK_FUNCTION", Token(RAND(), Nil) :: Nil) => Rand case Token("TOK_FUNCTION", Token(SUBSTR(), Nil) :: string :: pos :: Nil) => - Substring(nodeToExpr(string), nodeToExpr(pos), Literal(Integer.MAX_VALUE, IntegerType)) + Substring(nodeToExpr(string), nodeToExpr(pos), Literal.create(Integer.MAX_VALUE, IntegerType)) case Token("TOK_FUNCTION", Token(SUBSTR(), Nil) :: string :: pos :: length :: Nil) => Substring(nodeToExpr(string), nodeToExpr(pos), nodeToExpr(length)) case Token("TOK_FUNCTION", Token(COALESCE(), Nil) :: list) => Coalesce(list.map(nodeToExpr)) @@ -1213,9 +1234,9 @@ https://cwiki.apache.org/confluence/display/Hive/Enhanced+Aggregation%2C+Cube%2C UnresolvedFunction(name, UnresolvedStar(None) :: Nil) /* Literals */ - case Token("TOK_NULL", Nil) => Literal(null, NullType) - case Token(TRUE(), Nil) => Literal(true, BooleanType) - case Token(FALSE(), Nil) => Literal(false, BooleanType) + case Token("TOK_NULL", Nil) => Literal.create(null, NullType) + case Token(TRUE(), Nil) => Literal.create(true, BooleanType) + case Token(FALSE(), Nil) => Literal.create(false, BooleanType) case Token("TOK_STRINGLITERALSEQUENCE", strings) => Literal(strings.map(s => BaseSemanticAnalyzer.unescapeSQLString(s.getText)).mkString) @@ -1226,21 +1247,21 @@ https://cwiki.apache.org/confluence/display/Hive/Enhanced+Aggregation%2C+Cube%2C try { if (ast.getText.endsWith("L")) { // Literal bigint. - v = Literal(ast.getText.substring(0, ast.getText.length() - 1).toLong, LongType) + v = Literal.create(ast.getText.substring(0, ast.getText.length() - 1).toLong, LongType) } else if (ast.getText.endsWith("S")) { // Literal smallint. - v = Literal(ast.getText.substring(0, ast.getText.length() - 1).toShort, ShortType) + v = Literal.create(ast.getText.substring(0, ast.getText.length() - 1).toShort, ShortType) } else if (ast.getText.endsWith("Y")) { // Literal tinyint. - v = Literal(ast.getText.substring(0, ast.getText.length() - 1).toByte, ByteType) + v = Literal.create(ast.getText.substring(0, ast.getText.length() - 1).toByte, ByteType) } else if (ast.getText.endsWith("BD") || ast.getText.endsWith("D")) { // Literal decimal val strVal = ast.getText.stripSuffix("D").stripSuffix("B") v = Literal(Decimal(strVal)) } else { - v = Literal(ast.getText.toDouble, DoubleType) - v = Literal(ast.getText.toLong, LongType) - v = Literal(ast.getText.toInt, IntegerType) + v = Literal.create(ast.getText.toDouble, DoubleType) + v = Literal.create(ast.getText.toLong, LongType) + v = Literal.create(ast.getText.toInt, IntegerType) } } catch { case nfe: NumberFormatException => // Do nothing @@ -1283,8 +1304,13 @@ https://cwiki.apache.org/confluence/display/Hive/Enhanced+Aggregation%2C+Cube%2C Explode(attributes, nodeToExpr(child)) case Token("TOK_FUNCTION", Token(functionName, Nil) :: children) => + val functionInfo: FunctionInfo = + Option(FunctionRegistry.getFunctionInfo(functionName.toLowerCase)).getOrElse( + sys.error(s"Couldn't find function $functionName")) + val functionClassName = functionInfo.getFunctionClass.getName + HiveGenericUdtf( - new HiveFunctionWrapper(functionName), + new HiveFunctionWrapper(functionClassName), attributes, children.map(nodeToExpr)) diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/TableReader.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/TableReader.scala index 3563472c7ae81..d35291543c9f9 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/TableReader.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/TableReader.scala @@ -142,7 +142,46 @@ class HadoopTableReader( partitionToDeserializer: Map[HivePartition, Class[_ <: Deserializer]], filterOpt: Option[PathFilter]): RDD[Row] = { - val hivePartitionRDDs = partitionToDeserializer.map { case (partition, partDeserializer) => + + // SPARK-5068:get FileStatus and do the filtering locally when the path is not exists + def verifyPartitionPath( + partitionToDeserializer: Map[HivePartition, Class[_ <: Deserializer]]): + Map[HivePartition, Class[_ <: Deserializer]] = { + if (!sc.conf.verifyPartitionPath) { + partitionToDeserializer + } else { + var existPathSet = collection.mutable.Set[String]() + var pathPatternSet = collection.mutable.Set[String]() + partitionToDeserializer.filter { + case (partition, partDeserializer) => + def updateExistPathSetByPathPattern(pathPatternStr: String) { + val pathPattern = new Path(pathPatternStr) + val fs = pathPattern.getFileSystem(sc.hiveconf) + val matches = fs.globStatus(pathPattern) + matches.foreach(fileStatus => existPathSet += fileStatus.getPath.toString) + } + // convert /demo/data/year/month/day to /demo/data/*/*/*/ + def getPathPatternByPath(parNum: Int, tempPath: Path): String = { + var path = tempPath + for (i <- (1 to parNum)) path = path.getParent + val tails = (1 to parNum).map(_ => "*").mkString("/", "/", "/") + path.toString + tails + } + + val partPath = HiveShim.getDataLocationPath(partition) + val partNum = Utilities.getPartitionDesc(partition).getPartSpec.size(); + var pathPatternStr = getPathPatternByPath(partNum, partPath) + if (!pathPatternSet.contains(pathPatternStr)) { + pathPatternSet += pathPatternStr + updateExistPathSetByPathPattern(pathPatternStr) + } + existPathSet.contains(partPath.toString) + } + } + } + + val hivePartitionRDDs = verifyPartitionPath(partitionToDeserializer) + .map { case (partition, partDeserializer) => val partDesc = Utilities.getPartitionDesc(partition) val partPath = HiveShim.getDataLocationPath(partition) val inputPathStr = applyFilterIfNeeded(partPath, filterOpt) diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/InsertIntoHiveTable.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/InsertIntoHiveTable.scala index cdf012b5117be..6c96747439683 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/InsertIntoHiveTable.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/InsertIntoHiveTable.scala @@ -50,7 +50,7 @@ case class InsertIntoHiveTable( @transient val sc: HiveContext = sqlContext.asInstanceOf[HiveContext] @transient lazy val outputClass = newSerializer(table.tableDesc).getSerializedClass @transient private lazy val hiveContext = new Context(sc.hiveconf) - @transient private lazy val db = Hive.get(sc.hiveconf) + @transient private lazy val catalog = sc.catalog private def newSerializer(tableDesc: TableDesc): Serializer = { val serializer = tableDesc.getDeserializerClass.newInstance().asInstanceOf[Serializer] @@ -199,38 +199,45 @@ case class InsertIntoHiveTable( orderedPartitionSpec.put(entry.getName,partitionSpec.get(entry.getName).getOrElse("")) } val partVals = MetaStoreUtils.getPvals(table.hiveQlTable.getPartCols, partitionSpec) - db.validatePartitionNameCharacters(partVals) + catalog.synchronized { + catalog.client.validatePartitionNameCharacters(partVals) + } // inheritTableSpecs is set to true. It should be set to false for a IMPORT query // which is currently considered as a Hive native command. val inheritTableSpecs = true // TODO: Correctly set isSkewedStoreAsSubdir. val isSkewedStoreAsSubdir = false if (numDynamicPartitions > 0) { - db.loadDynamicPartitions( - outputPath, - qualifiedTableName, - orderedPartitionSpec, - overwrite, - numDynamicPartitions, - holdDDLTime, - isSkewedStoreAsSubdir - ) + catalog.synchronized { + catalog.client.loadDynamicPartitions( + outputPath, + qualifiedTableName, + orderedPartitionSpec, + overwrite, + numDynamicPartitions, + holdDDLTime, + isSkewedStoreAsSubdir) + } } else { - db.loadPartition( + catalog.synchronized { + catalog.client.loadPartition( + outputPath, + qualifiedTableName, + orderedPartitionSpec, + overwrite, + holdDDLTime, + inheritTableSpecs, + isSkewedStoreAsSubdir) + } + } + } else { + catalog.synchronized { + catalog.client.loadTable( outputPath, qualifiedTableName, - orderedPartitionSpec, overwrite, - holdDDLTime, - inheritTableSpecs, - isSkewedStoreAsSubdir) + holdDDLTime) } - } else { - db.loadTable( - outputPath, - qualifiedTableName, - overwrite, - holdDDLTime) } // Invalidate the cache. diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/commands.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/commands.scala index 4345ffbf30f77..99dc58646ddd6 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/commands.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/commands.scala @@ -58,12 +58,13 @@ case class DropTable( try { hiveContext.cacheManager.tryUncacheQuery(hiveContext.table(tableName)) } catch { - // This table's metadata is not in + // This table's metadata is not in Hive metastore (e.g. the table does not exist). case _: org.apache.hadoop.hive.ql.metadata.InvalidTableException => + case _: org.apache.spark.sql.catalyst.analysis.NoSuchTableException => // Other Throwables can be caused by users providing wrong parameters in OPTIONS // (e.g. invalid paths). We catch it and log a warning message. // Users should be able to drop such kinds of tables regardless if there is an error. - case e: Throwable => log.warn(s"${e.getMessage}") + case e: Throwable => log.warn(s"${e.getMessage}", e) } hiveContext.invalidateTable(tableName) hiveContext.runSqlHive(s"DROP TABLE $ifExistsClause$tableName") diff --git a/sql/hive/src/test/resources/golden/CTE feature #1-0-eedabbfe6ba8799f7b7782fb47a82768 b/sql/hive/src/test/resources/golden/CTE feature #1-0-eedabbfe6ba8799f7b7782fb47a82768 new file mode 100644 index 0000000000000..f6ba75da254ca --- /dev/null +++ b/sql/hive/src/test/resources/golden/CTE feature #1-0-eedabbfe6ba8799f7b7782fb47a82768 @@ -0,0 +1,3 @@ +5 +5 +5 diff --git a/sql/hive/src/test/resources/golden/CTE feature #2-0-aa03d104251f97e36bc52279cb9931c9 b/sql/hive/src/test/resources/golden/CTE feature #2-0-aa03d104251f97e36bc52279cb9931c9 new file mode 100644 index 0000000000000..ca7b591095e28 --- /dev/null +++ b/sql/hive/src/test/resources/golden/CTE feature #2-0-aa03d104251f97e36bc52279cb9931c9 @@ -0,0 +1,4 @@ +val_4 +val_5 +val_5 +val_5 diff --git a/sql/hive/src/test/resources/golden/CTE feature #3-0-b5d4bf3c0ee92b2fda0ca24f422383f2 b/sql/hive/src/test/resources/golden/CTE feature #3-0-b5d4bf3c0ee92b2fda0ca24f422383f2 new file mode 100644 index 0000000000000..b8626c4cff284 --- /dev/null +++ b/sql/hive/src/test/resources/golden/CTE feature #3-0-b5d4bf3c0ee92b2fda0ca24f422383f2 @@ -0,0 +1 @@ +4 diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/ErrorPositionSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/ErrorPositionSuite.scala index 968557c9c4686..d960a30e00738 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/ErrorPositionSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/ErrorPositionSuite.scala @@ -136,7 +136,7 @@ class ErrorPositionSuite extends QueryTest with BeforeAndAfter { * @param query the query to analyze * @param token a unique token in the string that should be indicated by the exception */ - def positionTest(name: String, query: String, token: String) = { + def positionTest(name: String, query: String, token: String): Unit = { def parseTree = Try(quietly(HiveQl.dumpTree(HiveQl.getAst(query)))).getOrElse("") diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveInspectorSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveInspectorSuite.scala index 3181cfe40016c..2a7374cc172b7 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveInspectorSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveInspectorSuite.scala @@ -79,9 +79,9 @@ class HiveInspectorSuite extends FunSuite with HiveInspectors { Literal(Decimal(BigDecimal(123.123))) :: Literal(new java.sql.Timestamp(123123)) :: Literal(Array[Byte](1,2,3)) :: - Literal(Seq[Int](1,2,3), ArrayType(IntegerType)) :: - Literal(Map[Int, Int](1->2, 2->1), MapType(IntegerType, IntegerType)) :: - Literal(Row(1,2.0d,3.0f), + Literal.create(Seq[Int](1,2,3), ArrayType(IntegerType)) :: + Literal.create(Map[Int, Int](1->2, 2->1), MapType(IntegerType, IntegerType)) :: + Literal.create(Row(1,2.0d,3.0f), StructType(StructField("c1", IntegerType) :: StructField("c2", DoubleType) :: StructField("c3", FloatType) :: Nil)) :: @@ -116,21 +116,20 @@ class HiveInspectorSuite extends FunSuite with HiveInspectors { } def checkDataType(dt1: Seq[DataType], dt2: Seq[DataType]): Unit = { - dt1.zip(dt2).map { - case (dd1, dd2) => - assert(dd1.getClass === dd2.getClass) // DecimalType doesn't has the default precision info + dt1.zip(dt2).foreach { case (dd1, dd2) => + assert(dd1.getClass === dd2.getClass) // DecimalType doesn't has the default precision info } } def checkValues(row1: Seq[Any], row2: Seq[Any]): Unit = { - row1.zip(row2).map { - case (r1, r2) => checkValue(r1, r2) + row1.zip(row2).foreach { case (r1, r2) => + checkValue(r1, r2) } } def checkValues(row1: Seq[Any], row2: Row): Unit = { - row1.zip(row2.toSeq).map { - case (r1, r2) => checkValue(r1, r2) + row1.zip(row2.toSeq).foreach { case (r1, r2) => + checkValue(r1, r2) } } @@ -141,7 +140,7 @@ class HiveInspectorSuite extends FunSuite with HiveInspectors { assert(r1.compare(r2) === 0) case (r1: Array[Byte], r2: Array[Byte]) if r1 != null && r2 != null && r1.length == r2.length => - r1.zip(r2).map { case (b1, b2) => assert(b1 === b2) } + r1.zip(r2).foreach { case (b1, b2) => assert(b1 === b2) } case (r1, r2) => assert(r1 === r2) } } @@ -166,7 +165,8 @@ class HiveInspectorSuite extends FunSuite with HiveInspectors { val constantData = constantExprs.map(_.eval()) val constantNullData = constantData.map(_ => null) val constantWritableOIs = constantExprs.map(e => toWritableInspector(e.dataType)) - val constantNullWritableOIs = constantExprs.map(e => toInspector(Literal(null, e.dataType))) + val constantNullWritableOIs = + constantExprs.map(e => toInspector(Literal.create(null, e.dataType))) checkValues(constantData, constantData.zip(constantWritableOIs).map { case (d, oi) => unwrap(wrap(d, oi), oi) @@ -202,7 +202,8 @@ class HiveInspectorSuite extends FunSuite with HiveInspectors { case (t, idx) => StructField(s"c_$idx", t) }) - checkValues(row, unwrap(wrap(Row.fromSeq(row), toInspector(dt)), toInspector(dt)).asInstanceOf[Row]) + checkValues(row, + unwrap(wrap(Row.fromSeq(row), toInspector(dt)), toInspector(dt)).asInstanceOf[Row]) checkValue(null, unwrap(wrap(null, toInspector(dt)), toInspector(dt))) } @@ -212,8 +213,10 @@ class HiveInspectorSuite extends FunSuite with HiveInspectors { val d = row(0) :: row(0) :: Nil checkValue(d, unwrap(wrap(d, toInspector(dt)), toInspector(dt))) checkValue(null, unwrap(wrap(null, toInspector(dt)), toInspector(dt))) - checkValue(d, unwrap(wrap(d, toInspector(Literal(d, dt))), toInspector(Literal(d, dt)))) - checkValue(d, unwrap(wrap(null, toInspector(Literal(d, dt))), toInspector(Literal(d, dt)))) + checkValue(d, + unwrap(wrap(d, toInspector(Literal.create(d, dt))), toInspector(Literal.create(d, dt)))) + checkValue(d, + unwrap(wrap(null, toInspector(Literal.create(d, dt))), toInspector(Literal.create(d, dt)))) } test("wrap / unwrap Map Type") { @@ -222,7 +225,9 @@ class HiveInspectorSuite extends FunSuite with HiveInspectors { val d = Map(row(0) -> row(1)) checkValue(d, unwrap(wrap(d, toInspector(dt)), toInspector(dt))) checkValue(null, unwrap(wrap(null, toInspector(dt)), toInspector(dt))) - checkValue(d, unwrap(wrap(d, toInspector(Literal(d, dt))), toInspector(Literal(d, dt)))) - checkValue(d, unwrap(wrap(null, toInspector(Literal(d, dt))), toInspector(Literal(d, dt)))) + checkValue(d, + unwrap(wrap(d, toInspector(Literal.create(d, dt))), toInspector(Literal.create(d, dt)))) + checkValue(d, + unwrap(wrap(null, toInspector(Literal.create(d, dt))), toInspector(Literal.create(d, dt)))) } } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/InsertIntoHiveTableSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/InsertIntoHiveTableSuite.scala index 8011952e0d535..ecb990e8aac91 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/InsertIntoHiveTableSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/InsertIntoHiveTableSuite.scala @@ -115,11 +115,36 @@ class InsertIntoHiveTableSuite extends QueryTest with BeforeAndAfter { test("SPARK-4203:random partition directory order") { sql("CREATE TABLE tmp_table (key int, value string)") val tmpDir = Utils.createTempDir() - sql(s"CREATE TABLE table_with_partition(c1 string) PARTITIONED by (p1 string,p2 string,p3 string,p4 string,p5 string) location '${tmpDir.toURI.toString}' ") - sql("INSERT OVERWRITE TABLE table_with_partition partition (p1='a',p2='b',p3='c',p4='c',p5='1') SELECT 'blarr' FROM tmp_table") - sql("INSERT OVERWRITE TABLE table_with_partition partition (p1='a',p2='b',p3='c',p4='c',p5='2') SELECT 'blarr' FROM tmp_table") - sql("INSERT OVERWRITE TABLE table_with_partition partition (p1='a',p2='b',p3='c',p4='c',p5='3') SELECT 'blarr' FROM tmp_table") - sql("INSERT OVERWRITE TABLE table_with_partition partition (p1='a',p2='b',p3='c',p4='c',p5='4') SELECT 'blarr' FROM tmp_table") + sql( + s""" + |CREATE TABLE table_with_partition(c1 string) + |PARTITIONED by (p1 string,p2 string,p3 string,p4 string,p5 string) + |location '${tmpDir.toURI.toString}' + """.stripMargin) + sql( + """ + |INSERT OVERWRITE TABLE table_with_partition + |partition (p1='a',p2='b',p3='c',p4='c',p5='1') + |SELECT 'blarr' FROM tmp_table + """.stripMargin) + sql( + """ + |INSERT OVERWRITE TABLE table_with_partition + |partition (p1='a',p2='b',p3='c',p4='c',p5='2') + |SELECT 'blarr' FROM tmp_table + """.stripMargin) + sql( + """ + |INSERT OVERWRITE TABLE table_with_partition + |partition (p1='a',p2='b',p3='c',p4='c',p5='3') + |SELECT 'blarr' FROM tmp_table + """.stripMargin) + sql( + """ + |INSERT OVERWRITE TABLE table_with_partition + |partition (p1='a',p2='b',p3='c',p4='c',p5='4') + |SELECT 'blarr' FROM tmp_table + """.stripMargin) def listFolders(path: File, acc: List[String]): List[List[String]] = { val dir = path.listFiles() val folders = dir.filter(_.isDirectory).toList @@ -196,34 +221,42 @@ class InsertIntoHiveTableSuite extends QueryTest with BeforeAndAfter { testData.registerTempTable("testData") val testDatawithNull = TestHive.sparkContext.parallelize( - (1 to 10).map(i => ThreeCloumntable(i, i.toString,null))).toDF() + (1 to 10).map(i => ThreeCloumntable(i, i.toString, null))).toDF() val tmpDir = Utils.createTempDir() - sql(s"CREATE TABLE table_with_partition(key int,value string) PARTITIONED by (ds string) location '${tmpDir.toURI.toString}' ") - sql("INSERT OVERWRITE TABLE table_with_partition partition (ds='1') SELECT key,value FROM testData") + sql( + s""" + |CREATE TABLE table_with_partition(key int,value string) + |PARTITIONED by (ds string) location '${tmpDir.toURI.toString}' + """.stripMargin) + sql( + """ + |INSERT OVERWRITE TABLE table_with_partition + |partition (ds='1') SELECT key,value FROM testData + """.stripMargin) // test schema the same between partition and table sql("ALTER TABLE table_with_partition CHANGE COLUMN key key BIGINT") checkAnswer(sql("select key,value from table_with_partition where ds='1' "), - testData.collect.toSeq + testData.collect().toSeq ) // test difference type of field sql("ALTER TABLE table_with_partition CHANGE COLUMN key key BIGINT") checkAnswer(sql("select key,value from table_with_partition where ds='1' "), - testData.collect.toSeq + testData.collect().toSeq ) // add column to table sql("ALTER TABLE table_with_partition ADD COLUMNS(key1 string)") checkAnswer(sql("select key,value,key1 from table_with_partition where ds='1' "), - testDatawithNull.collect.toSeq + testDatawithNull.collect().toSeq ) // change column name to table sql("ALTER TABLE table_with_partition CHANGE COLUMN key keynew BIGINT") checkAnswer(sql("select keynew,value from table_with_partition where ds='1' "), - testData.collect.toSeq + testData.collect().toSeq ) sql("DROP TABLE table_with_partition") diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/MetastoreDataSourcesSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/MetastoreDataSourcesSuite.scala index e5ad0bf552073..e09c702c8969e 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/MetastoreDataSourcesSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/MetastoreDataSourcesSuite.scala @@ -25,6 +25,8 @@ import org.scalatest.BeforeAndAfterEach import org.apache.commons.io.FileUtils import org.apache.hadoop.fs.Path +import org.apache.hadoop.hive.metastore.TableType +import org.apache.hadoop.hive.ql.metadata.Table import org.apache.hadoop.mapred.InvalidInputException import org.apache.spark.sql._ @@ -682,6 +684,27 @@ class MetastoreDataSourcesSuite extends QueryTest with BeforeAndAfterEach { assert(schema === actualSchema) } + test("SPARK-6655 still support a schema stored in spark.sql.sources.schema") { + val tableName = "spark6655" + val schema = StructType(StructField("int", IntegerType, true) :: Nil) + // Manually create the metadata in metastore. + val tbl = new Table("default", tableName) + tbl.setProperty("spark.sql.sources.provider", "json") + tbl.setProperty("spark.sql.sources.schema", schema.json) + tbl.setProperty("EXTERNAL", "FALSE") + tbl.setTableType(TableType.MANAGED_TABLE) + tbl.setSerdeParam("path", catalog.hiveDefaultTableFilePath(tableName)) + catalog.synchronized { + catalog.client.createTable(tbl) + } + + invalidateTable(tableName) + val actualSchema = table(tableName).schema + assert(schema === actualSchema) + sql(s"drop table $tableName") + } + + test("insert into a table") { def createDF(from: Int, to: Int): DataFrame = createDataFrame((from to to).map(i => Tuple2(i, s"str$i"))).toDF("c1", "c2") diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/QueryPartitionSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/QueryPartitionSuite.scala new file mode 100644 index 0000000000000..a787fa5546e76 --- /dev/null +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/QueryPartitionSuite.scala @@ -0,0 +1,66 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.hive + +import com.google.common.io.Files + +import org.apache.spark.sql.{QueryTest, _} +import org.apache.spark.sql.hive.test.TestHive +import org.apache.spark.sql.hive.test.TestHive._ +import org.apache.spark.util.Utils + + +class QueryPartitionSuite extends QueryTest { + import org.apache.spark.sql.hive.test.TestHive.implicits._ + + test("SPARK-5068: query data when path doesn't exists"){ + val testData = TestHive.sparkContext.parallelize( + (1 to 10).map(i => TestData(i, i.toString))).toDF() + testData.registerTempTable("testData") + + val tmpDir = Files.createTempDir() + // create the table for test + sql(s"CREATE TABLE table_with_partition(key int,value string) " + + s"PARTITIONED by (ds string) location '${tmpDir.toURI.toString}' ") + sql("INSERT OVERWRITE TABLE table_with_partition partition (ds='1') " + + "SELECT key,value FROM testData") + sql("INSERT OVERWRITE TABLE table_with_partition partition (ds='2') " + + "SELECT key,value FROM testData") + sql("INSERT OVERWRITE TABLE table_with_partition partition (ds='3') " + + "SELECT key,value FROM testData") + sql("INSERT OVERWRITE TABLE table_with_partition partition (ds='4') " + + "SELECT key,value FROM testData") + + // test for the exist path + checkAnswer(sql("select key,value from table_with_partition"), + testData.toSchemaRDD.collect ++ testData.toSchemaRDD.collect + ++ testData.toSchemaRDD.collect ++ testData.toSchemaRDD.collect) + + // delete the path of one partition + val folders = tmpDir.listFiles.filter(_.isDirectory) + Utils.deleteRecursively(folders(0)) + + // test for after delete the path + checkAnswer(sql("select key,value from table_with_partition"), + testData.toSchemaRDD.collect ++ testData.toSchemaRDD.collect + ++ testData.toSchemaRDD.collect) + + sql("DROP TABLE table_with_partition") + sql("DROP TABLE createAndInsertTest") + } +} diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/StatisticsSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/StatisticsSuite.scala index 1e05a024b8807..00a69de9e4262 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/StatisticsSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/StatisticsSuite.scala @@ -120,7 +120,7 @@ class StatisticsSuite extends QueryTest with BeforeAndAfterAll { // Try to analyze a temp table sql("""SELECT * FROM src""").registerTempTable("tempTable") - intercept[NotImplementedError] { + intercept[UnsupportedOperationException] { analyze("tempTable") } catalog.unregisterTable(Seq("tempTable")) @@ -142,7 +142,7 @@ class StatisticsSuite extends QueryTest with BeforeAndAfterAll { after: () => Unit, query: String, expectedAnswer: Seq[Row], - ct: ClassTag[_]) = { + ct: ClassTag[_]): Unit = { before() var df = sql(query) diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/BigDataBenchmarkSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/BigDataBenchmarkSuite.scala index 42a82c1fbf5c7..a3f5921a0cb23 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/BigDataBenchmarkSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/BigDataBenchmarkSuite.scala @@ -28,6 +28,7 @@ import org.apache.spark.sql.hive.test.TestHive._ class BigDataBenchmarkSuite extends HiveComparisonTest { val testDataDirectory = new File("target" + File.separator + "big-data-benchmark-testdata") + val userVisitPath = new File(testDataDirectory, "uservisits").getCanonicalPath val testTables = Seq( TestTable( "rankings", @@ -63,7 +64,7 @@ class BigDataBenchmarkSuite extends HiveComparisonTest { | searchWord STRING, | duration INT) | ROW FORMAT DELIMITED FIELDS TERMINATED BY "," - | STORED AS TEXTFILE LOCATION "${new File(testDataDirectory, "uservisits").getCanonicalPath}" + | STORED AS TEXTFILE LOCATION "$userVisitPath" """.stripMargin.cmd), TestTable( "documents", @@ -83,7 +84,10 @@ class BigDataBenchmarkSuite extends HiveComparisonTest { "SELECT pageURL, pageRank FROM rankings WHERE pageRank > 1") createQueryTest("query2", - "SELECT SUBSTR(sourceIP, 1, 10), SUM(adRevenue) FROM uservisits GROUP BY SUBSTR(sourceIP, 1, 10)") + """ + |SELECT SUBSTR(sourceIP, 1, 10), SUM(adRevenue) FROM uservisits + |GROUP BY SUBSTR(sourceIP, 1, 10) + """.stripMargin) createQueryTest("query3", """ @@ -113,8 +117,8 @@ class BigDataBenchmarkSuite extends HiveComparisonTest { |CREATE TABLE url_counts_total AS | SELECT SUM(count) AS totalCount, destpage | FROM url_counts_partial GROUP BY destpage - |-- The following queries run, but generate different results in HIVE likely because the UDF is not deterministic - |-- given different input splits. + |-- The following queries run, but generate different results in HIVE + |-- likely because the UDF is not deterministic given different input splits. |-- SELECT CAST(SUM(count) AS INT) FROM url_counts_partial |-- SELECT COUNT(*) FROM url_counts_partial |-- SELECT * FROM url_counts_partial diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveComparisonTest.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveComparisonTest.scala index 8f3285242091c..027056d4b865f 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveComparisonTest.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveComparisonTest.scala @@ -138,7 +138,7 @@ abstract class HiveComparisonTest case _ => plan.children.iterator.exists(isSorted) } - val orderedAnswer = hiveQuery.logical match { + val orderedAnswer = hiveQuery.analyzed match { // Clean out non-deterministic time schema info. // Hack: Hive simply prints the result of a SET command to screen, // and does not return it as a query answer. @@ -255,8 +255,9 @@ abstract class HiveComparisonTest .filterNot(_ contains "hive.outerjoin.supports.filters") .filterNot(_ contains "hive.exec.post.hooks") - if (allQueries != queryList) + if (allQueries != queryList) { logWarning(s"Simplifications made on unsupported operations for test $testCaseName") + } lazy val consoleTestCase = { val quotes = "\"\"\"" @@ -299,19 +300,22 @@ abstract class HiveComparisonTest val hiveQueries = queryList.map(new TestHive.HiveQLQueryExecution(_)) // Make sure we can at least parse everything before attempting hive execution. - hiveQueries.foreach(_.logical) + hiveQueries.foreach(_.analyzed) val computedResults = (queryList.zipWithIndex, hiveQueries, hiveCacheFiles).zipped.map { case ((queryString, i), hiveQuery, cachedAnswerFile)=> try { // Hooks often break the harness and don't really affect our test anyway, don't // even try running them. - if (installHooksCommand.findAllMatchIn(queryString).nonEmpty) + if (installHooksCommand.findAllMatchIn(queryString).nonEmpty) { sys.error("hive exec hooks not supported for tests.") + } - logWarning(s"Running query ${i+1}/${queryList.size} with hive.") + logWarning(s"Running query ${i + 1}/${queryList.size} with hive.") // Analyze the query with catalyst to ensure test tables are loaded. val answer = hiveQuery.analyzed match { - case _: ExplainCommand => Nil // No need to execute EXPLAIN queries as we don't check the output. + case _: ExplainCommand => + // No need to execute EXPLAIN queries as we don't check the output. + Nil case _ => TestHive.runSqlHive(queryString) } @@ -394,21 +398,24 @@ abstract class HiveComparisonTest case tf: org.scalatest.exceptions.TestFailedException => throw tf case originalException: Exception => if (System.getProperty("spark.hive.canarytest") != null) { - // When we encounter an error we check to see if the environment is still okay by running a simple query. - // If this fails then we halt testing since something must have gone seriously wrong. + // When we encounter an error we check to see if the environment is still + // okay by running a simple query. If this fails then we halt testing since + // something must have gone seriously wrong. try { new TestHive.HiveQLQueryExecution("SELECT key FROM src").stringResult() TestHive.runSqlHive("SELECT key FROM src") } catch { case e: Exception => - logError(s"FATAL ERROR: Canary query threw $e This implies that the testing environment has likely been corrupted.") - // The testing setup traps exits so wait here for a long time so the developer can see when things started - // to go wrong. + logError(s"FATAL ERROR: Canary query threw $e This implies that the " + + "testing environment has likely been corrupted.") + // The testing setup traps exits so wait here for a long time so the developer + // can see when things started to go wrong. Thread.sleep(1000000) } } - // If the canary query didn't fail then the environment is still okay, so just throw the original exception. + // If the canary query didn't fail then the environment is still okay, + // so just throw the original exception. throw originalException } } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQueryFileTest.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQueryFileTest.scala index 02518d516261b..f7b37dae0a5f3 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQueryFileTest.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQueryFileTest.scala @@ -24,8 +24,9 @@ import org.apache.spark.sql.catalyst.util._ /** * A framework for running the query tests that are listed as a set of text files. * - * TestSuites that derive from this class must provide a map of testCaseName -> testCaseFiles that should be included. - * Additionally, there is support for whitelisting and blacklisting tests as development progresses. + * TestSuites that derive from this class must provide a map of testCaseName -> testCaseFiles + * that should be included. Additionally, there is support for whitelisting and blacklisting + * tests as development progresses. */ abstract class HiveQueryFileTest extends HiveComparisonTest { /** A list of tests deemed out of scope and thus completely disregarded */ @@ -54,15 +55,17 @@ abstract class HiveQueryFileTest extends HiveComparisonTest { case (testCaseName, testCaseFile) => if (blackList.map(_.r.pattern.matcher(testCaseName).matches()).reduceLeft(_||_)) { logDebug(s"Blacklisted test skipped $testCaseName") - } else if (realWhiteList.map(_.r.pattern.matcher(testCaseName).matches()).reduceLeft(_||_) || runAll) { + } else if (realWhiteList.map(_.r.pattern.matcher(testCaseName).matches()).reduceLeft(_||_) || + runAll) { // Build a test case and submit it to scala test framework... val queriesString = fileToString(testCaseFile) createQueryTest(testCaseName, queriesString) } else { // Only output warnings for the built in whitelist as this clutters the output when the user // trying to execute a single test from the commandline. - if(System.getProperty(whiteListProperty) == null && !runAll) + if (System.getProperty(whiteListProperty) == null && !runAll) { ignore(testCaseName) {} + } } } } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala index de140fc72a2c3..1222fbabd8b33 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala @@ -37,7 +37,8 @@ import org.apache.spark.sql.hive.test.TestHive._ case class TestData(a: Int, b: String) /** - * A set of test cases expressed in Hive QL that are not covered by the tests included in the hive distribution. + * A set of test cases expressed in Hive QL that are not covered by the tests + * included in the hive distribution. */ class HiveQuerySuite extends HiveComparisonTest with BeforeAndAfter { private val originalTimeZone = TimeZone.getDefault @@ -237,7 +238,8 @@ class HiveQuerySuite extends HiveComparisonTest with BeforeAndAfter { } createQueryTest("modulus", - "SELECT 11 % 10, IF((101.1 % 100.0) BETWEEN 1.01 AND 1.11, \"true\", \"false\"), (101 / 2) % 10 FROM src LIMIT 1") + "SELECT 11 % 10, IF((101.1 % 100.0) BETWEEN 1.01 AND 1.11, \"true\", \"false\"), " + + "(101 / 2) % 10 FROM src LIMIT 1") test("Query expressed in SQL") { setConf("spark.sql.dialect", "sql") @@ -309,7 +311,8 @@ class HiveQuerySuite extends HiveComparisonTest with BeforeAndAfter { "SELECT * FROM src a JOIN src b ON a.key = b.key") createQueryTest("small.cartesian", - "SELECT a.key, b.key FROM (SELECT key FROM src WHERE key < 1) a JOIN (SELECT key FROM src WHERE key = 2) b") + "SELECT a.key, b.key FROM (SELECT key FROM src WHERE key < 1) a JOIN " + + "(SELECT key FROM src WHERE key = 2) b") createQueryTest("length.udf", "SELECT length(\"test\") FROM src LIMIT 1") @@ -457,6 +460,7 @@ class HiveQuerySuite extends HiveComparisonTest with BeforeAndAfter { createQueryTest("lateral view3", "FROM src SELECT key, D.* lateral view explode(array(key+3, key+4)) D as CX") + // scalastyle:off createQueryTest("lateral view4", """ |create table src_lv1 (key string, value string); @@ -466,6 +470,7 @@ class HiveQuerySuite extends HiveComparisonTest with BeforeAndAfter { |insert overwrite table src_lv1 SELECT key, D.* lateral view explode(array(key+3, key+4)) D as CX |insert overwrite table src_lv2 SELECT key, D.* lateral view explode(array(key+3, key+4)) D as CX """.stripMargin) + // scalastyle:on createQueryTest("lateral view5", "FROM src SELECT explode(array(key+3, key+4))") @@ -537,6 +542,21 @@ class HiveQuerySuite extends HiveComparisonTest with BeforeAndAfter { createQueryTest("select null from table", "SELECT null FROM src LIMIT 1") + createQueryTest("CTE feature #1", + "with q1 as (select key from src) select * from q1 where key = 5") + + createQueryTest("CTE feature #2", + """with q1 as (select * from src where key= 5), + |q2 as (select * from src s2 where key = 4) + |select value from q1 union all select value from q2 + """.stripMargin) + + createQueryTest("CTE feature #3", + """with q1 as (select key from src) + |from q1 + |select * where key = 4 + """.stripMargin) + test("predicates contains an empty AttributeSet() references") { sql( """ @@ -584,7 +604,7 @@ class HiveQuerySuite extends HiveComparisonTest with BeforeAndAfter { } } - def isExplanation(result: DataFrame) = { + def isExplanation(result: DataFrame): Boolean = { val explanation = result.select('plan).collect().map { case Row(plan: String) => plan } explanation.contains("== Physical Plan ==") } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveResolutionSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveResolutionSuite.scala index f4440e5b7846a..8ad3627504229 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveResolutionSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveResolutionSuite.scala @@ -25,7 +25,8 @@ case class Nested(a: Int, B: Int) case class Data(a: Int, B: Int, n: Nested, nestedArray: Seq[Nested]) /** - * A set of test cases expressed in Hive QL that are not covered by the tests included in the hive distribution. + * A set of test cases expressed in Hive QL that are not covered by the tests + * included in the hive distribution. */ class HiveResolutionSuite extends HiveComparisonTest { diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveSerDeSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveSerDeSuite.scala index 7486bfa82b00b..d05e11fcf281b 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveSerDeSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveSerDeSuite.scala @@ -26,8 +26,9 @@ import org.apache.spark.sql.hive.test.TestHive */ class HiveSerDeSuite extends HiveComparisonTest with BeforeAndAfterAll { - override def beforeAll() = { + override def beforeAll(): Unit = { TestHive.cacheTables = false + super.beforeAll() } createQueryTest( diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveTypeCoercionSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveTypeCoercionSuite.scala index ab0e0443c7faa..f0f04f8c73fb4 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveTypeCoercionSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveTypeCoercionSuite.scala @@ -35,8 +35,10 @@ class HiveTypeCoercionSuite extends HiveComparisonTest { val nullVal = "null" baseTypes.init.foreach { i => - createQueryTest(s"case when then $i else $nullVal end ", s"SELECT case when true then $i else $nullVal end FROM src limit 1") - createQueryTest(s"case when then $nullVal else $i end ", s"SELECT case when true then $nullVal else $i end FROM src limit 1") + createQueryTest(s"case when then $i else $nullVal end ", + s"SELECT case when true then $i else $nullVal end FROM src limit 1") + createQueryTest(s"case when then $nullVal else $i end ", + s"SELECT case when true then $nullVal else $i end FROM src limit 1") } test("[SPARK-2210] boolean cast on boolean value should be removed") { diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveUdfSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveUdfSuite.scala index d7c5d1a25a82b..7f49eac490572 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveUdfSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveUdfSuite.scala @@ -123,9 +123,10 @@ class HiveUdfSuite extends QueryTest { IntegerCaseClass(1) :: IntegerCaseClass(2) :: Nil).toDF() testData.registerTempTable("integerTable") - sql(s"CREATE TEMPORARY FUNCTION testUDFIntegerToString AS '${classOf[UDFIntegerToString].getName}'") + val udfName = classOf[UDFIntegerToString].getName + sql(s"CREATE TEMPORARY FUNCTION testUDFIntegerToString AS '$udfName'") checkAnswer( - sql("SELECT testUDFIntegerToString(i) FROM integerTable"), //.collect(), + sql("SELECT testUDFIntegerToString(i) FROM integerTable"), Seq(Row("1"), Row("2"))) sql("DROP TEMPORARY FUNCTION IF EXISTS testUDFIntegerToString") @@ -141,7 +142,7 @@ class HiveUdfSuite extends QueryTest { sql(s"CREATE TEMPORARY FUNCTION testUDFListListInt AS '${classOf[UDFListListInt].getName}'") checkAnswer( - sql("SELECT testUDFListListInt(lli) FROM listListIntTable"), //.collect(), + sql("SELECT testUDFListListInt(lli) FROM listListIntTable"), Seq(Row(0), Row(2), Row(13))) sql("DROP TEMPORARY FUNCTION IF EXISTS testUDFListListInt") @@ -156,7 +157,7 @@ class HiveUdfSuite extends QueryTest { sql(s"CREATE TEMPORARY FUNCTION testUDFListString AS '${classOf[UDFListString].getName}'") checkAnswer( - sql("SELECT testUDFListString(l) FROM listStringTable"), //.collect(), + sql("SELECT testUDFListString(l) FROM listStringTable"), Seq(Row("a,b,c"), Row("d,e"))) sql("DROP TEMPORARY FUNCTION IF EXISTS testUDFListString") @@ -170,7 +171,7 @@ class HiveUdfSuite extends QueryTest { sql(s"CREATE TEMPORARY FUNCTION testStringStringUdf AS '${classOf[UDFStringString].getName}'") checkAnswer( - sql("SELECT testStringStringUdf(\"hello\", s) FROM stringTable"), //.collect(), + sql("SELECT testStringStringUdf(\"hello\", s) FROM stringTable"), Seq(Row("hello world"), Row("hello goodbye"))) sql("DROP TEMPORARY FUNCTION IF EXISTS testStringStringUdf") @@ -187,7 +188,7 @@ class HiveUdfSuite extends QueryTest { sql(s"CREATE TEMPORARY FUNCTION testUDFTwoListList AS '${classOf[UDFTwoListList].getName}'") checkAnswer( - sql("SELECT testUDFTwoListList(lli, lli) FROM TwoListTable"), //.collect(), + sql("SELECT testUDFTwoListList(lli, lli) FROM TwoListTable"), Seq(Row("0, 0"), Row("2, 2"), Row("13, 13"))) sql("DROP TEMPORARY FUNCTION IF EXISTS testUDFTwoListList") @@ -247,7 +248,8 @@ class PairUdf extends GenericUDF { override def initialize(p1: Array[ObjectInspector]): ObjectInspector = ObjectInspectorFactory.getStandardStructObjectInspector( Seq("id", "value"), - Seq(PrimitiveObjectInspectorFactory.javaIntObjectInspector, PrimitiveObjectInspectorFactory.javaIntObjectInspector) + Seq(PrimitiveObjectInspectorFactory.javaIntObjectInspector, + PrimitiveObjectInspectorFactory.javaIntObjectInspector) ) override def evaluate(args: Array[DeferredObject]): AnyRef = { diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/PruningSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/PruningSuite.scala index 8474d850c9c6c..067b577f1560e 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/PruningSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/PruningSuite.scala @@ -143,7 +143,7 @@ class PruningSuite extends HiveComparisonTest with BeforeAndAfter { sql: String, expectedOutputColumns: Seq[String], expectedScannedColumns: Seq[String], - expectedPartValues: Seq[Seq[String]]) = { + expectedPartValues: Seq[Seq[String]]): Unit = { test(s"$testCaseName - pruning test") { val plan = new TestHive.HiveQLQueryExecution(sql).executedPlan val actualOutputColumns = plan.output.map(_.name) diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala index 2065f0d60d92f..47b4cb9ca61ff 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala @@ -34,12 +34,95 @@ case class Nested3(f3: Int) case class NestedArray2(b: Seq[Int]) case class NestedArray1(a: NestedArray2) +case class Order( + id: Int, + make: String, + `type`: String, + price: Int, + pdate: String, + customer: String, + city: String, + state: String, + month: Int) + /** * A collection of hive query tests where we generate the answers ourselves instead of depending on * Hive to generate them (in contrast to HiveQuerySuite). Often this is because the query is * valid, but Hive currently cannot execute it. */ class SQLQuerySuite extends QueryTest { + test("SPARK-6835: udtf in lateral view") { + val df = Seq((1, 1)).toDF("c1", "c2") + df.registerTempTable("table1") + val query = sql("SELECT c1, v FROM table1 LATERAL VIEW stack(3, 1, c1 + 1, c1 + 2) d AS v") + checkAnswer(query, Row(1, 1) :: Row(1, 2) :: Row(1, 3) :: Nil) + } + + test("SPARK-6851: Self-joined converted parquet tables") { + val orders = Seq( + Order(1, "Atlas", "MTB", 234, "2015-01-07", "John D", "Pacifica", "CA", 20151), + Order(3, "Swift", "MTB", 285, "2015-01-17", "John S", "Redwood City", "CA", 20151), + Order(4, "Atlas", "Hybrid", 303, "2015-01-23", "Jones S", "San Mateo", "CA", 20151), + Order(7, "Next", "MTB", 356, "2015-01-04", "Jane D", "Daly City", "CA", 20151), + Order(10, "Next", "YFlikr", 187, "2015-01-09", "John D", "Fremont", "CA", 20151), + Order(11, "Swift", "YFlikr", 187, "2015-01-23", "John D", "Hayward", "CA", 20151), + Order(2, "Next", "Hybrid", 324, "2015-02-03", "Jane D", "Daly City", "CA", 20152), + Order(5, "Next", "Street", 187, "2015-02-08", "John D", "Fremont", "CA", 20152), + Order(6, "Atlas", "Street", 154, "2015-02-09", "John D", "Pacifica", "CA", 20152), + Order(8, "Swift", "Hybrid", 485, "2015-02-19", "John S", "Redwood City", "CA", 20152), + Order(9, "Atlas", "Split", 303, "2015-02-28", "Jones S", "San Mateo", "CA", 20152)) + + val orderUpdates = Seq( + Order(1, "Atlas", "MTB", 434, "2015-01-07", "John D", "Pacifica", "CA", 20151), + Order(11, "Swift", "YFlikr", 137, "2015-01-23", "John D", "Hayward", "CA", 20151)) + + orders.toDF.registerTempTable("orders1") + orderUpdates.toDF.registerTempTable("orderupdates1") + + sql( + """CREATE TABLE orders( + | id INT, + | make String, + | type String, + | price INT, + | pdate String, + | customer String, + | city String) + |PARTITIONED BY (state STRING, month INT) + |STORED AS PARQUET + """.stripMargin) + + sql( + """CREATE TABLE orderupdates( + | id INT, + | make String, + | type String, + | price INT, + | pdate String, + | customer String, + | city String) + |PARTITIONED BY (state STRING, month INT) + |STORED AS PARQUET + """.stripMargin) + + sql("set hive.exec.dynamic.partition.mode=nonstrict") + sql("INSERT INTO TABLE orders PARTITION(state, month) SELECT * FROM orders1") + sql("INSERT INTO TABLE orderupdates PARTITION(state, month) SELECT * FROM orderupdates1") + + checkAnswer( + sql( + """ + |select orders.state, orders.month + |from orders + |join ( + | select distinct orders.state,orders.month + | from orders + | join orderupdates + | on orderupdates.id = orders.id) ao + | on ao.state = orders.state and ao.month = orders.month + """.stripMargin), + (1 to 6).map(_ => Row("CA", 20151))) + } test("SPARK-5371: union with null and sum") { val df = Seq((1, 1)).toDF("c1", "c2") @@ -422,7 +505,7 @@ class SQLQuerySuite extends QueryTest { } test("resolve udtf with single alias") { - val rdd = sparkContext.makeRDD((1 to 5).map(i => s"""{"a":[$i, ${i+1}]}""")) + val rdd = sparkContext.makeRDD((1 to 5).map(i => s"""{"a":[$i, ${i + 1}]}""")) jsonRDD(rdd).registerTempTable("data") val df = sql("SELECT explode(a) AS val FROM data") val col = df("val") @@ -435,7 +518,7 @@ class SQLQuerySuite extends QueryTest { // is not in a valid state (cannot be executed). Because of this bug, the analysis rule of // PreInsertionCasts will actually start to work before ImplicitGenerate and then // generates an invalid query plan. - val rdd = sparkContext.makeRDD((1 to 5).map(i => s"""{"a":[$i, ${i+1}]}""")) + val rdd = sparkContext.makeRDD((1 to 5).map(i => s"""{"a":[$i, ${i + 1}]}""")) jsonRDD(rdd).registerTempTable("data") val originalConf = getConf("spark.sql.hive.convertCTAS", "false") setConf("spark.sql.hive.convertCTAS", "false") @@ -468,4 +551,14 @@ class SQLQuerySuite extends QueryTest { sql(s"DROP TABLE $tableName") } } + + test("SPARK-5203 union with different decimal precision") { + Seq.empty[(Decimal, Decimal)] + .toDF("d1", "d2") + .select($"d1".cast(DecimalType(10, 15)).as("d")) + .registerTempTable("dn") + + sql("select d from dn union all select d * 2 from dn") + .queryExecution.analyzed + } } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/parquetSuites.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/parquetSuites.scala index 432d65a874518..d5dd0bf58e702 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/parquetSuites.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/parquetSuites.scala @@ -1,4 +1,3 @@ - /* * Licensed to the Apache Software Foundation (ASF) under one or more * contributor license agreements. See the NOTICE file distributed with @@ -26,8 +25,10 @@ import org.apache.spark.sql.{QueryTest, SQLConf, SaveMode} import org.apache.spark.sql.catalyst.expressions.Row import org.apache.spark.sql.execution.{ExecutedCommand, PhysicalRDD} import org.apache.spark.sql.hive.execution.HiveTableScan +import org.apache.spark.sql.hive.test.TestHive import org.apache.spark.sql.hive.test.TestHive._ import org.apache.spark.sql.hive.test.TestHive.implicits._ +import org.apache.spark.sql.json.JSONRelation import org.apache.spark.sql.sources.{InsertIntoDataSource, LogicalRelation} import org.apache.spark.sql.parquet.{ParquetRelation2, ParquetTableScan} import org.apache.spark.sql.SaveMode @@ -390,6 +391,114 @@ class ParquetDataSourceOnMetastoreSuite extends ParquetMetastoreSuiteBase { sql("DROP TABLE ms_convert") } + + test("Caching converted data source Parquet Relations") { + def checkCached(tableIdentifer: catalog.QualifiedTableName): Unit = { + // Converted test_parquet should be cached. + catalog.cachedDataSourceTables.getIfPresent(tableIdentifer) match { + case null => fail("Converted test_parquet should be cached in the cache.") + case logical @ LogicalRelation(parquetRelation: ParquetRelation2) => // OK + case other => + fail( + "The cached test_parquet should be a Parquet Relation. " + + s"However, $other is returned form the cache.") + } + } + + sql("DROP TABLE IF EXISTS test_insert_parquet") + sql("DROP TABLE IF EXISTS test_parquet_partitioned_cache_test") + + sql( + """ + |create table test_insert_parquet + |( + | intField INT, + | stringField STRING + |) + |ROW FORMAT SERDE 'org.apache.hadoop.hive.ql.io.parquet.serde.ParquetHiveSerDe' + |STORED AS + | INPUTFORMAT 'org.apache.hadoop.hive.ql.io.parquet.MapredParquetInputFormat' + | OUTPUTFORMAT 'org.apache.hadoop.hive.ql.io.parquet.MapredParquetOutputFormat' + """.stripMargin) + + var tableIdentifer = catalog.QualifiedTableName("default", "test_insert_parquet") + + // First, make sure the converted test_parquet is not cached. + assert(catalog.cachedDataSourceTables.getIfPresent(tableIdentifer) === null) + // Table lookup will make the table cached. + table("test_insert_parquet") + checkCached(tableIdentifer) + // For insert into non-partitioned table, we will do the conversion, + // so the converted test_insert_parquet should be cached. + invalidateTable("test_insert_parquet") + assert(catalog.cachedDataSourceTables.getIfPresent(tableIdentifer) === null) + sql( + """ + |INSERT INTO TABLE test_insert_parquet + |select a, b from jt + """.stripMargin) + checkCached(tableIdentifer) + // Make sure we can read the data. + checkAnswer( + sql("select * from test_insert_parquet"), + sql("select a, b from jt").collect()) + // Invalidate the cache. + invalidateTable("test_insert_parquet") + assert(catalog.cachedDataSourceTables.getIfPresent(tableIdentifer) === null) + + // Create a partitioned table. + sql( + """ + |create table test_parquet_partitioned_cache_test + |( + | intField INT, + | stringField STRING + |) + |PARTITIONED BY (date string) + |ROW FORMAT SERDE 'org.apache.hadoop.hive.ql.io.parquet.serde.ParquetHiveSerDe' + |STORED AS + | INPUTFORMAT 'org.apache.hadoop.hive.ql.io.parquet.MapredParquetInputFormat' + | OUTPUTFORMAT 'org.apache.hadoop.hive.ql.io.parquet.MapredParquetOutputFormat' + """.stripMargin) + + tableIdentifer = catalog.QualifiedTableName("default", "test_parquet_partitioned_cache_test") + assert(catalog.cachedDataSourceTables.getIfPresent(tableIdentifer) === null) + sql( + """ + |INSERT INTO TABLE test_parquet_partitioned_cache_test + |PARTITION (date='2015-04-01') + |select a, b from jt + """.stripMargin) + // Right now, insert into a partitioned Parquet is not supported in data source Parquet. + // So, we expect it is not cached. + assert(catalog.cachedDataSourceTables.getIfPresent(tableIdentifer) === null) + sql( + """ + |INSERT INTO TABLE test_parquet_partitioned_cache_test + |PARTITION (date='2015-04-02') + |select a, b from jt + """.stripMargin) + assert(catalog.cachedDataSourceTables.getIfPresent(tableIdentifer) === null) + + // Make sure we can cache the partitioned table. + table("test_parquet_partitioned_cache_test") + checkCached(tableIdentifer) + // Make sure we can read the data. + checkAnswer( + sql("select STRINGField, date, intField from test_parquet_partitioned_cache_test"), + sql( + """ + |select b, '2015-04-01', a FROM jt + |UNION ALL + |select b, '2015-04-02', a FROM jt + """.stripMargin).collect()) + + invalidateTable("test_parquet_partitioned_cache_test") + assert(catalog.cachedDataSourceTables.getIfPresent(tableIdentifer) === null) + + sql("DROP TABLE test_insert_parquet") + sql("DROP TABLE test_parquet_partitioned_cache_test") + } } class ParquetDataSourceOffMetastoreSuite extends ParquetMetastoreSuiteBase { @@ -578,6 +687,22 @@ class ParquetDataSourceOnSourceSuite extends ParquetSourceSuiteBase { sql("DROP TABLE alwaysNullable") } + + test("Aggregation attribute names can't contain special chars \" ,;{}()\\n\\t=\"") { + val tempDir = Utils.createTempDir() + val filePath = new File(tempDir, "testParquet").getCanonicalPath + val filePath2 = new File(tempDir, "testParquet2").getCanonicalPath + + val df = Seq(1,2,3).map(i => (i, i.toString)).toDF("int", "str") + val df2 = df.as('x).join(df.as('y), $"x.str" === $"y.str").groupBy("y.str").max("y.int") + intercept[RuntimeException](df2.saveAsParquetFile(filePath)) + + val df3 = df2.toDF("str", "max_int") + df3.saveAsParquetFile(filePath2) + val df4 = parquetFile(filePath2) + checkAnswer(df4, Row("1", 1) :: Row("2", 2) :: Row("3", 3) :: Nil) + assert(df4.columns === Array("str", "max_int")) + } } class ParquetDataSourceOffSourceSuite extends ParquetSourceSuiteBase { @@ -761,7 +886,11 @@ abstract class ParquetPartitioningTest extends QueryTest with BeforeAndAfterAll test(s"SPARK-5775 read struct from $table") { checkAnswer( - sql(s"SELECT p, structField.intStructField, structField.stringStructField FROM $table WHERE p = 1"), + sql( + s""" + |SELECT p, structField.intStructField, structField.stringStructField + |FROM $table WHERE p = 1 + """.stripMargin), (1 to 10).map(i => Row(1, i, f"${i}_string"))) } diff --git a/streaming/src/main/scala/org/apache/spark/streaming/Checkpoint.scala b/streaming/src/main/scala/org/apache/spark/streaming/Checkpoint.scala index f73b463d07779..0a50485118588 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/Checkpoint.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/Checkpoint.scala @@ -26,7 +26,7 @@ import org.apache.hadoop.conf.Configuration import org.apache.spark.{SparkException, SparkConf, Logging} import org.apache.spark.io.CompressionCodec -import org.apache.spark.util.MetadataCleaner +import org.apache.spark.util.{MetadataCleaner, Utils} import org.apache.spark.streaming.scheduler.JobGenerator @@ -139,8 +139,11 @@ class CheckpointWriter( // Write checkpoint to temp file fs.delete(tempFile, true) // just in case it exists val fos = fs.create(tempFile) - fos.write(bytes) - fos.close() + Utils.tryWithSafeFinally { + fos.write(bytes) + } { + fos.close() + } // If the checkpoint file exists, back it up // If the backup exists as well, just delete it, otherwise rename will fail @@ -187,9 +190,11 @@ class CheckpointWriter( val bos = new ByteArrayOutputStream() val zos = compressionCodec.compressedOutputStream(bos) val oos = new ObjectOutputStream(zos) - oos.writeObject(checkpoint) - oos.close() - bos.close() + Utils.tryWithSafeFinally { + oos.writeObject(checkpoint) + } { + oos.close() + } try { executor.execute(new CheckpointWriteHandler( checkpoint.checkpointTime, bos.toByteArray, clearCheckpointDataLater)) @@ -234,7 +239,7 @@ object CheckpointReader extends Logging { val checkpointPath = new Path(checkpointDir) // TODO(rxin): Why is this a def?! - def fs = checkpointPath.getFileSystem(hadoopConf) + def fs: FileSystem = checkpointPath.getFileSystem(hadoopConf) // Try to find the checkpoint files val checkpointFiles = Checkpoint.getCheckpointFiles(checkpointDir, fs).reverse @@ -248,18 +253,24 @@ object CheckpointReader extends Logging { checkpointFiles.foreach(file => { logInfo("Attempting to load checkpoint from file " + file) try { - val fis = fs.open(file) - // ObjectInputStream uses the last defined user-defined class loader in the stack - // to find classes, which maybe the wrong class loader. Hence, a inherited version - // of ObjectInputStream is used to explicitly use the current thread's default class - // loader to find and load classes. This is a well know Java issue and has popped up - // in other places (e.g., http://jira.codehaus.org/browse/GROOVY-1627) - val zis = compressionCodec.compressedInputStream(fis) - val ois = new ObjectInputStreamWithLoader(zis, - Thread.currentThread().getContextClassLoader) - val cp = ois.readObject.asInstanceOf[Checkpoint] - ois.close() - fs.close() + var ois: ObjectInputStreamWithLoader = null + var cp: Checkpoint = null + Utils.tryWithSafeFinally { + val fis = fs.open(file) + // ObjectInputStream uses the last defined user-defined class loader in the stack + // to find classes, which maybe the wrong class loader. Hence, a inherited version + // of ObjectInputStream is used to explicitly use the current thread's default class + // loader to find and load classes. This is a well know Java issue and has popped up + // in other places (e.g., http://jira.codehaus.org/browse/GROOVY-1627) + val zis = compressionCodec.compressedInputStream(fis) + ois = new ObjectInputStreamWithLoader(zis, + Thread.currentThread().getContextClassLoader) + cp = ois.readObject.asInstanceOf[Checkpoint] + } { + if (ois != null) { + ois.close() + } + } cp.validate() logInfo("Checkpoint successfully loaded from file " + file) logInfo("Checkpoint was generated at time " + cp.checkpointTime) diff --git a/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaDStreamLike.scala b/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaDStreamLike.scala index 73030e15c5661..808dcc174cf9a 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaDStreamLike.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaDStreamLike.scala @@ -169,7 +169,7 @@ trait JavaDStreamLike[T, This <: JavaDStreamLike[T, This, R], R <: JavaRDDLike[T */ def flatMap[U](f: FlatMapFunction[T, U]): JavaDStream[U] = { import scala.collection.JavaConverters._ - def fn = (x: T) => f.call(x).asScala + def fn: (T) => Iterable[U] = (x: T) => f.call(x).asScala new JavaDStream(dstream.flatMap(fn)(fakeClassTag[U]))(fakeClassTag[U]) } @@ -179,7 +179,7 @@ trait JavaDStreamLike[T, This <: JavaDStreamLike[T, This, R], R <: JavaRDDLike[T */ def flatMapToPair[K2, V2](f: PairFlatMapFunction[T, K2, V2]): JavaPairDStream[K2, V2] = { import scala.collection.JavaConverters._ - def fn = (x: T) => f.call(x).asScala + def fn: (T) => Iterable[(K2, V2)] = (x: T) => f.call(x).asScala def cm: ClassTag[(K2, V2)] = fakeClassTag new JavaPairDStream(dstream.flatMap(fn)(cm))(fakeClassTag[K2], fakeClassTag[V2]) } @@ -190,7 +190,9 @@ trait JavaDStreamLike[T, This <: JavaDStreamLike[T, This, R], R <: JavaRDDLike[T * of the RDD. */ def mapPartitions[U](f: FlatMapFunction[java.util.Iterator[T], U]): JavaDStream[U] = { - def fn = (x: Iterator[T]) => asScalaIterator(f.call(asJavaIterator(x)).iterator()) + def fn: (Iterator[T]) => Iterator[U] = { + (x: Iterator[T]) => asScalaIterator(f.call(asJavaIterator(x)).iterator()) + } new JavaDStream(dstream.mapPartitions(fn)(fakeClassTag[U]))(fakeClassTag[U]) } @@ -201,7 +203,9 @@ trait JavaDStreamLike[T, This <: JavaDStreamLike[T, This, R], R <: JavaRDDLike[T */ def mapPartitionsToPair[K2, V2](f: PairFlatMapFunction[java.util.Iterator[T], K2, V2]) : JavaPairDStream[K2, V2] = { - def fn = (x: Iterator[T]) => asScalaIterator(f.call(asJavaIterator(x)).iterator()) + def fn: (Iterator[T]) => Iterator[(K2, V2)] = { + (x: Iterator[T]) => asScalaIterator(f.call(asJavaIterator(x)).iterator()) + } new JavaPairDStream(dstream.mapPartitions(fn))(fakeClassTag[K2], fakeClassTag[V2]) } diff --git a/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaPairDStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaPairDStream.scala index f94f2d0e8bd31..93baad19e3ee1 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaPairDStream.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaPairDStream.scala @@ -526,7 +526,7 @@ class JavaPairDStream[K, V](val dstream: DStream[(K, V)])( */ def flatMapValues[U](f: JFunction[V, java.lang.Iterable[U]]): JavaPairDStream[K, U] = { import scala.collection.JavaConverters._ - def fn = (x: V) => f.apply(x).asScala + def fn: (V) => Iterable[U] = (x: V) => f.apply(x).asScala implicit val cm: ClassTag[U] = implicitly[ClassTag[AnyRef]].asInstanceOf[ClassTag[U]] dstream.flatMapValues(fn) diff --git a/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaStreamingContext.scala b/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaStreamingContext.scala index e3db01c1e12c6..4095a7cc84946 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaStreamingContext.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaStreamingContext.scala @@ -192,7 +192,7 @@ class JavaStreamingContext(val ssc: StreamingContext) extends Closeable { converter: JFunction[InputStream, java.lang.Iterable[T]], storageLevel: StorageLevel) : JavaReceiverInputDStream[T] = { - def fn = (x: InputStream) => converter.call(x).toIterator + def fn: (InputStream) => Iterator[T] = (x: InputStream) => converter.call(x).toIterator implicit val cmt: ClassTag[T] = implicitly[ClassTag[AnyRef]].asInstanceOf[ClassTag[T]] ssc.socketStream(hostname, port, fn, storageLevel) @@ -313,7 +313,7 @@ class JavaStreamingContext(val ssc: StreamingContext) extends Closeable { implicit val cmk: ClassTag[K] = ClassTag(kClass) implicit val cmv: ClassTag[V] = ClassTag(vClass) implicit val cmf: ClassTag[F] = ClassTag(fClass) - def fn = (x: Path) => filter.call(x).booleanValue() + def fn: (Path) => Boolean = (x: Path) => filter.call(x).booleanValue() ssc.fileStream[K, V, F](directory, fn, newFilesOnly) } @@ -344,7 +344,7 @@ class JavaStreamingContext(val ssc: StreamingContext) extends Closeable { implicit val cmk: ClassTag[K] = ClassTag(kClass) implicit val cmv: ClassTag[V] = ClassTag(vClass) implicit val cmf: ClassTag[F] = ClassTag(fClass) - def fn = (x: Path) => filter.call(x).booleanValue() + def fn: (Path) => Boolean = (x: Path) => filter.call(x).booleanValue() ssc.fileStream[K, V, F](directory, fn, newFilesOnly, conf) } @@ -625,7 +625,7 @@ class JavaStreamingContext(val ssc: StreamingContext) extends Closeable { * Stop the execution of the streams. * @param stopSparkContext Stop the associated SparkContext or not */ - def stop(stopSparkContext: Boolean) = ssc.stop(stopSparkContext) + def stop(stopSparkContext: Boolean): Unit = ssc.stop(stopSparkContext) /** * Stop the execution of the streams. @@ -633,7 +633,7 @@ class JavaStreamingContext(val ssc: StreamingContext) extends Closeable { * @param stopGracefully Stop gracefully by waiting for the processing of all * received data to be completed */ - def stop(stopSparkContext: Boolean, stopGracefully: Boolean) = { + def stop(stopSparkContext: Boolean, stopGracefully: Boolean): Unit = { ssc.stop(stopSparkContext, stopGracefully) } diff --git a/streaming/src/main/scala/org/apache/spark/streaming/dstream/DStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/dstream/DStream.scala index 795c5aa6d585b..24f99a2b929f5 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/dstream/DStream.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/dstream/DStream.scala @@ -839,7 +839,7 @@ object DStream { /** Filtering function that excludes non-user classes for a streaming application */ def streamingExclustionFunction(className: String): Boolean = { - def doesMatch(r: Regex) = r.findFirstIn(className).isDefined + def doesMatch(r: Regex): Boolean = r.findFirstIn(className).isDefined val isSparkClass = doesMatch(SPARK_CLASS_REGEX) val isSparkExampleClass = doesMatch(SPARK_EXAMPLES_CLASS_REGEX) val isSparkStreamingTestClass = doesMatch(SPARK_STREAMING_TESTCLASS_REGEX) diff --git a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobScheduler.scala b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobScheduler.scala index d6a93acbe711b..95f1857b4c377 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobScheduler.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobScheduler.scala @@ -105,6 +105,7 @@ class JobScheduler(val ssc: StreamingContext) extends Logging { if (jobSet.jobs.isEmpty) { logInfo("No jobs added for time " + jobSet.time) } else { + listenerBus.post(StreamingListenerBatchSubmitted(jobSet.toBatchInfo)) jobSets.put(jobSet.time, jobSet) jobSet.jobs.foreach(job => jobExecutor.execute(new JobHandler(job))) logInfo("Added jobs for time " + jobSet.time) @@ -134,10 +135,13 @@ class JobScheduler(val ssc: StreamingContext) extends Logging { private def handleJobStart(job: Job) { val jobSet = jobSets.get(job.time) - if (!jobSet.hasStarted) { + val isFirstJobOfJobSet = !jobSet.hasStarted + jobSet.handleJobStart(job) + if (isFirstJobOfJobSet) { + // "StreamingListenerBatchStarted" should be posted after calling "handleJobStart" to get the + // correct "jobSet.processingStartTime". listenerBus.post(StreamingListenerBatchStarted(jobSet.toBatchInfo)) } - jobSet.handleJobStart(job) logInfo("Starting job " + job.id + " from job set of time " + jobSet.time) } diff --git a/streaming/src/main/scala/org/apache/spark/streaming/ui/StreamingJobProgressListener.scala b/streaming/src/main/scala/org/apache/spark/streaming/ui/StreamingJobProgressListener.scala index e4bd067cacb77..be1e8686cf9fa 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/ui/StreamingJobProgressListener.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/ui/StreamingJobProgressListener.scala @@ -33,7 +33,7 @@ private[streaming] class StreamingJobProgressListener(ssc: StreamingContext) private val waitingBatchInfos = new HashMap[Time, BatchInfo] private val runningBatchInfos = new HashMap[Time, BatchInfo] - private val completedaBatchInfos = new Queue[BatchInfo] + private val completedBatchInfos = new Queue[BatchInfo] private val batchInfoLimit = ssc.conf.getInt("spark.streaming.ui.retainedBatches", 100) private var totalCompletedBatches = 0L private var totalReceivedRecords = 0L @@ -62,7 +62,7 @@ private[streaming] class StreamingJobProgressListener(ssc: StreamingContext) override def onBatchSubmitted(batchSubmitted: StreamingListenerBatchSubmitted): Unit = { synchronized { - runningBatchInfos(batchSubmitted.batchInfo.batchTime) = batchSubmitted.batchInfo + waitingBatchInfos(batchSubmitted.batchInfo.batchTime) = batchSubmitted.batchInfo } } @@ -79,8 +79,8 @@ private[streaming] class StreamingJobProgressListener(ssc: StreamingContext) synchronized { waitingBatchInfos.remove(batchCompleted.batchInfo.batchTime) runningBatchInfos.remove(batchCompleted.batchInfo.batchTime) - completedaBatchInfos.enqueue(batchCompleted.batchInfo) - if (completedaBatchInfos.size > batchInfoLimit) completedaBatchInfos.dequeue() + completedBatchInfos.enqueue(batchCompleted.batchInfo) + if (completedBatchInfos.size > batchInfoLimit) completedBatchInfos.dequeue() totalCompletedBatches += 1L batchCompleted.batchInfo.receivedBlockInfo.foreach { case (_, infos) => @@ -118,7 +118,7 @@ private[streaming] class StreamingJobProgressListener(ssc: StreamingContext) } def retainedCompletedBatches: Seq[BatchInfo] = synchronized { - completedaBatchInfos.toSeq + completedBatchInfos.toSeq } def processingDelayDistribution: Option[Distribution] = synchronized { @@ -149,7 +149,7 @@ private[streaming] class StreamingJobProgressListener(ssc: StreamingContext) }.toMap } - def lastReceivedBatchRecords: Map[Int, Long] = { + def lastReceivedBatchRecords: Map[Int, Long] = synchronized { val lastReceivedBlockInfoOption = lastReceivedBatch.map(_.receivedBlockInfo) lastReceivedBlockInfoOption.map { lastReceivedBlockInfo => (0 until numReceivers).map { receiverId => @@ -160,24 +160,24 @@ private[streaming] class StreamingJobProgressListener(ssc: StreamingContext) } } - def receiverInfo(receiverId: Int): Option[ReceiverInfo] = { + def receiverInfo(receiverId: Int): Option[ReceiverInfo] = synchronized { receiverInfos.get(receiverId) } - def lastCompletedBatch: Option[BatchInfo] = { - completedaBatchInfos.sortBy(_.batchTime)(Time.ordering).lastOption + def lastCompletedBatch: Option[BatchInfo] = synchronized { + completedBatchInfos.sortBy(_.batchTime)(Time.ordering).lastOption } - def lastReceivedBatch: Option[BatchInfo] = { + def lastReceivedBatch: Option[BatchInfo] = synchronized { retainedBatches.lastOption } - private def retainedBatches: Seq[BatchInfo] = synchronized { + private def retainedBatches: Seq[BatchInfo] = { (waitingBatchInfos.values.toSeq ++ - runningBatchInfos.values.toSeq ++ completedaBatchInfos).sortBy(_.batchTime)(Time.ordering) + runningBatchInfos.values.toSeq ++ completedBatchInfos).sortBy(_.batchTime)(Time.ordering) } private def extractDistribution(getMetric: BatchInfo => Option[Long]): Option[Distribution] = { - Distribution(completedaBatchInfos.flatMap(getMetric(_)).map(_.toDouble)) + Distribution(completedBatchInfos.flatMap(getMetric(_)).map(_.toDouble)) } } diff --git a/streaming/src/main/scala/org/apache/spark/streaming/ui/StreamingPage.scala b/streaming/src/main/scala/org/apache/spark/streaming/ui/StreamingPage.scala index bfe8086fcf8fe..b6dcb62bfeec8 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/ui/StreamingPage.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/ui/StreamingPage.scala @@ -37,11 +37,12 @@ private[ui] class StreamingPage(parent: StreamingTab) /** Render the page */ def render(request: HttpServletRequest): Seq[Node] = { - val content = + val content = listener.synchronized { generateBasicStats() ++

    ++

    Statistics over last {listener.retainedCompletedBatches.size} processed batches

    ++ generateReceiverStats() ++ generateBatchStatsTable() + } UIUtils.headerSparkPage("Streaming", content, parent, Some(5000)) } diff --git a/streaming/src/main/scala/org/apache/spark/streaming/util/RawTextSender.scala b/streaming/src/main/scala/org/apache/spark/streaming/util/RawTextSender.scala index a7850812bd612..ca2f319f174a2 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/util/RawTextSender.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/util/RawTextSender.scala @@ -72,7 +72,8 @@ object RawTextSender extends Logging { } catch { case e: IOException => logError("Client disconnected") - socket.close() + } finally { + socket.close() } } } diff --git a/streaming/src/test/resources/log4j.properties b/streaming/src/test/resources/log4j.properties index 9697237bfa1a3..75e3b53a093f6 100644 --- a/streaming/src/test/resources/log4j.properties +++ b/streaming/src/test/resources/log4j.properties @@ -24,5 +24,5 @@ log4j.appender.file.layout=org.apache.log4j.PatternLayout log4j.appender.file.layout.ConversionPattern=%d{yy/MM/dd HH:mm:ss.SSS} %t %p %c{1}: %m%n # Ignore messages below warning level from Jetty, because it's a bit verbose -log4j.logger.org.eclipse.jetty=WARN +log4j.logger.org.spark-project.jetty=WARN diff --git a/streaming/src/test/scala/org/apache/spark/streaming/BasicOperationsSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/BasicOperationsSuite.scala index cf191715d29d6..87bc20f79c3cd 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/BasicOperationsSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/BasicOperationsSuite.scala @@ -171,7 +171,9 @@ class BasicOperationsSuite extends TestSuiteBase { test("flatMapValues") { testOperation( Seq( Seq("a", "a", "b"), Seq("", ""), Seq() ), - (s: DStream[String]) => s.map(x => (x, 1)).reduceByKey(_ + _).flatMapValues(x => Seq(x, x + 10)), + (s: DStream[String]) => { + s.map(x => (x, 1)).reduceByKey(_ + _).flatMapValues(x => Seq(x, x + 10)) + }, Seq( Seq(("a", 2), ("a", 12), ("b", 1), ("b", 11)), Seq(("", 2), ("", 12)), Seq() ), true ) @@ -474,7 +476,7 @@ class BasicOperationsSuite extends TestSuiteBase { stream.foreachRDD(_ => {}) // Dummy output stream ssc.start() Thread.sleep(2000) - def getInputFromSlice(fromMillis: Long, toMillis: Long) = { + def getInputFromSlice(fromMillis: Long, toMillis: Long): Set[Int] = { stream.slice(new Time(fromMillis), new Time(toMillis)).flatMap(_.collect()).toSet } diff --git a/streaming/src/test/scala/org/apache/spark/streaming/CheckpointSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/CheckpointSuite.scala index 91a2b2bba461d..54c30440a6e8d 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/CheckpointSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/CheckpointSuite.scala @@ -43,7 +43,7 @@ class CheckpointSuite extends TestSuiteBase { var ssc: StreamingContext = null - override def batchDuration = Milliseconds(500) + override def batchDuration: Duration = Milliseconds(500) override def beforeFunction() { super.beforeFunction() @@ -72,7 +72,7 @@ class CheckpointSuite extends TestSuiteBase { val input = (1 to 10).map(_ => Seq("a")).toSeq val operation = (st: DStream[String]) => { val updateFunc = (values: Seq[Int], state: Option[Int]) => { - Some((values.sum + state.getOrElse(0))) + Some(values.sum + state.getOrElse(0)) } st.map(x => (x, 1)) .updateStateByKey(updateFunc) @@ -199,7 +199,12 @@ class CheckpointSuite extends TestSuiteBase { testCheckpointedOperation( Seq( Seq("a", "a", "b"), Seq("", ""), Seq(), Seq("a", "a", "b"), Seq("", ""), Seq() ), (s: DStream[String]) => s.map(x => (x, 1)).reduceByKey(_ + _), - Seq( Seq(("a", 2), ("b", 1)), Seq(("", 2)), Seq(), Seq(("a", 2), ("b", 1)), Seq(("", 2)), Seq() ), + Seq( + Seq(("a", 2), ("b", 1)), + Seq(("", 2)), + Seq(), + Seq(("a", 2), ("b", 1)), + Seq(("", 2)), Seq() ), 3 ) } @@ -212,7 +217,8 @@ class CheckpointSuite extends TestSuiteBase { val n = 10 val w = 4 val input = (1 to n).map(_ => Seq("a")).toSeq - val output = Seq(Seq(("a", 1)), Seq(("a", 2)), Seq(("a", 3))) ++ (1 to (n - w + 1)).map(x => Seq(("a", 4))) + val output = Seq( + Seq(("a", 1)), Seq(("a", 2)), Seq(("a", 3))) ++ (1 to (n - w + 1)).map(x => Seq(("a", 4))) val operation = (st: DStream[String]) => { st.map(x => (x, 1)) .reduceByKeyAndWindow(_ + _, _ - _, batchDuration * w, batchDuration) @@ -236,7 +242,13 @@ class CheckpointSuite extends TestSuiteBase { classOf[TextOutputFormat[Text, IntWritable]]) output }, - Seq(Seq(("a", 2), ("b", 1)), Seq(("", 2)), Seq(), Seq(("a", 2), ("b", 1)), Seq(("", 2)), Seq()), + Seq( + Seq(("a", 2), ("b", 1)), + Seq(("", 2)), + Seq(), + Seq(("a", 2), ("b", 1)), + Seq(("", 2)), + Seq()), 3 ) } finally { @@ -259,7 +271,13 @@ class CheckpointSuite extends TestSuiteBase { classOf[NewTextOutputFormat[Text, IntWritable]]) output }, - Seq(Seq(("a", 2), ("b", 1)), Seq(("", 2)), Seq(), Seq(("a", 2), ("b", 1)), Seq(("", 2)), Seq()), + Seq( + Seq(("a", 2), ("b", 1)), + Seq(("", 2)), + Seq(), + Seq(("a", 2), ("b", 1)), + Seq(("", 2)), + Seq()), 3 ) } finally { @@ -298,7 +316,13 @@ class CheckpointSuite extends TestSuiteBase { output } }, - Seq(Seq(("a", 2), ("b", 1)), Seq(("", 2)), Seq(), Seq(("a", 2), ("b", 1)), Seq(("", 2)), Seq()), + Seq( + Seq(("a", 2), ("b", 1)), + Seq(("", 2)), + Seq(), + Seq(("a", 2), ("b", 1)), + Seq(("", 2)), + Seq()), 3 ) } finally { @@ -533,7 +557,8 @@ class CheckpointSuite extends TestSuiteBase { * Advances the manual clock on the streaming scheduler by given number of batches. * It also waits for the expected amount of time for each batch. */ - def advanceTimeWithRealDelay[V: ClassTag](ssc: StreamingContext, numBatches: Long): Seq[Seq[V]] = { + def advanceTimeWithRealDelay[V: ClassTag](ssc: StreamingContext, numBatches: Long): Seq[Seq[V]] = + { val clock = ssc.scheduler.clock.asInstanceOf[ManualClock] logInfo("Manual clock before advancing = " + clock.getTimeMillis()) for (i <- 1 to numBatches.toInt) { @@ -543,7 +568,7 @@ class CheckpointSuite extends TestSuiteBase { logInfo("Manual clock after advancing = " + clock.getTimeMillis()) Thread.sleep(batchDuration.milliseconds) - val outputStream = ssc.graph.getOutputStreams.filter { dstream => + val outputStream = ssc.graph.getOutputStreams().filter { dstream => dstream.isInstanceOf[TestOutputStreamWithPartitions[V]] }.head.asInstanceOf[TestOutputStreamWithPartitions[V]] outputStream.output.map(_.flatten) @@ -552,4 +577,4 @@ class CheckpointSuite extends TestSuiteBase { private object CheckpointSuite extends Serializable { var batchThreeShouldBlockIndefinitely: Boolean = true -} \ No newline at end of file +} diff --git a/streaming/src/test/scala/org/apache/spark/streaming/FailureSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/FailureSuite.scala index 26435d8515815..0c4c06534a693 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/FailureSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/FailureSuite.scala @@ -29,9 +29,9 @@ class FailureSuite extends TestSuiteBase with Logging { val directory = Utils.createTempDir() val numBatches = 30 - override def batchDuration = Milliseconds(1000) + override def batchDuration: Duration = Milliseconds(1000) - override def useManualClock = false + override def useManualClock: Boolean = false override def afterFunction() { Utils.deleteRecursively(directory) diff --git a/streaming/src/test/scala/org/apache/spark/streaming/InputStreamsSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/InputStreamsSuite.scala index 7ed6320a3d0bc..e6ac4975c5e68 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/InputStreamsSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/InputStreamsSuite.scala @@ -52,7 +52,7 @@ class InputStreamsSuite extends TestSuiteBase with BeforeAndAfter { "localhost", testServer.port, StorageLevel.MEMORY_AND_DISK) val outputBuffer = new ArrayBuffer[Seq[String]] with SynchronizedBuffer[Seq[String]] val outputStream = new TestOutputStream(networkStream, outputBuffer) - def output = outputBuffer.flatMap(x => x) + def output: ArrayBuffer[String] = outputBuffer.flatMap(x => x) outputStream.register() ssc.start() @@ -164,7 +164,7 @@ class InputStreamsSuite extends TestSuiteBase with BeforeAndAfter { val countStream = networkStream.count val outputBuffer = new ArrayBuffer[Seq[Long]] with SynchronizedBuffer[Seq[Long]] val outputStream = new TestOutputStream(countStream, outputBuffer) - def output = outputBuffer.flatMap(x => x) + def output: ArrayBuffer[Long] = outputBuffer.flatMap(x => x) outputStream.register() ssc.start() @@ -196,7 +196,7 @@ class InputStreamsSuite extends TestSuiteBase with BeforeAndAfter { val queueStream = ssc.queueStream(queue, oneAtATime = true) val outputBuffer = new ArrayBuffer[Seq[String]] with SynchronizedBuffer[Seq[String]] val outputStream = new TestOutputStream(queueStream, outputBuffer) - def output = outputBuffer.filter(_.size > 0) + def output: ArrayBuffer[Seq[String]] = outputBuffer.filter(_.size > 0) outputStream.register() ssc.start() @@ -204,7 +204,7 @@ class InputStreamsSuite extends TestSuiteBase with BeforeAndAfter { val clock = ssc.scheduler.clock.asInstanceOf[ManualClock] val input = Seq("1", "2", "3", "4", "5") val expectedOutput = input.map(Seq(_)) - //Thread.sleep(1000) + val inputIterator = input.toIterator for (i <- 0 until input.size) { // Enqueue more than 1 item per tick but they should dequeue one at a time @@ -239,7 +239,7 @@ class InputStreamsSuite extends TestSuiteBase with BeforeAndAfter { val queueStream = ssc.queueStream(queue, oneAtATime = false) val outputBuffer = new ArrayBuffer[Seq[String]] with SynchronizedBuffer[Seq[String]] val outputStream = new TestOutputStream(queueStream, outputBuffer) - def output = outputBuffer.filter(_.size > 0) + def output: ArrayBuffer[Seq[String]] = outputBuffer.filter(_.size > 0) outputStream.register() ssc.start() @@ -352,7 +352,8 @@ class TestServer(portToBind: Int = 0) extends Logging { logInfo("New connection") try { clientSocket.setTcpNoDelay(true) - val outputStream = new BufferedWriter(new OutputStreamWriter(clientSocket.getOutputStream)) + val outputStream = new BufferedWriter( + new OutputStreamWriter(clientSocket.getOutputStream)) while(clientSocket.isConnected) { val msg = queue.poll(100, TimeUnit.MILLISECONDS) @@ -384,7 +385,7 @@ class TestServer(portToBind: Int = 0) extends Logging { def stop() { servingThread.interrupt() } - def port = serverSocket.getLocalPort + def port: Int = serverSocket.getLocalPort } /** This is a receiver to test multiple threads inserting data using block generator */ diff --git a/streaming/src/test/scala/org/apache/spark/streaming/ReceivedBlockHandlerSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/ReceivedBlockHandlerSuite.scala index 18a477f92094d..c090eaec2928d 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/ReceivedBlockHandlerSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/ReceivedBlockHandlerSuite.scala @@ -24,20 +24,20 @@ import scala.collection.mutable.ArrayBuffer import scala.concurrent.duration._ import scala.language.postfixOps -import akka.actor.{ActorSystem, Props} import org.apache.hadoop.conf.Configuration import org.scalatest.{BeforeAndAfter, FunSuite, Matchers} import org.scalatest.concurrent.Eventually._ import org.apache.spark._ import org.apache.spark.network.nio.NioBlockTransferService +import org.apache.spark.rpc.RpcEnv import org.apache.spark.scheduler.LiveListenerBus import org.apache.spark.serializer.KryoSerializer import org.apache.spark.shuffle.hash.HashShuffleManager import org.apache.spark.storage._ import org.apache.spark.streaming.receiver._ import org.apache.spark.streaming.util._ -import org.apache.spark.util.{AkkaUtils, ManualClock, Utils} +import org.apache.spark.util.{ManualClock, Utils} import WriteAheadLogBasedBlockHandler._ import WriteAheadLogSuite._ @@ -54,22 +54,19 @@ class ReceivedBlockHandlerSuite extends FunSuite with BeforeAndAfter with Matche val manualClock = new ManualClock val blockManagerSize = 10000000 - var actorSystem: ActorSystem = null + var rpcEnv: RpcEnv = null var blockManagerMaster: BlockManagerMaster = null var blockManager: BlockManager = null var tempDirectory: File = null before { - val (actorSystem, boundPort) = AkkaUtils.createActorSystem( - "test", "localhost", 0, conf = conf, securityManager = securityMgr) - this.actorSystem = actorSystem - conf.set("spark.driver.port", boundPort.toString) + rpcEnv = RpcEnv.create("test", "localhost", 0, conf, securityMgr) + conf.set("spark.driver.port", rpcEnv.address.port.toString) - blockManagerMaster = new BlockManagerMaster( - actorSystem.actorOf(Props(new BlockManagerMasterActor(true, conf, new LiveListenerBus))), - conf, true) + blockManagerMaster = new BlockManagerMaster(rpcEnv.setupEndpoint("blockmanager", + new BlockManagerMasterEndpoint(rpcEnv, true, conf, new LiveListenerBus)), conf, true) - blockManager = new BlockManager("bm", actorSystem, blockManagerMaster, serializer, + blockManager = new BlockManager("bm", rpcEnv, blockManagerMaster, serializer, blockManagerSize, conf, mapOutputTracker, shuffleManager, new NioBlockTransferService(conf, securityMgr), securityMgr, 0) blockManager.initialize("app-id") @@ -87,9 +84,9 @@ class ReceivedBlockHandlerSuite extends FunSuite with BeforeAndAfter with Matche blockManagerMaster.stop() blockManagerMaster = null } - actorSystem.shutdown() - actorSystem.awaitTermination() - actorSystem = null + rpcEnv.shutdown() + rpcEnv.awaitTermination() + rpcEnv = null Utils.deleteRecursively(tempDirectory) } @@ -99,7 +96,7 @@ class ReceivedBlockHandlerSuite extends FunSuite with BeforeAndAfter with Matche testBlockStoring(handler) { case (data, blockIds, storeResults) => // Verify the data in block manager is correct val storedData = blockIds.flatMap { blockId => - blockManager.getLocal(blockId).map { _.data.map {_.toString}.toList }.getOrElse(List.empty) + blockManager.getLocal(blockId).map(_.data.map(_.toString).toList).getOrElse(List.empty) }.toList storedData shouldEqual data @@ -123,7 +120,7 @@ class ReceivedBlockHandlerSuite extends FunSuite with BeforeAndAfter with Matche testBlockStoring(handler) { case (data, blockIds, storeResults) => // Verify the data in block manager is correct val storedData = blockIds.flatMap { blockId => - blockManager.getLocal(blockId).map { _.data.map {_.toString}.toList }.getOrElse(List.empty) + blockManager.getLocal(blockId).map(_.data.map(_.toString).toList).getOrElse(List.empty) }.toList storedData shouldEqual data diff --git a/streaming/src/test/scala/org/apache/spark/streaming/ReceivedBlockTrackerSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/ReceivedBlockTrackerSuite.scala index 42fad769f0c1a..b63b37d9f9cef 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/ReceivedBlockTrackerSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/ReceivedBlockTrackerSuite.scala @@ -228,7 +228,8 @@ class ReceivedBlockTrackerSuite * Get all the data written in the given write ahead log files. By default, it will read all * files in the test log directory. */ - def getWrittenLogData(logFiles: Seq[String] = getWriteAheadLogFiles): Seq[ReceivedBlockTrackerLogEvent] = { + def getWrittenLogData(logFiles: Seq[String] = getWriteAheadLogFiles) + : Seq[ReceivedBlockTrackerLogEvent] = { logFiles.flatMap { file => new WriteAheadLogReader(file, hadoopConf).toSeq }.map { byteBuffer => @@ -244,7 +245,8 @@ class ReceivedBlockTrackerSuite } /** Create batch allocation object from the given info */ - def createBatchAllocation(time: Long, blockInfos: Seq[ReceivedBlockInfo]): BatchAllocationEvent = { + def createBatchAllocation(time: Long, blockInfos: Seq[ReceivedBlockInfo]) + : BatchAllocationEvent = { BatchAllocationEvent(time, AllocatedBlocks(Map((streamId -> blockInfos)))) } diff --git a/streaming/src/test/scala/org/apache/spark/streaming/ReceiverSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/ReceiverSuite.scala index aa20ad0b5374e..10c35cba8dc53 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/ReceiverSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/ReceiverSuite.scala @@ -308,7 +308,7 @@ class ReceiverSuite extends TestSuiteBase with Timeouts with Serializable { val errors = new ArrayBuffer[Throwable] /** Check if all data structures are clean */ - def isAllEmpty = { + def isAllEmpty: Boolean = { singles.isEmpty && byteBuffers.isEmpty && iterators.isEmpty && arrayBuffers.isEmpty && errors.isEmpty } @@ -320,24 +320,21 @@ class ReceiverSuite extends TestSuiteBase with Timeouts with Serializable { def pushBytes( bytes: ByteBuffer, optionalMetadata: Option[Any], - optionalBlockId: Option[StreamBlockId] - ) { + optionalBlockId: Option[StreamBlockId]) { byteBuffers += bytes } def pushIterator( iterator: Iterator[_], optionalMetadata: Option[Any], - optionalBlockId: Option[StreamBlockId] - ) { + optionalBlockId: Option[StreamBlockId]) { iterators += iterator } def pushArrayBuffer( arrayBuffer: ArrayBuffer[_], optionalMetadata: Option[Any], - optionalBlockId: Option[StreamBlockId] - ) { + optionalBlockId: Option[StreamBlockId]) { arrayBuffers += arrayBuffer } diff --git a/streaming/src/test/scala/org/apache/spark/streaming/StreamingContextSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/StreamingContextSuite.scala index 2e5005ef6ff14..d1bbf39dc7897 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/StreamingContextSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/StreamingContextSuite.scala @@ -213,7 +213,7 @@ class StreamingContextSuite extends FunSuite with BeforeAndAfter with Timeouts w ssc = new StreamingContext(sc, Milliseconds(100)) var runningCount = 0 SlowTestReceiver.receivedAllRecords = false - //Create test receiver that sleeps in onStop() + // Create test receiver that sleeps in onStop() val totalNumRecords = 15 val recordsPerSecond = 1 val input = ssc.receiverStream(new SlowTestReceiver(totalNumRecords, recordsPerSecond)) @@ -370,7 +370,8 @@ object TestReceiver { } /** Custom receiver for testing whether a slow receiver can be shutdown gracefully or not */ -class SlowTestReceiver(totalRecords: Int, recordsPerSecond: Int) extends Receiver[Int](StorageLevel.MEMORY_ONLY) with Logging { +class SlowTestReceiver(totalRecords: Int, recordsPerSecond: Int) + extends Receiver[Int](StorageLevel.MEMORY_ONLY) with Logging { var receivingThreadOption: Option[Thread] = None diff --git a/streaming/src/test/scala/org/apache/spark/streaming/StreamingListenerSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/StreamingListenerSuite.scala index f52562b0a0f73..7210439509541 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/StreamingListenerSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/StreamingListenerSuite.scala @@ -38,18 +38,46 @@ class StreamingListenerSuite extends TestSuiteBase with Matchers { // To make sure that the processing start and end times in collected // information are different for successive batches - override def batchDuration = Milliseconds(100) - override def actuallyWait = true + override def batchDuration: Duration = Milliseconds(100) + override def actuallyWait: Boolean = true test("batch info reporting") { val ssc = setupStreams(input, operation) val collector = new BatchInfoCollector ssc.addStreamingListener(collector) runStreams(ssc, input.size, input.size) - val batchInfos = collector.batchInfos - batchInfos should have size 4 - batchInfos.foreach(info => { + // SPARK-6766: batch info should be submitted + val batchInfosSubmitted = collector.batchInfosSubmitted + batchInfosSubmitted should have size 4 + + batchInfosSubmitted.foreach(info => { + info.schedulingDelay should be (None) + info.processingDelay should be (None) + info.totalDelay should be (None) + }) + + isInIncreasingOrder(batchInfosSubmitted.map(_.submissionTime)) should be (true) + + // SPARK-6766: processingStartTime of batch info should not be None when starting + val batchInfosStarted = collector.batchInfosStarted + batchInfosStarted should have size 4 + + batchInfosStarted.foreach(info => { + info.schedulingDelay should not be None + info.schedulingDelay.get should be >= 0L + info.processingDelay should be (None) + info.totalDelay should be (None) + }) + + isInIncreasingOrder(batchInfosStarted.map(_.submissionTime)) should be (true) + isInIncreasingOrder(batchInfosStarted.map(_.processingStartTime.get)) should be (true) + + // test onBatchCompleted + val batchInfosCompleted = collector.batchInfosCompleted + batchInfosCompleted should have size 4 + + batchInfosCompleted.foreach(info => { info.schedulingDelay should not be None info.processingDelay should not be None info.totalDelay should not be None @@ -58,9 +86,9 @@ class StreamingListenerSuite extends TestSuiteBase with Matchers { info.totalDelay.get should be >= 0L }) - isInIncreasingOrder(batchInfos.map(_.submissionTime)) should be (true) - isInIncreasingOrder(batchInfos.map(_.processingStartTime.get)) should be (true) - isInIncreasingOrder(batchInfos.map(_.processingEndTime.get)) should be (true) + isInIncreasingOrder(batchInfosCompleted.map(_.submissionTime)) should be (true) + isInIncreasingOrder(batchInfosCompleted.map(_.processingStartTime.get)) should be (true) + isInIncreasingOrder(batchInfosCompleted.map(_.processingEndTime.get)) should be (true) } test("receiver info reporting") { @@ -99,9 +127,20 @@ class StreamingListenerSuite extends TestSuiteBase with Matchers { /** Listener that collects information on processed batches */ class BatchInfoCollector extends StreamingListener { - val batchInfos = new ArrayBuffer[BatchInfo] + val batchInfosCompleted = new ArrayBuffer[BatchInfo] + val batchInfosStarted = new ArrayBuffer[BatchInfo] + val batchInfosSubmitted = new ArrayBuffer[BatchInfo] + + override def onBatchSubmitted(batchSubmitted: StreamingListenerBatchSubmitted) { + batchInfosSubmitted += batchSubmitted.batchInfo + } + + override def onBatchStarted(batchStarted: StreamingListenerBatchStarted) { + batchInfosStarted += batchStarted.batchInfo + } + override def onBatchCompleted(batchCompleted: StreamingListenerBatchCompleted) { - batchInfos += batchCompleted.batchInfo + batchInfosCompleted += batchCompleted.batchInfo } } diff --git a/streaming/src/test/scala/org/apache/spark/streaming/TestSuiteBase.scala b/streaming/src/test/scala/org/apache/spark/streaming/TestSuiteBase.scala index 3565d621e8a6c..c3cae8aeb6d15 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/TestSuiteBase.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/TestSuiteBase.scala @@ -53,8 +53,9 @@ class TestInputStream[T: ClassTag](ssc_ : StreamingContext, input: Seq[Seq[T]], val selectedInput = if (index < input.size) input(index) else Seq[T]() // lets us test cases where RDDs are not created - if (selectedInput == null) + if (selectedInput == null) { return None + } val rdd = ssc.sc.makeRDD(selectedInput, numPartitions) logInfo("Created RDD " + rdd.id + " with " + selectedInput) @@ -104,7 +105,9 @@ class TestOutputStreamWithPartitions[T: ClassTag](parent: DStream[T], output.clear() } - def toTestOutputStream = new TestOutputStream[T](this.parent, this.output.map(_.flatten)) + def toTestOutputStream: TestOutputStream[T] = { + new TestOutputStream[T](this.parent, this.output.map(_.flatten)) + } } /** @@ -148,34 +151,34 @@ class BatchCounter(ssc: StreamingContext) { trait TestSuiteBase extends FunSuite with BeforeAndAfter with Logging { // Name of the framework for Spark context - def framework = this.getClass.getSimpleName + def framework: String = this.getClass.getSimpleName // Master for Spark context - def master = "local[2]" + def master: String = "local[2]" // Batch duration - def batchDuration = Seconds(1) + def batchDuration: Duration = Seconds(1) // Directory where the checkpoint data will be saved - lazy val checkpointDir = { + lazy val checkpointDir: String = { val dir = Utils.createTempDir() logDebug(s"checkpointDir: $dir") dir.toString } // Number of partitions of the input parallel collections created for testing - def numInputPartitions = 2 + def numInputPartitions: Int = 2 // Maximum time to wait before the test times out - def maxWaitTimeMillis = 10000 + def maxWaitTimeMillis: Int = 10000 // Whether to use manual clock or not - def useManualClock = true + def useManualClock: Boolean = true // Whether to actually wait in real time before changing manual clock - def actuallyWait = false + def actuallyWait: Boolean = false - //// A SparkConf to use in tests. Can be modified before calling setupStreams to configure things. + // A SparkConf to use in tests. Can be modified before calling setupStreams to configure things. val conf = new SparkConf() .setMaster(master) .setAppName(framework) @@ -346,7 +349,8 @@ trait TestSuiteBase extends FunSuite with BeforeAndAfter with Logging { // Wait until expected number of output items have been generated val startTime = System.currentTimeMillis() - while (output.size < numExpectedOutput && System.currentTimeMillis() - startTime < maxWaitTimeMillis) { + while (output.size < numExpectedOutput && + System.currentTimeMillis() - startTime < maxWaitTimeMillis) { logInfo("output.size = " + output.size + ", numExpectedOutput = " + numExpectedOutput) ssc.awaitTerminationOrTimeout(50) } diff --git a/streaming/src/test/scala/org/apache/spark/streaming/UISeleniumSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/UISeleniumSuite.scala index 87a0395efbf2a..998426ebb82e5 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/UISeleniumSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/UISeleniumSuite.scala @@ -32,7 +32,8 @@ import org.apache.spark._ /** * Selenium tests for the Spark Web UI. */ -class UISeleniumSuite extends FunSuite with WebBrowser with Matchers with BeforeAndAfterAll with TestSuiteBase { +class UISeleniumSuite + extends FunSuite with WebBrowser with Matchers with BeforeAndAfterAll with TestSuiteBase { implicit var webDriver: WebDriver = _ diff --git a/streaming/src/test/scala/org/apache/spark/streaming/WindowOperationsSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/WindowOperationsSuite.scala index a5d2bb2fde16c..c39ad05f41520 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/WindowOperationsSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/WindowOperationsSuite.scala @@ -22,9 +22,9 @@ import org.apache.spark.storage.StorageLevel class WindowOperationsSuite extends TestSuiteBase { - override def maxWaitTimeMillis = 20000 // large window tests can sometimes take longer + override def maxWaitTimeMillis: Int = 20000 // large window tests can sometimes take longer - override def batchDuration = Seconds(1) // making sure its visible in this class + override def batchDuration: Duration = Seconds(1) // making sure its visible in this class val largerSlideInput = Seq( Seq(("a", 1)), diff --git a/streaming/src/test/scala/org/apache/spark/streaming/rdd/WriteAheadLogBackedBlockRDDSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/rdd/WriteAheadLogBackedBlockRDDSuite.scala index 7a6a2f3e577dd..c3602a5b73732 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/rdd/WriteAheadLogBackedBlockRDDSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/rdd/WriteAheadLogBackedBlockRDDSuite.scala @@ -28,10 +28,13 @@ import org.apache.spark.storage.{BlockId, BlockManager, StorageLevel, StreamBloc import org.apache.spark.streaming.util.{WriteAheadLogFileSegment, WriteAheadLogWriter} import org.apache.spark.util.Utils -class WriteAheadLogBackedBlockRDDSuite extends FunSuite with BeforeAndAfterAll with BeforeAndAfterEach { +class WriteAheadLogBackedBlockRDDSuite + extends FunSuite with BeforeAndAfterAll with BeforeAndAfterEach { + val conf = new SparkConf() .setMaster("local[2]") .setAppName(this.getClass.getSimpleName) + val hadoopConf = new Configuration() var sparkContext: SparkContext = null @@ -86,7 +89,8 @@ class WriteAheadLogBackedBlockRDDSuite extends FunSuite with BeforeAndAfterAll w * @param numPartitionsInWAL Number of partitions to write to the Write Ahead Log * @param testStoreInBM Test whether blocks read from log are stored back into block manager */ - private def testRDD(numPartitionsInBM: Int, numPartitionsInWAL: Int, testStoreInBM: Boolean = false) { + private def testRDD( + numPartitionsInBM: Int, numPartitionsInWAL: Int, testStoreInBM: Boolean = false) { val numBlocks = numPartitionsInBM + numPartitionsInWAL val data = Seq.fill(numBlocks, 10)(scala.util.Random.nextString(50)) @@ -110,7 +114,7 @@ class WriteAheadLogBackedBlockRDDSuite extends FunSuite with BeforeAndAfterAll w "Unexpected blocks in BlockManager" ) - // Make sure that the right `numPartitionsInWAL` blocks are in write ahead logs, and other are not + // Make sure that the right `numPartitionsInWAL` blocks are in WALs, and other are not require( segments.takeRight(numPartitionsInWAL).forall(s => new File(s.path.stripPrefix("file://")).exists()), @@ -152,6 +156,6 @@ class WriteAheadLogBackedBlockRDDSuite extends FunSuite with BeforeAndAfterAll w } private def generateFakeSegments(count: Int): Seq[WriteAheadLogFileSegment] = { - Array.fill(count)(new WriteAheadLogFileSegment("random", 0l, 0)) + Array.fill(count)(new WriteAheadLogFileSegment("random", 0L, 0)) } } diff --git a/streaming/src/test/scala/org/apache/spark/streaming/scheduler/JobGeneratorSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/scheduler/JobGeneratorSuite.scala index 4150b60635ed6..7865b06c2e3c2 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/scheduler/JobGeneratorSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/scheduler/JobGeneratorSuite.scala @@ -90,7 +90,7 @@ class JobGeneratorSuite extends TestSuiteBase { val receiverTracker = ssc.scheduler.receiverTracker // Get the blocks belonging to a batch - def getBlocksOfBatch(batchTime: Long) = { + def getBlocksOfBatch(batchTime: Long): Seq[ReceivedBlockInfo] = { receiverTracker.getBlocksOfBatchAndStream(Time(batchTime), inputStream.id) } diff --git a/streaming/src/test/scala/org/apache/spark/streaming/ui/StreamingJobProgressListenerSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/ui/StreamingJobProgressListenerSuite.scala new file mode 100644 index 0000000000000..94b1985116feb --- /dev/null +++ b/streaming/src/test/scala/org/apache/spark/streaming/ui/StreamingJobProgressListenerSuite.scala @@ -0,0 +1,119 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.streaming.ui + +import org.scalatest.Matchers + +import org.apache.spark.streaming.dstream.DStream +import org.apache.spark.streaming.scheduler._ +import org.apache.spark.streaming.{Duration, Time, Milliseconds, TestSuiteBase} + +class StreamingJobProgressListenerSuite extends TestSuiteBase with Matchers { + + val input = (1 to 4).map(Seq(_)).toSeq + val operation = (d: DStream[Int]) => d.map(x => x) + + override def batchDuration: Duration = Milliseconds(100) + + test("onBatchSubmitted, onBatchStarted, onBatchCompleted, " + + "onReceiverStarted, onReceiverError, onReceiverStopped") { + val ssc = setupStreams(input, operation) + val listener = new StreamingJobProgressListener(ssc) + + val receivedBlockInfo = Map( + 0 -> Array(ReceivedBlockInfo(0, 100, null), ReceivedBlockInfo(0, 200, null)), + 1 -> Array(ReceivedBlockInfo(1, 300, null)) + ) + + // onBatchSubmitted + val batchInfoSubmitted = BatchInfo(Time(1000), receivedBlockInfo, 1000, None, None) + listener.onBatchSubmitted(StreamingListenerBatchSubmitted(batchInfoSubmitted)) + listener.waitingBatches should be (List(batchInfoSubmitted)) + listener.runningBatches should be (Nil) + listener.retainedCompletedBatches should be (Nil) + listener.lastCompletedBatch should be (None) + listener.numUnprocessedBatches should be (1) + listener.numTotalCompletedBatches should be (0) + listener.numTotalProcessedRecords should be (0) + listener.numTotalReceivedRecords should be (0) + + // onBatchStarted + val batchInfoStarted = BatchInfo(Time(1000), receivedBlockInfo, 1000, Some(2000), None) + listener.onBatchStarted(StreamingListenerBatchStarted(batchInfoStarted)) + listener.waitingBatches should be (Nil) + listener.runningBatches should be (List(batchInfoStarted)) + listener.retainedCompletedBatches should be (Nil) + listener.lastCompletedBatch should be (None) + listener.numUnprocessedBatches should be (1) + listener.numTotalCompletedBatches should be (0) + listener.numTotalProcessedRecords should be (0) + listener.numTotalReceivedRecords should be (600) + + // onBatchCompleted + val batchInfoCompleted = BatchInfo(Time(1000), receivedBlockInfo, 1000, Some(2000), None) + listener.onBatchCompleted(StreamingListenerBatchCompleted(batchInfoCompleted)) + listener.waitingBatches should be (Nil) + listener.runningBatches should be (Nil) + listener.retainedCompletedBatches should be (List(batchInfoCompleted)) + listener.lastCompletedBatch should be (Some(batchInfoCompleted)) + listener.numUnprocessedBatches should be (0) + listener.numTotalCompletedBatches should be (1) + listener.numTotalProcessedRecords should be (600) + listener.numTotalReceivedRecords should be (600) + + // onReceiverStarted + val receiverInfoStarted = ReceiverInfo(0, "test", null, true, "localhost") + listener.onReceiverStarted(StreamingListenerReceiverStarted(receiverInfoStarted)) + listener.receiverInfo(0) should be (Some(receiverInfoStarted)) + listener.receiverInfo(1) should be (None) + + // onReceiverError + val receiverInfoError = ReceiverInfo(1, "test", null, true, "localhost") + listener.onReceiverError(StreamingListenerReceiverError(receiverInfoError)) + listener.receiverInfo(0) should be (Some(receiverInfoStarted)) + listener.receiverInfo(1) should be (Some(receiverInfoError)) + listener.receiverInfo(2) should be (None) + + // onReceiverStopped + val receiverInfoStopped = ReceiverInfo(2, "test", null, true, "localhost") + listener.onReceiverStopped(StreamingListenerReceiverStopped(receiverInfoStopped)) + listener.receiverInfo(0) should be (Some(receiverInfoStarted)) + listener.receiverInfo(1) should be (Some(receiverInfoError)) + listener.receiverInfo(2) should be (Some(receiverInfoStopped)) + listener.receiverInfo(3) should be (None) + } + + test("Remove the old completed batches when exceeding the limit") { + val ssc = setupStreams(input, operation) + val limit = ssc.conf.getInt("spark.streaming.ui.retainedBatches", 100) + val listener = new StreamingJobProgressListener(ssc) + + val receivedBlockInfo = Map( + 0 -> Array(ReceivedBlockInfo(0, 100, null), ReceivedBlockInfo(0, 200, null)), + 1 -> Array(ReceivedBlockInfo(1, 300, null)) + ) + val batchInfoCompleted = BatchInfo(Time(1000), receivedBlockInfo, 1000, Some(2000), None) + + for(_ <- 0 until (limit + 10)) { + listener.onBatchCompleted(StreamingListenerBatchCompleted(batchInfoCompleted)) + } + + listener.retainedCompletedBatches.size should be (limit) + listener.numTotalCompletedBatches should be(limit + 10) + } +} diff --git a/streaming/src/test/scala/org/apache/spark/streaming/util/WriteAheadLogSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/util/WriteAheadLogSuite.scala index 8335659667f22..a3919c43b95b4 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/util/WriteAheadLogSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/util/WriteAheadLogSuite.scala @@ -291,7 +291,7 @@ object WriteAheadLogSuite { manager } - /** Read data from a segments of a log file directly and return the list of byte buffers.*/ + /** Read data from a segments of a log file directly and return the list of byte buffers. */ def readDataManually(segments: Seq[WriteAheadLogFileSegment]): Seq[String] = { segments.map { segment => val reader = HdfsUtils.getInputStream(segment.path, hadoopConf) diff --git a/streaming/src/test/scala/org/apache/spark/streamingtest/ImplicitSuite.scala b/streaming/src/test/scala/org/apache/spark/streamingtest/ImplicitSuite.scala index d0bf328f2b74d..d66750463033a 100644 --- a/streaming/src/test/scala/org/apache/spark/streamingtest/ImplicitSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streamingtest/ImplicitSuite.scala @@ -25,7 +25,8 @@ package org.apache.spark.streamingtest */ class ImplicitSuite { - // We only want to test if `implict` works well with the compiler, so we don't need a real DStream. + // We only want to test if `implicit` works well with the compiler, + // so we don't need a real DStream. def mockDStream[T]: org.apache.spark.streaming.dstream.DStream[T] = null def testToPairDStreamFunctions(): Unit = { diff --git a/tools/src/main/scala/org/apache/spark/tools/JavaAPICompletenessChecker.scala b/tools/src/main/scala/org/apache/spark/tools/JavaAPICompletenessChecker.scala index 8d0f09933c8d3..583823c90c5c6 100644 --- a/tools/src/main/scala/org/apache/spark/tools/JavaAPICompletenessChecker.scala +++ b/tools/src/main/scala/org/apache/spark/tools/JavaAPICompletenessChecker.scala @@ -17,7 +17,7 @@ package org.apache.spark.tools -import java.lang.reflect.Method +import java.lang.reflect.{Type, Method} import scala.collection.mutable.ArrayBuffer import scala.language.existentials @@ -302,7 +302,7 @@ object JavaAPICompletenessChecker { private def isExcludedByInterface(method: Method): Boolean = { val excludedInterfaces = Set("org.apache.spark.Logging", "org.apache.hadoop.mapreduce.HadoopMapReduceUtil") - def toComparisionKey(method: Method) = + def toComparisionKey(method: Method): (Class[_], String, Type) = (method.getReturnType, method.getName, method.getGenericReturnType) val interfaces = method.getDeclaringClass.getInterfaces.filter { i => excludedInterfaces.contains(i.getName) diff --git a/tools/src/main/scala/org/apache/spark/tools/StoragePerfTester.scala b/tools/src/main/scala/org/apache/spark/tools/StoragePerfTester.scala index 15ee95070a3d3..f2d135397ce2f 100644 --- a/tools/src/main/scala/org/apache/spark/tools/StoragePerfTester.scala +++ b/tools/src/main/scala/org/apache/spark/tools/StoragePerfTester.scala @@ -32,7 +32,7 @@ import org.apache.spark.util.Utils * Writes simulated shuffle output from several threads and records the observed throughput. */ object StoragePerfTester { - def main(args: Array[String]) = { + def main(args: Array[String]): Unit = { /** Total amount of data to generate. Distributed evenly amongst maps and reduce splits. */ val dataSizeMb = Utils.memoryStringToMb(sys.env.getOrElse("OUTPUT_DATA", "1g")) @@ -58,8 +58,8 @@ object StoragePerfTester { val sc = new SparkContext("local[4]", "Write Tester", conf) val hashShuffleManager = sc.env.shuffleManager.asInstanceOf[HashShuffleManager] - def writeOutputBytes(mapId: Int, total: AtomicLong) = { - val shuffle = hashShuffleManager.shuffleBlockManager.forMapTask(1, mapId, numOutputSplits, + def writeOutputBytes(mapId: Int, total: AtomicLong): Unit = { + val shuffle = hashShuffleManager.shuffleBlockResolver.forMapTask(1, mapId, numOutputSplits, new KryoSerializer(sc.conf), new ShuffleWriteMetrics()) val writers = shuffle.writers for (i <- 1 to recordsPerMap) { @@ -78,7 +78,7 @@ object StoragePerfTester { val totalBytes = new AtomicLong() for (task <- 1 to numMaps) { executor.submit(new Runnable() { - override def run() = { + override def run(): Unit = { try { writeOutputBytes(task, totalBytes) latch.countDown() diff --git a/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala b/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala index 3d18690cd9cbf..26259cee77151 100644 --- a/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala +++ b/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala @@ -24,22 +24,20 @@ import java.lang.reflect.InvocationTargetException import java.net.{Socket, URL} import java.util.concurrent.atomic.AtomicReference -import akka.actor._ -import akka.remote._ import org.apache.hadoop.fs.{FileSystem, Path} import org.apache.hadoop.util.ShutdownHookManager import org.apache.hadoop.yarn.api._ import org.apache.hadoop.yarn.api.records._ import org.apache.hadoop.yarn.conf.YarnConfiguration +import org.apache.spark.rpc._ import org.apache.spark.{Logging, SecurityManager, SparkConf, SparkContext, SparkEnv} import org.apache.spark.SparkException import org.apache.spark.deploy.{PythonRunner, SparkHadoopUtil} import org.apache.spark.deploy.history.HistoryServer import org.apache.spark.scheduler.cluster.YarnSchedulerBackend import org.apache.spark.scheduler.cluster.CoarseGrainedClusterMessages._ -import org.apache.spark.util.{AkkaUtils, ChildFirstURLClassLoader, MutableURLClassLoader, - SignalLogger, Utils} +import org.apache.spark.util._ /** * Common application master functionality for Spark on Yarn. @@ -72,8 +70,8 @@ private[spark] class ApplicationMaster( @volatile private var allocator: YarnAllocator = _ // Fields used in client mode. - private var actorSystem: ActorSystem = null - private var actor: ActorRef = _ + private var rpcEnv: RpcEnv = null + private var amEndpoint: RpcEndpointRef = _ // Fields used in cluster mode. private val sparkContextRef = new AtomicReference[SparkContext](null) @@ -162,7 +160,7 @@ private[spark] class ApplicationMaster( * status to SUCCEEDED in cluster mode to handle if the user calls System.exit * from the application code. */ - final def getDefaultFinalStatus() = { + final def getDefaultFinalStatus(): FinalApplicationStatus = { if (isClusterMode) { FinalApplicationStatus.SUCCEEDED } else { @@ -175,31 +173,35 @@ private[spark] class ApplicationMaster( * This means the ResourceManager will not retry the application attempt on your behalf if * a failure occurred. */ - final def unregister(status: FinalApplicationStatus, diagnostics: String = null) = synchronized { - if (!unregistered) { - logInfo(s"Unregistering ApplicationMaster with $status" + - Option(diagnostics).map(msg => s" (diag message: $msg)").getOrElse("")) - unregistered = true - client.unregister(status, Option(diagnostics).getOrElse("")) + final def unregister(status: FinalApplicationStatus, diagnostics: String = null): Unit = { + synchronized { + if (!unregistered) { + logInfo(s"Unregistering ApplicationMaster with $status" + + Option(diagnostics).map(msg => s" (diag message: $msg)").getOrElse("")) + unregistered = true + client.unregister(status, Option(diagnostics).getOrElse("")) + } } } - final def finish(status: FinalApplicationStatus, code: Int, msg: String = null) = synchronized { - if (!finished) { - val inShutdown = Utils.inShutdown() - logInfo(s"Final app status: ${status}, exitCode: ${code}" + - Option(msg).map(msg => s", (reason: $msg)").getOrElse("")) - exitCode = code - finalStatus = status - finalMsg = msg - finished = true - if (!inShutdown && Thread.currentThread() != reporterThread && reporterThread != null) { - logDebug("shutting down reporter thread") - reporterThread.interrupt() - } - if (!inShutdown && Thread.currentThread() != userClassThread && userClassThread != null) { - logDebug("shutting down user thread") - userClassThread.interrupt() + final def finish(status: FinalApplicationStatus, code: Int, msg: String = null): Unit = { + synchronized { + if (!finished) { + val inShutdown = Utils.inShutdown() + logInfo(s"Final app status: $status, exitCode: $code" + + Option(msg).map(msg => s", (reason: $msg)").getOrElse("")) + exitCode = code + finalStatus = status + finalMsg = msg + finished = true + if (!inShutdown && Thread.currentThread() != reporterThread && reporterThread != null) { + logDebug("shutting down reporter thread") + reporterThread.interrupt() + } + if (!inShutdown && Thread.currentThread() != userClassThread && userClassThread != null) { + logDebug("shutting down user thread") + userClassThread.interrupt() + } } } } @@ -221,6 +223,7 @@ private[spark] class ApplicationMaster( val appId = client.getAttemptId().getApplicationId().toString() val historyAddress = sparkConf.getOption("spark.yarn.historyServer.address") + .map { text => SparkHadoopUtil.get.substituteHadoopVariables(text, yarnConf) } .map { address => s"${address}${HistoryServer.UI_PATH_PREFIX}/${appId}" } .getOrElse("") @@ -236,22 +239,21 @@ private[spark] class ApplicationMaster( } /** - * Create an actor that communicates with the driver. + * Create an [[RpcEndpoint]] that communicates with the driver. * * In cluster mode, the AM and the driver belong to same process - * so the AM actor need not monitor lifecycle of the driver. + * so the AMEndpoint need not monitor lifecycle of the driver. */ - private def runAMActor( + private def runAMEndpoint( host: String, port: String, isClusterMode: Boolean): Unit = { - val driverUrl = AkkaUtils.address( - AkkaUtils.protocol(actorSystem), + val driverEndpont = rpcEnv.setupEndpointRef( SparkEnv.driverActorSystemName, - host, - port, - YarnSchedulerBackend.ACTOR_NAME) - actor = actorSystem.actorOf(Props(new AMActor(driverUrl, isClusterMode)), name = "YarnAM") + RpcAddress(host, port.toInt), + YarnSchedulerBackend.ENDPOINT_NAME) + amEndpoint = + rpcEnv.setupEndpoint("YarnAM", new AMEndpoint(rpcEnv, driverEndpont, isClusterMode)) } private def runDriver(securityMgr: SecurityManager): Unit = { @@ -268,8 +270,8 @@ private[spark] class ApplicationMaster( ApplicationMaster.EXIT_SC_NOT_INITED, "Timed out waiting for SparkContext.") } else { - actorSystem = sc.env.actorSystem - runAMActor( + rpcEnv = sc.env.rpcEnv + runAMEndpoint( sc.getConf.get("spark.driver.host"), sc.getConf.get("spark.driver.port"), isClusterMode = true) @@ -279,8 +281,7 @@ private[spark] class ApplicationMaster( } private def runExecutorLauncher(securityMgr: SecurityManager): Unit = { - actorSystem = AkkaUtils.createActorSystem("sparkYarnAM", Utils.localHostName, 0, - conf = sparkConf, securityManager = securityMgr)._1 + rpcEnv = RpcEnv.create("sparkYarnAM", Utils.localHostName, 0, sparkConf, securityMgr) waitForSparkDriver() addAmIpFilter() registerAM(sparkConf.get("spark.driver.appUIAddress", ""), securityMgr) @@ -427,7 +428,7 @@ private[spark] class ApplicationMaster( sparkConf.set("spark.driver.host", driverHost) sparkConf.set("spark.driver.port", driverPort.toString) - runAMActor(driverHost, driverPort.toString, isClusterMode = false) + runAMEndpoint(driverHost, driverPort.toString, isClusterMode = false) } /** Add the Yarn IP filter that is required for properly securing the UI. */ @@ -439,7 +440,7 @@ private[spark] class ApplicationMaster( System.setProperty("spark.ui.filters", amFilter) params.foreach { case (k, v) => System.setProperty(s"spark.$amFilter.param.$k", v) } } else { - actor ! AddWebUIFilter(amFilter, params.toMap, proxyBase) + amEndpoint.send(AddWebUIFilter(amFilter, params.toMap, proxyBase)) } } @@ -469,6 +470,9 @@ private[spark] class ApplicationMaster( System.setProperty("spark.submit.pyFiles", PythonRunner.formatPaths(args.pyFiles).mkString(",")) } + if (args.primaryRFile != null && args.primaryRFile.endsWith(".R")) { + // TODO(davies): add R dependencies here + } val mainMethod = userClassLoader.loadClass(args.userClass) .getMethod("main", classOf[Array[String]]) @@ -501,44 +505,29 @@ private[spark] class ApplicationMaster( } /** - * An actor that communicates with the driver's scheduler backend. + * An [[RpcEndpoint]] that communicates with the driver's scheduler backend. */ - private class AMActor(driverUrl: String, isClusterMode: Boolean) extends Actor { - var driver: ActorSelection = _ - - override def preStart() = { - logInfo("Listen to driver: " + driverUrl) - driver = context.actorSelection(driverUrl) - // Send a hello message to establish the connection, after which - // we can monitor Lifecycle Events. - driver ! "Hello" - driver ! RegisterClusterManager - // In cluster mode, the AM can directly monitor the driver status instead - // of trying to deduce it from the lifecycle of the driver's actor - if (!isClusterMode) { - context.system.eventStream.subscribe(self, classOf[RemotingLifecycleEvent]) - } - } + private class AMEndpoint( + override val rpcEnv: RpcEnv, driver: RpcEndpointRef, isClusterMode: Boolean) + extends RpcEndpoint with Logging { - override def receive = { - case x: DisassociatedEvent => - logInfo(s"Driver terminated or disconnected! Shutting down. $x") - // In cluster mode, do not rely on the disassociated event to exit - // This avoids potentially reporting incorrect exit codes if the driver fails - if (!isClusterMode) { - finish(FinalApplicationStatus.SUCCEEDED, ApplicationMaster.EXIT_SUCCESS) - } + override def onStart(): Unit = { + driver.send(RegisterClusterManager(self)) + } + override def receive: PartialFunction[Any, Unit] = { case x: AddWebUIFilter => logInfo(s"Add WebUI Filter. $x") - driver ! x + driver.send(x) + } + override def receiveAndReply(context: RpcCallContext): PartialFunction[Any, Unit] = { case RequestExecutors(requestedTotal) => Option(allocator) match { case Some(a) => a.requestTotalExecutors(requestedTotal) case None => logWarning("Container allocator is not ready to request executors yet.") } - sender ! true + context.reply(true) case KillExecutors(executorIds) => logInfo(s"Driver requested to kill executor(s) ${executorIds.mkString(", ")}.") @@ -546,7 +535,16 @@ private[spark] class ApplicationMaster( case Some(a) => executorIds.foreach(a.killExecutor) case None => logWarning("Container allocator is not ready to kill executors yet.") } - sender ! true + context.reply(true) + } + + override def onDisconnected(remoteAddress: RpcAddress): Unit = { + logInfo(s"Driver terminated or disconnected! Shutting down. $remoteAddress") + // In cluster mode, do not rely on the disassociated event to exit + // This avoids potentially reporting incorrect exit codes if the driver fails + if (!isClusterMode) { + finish(FinalApplicationStatus.SUCCEEDED, ApplicationMaster.EXIT_SUCCESS) + } } } @@ -567,7 +565,7 @@ object ApplicationMaster extends Logging { private var master: ApplicationMaster = _ - def main(args: Array[String]) = { + def main(args: Array[String]): Unit = { SignalLogger.register(log) val amArgs = new ApplicationMasterArguments(args) SparkHadoopUtil.get.runAsSparkUser { () => @@ -576,11 +574,11 @@ object ApplicationMaster extends Logging { } } - private[spark] def sparkContextInitialized(sc: SparkContext) = { + private[spark] def sparkContextInitialized(sc: SparkContext): Unit = { master.sparkContextInitialized(sc) } - private[spark] def sparkContextStopped(sc: SparkContext) = { + private[spark] def sparkContextStopped(sc: SparkContext): Boolean = { master.sparkContextStopped(sc) } @@ -592,7 +590,7 @@ object ApplicationMaster extends Logging { */ object ExecutorLauncher { - def main(args: Array[String]) = { + def main(args: Array[String]): Unit = { ApplicationMaster.main(args) } diff --git a/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMasterArguments.scala b/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMasterArguments.scala index e1a992af3aae7..ae6dc1094d724 100644 --- a/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMasterArguments.scala +++ b/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMasterArguments.scala @@ -25,6 +25,7 @@ class ApplicationMasterArguments(val args: Array[String]) { var userJar: String = null var userClass: String = null var primaryPyFile: String = null + var primaryRFile: String = null var pyFiles: String = null var userArgs: Seq[String] = Seq[String]() var executorMemory = 1024 @@ -54,6 +55,10 @@ class ApplicationMasterArguments(val args: Array[String]) { primaryPyFile = value args = tail + case ("--primary-r-file") :: value :: tail => + primaryRFile = value + args = tail + case ("--py-files") :: value :: tail => pyFiles = value args = tail @@ -79,6 +84,11 @@ class ApplicationMasterArguments(val args: Array[String]) { } } + if (primaryPyFile != null && primaryRFile != null) { + System.err.println("Cannot have primary-py-file and primary-r-file at the same time") + System.exit(-1) + } + userArgs = userArgsBuffer.readOnly } @@ -92,6 +102,7 @@ class ApplicationMasterArguments(val args: Array[String]) { | --jar JAR_PATH Path to your application's JAR file | --class CLASS_NAME Name of your application's main class | --primary-py-file A main Python file + | --primary-r-file A main R file | --py-files PY_FILES Comma-separated list of .zip, .egg, or .py files to | place on the PYTHONPATH for Python apps. | --args ARGS Arguments to be passed to your application's main class. diff --git a/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala b/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala index 61f8fc3f5a014..1091ff54b0463 100644 --- a/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala +++ b/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala @@ -22,17 +22,21 @@ import java.nio.ByteBuffer import scala.collection.JavaConversions._ import scala.collection.mutable.{ArrayBuffer, HashMap, ListBuffer, Map} +import scala.reflect.runtime.universe import scala.util.{Try, Success, Failure} import com.google.common.base.Objects import org.apache.hadoop.io.DataOutputBuffer import org.apache.hadoop.conf.Configuration +import org.apache.hadoop.hdfs.security.token.delegation.DelegationTokenIdentifier import org.apache.hadoop.fs._ import org.apache.hadoop.fs.permission.FsPermission +import org.apache.hadoop.io.Text import org.apache.hadoop.mapred.Master import org.apache.hadoop.mapreduce.MRJobConfig import org.apache.hadoop.security.{Credentials, UserGroupInformation} +import org.apache.hadoop.security.token.Token import org.apache.hadoop.util.StringUtils import org.apache.hadoop.yarn.api._ import org.apache.hadoop.yarn.api.ApplicationConstants.Environment @@ -40,6 +44,7 @@ import org.apache.hadoop.yarn.api.protocolrecords._ import org.apache.hadoop.yarn.api.records._ import org.apache.hadoop.yarn.client.api.{YarnClient, YarnClientApplication} import org.apache.hadoop.yarn.conf.YarnConfiguration +import org.apache.hadoop.yarn.exceptions.ApplicationNotFoundException import org.apache.hadoop.yarn.util.Records import org.apache.spark.{Logging, SecurityManager, SparkConf, SparkContext, SparkException} @@ -66,6 +71,8 @@ private[spark] class Client( private val executorMemoryOverhead = args.executorMemoryOverhead // MB private val distCacheMgr = new ClientDistributedCacheManager() private val isClusterMode = args.isClusterMode + private val fireAndForget = isClusterMode && + !sparkConf.getBoolean("spark.yarn.submit.waitAppCompletion", true) def stop(): Unit = yarnClient.stop() @@ -217,6 +224,7 @@ private[spark] class Client( val dst = new Path(fs.getHomeDirectory(), appStagingDir) val nns = getNameNodesToAccess(sparkConf) + dst obtainTokensForNamenodes(nns, hadoopConf, credentials) + obtainTokenForHiveMetastore(hadoopConf, credentials) val replication = sparkConf.getInt("spark.yarn.submit.file.replication", fs.getDefaultReplication(dst)).toShort @@ -488,6 +496,12 @@ private[spark] class Client( } else { Nil } + val primaryRFile = + if (args.primaryRFile != null) { + Seq("--primary-r-file", args.primaryRFile) + } else { + Nil + } val amClass = if (isClusterMode) { Class.forName("org.apache.spark.deploy.yarn.ApplicationMaster").getName @@ -497,12 +511,15 @@ private[spark] class Client( if (args.primaryPyFile != null && args.primaryPyFile.endsWith(".py")) { args.userArgs = ArrayBuffer(args.primaryPyFile, args.pyFiles) ++ args.userArgs } + if (args.primaryRFile != null && args.primaryRFile.endsWith(".R")) { + args.userArgs = ArrayBuffer(args.primaryRFile) ++ args.userArgs + } val userArgs = args.userArgs.flatMap { arg => Seq("--arg", YarnSparkHadoopUtil.escapeForShell(arg)) } val amArgs = - Seq(amClass) ++ userClass ++ userJar ++ primaryPyFile ++ pyFiles ++ userArgs ++ - Seq( + Seq(amClass) ++ userClass ++ userJar ++ primaryPyFile ++ pyFiles ++ primaryRFile ++ + userArgs ++ Seq( "--executor-memory", args.executorMemory.toString + "m", "--executor-cores", args.executorCores.toString, "--num-executors ", args.numExecutors.toString) @@ -559,36 +576,25 @@ private[spark] class Client( var lastState: YarnApplicationState = null while (true) { Thread.sleep(interval) - val report = getApplicationReport(appId) + val report: ApplicationReport = + try { + getApplicationReport(appId) + } catch { + case e: ApplicationNotFoundException => + logError(s"Application $appId not found.") + return (YarnApplicationState.KILLED, FinalApplicationStatus.KILLED) + } val state = report.getYarnApplicationState if (logApplicationReport) { logInfo(s"Application report for $appId (state: $state)") - val details = Seq[(String, String)]( - ("client token", getClientToken(report)), - ("diagnostics", report.getDiagnostics), - ("ApplicationMaster host", report.getHost), - ("ApplicationMaster RPC port", report.getRpcPort.toString), - ("queue", report.getQueue), - ("start time", report.getStartTime.toString), - ("final status", report.getFinalApplicationStatus.toString), - ("tracking URL", report.getTrackingUrl), - ("user", report.getUser) - ) - - // Use more loggable format if value is null or empty - val formattedDetails = details - .map { case (k, v) => - val newValue = Option(v).filter(_.nonEmpty).getOrElse("N/A") - s"\n\t $k: $newValue" } - .mkString("") // If DEBUG is enabled, log report details every iteration // Otherwise, log them every time the application changes state if (log.isDebugEnabled) { - logDebug(formattedDetails) + logDebug(formatReportDetails(report)) } else if (lastState != state) { - logInfo(formattedDetails) + logInfo(formatReportDetails(report)) } } @@ -609,24 +615,57 @@ private[spark] class Client( throw new SparkException("While loop is depleted! This should never happen...") } + private def formatReportDetails(report: ApplicationReport): String = { + val details = Seq[(String, String)]( + ("client token", getClientToken(report)), + ("diagnostics", report.getDiagnostics), + ("ApplicationMaster host", report.getHost), + ("ApplicationMaster RPC port", report.getRpcPort.toString), + ("queue", report.getQueue), + ("start time", report.getStartTime.toString), + ("final status", report.getFinalApplicationStatus.toString), + ("tracking URL", report.getTrackingUrl), + ("user", report.getUser) + ) + + // Use more loggable format if value is null or empty + details.map { case (k, v) => + val newValue = Option(v).filter(_.nonEmpty).getOrElse("N/A") + s"\n\t $k: $newValue" + }.mkString("") + } + /** - * Submit an application to the ResourceManager and monitor its state. - * This continues until the application has exited for any reason. + * Submit an application to the ResourceManager. + * If set spark.yarn.submit.waitAppCompletion to true, it will stay alive + * reporting the application's status until the application has exited for any reason. + * Otherwise, the client process will exit after submission. * If the application finishes with a failed, killed, or undefined status, * throw an appropriate SparkException. */ def run(): Unit = { - val (yarnApplicationState, finalApplicationStatus) = monitorApplication(submitApplication()) - if (yarnApplicationState == YarnApplicationState.FAILED || - finalApplicationStatus == FinalApplicationStatus.FAILED) { - throw new SparkException("Application finished with failed status") - } - if (yarnApplicationState == YarnApplicationState.KILLED || - finalApplicationStatus == FinalApplicationStatus.KILLED) { - throw new SparkException("Application is killed") - } - if (finalApplicationStatus == FinalApplicationStatus.UNDEFINED) { - throw new SparkException("The final status of application is undefined") + val appId = submitApplication() + if (fireAndForget) { + val report = getApplicationReport(appId) + val state = report.getYarnApplicationState + logInfo(s"Application report for $appId (state: $state)") + logInfo(formatReportDetails(report)) + if (state == YarnApplicationState.FAILED || state == YarnApplicationState.KILLED) { + throw new SparkException(s"Application $appId finished with status: $state") + } + } else { + val (yarnApplicationState, finalApplicationStatus) = monitorApplication(appId) + if (yarnApplicationState == YarnApplicationState.FAILED || + finalApplicationStatus == FinalApplicationStatus.FAILED) { + throw new SparkException(s"Application $appId finished with failed status") + } + if (yarnApplicationState == YarnApplicationState.KILLED || + finalApplicationStatus == FinalApplicationStatus.KILLED) { + throw new SparkException(s"Application $appId is killed") + } + if (finalApplicationStatus == FinalApplicationStatus.UNDEFINED) { + throw new SparkException(s"The final status of application $appId is undefined") + } } } } @@ -902,6 +941,64 @@ object Client extends Logging { } } + /** + * Obtains token for the Hive metastore and adds them to the credentials. + */ + private def obtainTokenForHiveMetastore(conf: Configuration, credentials: Credentials) { + if (UserGroupInformation.isSecurityEnabled) { + val mirror = universe.runtimeMirror(getClass.getClassLoader) + + try { + val hiveClass = mirror.classLoader.loadClass("org.apache.hadoop.hive.ql.metadata.Hive") + val hive = hiveClass.getMethod("get").invoke(null) + + val hiveConf = hiveClass.getMethod("getConf").invoke(hive) + val hiveConfClass = mirror.classLoader.loadClass("org.apache.hadoop.hive.conf.HiveConf") + + val hiveConfGet = (param:String) => Option(hiveConfClass + .getMethod("get", classOf[java.lang.String]) + .invoke(hiveConf, param)) + + val metastore_uri = hiveConfGet("hive.metastore.uris") + + // Check for local metastore + if (metastore_uri != None && metastore_uri.get.toString.size > 0) { + val metastore_kerberos_principal_conf_var = mirror.classLoader + .loadClass("org.apache.hadoop.hive.conf.HiveConf$ConfVars") + .getField("METASTORE_KERBEROS_PRINCIPAL").get("varname").toString + + val principal = hiveConfGet(metastore_kerberos_principal_conf_var) + + val username = Option(UserGroupInformation.getCurrentUser().getUserName) + if (principal != None && username != None) { + val tokenStr = hiveClass.getMethod("getDelegationToken", + classOf[java.lang.String], classOf[java.lang.String]) + .invoke(hive, username.get, principal.get).asInstanceOf[java.lang.String] + + val hive2Token = new Token[DelegationTokenIdentifier]() + hive2Token.decodeFromUrlString(tokenStr) + credentials.addToken(new Text("hive.server2.delegation.token"),hive2Token) + logDebug("Added hive.Server2.delegation.token to conf.") + hiveClass.getMethod("closeCurrent").invoke(null) + } else { + logError("Username or principal == NULL") + logError(s"""username=${username.getOrElse("(NULL)")}""") + logError(s"""principal=${principal.getOrElse("(NULL)")}""") + throw new IllegalArgumentException("username and/or principal is equal to null!") + } + } else { + logDebug("HiveMetaStore configured in localmode") + } + } catch { + case e:java.lang.NoSuchMethodException => { logInfo("Hive Method not found " + e); return } + case e:java.lang.ClassNotFoundException => { logInfo("Hive Class not found " + e); return } + case e:Exception => { logError("Unexpected Exception " + e) + throw new RuntimeException("Unexpected exception", e) + } + } + } + } + /** * Return whether the two file systems are the same. */ diff --git a/yarn/src/main/scala/org/apache/spark/deploy/yarn/ClientArguments.scala b/yarn/src/main/scala/org/apache/spark/deploy/yarn/ClientArguments.scala index 3bc7eb1abf341..da6798cb1b279 100644 --- a/yarn/src/main/scala/org/apache/spark/deploy/yarn/ClientArguments.scala +++ b/yarn/src/main/scala/org/apache/spark/deploy/yarn/ClientArguments.scala @@ -32,6 +32,7 @@ private[spark] class ClientArguments(args: Array[String], sparkConf: SparkConf) var userClass: String = null var pyFiles: String = null var primaryPyFile: String = null + var primaryRFile: String = null var userArgs: ArrayBuffer[String] = new ArrayBuffer[String]() var executorMemory = 1024 // MB var executorCores = 1 @@ -150,6 +151,10 @@ private[spark] class ClientArguments(args: Array[String], sparkConf: SparkConf) primaryPyFile = value args = tail + case ("--primary-r-file") :: value :: tail => + primaryRFile = value + args = tail + case ("--args" | "--arg") :: value :: tail => if (args(0) == "--args") { println("--args is deprecated. Use --arg instead.") @@ -228,6 +233,11 @@ private[spark] class ClientArguments(args: Array[String], sparkConf: SparkConf) throw new IllegalArgumentException(getUsageMessage(args)) } } + + if (primaryPyFile != null && primaryRFile != null) { + throw new IllegalArgumentException("Cannot have primary-py-file and primary-r-file" + + " at the same time") + } } private def getUsageMessage(unknownParam: List[String] = null): String = { @@ -240,6 +250,7 @@ private[spark] class ClientArguments(args: Array[String], sparkConf: SparkConf) | mode) | --class CLASS_NAME Name of your application's main class (required) | --primary-py-file A main Python file + | --primary-r-file A main R file | --arg ARG Argument to be passed to your application's main class. | Multiple invocations are possible, each will be passed in order. | --num-executors NUM Number of executors to start (Default: 2) diff --git a/yarn/src/main/scala/org/apache/spark/deploy/yarn/ExecutorRunnable.scala b/yarn/src/main/scala/org/apache/spark/deploy/yarn/ExecutorRunnable.scala index c1d3f7320f53c..b06069c07f451 100644 --- a/yarn/src/main/scala/org/apache/spark/deploy/yarn/ExecutorRunnable.scala +++ b/yarn/src/main/scala/org/apache/spark/deploy/yarn/ExecutorRunnable.scala @@ -59,15 +59,15 @@ class ExecutorRunnable( val yarnConf: YarnConfiguration = new YarnConfiguration(conf) lazy val env = prepareEnvironment(container) - def run = { + override def run(): Unit = { logInfo("Starting Executor Container") nmClient = NMClient.createNMClient() nmClient.init(yarnConf) nmClient.start() - startContainer + startContainer() } - def startContainer = { + def startContainer(): java.util.Map[String, ByteBuffer] = { logInfo("Setting up ContainerLaunchContext") val ctx = Records.newRecord(classOf[ContainerLaunchContext]) @@ -290,10 +290,19 @@ class ExecutorRunnable( YarnSparkHadoopUtil.setEnvFromInputString(env, userEnvs) } + // lookup appropriate http scheme for container log urls + val yarnHttpPolicy = yarnConf.get( + YarnConfiguration.YARN_HTTP_POLICY_KEY, + YarnConfiguration.YARN_HTTP_POLICY_DEFAULT + ) + val httpScheme = if (yarnHttpPolicy == "HTTPS_ONLY") "https://" else "http://" + // Add log urls sys.env.get("SPARK_USER").foreach { user => - val baseUrl = "http://%s/node/containerlogs/%s/%s" - .format(container.getNodeHttpAddress, ConverterUtils.toString(container.getId), user) + val containerId = ConverterUtils.toString(container.getId) + val address = container.getNodeHttpAddress + val baseUrl = s"$httpScheme$address/node/containerlogs/$containerId/$user" + env("SPARK_LOG_URL_STDERR") = s"$baseUrl/stderr?start=0" env("SPARK_LOG_URL_STDOUT") = s"$baseUrl/stdout?start=0" } diff --git a/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnAllocator.scala b/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnAllocator.scala index c98763e15b58f..b8f42dadcb464 100644 --- a/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnAllocator.scala +++ b/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnAllocator.scala @@ -112,7 +112,7 @@ private[yarn] class YarnAllocator( SparkEnv.driverActorSystemName, sparkConf.get("spark.driver.host"), sparkConf.get("spark.driver.port"), - CoarseGrainedSchedulerBackend.ACTOR_NAME) + CoarseGrainedSchedulerBackend.ENDPOINT_NAME) // For testing private val launchContainers = sparkConf.getBoolean("spark.yarn.launchContainers", true) diff --git a/yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnClientSchedulerBackend.scala b/yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnClientSchedulerBackend.scala index 8abdc26b43806..99c05329b4d73 100644 --- a/yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnClientSchedulerBackend.scala +++ b/yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnClientSchedulerBackend.scala @@ -34,7 +34,7 @@ private[spark] class YarnClientSchedulerBackend( private var client: Client = null private var appId: ApplicationId = null - @volatile private var stopping: Boolean = false + private var monitorThread: Thread = null /** * Create a Yarn client to submit an application to the ResourceManager. @@ -57,7 +57,8 @@ private[spark] class YarnClientSchedulerBackend( client = new Client(args, conf) appId = client.submitApplication() waitForApplication() - asyncMonitorApplication() + monitorThread = asyncMonitorApplication() + monitorThread.start() } /** @@ -123,34 +124,22 @@ private[spark] class YarnClientSchedulerBackend( * If the application has exited for any reason, stop the SparkContext. * This assumes both `client` and `appId` have already been set. */ - private def asyncMonitorApplication(): Unit = { + private def asyncMonitorApplication(): Thread = { assert(client != null && appId != null, "Application has not been submitted yet!") val t = new Thread { override def run() { - while (!stopping) { - var state: YarnApplicationState = null - try { - val report = client.getApplicationReport(appId) - state = report.getYarnApplicationState() - } catch { - case e: ApplicationNotFoundException => - state = YarnApplicationState.KILLED - } - if (state == YarnApplicationState.FINISHED || - state == YarnApplicationState.KILLED || - state == YarnApplicationState.FAILED) { - logError(s"Yarn application has already exited with state $state!") - sc.stop() - stopping = true - } - Thread.sleep(1000L) + try { + val (state, _) = client.monitorApplication(appId, logApplicationReport = false) + logError(s"Yarn application has already exited with state $state!") + sc.stop() + } catch { + case e: InterruptedException => logInfo("Interrupting monitor thread") } - Thread.currentThread().interrupt() } } t.setName("Yarn application state monitor") t.setDaemon(true) - t.start() + t } /** @@ -158,7 +147,7 @@ private[spark] class YarnClientSchedulerBackend( */ override def stop() { assert(client != null, "Attempted to stop this scheduler before starting it!") - stopping = true + monitorThread.interrupt() super.stop() client.stop() logInfo("Stopped") diff --git a/yarn/src/test/resources/log4j.properties b/yarn/src/test/resources/log4j.properties index aab41fa49430f..6b8a5dbf6373e 100644 --- a/yarn/src/test/resources/log4j.properties +++ b/yarn/src/test/resources/log4j.properties @@ -24,5 +24,5 @@ log4j.appender.file.layout=org.apache.log4j.PatternLayout log4j.appender.file.layout.ConversionPattern=%d{yy/MM/dd HH:mm:ss.SSS} %t %p %c{1}: %m%n # Ignore messages below warning level from Jetty, because it's a bit verbose -log4j.logger.org.eclipse.jetty=WARN +log4j.logger.org.spark-project.jetty=WARN log4j.logger.org.apache.hadoop=WARN diff --git a/yarn/src/test/scala/org/apache/spark/deploy/yarn/ClientSuite.scala b/yarn/src/test/scala/org/apache/spark/deploy/yarn/ClientSuite.scala index 92f04b4b859b3..c1b94ac9c5bdd 100644 --- a/yarn/src/test/scala/org/apache/spark/deploy/yarn/ClientSuite.scala +++ b/yarn/src/test/scala/org/apache/spark/deploy/yarn/ClientSuite.scala @@ -232,19 +232,26 @@ class ClientSuite extends FunSuite with Matchers with BeforeAndAfterAll { testCode(conf) } - def newEnv = MutableHashMap[String, String]() + def newEnv: MutableHashMap[String, String] = MutableHashMap[String, String]() - def classpath(env: MutableHashMap[String, String]) = env(Environment.CLASSPATH.name).split(":|;|") + def classpath(env: MutableHashMap[String, String]): Array[String] = + env(Environment.CLASSPATH.name).split(":|;|") - def flatten(a: Option[Seq[String]], b: Option[Seq[String]]) = (a ++ b).flatten.toArray + def flatten(a: Option[Seq[String]], b: Option[Seq[String]]): Array[String] = + (a ++ b).flatten.toArray - def getFieldValue[A, B](clazz: Class[_], field: String, defaults: => B)(mapTo: A => B): B = - Try(clazz.getField(field)).map(_.get(null).asInstanceOf[A]).toOption.map(mapTo).getOrElse(defaults) + def getFieldValue[A, B](clazz: Class[_], field: String, defaults: => B)(mapTo: A => B): B = { + Try(clazz.getField(field)) + .map(_.get(null).asInstanceOf[A]) + .toOption + .map(mapTo) + .getOrElse(defaults) + } def getFieldValue2[A: ClassTag, A1: ClassTag, B]( clazz: Class[_], field: String, - defaults: => B)(mapTo: A => B)(mapTo1: A1 => B) : B = { + defaults: => B)(mapTo: A => B)(mapTo1: A1 => B): B = { Try(clazz.getField(field)).map(_.get(null)).map { case v: A => mapTo(v) case v1: A1 => mapTo1(v1) diff --git a/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnAllocatorSuite.scala b/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnAllocatorSuite.scala index c09b01bafce37..455f1019d86dd 100644 --- a/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnAllocatorSuite.scala +++ b/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnAllocatorSuite.scala @@ -79,7 +79,7 @@ class YarnAllocatorSuite extends FunSuite with Matchers with BeforeAndAfterEach } class MockSplitInfo(host: String) extends SplitInfo(null, host, null, 1, null) { - override def equals(other: Any) = false + override def equals(other: Any): Boolean = false } def createAllocator(maxExecutors: Int = 5): YarnAllocator = { @@ -118,7 +118,9 @@ class YarnAllocatorSuite extends FunSuite with Matchers with BeforeAndAfterEach handler.getNumExecutorsRunning should be (1) handler.allocatedContainerToHostMap.get(container.getId).get should be ("host1") handler.allocatedHostToContainersMap.get("host1").get should contain (container.getId) - rmClient.getMatchingRequests(container.getPriority, "host1", containerResource).size should be (0) + + val size = rmClient.getMatchingRequests(container.getPriority, "host1", containerResource).size + size should be (0) } test("some containers allocated") { diff --git a/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnClusterSuite.scala b/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnClusterSuite.scala index 0e37276ba724b..c06c0105670c0 100644 --- a/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnClusterSuite.scala +++ b/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnClusterSuite.scala @@ -143,7 +143,8 @@ class YarnClusterSuite extends FunSuite with BeforeAndAfterAll with Matchers wit } } - test("run Python application in yarn-cluster mode") { + // Enable this once fix SPARK-6700 + ignore("run Python application in yarn-cluster mode") { val primaryPyFile = new File(tempDir, "test.py") Files.write(TEST_PYFILE, primaryPyFile, UTF_8) val pyFile = new File(tempDir, "test2.py") diff --git a/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnSparkHadoopUtilSuite.scala b/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnSparkHadoopUtilSuite.scala index 4194f36499e66..9395316b71ff4 100644 --- a/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnSparkHadoopUtilSuite.scala +++ b/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnSparkHadoopUtilSuite.scala @@ -46,7 +46,7 @@ class YarnSparkHadoopUtilSuite extends FunSuite with Matchers with Logging { logWarning("Cannot execute bash, skipping bash tests.") } - def bashTest(name: String)(fn: => Unit) = + def bashTest(name: String)(fn: => Unit): Unit = if (hasBash) test(name)(fn) else ignore(name)(fn) bashTest("shell script escaping") {