diff --git a/LICENSE b/LICENSE index 9d1b00beff748..d0cd0dcb4bdb7 100644 --- a/LICENSE +++ b/LICENSE @@ -853,6 +853,52 @@ and Vis.js may be distributed under either license. +======================================================================== +For dagre-d3 (core/src/main/resources/org/apache/spark/ui/static/dagre-d3.min.js): +======================================================================== +Copyright (c) 2013 Chris Pettitt + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in +all copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +THE SOFTWARE. + +======================================================================== +For graphlib-dot (core/src/main/resources/org/apache/spark/ui/static/graphlib-dot.min.js): +======================================================================== +Copyright (c) 2012-2013 Chris Pettitt + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in +all copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +THE SOFTWARE. + ======================================================================== BSD-style licenses ======================================================================== diff --git a/R/pkg/NAMESPACE b/R/pkg/NAMESPACE index 411126a377950..f9447f6c3288d 100644 --- a/R/pkg/NAMESPACE +++ b/R/pkg/NAMESPACE @@ -19,9 +19,11 @@ exportMethods("arrange", "count", "describe", "distinct", + "dropna", "dtypes", "except", "explain", + "fillna", "filter", "first", "group_by", diff --git a/R/pkg/R/DataFrame.R b/R/pkg/R/DataFrame.R index ed8093c80d360..0af5cb8881e35 100644 --- a/R/pkg/R/DataFrame.R +++ b/R/pkg/R/DataFrame.R @@ -1314,9 +1314,8 @@ setMethod("except", #' write.df(df, "myfile", "parquet", "overwrite") #' } setMethod("write.df", - signature(df = "DataFrame", path = 'character', source = 'character', - mode = 'character'), - function(df, path = NULL, source = NULL, mode = "append", ...){ + signature(df = "DataFrame", path = 'character'), + function(df, path, source = NULL, mode = "append", ...){ if (is.null(source)) { sqlContext <- get(".sparkRSQLsc", envir = .sparkREnv) source <- callJMethod(sqlContext, "getConf", "spark.sql.sources.default", @@ -1338,9 +1337,8 @@ setMethod("write.df", #' @aliases saveDF #' @export setMethod("saveDF", - signature(df = "DataFrame", path = 'character', source = 'character', - mode = 'character'), - function(df, path = NULL, source = NULL, mode = "append", ...){ + signature(df = "DataFrame", path = 'character'), + function(df, path, source = NULL, mode = "append", ...){ write.df(df, path, source, mode, ...) }) @@ -1431,3 +1429,128 @@ setMethod("describe", sdf <- callJMethod(x@sdf, "describe", listToSeq(colList)) dataFrame(sdf) }) + +#' dropna +#' +#' Returns a new DataFrame omitting rows with null values. +#' +#' @param x A SparkSQL DataFrame. +#' @param how "any" or "all". +#' if "any", drop a row if it contains any nulls. +#' if "all", drop a row only if all its values are null. +#' if minNonNulls is specified, how is ignored. +#' @param minNonNulls If specified, drop rows that have less than +#' minNonNulls non-null values. +#' This overwrites the how parameter. +#' @param cols Optional list of column names to consider. +#' @return A DataFrame +#' +#' @rdname nafunctions +#' @export +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' sqlCtx <- sparkRSQL.init(sc) +#' path <- "path/to/file.json" +#' df <- jsonFile(sqlCtx, path) +#' dropna(df) +#' } +setMethod("dropna", + signature(x = "DataFrame"), + function(x, how = c("any", "all"), minNonNulls = NULL, cols = NULL) { + how <- match.arg(how) + if (is.null(cols)) { + cols <- columns(x) + } + if (is.null(minNonNulls)) { + minNonNulls <- if (how == "any") { length(cols) } else { 1 } + } + + naFunctions <- callJMethod(x@sdf, "na") + sdf <- callJMethod(naFunctions, "drop", + as.integer(minNonNulls), listToSeq(as.list(cols))) + dataFrame(sdf) + }) + +#' @aliases dropna +#' @export +setMethod("na.omit", + signature(x = "DataFrame"), + function(x, how = c("any", "all"), minNonNulls = NULL, cols = NULL) { + dropna(x, how, minNonNulls, cols) + }) + +#' fillna +#' +#' Replace null values. +#' +#' @param x A SparkSQL DataFrame. +#' @param value Value to replace null values with. +#' Should be an integer, numeric, character or named list. +#' If the value is a named list, then cols is ignored and +#' value must be a mapping from column name (character) to +#' replacement value. The replacement value must be an +#' integer, numeric or character. +#' @param cols optional list of column names to consider. +#' Columns specified in cols that do not have matching data +#' type are ignored. For example, if value is a character, and +#' subset contains a non-character column, then the non-character +#' column is simply ignored. +#' @return A DataFrame +#' +#' @rdname nafunctions +#' @export +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' sqlCtx <- sparkRSQL.init(sc) +#' path <- "path/to/file.json" +#' df <- jsonFile(sqlCtx, path) +#' fillna(df, 1) +#' fillna(df, list("age" = 20, "name" = "unknown")) +#' } +setMethod("fillna", + signature(x = "DataFrame"), + function(x, value, cols = NULL) { + if (!(class(value) %in% c("integer", "numeric", "character", "list"))) { + stop("value should be an integer, numeric, charactor or named list.") + } + + if (class(value) == "list") { + # Check column names in the named list + colNames <- names(value) + if (length(colNames) == 0 || !all(colNames != "")) { + stop("value should be an a named list with each name being a column name.") + } + + # Convert to the named list to an environment to be passed to JVM + valueMap <- new.env() + for (col in colNames) { + # Check each item in the named list is of valid type + v <- value[[col]] + if (!(class(v) %in% c("integer", "numeric", "character"))) { + stop("Each item in value should be an integer, numeric or charactor.") + } + valueMap[[col]] <- v + } + + # When value is a named list, caller is expected not to pass in cols + if (!is.null(cols)) { + warning("When value is a named list, cols is ignored!") + cols <- NULL + } + + value <- valueMap + } else if (is.integer(value)) { + # Cast an integer to a numeric + value <- as.numeric(value) + } + + naFunctions <- callJMethod(x@sdf, "na") + sdf <- if (length(cols) == 0) { + callJMethod(naFunctions, "fill", value) + } else { + callJMethod(naFunctions, "fill", value, listToSeq(as.list(cols))) + } + dataFrame(sdf) + }) diff --git a/R/pkg/R/SQLContext.R b/R/pkg/R/SQLContext.R index 36cc612875879..88e1a508f37c4 100644 --- a/R/pkg/R/SQLContext.R +++ b/R/pkg/R/SQLContext.R @@ -457,6 +457,11 @@ read.df <- function(sqlContext, path = NULL, source = NULL, ...) { if (!is.null(path)) { options[['path']] <- path } + if (is.null(source)) { + sqlContext <- get(".sparkRSQLsc", envir = .sparkREnv) + source <- callJMethod(sqlContext, "getConf", "spark.sql.sources.default", + "org.apache.spark.sql.parquet") + } sdf <- callJMethod(sqlContext, "load", source, options) dataFrame(sdf) } diff --git a/R/pkg/R/generics.R b/R/pkg/R/generics.R index a23d3b217b2fd..12e09176c9f92 100644 --- a/R/pkg/R/generics.R +++ b/R/pkg/R/generics.R @@ -396,6 +396,20 @@ setGeneric("columns", function(x) {standardGeneric("columns") }) #' @export setGeneric("describe", function(x, col, ...) { standardGeneric("describe") }) +#' @rdname nafunctions +#' @export +setGeneric("dropna", + function(x, how = c("any", "all"), minNonNulls = NULL, cols = NULL) { + standardGeneric("dropna") + }) + +#' @rdname nafunctions +#' @export +setGeneric("na.omit", + function(x, how = c("any", "all"), minNonNulls = NULL, cols = NULL) { + standardGeneric("na.omit") + }) + #' @rdname schema #' @export setGeneric("dtypes", function(x) { standardGeneric("dtypes") }) @@ -408,6 +422,10 @@ setGeneric("explain", function(x, ...) { standardGeneric("explain") }) #' @export setGeneric("except", function(x, y) { standardGeneric("except") }) +#' @rdname nafunctions +#' @export +setGeneric("fillna", function(x, value, cols = NULL) { standardGeneric("fillna") }) + #' @rdname filter #' @export setGeneric("filter", function(x, condition) { standardGeneric("filter") }) @@ -482,11 +500,11 @@ setGeneric("saveAsTable", function(df, tableName, source, mode, ...) { #' @rdname write.df #' @export -setGeneric("write.df", function(df, path, source, mode, ...) { standardGeneric("write.df") }) +setGeneric("write.df", function(df, path, ...) { standardGeneric("write.df") }) #' @rdname write.df #' @export -setGeneric("saveDF", function(df, path, source, mode, ...) { standardGeneric("saveDF") }) +setGeneric("saveDF", function(df, path, ...) { standardGeneric("saveDF") }) #' @rdname schema #' @export diff --git a/R/pkg/R/serialize.R b/R/pkg/R/serialize.R index c53d0a961016f..2081786e6f833 100644 --- a/R/pkg/R/serialize.R +++ b/R/pkg/R/serialize.R @@ -160,6 +160,14 @@ writeList <- function(con, arr) { } } +# Used to pass arrays where the elements can be of different types +writeGenericList <- function(con, list) { + writeInt(con, length(list)) + for (elem in list) { + writeObject(con, elem) + } +} + # Used to pass in hash maps required on Java side. writeEnv <- function(con, env) { len <- length(env) @@ -168,7 +176,7 @@ writeEnv <- function(con, env) { if (len > 0) { writeList(con, as.list(ls(env))) vals <- lapply(ls(env), function(x) { env[[x]] }) - writeList(con, as.list(vals)) + writeGenericList(con, as.list(vals)) } } diff --git a/R/pkg/R/sparkR.R b/R/pkg/R/sparkR.R index 68387f0f5365d..5ced7c688f98a 100644 --- a/R/pkg/R/sparkR.R +++ b/R/pkg/R/sparkR.R @@ -225,14 +225,21 @@ sparkR.init <- function( #' sqlContext <- sparkRSQL.init(sc) #'} -sparkRSQL.init <- function(jsc) { +sparkRSQL.init <- function(jsc = NULL) { if (exists(".sparkRSQLsc", envir = .sparkREnv)) { return(get(".sparkRSQLsc", envir = .sparkREnv)) } + # If jsc is NULL, create a Spark Context + sc <- if (is.null(jsc)) { + sparkR.init() + } else { + jsc + } + sqlContext <- callJStatic("org.apache.spark.sql.api.r.SQLUtils", - "createSQLContext", - jsc) + "createSQLContext", + sc) assign(".sparkRSQLsc", sqlContext, envir = .sparkREnv) sqlContext } @@ -249,12 +256,19 @@ sparkRSQL.init <- function(jsc) { #' sqlContext <- sparkRHive.init(sc) #'} -sparkRHive.init <- function(jsc) { +sparkRHive.init <- function(jsc = NULL) { if (exists(".sparkRHivesc", envir = .sparkREnv)) { return(get(".sparkRHivesc", envir = .sparkREnv)) } - ssc <- callJMethod(jsc, "sc") + # If jsc is NULL, create a Spark Context + sc <- if (is.null(jsc)) { + sparkR.init() + } else { + jsc + } + + ssc <- callJMethod(sc, "sc") hiveCtx <- tryCatch({ newJObject("org.apache.spark.sql.hive.HiveContext", ssc) }, error = function(err) { diff --git a/R/pkg/inst/tests/test_sparkSQL.R b/R/pkg/inst/tests/test_sparkSQL.R index 1857e636e8577..d2d82e791e876 100644 --- a/R/pkg/inst/tests/test_sparkSQL.R +++ b/R/pkg/inst/tests/test_sparkSQL.R @@ -32,6 +32,15 @@ jsonPath <- tempfile(pattern="sparkr-test", fileext=".tmp") parquetPath <- tempfile(pattern="sparkr-test", fileext=".parquet") writeLines(mockLines, jsonPath) +# For test nafunctions, like dropna(), fillna(),... +mockLinesNa <- c("{\"name\":\"Bob\",\"age\":16,\"height\":176.5}", + "{\"name\":\"Alice\",\"age\":null,\"height\":164.3}", + "{\"name\":\"David\",\"age\":60,\"height\":null}", + "{\"name\":\"Amy\",\"age\":null,\"height\":null}", + "{\"name\":null,\"age\":null,\"height\":null}") +jsonPathNa <- tempfile(pattern="sparkr-test", fileext=".tmp") +writeLines(mockLinesNa, jsonPathNa) + test_that("infer types", { expect_equal(infer_type(1L), "integer") expect_equal(infer_type(1.0), "double") @@ -765,5 +774,105 @@ test_that("describe() on a DataFrame", { expect_equal(collect(stats)[5, "age"], "30") }) +test_that("dropna() on a DataFrame", { + df <- jsonFile(sqlContext, jsonPathNa) + rows <- collect(df) + + # drop with columns + + expected <- rows[!is.na(rows$name),] + actual <- collect(dropna(df, cols = "name")) + expect_true(identical(expected, actual)) + + expected <- rows[!is.na(rows$age),] + actual <- collect(dropna(df, cols = "age")) + row.names(expected) <- row.names(actual) + # identical on two dataframes does not work here. Don't know why. + # use identical on all columns as a workaround. + expect_true(identical(expected$age, actual$age)) + expect_true(identical(expected$height, actual$height)) + expect_true(identical(expected$name, actual$name)) + + expected <- rows[!is.na(rows$age) & !is.na(rows$height),] + actual <- collect(dropna(df, cols = c("age", "height"))) + expect_true(identical(expected, actual)) + + expected <- rows[!is.na(rows$age) & !is.na(rows$height) & !is.na(rows$name),] + actual <- collect(dropna(df)) + expect_true(identical(expected, actual)) + + # drop with how + + expected <- rows[!is.na(rows$age) & !is.na(rows$height) & !is.na(rows$name),] + actual <- collect(dropna(df)) + expect_true(identical(expected, actual)) + + expected <- rows[!is.na(rows$age) | !is.na(rows$height) | !is.na(rows$name),] + actual <- collect(dropna(df, "all")) + expect_true(identical(expected, actual)) + + expected <- rows[!is.na(rows$age) & !is.na(rows$height) & !is.na(rows$name),] + actual <- collect(dropna(df, "any")) + expect_true(identical(expected, actual)) + + expected <- rows[!is.na(rows$age) & !is.na(rows$height),] + actual <- collect(dropna(df, "any", cols = c("age", "height"))) + expect_true(identical(expected, actual)) + + expected <- rows[!is.na(rows$age) | !is.na(rows$height),] + actual <- collect(dropna(df, "all", cols = c("age", "height"))) + expect_true(identical(expected, actual)) + + # drop with threshold + + expected <- rows[as.integer(!is.na(rows$age)) + as.integer(!is.na(rows$height)) >= 2,] + actual <- collect(dropna(df, minNonNulls = 2, cols = c("age", "height"))) + expect_true(identical(expected, actual)) + + expected <- rows[as.integer(!is.na(rows$age)) + + as.integer(!is.na(rows$height)) + + as.integer(!is.na(rows$name)) >= 3,] + actual <- collect(dropna(df, minNonNulls = 3, cols = c("name", "age", "height"))) + expect_true(identical(expected, actual)) +}) + +test_that("fillna() on a DataFrame", { + df <- jsonFile(sqlContext, jsonPathNa) + rows <- collect(df) + + # fill with value + + expected <- rows + expected$age[is.na(expected$age)] <- 50 + expected$height[is.na(expected$height)] <- 50.6 + actual <- collect(fillna(df, 50.6)) + expect_true(identical(expected, actual)) + + expected <- rows + expected$name[is.na(expected$name)] <- "unknown" + actual <- collect(fillna(df, "unknown")) + expect_true(identical(expected, actual)) + + expected <- rows + expected$age[is.na(expected$age)] <- 50 + actual <- collect(fillna(df, 50.6, "age")) + expect_true(identical(expected, actual)) + + expected <- rows + expected$name[is.na(expected$name)] <- "unknown" + actual <- collect(fillna(df, "unknown", c("age", "name"))) + expect_true(identical(expected, actual)) + + # fill with named list + + expected <- rows + expected$age[is.na(expected$age)] <- 50 + expected$height[is.na(expected$height)] <- 50.6 + expected$name[is.na(expected$name)] <- "unknown" + actual <- collect(fillna(df, list("age" = 50, "height" = 50.6, "name" = "unknown"))) + expect_true(identical(expected, actual)) +}) + unlink(parquetPath) unlink(jsonPath) +unlink(jsonPathNa) diff --git a/README.md b/README.md index 9c09d40e2bdae..380422ca00dbe 100644 --- a/README.md +++ b/README.md @@ -3,8 +3,8 @@ Spark is a fast and general cluster computing system for Big Data. It provides high-level APIs in Scala, Java, and Python, and an optimized engine that supports general computation graphs for data analysis. It also supports a -rich set of higher-level tools including Spark SQL for SQL and structured -data processing, MLlib for machine learning, GraphX for graph processing, +rich set of higher-level tools including Spark SQL for SQL and DataFrames, +MLlib for machine learning, GraphX for graph processing, and Spark Streaming for stream processing. @@ -22,7 +22,7 @@ This README file only contains basic setup instructions. Spark is built using [Apache Maven](http://maven.apache.org/). To build Spark and its example programs, run: - mvn -DskipTests clean package + build/mvn -DskipTests clean package (You do not need to do this if you downloaded a pre-built package.) More detailed documentation is available from the project site, at @@ -43,7 +43,7 @@ Try the following command, which should return 1000: Alternatively, if you prefer Python, you can use the Python shell: ./bin/pyspark - + And run the following command, which should also return 1000: >>> sc.parallelize(range(1000)).count() @@ -58,9 +58,9 @@ To run one of them, use `./bin/run-example [params]`. For example: will run the Pi example locally. You can set the MASTER environment variable when running examples to submit -examples to a cluster. This can be a mesos:// or spark:// URL, -"yarn-cluster" or "yarn-client" to run on YARN, and "local" to run -locally with one thread, or "local[N]" to run locally with N threads. You +examples to a cluster. This can be a mesos:// or spark:// URL, +"yarn-cluster" or "yarn-client" to run on YARN, and "local" to run +locally with one thread, or "local[N]" to run locally with N threads. You can also use an abbreviated class name if the class is in the `examples` package. For instance: @@ -75,7 +75,7 @@ can be run using: ./dev/run-tests -Please see the guidance on how to +Please see the guidance on how to [run tests for a module, or individual tests](https://cwiki.apache.org/confluence/display/SPARK/Useful+Developer+Tools). ## A Note About Hadoop Versions diff --git a/bagel/pom.xml b/bagel/pom.xml index 1f3dec91314f2..132cd433d78a2 100644 --- a/bagel/pom.xml +++ b/bagel/pom.xml @@ -40,6 +40,13 @@ spark-core_${scala.binary.version} ${project.version} + + org.apache.spark + spark-core_${scala.binary.version} + ${project.version} + test-jar + test + org.scalacheck scalacheck_${scala.binary.version} diff --git a/bagel/src/test/scala/org/apache/spark/bagel/BagelSuite.scala b/bagel/src/test/scala/org/apache/spark/bagel/BagelSuite.scala index ccb262a4ee02a..fb10d734ac74b 100644 --- a/bagel/src/test/scala/org/apache/spark/bagel/BagelSuite.scala +++ b/bagel/src/test/scala/org/apache/spark/bagel/BagelSuite.scala @@ -17,7 +17,7 @@ package org.apache.spark.bagel -import org.scalatest.{BeforeAndAfter, FunSuite, Assertions} +import org.scalatest.{BeforeAndAfter, Assertions} import org.scalatest.concurrent.Timeouts import org.scalatest.time.SpanSugar._ @@ -27,7 +27,7 @@ import org.apache.spark.storage.StorageLevel class TestVertex(val active: Boolean, val age: Int) extends Vertex with Serializable class TestMessage(val targetId: String) extends Message[String] with Serializable -class BagelSuite extends FunSuite with Assertions with BeforeAndAfter with Timeouts { +class BagelSuite extends SparkFunSuite with Assertions with BeforeAndAfter with Timeouts { var sc: SparkContext = _ diff --git a/bin/pyspark b/bin/pyspark index 8acad6113797d..7cb19c51b43a2 100755 --- a/bin/pyspark +++ b/bin/pyspark @@ -90,11 +90,7 @@ if [[ -n "$SPARK_TESTING" ]]; then unset YARN_CONF_DIR unset HADOOP_CONF_DIR export PYTHONHASHSEED=0 - if [[ -n "$PYSPARK_DOC_TEST" ]]; then - exec "$PYSPARK_DRIVER_PYTHON" -m doctest $1 - else - exec "$PYSPARK_DRIVER_PYTHON" $1 - fi + exec "$PYSPARK_DRIVER_PYTHON" -m $1 exit fi diff --git a/core/pom.xml b/core/pom.xml index e58efe495e36d..a02184222e9f0 100644 --- a/core/pom.xml +++ b/core/pom.xml @@ -338,6 +338,12 @@ org.seleniumhq.selenium selenium-java + + + com.google.guava + guava + + test @@ -475,6 +481,29 @@ + + sparkr-docs + + + + org.codehaus.mojo + exec-maven-plugin + + + sparkr-pkg-docs + compile + + exec + + + + + ..${path.separator}R${path.separator}create-docs${script.extension} + + + + + diff --git a/core/src/main/java/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriter.java b/core/src/main/java/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriter.java new file mode 100644 index 0000000000000..d3d6280284beb --- /dev/null +++ b/core/src/main/java/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriter.java @@ -0,0 +1,184 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.shuffle.sort; + +import java.io.File; +import java.io.FileInputStream; +import java.io.FileOutputStream; +import java.io.IOException; + +import scala.Product2; +import scala.Tuple2; +import scala.collection.Iterator; + +import com.google.common.io.Closeables; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import org.apache.spark.Partitioner; +import org.apache.spark.SparkConf; +import org.apache.spark.TaskContext; +import org.apache.spark.executor.ShuffleWriteMetrics; +import org.apache.spark.serializer.Serializer; +import org.apache.spark.serializer.SerializerInstance; +import org.apache.spark.storage.*; +import org.apache.spark.util.Utils; + +/** + * This class implements sort-based shuffle's hash-style shuffle fallback path. This write path + * writes incoming records to separate files, one file per reduce partition, then concatenates these + * per-partition files to form a single output file, regions of which are served to reducers. + * Records are not buffered in memory. This is essentially identical to + * {@link org.apache.spark.shuffle.hash.HashShuffleWriter}, except that it writes output in a format + * that can be served / consumed via {@link org.apache.spark.shuffle.IndexShuffleBlockResolver}. + *

+ * This write path is inefficient for shuffles with large numbers of reduce partitions because it + * simultaneously opens separate serializers and file streams for all partitions. As a result, + * {@link SortShuffleManager} only selects this write path when + *

    + *
  • no Ordering is specified,
  • + *
  • no Aggregator is specific, and
  • + *
  • the number of partitions is less than + * spark.shuffle.sort.bypassMergeThreshold.
  • + *
+ * + * This code used to be part of {@link org.apache.spark.util.collection.ExternalSorter} but was + * refactored into its own class in order to reduce code complexity; see SPARK-7855 for details. + *

+ * There have been proposals to completely remove this code path; see SPARK-6026 for details. + */ +final class BypassMergeSortShuffleWriter implements SortShuffleFileWriter { + + private final Logger logger = LoggerFactory.getLogger(BypassMergeSortShuffleWriter.class); + + private final int fileBufferSize; + private final boolean transferToEnabled; + private final int numPartitions; + private final BlockManager blockManager; + private final Partitioner partitioner; + private final ShuffleWriteMetrics writeMetrics; + private final Serializer serializer; + + /** Array of file writers, one for each partition */ + private BlockObjectWriter[] partitionWriters; + + public BypassMergeSortShuffleWriter( + SparkConf conf, + BlockManager blockManager, + Partitioner partitioner, + ShuffleWriteMetrics writeMetrics, + Serializer serializer) { + // Use getSizeAsKb (not bytes) to maintain backwards compatibility if no units are provided + this.fileBufferSize = (int) conf.getSizeAsKb("spark.shuffle.file.buffer", "32k") * 1024; + this.transferToEnabled = conf.getBoolean("spark.file.transferTo", true); + this.numPartitions = partitioner.numPartitions(); + this.blockManager = blockManager; + this.partitioner = partitioner; + this.writeMetrics = writeMetrics; + this.serializer = serializer; + } + + @Override + public void insertAll(Iterator> records) throws IOException { + assert (partitionWriters == null); + if (!records.hasNext()) { + return; + } + final SerializerInstance serInstance = serializer.newInstance(); + final long openStartTime = System.nanoTime(); + partitionWriters = new BlockObjectWriter[numPartitions]; + for (int i = 0; i < numPartitions; i++) { + final Tuple2 tempShuffleBlockIdPlusFile = + blockManager.diskBlockManager().createTempShuffleBlock(); + final File file = tempShuffleBlockIdPlusFile._2(); + final BlockId blockId = tempShuffleBlockIdPlusFile._1(); + partitionWriters[i] = + blockManager.getDiskWriter(blockId, file, serInstance, fileBufferSize, writeMetrics).open(); + } + // Creating the file to write to and creating a disk writer both involve interacting with + // the disk, and can take a long time in aggregate when we open many files, so should be + // included in the shuffle write time. + writeMetrics.incShuffleWriteTime(System.nanoTime() - openStartTime); + + while (records.hasNext()) { + final Product2 record = records.next(); + final K key = record._1(); + partitionWriters[partitioner.getPartition(key)].write(key, record._2()); + } + + for (BlockObjectWriter writer : partitionWriters) { + writer.commitAndClose(); + } + } + + @Override + public long[] writePartitionedFile( + BlockId blockId, + TaskContext context, + File outputFile) throws IOException { + // Track location of the partition starts in the output file + final long[] lengths = new long[numPartitions]; + if (partitionWriters == null) { + // We were passed an empty iterator + return lengths; + } + + final FileOutputStream out = new FileOutputStream(outputFile, true); + final long writeStartTime = System.nanoTime(); + boolean threwException = true; + try { + for (int i = 0; i < numPartitions; i++) { + final FileInputStream in = new FileInputStream(partitionWriters[i].fileSegment().file()); + boolean copyThrewException = true; + try { + lengths[i] = Utils.copyStream(in, out, false, transferToEnabled); + copyThrewException = false; + } finally { + Closeables.close(in, copyThrewException); + } + if (!blockManager.diskBlockManager().getFile(partitionWriters[i].blockId()).delete()) { + logger.error("Unable to delete file for partition {}", i); + } + } + threwException = false; + } finally { + Closeables.close(out, threwException); + writeMetrics.incShuffleWriteTime(System.nanoTime() - writeStartTime); + } + partitionWriters = null; + return lengths; + } + + @Override + public void stop() throws IOException { + if (partitionWriters != null) { + try { + final DiskBlockManager diskBlockManager = blockManager.diskBlockManager(); + for (BlockObjectWriter writer : partitionWriters) { + // This method explicitly does _not_ throw exceptions: + writer.revertPartialWritesAndClose(); + if (!diskBlockManager.getFile(writer.blockId()).delete()) { + logger.error("Error while deleting file for block {}", writer.blockId()); + } + } + } finally { + partitionWriters = null; + } + } + } +} diff --git a/core/src/main/java/org/apache/spark/shuffle/sort/SortShuffleFileWriter.java b/core/src/main/java/org/apache/spark/shuffle/sort/SortShuffleFileWriter.java new file mode 100644 index 0000000000000..656ea0401a144 --- /dev/null +++ b/core/src/main/java/org/apache/spark/shuffle/sort/SortShuffleFileWriter.java @@ -0,0 +1,53 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.shuffle.sort; + +import java.io.File; +import java.io.IOException; + +import scala.Product2; +import scala.collection.Iterator; + +import org.apache.spark.annotation.Private; +import org.apache.spark.TaskContext; +import org.apache.spark.storage.BlockId; + +/** + * Interface for objects that {@link SortShuffleWriter} uses to write its output files. + */ +@Private +public interface SortShuffleFileWriter { + + void insertAll(Iterator> records) throws IOException; + + /** + * Write all the data added into this shuffle sorter into a file in the disk store. This is + * called by the SortShuffleWriter and can go through an efficient path of just concatenating + * binary files if we decided to avoid merge-sorting. + * + * @param blockId block ID to write to. The index file will be blockId.name + ".index". + * @param context a TaskContext for a running Spark task, for us to update shuffle metrics. + * @return array of lengths, in bytes, of each partition of the file (used by map output tracker) + */ + long[] writePartitionedFile( + BlockId blockId, + TaskContext context, + File outputFile) throws IOException; + + void stop() throws IOException; +} diff --git a/core/src/main/scala/org/apache/spark/Aggregator.scala b/core/src/main/scala/org/apache/spark/Aggregator.scala index b8a5f5016860f..ceeb58075d345 100644 --- a/core/src/main/scala/org/apache/spark/Aggregator.scala +++ b/core/src/main/scala/org/apache/spark/Aggregator.scala @@ -34,8 +34,8 @@ case class Aggregator[K, V, C] ( mergeValue: (C, V) => C, mergeCombiners: (C, C) => C) { - // When spilling is enabled sorting will happen externally, but not necessarily with an - // ExternalSorter. + // When spilling is enabled sorting will happen externally, but not necessarily with an + // ExternalSorter. private val isSpillEnabled = SparkEnv.get.conf.getBoolean("spark.shuffle.spill", true) @deprecated("use combineValuesByKey with TaskContext argument", "0.9.0") diff --git a/core/src/main/scala/org/apache/spark/FutureAction.scala b/core/src/main/scala/org/apache/spark/FutureAction.scala index 91f9ef8ce7185..48792a958130c 100644 --- a/core/src/main/scala/org/apache/spark/FutureAction.scala +++ b/core/src/main/scala/org/apache/spark/FutureAction.scala @@ -150,7 +150,7 @@ class SimpleFutureAction[T] private[spark](jobWaiter: JobWaiter[_], resultFunc: } override def isCompleted: Boolean = jobWaiter.jobFinished - + override def isCancelled: Boolean = _cancelled override def value: Option[Try[T]] = { diff --git a/core/src/main/scala/org/apache/spark/HeartbeatReceiver.scala b/core/src/main/scala/org/apache/spark/HeartbeatReceiver.scala index f2b024ff6cb67..6909015ff66e6 100644 --- a/core/src/main/scala/org/apache/spark/HeartbeatReceiver.scala +++ b/core/src/main/scala/org/apache/spark/HeartbeatReceiver.scala @@ -29,7 +29,7 @@ import org.apache.spark.util.{ThreadUtils, Utils} /** * A heartbeat from executors to the driver. This is a shared message used by several internal - * components to convey liveness or execution information for in-progress tasks. It will also + * components to convey liveness or execution information for in-progress tasks. It will also * expire the hosts that have not heartbeated for more than spark.network.timeout. */ private[spark] case class Heartbeat( @@ -43,8 +43,8 @@ private[spark] case class Heartbeat( */ private[spark] case object TaskSchedulerIsSet -private[spark] case object ExpireDeadHosts - +private[spark] case object ExpireDeadHosts + private[spark] case class HeartbeatResponse(reregisterBlockManager: Boolean) /** @@ -62,18 +62,18 @@ private[spark] class HeartbeatReceiver(sc: SparkContext) // "spark.network.timeout" uses "seconds", while `spark.storage.blockManagerSlaveTimeoutMs` uses // "milliseconds" - private val slaveTimeoutMs = + private val slaveTimeoutMs = sc.conf.getTimeAsMs("spark.storage.blockManagerSlaveTimeoutMs", "120s") - private val executorTimeoutMs = + private val executorTimeoutMs = sc.conf.getTimeAsSeconds("spark.network.timeout", s"${slaveTimeoutMs}ms") * 1000 - + // "spark.network.timeoutInterval" uses "seconds", while // "spark.storage.blockManagerTimeoutIntervalMs" uses "milliseconds" - private val timeoutIntervalMs = + private val timeoutIntervalMs = sc.conf.getTimeAsMs("spark.storage.blockManagerTimeoutIntervalMs", "60s") - private val checkTimeoutIntervalMs = + private val checkTimeoutIntervalMs = sc.conf.getTimeAsSeconds("spark.network.timeoutInterval", s"${timeoutIntervalMs}ms") * 1000 - + private var timeoutCheckingTask: ScheduledFuture[_] = null // "eventLoopThread" is used to run some pretty fast actions. The actions running in it should not @@ -140,7 +140,7 @@ private[spark] class HeartbeatReceiver(sc: SparkContext) } } } - + override def onStop(): Unit = { if (timeoutCheckingTask != null) { timeoutCheckingTask.cancel(true) diff --git a/core/src/main/scala/org/apache/spark/HttpFileServer.scala b/core/src/main/scala/org/apache/spark/HttpFileServer.scala index 7e706bcc42f04..7cf7bc0dc6810 100644 --- a/core/src/main/scala/org/apache/spark/HttpFileServer.scala +++ b/core/src/main/scala/org/apache/spark/HttpFileServer.scala @@ -50,8 +50,8 @@ private[spark] class HttpFileServer( def stop() { httpServer.stop() - - // If we only stop sc, but the driver process still run as a services then we need to delete + + // If we only stop sc, but the driver process still run as a services then we need to delete // the tmp dir, if not, it will create too many tmp dirs try { Utils.deleteRecursively(baseDir) diff --git a/core/src/main/scala/org/apache/spark/SparkConf.scala b/core/src/main/scala/org/apache/spark/SparkConf.scala index 4b5bcb54aa873..46d72841dccce 100644 --- a/core/src/main/scala/org/apache/spark/SparkConf.scala +++ b/core/src/main/scala/org/apache/spark/SparkConf.scala @@ -227,7 +227,7 @@ class SparkConf(loadDefaults: Boolean) extends Cloneable with Logging { def getSizeAsBytes(key: String, defaultValue: String): Long = { Utils.byteStringAsBytes(get(key, defaultValue)) } - + /** * Get a size parameter as Kibibytes; throws a NoSuchElementException if it's not set. If no * suffix is provided then Kibibytes are assumed. @@ -244,7 +244,7 @@ class SparkConf(loadDefaults: Boolean) extends Cloneable with Logging { def getSizeAsKb(key: String, defaultValue: String): Long = { Utils.byteStringAsKb(get(key, defaultValue)) } - + /** * Get a size parameter as Mebibytes; throws a NoSuchElementException if it's not set. If no * suffix is provided then Mebibytes are assumed. @@ -261,7 +261,7 @@ class SparkConf(loadDefaults: Boolean) extends Cloneable with Logging { def getSizeAsMb(key: String, defaultValue: String): Long = { Utils.byteStringAsMb(get(key, defaultValue)) } - + /** * Get a size parameter as Gibibytes; throws a NoSuchElementException if it's not set. If no * suffix is provided then Gibibytes are assumed. @@ -278,7 +278,7 @@ class SparkConf(loadDefaults: Boolean) extends Cloneable with Logging { def getSizeAsGb(key: String, defaultValue: String): Long = { Utils.byteStringAsGb(get(key, defaultValue)) } - + /** Get a parameter as an Option */ def getOption(key: String): Option[String] = { Option(settings.get(key)).orElse(getDeprecatedConfig(key, this)) @@ -480,7 +480,7 @@ private[spark] object SparkConf extends Logging { "spark.kryoserializer.buffer.mb was previously specified as '0.064'. Fractional values " + "are no longer accepted. To specify the equivalent now, one may use '64k'.") ) - + Map(configs.map { cfg => (cfg.key -> cfg) } : _*) } @@ -508,7 +508,7 @@ private[spark] object SparkConf extends Logging { "spark.reducer.maxSizeInFlight" -> Seq( AlternateConfig("spark.reducer.maxMbInFlight", "1.4")), "spark.kryoserializer.buffer" -> - Seq(AlternateConfig("spark.kryoserializer.buffer.mb", "1.4", + Seq(AlternateConfig("spark.kryoserializer.buffer.mb", "1.4", translation = s => s"${(s.toDouble * 1000).toInt}k")), "spark.kryoserializer.buffer.max" -> Seq( AlternateConfig("spark.kryoserializer.buffer.max.mb", "1.4")), diff --git a/core/src/main/scala/org/apache/spark/TestUtils.scala b/core/src/main/scala/org/apache/spark/TestUtils.scala index fe6320b504e15..a1ebbecf93b7b 100644 --- a/core/src/main/scala/org/apache/spark/TestUtils.scala +++ b/core/src/main/scala/org/apache/spark/TestUtils.scala @@ -51,7 +51,7 @@ private[spark] object TestUtils { classpathUrls: Seq[URL] = Seq()): URL = { val tempDir = Utils.createTempDir() val files1 = for (name <- classNames) yield { - createCompiledClass(name, tempDir, toStringValue, classpathUrls = classpathUrls) + createCompiledClass(name, tempDir, toStringValue, classpathUrls = classpathUrls) } val files2 = for ((childName, baseName) <- classNamesWithBase) yield { createCompiledClass(childName, tempDir, toStringValue, baseName, classpathUrls) diff --git a/core/src/main/scala/org/apache/spark/api/java/JavaDoubleRDD.scala b/core/src/main/scala/org/apache/spark/api/java/JavaDoubleRDD.scala index 61af867b11b9c..a650df605b92e 100644 --- a/core/src/main/scala/org/apache/spark/api/java/JavaDoubleRDD.scala +++ b/core/src/main/scala/org/apache/spark/api/java/JavaDoubleRDD.scala @@ -137,7 +137,7 @@ class JavaDoubleRDD(val srdd: RDD[scala.Double]) */ def sample(withReplacement: Boolean, fraction: JDouble): JavaDoubleRDD = sample(withReplacement, fraction, Utils.random.nextLong) - + /** * Return a sampled subset of this RDD. */ diff --git a/core/src/main/scala/org/apache/spark/api/java/JavaRDD.scala b/core/src/main/scala/org/apache/spark/api/java/JavaRDD.scala index db4e996feb31c..ed312770ee131 100644 --- a/core/src/main/scala/org/apache/spark/api/java/JavaRDD.scala +++ b/core/src/main/scala/org/apache/spark/api/java/JavaRDD.scala @@ -101,7 +101,7 @@ class JavaRDD[T](val rdd: RDD[T])(implicit val classTag: ClassTag[T]) /** * Return a sampled subset of this RDD. - * + * * @param withReplacement can elements be sampled multiple times (replaced when sampled out) * @param fraction expected size of the sample as a fraction of this RDD's size * without replacement: probability that each element is chosen; fraction must be [0, 1] @@ -109,10 +109,10 @@ class JavaRDD[T](val rdd: RDD[T])(implicit val classTag: ClassTag[T]) */ def sample(withReplacement: Boolean, fraction: Double): JavaRDD[T] = sample(withReplacement, fraction, Utils.random.nextLong) - + /** * Return a sampled subset of this RDD. - * + * * @param withReplacement can elements be sampled multiple times (replaced when sampled out) * @param fraction expected size of the sample as a fraction of this RDD's size * without replacement: probability that each element is chosen; fraction must be [0, 1] diff --git a/core/src/main/scala/org/apache/spark/api/java/JavaRDDLike.scala b/core/src/main/scala/org/apache/spark/api/java/JavaRDDLike.scala index b8e15f38a20d2..c95615a5a9307 100644 --- a/core/src/main/scala/org/apache/spark/api/java/JavaRDDLike.scala +++ b/core/src/main/scala/org/apache/spark/api/java/JavaRDDLike.scala @@ -60,10 +60,13 @@ trait JavaRDDLike[T, This <: JavaRDDLike[T, This]] extends Serializable { @deprecated("Use partitions() instead.", "1.1.0") def splits: JList[Partition] = new java.util.ArrayList(rdd.partitions.toSeq) - + /** Set of partitions in this RDD. */ def partitions: JList[Partition] = new java.util.ArrayList(rdd.partitions.toSeq) + /** The partitioner of this RDD. */ + def partitioner: Optional[Partitioner] = JavaUtils.optionToOptional(rdd.partitioner) + /** The [[org.apache.spark.SparkContext]] that this RDD was created on. */ def context: SparkContext = rdd.context @@ -492,9 +495,9 @@ trait JavaRDDLike[T, This <: JavaRDDLike[T, This]] extends Serializable { new java.util.ArrayList(arr) } - def takeSample(withReplacement: Boolean, num: Int): JList[T] = + def takeSample(withReplacement: Boolean, num: Int): JList[T] = takeSample(withReplacement, num, Utils.random.nextLong) - + def takeSample(withReplacement: Boolean, num: Int, seed: Long): JList[T] = { import scala.collection.JavaConversions._ val arr: java.util.Collection[T] = rdd.takeSample(withReplacement, num, seed).toSeq diff --git a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala index a77bf42ce1d38..55a37f8c944b2 100644 --- a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala +++ b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala @@ -797,10 +797,10 @@ private class PythonAccumulatorParam(@transient serverHost: String, serverPort: val bufferSize = SparkEnv.get.conf.getInt("spark.buffer.size", 65536) - /** + /** * We try to reuse a single Socket to transfer accumulator updates, as they are all added * by the DAGScheduler's single-threaded actor anyway. - */ + */ @transient var socket: Socket = _ def openSocket(): Socket = synchronized { @@ -843,6 +843,7 @@ private class PythonAccumulatorParam(@transient serverHost: String, serverPort: * An Wrapper for Python Broadcast, which is written into disk by Python. It also will * write the data into disk after deserialization, then Python can read it from disks. */ +// scalastyle:off no.finalize private[spark] class PythonBroadcast(@transient var path: String) extends Serializable { /** @@ -884,3 +885,4 @@ private[spark] class PythonBroadcast(@transient var path: String) extends Serial } } } +// scalastyle:on no.finalize diff --git a/core/src/main/scala/org/apache/spark/api/r/RBackend.scala b/core/src/main/scala/org/apache/spark/api/r/RBackend.scala index 0a91977928cee..d24c650d37bb0 100644 --- a/core/src/main/scala/org/apache/spark/api/r/RBackend.scala +++ b/core/src/main/scala/org/apache/spark/api/r/RBackend.scala @@ -44,11 +44,11 @@ private[spark] class RBackend { bossGroup = new NioEventLoopGroup(2) val workerGroup = bossGroup val handler = new RBackendHandler(this) - + bootstrap = new ServerBootstrap() .group(bossGroup, workerGroup) .channel(classOf[NioServerSocketChannel]) - + bootstrap.childHandler(new ChannelInitializer[SocketChannel]() { def initChannel(ch: SocketChannel): Unit = { ch.pipeline() diff --git a/core/src/main/scala/org/apache/spark/api/r/RBackendHandler.scala b/core/src/main/scala/org/apache/spark/api/r/RBackendHandler.scala index 026a1b9380357..2e86984c66b3a 100644 --- a/core/src/main/scala/org/apache/spark/api/r/RBackendHandler.scala +++ b/core/src/main/scala/org/apache/spark/api/r/RBackendHandler.scala @@ -77,7 +77,7 @@ private[r] class RBackendHandler(server: RBackend) val reply = bos.toByteArray ctx.write(reply) } - + override def channelReadComplete(ctx: ChannelHandlerContext): Unit = { ctx.flush() } diff --git a/core/src/main/scala/org/apache/spark/api/r/RRDD.scala b/core/src/main/scala/org/apache/spark/api/r/RRDD.scala index e020458888e4a..4dfa7325934ff 100644 --- a/core/src/main/scala/org/apache/spark/api/r/RRDD.scala +++ b/core/src/main/scala/org/apache/spark/api/r/RRDD.scala @@ -355,7 +355,6 @@ private[r] object RRDD { val sparkConf = new SparkConf().setAppName(appName) .setSparkHome(sparkHome) - .setJars(jars) // Override `master` if we have a user-specified value if (master != "") { @@ -373,7 +372,11 @@ private[r] object RRDD { sparkConf.setExecutorEnv(name.asInstanceOf[String], value.asInstanceOf[String]) } - new JavaSparkContext(sparkConf) + val jsc = new JavaSparkContext(sparkConf) + jars.foreach { jar => + jsc.addJar(jar) + } + jsc } /** diff --git a/core/src/main/scala/org/apache/spark/api/r/SerDe.scala b/core/src/main/scala/org/apache/spark/api/r/SerDe.scala index 371dfe454d1a2..f8e3f1a79082e 100644 --- a/core/src/main/scala/org/apache/spark/api/r/SerDe.scala +++ b/core/src/main/scala/org/apache/spark/api/r/SerDe.scala @@ -157,9 +157,11 @@ private[spark] object SerDe { val keysLen = readInt(in) val keys = (0 until keysLen).map(_ => readTypedObject(in, keysType)) - val valuesType = readObjectType(in) val valuesLen = readInt(in) - val values = (0 until valuesLen).map(_ => readTypedObject(in, valuesType)) + val values = (0 until valuesLen).map(_ => { + val valueType = readObjectType(in) + readTypedObject(in, valueType) + }) mapAsJavaMap(keys.zip(values).toMap) } else { new java.util.HashMap[Object, Object]() diff --git a/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala b/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala index 92bb5059a0313..8cf4d58847d8e 100644 --- a/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala +++ b/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala @@ -428,6 +428,8 @@ object SparkSubmit { OptionAssigner(args.executorCores, YARN, CLIENT, sysProp = "spark.executor.cores"), OptionAssigner(args.files, YARN, CLIENT, sysProp = "spark.yarn.dist.files"), OptionAssigner(args.archives, YARN, CLIENT, sysProp = "spark.yarn.dist.archives"), + OptionAssigner(args.principal, YARN, CLIENT, sysProp = "spark.yarn.principal"), + OptionAssigner(args.keytab, YARN, CLIENT, sysProp = "spark.yarn.keytab"), // Yarn cluster only OptionAssigner(args.name, YARN, CLUSTER, clOption = "--name"), @@ -440,10 +442,8 @@ object SparkSubmit { OptionAssigner(args.files, YARN, CLUSTER, clOption = "--files"), OptionAssigner(args.archives, YARN, CLUSTER, clOption = "--archives"), OptionAssigner(args.jars, YARN, CLUSTER, clOption = "--addJars"), - - // Yarn client or cluster - OptionAssigner(args.principal, YARN, ALL_DEPLOY_MODES, clOption = "--principal"), - OptionAssigner(args.keytab, YARN, ALL_DEPLOY_MODES, clOption = "--keytab"), + OptionAssigner(args.principal, YARN, CLUSTER, clOption = "--principal"), + OptionAssigner(args.keytab, YARN, CLUSTER, clOption = "--keytab"), // Other options OptionAssigner(args.executorCores, STANDALONE, ALL_DEPLOY_MODES, @@ -869,7 +869,7 @@ private[spark] object SparkSubmitUtils { md.addDependency(dd) } } - + /** Add exclusion rules for dependencies already included in the spark-assembly */ def addExclusionRules( ivySettings: IvySettings, diff --git a/core/src/main/scala/org/apache/spark/deploy/SparkSubmitArguments.scala b/core/src/main/scala/org/apache/spark/deploy/SparkSubmitArguments.scala index c0e4c771908b3..cc6a7bd9f4119 100644 --- a/core/src/main/scala/org/apache/spark/deploy/SparkSubmitArguments.scala +++ b/core/src/main/scala/org/apache/spark/deploy/SparkSubmitArguments.scala @@ -169,6 +169,8 @@ private[deploy] class SparkSubmitArguments(args: Seq[String], env: Map[String, S deployMode = Option(deployMode).orElse(env.get("DEPLOY_MODE")).orNull numExecutors = Option(numExecutors) .getOrElse(sparkProperties.get("spark.executor.instances").orNull) + keytab = Option(keytab).orElse(sparkProperties.get("spark.yarn.keytab")).orNull + principal = Option(principal).orElse(sparkProperties.get("spark.yarn.principal")).orNull // Try to set main class from JAR if no --class argument is given if (mainClass == null && !isPython && !isR && primaryResource != null) { diff --git a/core/src/main/scala/org/apache/spark/deploy/history/HistoryServerArguments.scala b/core/src/main/scala/org/apache/spark/deploy/history/HistoryServerArguments.scala index a2a97a7877ce7..4692d22651c93 100644 --- a/core/src/main/scala/org/apache/spark/deploy/history/HistoryServerArguments.scala +++ b/core/src/main/scala/org/apache/spark/deploy/history/HistoryServerArguments.scala @@ -23,7 +23,7 @@ import org.apache.spark.util.Utils /** * Command-line parser for the master. */ -private[history] class HistoryServerArguments(conf: SparkConf, args: Array[String]) +private[history] class HistoryServerArguments(conf: SparkConf, args: Array[String]) extends Logging { private var propertiesFile: String = null diff --git a/core/src/main/scala/org/apache/spark/deploy/master/ZooKeeperPersistenceEngine.scala b/core/src/main/scala/org/apache/spark/deploy/master/ZooKeeperPersistenceEngine.scala index 80db6d474b5c1..328d95a7a0c68 100644 --- a/core/src/main/scala/org/apache/spark/deploy/master/ZooKeeperPersistenceEngine.scala +++ b/core/src/main/scala/org/apache/spark/deploy/master/ZooKeeperPersistenceEngine.scala @@ -32,7 +32,7 @@ import org.apache.spark.deploy.SparkCuratorUtil private[master] class ZooKeeperPersistenceEngine(conf: SparkConf, val serialization: Serialization) extends PersistenceEngine with Logging { - + private val WORKING_DIR = conf.get("spark.deploy.zookeeper.dir", "/spark") + "/master_status" private val zk: CuratorFramework = SparkCuratorUtil.newClient(conf) diff --git a/core/src/main/scala/org/apache/spark/deploy/master/ui/MasterPage.scala b/core/src/main/scala/org/apache/spark/deploy/master/ui/MasterPage.scala index 756927682cd24..6a7c74020bace 100644 --- a/core/src/main/scala/org/apache/spark/deploy/master/ui/MasterPage.scala +++ b/core/src/main/scala/org/apache/spark/deploy/master/ui/MasterPage.scala @@ -75,6 +75,7 @@ private[ui] class MasterPage(parent: MasterWebUI) extends WebUIPage("") { val workerHeaders = Seq("Worker Id", "Address", "State", "Cores", "Memory") val workers = state.workers.sortBy(_.id) + val aliveWorkers = state.workers.filter(_.state == WorkerState.ALIVE) val workerTable = UIUtils.listingTable(workerHeaders, workerRow, workers) val appHeaders = Seq("Application ID", "Name", "Cores", "Memory per Node", "Submitted Time", @@ -108,12 +109,12 @@ private[ui] class MasterPage(parent: MasterWebUI) extends WebUIPage("") { }.getOrElse { Seq.empty } } -

  • Workers: {state.workers.size}
  • -
  • Cores: {state.workers.map(_.cores).sum} Total, - {state.workers.map(_.coresUsed).sum} Used
  • -
  • Memory: - {Utils.megabytesToString(state.workers.map(_.memory).sum)} Total, - {Utils.megabytesToString(state.workers.map(_.memoryUsed).sum)} Used
  • +
  • Alive Workers: {aliveWorkers.size}
  • +
  • Cores in use: {aliveWorkers.map(_.cores).sum} Total, + {aliveWorkers.map(_.coresUsed).sum} Used
  • +
  • Memory in use: + {Utils.megabytesToString(aliveWorkers.map(_.memory).sum)} Total, + {Utils.megabytesToString(aliveWorkers.map(_.memoryUsed).sum)} Used
  • Applications: {state.activeApps.size} Running, {state.completedApps.size} Completed
  • diff --git a/core/src/main/scala/org/apache/spark/deploy/rest/RestSubmissionClient.scala b/core/src/main/scala/org/apache/spark/deploy/rest/RestSubmissionClient.scala index 6078f50518ba4..1fe956320a1b8 100644 --- a/core/src/main/scala/org/apache/spark/deploy/rest/RestSubmissionClient.scala +++ b/core/src/main/scala/org/apache/spark/deploy/rest/RestSubmissionClient.scala @@ -57,7 +57,11 @@ private[spark] class RestSubmissionClient(master: String) extends Logging { private val supportedMasterPrefixes = Seq("spark://", "mesos://") - private val masters: Array[String] = Utils.parseStandaloneMasterUrls(master) + private val masters: Array[String] = if (master.startsWith("spark://")) { + Utils.parseStandaloneMasterUrls(master) + } else { + Array(master) + } // Set of masters that lost contact with us, used to keep track of // whether there are masters still alive for us to communicate with diff --git a/core/src/main/scala/org/apache/spark/deploy/worker/ui/LogPage.scala b/core/src/main/scala/org/apache/spark/deploy/worker/ui/LogPage.scala index 88170d4df3053..dc2bee6f2bdca 100644 --- a/core/src/main/scala/org/apache/spark/deploy/worker/ui/LogPage.scala +++ b/core/src/main/scala/org/apache/spark/deploy/worker/ui/LogPage.scala @@ -29,6 +29,7 @@ import org.apache.spark.util.logging.RollingFileAppender private[ui] class LogPage(parent: WorkerWebUI) extends WebUIPage("logPage") with Logging { private val worker = parent.worker private val workDir = parent.workDir + private val supportedLogTypes = Set("stderr", "stdout") def renderLog(request: HttpServletRequest): String = { val defaultBytes = 100 * 1024 @@ -129,6 +130,11 @@ private[ui] class LogPage(parent: WorkerWebUI) extends WebUIPage("logPage") with offsetOption: Option[Long], byteLength: Int ): (String, Long, Long, Long) = { + + if (!supportedLogTypes.contains(logType)) { + return ("Error: Log type must be one of " + supportedLogTypes.mkString(", "), 0, 0, 0) + } + try { val files = RollingFileAppender.getSortedRolledOverFiles(logDirectory, logType) logDebug(s"Sorted log files of type $logType in $logDirectory:\n${files.mkString("\n")}") diff --git a/core/src/main/scala/org/apache/spark/executor/TaskMetrics.scala b/core/src/main/scala/org/apache/spark/executor/TaskMetrics.scala index d90ae405a0849..38b61d7242fce 100644 --- a/core/src/main/scala/org/apache/spark/executor/TaskMetrics.scala +++ b/core/src/main/scala/org/apache/spark/executor/TaskMetrics.scala @@ -43,22 +43,22 @@ class TaskMetrics extends Serializable { private var _hostname: String = _ def hostname: String = _hostname private[spark] def setHostname(value: String) = _hostname = value - + /** * Time taken on the executor to deserialize this task */ private var _executorDeserializeTime: Long = _ def executorDeserializeTime: Long = _executorDeserializeTime private[spark] def setExecutorDeserializeTime(value: Long) = _executorDeserializeTime = value - - + + /** * Time the executor spends actually running the task (including fetching shuffle data) */ private var _executorRunTime: Long = _ def executorRunTime: Long = _executorRunTime private[spark] def setExecutorRunTime(value: Long) = _executorRunTime = value - + /** * The number of bytes this task transmitted back to the driver as the TaskResult */ @@ -315,7 +315,7 @@ class ShuffleReadMetrics extends Serializable { def remoteBlocksFetched: Int = _remoteBlocksFetched private[spark] def incRemoteBlocksFetched(value: Int) = _remoteBlocksFetched += value private[spark] def decRemoteBlocksFetched(value: Int) = _remoteBlocksFetched -= value - + /** * Number of local blocks fetched in this shuffle by this task */ @@ -333,7 +333,7 @@ class ShuffleReadMetrics extends Serializable { def fetchWaitTime: Long = _fetchWaitTime private[spark] def incFetchWaitTime(value: Long) = _fetchWaitTime += value private[spark] def decFetchWaitTime(value: Long) = _fetchWaitTime -= value - + /** * Total number of remote bytes read from the shuffle by this task */ @@ -381,7 +381,7 @@ class ShuffleWriteMetrics extends Serializable { def shuffleBytesWritten: Long = _shuffleBytesWritten private[spark] def incShuffleBytesWritten(value: Long) = _shuffleBytesWritten += value private[spark] def decShuffleBytesWritten(value: Long) = _shuffleBytesWritten -= value - + /** * Time the task spent blocking on writes to disk or buffer cache, in nanoseconds */ @@ -389,7 +389,7 @@ class ShuffleWriteMetrics extends Serializable { def shuffleWriteTime: Long = _shuffleWriteTime private[spark] def incShuffleWriteTime(value: Long) = _shuffleWriteTime += value private[spark] def decShuffleWriteTime(value: Long) = _shuffleWriteTime -= value - + /** * Total number of records written to the shuffle by this task */ diff --git a/core/src/main/scala/org/apache/spark/metrics/sink/Slf4jSink.scala b/core/src/main/scala/org/apache/spark/metrics/sink/Slf4jSink.scala index e8b3074e8f1a6..11dfcfe2f04e1 100644 --- a/core/src/main/scala/org/apache/spark/metrics/sink/Slf4jSink.scala +++ b/core/src/main/scala/org/apache/spark/metrics/sink/Slf4jSink.scala @@ -26,9 +26,9 @@ import org.apache.spark.SecurityManager import org.apache.spark.metrics.MetricsSystem private[spark] class Slf4jSink( - val property: Properties, + val property: Properties, val registry: MetricRegistry, - securityMgr: SecurityManager) + securityMgr: SecurityManager) extends Sink { val SLF4J_DEFAULT_PERIOD = 10 val SLF4J_DEFAULT_UNIT = "SECONDS" diff --git a/core/src/main/scala/org/apache/spark/metrics/sink/package.scala b/core/src/main/scala/org/apache/spark/metrics/sink/package.scala index 90e3aa70b99ef..670e683663324 100644 --- a/core/src/main/scala/org/apache/spark/metrics/sink/package.scala +++ b/core/src/main/scala/org/apache/spark/metrics/sink/package.scala @@ -20,4 +20,4 @@ package org.apache.spark.metrics /** * Sinks used in Spark's metrics system. */ -package object sink +package object sink diff --git a/core/src/main/scala/org/apache/spark/network/nio/BlockMessage.scala b/core/src/main/scala/org/apache/spark/network/nio/BlockMessage.scala index 1a92a799d004a..67a376102994c 100644 --- a/core/src/main/scala/org/apache/spark/network/nio/BlockMessage.scala +++ b/core/src/main/scala/org/apache/spark/network/nio/BlockMessage.scala @@ -155,7 +155,7 @@ private[nio] class BlockMessage() { override def toString: String = { "BlockMessage [type = " + typ + ", id = " + id + ", level = " + level + - ", data = " + (if (data != null) data.remaining.toString else "null") + "]" + ", data = " + (if (data != null) data.remaining.toString else "null") + "]" } } diff --git a/core/src/main/scala/org/apache/spark/network/nio/Connection.scala b/core/src/main/scala/org/apache/spark/network/nio/Connection.scala index 6b898bd4bfc1b..1499da07bb83b 100644 --- a/core/src/main/scala/org/apache/spark/network/nio/Connection.scala +++ b/core/src/main/scala/org/apache/spark/network/nio/Connection.scala @@ -326,15 +326,14 @@ class SendingConnection(val address: InetSocketAddress, selector_ : Selector, // MUST be called within the selector loop def connect() { - try{ + try { channel.register(selector, SelectionKey.OP_CONNECT) channel.connect(address) logInfo("Initiating connection to [" + address + "]") } catch { - case e: Exception => { + case e: Exception => logError("Error connecting to " + address, e) callOnExceptionCallbacks(e) - } } } diff --git a/core/src/main/scala/org/apache/spark/network/nio/ConnectionManager.scala b/core/src/main/scala/org/apache/spark/network/nio/ConnectionManager.scala index 497871ed6d5e5..c0bca2c4bc994 100644 --- a/core/src/main/scala/org/apache/spark/network/nio/ConnectionManager.scala +++ b/core/src/main/scala/org/apache/spark/network/nio/ConnectionManager.scala @@ -635,12 +635,11 @@ private[nio] class ConnectionManager( val message = securityMsgResp.toBufferMessage if (message == null) throw new IOException("Error creating security message") sendSecurityMessage(waitingConn.getRemoteConnectionManagerId(), message) - } catch { - case e: Exception => { + } catch { + case e: Exception => logError("Error handling sasl client authentication", e) waitingConn.close() throw new IOException("Error evaluating sasl response: ", e) - } } } } diff --git a/core/src/main/scala/org/apache/spark/rdd/AsyncRDDActions.scala b/core/src/main/scala/org/apache/spark/rdd/AsyncRDDActions.scala index bbf1b83af0795..ca1eb1f4e4a9a 100644 --- a/core/src/main/scala/org/apache/spark/rdd/AsyncRDDActions.scala +++ b/core/src/main/scala/org/apache/spark/rdd/AsyncRDDActions.scala @@ -85,9 +85,9 @@ class AsyncRDDActions[T: ClassTag](self: RDD[T]) extends Serializable with Loggi numPartsToTry = partsScanned * 4 } else { // the left side of max is >=1 whenever partsScanned >= 2 - numPartsToTry = Math.max(1, + numPartsToTry = Math.max(1, (1.5 * num * partsScanned / results.size).toInt - partsScanned) - numPartsToTry = Math.min(numPartsToTry, partsScanned * 4) + numPartsToTry = Math.min(numPartsToTry, partsScanned * 4) } } diff --git a/core/src/main/scala/org/apache/spark/rdd/NewHadoopRDD.scala b/core/src/main/scala/org/apache/spark/rdd/NewHadoopRDD.scala index 2ab967f4bb313..84456d6d868dc 100644 --- a/core/src/main/scala/org/apache/spark/rdd/NewHadoopRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/NewHadoopRDD.scala @@ -196,7 +196,7 @@ class NewHadoopRDD[K, V]( override def getPreferredLocations(hsplit: Partition): Seq[String] = { val split = hsplit.asInstanceOf[NewHadoopPartition].serializableHadoopSplit.value val locs = HadoopRDD.SPLIT_INFO_REFLECTIONS match { - case Some(c) => + case Some(c) => try { val infos = c.newGetLocationInfo.invoke(split).asInstanceOf[Array[AnyRef]] Some(HadoopRDD.convertSplitLocationInfo(infos)) diff --git a/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala b/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala index 004899f27b7a6..cfd3e26faf2b9 100644 --- a/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala +++ b/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala @@ -328,7 +328,7 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)]) reduceByKeyLocally(func) } - /** + /** * Count the number of elements for each key, collecting the results to a local Map. * * Note that this method should only be used if the resulting map is expected to be small, as diff --git a/core/src/main/scala/org/apache/spark/rdd/PartitionerAwareUnionRDD.scala b/core/src/main/scala/org/apache/spark/rdd/PartitionerAwareUnionRDD.scala index 7598ff617b399..9e3880714a79f 100644 --- a/core/src/main/scala/org/apache/spark/rdd/PartitionerAwareUnionRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/PartitionerAwareUnionRDD.scala @@ -86,7 +86,7 @@ class PartitionerAwareUnionRDD[T: ClassTag]( } val location = if (locations.isEmpty) { None - } else { + } else { // Find the location that maximum number of parent partitions prefer Some(locations.groupBy(x => x).maxBy(_._2.length)._1) } diff --git a/core/src/main/scala/org/apache/spark/rdd/RDD.scala b/core/src/main/scala/org/apache/spark/rdd/RDD.scala index 5fcef255e13af..10610f4b6f1ff 100644 --- a/core/src/main/scala/org/apache/spark/rdd/RDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/RDD.scala @@ -434,11 +434,11 @@ abstract class RDD[T: ClassTag]( * @return A random sub-sample of the RDD without replacement. */ private[spark] def randomSampleWithRange(lb: Double, ub: Double, seed: Long): RDD[T] = { - this.mapPartitionsWithIndex { case (index, partition) => + this.mapPartitionsWithIndex( { (index, partition) => val sampler = new BernoulliCellSampler[T](lb, ub) sampler.setSeed(seed + index) sampler.sample(partition) - } + }, preservesPartitioning = true) } /** diff --git a/core/src/main/scala/org/apache/spark/scheduler/ReplayListenerBus.scala b/core/src/main/scala/org/apache/spark/scheduler/ReplayListenerBus.scala index 86f357abb8723..c6d957b65f3fb 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/ReplayListenerBus.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/ReplayListenerBus.scala @@ -41,7 +41,7 @@ private[spark] class ReplayListenerBus extends SparkListenerBus with Logging { * * @param logData Stream containing event log data. * @param sourceName Filename (or other source identifier) from whence @logData is being read - * @param maybeTruncated Indicate whether log file might be truncated (some abnormal situations + * @param maybeTruncated Indicate whether log file might be truncated (some abnormal situations * encountered, log file might not finished writing) or not */ def replay( @@ -62,7 +62,7 @@ private[spark] class ReplayListenerBus extends SparkListenerBus with Logging { if (!maybeTruncated || lines.hasNext) { throw jpe } else { - logWarning(s"Got JsonParseException from log file $sourceName" + + logWarning(s"Got JsonParseException from log file $sourceName" + s" at line $lineNumber, the file might not have finished writing cleanly.") } } diff --git a/core/src/main/scala/org/apache/spark/scheduler/Task.scala b/core/src/main/scala/org/apache/spark/scheduler/Task.scala index 586d1e06204c1..15101c64f0503 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/Task.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/Task.scala @@ -125,7 +125,7 @@ private[spark] abstract class Task[T](val stageId: Int, var partitionId: Int) ex if (interruptThread && taskThread != null) { taskThread.interrupt() } - } + } } /** diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala index d473e51abab80..673cd0e19eba2 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala @@ -861,9 +861,9 @@ private[spark] class TaskSetManager( case TaskLocality.RACK_LOCAL => "spark.locality.wait.rack" case _ => null } - + if (localityWaitKey != null) { - conf.getTimeAsMs(localityWaitKey, defaultWait) + conf.getTimeAsMs(localityWaitKey, defaultWait) } else { 0L } diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala index c5bc6294a5577..fcad959540f5a 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala @@ -84,7 +84,7 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, val rpcEnv: Rp override def onStart() { // Periodically revive offers to allow delay scheduling to work val reviveIntervalMs = conf.getTimeAsMs("spark.scheduler.revive.interval", "1s") - + reviveThread.scheduleAtFixedRate(new Runnable { override def run(): Unit = Utils.tryLogNonFatalError { Option(self).foreach(_.send(ReviveOffers)) diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerBackendUtil.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerBackendUtil.scala index 2f2934c249eb0..e79c543a9de27 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerBackendUtil.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerBackendUtil.scala @@ -37,14 +37,14 @@ private[mesos] object MesosSchedulerBackendUtil extends Logging { .newBuilder() .setMode(Volume.Mode.RW) spec match { - case Array(container_path) => + case Array(container_path) => Some(vol.setContainerPath(container_path)) case Array(container_path, "rw") => Some(vol.setContainerPath(container_path)) case Array(container_path, "ro") => Some(vol.setContainerPath(container_path) .setMode(Volume.Mode.RO)) - case Array(host_path, container_path) => + case Array(host_path, container_path) => Some(vol.setContainerPath(container_path) .setHostPath(host_path)) case Array(host_path, container_path, "rw") => diff --git a/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala b/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala index 3f909885dbd66..cd8a82347a1e9 100644 --- a/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala +++ b/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala @@ -52,7 +52,7 @@ class KryoSerializer(conf: SparkConf) with Serializable { private val bufferSizeKb = conf.getSizeAsKb("spark.kryoserializer.buffer", "64k") - + if (bufferSizeKb >= ByteUnit.GiB.toKiB(2)) { throw new IllegalArgumentException("spark.kryoserializer.buffer must be less than " + s"2048 mb, got: + ${ByteUnit.KiB.toMiB(bufferSizeKb)} mb.") diff --git a/core/src/main/scala/org/apache/spark/shuffle/hash/BlockStoreShuffleFetcher.scala b/core/src/main/scala/org/apache/spark/shuffle/hash/BlockStoreShuffleFetcher.scala index 80374adc44296..597d46a3d2223 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/hash/BlockStoreShuffleFetcher.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/hash/BlockStoreShuffleFetcher.scala @@ -80,7 +80,7 @@ private[hash] object BlockStoreShuffleFetcher extends Logging { blocksByAddress, serializer, // Note: we use getSizeAsMb when no suffix is provided for backwards compatibility - SparkEnv.get.conf.getSizeAsMb("spark.reducer.maxSizeInFlight", "48m") * 1024 * 1024) + SparkEnv.get.conf.getSizeAsMb("spark.reducer.maxSizeInFlight", "48m") * 1024 * 1024) val itr = blockFetcherItr.flatMap(unpackBlock) val completionIter = CompletionIterator[T, Iterator[T]](itr, { diff --git a/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleWriter.scala b/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleWriter.scala index c9dd6bfc4c219..5865e7640c1cf 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleWriter.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleWriter.scala @@ -17,9 +17,10 @@ package org.apache.spark.shuffle.sort -import org.apache.spark.{MapOutputTracker, SparkEnv, Logging, TaskContext} +import org.apache.spark._ import org.apache.spark.executor.ShuffleWriteMetrics import org.apache.spark.scheduler.MapStatus +import org.apache.spark.serializer.Serializer import org.apache.spark.shuffle.{IndexShuffleBlockResolver, ShuffleWriter, BaseShuffleHandle} import org.apache.spark.storage.ShuffleBlockId import org.apache.spark.util.collection.ExternalSorter @@ -35,7 +36,7 @@ private[spark] class SortShuffleWriter[K, V, C]( private val blockManager = SparkEnv.get.blockManager - private var sorter: ExternalSorter[K, V, _] = null + private var sorter: SortShuffleFileWriter[K, V] = null // Are we in the process of stopping? Because map tasks can call stop() with success = true // and then call stop() with success = false if they get an exception, we want to make sure @@ -49,18 +50,27 @@ private[spark] class SortShuffleWriter[K, V, C]( /** Write a bunch of records to this task's output */ override def write(records: Iterator[Product2[K, V]]): Unit = { - if (dep.mapSideCombine) { + sorter = if (dep.mapSideCombine) { require(dep.aggregator.isDefined, "Map-side combine without Aggregator specified!") - sorter = new ExternalSorter[K, V, C]( + new ExternalSorter[K, V, C]( dep.aggregator, Some(dep.partitioner), dep.keyOrdering, dep.serializer) - sorter.insertAll(records) + } else if (SortShuffleWriter.shouldBypassMergeSort( + SparkEnv.get.conf, dep.partitioner.numPartitions, aggregator = None, keyOrdering = None)) { + // If there are fewer than spark.shuffle.sort.bypassMergeThreshold partitions and we don't + // need local aggregation and sorting, write numPartitions files directly and just concatenate + // them at the end. This avoids doing serialization and deserialization twice to merge + // together the spilled files, which would happen with the normal code path. The downside is + // having multiple files open at a time and thus more memory allocated to buffers. + new BypassMergeSortShuffleWriter[K, V](SparkEnv.get.conf, blockManager, dep.partitioner, + writeMetrics, Serializer.getSerializer(dep.serializer)) } else { // In this case we pass neither an aggregator nor an ordering to the sorter, because we don't // care whether the keys get sorted in each partition; that will be done on the reduce side // if the operation being run is sortByKey. - sorter = new ExternalSorter[K, V, V](None, Some(dep.partitioner), None, dep.serializer) - sorter.insertAll(records) + new ExternalSorter[K, V, V]( + aggregator = None, Some(dep.partitioner), ordering = None, dep.serializer) } + sorter.insertAll(records) // Don't bother including the time to open the merged output file in the shuffle write time, // because it just opens a single file, so is typically too fast to measure accurately @@ -100,3 +110,13 @@ private[spark] class SortShuffleWriter[K, V, C]( } } +private[spark] object SortShuffleWriter { + def shouldBypassMergeSort( + conf: SparkConf, + numPartitions: Int, + aggregator: Option[Aggregator[_, _, _]], + keyOrdering: Option[Ordering[_]]): Boolean = { + val bypassMergeThreshold: Int = conf.getInt("spark.shuffle.sort.bypassMergeThreshold", 200) + numPartitions <= bypassMergeThreshold && aggregator.isEmpty && keyOrdering.isEmpty + } +} diff --git a/core/src/main/scala/org/apache/spark/status/api/v1/OneStageResource.scala b/core/src/main/scala/org/apache/spark/status/api/v1/OneStageResource.scala index fd24aea63a8a1..f9812f06cf527 100644 --- a/core/src/main/scala/org/apache/spark/status/api/v1/OneStageResource.scala +++ b/core/src/main/scala/org/apache/spark/status/api/v1/OneStageResource.scala @@ -83,7 +83,7 @@ private[v1] class OneStageResource(ui: SparkUI) { withStageAttempt(stageId, stageAttemptId) { stage => val tasks = stage.ui.taskData.values.map{AllStagesResource.convertTaskData}.toIndexedSeq .sorted(OneStageResource.ordering(sortBy)) - tasks.slice(offset, offset + length) + tasks.slice(offset, offset + length) } } diff --git a/core/src/main/scala/org/apache/spark/status/api/v1/SimpleDateParam.scala b/core/src/main/scala/org/apache/spark/status/api/v1/SimpleDateParam.scala index cee29786c3019..0c71cd2382225 100644 --- a/core/src/main/scala/org/apache/spark/status/api/v1/SimpleDateParam.scala +++ b/core/src/main/scala/org/apache/spark/status/api/v1/SimpleDateParam.scala @@ -16,40 +16,33 @@ */ package org.apache.spark.status.api.v1 -import java.text.SimpleDateFormat +import java.text.{ParseException, SimpleDateFormat} import java.util.TimeZone import javax.ws.rs.WebApplicationException import javax.ws.rs.core.Response import javax.ws.rs.core.Response.Status -import scala.util.Try - private[v1] class SimpleDateParam(val originalValue: String) { - val timestamp: Long = { - SimpleDateParam.formats.collectFirst { - case fmt if Try(fmt.parse(originalValue)).isSuccess => - fmt.parse(originalValue).getTime() - }.getOrElse( - throw new WebApplicationException( - Response - .status(Status.BAD_REQUEST) - .entity("Couldn't parse date: " + originalValue) - .build() - ) - ) - } -} -private[v1] object SimpleDateParam { - - val formats: Seq[SimpleDateFormat] = { - - val gmtDay = new SimpleDateFormat("yyyy-MM-dd") - gmtDay.setTimeZone(TimeZone.getTimeZone("GMT")) - - Seq( - new SimpleDateFormat("yyyy-MM-dd'T'HH:mm:ss.SSSz"), - gmtDay - ) + val timestamp: Long = { + val format = new SimpleDateFormat("yyyy-MM-dd'T'HH:mm:ss.SSSz") + try { + format.parse(originalValue).getTime() + } catch { + case _: ParseException => + val gmtDay = new SimpleDateFormat("yyyy-MM-dd") + gmtDay.setTimeZone(TimeZone.getTimeZone("GMT")) + try { + gmtDay.parse(originalValue).getTime() + } catch { + case _: ParseException => + throw new WebApplicationException( + Response + .status(Status.BAD_REQUEST) + .entity("Couldn't parse date: " + originalValue) + .build() + ) + } + } } } diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterEndpoint.scala b/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterEndpoint.scala index 3afb4c3c02e2d..2cd8c5297b741 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterEndpoint.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterEndpoint.scala @@ -292,16 +292,16 @@ class BlockManagerMasterEndpoint( blockManagerIdByExecutor.get(id.executorId) match { case Some(oldId) => // A block manager of the same executor already exists, so remove it (assumed dead) - logError("Got two different block manager registrations on same executor - " + logError("Got two different block manager registrations on same executor - " + s" will replace old one $oldId with new one $id") - removeExecutor(id.executorId) + removeExecutor(id.executorId) case None => } logInfo("Registering block manager %s with %s RAM, %s".format( id.hostPort, Utils.bytesToString(maxMemSize), id)) - + blockManagerIdByExecutor(id.executorId) = id - + blockManagerInfo(id) = new BlockManagerInfo( id, System.currentTimeMillis(), maxMemSize, slaveEndpoint) } diff --git a/core/src/main/scala/org/apache/spark/storage/BlockObjectWriter.scala b/core/src/main/scala/org/apache/spark/storage/BlockObjectWriter.scala index a33f22ef52687..7eeabd1e0489c 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockObjectWriter.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockObjectWriter.scala @@ -95,6 +95,7 @@ private[spark] class DiskBlockObjectWriter( private var objOut: SerializationStream = null private var initialized = false private var hasBeenClosed = false + private var commitAndCloseHasBeenCalled = false /** * Cursors used to represent positions in the file. @@ -167,20 +168,22 @@ private[spark] class DiskBlockObjectWriter( objOut.flush() bs.flush() close() + finalPosition = file.length() + // In certain compression codecs, more bytes are written after close() is called + writeMetrics.incShuffleBytesWritten(finalPosition - reportedPosition) + } else { + finalPosition = file.length() } - finalPosition = file.length() - // In certain compression codecs, more bytes are written after close() is called - writeMetrics.incShuffleBytesWritten(finalPosition - reportedPosition) + commitAndCloseHasBeenCalled = true } // Discard current writes. We do this by flushing the outstanding writes and then // truncating the file to its initial position. override def revertPartialWritesAndClose() { try { - writeMetrics.decShuffleBytesWritten(reportedPosition - initialPosition) - writeMetrics.decShuffleRecordsWritten(numRecordsWritten) - if (initialized) { + writeMetrics.decShuffleBytesWritten(reportedPosition - initialPosition) + writeMetrics.decShuffleRecordsWritten(numRecordsWritten) objOut.flush() bs.flush() close() @@ -228,6 +231,10 @@ private[spark] class DiskBlockObjectWriter( } override def fileSegment(): FileSegment = { + if (!commitAndCloseHasBeenCalled) { + throw new IllegalStateException( + "fileSegment() is only valid after commitAndClose() has been called") + } new FileSegment(file, initialPosition, finalPosition - initialPosition) } diff --git a/core/src/main/scala/org/apache/spark/storage/DiskBlockManager.scala b/core/src/main/scala/org/apache/spark/storage/DiskBlockManager.scala index d441a4d31b954..91ef86389a0c3 100644 --- a/core/src/main/scala/org/apache/spark/storage/DiskBlockManager.scala +++ b/core/src/main/scala/org/apache/spark/storage/DiskBlockManager.scala @@ -151,7 +151,7 @@ private[spark] class DiskBlockManager(blockManager: BlockManager, conf: SparkCon try { Utils.removeShutdownHook(shutdownHook) } catch { - case e: Exception => + case e: Exception => logError(s"Exception while removing shutdown hook.", e) } doStop() diff --git a/core/src/main/scala/org/apache/spark/storage/FileSegment.scala b/core/src/main/scala/org/apache/spark/storage/FileSegment.scala index 95e2d688d9b17..021a9facfb0b2 100644 --- a/core/src/main/scala/org/apache/spark/storage/FileSegment.scala +++ b/core/src/main/scala/org/apache/spark/storage/FileSegment.scala @@ -24,6 +24,8 @@ import java.io.File * based off an offset and a length. */ private[spark] class FileSegment(val file: File, val offset: Long, val length: Long) { + require(offset >= 0, s"File segment offset cannot be negative (got $offset)") + require(length >= 0, s"File segment length cannot be negative (got $length)") override def toString: String = { "(name=%s, offset=%d, length=%d)".format(file.getName, offset, length) } diff --git a/core/src/main/scala/org/apache/spark/storage/TachyonBlockManager.scala b/core/src/main/scala/org/apache/spark/storage/TachyonBlockManager.scala index fb4ba0eac9d9a..b53c86e89a273 100644 --- a/core/src/main/scala/org/apache/spark/storage/TachyonBlockManager.scala +++ b/core/src/main/scala/org/apache/spark/storage/TachyonBlockManager.scala @@ -100,7 +100,7 @@ private[spark] class TachyonBlockManager() extends ExternalBlockManager with Log try { os.write(bytes.array()) } catch { - case NonFatal(e) => + case NonFatal(e) => logWarning(s"Failed to put bytes of block $blockId into Tachyon", e) os.cancel() } finally { @@ -114,7 +114,7 @@ private[spark] class TachyonBlockManager() extends ExternalBlockManager with Log try { blockManager.dataSerializeStream(blockId, os, values) } catch { - case NonFatal(e) => + case NonFatal(e) => logWarning(s"Failed to put values of block $blockId into Tachyon", e) os.cancel() } finally { diff --git a/core/src/main/scala/org/apache/spark/ui/WebUI.scala b/core/src/main/scala/org/apache/spark/ui/WebUI.scala index 594df15e9cc85..2c84e4485996e 100644 --- a/core/src/main/scala/org/apache/spark/ui/WebUI.scala +++ b/core/src/main/scala/org/apache/spark/ui/WebUI.scala @@ -62,12 +62,12 @@ private[spark] abstract class WebUI( tab.pages.foreach(attachPage) tabs += tab } - + def detachTab(tab: WebUITab) { tab.pages.foreach(detachPage) tabs -= tab } - + def detachPage(page: WebUIPage) { pageToHandlers.remove(page).foreach(_.foreach(detachHandler)) } diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/JobProgressListener.scala b/core/src/main/scala/org/apache/spark/ui/jobs/JobProgressListener.scala index 246e191d64776..f39e961772c46 100644 --- a/core/src/main/scala/org/apache/spark/ui/jobs/JobProgressListener.scala +++ b/core/src/main/scala/org/apache/spark/ui/jobs/JobProgressListener.scala @@ -119,7 +119,7 @@ class JobProgressListener(conf: SparkConf) extends SparkListener with Logging { "failedStages" -> failedStages.size ) } - + // These collections may grow arbitrarily, but once Spark becomes idle they should shrink back to // some bound based on the `spark.ui.retainedStages` and `spark.ui.retainedJobs` settings: private[spark] def getSizesOfSoftSizeLimitedCollections: Map[String, Int] = { diff --git a/core/src/main/scala/org/apache/spark/util/AsynchronousListenerBus.scala b/core/src/main/scala/org/apache/spark/util/AsynchronousListenerBus.scala index ce7887b76ff96..1861d38640102 100644 --- a/core/src/main/scala/org/apache/spark/util/AsynchronousListenerBus.scala +++ b/core/src/main/scala/org/apache/spark/util/AsynchronousListenerBus.scala @@ -40,7 +40,7 @@ private[spark] abstract class AsynchronousListenerBus[L <: AnyRef, E](name: Stri self => private var sparkContext: SparkContext = null - + /* Cap the capacity of the event queue so we get an explicit error (rather than * an OOM exception) if it's perpetually being added to more quickly than it's being drained. */ private val EVENT_QUEUE_CAPACITY = 10000 diff --git a/core/src/main/scala/org/apache/spark/util/SizeEstimator.scala b/core/src/main/scala/org/apache/spark/util/SizeEstimator.scala index f1f6b5e1f93d8..0180399c9dad5 100644 --- a/core/src/main/scala/org/apache/spark/util/SizeEstimator.scala +++ b/core/src/main/scala/org/apache/spark/util/SizeEstimator.scala @@ -236,7 +236,7 @@ object SizeEstimator extends Logging { val s1 = sampleArray(array, state, rand, drawn, length) val s2 = sampleArray(array, state, rand, drawn, length) val size = math.min(s1, s2) - state.size += math.max(s1, s2) + + state.size += math.max(s1, s2) + (size * ((length - ARRAY_SAMPLE_SIZE) / (ARRAY_SAMPLE_SIZE))).toLong } } @@ -244,7 +244,7 @@ object SizeEstimator extends Logging { private def sampleArray( array: AnyRef, - state: SearchState, + state: SearchState, rand: Random, drawn: OpenHashSet[Int], length: Int): Long = { diff --git a/core/src/main/scala/org/apache/spark/util/collection/ExternalAppendOnlyMap.scala b/core/src/main/scala/org/apache/spark/util/collection/ExternalAppendOnlyMap.scala index df2d6ad3b41a4..1e4531ef395ae 100644 --- a/core/src/main/scala/org/apache/spark/util/collection/ExternalAppendOnlyMap.scala +++ b/core/src/main/scala/org/apache/spark/util/collection/ExternalAppendOnlyMap.scala @@ -89,9 +89,9 @@ class ExternalAppendOnlyMap[K, V, C]( // Number of bytes spilled in total private var _diskBytesSpilled = 0L - + // Use getSizeAsKb (not bytes) to maintain backwards compatibility if no units are provided - private val fileBufferSize = + private val fileBufferSize = sparkConf.getSizeAsKb("spark.shuffle.file.buffer", "32k").toInt * 1024 // Write metrics for current spill diff --git a/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala b/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala index 3b9d14f9372b6..757dec66c203b 100644 --- a/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala +++ b/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala @@ -23,12 +23,14 @@ import java.util.Comparator import scala.collection.mutable.ArrayBuffer import scala.collection.mutable +import com.google.common.annotations.VisibleForTesting import com.google.common.io.ByteStreams import org.apache.spark._ import org.apache.spark.serializer._ import org.apache.spark.executor.ShuffleWriteMetrics -import org.apache.spark.storage.{BlockObjectWriter, BlockId} +import org.apache.spark.shuffle.sort.{SortShuffleFileWriter, SortShuffleWriter} +import org.apache.spark.storage.{BlockId, BlockObjectWriter} /** * Sorts and potentially merges a number of key-value pairs of type (K, V) to produce key-combiner @@ -84,35 +86,40 @@ import org.apache.spark.storage.{BlockObjectWriter, BlockId} * each other for equality to merge values. * * - Users are expected to call stop() at the end to delete all the intermediate files. - * - * As a special case, if no Ordering and no Aggregator is given, and the number of partitions is - * less than spark.shuffle.sort.bypassMergeThreshold, we bypass the merge-sort and just write to - * separate files for each partition each time we spill, similar to the HashShuffleWriter. We can - * then concatenate these files to produce a single sorted file, without having to serialize and - * de-serialize each item twice (as is needed during the merge). This speeds up the map side of - * groupBy, sort, etc operations since they do no partial aggregation. */ private[spark] class ExternalSorter[K, V, C]( aggregator: Option[Aggregator[K, V, C]] = None, partitioner: Option[Partitioner] = None, ordering: Option[Ordering[K]] = None, serializer: Option[Serializer] = None) - extends Logging with Spillable[WritablePartitionedPairCollection[K, C]] { + extends Logging + with Spillable[WritablePartitionedPairCollection[K, C]] + with SortShuffleFileWriter[K, V] { + + private val conf = SparkEnv.get.conf private val numPartitions = partitioner.map(_.numPartitions).getOrElse(1) private val shouldPartition = numPartitions > 1 + private def getPartition(key: K): Int = { + if (shouldPartition) partitioner.get.getPartition(key) else 0 + } + + // Since SPARK-7855, bypassMergeSort optimization is no longer performed as part of this class. + // As a sanity check, make sure that we're not handling a shuffle which should use that path. + if (SortShuffleWriter.shouldBypassMergeSort(conf, numPartitions, aggregator, ordering)) { + throw new IllegalArgumentException("ExternalSorter should not be used to handle " + + " a sort that the BypassMergeSortShuffleWriter should handle") + } private val blockManager = SparkEnv.get.blockManager private val diskBlockManager = blockManager.diskBlockManager private val ser = Serializer.getSerializer(serializer) private val serInstance = ser.newInstance() - private val conf = SparkEnv.get.conf private val spillingEnabled = conf.getBoolean("spark.shuffle.spill", true) - + // Use getSizeAsKb (not bytes) to maintain backwards compatibility if no units are provided private val fileBufferSize = conf.getSizeAsKb("spark.shuffle.file.buffer", "32k").toInt * 1024 - private val transferToEnabled = conf.getBoolean("spark.file.transferTo", true) // Size of object batches when reading/writing from serializers. // @@ -123,43 +130,28 @@ private[spark] class ExternalSorter[K, V, C]( // grow internal data structures by growing + copying every time the number of objects doubles. private val serializerBatchSize = conf.getLong("spark.shuffle.spill.batchSize", 10000) - private def getPartition(key: K): Int = { - if (shouldPartition) partitioner.get.getPartition(key) else 0 - } - - private val metaInitialRecords = 256 - private val kvChunkSize = conf.getInt("spark.shuffle.sort.kvChunkSize", 1 << 22) // 4 MB private val useSerializedPairBuffer = - !ordering.isDefined && conf.getBoolean("spark.shuffle.sort.serializeMapOutputs", true) && - ser.supportsRelocationOfSerializedObjects - + ordering.isEmpty && + conf.getBoolean("spark.shuffle.sort.serializeMapOutputs", true) && + ser.supportsRelocationOfSerializedObjects + private val kvChunkSize = conf.getInt("spark.shuffle.sort.kvChunkSize", 1 << 22) // 4 MB + private def newBuffer(): WritablePartitionedPairCollection[K, C] with SizeTracker = { + if (useSerializedPairBuffer) { + new PartitionedSerializedPairBuffer(metaInitialRecords = 256, kvChunkSize, serInstance) + } else { + new PartitionedPairBuffer[K, C] + } + } // Data structures to store in-memory objects before we spill. Depending on whether we have an // Aggregator set, we either put objects into an AppendOnlyMap where we combine them, or we // store them in an array buffer. private var map = new PartitionedAppendOnlyMap[K, C] - private var buffer = if (useSerializedPairBuffer) { - new PartitionedSerializedPairBuffer[K, C](metaInitialRecords, kvChunkSize, serInstance) - } else { - new PartitionedPairBuffer[K, C] - } + private var buffer = newBuffer() // Total spilling statistics private var _diskBytesSpilled = 0L + def diskBytesSpilled: Long = _diskBytesSpilled - // Write metrics for current spill - private var curWriteMetrics: ShuffleWriteMetrics = _ - - // If there are fewer than spark.shuffle.sort.bypassMergeThreshold partitions and we don't need - // local aggregation and sorting, write numPartitions files directly and just concatenate them - // at the end. This avoids doing serialization and deserialization twice to merge together the - // spilled files, which would happen with the normal code path. The downside is having multiple - // files open at a time and thus more memory allocated to buffers. - private val bypassMergeThreshold = conf.getInt("spark.shuffle.sort.bypassMergeThreshold", 200) - private val bypassMergeSort = - (numPartitions <= bypassMergeThreshold && aggregator.isEmpty && ordering.isEmpty) - - // Array of file writers for each partition, used if bypassMergeSort is true and we've spilled - private var partitionWriters: Array[BlockObjectWriter] = null // A comparator for keys K that orders them within a partition to allow aggregation or sorting. // Can be a partial ordering by hash code if a total ordering is not provided through by the @@ -174,6 +166,14 @@ private[spark] class ExternalSorter[K, V, C]( } }) + private def comparator: Option[Comparator[K]] = { + if (ordering.isDefined || aggregator.isDefined) { + Some(keyComparator) + } else { + None + } + } + // Information about a spilled file. Includes sizes in bytes of "batches" written by the // serializer as we periodically reset its stream, as well as number of elements in each // partition, used to efficiently keep track of partitions when merging. @@ -182,9 +182,10 @@ private[spark] class ExternalSorter[K, V, C]( blockId: BlockId, serializerBatchSizes: Array[Long], elementsPerPartition: Array[Long]) + private val spills = new ArrayBuffer[SpilledFile] - def insertAll(records: Iterator[_ <: Product2[K, V]]): Unit = { + override def insertAll(records: Iterator[Product2[K, V]]): Unit = { // TODO: stop combining if we find that the reduction factor isn't high val shouldCombine = aggregator.isDefined @@ -202,15 +203,6 @@ private[spark] class ExternalSorter[K, V, C]( map.changeValue((getPartition(kv._1), kv._1), update) maybeSpillCollection(usingMap = true) } - } else if (bypassMergeSort) { - // SPARK-4479: Also bypass buffering if merge sort is bypassed to avoid defensive copies - if (records.hasNext) { - spillToPartitionFiles( - WritablePartitionedIterator.fromIterator(records.map { kv => - ((getPartition(kv._1), kv._1), kv._2.asInstanceOf[C]) - }) - ) - } } else { // Stick values into our buffer while (records.hasNext) { @@ -238,46 +230,33 @@ private[spark] class ExternalSorter[K, V, C]( } } else { if (maybeSpill(buffer, buffer.estimateSize())) { - buffer = if (useSerializedPairBuffer) { - new PartitionedSerializedPairBuffer[K, C](metaInitialRecords, kvChunkSize, serInstance) - } else { - new PartitionedPairBuffer[K, C] - } + buffer = newBuffer() } } } /** - * Spill the current in-memory collection to disk, adding a new file to spills, and clear it. - */ - override protected[this] def spill(collection: WritablePartitionedPairCollection[K, C]): Unit = { - if (bypassMergeSort) { - spillToPartitionFiles(collection) - } else { - spillToMergeableFile(collection) - } - } - - /** - * Spill our in-memory collection to a sorted file that we can merge later (normal code path). - * We add this file into spilledFiles to find it later. - * - * This should not be invoked if bypassMergeSort is true. In that case, spillToPartitionedFiles() - * is used to write files for each partition. + * Spill our in-memory collection to a sorted file that we can merge later. + * We add this file into `spilledFiles` to find it later. * * @param collection whichever collection we're using (map or buffer) */ - private def spillToMergeableFile(collection: WritablePartitionedPairCollection[K, C]): Unit = { - assert(!bypassMergeSort) - + override protected[this] def spill(collection: WritablePartitionedPairCollection[K, C]): Unit = { // Because these files may be read during shuffle, their compression must be controlled by // spark.shuffle.compress instead of spark.shuffle.spill.compress, so we need to use // createTempShuffleBlock here; see SPARK-3426 for more context. val (blockId, file) = diskBlockManager.createTempShuffleBlock() - curWriteMetrics = new ShuffleWriteMetrics() - var writer = blockManager.getDiskWriter( - blockId, file, serInstance, fileBufferSize, curWriteMetrics) - var objectsWritten = 0 // Objects written since the last flush + + // These variables are reset after each flush + var objectsWritten: Long = 0 + var spillMetrics: ShuffleWriteMetrics = null + var writer: BlockObjectWriter = null + def openWriter(): Unit = { + assert (writer == null && spillMetrics == null) + spillMetrics = new ShuffleWriteMetrics + writer = blockManager.getDiskWriter(blockId, file, serInstance, fileBufferSize, spillMetrics) + } + openWriter() // List of batch sizes (bytes) in the order they are written to disk val batchSizes = new ArrayBuffer[Long] @@ -291,8 +270,9 @@ private[spark] class ExternalSorter[K, V, C]( val w = writer writer = null w.commitAndClose() - _diskBytesSpilled += curWriteMetrics.shuffleBytesWritten - batchSizes.append(curWriteMetrics.shuffleBytesWritten) + _diskBytesSpilled += spillMetrics.shuffleBytesWritten + batchSizes.append(spillMetrics.shuffleBytesWritten) + spillMetrics = null objectsWritten = 0 } @@ -307,9 +287,7 @@ private[spark] class ExternalSorter[K, V, C]( if (objectsWritten == serializerBatchSize) { flush() - curWriteMetrics = new ShuffleWriteMetrics() - writer = blockManager.getDiskWriter( - blockId, file, serInstance, fileBufferSize, curWriteMetrics) + openWriter() } } if (objectsWritten > 0) { @@ -336,46 +314,6 @@ private[spark] class ExternalSorter[K, V, C]( spills.append(SpilledFile(file, blockId, batchSizes.toArray, elementsPerPartition)) } - /** - * Spill our in-memory collection to separate files, one for each partition. This is used when - * there's no aggregator and ordering and the number of partitions is small, because it allows - * writePartitionedFile to just concatenate files without deserializing data. - * - * @param collection whichever collection we're using (map or buffer) - */ - private def spillToPartitionFiles(collection: WritablePartitionedPairCollection[K, C]): Unit = { - spillToPartitionFiles(collection.writablePartitionedIterator()) - } - - private def spillToPartitionFiles(iterator: WritablePartitionedIterator): Unit = { - assert(bypassMergeSort) - - // Create our file writers if we haven't done so yet - if (partitionWriters == null) { - curWriteMetrics = new ShuffleWriteMetrics() - val openStartTime = System.nanoTime - partitionWriters = Array.fill(numPartitions) { - // Because these files may be read during shuffle, their compression must be controlled by - // spark.shuffle.compress instead of spark.shuffle.spill.compress, so we need to use - // createTempShuffleBlock here; see SPARK-3426 for more context. - val (blockId, file) = diskBlockManager.createTempShuffleBlock() - val writer = blockManager.getDiskWriter(blockId, file, serInstance, fileBufferSize, - curWriteMetrics) - writer.open() - } - // Creating the file to write to and creating a disk writer both involve interacting with - // the disk, and can take a long time in aggregate when we open many files, so should be - // included in the shuffle write time. - curWriteMetrics.incShuffleWriteTime(System.nanoTime - openStartTime) - } - - // No need to sort stuff, just write each element out - while (iterator.hasNext) { - val partitionId = iterator.nextPartition() - iterator.writeNext(partitionWriters(partitionId)) - } - } - /** * Merge a sequence of sorted files, giving an iterator over partitions and then over elements * inside each partition. This can be used to either write out a new file or return data to @@ -665,8 +603,6 @@ private[spark] class ExternalSorter[K, V, C]( } /** - * Exposed for testing purposes. - * * Return an iterator over all the data written to this object, grouped by partition and * aggregated by the requested aggregator. For each partition we then have an iterator over its * contents, and these are expected to be accessed in order (you can't "skip ahead" to one @@ -676,10 +612,11 @@ private[spark] class ExternalSorter[K, V, C]( * For now, we just merge all the spilled files in once pass, but this can be modified to * support hierarchical merging. */ - def partitionedIterator: Iterator[(Int, Iterator[Product2[K, C]])] = { + @VisibleForTesting + def partitionedIterator: Iterator[(Int, Iterator[Product2[K, C]])] = { val usingMap = aggregator.isDefined val collection: WritablePartitionedPairCollection[K, C] = if (usingMap) map else buffer - if (spills.isEmpty && partitionWriters == null) { + if (spills.isEmpty) { // Special case: if we have only in-memory data, we don't need to merge streams, and perhaps // we don't even need to sort by anything other than partition ID if (!ordering.isDefined) { @@ -689,13 +626,6 @@ private[spark] class ExternalSorter[K, V, C]( // We do need to sort by both partition ID and key groupByPartition(collection.partitionedDestructiveSortedIterator(Some(keyComparator))) } - } else if (bypassMergeSort) { - // Read data from each partition file and merge it together with the data in memory; - // note that there's no ordering or aggregator in this case -- we just partition objects - val collIter = groupByPartition(collection.partitionedDestructiveSortedIterator(None)) - collIter.map { case (partitionId, values) => - (partitionId, values ++ readPartitionFile(partitionWriters(partitionId))) - } } else { // Merge spilled and in-memory data merge(spills, collection.partitionedDestructiveSortedIterator(comparator)) @@ -709,14 +639,13 @@ private[spark] class ExternalSorter[K, V, C]( /** * Write all the data added into this ExternalSorter into a file in the disk store. This is - * called by the SortShuffleWriter and can go through an efficient path of just concatenating - * binary files if we decided to avoid merge-sorting. + * called by the SortShuffleWriter. * * @param blockId block ID to write to. The index file will be blockId.name + ".index". * @param context a TaskContext for a running Spark task, for us to update shuffle metrics. * @return array of lengths, in bytes, of each partition of the file (used by map output tracker) */ - def writePartitionedFile( + override def writePartitionedFile( blockId: BlockId, context: TaskContext, outputFile: File): Array[Long] = { @@ -724,28 +653,7 @@ private[spark] class ExternalSorter[K, V, C]( // Track location of each range in the output file val lengths = new Array[Long](numPartitions) - if (bypassMergeSort && partitionWriters != null) { - // We decided to write separate files for each partition, so just concatenate them. To keep - // this simple we spill out the current in-memory collection so that everything is in files. - spillToPartitionFiles(if (aggregator.isDefined) map else buffer) - partitionWriters.foreach(_.commitAndClose()) - val out = new FileOutputStream(outputFile, true) - val writeStartTime = System.nanoTime - util.Utils.tryWithSafeFinally { - for (i <- 0 until numPartitions) { - val in = new FileInputStream(partitionWriters(i).fileSegment().file) - util.Utils.tryWithSafeFinally { - lengths(i) = org.apache.spark.util.Utils.copyStream(in, out, false, transferToEnabled) - } { - in.close() - } - } - } { - out.close() - context.taskMetrics.shuffleWriteMetrics.foreach( - _.incShuffleWriteTime(System.nanoTime - writeStartTime)) - } - } else if (spills.isEmpty && partitionWriters == null) { + if (spills.isEmpty) { // Case where we only have in-memory data val collection = if (aggregator.isDefined) map else buffer val it = collection.destructiveSortedWritablePartitionedIterator(comparator) @@ -761,7 +669,7 @@ private[spark] class ExternalSorter[K, V, C]( lengths(partitionId) = segment.length } } else { - // Not bypassing merge-sort; get an iterator by partition and just write everything directly. + // We must perform merge-sort; get an iterator by partition and write everything directly. for ((id, elements) <- this.partitionedIterator) { if (elements.hasNext) { val writer = blockManager.getDiskWriter(blockId, outputFile, serInstance, fileBufferSize, @@ -778,41 +686,15 @@ private[spark] class ExternalSorter[K, V, C]( context.taskMetrics.incMemoryBytesSpilled(memoryBytesSpilled) context.taskMetrics.incDiskBytesSpilled(diskBytesSpilled) - context.taskMetrics.shuffleWriteMetrics.filter(_ => bypassMergeSort).foreach { m => - if (curWriteMetrics != null) { - m.incShuffleBytesWritten(curWriteMetrics.shuffleBytesWritten) - m.incShuffleWriteTime(curWriteMetrics.shuffleWriteTime) - m.incShuffleRecordsWritten(curWriteMetrics.shuffleRecordsWritten) - } - } lengths } - /** - * Read a partition file back as an iterator (used in our iterator method) - */ - private def readPartitionFile(writer: BlockObjectWriter): Iterator[Product2[K, C]] = { - if (writer.isOpen) { - writer.commitAndClose() - } - new PairIterator[K, C](blockManager.diskStore.getValues(writer.blockId, ser).get) - } - def stop(): Unit = { spills.foreach(s => s.file.delete()) spills.clear() - if (partitionWriters != null) { - partitionWriters.foreach { w => - w.revertPartialWritesAndClose() - diskBlockManager.getFile(w.blockId).delete() - } - partitionWriters = null - } } - def diskBytesSpilled: Long = _diskBytesSpilled - /** * Given a stream of ((partition, key), combiner) pairs *assumed to be sorted by partition ID*, * group together the pairs for each partition into a sub-iterator. @@ -826,14 +708,6 @@ private[spark] class ExternalSorter[K, V, C]( (0 until numPartitions).iterator.map(p => (p, new IteratorForPartition(p, buffered))) } - private def comparator: Option[Comparator[K]] = { - if (ordering.isDefined || aggregator.isDefined) { - Some(keyComparator) - } else { - None - } - } - /** * An iterator that reads only the elements for a given partition ID from an underlying buffered * stream, assuming this partition is the next one to be read. Used to make it easier to return diff --git a/core/src/main/scala/org/apache/spark/util/collection/PartitionedAppendOnlyMap.scala b/core/src/main/scala/org/apache/spark/util/collection/PartitionedAppendOnlyMap.scala index e2e2f1faae9d1..d0d25b43d0477 100644 --- a/core/src/main/scala/org/apache/spark/util/collection/PartitionedAppendOnlyMap.scala +++ b/core/src/main/scala/org/apache/spark/util/collection/PartitionedAppendOnlyMap.scala @@ -34,10 +34,6 @@ private[spark] class PartitionedAppendOnlyMap[K, V] destructiveSortedIterator(comparator) } - def writablePartitionedIterator(): WritablePartitionedIterator = { - WritablePartitionedIterator.fromIterator(super.iterator) - } - def insert(partition: Int, key: K, value: V): Unit = { update((partition, key), value) } diff --git a/core/src/main/scala/org/apache/spark/util/collection/PartitionedPairBuffer.scala b/core/src/main/scala/org/apache/spark/util/collection/PartitionedPairBuffer.scala index e8332e1a87eac..5a6e9a9580e9b 100644 --- a/core/src/main/scala/org/apache/spark/util/collection/PartitionedPairBuffer.scala +++ b/core/src/main/scala/org/apache/spark/util/collection/PartitionedPairBuffer.scala @@ -71,10 +71,6 @@ private[spark] class PartitionedPairBuffer[K, V](initialCapacity: Int = 64) iterator } - override def writablePartitionedIterator(): WritablePartitionedIterator = { - WritablePartitionedIterator.fromIterator(iterator) - } - private def iterator(): Iterator[((Int, K), V)] = new Iterator[((Int, K), V)] { var pos = 0 diff --git a/core/src/main/scala/org/apache/spark/util/collection/PartitionedSerializedPairBuffer.scala b/core/src/main/scala/org/apache/spark/util/collection/PartitionedSerializedPairBuffer.scala index 554d88206e221..862408b7a4d21 100644 --- a/core/src/main/scala/org/apache/spark/util/collection/PartitionedSerializedPairBuffer.scala +++ b/core/src/main/scala/org/apache/spark/util/collection/PartitionedSerializedPairBuffer.scala @@ -122,10 +122,6 @@ private[spark] class PartitionedSerializedPairBuffer[K, V]( override def destructiveSortedWritablePartitionedIterator(keyComparator: Option[Comparator[K]]) : WritablePartitionedIterator = { sort(keyComparator) - writablePartitionedIterator - } - - override def writablePartitionedIterator(): WritablePartitionedIterator = { new WritablePartitionedIterator { // current position in the meta buffer in ints var pos = 0 diff --git a/core/src/main/scala/org/apache/spark/util/collection/WritablePartitionedPairCollection.scala b/core/src/main/scala/org/apache/spark/util/collection/WritablePartitionedPairCollection.scala index f26d1618c9200..7bc59898658e4 100644 --- a/core/src/main/scala/org/apache/spark/util/collection/WritablePartitionedPairCollection.scala +++ b/core/src/main/scala/org/apache/spark/util/collection/WritablePartitionedPairCollection.scala @@ -47,13 +47,20 @@ private[spark] trait WritablePartitionedPairCollection[K, V] { */ def destructiveSortedWritablePartitionedIterator(keyComparator: Option[Comparator[K]]) : WritablePartitionedIterator = { - WritablePartitionedIterator.fromIterator(partitionedDestructiveSortedIterator(keyComparator)) - } + val it = partitionedDestructiveSortedIterator(keyComparator) + new WritablePartitionedIterator { + private[this] var cur = if (it.hasNext) it.next() else null - /** - * Iterate through the data and write out the elements instead of returning them. - */ - def writablePartitionedIterator(): WritablePartitionedIterator + def writeNext(writer: BlockObjectWriter): Unit = { + writer.write(cur._1._2, cur._2) + cur = if (it.hasNext) it.next() else null + } + + def hasNext(): Boolean = cur != null + + def nextPartition(): Int = cur._1._1 + } + } } private[spark] object WritablePartitionedPairCollection { @@ -94,20 +101,3 @@ private[spark] trait WritablePartitionedIterator { def nextPartition(): Int } - -private[spark] object WritablePartitionedIterator { - def fromIterator(it: Iterator[((Int, _), _)]): WritablePartitionedIterator = { - new WritablePartitionedIterator { - var cur = if (it.hasNext) it.next() else null - - def writeNext(writer: BlockObjectWriter): Unit = { - writer.write(cur._1._2, cur._2) - cur = if (it.hasNext) it.next() else null - } - - def hasNext(): Boolean = cur != null - - def nextPartition(): Int = cur._1._1 - } - } -} diff --git a/core/src/test/java/org/apache/spark/JavaAPISuite.java b/core/src/test/java/org/apache/spark/JavaAPISuite.java index c2089b0e56a1f..dfd86d3e51e7d 100644 --- a/core/src/test/java/org/apache/spark/JavaAPISuite.java +++ b/core/src/test/java/org/apache/spark/JavaAPISuite.java @@ -212,6 +212,8 @@ public int getPartition(Object key) { JavaPairRDD repartitioned = rdd.repartitionAndSortWithinPartitions(partitioner); + Assert.assertTrue(repartitioned.partitioner().isPresent()); + Assert.assertEquals(repartitioned.partitioner().get(), partitioner); List>> partitions = repartitioned.glom().collect(); Assert.assertEquals(partitions.get(0), Arrays.asList(new Tuple2(0, 5), new Tuple2(0, 8), new Tuple2(2, 6))); diff --git a/core/src/test/scala/org/apache/spark/AccumulatorSuite.scala b/core/src/test/scala/org/apache/spark/AccumulatorSuite.scala index 746a40a21bf9e..e942d6579b2fd 100644 --- a/core/src/test/scala/org/apache/spark/AccumulatorSuite.scala +++ b/core/src/test/scala/org/apache/spark/AccumulatorSuite.scala @@ -20,11 +20,10 @@ package org.apache.spark import scala.collection.mutable import scala.ref.WeakReference -import org.scalatest.FunSuite import org.scalatest.Matchers -class AccumulatorSuite extends FunSuite with Matchers with LocalSparkContext { +class AccumulatorSuite extends SparkFunSuite with Matchers with LocalSparkContext { implicit def setAccum[A]: AccumulableParam[mutable.Set[A], A] = diff --git a/core/src/test/scala/org/apache/spark/CacheManagerSuite.scala b/core/src/test/scala/org/apache/spark/CacheManagerSuite.scala index 668ddf9f5f0a9..af81e46a657d3 100644 --- a/core/src/test/scala/org/apache/spark/CacheManagerSuite.scala +++ b/core/src/test/scala/org/apache/spark/CacheManagerSuite.scala @@ -18,7 +18,7 @@ package org.apache.spark import org.mockito.Mockito._ -import org.scalatest.{BeforeAndAfter, FunSuite} +import org.scalatest.BeforeAndAfter import org.scalatest.mock.MockitoSugar import org.apache.spark.executor.DataReadMethod @@ -26,7 +26,7 @@ import org.apache.spark.rdd.RDD import org.apache.spark.storage._ // TODO: Test the CacheManager's thread-safety aspects -class CacheManagerSuite extends FunSuite with LocalSparkContext with BeforeAndAfter +class CacheManagerSuite extends SparkFunSuite with LocalSparkContext with BeforeAndAfter with MockitoSugar { var blockManager: BlockManager = _ diff --git a/core/src/test/scala/org/apache/spark/CheckpointSuite.scala b/core/src/test/scala/org/apache/spark/CheckpointSuite.scala index 91d8fdedbe0f3..d1761a48babbc 100644 --- a/core/src/test/scala/org/apache/spark/CheckpointSuite.scala +++ b/core/src/test/scala/org/apache/spark/CheckpointSuite.scala @@ -21,13 +21,11 @@ import java.io.File import scala.reflect.ClassTag -import org.scalatest.FunSuite - import org.apache.spark.rdd._ import org.apache.spark.storage.{BlockId, StorageLevel, TestBlockId} import org.apache.spark.util.Utils -class CheckpointSuite extends FunSuite with LocalSparkContext with Logging { +class CheckpointSuite extends SparkFunSuite with LocalSparkContext with Logging { var checkpointDir: File = _ val partitioner = new HashPartitioner(2) diff --git a/core/src/test/scala/org/apache/spark/ContextCleanerSuite.scala b/core/src/test/scala/org/apache/spark/ContextCleanerSuite.scala index 4a48f6580c78e..501fe186bfd7c 100644 --- a/core/src/test/scala/org/apache/spark/ContextCleanerSuite.scala +++ b/core/src/test/scala/org/apache/spark/ContextCleanerSuite.scala @@ -23,7 +23,7 @@ import scala.collection.mutable.{HashSet, SynchronizedSet} import scala.language.existentials import scala.util.Random -import org.scalatest.{BeforeAndAfter, FunSuite} +import org.scalatest.BeforeAndAfter import org.scalatest.concurrent.{PatienceConfiguration, Eventually} import org.scalatest.concurrent.Eventually._ import org.scalatest.time.SpanSugar._ @@ -44,7 +44,7 @@ import org.apache.spark.storage.ShuffleIndexBlockId * config options, in particular, a different shuffle manager class */ abstract class ContextCleanerSuiteBase(val shuffleManager: Class[_] = classOf[HashShuffleManager]) - extends FunSuite with BeforeAndAfter with LocalSparkContext + extends SparkFunSuite with BeforeAndAfter with LocalSparkContext { implicit val defaultTimeout = timeout(10000 millis) val conf = new SparkConf() diff --git a/core/src/test/scala/org/apache/spark/DistributedSuite.scala b/core/src/test/scala/org/apache/spark/DistributedSuite.scala index 96a9c207ad022..9c191ed52206d 100644 --- a/core/src/test/scala/org/apache/spark/DistributedSuite.scala +++ b/core/src/test/scala/org/apache/spark/DistributedSuite.scala @@ -17,7 +17,6 @@ package org.apache.spark -import org.scalatest.FunSuite import org.scalatest.concurrent.Timeouts._ import org.scalatest.Matchers import org.scalatest.time.{Millis, Span} @@ -28,7 +27,7 @@ class NotSerializableClass class NotSerializableExn(val notSer: NotSerializableClass) extends Throwable() {} -class DistributedSuite extends FunSuite with Matchers with LocalSparkContext { +class DistributedSuite extends SparkFunSuite with Matchers with LocalSparkContext { val clusterUrl = "local-cluster[2,1,512]" diff --git a/core/src/test/scala/org/apache/spark/DriverSuite.scala b/core/src/test/scala/org/apache/spark/DriverSuite.scala index c42dfbc82ada4..b2262033ca238 100644 --- a/core/src/test/scala/org/apache/spark/DriverSuite.scala +++ b/core/src/test/scala/org/apache/spark/DriverSuite.scala @@ -19,14 +19,13 @@ package org.apache.spark import java.io.File -import org.scalatest.FunSuite import org.scalatest.concurrent.Timeouts import org.scalatest.prop.TableDrivenPropertyChecks._ import org.scalatest.time.SpanSugar._ import org.apache.spark.util.Utils -class DriverSuite extends FunSuite with Timeouts { +class DriverSuite extends SparkFunSuite with Timeouts { ignore("driver should exit after finishing without cleanup (SPARK-530)") { val sparkHome = sys.props.getOrElse("spark.test.home", fail("spark.test.home is not set!")) diff --git a/core/src/test/scala/org/apache/spark/ExecutorAllocationManagerSuite.scala b/core/src/test/scala/org/apache/spark/ExecutorAllocationManagerSuite.scala index 84f787ee3715d..1c2b681f0b843 100644 --- a/core/src/test/scala/org/apache/spark/ExecutorAllocationManagerSuite.scala +++ b/core/src/test/scala/org/apache/spark/ExecutorAllocationManagerSuite.scala @@ -19,7 +19,7 @@ package org.apache.spark import scala.collection.mutable -import org.scalatest.{BeforeAndAfter, FunSuite, PrivateMethodTester} +import org.scalatest.{BeforeAndAfter, PrivateMethodTester} import org.apache.spark.executor.TaskMetrics import org.apache.spark.scheduler._ import org.apache.spark.scheduler.cluster.ExecutorInfo @@ -28,7 +28,11 @@ import org.apache.spark.util.ManualClock /** * Test add and remove behavior of ExecutorAllocationManager. */ -class ExecutorAllocationManagerSuite extends FunSuite with LocalSparkContext with BeforeAndAfter { +class ExecutorAllocationManagerSuite + extends SparkFunSuite + with LocalSparkContext + with BeforeAndAfter { + import ExecutorAllocationManager._ import ExecutorAllocationManagerSuite._ diff --git a/core/src/test/scala/org/apache/spark/FailureSuite.scala b/core/src/test/scala/org/apache/spark/FailureSuite.scala index cade1fda2c7be..a8c8c6f73fb5a 100644 --- a/core/src/test/scala/org/apache/spark/FailureSuite.scala +++ b/core/src/test/scala/org/apache/spark/FailureSuite.scala @@ -17,8 +17,6 @@ package org.apache.spark -import org.scalatest.FunSuite - import org.apache.spark.util.NonSerializable import java.io.NotSerializableException @@ -38,7 +36,7 @@ object FailureSuiteState { } } -class FailureSuite extends FunSuite with LocalSparkContext { +class FailureSuite extends SparkFunSuite with LocalSparkContext { // Run a 3-task map job in which task 1 deterministically fails once, and check // whether the job completes successfully and we ran 4 tasks in total. @@ -119,7 +117,7 @@ class FailureSuite extends FunSuite with LocalSparkContext { sc.parallelize(1 to 10, 2).map(x => a).count() } assert(thrown.getClass === classOf[SparkException]) - assert(thrown.getMessage.contains("NotSerializableException") || + assert(thrown.getMessage.contains("NotSerializableException") || thrown.getCause.getClass === classOf[NotSerializableException]) // Non-serializable closure in an earlier stage @@ -127,7 +125,7 @@ class FailureSuite extends FunSuite with LocalSparkContext { sc.parallelize(1 to 10, 2).map(x => (x, a)).partitionBy(new HashPartitioner(3)).count() } assert(thrown1.getClass === classOf[SparkException]) - assert(thrown1.getMessage.contains("NotSerializableException") || + assert(thrown1.getMessage.contains("NotSerializableException") || thrown1.getCause.getClass === classOf[NotSerializableException]) // Non-serializable closure in foreach function @@ -135,7 +133,7 @@ class FailureSuite extends FunSuite with LocalSparkContext { sc.parallelize(1 to 10, 2).foreach(x => println(a)) } assert(thrown2.getClass === classOf[SparkException]) - assert(thrown2.getMessage.contains("NotSerializableException") || + assert(thrown2.getMessage.contains("NotSerializableException") || thrown2.getCause.getClass === classOf[NotSerializableException]) FailureSuiteState.clear() diff --git a/core/src/test/scala/org/apache/spark/FileServerSuite.scala b/core/src/test/scala/org/apache/spark/FileServerSuite.scala index bff2d10b9946c..6e65b0a8f6c76 100644 --- a/core/src/test/scala/org/apache/spark/FileServerSuite.scala +++ b/core/src/test/scala/org/apache/spark/FileServerSuite.scala @@ -24,13 +24,12 @@ import javax.net.ssl.SSLException import com.google.common.io.{ByteStreams, Files} import org.apache.commons.lang3.RandomUtils -import org.scalatest.FunSuite import org.apache.spark.util.Utils import SSLSampleConfigs._ -class FileServerSuite extends FunSuite with LocalSparkContext { +class FileServerSuite extends SparkFunSuite with LocalSparkContext { @transient var tmpDir: File = _ @transient var tmpFile: File = _ diff --git a/core/src/test/scala/org/apache/spark/FileSuite.scala b/core/src/test/scala/org/apache/spark/FileSuite.scala index d67de8692df62..1d8fade90f398 100644 --- a/core/src/test/scala/org/apache/spark/FileSuite.scala +++ b/core/src/test/scala/org/apache/spark/FileSuite.scala @@ -30,12 +30,11 @@ import org.apache.hadoop.mapred.{JobConf, FileAlreadyExistsException, FileSplit, import org.apache.hadoop.mapreduce.Job import org.apache.hadoop.mapreduce.lib.input.{FileSplit => NewFileSplit, TextInputFormat => NewTextInputFormat} import org.apache.hadoop.mapreduce.lib.output.{TextOutputFormat => NewTextOutputFormat} -import org.scalatest.FunSuite import org.apache.spark.rdd.{NewHadoopRDD, HadoopRDD} import org.apache.spark.util.Utils -class FileSuite extends FunSuite with LocalSparkContext { +class FileSuite extends SparkFunSuite with LocalSparkContext { var tempDir: File = _ override def beforeEach() { diff --git a/core/src/test/scala/org/apache/spark/FutureActionSuite.scala b/core/src/test/scala/org/apache/spark/FutureActionSuite.scala index f5cdb01ec9504..1102aea96b548 100644 --- a/core/src/test/scala/org/apache/spark/FutureActionSuite.scala +++ b/core/src/test/scala/org/apache/spark/FutureActionSuite.scala @@ -20,10 +20,14 @@ package org.apache.spark import scala.concurrent.Await import scala.concurrent.duration.Duration -import org.scalatest.{BeforeAndAfter, FunSuite, Matchers} +import org.scalatest.{BeforeAndAfter, Matchers} -class FutureActionSuite extends FunSuite with BeforeAndAfter with Matchers with LocalSparkContext { +class FutureActionSuite + extends SparkFunSuite + with BeforeAndAfter + with Matchers + with LocalSparkContext { before { sc = new SparkContext("local", "FutureActionSuite") diff --git a/core/src/test/scala/org/apache/spark/HeartbeatReceiverSuite.scala b/core/src/test/scala/org/apache/spark/HeartbeatReceiverSuite.scala index b789912e9ebef..911b3bddd1836 100644 --- a/core/src/test/scala/org/apache/spark/HeartbeatReceiverSuite.scala +++ b/core/src/test/scala/org/apache/spark/HeartbeatReceiverSuite.scala @@ -22,7 +22,6 @@ import scala.language.postfixOps import org.apache.spark.executor.TaskMetrics import org.apache.spark.storage.BlockManagerId -import org.scalatest.FunSuite import org.mockito.Mockito.{mock, spy, verify, when} import org.mockito.Matchers import org.mockito.Matchers._ @@ -31,7 +30,7 @@ import org.apache.spark.scheduler.TaskScheduler import org.apache.spark.util.RpcUtils import org.scalatest.concurrent.Eventually._ -class HeartbeatReceiverSuite extends FunSuite with LocalSparkContext { +class HeartbeatReceiverSuite extends SparkFunSuite with LocalSparkContext { test("HeartbeatReceiver") { sc = spy(new SparkContext("local[2]", "test")) diff --git a/core/src/test/scala/org/apache/spark/ImplicitOrderingSuite.scala b/core/src/test/scala/org/apache/spark/ImplicitOrderingSuite.scala index 69314deda1f03..4399f25626472 100644 --- a/core/src/test/scala/org/apache/spark/ImplicitOrderingSuite.scala +++ b/core/src/test/scala/org/apache/spark/ImplicitOrderingSuite.scala @@ -17,11 +17,9 @@ package org.apache.spark -import org.scalatest.FunSuite - import org.apache.spark.rdd.RDD -class ImplicitOrderingSuite extends FunSuite with LocalSparkContext { +class ImplicitOrderingSuite extends SparkFunSuite with LocalSparkContext { // Tests that PairRDDFunctions grabs an implicit Ordering in various cases where it should. test("basic inference of Orderings"){ sc = new SparkContext("local", "test") @@ -29,11 +27,11 @@ class ImplicitOrderingSuite extends FunSuite with LocalSparkContext { // These RDD methods are in the companion object so that the unserializable ScalaTest Engine // won't be reachable from the closure object - + // Infer orderings after basic maps to particular types val basicMapExpectations = ImplicitOrderingSuite.basicMapExpectations(rdd) basicMapExpectations.map({case (met, explain) => assert(met, explain)}) - + // Infer orderings for other RDD methods val otherRDDMethodExpectations = ImplicitOrderingSuite.otherRDDMethodExpectations(rdd) otherRDDMethodExpectations.map({case (met, explain) => assert(met, explain)}) @@ -50,30 +48,30 @@ private object ImplicitOrderingSuite { class OrderedClass extends Ordered[OrderedClass] { override def compare(o: OrderedClass): Int = throw new UnsupportedOperationException } - + def basicMapExpectations(rdd: RDD[Int]): List[(Boolean, String)] = { - List((rdd.map(x => (x, x)).keyOrdering.isDefined, + List((rdd.map(x => (x, x)).keyOrdering.isDefined, "rdd.map(x => (x, x)).keyOrdering.isDefined"), - (rdd.map(x => (1, x)).keyOrdering.isDefined, + (rdd.map(x => (1, x)).keyOrdering.isDefined, "rdd.map(x => (1, x)).keyOrdering.isDefined"), - (rdd.map(x => (x.toString, x)).keyOrdering.isDefined, + (rdd.map(x => (x.toString, x)).keyOrdering.isDefined, "rdd.map(x => (x.toString, x)).keyOrdering.isDefined"), - (rdd.map(x => (null, x)).keyOrdering.isDefined, + (rdd.map(x => (null, x)).keyOrdering.isDefined, "rdd.map(x => (null, x)).keyOrdering.isDefined"), - (rdd.map(x => (new NonOrderedClass, x)).keyOrdering.isEmpty, + (rdd.map(x => (new NonOrderedClass, x)).keyOrdering.isEmpty, "rdd.map(x => (new NonOrderedClass, x)).keyOrdering.isEmpty"), - (rdd.map(x => (new ComparableClass, x)).keyOrdering.isDefined, + (rdd.map(x => (new ComparableClass, x)).keyOrdering.isDefined, "rdd.map(x => (new ComparableClass, x)).keyOrdering.isDefined"), - (rdd.map(x => (new OrderedClass, x)).keyOrdering.isDefined, + (rdd.map(x => (new OrderedClass, x)).keyOrdering.isDefined, "rdd.map(x => (new OrderedClass, x)).keyOrdering.isDefined")) } - + def otherRDDMethodExpectations(rdd: RDD[Int]): List[(Boolean, String)] = { - List((rdd.groupBy(x => x).keyOrdering.isDefined, + List((rdd.groupBy(x => x).keyOrdering.isDefined, "rdd.groupBy(x => x).keyOrdering.isDefined"), - (rdd.groupBy(x => new NonOrderedClass).keyOrdering.isEmpty, + (rdd.groupBy(x => new NonOrderedClass).keyOrdering.isEmpty, "rdd.groupBy(x => new NonOrderedClass).keyOrdering.isEmpty"), - (rdd.groupBy(x => new ComparableClass).keyOrdering.isDefined, + (rdd.groupBy(x => new ComparableClass).keyOrdering.isDefined, "rdd.groupBy(x => new ComparableClass).keyOrdering.isDefined"), (rdd.groupBy(x => new OrderedClass).keyOrdering.isDefined, "rdd.groupBy(x => new OrderedClass).keyOrdering.isDefined"), diff --git a/core/src/test/scala/org/apache/spark/JobCancellationSuite.scala b/core/src/test/scala/org/apache/spark/JobCancellationSuite.scala index ae17fc60e4a43..340a9e327107e 100644 --- a/core/src/test/scala/org/apache/spark/JobCancellationSuite.scala +++ b/core/src/test/scala/org/apache/spark/JobCancellationSuite.scala @@ -24,7 +24,7 @@ import scala.concurrent.ExecutionContext.Implicits.global import scala.concurrent.duration._ import scala.concurrent.future -import org.scalatest.{BeforeAndAfter, FunSuite} +import org.scalatest.BeforeAndAfter import org.scalatest.Matchers import org.apache.spark.scheduler.{SparkListener, SparkListenerTaskStart} @@ -34,7 +34,7 @@ import org.apache.spark.scheduler.{SparkListener, SparkListenerTaskStart} * (e.g. count) as well as multi-job action (e.g. take). We test the local and cluster schedulers * in both FIFO and fair scheduling modes. */ -class JobCancellationSuite extends FunSuite with Matchers with BeforeAndAfter +class JobCancellationSuite extends SparkFunSuite with Matchers with BeforeAndAfter with LocalSparkContext { override def afterEach() { diff --git a/core/src/test/scala/org/apache/spark/MapOutputTrackerSuite.scala b/core/src/test/scala/org/apache/spark/MapOutputTrackerSuite.scala index 6ed057a7cab97..1fab69678d040 100644 --- a/core/src/test/scala/org/apache/spark/MapOutputTrackerSuite.scala +++ b/core/src/test/scala/org/apache/spark/MapOutputTrackerSuite.scala @@ -19,14 +19,13 @@ package org.apache.spark import org.mockito.Mockito._ import org.mockito.Matchers.{any, isA} -import org.scalatest.FunSuite import org.apache.spark.rpc.{RpcAddress, RpcEndpointRef, RpcCallContext, RpcEnv} import org.apache.spark.scheduler.{CompressedMapStatus, MapStatus} import org.apache.spark.shuffle.FetchFailedException import org.apache.spark.storage.BlockManagerId -class MapOutputTrackerSuite extends FunSuite { +class MapOutputTrackerSuite extends SparkFunSuite { private val conf = new SparkConf def createRpcEnv(name: String, host: String = "localhost", port: Int = 0, diff --git a/core/src/test/scala/org/apache/spark/PartitioningSuite.scala b/core/src/test/scala/org/apache/spark/PartitioningSuite.scala index 47e3bf6e1ac41..3316f561a4949 100644 --- a/core/src/test/scala/org/apache/spark/PartitioningSuite.scala +++ b/core/src/test/scala/org/apache/spark/PartitioningSuite.scala @@ -20,12 +20,12 @@ package org.apache.spark import scala.collection.mutable.ArrayBuffer import scala.math.abs -import org.scalatest.{FunSuite, PrivateMethodTester} +import org.scalatest.PrivateMethodTester import org.apache.spark.rdd.RDD import org.apache.spark.util.StatCounter -class PartitioningSuite extends FunSuite with SharedSparkContext with PrivateMethodTester { +class PartitioningSuite extends SparkFunSuite with SharedSparkContext with PrivateMethodTester { test("HashPartitioner equality") { val p2 = new HashPartitioner(2) diff --git a/core/src/test/scala/org/apache/spark/SSLOptionsSuite.scala b/core/src/test/scala/org/apache/spark/SSLOptionsSuite.scala index 93f46ef11c0e2..376481ba541fa 100644 --- a/core/src/test/scala/org/apache/spark/SSLOptionsSuite.scala +++ b/core/src/test/scala/org/apache/spark/SSLOptionsSuite.scala @@ -21,9 +21,9 @@ import java.io.File import com.google.common.io.Files import org.apache.spark.util.Utils -import org.scalatest.{BeforeAndAfterAll, FunSuite} +import org.scalatest.BeforeAndAfterAll -class SSLOptionsSuite extends FunSuite with BeforeAndAfterAll { +class SSLOptionsSuite extends SparkFunSuite with BeforeAndAfterAll { test("test resolving property file as spark conf ") { val keyStorePath = new File(this.getClass.getResource("/keystore").toURI).getAbsolutePath diff --git a/core/src/test/scala/org/apache/spark/SSLSampleConfigs.scala b/core/src/test/scala/org/apache/spark/SSLSampleConfigs.scala index 308b9ea17708d..1a099da2c6c8e 100644 --- a/core/src/test/scala/org/apache/spark/SSLSampleConfigs.scala +++ b/core/src/test/scala/org/apache/spark/SSLSampleConfigs.scala @@ -34,7 +34,7 @@ object SSLSampleConfigs { conf.set("spark.ssl.trustStore", trustStorePath) conf.set("spark.ssl.trustStorePassword", "password") conf.set("spark.ssl.enabledAlgorithms", - "TLS_RSA_WITH_AES_128_CBC_SHA, SSL_RSA_WITH_DES_CBC_SHA") + "SSL_RSA_WITH_RC4_128_SHA, SSL_RSA_WITH_DES_CBC_SHA") conf.set("spark.ssl.protocol", "TLSv1") conf } @@ -48,7 +48,7 @@ object SSLSampleConfigs { conf.set("spark.ssl.trustStore", trustStorePath) conf.set("spark.ssl.trustStorePassword", "password") conf.set("spark.ssl.enabledAlgorithms", - "TLS_RSA_WITH_AES_128_CBC_SHA, SSL_RSA_WITH_DES_CBC_SHA") + "SSL_RSA_WITH_RC4_128_SHA, SSL_RSA_WITH_DES_CBC_SHA") conf.set("spark.ssl.protocol", "TLSv1") conf } diff --git a/core/src/test/scala/org/apache/spark/SecurityManagerSuite.scala b/core/src/test/scala/org/apache/spark/SecurityManagerSuite.scala index 62cb7649c0284..e9b64aa82a17a 100644 --- a/core/src/test/scala/org/apache/spark/SecurityManagerSuite.scala +++ b/core/src/test/scala/org/apache/spark/SecurityManagerSuite.scala @@ -19,11 +19,9 @@ package org.apache.spark import java.io.File -import org.scalatest.FunSuite - import org.apache.spark.util.Utils -class SecurityManagerSuite extends FunSuite { +class SecurityManagerSuite extends SparkFunSuite { test("set security with conf") { val conf = new SparkConf @@ -147,7 +145,7 @@ class SecurityManagerSuite extends FunSuite { assert(securityManager.fileServerSSLOptions.keyPassword === Some("password")) assert(securityManager.fileServerSSLOptions.protocol === Some("TLSv1")) assert(securityManager.fileServerSSLOptions.enabledAlgorithms === - Set("TLS_RSA_WITH_AES_128_CBC_SHA", "SSL_RSA_WITH_DES_CBC_SHA")) + Set("SSL_RSA_WITH_RC4_128_SHA", "SSL_RSA_WITH_DES_CBC_SHA")) assert(securityManager.akkaSSLOptions.trustStore.isDefined === true) assert(securityManager.akkaSSLOptions.trustStore.get.getName === "truststore") @@ -158,7 +156,7 @@ class SecurityManagerSuite extends FunSuite { assert(securityManager.akkaSSLOptions.keyPassword === Some("password")) assert(securityManager.akkaSSLOptions.protocol === Some("TLSv1")) assert(securityManager.akkaSSLOptions.enabledAlgorithms === - Set("TLS_RSA_WITH_AES_128_CBC_SHA", "SSL_RSA_WITH_DES_CBC_SHA")) + Set("SSL_RSA_WITH_RC4_128_SHA", "SSL_RSA_WITH_DES_CBC_SHA")) } test("ssl off setup") { diff --git a/core/src/test/scala/org/apache/spark/ShuffleSuite.scala b/core/src/test/scala/org/apache/spark/ShuffleSuite.scala index d7180516029d5..c3c2b1ffc1efa 100644 --- a/core/src/test/scala/org/apache/spark/ShuffleSuite.scala +++ b/core/src/test/scala/org/apache/spark/ShuffleSuite.scala @@ -17,16 +17,16 @@ package org.apache.spark -import org.scalatest.FunSuite import org.scalatest.Matchers import org.apache.spark.ShuffleSuite.NonJavaSerializableClass import org.apache.spark.rdd.{CoGroupedRDD, OrderedRDDFunctions, RDD, ShuffledRDD, SubtractedRDD} +import org.apache.spark.scheduler.{SparkListener, SparkListenerTaskEnd} import org.apache.spark.serializer.KryoSerializer import org.apache.spark.storage.{ShuffleDataBlockId, ShuffleBlockId} import org.apache.spark.util.MutablePair -abstract class ShuffleSuite extends FunSuite with Matchers with LocalSparkContext { +abstract class ShuffleSuite extends SparkFunSuite with Matchers with LocalSparkContext { val conf = new SparkConf(loadDefaults = false) @@ -282,6 +282,39 @@ abstract class ShuffleSuite extends FunSuite with Matchers with LocalSparkContex // This count should retry the execution of the previous stage and rerun shuffle. rdd.count() } + + test("metrics for shuffle without aggregation") { + sc = new SparkContext("local", "test", conf.clone()) + val numRecords = 10000 + + val metrics = ShuffleSuite.runAndReturnMetrics(sc) { + sc.parallelize(1 to numRecords, 4) + .map(key => (key, 1)) + .groupByKey() + .collect() + } + + assert(metrics.recordsRead === numRecords) + assert(metrics.recordsWritten === numRecords) + assert(metrics.bytesWritten === metrics.byresRead) + assert(metrics.bytesWritten > 0) + } + + test("metrics for shuffle with aggregation") { + sc = new SparkContext("local", "test", conf.clone()) + val numRecords = 10000 + + val metrics = ShuffleSuite.runAndReturnMetrics(sc) { + sc.parallelize(1 to numRecords, 4) + .flatMap(key => Array.fill(100)((key, 1))) + .countByKey() + } + + assert(metrics.recordsRead === numRecords) + assert(metrics.recordsWritten === numRecords) + assert(metrics.bytesWritten === metrics.byresRead) + assert(metrics.bytesWritten > 0) + } } object ShuffleSuite { @@ -295,4 +328,35 @@ object ShuffleSuite { value - o.value } } + + case class AggregatedShuffleMetrics( + recordsWritten: Long, + recordsRead: Long, + bytesWritten: Long, + byresRead: Long) + + def runAndReturnMetrics(sc: SparkContext)(job: => Unit): AggregatedShuffleMetrics = { + @volatile var recordsWritten: Long = 0 + @volatile var recordsRead: Long = 0 + @volatile var bytesWritten: Long = 0 + @volatile var bytesRead: Long = 0 + val listener = new SparkListener { + override def onTaskEnd(taskEnd: SparkListenerTaskEnd) { + taskEnd.taskMetrics.shuffleWriteMetrics.foreach { m => + recordsWritten += m.shuffleRecordsWritten + bytesWritten += m.shuffleBytesWritten + } + taskEnd.taskMetrics.shuffleReadMetrics.foreach { m => + recordsRead += m.recordsRead + bytesRead += m.totalBytesRead + } + } + } + sc.addSparkListener(listener) + + job + + sc.listenerBus.waitUntilEmpty(500) + AggregatedShuffleMetrics(recordsWritten, recordsRead, bytesWritten, bytesRead) + } } diff --git a/core/src/test/scala/org/apache/spark/SparkConfSuite.scala b/core/src/test/scala/org/apache/spark/SparkConfSuite.scala index fafc9d47503b7..9fbaeb33f97cd 100644 --- a/core/src/test/scala/org/apache/spark/SparkConfSuite.scala +++ b/core/src/test/scala/org/apache/spark/SparkConfSuite.scala @@ -23,13 +23,12 @@ import scala.concurrent.duration._ import scala.language.postfixOps import scala.util.{Try, Random} -import org.scalatest.FunSuite import org.apache.spark.network.util.ByteUnit import org.apache.spark.serializer.{KryoRegistrator, KryoSerializer} import org.apache.spark.util.{RpcUtils, ResetSystemProperties} import com.esotericsoftware.kryo.Kryo -class SparkConfSuite extends FunSuite with LocalSparkContext with ResetSystemProperties { +class SparkConfSuite extends SparkFunSuite with LocalSparkContext with ResetSystemProperties { test("Test byteString conversion") { val conf = new SparkConf() // Simply exercise the API, we don't need a complete conversion test since that's handled in diff --git a/core/src/test/scala/org/apache/spark/SparkContextInfoSuite.scala b/core/src/test/scala/org/apache/spark/SparkContextInfoSuite.scala index e6ab538d77bcc..2bdbd70c638a5 100644 --- a/core/src/test/scala/org/apache/spark/SparkContextInfoSuite.scala +++ b/core/src/test/scala/org/apache/spark/SparkContextInfoSuite.scala @@ -17,10 +17,10 @@ package org.apache.spark -import org.scalatest.{Assertions, FunSuite} +import org.scalatest.Assertions import org.apache.spark.storage.StorageLevel -class SparkContextInfoSuite extends FunSuite with LocalSparkContext { +class SparkContextInfoSuite extends SparkFunSuite with LocalSparkContext { test("getPersistentRDDs only returns RDDs that are marked as cached") { sc = new SparkContext("local", "test") assert(sc.getPersistentRDDs.isEmpty === true) diff --git a/core/src/test/scala/org/apache/spark/SparkContextSchedulerCreationSuite.scala b/core/src/test/scala/org/apache/spark/SparkContextSchedulerCreationSuite.scala index 9343f4fff89da..f89e3d0a49920 100644 --- a/core/src/test/scala/org/apache/spark/SparkContextSchedulerCreationSuite.scala +++ b/core/src/test/scala/org/apache/spark/SparkContextSchedulerCreationSuite.scala @@ -17,7 +17,7 @@ package org.apache.spark -import org.scalatest.{FunSuite, PrivateMethodTester} +import org.scalatest.PrivateMethodTester import org.apache.spark.scheduler.{SchedulerBackend, TaskScheduler, TaskSchedulerImpl} import org.apache.spark.scheduler.cluster.{SimrSchedulerBackend, SparkDeploySchedulerBackend} @@ -25,7 +25,7 @@ import org.apache.spark.scheduler.cluster.mesos.{CoarseMesosSchedulerBackend, Me import org.apache.spark.scheduler.local.LocalBackend class SparkContextSchedulerCreationSuite - extends FunSuite with LocalSparkContext with PrivateMethodTester with Logging { + extends SparkFunSuite with LocalSparkContext with PrivateMethodTester with Logging { def createTaskScheduler(master: String): TaskSchedulerImpl = createTaskScheduler(master, new SparkConf()) diff --git a/core/src/test/scala/org/apache/spark/SparkContextSuite.scala b/core/src/test/scala/org/apache/spark/SparkContextSuite.scala index 31ef5cd75bd4a..6838b35ab4cc8 100644 --- a/core/src/test/scala/org/apache/spark/SparkContextSuite.scala +++ b/core/src/test/scala/org/apache/spark/SparkContextSuite.scala @@ -23,8 +23,6 @@ import java.util.concurrent.TimeUnit import com.google.common.base.Charsets._ import com.google.common.io.Files -import org.scalatest.FunSuite - import org.apache.hadoop.io.{BytesWritable, LongWritable, Text} import org.apache.hadoop.mapred.TextInputFormat import org.apache.hadoop.mapreduce.lib.input.{TextInputFormat => NewTextInputFormat} @@ -33,7 +31,7 @@ import org.apache.spark.util.Utils import scala.concurrent.Await import scala.concurrent.duration.Duration -class SparkContextSuite extends FunSuite with LocalSparkContext { +class SparkContextSuite extends SparkFunSuite with LocalSparkContext { test("Only one SparkContext may be active at a time") { // Regression test for SPARK-4180 @@ -73,22 +71,22 @@ class SparkContextSuite extends FunSuite with LocalSparkContext { var sc2: SparkContext = null SparkContext.clearActiveContext() val conf = new SparkConf().setAppName("test").setMaster("local") - + sc = SparkContext.getOrCreate(conf) - + assert(sc.getConf.get("spark.app.name").equals("test")) sc2 = SparkContext.getOrCreate(new SparkConf().setAppName("test2").setMaster("local")) assert(sc2.getConf.get("spark.app.name").equals("test")) assert(sc === sc2) assert(sc eq sc2) - + // Try creating second context to confirm that it's still possible, if desired sc2 = new SparkContext(new SparkConf().setAppName("test3").setMaster("local") .set("spark.driver.allowMultipleContexts", "true")) - + sc2.stop() } - + test("BytesWritable implicit conversion is correct") { // Regression test for SPARK-3121 val bytesWritable = new BytesWritable() diff --git a/core/src/test/scala/org/apache/spark/SparkFunSuite.scala b/core/src/test/scala/org/apache/spark/SparkFunSuite.scala new file mode 100644 index 0000000000000..8cb344332668f --- /dev/null +++ b/core/src/test/scala/org/apache/spark/SparkFunSuite.scala @@ -0,0 +1,48 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark + +// scalastyle:off +import org.scalatest.{FunSuite, Outcome} + +/** + * Base abstract class for all unit tests in Spark for handling common functionality. + */ +private[spark] abstract class SparkFunSuite extends FunSuite with Logging { +// scalastyle:on + + /** + * Log the suite name and the test name before and after each test. + * + * Subclasses should never override this method. If they wish to run + * custom code before and after each test, they should should mix in + * the {{org.scalatest.BeforeAndAfter}} trait instead. + */ + final protected override def withFixture(test: NoArgTest): Outcome = { + val testName = test.text + val suiteName = this.getClass.getName + val shortSuiteName = suiteName.replaceAll("org.apache.spark", "o.a.s") + try { + logInfo(s"\n\n===== TEST OUTPUT FOR $shortSuiteName: '$testName' =====\n") + test() + } finally { + logInfo(s"\n\n===== FINISHED $shortSuiteName: '$testName' =====\n") + } + } + +} diff --git a/core/src/test/scala/org/apache/spark/StatusTrackerSuite.scala b/core/src/test/scala/org/apache/spark/StatusTrackerSuite.scala index 084eb237d70d1..46516e8d25298 100644 --- a/core/src/test/scala/org/apache/spark/StatusTrackerSuite.scala +++ b/core/src/test/scala/org/apache/spark/StatusTrackerSuite.scala @@ -21,12 +21,12 @@ import scala.concurrent.duration._ import scala.language.implicitConversions import scala.language.postfixOps -import org.scalatest.{Matchers, FunSuite} +import org.scalatest.Matchers import org.scalatest.concurrent.Eventually._ import org.apache.spark.JobExecutionStatus._ -class StatusTrackerSuite extends FunSuite with Matchers with LocalSparkContext { +class StatusTrackerSuite extends SparkFunSuite with Matchers with LocalSparkContext { test("basic status API usage") { sc = new SparkContext("local", "test", new SparkConf(false)) diff --git a/core/src/test/scala/org/apache/spark/ThreadingSuite.scala b/core/src/test/scala/org/apache/spark/ThreadingSuite.scala index 10917c866cc7d..6580139df6c60 100644 --- a/core/src/test/scala/org/apache/spark/ThreadingSuite.scala +++ b/core/src/test/scala/org/apache/spark/ThreadingSuite.scala @@ -22,7 +22,6 @@ import java.util.concurrent.atomic.AtomicBoolean import java.util.concurrent.atomic.AtomicInteger import org.apache.spark.scheduler._ -import org.scalatest.FunSuite /** * Holds state shared across task threads in some ThreadingSuite tests. @@ -37,7 +36,7 @@ object ThreadingSuiteState { } } -class ThreadingSuite extends FunSuite with LocalSparkContext { +class ThreadingSuite extends SparkFunSuite with LocalSparkContext { test("accessing SparkContext form a different thread") { sc = new SparkContext("local", "test") diff --git a/core/src/test/scala/org/apache/spark/UnpersistSuite.scala b/core/src/test/scala/org/apache/spark/UnpersistSuite.scala index 42ff059e018a3..f7a13ab3996d8 100644 --- a/core/src/test/scala/org/apache/spark/UnpersistSuite.scala +++ b/core/src/test/scala/org/apache/spark/UnpersistSuite.scala @@ -17,11 +17,10 @@ package org.apache.spark -import org.scalatest.FunSuite import org.scalatest.concurrent.Timeouts._ import org.scalatest.time.{Millis, Span} -class UnpersistSuite extends FunSuite with LocalSparkContext { +class UnpersistSuite extends SparkFunSuite with LocalSparkContext { test("unpersist RDD") { sc = new SparkContext("local", "test") val rdd = sc.makeRDD(Array(1, 2, 3, 4), 2).cache() diff --git a/core/src/test/scala/org/apache/spark/api/python/PythonBroadcastSuite.scala b/core/src/test/scala/org/apache/spark/api/python/PythonBroadcastSuite.scala index 8959a843dbd7d..135c56bf5bc9d 100644 --- a/core/src/test/scala/org/apache/spark/api/python/PythonBroadcastSuite.scala +++ b/core/src/test/scala/org/apache/spark/api/python/PythonBroadcastSuite.scala @@ -21,15 +21,15 @@ import scala.io.Source import java.io.{PrintWriter, File} -import org.scalatest.{Matchers, FunSuite} +import org.scalatest.Matchers -import org.apache.spark.{SharedSparkContext, SparkConf} +import org.apache.spark.{SharedSparkContext, SparkConf, SparkFunSuite} import org.apache.spark.serializer.KryoSerializer import org.apache.spark.util.Utils // This test suite uses SharedSparkContext because we need a SparkEnv in order to deserialize // a PythonBroadcast: -class PythonBroadcastSuite extends FunSuite with Matchers with SharedSparkContext { +class PythonBroadcastSuite extends SparkFunSuite with Matchers with SharedSparkContext { test("PythonBroadcast can be serialized with Kryo (SPARK-4882)") { val tempDir = Utils.createTempDir() val broadcastedString = "Hello, world!" diff --git a/core/src/test/scala/org/apache/spark/api/python/PythonRDDSuite.scala b/core/src/test/scala/org/apache/spark/api/python/PythonRDDSuite.scala index c63d834f9048b..41f2a5c972b6b 100644 --- a/core/src/test/scala/org/apache/spark/api/python/PythonRDDSuite.scala +++ b/core/src/test/scala/org/apache/spark/api/python/PythonRDDSuite.scala @@ -19,9 +19,9 @@ package org.apache.spark.api.python import java.io.{ByteArrayOutputStream, DataOutputStream} -import org.scalatest.FunSuite +import org.apache.spark.SparkFunSuite -class PythonRDDSuite extends FunSuite { +class PythonRDDSuite extends SparkFunSuite { test("Writing large strings to the worker") { val input: List[String] = List("a"*100000) diff --git a/core/src/test/scala/org/apache/spark/api/python/SerDeUtilSuite.scala b/core/src/test/scala/org/apache/spark/api/python/SerDeUtilSuite.scala index f8c39326145e1..267a79fa63782 100644 --- a/core/src/test/scala/org/apache/spark/api/python/SerDeUtilSuite.scala +++ b/core/src/test/scala/org/apache/spark/api/python/SerDeUtilSuite.scala @@ -17,11 +17,9 @@ package org.apache.spark.api.python -import org.scalatest.FunSuite +import org.apache.spark.{SharedSparkContext, SparkFunSuite} -import org.apache.spark.SharedSparkContext - -class SerDeUtilSuite extends FunSuite with SharedSparkContext { +class SerDeUtilSuite extends SparkFunSuite with SharedSparkContext { test("Converting an empty pair RDD to python does not throw an exception (SPARK-5441)") { val emptyRdd = sc.makeRDD(Seq[(Any, Any)]()) diff --git a/core/src/test/scala/org/apache/spark/broadcast/BroadcastSuite.scala b/core/src/test/scala/org/apache/spark/broadcast/BroadcastSuite.scala index c38e306b6ac40..c05e8bb6538ba 100644 --- a/core/src/test/scala/org/apache/spark/broadcast/BroadcastSuite.scala +++ b/core/src/test/scala/org/apache/spark/broadcast/BroadcastSuite.scala @@ -20,10 +20,10 @@ package org.apache.spark.broadcast import scala.concurrent.duration._ import scala.util.Random -import org.scalatest.{Assertions, FunSuite} +import org.scalatest.Assertions import org.scalatest.concurrent.Eventually._ -import org.apache.spark.{LocalSparkContext, SparkConf, SparkContext, SparkException, SparkEnv} +import org.apache.spark._ import org.apache.spark.io.SnappyCompressionCodec import org.apache.spark.rdd.RDD import org.apache.spark.serializer.JavaSerializer @@ -45,7 +45,7 @@ class DummyBroadcastClass(rdd: RDD[Int]) extends Serializable { } } -class BroadcastSuite extends FunSuite with LocalSparkContext { +class BroadcastSuite extends SparkFunSuite with LocalSparkContext { private val httpConf = broadcastConf("HttpBroadcastFactory") private val torrentConf = broadcastConf("TorrentBroadcastFactory") diff --git a/core/src/test/scala/org/apache/spark/deploy/ClientSuite.scala b/core/src/test/scala/org/apache/spark/deploy/ClientSuite.scala index 745f9eeee7536..6a99dbca64f4b 100644 --- a/core/src/test/scala/org/apache/spark/deploy/ClientSuite.scala +++ b/core/src/test/scala/org/apache/spark/deploy/ClientSuite.scala @@ -17,10 +17,11 @@ package org.apache.spark.deploy -import org.scalatest.FunSuite import org.scalatest.Matchers -class ClientSuite extends FunSuite with Matchers { +import org.apache.spark.SparkFunSuite + +class ClientSuite extends SparkFunSuite with Matchers { test("correctly validates driver jar URL's") { ClientArguments.isValidJarUrl("http://someHost:8080/foo.jar") should be (true) ClientArguments.isValidJarUrl("https://someHost:8080/foo.jar") should be (true) diff --git a/core/src/test/scala/org/apache/spark/deploy/JsonProtocolSuite.scala b/core/src/test/scala/org/apache/spark/deploy/JsonProtocolSuite.scala index e04a79284175c..08529e0ef2806 100644 --- a/core/src/test/scala/org/apache/spark/deploy/JsonProtocolSuite.scala +++ b/core/src/test/scala/org/apache/spark/deploy/JsonProtocolSuite.scala @@ -23,14 +23,13 @@ import java.util.Date import com.fasterxml.jackson.core.JsonParseException import org.json4s._ import org.json4s.jackson.JsonMethods -import org.scalatest.FunSuite import org.apache.spark.deploy.DeployMessages.{MasterStateResponse, WorkerStateResponse} import org.apache.spark.deploy.master.{ApplicationInfo, DriverInfo, RecoveryState, WorkerInfo} import org.apache.spark.deploy.worker.{DriverRunner, ExecutorRunner} -import org.apache.spark.{JsonTestUtils, SecurityManager, SparkConf} +import org.apache.spark.{JsonTestUtils, SecurityManager, SparkConf, SparkFunSuite} -class JsonProtocolSuite extends FunSuite with JsonTestUtils { +class JsonProtocolSuite extends SparkFunSuite with JsonTestUtils { test("writeApplicationInfo") { val output = JsonProtocol.writeApplicationInfo(createAppInfo()) diff --git a/core/src/test/scala/org/apache/spark/deploy/LogUrlsStandaloneSuite.scala b/core/src/test/scala/org/apache/spark/deploy/LogUrlsStandaloneSuite.scala index c93d16f8a1586..c215b0582889f 100644 --- a/core/src/test/scala/org/apache/spark/deploy/LogUrlsStandaloneSuite.scala +++ b/core/src/test/scala/org/apache/spark/deploy/LogUrlsStandaloneSuite.scala @@ -23,13 +23,11 @@ import scala.collection.JavaConversions._ import scala.collection.mutable import scala.io.Source -import org.scalatest.FunSuite - import org.apache.spark.scheduler.cluster.ExecutorInfo import org.apache.spark.scheduler.{SparkListenerExecutorAdded, SparkListener} -import org.apache.spark.{SparkConf, SparkContext, LocalSparkContext} +import org.apache.spark.{LocalSparkContext, SparkConf, SparkContext, SparkFunSuite} -class LogUrlsStandaloneSuite extends FunSuite with LocalSparkContext { +class LogUrlsStandaloneSuite extends SparkFunSuite with LocalSparkContext { /** Length of time to wait while draining listener events. */ private val WAIT_TIMEOUT_MILLIS = 10000 diff --git a/core/src/test/scala/org/apache/spark/deploy/PythonRunnerSuite.scala b/core/src/test/scala/org/apache/spark/deploy/PythonRunnerSuite.scala index 80f2cc02516fe..473a2d7b2a258 100644 --- a/core/src/test/scala/org/apache/spark/deploy/PythonRunnerSuite.scala +++ b/core/src/test/scala/org/apache/spark/deploy/PythonRunnerSuite.scala @@ -17,11 +17,10 @@ package org.apache.spark.deploy -import org.scalatest.FunSuite - +import org.apache.spark.SparkFunSuite import org.apache.spark.util.Utils -class PythonRunnerSuite extends FunSuite { +class PythonRunnerSuite extends SparkFunSuite { // Test formatting a single path to be added to the PYTHONPATH test("format path") { diff --git a/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala b/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala index ea9227a7e9af5..46369457f000a 100644 --- a/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala +++ b/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala @@ -23,7 +23,6 @@ import scala.collection.mutable.ArrayBuffer import com.google.common.base.Charsets.UTF_8 import com.google.common.io.ByteStreams -import org.scalatest.FunSuite import org.scalatest.Matchers import org.scalatest.concurrent.Timeouts import org.scalatest.time.SpanSugar._ @@ -35,7 +34,12 @@ import org.apache.spark.util.{ResetSystemProperties, Utils} // Note: this suite mixes in ResetSystemProperties because SparkSubmit.main() sets a bunch // of properties that neeed to be cleared after tests. -class SparkSubmitSuite extends FunSuite with Matchers with ResetSystemProperties with Timeouts { +class SparkSubmitSuite + extends SparkFunSuite + with Matchers + with ResetSystemProperties + with Timeouts { + def beforeAll() { System.setProperty("spark.testing", "true") } diff --git a/core/src/test/scala/org/apache/spark/deploy/SparkSubmitUtilsSuite.scala b/core/src/test/scala/org/apache/spark/deploy/SparkSubmitUtilsSuite.scala index 088ca3cb93b49..8fda5c8b472c9 100644 --- a/core/src/test/scala/org/apache/spark/deploy/SparkSubmitUtilsSuite.scala +++ b/core/src/test/scala/org/apache/spark/deploy/SparkSubmitUtilsSuite.scala @@ -20,15 +20,16 @@ package org.apache.spark.deploy import java.io.{File, PrintStream, OutputStream} import scala.collection.mutable.ArrayBuffer -import org.scalatest.{BeforeAndAfterAll, FunSuite} +import org.scalatest.BeforeAndAfterAll import org.apache.ivy.core.module.descriptor.MDArtifact import org.apache.ivy.core.settings.IvySettings import org.apache.ivy.plugins.resolver.IBiblioResolver +import org.apache.spark.SparkFunSuite import org.apache.spark.deploy.SparkSubmitUtils.MavenCoordinate -class SparkSubmitUtilsSuite extends FunSuite with BeforeAndAfterAll { +class SparkSubmitUtilsSuite extends SparkFunSuite with BeforeAndAfterAll { private val noOpOutputStream = new OutputStream { def write(b: Int) = {} diff --git a/core/src/test/scala/org/apache/spark/deploy/history/FsHistoryProviderSuite.scala b/core/src/test/scala/org/apache/spark/deploy/history/FsHistoryProviderSuite.scala index a0a0afa48833e..0f6933df9e6bc 100644 --- a/core/src/test/scala/org/apache/spark/deploy/history/FsHistoryProviderSuite.scala +++ b/core/src/test/scala/org/apache/spark/deploy/history/FsHistoryProviderSuite.scala @@ -25,15 +25,15 @@ import scala.io.Source import org.apache.hadoop.fs.Path import org.json4s.jackson.JsonMethods._ -import org.scalatest.{BeforeAndAfter, FunSuite} +import org.scalatest.BeforeAndAfter import org.scalatest.Matchers -import org.apache.spark.{Logging, SparkConf} +import org.apache.spark.{Logging, SparkConf, SparkFunSuite} import org.apache.spark.io._ import org.apache.spark.scheduler._ import org.apache.spark.util.{JsonProtocol, ManualClock, Utils} -class FsHistoryProviderSuite extends FunSuite with BeforeAndAfter with Matchers with Logging { +class FsHistoryProviderSuite extends SparkFunSuite with BeforeAndAfter with Matchers with Logging { private var testDir: File = null diff --git a/core/src/test/scala/org/apache/spark/deploy/history/HistoryServerSuite.scala b/core/src/test/scala/org/apache/spark/deploy/history/HistoryServerSuite.scala index e10dd4cf837aa..14f2d1a5894b8 100644 --- a/core/src/test/scala/org/apache/spark/deploy/history/HistoryServerSuite.scala +++ b/core/src/test/scala/org/apache/spark/deploy/history/HistoryServerSuite.scala @@ -22,10 +22,10 @@ import javax.servlet.http.{HttpServletRequest, HttpServletResponse} import org.apache.commons.io.{FileUtils, IOUtils} import org.mockito.Mockito.when -import org.scalatest.{BeforeAndAfter, FunSuite, Matchers} +import org.scalatest.{BeforeAndAfter, Matchers} import org.scalatest.mock.MockitoSugar -import org.apache.spark.{JsonTestUtils, SecurityManager, SparkConf} +import org.apache.spark.{JsonTestUtils, SecurityManager, SparkConf, SparkFunSuite} import org.apache.spark.ui.SparkUI /** @@ -39,7 +39,7 @@ import org.apache.spark.ui.SparkUI * expectations. However, in general this should be done with extreme caution, as the metrics * are considered part of Spark's public api. */ -class HistoryServerSuite extends FunSuite with BeforeAndAfter with Matchers with MockitoSugar +class HistoryServerSuite extends SparkFunSuite with BeforeAndAfter with Matchers with MockitoSugar with JsonTestUtils { private val logDir = new File("src/test/resources/spark-events") diff --git a/core/src/test/scala/org/apache/spark/deploy/master/MasterSuite.scala b/core/src/test/scala/org/apache/spark/deploy/master/MasterSuite.scala index f97e5ff6db31d..014e87bb40254 100644 --- a/core/src/test/scala/org/apache/spark/deploy/master/MasterSuite.scala +++ b/core/src/test/scala/org/apache/spark/deploy/master/MasterSuite.scala @@ -27,14 +27,14 @@ import scala.language.postfixOps import akka.actor.Address import org.json4s._ import org.json4s.jackson.JsonMethods._ -import org.scalatest.{FunSuite, Matchers} +import org.scalatest.Matchers import org.scalatest.concurrent.Eventually import other.supplier.{CustomPersistenceEngine, CustomRecoveryModeFactory} -import org.apache.spark.{SparkConf, SparkException} +import org.apache.spark.{SparkConf, SparkException, SparkFunSuite} import org.apache.spark.deploy._ -class MasterSuite extends FunSuite with Matchers with Eventually { +class MasterSuite extends SparkFunSuite with Matchers with Eventually { test("toAkkaUrl") { val conf = new SparkConf(loadDefaults = false) diff --git a/core/src/test/scala/org/apache/spark/deploy/rest/StandaloneRestSubmitSuite.scala b/core/src/test/scala/org/apache/spark/deploy/rest/StandaloneRestSubmitSuite.scala index f4d548d9e7720..197f68e7ec5ed 100644 --- a/core/src/test/scala/org/apache/spark/deploy/rest/StandaloneRestSubmitSuite.scala +++ b/core/src/test/scala/org/apache/spark/deploy/rest/StandaloneRestSubmitSuite.scala @@ -25,7 +25,7 @@ import scala.collection.mutable import akka.actor.{Actor, ActorRef, ActorSystem, Props} import com.google.common.base.Charsets -import org.scalatest.{BeforeAndAfterEach, FunSuite} +import org.scalatest.BeforeAndAfterEach import org.json4s.JsonAST._ import org.json4s.jackson.JsonMethods._ @@ -38,7 +38,7 @@ import org.apache.spark.deploy.master.DriverState._ /** * Tests for the REST application submission protocol used in standalone cluster mode. */ -class StandaloneRestSubmitSuite extends FunSuite with BeforeAndAfterEach { +class StandaloneRestSubmitSuite extends SparkFunSuite with BeforeAndAfterEach { private var actorSystem: Option[ActorSystem] = None private var server: Option[RestSubmissionServer] = None diff --git a/core/src/test/scala/org/apache/spark/deploy/rest/SubmitRestProtocolSuite.scala b/core/src/test/scala/org/apache/spark/deploy/rest/SubmitRestProtocolSuite.scala index 61071ee17256c..115ac0534a1b4 100644 --- a/core/src/test/scala/org/apache/spark/deploy/rest/SubmitRestProtocolSuite.scala +++ b/core/src/test/scala/org/apache/spark/deploy/rest/SubmitRestProtocolSuite.scala @@ -21,14 +21,13 @@ import java.lang.Boolean import java.lang.Integer import org.json4s.jackson.JsonMethods._ -import org.scalatest.FunSuite -import org.apache.spark.SparkConf +import org.apache.spark.{SparkConf, SparkFunSuite} /** * Tests for the REST application submission protocol. */ -class SubmitRestProtocolSuite extends FunSuite { +class SubmitRestProtocolSuite extends SparkFunSuite { test("validate") { val request = new DummyRequest diff --git a/core/src/test/scala/org/apache/spark/deploy/worker/CommandUtilsSuite.scala b/core/src/test/scala/org/apache/spark/deploy/worker/CommandUtilsSuite.scala index 1c27d83cf876c..5b3930c0b0132 100644 --- a/core/src/test/scala/org/apache/spark/deploy/worker/CommandUtilsSuite.scala +++ b/core/src/test/scala/org/apache/spark/deploy/worker/CommandUtilsSuite.scala @@ -17,11 +17,12 @@ package org.apache.spark.deploy.worker +import org.apache.spark.SparkFunSuite import org.apache.spark.deploy.Command import org.apache.spark.util.Utils -import org.scalatest.{FunSuite, Matchers} +import org.scalatest.Matchers -class CommandUtilsSuite extends FunSuite with Matchers { +class CommandUtilsSuite extends SparkFunSuite with Matchers { test("set libraryPath correctly") { val appId = "12345-worker321-9876" diff --git a/core/src/test/scala/org/apache/spark/deploy/worker/DriverRunnerTest.scala b/core/src/test/scala/org/apache/spark/deploy/worker/DriverRunnerTest.scala index 2159fd8c16c6f..6258c18d177fd 100644 --- a/core/src/test/scala/org/apache/spark/deploy/worker/DriverRunnerTest.scala +++ b/core/src/test/scala/org/apache/spark/deploy/worker/DriverRunnerTest.scala @@ -23,13 +23,12 @@ import org.mockito.Mockito._ import org.mockito.Matchers._ import org.mockito.invocation.InvocationOnMock import org.mockito.stubbing.Answer -import org.scalatest.FunSuite -import org.apache.spark.{SecurityManager, SparkConf} +import org.apache.spark.{SecurityManager, SparkConf, SparkFunSuite} import org.apache.spark.deploy.{Command, DriverDescription} import org.apache.spark.util.Clock -class DriverRunnerTest extends FunSuite { +class DriverRunnerTest extends SparkFunSuite { private def createDriverRunner() = { val command = new Command("mainClass", Seq(), Map(), Seq(), Seq(), Seq()) val driverDescription = new DriverDescription("jarUrl", 512, 1, true, command) diff --git a/core/src/test/scala/org/apache/spark/deploy/worker/ExecutorRunnerTest.scala b/core/src/test/scala/org/apache/spark/deploy/worker/ExecutorRunnerTest.scala index a8b9df227c996..3da992788962b 100644 --- a/core/src/test/scala/org/apache/spark/deploy/worker/ExecutorRunnerTest.scala +++ b/core/src/test/scala/org/apache/spark/deploy/worker/ExecutorRunnerTest.scala @@ -21,12 +21,10 @@ import java.io.File import scala.collection.JavaConversions._ -import org.scalatest.FunSuite - import org.apache.spark.deploy.{ApplicationDescription, Command, ExecutorState} -import org.apache.spark.SparkConf +import org.apache.spark.{SparkConf, SparkFunSuite} -class ExecutorRunnerTest extends FunSuite { +class ExecutorRunnerTest extends SparkFunSuite { test("command includes appId") { val appId = "12345-worker321-9876" val sparkHome = sys.props.getOrElse("spark.test.home", fail("spark.test.home is not set!")) diff --git a/core/src/test/scala/org/apache/spark/deploy/worker/WorkerArgumentsTest.scala b/core/src/test/scala/org/apache/spark/deploy/worker/WorkerArgumentsTest.scala index e432b8e94654a..15f7ca4a6dacc 100644 --- a/core/src/test/scala/org/apache/spark/deploy/worker/WorkerArgumentsTest.scala +++ b/core/src/test/scala/org/apache/spark/deploy/worker/WorkerArgumentsTest.scala @@ -18,11 +18,10 @@ package org.apache.spark.deploy.worker -import org.apache.spark.SparkConf -import org.scalatest.FunSuite +import org.apache.spark.{SparkConf, SparkFunSuite} -class WorkerArgumentsTest extends FunSuite { +class WorkerArgumentsTest extends SparkFunSuite { test("Memory can't be set to 0 when cmd line args leave off M or G") { val conf = new SparkConf diff --git a/core/src/test/scala/org/apache/spark/deploy/worker/WorkerSuite.scala b/core/src/test/scala/org/apache/spark/deploy/worker/WorkerSuite.scala index 93a779d5ce6f2..0f4d3b28d09df 100644 --- a/core/src/test/scala/org/apache/spark/deploy/worker/WorkerSuite.scala +++ b/core/src/test/scala/org/apache/spark/deploy/worker/WorkerSuite.scala @@ -17,12 +17,12 @@ package org.apache.spark.deploy.worker -import org.apache.spark.SparkConf +import org.apache.spark.{SparkConf, SparkFunSuite} import org.apache.spark.deploy.Command -import org.scalatest.{Matchers, FunSuite} +import org.scalatest.Matchers -class WorkerSuite extends FunSuite with Matchers { +class WorkerSuite extends SparkFunSuite with Matchers { def cmd(javaOpts: String*): Command = { Command("", Seq.empty, Map.empty, Seq.empty, Seq.empty, Seq(javaOpts : _*)) diff --git a/core/src/test/scala/org/apache/spark/deploy/worker/WorkerWatcherSuite.scala b/core/src/test/scala/org/apache/spark/deploy/worker/WorkerWatcherSuite.scala index 6a6f29dd613cd..ac18f04a11475 100644 --- a/core/src/test/scala/org/apache/spark/deploy/worker/WorkerWatcherSuite.scala +++ b/core/src/test/scala/org/apache/spark/deploy/worker/WorkerWatcherSuite.scala @@ -18,12 +18,11 @@ package org.apache.spark.deploy.worker import akka.actor.AddressFromURIString -import org.apache.spark.SparkConf +import org.apache.spark.{SparkConf, SparkFunSuite} import org.apache.spark.SecurityManager import org.apache.spark.rpc.{RpcAddress, RpcEnv} -import org.scalatest.FunSuite -class WorkerWatcherSuite extends FunSuite { +class WorkerWatcherSuite extends SparkFunSuite { test("WorkerWatcher shuts down on valid disassociation") { val conf = new SparkConf() val rpcEnv = RpcEnv.create("test", "localhost", 12345, conf, new SecurityManager(conf)) diff --git a/core/src/test/scala/org/apache/spark/deploy/worker/ui/LogPageSuite.scala b/core/src/test/scala/org/apache/spark/deploy/worker/ui/LogPageSuite.scala new file mode 100644 index 0000000000000..572360ddb95d4 --- /dev/null +++ b/core/src/test/scala/org/apache/spark/deploy/worker/ui/LogPageSuite.scala @@ -0,0 +1,70 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.deploy.worker.ui + +import java.io.{File, FileWriter} + +import org.mockito.Mockito.mock +import org.scalatest.PrivateMethodTester + +import org.apache.spark.SparkFunSuite + +class LogPageSuite extends SparkFunSuite with PrivateMethodTester { + + test("get logs simple") { + val webui = mock(classOf[WorkerWebUI]) + val logPage = new LogPage(webui) + + // Prepare some fake log files to read later + val out = "some stdout here" + val err = "some stderr here" + val tmpDir = new File(sys.props("java.io.tmpdir")) + val tmpOut = new File(tmpDir, "stdout") + val tmpErr = new File(tmpDir, "stderr") + val tmpRand = new File(tmpDir, "random") + write(tmpOut, out) + write(tmpErr, err) + write(tmpRand, "1 6 4 5 2 7 8") + + // Get the logs. All log types other than "stderr" or "stdout" will be rejected + val getLog = PrivateMethod[(String, Long, Long, Long)]('getLog) + val (stdout, _, _, _) = + logPage invokePrivate getLog(tmpDir.getAbsolutePath, "stdout", None, 100) + val (stderr, _, _, _) = + logPage invokePrivate getLog(tmpDir.getAbsolutePath, "stderr", None, 100) + val (error1, _, _, _) = + logPage invokePrivate getLog(tmpDir.getAbsolutePath, "random", None, 100) + val (error2, _, _, _) = + logPage invokePrivate getLog(tmpDir.getAbsolutePath, "does-not-exist.txt", None, 100) + assert(stdout === out) + assert(stderr === err) + assert(error1.startsWith("Error")) + assert(error2.startsWith("Error")) + } + + /** Write the specified string to the file. */ + private def write(f: File, s: String): Unit = { + val writer = new FileWriter(f) + try { + writer.write(s) + } finally { + writer.close() + } + } + +} diff --git a/core/src/test/scala/org/apache/spark/executor/TaskMetricsSuite.scala b/core/src/test/scala/org/apache/spark/executor/TaskMetricsSuite.scala index 326e203afe136..8275fd87764cd 100644 --- a/core/src/test/scala/org/apache/spark/executor/TaskMetricsSuite.scala +++ b/core/src/test/scala/org/apache/spark/executor/TaskMetricsSuite.scala @@ -17,9 +17,9 @@ package org.apache.spark.executor -import org.scalatest.FunSuite +import org.apache.spark.SparkFunSuite -class TaskMetricsSuite extends FunSuite { +class TaskMetricsSuite extends SparkFunSuite { test("[SPARK-5701] updateShuffleReadMetrics: ShuffleReadMetrics not added when no shuffle deps") { val taskMetrics = new TaskMetrics() taskMetrics.updateShuffleReadMetrics() diff --git a/core/src/test/scala/org/apache/spark/input/WholeTextFileRecordReaderSuite.scala b/core/src/test/scala/org/apache/spark/input/WholeTextFileRecordReaderSuite.scala index 2e58c159a2ed8..63947df3d43a2 100644 --- a/core/src/test/scala/org/apache/spark/input/WholeTextFileRecordReaderSuite.scala +++ b/core/src/test/scala/org/apache/spark/input/WholeTextFileRecordReaderSuite.scala @@ -24,11 +24,10 @@ import java.io.FileOutputStream import scala.collection.immutable.IndexedSeq import org.scalatest.BeforeAndAfterAll -import org.scalatest.FunSuite import org.apache.hadoop.io.Text -import org.apache.spark.{SparkConf, SparkContext} +import org.apache.spark.{SparkConf, SparkContext, SparkFunSuite} import org.apache.spark.util.Utils import org.apache.hadoop.io.compress.{DefaultCodec, CompressionCodecFactory, GzipCodec} @@ -37,7 +36,7 @@ import org.apache.hadoop.io.compress.{DefaultCodec, CompressionCodecFactory, Gzi * [[org.apache.spark.input.WholeTextFileRecordReader WholeTextFileRecordReader]]. A temporary * directory is created as fake input. Temporal storage would be deleted in the end. */ -class WholeTextFileRecordReaderSuite extends FunSuite with BeforeAndAfterAll { +class WholeTextFileRecordReaderSuite extends SparkFunSuite with BeforeAndAfterAll { private var sc: SparkContext = _ private var factory: CompressionCodecFactory = _ diff --git a/core/src/test/scala/org/apache/spark/io/CompressionCodecSuite.scala b/core/src/test/scala/org/apache/spark/io/CompressionCodecSuite.scala index cf6a143537889..cbdb33c89d0fb 100644 --- a/core/src/test/scala/org/apache/spark/io/CompressionCodecSuite.scala +++ b/core/src/test/scala/org/apache/spark/io/CompressionCodecSuite.scala @@ -20,11 +20,10 @@ package org.apache.spark.io import java.io.{ByteArrayInputStream, ByteArrayOutputStream} import com.google.common.io.ByteStreams -import org.scalatest.FunSuite -import org.apache.spark.SparkConf +import org.apache.spark.{SparkConf, SparkFunSuite} -class CompressionCodecSuite extends FunSuite { +class CompressionCodecSuite extends SparkFunSuite { val conf = new SparkConf(false) def testCodec(codec: CompressionCodec) { diff --git a/core/src/test/scala/org/apache/spark/metrics/InputOutputMetricsSuite.scala b/core/src/test/scala/org/apache/spark/metrics/InputOutputMetricsSuite.scala index 60dba3b2d6719..9e4d34fb7d382 100644 --- a/core/src/test/scala/org/apache/spark/metrics/InputOutputMetricsSuite.scala +++ b/core/src/test/scala/org/apache/spark/metrics/InputOutputMetricsSuite.scala @@ -36,14 +36,14 @@ import org.apache.hadoop.mapreduce.lib.input.{CombineFileInputFormat => NewCombi import org.apache.hadoop.mapreduce.lib.output.{TextOutputFormat => NewTextOutputFormat} import org.apache.hadoop.mapreduce.{TaskAttemptContext, InputSplit => NewInputSplit, RecordReader => NewRecordReader} -import org.scalatest.{BeforeAndAfter, FunSuite} +import org.scalatest.BeforeAndAfter -import org.apache.spark.SharedSparkContext +import org.apache.spark.{SharedSparkContext, SparkFunSuite} import org.apache.spark.deploy.SparkHadoopUtil import org.apache.spark.scheduler.{SparkListener, SparkListenerTaskEnd} import org.apache.spark.util.Utils -class InputOutputMetricsSuite extends FunSuite with SharedSparkContext +class InputOutputMetricsSuite extends SparkFunSuite with SharedSparkContext with BeforeAndAfter { @transient var tmpDir: File = _ @@ -193,26 +193,6 @@ class InputOutputMetricsSuite extends FunSuite with SharedSparkContext assert(records == numRecords) } - test("shuffle records read metrics") { - val recordsRead = runAndReturnShuffleRecordsRead { - sc.textFile(tmpFilePath, 4) - .map(key => (key, 1)) - .groupByKey() - .collect() - } - assert(recordsRead == numRecords) - } - - test("shuffle records written metrics") { - val recordsWritten = runAndReturnShuffleRecordsWritten { - sc.textFile(tmpFilePath, 4) - .map(key => (key, 1)) - .groupByKey() - .collect() - } - assert(recordsWritten == numRecords) - } - /** * Tests the metrics from end to end. * 1) reading a hadoop file @@ -301,14 +281,6 @@ class InputOutputMetricsSuite extends FunSuite with SharedSparkContext runAndReturnMetrics(job, _.taskMetrics.outputMetrics.map(_.recordsWritten)) } - private def runAndReturnShuffleRecordsRead(job: => Unit): Long = { - runAndReturnMetrics(job, _.taskMetrics.shuffleReadMetrics.map(_.recordsRead)) - } - - private def runAndReturnShuffleRecordsWritten(job: => Unit): Long = { - runAndReturnMetrics(job, _.taskMetrics.shuffleWriteMetrics.map(_.shuffleRecordsWritten)) - } - private def runAndReturnMetrics(job: => Unit, collector: (SparkListenerTaskEnd) => Option[Long]): Long = { val taskMetrics = new ArrayBuffer[Long]() diff --git a/core/src/test/scala/org/apache/spark/metrics/MetricsConfigSuite.scala b/core/src/test/scala/org/apache/spark/metrics/MetricsConfigSuite.scala index 100ac77dec1f7..a901a069d9bfe 100644 --- a/core/src/test/scala/org/apache/spark/metrics/MetricsConfigSuite.scala +++ b/core/src/test/scala/org/apache/spark/metrics/MetricsConfigSuite.scala @@ -17,9 +17,11 @@ package org.apache.spark.metrics -import org.scalatest.{BeforeAndAfter, FunSuite} +import org.scalatest.BeforeAndAfter -class MetricsConfigSuite extends FunSuite with BeforeAndAfter { +import org.apache.spark.SparkFunSuite + +class MetricsConfigSuite extends SparkFunSuite with BeforeAndAfter { var filePath: String = _ before { diff --git a/core/src/test/scala/org/apache/spark/metrics/MetricsSystemSuite.scala b/core/src/test/scala/org/apache/spark/metrics/MetricsSystemSuite.scala index bbdc9568a6ddb..9c389c76bf3bd 100644 --- a/core/src/test/scala/org/apache/spark/metrics/MetricsSystemSuite.scala +++ b/core/src/test/scala/org/apache/spark/metrics/MetricsSystemSuite.scala @@ -17,9 +17,9 @@ package org.apache.spark.metrics -import org.scalatest.{BeforeAndAfter, FunSuite, PrivateMethodTester} +import org.scalatest.{BeforeAndAfter, PrivateMethodTester} -import org.apache.spark.{SecurityManager, SparkConf} +import org.apache.spark.{SecurityManager, SparkConf, SparkFunSuite} import org.apache.spark.deploy.master.MasterSource import org.apache.spark.metrics.source.Source @@ -27,7 +27,7 @@ import com.codahale.metrics.MetricRegistry import scala.collection.mutable.ArrayBuffer -class MetricsSystemSuite extends FunSuite with BeforeAndAfter with PrivateMethodTester{ +class MetricsSystemSuite extends SparkFunSuite with BeforeAndAfter with PrivateMethodTester{ var filePath: String = _ var conf: SparkConf = null var securityMgr: SecurityManager = null diff --git a/core/src/test/scala/org/apache/spark/network/netty/NettyBlockTransferSecuritySuite.scala b/core/src/test/scala/org/apache/spark/network/netty/NettyBlockTransferSecuritySuite.scala index 46d2e5173acae..3940527fb874e 100644 --- a/core/src/test/scala/org/apache/spark/network/netty/NettyBlockTransferSecuritySuite.scala +++ b/core/src/test/scala/org/apache/spark/network/netty/NettyBlockTransferSecuritySuite.scala @@ -31,12 +31,12 @@ import org.apache.spark.network.buffer.{ManagedBuffer, NioManagedBuffer} import org.apache.spark.network.shuffle.BlockFetchingListener import org.apache.spark.network.{BlockDataManager, BlockTransferService} import org.apache.spark.storage.{BlockId, ShuffleBlockId} -import org.apache.spark.{SecurityManager, SparkConf} +import org.apache.spark.{SecurityManager, SparkConf, SparkFunSuite} import org.mockito.Mockito._ import org.scalatest.mock.MockitoSugar -import org.scalatest.{FunSuite, ShouldMatchers} +import org.scalatest.ShouldMatchers -class NettyBlockTransferSecuritySuite extends FunSuite with MockitoSugar with ShouldMatchers { +class NettyBlockTransferSecuritySuite extends SparkFunSuite with MockitoSugar with ShouldMatchers { test("security default off") { val conf = new SparkConf() .set("spark.app.id", "app-id") diff --git a/core/src/test/scala/org/apache/spark/network/netty/NettyBlockTransferServiceSuite.scala b/core/src/test/scala/org/apache/spark/network/netty/NettyBlockTransferServiceSuite.scala index a41f8b7ce5ce0..6f8e8a7ac6033 100644 --- a/core/src/test/scala/org/apache/spark/network/netty/NettyBlockTransferServiceSuite.scala +++ b/core/src/test/scala/org/apache/spark/network/netty/NettyBlockTransferServiceSuite.scala @@ -18,11 +18,15 @@ package org.apache.spark.network.netty import org.apache.spark.network.BlockDataManager -import org.apache.spark.{SecurityManager, SparkConf} +import org.apache.spark.{SecurityManager, SparkConf, SparkFunSuite} import org.mockito.Mockito.mock import org.scalatest._ -class NettyBlockTransferServiceSuite extends FunSuite with BeforeAndAfterEach with ShouldMatchers { +class NettyBlockTransferServiceSuite + extends SparkFunSuite + with BeforeAndAfterEach + with ShouldMatchers { + private var service0: NettyBlockTransferService = _ private var service1: NettyBlockTransferService = _ diff --git a/core/src/test/scala/org/apache/spark/network/nio/ConnectionManagerSuite.scala b/core/src/test/scala/org/apache/spark/network/nio/ConnectionManagerSuite.scala index 02424c59d6831..5e364cc0edeb2 100644 --- a/core/src/test/scala/org/apache/spark/network/nio/ConnectionManagerSuite.scala +++ b/core/src/test/scala/org/apache/spark/network/nio/ConnectionManagerSuite.scala @@ -24,15 +24,13 @@ import scala.concurrent.duration._ import scala.concurrent.{Await, TimeoutException} import scala.language.postfixOps -import org.scalatest.FunSuite - -import org.apache.spark.{SecurityManager, SparkConf} +import org.apache.spark.{SecurityManager, SparkConf, SparkFunSuite} import org.apache.spark.util.Utils /** * Test the ConnectionManager with various security settings. */ -class ConnectionManagerSuite extends FunSuite { +class ConnectionManagerSuite extends SparkFunSuite { test("security default off") { val conf = new SparkConf diff --git a/core/src/test/scala/org/apache/spark/rdd/AsyncRDDActionsSuite.scala b/core/src/test/scala/org/apache/spark/rdd/AsyncRDDActionsSuite.scala index f2b0ea1063a72..ec99f2a1bad66 100644 --- a/core/src/test/scala/org/apache/spark/rdd/AsyncRDDActionsSuite.scala +++ b/core/src/test/scala/org/apache/spark/rdd/AsyncRDDActionsSuite.scala @@ -23,13 +23,13 @@ import scala.concurrent.{Await, TimeoutException} import scala.concurrent.duration.Duration import scala.concurrent.ExecutionContext.Implicits.global -import org.scalatest.{BeforeAndAfterAll, FunSuite} +import org.scalatest.BeforeAndAfterAll import org.scalatest.concurrent.Timeouts import org.scalatest.time.SpanSugar._ -import org.apache.spark.{SparkContext, SparkException, LocalSparkContext} +import org.apache.spark.{LocalSparkContext, SparkContext, SparkException, SparkFunSuite} -class AsyncRDDActionsSuite extends FunSuite with BeforeAndAfterAll with Timeouts { +class AsyncRDDActionsSuite extends SparkFunSuite with BeforeAndAfterAll with Timeouts { @transient private var sc: SparkContext = _ diff --git a/core/src/test/scala/org/apache/spark/rdd/DoubleRDDSuite.scala b/core/src/test/scala/org/apache/spark/rdd/DoubleRDDSuite.scala index 01039b9449daf..4e72b89bfcc40 100644 --- a/core/src/test/scala/org/apache/spark/rdd/DoubleRDDSuite.scala +++ b/core/src/test/scala/org/apache/spark/rdd/DoubleRDDSuite.scala @@ -17,11 +17,9 @@ package org.apache.spark.rdd -import org.scalatest.FunSuite - import org.apache.spark._ -class DoubleRDDSuite extends FunSuite with SharedSparkContext { +class DoubleRDDSuite extends SparkFunSuite with SharedSparkContext { test("sum") { assert(sc.parallelize(Seq.empty[Double]).sum() === 0.0) assert(sc.parallelize(Seq(1.0)).sum() === 1.0) diff --git a/core/src/test/scala/org/apache/spark/rdd/JdbcRDDSuite.scala b/core/src/test/scala/org/apache/spark/rdd/JdbcRDDSuite.scala index be8467354b222..08215a2bafc09 100644 --- a/core/src/test/scala/org/apache/spark/rdd/JdbcRDDSuite.scala +++ b/core/src/test/scala/org/apache/spark/rdd/JdbcRDDSuite.scala @@ -19,11 +19,11 @@ package org.apache.spark.rdd import java.sql._ -import org.scalatest.{BeforeAndAfter, FunSuite} +import org.scalatest.BeforeAndAfter -import org.apache.spark.{LocalSparkContext, SparkContext} +import org.apache.spark.{LocalSparkContext, SparkContext, SparkFunSuite} -class JdbcRDDSuite extends FunSuite with BeforeAndAfter with LocalSparkContext { +class JdbcRDDSuite extends SparkFunSuite with BeforeAndAfter with LocalSparkContext { before { Class.forName("org.apache.derby.jdbc.EmbeddedDriver") @@ -82,7 +82,7 @@ class JdbcRDDSuite extends FunSuite with BeforeAndAfter with LocalSparkContext { assert(rdd.count === 100) assert(rdd.reduce(_ + _) === 10100) } - + test("large id overflow") { sc = new SparkContext("local", "test") val rdd = new JdbcRDD( diff --git a/core/src/test/scala/org/apache/spark/rdd/PairRDDFunctionsSuite.scala b/core/src/test/scala/org/apache/spark/rdd/PairRDDFunctionsSuite.scala index 6564232986cfa..dfa102f432a02 100644 --- a/core/src/test/scala/org/apache/spark/rdd/PairRDDFunctionsSuite.scala +++ b/core/src/test/scala/org/apache/spark/rdd/PairRDDFunctionsSuite.scala @@ -28,12 +28,10 @@ import org.apache.hadoop.conf.{Configurable, Configuration} import org.apache.hadoop.mapreduce.{JobContext => NewJobContext, OutputCommitter => NewOutputCommitter, OutputFormat => NewOutputFormat, RecordWriter => NewRecordWriter, TaskAttemptContext => NewTaskAttempContext} -import org.apache.spark.{Partitioner, SharedSparkContext} +import org.apache.spark.{Partitioner, SharedSparkContext, SparkFunSuite} import org.apache.spark.util.Utils -import org.scalatest.FunSuite - -class PairRDDFunctionsSuite extends FunSuite with SharedSparkContext { +class PairRDDFunctionsSuite extends SparkFunSuite with SharedSparkContext { test("aggregateByKey") { val pairs = sc.parallelize(Array((1, 1), (1, 1), (3, 2), (5, 1), (5, 3)), 2) diff --git a/core/src/test/scala/org/apache/spark/rdd/ParallelCollectionSplitSuite.scala b/core/src/test/scala/org/apache/spark/rdd/ParallelCollectionSplitSuite.scala index 1880364581c1a..e7cc1617cdf1c 100644 --- a/core/src/test/scala/org/apache/spark/rdd/ParallelCollectionSplitSuite.scala +++ b/core/src/test/scala/org/apache/spark/rdd/ParallelCollectionSplitSuite.scala @@ -22,10 +22,11 @@ import scala.collection.immutable.NumericRange import org.scalacheck.Arbitrary._ import org.scalacheck.Gen import org.scalacheck.Prop._ -import org.scalatest.FunSuite import org.scalatest.prop.Checkers -class ParallelCollectionSplitSuite extends FunSuite with Checkers { +import org.apache.spark.SparkFunSuite + +class ParallelCollectionSplitSuite extends SparkFunSuite with Checkers { test("one element per slice") { val data = Array(1, 2, 3) val slices = ParallelCollectionRDD.slice(data, 3) diff --git a/core/src/test/scala/org/apache/spark/rdd/PartitionPruningRDDSuite.scala b/core/src/test/scala/org/apache/spark/rdd/PartitionPruningRDDSuite.scala index 465068c6cbb16..b1544a6106110 100644 --- a/core/src/test/scala/org/apache/spark/rdd/PartitionPruningRDDSuite.scala +++ b/core/src/test/scala/org/apache/spark/rdd/PartitionPruningRDDSuite.scala @@ -17,11 +17,9 @@ package org.apache.spark.rdd -import org.scalatest.FunSuite +import org.apache.spark.{Partition, SharedSparkContext, SparkFunSuite, TaskContext} -import org.apache.spark.{Partition, SharedSparkContext, TaskContext} - -class PartitionPruningRDDSuite extends FunSuite with SharedSparkContext { +class PartitionPruningRDDSuite extends SparkFunSuite with SharedSparkContext { test("Pruned Partitions inherit locality prefs correctly") { diff --git a/core/src/test/scala/org/apache/spark/rdd/PartitionwiseSampledRDDSuite.scala b/core/src/test/scala/org/apache/spark/rdd/PartitionwiseSampledRDDSuite.scala index 0d1369c19c69e..132a5fa9a80fb 100644 --- a/core/src/test/scala/org/apache/spark/rdd/PartitionwiseSampledRDDSuite.scala +++ b/core/src/test/scala/org/apache/spark/rdd/PartitionwiseSampledRDDSuite.scala @@ -17,9 +17,7 @@ package org.apache.spark.rdd -import org.scalatest.FunSuite - -import org.apache.spark.SharedSparkContext +import org.apache.spark.{SharedSparkContext, SparkFunSuite} import org.apache.spark.util.random.{BernoulliSampler, PoissonSampler, RandomSampler} /** a sampler that outputs its seed */ @@ -38,7 +36,7 @@ class MockSampler extends RandomSampler[Long, Long] { override def clone: MockSampler = new MockSampler } -class PartitionwiseSampledRDDSuite extends FunSuite with SharedSparkContext { +class PartitionwiseSampledRDDSuite extends SparkFunSuite with SharedSparkContext { test("seed distribution") { val rdd = sc.makeRDD(Array(1L, 2L, 3L, 4L), 2) diff --git a/core/src/test/scala/org/apache/spark/rdd/PipedRDDSuite.scala b/core/src/test/scala/org/apache/spark/rdd/PipedRDDSuite.scala index 85eb2a1d07ba4..32f04d54eff94 100644 --- a/core/src/test/scala/org/apache/spark/rdd/PipedRDDSuite.scala +++ b/core/src/test/scala/org/apache/spark/rdd/PipedRDDSuite.scala @@ -22,7 +22,6 @@ import java.io.File import org.apache.hadoop.fs.Path import org.apache.hadoop.io.{LongWritable, Text} import org.apache.hadoop.mapred.{FileSplit, JobConf, TextInputFormat} -import org.scalatest.FunSuite import scala.collection.Map import scala.language.postfixOps @@ -32,7 +31,7 @@ import scala.util.Try import org.apache.spark._ import org.apache.spark.util.Utils -class PipedRDDSuite extends FunSuite with SharedSparkContext { +class PipedRDDSuite extends SparkFunSuite with SharedSparkContext { test("basic pipe") { if (testCommandAvailable("cat")) { diff --git a/core/src/test/scala/org/apache/spark/rdd/RDDOperationScopeSuite.scala b/core/src/test/scala/org/apache/spark/rdd/RDDOperationScopeSuite.scala index 4434ed858c60c..f65349e3e3585 100644 --- a/core/src/test/scala/org/apache/spark/rdd/RDDOperationScopeSuite.scala +++ b/core/src/test/scala/org/apache/spark/rdd/RDDOperationScopeSuite.scala @@ -17,14 +17,14 @@ package org.apache.spark.rdd -import org.scalatest.{BeforeAndAfter, FunSuite} +import org.scalatest.BeforeAndAfter -import org.apache.spark.{TaskContext, Partition, SparkContext} +import org.apache.spark.{Partition, SparkContext, SparkFunSuite, TaskContext} /** * Tests whether scopes are passed from the RDD operation to the RDDs correctly. */ -class RDDOperationScopeSuite extends FunSuite with BeforeAndAfter { +class RDDOperationScopeSuite extends SparkFunSuite with BeforeAndAfter { private var sc: SparkContext = null private val scope1 = new RDDOperationScope("scope1") private val scope2 = new RDDOperationScope("scope2", Some(scope1)) diff --git a/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala b/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala index 8079d5dcaea81..f6da9f98ad253 100644 --- a/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala +++ b/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala @@ -25,14 +25,12 @@ import scala.collection.mutable.{ArrayBuffer, HashMap} import scala.collection.JavaConverters._ import scala.reflect.ClassTag -import org.scalatest.FunSuite - import org.apache.spark._ import org.apache.spark.api.java.{JavaRDD, JavaSparkContext} import org.apache.spark.rdd.RDDSuiteUtils._ import org.apache.spark.util.Utils -class RDDSuite extends FunSuite with SharedSparkContext { +class RDDSuite extends SparkFunSuite with SharedSparkContext { test("basic operations") { val nums = sc.makeRDD(Array(1, 2, 3, 4), 2) diff --git a/core/src/test/scala/org/apache/spark/rdd/SortingSuite.scala b/core/src/test/scala/org/apache/spark/rdd/SortingSuite.scala index 54fc914722b46..a7de9cabe7cc9 100644 --- a/core/src/test/scala/org/apache/spark/rdd/SortingSuite.scala +++ b/core/src/test/scala/org/apache/spark/rdd/SortingSuite.scala @@ -17,12 +17,11 @@ package org.apache.spark.rdd -import org.scalatest.FunSuite import org.scalatest.Matchers -import org.apache.spark.{Logging, SharedSparkContext} +import org.apache.spark.{Logging, SharedSparkContext, SparkFunSuite} -class SortingSuite extends FunSuite with SharedSparkContext with Matchers with Logging { +class SortingSuite extends SparkFunSuite with SharedSparkContext with Matchers with Logging { test("sortByKey") { val pairs = sc.parallelize(Array((1, 0), (2, 0), (0, 0), (3, 0)), 2) diff --git a/core/src/test/scala/org/apache/spark/rdd/ZippedPartitionsSuite.scala b/core/src/test/scala/org/apache/spark/rdd/ZippedPartitionsSuite.scala index 72596e86865b2..5d7b973fbd9ac 100644 --- a/core/src/test/scala/org/apache/spark/rdd/ZippedPartitionsSuite.scala +++ b/core/src/test/scala/org/apache/spark/rdd/ZippedPartitionsSuite.scala @@ -17,8 +17,7 @@ package org.apache.spark.rdd -import org.apache.spark.SharedSparkContext -import org.scalatest.FunSuite +import org.apache.spark.{SharedSparkContext, SparkFunSuite} object ZippedPartitionsSuite { def procZippedData(i: Iterator[Int], s: Iterator[String], d: Iterator[Double]) : Iterator[Int] = { @@ -26,7 +25,7 @@ object ZippedPartitionsSuite { } } -class ZippedPartitionsSuite extends FunSuite with SharedSparkContext { +class ZippedPartitionsSuite extends SparkFunSuite with SharedSparkContext { test("print sizes") { val data1 = sc.makeRDD(Array(1, 2, 3, 4), 2) val data2 = sc.makeRDD(Array("1", "2", "3", "4", "5", "6"), 2) diff --git a/core/src/test/scala/org/apache/spark/rpc/RpcEnvSuite.scala b/core/src/test/scala/org/apache/spark/rpc/RpcEnvSuite.scala index 21eb71d9acfbd..1f0aa759b08da 100644 --- a/core/src/test/scala/org/apache/spark/rpc/RpcEnvSuite.scala +++ b/core/src/test/scala/org/apache/spark/rpc/RpcEnvSuite.scala @@ -24,15 +24,15 @@ import scala.concurrent.Await import scala.concurrent.duration._ import scala.language.postfixOps -import org.scalatest.{BeforeAndAfterAll, FunSuite} +import org.scalatest.BeforeAndAfterAll import org.scalatest.concurrent.Eventually._ -import org.apache.spark.{SparkException, SparkConf} +import org.apache.spark.{SparkConf, SparkException, SparkFunSuite} /** * Common tests for an RpcEnv implementation. */ -abstract class RpcEnvSuite extends FunSuite with BeforeAndAfterAll { +abstract class RpcEnvSuite extends SparkFunSuite with BeforeAndAfterAll { var env: RpcEnv = _ diff --git a/core/src/test/scala/org/apache/spark/scheduler/CoarseGrainedSchedulerBackendSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/CoarseGrainedSchedulerBackendSuite.scala index 3821166386fa6..34145691153ce 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/CoarseGrainedSchedulerBackendSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/CoarseGrainedSchedulerBackendSuite.scala @@ -17,12 +17,10 @@ package org.apache.spark.scheduler -import org.apache.spark.{LocalSparkContext, SparkConf, SparkException, SparkContext} +import org.apache.spark.{LocalSparkContext, SparkConf, SparkContext, SparkException, SparkFunSuite} import org.apache.spark.util.{SerializableBuffer, AkkaUtils} -import org.scalatest.FunSuite - -class CoarseGrainedSchedulerBackendSuite extends FunSuite with LocalSparkContext { +class CoarseGrainedSchedulerBackendSuite extends SparkFunSuite with LocalSparkContext { test("serialized task larger than akka frame size") { val conf = new SparkConf diff --git a/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala index eea7a600841cc..bfcf918e06162 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala @@ -21,7 +21,7 @@ import scala.collection.mutable.{ArrayBuffer, HashSet, HashMap, Map} import scala.language.reflectiveCalls import scala.util.control.NonFatal -import org.scalatest.{BeforeAndAfter, FunSuiteLike} +import org.scalatest.BeforeAndAfter import org.scalatest.concurrent.Timeouts import org.scalatest.time.SpanSugar._ @@ -68,7 +68,7 @@ class MyRDD( class DAGSchedulerSuiteDummyException extends Exception class DAGSchedulerSuite - extends FunSuiteLike with BeforeAndAfter with LocalSparkContext with Timeouts { + extends SparkFunSuite with BeforeAndAfter with LocalSparkContext with Timeouts { val conf = new SparkConf /** Set of TaskSets the DAGScheduler has requested executed. */ diff --git a/core/src/test/scala/org/apache/spark/scheduler/EventLoggingListenerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/EventLoggingListenerSuite.scala index b52a8d11d147d..f681f21b6205e 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/EventLoggingListenerSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/EventLoggingListenerSuite.scala @@ -25,7 +25,7 @@ import scala.io.Source import org.apache.hadoop.fs.Path import org.json4s.jackson.JsonMethods._ -import org.scalatest.{FunSuiteLike, BeforeAndAfter, FunSuite} +import org.scalatest.BeforeAndAfter import org.apache.spark._ import org.apache.spark.deploy.SparkHadoopUtil @@ -39,7 +39,7 @@ import org.apache.spark.util.{JsonProtocol, Utils} * logging events, whether the parsing of the file names is correct, and whether the logged events * can be read and deserialized into actual SparkListenerEvents. */ -class EventLoggingListenerSuite extends FunSuite with LocalSparkContext with BeforeAndAfter +class EventLoggingListenerSuite extends SparkFunSuite with LocalSparkContext with BeforeAndAfter with Logging { import EventLoggingListenerSuite._ diff --git a/core/src/test/scala/org/apache/spark/scheduler/MapStatusSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/MapStatusSuite.scala index 950c6dc58e332..b8e466fab4506 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/MapStatusSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/MapStatusSuite.scala @@ -18,14 +18,13 @@ package org.apache.spark.scheduler import org.apache.spark.storage.BlockManagerId -import org.scalatest.FunSuite -import org.apache.spark.SparkConf +import org.apache.spark.{SparkConf, SparkFunSuite} import org.apache.spark.serializer.JavaSerializer import scala.util.Random -class MapStatusSuite extends FunSuite { +class MapStatusSuite extends SparkFunSuite { test("compressSize") { assert(MapStatus.compressSize(0L) === 0) diff --git a/core/src/test/scala/org/apache/spark/scheduler/OutputCommitCoordinatorSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/OutputCommitCoordinatorSuite.scala index 7078a7a12232a..a9036da9cc93d 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/OutputCommitCoordinatorSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/OutputCommitCoordinatorSuite.scala @@ -24,7 +24,7 @@ import org.mockito.Matchers import org.mockito.Mockito._ import org.mockito.invocation.InvocationOnMock import org.mockito.stubbing.Answer -import org.scalatest.{BeforeAndAfter, FunSuite} +import org.scalatest.BeforeAndAfter import org.apache.hadoop.mapred.{TaskAttemptID, JobConf, TaskAttemptContext, OutputCommitter} @@ -64,7 +64,7 @@ import scala.language.postfixOps * increments would be captured even though the commit in both tasks was executed * erroneously. */ -class OutputCommitCoordinatorSuite extends FunSuite with BeforeAndAfter { +class OutputCommitCoordinatorSuite extends SparkFunSuite with BeforeAndAfter { var outputCommitCoordinator: OutputCommitCoordinator = null var tempDir: File = null diff --git a/core/src/test/scala/org/apache/spark/scheduler/PoolSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/PoolSuite.scala index 456451b676bed..467796d7c24b0 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/PoolSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/PoolSuite.scala @@ -19,15 +19,13 @@ package org.apache.spark.scheduler import java.util.Properties -import org.scalatest.FunSuite - -import org.apache.spark.{LocalSparkContext, SparkConf, SparkContext} +import org.apache.spark.{LocalSparkContext, SparkConf, SparkContext, SparkFunSuite} /** * Tests that pools and the associated scheduling algorithms for FIFO and fair scheduling work * correctly. */ -class PoolSuite extends FunSuite with LocalSparkContext { +class PoolSuite extends SparkFunSuite with LocalSparkContext { def createTaskSetManager(stageId: Int, numTasks: Int, taskScheduler: TaskSchedulerImpl) : TaskSetManager = { diff --git a/core/src/test/scala/org/apache/spark/scheduler/ReplayListenerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/ReplayListenerSuite.scala index dabe4574b6456..ff3fa95ec32ae 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/ReplayListenerSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/ReplayListenerSuite.scala @@ -21,10 +21,10 @@ import java.io.{File, PrintWriter} import java.net.URI import org.json4s.jackson.JsonMethods._ -import org.scalatest.{BeforeAndAfter, FunSuite} +import org.scalatest.BeforeAndAfter import org.apache.spark.{SparkConf, SparkContext, SPARK_VERSION} -import org.apache.spark.{SparkConf, SparkContext} +import org.apache.spark.{SparkConf, SparkContext, SparkFunSuite} import org.apache.spark.deploy.SparkHadoopUtil import org.apache.spark.io.CompressionCodec import org.apache.spark.util.{JsonProtocol, Utils} @@ -32,7 +32,7 @@ import org.apache.spark.util.{JsonProtocol, Utils} /** * Test whether ReplayListenerBus replays events from logs correctly. */ -class ReplayListenerSuite extends FunSuite with BeforeAndAfter { +class ReplayListenerSuite extends SparkFunSuite with BeforeAndAfter { private val fileSystem = Utils.getHadoopFileSystem("/", SparkHadoopUtil.get.newConfiguration(new SparkConf())) private var testDir: File = _ diff --git a/core/src/test/scala/org/apache/spark/scheduler/SparkListenerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/SparkListenerSuite.scala index 825c616c0c3e0..06fb909bf5419 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/SparkListenerSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/SparkListenerSuite.scala @@ -22,13 +22,13 @@ import java.util.concurrent.Semaphore import scala.collection.mutable import scala.collection.JavaConversions._ -import org.scalatest.{FunSuite, Matchers} +import org.scalatest.Matchers import org.apache.spark.executor.TaskMetrics import org.apache.spark.util.ResetSystemProperties -import org.apache.spark.{LocalSparkContext, SparkConf, SparkContext} +import org.apache.spark.{LocalSparkContext, SparkConf, SparkContext, SparkFunSuite} -class SparkListenerSuite extends FunSuite with LocalSparkContext with Matchers +class SparkListenerSuite extends SparkFunSuite with LocalSparkContext with Matchers with ResetSystemProperties { /** Length of time to wait while draining listener events. */ diff --git a/core/src/test/scala/org/apache/spark/scheduler/SparkListenerWithClusterSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/SparkListenerWithClusterSuite.scala index 623a687c359a2..c7f179e1483a5 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/SparkListenerWithClusterSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/SparkListenerWithClusterSuite.scala @@ -18,16 +18,16 @@ package org.apache.spark.scheduler import org.apache.spark.scheduler.cluster.ExecutorInfo -import org.apache.spark.{SparkContext, LocalSparkContext} +import org.apache.spark.{LocalSparkContext, SparkContext, SparkFunSuite} -import org.scalatest.{FunSuite, BeforeAndAfter, BeforeAndAfterAll} +import org.scalatest.{BeforeAndAfter, BeforeAndAfterAll} import scala.collection.mutable /** * Unit tests for SparkListener that require a local cluster. */ -class SparkListenerWithClusterSuite extends FunSuite with LocalSparkContext +class SparkListenerWithClusterSuite extends SparkFunSuite with LocalSparkContext with BeforeAndAfter with BeforeAndAfterAll { /** Length of time to wait while draining listener events. */ diff --git a/core/src/test/scala/org/apache/spark/scheduler/TaskContextSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/TaskContextSuite.scala index 83ae8701243e5..7c1adc1aef1b6 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/TaskContextSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/TaskContextSuite.scala @@ -20,7 +20,6 @@ package org.apache.spark.scheduler import org.mockito.Mockito._ import org.mockito.Matchers.any -import org.scalatest.FunSuite import org.scalatest.BeforeAndAfter import org.apache.spark._ @@ -28,7 +27,7 @@ import org.apache.spark.rdd.RDD import org.apache.spark.util.{TaskCompletionListenerException, TaskCompletionListener} -class TaskContextSuite extends FunSuite with BeforeAndAfter with LocalSparkContext { +class TaskContextSuite extends SparkFunSuite with BeforeAndAfter with LocalSparkContext { test("calls TaskCompletionListener after failure") { TaskContextSuite.completed = false diff --git a/core/src/test/scala/org/apache/spark/scheduler/TaskResultGetterSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/TaskResultGetterSuite.scala index e3a3803e6483a..815caa79ff529 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/TaskResultGetterSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/TaskResultGetterSuite.scala @@ -23,10 +23,10 @@ import scala.concurrent.duration._ import scala.language.postfixOps import scala.util.control.NonFatal -import org.scalatest.{BeforeAndAfter, FunSuite} +import org.scalatest.BeforeAndAfter import org.scalatest.concurrent.Eventually._ -import org.apache.spark.{LocalSparkContext, SparkConf, SparkContext, SparkEnv} +import org.apache.spark.{LocalSparkContext, SparkConf, SparkContext, SparkEnv, SparkFunSuite} import org.apache.spark.storage.TaskResultBlockId /** @@ -71,7 +71,7 @@ class ResultDeletingTaskResultGetter(sparkEnv: SparkEnv, scheduler: TaskSchedule /** * Tests related to handling task results (both direct and indirect). */ -class TaskResultGetterSuite extends FunSuite with BeforeAndAfter with LocalSparkContext { +class TaskResultGetterSuite extends SparkFunSuite with BeforeAndAfter with LocalSparkContext { // Set the Akka frame size to be as small as possible (it must be an integer, so 1 is as small // as we can make it) so the tests don't take too long. diff --git a/core/src/test/scala/org/apache/spark/scheduler/TaskSchedulerImplSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/TaskSchedulerImplSuite.scala index ffa4381969b68..a6d5232feb8de 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/TaskSchedulerImplSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/TaskSchedulerImplSuite.scala @@ -17,8 +17,6 @@ package org.apache.spark.scheduler -import org.scalatest.FunSuite - import org.apache.spark._ class FakeSchedulerBackend extends SchedulerBackend { @@ -28,7 +26,7 @@ class FakeSchedulerBackend extends SchedulerBackend { def defaultParallelism(): Int = 1 } -class TaskSchedulerImplSuite extends FunSuite with LocalSparkContext with Logging { +class TaskSchedulerImplSuite extends SparkFunSuite with LocalSparkContext with Logging { test("Scheduler does not always schedule tasks on the same workers") { sc = new SparkContext("local", "TaskSchedulerImplSuite") diff --git a/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala index 6198cea46ddf8..0060f3396dcde 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala @@ -22,8 +22,6 @@ import java.util.Random import scala.collection.mutable.ArrayBuffer import scala.collection.mutable -import org.scalatest.FunSuite - import org.apache.spark._ import org.apache.spark.executor.TaskMetrics import org.apache.spark.util.{ManualClock, Utils} @@ -146,7 +144,7 @@ class LargeTask(stageId: Int) extends Task[Array[Byte]](stageId, 0) { override def preferredLocations: Seq[TaskLocation] = Seq[TaskLocation]() } -class TaskSetManagerSuite extends FunSuite with LocalSparkContext with Logging { +class TaskSetManagerSuite extends SparkFunSuite with LocalSparkContext with Logging { import TaskLocality.{ANY, PROCESS_LOCAL, NO_PREF, NODE_LOCAL, RACK_LOCAL} private val conf = new SparkConf diff --git a/core/src/test/scala/org/apache/spark/scheduler/cluster/mesos/MemoryUtilsSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/cluster/mesos/MemoryUtilsSuite.scala index 3fa0115e68259..e72285d03d3ee 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/cluster/mesos/MemoryUtilsSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/cluster/mesos/MemoryUtilsSuite.scala @@ -18,22 +18,21 @@ package org.apache.spark.scheduler.cluster.mesos import org.mockito.Mockito._ -import org.scalatest.FunSuite import org.scalatest.mock.MockitoSugar -import org.apache.spark.{SparkConf, SparkContext} +import org.apache.spark.{SparkConf, SparkContext, SparkFunSuite} -class MemoryUtilsSuite extends FunSuite with MockitoSugar { +class MemoryUtilsSuite extends SparkFunSuite with MockitoSugar { test("MesosMemoryUtils should always override memoryOverhead when it's set") { val sparkConf = new SparkConf val sc = mock[SparkContext] when(sc.conf).thenReturn(sparkConf) - + // 384 > sc.executorMemory * 0.1 => 512 + 384 = 896 when(sc.executorMemory).thenReturn(512) assert(MemoryUtils.calculateTotalMemory(sc) === 896) - + // 384 < sc.executorMemory * 0.1 => 4096 + (4096 * 0.1) = 4505.6 when(sc.executorMemory).thenReturn(4096) assert(MemoryUtils.calculateTotalMemory(sc) === 4505) diff --git a/core/src/test/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerBackendSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerBackendSuite.scala index ab863f3d8d672..68df46a41ddc8 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerBackendSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerBackendSuite.scala @@ -30,16 +30,15 @@ import org.apache.mesos.SchedulerDriver import org.mockito.Matchers._ import org.mockito.Mockito._ import org.mockito.{ArgumentCaptor, Matchers} -import org.scalatest.FunSuite import org.scalatest.mock.MockitoSugar import org.apache.spark.executor.MesosExecutorBackend import org.apache.spark.scheduler.cluster.ExecutorInfo import org.apache.spark.scheduler.{LiveListenerBus, SparkListenerExecutorAdded, TaskDescription, TaskSchedulerImpl, WorkerOffer} -import org.apache.spark.{LocalSparkContext, SparkConf, SparkContext} +import org.apache.spark.{LocalSparkContext, SparkConf, SparkContext, SparkFunSuite} -class MesosSchedulerBackendSuite extends FunSuite with LocalSparkContext with MockitoSugar { +class MesosSchedulerBackendSuite extends SparkFunSuite with LocalSparkContext with MockitoSugar { test("check spark-class location correctly") { val conf = new SparkConf @@ -80,11 +79,11 @@ class MesosSchedulerBackendSuite extends FunSuite with LocalSparkContext with Mo .set("spark.mesos.executor.docker.image", "spark/mock") .set("spark.mesos.executor.docker.volumes", "/a,/b:/b,/c:/c:rw,/d:ro,/e:/e:ro") .set("spark.mesos.executor.docker.portmaps", "80:8080,53:53:tcp") - + val listenerBus = mock[LiveListenerBus] listenerBus.post( SparkListenerExecutorAdded(anyLong, "s1", new ExecutorInfo("host1", 2, Map.empty))) - + val sc = mock[SparkContext] when(sc.executorMemory).thenReturn(100) when(sc.getSparkHome()).thenReturn(Option("/spark-home")) diff --git a/core/src/test/scala/org/apache/spark/scheduler/cluster/mesos/MesosTaskLaunchDataSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/cluster/mesos/MesosTaskLaunchDataSuite.scala index eebcba40f8a1c..5a81bb335fdb7 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/cluster/mesos/MesosTaskLaunchDataSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/cluster/mesos/MesosTaskLaunchDataSuite.scala @@ -19,9 +19,9 @@ package org.apache.spark.scheduler.cluster.mesos import java.nio.ByteBuffer -import org.scalatest.FunSuite +import org.apache.spark.SparkFunSuite -class MesosTaskLaunchDataSuite extends FunSuite { +class MesosTaskLaunchDataSuite extends SparkFunSuite { test("serialize and deserialize data must be same") { val serializedTask = ByteBuffer.allocate(40) (Range(100, 110).map(serializedTask.putInt(_))) diff --git a/core/src/test/scala/org/apache/spark/scheduler/mesos/MesosClusterSchedulerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/mesos/MesosClusterSchedulerSuite.scala index f28e29e9b8d8e..f5cef1caaf1ac 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/mesos/MesosClusterSchedulerSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/mesos/MesosClusterSchedulerSuite.scala @@ -19,16 +19,15 @@ package org.apache.spark.scheduler.mesos import java.util.Date -import org.scalatest.FunSuite import org.scalatest.mock.MockitoSugar import org.apache.spark.deploy.Command import org.apache.spark.deploy.mesos.MesosDriverDescription import org.apache.spark.scheduler.cluster.mesos._ -import org.apache.spark.{LocalSparkContext, SparkConf} +import org.apache.spark.{LocalSparkContext, SparkConf, SparkFunSuite} -class MesosClusterSchedulerSuite extends FunSuite with LocalSparkContext with MockitoSugar { +class MesosClusterSchedulerSuite extends SparkFunSuite with LocalSparkContext with MockitoSugar { private val command = new Command("mainClass", Seq("arg"), null, null, null, null) diff --git a/core/src/test/scala/org/apache/spark/serializer/JavaSerializerSuite.scala b/core/src/test/scala/org/apache/spark/serializer/JavaSerializerSuite.scala index ed4d8ce632e16..329a2b6dad831 100644 --- a/core/src/test/scala/org/apache/spark/serializer/JavaSerializerSuite.scala +++ b/core/src/test/scala/org/apache/spark/serializer/JavaSerializerSuite.scala @@ -17,10 +17,9 @@ package org.apache.spark.serializer -import org.apache.spark.SparkConf -import org.scalatest.FunSuite +import org.apache.spark.{SparkConf, SparkFunSuite} -class JavaSerializerSuite extends FunSuite { +class JavaSerializerSuite extends SparkFunSuite { test("JavaSerializer instances are serializable") { val serializer = new JavaSerializer(new SparkConf()) val instance = serializer.newInstance() diff --git a/core/src/test/scala/org/apache/spark/serializer/KryoSerializerDistributedSuite.scala b/core/src/test/scala/org/apache/spark/serializer/KryoSerializerDistributedSuite.scala index 054a4c64897a9..63a8480c9b57b 100644 --- a/core/src/test/scala/org/apache/spark/serializer/KryoSerializerDistributedSuite.scala +++ b/core/src/test/scala/org/apache/spark/serializer/KryoSerializerDistributedSuite.scala @@ -20,12 +20,11 @@ package org.apache.spark.serializer import org.apache.spark.util.Utils import com.esotericsoftware.kryo.Kryo -import org.scalatest.FunSuite -import org.apache.spark.{LocalSparkContext, SparkConf, SparkContext, SparkEnv, TestUtils} +import org.apache.spark._ import org.apache.spark.serializer.KryoDistributedTest._ -class KryoSerializerDistributedSuite extends FunSuite { +class KryoSerializerDistributedSuite extends SparkFunSuite { test("kryo objects are serialised consistently in different processes") { val conf = new SparkConf(false) diff --git a/core/src/test/scala/org/apache/spark/serializer/KryoSerializerResizableOutputSuite.scala b/core/src/test/scala/org/apache/spark/serializer/KryoSerializerResizableOutputSuite.scala index da98d09184735..a9b209ccfc76e 100644 --- a/core/src/test/scala/org/apache/spark/serializer/KryoSerializerResizableOutputSuite.scala +++ b/core/src/test/scala/org/apache/spark/serializer/KryoSerializerResizableOutputSuite.scala @@ -17,15 +17,13 @@ package org.apache.spark.serializer -import org.scalatest.FunSuite - -import org.apache.spark.SparkConf +import org.apache.spark.{SparkConf, SparkFunSuite} import org.apache.spark.SparkContext import org.apache.spark.LocalSparkContext import org.apache.spark.SparkException -class KryoSerializerResizableOutputSuite extends FunSuite { +class KryoSerializerResizableOutputSuite extends SparkFunSuite { // trial and error showed this will not serialize with 1mb buffer val x = (1 to 400000).toArray diff --git a/core/src/test/scala/org/apache/spark/serializer/KryoSerializerSuite.scala b/core/src/test/scala/org/apache/spark/serializer/KryoSerializerSuite.scala index 14c0172fa96ab..23a1fdb0f5009 100644 --- a/core/src/test/scala/org/apache/spark/serializer/KryoSerializerSuite.scala +++ b/core/src/test/scala/org/apache/spark/serializer/KryoSerializerSuite.scala @@ -23,21 +23,20 @@ import scala.collection.mutable import scala.reflect.ClassTag import com.esotericsoftware.kryo.Kryo -import org.scalatest.FunSuite -import org.apache.spark.{SharedSparkContext, SparkConf} +import org.apache.spark.{SharedSparkContext, SparkConf, SparkFunSuite} import org.apache.spark.scheduler.HighlyCompressedMapStatus import org.apache.spark.serializer.KryoTest._ import org.apache.spark.storage.BlockManagerId -class KryoSerializerSuite extends FunSuite with SharedSparkContext { +class KryoSerializerSuite extends SparkFunSuite with SharedSparkContext { conf.set("spark.serializer", "org.apache.spark.serializer.KryoSerializer") conf.set("spark.kryo.registrator", classOf[MyRegistrator].getName) test("SPARK-7392 configuration limits") { val kryoBufferProperty = "spark.kryoserializer.buffer" val kryoBufferMaxProperty = "spark.kryoserializer.buffer.max" - + def newKryoInstance( conf: SparkConf, bufferSize: String = "64k", @@ -47,7 +46,7 @@ class KryoSerializerSuite extends FunSuite with SharedSparkContext { kryoConf.set(kryoBufferMaxProperty, maxBufferSize) new KryoSerializer(kryoConf).newInstance() } - + // test default values newKryoInstance(conf, "64k", "64m") // 2048m = 2097152k @@ -70,7 +69,7 @@ class KryoSerializerSuite extends FunSuite with SharedSparkContext { // test configuration with mb is supported properly newKryoInstance(conf, "8m", "9m") } - + test("basic types") { val ser = new KryoSerializer(conf).newInstance() def check[T: ClassTag](t: T) { @@ -361,7 +360,7 @@ class KryoSerializerSuite extends FunSuite with SharedSparkContext { } } -class KryoSerializerAutoResetDisabledSuite extends FunSuite with SharedSparkContext { +class KryoSerializerAutoResetDisabledSuite extends SparkFunSuite with SharedSparkContext { conf.set("spark.serializer", classOf[KryoSerializer].getName) conf.set("spark.kryo.registrator", classOf[RegistratorWithoutAutoReset].getName) conf.set("spark.kryo.referenceTracking", "true") diff --git a/core/src/test/scala/org/apache/spark/serializer/ProactiveClosureSerializationSuite.scala b/core/src/test/scala/org/apache/spark/serializer/ProactiveClosureSerializationSuite.scala index 673948d84d82b..c657414e9e5c3 100644 --- a/core/src/test/scala/org/apache/spark/serializer/ProactiveClosureSerializationSuite.scala +++ b/core/src/test/scala/org/apache/spark/serializer/ProactiveClosureSerializationSuite.scala @@ -17,19 +17,17 @@ package org.apache.spark.serializer -import org.scalatest.FunSuite - -import org.apache.spark.{SharedSparkContext, SparkException} +import org.apache.spark.{SharedSparkContext, SparkException, SparkFunSuite} import org.apache.spark.rdd.RDD /* A trivial (but unserializable) container for trivial functions */ class UnserializableClass { def op[T](x: T): String = x.toString - + def pred[T](x: T): Boolean = x.toString.length % 2 == 0 } -class ProactiveClosureSerializationSuite extends FunSuite with SharedSparkContext { +class ProactiveClosureSerializationSuite extends SparkFunSuite with SharedSparkContext { def fixture: (RDD[String], UnserializableClass) = { (sc.parallelize(0 until 1000).map(_.toString), new UnserializableClass) @@ -47,7 +45,7 @@ class ProactiveClosureSerializationSuite extends FunSuite with SharedSparkContex // iterating over a map from transformation names to functions that perform that // transformation on a given RDD, creating one test case for each - for (transformation <- + for (transformation <- Map("map" -> xmap _, "flatMap" -> xflatMap _, "filter" -> xfilter _, @@ -60,24 +58,24 @@ class ProactiveClosureSerializationSuite extends FunSuite with SharedSparkContex val ex = intercept[SparkException] { xf(data, uc) } - assert(ex.getMessage.contains("Task not serializable"), + assert(ex.getMessage.contains("Task not serializable"), s"RDD.$name doesn't proactively throw NotSerializableException") } } - private def xmap(x: RDD[String], uc: UnserializableClass): RDD[String] = + private def xmap(x: RDD[String], uc: UnserializableClass): RDD[String] = x.map(y => uc.op(y)) - private def xflatMap(x: RDD[String], uc: UnserializableClass): RDD[String] = + private def xflatMap(x: RDD[String], uc: UnserializableClass): RDD[String] = x.flatMap(y => Seq(uc.op(y))) - private def xfilter(x: RDD[String], uc: UnserializableClass): RDD[String] = + private def xfilter(x: RDD[String], uc: UnserializableClass): RDD[String] = x.filter(y => uc.pred(y)) - private def xmapPartitions(x: RDD[String], uc: UnserializableClass): RDD[String] = + private def xmapPartitions(x: RDD[String], uc: UnserializableClass): RDD[String] = x.mapPartitions(_.map(y => uc.op(y))) - private def xmapPartitionsWithIndex(x: RDD[String], uc: UnserializableClass): RDD[String] = + private def xmapPartitionsWithIndex(x: RDD[String], uc: UnserializableClass): RDD[String] = x.mapPartitionsWithIndex((_, it) => it.map(y => uc.op(y))) - + } diff --git a/core/src/test/scala/org/apache/spark/serializer/SerializationDebuggerSuite.scala b/core/src/test/scala/org/apache/spark/serializer/SerializationDebuggerSuite.scala index e62828c4fbac6..2707bb53bc383 100644 --- a/core/src/test/scala/org/apache/spark/serializer/SerializationDebuggerSuite.scala +++ b/core/src/test/scala/org/apache/spark/serializer/SerializationDebuggerSuite.scala @@ -19,10 +19,12 @@ package org.apache.spark.serializer import java.io.{ObjectOutput, ObjectInput} -import org.scalatest.{BeforeAndAfterEach, FunSuite} +import org.scalatest.BeforeAndAfterEach +import org.apache.spark.SparkFunSuite -class SerializationDebuggerSuite extends FunSuite with BeforeAndAfterEach { + +class SerializationDebuggerSuite extends SparkFunSuite with BeforeAndAfterEach { import SerializationDebugger.find diff --git a/core/src/test/scala/org/apache/spark/serializer/SerializerPropertiesSuite.scala b/core/src/test/scala/org/apache/spark/serializer/SerializerPropertiesSuite.scala index bb34033fe9e7e..4ce3b941bea55 100644 --- a/core/src/test/scala/org/apache/spark/serializer/SerializerPropertiesSuite.scala +++ b/core/src/test/scala/org/apache/spark/serializer/SerializerPropertiesSuite.scala @@ -21,9 +21,9 @@ import java.io.{ByteArrayInputStream, ByteArrayOutputStream} import scala.util.Random -import org.scalatest.{Assertions, FunSuite} +import org.scalatest.Assertions -import org.apache.spark.SparkConf +import org.apache.spark.{SparkConf, SparkFunSuite} import org.apache.spark.serializer.KryoTest.RegistratorWithoutAutoReset /** @@ -31,7 +31,7 @@ import org.apache.spark.serializer.KryoTest.RegistratorWithoutAutoReset * describe properties of the serialized stream, such as * [[Serializer.supportsRelocationOfSerializedObjects]]. */ -class SerializerPropertiesSuite extends FunSuite { +class SerializerPropertiesSuite extends SparkFunSuite { import SerializerPropertiesSuite._ diff --git a/core/src/test/scala/org/apache/spark/shuffle/ShuffleMemoryManagerSuite.scala b/core/src/test/scala/org/apache/spark/shuffle/ShuffleMemoryManagerSuite.scala index e0e646f0a3652..96778c9ebafb1 100644 --- a/core/src/test/scala/org/apache/spark/shuffle/ShuffleMemoryManagerSuite.scala +++ b/core/src/test/scala/org/apache/spark/shuffle/ShuffleMemoryManagerSuite.scala @@ -17,13 +17,14 @@ package org.apache.spark.shuffle -import org.scalatest.FunSuite import org.scalatest.concurrent.Timeouts import org.scalatest.time.SpanSugar._ import java.util.concurrent.atomic.AtomicBoolean import java.util.concurrent.CountDownLatch -class ShuffleMemoryManagerSuite extends FunSuite with Timeouts { +import org.apache.spark.SparkFunSuite + +class ShuffleMemoryManagerSuite extends SparkFunSuite with Timeouts { /** Launch a thread with the given body block and return it. */ private def startThread(name: String)(body: => Unit): Thread = { val thread = new Thread("ShuffleMemorySuite " + name) { diff --git a/core/src/test/scala/org/apache/spark/shuffle/hash/HashShuffleManagerSuite.scala b/core/src/test/scala/org/apache/spark/shuffle/hash/HashShuffleManagerSuite.scala index 0537bf66ad020..491dc3659e184 100644 --- a/core/src/test/scala/org/apache/spark/shuffle/hash/HashShuffleManagerSuite.scala +++ b/core/src/test/scala/org/apache/spark/shuffle/hash/HashShuffleManagerSuite.scala @@ -21,16 +21,14 @@ import java.io.{File, FileWriter} import scala.language.reflectiveCalls -import org.scalatest.FunSuite - -import org.apache.spark.{SparkEnv, SparkContext, LocalSparkContext, SparkConf} +import org.apache.spark.{LocalSparkContext, SparkConf, SparkContext, SparkEnv, SparkFunSuite} import org.apache.spark.executor.ShuffleWriteMetrics import org.apache.spark.network.buffer.{FileSegmentManagedBuffer, ManagedBuffer} import org.apache.spark.serializer.JavaSerializer import org.apache.spark.shuffle.FileShuffleBlockResolver import org.apache.spark.storage.{ShuffleBlockId, FileSegment} -class HashShuffleManagerSuite extends FunSuite with LocalSparkContext { +class HashShuffleManagerSuite extends SparkFunSuite with LocalSparkContext { private val testConf = new SparkConf(false) private def checkSegments(expected: FileSegment, buffer: ManagedBuffer) { diff --git a/core/src/test/scala/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriterSuite.scala b/core/src/test/scala/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriterSuite.scala new file mode 100644 index 0000000000000..542f8f45125a4 --- /dev/null +++ b/core/src/test/scala/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriterSuite.scala @@ -0,0 +1,171 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.shuffle.sort + +import java.io.File +import java.util.UUID + +import scala.collection.mutable +import scala.collection.mutable.ArrayBuffer + +import org.mockito.Answers.RETURNS_SMART_NULLS +import org.mockito.{Mock, MockitoAnnotations} +import org.mockito.Matchers._ +import org.mockito.Mockito._ +import org.mockito.invocation.InvocationOnMock +import org.mockito.stubbing.Answer +import org.scalatest.BeforeAndAfterEach + +import org.apache.spark._ +import org.apache.spark.executor.{TaskMetrics, ShuffleWriteMetrics} +import org.apache.spark.serializer.{SerializerInstance, Serializer, JavaSerializer} +import org.apache.spark.storage._ +import org.apache.spark.util.Utils + +class BypassMergeSortShuffleWriterSuite extends SparkFunSuite with BeforeAndAfterEach { + + @Mock(answer = RETURNS_SMART_NULLS) private var blockManager: BlockManager = _ + @Mock(answer = RETURNS_SMART_NULLS) private var diskBlockManager: DiskBlockManager = _ + @Mock(answer = RETURNS_SMART_NULLS) private var taskContext: TaskContext = _ + + private var taskMetrics: TaskMetrics = _ + private var shuffleWriteMetrics: ShuffleWriteMetrics = _ + private var tempDir: File = _ + private var outputFile: File = _ + private val conf: SparkConf = new SparkConf(loadDefaults = false) + private val temporaryFilesCreated: mutable.Buffer[File] = new ArrayBuffer[File]() + private val blockIdToFileMap: mutable.Map[BlockId, File] = new mutable.HashMap[BlockId, File] + private val shuffleBlockId: ShuffleBlockId = new ShuffleBlockId(0, 0, 0) + private val serializer: Serializer = new JavaSerializer(conf) + + override def beforeEach(): Unit = { + tempDir = Utils.createTempDir() + outputFile = File.createTempFile("shuffle", null, tempDir) + shuffleWriteMetrics = new ShuffleWriteMetrics + taskMetrics = new TaskMetrics + taskMetrics.shuffleWriteMetrics = Some(shuffleWriteMetrics) + MockitoAnnotations.initMocks(this) + when(taskContext.taskMetrics()).thenReturn(taskMetrics) + when(blockManager.diskBlockManager).thenReturn(diskBlockManager) + when(blockManager.getDiskWriter( + any[BlockId], + any[File], + any[SerializerInstance], + anyInt(), + any[ShuffleWriteMetrics] + )).thenAnswer(new Answer[BlockObjectWriter] { + override def answer(invocation: InvocationOnMock): BlockObjectWriter = { + val args = invocation.getArguments + new DiskBlockObjectWriter( + args(0).asInstanceOf[BlockId], + args(1).asInstanceOf[File], + args(2).asInstanceOf[SerializerInstance], + args(3).asInstanceOf[Int], + compressStream = identity, + syncWrites = false, + args(4).asInstanceOf[ShuffleWriteMetrics] + ) + } + }) + when(diskBlockManager.createTempShuffleBlock()).thenAnswer( + new Answer[(TempShuffleBlockId, File)] { + override def answer(invocation: InvocationOnMock): (TempShuffleBlockId, File) = { + val blockId = new TempShuffleBlockId(UUID.randomUUID) + val file = File.createTempFile(blockId.toString, null, tempDir) + blockIdToFileMap.put(blockId, file) + temporaryFilesCreated.append(file) + (blockId, file) + } + }) + when(diskBlockManager.getFile(any[BlockId])).thenAnswer( + new Answer[File] { + override def answer(invocation: InvocationOnMock): File = { + blockIdToFileMap.get(invocation.getArguments.head.asInstanceOf[BlockId]).get + } + }) + } + + override def afterEach(): Unit = { + Utils.deleteRecursively(tempDir) + blockIdToFileMap.clear() + temporaryFilesCreated.clear() + } + + test("write empty iterator") { + val writer = new BypassMergeSortShuffleWriter[Int, Int]( + new SparkConf(loadDefaults = false), + blockManager, + new HashPartitioner(7), + shuffleWriteMetrics, + serializer + ) + writer.insertAll(Iterator.empty) + val partitionLengths = writer.writePartitionedFile(shuffleBlockId, taskContext, outputFile) + assert(partitionLengths.sum === 0) + assert(outputFile.exists()) + assert(outputFile.length() === 0) + assert(temporaryFilesCreated.isEmpty) + assert(shuffleWriteMetrics.shuffleBytesWritten === 0) + assert(shuffleWriteMetrics.shuffleRecordsWritten === 0) + assert(taskMetrics.diskBytesSpilled === 0) + assert(taskMetrics.memoryBytesSpilled === 0) + } + + test("write with some empty partitions") { + def records: Iterator[(Int, Int)] = + Iterator((1, 1), (5, 5)) ++ (0 until 100000).iterator.map(x => (2, 2)) + val writer = new BypassMergeSortShuffleWriter[Int, Int]( + new SparkConf(loadDefaults = false), + blockManager, + new HashPartitioner(7), + shuffleWriteMetrics, + serializer + ) + writer.insertAll(records) + assert(temporaryFilesCreated.nonEmpty) + val partitionLengths = writer.writePartitionedFile(shuffleBlockId, taskContext, outputFile) + assert(partitionLengths.sum === outputFile.length()) + assert(temporaryFilesCreated.count(_.exists()) === 0) // check that temporary files were deleted + assert(shuffleWriteMetrics.shuffleBytesWritten === outputFile.length()) + assert(shuffleWriteMetrics.shuffleRecordsWritten === records.length) + assert(taskMetrics.diskBytesSpilled === 0) + assert(taskMetrics.memoryBytesSpilled === 0) + } + + test("cleanup of intermediate files after errors") { + val writer = new BypassMergeSortShuffleWriter[Int, Int]( + new SparkConf(loadDefaults = false), + blockManager, + new HashPartitioner(7), + shuffleWriteMetrics, + serializer + ) + intercept[SparkException] { + writer.insertAll((0 until 100000).iterator.map(i => { + if (i == 99990) { + throw new SparkException("Intentional failure") + } + (i, i) + })) + } + assert(temporaryFilesCreated.nonEmpty) + writer.stop() + assert(temporaryFilesCreated.count(_.exists()) === 0) + } + +} diff --git a/core/src/test/scala/org/apache/spark/shuffle/sort/SortShuffleWriterSuite.scala b/core/src/test/scala/org/apache/spark/shuffle/sort/SortShuffleWriterSuite.scala new file mode 100644 index 0000000000000..34b4984f12c09 --- /dev/null +++ b/core/src/test/scala/org/apache/spark/shuffle/sort/SortShuffleWriterSuite.scala @@ -0,0 +1,45 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.shuffle.sort + +import org.mockito.Mockito._ + +import org.apache.spark.{Aggregator, SparkConf, SparkFunSuite} + +class SortShuffleWriterSuite extends SparkFunSuite { + + import SortShuffleWriter._ + + test("conditions for bypassing merge-sort") { + val conf = new SparkConf(loadDefaults = false) + val agg = mock(classOf[Aggregator[_, _, _]], RETURNS_SMART_NULLS) + val ord = implicitly[Ordering[Int]] + + // Numbers of partitions that are above and below the default bypassMergeThreshold + val FEW_PARTITIONS = 50 + val MANY_PARTITIONS = 10000 + + // Shuffles with no ordering or aggregator: should bypass unless # of partitions is high + assert(shouldBypassMergeSort(conf, FEW_PARTITIONS, None, None)) + assert(!shouldBypassMergeSort(conf, MANY_PARTITIONS, None, None)) + + // Shuffles with an ordering or aggregator: should not bypass even if they have few partitions + assert(!shouldBypassMergeSort(conf, FEW_PARTITIONS, None, Some(ord))) + assert(!shouldBypassMergeSort(conf, FEW_PARTITIONS, Some(agg), None)) + } +} diff --git a/core/src/test/scala/org/apache/spark/shuffle/unsafe/UnsafeShuffleManagerSuite.scala b/core/src/test/scala/org/apache/spark/shuffle/unsafe/UnsafeShuffleManagerSuite.scala index 49a04a2a45280..a73e94e05575e 100644 --- a/core/src/test/scala/org/apache/spark/shuffle/unsafe/UnsafeShuffleManagerSuite.scala +++ b/core/src/test/scala/org/apache/spark/shuffle/unsafe/UnsafeShuffleManagerSuite.scala @@ -20,7 +20,7 @@ package org.apache.spark.shuffle.unsafe import org.mockito.Mockito._ import org.mockito.invocation.InvocationOnMock import org.mockito.stubbing.Answer -import org.scalatest.{FunSuite, Matchers} +import org.scalatest.Matchers import org.apache.spark._ import org.apache.spark.serializer.{JavaSerializer, KryoSerializer, Serializer} @@ -29,7 +29,7 @@ import org.apache.spark.serializer.{JavaSerializer, KryoSerializer, Serializer} * Tests for the fallback logic in UnsafeShuffleManager. Actual tests of shuffling data are * performed in other suites. */ -class UnsafeShuffleManagerSuite extends FunSuite with Matchers { +class UnsafeShuffleManagerSuite extends SparkFunSuite with Matchers { import UnsafeShuffleManager.canUseUnsafeShuffle diff --git a/core/src/test/scala/org/apache/spark/status/api/v1/SimpleDateParamSuite.scala b/core/src/test/scala/org/apache/spark/status/api/v1/SimpleDateParamSuite.scala index 731d1f557ed33..63b0e77629dde 100644 --- a/core/src/test/scala/org/apache/spark/status/api/v1/SimpleDateParamSuite.scala +++ b/core/src/test/scala/org/apache/spark/status/api/v1/SimpleDateParamSuite.scala @@ -16,14 +16,21 @@ */ package org.apache.spark.status.api.v1 -import org.scalatest.{Matchers, FunSuite} +import javax.ws.rs.WebApplicationException -class SimpleDateParamSuite extends FunSuite with Matchers { +import org.scalatest.Matchers + +import org.apache.spark.SparkFunSuite + +class SimpleDateParamSuite extends SparkFunSuite with Matchers { test("date parsing") { new SimpleDateParam("2015-02-20T23:21:17.190GMT").timestamp should be (1424474477190L) new SimpleDateParam("2015-02-20T17:21:17.190EST").timestamp should be (1424470877190L) new SimpleDateParam("2015-02-20").timestamp should be (1424390400000L) // GMT + intercept[WebApplicationException] { + new SimpleDateParam("invalid date") + } } } diff --git a/core/src/test/scala/org/apache/spark/storage/BlockIdSuite.scala b/core/src/test/scala/org/apache/spark/storage/BlockIdSuite.scala index b647e8a6728ec..89ed031b6fcd1 100644 --- a/core/src/test/scala/org/apache/spark/storage/BlockIdSuite.scala +++ b/core/src/test/scala/org/apache/spark/storage/BlockIdSuite.scala @@ -17,9 +17,9 @@ package org.apache.spark.storage -import org.scalatest.FunSuite +import org.apache.spark.SparkFunSuite -class BlockIdSuite extends FunSuite { +class BlockIdSuite extends SparkFunSuite { def assertSame(id1: BlockId, id2: BlockId) { assert(id1.name === id2.name) assert(id1.hashCode === id2.hashCode) diff --git a/core/src/test/scala/org/apache/spark/storage/BlockManagerReplicationSuite.scala b/core/src/test/scala/org/apache/spark/storage/BlockManagerReplicationSuite.scala index f647200402ecb..0f5ba46f69c2f 100644 --- a/core/src/test/scala/org/apache/spark/storage/BlockManagerReplicationSuite.scala +++ b/core/src/test/scala/org/apache/spark/storage/BlockManagerReplicationSuite.scala @@ -23,11 +23,11 @@ import scala.language.implicitConversions import scala.language.postfixOps import org.mockito.Mockito.{mock, when} -import org.scalatest.{BeforeAndAfter, FunSuite, Matchers} +import org.scalatest.{BeforeAndAfter, Matchers} import org.scalatest.concurrent.Eventually._ import org.apache.spark.rpc.RpcEnv -import org.apache.spark.{MapOutputTrackerMaster, SparkConf, SparkContext, SecurityManager} +import org.apache.spark._ import org.apache.spark.network.BlockTransferService import org.apache.spark.network.nio.NioBlockTransferService import org.apache.spark.scheduler.LiveListenerBus @@ -36,7 +36,7 @@ import org.apache.spark.shuffle.hash.HashShuffleManager import org.apache.spark.storage.StorageLevel._ /** Testsuite that tests block replication in BlockManager */ -class BlockManagerReplicationSuite extends FunSuite with Matchers with BeforeAndAfter { +class BlockManagerReplicationSuite extends SparkFunSuite with Matchers with BeforeAndAfter { private val conf = new SparkConf(false) var rpcEnv: RpcEnv = null diff --git a/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala b/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala index 151955ef7f435..bcee901f5dd5f 100644 --- a/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala +++ b/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala @@ -31,7 +31,7 @@ import org.scalatest.concurrent.Eventually._ import org.scalatest.concurrent.Timeouts._ import org.apache.spark.rpc.RpcEnv -import org.apache.spark.{MapOutputTrackerMaster, SparkConf, SparkContext, SecurityManager} +import org.apache.spark._ import org.apache.spark.executor.DataReadMethod import org.apache.spark.network.nio.NioBlockTransferService import org.apache.spark.scheduler.LiveListenerBus @@ -41,7 +41,7 @@ import org.apache.spark.storage.BlockManagerMessages.BlockManagerHeartbeat import org.apache.spark.util._ -class BlockManagerSuite extends FunSuite with Matchers with BeforeAndAfterEach +class BlockManagerSuite extends SparkFunSuite with Matchers with BeforeAndAfterEach with PrivateMethodTester with ResetSystemProperties { private val conf = new SparkConf(false) diff --git a/core/src/test/scala/org/apache/spark/storage/BlockObjectWriterSuite.scala b/core/src/test/scala/org/apache/spark/storage/BlockObjectWriterSuite.scala index 43ef469c1fd48..7bdea724fea58 100644 --- a/core/src/test/scala/org/apache/spark/storage/BlockObjectWriterSuite.scala +++ b/core/src/test/scala/org/apache/spark/storage/BlockObjectWriterSuite.scala @@ -18,16 +18,28 @@ package org.apache.spark.storage import java.io.File -import org.scalatest.FunSuite +import org.scalatest.BeforeAndAfterEach import org.apache.spark.SparkConf +import org.apache.spark.{SparkConf, SparkFunSuite} import org.apache.spark.executor.ShuffleWriteMetrics import org.apache.spark.serializer.JavaSerializer import org.apache.spark.util.Utils -class BlockObjectWriterSuite extends FunSuite { +class BlockObjectWriterSuite extends SparkFunSuite with BeforeAndAfterEach { + + var tempDir: File = _ + + override def beforeEach(): Unit = { + tempDir = Utils.createTempDir() + } + + override def afterEach(): Unit = { + Utils.deleteRecursively(tempDir) + } + test("verify write metrics") { - val file = new File(Utils.createTempDir(), "somefile") + val file = new File(tempDir, "somefile") val writeMetrics = new ShuffleWriteMetrics() val writer = new DiskBlockObjectWriter(new TestBlockId("0"), file, new JavaSerializer(new SparkConf()).newInstance(), 1024, os => os, true, writeMetrics) @@ -49,7 +61,7 @@ class BlockObjectWriterSuite extends FunSuite { } test("verify write metrics on revert") { - val file = new File(Utils.createTempDir(), "somefile") + val file = new File(tempDir, "somefile") val writeMetrics = new ShuffleWriteMetrics() val writer = new DiskBlockObjectWriter(new TestBlockId("0"), file, new JavaSerializer(new SparkConf()).newInstance(), 1024, os => os, true, writeMetrics) @@ -72,7 +84,7 @@ class BlockObjectWriterSuite extends FunSuite { } test("Reopening a closed block writer") { - val file = new File(Utils.createTempDir(), "somefile") + val file = new File(tempDir, "somefile") val writeMetrics = new ShuffleWriteMetrics() val writer = new DiskBlockObjectWriter(new TestBlockId("0"), file, new JavaSerializer(new SparkConf()).newInstance(), 1024, os => os, true, writeMetrics) @@ -83,4 +95,79 @@ class BlockObjectWriterSuite extends FunSuite { writer.open() } } + + test("calling revertPartialWritesAndClose() on a closed block writer should have no effect") { + val file = new File(tempDir, "somefile") + val writeMetrics = new ShuffleWriteMetrics() + val writer = new DiskBlockObjectWriter(new TestBlockId("0"), file, + new JavaSerializer(new SparkConf()).newInstance(), 1024, os => os, true, writeMetrics) + for (i <- 1 to 1000) { + writer.write(i, i) + } + writer.commitAndClose() + val bytesWritten = writeMetrics.shuffleBytesWritten + assert(writeMetrics.shuffleRecordsWritten === 1000) + writer.revertPartialWritesAndClose() + assert(writeMetrics.shuffleRecordsWritten === 1000) + assert(writeMetrics.shuffleBytesWritten === bytesWritten) + } + + test("commitAndClose() should be idempotent") { + val file = new File(tempDir, "somefile") + val writeMetrics = new ShuffleWriteMetrics() + val writer = new DiskBlockObjectWriter(new TestBlockId("0"), file, + new JavaSerializer(new SparkConf()).newInstance(), 1024, os => os, true, writeMetrics) + for (i <- 1 to 1000) { + writer.write(i, i) + } + writer.commitAndClose() + val bytesWritten = writeMetrics.shuffleBytesWritten + val writeTime = writeMetrics.shuffleWriteTime + assert(writeMetrics.shuffleRecordsWritten === 1000) + writer.commitAndClose() + assert(writeMetrics.shuffleRecordsWritten === 1000) + assert(writeMetrics.shuffleBytesWritten === bytesWritten) + assert(writeMetrics.shuffleWriteTime === writeTime) + } + + test("revertPartialWritesAndClose() should be idempotent") { + val file = new File(tempDir, "somefile") + val writeMetrics = new ShuffleWriteMetrics() + val writer = new DiskBlockObjectWriter(new TestBlockId("0"), file, + new JavaSerializer(new SparkConf()).newInstance(), 1024, os => os, true, writeMetrics) + for (i <- 1 to 1000) { + writer.write(i, i) + } + writer.revertPartialWritesAndClose() + val bytesWritten = writeMetrics.shuffleBytesWritten + val writeTime = writeMetrics.shuffleWriteTime + assert(writeMetrics.shuffleRecordsWritten === 0) + writer.revertPartialWritesAndClose() + assert(writeMetrics.shuffleRecordsWritten === 0) + assert(writeMetrics.shuffleBytesWritten === bytesWritten) + assert(writeMetrics.shuffleWriteTime === writeTime) + } + + test("fileSegment() can only be called after commitAndClose() has been called") { + val file = new File(tempDir, "somefile") + val writeMetrics = new ShuffleWriteMetrics() + val writer = new DiskBlockObjectWriter(new TestBlockId("0"), file, + new JavaSerializer(new SparkConf()).newInstance(), 1024, os => os, true, writeMetrics) + for (i <- 1 to 1000) { + writer.write(i, i) + } + intercept[IllegalStateException] { + writer.fileSegment() + } + writer.close() + } + + test("commitAndClose() without ever opening or writing") { + val file = new File(tempDir, "somefile") + val writeMetrics = new ShuffleWriteMetrics() + val writer = new DiskBlockObjectWriter(new TestBlockId("0"), file, + new JavaSerializer(new SparkConf()).newInstance(), 1024, os => os, true, writeMetrics) + writer.commitAndClose() + assert(writer.fileSegment().length === 0) + } } diff --git a/core/src/test/scala/org/apache/spark/storage/DiskBlockManagerSuite.scala b/core/src/test/scala/org/apache/spark/storage/DiskBlockManagerSuite.scala index bc5c74c126b74..688f56f4665f3 100644 --- a/core/src/test/scala/org/apache/spark/storage/DiskBlockManagerSuite.scala +++ b/core/src/test/scala/org/apache/spark/storage/DiskBlockManagerSuite.scala @@ -22,12 +22,12 @@ import java.io.{File, FileWriter} import scala.language.reflectiveCalls import org.mockito.Mockito.{mock, when} -import org.scalatest.{BeforeAndAfterAll, BeforeAndAfterEach, FunSuite} +import org.scalatest.{BeforeAndAfterAll, BeforeAndAfterEach} -import org.apache.spark.SparkConf +import org.apache.spark.{SparkConf, SparkFunSuite} import org.apache.spark.util.Utils -class DiskBlockManagerSuite extends FunSuite with BeforeAndAfterEach with BeforeAndAfterAll { +class DiskBlockManagerSuite extends SparkFunSuite with BeforeAndAfterEach with BeforeAndAfterAll { private val testConf = new SparkConf(false) private var rootDir0: File = _ private var rootDir1: File = _ diff --git a/core/src/test/scala/org/apache/spark/storage/FlatmapIteratorSuite.scala b/core/src/test/scala/org/apache/spark/storage/FlatmapIteratorSuite.scala index 47341b74e9c0f..b21c91f75d5c7 100644 --- a/core/src/test/scala/org/apache/spark/storage/FlatmapIteratorSuite.scala +++ b/core/src/test/scala/org/apache/spark/storage/FlatmapIteratorSuite.scala @@ -16,11 +16,10 @@ */ package org.apache.spark.storage -import org.scalatest.FunSuite -import org.apache.spark.{SharedSparkContext, SparkConf, LocalSparkContext, SparkContext} +import org.apache.spark._ -class FlatmapIteratorSuite extends FunSuite with LocalSparkContext { +class FlatmapIteratorSuite extends SparkFunSuite with LocalSparkContext { /* Tests the ability of Spark to deal with user provided iterators from flatMap * calls, that may generate more data then available memory. In any * memory based persistance Spark will unroll the iterator into an ArrayBuffer diff --git a/core/src/test/scala/org/apache/spark/storage/LocalDirsSuite.scala b/core/src/test/scala/org/apache/spark/storage/LocalDirsSuite.scala index b47157f8331cc..ac6fec56bbf4f 100644 --- a/core/src/test/scala/org/apache/spark/storage/LocalDirsSuite.scala +++ b/core/src/test/scala/org/apache/spark/storage/LocalDirsSuite.scala @@ -20,15 +20,15 @@ package org.apache.spark.storage import java.io.File import org.apache.spark.util.Utils -import org.scalatest.{BeforeAndAfter, FunSuite} +import org.scalatest.BeforeAndAfter -import org.apache.spark.SparkConf +import org.apache.spark.{SparkConf, SparkFunSuite} /** * Tests for the spark.local.dir and SPARK_LOCAL_DIRS configuration options. */ -class LocalDirsSuite extends FunSuite with BeforeAndAfter { +class LocalDirsSuite extends SparkFunSuite with BeforeAndAfter { before { Utils.clearLocalRootDirs() diff --git a/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala b/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala index 2080c432d77db..2a7fe67ad8585 100644 --- a/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala +++ b/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala @@ -26,15 +26,14 @@ import org.mockito.Matchers.{any, eq => meq} import org.mockito.Mockito._ import org.mockito.invocation.InvocationOnMock import org.mockito.stubbing.Answer -import org.scalatest.FunSuite -import org.apache.spark.{SparkConf, TaskContextImpl} +import org.apache.spark.{SparkConf, SparkFunSuite, TaskContextImpl} import org.apache.spark.network._ import org.apache.spark.network.buffer.ManagedBuffer import org.apache.spark.network.shuffle.BlockFetchingListener import org.apache.spark.serializer.TestSerializer -class ShuffleBlockFetcherIteratorSuite extends FunSuite { +class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite { // Some of the tests are quite tricky because we are testing the cleanup behavior // in the presence of faults. diff --git a/core/src/test/scala/org/apache/spark/storage/StorageStatusListenerSuite.scala b/core/src/test/scala/org/apache/spark/storage/StorageStatusListenerSuite.scala index 3a45875391e29..1a199beb3558f 100644 --- a/core/src/test/scala/org/apache/spark/storage/StorageStatusListenerSuite.scala +++ b/core/src/test/scala/org/apache/spark/storage/StorageStatusListenerSuite.scala @@ -17,15 +17,14 @@ package org.apache.spark.storage -import org.scalatest.FunSuite -import org.apache.spark.Success +import org.apache.spark.{SparkFunSuite, Success} import org.apache.spark.executor.TaskMetrics import org.apache.spark.scheduler._ /** * Test the behavior of StorageStatusListener in response to all relevant events. */ -class StorageStatusListenerSuite extends FunSuite { +class StorageStatusListenerSuite extends SparkFunSuite { private val bm1 = BlockManagerId("big", "dog", 1) private val bm2 = BlockManagerId("fat", "duck", 2) private val taskInfo1 = new TaskInfo(0, 0, 0, 0, "big", "dog", TaskLocality.ANY, false) diff --git a/core/src/test/scala/org/apache/spark/storage/StorageSuite.scala b/core/src/test/scala/org/apache/spark/storage/StorageSuite.scala index 17193ddbfd894..1d5a813a4d336 100644 --- a/core/src/test/scala/org/apache/spark/storage/StorageSuite.scala +++ b/core/src/test/scala/org/apache/spark/storage/StorageSuite.scala @@ -17,12 +17,12 @@ package org.apache.spark.storage -import org.scalatest.FunSuite +import org.apache.spark.SparkFunSuite /** * Test various functionalities in StorageUtils and StorageStatus. */ -class StorageSuite extends FunSuite { +class StorageSuite extends SparkFunSuite { private val memAndDisk = StorageLevel.MEMORY_AND_DISK // For testing add, update, and remove (for non-RDD blocks) diff --git a/core/src/test/scala/org/apache/spark/ui/UISeleniumSuite.scala b/core/src/test/scala/org/apache/spark/ui/UISeleniumSuite.scala index a727a43f44dfc..33712f1bfa782 100644 --- a/core/src/test/scala/org/apache/spark/ui/UISeleniumSuite.scala +++ b/core/src/test/scala/org/apache/spark/ui/UISeleniumSuite.scala @@ -42,7 +42,7 @@ import org.apache.spark.status.api.v1.{JacksonMessageWriter, StageStatus} /** * Selenium tests for the Spark Web UI. */ -class UISeleniumSuite extends FunSuite with WebBrowser with Matchers with BeforeAndAfterAll { +class UISeleniumSuite extends SparkFunSuite with WebBrowser with Matchers with BeforeAndAfterAll { implicit var webDriver: WebDriver = _ implicit val formats = DefaultFormats diff --git a/core/src/test/scala/org/apache/spark/ui/UISuite.scala b/core/src/test/scala/org/apache/spark/ui/UISuite.scala index 77a038dc1720d..8f9502b5673d1 100644 --- a/core/src/test/scala/org/apache/spark/ui/UISuite.scala +++ b/core/src/test/scala/org/apache/spark/ui/UISuite.scala @@ -23,14 +23,13 @@ import scala.io.Source import scala.util.{Failure, Success, Try} import org.eclipse.jetty.servlet.ServletContextHandler -import org.scalatest.FunSuite import org.scalatest.concurrent.Eventually._ import org.scalatest.time.SpanSugar._ import org.apache.spark.LocalSparkContext._ -import org.apache.spark.{SparkConf, SparkContext} +import org.apache.spark.{SparkConf, SparkContext, SparkFunSuite} -class UISuite extends FunSuite { +class UISuite extends SparkFunSuite { /** * Create a test SparkContext with the SparkUI enabled. diff --git a/core/src/test/scala/org/apache/spark/ui/jobs/JobProgressListenerSuite.scala b/core/src/test/scala/org/apache/spark/ui/jobs/JobProgressListenerSuite.scala index 967dd0821ebd0..56f7b9cf1f358 100644 --- a/core/src/test/scala/org/apache/spark/ui/jobs/JobProgressListenerSuite.scala +++ b/core/src/test/scala/org/apache/spark/ui/jobs/JobProgressListenerSuite.scala @@ -19,7 +19,6 @@ package org.apache.spark.ui.jobs import java.util.Properties -import org.scalatest.FunSuite import org.scalatest.Matchers import org.apache.spark._ @@ -28,7 +27,7 @@ import org.apache.spark.executor._ import org.apache.spark.scheduler._ import org.apache.spark.util.Utils -class JobProgressListenerSuite extends FunSuite with LocalSparkContext with Matchers { +class JobProgressListenerSuite extends SparkFunSuite with LocalSparkContext with Matchers { val jobSubmissionTime = 1421191042750L val jobCompletionTime = 1421191296660L diff --git a/core/src/test/scala/org/apache/spark/ui/scope/RDDOperationGraphListenerSuite.scala b/core/src/test/scala/org/apache/spark/ui/scope/RDDOperationGraphListenerSuite.scala index c1126f3af52e6..86b078851851f 100644 --- a/core/src/test/scala/org/apache/spark/ui/scope/RDDOperationGraphListenerSuite.scala +++ b/core/src/test/scala/org/apache/spark/ui/scope/RDDOperationGraphListenerSuite.scala @@ -17,9 +17,7 @@ package org.apache.spark.ui.scope -import org.scalatest.FunSuite - -import org.apache.spark.SparkConf +import org.apache.spark.{SparkConf, SparkFunSuite} import org.apache.spark.scheduler._ import org.apache.spark.scheduler.SparkListenerStageSubmitted import org.apache.spark.scheduler.SparkListenerStageCompleted @@ -28,7 +26,7 @@ import org.apache.spark.scheduler.SparkListenerJobStart /** * Tests that this listener populates and cleans up its data structures properly. */ -class RDDOperationGraphListenerSuite extends FunSuite { +class RDDOperationGraphListenerSuite extends SparkFunSuite { private var jobIdCounter = 0 private var stageIdCounter = 0 private val maxRetainedJobs = 10 diff --git a/core/src/test/scala/org/apache/spark/ui/storage/StorageTabSuite.scala b/core/src/test/scala/org/apache/spark/ui/storage/StorageTabSuite.scala index 8778042e34657..37e2670de9685 100644 --- a/core/src/test/scala/org/apache/spark/ui/storage/StorageTabSuite.scala +++ b/core/src/test/scala/org/apache/spark/ui/storage/StorageTabSuite.scala @@ -17,8 +17,8 @@ package org.apache.spark.ui.storage -import org.scalatest.{BeforeAndAfter, FunSuite} -import org.apache.spark.Success +import org.scalatest.BeforeAndAfter +import org.apache.spark.{SparkFunSuite, Success} import org.apache.spark.executor.TaskMetrics import org.apache.spark.scheduler._ import org.apache.spark.storage._ @@ -26,7 +26,7 @@ import org.apache.spark.storage._ /** * Test various functionality in the StorageListener that supports the StorageTab. */ -class StorageTabSuite extends FunSuite with BeforeAndAfter { +class StorageTabSuite extends SparkFunSuite with BeforeAndAfter { private var bus: LiveListenerBus = _ private var storageStatusListener: StorageStatusListener = _ private var storageListener: StorageListener = _ diff --git a/core/src/test/scala/org/apache/spark/util/AkkaUtilsSuite.scala b/core/src/test/scala/org/apache/spark/util/AkkaUtilsSuite.scala index ccdb3f571429d..6c40685484ed4 100644 --- a/core/src/test/scala/org/apache/spark/util/AkkaUtilsSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/AkkaUtilsSuite.scala @@ -20,7 +20,6 @@ package org.apache.spark.util import java.util.concurrent.TimeoutException import akka.actor.ActorNotFound -import org.scalatest.FunSuite import org.apache.spark._ import org.apache.spark.rpc.RpcEnv @@ -32,7 +31,7 @@ import org.apache.spark.SSLSampleConfigs._ /** * Test the AkkaUtils with various security settings. */ -class AkkaUtilsSuite extends FunSuite with LocalSparkContext with ResetSystemProperties { +class AkkaUtilsSuite extends SparkFunSuite with LocalSparkContext with ResetSystemProperties { test("remote fetch security bad password") { val conf = new SparkConf diff --git a/core/src/test/scala/org/apache/spark/util/ClosureCleanerSuite.scala b/core/src/test/scala/org/apache/spark/util/ClosureCleanerSuite.scala index 7b165fe28bdd3..70cd27b04347d 100644 --- a/core/src/test/scala/org/apache/spark/util/ClosureCleanerSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/ClosureCleanerSuite.scala @@ -20,14 +20,12 @@ package org.apache.spark.util import java.io.NotSerializableException import java.util.Random -import org.scalatest.FunSuite - import org.apache.spark.LocalSparkContext._ -import org.apache.spark.{TaskContext, SparkContext, SparkException} +import org.apache.spark.{SparkContext, SparkException, SparkFunSuite, TaskContext} import org.apache.spark.partial.CountEvaluator import org.apache.spark.rdd.RDD -class ClosureCleanerSuite extends FunSuite { +class ClosureCleanerSuite extends SparkFunSuite { test("closures inside an object") { assert(TestObject.run() === 30) // 6 + 7 + 8 + 9 } @@ -203,7 +201,7 @@ object TestObjectWithNestedReturns { def run(): Int = { withSpark(new SparkContext("local", "test")) { sc => val nums = sc.parallelize(Array(1, 2, 3, 4)) - nums.map {x => + nums.map {x => // this return is fine since it will not transfer control outside the closure def foo(): Int = { return 5; 1 } foo() diff --git a/core/src/test/scala/org/apache/spark/util/ClosureCleanerSuite2.scala b/core/src/test/scala/org/apache/spark/util/ClosureCleanerSuite2.scala index 59456790e89f0..3147c937769d2 100644 --- a/core/src/test/scala/org/apache/spark/util/ClosureCleanerSuite2.scala +++ b/core/src/test/scala/org/apache/spark/util/ClosureCleanerSuite2.scala @@ -21,16 +21,16 @@ import java.io.NotSerializableException import scala.collection.mutable -import org.scalatest.{BeforeAndAfterAll, FunSuite, PrivateMethodTester} +import org.scalatest.{BeforeAndAfterAll, PrivateMethodTester} -import org.apache.spark.{SparkContext, SparkException} +import org.apache.spark.{SparkContext, SparkException, SparkFunSuite} import org.apache.spark.serializer.SerializerInstance /** * Another test suite for the closure cleaner that is finer-grained. * For tests involving end-to-end Spark jobs, see {{ClosureCleanerSuite}}. */ -class ClosureCleanerSuite2 extends FunSuite with BeforeAndAfterAll with PrivateMethodTester { +class ClosureCleanerSuite2 extends SparkFunSuite with BeforeAndAfterAll with PrivateMethodTester { // Start a SparkContext so that the closure serializer is accessible // We do not actually use this explicitly otherwise diff --git a/core/src/test/scala/org/apache/spark/util/CompletionIteratorSuite.scala b/core/src/test/scala/org/apache/spark/util/CompletionIteratorSuite.scala index 3755d43e25ea8..688fcd9f9aaba 100644 --- a/core/src/test/scala/org/apache/spark/util/CompletionIteratorSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/CompletionIteratorSuite.scala @@ -17,9 +17,9 @@ package org.apache.spark.util -import org.scalatest.FunSuite +import org.apache.spark.SparkFunSuite -class CompletionIteratorSuite extends FunSuite { +class CompletionIteratorSuite extends SparkFunSuite { test("basic test") { var numTimesCompleted = 0 val iter = List(1, 2, 3).iterator diff --git a/core/src/test/scala/org/apache/spark/util/DistributionSuite.scala b/core/src/test/scala/org/apache/spark/util/DistributionSuite.scala index 090d48ec921a1..cdd6555697c23 100644 --- a/core/src/test/scala/org/apache/spark/util/DistributionSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/DistributionSuite.scala @@ -17,14 +17,15 @@ package org.apache.spark.util -import org.scalatest.FunSuite import org.scalatest.Matchers +import org.apache.spark.SparkFunSuite + /** * */ -class DistributionSuite extends FunSuite with Matchers { +class DistributionSuite extends SparkFunSuite with Matchers { test("summary") { val d = new Distribution((1 to 100).toArray.map{_.toDouble}) val stats = d.statCounter diff --git a/core/src/test/scala/org/apache/spark/util/EventLoopSuite.scala b/core/src/test/scala/org/apache/spark/util/EventLoopSuite.scala index 47b535206c949..b207d497f33c2 100644 --- a/core/src/test/scala/org/apache/spark/util/EventLoopSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/EventLoopSuite.scala @@ -25,9 +25,10 @@ import scala.language.postfixOps import org.scalatest.concurrent.Eventually._ import org.scalatest.concurrent.Timeouts -import org.scalatest.FunSuite -class EventLoopSuite extends FunSuite with Timeouts { +import org.apache.spark.SparkFunSuite + +class EventLoopSuite extends SparkFunSuite with Timeouts { test("EventLoop") { val buffer = new mutable.ArrayBuffer[Int] with mutable.SynchronizedBuffer[Int] diff --git a/core/src/test/scala/org/apache/spark/util/FileAppenderSuite.scala b/core/src/test/scala/org/apache/spark/util/FileAppenderSuite.scala index c05317534cddf..2b76ae1f8a24b 100644 --- a/core/src/test/scala/org/apache/spark/util/FileAppenderSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/FileAppenderSuite.scala @@ -22,15 +22,15 @@ import java.io._ import scala.collection.mutable.HashSet import scala.reflect._ -import org.scalatest.{BeforeAndAfter, FunSuite} +import org.scalatest.BeforeAndAfter import com.google.common.base.Charsets.UTF_8 import com.google.common.io.Files -import org.apache.spark.{Logging, SparkConf} +import org.apache.spark.{Logging, SparkConf, SparkFunSuite} import org.apache.spark.util.logging.{RollingFileAppender, SizeBasedRollingPolicy, TimeBasedRollingPolicy, FileAppender} -class FileAppenderSuite extends FunSuite with BeforeAndAfter with Logging { +class FileAppenderSuite extends SparkFunSuite with BeforeAndAfter with Logging { val testFile = new File(Utils.createTempDir(), "FileAppenderSuite-test").getAbsoluteFile diff --git a/core/src/test/scala/org/apache/spark/util/JsonProtocolSuite.scala b/core/src/test/scala/org/apache/spark/util/JsonProtocolSuite.scala index 0d9126f23ccc5..e0ef9c70a5fc3 100644 --- a/core/src/test/scala/org/apache/spark/util/JsonProtocolSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/JsonProtocolSuite.scala @@ -25,7 +25,6 @@ import org.apache.spark.shuffle.MetadataFetchFailedException import scala.collection.Map import org.json4s.jackson.JsonMethods._ -import org.scalatest.FunSuite import org.apache.spark._ import org.apache.spark.executor._ @@ -33,7 +32,7 @@ import org.apache.spark.rdd.RDDOperationScope import org.apache.spark.scheduler._ import org.apache.spark.storage._ -class JsonProtocolSuite extends FunSuite { +class JsonProtocolSuite extends SparkFunSuite { val jobSubmissionTime = 1421191042750L val jobCompletionTime = 1421191296660L diff --git a/core/src/test/scala/org/apache/spark/util/MutableURLClassLoaderSuite.scala b/core/src/test/scala/org/apache/spark/util/MutableURLClassLoaderSuite.scala index 87de90bb0dfb0..42125547436cb 100644 --- a/core/src/test/scala/org/apache/spark/util/MutableURLClassLoaderSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/MutableURLClassLoaderSuite.scala @@ -19,11 +19,9 @@ package org.apache.spark.util import java.net.URLClassLoader -import org.scalatest.FunSuite +import org.apache.spark.{SparkContext, SparkException, SparkFunSuite, TestUtils} -import org.apache.spark.{SparkContext, SparkException, TestUtils} - -class MutableURLClassLoaderSuite extends FunSuite { +class MutableURLClassLoaderSuite extends SparkFunSuite { val urls2 = List(TestUtils.createJarWithClasses( classNames = Seq("FakeClass1", "FakeClass2", "FakeClass3"), diff --git a/core/src/test/scala/org/apache/spark/util/NextIteratorSuite.scala b/core/src/test/scala/org/apache/spark/util/NextIteratorSuite.scala index 403dcb03bd6e5..4b7164d8acbce 100644 --- a/core/src/test/scala/org/apache/spark/util/NextIteratorSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/NextIteratorSuite.scala @@ -21,10 +21,11 @@ import java.util.NoSuchElementException import scala.collection.mutable.Buffer -import org.scalatest.FunSuite import org.scalatest.Matchers -class NextIteratorSuite extends FunSuite with Matchers { +import org.apache.spark.SparkFunSuite + +class NextIteratorSuite extends SparkFunSuite with Matchers { test("one iteration") { val i = new StubIterator(Buffer(1)) i.hasNext should be (true) diff --git a/core/src/test/scala/org/apache/spark/util/ResetSystemProperties.scala b/core/src/test/scala/org/apache/spark/util/ResetSystemProperties.scala index bad1aa99952cf..c58db5e606f7c 100644 --- a/core/src/test/scala/org/apache/spark/util/ResetSystemProperties.scala +++ b/core/src/test/scala/org/apache/spark/util/ResetSystemProperties.scala @@ -22,12 +22,14 @@ import java.util.Properties import org.apache.commons.lang3.SerializationUtils import org.scalatest.{BeforeAndAfterEach, Suite} +import org.apache.spark.SparkFunSuite + /** * Mixin for automatically resetting system properties that are modified in ScalaTest tests. * This resets the properties after each individual test. * * The order in which fixtures are mixed in affects the order in which they are invoked by tests. - * If we have a suite `MySuite extends FunSuite with Foo with Bar`, then + * If we have a suite `MySuite extends SparkFunSuite with Foo with Bar`, then * Bar's `super` is Foo, so Bar's beforeEach() will and afterEach() methods will be invoked first * by the rest runner. * diff --git a/core/src/test/scala/org/apache/spark/util/SizeEstimatorSuite.scala b/core/src/test/scala/org/apache/spark/util/SizeEstimatorSuite.scala index 04f0f3749d6b9..20550178fb1bd 100644 --- a/core/src/test/scala/org/apache/spark/util/SizeEstimatorSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/SizeEstimatorSuite.scala @@ -19,7 +19,9 @@ package org.apache.spark.util import scala.collection.mutable.ArrayBuffer -import org.scalatest.{BeforeAndAfterEach, BeforeAndAfterAll, FunSuite, PrivateMethodTester} +import org.scalatest.{BeforeAndAfterEach, BeforeAndAfterAll, PrivateMethodTester} + +import org.apache.spark.SparkFunSuite class DummyClass1 {} @@ -59,7 +61,10 @@ class DummyString(val arr: Array[Char]) { } class SizeEstimatorSuite - extends FunSuite with BeforeAndAfterEach with PrivateMethodTester with ResetSystemProperties { + extends SparkFunSuite + with BeforeAndAfterEach + with PrivateMethodTester + with ResetSystemProperties { override def beforeEach() { // Set the arch to 64-bit and compressedOops to true to get a deterministic test-case diff --git a/core/src/test/scala/org/apache/spark/util/ThreadUtilsSuite.scala b/core/src/test/scala/org/apache/spark/util/ThreadUtilsSuite.scala index 751d3df9cc8f7..8c51e6b14b7fc 100644 --- a/core/src/test/scala/org/apache/spark/util/ThreadUtilsSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/ThreadUtilsSuite.scala @@ -23,9 +23,9 @@ import java.util.concurrent.{CountDownLatch, TimeUnit} import scala.concurrent.{Await, Future} import scala.concurrent.duration._ -import org.scalatest.FunSuite +import org.apache.spark.SparkFunSuite -class ThreadUtilsSuite extends FunSuite { +class ThreadUtilsSuite extends SparkFunSuite { test("newDaemonSingleThreadExecutor") { val executor = ThreadUtils.newDaemonSingleThreadExecutor("this-is-a-thread-name") diff --git a/core/src/test/scala/org/apache/spark/util/TimeStampedHashMapSuite.scala b/core/src/test/scala/org/apache/spark/util/TimeStampedHashMapSuite.scala index 8b72fe665c214..9b3169026cda3 100644 --- a/core/src/test/scala/org/apache/spark/util/TimeStampedHashMapSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/TimeStampedHashMapSuite.scala @@ -23,9 +23,9 @@ import scala.collection.mutable import scala.collection.mutable.ArrayBuffer import scala.util.Random -import org.scalatest.FunSuite +import org.apache.spark.SparkFunSuite -class TimeStampedHashMapSuite extends FunSuite { +class TimeStampedHashMapSuite extends SparkFunSuite { // Test the testMap function - a Scala HashMap should obviously pass testMap(new mutable.HashMap[String, String]()) diff --git a/core/src/test/scala/org/apache/spark/util/UtilsSuite.scala b/core/src/test/scala/org/apache/spark/util/UtilsSuite.scala index afa5cdc819746..a867cf83dc3f1 100644 --- a/core/src/test/scala/org/apache/spark/util/UtilsSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/UtilsSuite.scala @@ -29,16 +29,15 @@ import scala.util.Random import com.google.common.base.Charsets.UTF_8 import com.google.common.io.Files -import org.scalatest.FunSuite import org.apache.hadoop.conf.Configuration import org.apache.hadoop.fs.Path import org.apache.spark.network.util.ByteUnit -import org.apache.spark.Logging +import org.apache.spark.{Logging, SparkFunSuite} import org.apache.spark.SparkConf -class UtilsSuite extends FunSuite with ResetSystemProperties with Logging { +class UtilsSuite extends SparkFunSuite with ResetSystemProperties with Logging { test("timeConversion") { // Test -1 diff --git a/core/src/test/scala/org/apache/spark/util/VectorSuite.scala b/core/src/test/scala/org/apache/spark/util/VectorSuite.scala index ce2968728a996..11194cd22a419 100644 --- a/core/src/test/scala/org/apache/spark/util/VectorSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/VectorSuite.scala @@ -19,13 +19,13 @@ package org.apache.spark.util import scala.util.Random -import org.scalatest.FunSuite +import org.apache.spark.SparkFunSuite /** * Tests org.apache.spark.util.Vector functionality */ @deprecated("suppress compile time deprecation warning", "1.0.0") -class VectorSuite extends FunSuite { +class VectorSuite extends SparkFunSuite { def verifyVector(vector: Vector, expectedLength: Int): Unit = { assert(vector.length == expectedLength) diff --git a/core/src/test/scala/org/apache/spark/util/collection/AppendOnlyMapSuite.scala b/core/src/test/scala/org/apache/spark/util/collection/AppendOnlyMapSuite.scala index cb99d14b27af4..a2a6d703860f2 100644 --- a/core/src/test/scala/org/apache/spark/util/collection/AppendOnlyMapSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/collection/AppendOnlyMapSuite.scala @@ -21,9 +21,9 @@ import java.util.Comparator import scala.collection.mutable.HashSet -import org.scalatest.FunSuite +import org.apache.spark.SparkFunSuite -class AppendOnlyMapSuite extends FunSuite { +class AppendOnlyMapSuite extends SparkFunSuite { test("initialization") { val goodMap1 = new AppendOnlyMap[Int, Int](1) assert(goodMap1.size === 0) diff --git a/core/src/test/scala/org/apache/spark/util/collection/BitSetSuite.scala b/core/src/test/scala/org/apache/spark/util/collection/BitSetSuite.scala index ffc206991906a..69dbfa9cd7141 100644 --- a/core/src/test/scala/org/apache/spark/util/collection/BitSetSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/collection/BitSetSuite.scala @@ -17,9 +17,9 @@ package org.apache.spark.util.collection -import org.scalatest.FunSuite +import org.apache.spark.SparkFunSuite -class BitSetSuite extends FunSuite { +class BitSetSuite extends SparkFunSuite { test("basic set and get") { val setBits = Seq(0, 9, 1, 10, 90, 96) diff --git a/core/src/test/scala/org/apache/spark/util/collection/ChainedBufferSuite.scala b/core/src/test/scala/org/apache/spark/util/collection/ChainedBufferSuite.scala index c0c38cd4ac4ad..05306f408847d 100644 --- a/core/src/test/scala/org/apache/spark/util/collection/ChainedBufferSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/collection/ChainedBufferSuite.scala @@ -19,10 +19,11 @@ package org.apache.spark.util.collection import java.nio.ByteBuffer -import org.scalatest.FunSuite import org.scalatest.Matchers._ -class ChainedBufferSuite extends FunSuite { +import org.apache.spark.SparkFunSuite + +class ChainedBufferSuite extends SparkFunSuite { test("write and read at start") { // write from start of source array val buffer = new ChainedBuffer(8) diff --git a/core/src/test/scala/org/apache/spark/util/collection/CompactBufferSuite.scala b/core/src/test/scala/org/apache/spark/util/collection/CompactBufferSuite.scala index 6c956d93dc80d..bc5479991a99d 100644 --- a/core/src/test/scala/org/apache/spark/util/collection/CompactBufferSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/collection/CompactBufferSuite.scala @@ -17,9 +17,9 @@ package org.apache.spark.util.collection -import org.scalatest.FunSuite +import org.apache.spark.SparkFunSuite -class CompactBufferSuite extends FunSuite { +class CompactBufferSuite extends SparkFunSuite { test("empty buffer") { val b = new CompactBuffer[Int] assert(b.size === 0) diff --git a/core/src/test/scala/org/apache/spark/util/collection/ExternalAppendOnlyMapSuite.scala b/core/src/test/scala/org/apache/spark/util/collection/ExternalAppendOnlyMapSuite.scala index dff8f3ddc816f..79eba61a87251 100644 --- a/core/src/test/scala/org/apache/spark/util/collection/ExternalAppendOnlyMapSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/collection/ExternalAppendOnlyMapSuite.scala @@ -19,12 +19,10 @@ package org.apache.spark.util.collection import scala.collection.mutable.ArrayBuffer -import org.scalatest.FunSuite - import org.apache.spark._ import org.apache.spark.io.CompressionCodec -class ExternalAppendOnlyMapSuite extends FunSuite with LocalSparkContext { +class ExternalAppendOnlyMapSuite extends SparkFunSuite with LocalSparkContext { private val allCompressionCodecs = CompressionCodec.ALL_COMPRESSION_CODECS private def createCombiner[T](i: T) = ArrayBuffer[T](i) private def mergeValue[T](buffer: ArrayBuffer[T], i: T): ArrayBuffer[T] = buffer += i diff --git a/core/src/test/scala/org/apache/spark/util/collection/ExternalSorterSuite.scala b/core/src/test/scala/org/apache/spark/util/collection/ExternalSorterSuite.scala index 7a98723bc6472..9cefa612f5491 100644 --- a/core/src/test/scala/org/apache/spark/util/collection/ExternalSorterSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/collection/ExternalSorterSuite.scala @@ -19,14 +19,12 @@ package org.apache.spark.util.collection import scala.collection.mutable.ArrayBuffer -import org.scalatest.{FunSuite, PrivateMethodTester} - import scala.util.Random import org.apache.spark._ import org.apache.spark.serializer.{JavaSerializer, KryoSerializer} -class ExternalSorterSuite extends FunSuite with LocalSparkContext with PrivateMethodTester { +class ExternalSorterSuite extends SparkFunSuite with LocalSparkContext { private def createSparkConf(loadDefaults: Boolean, kryo: Boolean): SparkConf = { val conf = new SparkConf(loadDefaults) if (kryo) { @@ -37,21 +35,12 @@ class ExternalSorterSuite extends FunSuite with LocalSparkContext with PrivateMe conf.set("spark.serializer.objectStreamReset", "1") conf.set("spark.serializer", classOf[JavaSerializer].getName) } + conf.set("spark.shuffle.sort.bypassMergeThreshold", "0") // Ensure that we actually have multiple batches per spill file conf.set("spark.shuffle.spill.batchSize", "10") conf } - private def assertBypassedMergeSort(sorter: ExternalSorter[_, _, _]): Unit = { - val bypassMergeSort = PrivateMethod[Boolean]('bypassMergeSort) - assert(sorter.invokePrivate(bypassMergeSort()), "sorter did not bypass merge-sort") - } - - private def assertDidNotBypassMergeSort(sorter: ExternalSorter[_, _, _]): Unit = { - val bypassMergeSort = PrivateMethod[Boolean]('bypassMergeSort) - assert(!sorter.invokePrivate(bypassMergeSort()), "sorter bypassed merge-sort") - } - test("empty data stream with kryo ser") { emptyDataStream(createSparkConf(false, true)) } @@ -161,39 +150,6 @@ class ExternalSorterSuite extends FunSuite with LocalSparkContext with PrivateMe val sorter = new ExternalSorter[Int, Int, Int]( None, Some(new HashPartitioner(7)), Some(ord), None) - assertDidNotBypassMergeSort(sorter) - sorter.insertAll(elements) - assert(sc.env.blockManager.diskBlockManager.getAllFiles().length > 0) // Make sure it spilled - val iter = sorter.partitionedIterator.map(p => (p._1, p._2.toList)) - assert(iter.next() === (0, Nil)) - assert(iter.next() === (1, List((1, 1)))) - assert(iter.next() === (2, (0 until 100000).map(x => (2, 2)).toList)) - assert(iter.next() === (3, Nil)) - assert(iter.next() === (4, Nil)) - assert(iter.next() === (5, List((5, 5)))) - assert(iter.next() === (6, Nil)) - sorter.stop() - } - - test("empty partitions with spilling, bypass merge-sort with kryo ser") { - emptyPartitionerWithSpillingBypassMergeSort(createSparkConf(false, true)) - } - - test("empty partitions with spilling, bypass merge-sort with java ser") { - emptyPartitionerWithSpillingBypassMergeSort(createSparkConf(false, false)) - } - - def emptyPartitionerWithSpillingBypassMergeSort(conf: SparkConf) { - conf.set("spark.shuffle.memoryFraction", "0.001") - conf.set("spark.shuffle.spill.initialMemoryThreshold", "512") - conf.set("spark.shuffle.manager", "org.apache.spark.shuffle.sort.SortShuffleManager") - sc = new SparkContext("local", "test", conf) - - val elements = Iterator((1, 1), (5, 5)) ++ (0 until 100000).iterator.map(x => (2, 2)) - - val sorter = new ExternalSorter[Int, Int, Int]( - None, Some(new HashPartitioner(7)), None, None) - assertBypassedMergeSort(sorter) sorter.insertAll(elements) assert(sc.env.blockManager.diskBlockManager.getAllFiles().length > 0) // Make sure it spilled val iter = sorter.partitionedIterator.map(p => (p._1, p._2.toList)) @@ -376,7 +332,6 @@ class ExternalSorterSuite extends FunSuite with LocalSparkContext with PrivateMe val sorter = new ExternalSorter[Int, Int, Int]( None, Some(new HashPartitioner(3)), Some(ord), None) - assertDidNotBypassMergeSort(sorter) sorter.insertAll((0 until 120000).iterator.map(i => (i, i))) assert(diskBlockManager.getAllFiles().length > 0) sorter.stop() @@ -384,7 +339,6 @@ class ExternalSorterSuite extends FunSuite with LocalSparkContext with PrivateMe val sorter2 = new ExternalSorter[Int, Int, Int]( None, Some(new HashPartitioner(3)), Some(ord), None) - assertDidNotBypassMergeSort(sorter2) sorter2.insertAll((0 until 120000).iterator.map(i => (i, i))) assert(diskBlockManager.getAllFiles().length > 0) assert(sorter2.iterator.toSet === (0 until 120000).map(i => (i, i)).toSet) @@ -392,29 +346,6 @@ class ExternalSorterSuite extends FunSuite with LocalSparkContext with PrivateMe assert(diskBlockManager.getAllBlocks().length === 0) } - test("cleanup of intermediate files in sorter, bypass merge-sort") { - val conf = createSparkConf(true, false) // Load defaults, otherwise SPARK_HOME is not found - conf.set("spark.shuffle.memoryFraction", "0.001") - conf.set("spark.shuffle.manager", "org.apache.spark.shuffle.sort.SortShuffleManager") - sc = new SparkContext("local", "test", conf) - val diskBlockManager = SparkEnv.get.blockManager.diskBlockManager - - val sorter = new ExternalSorter[Int, Int, Int](None, Some(new HashPartitioner(3)), None, None) - assertBypassedMergeSort(sorter) - sorter.insertAll((0 until 100000).iterator.map(i => (i, i))) - assert(diskBlockManager.getAllFiles().length > 0) - sorter.stop() - assert(diskBlockManager.getAllBlocks().length === 0) - - val sorter2 = new ExternalSorter[Int, Int, Int](None, Some(new HashPartitioner(3)), None, None) - assertBypassedMergeSort(sorter2) - sorter2.insertAll((0 until 100000).iterator.map(i => (i, i))) - assert(diskBlockManager.getAllFiles().length > 0) - assert(sorter2.iterator.toSet === (0 until 100000).map(i => (i, i)).toSet) - sorter2.stop() - assert(diskBlockManager.getAllBlocks().length === 0) - } - test("cleanup of intermediate files in sorter if there are errors") { val conf = createSparkConf(true, false) // Load defaults, otherwise SPARK_HOME is not found conf.set("spark.shuffle.memoryFraction", "0.001") @@ -426,7 +357,6 @@ class ExternalSorterSuite extends FunSuite with LocalSparkContext with PrivateMe val sorter = new ExternalSorter[Int, Int, Int]( None, Some(new HashPartitioner(3)), Some(ord), None) - assertDidNotBypassMergeSort(sorter) intercept[SparkException] { sorter.insertAll((0 until 120000).iterator.map(i => { if (i == 119990) { @@ -440,28 +370,6 @@ class ExternalSorterSuite extends FunSuite with LocalSparkContext with PrivateMe assert(diskBlockManager.getAllBlocks().length === 0) } - test("cleanup of intermediate files in sorter if there are errors, bypass merge-sort") { - val conf = createSparkConf(true, false) // Load defaults, otherwise SPARK_HOME is not found - conf.set("spark.shuffle.memoryFraction", "0.001") - conf.set("spark.shuffle.manager", "org.apache.spark.shuffle.sort.SortShuffleManager") - sc = new SparkContext("local", "test", conf) - val diskBlockManager = SparkEnv.get.blockManager.diskBlockManager - - val sorter = new ExternalSorter[Int, Int, Int](None, Some(new HashPartitioner(3)), None, None) - assertBypassedMergeSort(sorter) - intercept[SparkException] { - sorter.insertAll((0 until 100000).iterator.map(i => { - if (i == 99990) { - throw new SparkException("Intentional failure") - } - (i, i) - })) - } - assert(diskBlockManager.getAllFiles().length > 0) - sorter.stop() - assert(diskBlockManager.getAllBlocks().length === 0) - } - test("cleanup of intermediate files in shuffle") { val conf = createSparkConf(false, false) conf.set("spark.shuffle.memoryFraction", "0.001") @@ -776,40 +684,6 @@ class ExternalSorterSuite extends FunSuite with LocalSparkContext with PrivateMe } } - test("conditions for bypassing merge-sort") { - val conf = createSparkConf(false, false) - conf.set("spark.shuffle.memoryFraction", "0.001") - conf.set("spark.shuffle.manager", "org.apache.spark.shuffle.sort.SortShuffleManager") - sc = new SparkContext("local", "test", conf) - - val agg = new Aggregator[Int, Int, Int](i => i, (i, j) => i + j, (i, j) => i + j) - val ord = implicitly[Ordering[Int]] - - // Numbers of partitions that are above and below the default bypassMergeThreshold - val FEW_PARTITIONS = 50 - val MANY_PARTITIONS = 10000 - - // Sorters with no ordering or aggregator: should bypass unless # of partitions is high - - val sorter1 = new ExternalSorter[Int, Int, Int]( - None, Some(new HashPartitioner(FEW_PARTITIONS)), None, None) - assertBypassedMergeSort(sorter1) - - val sorter2 = new ExternalSorter[Int, Int, Int]( - None, Some(new HashPartitioner(MANY_PARTITIONS)), None, None) - assertDidNotBypassMergeSort(sorter2) - - // Sorters with an ordering or aggregator: should not bypass even if they have few partitions - - val sorter3 = new ExternalSorter[Int, Int, Int]( - None, Some(new HashPartitioner(FEW_PARTITIONS)), Some(ord), None) - assertDidNotBypassMergeSort(sorter3) - - val sorter4 = new ExternalSorter[Int, Int, Int]( - Some(agg), Some(new HashPartitioner(FEW_PARTITIONS)), None, None) - assertDidNotBypassMergeSort(sorter4) - } - test("sort without breaking sorting contracts with kryo ser") { sortWithoutBreakingSortingContracts(createSparkConf(true, true)) } diff --git a/core/src/test/scala/org/apache/spark/util/collection/OpenHashMapSuite.scala b/core/src/test/scala/org/apache/spark/util/collection/OpenHashMapSuite.scala index ef890d2ba60f3..94e011799921b 100644 --- a/core/src/test/scala/org/apache/spark/util/collection/OpenHashMapSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/collection/OpenHashMapSuite.scala @@ -19,12 +19,12 @@ package org.apache.spark.util.collection import scala.collection.mutable.HashSet -import org.scalatest.FunSuite import org.scalatest.Matchers +import org.apache.spark.SparkFunSuite import org.apache.spark.util.SizeEstimator -class OpenHashMapSuite extends FunSuite with Matchers { +class OpenHashMapSuite extends SparkFunSuite with Matchers { test("size for specialized, primitive value (int)") { val capacity = 1024 diff --git a/core/src/test/scala/org/apache/spark/util/collection/OpenHashSetSuite.scala b/core/src/test/scala/org/apache/spark/util/collection/OpenHashSetSuite.scala index 68a03e3a0970f..2607a543dd614 100644 --- a/core/src/test/scala/org/apache/spark/util/collection/OpenHashSetSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/collection/OpenHashSetSuite.scala @@ -17,12 +17,12 @@ package org.apache.spark.util.collection -import org.scalatest.FunSuite import org.scalatest.Matchers +import org.apache.spark.SparkFunSuite import org.apache.spark.util.SizeEstimator -class OpenHashSetSuite extends FunSuite with Matchers { +class OpenHashSetSuite extends SparkFunSuite with Matchers { test("size for specialized, primitive int") { val loadFactor = 0.7 diff --git a/core/src/test/scala/org/apache/spark/util/collection/PartitionedSerializedPairBufferSuite.scala b/core/src/test/scala/org/apache/spark/util/collection/PartitionedSerializedPairBufferSuite.scala index b5a2d9ef720c1..6d2459d48d326 100644 --- a/core/src/test/scala/org/apache/spark/util/collection/PartitionedSerializedPairBufferSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/collection/PartitionedSerializedPairBufferSuite.scala @@ -21,14 +21,13 @@ import java.io.{ByteArrayInputStream, ByteArrayOutputStream, InputStream} import com.google.common.io.ByteStreams -import org.scalatest.FunSuite import org.scalatest.Matchers._ -import org.apache.spark.SparkConf +import org.apache.spark.{SparkConf, SparkFunSuite} import org.apache.spark.serializer.KryoSerializer import org.apache.spark.storage.{FileSegment, BlockObjectWriter} -class PartitionedSerializedPairBufferSuite extends FunSuite { +class PartitionedSerializedPairBufferSuite extends SparkFunSuite { test("OrderedInputStream single record") { val serializerInstance = new KryoSerializer(new SparkConf()).newInstance diff --git a/core/src/test/scala/org/apache/spark/util/collection/PrimitiveKeyOpenHashMapSuite.scala b/core/src/test/scala/org/apache/spark/util/collection/PrimitiveKeyOpenHashMapSuite.scala index caf378fec8b3e..462bc2f29f9f8 100644 --- a/core/src/test/scala/org/apache/spark/util/collection/PrimitiveKeyOpenHashMapSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/collection/PrimitiveKeyOpenHashMapSuite.scala @@ -19,12 +19,12 @@ package org.apache.spark.util.collection import scala.collection.mutable.HashSet -import org.scalatest.FunSuite import org.scalatest.Matchers +import org.apache.spark.SparkFunSuite import org.apache.spark.util.SizeEstimator -class PrimitiveKeyOpenHashMapSuite extends FunSuite with Matchers { +class PrimitiveKeyOpenHashMapSuite extends SparkFunSuite with Matchers { test("size for specialized, primitive key, value (int, int)") { val capacity = 1024 diff --git a/core/src/test/scala/org/apache/spark/util/collection/PrimitiveVectorSuite.scala b/core/src/test/scala/org/apache/spark/util/collection/PrimitiveVectorSuite.scala index 970dade628fe4..ae0eebc26f01b 100644 --- a/core/src/test/scala/org/apache/spark/util/collection/PrimitiveVectorSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/collection/PrimitiveVectorSuite.scala @@ -17,11 +17,10 @@ package org.apache.spark.util.collection -import org.scalatest.FunSuite - +import org.apache.spark.SparkFunSuite import org.apache.spark.util.SizeEstimator -class PrimitiveVectorSuite extends FunSuite { +class PrimitiveVectorSuite extends SparkFunSuite { test("primitive value") { val vector = new PrimitiveVector[Int] diff --git a/core/src/test/scala/org/apache/spark/util/collection/SizeTrackerSuite.scala b/core/src/test/scala/org/apache/spark/util/collection/SizeTrackerSuite.scala index 1f33967249654..5a5919fca2469 100644 --- a/core/src/test/scala/org/apache/spark/util/collection/SizeTrackerSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/collection/SizeTrackerSuite.scala @@ -20,11 +20,10 @@ package org.apache.spark.util.collection import scala.reflect.ClassTag import scala.util.Random -import org.scalatest.FunSuite - +import org.apache.spark.SparkFunSuite import org.apache.spark.util.SizeEstimator -class SizeTrackerSuite extends FunSuite { +class SizeTrackerSuite extends SparkFunSuite { val NORMAL_ERROR = 0.20 val HIGH_ERROR = 0.30 diff --git a/core/src/test/scala/org/apache/spark/util/collection/SorterSuite.scala b/core/src/test/scala/org/apache/spark/util/collection/SorterSuite.scala index e0d6cc16bde05..72fd6daba8de0 100644 --- a/core/src/test/scala/org/apache/spark/util/collection/SorterSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/collection/SorterSuite.scala @@ -20,11 +20,10 @@ package org.apache.spark.util.collection import java.lang.{Float => JFloat, Integer => JInteger} import java.util.{Arrays, Comparator} -import org.scalatest.FunSuite - +import org.apache.spark.SparkFunSuite import org.apache.spark.util.random.XORShiftRandom -class SorterSuite extends FunSuite { +class SorterSuite extends SparkFunSuite { test("equivalent to Arrays.sort") { val rand = new XORShiftRandom(123) diff --git a/core/src/test/scala/org/apache/spark/util/io/ByteArrayChunkOutputStreamSuite.scala b/core/src/test/scala/org/apache/spark/util/io/ByteArrayChunkOutputStreamSuite.scala index f855831b8e367..361ec95654f47 100644 --- a/core/src/test/scala/org/apache/spark/util/io/ByteArrayChunkOutputStreamSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/io/ByteArrayChunkOutputStreamSuite.scala @@ -19,10 +19,10 @@ package org.apache.spark.util.io import scala.util.Random -import org.scalatest.FunSuite +import org.apache.spark.SparkFunSuite -class ByteArrayChunkOutputStreamSuite extends FunSuite { +class ByteArrayChunkOutputStreamSuite extends SparkFunSuite { test("empty output") { val o = new ByteArrayChunkOutputStream(1024) diff --git a/core/src/test/scala/org/apache/spark/util/random/RandomSamplerSuite.scala b/core/src/test/scala/org/apache/spark/util/random/RandomSamplerSuite.scala index 20944b62473c5..d6af0aebde733 100644 --- a/core/src/test/scala/org/apache/spark/util/random/RandomSamplerSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/random/RandomSamplerSuite.scala @@ -21,9 +21,11 @@ import java.util.Random import scala.collection.mutable.ArrayBuffer import org.apache.commons.math3.distribution.PoissonDistribution -import org.scalatest.{FunSuite, Matchers} +import org.scalatest.Matchers -class RandomSamplerSuite extends FunSuite with Matchers { +import org.apache.spark.SparkFunSuite + +class RandomSamplerSuite extends SparkFunSuite with Matchers { /** * My statistical testing methodology is to run a Kolmogorov-Smirnov (KS) test * between the random samplers and simple reference samplers (known to work correctly). @@ -76,7 +78,7 @@ class RandomSamplerSuite extends FunSuite with Matchers { } // Returns iterator over gap lengths between samples. - // This function assumes input data is integers sampled from the sequence of + // This function assumes input data is integers sampled from the sequence of // increasing integers: {0, 1, 2, ...}. This works because that is how I generate them, // and the samplers preserve their input order def gaps(data: Iterator[Int]): Iterator[Int] = { diff --git a/core/src/test/scala/org/apache/spark/util/random/SamplingUtilsSuite.scala b/core/src/test/scala/org/apache/spark/util/random/SamplingUtilsSuite.scala index 73a9d029b0248..667a4db6f7bb6 100644 --- a/core/src/test/scala/org/apache/spark/util/random/SamplingUtilsSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/random/SamplingUtilsSuite.scala @@ -20,9 +20,10 @@ package org.apache.spark.util.random import scala.util.Random import org.apache.commons.math3.distribution.{BinomialDistribution, PoissonDistribution} -import org.scalatest.FunSuite -class SamplingUtilsSuite extends FunSuite { +import org.apache.spark.SparkFunSuite + +class SamplingUtilsSuite extends SparkFunSuite { test("reservoirSampleAndCount") { val input = Seq.fill(100)(Random.nextInt()) diff --git a/core/src/test/scala/org/apache/spark/util/random/XORShiftRandomSuite.scala b/core/src/test/scala/org/apache/spark/util/random/XORShiftRandomSuite.scala index 03f5f2d1b8528..d26667bf720cf 100644 --- a/core/src/test/scala/org/apache/spark/util/random/XORShiftRandomSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/random/XORShiftRandomSuite.scala @@ -17,18 +17,18 @@ package org.apache.spark.util.random -import org.scalatest.FunSuite import org.scalatest.Matchers import org.apache.commons.math3.stat.inference.ChiSquareTest +import org.apache.spark.SparkFunSuite import org.apache.spark.util.Utils.times import scala.language.reflectiveCalls -class XORShiftRandomSuite extends FunSuite with Matchers { +class XORShiftRandomSuite extends SparkFunSuite with Matchers { - def fixture: Object {val seed: Long; val hundMil: Int; val xorRand: XORShiftRandom} = new { + private def fixture = new { val seed = 1L val xorRand = new XORShiftRandom(seed) val hundMil = 1e8.toInt diff --git a/dev/create-release/create-release.sh b/dev/create-release/create-release.sh index 54274a83f6d66..0b14a618e755c 100755 --- a/dev/create-release/create-release.sh +++ b/dev/create-release/create-release.sh @@ -228,14 +228,14 @@ if [[ ! "$@" =~ --skip-package ]]; then # We increment the Zinc port each time to avoid OOM's and other craziness if multiple builds # share the same Zinc server. - make_binary_release "hadoop1" "-Psparkr -Phadoop-1 -Phive -Phive-thriftserver" "3030" & - make_binary_release "hadoop1-scala2.11" "-Psparkr -Phadoop-1 -Phive -Dscala-2.11" "3031" & - make_binary_release "cdh4" "-Psparkr -Phadoop-1 -Phive -Phive-thriftserver -Dhadoop.version=2.0.0-mr1-cdh4.2.0" "3032" & - make_binary_release "hadoop2.3" "-Psparkr -Phadoop-2.3 -Phive -Phive-thriftserver -Pyarn" "3033" & - make_binary_release "hadoop2.4" "-Psparkr -Phadoop-2.4 -Phive -Phive-thriftserver -Pyarn" "3034" & - make_binary_release "mapr3" "-Pmapr3 -Psparkr -Phive -Phive-thriftserver" "3035" & - make_binary_release "mapr4" "-Pmapr4 -Psparkr -Pyarn -Phive -Phive-thriftserver" "3036" & - make_binary_release "hadoop2.4-without-hive" "-Psparkr -Phadoop-2.4 -Pyarn" "3037" & + make_binary_release "hadoop1" "-Psparkr -Psparkr-docs -Phadoop-1 -Phive -Phive-thriftserver" "3030" & + make_binary_release "hadoop1-scala2.11" "-Psparkr -Psparkr-docs -Phadoop-1 -Phive -Dscala-2.11" "3031" & + make_binary_release "cdh4" "-Psparkr -Psparkr-docs -Phadoop-1 -Phive -Phive-thriftserver -Dhadoop.version=2.0.0-mr1-cdh4.2.0" "3032" & + make_binary_release "hadoop2.3" "-Psparkr -Psparkr-docs -Phadoop-2.3 -Phive -Phive-thriftserver -Pyarn" "3033" & + make_binary_release "hadoop2.4" "-Psparkr -Psparkr-docs -Phadoop-2.4 -Phive -Phive-thriftserver -Pyarn" "3034" & + make_binary_release "mapr3" "-Pmapr3 -Psparkr -Psparkr-docs -Phive -Phive-thriftserver" "3035" & + make_binary_release "mapr4" "-Pmapr4 -Psparkr -Psparkr-docs -Pyarn -Phive -Phive-thriftserver" "3036" & + make_binary_release "hadoop2.4-without-hive" "-Psparkr -Psparkr-docs -Phadoop-2.4 -Pyarn" "3037" & wait rm -rf spark-$RELEASE_VERSION-bin-*/ diff --git a/docs/_layouts/global.html b/docs/_layouts/global.html index b92c75f90b11c..eebb3faf90fc0 100755 --- a/docs/_layouts/global.html +++ b/docs/_layouts/global.html @@ -75,6 +75,7 @@
  • MLlib (Machine Learning)
  • GraphX (Graph Processing)
  • Bagel (Pregel on Spark)
  • +
  • SparkR (R on Spark)
  • diff --git a/docs/building-spark.md b/docs/building-spark.md index b2649d1ee2a53..78cb9086f95e8 100644 --- a/docs/building-spark.md +++ b/docs/building-spark.md @@ -80,6 +80,7 @@ Because HDFS is not protocol-compatible across versions, if you want to read fro 2.2.xhadoop-2.2 2.3.xhadoop-2.3 2.4.xhadoop-2.4 + 2.6.x and later 2.xhadoop-2.6 @@ -130,9 +131,7 @@ To produce a Spark package compiled with Scala 2.11, use the `-Dscala-2.11` prop dev/change-version-to-2.11.sh mvn -Pyarn -Phadoop-2.4 -Dscala-2.11 -DskipTests clean package -Scala 2.11 support in Spark does not support a few features due to dependencies -which are themselves not Scala 2.11 ready. Specifically, Spark's external -Kafka library and JDBC component are not yet supported in Scala 2.11 builds. +Spark does not yet support its JDBC component for Scala 2.11. # Spark Tests in Maven diff --git a/docs/configuration.md b/docs/configuration.md index 30508a617fdd8..3a48da4592dd9 100644 --- a/docs/configuration.md +++ b/docs/configuration.md @@ -1,4 +1,4 @@ --- +--- layout: global displayTitle: Spark Configuration title: Configuration diff --git a/docs/index.md b/docs/index.md index 5ef6d983c45a5..fac071da81e60 100644 --- a/docs/index.md +++ b/docs/index.md @@ -54,7 +54,7 @@ Example applications are also provided in Python. For example, ./bin/spark-submit examples/src/main/python/pi.py 10 -Spark also provides an experimental R API since 1.4 (only DataFrames APIs included). +Spark also provides an experimental [R API](sparkr.html) since 1.4 (only DataFrames APIs included). To run Spark interactively in a R interpreter, use `bin/sparkR`: ./bin/sparkR --master local[2] diff --git a/docs/ml-features.md b/docs/ml-features.md index d7851a55fabfe..f88c0248c1a8a 100644 --- a/docs/ml-features.md +++ b/docs/ml-features.md @@ -456,6 +456,122 @@ for expanded in polyDF.select("polyFeatures").take(3): +## StringIndexer + +`StringIndexer` encodes a string column of labels to a column of label indices. +The indices are in `[0, numLabels)`, ordered by label frequencies. +So the most frequent label gets index `0`. +If the input column is numeric, we cast it to string and index the string values. + +**Examples** + +Assume that we have the following DataFrame with columns `id` and `category`: + +~~~~ + id | category +----|---------- + 0 | a + 1 | b + 2 | c + 3 | a + 4 | a + 5 | c +~~~~ + +`category` is a string column with three labels: "a", "b", and "c". +Applying `StringIndexer` with `category` as the input column and `categoryIndex` as the output +column, we should get the following: + +~~~~ + id | category | categoryIndex +----|----------|--------------- + 0 | a | 0.0 + 1 | b | 2.0 + 2 | c | 1.0 + 3 | a | 0.0 + 4 | a | 0.0 + 5 | c | 1.0 +~~~~ + +"a" gets index `0` because it is the most frequent, followed by "c" with index `1` and "b" with +index `2`. + +
    + +
    + +[`StringIndexer`](api/scala/index.html#org.apache.spark.ml.feature.StringIndexer) takes an input +column name and an output column name. + +{% highlight scala %} +import org.apache.spark.ml.feature.StringIndexer + +val df = sqlContext.createDataFrame( + Seq((0, "a"), (1, "b"), (2, "c"), (3, "a"), (4, "a"), (5, "c")) +).toDF("id", "category") +val indexer = new StringIndexer() + .setInputCol("category") + .setOutputCol("categoryIndex") +val indexed = indexer.fit(df).transform(df) +indexed.show() +{% endhighlight %} +
    + +
    +[`StringIndexer`](api/java/org/apache/spark/ml/feature/StringIndexer.html) takes an input column +name and an output column name. + +{% highlight java %} +import java.util.Arrays; + +import org.apache.spark.api.java.JavaRDD; +import org.apache.spark.ml.feature.StringIndexer; +import org.apache.spark.sql.DataFrame; +import org.apache.spark.sql.Row; +import org.apache.spark.sql.RowFactory; +import org.apache.spark.sql.types.StructField; +import org.apache.spark.sql.types.StructType; +import static org.apache.spark.sql.types.DataTypes.*; + +JavaRDD jrdd = jsc.parallelize(Arrays.asList( + RowFactory.create(0, "a"), + RowFactory.create(1, "b"), + RowFactory.create(2, "c"), + RowFactory.create(3, "a"), + RowFactory.create(4, "a"), + RowFactory.create(5, "c") +)); +StructType schema = new StructType(new StructField[] { + createStructField("id", DoubleType, false), + createStructField("category", StringType, false) +}); +DataFrame df = sqlContext.createDataFrame(jrdd, schema); +StringIndexer indexer = new StringIndexer() + .setInputCol("category") + .setOutputCol("categoryIndex"); +DataFrame indexed = indexer.fit(df).transform(df); +indexed.show(); +{% endhighlight %} +
    + +
    + +[`StringIndexer`](api/python/pyspark.ml.html#pyspark.ml.feature.StringIndexer) takes an input +column name and an output column name. + +{% highlight python %} +from pyspark.ml.feature import StringIndexer + +df = sqlContext.createDataFrame( + [(0, "a"), (1, "b"), (2, "c"), (3, "a"), (4, "a"), (5, "c")], + ["id", "category"]) +indexer = StringIndexer(inputCol="category", outputCol="categoryIndex") +indexed = indexer.fit(df).transform(df) +indexed.show() +{% endhighlight %} +
    +
    + ## OneHotEncoder [One-hot encoding](http://en.wikipedia.org/wiki/One-hot) maps a column of label indices to a column of binary vectors, with at most a single one-value. This encoding allows algorithms which expect continuous features, such as Logistic Regression, to use categorical features @@ -876,5 +992,207 @@ bucketedData = bucketizer.transform(dataFrame) +## ElementwiseProduct + +ElementwiseProduct multiplies each input vector by a provided "weight" vector, using element-wise multiplication. In other words, it scales each column of the dataset by a scalar multiplier. This represents the [Hadamard product](https://en.wikipedia.org/wiki/Hadamard_product_%28matrices%29) between the input vector, `v` and transforming vector, `w`, to yield a result vector. + +`\[ \begin{pmatrix} +v_1 \\ +\vdots \\ +v_N +\end{pmatrix} \circ \begin{pmatrix} + w_1 \\ + \vdots \\ + w_N + \end{pmatrix} += \begin{pmatrix} + v_1 w_1 \\ + \vdots \\ + v_N w_N + \end{pmatrix} +\]` + +[`ElementwiseProduct`](api/scala/index.html#org.apache.spark.ml.feature.ElementwiseProduct) takes the following parameter: + +* `scalingVec`: the transforming vector. + +This example below demonstrates how to transform vectors using a transforming vector value. + +
    +
    +{% highlight scala %} +import org.apache.spark.ml.feature.ElementwiseProduct +import org.apache.spark.mllib.linalg.Vectors + +// Create some vector data; also works for sparse vectors +val dataFrame = sqlContext.createDataFrame(Seq( + ("a", Vectors.dense(1.0, 2.0, 3.0)), + ("b", Vectors.dense(4.0, 5.0, 6.0)))).toDF("id", "vector") + +val transformingVector = Vectors.dense(0.0, 1.0, 2.0) +val transformer = new ElementwiseProduct() + .setScalingVec(transformingVector) + .setInputCol("vector") + .setOutputCol("transformedVector") + +// Batch transform the vectors to create new column: +val transformedData = transformer.transform(dataFrame) + +{% endhighlight %} +
    + +
    +{% highlight java %} +import com.google.common.collect.Lists; + +import org.apache.spark.api.java.JavaRDD; +import org.apache.spark.ml.feature.ElementwiseProduct; +import org.apache.spark.mllib.linalg.Vector; +import org.apache.spark.mllib.linalg.Vectors; +import org.apache.spark.sql.DataFrame; +import org.apache.spark.sql.Row; +import org.apache.spark.sql.RowFactory; +import org.apache.spark.sql.SQLContext; +import org.apache.spark.sql.types.DataTypes; +import org.apache.spark.sql.types.Metadata; +import org.apache.spark.sql.types.StructField; +import org.apache.spark.sql.types.StructType; + +// Create some vector data; also works for sparse vectors +JavaRDD jrdd = jsc.parallelize(Lists.newArrayList( + RowFactory.create("a", Vectors.dense(1.0, 2.0, 3.0)), + RowFactory.create("b", Vectors.dense(4.0, 5.0, 6.0)) +)); +List fields = new ArrayList(2); +fields.add(DataTypes.createStructField("id", DataTypes.StringType, false)); +fields.add(DataTypes.createStructField("vector", DataTypes.StringType, false)); +StructType schema = DataTypes.createStructType(fields); +DataFrame dataFrame = sqlContext.createDataFrame(jrdd, schema); +Vector transformingVector = Vectors.dense(0.0, 1.0, 2.0); +ElementwiseProduct transformer = new ElementwiseProduct() + .setScalingVec(transformingVector) + .setInputCol("vector") + .setOutputCol("transformedVector"); +// Batch transform the vectors to create new column: +DataFrame transformedData = transformer.transform(dataFrame); + +{% endhighlight %} +
    +
    + +## VectorAssembler + +`VectorAssembler` is a transformer that combines a given list of columns into a single vector +column. +It is useful for combining raw features and features generated by different feature transformers +into a single feature vector, in order to train ML models like logistic regression and decision +trees. +`VectorAssembler` accepts the following input column types: all numeric types, boolean type, +and vector type. +In each row, the values of the input columns will be concatenated into a vector in the specified +order. + +**Examples** + +Assume that we have a DataFrame with the columns `id`, `hour`, `mobile`, `userFeatures`, +and `clicked`: + +~~~ + id | hour | mobile | userFeatures | clicked +----|------|--------|------------------|--------- + 0 | 18 | 1.0 | [0.0, 10.0, 0.5] | 1.0 +~~~ + +`userFeatures` is a vector column that contains three user features. +We want to combine `hour`, `mobile`, and `userFeatures` into a single feature vector +called `features` and use it to predict `clicked` or not. +If we set `VectorAssembler`'s input columns to `hour`, `mobile`, and `userFeatures` and +output column to `features`, after transformation we should get the following DataFrame: + +~~~ + id | hour | mobile | userFeatures | clicked | features +----|------|--------|------------------|---------|----------------------------- + 0 | 18 | 1.0 | [0.0, 10.0, 0.5] | 1.0 | [18.0, 1.0, 0.0, 10.0, 0.5] +~~~ + +
    +
    + +[`VectorAssembler`](api/scala/index.html#org.apache.spark.ml.feature.VectorAssembler) takes an array +of input column names and an output column name. + +{% highlight scala %} +import org.apache.spark.mllib.linalg.Vectors +import org.apache.spark.ml.feature.VectorAssembler + +val dataset = sqlContext.createDataFrame( + Seq((0, 18, 1.0, Vectors.dense(0.0, 10.0, 0.5), 1.0)) +).toDF("id", "hour", "mobile", "userFeatures", "clicked") +val assembler = new VectorAssembler() + .setInputCols(Array("hour", "mobile", "userFeatures")) + .setOutputCol("features") +val output = assembler.transform(dataset) +println(output.select("features", "clicked").first()) +{% endhighlight %} +
    + +
    + +[`VectorAssembler`](api/java/org/apache/spark/ml/feature/VectorAssembler.html) takes an array +of input column names and an output column name. + +{% highlight java %} +import java.util.Arrays; + +import org.apache.spark.api.java.JavaRDD; +import org.apache.spark.mllib.linalg.VectorUDT; +import org.apache.spark.mllib.linalg.Vectors; +import org.apache.spark.sql.DataFrame; +import org.apache.spark.sql.Row; +import org.apache.spark.sql.RowFactory; +import org.apache.spark.sql.types.*; +import static org.apache.spark.sql.types.DataTypes.*; + +StructType schema = createStructType(new StructField[] { + createStructField("id", IntegerType, false), + createStructField("hour", IntegerType, false), + createStructField("mobile", DoubleType, false), + createStructField("userFeatures", new VectorUDT(), false), + createStructField("clicked", DoubleType, false) +}); +Row row = RowFactory.create(0, 18, 1.0, Vectors.dense(0.0, 10.0, 0.5), 1.0); +JavaRDD rdd = jsc.parallelize(Arrays.asList(row)); +DataFrame dataset = sqlContext.createDataFrame(rdd, schema); + +VectorAssembler assembler = new VectorAssembler() + .setInputCols(new String[] {"hour", "mobile", "userFeatures"}) + .setOutputCol("features"); + +DataFrame output = assembler.transform(dataset); +System.out.println(output.select("features", "clicked").first()); +{% endhighlight %} +
    + +
    + +[`VectorAssembler`](api/python/pyspark.ml.html#pyspark.ml.feature.VectorAssembler) takes a list +of input column names and an output column name. + +{% highlight python %} +from pyspark.mllib.linalg import Vectors +from pyspark.ml.feature import VectorAssembler + +dataset = sqlContext.createDataFrame( + [(0, 18, 1.0, Vectors.dense([0.0, 10.0, 0.5]), 1.0)], + ["id", "hour", "mobile", "userFeatures", "clicked"]) +assembler = VectorAssembler( + inputCols=["hour", "mobile", "userFeatures"], + outputCol="features") +output = assembler.transform(dataset) +print(output.select("features", "clicked").first()) +{% endhighlight %} +
    +
    + # Feature Selectors diff --git a/docs/ml-guide.md b/docs/ml-guide.md index c5f50ed7990f1..4eb622d4b95e8 100644 --- a/docs/ml-guide.md +++ b/docs/ml-guide.md @@ -207,7 +207,7 @@ val model1 = lr.fit(training.toDF) // we can view the parameters it used during fit(). // This prints the parameter (name: value) pairs, where names are unique IDs for this // LogisticRegression instance. -println("Model 1 was fit using parameters: " + model1.fittingParamMap) +println("Model 1 was fit using parameters: " + model1.parent.extractParamMap) // We may alternatively specify parameters using a ParamMap, // which supports several methods for specifying parameters. @@ -222,7 +222,7 @@ val paramMapCombined = paramMap ++ paramMap2 // Now learn a new model using the paramMapCombined parameters. // paramMapCombined overrides all parameters set earlier via lr.set* methods. val model2 = lr.fit(training.toDF, paramMapCombined) -println("Model 2 was fit using parameters: " + model2.fittingParamMap) +println("Model 2 was fit using parameters: " + model2.parent.extractParamMap) // Prepare test data. val test = sc.parallelize(Seq( @@ -289,7 +289,7 @@ LogisticRegressionModel model1 = lr.fit(training); // we can view the parameters it used during fit(). // This prints the parameter (name: value) pairs, where names are unique IDs for this // LogisticRegression instance. -System.out.println("Model 1 was fit using parameters: " + model1.fittingParamMap()); +System.out.println("Model 1 was fit using parameters: " + model1.parent().extractParamMap()); // We may alternatively specify parameters using a ParamMap. ParamMap paramMap = new ParamMap(); @@ -305,7 +305,7 @@ ParamMap paramMapCombined = paramMap.$plus$plus(paramMap2); // Now learn a new model using the paramMapCombined parameters. // paramMapCombined overrides all parameters set earlier via lr.set* methods. LogisticRegressionModel model2 = lr.fit(training, paramMapCombined); -System.out.println("Model 2 was fit using parameters: " + model2.fittingParamMap()); +System.out.println("Model 2 was fit using parameters: " + model2.parent().extractParamMap()); // Prepare test documents. List localTest = Lists.newArrayList( diff --git a/docs/mllib-clustering.md b/docs/mllib-clustering.md index f41ca70952eb7..dac22f736e8cb 100644 --- a/docs/mllib-clustering.md +++ b/docs/mllib-clustering.md @@ -47,7 +47,7 @@ Set Sum of Squared Error (WSSSE). You can reduce this error measure by increasin optimal *k* is usually one where there is an "elbow" in the WSSSE graph. {% highlight scala %} -import org.apache.spark.mllib.clustering.KMeans +import org.apache.spark.mllib.clustering.{KMeans, KMeansModel} import org.apache.spark.mllib.linalg.Vectors // Load and parse the data @@ -62,6 +62,10 @@ val clusters = KMeans.train(parsedData, numClusters, numIterations) // Evaluate clustering by computing Within Set Sum of Squared Errors val WSSSE = clusters.computeCost(parsedData) println("Within Set Sum of Squared Errors = " + WSSSE) + +// Save and load model +clusters.save(sc, "myModelPath") +val sameModel = KMeansModel.load(sc, "myModelPath") {% endhighlight %} @@ -110,6 +114,10 @@ public class KMeansExample { // Evaluate clustering by computing Within Set Sum of Squared Errors double WSSSE = clusters.computeCost(parsedData.rdd()); System.out.println("Within Set Sum of Squared Errors = " + WSSSE); + + // Save and load model + clusters.save(sc.sc(), "myModelPath"); + KMeansModel sameModel = KMeansModel.load(sc.sc(), "myModelPath"); } } {% endhighlight %} @@ -124,7 +132,7 @@ Within Set Sum of Squared Error (WSSSE). You can reduce this error measure by in fact the optimal *k* is usually one where there is an "elbow" in the WSSSE graph. {% highlight python %} -from pyspark.mllib.clustering import KMeans +from pyspark.mllib.clustering import KMeans, KMeansModel from numpy import array from math import sqrt @@ -143,6 +151,10 @@ def error(point): WSSSE = parsedData.map(lambda point: error(point)).reduce(lambda x, y: x + y) print("Within Set Sum of Squared Error = " + str(WSSSE)) + +# Save and load model +clusters.save(sc, "myModelPath") +sameModel = KMeansModel.load(sc, "myModelPath") {% endhighlight %} @@ -312,12 +324,12 @@ Calling `PowerIterationClustering.run` returns a which contains the computed clustering assignments. {% highlight scala %} -import org.apache.spark.mllib.clustering.PowerIterationClustering +import org.apache.spark.mllib.clustering.{PowerIterationClustering, PowerIterationClusteringModel} import org.apache.spark.mllib.linalg.Vectors val similarities: RDD[(Long, Long, Double)] = ... -val pic = new PowerIteartionClustering() +val pic = new PowerIterationClustering() .setK(3) .setMaxIterations(20) val model = pic.run(similarities) @@ -325,6 +337,10 @@ val model = pic.run(similarities) model.assignments.foreach { a => println(s"${a.id} -> ${a.cluster}") } + +// Save and load model +model.save(sc, "myModelPath") +val sameModel = PowerIterationClusteringModel.load(sc, "myModelPath") {% endhighlight %} A full example that produces the experiment described in the PIC paper can be found under @@ -360,6 +376,10 @@ PowerIterationClusteringModel model = pic.run(similarities); for (PowerIterationClustering.Assignment a: model.assignments().toJavaRDD().collect()) { System.out.println(a.id() + " -> " + a.cluster()); } + +// Save and load model +model.save(sc.sc(), "myModelPath"); +PowerIterationClusteringModel sameModel = PowerIterationClusteringModel.load(sc.sc(), "myModelPath"); {% endhighlight %} diff --git a/docs/mllib-feature-extraction.md b/docs/mllib-feature-extraction.md index f723cd6b9dfab..4fe470a8de810 100644 --- a/docs/mllib-feature-extraction.md +++ b/docs/mllib-feature-extraction.md @@ -188,7 +188,7 @@ Here we assume the extracted file is `text8` and in same directory as you run th import org.apache.spark._ import org.apache.spark.rdd._ import org.apache.spark.SparkContext._ -import org.apache.spark.mllib.feature.Word2Vec +import org.apache.spark.mllib.feature.{Word2Vec, Word2VecModel} val input = sc.textFile("text8").map(line => line.split(" ").toSeq) @@ -201,6 +201,10 @@ val synonyms = model.findSynonyms("china", 40) for((synonym, cosineSimilarity) <- synonyms) { println(s"$synonym $cosineSimilarity") } + +// Save and load model +model.save(sc, "myModelPath") +val sameModel = Word2VecModel.load(sc, "myModelPath") {% endhighlight %}
    @@ -410,6 +414,7 @@ import org.apache.spark.SparkContext._ import org.apache.spark.mllib.linalg.Vectors import org.apache.spark.mllib.regression.LabeledPoint import org.apache.spark.mllib.util.MLUtils +import org.apache.spark.mllib.feature.ChiSqSelector // Load some data in libsvm format val data = MLUtils.loadLibSVMFile(sc, "data/mllib/sample_libsvm_data.txt") @@ -505,7 +510,7 @@ v_N ### Example -This example below demonstrates how to load a simple vectors file, extract a set of vectors, then transform those vectors using a transforming vector value. +This example below demonstrates how to transform vectors using a transforming vector value.
    @@ -514,16 +519,44 @@ import org.apache.spark.SparkContext._ import org.apache.spark.mllib.feature.ElementwiseProduct import org.apache.spark.mllib.linalg.Vectors -// Load and parse the data: -val data = sc.textFile("data/mllib/kmeans_data.txt") -val parsedData = data.map(s => Vectors.dense(s.split(' ').map(_.toDouble))) +// Create some vector data; also works for sparse vectors +val data = sc.parallelize(Array(Vectors.dense(1.0, 2.0, 3.0), Vectors.dense(4.0, 5.0, 6.0))) val transformingVector = Vectors.dense(0.0, 1.0, 2.0) val transformer = new ElementwiseProduct(transformingVector) // Batch transform and per-row transform give the same results: -val transformedData = transformer.transform(parsedData) -val transformedData2 = parsedData.map(x => transformer.transform(x)) +val transformedData = transformer.transform(data) +val transformedData2 = data.map(x => transformer.transform(x)) + +{% endhighlight %} +
    + +
    +{% highlight java %} +import java.util.Arrays; +import org.apache.spark.api.java.JavaRDD; +import org.apache.spark.api.java.JavaSparkContext; +import org.apache.spark.mllib.feature.ElementwiseProduct; +import org.apache.spark.mllib.linalg.Vector; +import org.apache.spark.mllib.linalg.Vectors; + +// Create some vector data; also works for sparse vectors +JavaRDD data = sc.parallelize(Arrays.asList( + Vectors.dense(1.0, 2.0, 3.0), Vectors.dense(4.0, 5.0, 6.0))); +Vector transformingVector = Vectors.dense(0.0, 1.0, 2.0); +ElementwiseProduct transformer = new ElementwiseProduct(transformingVector); + +// Batch transform and per-row transform give the same results: +JavaRDD transformedData = transformer.transform(data); +JavaRDD transformedData2 = data.map( + new Function() { + @Override + public Vector call(Vector v) { + return transformer.transform(v); + } + } +); {% endhighlight %}
    diff --git a/docs/mllib-isotonic-regression.md b/docs/mllib-isotonic-regression.md index b521c2f27cd6e..5732bc4c7e79e 100644 --- a/docs/mllib-isotonic-regression.md +++ b/docs/mllib-isotonic-regression.md @@ -60,7 +60,7 @@ Model is created using the training set and a mean squared error is calculated f labels and real labels in the test set. {% highlight scala %} -import org.apache.spark.mllib.regression.IsotonicRegression +import org.apache.spark.mllib.regression.{IsotonicRegression, IsotonicRegressionModel} val data = sc.textFile("data/mllib/sample_isotonic_regression_data.txt") @@ -88,6 +88,10 @@ val predictionAndLabel = test.map { point => // Calculate mean squared error between predicted and real labels. val meanSquaredError = predictionAndLabel.map{case(p, l) => math.pow((p - l), 2)}.mean() println("Mean Squared Error = " + meanSquaredError) + +// Save and load model +model.save(sc, "myModelPath") +val sameModel = IsotonicRegressionModel.load(sc, "myModelPath") {% endhighlight %}
    @@ -150,6 +154,10 @@ Double meanSquaredError = new JavaDoubleRDD(predictionAndLabel.map( ).rdd()).mean(); System.out.println("Mean Squared Error = " + meanSquaredError); + +// Save and load model +model.save(sc.sc(), "myModelPath"); +IsotonicRegressionModel sameModel = IsotonicRegressionModel.load(sc.sc(), "myModelPath"); {% endhighlight %}
    diff --git a/docs/running-on-yarn.md b/docs/running-on-yarn.md index 9d55f435e80ad..96cf612c54fdd 100644 --- a/docs/running-on-yarn.md +++ b/docs/running-on-yarn.md @@ -242,6 +242,22 @@ Most of the configs are the same for Spark on YARN as for other deployment modes running against earlier versions, this property will be ignored. + + spark.yarn.keytab + (none) + + The full path to the file that contains the keytab for the principal specified above. + This keytab will be copied to the node running the Application Master via the Secure Distributed Cache, + for renewing the login tickets and the delegation tokens periodically. + + + + spark.yarn.principal + (none) + + Principal to be used to login to KDC, while running on secure HDFS. + + # Launching Spark on YARN diff --git a/docs/sparkr.md b/docs/sparkr.md new file mode 100644 index 0000000000000..4d82129921a37 --- /dev/null +++ b/docs/sparkr.md @@ -0,0 +1,223 @@ +--- +layout: global +displayTitle: SparkR (R on Spark) +title: SparkR (R on Spark) +--- + +* This will become a table of contents (this text will be scraped). +{:toc} + +# Overview +SparkR is an R package that provides a light-weight frontend to use Apache Spark from R. +In Spark {{site.SPARK_VERSION}}, SparkR provides a distributed data frame implementation that +supports operations like selection, filtering, aggregation etc. (similar to R data frames, +[dplyr](https://github.com/hadley/dplyr)) but on large datasets. + +# SparkR DataFrames + +A DataFrame is a distributed collection of data organized into named columns. It is conceptually +equivalent to a table in a relational database or a data frame in R, but with richer +optimizations under the hood. DataFrames can be constructed from a wide array of sources such as: +structured data files, tables in Hive, external databases, or existing local R data frames. + +All of the examples on this page use sample data included in R or the Spark distribution and can be run using the `./bin/sparkR` shell. + +## Starting Up: SparkContext, SQLContext + +
    +The entry point into SparkR is the `SparkContext` which connects your R program to a Spark cluster. +You can create a `SparkContext` using `sparkR.init` and pass in options such as the application name +etc. Further, to work with DataFrames we will need a `SQLContext`, which can be created from the +SparkContext. If you are working from the SparkR shell, the `SQLContext` and `SparkContext` should +already be created for you. + +{% highlight r %} +sc <- sparkR.init() +sqlContext <- sparkRSQL.init(sc) +{% endhighlight %} + +
    + +## Creating DataFrames +With a `SQLContext`, applications can create `DataFrame`s from a local R data frame, from a [Hive table](sql-programming-guide.html#hive-tables), or from other [data sources](sql-programming-guide.html#data-sources). + +### From local data frames +The simplest way to create a data frame is to convert a local R data frame into a SparkR DataFrame. Specifically we can use `createDataFrame` and pass in the local R data frame to create a SparkR DataFrame. As an example, the following creates a `DataFrame` based using the `faithful` dataset from R. + +
    +{% highlight r %} +df <- createDataFrame(sqlContext, faithful) + +# Displays the content of the DataFrame to stdout +head(df) +## eruptions waiting +##1 3.600 79 +##2 1.800 54 +##3 3.333 74 + +{% endhighlight %} +
    + +### From Data Sources + +SparkR supports operating on a variety of data sources through the `DataFrame` interface. This section describes the general methods for loading and saving data using Data Sources. You can check the Spark SQL programming guide for more [specific options](sql-programming-guide.html#manually-specifying-options) that are available for the built-in data sources. + +The general method for creating DataFrames from data sources is `read.df`. This method takes in the `SQLContext`, the path for the file to load and the type of data source. SparkR supports reading JSON and Parquet files natively and through [Spark Packages](http://spark-packages.org/) you can find data source connectors for popular file formats like [CSV](http://spark-packages.org/package/databricks/spark-csv) and [Avro](http://spark-packages.org/package/databricks/spark-avro). + +We can see how to use data sources using an example JSON input file. Note that the file that is used here is _not_ a typical JSON file. Each line in the file must contain a separate, self-contained valid JSON object. As a consequence, a regular multi-line JSON file will most often fail. + +
    + +{% highlight r %} +people <- read.df(sqlContext, "./examples/src/main/resources/people.json", "json") +head(people) +## age name +##1 NA Michael +##2 30 Andy +##3 19 Justin + +# SparkR automatically infers the schema from the JSON file +printSchema(people) +# root +# |-- age: integer (nullable = true) +# |-- name: string (nullable = true) + +{% endhighlight %} +
    + +The data sources API can also be used to save out DataFrames into multiple file formats. For example we can save the DataFrame from the previous example +to a Parquet file using `write.df` + +
    +{% highlight r %} +write.df(people, path="people.parquet", source="parquet", mode="overwrite") +{% endhighlight %} +
    + +### From Hive tables + +You can also create SparkR DataFrames from Hive tables. To do this we will need to create a HiveContext which can access tables in the Hive MetaStore. Note that Spark should have been built with [Hive support](building-spark.html#building-with-hive-and-jdbc-support) and more details on the difference between SQLContext and HiveContext can be found in the [SQL programming guide](sql-programming-guide.html#starting-point-sqlcontext). + +
    +{% highlight r %} +# sc is an existing SparkContext. +hiveContext <- sparkRHive.init(sc) + +sql(hiveContext, "CREATE TABLE IF NOT EXISTS src (key INT, value STRING)") +sql(hiveContext, "LOAD DATA LOCAL INPATH 'examples/src/main/resources/kv1.txt' INTO TABLE src") + +# Queries can be expressed in HiveQL. +results <- hiveContext.sql("FROM src SELECT key, value") + +# results is now a DataFrame +head(results) +## key value +## 1 238 val_238 +## 2 86 val_86 +## 3 311 val_311 + +{% endhighlight %} +
    + +## DataFrame Operations + +SparkR DataFrames support a number of functions to do structured data processing. +Here we include some basic examples and a complete list can be found in the [API](api/R/index.html) docs: + +### Selecting rows, columns + +
    +{% highlight r %} +# Create the DataFrame +df <- createDataFrame(sqlContext, faithful) + +# Get basic information about the DataFrame +df +## DataFrame[eruptions:double, waiting:double] + +# Select only the "eruptions" column +head(select(df, df$eruptions)) +## eruptions +##1 3.600 +##2 1.800 +##3 3.333 + +# You can also pass in column name as strings +head(select(df, "eruptions")) + +# Filter the DataFrame to only retain rows with wait times shorter than 50 mins +head(filter(df, df$waiting < 50)) +## eruptions waiting +##1 1.750 47 +##2 1.750 47 +##3 1.867 48 + +{% endhighlight %} + +
    + +### Grouping, Aggregation + +SparkR data frames support a number of commonly used functions to aggregate data after grouping. For example we can compute a histogram of the `waiting` time in the `faithful` dataset as shown below + +
    +{% highlight r %} + +# We use the `n` operator to count the number of times each waiting time appears +head(summarize(groupBy(df, df$waiting), count = n(df$waiting))) +## waiting count +##1 81 13 +##2 60 6 +##3 68 1 + +# We can also sort the output from the aggregation to get the most common waiting times +waiting_counts <- summarize(groupBy(df, df$waiting), count = n(df$waiting)) +head(arrange(waiting_counts, desc(waiting_counts$count))) + +## waiting count +##1 78 15 +##2 83 14 +##3 81 13 + +{% endhighlight %} +
    + +### Operating on Columns + +SparkR also provides a number of functions that can directly applied to columns for data processing and during aggregation. The example below shows the use of basic arithmetic functions. + +
    +{% highlight r %} + +# Convert waiting time from hours to seconds. +# Note that we can assign this to a new column in the same DataFrame +df$waiting_secs <- df$waiting * 60 +head(df) +## eruptions waiting waiting_secs +##1 3.600 79 4740 +##2 1.800 54 3240 +##3 3.333 74 4440 + +{% endhighlight %} +
    + +## Running SQL Queries from SparkR +A SparkR DataFrame can also be registered as a temporary table in Spark SQL and registering a DataFrame as a table allows you to run SQL queries over its data. +The `sql` function enables applications to run SQL queries programmatically and returns the result as a `DataFrame`. + +
    +{% highlight r %} +# Load a JSON file +people <- read.df(sqlContext, "./examples/src/main/resources/people.json", "json") + +# Register this DataFrame as a table. +registerTempTable(people, "people") + +# SQL statements can be run by using the sql method +teenagers <- sql(sqlContext, "SELECT name FROM people WHERE age >= 13 AND age <= 19") +head(teenagers) +## name +##1 Justin + +{% endhighlight %} +
    diff --git a/docs/sql-programming-guide.md b/docs/sql-programming-guide.md index ab646f65bb5eb..282ea75e1e785 100644 --- a/docs/sql-programming-guide.md +++ b/docs/sql-programming-guide.md @@ -11,6 +11,7 @@ title: Spark SQL and DataFrames Spark SQL is a Spark module for structured data processing. It provides a programming abstraction called DataFrames and can also act as distributed SQL query engine. +For how to enable Hive support, please refer to the [Hive Tables](#hive-tables) section. # DataFrames @@ -108,7 +109,7 @@ As an example, the following creates a `DataFrame` based on the content of a JSO val sc: SparkContext // An existing SparkContext. val sqlContext = new org.apache.spark.sql.SQLContext(sc) -val df = sqlContext.jsonFile("examples/src/main/resources/people.json") +val df = sqlContext.read.json("examples/src/main/resources/people.json") // Displays the content of the DataFrame to stdout df.show() @@ -121,7 +122,7 @@ df.show() JavaSparkContext sc = ...; // An existing JavaSparkContext. SQLContext sqlContext = new org.apache.spark.sql.SQLContext(sc); -DataFrame df = sqlContext.jsonFile("examples/src/main/resources/people.json"); +DataFrame df = sqlContext.read().json("examples/src/main/resources/people.json"); // Displays the content of the DataFrame to stdout df.show(); @@ -134,7 +135,7 @@ df.show(); from pyspark.sql import SQLContext sqlContext = SQLContext(sc) -df = sqlContext.jsonFile("examples/src/main/resources/people.json") +df = sqlContext.read.json("examples/src/main/resources/people.json") # Displays the content of the DataFrame to stdout df.show() @@ -170,7 +171,7 @@ val sc: SparkContext // An existing SparkContext. val sqlContext = new org.apache.spark.sql.SQLContext(sc) // Create the DataFrame -val df = sqlContext.jsonFile("examples/src/main/resources/people.json") +val df = sqlContext.read.json("examples/src/main/resources/people.json") // Show the content of the DataFrame df.show() @@ -220,7 +221,7 @@ JavaSparkContext sc // An existing SparkContext. SQLContext sqlContext = new org.apache.spark.sql.SQLContext(sc) // Create the DataFrame -DataFrame df = sqlContext.jsonFile("examples/src/main/resources/people.json"); +DataFrame df = sqlContext.read().json("examples/src/main/resources/people.json"); // Show the content of the DataFrame df.show(); @@ -276,7 +277,7 @@ from pyspark.sql import SQLContext sqlContext = SQLContext(sc) # Create the DataFrame -df = sqlContext.jsonFile("examples/src/main/resources/people.json") +df = sqlContext.read.json("examples/src/main/resources/people.json") # Show the content of the DataFrame df.show() @@ -776,8 +777,8 @@ In the simplest form, the default data source (`parquet` unless otherwise config
    {% highlight scala %} -val df = sqlContext.load("examples/src/main/resources/users.parquet") -df.select("name", "favorite_color").save("namesAndFavColors.parquet") +val df = sqlContext.read.load("examples/src/main/resources/users.parquet") +df.select("name", "favorite_color").write.save("namesAndFavColors.parquet") {% endhighlight %}
    @@ -786,8 +787,8 @@ df.select("name", "favorite_color").save("namesAndFavColors.parquet") {% highlight java %} -DataFrame df = sqlContext.load("examples/src/main/resources/users.parquet"); -df.select("name", "favorite_color").save("namesAndFavColors.parquet"); +DataFrame df = sqlContext.read().load("examples/src/main/resources/users.parquet"); +df.select("name", "favorite_color").write().save("namesAndFavColors.parquet"); {% endhighlight %} @@ -797,8 +798,8 @@ df.select("name", "favorite_color").save("namesAndFavColors.parquet"); {% highlight python %} -df = sqlContext.load("examples/src/main/resources/users.parquet") -df.select("name", "favorite_color").save("namesAndFavColors.parquet") +df = sqlContext.read.load("examples/src/main/resources/users.parquet") +df.select("name", "favorite_color").write.save("namesAndFavColors.parquet") {% endhighlight %} @@ -826,8 +827,8 @@ using this syntax.
    {% highlight scala %} -val df = sqlContext.load("examples/src/main/resources/people.json", "json") -df.select("name", "age").save("namesAndAges.parquet", "parquet") +val df = sqlContext.read.format("json").load("examples/src/main/resources/people.json") +df.select("name", "age").write.format("json").save("namesAndAges.parquet") {% endhighlight %}
    @@ -836,8 +837,8 @@ df.select("name", "age").save("namesAndAges.parquet", "parquet") {% highlight java %} -DataFrame df = sqlContext.load("examples/src/main/resources/people.json", "json"); -df.select("name", "age").save("namesAndAges.parquet", "parquet"); +DataFrame df = sqlContext.read().format("json").load("examples/src/main/resources/people.json"); +df.select("name", "age").write().format("parquet").save("namesAndAges.parquet"); {% endhighlight %} @@ -847,8 +848,8 @@ df.select("name", "age").save("namesAndAges.parquet", "parquet"); {% highlight python %} -df = sqlContext.load("examples/src/main/resources/people.json", "json") -df.select("name", "age").save("namesAndAges.parquet", "parquet") +df = sqlContext.read.load("examples/src/main/resources/people.json", format="json") +df.select("name", "age").write.save("namesAndAges.parquet", format="parquet") {% endhighlight %} @@ -906,7 +907,7 @@ new data. Ignore mode means that when saving a DataFrame to a data source, if data already exists, the save operation is expected to not save the contents of the DataFrame and to not - change the existing data. This is similar to a `CREATE TABLE IF NOT EXISTS` in SQL. + change the existing data. This is similar to a CREATE TABLE IF NOT EXISTS in SQL. @@ -946,11 +947,11 @@ import sqlContext.implicits._ val people: RDD[Person] = ... // An RDD of case class objects, from the previous example. // The RDD is implicitly converted to a DataFrame by implicits, allowing it to be stored using Parquet. -people.saveAsParquetFile("people.parquet") +people.write.parquet("people.parquet") // Read in the parquet file created above. Parquet files are self-describing so the schema is preserved. // The result of loading a Parquet file is also a DataFrame. -val parquetFile = sqlContext.parquetFile("people.parquet") +val parquetFile = sqlContext.read.parquet("people.parquet") //Parquet files can also be registered as tables and then used in SQL statements. parquetFile.registerTempTable("parquetFile") @@ -968,11 +969,11 @@ teenagers.map(t => "Name: " + t(0)).collect().foreach(println) DataFrame schemaPeople = ... // The DataFrame from the previous example. // DataFrames can be saved as Parquet files, maintaining the schema information. -schemaPeople.saveAsParquetFile("people.parquet"); +schemaPeople.write().parquet("people.parquet"); // Read in the Parquet file created above. Parquet files are self-describing so the schema is preserved. // The result of loading a parquet file is also a DataFrame. -DataFrame parquetFile = sqlContext.parquetFile("people.parquet"); +DataFrame parquetFile = sqlContext.read().parquet("people.parquet"); //Parquet files can also be registered as tables and then used in SQL statements. parquetFile.registerTempTable("parquetFile"); @@ -994,11 +995,11 @@ List teenagerNames = teenagers.javaRDD().map(new Function() schemaPeople # The DataFrame from the previous example. # DataFrames can be saved as Parquet files, maintaining the schema information. -schemaPeople.saveAsParquetFile("people.parquet") +schemaPeople.read.parquet("people.parquet") # Read in the Parquet file created above. Parquet files are self-describing so the schema is preserved. # The result of loading a parquet file is also a DataFrame. -parquetFile = sqlContext.parquetFile("people.parquet") +parquetFile = sqlContext.write.parquet("people.parquet") # Parquet files can also be registered as tables and then used in SQL statements. parquetFile.registerTempTable("parquetFile"); @@ -1030,7 +1031,7 @@ teenagers <- sql(sqlContext, "SELECT name FROM parquetFile WHERE age >= 13 AND a teenNames <- map(teenagers, function(p) { paste("Name:", p$name)}) for (teenName in collect(teenNames)) { cat(teenName, "\n") -} +} {% endhighlight %} @@ -1086,9 +1087,9 @@ path {% endhighlight %} -By passing `path/to/table` to either `SQLContext.parquetFile` or `SQLContext.load`, Spark SQL will -automatically extract the partitioning information from the paths. Now the schema of the returned -DataFrame becomes: +By passing `path/to/table` to either `SQLContext.read.parquet` or `SQLContext.read.load`, Spark SQL +will automatically extract the partitioning information from the paths. +Now the schema of the returned DataFrame becomes: {% highlight text %} @@ -1121,15 +1122,15 @@ import sqlContext.implicits._ // Create a simple DataFrame, stored into a partition directory val df1 = sparkContext.makeRDD(1 to 5).map(i => (i, i * 2)).toDF("single", "double") -df1.saveAsParquetFile("data/test_table/key=1") +df1.write.parquet("data/test_table/key=1") // Create another DataFrame in a new partition directory, // adding a new column and dropping an existing column val df2 = sparkContext.makeRDD(6 to 10).map(i => (i, i * 3)).toDF("single", "triple") -df2.saveAsParquetFile("data/test_table/key=2") +df2.write.parquet("data/test_table/key=2") // Read the partitioned table -val df3 = sqlContext.parquetFile("data/test_table") +val df3 = sqlContext.read.parquet("data/test_table") df3.printSchema() // The final schema consists of all 3 columns in the Parquet files together @@ -1268,12 +1269,10 @@ Configuration of Parquet can be done using the `setConf` method on `SQLContext`
    Spark SQL can automatically infer the schema of a JSON dataset and load it as a DataFrame. -This conversion can be done using one of two methods in a `SQLContext`: - -* `jsonFile` - loads data from a directory of JSON files where each line of the files is a JSON object. -* `jsonRDD` - loads data from an existing RDD where each element of the RDD is a string containing a JSON object. +This conversion can be done using `SQLContext.read.json()` on either an RDD of String, +or a JSON file. -Note that the file that is offered as _jsonFile_ is not a typical JSON file. Each +Note that the file that is offered as _a json file_ is not a typical JSON file. Each line must contain a separate, self-contained valid JSON object. As a consequence, a regular multi-line JSON file will most often fail. @@ -1284,8 +1283,7 @@ val sqlContext = new org.apache.spark.sql.SQLContext(sc) // A JSON dataset is pointed to by path. // The path can be either a single text file or a directory storing text files. val path = "examples/src/main/resources/people.json" -// Create a DataFrame from the file(s) pointed to by path -val people = sqlContext.jsonFile(path) +val people = sqlContext.read.json(path) // The inferred schema can be visualized using the printSchema() method. people.printSchema() @@ -1303,19 +1301,17 @@ val teenagers = sqlContext.sql("SELECT name FROM people WHERE age >= 13 AND age // an RDD[String] storing one JSON object per string. val anotherPeopleRDD = sc.parallelize( """{"name":"Yin","address":{"city":"Columbus","state":"Ohio"}}""" :: Nil) -val anotherPeople = sqlContext.jsonRDD(anotherPeopleRDD) +val anotherPeople = sqlContext.read.json(anotherPeopleRDD) {% endhighlight %}
    Spark SQL can automatically infer the schema of a JSON dataset and load it as a DataFrame. -This conversion can be done using one of two methods in a `SQLContext` : - -* `jsonFile` - loads data from a directory of JSON files where each line of the files is a JSON object. -* `jsonRDD` - loads data from an existing RDD where each element of the RDD is a string containing a JSON object. +This conversion can be done using `SQLContext.read().json()` on either an RDD of String, +or a JSON file. -Note that the file that is offered as _jsonFile_ is not a typical JSON file. Each +Note that the file that is offered as _a json file_ is not a typical JSON file. Each line must contain a separate, self-contained valid JSON object. As a consequence, a regular multi-line JSON file will most often fail. @@ -1325,9 +1321,7 @@ SQLContext sqlContext = new org.apache.spark.sql.SQLContext(sc); // A JSON dataset is pointed to by path. // The path can be either a single text file or a directory storing text files. -String path = "examples/src/main/resources/people.json"; -// Create a DataFrame from the file(s) pointed to by path -DataFrame people = sqlContext.jsonFile(path); +DataFrame people = sqlContext.read().json("examples/src/main/resources/people.json"); // The inferred schema can be visualized using the printSchema() method. people.printSchema(); @@ -1346,18 +1340,15 @@ DataFrame teenagers = sqlContext.sql("SELECT name FROM people WHERE age >= 13 AN List jsonData = Arrays.asList( "{\"name\":\"Yin\",\"address\":{\"city\":\"Columbus\",\"state\":\"Ohio\"}}"); JavaRDD anotherPeopleRDD = sc.parallelize(jsonData); -DataFrame anotherPeople = sqlContext.jsonRDD(anotherPeopleRDD); +DataFrame anotherPeople = sqlContext.read().json(anotherPeopleRDD); {% endhighlight %}
    Spark SQL can automatically infer the schema of a JSON dataset and load it as a DataFrame. -This conversion can be done using one of two methods in a `SQLContext`: +This conversion can be done using `SQLContext.read.json` on a JSON file. -* `jsonFile` - loads data from a directory of JSON files where each line of the files is a JSON object. -* `jsonRDD` - loads data from an existing RDD where each element of the RDD is a string containing a JSON object. - -Note that the file that is offered as _jsonFile_ is not a typical JSON file. Each +Note that the file that is offered as _a json file_ is not a typical JSON file. Each line must contain a separate, self-contained valid JSON object. As a consequence, a regular multi-line JSON file will most often fail. @@ -1368,9 +1359,7 @@ sqlContext = SQLContext(sc) # A JSON dataset is pointed to by path. # The path can be either a single text file or a directory storing text files. -path = "examples/src/main/resources/people.json" -# Create a DataFrame from the file(s) pointed to by path -people = sqlContext.jsonFile(path) +people = sqlContext.read.json("examples/src/main/resources/people.json") # The inferred schema can be visualized using the printSchema() method. people.printSchema() @@ -1393,12 +1382,11 @@ anotherPeople = sqlContext.jsonRDD(anotherPeopleRDD)
    -Spark SQL can automatically infer the schema of a JSON dataset and load it as a DataFrame. -This conversion can be done using one of two methods in a `SQLContext`: - -* `jsonFile` - loads data from a directory of JSON files where each line of the files is a JSON object. +Spark SQL can automatically infer the schema of a JSON dataset and load it as a DataFrame. using +the `jsonFile` function, which loads data from a directory of JSON files where each line of the +files is a JSON object. -Note that the file that is offered as _jsonFile_ is not a typical JSON file. Each +Note that the file that is offered as _a json file_ is not a typical JSON file. Each line must contain a separate, self-contained valid JSON object. As a consequence, a regular multi-line JSON file will most often fail. @@ -1502,7 +1490,7 @@ Row[] results = sqlContext.sql("FROM src SELECT key, value").collect();
    When working with Hive one must construct a `HiveContext`, which inherits from `SQLContext`, and -adds support for finding tables in the MetaStore and writing queries using HiveQL. +adds support for finding tables in the MetaStore and writing queries using HiveQL. {% highlight python %} # sc is an existing SparkContext. from pyspark.sql import HiveContext @@ -1526,8 +1514,8 @@ adds support for finding tables in the MetaStore and writing queries using HiveQ # sc is an existing SparkContext. sqlContext <- sparkRHive.init(sc) -hql(sqlContext, "CREATE TABLE IF NOT EXISTS src (key INT, value STRING)") -hql(sqlContext, "LOAD DATA LOCAL INPATH 'examples/src/main/resources/kv1.txt' INTO TABLE src") +sql(sqlContext, "CREATE TABLE IF NOT EXISTS src (key INT, value STRING)") +sql(sqlContext, "LOAD DATA LOCAL INPATH 'examples/src/main/resources/kv1.txt' INTO TABLE src") # Queries can be expressed in HiveQL. results = sqlContext.sql("FROM src SELECT key, value").collect() @@ -1537,6 +1525,70 @@ results = sqlContext.sql("FROM src SELECT key, value").collect()
    +### Interacting with Different Versions of Hive Metastore + +One of the most important pieces of Spark SQL's Hive support is interaction with Hive metastore, +which enables Spark SQL to access metadata of Hive tables. Starting from Spark 1.4.0, a single binary build of Spark SQL can be used to query different versions of Hive metastores, using the configuration described below. + +Internally, Spark SQL uses two Hive clients, one for executing native Hive commands like `SET` +and `DESCRIBE`, the other dedicated for communicating with Hive metastore. The former uses Hive +jars of version 0.13.1, which are bundled with Spark 1.4.0. The latter uses Hive jars of the +version specified by users. An isolated classloader is used here to avoid dependency conflicts. + + + + + + + + + + + + + + + + + + + + + + + +
    Property NameDefaultMeaning
    spark.sql.hive.metastore.version0.13.1 + Version of the Hive metastore. Available + options are 0.12.0 and 0.13.1. Support for more versions is coming in the future. +
    spark.sql.hive.metastore.jarsbuiltin + Location of the jars that should be used to instantiate the HiveMetastoreClient. This + property can be one of three options: +
      +
    1. builtin
    2. + Use Hive 0.13.1, which is bundled with the Spark assembly jar when -Phive is + enabled. When this option is chosen, spark.sql.hive.metastore.version must be + either 0.13.1 or not defined. +
    3. maven
    4. + Use Hive jars of specified version downloaded from Maven repositories. +
    5. A classpath in the standard format for both Hive and Hadoop.
    6. +
    +
    spark.sql.hive.metastore.sharedPrefixescom.mysql.jdbc,
    org.postgresql,
    com.microsoft.sqlserver,
    oracle.jdbc
    +

    + A comma separated list of class prefixes that should be loaded using the classloader that is + shared between Spark SQL and a specific version of Hive. An example of classes that should + be shared is JDBC drivers that are needed to talk to the metastore. Other classes that need + to be shared are those that interact with classes that are already shared. For example, + custom appenders that are used by log4j. +

    +
    spark.sql.hive.metastore.barrierPrefixes(empty) +

    + A comma separated list of class prefixes that should explicitly be reloaded for each version + of Hive that Spark SQL is communicating with. For example, Hive UDFs that are declared in a + prefix that typically would be shared (i.e. org.apache.spark.*). +

    +
    + + ## JDBC To Other Databases Spark SQL also includes a data source that can read data from other databases using JDBC. This @@ -1570,7 +1622,7 @@ the Data Sources API. The following options are supported: dbtable - The JDBC table that should be read. Note that anything that is valid in a `FROM` clause of + The JDBC table that should be read. Note that anything that is valid in a FROM clause of a SQL query can be used. For example, instead of a full table you could also use a subquery in parentheses. @@ -1714,7 +1766,7 @@ that these options will be deprecated in future release as more optimizations ar Configures the maximum size in bytes for a table that will be broadcast to all worker nodes when performing a join. By setting this value to -1 broadcasting can be disabled. Note that currently statistics are only supported for Hive Metastore tables where the command - `ANALYZE TABLE <tableName> COMPUTE STATISTICS noscan` has been run. + ANALYZE TABLE <tableName> COMPUTE STATISTICS noscan has been run. @@ -1737,7 +1789,9 @@ that these options will be deprecated in future release as more optimizations ar # Distributed SQL Engine -Spark SQL can also act as a distributed query engine using its JDBC/ODBC or command-line interface. In this mode, end-users or applications can interact with Spark SQL directly to run SQL queries, without the need to write any code. +Spark SQL can also act as a distributed query engine using its JDBC/ODBC or command-line interface. +In this mode, end-users or applications can interact with Spark SQL directly to run SQL queries, +without the need to write any code. ## Running the Thrift JDBC/ODBC server @@ -1816,6 +1870,25 @@ options. ## Upgrading from Spark SQL 1.3 to 1.4 +#### DataFrame data reader/writer interface + +Based on user feedback, we created a new, more fluid API for reading data in (`SQLContext.read`) +and writing data out (`DataFrame.write`), +and deprecated the old APIs (e.g. `SQLContext.parquetFile`, `SQLContext.jsonFile`). + +See the API docs for `SQLContext.read` ( + Scala, + Java, + Python +) and `DataFrame.write` ( + Scala, + Java, + Python +) more information. + + +#### DataFrame.groupBy retains grouping columns + Based on user feedback, we changed the default behavior of `DataFrame.groupBy().agg()` to retain the grouping columns in the resulting `DataFrame`. To keep the behavior in 1.3, set `spark.sql.retainGroupColumns` to `false`.
    diff --git a/docs/streaming-programming-guide.md b/docs/streaming-programming-guide.md index bd863d48d53e3..42b33947873b0 100644 --- a/docs/streaming-programming-guide.md +++ b/docs/streaming-programming-guide.md @@ -1946,10 +1946,10 @@ creates a single receiver (running on a worker machine) that receives a single s Receiving multiple data streams can therefore be achieved by creating multiple input DStreams and configuring them to receive different partitions of the data stream from the source(s). For example, a single Kafka input DStream receiving two topics of data can be split into two -Kafka input streams, each receiving only one topic. This would run two receivers on two workers, -thus allowing data to be received in parallel, and increasing overall throughput. These multiple -DStream can be unioned together to create a single DStream. Then the transformations that was -being applied on the single input DStream can applied on the unified stream. This is done as follows. +Kafka input streams, each receiving only one topic. This would run two receivers, +allowing data to be received in parallel, and increasing overall throughput. These multiple +DStreams can be unioned together to create a single DStream. Then the transformations that were +being applied on a single input DStream can be applied on the unified stream. This is done as follows.
    diff --git a/examples/src/main/java/org/apache/spark/examples/ml/JavaSimpleParamsExample.java b/examples/src/main/java/org/apache/spark/examples/ml/JavaSimpleParamsExample.java index 29158d5c85651..dac649d1d5ae6 100644 --- a/examples/src/main/java/org/apache/spark/examples/ml/JavaSimpleParamsExample.java +++ b/examples/src/main/java/org/apache/spark/examples/ml/JavaSimpleParamsExample.java @@ -97,7 +97,7 @@ public static void main(String[] args) { DataFrame test = jsql.createDataFrame(jsc.parallelize(localTest), LabeledPoint.class); // Make predictions on test documents using the Transformer.transform() method. - // LogisticRegression.transform will only use the 'features' column. + // LogisticRegressionModel.transform will only use the 'features' column. // Note that model2.transform() outputs a 'myProbability' column instead of the usual // 'probability' column since we renamed the lr.probabilityCol parameter previously. DataFrame results = model2.transform(test); diff --git a/examples/src/main/python/ml/gradient_boosted_trees.py b/examples/src/main/python/ml/gradient_boosted_trees.py new file mode 100644 index 0000000000000..6446f0fe5eeab --- /dev/null +++ b/examples/src/main/python/ml/gradient_boosted_trees.py @@ -0,0 +1,83 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from __future__ import print_function + +import sys + +from pyspark import SparkContext +from pyspark.ml.classification import GBTClassifier +from pyspark.ml.feature import StringIndexer +from pyspark.ml.regression import GBTRegressor +from pyspark.mllib.evaluation import BinaryClassificationMetrics, RegressionMetrics +from pyspark.mllib.util import MLUtils +from pyspark.sql import Row, SQLContext + +""" +A simple example demonstrating a Gradient Boosted Trees Classification/Regression Pipeline. +Note: GBTClassifier only supports binary classification currently +Run with: + bin/spark-submit examples/src/main/python/ml/gradient_boosted_trees.py +""" + + +def testClassification(train, test): + # Train a GradientBoostedTrees model. + + rf = GBTClassifier(maxIter=30, maxDepth=4, labelCol="indexedLabel") + + model = rf.fit(train) + predictionAndLabels = model.transform(test).select("prediction", "indexedLabel") \ + .map(lambda x: (x.prediction, x.indexedLabel)) + + metrics = BinaryClassificationMetrics(predictionAndLabels) + print("AUC %.3f" % metrics.areaUnderROC) + + +def testRegression(train, test): + # Train a GradientBoostedTrees model. + + rf = GBTRegressor(maxIter=30, maxDepth=4, labelCol="indexedLabel") + + model = rf.fit(train) + predictionAndLabels = model.transform(test).select("prediction", "indexedLabel") \ + .map(lambda x: (x.prediction, x.indexedLabel)) + + metrics = RegressionMetrics(predictionAndLabels) + print("rmse %.3f" % metrics.rootMeanSquaredError) + print("r2 %.3f" % metrics.r2) + print("mae %.3f" % metrics.meanAbsoluteError) + + +if __name__ == "__main__": + if len(sys.argv) > 1: + print("Usage: gradient_boosted_trees", file=sys.stderr) + exit(1) + sc = SparkContext(appName="PythonGBTExample") + sqlContext = SQLContext(sc) + + # Load and parse the data file into a dataframe. + df = MLUtils.loadLibSVMFile(sc, "data/mllib/sample_libsvm_data.txt").toDF() + + # Map labels into an indexed column of labels in [0, numLabels) + stringIndexer = StringIndexer(inputCol="label", outputCol="indexedLabel") + si_model = stringIndexer.fit(df) + td = si_model.transform(df) + [train, test] = td.randomSplit([0.7, 0.3]) + testClassification(train, test) + testRegression(train, test) + sc.stop() diff --git a/examples/src/main/python/ml/random_forest_example.py b/examples/src/main/python/ml/random_forest_example.py new file mode 100644 index 0000000000000..c7730e1bfacd9 --- /dev/null +++ b/examples/src/main/python/ml/random_forest_example.py @@ -0,0 +1,87 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from __future__ import print_function + +import sys + +from pyspark import SparkContext +from pyspark.ml.classification import RandomForestClassifier +from pyspark.ml.feature import StringIndexer +from pyspark.ml.regression import RandomForestRegressor +from pyspark.mllib.evaluation import MulticlassMetrics, RegressionMetrics +from pyspark.mllib.util import MLUtils +from pyspark.sql import Row, SQLContext + +""" +A simple example demonstrating a RandomForest Classification/Regression Pipeline. +Run with: + bin/spark-submit examples/src/main/python/ml/random_forest_example.py +""" + + +def testClassification(train, test): + # Train a RandomForest model. + # Setting featureSubsetStrategy="auto" lets the algorithm choose. + # Note: Use larger numTrees in practice. + + rf = RandomForestClassifier(labelCol="indexedLabel", numTrees=3, maxDepth=4) + + model = rf.fit(train) + predictionAndLabels = model.transform(test).select("prediction", "indexedLabel") \ + .map(lambda x: (x.prediction, x.indexedLabel)) + + metrics = MulticlassMetrics(predictionAndLabels) + print("weighted f-measure %.3f" % metrics.weightedFMeasure()) + print("precision %s" % metrics.precision()) + print("recall %s" % metrics.recall()) + + +def testRegression(train, test): + # Train a RandomForest model. + # Note: Use larger numTrees in practice. + + rf = RandomForestRegressor(labelCol="indexedLabel", numTrees=3, maxDepth=4) + + model = rf.fit(train) + predictionAndLabels = model.transform(test).select("prediction", "indexedLabel") \ + .map(lambda x: (x.prediction, x.indexedLabel)) + + metrics = RegressionMetrics(predictionAndLabels) + print("rmse %.3f" % metrics.rootMeanSquaredError) + print("r2 %.3f" % metrics.r2) + print("mae %.3f" % metrics.meanAbsoluteError) + + +if __name__ == "__main__": + if len(sys.argv) > 1: + print("Usage: random_forest_example", file=sys.stderr) + exit(1) + sc = SparkContext(appName="PythonRandomForestExample") + sqlContext = SQLContext(sc) + + # Load and parse the data file into a dataframe. + df = MLUtils.loadLibSVMFile(sc, "data/mllib/sample_libsvm_data.txt").toDF() + + # Map labels into an indexed column of labels in [0, numLabels) + stringIndexer = StringIndexer(inputCol="label", outputCol="indexedLabel") + si_model = stringIndexer.fit(df) + td = si_model.transform(df) + [train, test] = td.randomSplit([0.7, 0.3]) + testClassification(train, test) + testRegression(train, test) + sc.stop() diff --git a/examples/src/main/python/ml/simple_params_example.py b/examples/src/main/python/ml/simple_params_example.py new file mode 100644 index 0000000000000..3933d59b52cd1 --- /dev/null +++ b/examples/src/main/python/ml/simple_params_example.py @@ -0,0 +1,98 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from __future__ import print_function + +import pprint +import sys + +from pyspark import SparkContext +from pyspark.ml.classification import LogisticRegression +from pyspark.mllib.linalg import DenseVector +from pyspark.mllib.regression import LabeledPoint +from pyspark.sql import SQLContext + +""" +A simple example demonstrating ways to specify parameters for Estimators and Transformers. +Run with: + bin/spark-submit examples/src/main/python/ml/simple_params_example.py +""" + +if __name__ == "__main__": + if len(sys.argv) > 1: + print("Usage: simple_params_example", file=sys.stderr) + exit(1) + sc = SparkContext(appName="PythonSimpleParamsExample") + sqlContext = SQLContext(sc) + + # prepare training data. + # We create an RDD of LabeledPoints and convert them into a DataFrame. + # Spark DataFrames can automatically infer the schema from named tuples + # and LabeledPoint implements __reduce__ to behave like a named tuple. + training = sc.parallelize([ + LabeledPoint(1.0, DenseVector([0.0, 1.1, 0.1])), + LabeledPoint(0.0, DenseVector([2.0, 1.0, -1.0])), + LabeledPoint(0.0, DenseVector([2.0, 1.3, 1.0])), + LabeledPoint(1.0, DenseVector([0.0, 1.2, -0.5]))]).toDF() + + # Create a LogisticRegression instance with maxIter = 10. + # This instance is an Estimator. + lr = LogisticRegression(maxIter=10) + # Print out the parameters, documentation, and any default values. + print("LogisticRegression parameters:\n" + lr.explainParams() + "\n") + + # We may also set parameters using setter methods. + lr.setRegParam(0.01) + + # Learn a LogisticRegression model. This uses the parameters stored in lr. + model1 = lr.fit(training) + + # Since model1 is a Model (i.e., a Transformer produced by an Estimator), + # we can view the parameters it used during fit(). + # This prints the parameter (name: value) pairs, where names are unique IDs for this + # LogisticRegression instance. + print("Model 1 was fit using parameters:\n") + pprint.pprint(model1.extractParamMap()) + + # We may alternatively specify parameters using a parameter map. + # paramMap overrides all lr parameters set earlier. + paramMap = {lr.maxIter: 20, lr.threshold: 0.55, lr.probabilityCol: "myProbability"} + + # Now learn a new model using the new parameters. + model2 = lr.fit(training, paramMap) + print("Model 2 was fit using parameters:\n") + pprint.pprint(model2.extractParamMap()) + + # prepare test data. + test = sc.parallelize([ + LabeledPoint(1.0, DenseVector([-1.0, 1.5, 1.3])), + LabeledPoint(0.0, DenseVector([3.0, 2.0, -0.1])), + LabeledPoint(0.0, DenseVector([0.0, 2.2, -1.5]))]).toDF() + + # Make predictions on test data using the Transformer.transform() method. + # LogisticRegressionModel.transform will only use the 'features' column. + # Note that model2.transform() outputs a 'myProbability' column instead of the usual + # 'probability' column since we renamed the lr.probabilityCol parameter previously. + result = model2.transform(test) \ + .select("features", "label", "myProbability", "prediction") \ + .collect() + + for row in result: + print("features=%s,label=%s -> prob=%s, prediction=%s" + % (row.features, row.label, row.myProbability, row.prediction)) + + sc.stop() diff --git a/examples/src/main/scala/org/apache/spark/examples/LogQuery.scala b/examples/src/main/scala/org/apache/spark/examples/LogQuery.scala index 32e02eab8b031..75c82117cbad2 100644 --- a/examples/src/main/scala/org/apache/spark/examples/LogQuery.scala +++ b/examples/src/main/scala/org/apache/spark/examples/LogQuery.scala @@ -22,7 +22,7 @@ import org.apache.spark.SparkContext._ /** * Executes a roll up-style query against Apache logs. - * + * * Usage: LogQuery [logFile] */ object LogQuery { diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/SimpleParamsExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/SimpleParamsExample.scala index e8a991f50e338..a0561e2573fc9 100644 --- a/examples/src/main/scala/org/apache/spark/examples/ml/SimpleParamsExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/ml/SimpleParamsExample.scala @@ -87,7 +87,7 @@ object SimpleParamsExample { LabeledPoint(1.0, Vectors.dense(0.0, 2.2, -1.5)))) // Make predictions on test data using the Transformer.transform() method. - // LogisticRegression.transform will only use the 'features' column. + // LogisticRegressionModel.transform will only use the 'features' column. // Note that model2.transform() outputs a 'myProbability' column instead of the usual // 'probability' column since we renamed the lr.probabilityCol parameter previously. model2.transform(test.toDF()) diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/DecisionTreeRunner.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/DecisionTreeRunner.scala index b0613632c9946..3381941673db8 100644 --- a/examples/src/main/scala/org/apache/spark/examples/mllib/DecisionTreeRunner.scala +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/DecisionTreeRunner.scala @@ -22,7 +22,6 @@ import scala.language.reflectiveCalls import scopt.OptionParser import org.apache.spark.{SparkConf, SparkContext} -import org.apache.spark.SparkContext._ import org.apache.spark.mllib.evaluation.MulticlassMetrics import org.apache.spark.mllib.linalg.Vector import org.apache.spark.mllib.regression.LabeledPoint @@ -354,7 +353,11 @@ object DecisionTreeRunner { /** * Calculates the mean squared error for regression. + * + * This is just for demo purpose. In general, don't copy this code because it is NOT efficient + * due to the use of structural types, which leads to one reflection call per record. */ + // scalastyle:off structural.type private[mllib] def meanSquaredError( model: { def predict(features: Vector): Double }, data: RDD[LabeledPoint]): Double = { @@ -363,4 +366,5 @@ object DecisionTreeRunner { err * err }.mean() } + // scalastyle:on structural.type } diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/DenseGaussianMixture.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/DenseGaussianMixture.scala index 9a1aab036aa0f..f8c71ccabc43b 100644 --- a/examples/src/main/scala/org/apache/spark/examples/mllib/DenseGaussianMixture.scala +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/DenseGaussianMixture.scala @@ -41,22 +41,22 @@ object DenseGaussianMixture { private def run(inputFile: String, k: Int, convergenceTol: Double, maxIterations: Int) { val conf = new SparkConf().setAppName("Gaussian Mixture Model EM example") val ctx = new SparkContext(conf) - + val data = ctx.textFile(inputFile).map { line => Vectors.dense(line.trim.split(' ').map(_.toDouble)) }.cache() - + val clusters = new GaussianMixture() .setK(k) .setConvergenceTol(convergenceTol) .setMaxIterations(maxIterations) .run(data) - + for (i <- 0 until clusters.k) { - println("weight=%f\nmu=%s\nsigma=\n%s\n" format + println("weight=%f\nmu=%s\nsigma=\n%s\n" format (clusters.weights(i), clusters.gaussians(i).mu, clusters.gaussians(i).sigma)) } - + println("Cluster labels (first <= 100):") val clusterLabels = clusters.predict(data) clusterLabels.take(100).foreach { x => diff --git a/examples/src/main/scala/org/apache/spark/examples/streaming/MQTTWordCount.scala b/examples/src/main/scala/org/apache/spark/examples/streaming/MQTTWordCount.scala index b336751d81616..813c8554f5193 100644 --- a/examples/src/main/scala/org/apache/spark/examples/streaming/MQTTWordCount.scala +++ b/examples/src/main/scala/org/apache/spark/examples/streaming/MQTTWordCount.scala @@ -40,7 +40,7 @@ object MQTTPublisher { StreamingExamples.setStreamingLogLevels() val Seq(brokerUrl, topic) = args.toSeq - + var client: MqttClient = null try { @@ -59,10 +59,10 @@ object MQTTPublisher { println(s"Published data. topic: ${msgtopic.getName()}; Message: $message") } catch { case e: MqttException if e.getReasonCode == MqttException.REASON_CODE_MAX_INFLIGHT => - Thread.sleep(10) + Thread.sleep(10) println("Queue is full, wait for to consume data from the message queue") - } - } + } + } } catch { case e: MqttException => println("Exception Caught: " + e) } finally { @@ -107,7 +107,7 @@ object MQTTWordCount { val lines = MQTTUtils.createStream(ssc, brokerUrl, topic, StorageLevel.MEMORY_ONLY_SER_2) val words = lines.flatMap(x => x.split(" ")) val wordCounts = words.map(x => (x, 1)).reduceByKey(_ + _) - + wordCounts.print() ssc.start() ssc.awaitTermination() diff --git a/external/flume-sink/pom.xml b/external/flume-sink/pom.xml index 1f3e619d97a24..71f2b6fe18bd1 100644 --- a/external/flume-sink/pom.xml +++ b/external/flume-sink/pom.xml @@ -42,15 +42,46 @@ org.apache.flume flume-ng-sdk + + + + com.google.guava + guava + + + + org.apache.thrift + libthrift + + org.apache.flume flume-ng-core + + + com.google.guava + guava + + + org.apache.thrift + libthrift + + org.scala-lang scala-library + + + com.google.guava + guava + test + + + + diff --git a/external/flume-sink/src/main/scala/org/apache/spark/streaming/flume/sink/SparkAvroCallbackHandler.scala b/external/flume-sink/src/main/scala/org/apache/spark/streaming/flume/sink/SparkAvroCallbackHandler.scala index fd01807fc3ac4..dc2a4ab138e18 100644 --- a/external/flume-sink/src/main/scala/org/apache/spark/streaming/flume/sink/SparkAvroCallbackHandler.scala +++ b/external/flume-sink/src/main/scala/org/apache/spark/streaming/flume/sink/SparkAvroCallbackHandler.scala @@ -21,7 +21,6 @@ import java.util.concurrent.atomic.AtomicLong import scala.collection.mutable -import com.google.common.util.concurrent.ThreadFactoryBuilder import org.apache.flume.Channel import org.apache.commons.lang3.RandomStringUtils @@ -45,8 +44,7 @@ import org.apache.commons.lang3.RandomStringUtils private[flume] class SparkAvroCallbackHandler(val threads: Int, val channel: Channel, val transactionTimeout: Int, val backOffInterval: Int) extends SparkFlumeProtocol with Logging { val transactionExecutorOpt = Option(Executors.newFixedThreadPool(threads, - new ThreadFactoryBuilder().setDaemon(true) - .setNameFormat("Spark Sink Processor Thread - %d").build())) + new SparkSinkThreadFactory("Spark Sink Processor Thread - %d"))) // Protected by `sequenceNumberToProcessor` private val sequenceNumberToProcessor = mutable.HashMap[CharSequence, TransactionProcessor]() // This sink will not persist sequence numbers and reuses them if it gets restarted. diff --git a/core/src/main/scala/org/apache/spark/util/collection/PairIterator.scala b/external/flume-sink/src/main/scala/org/apache/spark/streaming/flume/sink/SparkSinkThreadFactory.scala similarity index 61% rename from core/src/main/scala/org/apache/spark/util/collection/PairIterator.scala rename to external/flume-sink/src/main/scala/org/apache/spark/streaming/flume/sink/SparkSinkThreadFactory.scala index d75959f480756..845fc8debda75 100644 --- a/core/src/main/scala/org/apache/spark/util/collection/PairIterator.scala +++ b/external/flume-sink/src/main/scala/org/apache/spark/streaming/flume/sink/SparkSinkThreadFactory.scala @@ -14,11 +14,22 @@ * See the License for the specific language governing permissions and * limitations under the License. */ +package org.apache.spark.streaming.flume.sink -package org.apache.spark.util.collection +import java.util.concurrent.ThreadFactory +import java.util.concurrent.atomic.AtomicLong -private[spark] class PairIterator[K, V](iter: Iterator[Any]) extends Iterator[(K, V)] { - def hasNext: Boolean = iter.hasNext +/** + * Thread factory that generates daemon threads with a specified name format. + */ +private[sink] class SparkSinkThreadFactory(nameFormat: String) extends ThreadFactory { + + private val threadId = new AtomicLong() + + override def newThread(r: Runnable): Thread = { + val t = new Thread(r, nameFormat.format(threadId.incrementAndGet())) + t.setDaemon(true) + t + } - def next(): (K, V) = (iter.next().asInstanceOf[K], iter.next().asInstanceOf[V]) } diff --git a/external/flume-sink/src/test/scala/org/apache/spark/streaming/flume/sink/SparkSinkSuite.scala b/external/flume-sink/src/test/scala/org/apache/spark/streaming/flume/sink/SparkSinkSuite.scala index 650b2fbe1c142..fa43629d49771 100644 --- a/external/flume-sink/src/test/scala/org/apache/spark/streaming/flume/sink/SparkSinkSuite.scala +++ b/external/flume-sink/src/test/scala/org/apache/spark/streaming/flume/sink/SparkSinkSuite.scala @@ -24,16 +24,24 @@ import scala.collection.JavaConversions._ import scala.concurrent.{ExecutionContext, Future} import scala.util.{Failure, Success} -import com.google.common.util.concurrent.ThreadFactoryBuilder import org.apache.avro.ipc.NettyTransceiver import org.apache.avro.ipc.specific.SpecificRequestor import org.apache.flume.Context import org.apache.flume.channel.MemoryChannel import org.apache.flume.event.EventBuilder import org.jboss.netty.channel.socket.nio.NioClientSocketChannelFactory + +// Due to MNG-1378, there is not a way to include test dependencies transitively. +// We cannot include Spark core tests as a dependency here because it depends on +// Spark core main, which has too many dependencies to require here manually. +// For this reason, we continue to use FunSuite and ignore the scalastyle checks +// that fail if this is detected. +//scalastyle:off import org.scalatest.FunSuite class SparkSinkSuite extends FunSuite { +//scalastyle:on + val eventsPerBatch = 1000 val channelCapacity = 5000 @@ -185,9 +193,8 @@ class SparkSinkSuite extends FunSuite { count: Int): Seq[(NettyTransceiver, SparkFlumeProtocol.Callback)] = { (1 to count).map(_ => { - lazy val channelFactoryExecutor = - Executors.newCachedThreadPool(new ThreadFactoryBuilder().setDaemon(true). - setNameFormat("Flume Receiver Channel Thread - %d").build()) + lazy val channelFactoryExecutor = Executors.newCachedThreadPool( + new SparkSinkThreadFactory("Flume Receiver Channel Thread - %d")) lazy val channelFactory = new NioClientSocketChannelFactory(channelFactoryExecutor, channelFactoryExecutor) val transceiver = new NettyTransceiver(address, channelFactory) diff --git a/external/flume/pom.xml b/external/flume/pom.xml index 8df7edbdcad33..a345c03582ad6 100644 --- a/external/flume/pom.xml +++ b/external/flume/pom.xml @@ -41,6 +41,13 @@ ${project.version} provided + + org.apache.spark + spark-core_${scala.binary.version} + ${project.version} + test-jar + test + org.apache.spark spark-streaming-flume-sink_${scala.binary.version} diff --git a/external/flume/src/main/scala/org/apache/spark/streaming/flume/FlumeInputDStream.scala b/external/flume/src/main/scala/org/apache/spark/streaming/flume/FlumeInputDStream.scala index 60e2994431b38..1e32a365a1eee 100644 --- a/external/flume/src/main/scala/org/apache/spark/streaming/flume/FlumeInputDStream.scala +++ b/external/flume/src/main/scala/org/apache/spark/streaming/flume/FlumeInputDStream.scala @@ -152,9 +152,9 @@ class FlumeReceiver( val channelFactory = new NioServerSocketChannelFactory(Executors.newCachedThreadPool(), Executors.newCachedThreadPool()) val channelPipelineFactory = new CompressionChannelPipelineFactory() - + new NettyServer( - responder, + responder, new InetSocketAddress(host, port), channelFactory, channelPipelineFactory, @@ -188,12 +188,12 @@ class FlumeReceiver( override def preferredLocation: Option[String] = Option(host) - /** A Netty Pipeline factory that will decompress incoming data from + /** A Netty Pipeline factory that will decompress incoming data from * and the Netty client and compress data going back to the client. * * The compression on the return is required because Flume requires - * a successful response to indicate it can remove the event/batch - * from the configured channel + * a successful response to indicate it can remove the event/batch + * from the configured channel */ private[streaming] class CompressionChannelPipelineFactory extends ChannelPipelineFactory { diff --git a/external/flume/src/main/scala/org/apache/spark/streaming/flume/FlumePollingInputDStream.scala b/external/flume/src/main/scala/org/apache/spark/streaming/flume/FlumePollingInputDStream.scala index 92fa5b41be89e..583e7dca317ad 100644 --- a/external/flume/src/main/scala/org/apache/spark/streaming/flume/FlumePollingInputDStream.scala +++ b/external/flume/src/main/scala/org/apache/spark/streaming/flume/FlumePollingInputDStream.scala @@ -110,7 +110,7 @@ private[streaming] class FlumePollingReceiver( } /** - * A wrapper around the transceiver and the Avro IPC API. + * A wrapper around the transceiver and the Avro IPC API. * @param transceiver The transceiver to use for communication with Flume * @param client The client that the callbacks are received on. */ diff --git a/external/flume/src/test/scala/org/apache/spark/streaming/flume/FlumePollingStreamSuite.scala b/external/flume/src/test/scala/org/apache/spark/streaming/flume/FlumePollingStreamSuite.scala index 93afe50c2134f..d772b9ca9b570 100644 --- a/external/flume/src/test/scala/org/apache/spark/streaming/flume/FlumePollingStreamSuite.scala +++ b/external/flume/src/test/scala/org/apache/spark/streaming/flume/FlumePollingStreamSuite.scala @@ -31,16 +31,16 @@ import org.apache.flume.conf.Configurables import org.apache.flume.event.EventBuilder import org.scalatest.concurrent.Eventually._ -import org.scalatest.{BeforeAndAfter, FunSuite} +import org.scalatest.BeforeAndAfter -import org.apache.spark.{SparkConf, Logging} +import org.apache.spark.{Logging, SparkConf, SparkFunSuite} import org.apache.spark.storage.StorageLevel import org.apache.spark.streaming.dstream.ReceiverInputDStream import org.apache.spark.streaming.{Seconds, TestOutputStream, StreamingContext} import org.apache.spark.streaming.flume.sink._ import org.apache.spark.util.{ManualClock, Utils} -class FlumePollingStreamSuite extends FunSuite with BeforeAndAfter with Logging { +class FlumePollingStreamSuite extends SparkFunSuite with BeforeAndAfter with Logging { val batchCount = 5 val eventsPerBatch = 100 diff --git a/external/flume/src/test/scala/org/apache/spark/streaming/flume/FlumeStreamSuite.scala b/external/flume/src/test/scala/org/apache/spark/streaming/flume/FlumeStreamSuite.scala index 39e6754c81dbf..c926359987d89 100644 --- a/external/flume/src/test/scala/org/apache/spark/streaming/flume/FlumeStreamSuite.scala +++ b/external/flume/src/test/scala/org/apache/spark/streaming/flume/FlumeStreamSuite.scala @@ -35,15 +35,15 @@ import org.jboss.netty.channel.ChannelPipeline import org.jboss.netty.channel.socket.SocketChannel import org.jboss.netty.channel.socket.nio.NioClientSocketChannelFactory import org.jboss.netty.handler.codec.compression._ -import org.scalatest.{BeforeAndAfter, FunSuite, Matchers} +import org.scalatest.{BeforeAndAfter, Matchers} import org.scalatest.concurrent.Eventually._ -import org.apache.spark.{Logging, SparkConf} +import org.apache.spark.{Logging, SparkConf, SparkFunSuite} import org.apache.spark.storage.StorageLevel import org.apache.spark.streaming.{Milliseconds, StreamingContext, TestOutputStream} import org.apache.spark.util.Utils -class FlumeStreamSuite extends FunSuite with BeforeAndAfter with Matchers with Logging { +class FlumeStreamSuite extends SparkFunSuite with BeforeAndAfter with Matchers with Logging { val conf = new SparkConf().setMaster("local[4]").setAppName("FlumeStreamSuite") var ssc: StreamingContext = null @@ -138,7 +138,7 @@ class FlumeStreamSuite extends FunSuite with BeforeAndAfter with Matchers with L val status = client.appendBatch(inputEvents.toList) status should be (avro.Status.OK) } - + eventually(timeout(10 seconds), interval(100 milliseconds)) { val outputEvents = outputBuffer.flatten.map { _.event } outputEvents.foreach { diff --git a/external/kafka/pom.xml b/external/kafka/pom.xml index 243ce6eaca658..5734d55bf4784 100644 --- a/external/kafka/pom.xml +++ b/external/kafka/pom.xml @@ -41,6 +41,13 @@ ${project.version} provided + + org.apache.spark + spark-core_${scala.binary.version} + ${project.version} + test-jar + test + org.apache.kafka kafka_${scala.binary.version} diff --git a/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/KafkaCluster.scala b/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/KafkaCluster.scala index 6cf254a7b69cb..65d51d87f8486 100644 --- a/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/KafkaCluster.scala +++ b/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/KafkaCluster.scala @@ -113,7 +113,7 @@ class KafkaCluster(val kafkaParams: Map[String, String]) extends Serializable { r.flatMap { tm: TopicMetadata => tm.partitionsMetadata.map { pm: PartitionMetadata => TopicAndPartition(tm.topic, pm.partitionId) - } + } } } } diff --git a/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/KafkaUtils.scala b/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/KafkaUtils.scala index 8be2707528d93..0b8a391a2c569 100644 --- a/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/KafkaUtils.scala +++ b/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/KafkaUtils.scala @@ -315,7 +315,7 @@ object KafkaUtils { * Points to note: * - No receivers: This stream does not use any receiver. It directly queries Kafka * - Offsets: This does not use Zookeeper to store offsets. The consumed offsets are tracked - * by the stream itself. For interoperability with Kafka monitoring tools that depend on + * by the stream itself. For interoperability with Kafka monitoring tools that depend on * Zookeeper, you have to update Kafka/Zookeeper yourself from the streaming application. * You can access the offsets used in each batch from the generated RDDs (see * [[org.apache.spark.streaming.kafka.HasOffsetRanges]]). @@ -363,7 +363,7 @@ object KafkaUtils { * Points to note: * - No receivers: This stream does not use any receiver. It directly queries Kafka * - Offsets: This does not use Zookeeper to store offsets. The consumed offsets are tracked - * by the stream itself. For interoperability with Kafka monitoring tools that depend on + * by the stream itself. For interoperability with Kafka monitoring tools that depend on * Zookeeper, you have to update Kafka/Zookeeper yourself from the streaming application. * You can access the offsets used in each batch from the generated RDDs (see * [[org.apache.spark.streaming.kafka.HasOffsetRanges]]). @@ -427,7 +427,7 @@ object KafkaUtils { * Points to note: * - No receivers: This stream does not use any receiver. It directly queries Kafka * - Offsets: This does not use Zookeeper to store offsets. The consumed offsets are tracked - * by the stream itself. For interoperability with Kafka monitoring tools that depend on + * by the stream itself. For interoperability with Kafka monitoring tools that depend on * Zookeeper, you have to update Kafka/Zookeeper yourself from the streaming application. * You can access the offsets used in each batch from the generated RDDs (see * [[org.apache.spark.streaming.kafka.HasOffsetRanges]]). @@ -489,7 +489,7 @@ object KafkaUtils { * Points to note: * - No receivers: This stream does not use any receiver. It directly queries Kafka * - Offsets: This does not use Zookeeper to store offsets. The consumed offsets are tracked - * by the stream itself. For interoperability with Kafka monitoring tools that depend on + * by the stream itself. For interoperability with Kafka monitoring tools that depend on * Zookeeper, you have to update Kafka/Zookeeper yourself from the streaming application. * You can access the offsets used in each batch from the generated RDDs (see * [[org.apache.spark.streaming.kafka.HasOffsetRanges]]). diff --git a/external/kafka/src/test/scala/org/apache/spark/streaming/kafka/DirectKafkaStreamSuite.scala b/external/kafka/src/test/scala/org/apache/spark/streaming/kafka/DirectKafkaStreamSuite.scala index b6d314dfc7783..47bbfb605850a 100644 --- a/external/kafka/src/test/scala/org/apache/spark/streaming/kafka/DirectKafkaStreamSuite.scala +++ b/external/kafka/src/test/scala/org/apache/spark/streaming/kafka/DirectKafkaStreamSuite.scala @@ -28,10 +28,10 @@ import scala.language.postfixOps import kafka.common.TopicAndPartition import kafka.message.MessageAndMetadata import kafka.serializer.StringDecoder -import org.scalatest.{BeforeAndAfter, BeforeAndAfterAll, FunSuite} +import org.scalatest.{BeforeAndAfter, BeforeAndAfterAll} import org.scalatest.concurrent.Eventually -import org.apache.spark.{Logging, SparkConf, SparkContext} +import org.apache.spark.{Logging, SparkConf, SparkContext, SparkFunSuite} import org.apache.spark.rdd.RDD import org.apache.spark.streaming.{Milliseconds, StreamingContext, Time} import org.apache.spark.streaming.dstream.DStream @@ -39,7 +39,7 @@ import org.apache.spark.streaming.scheduler._ import org.apache.spark.util.Utils class DirectKafkaStreamSuite - extends FunSuite + extends SparkFunSuite with BeforeAndAfter with BeforeAndAfterAll with Eventually diff --git a/external/kafka/src/test/scala/org/apache/spark/streaming/kafka/KafkaClusterSuite.scala b/external/kafka/src/test/scala/org/apache/spark/streaming/kafka/KafkaClusterSuite.scala index 7fb841b79cb65..d66830cbacdee 100644 --- a/external/kafka/src/test/scala/org/apache/spark/streaming/kafka/KafkaClusterSuite.scala +++ b/external/kafka/src/test/scala/org/apache/spark/streaming/kafka/KafkaClusterSuite.scala @@ -20,9 +20,11 @@ package org.apache.spark.streaming.kafka import scala.util.Random import kafka.common.TopicAndPartition -import org.scalatest.{BeforeAndAfterAll, FunSuite} +import org.scalatest.BeforeAndAfterAll -class KafkaClusterSuite extends FunSuite with BeforeAndAfterAll { +import org.apache.spark.SparkFunSuite + +class KafkaClusterSuite extends SparkFunSuite with BeforeAndAfterAll { private val topic = "kcsuitetopic" + Random.nextInt(10000) private val topicAndPartition = TopicAndPartition(topic, 0) private var kc: KafkaCluster = null diff --git a/external/kafka/src/test/scala/org/apache/spark/streaming/kafka/KafkaRDDSuite.scala b/external/kafka/src/test/scala/org/apache/spark/streaming/kafka/KafkaRDDSuite.scala index 3c875cb766513..054487269a935 100644 --- a/external/kafka/src/test/scala/org/apache/spark/streaming/kafka/KafkaRDDSuite.scala +++ b/external/kafka/src/test/scala/org/apache/spark/streaming/kafka/KafkaRDDSuite.scala @@ -22,11 +22,11 @@ import scala.util.Random import kafka.serializer.StringDecoder import kafka.common.TopicAndPartition import kafka.message.MessageAndMetadata -import org.scalatest.{BeforeAndAfterAll, FunSuite} +import org.scalatest.BeforeAndAfterAll import org.apache.spark._ -class KafkaRDDSuite extends FunSuite with BeforeAndAfterAll { +class KafkaRDDSuite extends SparkFunSuite with BeforeAndAfterAll { private var kafkaTestUtils: KafkaTestUtils = _ diff --git a/external/kafka/src/test/scala/org/apache/spark/streaming/kafka/KafkaStreamSuite.scala b/external/kafka/src/test/scala/org/apache/spark/streaming/kafka/KafkaStreamSuite.scala index 24699dfc33adb..8ee2cc660f849 100644 --- a/external/kafka/src/test/scala/org/apache/spark/streaming/kafka/KafkaStreamSuite.scala +++ b/external/kafka/src/test/scala/org/apache/spark/streaming/kafka/KafkaStreamSuite.scala @@ -23,14 +23,14 @@ import scala.language.postfixOps import scala.util.Random import kafka.serializer.StringDecoder -import org.scalatest.{BeforeAndAfterAll, FunSuite} +import org.scalatest.BeforeAndAfterAll import org.scalatest.concurrent.Eventually -import org.apache.spark.SparkConf +import org.apache.spark.{SparkConf, SparkFunSuite} import org.apache.spark.storage.StorageLevel import org.apache.spark.streaming.{Milliseconds, StreamingContext} -class KafkaStreamSuite extends FunSuite with Eventually with BeforeAndAfterAll { +class KafkaStreamSuite extends SparkFunSuite with Eventually with BeforeAndAfterAll { private var ssc: StreamingContext = _ private var kafkaTestUtils: KafkaTestUtils = _ diff --git a/external/kafka/src/test/scala/org/apache/spark/streaming/kafka/ReliableKafkaStreamSuite.scala b/external/kafka/src/test/scala/org/apache/spark/streaming/kafka/ReliableKafkaStreamSuite.scala index 38548dd73b82c..80e2df62de3fe 100644 --- a/external/kafka/src/test/scala/org/apache/spark/streaming/kafka/ReliableKafkaStreamSuite.scala +++ b/external/kafka/src/test/scala/org/apache/spark/streaming/kafka/ReliableKafkaStreamSuite.scala @@ -26,15 +26,15 @@ import scala.util.Random import kafka.serializer.StringDecoder import kafka.utils.{ZKGroupTopicDirs, ZkUtils} -import org.scalatest.{BeforeAndAfter, BeforeAndAfterAll, FunSuite} +import org.scalatest.{BeforeAndAfter, BeforeAndAfterAll} import org.scalatest.concurrent.Eventually -import org.apache.spark.SparkConf +import org.apache.spark.{SparkConf, SparkFunSuite} import org.apache.spark.storage.StorageLevel import org.apache.spark.streaming.{Milliseconds, StreamingContext} import org.apache.spark.util.Utils -class ReliableKafkaStreamSuite extends FunSuite +class ReliableKafkaStreamSuite extends SparkFunSuite with BeforeAndAfterAll with BeforeAndAfter with Eventually { private val sparkConf = new SparkConf() diff --git a/external/mqtt/pom.xml b/external/mqtt/pom.xml index 98f95a9a64fa0..7d102e10ab60f 100644 --- a/external/mqtt/pom.xml +++ b/external/mqtt/pom.xml @@ -41,6 +41,13 @@ ${project.version} provided + + org.apache.spark + spark-core_${scala.binary.version} + ${project.version} + test-jar + test + org.eclipse.paho org.eclipse.paho.client.mqttv3 diff --git a/external/mqtt/src/test/scala/org/apache/spark/streaming/mqtt/MQTTStreamSuite.scala b/external/mqtt/src/test/scala/org/apache/spark/streaming/mqtt/MQTTStreamSuite.scala index a19a72c58a705..c4bf5aa7869bb 100644 --- a/external/mqtt/src/test/scala/org/apache/spark/streaming/mqtt/MQTTStreamSuite.scala +++ b/external/mqtt/src/test/scala/org/apache/spark/streaming/mqtt/MQTTStreamSuite.scala @@ -29,7 +29,7 @@ import org.apache.commons.lang3.RandomUtils import org.eclipse.paho.client.mqttv3._ import org.eclipse.paho.client.mqttv3.persist.MqttDefaultFilePersistence -import org.scalatest.{BeforeAndAfter, FunSuite} +import org.scalatest.BeforeAndAfter import org.scalatest.concurrent.Eventually import org.apache.spark.streaming.{Milliseconds, StreamingContext} @@ -37,10 +37,10 @@ import org.apache.spark.storage.StorageLevel import org.apache.spark.streaming.dstream.ReceiverInputDStream import org.apache.spark.streaming.scheduler.StreamingListener import org.apache.spark.streaming.scheduler.StreamingListenerReceiverStarted -import org.apache.spark.SparkConf +import org.apache.spark.{SparkConf, SparkFunSuite} import org.apache.spark.util.Utils -class MQTTStreamSuite extends FunSuite with Eventually with BeforeAndAfter { +class MQTTStreamSuite extends SparkFunSuite with Eventually with BeforeAndAfter { private val batchDuration = Milliseconds(500) private val master = "local[2]" diff --git a/external/twitter/pom.xml b/external/twitter/pom.xml index 8b6a8959ac4cf..d28e3e1846d70 100644 --- a/external/twitter/pom.xml +++ b/external/twitter/pom.xml @@ -41,6 +41,13 @@ ${project.version} provided + + org.apache.spark + spark-core_${scala.binary.version} + ${project.version} + test-jar + test + org.twitter4j twitter4j-stream diff --git a/external/twitter/src/test/scala/org/apache/spark/streaming/twitter/TwitterStreamSuite.scala b/external/twitter/src/test/scala/org/apache/spark/streaming/twitter/TwitterStreamSuite.scala index 9ee57d7581d85..d9acb568879fe 100644 --- a/external/twitter/src/test/scala/org/apache/spark/streaming/twitter/TwitterStreamSuite.scala +++ b/external/twitter/src/test/scala/org/apache/spark/streaming/twitter/TwitterStreamSuite.scala @@ -18,16 +18,16 @@ package org.apache.spark.streaming.twitter -import org.scalatest.{BeforeAndAfter, FunSuite} +import org.scalatest.BeforeAndAfter import twitter4j.Status import twitter4j.auth.{NullAuthorization, Authorization} -import org.apache.spark.Logging +import org.apache.spark.{Logging, SparkFunSuite} import org.apache.spark.streaming.{Seconds, StreamingContext} import org.apache.spark.storage.StorageLevel import org.apache.spark.streaming.dstream.ReceiverInputDStream -class TwitterStreamSuite extends FunSuite with BeforeAndAfter with Logging { +class TwitterStreamSuite extends SparkFunSuite with BeforeAndAfter with Logging { val batchDuration = Seconds(1) diff --git a/external/zeromq/pom.xml b/external/zeromq/pom.xml index a50d378b34335..9998c11c85171 100644 --- a/external/zeromq/pom.xml +++ b/external/zeromq/pom.xml @@ -41,6 +41,13 @@ ${project.version} provided + + org.apache.spark + spark-core_${scala.binary.version} + ${project.version} + test-jar + test + ${akka.group} akka-zeromq_${scala.binary.version} diff --git a/external/zeromq/src/test/scala/org/apache/spark/streaming/zeromq/ZeroMQStreamSuite.scala b/external/zeromq/src/test/scala/org/apache/spark/streaming/zeromq/ZeroMQStreamSuite.scala index a7566e733d891..35d2e62c68480 100644 --- a/external/zeromq/src/test/scala/org/apache/spark/streaming/zeromq/ZeroMQStreamSuite.scala +++ b/external/zeromq/src/test/scala/org/apache/spark/streaming/zeromq/ZeroMQStreamSuite.scala @@ -20,13 +20,13 @@ package org.apache.spark.streaming.zeromq import akka.actor.SupervisorStrategy import akka.util.ByteString import akka.zeromq.Subscribe -import org.scalatest.FunSuite +import org.apache.spark.SparkFunSuite import org.apache.spark.storage.StorageLevel import org.apache.spark.streaming.{Seconds, StreamingContext} import org.apache.spark.streaming.dstream.ReceiverInputDStream -class ZeroMQStreamSuite extends FunSuite { +class ZeroMQStreamSuite extends SparkFunSuite { val batchDuration = Seconds(1) diff --git a/extras/kinesis-asl/src/main/scala/org/apache/spark/examples/streaming/KinesisWordCountASL.scala b/extras/kinesis-asl/src/main/scala/org/apache/spark/examples/streaming/KinesisWordCountASL.scala index 97c3476049289..be8b62d3cc6ba 100644 --- a/extras/kinesis-asl/src/main/scala/org/apache/spark/examples/streaming/KinesisWordCountASL.scala +++ b/extras/kinesis-asl/src/main/scala/org/apache/spark/examples/streaming/KinesisWordCountASL.scala @@ -119,7 +119,7 @@ object KinesisWordCountASL extends Logging { val batchInterval = Milliseconds(2000) // Kinesis checkpoint interval is the interval at which the DynamoDB is updated with information - // on sequence number of records that have been received. Same as batchInterval for this + // on sequence number of records that have been received. Same as batchInterval for this // example. val kinesisCheckpointInterval = batchInterval @@ -145,7 +145,7 @@ object KinesisWordCountASL extends Logging { // Map each word to a (word, 1) tuple so we can reduce by key to count the words val wordCounts = words.map(word => (word, 1)).reduceByKey(_ + _) - + // Print the first 10 wordCounts wordCounts.print() @@ -210,14 +210,14 @@ object KinesisWordProducerASL { val randomWords = List("spark", "you", "are", "my", "father") val totals = scala.collection.mutable.Map[String, Int]() - + // Create the low-level Kinesis Client from the AWS Java SDK. val kinesisClient = new AmazonKinesisClient(new DefaultAWSCredentialsProviderChain()) kinesisClient.setEndpoint(endpoint) println(s"Putting records onto stream $stream and endpoint $endpoint at a rate of" + s" $recordsPerSecond records per second and $wordsPerRecord words per record") - + // Iterate and put records onto the stream per the given recordPerSec and wordsPerRecord for (i <- 1 to 10) { // Generate recordsPerSec records to put onto the stream @@ -255,8 +255,8 @@ object KinesisWordProducerASL { } } -/** - * Utility functions for Spark Streaming examples. +/** + * Utility functions for Spark Streaming examples. * This has been lifted from the examples/ project to remove the circular dependency. */ private[streaming] object StreamingExamples extends Logging { diff --git a/extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisCheckpointState.scala b/extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisCheckpointState.scala index 1c9b0c218ae18..83a4537559512 100644 --- a/extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisCheckpointState.scala +++ b/extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisCheckpointState.scala @@ -23,20 +23,20 @@ import org.apache.spark.util.{Clock, ManualClock, SystemClock} /** * This is a helper class for managing checkpoint clocks. * - * @param checkpointInterval + * @param checkpointInterval * @param currentClock. Default to current SystemClock if none is passed in (mocking purposes) */ private[kinesis] class KinesisCheckpointState( - checkpointInterval: Duration, + checkpointInterval: Duration, currentClock: Clock = new SystemClock()) extends Logging { - + /* Initialize the checkpoint clock using the given currentClock + checkpointInterval millis */ val checkpointClock = new ManualClock() checkpointClock.setTime(currentClock.getTimeMillis() + checkpointInterval.milliseconds) /** - * Check if it's time to checkpoint based on the current time and the derived time + * Check if it's time to checkpoint based on the current time and the derived time * for the next checkpoint * * @return true if it's time to checkpoint diff --git a/extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisReceiver.scala b/extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisReceiver.scala index 7dd8bfdc2a6db..1a8a4cecc1141 100644 --- a/extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisReceiver.scala +++ b/extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisReceiver.scala @@ -44,12 +44,12 @@ case class SerializableAWSCredentials(accessKeyId: String, secretKey: String) * https://github.com/awslabs/amazon-kinesis-client * This is a custom receiver used with StreamingContext.receiverStream(Receiver) as described here: * http://spark.apache.org/docs/latest/streaming-custom-receivers.html - * Instances of this class will get shipped to the Spark Streaming Workers to run within a + * Instances of this class will get shipped to the Spark Streaming Workers to run within a * Spark Executor. * * @param appName Kinesis application name. Kinesis Apps are mapped to Kinesis Streams * by the Kinesis Client Library. If you change the App name or Stream name, - * the KCL will throw errors. This usually requires deleting the backing + * the KCL will throw errors. This usually requires deleting the backing * DynamoDB table with the same name this Kinesis application. * @param streamName Kinesis stream name * @param endpointUrl Url of Kinesis service (e.g., https://kinesis.us-east-1.amazonaws.com) @@ -87,7 +87,7 @@ private[kinesis] class KinesisReceiver( */ /** - * workerId is used by the KCL should be based on the ip address of the actual Spark Worker + * workerId is used by the KCL should be based on the ip address of the actual Spark Worker * where this code runs (not the driver's IP address.) */ private var workerId: String = null @@ -121,7 +121,7 @@ private[kinesis] class KinesisReceiver( /* * RecordProcessorFactory creates impls of IRecordProcessor. - * IRecordProcessor adapts the KCL to our Spark KinesisReceiver via the + * IRecordProcessor adapts the KCL to our Spark KinesisReceiver via the * IRecordProcessor.processRecords() method. * We're using our custom KinesisRecordProcessor in this case. */ diff --git a/extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisRecordProcessor.scala b/extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisRecordProcessor.scala index f65e743c4e2a3..fe9e3a0c793e2 100644 --- a/extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisRecordProcessor.scala +++ b/extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisRecordProcessor.scala @@ -35,9 +35,9 @@ import com.amazonaws.services.kinesis.model.Record /** * Kinesis-specific implementation of the Kinesis Client Library (KCL) IRecordProcessor. * This implementation operates on the Array[Byte] from the KinesisReceiver. - * The Kinesis Worker creates an instance of this KinesisRecordProcessor for each - * shard in the Kinesis stream upon startup. This is normally done in separate threads, - * but the KCLs within the KinesisReceivers will balance themselves out if you create + * The Kinesis Worker creates an instance of this KinesisRecordProcessor for each + * shard in the Kinesis stream upon startup. This is normally done in separate threads, + * but the KCLs within the KinesisReceivers will balance themselves out if you create * multiple Receivers. * * @param receiver Kinesis receiver @@ -69,14 +69,14 @@ private[kinesis] class KinesisRecordProcessor( * and Spark Streaming's Receiver.store(). * * @param batch list of records from the Kinesis stream shard - * @param checkpointer used to update Kinesis when this batch has been processed/stored + * @param checkpointer used to update Kinesis when this batch has been processed/stored * in the DStream */ override def processRecords(batch: List[Record], checkpointer: IRecordProcessorCheckpointer) { if (!receiver.isStopped()) { try { /* - * Notes: + * Notes: * 1) If we try to store the raw ByteBuffer from record.getData(), the Spark Streaming * Receiver.store(ByteBuffer) attempts to deserialize the ByteBuffer using the * internally-configured Spark serializer (kryo, etc). @@ -84,19 +84,19 @@ private[kinesis] class KinesisRecordProcessor( * ourselves from Spark's internal serialization strategy. * 3) For performance, the BlockGenerator is asynchronously queuing elements within its * memory before creating blocks. This prevents the small block scenario, but requires - * that you register callbacks to know when a block has been generated and stored + * that you register callbacks to know when a block has been generated and stored * (WAL is sufficient for storage) before can checkpoint back to the source. */ batch.foreach(record => receiver.store(record.getData().array())) - + logDebug(s"Stored: Worker $workerId stored ${batch.size} records for shardId $shardId") /* - * Checkpoint the sequence number of the last record successfully processed/stored + * Checkpoint the sequence number of the last record successfully processed/stored * in the batch. * In this implementation, we're checkpointing after the given checkpointIntervalMillis. - * Note that this logic requires that processRecords() be called AND that it's time to - * checkpoint. I point this out because there is no background thread running the + * Note that this logic requires that processRecords() be called AND that it's time to + * checkpoint. I point this out because there is no background thread running the * checkpointer. Checkpointing is tested and trigger only when a new batch comes in. * If the worker is shutdown cleanly, checkpoint will happen (see shutdown() below). * However, if the worker dies unexpectedly, a checkpoint may not happen. @@ -130,16 +130,16 @@ private[kinesis] class KinesisRecordProcessor( } } else { /* RecordProcessor has been stopped. */ - logInfo(s"Stopped: The Spark KinesisReceiver has stopped for workerId $workerId" + + logInfo(s"Stopped: The Spark KinesisReceiver has stopped for workerId $workerId" + s" and shardId $shardId. No more records will be processed.") } } /** * Kinesis Client Library is shutting down this Worker for 1 of 2 reasons: - * 1) the stream is resharding by splitting or merging adjacent shards + * 1) the stream is resharding by splitting or merging adjacent shards * (ShutdownReason.TERMINATE) - * 2) the failed or latent Worker has stopped sending heartbeats for whatever reason + * 2) the failed or latent Worker has stopped sending heartbeats for whatever reason * (ShutdownReason.ZOMBIE) * * @param checkpointer used to perform a Kinesis checkpoint for ShutdownReason.TERMINATE @@ -153,7 +153,7 @@ private[kinesis] class KinesisRecordProcessor( * Checkpoint to indicate that all records from the shard have been drained and processed. * It's now OK to read from the new shards that resulted from a resharding event. */ - case ShutdownReason.TERMINATE => + case ShutdownReason.TERMINATE => KinesisRecordProcessor.retryRandom(checkpointer.checkpoint(), 4, 100) /* diff --git a/graphx/pom.xml b/graphx/pom.xml index d38a3aa8256b7..28b41228feb3d 100644 --- a/graphx/pom.xml +++ b/graphx/pom.xml @@ -40,6 +40,13 @@ spark-core_${scala.binary.version} ${project.version} + + org.apache.spark + spark-core_${scala.binary.version} + ${project.version} + test-jar + test + com.google.guava guava diff --git a/graphx/src/main/scala/org/apache/spark/graphx/EdgeRDD.scala b/graphx/src/main/scala/org/apache/spark/graphx/EdgeRDD.scala index cc70b396a8dd4..4611a3ace219b 100644 --- a/graphx/src/main/scala/org/apache/spark/graphx/EdgeRDD.scala +++ b/graphx/src/main/scala/org/apache/spark/graphx/EdgeRDD.scala @@ -41,14 +41,16 @@ abstract class EdgeRDD[ED]( @transient sc: SparkContext, @transient deps: Seq[Dependency[_]]) extends RDD[Edge[ED]](sc, deps) { + // scalastyle:off structural.type private[graphx] def partitionsRDD: RDD[(PartitionID, EdgePartition[ED, VD])] forSome { type VD } + // scalastyle:on structural.type override protected def getPartitions: Array[Partition] = partitionsRDD.partitions override def compute(part: Partition, context: TaskContext): Iterator[Edge[ED]] = { val p = firstParent[(PartitionID, EdgePartition[ED, _])].iterator(part, context) if (p.hasNext) { - p.next._2.iterator.map(_.copy()) + p.next()._2.iterator.map(_.copy()) } else { Iterator.empty } diff --git a/graphx/src/test/scala/org/apache/spark/graphx/EdgeRDDSuite.scala b/graphx/src/test/scala/org/apache/spark/graphx/EdgeRDDSuite.scala index eb1dbe52c2fda..f1ecc9e2219d1 100644 --- a/graphx/src/test/scala/org/apache/spark/graphx/EdgeRDDSuite.scala +++ b/graphx/src/test/scala/org/apache/spark/graphx/EdgeRDDSuite.scala @@ -17,11 +17,10 @@ package org.apache.spark.graphx -import org.scalatest.FunSuite - +import org.apache.spark.SparkFunSuite import org.apache.spark.storage.StorageLevel -class EdgeRDDSuite extends FunSuite with LocalSparkContext { +class EdgeRDDSuite extends SparkFunSuite with LocalSparkContext { test("cache, getStorageLevel") { // test to see if getStorageLevel returns correct value after caching diff --git a/graphx/src/test/scala/org/apache/spark/graphx/EdgeSuite.scala b/graphx/src/test/scala/org/apache/spark/graphx/EdgeSuite.scala index 5a2c73b414279..094a63472eaab 100644 --- a/graphx/src/test/scala/org/apache/spark/graphx/EdgeSuite.scala +++ b/graphx/src/test/scala/org/apache/spark/graphx/EdgeSuite.scala @@ -17,21 +17,21 @@ package org.apache.spark.graphx -import org.scalatest.FunSuite +import org.apache.spark.SparkFunSuite -class EdgeSuite extends FunSuite { +class EdgeSuite extends SparkFunSuite { test ("compare") { // decending order val testEdges: Array[Edge[Int]] = Array( - Edge(0x7FEDCBA987654321L, -0x7FEDCBA987654321L, 1), - Edge(0x2345L, 0x1234L, 1), - Edge(0x1234L, 0x5678L, 1), - Edge(0x1234L, 0x2345L, 1), + Edge(0x7FEDCBA987654321L, -0x7FEDCBA987654321L, 1), + Edge(0x2345L, 0x1234L, 1), + Edge(0x1234L, 0x5678L, 1), + Edge(0x1234L, 0x2345L, 1), Edge(-0x7FEDCBA987654321L, 0x7FEDCBA987654321L, 1) ) // to ascending order val sortedEdges = testEdges.sorted(Edge.lexicographicOrdering[Int]) - + for (i <- 0 until testEdges.length) { assert(sortedEdges(i) == testEdges(testEdges.length - i - 1)) } diff --git a/graphx/src/test/scala/org/apache/spark/graphx/GraphOpsSuite.scala b/graphx/src/test/scala/org/apache/spark/graphx/GraphOpsSuite.scala index 68fe83739e399..57a8b95dd12e9 100644 --- a/graphx/src/test/scala/org/apache/spark/graphx/GraphOpsSuite.scala +++ b/graphx/src/test/scala/org/apache/spark/graphx/GraphOpsSuite.scala @@ -17,13 +17,12 @@ package org.apache.spark.graphx -import org.apache.spark.SparkContext +import org.apache.spark.{SparkContext, SparkFunSuite} import org.apache.spark.graphx.Graph._ import org.apache.spark.graphx.impl.EdgePartition import org.apache.spark.rdd._ -import org.scalatest.FunSuite -class GraphOpsSuite extends FunSuite with LocalSparkContext { +class GraphOpsSuite extends SparkFunSuite with LocalSparkContext { test("joinVertices") { withSpark { sc => diff --git a/graphx/src/test/scala/org/apache/spark/graphx/GraphSuite.scala b/graphx/src/test/scala/org/apache/spark/graphx/GraphSuite.scala index 2b1d8e47326f8..1f5e27d5508b8 100644 --- a/graphx/src/test/scala/org/apache/spark/graphx/GraphSuite.scala +++ b/graphx/src/test/scala/org/apache/spark/graphx/GraphSuite.scala @@ -17,16 +17,14 @@ package org.apache.spark.graphx -import org.scalatest.FunSuite - -import org.apache.spark.SparkContext +import org.apache.spark.{SparkContext, SparkFunSuite} import org.apache.spark.graphx.Graph._ import org.apache.spark.graphx.PartitionStrategy._ import org.apache.spark.rdd._ import org.apache.spark.storage.StorageLevel import org.apache.spark.util.Utils -class GraphSuite extends FunSuite with LocalSparkContext { +class GraphSuite extends SparkFunSuite with LocalSparkContext { def starGraph(sc: SparkContext, n: Int): Graph[String, Int] = { Graph.fromEdgeTuples(sc.parallelize((1 to n).map(x => (0: VertexId, x: VertexId)), 3), "v") diff --git a/graphx/src/test/scala/org/apache/spark/graphx/PregelSuite.scala b/graphx/src/test/scala/org/apache/spark/graphx/PregelSuite.scala index 490b94429ea1f..8afa2d403b53f 100644 --- a/graphx/src/test/scala/org/apache/spark/graphx/PregelSuite.scala +++ b/graphx/src/test/scala/org/apache/spark/graphx/PregelSuite.scala @@ -17,12 +17,10 @@ package org.apache.spark.graphx -import org.scalatest.FunSuite - -import org.apache.spark.SparkContext +import org.apache.spark.{SparkContext, SparkFunSuite} import org.apache.spark.rdd._ -class PregelSuite extends FunSuite with LocalSparkContext { +class PregelSuite extends SparkFunSuite with LocalSparkContext { test("1 iteration") { withSpark { sc => diff --git a/graphx/src/test/scala/org/apache/spark/graphx/VertexRDDSuite.scala b/graphx/src/test/scala/org/apache/spark/graphx/VertexRDDSuite.scala index d0a7198d691d7..f1aa685a79c98 100644 --- a/graphx/src/test/scala/org/apache/spark/graphx/VertexRDDSuite.scala +++ b/graphx/src/test/scala/org/apache/spark/graphx/VertexRDDSuite.scala @@ -17,13 +17,11 @@ package org.apache.spark.graphx -import org.scalatest.FunSuite - -import org.apache.spark.{HashPartitioner, SparkContext} +import org.apache.spark.{HashPartitioner, SparkContext, SparkFunSuite} import org.apache.spark.rdd.RDD import org.apache.spark.storage.StorageLevel -class VertexRDDSuite extends FunSuite with LocalSparkContext { +class VertexRDDSuite extends SparkFunSuite with LocalSparkContext { private def vertices(sc: SparkContext, n: Int) = { VertexRDD(sc.parallelize((0 to n).map(x => (x.toLong, x)), 5)) diff --git a/graphx/src/test/scala/org/apache/spark/graphx/impl/EdgePartitionSuite.scala b/graphx/src/test/scala/org/apache/spark/graphx/impl/EdgePartitionSuite.scala index 515f3a9cd02eb..7435647c6d9ee 100644 --- a/graphx/src/test/scala/org/apache/spark/graphx/impl/EdgePartitionSuite.scala +++ b/graphx/src/test/scala/org/apache/spark/graphx/impl/EdgePartitionSuite.scala @@ -20,15 +20,13 @@ package org.apache.spark.graphx.impl import scala.reflect.ClassTag import scala.util.Random -import org.scalatest.FunSuite - -import org.apache.spark.SparkConf +import org.apache.spark.{SparkConf, SparkFunSuite} import org.apache.spark.serializer.JavaSerializer import org.apache.spark.serializer.KryoSerializer import org.apache.spark.graphx._ -class EdgePartitionSuite extends FunSuite { +class EdgePartitionSuite extends SparkFunSuite { def makeEdgePartition[A: ClassTag](xs: Iterable[(Int, Int, A)]): EdgePartition[A, Int] = { val builder = new EdgePartitionBuilder[A, Int] diff --git a/graphx/src/test/scala/org/apache/spark/graphx/impl/VertexPartitionSuite.scala b/graphx/src/test/scala/org/apache/spark/graphx/impl/VertexPartitionSuite.scala index fe8304c1cdc32..1203f8959f506 100644 --- a/graphx/src/test/scala/org/apache/spark/graphx/impl/VertexPartitionSuite.scala +++ b/graphx/src/test/scala/org/apache/spark/graphx/impl/VertexPartitionSuite.scala @@ -17,15 +17,13 @@ package org.apache.spark.graphx.impl -import org.scalatest.FunSuite - -import org.apache.spark.SparkConf +import org.apache.spark.{SparkConf, SparkFunSuite} import org.apache.spark.serializer.JavaSerializer import org.apache.spark.serializer.KryoSerializer import org.apache.spark.graphx._ -class VertexPartitionSuite extends FunSuite { +class VertexPartitionSuite extends SparkFunSuite { test("isDefined, filter") { val vp = VertexPartition(Iterator((0L, 1), (1L, 1))).filter { (vid, attr) => vid == 0 } diff --git a/graphx/src/test/scala/org/apache/spark/graphx/lib/ConnectedComponentsSuite.scala b/graphx/src/test/scala/org/apache/spark/graphx/lib/ConnectedComponentsSuite.scala index accccfc232cd3..c965a6eb8df13 100644 --- a/graphx/src/test/scala/org/apache/spark/graphx/lib/ConnectedComponentsSuite.scala +++ b/graphx/src/test/scala/org/apache/spark/graphx/lib/ConnectedComponentsSuite.scala @@ -17,16 +17,14 @@ package org.apache.spark.graphx.lib -import org.scalatest.FunSuite - -import org.apache.spark.SparkContext +import org.apache.spark.{SparkContext, SparkFunSuite} import org.apache.spark.SparkContext._ import org.apache.spark.graphx._ import org.apache.spark.graphx.util.GraphGenerators import org.apache.spark.rdd._ -class ConnectedComponentsSuite extends FunSuite with LocalSparkContext { +class ConnectedComponentsSuite extends SparkFunSuite with LocalSparkContext { test("Grid Connected Components") { withSpark { sc => diff --git a/graphx/src/test/scala/org/apache/spark/graphx/lib/LabelPropagationSuite.scala b/graphx/src/test/scala/org/apache/spark/graphx/lib/LabelPropagationSuite.scala index 61fd0c4605568..808877f0590f8 100644 --- a/graphx/src/test/scala/org/apache/spark/graphx/lib/LabelPropagationSuite.scala +++ b/graphx/src/test/scala/org/apache/spark/graphx/lib/LabelPropagationSuite.scala @@ -17,11 +17,10 @@ package org.apache.spark.graphx.lib -import org.scalatest.FunSuite - +import org.apache.spark.SparkFunSuite import org.apache.spark.graphx._ -class LabelPropagationSuite extends FunSuite with LocalSparkContext { +class LabelPropagationSuite extends SparkFunSuite with LocalSparkContext { test("Label Propagation") { withSpark { sc => // Construct a graph with two cliques connected by a single edge diff --git a/graphx/src/test/scala/org/apache/spark/graphx/lib/PageRankSuite.scala b/graphx/src/test/scala/org/apache/spark/graphx/lib/PageRankSuite.scala index 39c6ace912b00..45f1e3011035e 100644 --- a/graphx/src/test/scala/org/apache/spark/graphx/lib/PageRankSuite.scala +++ b/graphx/src/test/scala/org/apache/spark/graphx/lib/PageRankSuite.scala @@ -17,8 +17,7 @@ package org.apache.spark.graphx.lib -import org.scalatest.FunSuite - +import org.apache.spark.SparkFunSuite import org.apache.spark.graphx._ import org.apache.spark.graphx.util.GraphGenerators @@ -57,7 +56,7 @@ object GridPageRank { } -class PageRankSuite extends FunSuite with LocalSparkContext { +class PageRankSuite extends SparkFunSuite with LocalSparkContext { def compareRanks(a: VertexRDD[Double], b: VertexRDD[Double]): Double = { a.leftJoin(b) { case (id, a, bOpt) => (a - bOpt.getOrElse(0.0)) * (a - bOpt.getOrElse(0.0)) } diff --git a/graphx/src/test/scala/org/apache/spark/graphx/lib/SVDPlusPlusSuite.scala b/graphx/src/test/scala/org/apache/spark/graphx/lib/SVDPlusPlusSuite.scala index 7bd6b7f3c4ab2..2991438f5e57e 100644 --- a/graphx/src/test/scala/org/apache/spark/graphx/lib/SVDPlusPlusSuite.scala +++ b/graphx/src/test/scala/org/apache/spark/graphx/lib/SVDPlusPlusSuite.scala @@ -17,12 +17,11 @@ package org.apache.spark.graphx.lib -import org.scalatest.FunSuite - +import org.apache.spark.SparkFunSuite import org.apache.spark.graphx._ -class SVDPlusPlusSuite extends FunSuite with LocalSparkContext { +class SVDPlusPlusSuite extends SparkFunSuite with LocalSparkContext { test("Test SVD++ with mean square error on training set") { withSpark { sc => diff --git a/graphx/src/test/scala/org/apache/spark/graphx/lib/ShortestPathsSuite.scala b/graphx/src/test/scala/org/apache/spark/graphx/lib/ShortestPathsSuite.scala index f2c38e79c452c..d7eaa70ce6407 100644 --- a/graphx/src/test/scala/org/apache/spark/graphx/lib/ShortestPathsSuite.scala +++ b/graphx/src/test/scala/org/apache/spark/graphx/lib/ShortestPathsSuite.scala @@ -17,16 +17,14 @@ package org.apache.spark.graphx.lib -import org.scalatest.FunSuite - -import org.apache.spark.SparkContext +import org.apache.spark.{SparkContext, SparkFunSuite} import org.apache.spark.SparkContext._ import org.apache.spark.graphx._ import org.apache.spark.graphx.lib._ import org.apache.spark.graphx.util.GraphGenerators import org.apache.spark.rdd._ -class ShortestPathsSuite extends FunSuite with LocalSparkContext { +class ShortestPathsSuite extends SparkFunSuite with LocalSparkContext { test("Shortest Path Computations") { withSpark { sc => diff --git a/graphx/src/test/scala/org/apache/spark/graphx/lib/StronglyConnectedComponentsSuite.scala b/graphx/src/test/scala/org/apache/spark/graphx/lib/StronglyConnectedComponentsSuite.scala index 1f658c371ffcf..d6b03208180db 100644 --- a/graphx/src/test/scala/org/apache/spark/graphx/lib/StronglyConnectedComponentsSuite.scala +++ b/graphx/src/test/scala/org/apache/spark/graphx/lib/StronglyConnectedComponentsSuite.scala @@ -17,16 +17,14 @@ package org.apache.spark.graphx.lib -import org.scalatest.FunSuite - -import org.apache.spark.SparkContext +import org.apache.spark.{SparkContext, SparkFunSuite} import org.apache.spark.SparkContext._ import org.apache.spark.graphx._ import org.apache.spark.graphx.util.GraphGenerators import org.apache.spark.rdd._ -class StronglyConnectedComponentsSuite extends FunSuite with LocalSparkContext { +class StronglyConnectedComponentsSuite extends SparkFunSuite with LocalSparkContext { test("Island Strongly Connected Components") { withSpark { sc => diff --git a/graphx/src/test/scala/org/apache/spark/graphx/lib/TriangleCountSuite.scala b/graphx/src/test/scala/org/apache/spark/graphx/lib/TriangleCountSuite.scala index 79bf4e6cd18ee..c47552cf3a3bd 100644 --- a/graphx/src/test/scala/org/apache/spark/graphx/lib/TriangleCountSuite.scala +++ b/graphx/src/test/scala/org/apache/spark/graphx/lib/TriangleCountSuite.scala @@ -17,13 +17,12 @@ package org.apache.spark.graphx.lib -import org.scalatest.FunSuite - +import org.apache.spark.SparkFunSuite import org.apache.spark.graphx._ import org.apache.spark.graphx.PartitionStrategy.RandomVertexCut -class TriangleCountSuite extends FunSuite with LocalSparkContext { +class TriangleCountSuite extends SparkFunSuite with LocalSparkContext { test("Count a single triangle") { withSpark { sc => diff --git a/graphx/src/test/scala/org/apache/spark/graphx/util/BytecodeUtilsSuite.scala b/graphx/src/test/scala/org/apache/spark/graphx/util/BytecodeUtilsSuite.scala index f3b3738db0dad..186d0cc2a977b 100644 --- a/graphx/src/test/scala/org/apache/spark/graphx/util/BytecodeUtilsSuite.scala +++ b/graphx/src/test/scala/org/apache/spark/graphx/util/BytecodeUtilsSuite.scala @@ -17,10 +17,10 @@ package org.apache.spark.graphx.util -import org.scalatest.FunSuite +import org.apache.spark.SparkFunSuite -class BytecodeUtilsSuite extends FunSuite { +class BytecodeUtilsSuite extends SparkFunSuite { import BytecodeUtilsSuite.TestClass diff --git a/graphx/src/test/scala/org/apache/spark/graphx/util/GraphGeneratorsSuite.scala b/graphx/src/test/scala/org/apache/spark/graphx/util/GraphGeneratorsSuite.scala index 8d9c8ddccbb3c..32e0c841c6997 100644 --- a/graphx/src/test/scala/org/apache/spark/graphx/util/GraphGeneratorsSuite.scala +++ b/graphx/src/test/scala/org/apache/spark/graphx/util/GraphGeneratorsSuite.scala @@ -17,11 +17,10 @@ package org.apache.spark.graphx.util -import org.scalatest.FunSuite - +import org.apache.spark.SparkFunSuite import org.apache.spark.graphx.LocalSparkContext -class GraphGeneratorsSuite extends FunSuite with LocalSparkContext { +class GraphGeneratorsSuite extends SparkFunSuite with LocalSparkContext { test("GraphGenerators.generateRandomEdges") { val src = 5 diff --git a/launcher/src/main/java/org/apache/spark/launcher/AbstractCommandBuilder.java b/launcher/src/main/java/org/apache/spark/launcher/AbstractCommandBuilder.java index 33fd813f7a86c..33d65d13f0d25 100644 --- a/launcher/src/main/java/org/apache/spark/launcher/AbstractCommandBuilder.java +++ b/launcher/src/main/java/org/apache/spark/launcher/AbstractCommandBuilder.java @@ -296,6 +296,9 @@ Properties loadPropertiesFile() throws IOException { try { fd = new FileInputStream(propsFile); props.load(new InputStreamReader(fd, "UTF-8")); + for (Map.Entry e : props.entrySet()) { + e.setValue(e.getValue().toString().trim()); + } } finally { if (fd != null) { try { diff --git a/mllib/pom.xml b/mllib/pom.xml index 0c07ca1a62fd3..65c647a91d192 100644 --- a/mllib/pom.xml +++ b/mllib/pom.xml @@ -40,6 +40,13 @@ spark-core_${scala.binary.version} ${project.version} + + org.apache.spark + spark-core_${scala.binary.version} + ${project.version} + test-jar + test + org.apache.spark spark-streaming_${scala.binary.version} diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala index d8592eb2d947d..62f4b51f770e9 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala @@ -208,7 +208,7 @@ private[ml] object GBTClassificationModel { require(oldModel.algo == OldAlgo.Classification, "Cannot convert GradientBoostedTreesModel" + s" with algo=${oldModel.algo} (old API) to GBTClassificationModel (new API).") val newTrees = oldModel.trees.map { tree => - // parent, fittingParamMap for each tree is null since there are no good ways to set these. + // parent for each tree is null since there is no good way to set this. DecisionTreeRegressionModel.fromOld(tree, null, categoricalFeatures) } val uid = if (parent != null) parent.uid else Identifiable.randomUID("gbtc") diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/OneVsRest.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/OneVsRest.scala index b8c7f3c5bc3b9..825f9ed1b54b2 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/OneVsRest.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/OneVsRest.scala @@ -37,11 +37,13 @@ import org.apache.spark.storage.StorageLevel */ private[ml] trait OneVsRestParams extends PredictorParams { + // scalastyle:off structural.type type ClassifierType = Classifier[F, E, M] forSome { type F type M <: ClassificationModel[F, M] type E <: Classifier[F, E, M] } + // scalastyle:on structural.type /** * param for the base binary classifier that we reduce multiclass classification into. @@ -129,6 +131,7 @@ final class OneVsRestModel private[ml] ( // output label and label metadata as prediction val labelUdf = callUDF(label, DoubleType, col(accColName)) aggregatedDataset.withColumn($(predictionCol), labelUdf.as($(predictionCol), labelMetadata)) + .drop(accColName) } } diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala index 67600ebd7b38e..852a67e066322 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala @@ -170,7 +170,7 @@ private[ml] object RandomForestClassificationModel { require(oldModel.algo == OldAlgo.Classification, "Cannot convert RandomForestModel" + s" with algo=${oldModel.algo} (old API) to RandomForestClassificationModel (new API).") val newTrees = oldModel.trees.map { tree => - // parent, fittingParamMap for each tree is null since there are no good ways to set these. + // parent for each tree is null since there is no good way to set this. DecisionTreeClassificationModel.fromOld(tree, null, categoricalFeatures) } val uid = if (parent != null) parent.uid else Identifiable.randomUID("rfc") diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/OneHotEncoder.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/OneHotEncoder.scala index eb6ec49f854be..8f34878c8d329 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/OneHotEncoder.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/OneHotEncoder.scala @@ -17,94 +17,152 @@ package org.apache.spark.ml.feature -import org.apache.spark.SparkException import org.apache.spark.annotation.Experimental -import org.apache.spark.ml.UnaryTransformer -import org.apache.spark.ml.attribute.{Attribute, BinaryAttribute, NominalAttribute} +import org.apache.spark.ml.Transformer +import org.apache.spark.ml.attribute._ import org.apache.spark.ml.param._ import org.apache.spark.ml.param.shared.{HasInputCol, HasOutputCol} import org.apache.spark.ml.util.{Identifiable, SchemaUtils} -import org.apache.spark.mllib.linalg.{Vector, VectorUDT, Vectors} -import org.apache.spark.sql.types.{DataType, DoubleType, StructType} +import org.apache.spark.mllib.linalg.Vectors +import org.apache.spark.sql.DataFrame +import org.apache.spark.sql.functions.{col, udf} +import org.apache.spark.sql.types.{DoubleType, StructType} /** * :: Experimental :: - * A one-hot encoder that maps a column of label indices to a column of binary vectors, with - * at most a single one-value. By default, the binary vector has an element for each category, so - * with 5 categories, an input value of 2.0 would map to an output vector of - * (0.0, 0.0, 1.0, 0.0, 0.0). If includeFirst is set to false, the first category is omitted, so the - * output vector for the previous example would be (0.0, 1.0, 0.0, 0.0) and an input value - * of 0.0 would map to a vector of all zeros. Including the first category makes the vector columns - * linearly dependent because they sum up to one. + * A one-hot encoder that maps a column of category indices to a column of binary vectors, with + * at most a single one-value per row that indicates the input category index. + * For example with 5 categories, an input value of 2.0 would map to an output vector of + * `[0.0, 0.0, 1.0, 0.0]`. + * The last category is not included by default (configurable via [[OneHotEncoder!.dropLast]] + * because it makes the vector entries sum up to one, and hence linearly dependent. + * So an input value of 4.0 maps to `[0.0, 0.0, 0.0, 0.0]`. + * Note that this is different from scikit-learn's OneHotEncoder, which keeps all categories. + * The output vectors are sparse. + * + * @see [[StringIndexer]] for converting categorical values into category indices */ @Experimental -class OneHotEncoder(override val uid: String) - extends UnaryTransformer[Double, Vector, OneHotEncoder] with HasInputCol with HasOutputCol { +class OneHotEncoder(override val uid: String) extends Transformer + with HasInputCol with HasOutputCol { def this() = this(Identifiable.randomUID("oneHot")) /** - * Whether to include a component in the encoded vectors for the first category, defaults to true. + * Whether to drop the last category in the encoded vector (default: true) * @group param */ - final val includeFirst: BooleanParam = - new BooleanParam(this, "includeFirst", "include first category") - setDefault(includeFirst -> true) - - private var categories: Array[String] = _ + final val dropLast: BooleanParam = + new BooleanParam(this, "dropLast", "whether to drop the last category") + setDefault(dropLast -> true) /** @group setParam */ - def setIncludeFirst(value: Boolean): this.type = set(includeFirst, value) + def setDropLast(value: Boolean): this.type = set(dropLast, value) /** @group setParam */ - override def setInputCol(value: String): this.type = set(inputCol, value) + def setInputCol(value: String): this.type = set(inputCol, value) /** @group setParam */ - override def setOutputCol(value: String): this.type = set(outputCol, value) + def setOutputCol(value: String): this.type = set(outputCol, value) override def transformSchema(schema: StructType): StructType = { - SchemaUtils.checkColumnType(schema, $(inputCol), DoubleType) - val inputFields = schema.fields + val is = "_is_" + val inputColName = $(inputCol) val outputColName = $(outputCol) - require(inputFields.forall(_.name != $(outputCol)), - s"Output column ${$(outputCol)} already exists.") - val inputColAttr = Attribute.fromStructField(schema($(inputCol))) - categories = inputColAttr match { + SchemaUtils.checkColumnType(schema, inputColName, DoubleType) + val inputFields = schema.fields + require(!inputFields.exists(_.name == outputColName), + s"Output column $outputColName already exists.") + + val inputAttr = Attribute.fromStructField(schema(inputColName)) + val outputAttrNames: Option[Array[String]] = inputAttr match { case nominal: NominalAttribute => - nominal.values.getOrElse((0 until nominal.numValues.get).map(_.toString).toArray) - case binary: BinaryAttribute => binary.values.getOrElse(Array("0", "1")) + if (nominal.values.isDefined) { + nominal.values.map(_.map(v => inputColName + is + v)) + } else if (nominal.numValues.isDefined) { + nominal.numValues.map(n => Array.tabulate(n)(i => inputColName + is + i)) + } else { + None + } + case binary: BinaryAttribute => + if (binary.values.isDefined) { + binary.values.map(_.map(v => inputColName + is + v)) + } else { + Some(Array.tabulate(2)(i => inputColName + is + i)) + } + case _: NumericAttribute => + throw new RuntimeException( + s"The input column $inputColName cannot be numeric.") case _ => - throw new SparkException(s"OneHotEncoder input column ${$(inputCol)} is not nominal") + None // optimistic about unknown attributes } - val attrValues = (if ($(includeFirst)) categories else categories.drop(1)).toArray - val attr = NominalAttribute.defaultAttr.withName(outputColName).withValues(attrValues) - val outputFields = inputFields :+ attr.toStructField() + val filteredOutputAttrNames = outputAttrNames.map { names => + if ($(dropLast)) { + require(names.length > 1, + s"The input column $inputColName should have at least two distinct values.") + names.dropRight(1) + } else { + names + } + } + + val outputAttrGroup = if (filteredOutputAttrNames.isDefined) { + val attrs: Array[Attribute] = filteredOutputAttrNames.get.map { name => + BinaryAttribute.defaultAttr.withName(name) + } + new AttributeGroup($(outputCol), attrs) + } else { + new AttributeGroup($(outputCol)) + } + + val outputFields = inputFields :+ outputAttrGroup.toStructField() StructType(outputFields) } - protected override def createTransformFunc(): (Double) => Vector = { - val first = $(includeFirst) - val vecLen = if (first) categories.length else categories.length - 1 + override def transform(dataset: DataFrame): DataFrame = { + // schema transformation + val is = "_is_" + val inputColName = $(inputCol) + val outputColName = $(outputCol) + val shouldDropLast = $(dropLast) + var outputAttrGroup = AttributeGroup.fromStructField( + transformSchema(dataset.schema)(outputColName)) + if (outputAttrGroup.size < 0) { + // If the number of attributes is unknown, we check the values from the input column. + val numAttrs = dataset.select(col(inputColName).cast(DoubleType)).map(_.getDouble(0)) + .aggregate(0.0)( + (m, x) => { + assert(x >=0.0 && x == x.toInt, + s"Values from column $inputColName must be indices, but got $x.") + math.max(m, x) + }, + (m0, m1) => { + math.max(m0, m1) + } + ).toInt + 1 + val outputAttrNames = Array.tabulate(numAttrs)(i => inputColName + is + i) + val filtered = if (shouldDropLast) outputAttrNames.dropRight(1) else outputAttrNames + val outputAttrs: Array[Attribute] = + filtered.map(name => BinaryAttribute.defaultAttr.withName(name)) + outputAttrGroup = new AttributeGroup(outputColName, outputAttrs) + } + val metadata = outputAttrGroup.toMetadata() + + // data transformation + val size = outputAttrGroup.size val oneValue = Array(1.0) val emptyValues = Array[Double]() val emptyIndices = Array[Int]() - label: Double => { - val values = if (first || label != 0.0) oneValue else emptyValues - val indices = if (first) { - Array(label.toInt) - } else if (label != 0.0) { - Array(label.toInt - 1) + val encode = udf { label: Double => + if (label < size) { + Vectors.sparse(size, Array(label.toInt), oneValue) } else { - emptyIndices + Vectors.sparse(size, emptyIndices, emptyValues) } - Vectors.sparse(vecLen, indices, values) } - } - /** - * Returns the data type of the output column. - */ - protected def outputDataType: DataType = new VectorUDT + dataset.select(col("*"), encode(col(inputColName).cast(DoubleType)).as(outputColName, metadata)) + } } diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/StandardScaler.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/StandardScaler.scala index fdd2494fc87a6..b0fd06d84fdb3 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/StandardScaler.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/StandardScaler.scala @@ -35,13 +35,13 @@ private[feature] trait StandardScalerParams extends Params with HasInputCol with /** * Centers the data with mean before scaling. - * It will build a dense output, so this does not work on sparse input + * It will build a dense output, so this does not work on sparse input * and will raise an exception. * Default: false * @group param */ val withMean: BooleanParam = new BooleanParam(this, "withMean", "Center data with mean") - + /** * Scales the data to unit standard deviation. * Default: true @@ -68,13 +68,13 @@ class StandardScaler(override val uid: String) extends Estimator[StandardScalerM /** @group setParam */ def setOutputCol(value: String): this.type = set(outputCol, value) - + /** @group setParam */ def setWithMean(value: Boolean): this.type = set(withMean, value) - + /** @group setParam */ def setWithStd(value: Boolean): this.type = set(withStd, value) - + override def fit(dataset: DataFrame): StandardScalerModel = { transformSchema(dataset.schema, logging = true) val input = dataset.select($(inputCol)).map { case Row(v: Vector) => v } diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala index 69f4f5414c8c6..b7e374bb6cb49 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala @@ -198,7 +198,7 @@ private[ml] object GBTRegressionModel { require(oldModel.algo == OldAlgo.Regression, "Cannot convert GradientBoostedTreesModel" + s" with algo=${oldModel.algo} (old API) to GBTRegressionModel (new API).") val newTrees = oldModel.trees.map { tree => - // parent, fittingParamMap for each tree is null since there are no good ways to set these. + // parent for each tree is null since there is no good way to set this. DecisionTreeRegressionModel.fromOld(tree, null, categoricalFeatures) } val uid = if (parent != null) parent.uid else Identifiable.randomUID("gbtr") diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala index 7c40db1a40040..fe2a71a331694 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala @@ -321,7 +321,7 @@ private class LeastSquaresAggregator( } (weightsArray, -sum + labelMean / labelStd, weightsArray.length) } - + private val effectiveWeightsVector = Vectors.dense(effectiveWeightsArray) private val gradientSumArray = Array.ofDim[Double](dim) diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/RandomForestRegressor.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/RandomForestRegressor.scala index ae767a17329d2..49a1f7ce8c995 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/RandomForestRegressor.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/RandomForestRegressor.scala @@ -152,7 +152,7 @@ private[ml] object RandomForestRegressionModel { require(oldModel.algo == OldAlgo.Regression, "Cannot convert RandomForestModel" + s" with algo=${oldModel.algo} (old API) to RandomForestRegressionModel (new API).") val newTrees = oldModel.trees.map { tree => - // parent, fittingParamMap for each tree is null since there are no good ways to set these. + // parent for each tree is null since there is no good way to set this. DecisionTreeRegressionModel.fromOld(tree, null, categoricalFeatures) } new RandomForestRegressionModel(parent.uid, newTrees) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala b/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala index 65f30fdba7393..16f3131796709 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala @@ -399,7 +399,7 @@ private[python] class PythonMLLibAPI extends Serializable { val sigma = si.map(_.asInstanceOf[DenseMatrix]) val gaussians = Array.tabulate(weight.length){ i => new MultivariateGaussian(mean(i), sigma(i)) - } + } val model = new GaussianMixtureModel(weight, gaussians) model.predictSoft(data).map(Vectors.dense) } @@ -494,7 +494,7 @@ private[python] class PythonMLLibAPI extends Serializable { def normalizeVector(p: Double, rdd: JavaRDD[Vector]): JavaRDD[Vector] = { new Normalizer(p).transform(rdd) } - + /** * Java stub for StandardScaler.fit(). This stub returns a * handle to the Java object instead of the content of the Java object. diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/GaussianMixture.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/GaussianMixture.scala index e9a23e40cc790..70b0e40948e51 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/GaussianMixture.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/GaussianMixture.scala @@ -36,11 +36,11 @@ import org.apache.spark.util.Utils * independent Gaussian distributions with associated "mixing" weights * specifying each's contribution to the composite. * - * Given a set of sample points, this class will maximize the log-likelihood - * for a mixture of k Gaussians, iterating until the log-likelihood changes by + * Given a set of sample points, this class will maximize the log-likelihood + * for a mixture of k Gaussians, iterating until the log-likelihood changes by * less than convergenceTol, or until it has reached the max number of iterations. * While this process is generally guaranteed to converge, it is not guaranteed - * to find a global optimum. + * to find a global optimum. * * Note: For high-dimensional data (with many features), this algorithm may perform poorly. * This is due to high-dimensional data (a) making it difficult to cluster at all (based @@ -53,24 +53,24 @@ import org.apache.spark.util.Utils */ @Experimental class GaussianMixture private ( - private var k: Int, - private var convergenceTol: Double, + private var k: Int, + private var convergenceTol: Double, private var maxIterations: Int, private var seed: Long) extends Serializable { - + /** * Constructs a default instance. The default parameters are {k: 2, convergenceTol: 0.01, * maxIterations: 100, seed: random}. */ def this() = this(2, 0.01, 100, Utils.random.nextLong()) - + // number of samples per cluster to use when initializing Gaussians private val nSamples = 5 - - // an initializing GMM can be provided rather than using the + + // an initializing GMM can be provided rather than using the // default random starting point private var initialModel: Option[GaussianMixtureModel] = None - + /** Set the initial GMM starting point, bypassing the random initialization. * You must call setK() prior to calling this method, and the condition * (model.k == this.k) must be met; failure will result in an IllegalArgumentException @@ -83,37 +83,37 @@ class GaussianMixture private ( } this } - + /** Return the user supplied initial GMM, if supplied */ def getInitialModel: Option[GaussianMixtureModel] = initialModel - + /** Set the number of Gaussians in the mixture model. Default: 2 */ def setK(k: Int): this.type = { this.k = k this } - + /** Return the number of Gaussians in the mixture model */ def getK: Int = k - + /** Set the maximum number of iterations to run. Default: 100 */ def setMaxIterations(maxIterations: Int): this.type = { this.maxIterations = maxIterations this } - + /** Return the maximum number of iterations to run */ def getMaxIterations: Int = maxIterations - + /** - * Set the largest change in log-likelihood at which convergence is + * Set the largest change in log-likelihood at which convergence is * considered to have occurred. */ def setConvergenceTol(convergenceTol: Double): this.type = { this.convergenceTol = convergenceTol this } - + /** * Return the largest change in log-likelihood at which convergence is * considered to have occurred. @@ -132,41 +132,41 @@ class GaussianMixture private ( /** Perform expectation maximization */ def run(data: RDD[Vector]): GaussianMixtureModel = { val sc = data.sparkContext - + // we will operate on the data as breeze data val breezeData = data.map(_.toBreeze).cache() - + // Get length of the input vectors val d = breezeData.first().length - + // Determine initial weights and corresponding Gaussians. // If the user supplied an initial GMM, we use those values, otherwise // we start with uniform weights, a random mean from the data, and // diagonal covariance matrices using component variances - // derived from the samples + // derived from the samples val (weights, gaussians) = initialModel match { case Some(gmm) => (gmm.weights, gmm.gaussians) - + case None => { val samples = breezeData.takeSample(withReplacement = true, k * nSamples, seed) - (Array.fill(k)(1.0 / k), Array.tabulate(k) { i => + (Array.fill(k)(1.0 / k), Array.tabulate(k) { i => val slice = samples.view(i * nSamples, (i + 1) * nSamples) - new MultivariateGaussian(vectorMean(slice), initCovariance(slice)) + new MultivariateGaussian(vectorMean(slice), initCovariance(slice)) }) } } - - var llh = Double.MinValue // current log-likelihood + + var llh = Double.MinValue // current log-likelihood var llhp = 0.0 // previous log-likelihood - + var iter = 0 while (iter < maxIterations && math.abs(llh-llhp) > convergenceTol) { // create and broadcast curried cluster contribution function val compute = sc.broadcast(ExpectationSum.add(weights, gaussians)_) - + // aggregate the cluster contribution for all sample points val sums = breezeData.aggregate(ExpectationSum.zero(k, d))(compute.value, _ += _) - + // Create new distributions based on the partial assignments // (often referred to as the "M" step in literature) val sumWeights = sums.weights.sum @@ -179,22 +179,22 @@ class GaussianMixture private ( gaussians(i) = new MultivariateGaussian(mu, sums.sigmas(i) / sums.weights(i)) i = i + 1 } - + llhp = llh // current becomes previous llh = sums.logLikelihood // this is the freshly computed log-likelihood iter += 1 - } - + } + new GaussianMixtureModel(weights, gaussians) } - + /** Average of dense breeze vectors */ private def vectorMean(x: IndexedSeq[BV[Double]]): BDV[Double] = { val v = BDV.zeros[Double](x(0).length) x.foreach(xi => v += xi) - v / x.length.toDouble + v / x.length.toDouble } - + /** * Construct matrix where diagonal entries are element-wise * variance of input vectors (computes biased variance) @@ -210,14 +210,14 @@ class GaussianMixture private ( // companion class to provide zero constructor for ExpectationSum private object ExpectationSum { def zero(k: Int, d: Int): ExpectationSum = { - new ExpectationSum(0.0, Array.fill(k)(0.0), + new ExpectationSum(0.0, Array.fill(k)(0.0), Array.fill(k)(BDV.zeros(d)), Array.fill(k)(BreezeMatrix.zeros(d, d))) } - + // compute cluster contributions for each input point // (U, T) => U for aggregation def add( - weights: Array[Double], + weights: Array[Double], dists: Array[MultivariateGaussian]) (sums: ExpectationSum, x: BV[Double]): ExpectationSum = { val p = weights.zip(dists).map { @@ -235,7 +235,7 @@ private object ExpectationSum { i = i + 1 } sums - } + } } // Aggregation class for partial expectation results @@ -244,9 +244,9 @@ private class ExpectationSum( val weights: Array[Double], val means: Array[BDV[Double]], val sigmas: Array[BreezeMatrix[Double]]) extends Serializable { - + val k = weights.length - + def +=(x: ExpectationSum): ExpectationSum = { var i = 0 while (i < k) { @@ -257,5 +257,5 @@ private class ExpectationSum( } logLikelihood += x.logLikelihood this - } + } } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/GaussianMixtureModel.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/GaussianMixtureModel.scala index 86353aed81156..5fc2cb1b62d33 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/GaussianMixtureModel.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/GaussianMixtureModel.scala @@ -34,10 +34,10 @@ import org.apache.spark.sql.{SQLContext, Row} /** * :: Experimental :: * - * Multivariate Gaussian Mixture Model (GMM) consisting of k Gaussians, where points - * are drawn from each Gaussian i=1..k with probability w(i); mu(i) and sigma(i) are - * the respective mean and covariance for each Gaussian distribution i=1..k. - * + * Multivariate Gaussian Mixture Model (GMM) consisting of k Gaussians, where points + * are drawn from each Gaussian i=1..k with probability w(i); mu(i) and sigma(i) are + * the respective mean and covariance for each Gaussian distribution i=1..k. + * * @param weights Weights for each Gaussian distribution in the mixture, where weights(i) is * the weight for Gaussian i, and weights.sum == 1 * @param gaussians Array of MultivariateGaussian where gaussians(i) represents @@ -45,9 +45,9 @@ import org.apache.spark.sql.{SQLContext, Row} */ @Experimental class GaussianMixtureModel( - val weights: Array[Double], + val weights: Array[Double], val gaussians: Array[MultivariateGaussian]) extends Serializable with Saveable{ - + require(weights.length == gaussians.length, "Length of weight and Gaussian arrays must match") override protected def formatVersion = "1.0" @@ -64,20 +64,20 @@ class GaussianMixtureModel( val responsibilityMatrix = predictSoft(points) responsibilityMatrix.map(r => r.indexOf(r.max)) } - + /** * Given the input vectors, return the membership value of each vector - * to all mixture components. + * to all mixture components. */ def predictSoft(points: RDD[Vector]): RDD[Array[Double]] = { val sc = points.sparkContext val bcDists = sc.broadcast(gaussians) val bcWeights = sc.broadcast(weights) - points.map { x => + points.map { x => computeSoftAssignments(x.toBreeze.toDenseVector, bcDists.value, bcWeights.value, k) } } - + /** * Compute the partial assignments for each vector */ @@ -89,7 +89,7 @@ class GaussianMixtureModel( val p = weights.zip(dists).map { case (weight, dist) => MLUtils.EPSILON + weight * dist.pdf(pt) } - val pSum = p.sum + val pSum = p.sum for (i <- 0 until k) { p(i) /= pSum } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/PowerIterationClustering.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/PowerIterationClustering.scala index 1ed01c9d8ba0b..e7a243f854e33 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/PowerIterationClustering.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/PowerIterationClustering.scala @@ -121,7 +121,7 @@ class PowerIterationClustering private[clustering] ( import org.apache.spark.mllib.clustering.PowerIterationClustering._ /** Constructs a PIC instance with default parameters: {k: 2, maxIterations: 100, - * initMode: "random"}. + * initMode: "random"}. */ def this() = this(k = 2, maxIterations = 100, initMode = "random") @@ -243,7 +243,7 @@ object PowerIterationClustering extends Logging { /** * Generates random vertex properties (v0) to start power iteration. - * + * * @param g a graph representing the normalized affinity matrix (W) * @return a graph with edges representing W and vertices representing a random vector * with unit 1-norm @@ -266,7 +266,7 @@ object PowerIterationClustering extends Logging { * Generates the degree vector as the vertex properties (v0) to start power iteration. * It is not exactly the node degrees but just the normalized sum similarities. Call it * as degree vector because it is used in the PIC paper. - * + * * @param g a graph representing the normalized affinity matrix (W) * @return a graph with edges representing W and vertices representing the degree vector */ @@ -276,7 +276,7 @@ object PowerIterationClustering extends Logging { val v0 = g.vertices.mapValues(_ / sum) GraphImpl.fromExistingRDDs(VertexRDD(v0), g.edges) } - + /** * Runs power iteration. * @param g input graph with edges representing the normalized affinity matrix (W) and vertices diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/StreamingKMeans.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/StreamingKMeans.scala index 812014a041719..c21e4fe7dc9b6 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/StreamingKMeans.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/StreamingKMeans.scala @@ -178,7 +178,7 @@ class StreamingKMeans( /** Set the decay factor directly (for forgetful algorithms). */ def setDecayFactor(a: Double): this.type = { - this.decayFactor = decayFactor + this.decayFactor = a this } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/feature/ChiSqSelector.scala b/mllib/src/main/scala/org/apache/spark/mllib/feature/ChiSqSelector.scala index 9cc2d0ffcab7d..5f8c1dea237b4 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/feature/ChiSqSelector.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/feature/ChiSqSelector.scala @@ -108,7 +108,7 @@ class ChiSqSelectorModel (val selectedFeatures: Array[Int]) extends VectorTransf * (ordered by statistic value descending) */ @Experimental -class ChiSqSelector (val numTopFeatures: Int) { +class ChiSqSelector (val numTopFeatures: Int) extends Serializable { /** * Returns a ChiSquared feature selector. diff --git a/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala b/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala index 466ae95859b82..51546d41c36a6 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala @@ -42,7 +42,7 @@ import org.apache.spark.util.random.XORShiftRandom import org.apache.spark.sql.{SQLContext, Row} /** - * Entry in vocabulary + * Entry in vocabulary */ private case class VocabWord( var word: String, @@ -56,18 +56,18 @@ private case class VocabWord( * :: Experimental :: * Word2Vec creates vector representation of words in a text corpus. * The algorithm first constructs a vocabulary from the corpus - * and then learns vector representation of words in the vocabulary. - * The vector representation can be used as features in + * and then learns vector representation of words in the vocabulary. + * The vector representation can be used as features in * natural language processing and machine learning algorithms. - * - * We used skip-gram model in our implementation and hierarchical softmax + * + * We used skip-gram model in our implementation and hierarchical softmax * method to train the model. The variable names in the implementation * matches the original C implementation. * - * For original C implementation, see https://code.google.com/p/word2vec/ - * For research papers, see + * For original C implementation, see https://code.google.com/p/word2vec/ + * For research papers, see * Efficient Estimation of Word Representations in Vector Space - * and + * and * Distributed Representations of Words and Phrases and their Compositionality. */ @Experimental @@ -79,7 +79,7 @@ class Word2Vec extends Serializable with Logging { private var numIterations = 1 private var seed = Utils.random.nextLong() private var minCount = 5 - + /** * Sets vector size (default: 100). */ @@ -122,15 +122,15 @@ class Word2Vec extends Serializable with Logging { this } - /** - * Sets minCount, the minimum number of times a token must appear to be included in the word2vec + /** + * Sets minCount, the minimum number of times a token must appear to be included in the word2vec * model's vocabulary (default: 5). */ def setMinCount(minCount: Int): this.type = { this.minCount = minCount this } - + private val EXP_TABLE_SIZE = 1000 private val MAX_EXP = 6 private val MAX_CODE_LENGTH = 40 @@ -150,13 +150,13 @@ class Word2Vec extends Serializable with Logging { .map(x => VocabWord( x._1, x._2, - new Array[Int](MAX_CODE_LENGTH), - new Array[Int](MAX_CODE_LENGTH), + new Array[Int](MAX_CODE_LENGTH), + new Array[Int](MAX_CODE_LENGTH), 0)) .filter(_.cn >= minCount) .collect() .sortWith((a, b) => a.cn > b.cn) - + vocabSize = vocab.length require(vocabSize > 0, "The vocabulary size should be > 0. You may need to check " + "the setting of minCount, which could be large enough to remove all your words in sentences.") @@ -198,8 +198,8 @@ class Word2Vec extends Serializable with Logging { } var pos1 = vocabSize - 1 var pos2 = vocabSize - - var min1i = 0 + + var min1i = 0 var min2i = 0 a = 0 @@ -268,15 +268,15 @@ class Word2Vec extends Serializable with Logging { val words = dataset.flatMap(x => x) learnVocab(words) - + createBinaryTree() - + val sc = dataset.context val expTable = sc.broadcast(createExpTable()) val bcVocab = sc.broadcast(vocab) val bcVocabHash = sc.broadcast(vocabHash) - + val sentences: RDD[Array[Int]] = words.mapPartitions { iter => new Iterator[Array[Int]] { def hasNext: Boolean = iter.hasNext @@ -297,7 +297,7 @@ class Word2Vec extends Serializable with Logging { } } } - + val newSentences = sentences.repartition(numPartitions).cache() val initRandom = new XORShiftRandom(seed) @@ -402,7 +402,7 @@ class Word2Vec extends Serializable with Logging { } } newSentences.unpersist() - + val word2VecMap = mutable.HashMap.empty[String, Array[Float]] var i = 0 while (i < vocabSize) { @@ -480,7 +480,7 @@ class Word2VecModel private[mllib] ( /** * Transforms a word to its vector representation - * @param word a word + * @param word a word * @return vector representation of word */ def transform(word: String): Vector = { @@ -495,7 +495,7 @@ class Word2VecModel private[mllib] ( /** * Find synonyms of a word * @param word a word - * @param num number of synonyms to find + * @param num number of synonyms to find * @return array of (word, cosineSimilarity) */ def findSynonyms(word: String, num: Int): Array[(String, Double)] = { @@ -506,7 +506,7 @@ class Word2VecModel private[mllib] ( /** * Find synonyms of the vector representation of a word * @param vector vector representation of a word - * @param num number of synonyms to find + * @param num number of synonyms to find * @return array of (word, cosineSimilarity) */ def findSynonyms(vector: Vector, num: Int): Array[(String, Double)] = { diff --git a/mllib/src/main/scala/org/apache/spark/mllib/linalg/BLAS.scala b/mllib/src/main/scala/org/apache/spark/mllib/linalg/BLAS.scala index ec38529cf8fae..557119f7b1cd1 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/linalg/BLAS.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/linalg/BLAS.scala @@ -228,7 +228,7 @@ private[spark] object BLAS extends Serializable with Logging { } _nativeBLAS } - + /** * A := alpha * x * x^T^ + A * @param alpha a real scalar that will be multiplied to x * x^T^. @@ -264,7 +264,7 @@ private[spark] object BLAS extends Serializable with Logging { j += 1 } i += 1 - } + } } private def syr(alpha: Double, x: SparseVector, A: DenseMatrix) { @@ -505,7 +505,7 @@ private[spark] object BLAS extends Serializable with Logging { nativeBLAS.dgemv(tStrA, mA, nA, alpha, A.values, mA, x.values, 1, beta, y.values, 1) } - + /** * y := alpha * A * x + beta * y * For `DenseMatrix` A and `SparseVector` x. @@ -557,7 +557,7 @@ private[spark] object BLAS extends Serializable with Logging { } } } - + /** * y := alpha * A * x + beta * y * For `SparseMatrix` A and `SparseVector` x. diff --git a/mllib/src/main/scala/org/apache/spark/mllib/linalg/EigenValueDecomposition.scala b/mllib/src/main/scala/org/apache/spark/mllib/linalg/EigenValueDecomposition.scala index 866936aa4f118..ae3ba3099c878 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/linalg/EigenValueDecomposition.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/linalg/EigenValueDecomposition.scala @@ -81,7 +81,7 @@ private[mllib] object EigenValueDecomposition { require(n * ncv.toLong <= Integer.MAX_VALUE && ncv * (ncv.toLong + 8) <= Integer.MAX_VALUE, s"k = $k and/or n = $n are too large to compute an eigendecomposition") - + var ido = new intW(0) var info = new intW(0) var resid = new Array[Double](n) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/pmml/export/BinaryClassificationPMMLModelExport.scala b/mllib/src/main/scala/org/apache/spark/mllib/pmml/export/BinaryClassificationPMMLModelExport.scala index 34b447584e521..622b53a252ac5 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/pmml/export/BinaryClassificationPMMLModelExport.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/pmml/export/BinaryClassificationPMMLModelExport.scala @@ -27,10 +27,10 @@ import org.apache.spark.mllib.regression.GeneralizedLinearModel * PMML Model Export for GeneralizedLinearModel class with binary ClassificationModel */ private[mllib] class BinaryClassificationPMMLModelExport( - model : GeneralizedLinearModel, + model : GeneralizedLinearModel, description : String, normalizationMethod : RegressionNormalizationMethodType, - threshold: Double) + threshold: Double) extends PMMLModelExport { populateBinaryClassificationPMML() @@ -72,7 +72,7 @@ private[mllib] class BinaryClassificationPMMLModelExport( .withUsageType(FieldUsageType.ACTIVE)) regressionTableYES.withNumericPredictors(new NumericPredictor(fields(i), model.weights(i))) } - + // add target field val targetField = FieldName.create("target") dataDictionary @@ -80,9 +80,9 @@ private[mllib] class BinaryClassificationPMMLModelExport( miningSchema .withMiningFields(new MiningField(targetField) .withUsageType(FieldUsageType.TARGET)) - + dataDictionary.withNumberOfFields(dataDictionary.getDataFields.size) - + pmml.setDataDictionary(dataDictionary) pmml.withModels(regressionModel) } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/pmml/export/PMMLModelExport.scala b/mllib/src/main/scala/org/apache/spark/mllib/pmml/export/PMMLModelExport.scala index ebdeae50bb32f..c5fdecd3ca17f 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/pmml/export/PMMLModelExport.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/pmml/export/PMMLModelExport.scala @@ -25,7 +25,7 @@ import scala.beans.BeanProperty import org.dmg.pmml.{Application, Header, PMML, Timestamp} private[mllib] trait PMMLModelExport { - + /** * Holder of the exported model in PMML format */ @@ -33,7 +33,7 @@ private[mllib] trait PMMLModelExport { val pmml: PMML = new PMML setHeader(pmml) - + private def setHeader(pmml: PMML): Unit = { val version = getClass.getPackage.getImplementationVersion val app = new Application().withName("Apache Spark MLlib").withVersion(version) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/pmml/export/PMMLModelExportFactory.scala b/mllib/src/main/scala/org/apache/spark/mllib/pmml/export/PMMLModelExportFactory.scala index c16e83d6a067d..29bd689e1185a 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/pmml/export/PMMLModelExportFactory.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/pmml/export/PMMLModelExportFactory.scala @@ -27,9 +27,9 @@ import org.apache.spark.mllib.regression.LinearRegressionModel import org.apache.spark.mllib.regression.RidgeRegressionModel private[mllib] object PMMLModelExportFactory { - + /** - * Factory object to help creating the necessary PMMLModelExport implementation + * Factory object to help creating the necessary PMMLModelExport implementation * taking as input the machine learning model (for example KMeansModel). */ def createPMMLModelExport(model: Any): PMMLModelExport = { @@ -44,7 +44,7 @@ private[mllib] object PMMLModelExportFactory { new GeneralizedLinearPMMLModelExport(lasso, "lasso regression") case svm: SVMModel => new BinaryClassificationPMMLModelExport( - svm, "linear SVM", RegressionNormalizationMethodType.NONE, + svm, "linear SVM", RegressionNormalizationMethodType.NONE, svm.getThreshold.getOrElse(0.0)) case logistic: LogisticRegressionModel => if (logistic.numClasses == 2) { @@ -60,5 +60,5 @@ private[mllib] object PMMLModelExportFactory { "PMML Export not supported for model: " + model.getClass.getName) } } - + } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/random/RandomRDDs.scala b/mllib/src/main/scala/org/apache/spark/mllib/random/RandomRDDs.scala index 7db5a14fd45a5..174d5e0f6c9f0 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/random/RandomRDDs.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/random/RandomRDDs.scala @@ -234,7 +234,7 @@ object RandomRDDs { * * @param sc SparkContext used to create the RDD. * @param shape shape parameter (> 0) for the gamma distribution - * @param scale scale parameter (> 0) for the gamma distribution + * @param scale scale parameter (> 0) for the gamma distribution * @param size Size of the RDD. * @param numPartitions Number of partitions in the RDD (default: `sc.defaultParallelism`). * @param seed Random seed (default: a random long integer). @@ -293,7 +293,7 @@ object RandomRDDs { * * @param sc SparkContext used to create the RDD. * @param mean mean for the log normal distribution - * @param std standard deviation for the log normal distribution + * @param std standard deviation for the log normal distribution * @param size Size of the RDD. * @param numPartitions Number of partitions in the RDD (default: `sc.defaultParallelism`). * @param seed Random seed (default: a random long integer). @@ -671,7 +671,7 @@ object RandomRDDs { * * @param sc SparkContext used to create the RDD. * @param shape shape parameter (> 0) for the gamma distribution. - * @param scale scale parameter (> 0) for the gamma distribution. + * @param scale scale parameter (> 0) for the gamma distribution. * @param numRows Number of Vectors in the RDD. * @param numCols Number of elements in each Vector. * @param numPartitions Number of partitions in the RDD (default: `sc.defaultParallelism`) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/recommendation/ALS.scala b/mllib/src/main/scala/org/apache/spark/mllib/recommendation/ALS.scala index dddefe1944e9d..93290e6508529 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/recommendation/ALS.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/recommendation/ALS.scala @@ -175,7 +175,7 @@ class ALS private ( /** * :: DeveloperApi :: * Sets storage level for final RDDs (user/product used in MatrixFactorizationModel). The default - * value is `MEMORY_AND_DISK`. Users can change it to a serialized storage, e.g. + * value is `MEMORY_AND_DISK`. Users can change it to a serialized storage, e.g. * `MEMORY_AND_DISK_SER` and set `spark.rdd.compress` to `true` to reduce the space requirement, * at the cost of speed. */ diff --git a/mllib/src/main/scala/org/apache/spark/mllib/regression/IsotonicRegression.scala b/mllib/src/main/scala/org/apache/spark/mllib/regression/IsotonicRegression.scala index 96e50faca2b19..f3b46c75c05f3 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/regression/IsotonicRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/regression/IsotonicRegression.scala @@ -170,15 +170,15 @@ object IsotonicRegressionModel extends Loader[IsotonicRegressionModel] { case class Data(boundary: Double, prediction: Double) def save( - sc: SparkContext, - path: String, - boundaries: Array[Double], - predictions: Array[Double], + sc: SparkContext, + path: String, + boundaries: Array[Double], + predictions: Array[Double], isotonic: Boolean): Unit = { val sqlContext = new SQLContext(sc) val metadata = compact(render( - ("class" -> thisClassName) ~ ("version" -> thisFormatVersion) ~ + ("class" -> thisClassName) ~ ("version" -> thisFormatVersion) ~ ("isotonic" -> isotonic))) sc.parallelize(Seq(metadata), 1).saveAsTextFile(metadataPath(path)) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/stat/distribution/MultivariateGaussian.scala b/mllib/src/main/scala/org/apache/spark/mllib/stat/distribution/MultivariateGaussian.scala index cd6add9d60b0d..cf51b24ff777f 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/stat/distribution/MultivariateGaussian.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/stat/distribution/MultivariateGaussian.scala @@ -29,102 +29,102 @@ import org.apache.spark.mllib.util.MLUtils * the event that the covariance matrix is singular, the density will be computed in a * reduced dimensional subspace under which the distribution is supported. * (see [[http://en.wikipedia.org/wiki/Multivariate_normal_distribution#Degenerate_case]]) - * + * * @param mu The mean vector of the distribution * @param sigma The covariance matrix of the distribution */ @DeveloperApi class MultivariateGaussian ( - val mu: Vector, + val mu: Vector, val sigma: Matrix) extends Serializable { require(sigma.numCols == sigma.numRows, "Covariance matrix must be square") require(mu.size == sigma.numCols, "Mean vector length must match covariance matrix size") - + private val breezeMu = mu.toBreeze.toDenseVector - + /** * private[mllib] constructor - * + * * @param mu The mean vector of the distribution * @param sigma The covariance matrix of the distribution */ private[mllib] def this(mu: DBV[Double], sigma: DBM[Double]) = { this(Vectors.fromBreeze(mu), Matrices.fromBreeze(sigma)) } - + /** * Compute distribution dependent constants: * rootSigmaInv = D^(-1/2)^ * U, where sigma = U * D * U.t - * u = log((2*pi)^(-k/2)^ * det(sigma)^(-1/2)^) + * u = log((2*pi)^(-k/2)^ * det(sigma)^(-1/2)^) */ private val (rootSigmaInv: DBM[Double], u: Double) = calculateCovarianceConstants - + /** Returns density of this multivariate Gaussian at given point, x */ def pdf(x: Vector): Double = { pdf(x.toBreeze) } - + /** Returns the log-density of this multivariate Gaussian at given point, x */ def logpdf(x: Vector): Double = { logpdf(x.toBreeze) } - + /** Returns density of this multivariate Gaussian at given point, x */ private[mllib] def pdf(x: BV[Double]): Double = { math.exp(logpdf(x)) } - + /** Returns the log-density of this multivariate Gaussian at given point, x */ private[mllib] def logpdf(x: BV[Double]): Double = { val delta = x - breezeMu val v = rootSigmaInv * delta u + v.t * v * -0.5 } - + /** * Calculate distribution dependent components used for the density function: * pdf(x) = (2*pi)^(-k/2)^ * det(sigma)^(-1/2)^ * exp((-1/2) * (x-mu).t * inv(sigma) * (x-mu)) * where k is length of the mean vector. - * - * We here compute distribution-fixed parts + * + * We here compute distribution-fixed parts * log((2*pi)^(-k/2)^ * det(sigma)^(-1/2)^) * and * D^(-1/2)^ * U, where sigma = U * D * U.t - * + * * Both the determinant and the inverse can be computed from the singular value decomposition * of sigma. Noting that covariance matrices are always symmetric and positive semi-definite, * we can use the eigendecomposition. We also do not compute the inverse directly; noting - * that - * + * that + * * sigma = U * D * U.t - * inv(Sigma) = U * inv(D) * U.t + * inv(Sigma) = U * inv(D) * U.t * = (D^{-1/2}^ * U).t * (D^{-1/2}^ * U) - * + * * and thus - * + * * -0.5 * (x-mu).t * inv(Sigma) * (x-mu) = -0.5 * norm(D^{-1/2}^ * U * (x-mu))^2^ - * - * To guard against singular covariance matrices, this method computes both the + * + * To guard against singular covariance matrices, this method computes both the * pseudo-determinant and the pseudo-inverse (Moore-Penrose). Singular values are considered * to be non-zero only if they exceed a tolerance based on machine precision, matrix size, and * relation to the maximum singular value (same tolerance used by, e.g., Octave). */ private def calculateCovarianceConstants: (DBM[Double], Double) = { val eigSym.EigSym(d, u) = eigSym(sigma.toBreeze.toDenseMatrix) // sigma = u * diag(d) * u.t - + // For numerical stability, values are considered to be non-zero only if they exceed tol. // This prevents any inverted value from exceeding (eps * n * max(d))^-1 val tol = MLUtils.EPSILON * max(d) * d.length - + try { // log(pseudo-determinant) is sum of the logs of all non-zero singular values val logPseudoDetSigma = d.activeValuesIterator.filter(_ > tol).map(math.log).sum - - // calculate the root-pseudo-inverse of the diagonal matrix of singular values + + // calculate the root-pseudo-inverse of the diagonal matrix of singular values // by inverting the square root of all non-zero values val pinvS = diag(new DBV(d.map(v => if (v > tol) math.sqrt(1.0 / v) else 0.0).toArray)) - + (pinvS * u, -0.5 * (mu.size * math.log(2.0 * math.Pi) + logPseudoDetSigma)) } catch { case uex: UnsupportedOperationException => diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/GradientBoostedTrees.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/GradientBoostedTrees.scala index e3ddc7053693c..a835f96d5d0e3 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/GradientBoostedTrees.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/GradientBoostedTrees.scala @@ -270,7 +270,7 @@ object GradientBoostedTrees extends Logging { logInfo(s"$timer") if (persistedInput) input.unpersist() - + if (validate) { new GradientBoostedTreesModel( boostingStrategy.treeStrategy.algo, diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/RandomForest.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/RandomForest.scala index 99d0e3cf2fd6d..069959976a188 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/RandomForest.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/RandomForest.scala @@ -474,7 +474,7 @@ object RandomForest extends Serializable with Logging { val (treeIndex, node) = nodeQueue.head // Choose subset of features for node (if subsampling). val featureSubset: Option[Array[Int]] = if (metadata.subsamplingFeatures) { - Some(SamplingUtils.reservoirSampleAndCount(Range(0, + Some(SamplingUtils.reservoirSampleAndCount(Range(0, metadata.numFeatures).iterator, metadata.numFeaturesPerNode, rng.nextLong)._1) } else { None diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Node.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Node.scala index ee710fc1ed299..a6d1398fc267b 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Node.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Node.scala @@ -83,7 +83,7 @@ class Node ( def predict(features: Vector) : Double = { if (isLeaf) { predict.predict - } else{ + } else { if (split.get.featureType == Continuous) { if (features(split.get.feature) <= split.get.threshold) { leftNode.get.predict(features) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/util/MLUtils.scala b/mllib/src/main/scala/org/apache/spark/mllib/util/MLUtils.scala index 8676a8983cf2a..67d7d4d79f08b 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/util/MLUtils.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/util/MLUtils.scala @@ -276,7 +276,7 @@ object MLUtils { } Vectors.fromBreeze(vector1) } - + /** * Returns the squared Euclidean distance between two vectors. The following formula will be used * if it does not introduce too much numerical error: diff --git a/mllib/src/test/java/org/apache/spark/ml/feature/JavaStringIndexerSuite.java b/mllib/src/test/java/org/apache/spark/ml/feature/JavaStringIndexerSuite.java new file mode 100644 index 0000000000000..35b18c5308f61 --- /dev/null +++ b/mllib/src/test/java/org/apache/spark/ml/feature/JavaStringIndexerSuite.java @@ -0,0 +1,77 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.feature; + +import java.util.Arrays; + +import org.junit.After; +import org.junit.Assert; +import org.junit.Before; +import org.junit.Test; + +import org.apache.spark.api.java.JavaRDD; +import org.apache.spark.api.java.JavaSparkContext; +import org.apache.spark.sql.DataFrame; +import org.apache.spark.sql.Row; +import org.apache.spark.sql.RowFactory; +import org.apache.spark.sql.SQLContext; +import org.apache.spark.sql.types.StructField; +import org.apache.spark.sql.types.StructType; +import static org.apache.spark.sql.types.DataTypes.*; + +public class JavaStringIndexerSuite { + private transient JavaSparkContext jsc; + private transient SQLContext sqlContext; + + @Before + public void setUp() { + jsc = new JavaSparkContext("local", "JavaStringIndexerSuite"); + sqlContext = new SQLContext(jsc); + } + + @After + public void tearDown() { + jsc.stop(); + sqlContext = null; + } + + @Test + public void testStringIndexer() { + StructType schema = createStructType(new StructField[] { + createStructField("id", IntegerType, false), + createStructField("label", StringType, false) + }); + JavaRDD rdd = jsc.parallelize( + Arrays.asList(c(0, "a"), c(1, "b"), c(2, "c"), c(3, "a"), c(4, "a"), c(5, "c"))); + DataFrame dataset = sqlContext.createDataFrame(rdd, schema); + + StringIndexer indexer = new StringIndexer() + .setInputCol("label") + .setOutputCol("labelIndex"); + DataFrame output = indexer.fit(dataset).transform(dataset); + + Assert.assertArrayEquals( + new Row[] { c(0, 0.0), c(1, 2.0), c(2, 1.0), c(3, 0.0), c(4, 0.0), c(5, 1.0) }, + output.orderBy("id").select("id", "labelIndex").collect()); + } + + /** An alias for RowFactory.create. */ + private Row c(Object... values) { + return RowFactory.create(values); + } +} diff --git a/mllib/src/test/java/org/apache/spark/ml/feature/JavaVectorAssemblerSuite.java b/mllib/src/test/java/org/apache/spark/ml/feature/JavaVectorAssemblerSuite.java new file mode 100644 index 0000000000000..b7c564caad3bd --- /dev/null +++ b/mllib/src/test/java/org/apache/spark/ml/feature/JavaVectorAssemblerSuite.java @@ -0,0 +1,78 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.feature; + +import java.util.Arrays; + +import org.junit.After; +import org.junit.Assert; +import org.junit.Before; +import org.junit.Test; + +import org.apache.spark.api.java.JavaRDD; +import org.apache.spark.api.java.JavaSparkContext; +import org.apache.spark.mllib.linalg.Vector; +import org.apache.spark.mllib.linalg.VectorUDT; +import org.apache.spark.mllib.linalg.Vectors; +import org.apache.spark.sql.DataFrame; +import org.apache.spark.sql.Row; +import org.apache.spark.sql.RowFactory; +import org.apache.spark.sql.SQLContext; +import org.apache.spark.sql.types.*; +import static org.apache.spark.sql.types.DataTypes.*; + +public class JavaVectorAssemblerSuite { + private transient JavaSparkContext jsc; + private transient SQLContext sqlContext; + + @Before + public void setUp() { + jsc = new JavaSparkContext("local", "JavaVectorAssemblerSuite"); + sqlContext = new SQLContext(jsc); + } + + @After + public void tearDown() { + jsc.stop(); + jsc = null; + } + + @Test + public void testVectorAssembler() { + StructType schema = createStructType(new StructField[] { + createStructField("id", IntegerType, false), + createStructField("x", DoubleType, false), + createStructField("y", new VectorUDT(), false), + createStructField("name", StringType, false), + createStructField("z", new VectorUDT(), false), + createStructField("n", LongType, false) + }); + Row row = RowFactory.create( + 0, 0.0, Vectors.dense(1.0, 2.0), "a", + Vectors.sparse(2, new int[] {1}, new double[] {3.0}), 10L); + JavaRDD rdd = jsc.parallelize(Arrays.asList(row)); + DataFrame dataset = sqlContext.createDataFrame(rdd, schema); + VectorAssembler assembler = new VectorAssembler() + .setInputCols(new String[] {"x", "y", "z", "n"}) + .setOutputCol("features"); + DataFrame output = assembler.transform(dataset); + Assert.assertEquals( + Vectors.sparse(6, new int[] {1, 2, 4, 5}, new double[] {1.0, 2.0, 3.0, 10.0}), + output.select("features").first().getAs(0)); + } +} diff --git a/mllib/src/test/java/org/apache/spark/ml/util/IdentifiableSuite.scala b/mllib/src/test/java/org/apache/spark/ml/util/IdentifiableSuite.scala index 67c262d0f9d8d..928301523fba9 100644 --- a/mllib/src/test/java/org/apache/spark/ml/util/IdentifiableSuite.scala +++ b/mllib/src/test/java/org/apache/spark/ml/util/IdentifiableSuite.scala @@ -17,9 +17,9 @@ package org.apache.spark.ml.util -import org.scalatest.FunSuite +import org.apache.spark.SparkFunSuite -class IdentifiableSuite extends FunSuite { +class IdentifiableSuite extends SparkFunSuite { import IdentifiableSuite.Test diff --git a/mllib/src/test/scala/org/apache/spark/ml/PipelineSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/PipelineSuite.scala index 2b04a3034782e..05bf58e63abaf 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/PipelineSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/PipelineSuite.scala @@ -19,13 +19,13 @@ package org.apache.spark.ml import org.mockito.Matchers.{any, eq => meq} import org.mockito.Mockito.when -import org.scalatest.FunSuite import org.scalatest.mock.MockitoSugar.mock +import org.apache.spark.SparkFunSuite import org.apache.spark.ml.param.ParamMap import org.apache.spark.sql.DataFrame -class PipelineSuite extends FunSuite { +class PipelineSuite extends SparkFunSuite { abstract class MyModel extends Model[MyModel] diff --git a/mllib/src/test/scala/org/apache/spark/ml/attribute/AttributeGroupSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/attribute/AttributeGroupSuite.scala index 17ddd335deb6d..512cffb1acb66 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/attribute/AttributeGroupSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/attribute/AttributeGroupSuite.scala @@ -17,9 +17,9 @@ package org.apache.spark.ml.attribute -import org.scalatest.FunSuite +import org.apache.spark.SparkFunSuite -class AttributeGroupSuite extends FunSuite { +class AttributeGroupSuite extends SparkFunSuite { test("attribute group") { val attrs = Array( diff --git a/mllib/src/test/scala/org/apache/spark/ml/attribute/AttributeSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/attribute/AttributeSuite.scala index ec9b717e41ce8..72b575d022547 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/attribute/AttributeSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/attribute/AttributeSuite.scala @@ -17,11 +17,10 @@ package org.apache.spark.ml.attribute -import org.scalatest.FunSuite - +import org.apache.spark.SparkFunSuite import org.apache.spark.sql.types._ -class AttributeSuite extends FunSuite { +class AttributeSuite extends SparkFunSuite { test("default numeric attribute") { val attr: NumericAttribute = NumericAttribute.defaultAttr diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/DecisionTreeClassifierSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/DecisionTreeClassifierSuite.scala index 3fdc66be8a314..ae40b0b8ff854 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/classification/DecisionTreeClassifierSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/classification/DecisionTreeClassifierSuite.scala @@ -17,8 +17,7 @@ package org.apache.spark.ml.classification -import org.scalatest.FunSuite - +import org.apache.spark.SparkFunSuite import org.apache.spark.ml.impl.TreeTests import org.apache.spark.mllib.linalg.Vectors import org.apache.spark.mllib.regression.LabeledPoint @@ -29,7 +28,7 @@ import org.apache.spark.rdd.RDD import org.apache.spark.sql.DataFrame -class DecisionTreeClassifierSuite extends FunSuite with MLlibTestSparkContext { +class DecisionTreeClassifierSuite extends SparkFunSuite with MLlibTestSparkContext { import DecisionTreeClassifierSuite.compareAPIs @@ -251,7 +250,7 @@ class DecisionTreeClassifierSuite extends FunSuite with MLlibTestSparkContext { */ } -private[ml] object DecisionTreeClassifierSuite extends FunSuite { +private[ml] object DecisionTreeClassifierSuite extends SparkFunSuite { /** * Train 2 decision trees on the given dataset, one using the old API and one using the new API. @@ -266,7 +265,7 @@ private[ml] object DecisionTreeClassifierSuite extends FunSuite { val oldTree = OldDecisionTree.train(data, oldStrategy) val newData: DataFrame = TreeTests.setMetadata(data, categoricalFeatures, numClasses) val newTree = dt.fit(newData) - // Use parent, fittingParamMap from newTree since these are not checked anyways. + // Use parent from newTree since this is not checked anyways. val oldTreeAsNew = DecisionTreeClassificationModel.fromOld( oldTree, newTree.parent.asInstanceOf[DecisionTreeClassifier], categoricalFeatures) TreeTests.checkEqual(oldTreeAsNew, newTree) diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/GBTClassifierSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/GBTClassifierSuite.scala index ea86867f1161a..1302da3c373ff 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/classification/GBTClassifierSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/classification/GBTClassifierSuite.scala @@ -17,8 +17,7 @@ package org.apache.spark.ml.classification -import org.scalatest.FunSuite - +import org.apache.spark.SparkFunSuite import org.apache.spark.ml.impl.TreeTests import org.apache.spark.mllib.regression.LabeledPoint import org.apache.spark.mllib.tree.{EnsembleTestHelper, GradientBoostedTrees => OldGBT} @@ -31,7 +30,7 @@ import org.apache.spark.sql.DataFrame /** * Test suite for [[GBTClassifier]]. */ -class GBTClassifierSuite extends FunSuite with MLlibTestSparkContext { +class GBTClassifierSuite extends SparkFunSuite with MLlibTestSparkContext { import GBTClassifierSuite.compareAPIs @@ -128,7 +127,7 @@ private object GBTClassifierSuite { val oldModel = oldGBT.run(data) val newData: DataFrame = TreeTests.setMetadata(data, categoricalFeatures, numClasses = 2) val newModel = gbt.fit(newData) - // Use parent, fittingParamMap from newTree since these are not checked anyways. + // Use parent from newTree since this is not checked anyways. val oldModelAsNew = GBTClassificationModel.fromOld( oldModel, newModel.parent.asInstanceOf[GBTClassifier], categoricalFeatures) TreeTests.checkEqual(oldModelAsNew, newModel) diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala index 9f77d5f3efc55..a755cac3ea76e 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala @@ -17,15 +17,14 @@ package org.apache.spark.ml.classification -import org.scalatest.FunSuite - +import org.apache.spark.SparkFunSuite import org.apache.spark.mllib.classification.LogisticRegressionSuite._ import org.apache.spark.mllib.linalg.Vector import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.mllib.util.TestingUtils._ import org.apache.spark.sql.{DataFrame, Row} -class LogisticRegressionSuite extends FunSuite with MLlibTestSparkContext { +class LogisticRegressionSuite extends SparkFunSuite with MLlibTestSparkContext { @transient var dataset: DataFrame = _ @transient var binaryDataset: DataFrame = _ diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/OneVsRestSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/OneVsRestSuite.scala index 770b56890fa45..1d04ccb509057 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/classification/OneVsRestSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/classification/OneVsRestSuite.scala @@ -17,8 +17,7 @@ package org.apache.spark.ml.classification -import org.scalatest.FunSuite - +import org.apache.spark.SparkFunSuite import org.apache.spark.ml.attribute.NominalAttribute import org.apache.spark.ml.util.MetadataUtils import org.apache.spark.mllib.classification.LogisticRegressionWithLBFGS @@ -30,7 +29,7 @@ import org.apache.spark.mllib.util.TestingUtils._ import org.apache.spark.rdd.RDD import org.apache.spark.sql.DataFrame -class OneVsRestSuite extends FunSuite with MLlibTestSparkContext { +class OneVsRestSuite extends SparkFunSuite with MLlibTestSparkContext { @transient var dataset: DataFrame = _ @transient var rdd: RDD[LabeledPoint] = _ @@ -94,6 +93,15 @@ class OneVsRestSuite extends FunSuite with MLlibTestSparkContext { val datasetWithLabelMetadata = dataset.select(labelWithMetadata, features) ova.fit(datasetWithLabelMetadata) } + + test("SPARK-8049: OneVsRest shouldn't output temp columns") { + val logReg = new LogisticRegression() + .setMaxIter(1) + val ovr = new OneVsRest() + .setClassifier(logReg) + val output = ovr.fit(dataset).transform(dataset) + assert(output.schema.fieldNames.toSet === Set("label", "features", "prediction")) + } } private class MockLogisticRegression(uid: String) extends LogisticRegression(uid) { diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/RandomForestClassifierSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/RandomForestClassifierSuite.scala index cdbbacab8e0e3..eee9355a67be3 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/classification/RandomForestClassifierSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/classification/RandomForestClassifierSuite.scala @@ -17,8 +17,7 @@ package org.apache.spark.ml.classification -import org.scalatest.FunSuite - +import org.apache.spark.SparkFunSuite import org.apache.spark.ml.impl.TreeTests import org.apache.spark.mllib.linalg.Vectors import org.apache.spark.mllib.regression.LabeledPoint @@ -32,7 +31,7 @@ import org.apache.spark.sql.DataFrame /** * Test suite for [[RandomForestClassifier]]. */ -class RandomForestClassifierSuite extends FunSuite with MLlibTestSparkContext { +class RandomForestClassifierSuite extends SparkFunSuite with MLlibTestSparkContext { import RandomForestClassifierSuite.compareAPIs @@ -158,7 +157,7 @@ private object RandomForestClassifierSuite { data, oldStrategy, rf.getNumTrees, rf.getFeatureSubsetStrategy, rf.getSeed.toInt) val newData: DataFrame = TreeTests.setMetadata(data, categoricalFeatures, numClasses) val newModel = rf.fit(newData) - // Use parent, fittingParamMap from newTree since these are not checked anyways. + // Use parent from newTree since this is not checked anyways. val oldModelAsNew = RandomForestClassificationModel.fromOld( oldModel, newModel.parent.asInstanceOf[RandomForestClassifier], categoricalFeatures) TreeTests.checkEqual(oldModelAsNew, newModel) diff --git a/mllib/src/test/scala/org/apache/spark/ml/evaluation/RegressionEvaluatorSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/evaluation/RegressionEvaluatorSuite.scala index 3ea7aad5274f2..36a1ac6b7996d 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/evaluation/RegressionEvaluatorSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/evaluation/RegressionEvaluatorSuite.scala @@ -17,13 +17,12 @@ package org.apache.spark.ml.evaluation -import org.scalatest.FunSuite - +import org.apache.spark.SparkFunSuite import org.apache.spark.ml.regression.LinearRegression import org.apache.spark.mllib.util.{LinearDataGenerator, MLlibTestSparkContext} import org.apache.spark.mllib.util.TestingUtils._ -class RegressionEvaluatorSuite extends FunSuite with MLlibTestSparkContext { +class RegressionEvaluatorSuite extends SparkFunSuite with MLlibTestSparkContext { test("Regression Evaluator: default params") { /** @@ -39,7 +38,7 @@ class RegressionEvaluatorSuite extends FunSuite with MLlibTestSparkContext { val dataset = sqlContext.createDataFrame( sc.parallelize(LinearDataGenerator.generateLinearInput( 6.3, Array(4.7, 7.2), Array(0.9, -1.3), Array(0.7, 1.2), 100, 42, 0.1), 2)) - + /** * Using the following R code to load the data, train the model and evaluate metrics. * diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/BinarizerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/BinarizerSuite.scala index 8f6c6b39dc93b..7953bd0417191 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/BinarizerSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/BinarizerSuite.scala @@ -17,12 +17,11 @@ package org.apache.spark.ml.feature -import org.scalatest.FunSuite - +import org.apache.spark.SparkFunSuite import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.sql.{DataFrame, Row} -class BinarizerSuite extends FunSuite with MLlibTestSparkContext { +class BinarizerSuite extends SparkFunSuite with MLlibTestSparkContext { @transient var data: Array[Double] = _ @@ -48,7 +47,7 @@ class BinarizerSuite extends FunSuite with MLlibTestSparkContext { test("Binarize continuous features with setter") { val threshold: Double = 0.2 - val thresholdBinarized: Array[Double] = data.map(x => if (x > threshold) 1.0 else 0.0) + val thresholdBinarized: Array[Double] = data.map(x => if (x > threshold) 1.0 else 0.0) val dataFrame: DataFrame = sqlContext.createDataFrame( data.zip(thresholdBinarized)).toDF("feature", "expected") diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/BucketizerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/BucketizerSuite.scala index 0391bd8427c2c..507a8a7db24c7 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/BucketizerSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/BucketizerSuite.scala @@ -19,15 +19,13 @@ package org.apache.spark.ml.feature import scala.util.Random -import org.scalatest.FunSuite - -import org.apache.spark.SparkException +import org.apache.spark.{SparkException, SparkFunSuite} import org.apache.spark.mllib.linalg.Vectors import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.mllib.util.TestingUtils._ import org.apache.spark.sql.{DataFrame, Row} -class BucketizerSuite extends FunSuite with MLlibTestSparkContext { +class BucketizerSuite extends SparkFunSuite with MLlibTestSparkContext { test("Bucket continuous features, without -inf,inf") { // Check a set of valid feature values. @@ -110,7 +108,7 @@ class BucketizerSuite extends FunSuite with MLlibTestSparkContext { } } -private object BucketizerSuite extends FunSuite { +private object BucketizerSuite extends SparkFunSuite { /** Brute force search for buckets. Bucket i is defined by the range [split(i), split(i+1)). */ def linearSearchForBuckets(splits: Array[Double], feature: Double): Double = { require(feature >= splits.head) diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/HashingTFSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/HashingTFSuite.scala index 2e4beb0bfff63..7b2d70e644005 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/HashingTFSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/HashingTFSuite.scala @@ -17,8 +17,7 @@ package org.apache.spark.ml.feature -import org.scalatest.FunSuite - +import org.apache.spark.SparkFunSuite import org.apache.spark.ml.attribute.AttributeGroup import org.apache.spark.ml.param.ParamsSuite import org.apache.spark.mllib.linalg.{Vector, Vectors} @@ -26,7 +25,7 @@ import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.mllib.util.TestingUtils._ import org.apache.spark.util.Utils -class HashingTFSuite extends FunSuite with MLlibTestSparkContext { +class HashingTFSuite extends SparkFunSuite with MLlibTestSparkContext { test("params") { val hashingTF = new HashingTF diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/IDFSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/IDFSuite.scala index f85e85471617a..d83772e8be755 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/IDFSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/IDFSuite.scala @@ -17,14 +17,13 @@ package org.apache.spark.ml.feature -import org.scalatest.FunSuite - +import org.apache.spark.SparkFunSuite import org.apache.spark.mllib.linalg.{DenseVector, SparseVector, Vector, Vectors} import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.mllib.util.TestingUtils._ import org.apache.spark.sql.Row -class IDFSuite extends FunSuite with MLlibTestSparkContext { +class IDFSuite extends SparkFunSuite with MLlibTestSparkContext { def scaleDataWithIDF(dataSet: Array[Vector], model: Vector): Array[Vector] = { dataSet.map { diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/NormalizerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/NormalizerSuite.scala index 9d09f24709e23..9f03470b7f328 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/NormalizerSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/NormalizerSuite.scala @@ -17,15 +17,14 @@ package org.apache.spark.ml.feature -import org.scalatest.FunSuite - +import org.apache.spark.SparkFunSuite import org.apache.spark.mllib.linalg.{DenseVector, SparseVector, Vector, Vectors} import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.mllib.util.TestingUtils._ import org.apache.spark.sql.{DataFrame, Row, SQLContext} -class NormalizerSuite extends FunSuite with MLlibTestSparkContext { +class NormalizerSuite extends SparkFunSuite with MLlibTestSparkContext { @transient var data: Array[Vector] = _ @transient var dataFrame: DataFrame = _ diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/OneHotEncoderSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/OneHotEncoderSuite.scala index 056b9eda86bba..2e5036a844562 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/OneHotEncoderSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/OneHotEncoderSuite.scala @@ -17,14 +17,14 @@ package org.apache.spark.ml.feature -import org.scalatest.FunSuite - +import org.apache.spark.SparkFunSuite +import org.apache.spark.ml.attribute.{AttributeGroup, BinaryAttribute, NominalAttribute} import org.apache.spark.mllib.linalg.Vector import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.sql.DataFrame +import org.apache.spark.sql.functions.col - -class OneHotEncoderSuite extends FunSuite with MLlibTestSparkContext { +class OneHotEncoderSuite extends SparkFunSuite with MLlibTestSparkContext { def stringIndexed(): DataFrame = { val data = sc.parallelize(Seq((0, "a"), (1, "b"), (2, "c"), (3, "a"), (4, "a"), (5, "c")), 2) @@ -36,15 +36,16 @@ class OneHotEncoderSuite extends FunSuite with MLlibTestSparkContext { indexer.transform(df) } - test("OneHotEncoder includeFirst = true") { + test("OneHotEncoder dropLast = false") { val transformed = stringIndexed() val encoder = new OneHotEncoder() .setInputCol("labelIndex") .setOutputCol("labelVec") + .setDropLast(false) val encoded = encoder.transform(transformed) val output = encoded.select("id", "labelVec").map { r => - val vec = r.get(1).asInstanceOf[Vector] + val vec = r.getAs[Vector](1) (r.getInt(0), vec(0), vec(1), vec(2)) }.collect().toSet // a -> 0, b -> 2, c -> 1 @@ -53,22 +54,46 @@ class OneHotEncoderSuite extends FunSuite with MLlibTestSparkContext { assert(output === expected) } - test("OneHotEncoder includeFirst = false") { + test("OneHotEncoder dropLast = true") { val transformed = stringIndexed() val encoder = new OneHotEncoder() - .setIncludeFirst(false) .setInputCol("labelIndex") .setOutputCol("labelVec") val encoded = encoder.transform(transformed) val output = encoded.select("id", "labelVec").map { r => - val vec = r.get(1).asInstanceOf[Vector] + val vec = r.getAs[Vector](1) (r.getInt(0), vec(0), vec(1)) }.collect().toSet // a -> 0, b -> 2, c -> 1 - val expected = Set((0, 0.0, 0.0), (1, 0.0, 1.0), (2, 1.0, 0.0), - (3, 0.0, 0.0), (4, 0.0, 0.0), (5, 1.0, 0.0)) + val expected = Set((0, 1.0, 0.0), (1, 0.0, 0.0), (2, 0.0, 1.0), + (3, 1.0, 0.0), (4, 1.0, 0.0), (5, 0.0, 1.0)) assert(output === expected) } + test("input column with ML attribute") { + val attr = NominalAttribute.defaultAttr.withValues("small", "medium", "large") + val df = sqlContext.createDataFrame(Seq(0.0, 1.0, 2.0, 1.0).map(Tuple1.apply)).toDF("size") + .select(col("size").as("size", attr.toMetadata())) + val encoder = new OneHotEncoder() + .setInputCol("size") + .setOutputCol("encoded") + val output = encoder.transform(df) + val group = AttributeGroup.fromStructField(output.schema("encoded")) + assert(group.size === 2) + assert(group.getAttr(0) === BinaryAttribute.defaultAttr.withName("size_is_small").withIndex(0)) + assert(group.getAttr(1) === BinaryAttribute.defaultAttr.withName("size_is_medium").withIndex(1)) + } + + test("input column without ML attribute") { + val df = sqlContext.createDataFrame(Seq(0.0, 1.0, 2.0, 1.0).map(Tuple1.apply)).toDF("index") + val encoder = new OneHotEncoder() + .setInputCol("index") + .setOutputCol("encoded") + val output = encoder.transform(df) + val group = AttributeGroup.fromStructField(output.schema("encoded")) + assert(group.size === 2) + assert(group.getAttr(0) === BinaryAttribute.defaultAttr.withName("index_is_0").withIndex(0)) + assert(group.getAttr(1) === BinaryAttribute.defaultAttr.withName("index_is_1").withIndex(1)) + } } diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/PolynomialExpansionSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/PolynomialExpansionSuite.scala index aa230ca073d5b..feca866cd711d 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/PolynomialExpansionSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/PolynomialExpansionSuite.scala @@ -17,15 +17,15 @@ package org.apache.spark.ml.feature -import org.scalatest.FunSuite import org.scalatest.exceptions.TestFailedException +import org.apache.spark.SparkFunSuite import org.apache.spark.mllib.linalg.{DenseVector, SparseVector, Vector, Vectors} import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.mllib.util.TestingUtils._ import org.apache.spark.sql.Row -class PolynomialExpansionSuite extends FunSuite with MLlibTestSparkContext { +class PolynomialExpansionSuite extends SparkFunSuite with MLlibTestSparkContext { test("Polynomial expansion with default parameter") { val data = Array( diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/StringIndexerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/StringIndexerSuite.scala index 89c2fe45573aa..cbf1e8ddcb48a 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/StringIndexerSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/StringIndexerSuite.scala @@ -17,12 +17,11 @@ package org.apache.spark.ml.feature -import org.scalatest.FunSuite - +import org.apache.spark.SparkFunSuite import org.apache.spark.ml.attribute.{Attribute, NominalAttribute} import org.apache.spark.mllib.util.MLlibTestSparkContext -class StringIndexerSuite extends FunSuite with MLlibTestSparkContext { +class StringIndexerSuite extends SparkFunSuite with MLlibTestSparkContext { test("StringIndexer") { val data = sc.parallelize(Seq((0, "a"), (1, "b"), (2, "c"), (3, "a"), (4, "a"), (5, "c")), 2) diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/TokenizerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/TokenizerSuite.scala index eabda089d0988..ac279cb3215c2 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/TokenizerSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/TokenizerSuite.scala @@ -19,15 +19,14 @@ package org.apache.spark.ml.feature import scala.beans.BeanInfo -import org.scalatest.FunSuite - +import org.apache.spark.SparkFunSuite import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.sql.{DataFrame, Row} @BeanInfo case class TokenizerTestData(rawText: String, wantedTokens: Array[String]) -class RegexTokenizerSuite extends FunSuite with MLlibTestSparkContext { +class RegexTokenizerSuite extends SparkFunSuite with MLlibTestSparkContext { import org.apache.spark.ml.feature.RegexTokenizerSuite._ test("RegexTokenizer") { @@ -60,7 +59,7 @@ class RegexTokenizerSuite extends FunSuite with MLlibTestSparkContext { } } -object RegexTokenizerSuite extends FunSuite { +object RegexTokenizerSuite extends SparkFunSuite { def testRegexTokenizer(t: RegexTokenizer, dataset: DataFrame): Unit = { t.transform(dataset) diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/VectorAssemblerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/VectorAssemblerSuite.scala index 43534e89928b1..489abb5af7130 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/VectorAssemblerSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/VectorAssemblerSuite.scala @@ -17,16 +17,14 @@ package org.apache.spark.ml.feature -import org.scalatest.FunSuite - -import org.apache.spark.SparkException +import org.apache.spark.{SparkException, SparkFunSuite} import org.apache.spark.ml.attribute.{AttributeGroup, NominalAttribute, NumericAttribute} import org.apache.spark.mllib.linalg.{DenseVector, SparseVector, Vector, Vectors} import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.sql.Row import org.apache.spark.sql.functions.col -class VectorAssemblerSuite extends FunSuite with MLlibTestSparkContext { +class VectorAssemblerSuite extends SparkFunSuite with MLlibTestSparkContext { test("assemble") { import org.apache.spark.ml.feature.VectorAssembler.assemble diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/VectorIndexerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/VectorIndexerSuite.scala index b11b029c6343e..06affc7305cf5 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/VectorIndexerSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/VectorIndexerSuite.scala @@ -19,16 +19,14 @@ package org.apache.spark.ml.feature import scala.beans.{BeanInfo, BeanProperty} -import org.scalatest.FunSuite - -import org.apache.spark.SparkException +import org.apache.spark.{SparkException, SparkFunSuite} import org.apache.spark.ml.attribute._ import org.apache.spark.mllib.linalg.{SparseVector, Vector, Vectors} import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.rdd.RDD import org.apache.spark.sql.DataFrame -class VectorIndexerSuite extends FunSuite with MLlibTestSparkContext { +class VectorIndexerSuite extends SparkFunSuite with MLlibTestSparkContext { import VectorIndexerSuite.FeatureData diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/Word2VecSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/Word2VecSuite.scala index df446d0c22015..94ebc3aebfa37 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/Word2VecSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/Word2VecSuite.scala @@ -17,14 +17,13 @@ package org.apache.spark.ml.feature -import org.scalatest.FunSuite - +import org.apache.spark.SparkFunSuite import org.apache.spark.mllib.linalg.{Vector, Vectors} import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.mllib.util.TestingUtils._ import org.apache.spark.sql.{Row, SQLContext} -class Word2VecSuite extends FunSuite with MLlibTestSparkContext { +class Word2VecSuite extends SparkFunSuite with MLlibTestSparkContext { test("Word2Vec") { val sqlContext = new SQLContext(sc) diff --git a/mllib/src/test/scala/org/apache/spark/ml/impl/TreeTests.scala b/mllib/src/test/scala/org/apache/spark/ml/impl/TreeTests.scala index 1505ad872536b..778abcba22c10 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/impl/TreeTests.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/impl/TreeTests.scala @@ -19,8 +19,7 @@ package org.apache.spark.ml.impl import scala.collection.JavaConverters._ -import org.scalatest.FunSuite - +import org.apache.spark.SparkFunSuite import org.apache.spark.api.java.JavaRDD import org.apache.spark.ml.attribute.{AttributeGroup, NominalAttribute, NumericAttribute} import org.apache.spark.ml.tree._ @@ -29,7 +28,7 @@ import org.apache.spark.rdd.RDD import org.apache.spark.sql.{SQLContext, DataFrame} -private[ml] object TreeTests extends FunSuite { +private[ml] object TreeTests extends SparkFunSuite { /** * Convert the given data to a DataFrame, and set the features and label metadata. diff --git a/mllib/src/test/scala/org/apache/spark/ml/param/ParamsSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/param/ParamsSuite.scala index 04f2af4727ea4..f80e7749098a5 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/param/ParamsSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/param/ParamsSuite.scala @@ -17,9 +17,9 @@ package org.apache.spark.ml.param -import org.scalatest.FunSuite +import org.apache.spark.SparkFunSuite -class ParamsSuite extends FunSuite { +class ParamsSuite extends SparkFunSuite { test("param") { val solver = new TestParams() @@ -202,7 +202,7 @@ class ParamsSuite extends FunSuite { } } -object ParamsSuite extends FunSuite { +object ParamsSuite extends SparkFunSuite { /** * Checks common requirements for [[Params.params]]: 1) number of params; 2) params are ordered diff --git a/mllib/src/test/scala/org/apache/spark/ml/param/shared/SharedParamsSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/param/shared/SharedParamsSuite.scala index ca18fa1ad3c15..eb5408d3fee7c 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/param/shared/SharedParamsSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/param/shared/SharedParamsSuite.scala @@ -17,11 +17,10 @@ package org.apache.spark.ml.param.shared -import org.scalatest.FunSuite - +import org.apache.spark.SparkFunSuite import org.apache.spark.ml.param.Params -class SharedParamsSuite extends FunSuite { +class SharedParamsSuite extends SparkFunSuite { test("outputCol") { diff --git a/mllib/src/test/scala/org/apache/spark/ml/recommendation/ALSSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/recommendation/ALSSuite.scala index 9a35555e52b90..2e5cfe7027eb6 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/recommendation/ALSSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/recommendation/ALSSuite.scala @@ -25,9 +25,8 @@ import scala.collection.mutable.ArrayBuffer import scala.language.existentials import com.github.fommil.netlib.BLAS.{getInstance => blas} -import org.scalatest.FunSuite -import org.apache.spark.{Logging, SparkException} +import org.apache.spark.{Logging, SparkException, SparkFunSuite} import org.apache.spark.ml.recommendation.ALS._ import org.apache.spark.mllib.linalg.Vectors import org.apache.spark.mllib.util.MLlibTestSparkContext @@ -36,7 +35,7 @@ import org.apache.spark.rdd.RDD import org.apache.spark.sql.{Row, SQLContext} import org.apache.spark.util.Utils -class ALSSuite extends FunSuite with MLlibTestSparkContext with Logging { +class ALSSuite extends SparkFunSuite with MLlibTestSparkContext with Logging { private var tempDir: File = _ diff --git a/mllib/src/test/scala/org/apache/spark/ml/regression/DecisionTreeRegressorSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/regression/DecisionTreeRegressorSuite.scala index 1196a772dfdd4..33aa9d0d62343 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/regression/DecisionTreeRegressorSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/regression/DecisionTreeRegressorSuite.scala @@ -17,8 +17,7 @@ package org.apache.spark.ml.regression -import org.scalatest.FunSuite - +import org.apache.spark.SparkFunSuite import org.apache.spark.ml.impl.TreeTests import org.apache.spark.mllib.regression.LabeledPoint import org.apache.spark.mllib.tree.{DecisionTree => OldDecisionTree, @@ -28,7 +27,7 @@ import org.apache.spark.rdd.RDD import org.apache.spark.sql.DataFrame -class DecisionTreeRegressorSuite extends FunSuite with MLlibTestSparkContext { +class DecisionTreeRegressorSuite extends SparkFunSuite with MLlibTestSparkContext { import DecisionTreeRegressorSuite.compareAPIs @@ -69,7 +68,7 @@ class DecisionTreeRegressorSuite extends FunSuite with MLlibTestSparkContext { // TODO: test("model save/load") SPARK-6725 } -private[ml] object DecisionTreeRegressorSuite extends FunSuite { +private[ml] object DecisionTreeRegressorSuite extends SparkFunSuite { /** * Train 2 decision trees on the given dataset, one using the old API and one using the new API. @@ -83,7 +82,7 @@ private[ml] object DecisionTreeRegressorSuite extends FunSuite { val oldTree = OldDecisionTree.train(data, oldStrategy) val newData: DataFrame = TreeTests.setMetadata(data, categoricalFeatures, numClasses = 0) val newTree = dt.fit(newData) - // Use parent, fittingParamMap from newTree since these are not checked anyways. + // Use parent from newTree since this is not checked anyways. val oldTreeAsNew = DecisionTreeRegressionModel.fromOld( oldTree, newTree.parent.asInstanceOf[DecisionTreeRegressor], categoricalFeatures) TreeTests.checkEqual(oldTreeAsNew, newTree) diff --git a/mllib/src/test/scala/org/apache/spark/ml/regression/GBTRegressorSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/regression/GBTRegressorSuite.scala index 40e7e3273e965..98fb3d3f5f22c 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/regression/GBTRegressorSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/regression/GBTRegressorSuite.scala @@ -17,8 +17,7 @@ package org.apache.spark.ml.regression -import org.scalatest.FunSuite - +import org.apache.spark.SparkFunSuite import org.apache.spark.ml.impl.TreeTests import org.apache.spark.mllib.regression.LabeledPoint import org.apache.spark.mllib.tree.{EnsembleTestHelper, GradientBoostedTrees => OldGBT} @@ -31,7 +30,7 @@ import org.apache.spark.sql.DataFrame /** * Test suite for [[GBTRegressor]]. */ -class GBTRegressorSuite extends FunSuite with MLlibTestSparkContext { +class GBTRegressorSuite extends SparkFunSuite with MLlibTestSparkContext { import GBTRegressorSuite.compareAPIs @@ -129,7 +128,7 @@ private object GBTRegressorSuite { val oldModel = oldGBT.run(data) val newData: DataFrame = TreeTests.setMetadata(data, categoricalFeatures, numClasses = 0) val newModel = gbt.fit(newData) - // Use parent, fittingParamMap from newTree since these are not checked anyways. + // Use parent from newTree since this is not checked anyways. val oldModelAsNew = GBTRegressionModel.fromOld( oldModel, newModel.parent.asInstanceOf[GBTRegressor], categoricalFeatures) TreeTests.checkEqual(oldModelAsNew, newModel) diff --git a/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala index 50a78631fa6d6..732e2c42be144 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala @@ -17,14 +17,13 @@ package org.apache.spark.ml.regression -import org.scalatest.FunSuite - +import org.apache.spark.SparkFunSuite import org.apache.spark.mllib.linalg.DenseVector import org.apache.spark.mllib.util.{LinearDataGenerator, MLlibTestSparkContext} import org.apache.spark.mllib.util.TestingUtils._ import org.apache.spark.sql.{DataFrame, Row} -class LinearRegressionSuite extends FunSuite with MLlibTestSparkContext { +class LinearRegressionSuite extends SparkFunSuite with MLlibTestSparkContext { @transient var dataset: DataFrame = _ diff --git a/mllib/src/test/scala/org/apache/spark/ml/regression/RandomForestRegressorSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/regression/RandomForestRegressorSuite.scala index 3efffbb763b78..b24ecaa57c89b 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/regression/RandomForestRegressorSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/regression/RandomForestRegressorSuite.scala @@ -17,8 +17,7 @@ package org.apache.spark.ml.regression -import org.scalatest.FunSuite - +import org.apache.spark.SparkFunSuite import org.apache.spark.ml.impl.TreeTests import org.apache.spark.mllib.regression.LabeledPoint import org.apache.spark.mllib.tree.{EnsembleTestHelper, RandomForest => OldRandomForest} @@ -31,7 +30,7 @@ import org.apache.spark.sql.DataFrame /** * Test suite for [[RandomForestRegressor]]. */ -class RandomForestRegressorSuite extends FunSuite with MLlibTestSparkContext { +class RandomForestRegressorSuite extends SparkFunSuite with MLlibTestSparkContext { import RandomForestRegressorSuite.compareAPIs @@ -98,7 +97,7 @@ class RandomForestRegressorSuite extends FunSuite with MLlibTestSparkContext { */ } -private object RandomForestRegressorSuite extends FunSuite { +private object RandomForestRegressorSuite extends SparkFunSuite { /** * Train 2 models on the given dataset, one using the old API and one using the new API. @@ -114,7 +113,7 @@ private object RandomForestRegressorSuite extends FunSuite { data, oldStrategy, rf.getNumTrees, rf.getFeatureSubsetStrategy, rf.getSeed.toInt) val newData: DataFrame = TreeTests.setMetadata(data, categoricalFeatures, numClasses = 0) val newModel = rf.fit(newData) - // Use parent, fittingParamMap from newTree since these are not checked anyways. + // Use parent from newTree since this is not checked anyways. val oldModelAsNew = RandomForestRegressionModel.fromOld( oldModel, newModel.parent.asInstanceOf[RandomForestRegressor], categoricalFeatures) TreeTests.checkEqual(oldModelAsNew, newModel) diff --git a/mllib/src/test/scala/org/apache/spark/ml/tuning/CrossValidatorSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/tuning/CrossValidatorSuite.scala index 60d8bfe38fb13..5ba469c7b10a0 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/tuning/CrossValidatorSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/tuning/CrossValidatorSuite.scala @@ -17,7 +17,7 @@ package org.apache.spark.ml.tuning -import org.scalatest.FunSuite +import org.apache.spark.SparkFunSuite import org.apache.spark.ml.{Estimator, Model} import org.apache.spark.ml.classification.LogisticRegression @@ -29,7 +29,7 @@ import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.sql.{DataFrame, SQLContext} import org.apache.spark.sql.types.StructType -class CrossValidatorSuite extends FunSuite with MLlibTestSparkContext { +class CrossValidatorSuite extends SparkFunSuite with MLlibTestSparkContext { @transient var dataset: DataFrame = _ diff --git a/mllib/src/test/scala/org/apache/spark/ml/tuning/ParamGridBuilderSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/tuning/ParamGridBuilderSuite.scala index 20aa100112bfe..810b70049ec15 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/tuning/ParamGridBuilderSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/tuning/ParamGridBuilderSuite.scala @@ -19,11 +19,10 @@ package org.apache.spark.ml.tuning import scala.collection.mutable -import org.scalatest.FunSuite - +import org.apache.spark.SparkFunSuite import org.apache.spark.ml.param.{ParamMap, TestParams} -class ParamGridBuilderSuite extends FunSuite { +class ParamGridBuilderSuite extends SparkFunSuite { val solver = new TestParams() import solver.{inputCol, maxIter} diff --git a/mllib/src/test/scala/org/apache/spark/mllib/api/python/PythonMLLibAPISuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/api/python/PythonMLLibAPISuite.scala index 3d362b5ee53ea..59944416d96a6 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/api/python/PythonMLLibAPISuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/api/python/PythonMLLibAPISuite.scala @@ -17,13 +17,12 @@ package org.apache.spark.mllib.api.python -import org.scalatest.FunSuite - +import org.apache.spark.SparkFunSuite import org.apache.spark.mllib.linalg.{DenseMatrix, Matrices, Vectors, SparseMatrix} import org.apache.spark.mllib.regression.LabeledPoint import org.apache.spark.mllib.recommendation.Rating -class PythonMLLibAPISuite extends FunSuite { +class PythonMLLibAPISuite extends SparkFunSuite { SerDe.initialize() diff --git a/mllib/src/test/scala/org/apache/spark/mllib/classification/LogisticRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/classification/LogisticRegressionSuite.scala index 966811a5a3263..e8f3d0c4db20a 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/classification/LogisticRegressionSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/classification/LogisticRegressionSuite.scala @@ -21,9 +21,9 @@ import scala.collection.JavaConversions._ import scala.util.Random import scala.util.control.Breaks._ -import org.scalatest.FunSuite import org.scalatest.Matchers +import org.apache.spark.SparkFunSuite import org.apache.spark.mllib.linalg.{Vector, Vectors} import org.apache.spark.mllib.regression._ import org.apache.spark.mllib.util.{LocalClusterSparkContext, MLlibTestSparkContext} @@ -119,7 +119,7 @@ object LogisticRegressionSuite { } // Preventing the overflow when we compute the probability val maxMargin = margins.max - if (maxMargin > 0) for (i <-0 until nClasses) margins(i) -= maxMargin + if (maxMargin > 0) for (i <- 0 until nClasses) margins(i) -= maxMargin // Computing the probabilities for each class from the margins. val norm = { @@ -130,7 +130,7 @@ object LogisticRegressionSuite { } temp } - for (i <-0 until nClasses) probs(i) /= norm + for (i <- 0 until nClasses) probs(i) /= norm // Compute the cumulative probability so we can generate a random number and assign a label. for (i <- 1 until nClasses) probs(i) += probs(i - 1) @@ -169,7 +169,7 @@ object LogisticRegressionSuite { } -class LogisticRegressionSuite extends FunSuite with MLlibTestSparkContext with Matchers { +class LogisticRegressionSuite extends SparkFunSuite with MLlibTestSparkContext with Matchers { def validatePrediction( predictions: Seq[Double], input: Seq[LabeledPoint], @@ -541,7 +541,7 @@ class LogisticRegressionSuite extends FunSuite with MLlibTestSparkContext with M } -class LogisticRegressionClusterSuite extends FunSuite with LocalClusterSparkContext { +class LogisticRegressionClusterSuite extends SparkFunSuite with LocalClusterSparkContext { test("task size should be small in both training and prediction using SGD optimizer") { val m = 4 diff --git a/mllib/src/test/scala/org/apache/spark/mllib/classification/NaiveBayesSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/classification/NaiveBayesSuite.scala index ea40b41bbbe5e..f7fc8730606af 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/classification/NaiveBayesSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/classification/NaiveBayesSuite.scala @@ -21,9 +21,8 @@ import scala.util.Random import breeze.linalg.{DenseMatrix => BDM, DenseVector => BDV, argmax => brzArgmax, sum => brzSum} import breeze.stats.distributions.{Multinomial => BrzMultinomial} -import org.scalatest.FunSuite -import org.apache.spark.SparkException +import org.apache.spark.{SparkException, SparkFunSuite} import org.apache.spark.mllib.linalg.Vectors import org.apache.spark.mllib.regression.LabeledPoint import org.apache.spark.mllib.util.{LocalClusterSparkContext, MLlibTestSparkContext} @@ -86,7 +85,7 @@ object NaiveBayesSuite { pi = Array(0.2, 0.8), theta = Array(Array(0.1, 0.3, 0.6), Array(0.2, 0.4, 0.4)), Multinomial) } -class NaiveBayesSuite extends FunSuite with MLlibTestSparkContext { +class NaiveBayesSuite extends SparkFunSuite with MLlibTestSparkContext { import NaiveBayes.{Multinomial, Bernoulli} @@ -286,7 +285,7 @@ class NaiveBayesSuite extends FunSuite with MLlibTestSparkContext { } } -class NaiveBayesClusterSuite extends FunSuite with LocalClusterSparkContext { +class NaiveBayesClusterSuite extends SparkFunSuite with LocalClusterSparkContext { test("task size should be small in both training and prediction") { val m = 10 diff --git a/mllib/src/test/scala/org/apache/spark/mllib/classification/SVMSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/classification/SVMSuite.scala index 90f9cec6855bf..b1d78cba9e3dc 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/classification/SVMSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/classification/SVMSuite.scala @@ -21,9 +21,8 @@ import scala.collection.JavaConversions._ import scala.util.Random import org.jblas.DoubleMatrix -import org.scalatest.FunSuite -import org.apache.spark.SparkException +import org.apache.spark.{SparkException, SparkFunSuite} import org.apache.spark.mllib.linalg.Vectors import org.apache.spark.mllib.regression._ import org.apache.spark.mllib.util.{LocalClusterSparkContext, MLlibTestSparkContext} @@ -62,7 +61,7 @@ object SVMSuite { } -class SVMSuite extends FunSuite with MLlibTestSparkContext { +class SVMSuite extends SparkFunSuite with MLlibTestSparkContext { def validatePrediction(predictions: Seq[Double], input: Seq[LabeledPoint]) { val numOffPredictions = predictions.zip(input).count { case (prediction, expected) => @@ -229,7 +228,7 @@ class SVMSuite extends FunSuite with MLlibTestSparkContext { } } -class SVMClusterSuite extends FunSuite with LocalClusterSparkContext { +class SVMClusterSuite extends SparkFunSuite with LocalClusterSparkContext { test("task size should be small in both training and prediction") { val m = 4 diff --git a/mllib/src/test/scala/org/apache/spark/mllib/classification/StreamingLogisticRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/classification/StreamingLogisticRegressionSuite.scala index 5683b55e8500a..e98b61e13e21f 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/classification/StreamingLogisticRegressionSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/classification/StreamingLogisticRegressionSuite.scala @@ -19,15 +19,14 @@ package org.apache.spark.mllib.classification import scala.collection.mutable.ArrayBuffer -import org.scalatest.FunSuite - +import org.apache.spark.SparkFunSuite import org.apache.spark.mllib.linalg.Vectors import org.apache.spark.mllib.regression.LabeledPoint import org.apache.spark.mllib.util.TestingUtils._ import org.apache.spark.streaming.dstream.DStream import org.apache.spark.streaming.TestSuiteBase -class StreamingLogisticRegressionSuite extends FunSuite with TestSuiteBase { +class StreamingLogisticRegressionSuite extends SparkFunSuite with TestSuiteBase { // use longer wait time to ensure job completion override def maxWaitTimeMillis: Int = 30000 diff --git a/mllib/src/test/scala/org/apache/spark/mllib/clustering/GaussianMixtureSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/clustering/GaussianMixtureSuite.scala index f356ffa3e3a26..b218d72f1268a 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/clustering/GaussianMixtureSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/clustering/GaussianMixtureSuite.scala @@ -17,15 +17,14 @@ package org.apache.spark.mllib.clustering -import org.scalatest.FunSuite - +import org.apache.spark.SparkFunSuite import org.apache.spark.mllib.linalg.{Vectors, Matrices} import org.apache.spark.mllib.stat.distribution.MultivariateGaussian import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.mllib.util.TestingUtils._ import org.apache.spark.util.Utils -class GaussianMixtureSuite extends FunSuite with MLlibTestSparkContext { +class GaussianMixtureSuite extends SparkFunSuite with MLlibTestSparkContext { test("single cluster") { val data = sc.parallelize(Array( Vectors.dense(6.0, 9.0), @@ -47,7 +46,7 @@ class GaussianMixtureSuite extends FunSuite with MLlibTestSparkContext { } } - + test("two clusters") { val data = sc.parallelize(GaussianTestData.data) @@ -63,7 +62,7 @@ class GaussianMixtureSuite extends FunSuite with MLlibTestSparkContext { val Ew = Array(1.0 / 3.0, 2.0 / 3.0) val Emu = Array(Vectors.dense(-4.3673), Vectors.dense(5.1604)) val Esigma = Array(Matrices.dense(1, 1, Array(1.1098)), Matrices.dense(1, 1, Array(0.86644))) - + val gmm = new GaussianMixture() .setK(2) .setInitialModel(initialGmm) diff --git a/mllib/src/test/scala/org/apache/spark/mllib/clustering/KMeansSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/clustering/KMeansSuite.scala index 877e6dc699523..0dbbd7127444f 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/clustering/KMeansSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/clustering/KMeansSuite.scala @@ -19,14 +19,13 @@ package org.apache.spark.mllib.clustering import scala.util.Random -import org.scalatest.FunSuite - +import org.apache.spark.SparkFunSuite import org.apache.spark.mllib.linalg.{DenseVector, SparseVector, Vector, Vectors} import org.apache.spark.mllib.util.{LocalClusterSparkContext, MLlibTestSparkContext} import org.apache.spark.mllib.util.TestingUtils._ import org.apache.spark.util.Utils -class KMeansSuite extends FunSuite with MLlibTestSparkContext { +class KMeansSuite extends SparkFunSuite with MLlibTestSparkContext { import org.apache.spark.mllib.clustering.KMeans.{K_MEANS_PARALLEL, RANDOM} @@ -281,7 +280,7 @@ class KMeansSuite extends FunSuite with MLlibTestSparkContext { } } -object KMeansSuite extends FunSuite { +object KMeansSuite extends SparkFunSuite { def createModel(dim: Int, k: Int, isSparse: Boolean): KMeansModel = { val singlePoint = isSparse match { case true => @@ -305,7 +304,7 @@ object KMeansSuite extends FunSuite { } } -class KMeansClusterSuite extends FunSuite with LocalClusterSparkContext { +class KMeansClusterSuite extends SparkFunSuite with LocalClusterSparkContext { test("task size should be small in both training and prediction") { val m = 4 diff --git a/mllib/src/test/scala/org/apache/spark/mllib/clustering/LDASuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/clustering/LDASuite.scala index d5b7d96335744..406affa25539d 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/clustering/LDASuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/clustering/LDASuite.scala @@ -19,13 +19,12 @@ package org.apache.spark.mllib.clustering import breeze.linalg.{DenseMatrix => BDM} -import org.scalatest.FunSuite - +import org.apache.spark.SparkFunSuite import org.apache.spark.mllib.linalg.{Vector, DenseMatrix, Matrix, Vectors} import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.mllib.util.TestingUtils._ -class LDASuite extends FunSuite with MLlibTestSparkContext { +class LDASuite extends SparkFunSuite with MLlibTestSparkContext { import LDASuite._ diff --git a/mllib/src/test/scala/org/apache/spark/mllib/clustering/PowerIterationClusteringSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/clustering/PowerIterationClusteringSuite.scala index 556842f3129a3..19e65f1b53ab5 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/clustering/PowerIterationClusteringSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/clustering/PowerIterationClusteringSuite.scala @@ -20,15 +20,13 @@ package org.apache.spark.mllib.clustering import scala.collection.mutable import scala.util.Random -import org.scalatest.FunSuite - -import org.apache.spark.SparkContext +import org.apache.spark.{SparkContext, SparkFunSuite} import org.apache.spark.graphx.{Edge, Graph} import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.mllib.util.TestingUtils._ import org.apache.spark.util.Utils -class PowerIterationClusteringSuite extends FunSuite with MLlibTestSparkContext { +class PowerIterationClusteringSuite extends SparkFunSuite with MLlibTestSparkContext { import org.apache.spark.mllib.clustering.PowerIterationClustering._ @@ -58,7 +56,7 @@ class PowerIterationClusteringSuite extends FunSuite with MLlibTestSparkContext predictions(a.cluster) += a.id } assert(predictions.toSet == Set((0 to 3).toSet, (4 to 15).toSet)) - + val model2 = new PowerIterationClustering() .setK(2) .setInitializationMode("degree") @@ -130,7 +128,7 @@ class PowerIterationClusteringSuite extends FunSuite with MLlibTestSparkContext } } -object PowerIterationClusteringSuite extends FunSuite { +object PowerIterationClusteringSuite extends SparkFunSuite { def createModel(sc: SparkContext, k: Int, nPoints: Int): PowerIterationClusteringModel = { val assignments = sc.parallelize( (0 until nPoints).map(p => PowerIterationClustering.Assignment(p, Random.nextInt(k)))) diff --git a/mllib/src/test/scala/org/apache/spark/mllib/clustering/StreamingKMeansSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/clustering/StreamingKMeansSuite.scala index f90025d535e45..ac01622b8a089 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/clustering/StreamingKMeansSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/clustering/StreamingKMeansSuite.scala @@ -17,15 +17,14 @@ package org.apache.spark.mllib.clustering -import org.scalatest.FunSuite - +import org.apache.spark.SparkFunSuite import org.apache.spark.mllib.linalg.{Vector, Vectors} import org.apache.spark.mllib.util.TestingUtils._ import org.apache.spark.streaming.TestSuiteBase import org.apache.spark.streaming.dstream.DStream import org.apache.spark.util.random.XORShiftRandom -class StreamingKMeansSuite extends FunSuite with TestSuiteBase { +class StreamingKMeansSuite extends SparkFunSuite with TestSuiteBase { override def maxWaitTimeMillis: Int = 30000 @@ -133,6 +132,13 @@ class StreamingKMeansSuite extends FunSuite with TestSuiteBase { assert(math.abs(c1) ~== 0.8 absTol 0.6) } + test("SPARK-7946 setDecayFactor") { + val kMeans = new StreamingKMeans() + assert(kMeans.decayFactor === 1.0) + kMeans.setDecayFactor(2.0) + assert(kMeans.decayFactor === 2.0) + } + def StreamingKMeansDataGenerator( numPoints: Int, numBatches: Int, diff --git a/mllib/src/test/scala/org/apache/spark/mllib/evaluation/AreaUnderCurveSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/evaluation/AreaUnderCurveSuite.scala index 79847633ff0dc..87ccc7eda44ea 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/evaluation/AreaUnderCurveSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/evaluation/AreaUnderCurveSuite.scala @@ -17,12 +17,11 @@ package org.apache.spark.mllib.evaluation -import org.scalatest.FunSuite - +import org.apache.spark.SparkFunSuite import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.mllib.util.TestingUtils._ -class AreaUnderCurveSuite extends FunSuite with MLlibTestSparkContext { +class AreaUnderCurveSuite extends SparkFunSuite with MLlibTestSparkContext { test("auc computation") { val curve = Seq((0.0, 0.0), (1.0, 1.0), (2.0, 3.0), (3.0, 0.0)) val auc = 4.0 diff --git a/mllib/src/test/scala/org/apache/spark/mllib/evaluation/BinaryClassificationMetricsSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/evaluation/BinaryClassificationMetricsSuite.scala index e0224f960cc43..99d52fabc5309 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/evaluation/BinaryClassificationMetricsSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/evaluation/BinaryClassificationMetricsSuite.scala @@ -17,12 +17,11 @@ package org.apache.spark.mllib.evaluation -import org.scalatest.FunSuite - +import org.apache.spark.SparkFunSuite import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.mllib.util.TestingUtils._ -class BinaryClassificationMetricsSuite extends FunSuite with MLlibTestSparkContext { +class BinaryClassificationMetricsSuite extends SparkFunSuite with MLlibTestSparkContext { private def areWithinEpsilon(x: (Double, Double)): Boolean = x._1 ~= (x._2) absTol 1E-5 diff --git a/mllib/src/test/scala/org/apache/spark/mllib/evaluation/MulticlassMetricsSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/evaluation/MulticlassMetricsSuite.scala index 7dc4f3cfbc4e4..d55bc8c3ec09f 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/evaluation/MulticlassMetricsSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/evaluation/MulticlassMetricsSuite.scala @@ -17,12 +17,11 @@ package org.apache.spark.mllib.evaluation -import org.scalatest.FunSuite - +import org.apache.spark.SparkFunSuite import org.apache.spark.mllib.linalg.Matrices import org.apache.spark.mllib.util.MLlibTestSparkContext -class MulticlassMetricsSuite extends FunSuite with MLlibTestSparkContext { +class MulticlassMetricsSuite extends SparkFunSuite with MLlibTestSparkContext { test("Multiclass evaluation metrics") { /* * Confusion matrix for 3-class classification with total 9 instances: diff --git a/mllib/src/test/scala/org/apache/spark/mllib/evaluation/MultilabelMetricsSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/evaluation/MultilabelMetricsSuite.scala index 2537dd62c92f2..f3b19aeb42f84 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/evaluation/MultilabelMetricsSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/evaluation/MultilabelMetricsSuite.scala @@ -17,12 +17,11 @@ package org.apache.spark.mllib.evaluation -import org.scalatest.FunSuite - +import org.apache.spark.SparkFunSuite import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.rdd.RDD -class MultilabelMetricsSuite extends FunSuite with MLlibTestSparkContext { +class MultilabelMetricsSuite extends SparkFunSuite with MLlibTestSparkContext { test("Multilabel evaluation metrics") { /* * Documents true labels (5x class0, 3x class1, 4x class2): diff --git a/mllib/src/test/scala/org/apache/spark/mllib/evaluation/RankingMetricsSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/evaluation/RankingMetricsSuite.scala index 609eed983ff4e..c0924a213a844 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/evaluation/RankingMetricsSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/evaluation/RankingMetricsSuite.scala @@ -17,12 +17,11 @@ package org.apache.spark.mllib.evaluation -import org.scalatest.FunSuite - +import org.apache.spark.SparkFunSuite import org.apache.spark.mllib.util.TestingUtils._ import org.apache.spark.mllib.util.MLlibTestSparkContext -class RankingMetricsSuite extends FunSuite with MLlibTestSparkContext { +class RankingMetricsSuite extends SparkFunSuite with MLlibTestSparkContext { test("Ranking metrics: map, ndcg") { val predictionAndLabels = sc.parallelize( Seq( diff --git a/mllib/src/test/scala/org/apache/spark/mllib/evaluation/RegressionMetricsSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/evaluation/RegressionMetricsSuite.scala index 3aa732474ec2e..9de2bdb6d7246 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/evaluation/RegressionMetricsSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/evaluation/RegressionMetricsSuite.scala @@ -17,12 +17,11 @@ package org.apache.spark.mllib.evaluation -import org.scalatest.FunSuite - +import org.apache.spark.SparkFunSuite import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.mllib.util.TestingUtils._ -class RegressionMetricsSuite extends FunSuite with MLlibTestSparkContext { +class RegressionMetricsSuite extends SparkFunSuite with MLlibTestSparkContext { test("regression metrics") { val predictionAndObservations = sc.parallelize( diff --git a/mllib/src/test/scala/org/apache/spark/mllib/feature/ChiSqSelectorSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/feature/ChiSqSelectorSuite.scala index 747f5914598ec..889727fb55823 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/feature/ChiSqSelectorSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/feature/ChiSqSelectorSuite.scala @@ -17,13 +17,12 @@ package org.apache.spark.mllib.feature -import org.scalatest.FunSuite - +import org.apache.spark.SparkFunSuite import org.apache.spark.mllib.linalg.Vectors import org.apache.spark.mllib.regression.LabeledPoint import org.apache.spark.mllib.util.MLlibTestSparkContext -class ChiSqSelectorSuite extends FunSuite with MLlibTestSparkContext { +class ChiSqSelectorSuite extends SparkFunSuite with MLlibTestSparkContext { /* * Contingency tables diff --git a/mllib/src/test/scala/org/apache/spark/mllib/feature/ElementwiseProductSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/feature/ElementwiseProductSuite.scala index f3a482abda873..ccbf8a91cdd37 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/feature/ElementwiseProductSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/feature/ElementwiseProductSuite.scala @@ -17,13 +17,12 @@ package org.apache.spark.mllib.feature -import org.scalatest.FunSuite - +import org.apache.spark.SparkFunSuite import org.apache.spark.mllib.linalg.{DenseVector, SparseVector, Vectors} import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.mllib.util.TestingUtils._ -class ElementwiseProductSuite extends FunSuite with MLlibTestSparkContext { +class ElementwiseProductSuite extends SparkFunSuite with MLlibTestSparkContext { test("elementwise (hadamard) product should properly apply vector to dense data set") { val denseData = Array( diff --git a/mllib/src/test/scala/org/apache/spark/mllib/feature/HashingTFSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/feature/HashingTFSuite.scala index 0c4dfb7b97c7f..cf279c02334e9 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/feature/HashingTFSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/feature/HashingTFSuite.scala @@ -17,12 +17,11 @@ package org.apache.spark.mllib.feature -import org.scalatest.FunSuite - +import org.apache.spark.SparkFunSuite import org.apache.spark.mllib.linalg.Vectors import org.apache.spark.mllib.util.MLlibTestSparkContext -class HashingTFSuite extends FunSuite with MLlibTestSparkContext { +class HashingTFSuite extends SparkFunSuite with MLlibTestSparkContext { test("hashing tf on a single doc") { val hashingTF = new HashingTF(1000) diff --git a/mllib/src/test/scala/org/apache/spark/mllib/feature/IDFSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/feature/IDFSuite.scala index 0a5cad7caf8e4..21163633051e5 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/feature/IDFSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/feature/IDFSuite.scala @@ -17,13 +17,12 @@ package org.apache.spark.mllib.feature -import org.scalatest.FunSuite - +import org.apache.spark.SparkFunSuite import org.apache.spark.mllib.linalg.{DenseVector, SparseVector, Vectors, Vector} import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.mllib.util.TestingUtils._ -class IDFSuite extends FunSuite with MLlibTestSparkContext { +class IDFSuite extends SparkFunSuite with MLlibTestSparkContext { test("idf") { val n = 4 diff --git a/mllib/src/test/scala/org/apache/spark/mllib/feature/NormalizerSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/feature/NormalizerSuite.scala index 5c4af2b99e68b..34122d6ed2e95 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/feature/NormalizerSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/feature/NormalizerSuite.scala @@ -17,15 +17,14 @@ package org.apache.spark.mllib.feature -import org.scalatest.FunSuite - import breeze.linalg.{norm => brzNorm} +import org.apache.spark.SparkFunSuite import org.apache.spark.mllib.linalg.{DenseVector, SparseVector, Vectors} import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.mllib.util.TestingUtils._ -class NormalizerSuite extends FunSuite with MLlibTestSparkContext { +class NormalizerSuite extends SparkFunSuite with MLlibTestSparkContext { val data = Array( Vectors.sparse(3, Seq((0, -2.0), (1, 2.3))), diff --git a/mllib/src/test/scala/org/apache/spark/mllib/feature/PCASuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/feature/PCASuite.scala index 758af588f1c69..e57f49191378f 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/feature/PCASuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/feature/PCASuite.scala @@ -17,13 +17,12 @@ package org.apache.spark.mllib.feature -import org.scalatest.FunSuite - +import org.apache.spark.SparkFunSuite import org.apache.spark.mllib.linalg.Vectors import org.apache.spark.mllib.linalg.distributed.RowMatrix import org.apache.spark.mllib.util.MLlibTestSparkContext -class PCASuite extends FunSuite with MLlibTestSparkContext { +class PCASuite extends SparkFunSuite with MLlibTestSparkContext { private val data = Array( Vectors.sparse(5, Seq((1, 1.0), (3, 7.0))), diff --git a/mllib/src/test/scala/org/apache/spark/mllib/feature/StandardScalerSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/feature/StandardScalerSuite.scala index 1eb991869de40..6ab2fa6770123 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/feature/StandardScalerSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/feature/StandardScalerSuite.scala @@ -17,15 +17,14 @@ package org.apache.spark.mllib.feature -import org.scalatest.FunSuite - +import org.apache.spark.SparkFunSuite import org.apache.spark.mllib.linalg.{DenseVector, SparseVector, Vector, Vectors} import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.mllib.util.TestingUtils._ import org.apache.spark.mllib.stat.{MultivariateStatisticalSummary, MultivariateOnlineSummarizer} import org.apache.spark.rdd.RDD -class StandardScalerSuite extends FunSuite with MLlibTestSparkContext { +class StandardScalerSuite extends SparkFunSuite with MLlibTestSparkContext { // When the input data is all constant, the variance is zero. The standardization against // zero variance is not well-defined, but we decide to just set it into zero here. diff --git a/mllib/src/test/scala/org/apache/spark/mllib/feature/Word2VecSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/feature/Word2VecSuite.scala index 98a98a7599bcb..b6818369208d7 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/feature/Word2VecSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/feature/Word2VecSuite.scala @@ -17,14 +17,13 @@ package org.apache.spark.mllib.feature -import org.scalatest.FunSuite - +import org.apache.spark.SparkFunSuite import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.mllib.util.TestingUtils._ import org.apache.spark.util.Utils -class Word2VecSuite extends FunSuite with MLlibTestSparkContext { +class Word2VecSuite extends SparkFunSuite with MLlibTestSparkContext { // TODO: add more tests diff --git a/mllib/src/test/scala/org/apache/spark/mllib/fpm/FPGrowthSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/fpm/FPGrowthSuite.scala index bd5b9cc3afa10..66ae3543ecc4e 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/fpm/FPGrowthSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/fpm/FPGrowthSuite.scala @@ -16,11 +16,10 @@ */ package org.apache.spark.mllib.fpm -import org.scalatest.FunSuite - +import org.apache.spark.SparkFunSuite import org.apache.spark.mllib.util.MLlibTestSparkContext -class FPGrowthSuite extends FunSuite with MLlibTestSparkContext { +class FPGrowthSuite extends SparkFunSuite with MLlibTestSparkContext { test("FP-Growth using String type") { diff --git a/mllib/src/test/scala/org/apache/spark/mllib/fpm/FPTreeSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/fpm/FPTreeSuite.scala index 04017f67c311d..a56d7b3579213 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/fpm/FPTreeSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/fpm/FPTreeSuite.scala @@ -19,11 +19,10 @@ package org.apache.spark.mllib.fpm import scala.language.existentials -import org.scalatest.FunSuite - +import org.apache.spark.SparkFunSuite import org.apache.spark.mllib.util.MLlibTestSparkContext -class FPTreeSuite extends FunSuite with MLlibTestSparkContext { +class FPTreeSuite extends SparkFunSuite with MLlibTestSparkContext { test("add transaction") { val tree = new FPTree[String] diff --git a/mllib/src/test/scala/org/apache/spark/mllib/impl/PeriodicGraphCheckpointerSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/impl/PeriodicGraphCheckpointerSuite.scala index 699f009f0f2ec..d34888af2d73b 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/impl/PeriodicGraphCheckpointerSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/impl/PeriodicGraphCheckpointerSuite.scala @@ -17,18 +17,16 @@ package org.apache.spark.mllib.impl -import org.scalatest.FunSuite - import org.apache.hadoop.fs.{FileSystem, Path} -import org.apache.spark.SparkContext +import org.apache.spark.{SparkContext, SparkFunSuite} import org.apache.spark.graphx.{Edge, Graph} import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.storage.StorageLevel import org.apache.spark.util.Utils -class PeriodicGraphCheckpointerSuite extends FunSuite with MLlibTestSparkContext { +class PeriodicGraphCheckpointerSuite extends SparkFunSuite with MLlibTestSparkContext { import PeriodicGraphCheckpointerSuite._ diff --git a/mllib/src/test/scala/org/apache/spark/mllib/linalg/BLASSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/linalg/BLASSuite.scala index 64ecd12ea7ded..b0f3f71113c57 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/linalg/BLASSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/linalg/BLASSuite.scala @@ -17,12 +17,11 @@ package org.apache.spark.mllib.linalg -import org.scalatest.FunSuite - +import org.apache.spark.SparkFunSuite import org.apache.spark.mllib.util.TestingUtils._ import org.apache.spark.mllib.linalg.BLAS._ -class BLASSuite extends FunSuite { +class BLASSuite extends SparkFunSuite { test("copy") { val sx = Vectors.sparse(4, Array(0, 2), Array(1.0, -2.0)) @@ -140,7 +139,7 @@ class BLASSuite extends FunSuite { syr(alpha, x, dA) assert(dA ~== expected absTol 1e-15) - + val dB = new DenseMatrix(3, 4, Array(0.0, 1.2, 2.2, 3.1, 1.2, 3.2, 5.3, 4.6, 2.2, 5.3, 1.8, 3.0)) @@ -149,7 +148,7 @@ class BLASSuite extends FunSuite { syr(alpha, x, dB) } } - + val dC = new DenseMatrix(3, 3, Array(0.0, 1.2, 2.2, 1.2, 3.2, 5.3, 2.2, 5.3, 1.8)) @@ -158,7 +157,7 @@ class BLASSuite extends FunSuite { syr(alpha, x, dC) } } - + val y = new DenseVector(Array(0.0, 2.7, 3.5, 2.1, 1.5)) withClue("Size of vector must match the rank of matrix") { @@ -256,13 +255,13 @@ class BLASSuite extends FunSuite { val dA = new DenseMatrix(4, 3, Array(0.0, 1.0, 0.0, 0.0, 2.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 3.0)) val sA = new SparseMatrix(4, 3, Array(0, 1, 3, 4), Array(1, 0, 2, 3), Array(1.0, 2.0, 1.0, 3.0)) - + val dA2 = new DenseMatrix(4, 3, Array(0.0, 2.0, 0.0, 1.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 3.0), true) val sA2 = new SparseMatrix(4, 3, Array(0, 1, 2, 3, 4), Array(1, 0, 1, 2), Array(2.0, 1.0, 1.0, 3.0), true) - + val dx = new DenseVector(Array(1.0, 2.0, 3.0)) val sx = dx.toSparse val expected = new DenseVector(Array(4.0, 1.0, 2.0, 9.0)) @@ -271,7 +270,7 @@ class BLASSuite extends FunSuite { assert(sA.multiply(dx) ~== expected absTol 1e-15) assert(dA.multiply(sx) ~== expected absTol 1e-15) assert(sA.multiply(sx) ~== expected absTol 1e-15) - + val y1 = new DenseVector(Array(1.0, 3.0, 1.0, 0.0)) val y2 = y1.copy val y3 = y1.copy @@ -288,7 +287,7 @@ class BLASSuite extends FunSuite { val y14 = y1.copy val y15 = y1.copy val y16 = y1.copy - + val expected2 = new DenseVector(Array(6.0, 7.0, 4.0, 9.0)) val expected3 = new DenseVector(Array(10.0, 8.0, 6.0, 18.0)) @@ -296,42 +295,42 @@ class BLASSuite extends FunSuite { gemv(1.0, sA, dx, 2.0, y2) gemv(1.0, dA, sx, 2.0, y3) gemv(1.0, sA, sx, 2.0, y4) - + gemv(1.0, dA2, dx, 2.0, y5) gemv(1.0, sA2, dx, 2.0, y6) gemv(1.0, dA2, sx, 2.0, y7) gemv(1.0, sA2, sx, 2.0, y8) - + gemv(2.0, dA, dx, 2.0, y9) gemv(2.0, sA, dx, 2.0, y10) gemv(2.0, dA, sx, 2.0, y11) gemv(2.0, sA, sx, 2.0, y12) - + gemv(2.0, dA2, dx, 2.0, y13) gemv(2.0, sA2, dx, 2.0, y14) gemv(2.0, dA2, sx, 2.0, y15) gemv(2.0, sA2, sx, 2.0, y16) - + assert(y1 ~== expected2 absTol 1e-15) assert(y2 ~== expected2 absTol 1e-15) assert(y3 ~== expected2 absTol 1e-15) assert(y4 ~== expected2 absTol 1e-15) - + assert(y5 ~== expected2 absTol 1e-15) assert(y6 ~== expected2 absTol 1e-15) assert(y7 ~== expected2 absTol 1e-15) assert(y8 ~== expected2 absTol 1e-15) - + assert(y9 ~== expected3 absTol 1e-15) assert(y10 ~== expected3 absTol 1e-15) assert(y11 ~== expected3 absTol 1e-15) assert(y12 ~== expected3 absTol 1e-15) - + assert(y13 ~== expected3 absTol 1e-15) assert(y14 ~== expected3 absTol 1e-15) assert(y15 ~== expected3 absTol 1e-15) assert(y16 ~== expected3 absTol 1e-15) - + withClue("columns of A don't match the rows of B") { intercept[Exception] { gemv(1.0, dA.transpose, dx, 2.0, y1) @@ -346,12 +345,12 @@ class BLASSuite extends FunSuite { gemv(1.0, sA.transpose, sx, 2.0, y1) } } - + val dAT = new DenseMatrix(3, 4, Array(0.0, 2.0, 0.0, 1.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 3.0)) val sAT = new SparseMatrix(3, 4, Array(0, 1, 2, 3, 4), Array(1, 0, 1, 2), Array(2.0, 1.0, 1.0, 3.0)) - + val dATT = dAT.transpose val sATT = sAT.transpose diff --git a/mllib/src/test/scala/org/apache/spark/mllib/linalg/BreezeMatrixConversionSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/linalg/BreezeMatrixConversionSuite.scala index 2031032373971..dc04258e41d27 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/linalg/BreezeMatrixConversionSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/linalg/BreezeMatrixConversionSuite.scala @@ -17,11 +17,11 @@ package org.apache.spark.mllib.linalg -import org.scalatest.FunSuite - import breeze.linalg.{DenseMatrix => BDM, CSCMatrix => BSM} -class BreezeMatrixConversionSuite extends FunSuite { +import org.apache.spark.SparkFunSuite + +class BreezeMatrixConversionSuite extends SparkFunSuite { test("dense matrix to breeze") { val mat = Matrices.dense(3, 2, Array(0.0, 1.0, 2.0, 3.0, 4.0, 5.0)) val breeze = mat.toBreeze.asInstanceOf[BDM[Double]] diff --git a/mllib/src/test/scala/org/apache/spark/mllib/linalg/BreezeVectorConversionSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/linalg/BreezeVectorConversionSuite.scala index 8abdac72902c6..3772c9235ad3a 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/linalg/BreezeVectorConversionSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/linalg/BreezeVectorConversionSuite.scala @@ -17,14 +17,14 @@ package org.apache.spark.mllib.linalg -import org.scalatest.FunSuite - import breeze.linalg.{DenseVector => BDV, SparseVector => BSV} +import org.apache.spark.SparkFunSuite + /** * Test Breeze vector conversions. */ -class BreezeVectorConversionSuite extends FunSuite { +class BreezeVectorConversionSuite extends SparkFunSuite { val arr = Array(0.1, 0.2, 0.3, 0.4) val n = 20 diff --git a/mllib/src/test/scala/org/apache/spark/mllib/linalg/MatricesSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/linalg/MatricesSuite.scala index 86119ec38101e..8dbb70f5d1c4c 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/linalg/MatricesSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/linalg/MatricesSuite.scala @@ -20,13 +20,13 @@ package org.apache.spark.mllib.linalg import java.util.Random import org.mockito.Mockito.when -import org.scalatest.FunSuite import org.scalatest.mock.MockitoSugar._ import scala.collection.mutable.{Map => MutableMap} +import org.apache.spark.SparkFunSuite import org.apache.spark.mllib.util.TestingUtils._ -class MatricesSuite extends FunSuite { +class MatricesSuite extends SparkFunSuite { test("dense matrix construction") { val m = 3 val n = 2 diff --git a/mllib/src/test/scala/org/apache/spark/mllib/linalg/VectorsSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/linalg/VectorsSuite.scala index 24755e9ff46fc..c4ae0a16f7c04 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/linalg/VectorsSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/linalg/VectorsSuite.scala @@ -20,12 +20,11 @@ package org.apache.spark.mllib.linalg import scala.util.Random import breeze.linalg.{DenseMatrix => BDM, squaredDistance => breezeSquaredDistance} -import org.scalatest.FunSuite -import org.apache.spark.SparkException +import org.apache.spark.{SparkException, SparkFunSuite} import org.apache.spark.mllib.util.TestingUtils._ -class VectorsSuite extends FunSuite { +class VectorsSuite extends SparkFunSuite { val arr = Array(0.1, 0.0, 0.3, 0.4) val n = 4 @@ -215,13 +214,13 @@ class VectorsSuite extends FunSuite { val squaredDist = breezeSquaredDistance(sparseVector1.toBreeze, sparseVector2.toBreeze) - // SparseVector vs. SparseVector - assert(Vectors.sqdist(sparseVector1, sparseVector2) ~== squaredDist relTol 1E-8) + // SparseVector vs. SparseVector + assert(Vectors.sqdist(sparseVector1, sparseVector2) ~== squaredDist relTol 1E-8) // DenseVector vs. SparseVector assert(Vectors.sqdist(denseVector1, sparseVector2) ~== squaredDist relTol 1E-8) // DenseVector vs. DenseVector assert(Vectors.sqdist(denseVector1, denseVector2) ~== squaredDist relTol 1E-8) - } + } } test("foreachActive") { diff --git a/mllib/src/test/scala/org/apache/spark/mllib/linalg/distributed/BlockMatrixSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/linalg/distributed/BlockMatrixSuite.scala index a58336175899c..93fe04c139b9a 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/linalg/distributed/BlockMatrixSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/linalg/distributed/BlockMatrixSuite.scala @@ -20,14 +20,13 @@ package org.apache.spark.mllib.linalg.distributed import java.{util => ju} import breeze.linalg.{DenseMatrix => BDM} -import org.scalatest.FunSuite -import org.apache.spark.SparkException +import org.apache.spark.{SparkException, SparkFunSuite} import org.apache.spark.mllib.linalg.{SparseMatrix, DenseMatrix, Matrices, Matrix} import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.mllib.util.TestingUtils._ -class BlockMatrixSuite extends FunSuite with MLlibTestSparkContext { +class BlockMatrixSuite extends SparkFunSuite with MLlibTestSparkContext { val m = 5 val n = 4 diff --git a/mllib/src/test/scala/org/apache/spark/mllib/linalg/distributed/CoordinateMatrixSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/linalg/distributed/CoordinateMatrixSuite.scala index 04b36a9ef9990..f3728cd036a3f 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/linalg/distributed/CoordinateMatrixSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/linalg/distributed/CoordinateMatrixSuite.scala @@ -17,14 +17,13 @@ package org.apache.spark.mllib.linalg.distributed -import org.scalatest.FunSuite - import breeze.linalg.{DenseMatrix => BDM} +import org.apache.spark.SparkFunSuite import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.mllib.linalg.Vectors -class CoordinateMatrixSuite extends FunSuite with MLlibTestSparkContext { +class CoordinateMatrixSuite extends SparkFunSuite with MLlibTestSparkContext { val m = 5 val n = 4 diff --git a/mllib/src/test/scala/org/apache/spark/mllib/linalg/distributed/IndexedRowMatrixSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/linalg/distributed/IndexedRowMatrixSuite.scala index 2ab53cc13db71..4a7b99a976f0a 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/linalg/distributed/IndexedRowMatrixSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/linalg/distributed/IndexedRowMatrixSuite.scala @@ -17,15 +17,14 @@ package org.apache.spark.mllib.linalg.distributed -import org.scalatest.FunSuite - import breeze.linalg.{diag => brzDiag, DenseMatrix => BDM, DenseVector => BDV} +import org.apache.spark.SparkFunSuite import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.rdd.RDD import org.apache.spark.mllib.linalg.{Matrices, Vectors} -class IndexedRowMatrixSuite extends FunSuite with MLlibTestSparkContext { +class IndexedRowMatrixSuite extends SparkFunSuite with MLlibTestSparkContext { val m = 4 val n = 3 diff --git a/mllib/src/test/scala/org/apache/spark/mllib/linalg/distributed/RowMatrixSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/linalg/distributed/RowMatrixSuite.scala index 27bb19f472e1e..b6cb53d0c743e 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/linalg/distributed/RowMatrixSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/linalg/distributed/RowMatrixSuite.scala @@ -20,12 +20,12 @@ package org.apache.spark.mllib.linalg.distributed import scala.util.Random import breeze.linalg.{DenseVector => BDV, DenseMatrix => BDM, norm => brzNorm, svd => brzSvd} -import org.scalatest.FunSuite +import org.apache.spark.SparkFunSuite import org.apache.spark.mllib.linalg.{Matrices, Vectors, Vector} import org.apache.spark.mllib.util.{LocalClusterSparkContext, MLlibTestSparkContext} -class RowMatrixSuite extends FunSuite with MLlibTestSparkContext { +class RowMatrixSuite extends SparkFunSuite with MLlibTestSparkContext { val m = 4 val n = 3 @@ -240,7 +240,7 @@ class RowMatrixSuite extends FunSuite with MLlibTestSparkContext { } } -class RowMatrixClusterSuite extends FunSuite with LocalClusterSparkContext { +class RowMatrixClusterSuite extends SparkFunSuite with LocalClusterSparkContext { var mat: RowMatrix = _ diff --git a/mllib/src/test/scala/org/apache/spark/mllib/optimization/GradientDescentSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/optimization/GradientDescentSuite.scala index e110506d579b0..a5a59e9fad5ae 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/optimization/GradientDescentSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/optimization/GradientDescentSuite.scala @@ -20,8 +20,9 @@ package org.apache.spark.mllib.optimization import scala.collection.JavaConversions._ import scala.util.Random -import org.scalatest.{FunSuite, Matchers} +import org.scalatest.Matchers +import org.apache.spark.SparkFunSuite import org.apache.spark.mllib.linalg.Vectors import org.apache.spark.mllib.regression._ import org.apache.spark.mllib.util.{LocalClusterSparkContext, MLlibTestSparkContext} @@ -61,7 +62,7 @@ object GradientDescentSuite { } } -class GradientDescentSuite extends FunSuite with MLlibTestSparkContext with Matchers { +class GradientDescentSuite extends SparkFunSuite with MLlibTestSparkContext with Matchers { test("Assert the loss is decreasing.") { val nPoints = 10000 @@ -140,7 +141,7 @@ class GradientDescentSuite extends FunSuite with MLlibTestSparkContext with Matc } } -class GradientDescentClusterSuite extends FunSuite with LocalClusterSparkContext { +class GradientDescentClusterSuite extends SparkFunSuite with LocalClusterSparkContext { test("task size should be small") { val m = 4 diff --git a/mllib/src/test/scala/org/apache/spark/mllib/optimization/LBFGSSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/optimization/LBFGSSuite.scala index c8f2adcf155a7..d07b9d5b89227 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/optimization/LBFGSSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/optimization/LBFGSSuite.scala @@ -19,14 +19,15 @@ package org.apache.spark.mllib.optimization import scala.util.Random -import org.scalatest.{FunSuite, Matchers} +import org.scalatest.Matchers +import org.apache.spark.SparkFunSuite import org.apache.spark.mllib.linalg.Vectors import org.apache.spark.mllib.regression.LabeledPoint import org.apache.spark.mllib.util.{LocalClusterSparkContext, MLlibTestSparkContext} import org.apache.spark.mllib.util.TestingUtils._ -class LBFGSSuite extends FunSuite with MLlibTestSparkContext with Matchers { +class LBFGSSuite extends SparkFunSuite with MLlibTestSparkContext with Matchers { val nPoints = 10000 val A = 2.0 @@ -229,7 +230,7 @@ class LBFGSSuite extends FunSuite with MLlibTestSparkContext with Matchers { } } -class LBFGSClusterSuite extends FunSuite with LocalClusterSparkContext { +class LBFGSClusterSuite extends SparkFunSuite with LocalClusterSparkContext { test("task size should be small") { val m = 10 diff --git a/mllib/src/test/scala/org/apache/spark/mllib/optimization/NNLSSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/optimization/NNLSSuite.scala index bb723fc471181..d8f9b8c33963d 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/optimization/NNLSSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/optimization/NNLSSuite.scala @@ -19,13 +19,12 @@ package org.apache.spark.mllib.optimization import scala.util.Random -import org.scalatest.FunSuite - import org.jblas.{DoubleMatrix, SimpleBlas} +import org.apache.spark.SparkFunSuite import org.apache.spark.mllib.util.TestingUtils._ -class NNLSSuite extends FunSuite { +class NNLSSuite extends SparkFunSuite { /** Generate an NNLS problem whose optimal solution is the all-ones vector. */ def genOnesData(n: Int, rand: Random): (DoubleMatrix, DoubleMatrix) = { val A = new DoubleMatrix(n, n, Array.fill(n*n)(rand.nextDouble()): _*) diff --git a/mllib/src/test/scala/org/apache/spark/mllib/pmml/export/BinaryClassificationPMMLModelExportSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/pmml/export/BinaryClassificationPMMLModelExportSuite.scala index 0b646cf1ce6c4..4c6e76e47419b 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/pmml/export/BinaryClassificationPMMLModelExportSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/pmml/export/BinaryClassificationPMMLModelExportSuite.scala @@ -19,13 +19,13 @@ package org.apache.spark.mllib.pmml.export import org.dmg.pmml.RegressionModel import org.dmg.pmml.RegressionNormalizationMethodType -import org.scalatest.FunSuite +import org.apache.spark.SparkFunSuite import org.apache.spark.mllib.classification.LogisticRegressionModel import org.apache.spark.mllib.classification.SVMModel import org.apache.spark.mllib.util.LinearDataGenerator -class BinaryClassificationPMMLModelExportSuite extends FunSuite { +class BinaryClassificationPMMLModelExportSuite extends SparkFunSuite { test("logistic regression PMML export") { val linearInput = LinearDataGenerator.generateLinearInput(3.0, Array(10.0, 10.0), 1, 17) @@ -53,13 +53,13 @@ class BinaryClassificationPMMLModelExportSuite extends FunSuite { // ensure logistic regression has normalization method set to LOGIT assert(pmmlRegressionModel.getNormalizationMethod() == RegressionNormalizationMethodType.LOGIT) } - + test("linear SVM PMML export") { val linearInput = LinearDataGenerator.generateLinearInput(3.0, Array(10.0, 10.0), 1, 17) val svmModel = new SVMModel(linearInput(0).features, linearInput(0).label) - + val svmModelExport = PMMLModelExportFactory.createPMMLModelExport(svmModel) - + // assert that the PMML format is as expected assert(svmModelExport.isInstanceOf[PMMLModelExport]) val pmml = svmModelExport.getPmml @@ -80,5 +80,5 @@ class BinaryClassificationPMMLModelExportSuite extends FunSuite { // ensure linear SVM has normalization method set to NONE assert(pmmlRegressionModel.getNormalizationMethod() == RegressionNormalizationMethodType.NONE) } - + } diff --git a/mllib/src/test/scala/org/apache/spark/mllib/pmml/export/GeneralizedLinearPMMLModelExportSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/pmml/export/GeneralizedLinearPMMLModelExportSuite.scala index f9afbd888dfc5..1d32309481787 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/pmml/export/GeneralizedLinearPMMLModelExportSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/pmml/export/GeneralizedLinearPMMLModelExportSuite.scala @@ -18,12 +18,12 @@ package org.apache.spark.mllib.pmml.export import org.dmg.pmml.RegressionModel -import org.scalatest.FunSuite +import org.apache.spark.SparkFunSuite import org.apache.spark.mllib.regression.{LassoModel, LinearRegressionModel, RidgeRegressionModel} import org.apache.spark.mllib.util.LinearDataGenerator -class GeneralizedLinearPMMLModelExportSuite extends FunSuite { +class GeneralizedLinearPMMLModelExportSuite extends SparkFunSuite { test("linear regression PMML export") { val linearInput = LinearDataGenerator.generateLinearInput(3.0, Array(10.0, 10.0), 1, 17) diff --git a/mllib/src/test/scala/org/apache/spark/mllib/pmml/export/KMeansPMMLModelExportSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/pmml/export/KMeansPMMLModelExportSuite.scala index b985d0446d7b0..b3f9750afa730 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/pmml/export/KMeansPMMLModelExportSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/pmml/export/KMeansPMMLModelExportSuite.scala @@ -18,12 +18,12 @@ package org.apache.spark.mllib.pmml.export import org.dmg.pmml.ClusteringModel -import org.scalatest.FunSuite +import org.apache.spark.SparkFunSuite import org.apache.spark.mllib.clustering.KMeansModel import org.apache.spark.mllib.linalg.Vectors -class KMeansPMMLModelExportSuite extends FunSuite { +class KMeansPMMLModelExportSuite extends SparkFunSuite { test("KMeansPMMLModelExport generate PMML format") { val clusterCenters = Array( @@ -45,5 +45,5 @@ class KMeansPMMLModelExportSuite extends FunSuite { val pmmlClusteringModel = pmml.getModels.get(0).asInstanceOf[ClusteringModel] assert(pmmlClusteringModel.getNumberOfClusters === clusterCenters.length) } - + } diff --git a/mllib/src/test/scala/org/apache/spark/mllib/pmml/export/PMMLModelExportFactorySuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/pmml/export/PMMLModelExportFactorySuite.scala index f28a4ac8ad01f..af49450961750 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/pmml/export/PMMLModelExportFactorySuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/pmml/export/PMMLModelExportFactorySuite.scala @@ -17,15 +17,14 @@ package org.apache.spark.mllib.pmml.export -import org.scalatest.FunSuite - +import org.apache.spark.SparkFunSuite import org.apache.spark.mllib.classification.{LogisticRegressionModel, SVMModel} import org.apache.spark.mllib.clustering.KMeansModel import org.apache.spark.mllib.linalg.Vectors import org.apache.spark.mllib.regression.{LassoModel, LinearRegressionModel, RidgeRegressionModel} import org.apache.spark.mllib.util.LinearDataGenerator -class PMMLModelExportFactorySuite extends FunSuite { +class PMMLModelExportFactorySuite extends SparkFunSuite { test("PMMLModelExportFactory create KMeansPMMLModelExport when passing a KMeansModel") { val clusterCenters = Array( @@ -61,25 +60,25 @@ class PMMLModelExportFactorySuite extends FunSuite { test("PMMLModelExportFactory create BinaryClassificationPMMLModelExport " + "when passing a LogisticRegressionModel or SVMModel") { val linearInput = LinearDataGenerator.generateLinearInput(3.0, Array(10.0, 10.0), 1, 17) - + val logisticRegressionModel = new LogisticRegressionModel(linearInput(0).features, linearInput(0).label) val logisticRegressionModelExport = PMMLModelExportFactory.createPMMLModelExport(logisticRegressionModel) assert(logisticRegressionModelExport.isInstanceOf[BinaryClassificationPMMLModelExport]) - + val svmModel = new SVMModel(linearInput(0).features, linearInput(0).label) val svmModelExport = PMMLModelExportFactory.createPMMLModelExport(svmModel) assert(svmModelExport.isInstanceOf[BinaryClassificationPMMLModelExport]) } - + test("PMMLModelExportFactory throw IllegalArgumentException " + "when passing a Multinomial Logistic Regression") { /** 3 classes, 2 features */ val multiclassLogisticRegressionModel = new LogisticRegressionModel( - weights = Vectors.dense(0.1, 0.2, 0.3, 0.4), intercept = 1.0, + weights = Vectors.dense(0.1, 0.2, 0.3, 0.4), intercept = 1.0, numFeatures = 2, numClasses = 3) - + intercept[IllegalArgumentException] { PMMLModelExportFactory.createPMMLModelExport(multiclassLogisticRegressionModel) } diff --git a/mllib/src/test/scala/org/apache/spark/mllib/random/RandomDataGeneratorSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/random/RandomDataGeneratorSuite.scala index b792d819fdabb..a5ca1518f82f5 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/random/RandomDataGeneratorSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/random/RandomDataGeneratorSuite.scala @@ -19,12 +19,11 @@ package org.apache.spark.mllib.random import scala.math -import org.scalatest.FunSuite - +import org.apache.spark.SparkFunSuite import org.apache.spark.util.StatCounter // TODO update tests to use TestingUtils for floating point comparison after PR 1367 is merged -class RandomDataGeneratorSuite extends FunSuite { +class RandomDataGeneratorSuite extends SparkFunSuite { def apiChecks(gen: RandomDataGenerator[Double]) { // resetting seed should generate the same sequence of random numbers diff --git a/mllib/src/test/scala/org/apache/spark/mllib/random/RandomRDDsSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/random/RandomRDDsSuite.scala index 63f2ea916d457..413db2000d6d7 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/random/RandomRDDsSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/random/RandomRDDsSuite.scala @@ -19,8 +19,7 @@ package org.apache.spark.mllib.random import scala.collection.mutable.ArrayBuffer -import org.scalatest.FunSuite - +import org.apache.spark.SparkFunSuite import org.apache.spark.SparkContext._ import org.apache.spark.mllib.linalg.Vector import org.apache.spark.mllib.rdd.{RandomRDDPartition, RandomRDD} @@ -34,7 +33,7 @@ import org.apache.spark.util.StatCounter * * TODO update tests to use TestingUtils for floating point comparison after PR 1367 is merged */ -class RandomRDDsSuite extends FunSuite with MLlibTestSparkContext with Serializable { +class RandomRDDsSuite extends SparkFunSuite with MLlibTestSparkContext with Serializable { def testGeneratedRDD(rdd: RDD[Double], expectedSize: Long, diff --git a/mllib/src/test/scala/org/apache/spark/mllib/rdd/MLPairRDDFunctionsSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/rdd/MLPairRDDFunctionsSuite.scala index 57216e8eb4a55..10f5a2be48f7c 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/rdd/MLPairRDDFunctionsSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/rdd/MLPairRDDFunctionsSuite.scala @@ -17,12 +17,11 @@ package org.apache.spark.mllib.rdd -import org.scalatest.FunSuite - +import org.apache.spark.SparkFunSuite import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.mllib.rdd.MLPairRDDFunctions._ -class MLPairRDDFunctionsSuite extends FunSuite with MLlibTestSparkContext { +class MLPairRDDFunctionsSuite extends SparkFunSuite with MLlibTestSparkContext { test("topByKey") { val topMap = sc.parallelize(Array((1, 7), (1, 3), (1, 6), (1, 1), (1, 2), (3, 2), (3, 7), (5, 1), (3, 5)), 2) diff --git a/mllib/src/test/scala/org/apache/spark/mllib/rdd/RDDFunctionsSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/rdd/RDDFunctionsSuite.scala index 6d6c0aa5be812..bc64172614830 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/rdd/RDDFunctionsSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/rdd/RDDFunctionsSuite.scala @@ -17,12 +17,11 @@ package org.apache.spark.mllib.rdd -import org.scalatest.FunSuite - +import org.apache.spark.SparkFunSuite import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.mllib.rdd.RDDFunctions._ -class RDDFunctionsSuite extends FunSuite with MLlibTestSparkContext { +class RDDFunctionsSuite extends SparkFunSuite with MLlibTestSparkContext { test("sliding") { val data = 0 until 6 diff --git a/mllib/src/test/scala/org/apache/spark/mllib/recommendation/ALSSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/recommendation/ALSSuite.scala index b3798940ddc38..05b87728d6fdb 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/recommendation/ALSSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/recommendation/ALSSuite.scala @@ -21,9 +21,9 @@ import scala.collection.JavaConversions._ import scala.math.abs import scala.util.Random -import org.scalatest.FunSuite import org.jblas.DoubleMatrix +import org.apache.spark.SparkFunSuite import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.storage.StorageLevel @@ -84,7 +84,7 @@ object ALSSuite { } -class ALSSuite extends FunSuite with MLlibTestSparkContext { +class ALSSuite extends SparkFunSuite with MLlibTestSparkContext { test("rank-1 matrices") { testALS(50, 100, 1, 15, 0.7, 0.3) diff --git a/mllib/src/test/scala/org/apache/spark/mllib/recommendation/MatrixFactorizationModelSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/recommendation/MatrixFactorizationModelSuite.scala index 2c92866f3893d..2c8ed057a516a 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/recommendation/MatrixFactorizationModelSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/recommendation/MatrixFactorizationModelSuite.scala @@ -17,14 +17,13 @@ package org.apache.spark.mllib.recommendation -import org.scalatest.FunSuite - +import org.apache.spark.SparkFunSuite import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.mllib.util.TestingUtils._ import org.apache.spark.rdd.RDD import org.apache.spark.util.Utils -class MatrixFactorizationModelSuite extends FunSuite with MLlibTestSparkContext { +class MatrixFactorizationModelSuite extends SparkFunSuite with MLlibTestSparkContext { val rank = 2 var userFeatures: RDD[(Int, Array[Double])] = _ diff --git a/mllib/src/test/scala/org/apache/spark/mllib/regression/IsotonicRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/regression/IsotonicRegressionSuite.scala index 3b38bdf5ef5eb..ea4f2865757c1 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/regression/IsotonicRegressionSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/regression/IsotonicRegressionSuite.scala @@ -17,13 +17,14 @@ package org.apache.spark.mllib.regression -import org.scalatest.{Matchers, FunSuite} +import org.scalatest.Matchers +import org.apache.spark.SparkFunSuite import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.mllib.util.TestingUtils._ import org.apache.spark.util.Utils -class IsotonicRegressionSuite extends FunSuite with MLlibTestSparkContext with Matchers { +class IsotonicRegressionSuite extends SparkFunSuite with MLlibTestSparkContext with Matchers { private def round(d: Double) = { math.round(d * 100).toDouble / 100 diff --git a/mllib/src/test/scala/org/apache/spark/mllib/regression/LabeledPointSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/regression/LabeledPointSuite.scala index 110c44a7193fd..d8364a06de4da 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/regression/LabeledPointSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/regression/LabeledPointSuite.scala @@ -17,11 +17,10 @@ package org.apache.spark.mllib.regression -import org.scalatest.FunSuite - +import org.apache.spark.SparkFunSuite import org.apache.spark.mllib.linalg.Vectors -class LabeledPointSuite extends FunSuite { +class LabeledPointSuite extends SparkFunSuite { test("parse labeled points") { val points = Seq( diff --git a/mllib/src/test/scala/org/apache/spark/mllib/regression/LassoSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/regression/LassoSuite.scala index 71dce50922991..08a152ffc7a23 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/regression/LassoSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/regression/LassoSuite.scala @@ -19,8 +19,7 @@ package org.apache.spark.mllib.regression import scala.util.Random -import org.scalatest.FunSuite - +import org.apache.spark.SparkFunSuite import org.apache.spark.mllib.linalg.Vectors import org.apache.spark.mllib.util.{LocalClusterSparkContext, LinearDataGenerator, MLlibTestSparkContext} @@ -32,7 +31,7 @@ private object LassoSuite { val model = new LassoModel(weights = Vectors.dense(0.1, 0.2, 0.3), intercept = 0.5) } -class LassoSuite extends FunSuite with MLlibTestSparkContext { +class LassoSuite extends SparkFunSuite with MLlibTestSparkContext { def validatePrediction(predictions: Seq[Double], input: Seq[LabeledPoint]) { val numOffPredictions = predictions.zip(input).count { case (prediction, expected) => @@ -143,7 +142,7 @@ class LassoSuite extends FunSuite with MLlibTestSparkContext { } } -class LassoClusterSuite extends FunSuite with LocalClusterSparkContext { +class LassoClusterSuite extends SparkFunSuite with LocalClusterSparkContext { test("task size should be small in both training and prediction") { val m = 4 diff --git a/mllib/src/test/scala/org/apache/spark/mllib/regression/LinearRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/regression/LinearRegressionSuite.scala index 3781931c2f819..f88a1c33c9f7c 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/regression/LinearRegressionSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/regression/LinearRegressionSuite.scala @@ -19,8 +19,7 @@ package org.apache.spark.mllib.regression import scala.util.Random -import org.scalatest.FunSuite - +import org.apache.spark.SparkFunSuite import org.apache.spark.mllib.linalg.Vectors import org.apache.spark.mllib.util.{LocalClusterSparkContext, LinearDataGenerator, MLlibTestSparkContext} @@ -32,7 +31,7 @@ private object LinearRegressionSuite { val model = new LinearRegressionModel(weights = Vectors.dense(0.1, 0.2, 0.3), intercept = 0.5) } -class LinearRegressionSuite extends FunSuite with MLlibTestSparkContext { +class LinearRegressionSuite extends SparkFunSuite with MLlibTestSparkContext { def validatePrediction(predictions: Seq[Double], input: Seq[LabeledPoint]) { val numOffPredictions = predictions.zip(input).count { case (prediction, expected) => @@ -150,7 +149,7 @@ class LinearRegressionSuite extends FunSuite with MLlibTestSparkContext { } } -class LinearRegressionClusterSuite extends FunSuite with LocalClusterSparkContext { +class LinearRegressionClusterSuite extends SparkFunSuite with LocalClusterSparkContext { test("task size should be small in both training and prediction") { val m = 4 diff --git a/mllib/src/test/scala/org/apache/spark/mllib/regression/RidgeRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/regression/RidgeRegressionSuite.scala index d6c93cc0e49cd..7a781fee634c8 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/regression/RidgeRegressionSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/regression/RidgeRegressionSuite.scala @@ -20,8 +20,8 @@ package org.apache.spark.mllib.regression import scala.util.Random import org.jblas.DoubleMatrix -import org.scalatest.FunSuite +import org.apache.spark.SparkFunSuite import org.apache.spark.mllib.linalg.Vectors import org.apache.spark.mllib.util.{LocalClusterSparkContext, LinearDataGenerator, MLlibTestSparkContext} @@ -33,7 +33,7 @@ private object RidgeRegressionSuite { val model = new RidgeRegressionModel(weights = Vectors.dense(0.1, 0.2, 0.3), intercept = 0.5) } -class RidgeRegressionSuite extends FunSuite with MLlibTestSparkContext { +class RidgeRegressionSuite extends SparkFunSuite with MLlibTestSparkContext { def predictionError(predictions: Seq[Double], input: Seq[LabeledPoint]): Double = { predictions.zip(input).map { case (prediction, expected) => @@ -101,7 +101,7 @@ class RidgeRegressionSuite extends FunSuite with MLlibTestSparkContext { } } -class RidgeRegressionClusterSuite extends FunSuite with LocalClusterSparkContext { +class RidgeRegressionClusterSuite extends SparkFunSuite with LocalClusterSparkContext { test("task size should be small in both training and prediction") { val m = 4 diff --git a/mllib/src/test/scala/org/apache/spark/mllib/regression/StreamingLinearRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/regression/StreamingLinearRegressionSuite.scala index 26604dbe6c1ef..9a379406d5061 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/regression/StreamingLinearRegressionSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/regression/StreamingLinearRegressionSuite.scala @@ -19,14 +19,13 @@ package org.apache.spark.mllib.regression import scala.collection.mutable.ArrayBuffer -import org.scalatest.FunSuite - +import org.apache.spark.SparkFunSuite import org.apache.spark.mllib.linalg.Vectors import org.apache.spark.mllib.util.LinearDataGenerator import org.apache.spark.streaming.dstream.DStream import org.apache.spark.streaming.TestSuiteBase -class StreamingLinearRegressionSuite extends FunSuite with TestSuiteBase { +class StreamingLinearRegressionSuite extends SparkFunSuite with TestSuiteBase { // use longer wait time to ensure job completion override def maxWaitTimeMillis: Int = 20000 diff --git a/mllib/src/test/scala/org/apache/spark/mllib/stat/CorrelationSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/stat/CorrelationSuite.scala index a7e6fce31ff7e..c292ced75e870 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/stat/CorrelationSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/stat/CorrelationSuite.scala @@ -17,16 +17,15 @@ package org.apache.spark.mllib.stat -import org.scalatest.FunSuite - import breeze.linalg.{DenseMatrix => BDM, Matrix => BM} +import org.apache.spark.SparkFunSuite import org.apache.spark.mllib.linalg.Vectors import org.apache.spark.mllib.stat.correlation.{Correlations, PearsonCorrelation, SpearmanCorrelation} import org.apache.spark.mllib.util.MLlibTestSparkContext -class CorrelationSuite extends FunSuite with MLlibTestSparkContext { +class CorrelationSuite extends SparkFunSuite with MLlibTestSparkContext { // test input data val xData = Array(1.0, 0.0, -2.0) diff --git a/mllib/src/test/scala/org/apache/spark/mllib/stat/HypothesisTestSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/stat/HypothesisTestSuite.scala index 15418e6035965..b084a5fb4313f 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/stat/HypothesisTestSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/stat/HypothesisTestSuite.scala @@ -19,16 +19,14 @@ package org.apache.spark.mllib.stat import java.util.Random -import org.scalatest.FunSuite - -import org.apache.spark.SparkException +import org.apache.spark.{SparkException, SparkFunSuite} import org.apache.spark.mllib.linalg.{DenseVector, Matrices, Vectors} import org.apache.spark.mllib.regression.LabeledPoint import org.apache.spark.mllib.stat.test.ChiSqTest import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.mllib.util.TestingUtils._ -class HypothesisTestSuite extends FunSuite with MLlibTestSparkContext { +class HypothesisTestSuite extends SparkFunSuite with MLlibTestSparkContext { test("chi squared pearson goodness of fit") { diff --git a/mllib/src/test/scala/org/apache/spark/mllib/stat/KernelDensitySuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/stat/KernelDensitySuite.scala index a309c942cf8ff..5feccdf33681a 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/stat/KernelDensitySuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/stat/KernelDensitySuite.scala @@ -18,11 +18,11 @@ package org.apache.spark.mllib.stat import org.apache.commons.math3.distribution.NormalDistribution -import org.scalatest.FunSuite +import org.apache.spark.SparkFunSuite import org.apache.spark.mllib.util.MLlibTestSparkContext -class KernelDensitySuite extends FunSuite with MLlibTestSparkContext { +class KernelDensitySuite extends SparkFunSuite with MLlibTestSparkContext { test("kernel density single sample") { val rdd = sc.parallelize(Array(5.0)) val evaluationPoints = Array(5.0, 6.0) diff --git a/mllib/src/test/scala/org/apache/spark/mllib/stat/MultivariateOnlineSummarizerSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/stat/MultivariateOnlineSummarizerSuite.scala index 23b0eec865de6..07efde4f5e6dc 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/stat/MultivariateOnlineSummarizerSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/stat/MultivariateOnlineSummarizerSuite.scala @@ -17,12 +17,11 @@ package org.apache.spark.mllib.stat -import org.scalatest.FunSuite - +import org.apache.spark.SparkFunSuite import org.apache.spark.mllib.linalg.Vectors import org.apache.spark.mllib.util.TestingUtils._ -class MultivariateOnlineSummarizerSuite extends FunSuite { +class MultivariateOnlineSummarizerSuite extends SparkFunSuite { test("basic error handing") { val summarizer = new MultivariateOnlineSummarizer diff --git a/mllib/src/test/scala/org/apache/spark/mllib/stat/distribution/MultivariateGaussianSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/stat/distribution/MultivariateGaussianSuite.scala index fac2498e4dcb3..aa60deb665aeb 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/stat/distribution/MultivariateGaussianSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/stat/distribution/MultivariateGaussianSuite.scala @@ -17,49 +17,48 @@ package org.apache.spark.mllib.stat.distribution -import org.scalatest.FunSuite - +import org.apache.spark.SparkFunSuite import org.apache.spark.mllib.linalg.{ Vectors, Matrices } import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.mllib.util.TestingUtils._ -class MultivariateGaussianSuite extends FunSuite with MLlibTestSparkContext { +class MultivariateGaussianSuite extends SparkFunSuite with MLlibTestSparkContext { test("univariate") { val x1 = Vectors.dense(0.0) val x2 = Vectors.dense(1.5) - + val mu = Vectors.dense(0.0) val sigma1 = Matrices.dense(1, 1, Array(1.0)) val dist1 = new MultivariateGaussian(mu, sigma1) assert(dist1.pdf(x1) ~== 0.39894 absTol 1E-5) assert(dist1.pdf(x2) ~== 0.12952 absTol 1E-5) - + val sigma2 = Matrices.dense(1, 1, Array(4.0)) val dist2 = new MultivariateGaussian(mu, sigma2) assert(dist2.pdf(x1) ~== 0.19947 absTol 1E-5) assert(dist2.pdf(x2) ~== 0.15057 absTol 1E-5) } - + test("multivariate") { val x1 = Vectors.dense(0.0, 0.0) val x2 = Vectors.dense(1.0, 1.0) - + val mu = Vectors.dense(0.0, 0.0) val sigma1 = Matrices.dense(2, 2, Array(1.0, 0.0, 0.0, 1.0)) val dist1 = new MultivariateGaussian(mu, sigma1) assert(dist1.pdf(x1) ~== 0.15915 absTol 1E-5) assert(dist1.pdf(x2) ~== 0.05855 absTol 1E-5) - + val sigma2 = Matrices.dense(2, 2, Array(4.0, -1.0, -1.0, 2.0)) val dist2 = new MultivariateGaussian(mu, sigma2) assert(dist2.pdf(x1) ~== 0.060155 absTol 1E-5) assert(dist2.pdf(x2) ~== 0.033971 absTol 1E-5) } - + test("multivariate degenerate") { val x1 = Vectors.dense(0.0, 0.0) val x2 = Vectors.dense(1.0, 1.0) - + val mu = Vectors.dense(0.0, 0.0) val sigma = Matrices.dense(2, 2, Array(1.0, 1.0, 1.0, 1.0)) val dist = new MultivariateGaussian(mu, sigma) diff --git a/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala index ce983eb27fa35..356d957f15909 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala @@ -20,8 +20,7 @@ package org.apache.spark.mllib.tree import scala.collection.JavaConverters._ import scala.collection.mutable -import org.scalatest.FunSuite - +import org.apache.spark.SparkFunSuite import org.apache.spark.mllib.linalg.Vectors import org.apache.spark.mllib.regression.LabeledPoint import org.apache.spark.mllib.tree.configuration.Algo._ @@ -34,7 +33,7 @@ import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.util.Utils -class DecisionTreeSuite extends FunSuite with MLlibTestSparkContext { +class DecisionTreeSuite extends SparkFunSuite with MLlibTestSparkContext { ///////////////////////////////////////////////////////////////////////////// // Tests examining individual elements of training @@ -859,7 +858,7 @@ class DecisionTreeSuite extends FunSuite with MLlibTestSparkContext { } } -object DecisionTreeSuite extends FunSuite { +object DecisionTreeSuite extends SparkFunSuite { def validateClassifier( model: DecisionTreeModel, diff --git a/mllib/src/test/scala/org/apache/spark/mllib/tree/GradientBoostedTreesSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/tree/GradientBoostedTreesSuite.scala index 55b0bac7d49fe..84dd3b342d4c0 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/tree/GradientBoostedTreesSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/tree/GradientBoostedTreesSuite.scala @@ -17,8 +17,7 @@ package org.apache.spark.mllib.tree -import org.scalatest.FunSuite - +import org.apache.spark.SparkFunSuite import org.apache.spark.mllib.regression.LabeledPoint import org.apache.spark.mllib.tree.configuration.Algo._ import org.apache.spark.mllib.tree.configuration.{BoostingStrategy, Strategy} @@ -32,7 +31,7 @@ import org.apache.spark.util.Utils /** * Test suite for [[GradientBoostedTrees]]. */ -class GradientBoostedTreesSuite extends FunSuite with MLlibTestSparkContext { +class GradientBoostedTreesSuite extends SparkFunSuite with MLlibTestSparkContext { test("Regression with continuous features: SquaredError") { GradientBoostedTreesSuite.testCombinations.foreach { diff --git a/mllib/src/test/scala/org/apache/spark/mllib/tree/ImpuritySuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/tree/ImpuritySuite.scala index 92b498580af03..49aff21fe7914 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/tree/ImpuritySuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/tree/ImpuritySuite.scala @@ -17,15 +17,14 @@ package org.apache.spark.mllib.tree -import org.scalatest.FunSuite - +import org.apache.spark.SparkFunSuite import org.apache.spark.mllib.tree.impurity.{EntropyAggregator, GiniAggregator} import org.apache.spark.mllib.util.MLlibTestSparkContext /** * Test suites for [[GiniAggregator]] and [[EntropyAggregator]]. */ -class ImpuritySuite extends FunSuite with MLlibTestSparkContext { +class ImpuritySuite extends SparkFunSuite with MLlibTestSparkContext { test("Gini impurity does not support negative labels") { val gini = new GiniAggregator(2) intercept[IllegalArgumentException] { diff --git a/mllib/src/test/scala/org/apache/spark/mllib/tree/RandomForestSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/tree/RandomForestSuite.scala index 4ed66953cb628..e6df5d974bf36 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/tree/RandomForestSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/tree/RandomForestSuite.scala @@ -19,8 +19,7 @@ package org.apache.spark.mllib.tree import scala.collection.mutable -import org.scalatest.FunSuite - +import org.apache.spark.SparkFunSuite import org.apache.spark.mllib.linalg.Vectors import org.apache.spark.mllib.regression.LabeledPoint import org.apache.spark.mllib.tree.configuration.Algo._ @@ -35,7 +34,7 @@ import org.apache.spark.util.Utils /** * Test suite for [[RandomForest]]. */ -class RandomForestSuite extends FunSuite with MLlibTestSparkContext { +class RandomForestSuite extends SparkFunSuite with MLlibTestSparkContext { def binaryClassificationTestWithContinuousFeatures(strategy: Strategy) { val arr = EnsembleTestHelper.generateOrderedLabeledPoints(numFeatures = 50, 1000) val rdd = sc.parallelize(arr) diff --git a/mllib/src/test/scala/org/apache/spark/mllib/tree/impl/BaggedPointSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/tree/impl/BaggedPointSuite.scala index b184e936672ca..9d756da410325 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/tree/impl/BaggedPointSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/tree/impl/BaggedPointSuite.scala @@ -17,15 +17,14 @@ package org.apache.spark.mllib.tree.impl -import org.scalatest.FunSuite - +import org.apache.spark.SparkFunSuite import org.apache.spark.mllib.tree.EnsembleTestHelper import org.apache.spark.mllib.util.MLlibTestSparkContext /** * Test suite for [[BaggedPoint]]. */ -class BaggedPointSuite extends FunSuite with MLlibTestSparkContext { +class BaggedPointSuite extends SparkFunSuite with MLlibTestSparkContext { test("BaggedPoint RDD: without subsampling") { val arr = EnsembleTestHelper.generateOrderedLabeledPoints(1, 1000) diff --git a/mllib/src/test/scala/org/apache/spark/mllib/util/MLUtilsSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/util/MLUtilsSuite.scala index 27050bde16ef3..7bb9d570a0b89 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/util/MLUtilsSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/util/MLUtilsSuite.scala @@ -64,7 +64,7 @@ class MLUtilsSuite extends FunSuite with MLlibTestSparkContext { val fastSquaredDist3 = fastSquaredDistance(v2, norm2, v3, norm3, precision) assert((fastSquaredDist3 - squaredDist2) <= precision * squaredDist2, s"failed with m = $m") - if (m > 10) { + if (m > 10) { val v4 = Vectors.sparse(n, indices.slice(0, m - 10), indices.map(i => a(i) + 0.5).slice(0, m - 10)) val norm4 = Vectors.norm(v4, 2.0) diff --git a/mllib/src/test/scala/org/apache/spark/mllib/util/NumericParserSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/util/NumericParserSuite.scala index f68fb95eac4e4..8dcb9ba9be108 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/util/NumericParserSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/util/NumericParserSuite.scala @@ -17,11 +17,9 @@ package org.apache.spark.mllib.util -import org.scalatest.FunSuite +import org.apache.spark.{SparkException, SparkFunSuite} -import org.apache.spark.SparkException - -class NumericParserSuite extends FunSuite { +class NumericParserSuite extends SparkFunSuite { test("parser") { val s = "((1.0,2e3),-4,[5e-6,7.0E8],+9)" diff --git a/mllib/src/test/scala/org/apache/spark/mllib/util/TestingUtilsSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/util/TestingUtilsSuite.scala index 59e6c778806f4..8f475f30249d6 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/util/TestingUtilsSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/util/TestingUtilsSuite.scala @@ -17,12 +17,12 @@ package org.apache.spark.mllib.util +import org.apache.spark.SparkFunSuite import org.apache.spark.mllib.linalg.Vectors -import org.scalatest.FunSuite import org.apache.spark.mllib.util.TestingUtils._ import org.scalatest.exceptions.TestFailedException -class TestingUtilsSuite extends FunSuite { +class TestingUtilsSuite extends SparkFunSuite { test("Comparing doubles using relative error.") { diff --git a/project/MimaExcludes.scala b/project/MimaExcludes.scala index 11b439e7875fc..8da72b3fa7cdb 100644 --- a/project/MimaExcludes.scala +++ b/project/MimaExcludes.scala @@ -38,6 +38,8 @@ object MimaExcludes { Seq( MimaBuild.excludeSparkPackage("deploy"), MimaBuild.excludeSparkPackage("ml"), + // SPARK-7910 Adding a method to get the partioner to JavaRDD, + ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.api.java.JavaRDDLike.partitioner"), // SPARK-5922 Adding a generalized diff(other: RDD[(VertexId, VD)]) to VertexRDD ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.graphx.VertexRDD.diff"), // These are needed if checking against the sbt build, since they are part of diff --git a/project/SparkBuild.scala b/project/SparkBuild.scala index b9515a12bc573..9a849639233bc 100644 --- a/project/SparkBuild.scala +++ b/project/SparkBuild.scala @@ -23,7 +23,6 @@ import scala.collection.JavaConversions._ import sbt._ import sbt.Classpaths.publishTask import sbt.Keys._ -import sbtunidoc.Plugin.genjavadocSettings import sbtunidoc.Plugin.UnidocKeys.unidocGenjavadocVersion import com.typesafe.sbt.pom.{loadEffectivePom, PomBuild, SbtPomKeys} import net.virtualvoid.sbt.graph.Plugin.graphSettings @@ -118,7 +117,12 @@ object SparkBuild extends PomBuild { lazy val MavenCompile = config("m2r") extend(Compile) lazy val publishLocalBoth = TaskKey[Unit]("publish-local", "publish local for m2 and ivy") - lazy val sharedSettings = graphSettings ++ genjavadocSettings ++ Seq ( + lazy val sparkGenjavadocSettings: Seq[sbt.Def.Setting[_]] = Seq( + libraryDependencies += compilerPlugin( + "org.spark-project" %% "genjavadoc-plugin" % unidocGenjavadocVersion.value cross CrossVersion.full), + scalacOptions <+= target.map(t => "-P:genjavadoc:out=" + (t / "java"))) + + lazy val sharedSettings = graphSettings ++ sparkGenjavadocSettings ++ Seq ( javaHome := sys.env.get("JAVA_HOME") .orElse(sys.props.get("java.home").map { p => new File(p).getParentFile().getAbsolutePath() }) .map(file), @@ -126,7 +130,7 @@ object SparkBuild extends PomBuild { retrieveManaged := true, retrievePattern := "[type]s/[artifact](-[revision])(-[classifier]).[ext]", publishMavenStyle := true, - unidocGenjavadocVersion := "0.8", + unidocGenjavadocVersion := "0.9-spark0", resolvers += Resolver.mavenLocal, otherResolvers <<= SbtPomKeys.mvnLocalRepository(dotM2 => Seq(Resolver.file("dotM2", dotM2))), diff --git a/project/plugins.sbt b/project/plugins.sbt index 7096b0d3ee7de..75bd604a1b857 100644 --- a/project/plugins.sbt +++ b/project/plugins.sbt @@ -25,7 +25,7 @@ addSbtPlugin("com.typesafe" % "sbt-mima-plugin" % "0.1.6") addSbtPlugin("com.alpinenow" % "junit_xml_listener" % "0.5.1") -addSbtPlugin("com.eed3si9n" % "sbt-unidoc" % "0.3.1") +addSbtPlugin("com.eed3si9n" % "sbt-unidoc" % "0.3.3") addSbtPlugin("com.cavorite" % "sbt-avro" % "0.3.2") diff --git a/python/pyspark/accumulators.py b/python/pyspark/accumulators.py index 0d21a132048a5..adca90ddaf397 100644 --- a/python/pyspark/accumulators.py +++ b/python/pyspark/accumulators.py @@ -261,3 +261,7 @@ def _start_update_server(): thread.daemon = True thread.start() return server + +if __name__ == "__main__": + import doctest + doctest.testmod() diff --git a/python/pyspark/ml/feature.py b/python/pyspark/ml/feature.py index b0479d9b074db..ddb33f427ac64 100644 --- a/python/pyspark/ml/feature.py +++ b/python/pyspark/ml/feature.py @@ -324,65 +324,73 @@ def getP(self): @inherit_doc class OneHotEncoder(JavaTransformer, HasInputCol, HasOutputCol): """ - A one-hot encoder that maps a column of label indices to a column of binary vectors, with - at most a single one-value. By default, the binary vector has an element for each category, so - with 5 categories, an input value of 2.0 would map to an output vector of - (0.0, 0.0, 1.0, 0.0, 0.0). If includeFirst is set to false, the first category is omitted, so - the output vector for the previous example would be (0.0, 1.0, 0.0, 0.0) and an input value - of 0.0 would map to a vector of all zeros. Including the first category makes the vector columns - linearly dependent because they sum up to one. - - TODO: This method requires the use of StringIndexer first. Decouple them. + A one-hot encoder that maps a column of category indices to a + column of binary vectors, with at most a single one-value per row + that indicates the input category index. + For example with 5 categories, an input value of 2.0 would map to + an output vector of `[0.0, 0.0, 1.0, 0.0]`. + The last category is not included by default (configurable via + :py:attr:`dropLast`) because it makes the vector entries sum up to + one, and hence linearly dependent. + So an input value of 4.0 maps to `[0.0, 0.0, 0.0, 0.0]`. + Note that this is different from scikit-learn's OneHotEncoder, + which keeps all categories. + The output vectors are sparse. + + .. seealso:: + + :py:class:`StringIndexer` for converting categorical values into + category indices >>> stringIndexer = StringIndexer(inputCol="label", outputCol="indexed") >>> model = stringIndexer.fit(stringIndDf) >>> td = model.transform(stringIndDf) - >>> encoder = OneHotEncoder(includeFirst=False, inputCol="indexed", outputCol="features") + >>> encoder = OneHotEncoder(inputCol="indexed", outputCol="features") >>> encoder.transform(td).head().features - SparseVector(2, {}) + SparseVector(2, {0: 1.0}) >>> encoder.setParams(outputCol="freqs").transform(td).head().freqs - SparseVector(2, {}) - >>> params = {encoder.includeFirst: True, encoder.outputCol: "test"} + SparseVector(2, {0: 1.0}) + >>> params = {encoder.dropLast: False, encoder.outputCol: "test"} >>> encoder.transform(td, params).head().test SparseVector(3, {0: 1.0}) """ # a placeholder to make it appear in the generated doc - includeFirst = Param(Params._dummy(), "includeFirst", "include first category") + dropLast = Param(Params._dummy(), "dropLast", "whether to drop the last category") @keyword_only - def __init__(self, includeFirst=True, inputCol=None, outputCol=None): + def __init__(self, dropLast=True, inputCol=None, outputCol=None): """ __init__(self, includeFirst=True, inputCol=None, outputCol=None) """ super(OneHotEncoder, self).__init__() self._java_obj = self._new_java_obj("org.apache.spark.ml.feature.OneHotEncoder", self.uid) - self.includeFirst = Param(self, "includeFirst", "include first category") - self._setDefault(includeFirst=True) + self.dropLast = Param(self, "dropLast", "whether to drop the last category") + self._setDefault(dropLast=True) kwargs = self.__init__._input_kwargs self.setParams(**kwargs) @keyword_only - def setParams(self, includeFirst=True, inputCol=None, outputCol=None): + def setParams(self, dropLast=True, inputCol=None, outputCol=None): """ - setParams(self, includeFirst=True, inputCol=None, outputCol=None) + setParams(self, dropLast=True, inputCol=None, outputCol=None) Sets params for this OneHotEncoder. """ kwargs = self.setParams._input_kwargs return self._set(**kwargs) - def setIncludeFirst(self, value): + def setDropLast(self, value): """ - Sets the value of :py:attr:`includeFirst`. + Sets the value of :py:attr:`dropLast`. """ - self._paramMap[self.includeFirst] = value + self._paramMap[self.dropLast] = value return self - def getIncludeFirst(self): + def getDropLast(self): """ - Gets the value of includeFirst or its default value. + Gets the value of dropLast or its default value. """ - return self.getOrDefault(self.includeFirst) + return self.getOrDefault(self.dropLast) @inherit_doc diff --git a/python/pyspark/ml/tuning.py b/python/pyspark/ml/tuning.py index 497841b6c8ce6..0bf988fd72f14 100644 --- a/python/pyspark/ml/tuning.py +++ b/python/pyspark/ml/tuning.py @@ -91,20 +91,19 @@ class CrossValidator(Estimator): >>> from pyspark.ml.evaluation import BinaryClassificationEvaluator >>> from pyspark.mllib.linalg import Vectors >>> dataset = sqlContext.createDataFrame( - ... [(Vectors.dense([0.0, 1.0]), 0.0), - ... (Vectors.dense([1.0, 2.0]), 1.0), - ... (Vectors.dense([0.55, 3.0]), 0.0), - ... (Vectors.dense([0.45, 4.0]), 1.0), - ... (Vectors.dense([0.51, 5.0]), 1.0)] * 10, + ... [(Vectors.dense([0.0]), 0.0), + ... (Vectors.dense([0.4]), 1.0), + ... (Vectors.dense([0.5]), 0.0), + ... (Vectors.dense([0.6]), 1.0), + ... (Vectors.dense([1.0]), 1.0)] * 10, ... ["features", "label"]) >>> lr = LogisticRegression() - >>> grid = ParamGridBuilder().addGrid(lr.maxIter, [0, 1, 5]).build() + >>> grid = ParamGridBuilder().addGrid(lr.maxIter, [0, 1]).build() >>> evaluator = BinaryClassificationEvaluator() >>> cv = CrossValidator(estimator=lr, estimatorParamMaps=grid, evaluator=evaluator) - >>> # SPARK-7432: The following test is flaky. - >>> # cvModel = cv.fit(dataset) - >>> # expected = lr.fit(dataset, {lr.maxIter: 5}).transform(dataset) - >>> # cvModel.transform(dataset).collect() == expected.collect() + >>> cvModel = cv.fit(dataset) + >>> evaluator.evaluate(cvModel.transform(dataset)) + 0.8333... """ # a placeholder to make it appear in the generated doc diff --git a/python/pyspark/mllib/__init__.py b/python/pyspark/mllib/__init__.py index 07507b2ad0d05..b11aed2c3afda 100644 --- a/python/pyspark/mllib/__init__.py +++ b/python/pyspark/mllib/__init__.py @@ -28,11 +28,3 @@ __all__ = ['classification', 'clustering', 'feature', 'fpm', 'linalg', 'random', 'recommendation', 'regression', 'stat', 'tree', 'util'] - -import sys -from . import rand as random -modname = __name__ + '.random' -random.__name__ = modname -random.RandomRDDs.__module__ = modname -sys.modules[modname] = random -del modname, sys diff --git a/python/pyspark/mllib/evaluation.py b/python/pyspark/mllib/evaluation.py index aab5e5f4b77b5..c5cf3a4e7ff22 100644 --- a/python/pyspark/mllib/evaluation.py +++ b/python/pyspark/mllib/evaluation.py @@ -27,6 +27,8 @@ class BinaryClassificationMetrics(JavaModelWrapper): """ Evaluator for binary classification. + :param scoreAndLabels: an RDD of (score, label) pairs + >>> scoreAndLabels = sc.parallelize([ ... (0.1, 0.0), (0.1, 1.0), (0.4, 0.0), (0.6, 0.0), (0.6, 1.0), (0.6, 1.0), (0.8, 1.0)], 2) >>> metrics = BinaryClassificationMetrics(scoreAndLabels) @@ -38,9 +40,6 @@ class BinaryClassificationMetrics(JavaModelWrapper): """ def __init__(self, scoreAndLabels): - """ - :param scoreAndLabels: an RDD of (score, label) pairs - """ sc = scoreAndLabels.ctx sql_ctx = SQLContext(sc) df = sql_ctx.createDataFrame(scoreAndLabels, schema=StructType([ @@ -76,6 +75,9 @@ class RegressionMetrics(JavaModelWrapper): """ Evaluator for regression. + :param predictionAndObservations: an RDD of (prediction, + observation) pairs. + >>> predictionAndObservations = sc.parallelize([ ... (2.5, 3.0), (0.0, -0.5), (2.0, 2.0), (8.0, 7.0)]) >>> metrics = RegressionMetrics(predictionAndObservations) @@ -92,9 +94,6 @@ class RegressionMetrics(JavaModelWrapper): """ def __init__(self, predictionAndObservations): - """ - :param predictionAndObservations: an RDD of (prediction, observation) pairs. - """ sc = predictionAndObservations.ctx sql_ctx = SQLContext(sc) df = sql_ctx.createDataFrame(predictionAndObservations, schema=StructType([ @@ -148,6 +147,8 @@ class MulticlassMetrics(JavaModelWrapper): """ Evaluator for multiclass classification. + :param predictionAndLabels an RDD of (prediction, label) pairs. + >>> predictionAndLabels = sc.parallelize([(0.0, 0.0), (0.0, 1.0), (0.0, 0.0), ... (1.0, 0.0), (1.0, 1.0), (1.0, 1.0), (1.0, 1.0), (2.0, 2.0), (2.0, 0.0)]) >>> metrics = MulticlassMetrics(predictionAndLabels) @@ -176,9 +177,6 @@ class MulticlassMetrics(JavaModelWrapper): """ def __init__(self, predictionAndLabels): - """ - :param predictionAndLabels an RDD of (prediction, label) pairs. - """ sc = predictionAndLabels.ctx sql_ctx = SQLContext(sc) df = sql_ctx.createDataFrame(predictionAndLabels, schema=StructType([ @@ -277,6 +275,9 @@ class RankingMetrics(JavaModelWrapper): """ Evaluator for ranking algorithms. + :param predictionAndLabels: an RDD of (predicted ranking, + ground truth set) pairs. + >>> predictionAndLabels = sc.parallelize([ ... ([1, 6, 2, 7, 8, 3, 9, 10, 4, 5], [1, 2, 3, 4, 5]), ... ([4, 1, 5, 6, 2, 7, 3, 8, 9, 10], [1, 2, 3]), @@ -298,9 +299,6 @@ class RankingMetrics(JavaModelWrapper): """ def __init__(self, predictionAndLabels): - """ - :param predictionAndLabels: an RDD of (predicted ranking, ground truth set) pairs. - """ sc = predictionAndLabels.ctx sql_ctx = SQLContext(sc) df = sql_ctx.createDataFrame(predictionAndLabels, @@ -347,6 +345,10 @@ class MultilabelMetrics(JavaModelWrapper): """ Evaluator for multilabel classification. + :param predictionAndLabels: an RDD of (predictions, labels) pairs, + both are non-null Arrays, each with + unique elements. + >>> predictionAndLabels = sc.parallelize([([0.0, 1.0], [0.0, 2.0]), ([0.0, 2.0], [0.0, 1.0]), ... ([], [0.0]), ([2.0], [2.0]), ([2.0, 0.0], [2.0, 0.0]), ... ([0.0, 1.0, 2.0], [0.0, 1.0]), ([1.0], [1.0, 2.0])]) diff --git a/python/pyspark/mllib/feature.py b/python/pyspark/mllib/feature.py index aac305db6c19a..da90554f41437 100644 --- a/python/pyspark/mllib/feature.py +++ b/python/pyspark/mllib/feature.py @@ -68,6 +68,8 @@ class Normalizer(VectorTransformer): For `p` = float('inf'), max(abs(vector)) will be used as norm for normalization. + :param p: Normalization in L^p^ space, p = 2 by default. + >>> v = Vectors.dense(range(3)) >>> nor = Normalizer(1) >>> nor.transform(v) @@ -82,9 +84,6 @@ class Normalizer(VectorTransformer): DenseVector([0.0, 0.5, 1.0]) """ def __init__(self, p=2.0): - """ - :param p: Normalization in L^p^ space, p = 2 by default. - """ assert p >= 1.0, "p should be greater than 1.0" self.p = float(p) @@ -94,7 +93,7 @@ def transform(self, vector): :param vector: vector or RDD of vector to be normalized. :return: normalized vector. If the norm of the input is zero, it - will return the input vector. + will return the input vector. """ sc = SparkContext._active_spark_context assert sc is not None, "SparkContext should be initialized first" @@ -164,6 +163,13 @@ class StandardScaler(object): variance using column summary statistics on the samples in the training set. + :param withMean: False by default. Centers the data with mean + before scaling. It will build a dense output, so this + does not work on sparse input and will raise an + exception. + :param withStd: True by default. Scales the data to unit + standard deviation. + >>> vs = [Vectors.dense([-2.0, 2.3, 0]), Vectors.dense([3.8, 0.0, 1.9])] >>> dataset = sc.parallelize(vs) >>> standardizer = StandardScaler(True, True) @@ -174,14 +180,6 @@ class StandardScaler(object): DenseVector([0.7071, -0.7071, 0.7071]) """ def __init__(self, withMean=False, withStd=True): - """ - :param withMean: False by default. Centers the data with mean - before scaling. It will build a dense output, so this - does not work on sparse input and will raise an - exception. - :param withStd: True by default. Scales the data to unit - standard deviation. - """ if not (withMean or withStd): warnings.warn("Both withMean and withStd are false. The model does nothing.") self.withMean = withMean @@ -193,7 +191,7 @@ def fit(self, dataset): for later scaling. :param data: The data used to compute the mean and variance - to build the transformation model. + to build the transformation model. :return: a StandardScalarModel """ dataset = dataset.map(_convert_to_vector) @@ -223,6 +221,8 @@ class ChiSqSelector(object): Creates a ChiSquared feature selector. + :param numTopFeatures: number of features that selector will select. + >>> data = [ ... LabeledPoint(0.0, SparseVector(3, {0: 8.0, 1: 7.0})), ... LabeledPoint(1.0, SparseVector(3, {1: 9.0, 2: 6.0})), @@ -236,9 +236,6 @@ class ChiSqSelector(object): DenseVector([5.0]) """ def __init__(self, numTopFeatures): - """ - :param numTopFeatures: number of features that selector will select. - """ self.numTopFeatures = int(numTopFeatures) def fit(self, data): @@ -246,9 +243,9 @@ def fit(self, data): Returns a ChiSquared feature selector. :param data: an `RDD[LabeledPoint]` containing the labeled dataset - with categorical features. Real-valued features will be - treated as categorical for each distinct value. - Apply feature discretizer before using this function. + with categorical features. Real-valued features will be + treated as categorical for each distinct value. + Apply feature discretizer before using this function. """ jmodel = callMLlibFunc("fitChiSqSelector", self.numTopFeatures, data) return ChiSqSelectorModel(jmodel) @@ -263,15 +260,14 @@ class HashingTF(object): Note: the terms must be hashable (can not be dict/set/list...). + :param numFeatures: number of features (default: 2^20) + >>> htf = HashingTF(100) >>> doc = "a a b b c d".split(" ") >>> htf.transform(doc) SparseVector(100, {...}) """ def __init__(self, numFeatures=1 << 20): - """ - :param numFeatures: number of features (default: 2^20) - """ self.numFeatures = numFeatures def indexOf(self, term): @@ -311,7 +307,7 @@ def transform(self, x): Call transform directly on the RDD instead. :param x: an RDD of term frequency vectors or a term frequency - vector + vector :return: an RDD of TF-IDF vectors or a TF-IDF vector """ if isinstance(x, RDD): @@ -342,6 +338,9 @@ class IDF(object): `minDocFreq`). For terms that are not in at least `minDocFreq` documents, the IDF is found as 0, resulting in TF-IDFs of 0. + :param minDocFreq: minimum of documents in which a term + should appear for filtering + >>> n = 4 >>> freqs = [Vectors.sparse(n, (1, 3), (1.0, 2.0)), ... Vectors.dense([0.0, 1.0, 2.0, 3.0]), @@ -362,10 +361,6 @@ class IDF(object): SparseVector(4, {1: 0.0, 3: 0.5754}) """ def __init__(self, minDocFreq=0): - """ - :param minDocFreq: minimum of documents in which a term - should appear for filtering - """ self.minDocFreq = minDocFreq def fit(self, dataset): diff --git a/python/pyspark/mllib/rand.py b/python/pyspark/mllib/random.py similarity index 100% rename from python/pyspark/mllib/rand.py rename to python/pyspark/mllib/random.py diff --git a/python/pyspark/sql/__init__.py b/python/pyspark/sql/__init__.py index 8fee92ae3aed5..726d288d97b2e 100644 --- a/python/pyspark/sql/__init__.py +++ b/python/pyspark/sql/__init__.py @@ -50,18 +50,6 @@ def deco(f): return f return deco -# fix the module name conflict for Python 3+ -import sys -from . import _types as types -modname = __name__ + '.types' -types.__name__ = modname -# update the __module__ for all objects, make them picklable -for v in types.__dict__.values(): - if hasattr(v, "__module__") and v.__module__.endswith('._types'): - v.__module__ = modname -sys.modules[modname] = types -del modname, sys - from pyspark.sql.types import Row from pyspark.sql.context import SQLContext, HiveContext from pyspark.sql.column import Column diff --git a/python/pyspark/sql/column.py b/python/pyspark/sql/column.py index 8dc5039f587f0..1ecec5b126505 100644 --- a/python/pyspark/sql/column.py +++ b/python/pyspark/sql/column.py @@ -315,6 +315,14 @@ def between(self, lowerBound, upperBound): """ A boolean expression that is evaluated to true if the value of this expression is between the given columns. + + >>> df.select(df.name, df.age.between(2, 4)).show() + +-----+--------------------------+ + | name|((age >= 2) && (age <= 4))| + +-----+--------------------------+ + |Alice| true| + | Bob| false| + +-----+--------------------------+ """ return (self >= lowerBound) & (self <= upperBound) @@ -328,12 +336,20 @@ def when(self, condition, value): :param condition: a boolean :class:`Column` expression. :param value: a literal value, or a :class:`Column` expression. + + >>> from pyspark.sql import functions as F + >>> df.select(df.name, F.when(df.age > 4, 1).when(df.age < 3, -1).otherwise(0)).show() + +-----+--------------------------------------------------------+ + | name|CASE WHEN (age > 4) THEN 1 WHEN (age < 3) THEN -1 ELSE 0| + +-----+--------------------------------------------------------+ + |Alice| -1| + | Bob| 1| + +-----+--------------------------------------------------------+ """ - sc = SparkContext._active_spark_context if not isinstance(condition, Column): raise TypeError("condition should be a Column") v = value._jc if isinstance(value, Column) else value - jc = sc._jvm.functions.when(condition._jc, v) + jc = self._jc.when(condition._jc, v) return Column(jc) @since(1.4) @@ -345,9 +361,18 @@ def otherwise(self, value): See :func:`pyspark.sql.functions.when` for example usage. :param value: a literal value, or a :class:`Column` expression. + + >>> from pyspark.sql import functions as F + >>> df.select(df.name, F.when(df.age > 3, 1).otherwise(0)).show() + +-----+---------------------------------+ + | name|CASE WHEN (age > 3) THEN 1 ELSE 0| + +-----+---------------------------------+ + |Alice| 0| + | Bob| 1| + +-----+---------------------------------+ """ v = value._jc if isinstance(value, Column) else value - jc = self._jc.otherwise(value) + jc = self._jc.otherwise(v) return Column(jc) @since(1.4) diff --git a/python/pyspark/sql/dataframe.py b/python/pyspark/sql/dataframe.py index 936487519a645..a82b6b87c413e 100644 --- a/python/pyspark/sql/dataframe.py +++ b/python/pyspark/sql/dataframe.py @@ -1170,6 +1170,9 @@ def freqItems(self, cols, support=None): "http://dx.doi.org/10.1145/762471.762473, proposed by Karp, Schenker, and Papadimitriou". :func:`DataFrame.freqItems` and :func:`DataFrameStatFunctions.freqItems` are aliases. + This function is meant for exploratory data analysis, as we make no guarantee about the + backward compatibility of the schema of the resulting DataFrame. + :param cols: Names of the columns to calculate frequent items for as a list or tuple of strings. :param support: The frequency with which to consider an item 'frequent'. Default is 1%. diff --git a/python/pyspark/sql/readwriter.py b/python/pyspark/sql/readwriter.py index b6fd413bec7db..d17d87419fe3d 100644 --- a/python/pyspark/sql/readwriter.py +++ b/python/pyspark/sql/readwriter.py @@ -43,6 +43,39 @@ def _df(self, jdf): from pyspark.sql.dataframe import DataFrame return DataFrame(jdf, self._sqlContext) + @since(1.4) + def format(self, source): + """ + Specifies the input data source format. + """ + self._jreader = self._jreader.format(source) + return self + + @since(1.4) + def schema(self, schema): + """ + Specifies the input schema. Some data sources (e.g. JSON) can + infer the input schema automatically from data. By specifying + the schema here, the underlying data source can skip the schema + inference step, and thus speed up data loading. + + :param schema: a StructType object + """ + if not isinstance(schema, StructType): + raise TypeError("schema should be StructType") + jschema = self._sqlContext._ssql_ctx.parseDataType(schema.json()) + self._jreader = self._jreader.schema(jschema) + return self + + @since(1.4) + def options(self, **options): + """ + Adds input options for the underlying data source. + """ + for k in options: + self._jreader = self._jreader.option(k, options[k]) + return self + @since(1.4) def load(self, path=None, format=None, schema=None, **options): """Loads data from a data source and returns it as a :class`DataFrame`. @@ -52,20 +85,15 @@ def load(self, path=None, format=None, schema=None, **options): :param schema: optional :class:`StructType` for the input schema. :param options: all other string options """ - jreader = self._jreader if format is not None: - jreader = jreader.format(format) + self.format(format) if schema is not None: - if not isinstance(schema, StructType): - raise TypeError("schema should be StructType") - jschema = self._sqlContext._ssql_ctx.parseDataType(schema.json()) - jreader = jreader.schema(jschema) - for k in options: - jreader = jreader.option(k, options[k]) + self.schema(schema) + self.options(**options) if path is not None: - return self._df(jreader.load(path)) + return self._df(self._jreader.load(path)) else: - return self._df(jreader.load()) + return self._df(self._jreader.load()) @since(1.4) def json(self, path, schema=None): @@ -105,12 +133,9 @@ def json(self, path, schema=None): | |-- field5: array (nullable = true) | | |-- element: integer (containsNull = true) """ - if schema is None: - jdf = self._jreader.json(path) - else: - jschema = self._sqlContext._ssql_ctx.parseDataType(schema.json()) - jdf = self._jreader.schema(jschema).json(path) - return self._df(jdf) + if schema is not None: + self.schema(schema) + return self._df(self._jreader.json(path)) @since(1.4) def table(self, tableName): @@ -194,6 +219,51 @@ def __init__(self, df): self._sqlContext = df.sql_ctx self._jwrite = df._jdf.write() + @since(1.4) + def mode(self, saveMode): + """ + Specifies the behavior when data or table already exists. Options include: + + * `append`: Append contents of this :class:`DataFrame` to existing data. + * `overwrite`: Overwrite existing data. + * `error`: Throw an exception if data already exists. + * `ignore`: Silently ignore this operation if data already exists. + """ + self._jwrite = self._jwrite.mode(saveMode) + return self + + @since(1.4) + def format(self, source): + """ + Specifies the underlying output data source. Built-in options include + "parquet", "json", etc. + """ + self._jwrite = self._jwrite.format(source) + return self + + @since(1.4) + def options(self, **options): + """ + Adds output options for the underlying data source. + """ + for k in options: + self._jwrite = self._jwrite.option(k, options[k]) + return self + + @since(1.4) + def partitionBy(self, *cols): + """ + Partitions the output by the given columns on the file system. + If specified, the output is laid out on the file system similar + to Hive's partitioning scheme. + + :param cols: name of columns + """ + if len(cols) == 1 and isinstance(cols[0], (list, tuple)): + cols = cols[0] + self._jwrite = self._jwrite.partitionBy(_to_seq(self._sqlContext._sc, cols)) + return self + @since(1.4) def save(self, path=None, format=None, mode="error", **options): """ @@ -216,16 +286,15 @@ def save(self, path=None, format=None, mode="error", **options): :param mode: one of `append`, `overwrite`, `error`, `ignore` (default: error) :param options: all other string options """ - jwrite = self._jwrite.mode(mode) + self.mode(mode).options(**options) if format is not None: - jwrite = jwrite.format(format) - for k in options: - jwrite = jwrite.option(k, options[k]) + self.format(format) if path is None: - jwrite.save() + self._jwrite.save() else: - jwrite.save(path) + self._jwrite.save(path) + @since(1.4) def insertInto(self, tableName, overwrite=False): """ Inserts the content of the :class:`DataFrame` to the specified table. @@ -256,12 +325,10 @@ def saveAsTable(self, name, format=None, mode="error", **options): :param mode: one of `append`, `overwrite`, `error`, `ignore` (default: error) :param options: all other string options """ - jwrite = self._jwrite.mode(mode) + self.mode(mode).options(**options) if format is not None: - jwrite = jwrite.format(format) - for k in options: - jwrite = jwrite.option(k, options[k]) - return jwrite.saveAsTable(name) + self.format(format) + return self._jwrite.saveAsTable(name) @since(1.4) def json(self, path, mode="error"): diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index 5c53c3a8ed4f1..76384d31f1bf4 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -100,6 +100,15 @@ def test_data_type_eq(self): lt2 = pickle.loads(pickle.dumps(LongType())) self.assertEquals(lt, lt2) + # regression test for SPARK-7978 + def test_decimal_type(self): + t1 = DecimalType() + t2 = DecimalType(10, 2) + self.assertTrue(t2 is not t1) + self.assertNotEqual(t1, t2) + t3 = DecimalType(8) + self.assertNotEqual(t2, t3) + class SQLTests(ReusedPySparkTestCase): diff --git a/python/pyspark/sql/_types.py b/python/pyspark/sql/types.py similarity index 99% rename from python/pyspark/sql/_types.py rename to python/pyspark/sql/types.py index 9e7e9f04bc35d..b6ec6137c9180 100644 --- a/python/pyspark/sql/_types.py +++ b/python/pyspark/sql/types.py @@ -97,8 +97,6 @@ class AtomicType(DataType): """An internal type used to represent everything that is not null, UDTs, arrays, structs, and maps.""" - __metaclass__ = DataTypeSingleton - class NumericType(AtomicType): """Numeric data types. @@ -109,6 +107,8 @@ class IntegralType(NumericType): """Integral data types. """ + __metaclass__ = DataTypeSingleton + class FractionalType(NumericType): """Fractional data types. @@ -119,26 +119,36 @@ class StringType(AtomicType): """String data type. """ + __metaclass__ = DataTypeSingleton + class BinaryType(AtomicType): """Binary (byte array) data type. """ + __metaclass__ = DataTypeSingleton + class BooleanType(AtomicType): """Boolean data type. """ + __metaclass__ = DataTypeSingleton + class DateType(AtomicType): """Date (datetime.date) data type. """ + __metaclass__ = DataTypeSingleton + class TimestampType(AtomicType): """Timestamp (datetime.datetime) data type. """ + __metaclass__ = DataTypeSingleton + class DecimalType(FractionalType): """Decimal (decimal.Decimal) data type. @@ -172,11 +182,15 @@ class DoubleType(FractionalType): """Double data type, representing double precision floats. """ + __metaclass__ = DataTypeSingleton + class FloatType(FractionalType): """Float data type, representing single precision floats. """ + __metaclass__ = DataTypeSingleton + class ByteType(IntegralType): """Byte data type, i.e. a signed integer in a single byte. diff --git a/python/pyspark/streaming/tests.py b/python/pyspark/streaming/tests.py index 33ea8c9293d74..46cb18b2e8ef9 100644 --- a/python/pyspark/streaming/tests.py +++ b/python/pyspark/streaming/tests.py @@ -41,8 +41,8 @@ class PySparkStreamingTestCase(unittest.TestCase): - timeout = 4 # seconds - duration = .2 + timeout = 10 # seconds + duration = .5 @classmethod def setUpClass(cls): @@ -379,13 +379,13 @@ def func(dstream): class WindowFunctionTests(PySparkStreamingTestCase): - timeout = 5 + timeout = 15 def test_window(self): input = [range(1), range(2), range(3), range(4), range(5)] def func(dstream): - return dstream.window(.6, .2).count() + return dstream.window(1.5, .5).count() expected = [[1], [3], [6], [9], [12], [9], [5]] self._test_func(input, func, expected) @@ -394,7 +394,7 @@ def test_count_by_window(self): input = [range(1), range(2), range(3), range(4), range(5)] def func(dstream): - return dstream.countByWindow(.6, .2) + return dstream.countByWindow(1.5, .5) expected = [[1], [3], [6], [9], [12], [9], [5]] self._test_func(input, func, expected) @@ -403,7 +403,7 @@ def test_count_by_window_large(self): input = [range(1), range(2), range(3), range(4), range(5), range(6)] def func(dstream): - return dstream.countByWindow(1, .2) + return dstream.countByWindow(2.5, .5) expected = [[1], [3], [6], [10], [15], [20], [18], [15], [11], [6]] self._test_func(input, func, expected) @@ -412,7 +412,7 @@ def test_count_by_value_and_window(self): input = [range(1), range(2), range(3), range(4), range(5), range(6)] def func(dstream): - return dstream.countByValueAndWindow(1, .2) + return dstream.countByValueAndWindow(2.5, .5) expected = [[1], [2], [3], [4], [5], [6], [6], [6], [6], [6]] self._test_func(input, func, expected) @@ -421,7 +421,7 @@ def test_group_by_key_and_window(self): input = [[('a', i)] for i in range(5)] def func(dstream): - return dstream.groupByKeyAndWindow(.6, .2).mapValues(list) + return dstream.groupByKeyAndWindow(1.5, .5).mapValues(list) expected = [[('a', [0])], [('a', [0, 1])], [('a', [0, 1, 2])], [('a', [1, 2, 3])], [('a', [2, 3, 4])], [('a', [3, 4])], [('a', [4])]] diff --git a/python/run-tests b/python/run-tests index ffde2fb24b369..17dda3eadac0c 100755 --- a/python/run-tests +++ b/python/run-tests @@ -57,54 +57,56 @@ function run_test() { function run_core_tests() { echo "Run core tests ..." - run_test "pyspark/rdd.py" - run_test "pyspark/context.py" - run_test "pyspark/conf.py" - PYSPARK_DOC_TEST=1 run_test "pyspark/broadcast.py" - PYSPARK_DOC_TEST=1 run_test "pyspark/accumulators.py" - run_test "pyspark/serializers.py" - run_test "pyspark/profiler.py" - run_test "pyspark/shuffle.py" - run_test "pyspark/tests.py" + run_test "pyspark.rdd" + run_test "pyspark.context" + run_test "pyspark.conf" + run_test "pyspark.broadcast" + run_test "pyspark.accumulators" + run_test "pyspark.serializers" + run_test "pyspark.profiler" + run_test "pyspark.shuffle" + run_test "pyspark.tests" } function run_sql_tests() { echo "Run sql tests ..." - run_test "pyspark/sql/_types.py" - run_test "pyspark/sql/context.py" - run_test "pyspark/sql/column.py" - run_test "pyspark/sql/dataframe.py" - run_test "pyspark/sql/group.py" - run_test "pyspark/sql/functions.py" - run_test "pyspark/sql/tests.py" + run_test "pyspark.sql.types" + run_test "pyspark.sql.context" + run_test "pyspark.sql.column" + run_test "pyspark.sql.dataframe" + run_test "pyspark.sql.group" + run_test "pyspark.sql.functions" + run_test "pyspark.sql.readwriter" + run_test "pyspark.sql.window" + run_test "pyspark.sql.tests" } function run_mllib_tests() { echo "Run mllib tests ..." - run_test "pyspark/mllib/classification.py" - run_test "pyspark/mllib/clustering.py" - run_test "pyspark/mllib/evaluation.py" - run_test "pyspark/mllib/feature.py" - run_test "pyspark/mllib/fpm.py" - run_test "pyspark/mllib/linalg.py" - run_test "pyspark/mllib/rand.py" - run_test "pyspark/mllib/recommendation.py" - run_test "pyspark/mllib/regression.py" - run_test "pyspark/mllib/stat/_statistics.py" - run_test "pyspark/mllib/tree.py" - run_test "pyspark/mllib/util.py" - run_test "pyspark/mllib/tests.py" + run_test "pyspark.mllib.classification" + run_test "pyspark.mllib.clustering" + run_test "pyspark.mllib.evaluation" + run_test "pyspark.mllib.feature" + run_test "pyspark.mllib.fpm" + run_test "pyspark.mllib.linalg" + run_test "pyspark.mllib.random" + run_test "pyspark.mllib.recommendation" + run_test "pyspark.mllib.regression" + run_test "pyspark.mllib.stat._statistics" + run_test "pyspark.mllib.tree" + run_test "pyspark.mllib.util" + run_test "pyspark.mllib.tests" } function run_ml_tests() { echo "Run ml tests ..." - run_test "pyspark/ml/feature.py" - run_test "pyspark/ml/classification.py" - run_test "pyspark/ml/recommendation.py" - run_test "pyspark/ml/regression.py" - run_test "pyspark/ml/tuning.py" - run_test "pyspark/ml/tests.py" - run_test "pyspark/ml/evaluation.py" + run_test "pyspark.ml.feature" + run_test "pyspark.ml.classification" + run_test "pyspark.ml.recommendation" + run_test "pyspark.ml.regression" + run_test "pyspark.ml.tuning" + run_test "pyspark.ml.tests" + run_test "pyspark.ml.evaluation" } function run_streaming_tests() { @@ -124,8 +126,8 @@ function run_streaming_tests() { done export PYSPARK_SUBMIT_ARGS="--jars ${KAFKA_ASSEMBLY_JAR} pyspark-shell" - run_test "pyspark/streaming/util.py" - run_test "pyspark/streaming/tests.py" + run_test "pyspark.streaming.util" + run_test "pyspark.streaming.tests" } echo "Running PySpark tests. Output is in python/$LOG_FILE." diff --git a/repl/pom.xml b/repl/pom.xml index 03053b4c3b287..6e5cb7f77e1df 100644 --- a/repl/pom.xml +++ b/repl/pom.xml @@ -48,6 +48,13 @@ spark-core_${scala.binary.version} ${project.version} + + org.apache.spark + spark-core_${scala.binary.version} + ${project.version} + test-jar + test + org.apache.spark spark-bagel_${scala.binary.version} diff --git a/repl/scala-2.10/src/test/scala/org/apache/spark/repl/ReplSuite.scala b/repl/scala-2.10/src/test/scala/org/apache/spark/repl/ReplSuite.scala index 934daaeaafca1..50fd43a418bca 100644 --- a/repl/scala-2.10/src/test/scala/org/apache/spark/repl/ReplSuite.scala +++ b/repl/scala-2.10/src/test/scala/org/apache/spark/repl/ReplSuite.scala @@ -22,13 +22,12 @@ import java.net.URLClassLoader import scala.collection.mutable.ArrayBuffer -import org.scalatest.FunSuite -import org.apache.spark.SparkContext +import org.apache.spark.{SparkContext, SparkFunSuite} import org.apache.commons.lang3.StringEscapeUtils import org.apache.spark.util.Utils -class ReplSuite extends FunSuite { +class ReplSuite extends SparkFunSuite { def runInterpreter(master: String, input: String): String = { val CONF_EXECUTOR_CLASSPATH = "spark.executor.extraClassPath" diff --git a/repl/scala-2.11/src/test/scala/org/apache/spark/repl/ReplSuite.scala b/repl/scala-2.11/src/test/scala/org/apache/spark/repl/ReplSuite.scala index 14f5e9ed4f25e..9ecc7c229e38a 100644 --- a/repl/scala-2.11/src/test/scala/org/apache/spark/repl/ReplSuite.scala +++ b/repl/scala-2.11/src/test/scala/org/apache/spark/repl/ReplSuite.scala @@ -24,14 +24,13 @@ import scala.collection.mutable.ArrayBuffer import scala.concurrent.duration._ import scala.tools.nsc.interpreter.SparkILoop -import org.scalatest.FunSuite import org.apache.commons.lang3.StringEscapeUtils -import org.apache.spark.SparkContext +import org.apache.spark.{SparkContext, SparkFunSuite} import org.apache.spark.util.Utils -class ReplSuite extends FunSuite { +class ReplSuite extends SparkFunSuite { def runInterpreter(master: String, input: String): String = { val CONF_EXECUTOR_CLASSPATH = "spark.executor.extraClassPath" diff --git a/repl/src/test/scala/org/apache/spark/repl/ExecutorClassLoaderSuite.scala b/repl/src/test/scala/org/apache/spark/repl/ExecutorClassLoaderSuite.scala index c709cde740748..a58eda12b1120 100644 --- a/repl/src/test/scala/org/apache/spark/repl/ExecutorClassLoaderSuite.scala +++ b/repl/src/test/scala/org/apache/spark/repl/ExecutorClassLoaderSuite.scala @@ -25,7 +25,6 @@ import scala.language.implicitConversions import scala.language.postfixOps import org.scalatest.BeforeAndAfterAll -import org.scalatest.FunSuite import org.scalatest.concurrent.Interruptor import org.scalatest.concurrent.Timeouts._ import org.scalatest.mock.MockitoSugar @@ -35,7 +34,7 @@ import org.apache.spark._ import org.apache.spark.util.Utils class ExecutorClassLoaderSuite - extends FunSuite + extends SparkFunSuite with BeforeAndAfterAll with MockitoSugar with Logging { diff --git a/scalastyle-config.xml b/scalastyle-config.xml index 68d980b610c00..d6f927b6fa803 100644 --- a/scalastyle-config.xml +++ b/scalastyle-config.xml @@ -14,25 +14,41 @@ ~ See the License for the specific language governing permissions and ~ limitations under the License. --> - - - - - - + - Scalastyle standard configuration - - - - - - - - - Scalastyle standard configuration + + + + + + + + + + - - - - - - - - - - true - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - + + + + + + + + + + + + + + true + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + - ARROW, EQUALS + ARROW, EQUALS, ELSE, TRY, CATCH, FINALLY, LARROW, RARROW - + + - ARROW, EQUALS, COMMA, COLON, IF, WHILE, FOR + ARROW, EQUALS, COMMA, COLON, IF, ELSE, DO, WHILE, FOR, MATCH, TRY, CATCH, FINALLY, LARROW, RARROW + + + + + + ^FunSuite[A-Za-z]*$ + Tests must extend org.apache.spark.SparkFunSuite instead. + + + + + + + + + ^println$ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + 800> + + + + + 30 + + + + + 10 + + + + + 50 + + + + + + + + + + + -1,0,1,2,3 + + diff --git a/sql/catalyst/pom.xml b/sql/catalyst/pom.xml index 5c322d032d474..d9e1cdb84bb27 100644 --- a/sql/catalyst/pom.xml +++ b/sql/catalyst/pom.xml @@ -50,6 +50,13 @@ spark-core_${scala.binary.version} ${project.version} + + org.apache.spark + spark-core_${scala.binary.version} + ${project.version} + test-jar + test + org.apache.spark spark-unsafe_${scala.binary.version} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystTypeConverters.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystTypeConverters.scala index 75a493b248f6e..1c0ddb5093d17 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystTypeConverters.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystTypeConverters.scala @@ -233,7 +233,7 @@ object CatalystTypeConverters { case other => other } - /** + /** * Converts Catalyst types used internally in rows to standard Scala types * This method is slow, and for batch conversion you should be using converter * produced by createToScalaConverter. diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index df37889eedcf0..bc17169f35a46 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -633,10 +633,10 @@ class Analyzer( * it into the plan tree. */ object ExtractWindowExpressions extends Rule[LogicalPlan] { - def hasWindowFunction(projectList: Seq[NamedExpression]): Boolean = + private def hasWindowFunction(projectList: Seq[NamedExpression]): Boolean = projectList.exists(hasWindowFunction) - def hasWindowFunction(expr: NamedExpression): Boolean = { + private def hasWindowFunction(expr: NamedExpression): Boolean = { expr.find { case window: WindowExpression => true case _ => false @@ -644,14 +644,24 @@ class Analyzer( } /** - * From a Seq of [[NamedExpression]]s, extract window expressions and - * other regular expressions. + * From a Seq of [[NamedExpression]]s, extract expressions containing window expressions and + * other regular expressions that do not contain any window expression. For example, for + * `col1, Sum(col2 + col3) OVER (PARTITION BY col4 ORDER BY col5)`, we will extract + * `col1`, `col2 + col3`, `col4`, and `col5` out and replace their appearances in + * the window expression as attribute references. So, the first returned value will be + * `[Sum(_w0) OVER (PARTITION BY _w1 ORDER BY _w2)]` and the second returned value will be + * [col1, col2 + col3 as _w0, col4 as _w1, col5 as _w2]. + * + * @return (seq of expressions containing at lease one window expressions, + * seq of non-window expressions) */ - def extract( + private def extract( expressions: Seq[NamedExpression]): (Seq[NamedExpression], Seq[NamedExpression]) = { - // First, we simple partition the input expressions to two part, one having - // WindowExpressions and another one without WindowExpressions. - val (windowExpressions, regularExpressions) = expressions.partition(hasWindowFunction) + // First, we partition the input expressions to two part. For the first part, + // every expression in it contain at least one WindowExpression. + // Expressions in the second part do not have any WindowExpression. + val (expressionsWithWindowFunctions, regularExpressions) = + expressions.partition(hasWindowFunction) // Then, we need to extract those regular expressions used in the WindowExpression. // For example, when we have col1 - Sum(col2 + col3) OVER (PARTITION BY col4 ORDER BY col5), @@ -660,8 +670,8 @@ class Analyzer( val extractedExprBuffer = new ArrayBuffer[NamedExpression]() def extractExpr(expr: Expression): Expression = expr match { case ne: NamedExpression => - // If a named expression is not in regularExpressions, add extract it and replace it - // with an AttributeReference. + // If a named expression is not in regularExpressions, add it to + // extractedExprBuffer and replace it with an AttributeReference. val missingExpr = AttributeSet(Seq(expr)) -- (regularExpressions ++ extractedExprBuffer) if (missingExpr.nonEmpty) { @@ -678,8 +688,9 @@ class Analyzer( withName.toAttribute } - // Now, we extract expressions from windowExpressions by using extractExpr. - val newWindowExpressions = windowExpressions.map { + // Now, we extract regular expressions from expressionsWithWindowFunctions + // by using extractExpr. + val newExpressionsWithWindowFunctions = expressionsWithWindowFunctions.map { _.transform { // Extracts children expressions of a WindowFunction (input parameters of // a WindowFunction). @@ -705,37 +716,80 @@ class Analyzer( }.asInstanceOf[NamedExpression] } - (newWindowExpressions, regularExpressions ++ extractedExprBuffer) - } + (newExpressionsWithWindowFunctions, regularExpressions ++ extractedExprBuffer) + } // end of extract /** * Adds operators for Window Expressions. Every Window operator handles a single Window Spec. */ - def addWindow(windowExpressions: Seq[NamedExpression], child: LogicalPlan): LogicalPlan = { - // First, we group window expressions based on their Window Spec. - val groupedWindowExpression = windowExpressions.groupBy { expr => - val windowSpec = expr.collectFirst { + private def addWindow( + expressionsWithWindowFunctions: Seq[NamedExpression], + child: LogicalPlan): LogicalPlan = { + // First, we need to extract all WindowExpressions from expressionsWithWindowFunctions + // and put those extracted WindowExpressions to extractedWindowExprBuffer. + // This step is needed because it is possible that an expression contains multiple + // WindowExpressions with different Window Specs. + // After extracting WindowExpressions, we need to construct a project list to generate + // expressionsWithWindowFunctions based on extractedWindowExprBuffer. + // For example, for "sum(a) over (...) / sum(b) over (...)", we will first extract + // "sum(a) over (...)" and "sum(b) over (...)" out, and assign "_we0" as the alias to + // "sum(a) over (...)" and "_we1" as the alias to "sum(b) over (...)". + // Then, the projectList will be [_we0/_we1]. + val extractedWindowExprBuffer = new ArrayBuffer[NamedExpression]() + val newExpressionsWithWindowFunctions = expressionsWithWindowFunctions.map { + // We need to use transformDown because we want to trigger + // "case alias @ Alias(window: WindowExpression, _)" first. + _.transformDown { + case alias @ Alias(window: WindowExpression, _) => + // If a WindowExpression has an assigned alias, just use it. + extractedWindowExprBuffer += alias + alias.toAttribute + case window: WindowExpression => + // If there is no alias assigned to the WindowExpressions. We create an + // internal column. + val withName = Alias(window, s"_we${extractedWindowExprBuffer.length}")() + extractedWindowExprBuffer += withName + withName.toAttribute + }.asInstanceOf[NamedExpression] + } + + // Second, we group extractedWindowExprBuffer based on their Window Spec. + val groupedWindowExpressions = extractedWindowExprBuffer.groupBy { expr => + val distinctWindowSpec = expr.collect { case window: WindowExpression => window.windowSpec + }.distinct + + // We do a final check and see if we only have a single Window Spec defined in an + // expressions. + if (distinctWindowSpec.length == 0 ) { + failAnalysis(s"$expr does not have any WindowExpression.") + } else if (distinctWindowSpec.length > 1) { + // newExpressionsWithWindowFunctions only have expressions with a single + // WindowExpression. If we reach here, we have a bug. + failAnalysis(s"$expr has multiple Window Specifications ($distinctWindowSpec)." + + s"Please file a bug report with this error message, stack trace, and the query.") + } else { + distinctWindowSpec.head } - windowSpec.getOrElse( - failAnalysis(s"$windowExpressions does not have any WindowExpression.")) }.toSeq - // For every Window Spec, we add a Window operator and set currentChild as the child of it. + // Third, for every Window Spec, we add a Window operator and set currentChild as the + // child of it. var currentChild = child var i = 0 - while (i < groupedWindowExpression.size) { - val (windowSpec, windowExpressions) = groupedWindowExpression(i) + while (i < groupedWindowExpressions.size) { + val (windowSpec, windowExpressions) = groupedWindowExpressions(i) // Set currentChild to the newly created Window operator. currentChild = Window(currentChild.output, windowExpressions, windowSpec, currentChild) - // Move to next WindowExpression. + // Move to next Window Spec. i += 1 } - // We return the top operator. - currentChild - } + // Finally, we create a Project to output currentChild's output + // newExpressionsWithWindowFunctions. + Project(currentChild.output ++ newExpressionsWithWindowFunctions, currentChild) + } // end of addWindow // We have to use transformDown at here to make sure the rule of // "Aggregate with Having clause" will be triggered. diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala index 44664f898f762..edcc918bfe921 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala @@ -76,8 +76,7 @@ trait HiveTypeCoercion { WidenTypes :: PromoteStrings :: DecimalPrecision :: - BooleanComparisons :: - BooleanCasts :: + BooleanEqualization :: StringToIntegralCasts :: FunctionArgumentConversion :: CaseWhenCoercion :: @@ -120,7 +119,7 @@ trait HiveTypeCoercion { * the appropriate numeric equivalent. */ object ConvertNaNs extends Rule[LogicalPlan] { - val stringNaN = Literal("NaN") + private val stringNaN = Literal("NaN") def apply(plan: LogicalPlan): LogicalPlan = plan transform { case q: LogicalPlan => q transformExpressions { @@ -297,8 +296,8 @@ trait HiveTypeCoercion { object InConversion extends Rule[LogicalPlan] { def apply(plan: LogicalPlan): LogicalPlan = plan transformAllExpressions { // Skip nodes who's children have not been resolved yet. - case e if !e.childrenResolved => e - + case e if !e.childrenResolved => e + case i @ In(a, b) if b.exists(_.dataType != a.dataType) => i.makeCopy(Array(a, b.map(Cast(_, a.dataType)))) } @@ -350,17 +349,17 @@ trait HiveTypeCoercion { import scala.math.{max, min} // Conversion rules for integer types into fixed-precision decimals - val intTypeToFixed: Map[DataType, DecimalType] = Map( + private val intTypeToFixed: Map[DataType, DecimalType] = Map( ByteType -> DecimalType(3, 0), ShortType -> DecimalType(5, 0), IntegerType -> DecimalType(10, 0), LongType -> DecimalType(20, 0) ) - def isFloat(t: DataType): Boolean = t == FloatType || t == DoubleType + private def isFloat(t: DataType): Boolean = t == FloatType || t == DoubleType // Conversion rules for float and double into fixed-precision decimals - val floatTypeToFixed: Map[DataType, DecimalType] = Map( + private val floatTypeToFixed: Map[DataType, DecimalType] = Map( FloatType -> DecimalType(7, 7), DoubleType -> DecimalType(15, 15) ) @@ -483,56 +482,66 @@ trait HiveTypeCoercion { } /** - * Changes Boolean values to Bytes so that expressions like true < false can be Evaluated. + * Changes numeric values to booleans so that expressions like true = 1 can be evaluated. */ - object BooleanComparisons extends Rule[LogicalPlan] { - val trueValues = Seq(1, 1L, 1.toByte, 1.toShort, new java.math.BigDecimal(1)).map(Literal(_)) - val falseValues = Seq(0, 0L, 0.toByte, 0.toShort, new java.math.BigDecimal(0)).map(Literal(_)) + object BooleanEqualization extends Rule[LogicalPlan] { + private val trueValues = Seq(1.toByte, 1.toShort, 1, 1L, new java.math.BigDecimal(1)) + private val falseValues = Seq(0.toByte, 0.toShort, 0, 0L, new java.math.BigDecimal(0)) + + private def buildCaseKeyWhen(booleanExpr: Expression, numericExpr: Expression) = { + CaseKeyWhen(numericExpr, Seq( + Literal(trueValues.head), booleanExpr, + Literal(falseValues.head), Not(booleanExpr), + Literal(false))) + } - def apply(plan: LogicalPlan): LogicalPlan = plan transformAllExpressions { - // Skip nodes who's children have not been resolved yet. - case e if !e.childrenResolved => e + private def transform(booleanExpr: Expression, numericExpr: Expression) = { + If(Or(IsNull(booleanExpr), IsNull(numericExpr)), + Literal.create(null, BooleanType), + buildCaseKeyWhen(booleanExpr, numericExpr)) + } - // Hive treats (true = 1) as true and (false = 0) as true. - case EqualTo(l @ BooleanType(), r) if trueValues.contains(r) => l - case EqualTo(l, r @ BooleanType()) if trueValues.contains(l) => r - case EqualTo(l @ BooleanType(), r) if falseValues.contains(r) => Not(l) - case EqualTo(l, r @ BooleanType()) if falseValues.contains(l) => Not(r) - - // No need to change other EqualTo operators as that actually makes sense for boolean types. - case e: EqualTo => e - // No need to change the EqualNullSafe operators, too - case e: EqualNullSafe => e - // Otherwise turn them to Byte types so that there exists and ordering. - case p: BinaryComparison if p.left.dataType == BooleanType && - p.right.dataType == BooleanType => - p.makeCopy(Array(Cast(p.left, ByteType), Cast(p.right, ByteType))) + private def transformNullSafe(booleanExpr: Expression, numericExpr: Expression) = { + CaseWhen(Seq( + And(IsNull(booleanExpr), IsNull(numericExpr)), Literal(true), + Or(IsNull(booleanExpr), IsNull(numericExpr)), Literal(false), + buildCaseKeyWhen(booleanExpr, numericExpr) + )) } - } - /** - * Casts to/from [[BooleanType]] are transformed into comparisons since - * the JVM does not consider Booleans to be numeric types. - */ - object BooleanCasts extends Rule[LogicalPlan] { def apply(plan: LogicalPlan): LogicalPlan = plan transformAllExpressions { // Skip nodes who's children have not been resolved yet. case e if !e.childrenResolved => e - // Skip if the type is boolean type already. Note that this extra cast should be removed - // by optimizer.SimplifyCasts. - case Cast(e, BooleanType) if e.dataType == BooleanType => e - // DateType should be null if be cast to boolean. - case Cast(e, BooleanType) if e.dataType == DateType => Cast(e, BooleanType) - // If the data type is not boolean and is being cast boolean, turn it into a comparison - // with the numeric value, i.e. x != 0. This will coerce the type into numeric type. - case Cast(e, BooleanType) if e.dataType != BooleanType => Not(EqualTo(e, Literal(0))) - // Stringify boolean if casting to StringType. - // TODO Ensure true/false string letter casing is consistent with Hive in all cases. - case Cast(e, StringType) if e.dataType == BooleanType => - If(e, Literal("true"), Literal("false")) - // Turn true into 1, and false into 0 if casting boolean into other types. - case Cast(e, dataType) if e.dataType == BooleanType => - Cast(If(e, Literal(1), Literal(0)), dataType) + + // Hive treats (true = 1) as true and (false = 0) as true, + // all other cases are considered as false. + + // We may simplify the expression if one side is literal numeric values + case EqualTo(l @ BooleanType(), Literal(value, _: NumericType)) + if trueValues.contains(value) => l + case EqualTo(l @ BooleanType(), Literal(value, _: NumericType)) + if falseValues.contains(value) => Not(l) + case EqualTo(Literal(value, _: NumericType), r @ BooleanType()) + if trueValues.contains(value) => r + case EqualTo(Literal(value, _: NumericType), r @ BooleanType()) + if falseValues.contains(value) => Not(r) + case EqualNullSafe(l @ BooleanType(), Literal(value, _: NumericType)) + if trueValues.contains(value) => And(IsNotNull(l), l) + case EqualNullSafe(l @ BooleanType(), Literal(value, _: NumericType)) + if falseValues.contains(value) => And(IsNotNull(l), Not(l)) + case EqualNullSafe(Literal(value, _: NumericType), r @ BooleanType()) + if trueValues.contains(value) => And(IsNotNull(r), r) + case EqualNullSafe(Literal(value, _: NumericType), r @ BooleanType()) + if falseValues.contains(value) => And(IsNotNull(r), Not(r)) + + case EqualTo(l @ BooleanType(), r @ NumericType()) => + transform(l , r) + case EqualTo(l @ NumericType(), r @ BooleanType()) => + transform(r, l) + case EqualNullSafe(l @ BooleanType(), r @ NumericType()) => + transformNullSafe(l, r) + case EqualNullSafe(l @ NumericType(), r @ BooleanType()) => + transformNullSafe(r, l) } } @@ -633,7 +642,7 @@ trait HiveTypeCoercion { import HiveTypeCoercion._ def apply(plan: LogicalPlan): LogicalPlan = plan transformAllExpressions { - case cw: CaseWhenLike if !cw.resolved && cw.childrenResolved && !cw.valueTypesEqual => + case cw: CaseWhenLike if cw.childrenResolved && !cw.valueTypesEqual => logDebug(s"Input values for null casting ${cw.valueTypes.mkString(",")}") val commonType = cw.valueTypes.reduce { (v1, v2) => findTightestCommonType(v1, v2).getOrElse(sys.error( @@ -652,6 +661,23 @@ trait HiveTypeCoercion { case CaseKeyWhen(key, _) => CaseKeyWhen(key, transformedBranches) } + + case ckw: CaseKeyWhen if ckw.childrenResolved && !ckw.resolved => + val commonType = (ckw.key +: ckw.whenList).map(_.dataType).reduce { (v1, v2) => + findTightestCommonType(v1, v2).getOrElse(sys.error( + s"Types in CASE WHEN must be the same or coercible to a common type: $v1 != $v2")) + } + val transformedBranches = ckw.branches.sliding(2, 2).map { + case Seq(when, then) if when.dataType != commonType => + Seq(Cast(when, commonType), then) + case s => s + }.reduce(_ ++ _) + val transformedKey = if (ckw.key.dataType != commonType) { + Cast(ckw.key, commonType) + } else { + ckw.key + } + CaseKeyWhen(transformedKey, transformedBranches) } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala index d19928784442e..adc6505d69cdf 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala @@ -37,7 +37,15 @@ abstract class Expression extends TreeNode[Expression] { * - A [[Cast]] or [[UnaryMinus]] is foldable if its child is foldable */ def foldable: Boolean = false + + /** + * Returns true when the current expression always return the same result for fixed input values. + */ + // TODO: Need to define explicit input values vs implicit input values. + def deterministic: Boolean = true + def nullable: Boolean + def references: AttributeSet = AttributeSet(children.flatMap(_.references.iterator)) /** Returns the result of evaluating this expression on a given input Row */ diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SortOrder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SortOrder.scala index 195eec8e5cdc4..99340a14c9ecc 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SortOrder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SortOrder.scala @@ -29,7 +29,7 @@ case object Descending extends SortDirection * An expression that can be used to sort a tuple. This class extends expression primarily so that * transformations over expression will descend into its child. */ -case class SortOrder(child: Expression, direction: SortDirection) extends Expression +case class SortOrder(child: Expression, direction: SortDirection) extends Expression with trees.UnaryNode[Expression] { override def dataType: DataType = child.dataType diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregates.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregates.scala index 6c380d3084652..0266084a6d174 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregates.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregates.scala @@ -394,13 +394,13 @@ case class Sum(child: Expression) extends PartialAggregate with trees.UnaryNode[ * Combining PartitionLevel InputData * <-- null * Zero <-- Zero <-- null - * + * * <-- null <-- no data - * null <-- null <-- no data + * null <-- null <-- no data */ case class CombineSum(child: Expression) extends AggregateExpression { def this() = this(null) - + override def children: Seq[Expression] = child :: Nil override def nullable: Boolean = true override def dataType: DataType = child.dataType @@ -616,7 +616,7 @@ case class SumFunction(expr: Expression, base: AggregateExpression) extends Aggr private val sum = MutableLiteral(null, calcType) - private val addFunction = + private val addFunction = Coalesce(Seq(Add(Coalesce(Seq(sum, zero)), Cast(expr, calcType)), sum, zero)) override def update(input: Row): Unit = { @@ -634,7 +634,7 @@ case class SumFunction(expr: Expression, base: AggregateExpression) extends Aggr case class CombineSumFunction(expr: Expression, base: AggregateExpression) extends AggregateFunction { - + def this() = this(null, null) // Required for serialization. private val calcType = @@ -649,12 +649,12 @@ case class CombineSumFunction(expr: Expression, base: AggregateExpression) private val sum = MutableLiteral(null, calcType) - private val addFunction = + private val addFunction = Coalesce(Seq(Add(Coalesce(Seq(sum, zero)), Cast(expr, calcType)), sum, zero)) - + override def update(input: Row): Unit = { val result = expr.eval(input) - // partial sum result can be null only when no input rows present + // partial sum result can be null only when no input rows present if(result != null) { sum.update(addFunction, input) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala index 34c833b260dc0..f2299d5db6e9f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala @@ -180,7 +180,7 @@ case class Divide(left: Expression, right: Expression) extends BinaryArithmetic case it: IntegralType => it.integral.asInstanceOf[Integral[Any]].quot case other => sys.error(s"Type $other does not support numeric operations") } - + override def eval(input: Row): Any = { val evalE2 = right.eval(input) if (evalE2 == null || evalE2 == 0) { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypes.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypes.scala index e7cd7131a9e56..6398b8f9e4ed7 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypes.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypes.scala @@ -26,7 +26,7 @@ import org.apache.spark.sql.types._ case class CreateArray(children: Seq[Expression]) extends Expression { override def foldable: Boolean = children.forall(_.foldable) - + lazy val childTypes = children.map(_.dataType).distinct override lazy val resolved = diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathfuncs/binary.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathfuncs/binary.scala index 890efc9f52ca3..01f62ba0442e9 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathfuncs/binary.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathfuncs/binary.scala @@ -26,7 +26,7 @@ import org.apache.spark.sql.types._ * @param f The math function. * @param name The short name of the function */ -abstract class BinaryMathExpression(f: (Double, Double) => Double, name: String) +abstract class BinaryMathExpression(f: (Double, Double) => Double, name: String) extends BinaryExpression with Serializable with ExpectsInputTypes { self: Product => override def symbol: String = null diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala index e2d1c8115e051..4f422d69c4382 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala @@ -366,7 +366,7 @@ trait CaseWhenLike extends Expression { // both then and else val should be considered. def valueTypes: Seq[DataType] = (thenList ++ elseValue).map(_.dataType) - def valueTypesEqual: Boolean = valueTypes.distinct.size <= 1 + def valueTypesEqual: Boolean = valueTypes.distinct.size == 1 override def dataType: DataType = { if (!resolved) { @@ -442,7 +442,8 @@ case class CaseKeyWhen(key: Expression, branches: Seq[Expression]) extends CaseW override def children: Seq[Expression] = key +: branches override lazy val resolved: Boolean = - childrenResolved && valueTypesEqual + childrenResolved && valueTypesEqual && + (key +: whenList).map(_.dataType).distinct.size == 1 /** Written in imperative fashion for performance considerations. */ override def eval(input: Row): Any = { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/random.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/random.scala index de82c15680607..b2647124c4e49 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/random.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/random.scala @@ -24,7 +24,7 @@ import org.apache.spark.util.random.XORShiftRandom /** * A Random distribution generating expression. - * TODO: This can be made generic to generate any type of random distribution, or any type of + * TODO: This can be made generic to generate any type of random distribution, or any type of * StructType. * * Since this expression is stateful, it cannot be a case object. @@ -38,6 +38,8 @@ abstract class RDG(seed: Long) extends LeafExpression with Serializable { */ @transient protected lazy val rng = new XORShiftRandom(seed + TaskContext.get().partitionId()) + override def deterministic: Boolean = false + override def nullable: Boolean = false override def dataType: DataType = DoubleType diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala index 83a44a12f0682..c4ef9c30907f1 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala @@ -133,7 +133,7 @@ trait CaseConversionExpression extends ExpectsInputTypes { * A function that converts the characters of a string to uppercase. */ case class Upper(child: Expression) extends UnaryExpression with CaseConversionExpression { - + override def convert(v: UTF8String): UTF8String = v.toUpperCase() override def toString: String = s"Upper($child)" @@ -143,7 +143,7 @@ case class Upper(child: Expression) extends UnaryExpression with CaseConversionE * A function that converts the characters of a string to lowercase. */ case class Lower(child: Expression) extends UnaryExpression with CaseConversionExpression { - + override def convert(v: UTF8String): UTF8String = v.toLowerCase() override def toString: String = s"Lower($child)" @@ -223,7 +223,7 @@ case class Substring(str: Expression, pos: Expression, len: Expression) @inline def slicePos(startPos: Int, sliceLen: Int, length: () => Int): (Int, Int) = { // Hive and SQL use one-based indexing for SUBSTR arguments but also accept zero and - // negative indices for start positions. If a start index i is greater than 0, it + // negative indices for start positions. If a start index i is greater than 0, it // refers to element i-1 in the sequence. If a start index i is less than 0, it refers // to the -ith element before the end of the sequence. If a start index i is 0, it // refers to the first element. diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index c2818d957cc79..b25fb48f55e2b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -179,8 +179,17 @@ object ColumnPruning extends Rule[LogicalPlan] { * expressions into one single expression. */ object ProjectCollapsing extends Rule[LogicalPlan] { + + /** Returns true if any expression in projectList is non-deterministic. */ + private def hasNondeterministic(projectList: Seq[NamedExpression]): Boolean = { + projectList.exists(expr => expr.find(!_.deterministic).isDefined) + } + def apply(plan: LogicalPlan): LogicalPlan = plan transformUp { - case Project(projectList1, Project(projectList2, child)) => + // We only collapse these two Projects if the child Project's expressions are all + // deterministic. + case Project(projectList1, Project(projectList2, child)) + if !hasNondeterministic(projectList2) => // Create a map of Aliases to their values from the child projection. // e.g., 'SELECT ... FROM (SELECT a + b AS c, d ...)' produces Map(c -> Alias(a + b, c)). val aliasMap = AttributeMap(projectList2.collect { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DataType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DataType.scala index 54604808e133e..1ba3a2686639f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DataType.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DataType.scala @@ -165,6 +165,9 @@ object DataType { def fromJson(json: String): DataType = parseDataType(parse(json)) + /** + * @deprecated As of 1.2.0, replaced by `DataType.fromJson()` + */ @deprecated("Use DataType.fromJson instead", "1.2.0") def fromCaseClassString(string: String): DataType = CaseClassStringParser(string) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/Decimal.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/Decimal.scala index 994c5202c15dc..eb3c58c37f308 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/Decimal.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/Decimal.scala @@ -313,7 +313,7 @@ object Decimal { // See scala.math's Numeric.scala for examples for Scala's built-in types. /** Common methods for Decimal evidence parameters */ - trait DecimalIsConflicted extends Numeric[Decimal] { + private[sql] trait DecimalIsConflicted extends Numeric[Decimal] { override def plus(x: Decimal, y: Decimal): Decimal = x + y override def times(x: Decimal, y: Decimal): Decimal = x * y override def minus(x: Decimal, y: Decimal): Decimal = x - y @@ -327,12 +327,12 @@ object Decimal { } /** A [[scala.math.Fractional]] evidence parameter for Decimals. */ - object DecimalIsFractional extends DecimalIsConflicted with Fractional[Decimal] { + private[sql] object DecimalIsFractional extends DecimalIsConflicted with Fractional[Decimal] { override def div(x: Decimal, y: Decimal): Decimal = x / y } /** A [[scala.math.Integral]] evidence parameter for Decimals. */ - object DecimalAsIfIntegral extends DecimalIsConflicted with Integral[Decimal] { + private[sql] object DecimalAsIfIntegral extends DecimalIsConflicted with Integral[Decimal] { override def quot(x: Decimal, y: Decimal): Decimal = x / y override def rem(x: Decimal, y: Decimal): Decimal = x % y } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DecimalType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DecimalType.scala index 0f8cecd28f7df..407dc27326c2e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DecimalType.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DecimalType.scala @@ -82,12 +82,12 @@ case class DecimalType(precisionInfo: Option[PrecisionInfo]) extends FractionalT object DecimalType { val Unlimited: DecimalType = DecimalType(None) - object Fixed { + private[sql] object Fixed { def unapply(t: DecimalType): Option[(Int, Int)] = t.precisionInfo.map(p => (p.precision, p.scale)) } - object Expression { + private[sql] object Expression { def unapply(e: Expression): Option[(Int, Int)] = e.dataType match { case t: DecimalType => t.precisionInfo.map(p => (p.precision, p.scale)) case _ => None diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/SQLUserDefinedType.java b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/SQLUserDefinedType.java index a64d2bb7cde37..df64a878b6b36 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/SQLUserDefinedType.java +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/SQLUserDefinedType.java @@ -24,11 +24,11 @@ /** * ::DeveloperApi:: * A user-defined type which can be automatically recognized by a SQLContext and registered. - * + *

    * WARNING: This annotation will only work if both Java and Scala reflection return the same class * names (after erasure) for the UDT. This will NOT be the case when, e.g., the UDT class * is enclosed in an object (a singleton). - * + *

    * WARNING: UDTs are currently only supported from Scala. */ // TODO: Should I used @Documented ? diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/StructType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/StructType.scala index a4f30c825befb..193c08a4d0df7 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/StructType.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/StructType.scala @@ -265,7 +265,7 @@ object StructType { case _ => throw new SparkException(s"Failed to merge incompatible data types $left and $right") } - + private[sql] def fieldsMap(fields: Array[StructField]): Map[String, StructField] = { import scala.collection.breakOut fields.map(s => (s.name, s))(breakOut) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/UTF8String.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/UTF8String.scala index bc9c37bf2d5d2..f5d8fcced362b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/UTF8String.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/UTF8String.scala @@ -203,7 +203,7 @@ object UTF8String { def apply(s: String): UTF8String = { if (s != null) { new UTF8String().set(s) - } else{ + } else { null } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/DistributionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/DistributionSuite.scala index ea82cd2622de9..c046dbf4dc2c9 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/DistributionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/DistributionSuite.scala @@ -17,14 +17,13 @@ package org.apache.spark.sql.catalyst -import org.scalatest.FunSuite - +import org.apache.spark.SparkFunSuite import org.apache.spark.sql.catalyst.plans.physical._ /* Implicit conversions */ import org.apache.spark.sql.catalyst.dsl.expressions._ -class DistributionSuite extends FunSuite { +class DistributionSuite extends SparkFunSuite { protected def checkSatisfied( inputPartitioning: Partitioning, diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ScalaReflectionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ScalaReflectionSuite.scala index 7ff51db76b6bb..9a24b23024e18 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ScalaReflectionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ScalaReflectionSuite.scala @@ -20,8 +20,7 @@ package org.apache.spark.sql.catalyst import java.math.BigInteger import java.sql.{Date, Timestamp} -import org.scalatest.FunSuite - +import org.apache.spark.SparkFunSuite import org.apache.spark.sql.catalyst.expressions.Row import org.apache.spark.sql.types._ @@ -75,7 +74,7 @@ case class MultipleConstructorsData(a: Int, b: String, c: Double) { def this(b: String, a: Int) = this(a, b, c = 1.0) } -class ScalaReflectionSuite extends FunSuite { +class ScalaReflectionSuite extends SparkFunSuite { import ScalaReflection._ test("primitive data") { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/SqlParserSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/SqlParserSuite.scala index 9eed15952d82b..b93a3abc6ebd2 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/SqlParserSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/SqlParserSuite.scala @@ -17,10 +17,10 @@ package org.apache.spark.sql.catalyst +import org.apache.spark.SparkFunSuite import org.apache.spark.sql.catalyst.expressions.Attribute import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.catalyst.plans.logical.Command -import org.scalatest.FunSuite private[sql] case class TestCommand(cmd: String) extends LogicalPlan with Command { override def output: Seq[Attribute] = Seq.empty @@ -49,7 +49,7 @@ private[sql] class CaseInsensitiveTestParser extends AbstractSparkSQLParser { } } -class SqlParserSuite extends FunSuite { +class SqlParserSuite extends SparkFunSuite { test("test long keyword") { val parser = new SuperLongKeywordTestParser diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala index fcff24ca31486..e09cd790a7187 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala @@ -17,8 +17,9 @@ package org.apache.spark.sql.catalyst.analysis -import org.scalatest.{BeforeAndAfter, FunSuite} +import org.scalatest.BeforeAndAfter +import org.apache.spark.SparkFunSuite import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.logical._ @@ -27,7 +28,7 @@ import org.apache.spark.sql.catalyst.SimpleCatalystConf import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.dsl.plans._ -class AnalysisSuite extends FunSuite with BeforeAndAfter { +class AnalysisSuite extends SparkFunSuite with BeforeAndAfter { val caseSensitiveConf = new SimpleCatalystConf(true) val caseInsensitiveConf = new SimpleCatalystConf(false) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/DecimalPrecisionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/DecimalPrecisionSuite.scala index 565b1cfe019c7..1b8d18ded2257 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/DecimalPrecisionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/DecimalPrecisionSuite.scala @@ -17,14 +17,15 @@ package org.apache.spark.sql.catalyst.analysis -import org.scalatest.{BeforeAndAfter, FunSuite} +import org.scalatest.BeforeAndAfter +import org.apache.spark.SparkFunSuite import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.logical.{Union, Project, LocalRelation} import org.apache.spark.sql.types._ import org.apache.spark.sql.catalyst.SimpleCatalystConf -class DecimalPrecisionSuite extends FunSuite with BeforeAndAfter { +class DecimalPrecisionSuite extends SparkFunSuite with BeforeAndAfter { val conf = new SimpleCatalystConf(true) val catalog = new SimpleCatalog(conf) val analyzer = new Analyzer(catalog, EmptyFunctionRegistry, conf) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercionSuite.scala index fcd745f43cfbf..a0798428db094 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercionSuite.scala @@ -20,7 +20,8 @@ package org.apache.spark.sql.catalyst.analysis import org.apache.spark.sql.catalyst.plans.PlanTest import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, Project} +import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, LocalRelation, Project} +import org.apache.spark.sql.catalyst.rules.Rule import org.apache.spark.sql.types._ class HiveTypeCoercionSuite extends PlanTest { @@ -104,31 +105,16 @@ class HiveTypeCoercionSuite extends PlanTest { widenTest(ArrayType(IntegerType), StructType(Seq()), None) } - test("boolean casts") { - val booleanCasts = new HiveTypeCoercion { }.BooleanCasts - def ruleTest(initial: Expression, transformed: Expression) { - val testRelation = LocalRelation(AttributeReference("a", IntegerType)()) - comparePlans( - booleanCasts(Project(Seq(Alias(initial, "a")()), testRelation)), - Project(Seq(Alias(transformed, "a")()), testRelation)) - } - // Remove superflous boolean -> boolean casts. - ruleTest(Cast(Literal(true), BooleanType), Literal(true)) - // Stringify boolean when casting to string. - ruleTest( - Cast(Literal(false), StringType), - If(Literal(false), Literal("true"), Literal("false"))) + private def ruleTest(rule: Rule[LogicalPlan], initial: Expression, transformed: Expression) { + val testRelation = LocalRelation(AttributeReference("a", IntegerType)()) + comparePlans( + rule(Project(Seq(Alias(initial, "a")()), testRelation)), + Project(Seq(Alias(transformed, "a")()), testRelation)) } test("coalesce casts") { val fac = new HiveTypeCoercion { }.FunctionArgumentConversion - def ruleTest(initial: Expression, transformed: Expression) { - val testRelation = LocalRelation(AttributeReference("a", IntegerType)()) - comparePlans( - fac(Project(Seq(Alias(initial, "a")()), testRelation)), - Project(Seq(Alias(transformed, "a")()), testRelation)) - } - ruleTest( + ruleTest(fac, Coalesce(Literal(1.0) :: Literal(1) :: Literal.create(1.0, FloatType) @@ -137,7 +123,7 @@ class HiveTypeCoercionSuite extends PlanTest { :: Cast(Literal(1), DoubleType) :: Cast(Literal.create(1.0, FloatType), DoubleType) :: Nil)) - ruleTest( + ruleTest(fac, Coalesce(Literal(1L) :: Literal(1) :: Literal(new java.math.BigDecimal("1000000000000000000000")) @@ -147,4 +133,39 @@ class HiveTypeCoercionSuite extends PlanTest { :: Cast(Literal(new java.math.BigDecimal("1000000000000000000000")), DecimalType()) :: Nil)) } + + test("type coercion for CaseKeyWhen") { + val cwc = new HiveTypeCoercion {}.CaseWhenCoercion + ruleTest(cwc, + CaseKeyWhen(Literal(1.toShort), Seq(Literal(1), Literal("a"))), + CaseKeyWhen(Cast(Literal(1.toShort), IntegerType), Seq(Literal(1), Literal("a"))) + ) + // Will remove exception expectation in PR#6405 + intercept[RuntimeException] { + ruleTest(cwc, + CaseKeyWhen(Literal(true), Seq(Literal(1), Literal("a"))), + CaseKeyWhen(Literal(true), Seq(Literal(1), Literal("a"))) + ) + } + } + + test("type coercion simplification for equal to") { + val be = new HiveTypeCoercion {}.BooleanEqualization + ruleTest(be, + EqualTo(Literal(true), Literal(1)), + Literal(true) + ) + ruleTest(be, + EqualTo(Literal(true), Literal(0)), + Not(Literal(true)) + ) + ruleTest(be, + EqualNullSafe(Literal(true), Literal(1)), + And(IsNotNull(Literal(true)), Literal(true)) + ) + ruleTest(be, + EqualNullSafe(Literal(true), Literal(0)), + And(IsNotNull(Literal(true)), Not(Literal(true))) + ) + } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/AttributeSetSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/AttributeSetSuite.scala index f2f3a84d19380..97cfb5f06dd73 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/AttributeSetSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/AttributeSetSuite.scala @@ -17,11 +17,10 @@ package org.apache.spark.sql.catalyst.expressions -import org.scalatest.FunSuite - +import org.apache.spark.SparkFunSuite import org.apache.spark.sql.types.IntegerType -class AttributeSetSuite extends FunSuite { +class AttributeSetSuite extends SparkFunSuite { val aUpper = AttributeReference("A", IntegerType)(exprId = ExprId(1)) val aLower = AttributeReference("a", IntegerType)(exprId = ExprId(1)) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvaluationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvaluationSuite.scala index a14f776b1eaee..b6927485f42bf 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvaluationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvaluationSuite.scala @@ -22,9 +22,9 @@ import java.sql.{Date, Timestamp} import scala.collection.immutable.HashSet import org.scalactic.TripleEqualsSupport.Spread -import org.scalatest.FunSuite import org.scalatest.Matchers._ +import org.apache.spark.SparkFunSuite import org.apache.spark.sql.catalyst.CatalystTypeConverters import org.apache.spark.sql.catalyst.analysis.UnresolvedExtractValue import org.apache.spark.sql.catalyst.dsl.expressions._ @@ -33,7 +33,7 @@ import org.apache.spark.sql.catalyst.util.DateUtils import org.apache.spark.sql.types._ -class ExpressionEvaluationBaseSuite extends FunSuite { +class ExpressionEvaluationBaseSuite extends SparkFunSuite { def evaluate(expression: Expression, inputRow: Row = EmptyRow): Any = { expression.eval(inputRow) @@ -372,6 +372,8 @@ class ExpressionEvaluationSuite extends ExpressionEvaluationBaseSuite { DecimalType.Unlimited, ByteType), TimestampType), LongType), StringType), ShortType), 0) checkEvaluation(Literal(true) cast IntegerType, 1) checkEvaluation(Literal(false) cast IntegerType, 0) + checkEvaluation(Literal(true) cast StringType, "true") + checkEvaluation(Literal(false) cast StringType, "false") checkEvaluation(Cast(Literal(1) cast BooleanType, IntegerType), 1) checkEvaluation(Cast(Literal(0) cast BooleanType, IntegerType), 0) checkEvaluation("23" cast DoubleType, 23d) @@ -860,7 +862,7 @@ class ExpressionEvaluationSuite extends ExpressionEvaluationBaseSuite { val c5 = 'a.string.at(4) val c6 = 'a.string.at(5) - val literalNull = Literal.create(null, BooleanType) + val literalNull = Literal.create(null, IntegerType) val literalInt = Literal(1) val literalString = Literal("a") @@ -869,12 +871,12 @@ class ExpressionEvaluationSuite extends ExpressionEvaluationBaseSuite { checkEvaluation(CaseKeyWhen(c2, Seq(literalInt, c4, c5)), "a", row) checkEvaluation(CaseKeyWhen(c2, Seq(c1, c4, c5)), "b", row) checkEvaluation(CaseKeyWhen(c4, Seq(literalString, c2, c3)), 1, row) - checkEvaluation(CaseKeyWhen(c4, Seq(c1, c3, c5, c2, Literal(3))), 3, row) + checkEvaluation(CaseKeyWhen(c4, Seq(c6, c3, c5, c2, Literal(3))), 3, row) checkEvaluation(CaseKeyWhen(literalInt, Seq(c2, c4, c5)), "a", row) checkEvaluation(CaseKeyWhen(literalString, Seq(c5, c2, c4, c3)), 2, row) - checkEvaluation(CaseKeyWhen(literalInt, Seq(c5, c2, c4, c3)), null, row) - checkEvaluation(CaseKeyWhen(literalNull, Seq(c5, c2, c1, c3)), 2, row) + checkEvaluation(CaseKeyWhen(c6, Seq(c5, c2, c4, c3)), null, row) + checkEvaluation(CaseKeyWhen(literalNull, Seq(c2, c5, c1, c6)), "c", row) } test("complex type") { @@ -1207,7 +1209,7 @@ class ExpressionEvaluationSuite extends ExpressionEvaluationBaseSuite { } /** - * Used for testing math functions for DataFrames. + * Used for testing math functions for DataFrames. * @param c The DataFrame function * @param f The functions in scala.math * @param domain The set of values to run the function with @@ -1215,7 +1217,7 @@ class ExpressionEvaluationSuite extends ExpressionEvaluationBaseSuite { * @tparam T Generic type for primitives */ def unaryMathFunctionEvaluation[@specialized(Int, Double, Float, Long) T]( - c: Expression => Expression, + c: Expression => Expression, f: T => T, domain: Iterable[T] = (-20 to 20).map(_ * 0.1), expectNull: Boolean = false): Unit = { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/UnsafeFixedWidthAggregationMapSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/UnsafeFixedWidthAggregationMapSuite.scala index 7a19e511eb8b5..88a36aa121b55 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/UnsafeFixedWidthAggregationMapSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/UnsafeFixedWidthAggregationMapSuite.scala @@ -20,12 +20,16 @@ package org.apache.spark.sql.catalyst.expressions import scala.collection.JavaConverters._ import scala.util.Random +import org.apache.spark.SparkFunSuite import org.apache.spark.unsafe.memory.{ExecutorMemoryManager, TaskMemoryManager, MemoryAllocator} -import org.scalatest.{BeforeAndAfterEach, FunSuite, Matchers} +import org.scalatest.{BeforeAndAfterEach, Matchers} import org.apache.spark.sql.types._ -class UnsafeFixedWidthAggregationMapSuite extends FunSuite with Matchers with BeforeAndAfterEach { +class UnsafeFixedWidthAggregationMapSuite + extends SparkFunSuite + with Matchers + with BeforeAndAfterEach { import UnsafeFixedWidthAggregationMap._ diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/UnsafeRowConverterSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/UnsafeRowConverterSuite.scala index 3a60c7fd32675..61722f1ffa462 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/UnsafeRowConverterSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/UnsafeRowConverterSuite.scala @@ -19,13 +19,14 @@ package org.apache.spark.sql.catalyst.expressions import java.util.Arrays -import org.scalatest.{FunSuite, Matchers} +import org.scalatest.Matchers +import org.apache.spark.SparkFunSuite import org.apache.spark.sql.types._ import org.apache.spark.unsafe.PlatformDependent import org.apache.spark.unsafe.array.ByteArrayMethods -class UnsafeRowConverterSuite extends FunSuite with Matchers { +class UnsafeRowConverterSuite extends SparkFunSuite with Matchers { test("basic conversion with only primitive types") { val fieldTypes: Array[DataType] = Array(LongType, LongType, IntegerType) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/CombiningLimitsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/CombiningLimitsSuite.scala index a30052b38fc11..06c592f4905a3 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/CombiningLimitsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/CombiningLimitsSuite.scala @@ -71,7 +71,7 @@ class CombiningLimitsSuite extends PlanTest { comparePlans(optimized, correctAnswer) } - + test("limits: combines two limits after ColumnPruning") { val originalQuery = testRelation @@ -79,7 +79,7 @@ class CombiningLimitsSuite extends PlanTest { .limit(2) .select('a) .limit(5) - + val optimized = Optimize.execute(originalQuery.analyze) val correctAnswer = testRelation diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ConstantFoldingSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ConstantFoldingSuite.scala index 5697c2272b8e8..ec3b2f1edfa05 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ConstantFoldingSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ConstantFoldingSuite.scala @@ -248,7 +248,7 @@ class ConstantFoldingSuite extends PlanTest { comparePlans(optimized, correctAnswer) } - + test("Constant folding test: Fold In(v, list) into true or false") { var originalQuery = testRelation diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FilterPushdownSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FilterPushdownSuite.scala index ff25470bf0946..17dc9124749e8 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FilterPushdownSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FilterPushdownSuite.scala @@ -93,7 +93,7 @@ class FilterPushdownSuite extends PlanTest { comparePlans(optimized, correctAnswer) } - + test("column pruning for Project(ne, Limit)") { val originalQuery = testRelation @@ -109,7 +109,7 @@ class FilterPushdownSuite extends PlanTest { comparePlans(optimized, correctAnswer) } - + // After this line is unimplemented. test("simple push down") { val originalQuery = diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/OptimizeInSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/OptimizeInSuite.scala index 11b0859d3f066..1d433275fed2e 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/OptimizeInSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/OptimizeInSuite.scala @@ -57,7 +57,7 @@ class OptimizeInSuite extends PlanTest { comparePlans(optimized, correctAnswer) } - + test("OptimizedIn test: In clause not optimized in case filter has attributes") { val originalQuery = testRelation diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ProjectCollapsingSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ProjectCollapsingSuite.scala new file mode 100644 index 0000000000000..151654bffbd66 --- /dev/null +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ProjectCollapsingSuite.scala @@ -0,0 +1,73 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.optimizer + +import org.apache.spark.sql.catalyst.analysis.EliminateSubQueries +import org.apache.spark.sql.catalyst.dsl.plans._ +import org.apache.spark.sql.catalyst.dsl.expressions._ +import org.apache.spark.sql.catalyst.expressions.Rand +import org.apache.spark.sql.catalyst.plans.PlanTest +import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan} +import org.apache.spark.sql.catalyst.rules.RuleExecutor + + +class ProjectCollapsingSuite extends PlanTest { + object Optimize extends RuleExecutor[LogicalPlan] { + val batches = + Batch("Subqueries", FixedPoint(10), EliminateSubQueries) :: + Batch("ProjectCollapsing", Once, ProjectCollapsing) :: Nil + } + + val testRelation = LocalRelation('a.int, 'b.int) + + test("collapse two deterministic, independent projects into one") { + val query = testRelation + .select(('a + 1).as('a_plus_1), 'b) + .select('a_plus_1, ('b + 1).as('b_plus_1)) + + val optimized = Optimize.execute(query.analyze) + val correctAnswer = testRelation.select(('a + 1).as('a_plus_1), ('b + 1).as('b_plus_1)).analyze + + comparePlans(optimized, correctAnswer) + } + + test("collapse two deterministic, dependent projects into one") { + val query = testRelation + .select(('a + 1).as('a_plus_1), 'b) + .select(('a_plus_1 + 1).as('a_plus_2), 'b) + + val optimized = Optimize.execute(query.analyze) + + val correctAnswer = testRelation.select( + (('a + 1).as('a_plus_1) + 1).as('a_plus_2), + 'b).analyze + + comparePlans(optimized, correctAnswer) + } + + test("do not collapse nondeterministic projects") { + val query = testRelation + .select(Rand(10).as('rand)) + .select(('rand + 1).as('rand1), ('rand + 2).as('rand2)) + + val optimized = Optimize.execute(query.analyze) + val correctAnswer = query.analyze + + comparePlans(optimized, correctAnswer) + } +} diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/PlanTest.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/PlanTest.scala index e7cafcc96de87..765c1e2dda99f 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/PlanTest.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/PlanTest.scala @@ -17,8 +17,7 @@ package org.apache.spark.sql.catalyst.plans -import org.scalatest.FunSuite - +import org.apache.spark.SparkFunSuite import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.logical.{OneRowRelation, Filter, LogicalPlan} import org.apache.spark.sql.catalyst.util._ @@ -26,7 +25,7 @@ import org.apache.spark.sql.catalyst.util._ /** * Provides helper methods for comparing plans. */ -class PlanTest extends FunSuite { +class PlanTest extends SparkFunSuite { /** * Since attribute references are given globally unique ids during analysis, diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/SameResultSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/SameResultSuite.scala index 1273921f6394c..62d5f6ac74885 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/SameResultSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/SameResultSuite.scala @@ -17,8 +17,7 @@ package org.apache.spark.sql.catalyst.plans -import org.scalatest.FunSuite - +import org.apache.spark.SparkFunSuite import org.apache.spark.sql.catalyst.dsl.plans._ import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.expressions.{ExprId, AttributeReference} @@ -28,7 +27,7 @@ import org.apache.spark.sql.catalyst.util._ /** * Tests for the sameResult function of [[LogicalPlan]]. */ -class SameResultSuite extends FunSuite { +class SameResultSuite extends SparkFunSuite { val testRelation = LocalRelation('a.int, 'b.int, 'c.int) val testRelation2 = LocalRelation('a.int, 'b.int, 'c.int) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/trees/RuleExecutorSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/trees/RuleExecutorSuite.scala index 2a641c63f87bb..a7de7b052bdc3 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/trees/RuleExecutorSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/trees/RuleExecutorSuite.scala @@ -17,12 +17,11 @@ package org.apache.spark.sql.catalyst.trees -import org.scalatest.FunSuite - +import org.apache.spark.SparkFunSuite import org.apache.spark.sql.catalyst.expressions.{Expression, IntegerLiteral, Literal} import org.apache.spark.sql.catalyst.rules.{Rule, RuleExecutor} -class RuleExecutorSuite extends FunSuite { +class RuleExecutorSuite extends SparkFunSuite { object DecrementLiterals extends Rule[Expression] { def apply(e: Expression): Expression = e transform { case IntegerLiteral(i) if i > 0 => Literal(i - 1) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/trees/TreeNodeSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/trees/TreeNodeSuite.scala index 9fcfc51c96139..67db3d5e6d751 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/trees/TreeNodeSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/trees/TreeNodeSuite.scala @@ -19,8 +19,7 @@ package org.apache.spark.sql.catalyst.trees import scala.collection.mutable.ArrayBuffer -import org.scalatest.FunSuite - +import org.apache.spark.SparkFunSuite import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.types.{IntegerType, StringType, NullType} @@ -32,7 +31,7 @@ case class Dummy(optKey: Option[Expression]) extends Expression { override def eval(input: Row): Any = null.asInstanceOf[Any] } -class TreeNodeSuite extends FunSuite { +class TreeNodeSuite extends SparkFunSuite { test("top node changed") { val after = Literal(1) transform { case Literal(1, _) => Literal(2) } assert(after === Literal(2)) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/MetadataSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/MetadataSuite.scala index d7d60efee50fa..4030a1b1df358 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/MetadataSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/MetadataSuite.scala @@ -18,11 +18,11 @@ package org.apache.spark.sql.catalyst.util import org.json4s.jackson.JsonMethods.parse -import org.scalatest.FunSuite +import org.apache.spark.SparkFunSuite import org.apache.spark.sql.types.{MetadataBuilder, Metadata} -class MetadataSuite extends FunSuite { +class MetadataSuite extends SparkFunSuite { val baseMetadata = new MetadataBuilder() .putString("purpose", "ml") diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/types/DataTypeParserSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/types/DataTypeParserSuite.scala index 3e7cf7cbb5e63..c6171b7b6916d 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/types/DataTypeParserSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/types/DataTypeParserSuite.scala @@ -17,9 +17,9 @@ package org.apache.spark.sql.types -import org.scalatest.FunSuite +import org.apache.spark.SparkFunSuite -class DataTypeParserSuite extends FunSuite { +class DataTypeParserSuite extends SparkFunSuite { def checkDataType(dataTypeString: String, expectedDataType: DataType): Unit = { test(s"parse ${dataTypeString.replace("\n", "")}") { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/types/DataTypeSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/types/DataTypeSuite.scala index df119827812f9..261c4fcad24aa 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/types/DataTypeSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/types/DataTypeSuite.scala @@ -17,10 +17,9 @@ package org.apache.spark.sql.types -import org.apache.spark.SparkException -import org.scalatest.FunSuite +import org.apache.spark.{SparkException, SparkFunSuite} -class DataTypeSuite extends FunSuite { +class DataTypeSuite extends SparkFunSuite { test("construct an ArrayType") { val array = ArrayType(StringType) @@ -72,7 +71,7 @@ class DataTypeSuite extends FunSuite { test("fieldsMap returns map of name to StructField") { val struct = StructType( - StructField("a", LongType) :: + StructField("a", LongType) :: StructField("b", FloatType) :: Nil) val mapped = StructType.fieldsMap(struct.fields) @@ -91,7 +90,7 @@ class DataTypeSuite extends FunSuite { val right = StructType(List()) val merged = left.merge(right) - + assert(merged === left) } @@ -134,7 +133,7 @@ class DataTypeSuite extends FunSuite { val right = StructType( StructField("b", LongType) :: Nil) - + intercept[SparkException] { left.merge(right) } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/types/UTF8StringSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/types/UTF8StringSuite.scala index a22aa6f244c48..81d7ab010f394 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/types/UTF8StringSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/types/UTF8StringSuite.scala @@ -17,10 +17,10 @@ package org.apache.spark.sql.types -import org.scalatest.FunSuite +import org.apache.spark.SparkFunSuite // scalastyle:off -class UTF8StringSuite extends FunSuite { +class UTF8StringSuite extends SparkFunSuite { test("basic") { def check(str: String, len: Int) { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/types/decimal/DecimalSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/types/decimal/DecimalSuite.scala index de6a2cd448c47..28b373e258311 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/types/decimal/DecimalSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/types/decimal/DecimalSuite.scala @@ -17,12 +17,13 @@ package org.apache.spark.sql.types.decimal +import org.apache.spark.SparkFunSuite import org.apache.spark.sql.types.Decimal -import org.scalatest.{PrivateMethodTester, FunSuite} +import org.scalatest.PrivateMethodTester import scala.language.postfixOps -class DecimalSuite extends FunSuite with PrivateMethodTester { +class DecimalSuite extends SparkFunSuite with PrivateMethodTester { test("creating decimals") { /** Check that a Decimal has the given string representation, precision and scale */ def checkDecimal(d: Decimal, string: String, precision: Int, scale: Int): Unit = { diff --git a/sql/core/pom.xml b/sql/core/pom.xml index ffe95bb49188f..8210c552603ea 100644 --- a/sql/core/pom.xml +++ b/sql/core/pom.xml @@ -41,6 +41,13 @@ spark-core_${scala.binary.version} ${project.version} + + org.apache.spark + spark-core_${scala.binary.version} + ${project.version} + test-jar + test + org.apache.spark spark-catalyst_${scala.binary.version} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Column.scala b/sql/core/src/main/scala/org/apache/spark/sql/Column.scala index b49b1d327289f..d3efa83380d04 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Column.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Column.scala @@ -716,6 +716,18 @@ class Column(protected[sql] val expr: Expression) extends Logging { */ def endsWith(literal: String): Column = this.endsWith(lit(literal)) + /** + * Gives the column an alias. Same as `as`. + * {{{ + * // Renames colA to colB in select output. + * df.select($"colA".alias("colB")) + * }}} + * + * @group expr_ops + * @since 1.4.0 + */ + def alias(alias: String): Column = as(alias) + /** * Gives the column an alias. * {{{ diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala index e90109446b642..034d887901975 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala @@ -57,14 +57,11 @@ private[sql] object DataFrame { * :: Experimental :: * A distributed collection of data organized into named columns. * - * A [[DataFrame]] is equivalent to a relational table in Spark SQL. There are multiple ways - * to create a [[DataFrame]]: + * A [[DataFrame]] is equivalent to a relational table in Spark SQL. The following example creates + * a [[DataFrame]] by pointing Spark SQL to a Parquet data set. * {{{ - * // Create a DataFrame from Parquet files - * val people = sqlContext.parquetFile("...") - * - * // Create a DataFrame from data sources - * val df = sqlContext.load("...", "json") + * val people = sqlContext.read.parquet("...") // in Scala + * DataFrame people = sqlContext.read().parquet("...") // in Java * }}} * * Once created, it can be manipulated using the various domain-specific-language (DSL) functions @@ -86,8 +83,8 @@ private[sql] object DataFrame { * A more concrete example in Scala: * {{{ * // To create DataFrame using SQLContext - * val people = sqlContext.parquetFile("...") - * val department = sqlContext.parquetFile("...") + * val people = sqlContext.read.parquet("...") + * val department = sqlContext.read.parquet("...") * * people.filter("age > 30") * .join(department, people("deptId") === department("id")) @@ -98,8 +95,8 @@ private[sql] object DataFrame { * and in Java: * {{{ * // To create DataFrame using SQLContext - * DataFrame people = sqlContext.parquetFile("..."); - * DataFrame department = sqlContext.parquetFile("..."); + * DataFrame people = sqlContext.read().parquet("..."); + * DataFrame department = sqlContext.read().parquet("..."); * * people.filter("age".gt(30)) * .join(department, people.col("deptId").equalTo(department("id"))) @@ -1444,7 +1441,9 @@ class DataFrame private[sql]( //////////////////////////////////////////////////////////////////////////// //////////////////////////////////////////////////////////////////////////// - /** Left here for backward compatibility. */ + /** + * @deprecated As of 1.3.0, replaced by `toDF()`. + */ @deprecated("use toDF", "1.3.0") def toSchemaRDD: DataFrame = this @@ -1455,6 +1454,7 @@ class DataFrame private[sql]( * given name; if you pass `false`, it will throw if the table already * exists. * @group output + * @deprecated As of 1.340, replaced by `write().jdbc()`. */ @deprecated("Use write.jdbc()", "1.4.0") def createJDBCTable(url: String, table: String, allowExisting: Boolean): Unit = { @@ -1473,6 +1473,7 @@ class DataFrame private[sql]( * the RDD in order via the simple statement * `INSERT INTO table VALUES (?, ?, ..., ?)` should not fail. * @group output + * @deprecated As of 1.4.0, replaced by `write().jdbc()`. */ @deprecated("Use write.jdbc()", "1.4.0") def insertIntoJDBC(url: String, table: String, overwrite: Boolean): Unit = { @@ -1485,6 +1486,7 @@ class DataFrame private[sql]( * Files that are written out using this method can be read back in as a [[DataFrame]] * using the `parquetFile` function in [[SQLContext]]. * @group output + * @deprecated As of 1.4.0, replaced by `write().parquet()`. */ @deprecated("Use write.parquet(path)", "1.4.0") def saveAsParquetFile(path: String): Unit = { @@ -1508,6 +1510,7 @@ class DataFrame private[sql]( * Also note that while this function can persist the table metadata into Hive's metastore, * the table will NOT be accessible from Hive, until SPARK-7550 is resolved. * @group output + * @deprecated As of 1.4.0, replaced by `write().saveAsTable(tableName)`. */ @deprecated("Use write.saveAsTable(tableName)", "1.4.0") def saveAsTable(tableName: String): Unit = { @@ -1526,6 +1529,7 @@ class DataFrame private[sql]( * Also note that while this function can persist the table metadata into Hive's metastore, * the table will NOT be accessible from Hive, until SPARK-7550 is resolved. * @group output + * @deprecated As of 1.4.0, replaced by `write().mode(mode).saveAsTable(tableName)`. */ @deprecated("Use write.mode(mode).saveAsTable(tableName)", "1.4.0") def saveAsTable(tableName: String, mode: SaveMode): Unit = { @@ -1545,6 +1549,7 @@ class DataFrame private[sql]( * Also note that while this function can persist the table metadata into Hive's metastore, * the table will NOT be accessible from Hive, until SPARK-7550 is resolved. * @group output + * @deprecated As of 1.4.0, replaced by `write().format(source).saveAsTable(tableName)`. */ @deprecated("Use write.format(source).saveAsTable(tableName)", "1.4.0") def saveAsTable(tableName: String, source: String): Unit = { @@ -1564,6 +1569,7 @@ class DataFrame private[sql]( * Also note that while this function can persist the table metadata into Hive's metastore, * the table will NOT be accessible from Hive, until SPARK-7550 is resolved. * @group output + * @deprecated As of 1.4.0, replaced by `write().mode(mode).saveAsTable(tableName)`. */ @deprecated("Use write.format(source).mode(mode).saveAsTable(tableName)", "1.4.0") def saveAsTable(tableName: String, source: String, mode: SaveMode): Unit = { @@ -1582,6 +1588,8 @@ class DataFrame private[sql]( * Also note that while this function can persist the table metadata into Hive's metastore, * the table will NOT be accessible from Hive, until SPARK-7550 is resolved. * @group output + * @deprecated As of 1.4.0, replaced by + * `write().format(source).mode(mode).options(options).saveAsTable(tableName)`. */ @deprecated("Use write.format(source).mode(mode).options(options).saveAsTable(tableName)", "1.4.0") @@ -1606,6 +1614,8 @@ class DataFrame private[sql]( * Also note that while this function can persist the table metadata into Hive's metastore, * the table will NOT be accessible from Hive, until SPARK-7550 is resolved. * @group output + * @deprecated As of 1.4.0, replaced by + * `write().format(source).mode(mode).options(options).saveAsTable(tableName)`. */ @deprecated("Use write.format(source).mode(mode).options(options).saveAsTable(tableName)", "1.4.0") @@ -1622,6 +1632,7 @@ class DataFrame private[sql]( * using the default data source configured by spark.sql.sources.default and * [[SaveMode.ErrorIfExists]] as the save mode. * @group output + * @deprecated As of 1.4.0, replaced by `write().save(path)`. */ @deprecated("Use write.save(path)", "1.4.0") def save(path: String): Unit = { @@ -1632,6 +1643,7 @@ class DataFrame private[sql]( * Saves the contents of this DataFrame to the given path and [[SaveMode]] specified by mode, * using the default data source configured by spark.sql.sources.default. * @group output + * @deprecated As of 1.4.0, replaced by `write().mode(mode).save(path)`. */ @deprecated("Use write.mode(mode).save(path)", "1.4.0") def save(path: String, mode: SaveMode): Unit = { @@ -1642,6 +1654,7 @@ class DataFrame private[sql]( * Saves the contents of this DataFrame to the given path based on the given data source, * using [[SaveMode.ErrorIfExists]] as the save mode. * @group output + * @deprecated As of 1.4.0, replaced by `write().format(source).save(path)`. */ @deprecated("Use write.format(source).save(path)", "1.4.0") def save(path: String, source: String): Unit = { @@ -1652,6 +1665,7 @@ class DataFrame private[sql]( * Saves the contents of this DataFrame to the given path based on the given data source and * [[SaveMode]] specified by mode. * @group output + * @deprecated As of 1.4.0, replaced by `write().format(source).mode(mode).save(path)`. */ @deprecated("Use write.format(source).mode(mode).save(path)", "1.4.0") def save(path: String, source: String, mode: SaveMode): Unit = { @@ -1662,6 +1676,8 @@ class DataFrame private[sql]( * Saves the contents of this DataFrame based on the given data source, * [[SaveMode]] specified by mode, and a set of options. * @group output + * @deprecated As of 1.4.0, replaced by + * `write().format(source).mode(mode).options(options).save(path)`. */ @deprecated("Use write.format(source).mode(mode).options(options).save()", "1.4.0") def save( @@ -1676,6 +1692,8 @@ class DataFrame private[sql]( * Saves the contents of this DataFrame based on the given data source, * [[SaveMode]] specified by mode, and a set of options * @group output + * @deprecated As of 1.4.0, replaced by + * `write().format(source).mode(mode).options(options).save(path)`. */ @deprecated("Use write.format(source).mode(mode).options(options).save()", "1.4.0") def save( @@ -1689,6 +1707,8 @@ class DataFrame private[sql]( /** * Adds the rows from this RDD to the specified table, optionally overwriting the existing data. * @group output + * @deprecated As of 1.4.0, replaced by + * `write().mode(SaveMode.Append|SaveMode.Overwrite).saveAsTable(tableName)`. */ @deprecated("Use write.mode(SaveMode.Append|SaveMode.Overwrite).saveAsTable(tableName)", "1.4.0") def insertInto(tableName: String, overwrite: Boolean): Unit = { @@ -1699,6 +1719,8 @@ class DataFrame private[sql]( * Adds the rows from this RDD to the specified table. * Throws an exception if the table already exists. * @group output + * @deprecated As of 1.4.0, replaced by + * `write().mode(SaveMode.Append).saveAsTable(tableName)`. */ @deprecated("Use write.mode(SaveMode.Append).saveAsTable(tableName)", "1.4.0") def insertInto(tableName: String): Unit = { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameStatFunctions.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameStatFunctions.scala index 5d106c1ac2674..edb9ed7bba56a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameStatFunctions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameStatFunctions.scala @@ -43,7 +43,7 @@ final class DataFrameStatFunctions private[sql](df: DataFrame) { /** * Calculates the correlation of two columns of a DataFrame. Currently only supports the Pearson - * Correlation Coefficient. For Spearman Correlation, consider using RDD methods found in + * Correlation Coefficient. For Spearman Correlation, consider using RDD methods found in * MLlib's Statistics. * * @param col1 the name of the column @@ -97,6 +97,9 @@ final class DataFrameStatFunctions private[sql](df: DataFrame) { * [[http://dx.doi.org/10.1145/762471.762473, proposed by Karp, Schenker, and Papadimitriou]]. * The `support` should be greater than 1e-4. * + * This function is meant for exploratory data analysis, as we make no guarantee about the + * backward compatibility of the schema of the resulting [[DataFrame]]. + * * @param cols the names of the columns to search frequent items in. * @param support The minimum frequency for an item to be considered `frequent`. Should be greater * than 1e-4. @@ -114,6 +117,9 @@ final class DataFrameStatFunctions private[sql](df: DataFrame) { * [[http://dx.doi.org/10.1145/762471.762473, proposed by Karp, Schenker, and Papadimitriou]]. * Uses a `default` support of 1%. * + * This function is meant for exploratory data analysis, as we make no guarantee about the + * backward compatibility of the schema of the resulting [[DataFrame]]. + * * @param cols the names of the columns to search frequent items in. * @return A Local DataFrame with the Array of frequent items for each column. * @@ -128,6 +134,9 @@ final class DataFrameStatFunctions private[sql](df: DataFrame) { * frequent element count algorithm described in * [[http://dx.doi.org/10.1145/762471.762473, proposed by Karp, Schenker, and Papadimitriou]]. * + * This function is meant for exploratory data analysis, as we make no guarantee about the + * backward compatibility of the schema of the resulting [[DataFrame]]. + * * @param cols the names of the columns to search frequent items in. * @return A Local DataFrame with the Array of frequent items for each column. * @@ -143,6 +152,9 @@ final class DataFrameStatFunctions private[sql](df: DataFrame) { * [[http://dx.doi.org/10.1145/762471.762473, proposed by Karp, Schenker, and Papadimitriou]]. * Uses a `default` support of 1%. * + * This function is meant for exploratory data analysis, as we make no guarantee about the + * backward compatibility of the schema of the resulting [[DataFrame]]. + * * @param cols the names of the columns to search frequent items in. * @return A Local DataFrame with the Array of frequent items for each column. * diff --git a/sql/core/src/main/scala/org/apache/spark/sql/GroupedData.scala b/sql/core/src/main/scala/org/apache/spark/sql/GroupedData.scala index 516ba2ac23371..45b3e1bc627d5 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/GroupedData.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/GroupedData.scala @@ -40,22 +40,22 @@ private[sql] object GroupedData { /** * The Grouping Type */ - trait GroupType + private[sql] trait GroupType /** * To indicate it's the GroupBy */ - object GroupByType extends GroupType + private[sql] object GroupByType extends GroupType /** * To indicate it's the CUBE */ - object CubeType extends GroupType + private[sql] object CubeType extends GroupType /** * To indicate it's the ROLLUP */ - object RollupType extends GroupType + private[sql] object RollupType extends GroupType } /** @@ -249,7 +249,7 @@ class GroupedData protected[sql]( def mean(colNames: String*): DataFrame = { aggregateNumericColumns(colNames : _*)(Average) } - + /** * Compute the max value for each numeric columns for each group. * The resulting [[DataFrame]] will also contain the grouping columns. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala index a32897c20b474..91e6385dec81b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala @@ -182,9 +182,28 @@ class SQLContext(@transient val sparkContext: SparkContext) conf.dialect } - sparkContext.getConf.getAll.foreach { - case (key, value) if key.startsWith("spark.sql") => setConf(key, value) - case _ => + { + // We extract spark sql settings from SparkContext's conf and put them to + // Spark SQL's conf. + // First, we populate the SQLConf (conf). So, we can make sure that other values using + // those settings in their construction can get the correct settings. + // For example, metadataHive in HiveContext may need both spark.sql.hive.metastore.version + // and spark.sql.hive.metastore.jars to get correctly constructed. + val properties = new Properties + sparkContext.getConf.getAll.foreach { + case (key, value) if key.startsWith("spark.sql") => properties.setProperty(key, value) + case _ => + } + // We directly put those settings to conf to avoid of calling setConf, which may have + // side-effects. For example, in HiveContext, setConf may cause executionHive and metadataHive + // get constructed. If we call setConf directly, the constructed metadataHive may have + // wrong settings, or the construction may fail. + conf.setConf(properties) + // After we have populated SQLConf, we call setConf to populate other confs in the subclass + // (e.g. hiveconf in HiveContext). + properties.foreach { + case (key, value) => setConf(key, value) + } } @transient @@ -1021,21 +1040,33 @@ class SQLContext(@transient val sparkContext: SparkContext) //////////////////////////////////////////////////////////////////////////// //////////////////////////////////////////////////////////////////////////// + /** + * @deprecated As of 1.3.0, replaced by `createDataFrame()`. + */ @deprecated("use createDataFrame", "1.3.0") def applySchema(rowRDD: RDD[Row], schema: StructType): DataFrame = { createDataFrame(rowRDD, schema) } + /** + * @deprecated As of 1.3.0, replaced by `createDataFrame()`. + */ @deprecated("use createDataFrame", "1.3.0") def applySchema(rowRDD: JavaRDD[Row], schema: StructType): DataFrame = { createDataFrame(rowRDD, schema) } + /** + * @deprecated As of 1.3.0, replaced by `createDataFrame()`. + */ @deprecated("use createDataFrame", "1.3.0") def applySchema(rdd: RDD[_], beanClass: Class[_]): DataFrame = { createDataFrame(rdd, beanClass) } + /** + * @deprecated As of 1.3.0, replaced by `createDataFrame()`. + */ @deprecated("use createDataFrame", "1.3.0") def applySchema(rdd: JavaRDD[_], beanClass: Class[_]): DataFrame = { createDataFrame(rdd, beanClass) @@ -1046,6 +1077,7 @@ class SQLContext(@transient val sparkContext: SparkContext) * [[DataFrame]] if no paths are passed in. * * @group specificdata + * @deprecated As of 1.4.0, replaced by `read().parquet()`. */ @deprecated("Use read.parquet()", "1.4.0") @scala.annotation.varargs @@ -1065,6 +1097,7 @@ class SQLContext(@transient val sparkContext: SparkContext) * It goes through the entire dataset once to determine the schema. * * @group specificdata + * @deprecated As of 1.4.0, replaced by `read().json()`. */ @deprecated("Use read.json()", "1.4.0") def jsonFile(path: String): DataFrame = { @@ -1076,6 +1109,7 @@ class SQLContext(@transient val sparkContext: SparkContext) * returning the result as a [[DataFrame]]. * * @group specificdata + * @deprecated As of 1.4.0, replaced by `read().json()`. */ @deprecated("Use read.json()", "1.4.0") def jsonFile(path: String, schema: StructType): DataFrame = { @@ -1084,6 +1118,7 @@ class SQLContext(@transient val sparkContext: SparkContext) /** * @group specificdata + * @deprecated As of 1.4.0, replaced by `read().json()`. */ @deprecated("Use read.json()", "1.4.0") def jsonFile(path: String, samplingRatio: Double): DataFrame = { @@ -1096,6 +1131,7 @@ class SQLContext(@transient val sparkContext: SparkContext) * It goes through the entire dataset once to determine the schema. * * @group specificdata + * @deprecated As of 1.4.0, replaced by `read().json()`. */ @deprecated("Use read.json()", "1.4.0") def jsonRDD(json: RDD[String]): DataFrame = read.json(json) @@ -1106,6 +1142,7 @@ class SQLContext(@transient val sparkContext: SparkContext) * It goes through the entire dataset once to determine the schema. * * @group specificdata + * @deprecated As of 1.4.0, replaced by `read().json()`. */ @deprecated("Use read.json()", "1.4.0") def jsonRDD(json: JavaRDD[String]): DataFrame = read.json(json) @@ -1115,6 +1152,7 @@ class SQLContext(@transient val sparkContext: SparkContext) * returning the result as a [[DataFrame]]. * * @group specificdata + * @deprecated As of 1.4.0, replaced by `read().json()`. */ @deprecated("Use read.json()", "1.4.0") def jsonRDD(json: RDD[String], schema: StructType): DataFrame = { @@ -1126,6 +1164,7 @@ class SQLContext(@transient val sparkContext: SparkContext) * schema, returning the result as a [[DataFrame]]. * * @group specificdata + * @deprecated As of 1.4.0, replaced by `read().json()`. */ @deprecated("Use read.json()", "1.4.0") def jsonRDD(json: JavaRDD[String], schema: StructType): DataFrame = { @@ -1137,6 +1176,7 @@ class SQLContext(@transient val sparkContext: SparkContext) * schema, returning the result as a [[DataFrame]]. * * @group specificdata + * @deprecated As of 1.4.0, replaced by `read().json()`. */ @deprecated("Use read.json()", "1.4.0") def jsonRDD(json: RDD[String], samplingRatio: Double): DataFrame = { @@ -1148,6 +1188,7 @@ class SQLContext(@transient val sparkContext: SparkContext) * schema, returning the result as a [[DataFrame]]. * * @group specificdata + * @deprecated As of 1.4.0, replaced by `read().json()`. */ @deprecated("Use read.json()", "1.4.0") def jsonRDD(json: JavaRDD[String], samplingRatio: Double): DataFrame = { @@ -1159,6 +1200,7 @@ class SQLContext(@transient val sparkContext: SparkContext) * using the default data source configured by spark.sql.sources.default. * * @group genericdata + * @deprecated As of 1.4.0, replaced by `read().load(path)`. */ @deprecated("Use read.load(path)", "1.4.0") def load(path: String): DataFrame = { @@ -1169,6 +1211,7 @@ class SQLContext(@transient val sparkContext: SparkContext) * Returns the dataset stored at path as a DataFrame, using the given data source. * * @group genericdata + * @deprecated As of 1.4.0, replaced by `read().format(source).load(path)`. */ @deprecated("Use read.format(source).load(path)", "1.4.0") def load(path: String, source: String): DataFrame = { @@ -1180,6 +1223,7 @@ class SQLContext(@transient val sparkContext: SparkContext) * a set of options as a DataFrame. * * @group genericdata + * @deprecated As of 1.4.0, replaced by `read().format(source).options(options).load()`. */ @deprecated("Use read.format(source).options(options).load()", "1.4.0") def load(source: String, options: java.util.Map[String, String]): DataFrame = { @@ -1191,6 +1235,7 @@ class SQLContext(@transient val sparkContext: SparkContext) * a set of options as a DataFrame. * * @group genericdata + * @deprecated As of 1.4.0, replaced by `read().format(source).options(options).load()`. */ @deprecated("Use read.format(source).options(options).load()", "1.4.0") def load(source: String, options: Map[String, String]): DataFrame = { @@ -1202,6 +1247,8 @@ class SQLContext(@transient val sparkContext: SparkContext) * a set of options as a DataFrame, using the given schema as the schema of the DataFrame. * * @group genericdata + * @deprecated As of 1.4.0, replaced by + * `read().format(source).schema(schema).options(options).load()`. */ @deprecated("Use read.format(source).schema(schema).options(options).load()", "1.4.0") def load(source: String, schema: StructType, options: java.util.Map[String, String]): DataFrame = @@ -1214,6 +1261,8 @@ class SQLContext(@transient val sparkContext: SparkContext) * a set of options as a DataFrame, using the given schema as the schema of the DataFrame. * * @group genericdata + * @deprecated As of 1.4.0, replaced by + * `read().format(source).schema(schema).options(options).load()`. */ @deprecated("Use read.format(source).schema(schema).options(options).load()", "1.4.0") def load(source: String, schema: StructType, options: Map[String, String]): DataFrame = { @@ -1225,6 +1274,7 @@ class SQLContext(@transient val sparkContext: SparkContext) * url named table. * * @group specificdata + * @deprecated As of 1.4.0, replaced by `read().jdbc()`. */ @deprecated("use read.jdbc()", "1.4.0") def jdbc(url: String, table: String): DataFrame = { @@ -1242,6 +1292,7 @@ class SQLContext(@transient val sparkContext: SparkContext) * @param numPartitions the number of partitions. the range `minValue`-`maxValue` will be split * evenly into this many partitions * @group specificdata + * @deprecated As of 1.4.0, replaced by `read().jdbc()`. */ @deprecated("use read.jdbc()", "1.4.0") def jdbc( @@ -1261,6 +1312,7 @@ class SQLContext(@transient val sparkContext: SparkContext) * of the [[DataFrame]]. * * @group specificdata + * @deprecated As of 1.4.0, replaced by `read().jdbc()`. */ @deprecated("use read.jdbc()", "1.4.0") def jdbc(url: String, table: String, theParts: Array[String]): DataFrame = { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/api/r/SQLUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/api/r/SQLUtils.scala index 423ecdff5804a..604f3124e23ae 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/api/r/SQLUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/api/r/SQLUtils.scala @@ -106,7 +106,7 @@ private[r] object SQLUtils { dfCols.map { col => colToRBytes(col) - } + } } def convertRowsToColumns(localDF: Array[Row], numCols: Int): Array[Array[Any]] = { @@ -121,7 +121,7 @@ private[r] object SQLUtils { val numRows = col.length val bos = new ByteArrayOutputStream() val dos = new DataOutputStream(bos) - + SerDe.writeInt(dos, numRows) col.map { item => diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/GeneratedAggregate.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/GeneratedAggregate.scala index 2ec7d4fbc92de..3e27c1bde2dfd 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/GeneratedAggregate.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/GeneratedAggregate.scala @@ -138,15 +138,15 @@ case class GeneratedAggregate( case UnscaledValue(e) => e case _ => expr } - // partial sum result can be null only when no input rows present + // partial sum result can be null only when no input rows present val updateFunction = If( IsNotNull(actualExpr), Coalesce( Add( - Coalesce(currentSum :: zero :: Nil), + Coalesce(currentSum :: zero :: Nil), Cast(expr, calcType)) :: currentSum :: zero :: Nil), currentSum) - + val result = expr.dataType match { case DecimalType.Fixed(_, _) => @@ -155,7 +155,7 @@ case class GeneratedAggregate( } AggregateEvaluation(currentSum :: Nil, initialValue :: Nil, updateFunction :: Nil, result) - + case m @ Max(expr) => val currentMax = AttributeReference("currentMax", expr.dataType, nullable = true)() val initialValue = Literal.create(null, expr.dataType) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala index 6cb67b4bbbb65..a30ade86441ca 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala @@ -65,7 +65,7 @@ case class Filter(condition: Expression, child: SparkPlan) extends UnaryNode { * :: DeveloperApi :: * Sample the dataset. * @param lowerBound Lower-bound of the sampling probability (usually 0.0) - * @param upperBound Upper-bound of the sampling probability. The expected fraction sampled + * @param upperBound Upper-bound of the sampling probability. The expected fraction sampled * will be ub - lb. * @param withReplacement Whether to sample with replacement. * @param seed the random seed diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/stat/FrequentItems.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/stat/FrequentItems.scala index fe8a81e3d0434..c41c21c0eeb50 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/stat/FrequentItems.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/stat/FrequentItems.scala @@ -62,7 +62,7 @@ private[sql] object FrequentItems extends Logging { } /** - * Finding frequent items for columns, possibly with false positives. Using the + * Finding frequent items for columns, possibly with false positives. Using the * frequent element count algorithm described in * [[http://dx.doi.org/10.1145/762471.762473, proposed by Karp, Schenker, and Papadimitriou]]. * The `support` should be greater than 1e-4. @@ -75,7 +75,7 @@ private[sql] object FrequentItems extends Logging { * @return A Local DataFrame with the Array of frequent items for each column. */ private[sql] def singlePassFreqItems( - df: DataFrame, + df: DataFrame, cols: Seq[String], support: Double): DataFrame = { require(support >= 1e-4, s"support ($support) must be greater than 1e-4.") @@ -88,7 +88,7 @@ private[sql] object FrequentItems extends Logging { val index = originalSchema.fieldIndex(name) (name, originalSchema.fields(index).dataType) } - + val freqItems = df.select(cols.map(Column(_)) : _*).rdd.aggregate(countMaps)( seqOp = (counts, row) => { var i = 0 diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/stat/StatFunctions.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/stat/StatFunctions.scala index d22f5fd2d439c..93383e5a62f11 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/stat/StatFunctions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/stat/StatFunctions.scala @@ -18,14 +18,14 @@ package org.apache.spark.sql.execution.stat import org.apache.spark.Logging -import org.apache.spark.sql.{Column, DataFrame} +import org.apache.spark.sql.{Row, Column, DataFrame} import org.apache.spark.sql.catalyst.expressions.{GenericMutableRow, Cast} import org.apache.spark.sql.catalyst.plans.logical.LocalRelation import org.apache.spark.sql.functions._ import org.apache.spark.sql.types._ private[sql] object StatFunctions extends Logging { - + /** Calculate the Pearson Correlation Coefficient for the given columns */ private[sql] def pearsonCorrelation(df: DataFrame, cols: Seq[String]): Double = { val counts = collectStatisticalData(df, cols) @@ -116,7 +116,10 @@ private[sql] object StatFunctions extends Logging { s"exceed 1e4. Currently $columnSize") val table = counts.groupBy(_.get(0)).map { case (col1Item, rows) => val countsRow = new GenericMutableRow(columnSize + 1) - rows.foreach { row => + rows.foreach { (row: Row) => + // row.get(0) is column 1 + // row.get(1) is column 2 + // row.get(3) is the frequency countsRow.setLong(distinctCol2.get(row.get(1)).get + 1, row.getLong(2)) } // the value of col1 is the first value, the rest are the counts @@ -126,6 +129,6 @@ private[sql] object StatFunctions extends Logging { val headerNames = distinctCol2.map(r => StructField(r._1.toString, LongType)).toSeq val schema = StructType(StructField(tableName, StringType) +: headerNames) - new DataFrame(df.sqlContext, LocalRelation(schema.toAttributes, table)) + new DataFrame(df.sqlContext, LocalRelation(schema.toAttributes, table)).na.fill(0.0) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/expressions/Window.scala b/sql/core/src/main/scala/org/apache/spark/sql/expressions/Window.scala index d4003b2d9cbf6..e9b60841fc28c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/expressions/Window.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/expressions/Window.scala @@ -79,3 +79,20 @@ object Window { } } + +/** + * :: Experimental :: + * Utility functions for defining window in DataFrames. + * + * {{{ + * // PARTITION BY country ORDER BY date ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW + * Window.partitionBy("country").orderBy("date").rowsBetween(Long.MinValue, 0) + * + * // PARTITION BY country ORDER BY date ROWS BETWEEN 3 PRECEDING AND 3 FOLLOWING + * Window.partitionBy("country").orderBy("date").rowsBetween(-3, 3) + * }}} + * + * @since 1.4.0 + */ +@Experimental +class Window private() // So we can see Window in JavaDoc. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala index 6dc17bbb2e768..77327f2b84eaa 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala @@ -1299,7 +1299,7 @@ object functions { * @since 1.4.0 */ def toRadians(columnName: String): Column = toRadians(Column(columnName)) - + ////////////////////////////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JDBCRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JDBCRDD.scala index 0bdb68e8ac845..40b604d710dce 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JDBCRDD.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JDBCRDD.scala @@ -262,7 +262,7 @@ private[sql] class JDBCRDD( } private def escapeSql(value: String): String = - if (value == null) null else StringUtils.replace(value, "'", "''") + if (value == null) null else StringUtils.replace(value, "'", "''") /** * Turns a single Filter into a String representing a SQL expression. @@ -304,7 +304,7 @@ private[sql] class JDBCRDD( // Each JDBC-to-Catalyst conversion corresponds to a tag defined here so that // we don't have to potentially poke around in the Metadata once for every - // row. + // row. // Is there a better way to do this? I'd rather be using a type that // contains only the tags I define. abstract class JDBCConversion diff --git a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JDBCRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JDBCRelation.scala index 09d6865457df6..30f9190d45bf8 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JDBCRelation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JDBCRelation.scala @@ -54,7 +54,7 @@ private[sql] object JDBCRelation { if (numPartitions == 1) return Array[Partition](JDBCPartition(null, 0)) // Overflow and silliness can happen if you subtract then divide. // Here we get a little roundoff, but that's (hopefully) OK. - val stride: Long = (partitioning.upperBound / numPartitions + val stride: Long = (partitioning.upperBound / numPartitions - partitioning.lowerBound / numPartitions) var i: Int = 0 var currentValue: Long = partitioning.lowerBound @@ -140,10 +140,10 @@ private[sql] case class JDBCRelation( filters, parts) } - + override def insert(data: DataFrame, overwrite: Boolean): Unit = { data.write .mode(if (overwrite) SaveMode.Overwrite else SaveMode.Append) .jdbc(url, table, properties) - } + } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/jdbc.scala b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/jdbc.scala index f21dd29aca37f..dd8aaf6474895 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/jdbc.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/jdbc.scala @@ -240,10 +240,10 @@ package object jdbc { } } } - + def getDriverClassName(url: String): String = DriverManager.getDriver(url) match { case wrapper: DriverWrapper => wrapper.wrapped.getClass.getCanonicalName - case driver => driver.getClass.getCanonicalName + case driver => driver.getClass.getCanonicalName } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/package.scala b/sql/core/src/main/scala/org/apache/spark/sql/package.scala index 3f97a11ceb97d..4e94fd07a8771 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/package.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/package.scala @@ -44,6 +44,7 @@ package object sql { /** * Type alias for [[DataFrame]]. Kept here for backward source compatibility for Scala. + * @deprecated As of 1.3.0, replaced by `DataFrame`. */ @deprecated("1.3.0", "use DataFrame") type SchemaRDD = DataFrame diff --git a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetConverter.scala b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetConverter.scala index 1b4196ab0be35..caa9f045537d0 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetConverter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetConverter.scala @@ -243,8 +243,10 @@ private[parquet] abstract class CatalystConverter extends GroupConverter { /** * Read a decimal value from a Parquet Binary into "dest". Only supports decimals that fit in * a long (i.e. precision <= 18) + * + * Returned value is needed by CatalystConverter, which doesn't reuse the Decimal object. */ - protected[parquet] def readDecimal(dest: Decimal, value: Binary, ctype: DecimalType): Unit = { + protected[parquet] def readDecimal(dest: Decimal, value: Binary, ctype: DecimalType): Decimal = { val precision = ctype.precisionInfo.get.precision val scale = ctype.precisionInfo.get.scale val bytes = value.getBytes diff --git a/sql/core/src/main/scala/org/apache/spark/sql/parquet/newParquet.scala b/sql/core/src/main/scala/org/apache/spark/sql/parquet/newParquet.scala index 8b3e1b2b59bf6..824ae36968c32 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/parquet/newParquet.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/parquet/newParquet.scala @@ -155,7 +155,7 @@ private[sql] class ParquetRelation2( meta } - override def equals(other: scala.Any): Boolean = other match { + override def equals(other: Any): Boolean = other match { case that: ParquetRelation2 => val schemaEquality = if (shouldMergeSchemas) { this.shouldMergeSchemas == that.shouldMergeSchemas @@ -190,7 +190,7 @@ private[sql] class ParquetRelation2( } } - override def dataSchema: StructType = metadataCache.dataSchema + override def dataSchema: StructType = maybeDataSchema.getOrElse(metadataCache.dataSchema) override private[sql] def refresh(): Unit = { super.refresh() diff --git a/sql/core/src/main/scala/org/apache/spark/sql/sources/PartitioningUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/sources/PartitioningUtils.scala index dafdf0f8b4564..c4c99de5a38dc 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/sources/PartitioningUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/sources/PartitioningUtils.scala @@ -187,7 +187,7 @@ private[sql] object PartitioningUtils { Seq.empty } else { assert(distinctPartitionsColNames.size == 1, { - val list = distinctPartitionsColNames.mkString("\t", "\n", "") + val list = distinctPartitionsColNames.mkString("\t", "\n\t", "") s"Conflicting partition column names detected:\n$list" }) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/sources/SqlNewHadoopRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/sources/SqlNewHadoopRDD.scala index a74a98631da35..ebad0c1564ec0 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/sources/SqlNewHadoopRDD.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/sources/SqlNewHadoopRDD.scala @@ -216,7 +216,7 @@ private[sql] class SqlNewHadoopRDD[K, V]( override def getPreferredLocations(hsplit: SparkPartition): Seq[String] = { val split = hsplit.asInstanceOf[SqlNewHadoopPartition].serializableHadoopSplit.value val locs = HadoopRDD.SPLIT_INFO_REFLECTIONS match { - case Some(c) => + case Some(c) => try { val infos = c.newGetLocationInfo.invoke(split).asInstanceOf[Array[AnyRef]] Some(HadoopRDD.convertSplitLocationInfo(infos)) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/sources/commands.scala b/sql/core/src/main/scala/org/apache/spark/sql/sources/commands.scala index 3132067d562f6..71f016b1f14de 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/sources/commands.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/sources/commands.scala @@ -30,9 +30,10 @@ import org.apache.spark._ import org.apache.spark.mapred.SparkHadoopMapRedUtil import org.apache.spark.mapreduce.SparkHadoopMapReduceUtil import org.apache.spark.sql.catalyst.CatalystTypeConverters +import org.apache.spark.sql.catalyst.analysis.UnresolvedAttribute import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.codegen.GenerateProjection -import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan +import org.apache.spark.sql.catalyst.plans.logical.{Project, LogicalPlan} import org.apache.spark.sql.execution.RunnableCommand import org.apache.spark.sql.types.StructType import org.apache.spark.sql.{DataFrame, SQLConf, SQLContext, SaveMode} @@ -94,10 +95,19 @@ private[sql] case class InsertIntoHadoopFsRelation( // We create a DataFrame by applying the schema of relation to the data to make sure. // We are writing data based on the expected schema, - val df = sqlContext.createDataFrame( - DataFrame(sqlContext, query).queryExecution.toRdd, - relation.schema, - needsConversion = false) + val df = { + // For partitioned relation r, r.schema's column ordering can be different from the column + // ordering of data.logicalPlan (partition columns are all moved after data column). We + // need a Project to adjust the ordering, so that inside InsertIntoHadoopFsRelation, we can + // safely apply the schema of r.schema to the data. + val project = Project( + relation.schema.map(field => new UnresolvedAttribute(Seq(field.name))), query) + + sqlContext.createDataFrame( + DataFrame(sqlContext, project).queryExecution.toRdd, + relation.schema, + needsConversion = false) + } val partitionColumns = relation.partitionColumns.fieldNames if (partitionColumns.isEmpty) { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/sources/ddl.scala b/sql/core/src/main/scala/org/apache/spark/sql/sources/ddl.scala index 22587f5a1c6f1..20afd60cb7767 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/sources/ddl.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/sources/ddl.scala @@ -25,7 +25,7 @@ import org.apache.hadoop.fs.Path import org.apache.spark.Logging import org.apache.spark.deploy.SparkHadoopUtil import org.apache.spark.sql.catalyst.AbstractSparkSQLParser -import org.apache.spark.sql.catalyst.analysis.{UnresolvedAttribute, UnresolvedRelation} +import org.apache.spark.sql.catalyst.analysis.UnresolvedRelation import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference, Row} import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.execution.RunnableCommand @@ -322,19 +322,13 @@ private[sql] object ResolvedDataSource { Some(partitionColumnsSchema(data.schema, partitionColumns)), caseInsensitiveOptions) - // For partitioned relation r, r.schema's column ordering is different with the column - // ordering of data.logicalPlan. We need a Project to adjust the ordering. - // So, inside InsertIntoHadoopFsRelation, we can safely apply the schema of r.schema to - // the data. - val project = - Project( - r.schema.map(field => new UnresolvedAttribute(Seq(field.name))), - data.logicalPlan) - + // For partitioned relation r, r.schema's column ordering can be different from the column + // ordering of data.logicalPlan (partition columns are all moved after data column). This + // will be adjusted within InsertIntoHadoopFsRelation. sqlContext.executePlan( InsertIntoHadoopFsRelation( r, - project, + data.logicalPlan, mode)).toRdd r case _ => diff --git a/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala b/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala index c06026e042d9f..f5bd2d2941ca0 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala @@ -93,7 +93,7 @@ trait SchemaRelationProvider { } /** - * ::DeveloperApi:: + * ::Experimental:: * Implemented by objects that produce relations for a specific kind of data source * with a given schema and partitioned columns. When Spark SQL is given a DDL operation with a * USING clause specified (to specify the implemented [[HadoopFsRelationProvider]]), a user defined @@ -115,6 +115,7 @@ trait SchemaRelationProvider { * * @since 1.4.0 */ +@Experimental trait HadoopFsRelationProvider { /** * Returns a new base relation with the given parameters, a user defined schema, and a list of @@ -378,10 +379,10 @@ abstract class HadoopFsRelation private[sql](maybePartitionSpec: Option[Partitio var leafDirToChildrenFiles = mutable.Map.empty[Path, Array[FileStatus]] def refresh(): Unit = { - // We don't filter files/directories whose name start with "_" or "." here, as specific data - // sources may take advantages over them (e.g. Parquet _metadata and _common_metadata files). - // But "_temporary" directories are explicitly ignored since failed tasks/jobs may leave - // partial/corrupted data files there. + // We don't filter files/directories whose name start with "_" except "_temporary" here, as + // specific data sources may take advantages over them (e.g. Parquet _metadata and + // _common_metadata files). "_temporary" directories are explicitly ignored since failed + // tasks/jobs may leave partial/corrupted data files there. def listLeafFilesAndDirs(fs: FileSystem, status: FileStatus): Set[FileStatus] = { if (status.getPath.getName.toLowerCase == "_temporary") { Set.empty @@ -399,6 +400,9 @@ abstract class HadoopFsRelation private[sql](maybePartitionSpec: Option[Partitio val fs = hdfsPath.getFileSystem(hadoopConf) val qualified = hdfsPath.makeQualified(fs.getUri, fs.getWorkingDirectory) Try(fs.getFileStatus(qualified)).toOption.toArray.flatMap(listLeafFilesAndDirs(fs, _)) + }.filterNot { status => + // SPARK-8037: Ignores files like ".DS_Store" and other hidden files/directories + status.getPath.getName.startsWith(".") } val files = statuses.filterNot(_.isDir) @@ -499,7 +503,7 @@ abstract class HadoopFsRelation private[sql](maybePartitionSpec: Option[Partitio */ override lazy val schema: StructType = { val dataSchemaColumnNames = dataSchema.map(_.name.toLowerCase).toSet - StructType(dataSchema ++ partitionSpec.partitionColumns.filterNot { column => + StructType(dataSchema ++ partitionColumns.filterNot { column => dataSchemaColumnNames.contains(column.name.toLowerCase) }) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala index d006b83fc075a..bfba379d9a518 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala @@ -19,6 +19,7 @@ package org.apache.spark.sql import org.scalatest.Matchers._ +import org.apache.spark.sql.execution.Project import org.apache.spark.sql.functions._ import org.apache.spark.sql.test.TestSQLContext import org.apache.spark.sql.test.TestSQLContext.implicits._ @@ -27,6 +28,12 @@ import org.apache.spark.sql.types._ class ColumnExpressionSuite extends QueryTest { import org.apache.spark.sql.TestData._ + test("alias") { + val df = Seq((1, Seq(1, 2, 3))).toDF("a", "intList") + assert(df.select(df("a").as("b")).columns.head === "b") + assert(df.select(df("a").alias("b")).columns.head === "b") + } + test("single explode") { val df = Seq((1, Seq(1, 2, 3))).toDF("a", "intList") checkAnswer( @@ -446,13 +453,51 @@ class ColumnExpressionSuite extends QueryTest { } test("rand") { - val randCol = testData.select('key, rand(5L).as("rand")) + val randCol = testData.select($"key", rand(5L).as("rand")) randCol.columns.length should be (2) val rows = randCol.collect() rows.foreach { row => assert(row.getDouble(1) <= 1.0) assert(row.getDouble(1) >= 0.0) } + + def checkNumProjects(df: DataFrame, expectedNumProjects: Int): Unit = { + val projects = df.queryExecution.executedPlan.collect { + case project: Project => project + } + assert(projects.size === expectedNumProjects) + } + + // We first create a plan with two Projects. + // Project [rand + 1 AS rand1, rand - 1 AS rand2] + // Project [key, (Rand 5 + 1) AS rand] + // LogicalRDD [key, value] + // Because Rand function is not deterministic, the column rand is not deterministic. + // So, in the optimizer, we will not collapse Project [rand + 1 AS rand1, rand - 1 AS rand2] + // and Project [key, Rand 5 AS rand]. The final plan still has two Projects. + val dfWithTwoProjects = + testData + .select($"key", (rand(5L) + 1).as("rand")) + .select(($"rand" + 1).as("rand1"), ($"rand" - 1).as("rand2")) + checkNumProjects(dfWithTwoProjects, 2) + + // Now, we add one more project rand1 - rand2 on top of the query plan. + // Since rand1 and rand2 are deterministic (they basically apply +/- to the generated + // rand value), we can collapse rand1 - rand2 to the Project generating rand1 and rand2. + // So, the plan will be optimized from ... + // Project [(rand1 - rand2) AS (rand1 - rand2)] + // Project [rand + 1 AS rand1, rand - 1 AS rand2] + // Project [key, (Rand 5 + 1) AS rand] + // LogicalRDD [key, value] + // to ... + // Project [((rand + 1 AS rand1) - (rand - 1 AS rand2)) AS (rand1 - rand2)] + // Project [key, Rand 5 AS rand] + // LogicalRDD [key, value] + val dfWithThreeProjects = dfWithTwoProjects.select($"rand1" - $"rand2") + checkNumProjects(dfWithThreeProjects, 2) + dfWithThreeProjects.collect().foreach { row => + assert(row.getDouble(0) === 2.0 +- 0.0001) + } } test("randn") { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameStatSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameStatSuite.scala index 46b1845a9180c..438f479459dfe 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameStatSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameStatSuite.scala @@ -17,14 +17,14 @@ package org.apache.spark.sql -import org.scalatest.FunSuite import org.scalatest.Matchers._ +import org.apache.spark.SparkFunSuite import org.apache.spark.sql.test.TestSQLContext import org.apache.spark.sql.test.TestSQLContext.implicits._ -class DataFrameStatSuite extends FunSuite { - +class DataFrameStatSuite extends SparkFunSuite { + val sqlCtx = TestSQLContext def toLetter(i: Int): String = (i + 97).toChar.toString @@ -74,10 +74,10 @@ class DataFrameStatSuite extends FunSuite { val rows: Array[Row] = crosstab.collect().sortBy(_.getString(0)) assert(rows(0).get(0).toString === "0") assert(rows(0).getLong(1) === 2L) - assert(rows(0).get(2) === null) + assert(rows(0).get(2) === 0L) assert(rows(1).get(0).toString === "1") assert(rows(1).getLong(1) === 1L) - assert(rows(1).get(2) === null) + assert(rows(1).get(2) === 0L) assert(rows(2).get(0).toString === "2") assert(rows(2).getLong(1) === 2L) assert(rows(2).getLong(2) === 1L) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/MathExpressionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/MathExpressionsSuite.scala index c4281c4b55c02..dd68965444f5d 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/MathExpressionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/MathExpressionsSuite.scala @@ -206,7 +206,7 @@ class MathExpressionsSuite extends QueryTest { } test("log") { - testOneToOneNonNegativeMathFunction(log, math.log) + testOneToOneNonNegativeMathFunction(org.apache.spark.sql.functions.log, math.log) } test("log10") { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/RowSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/RowSuite.scala index fb3ba4bc1b908..513ac915dcb2a 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/RowSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/RowSuite.scala @@ -17,15 +17,15 @@ package org.apache.spark.sql +import org.apache.spark.SparkFunSuite import org.apache.spark.sql.execution.SparkSqlSerializer -import org.scalatest.FunSuite import org.apache.spark.sql.catalyst.expressions.{GenericMutableRow, SpecificMutableRow} import org.apache.spark.sql.test.TestSQLContext import org.apache.spark.sql.test.TestSQLContext.implicits._ import org.apache.spark.sql.types._ -class RowSuite extends FunSuite { +class RowSuite extends SparkFunSuite { test("create row") { val expected = new GenericMutableRow(4) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLConfSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLConfSuite.scala index bf73d0c7074a5..3a5f071e2f7cb 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLConfSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLConfSuite.scala @@ -17,14 +17,13 @@ package org.apache.spark.sql -import org.scalatest.FunSuiteLike - +import org.apache.spark.SparkFunSuite import org.apache.spark.sql.test._ /* Implicits */ import TestSQLContext._ -class SQLConfSuite extends QueryTest with FunSuiteLike { +class SQLConfSuite extends QueryTest { val testKey = "test.key.0" val testVal = "test.val.0" diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLContextSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLContextSuite.scala index f186bc1c18123..797d123b48668 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLContextSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLContextSuite.scala @@ -17,11 +17,12 @@ package org.apache.spark.sql -import org.scalatest.{BeforeAndAfterAll, FunSuite} +import org.scalatest.BeforeAndAfterAll +import org.apache.spark.SparkFunSuite import org.apache.spark.sql.test.TestSQLContext -class SQLContextSuite extends FunSuite with BeforeAndAfterAll { +class SQLContextSuite extends SparkFunSuite with BeforeAndAfterAll { private val testSqlContext = TestSQLContext private val testSparkContext = TestSQLContext.sparkContext diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala index bf18bf854aa4a..63f7d314fb699 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala @@ -24,7 +24,7 @@ import org.apache.spark.sql.catalyst.errors.DialectException import org.apache.spark.sql.execution.GeneratedAggregate import org.apache.spark.sql.functions._ import org.apache.spark.sql.TestData._ -import org.apache.spark.sql.test.TestSQLContext +import org.apache.spark.sql.test.{SQLTestUtils, TestSQLContext} import org.apache.spark.sql.test.TestSQLContext.{udf => _, _} import org.apache.spark.sql.types._ @@ -32,12 +32,12 @@ import org.apache.spark.sql.types._ /** A SQL Dialect for testing purpose, and it can not be nested type */ class MyDialect extends DefaultParserDialect -class SQLQuerySuite extends QueryTest with BeforeAndAfterAll { +class SQLQuerySuite extends QueryTest with BeforeAndAfterAll with SQLTestUtils { // Make sure the tables are loaded. TestData - import org.apache.spark.sql.test.TestSQLContext.implicits._ - val sqlCtx = TestSQLContext + val sqlContext = TestSQLContext + import sqlContext.implicits._ test("SPARK-6743: no columns from cache") { Seq( @@ -915,7 +915,7 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll { Row(values(0).toInt, values(1), values(2).toBoolean, v4) } - val df1 = sqlCtx.createDataFrame(rowRDD1, schema1) + val df1 = createDataFrame(rowRDD1, schema1) df1.registerTempTable("applySchema1") checkAnswer( sql("SELECT * FROM applySchema1"), @@ -945,7 +945,7 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll { Row(Row(values(0).toInt, values(2).toBoolean), Map(values(1) -> v4)) } - val df2 = sqlCtx.createDataFrame(rowRDD2, schema2) + val df2 = createDataFrame(rowRDD2, schema2) df2.registerTempTable("applySchema2") checkAnswer( sql("SELECT * FROM applySchema2"), @@ -970,7 +970,7 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll { Row(Row(values(0).toInt, values(2).toBoolean), scala.collection.mutable.Map(values(1) -> v4)) } - val df3 = sqlCtx.createDataFrame(rowRDD3, schema2) + val df3 = createDataFrame(rowRDD3, schema2) df3.registerTempTable("applySchema3") checkAnswer( @@ -1015,7 +1015,7 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll { .build() val schemaWithMeta = new StructType(Array( schema("id"), schema("name").copy(metadata = metadata), schema("age"))) - val personWithMeta = sqlCtx.createDataFrame(person.rdd, schemaWithMeta) + val personWithMeta = createDataFrame(person.rdd, schemaWithMeta) def validateMetadata(rdd: DataFrame): Unit = { assert(rdd.schema("name").metadata.getString(docKey) == docValue) } @@ -1331,4 +1331,24 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll { checkAnswer(sql("SELECT a.`c.b`, `b.$q`[0].`a@!.q`, `q.w`.`w.i&`[0] FROM t"), Row(1, 1, 1)) } + + test("SPARK-7952: fix the equality check between boolean and numeric types") { + withTempTable("t") { + // numeric field i, boolean field j, result of i = j, result of i <=> j + Seq[(Integer, java.lang.Boolean, java.lang.Boolean, java.lang.Boolean)]( + (1, true, true, true), + (0, false, true, true), + (2, true, false, false), + (2, false, false, false), + (null, true, null, false), + (null, false, null, false), + (0, null, null, false), + (1, null, null, false), + (null, null, null, true) + ).toDF("i", "b", "r1", "r2").registerTempTable("t") + + checkAnswer(sql("select i = b from t"), sql("select r1 from t")) + checkAnswer(sql("select i <=> b from t"), sql("select r2 from t")) + } + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/ScalaReflectionRelationSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/ScalaReflectionRelationSuite.scala index 52d265b445e14..d2ede39f0a5f6 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/ScalaReflectionRelationSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/ScalaReflectionRelationSuite.scala @@ -19,8 +19,7 @@ package org.apache.spark.sql import java.sql.{Date, Timestamp} -import org.scalatest.FunSuite - +import org.apache.spark.SparkFunSuite import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.test.TestSQLContext._ @@ -74,7 +73,7 @@ case class ComplexReflectData( mapFieldContainsNull: Map[Int, Option[Long]], dataField: Data) -class ScalaReflectionRelationSuite extends FunSuite { +class ScalaReflectionRelationSuite extends SparkFunSuite { import org.apache.spark.sql.test.TestSQLContext.implicits._ diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SerializationSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SerializationSuite.scala index 6f6d3c9c243d4..1e8cde606b67b 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SerializationSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SerializationSuite.scala @@ -17,13 +17,11 @@ package org.apache.spark.sql -import org.scalatest.FunSuite - -import org.apache.spark.SparkConf +import org.apache.spark.{SparkConf, SparkFunSuite} import org.apache.spark.serializer.JavaSerializer import org.apache.spark.sql.test.TestSQLContext -class SerializationSuite extends FunSuite { +class SerializationSuite extends SparkFunSuite { test("[SPARK-5235] SQLContext should be serializable") { val sqlContext = new SQLContext(TestSQLContext.sparkContext) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnStatsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnStatsSuite.scala index 7cefcf44061ce..339e719f39f16 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnStatsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnStatsSuite.scala @@ -17,12 +17,11 @@ package org.apache.spark.sql.columnar -import org.scalatest.FunSuite - +import org.apache.spark.SparkFunSuite import org.apache.spark.sql.catalyst.expressions.Row import org.apache.spark.sql.types._ -class ColumnStatsSuite extends FunSuite { +class ColumnStatsSuite extends SparkFunSuite { testColumnStats(classOf[ByteColumnStats], BYTE, Row(Byte.MaxValue, Byte.MinValue, 0)) testColumnStats(classOf[ShortColumnStats], SHORT, Row(Short.MaxValue, Short.MinValue, 0)) testColumnStats(classOf[IntColumnStats], INT, Row(Int.MaxValue, Int.MinValue, 0)) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnTypeSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnTypeSuite.scala index 061efb37a0ac3..a1e76eaa982cc 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnTypeSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnTypeSuite.scala @@ -23,15 +23,14 @@ import java.sql.Timestamp import com.esotericsoftware.kryo.{Serializer, Kryo} import com.esotericsoftware.kryo.io.{Input, Output} import org.apache.spark.serializer.KryoRegistrator -import org.scalatest.FunSuite -import org.apache.spark.{SparkConf, Logging} +import org.apache.spark.{Logging, SparkConf, SparkFunSuite} import org.apache.spark.sql.catalyst.expressions.GenericMutableRow import org.apache.spark.sql.columnar.ColumnarTestUtils._ import org.apache.spark.sql.execution.SparkSqlSerializer import org.apache.spark.sql.types._ -class ColumnTypeSuite extends FunSuite with Logging { +class ColumnTypeSuite extends SparkFunSuite with Logging { val DEFAULT_BUFFER_SIZE = 512 test("defaultSize") { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/columnar/NullableColumnAccessorSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/columnar/NullableColumnAccessorSuite.scala index a0702144f942c..2a6e0c376551a 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/columnar/NullableColumnAccessorSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/columnar/NullableColumnAccessorSuite.scala @@ -19,8 +19,7 @@ package org.apache.spark.sql.columnar import java.nio.ByteBuffer -import org.scalatest.FunSuite - +import org.apache.spark.SparkFunSuite import org.apache.spark.sql.catalyst.expressions.GenericMutableRow import org.apache.spark.sql.types.DataType @@ -39,7 +38,7 @@ object TestNullableColumnAccessor { } } -class NullableColumnAccessorSuite extends FunSuite { +class NullableColumnAccessorSuite extends SparkFunSuite { import ColumnarTestUtils._ Seq( diff --git a/sql/core/src/test/scala/org/apache/spark/sql/columnar/NullableColumnBuilderSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/columnar/NullableColumnBuilderSuite.scala index 3a5605d2335d7..cb4e9f1eb7f46 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/columnar/NullableColumnBuilderSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/columnar/NullableColumnBuilderSuite.scala @@ -17,8 +17,7 @@ package org.apache.spark.sql.columnar -import org.scalatest.FunSuite - +import org.apache.spark.SparkFunSuite import org.apache.spark.sql.execution.SparkSqlSerializer import org.apache.spark.sql.types._ @@ -35,7 +34,7 @@ object TestNullableColumnBuilder { } } -class NullableColumnBuilderSuite extends FunSuite { +class NullableColumnBuilderSuite extends SparkFunSuite { import ColumnarTestUtils._ Seq( diff --git a/sql/core/src/test/scala/org/apache/spark/sql/columnar/PartitionBatchPruningSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/columnar/PartitionBatchPruningSuite.scala index 2a0b701cad7fa..cda1b0992e36f 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/columnar/PartitionBatchPruningSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/columnar/PartitionBatchPruningSuite.scala @@ -17,13 +17,14 @@ package org.apache.spark.sql.columnar -import org.scalatest.{BeforeAndAfter, BeforeAndAfterAll, FunSuite} +import org.scalatest.{BeforeAndAfter, BeforeAndAfterAll} +import org.apache.spark.SparkFunSuite import org.apache.spark.sql._ import org.apache.spark.sql.test.TestSQLContext._ import org.apache.spark.sql.test.TestSQLContext.implicits._ -class PartitionBatchPruningSuite extends FunSuite with BeforeAndAfterAll with BeforeAndAfter { +class PartitionBatchPruningSuite extends SparkFunSuite with BeforeAndAfterAll with BeforeAndAfter { val originalColumnBatchSize = conf.columnBatchSize val originalInMemoryPartitionPruning = conf.inMemoryPartitionPruning diff --git a/sql/core/src/test/scala/org/apache/spark/sql/columnar/compression/BooleanBitSetSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/columnar/compression/BooleanBitSetSuite.scala index 8b518f094174c..20d65a74e3b7a 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/columnar/compression/BooleanBitSetSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/columnar/compression/BooleanBitSetSuite.scala @@ -17,14 +17,13 @@ package org.apache.spark.sql.columnar.compression -import org.scalatest.FunSuite - +import org.apache.spark.SparkFunSuite import org.apache.spark.sql.Row import org.apache.spark.sql.catalyst.expressions.GenericMutableRow import org.apache.spark.sql.columnar.{NoopColumnStats, BOOLEAN} import org.apache.spark.sql.columnar.ColumnarTestUtils._ -class BooleanBitSetSuite extends FunSuite { +class BooleanBitSetSuite extends SparkFunSuite { import BooleanBitSet._ def skeleton(count: Int) { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/columnar/compression/DictionaryEncodingSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/columnar/compression/DictionaryEncodingSuite.scala index cef60ec204faa..acfab6586c0d1 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/columnar/compression/DictionaryEncodingSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/columnar/compression/DictionaryEncodingSuite.scala @@ -19,14 +19,13 @@ package org.apache.spark.sql.columnar.compression import java.nio.ByteBuffer -import org.scalatest.FunSuite - +import org.apache.spark.SparkFunSuite import org.apache.spark.sql.catalyst.expressions.GenericMutableRow import org.apache.spark.sql.columnar._ import org.apache.spark.sql.columnar.ColumnarTestUtils._ import org.apache.spark.sql.types.AtomicType -class DictionaryEncodingSuite extends FunSuite { +class DictionaryEncodingSuite extends SparkFunSuite { testDictionaryEncoding(new IntColumnStats, INT) testDictionaryEncoding(new LongColumnStats, LONG) testDictionaryEncoding(new StringColumnStats, STRING) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/columnar/compression/IntegralDeltaSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/columnar/compression/IntegralDeltaSuite.scala index 5514590541dd6..2111e9fbe62cb 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/columnar/compression/IntegralDeltaSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/columnar/compression/IntegralDeltaSuite.scala @@ -17,14 +17,13 @@ package org.apache.spark.sql.columnar.compression -import org.scalatest.FunSuite - +import org.apache.spark.SparkFunSuite import org.apache.spark.sql.catalyst.expressions.GenericMutableRow import org.apache.spark.sql.columnar._ import org.apache.spark.sql.columnar.ColumnarTestUtils._ import org.apache.spark.sql.types.IntegralType -class IntegralDeltaSuite extends FunSuite { +class IntegralDeltaSuite extends SparkFunSuite { testIntegralDelta(new IntColumnStats, INT, IntDelta) testIntegralDelta(new LongColumnStats, LONG, LongDelta) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/columnar/compression/RunLengthEncodingSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/columnar/compression/RunLengthEncodingSuite.scala index 6ee48f6291914..67ec08f594a43 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/columnar/compression/RunLengthEncodingSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/columnar/compression/RunLengthEncodingSuite.scala @@ -17,14 +17,13 @@ package org.apache.spark.sql.columnar.compression -import org.scalatest.FunSuite - +import org.apache.spark.SparkFunSuite import org.apache.spark.sql.catalyst.expressions.GenericMutableRow import org.apache.spark.sql.columnar._ import org.apache.spark.sql.columnar.ColumnarTestUtils._ import org.apache.spark.sql.types.AtomicType -class RunLengthEncodingSuite extends FunSuite { +class RunLengthEncodingSuite extends SparkFunSuite { testRunLengthEncoding(new NoopColumnStats, BOOLEAN) testRunLengthEncoding(new ByteColumnStats, BYTE) testRunLengthEncoding(new ShortColumnStats, SHORT) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala index 523be56df65ba..45a7e8fe68f72 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala @@ -17,8 +17,7 @@ package org.apache.spark.sql.execution -import org.scalatest.FunSuite - +import org.apache.spark.SparkFunSuite import org.apache.spark.sql.{SQLConf, execution} import org.apache.spark.sql.functions._ import org.apache.spark.sql.TestData._ @@ -31,7 +30,7 @@ import org.apache.spark.sql.test.TestSQLContext.planner._ import org.apache.spark.sql.types._ -class PlannerSuite extends FunSuite { +class PlannerSuite extends SparkFunSuite { test("unions are collapsed") { val query = testData.unionAll(testData).unionAll(testData).logicalPlan val planned = BasicOperators(query).head diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkSqlSerializer2Suite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkSqlSerializer2Suite.scala index 15337c4045436..6ca5390cde23e 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkSqlSerializer2Suite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkSqlSerializer2Suite.scala @@ -19,17 +19,17 @@ package org.apache.spark.sql.execution import java.sql.{Timestamp, Date} -import org.scalatest.{FunSuite, BeforeAndAfterAll} +import org.scalatest.BeforeAndAfterAll import org.apache.spark.rdd.ShuffledRDD import org.apache.spark.serializer.Serializer -import org.apache.spark.ShuffleDependency +import org.apache.spark.{ShuffleDependency, SparkFunSuite} import org.apache.spark.sql.types._ import org.apache.spark.sql.Row import org.apache.spark.sql.test.TestSQLContext._ import org.apache.spark.sql.{MyDenseVectorUDT, QueryTest} -class SparkSqlSerializer2DataTypeSuite extends FunSuite { +class SparkSqlSerializer2DataTypeSuite extends SparkFunSuite { // Make sure that we will not use serializer2 for unsupported data types. def checkSupported(dataType: DataType, isSupported: Boolean): Unit = { val testName = diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/debug/DebuggingSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/debug/DebuggingSuite.scala index 358d8cf06e463..8ec3985e00360 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/debug/DebuggingSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/debug/DebuggingSuite.scala @@ -17,12 +17,11 @@ package org.apache.spark.sql.execution.debug -import org.scalatest.FunSuite - +import org.apache.spark.SparkFunSuite import org.apache.spark.sql.TestData._ import org.apache.spark.sql.test.TestSQLContext._ -class DebuggingSuite extends FunSuite { +class DebuggingSuite extends SparkFunSuite { test("DataFrame.debug()") { testData.debug() } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/HashedRelationSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/HashedRelationSuite.scala index 2aad01ded1acf..5290c28cfca02 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/HashedRelationSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/HashedRelationSuite.scala @@ -17,13 +17,12 @@ package org.apache.spark.sql.execution.joins -import org.scalatest.FunSuite - +import org.apache.spark.SparkFunSuite import org.apache.spark.sql.catalyst.expressions.{Projection, Row} import org.apache.spark.util.collection.CompactBuffer -class HashedRelationSuite extends FunSuite { +class HashedRelationSuite extends SparkFunSuite { // Key is simply the record itself private val keyProjection = new Projection { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala index 30279f528944b..e20c66cb2f1d7 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala @@ -21,14 +21,15 @@ import java.math.BigDecimal import java.sql.DriverManager import java.util.{Calendar, GregorianCalendar, Properties} +import org.apache.spark.SparkFunSuite import org.apache.spark.sql.test._ import org.apache.spark.sql.types._ import org.h2.jdbc.JdbcSQLException -import org.scalatest.{FunSuite, BeforeAndAfter} +import org.scalatest.BeforeAndAfter import TestSQLContext._ import TestSQLContext.implicits._ -class JDBCSuite extends FunSuite with BeforeAndAfter { +class JDBCSuite extends SparkFunSuite with BeforeAndAfter { val url = "jdbc:h2:mem:testdb0" val urlWithUserAndPass = "jdbc:h2:mem:testdb0;user=testUser;password=testPass" var conn: java.sql.Connection = null @@ -67,7 +68,7 @@ class JDBCSuite extends FunSuite with BeforeAndAfter { |USING org.apache.spark.sql.jdbc |OPTIONS (url '$url', dbtable 'TEST.PEOPLE', user 'testUser', password 'testPass') """.stripMargin.replaceAll("\n", " ")) - + sql( s""" |CREATE TEMPORARY TABLE fetchtwo @@ -75,7 +76,7 @@ class JDBCSuite extends FunSuite with BeforeAndAfter { |OPTIONS (url '$url', dbtable 'TEST.PEOPLE', user 'testUser', password 'testPass', | fetchSize '2') """.stripMargin.replaceAll("\n", " ")) - + sql( s""" |CREATE TEMPORARY TABLE parts @@ -208,7 +209,7 @@ class JDBCSuite extends FunSuite with BeforeAndAfter { assert(ids(1) === 2) assert(ids(2) === 3) } - + test("SELECT second field when fetchSize is two") { val ids = sql("SELECT THEID FROM fetchtwo").collect().map(x => x.getInt(0)).sortWith(_ < _) assert(ids.size === 3) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCWriteSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCWriteSuite.scala index 2e4c12f9da80c..2de8c1a6098e0 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCWriteSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCWriteSuite.scala @@ -20,13 +20,14 @@ package org.apache.spark.sql.jdbc import java.sql.DriverManager import java.util.Properties -import org.scalatest.{BeforeAndAfter, FunSuite} +import org.scalatest.BeforeAndAfter +import org.apache.spark.SparkFunSuite import org.apache.spark.sql.{SaveMode, Row} import org.apache.spark.sql.test._ import org.apache.spark.sql.types._ -class JDBCWriteSuite extends FunSuite with BeforeAndAfter { +class JDBCWriteSuite extends SparkFunSuite with BeforeAndAfter { val url = "jdbc:h2:mem:testdb2" var conn: java.sql.Connection = null val url1 = "jdbc:h2:mem:testdb3" @@ -35,12 +36,12 @@ class JDBCWriteSuite extends FunSuite with BeforeAndAfter { properties.setProperty("user", "testUser") properties.setProperty("password", "testPass") properties.setProperty("rowId", "false") - + before { Class.forName("org.h2.Driver") conn = DriverManager.getConnection(url) conn.prepareStatement("create schema test").executeUpdate() - + conn1 = DriverManager.getConnection(url1, properties) conn1.prepareStatement("create schema test").executeUpdate() conn1.prepareStatement("drop table if exists test.people").executeUpdate() @@ -52,20 +53,20 @@ class JDBCWriteSuite extends FunSuite with BeforeAndAfter { conn1.prepareStatement( "create table test.people1 (name TEXT(32) NOT NULL, theid INTEGER NOT NULL)").executeUpdate() conn1.commit() - + TestSQLContext.sql( s""" |CREATE TEMPORARY TABLE PEOPLE |USING org.apache.spark.sql.jdbc |OPTIONS (url '$url1', dbtable 'TEST.PEOPLE', user 'testUser', password 'testPass') """.stripMargin.replaceAll("\n", " ")) - + TestSQLContext.sql( s""" |CREATE TEMPORARY TABLE PEOPLE1 |USING org.apache.spark.sql.jdbc |OPTIONS (url '$url1', dbtable 'TEST.PEOPLE1', user 'testUser', password 'testPass') - """.stripMargin.replaceAll("\n", " ")) + """.stripMargin.replaceAll("\n", " ")) } after { @@ -151,5 +152,5 @@ class JDBCWriteSuite extends FunSuite with BeforeAndAfter { TestSQLContext.sql("INSERT OVERWRITE TABLE PEOPLE1 SELECT * FROM PEOPLE") assert(2 == TestSQLContext.read.jdbc(url1, "TEST.PEOPLE1", properties).count) assert(2 == TestSQLContext.read.jdbc(url1, "TEST.PEOPLE1", properties).collect()(0).length) - } + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetPartitionDiscoverySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetPartitionDiscoverySuite.scala index f231589e9674d..3b29979452ad9 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetPartitionDiscoverySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetPartitionDiscoverySuite.scala @@ -18,10 +18,11 @@ package org.apache.spark.sql.parquet import java.io.File import java.math.BigInteger -import java.sql.{Timestamp, Date} +import java.sql.Timestamp import scala.collection.mutable.ArrayBuffer +import com.google.common.io.Files import org.apache.hadoop.fs.Path import org.apache.spark.sql.catalyst.expressions.Literal @@ -432,4 +433,20 @@ class ParquetPartitionDiscoverySuite extends QueryTest with ParquetTest { checkAnswer(read.load(dir.toString).select(fields: _*), row) } } + + test("SPARK-8037: Ignores files whose name starts with dot") { + withTempPath { dir => + val df = (1 to 3).map(i => (i, i, i, i)).toDF("a", "b", "c", "d") + + df.write + .format("parquet") + .partitionBy("b", "c", "d") + .save(dir.getCanonicalPath) + + Files.touch(new File(s"${dir.getCanonicalPath}/b=1", ".DS_Store")) + Files.createParentDirs(new File(s"${dir.getCanonicalPath}/b=1/c=1/.foo/bar")) + + checkAnswer(read.format("parquet").load(dir.getCanonicalPath), df) + } + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetQuerySuite.scala index b98ba09ccfc2d..304936fb2be8e 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetQuerySuite.scala @@ -19,6 +19,7 @@ package org.apache.spark.sql.parquet import org.scalatest.BeforeAndAfterAll +import org.apache.spark.sql.types._ import org.apache.spark.sql.{SQLConf, QueryTest} import org.apache.spark.sql.catalyst.expressions.Row import org.apache.spark.sql.test.TestSQLContext @@ -111,6 +112,18 @@ class ParquetQuerySuiteBase extends QueryTest with ParquetTest { List(Row("same", "run_5", 100))) } } + + test("SPARK-6917 DecimalType should work with non-native types") { + val data = (1 to 10).map(i => Row(Decimal(i, 18, 0), new java.sql.Timestamp(i))) + val schema = StructType(List(StructField("d", DecimalType(18, 0), false), + StructField("time", TimestampType, false)).toArray) + withTempPath { file => + val df = sqlContext.createDataFrame(sparkContext.parallelize(data), schema) + df.write.parquet(file.getCanonicalPath) + val df2 = sqlContext.read.parquet(file.getCanonicalPath) + checkAnswer(df2, df.collect().toSeq) + } + } } class ParquetDataSourceOnQuerySuite extends ParquetQuerySuiteBase with BeforeAndAfterAll { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetSchemaSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetSchemaSuite.scala index c964b6d984557..caec2a6f25489 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetSchemaSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetSchemaSuite.scala @@ -20,14 +20,14 @@ package org.apache.spark.sql.parquet import scala.reflect.ClassTag import scala.reflect.runtime.universe.TypeTag -import org.scalatest.FunSuite import parquet.schema.MessageTypeParser +import org.apache.spark.SparkFunSuite import org.apache.spark.sql.catalyst.ScalaReflection import org.apache.spark.sql.test.TestSQLContext import org.apache.spark.sql.types._ -class ParquetSchemaSuite extends FunSuite with ParquetTest { +class ParquetSchemaSuite extends SparkFunSuite with ParquetTest { val sqlContext = TestSQLContext /** diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/ResolvedDataSourceSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/ResolvedDataSourceSuite.scala index 8331a14c9295c..296b0d6f74a0c 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/sources/ResolvedDataSourceSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/ResolvedDataSourceSuite.scala @@ -17,9 +17,9 @@ package org.apache.spark.sql.sources -import org.scalatest.FunSuite +import org.apache.spark.SparkFunSuite -class ResolvedDataSourceSuite extends FunSuite { +class ResolvedDataSourceSuite extends SparkFunSuite { test("builtin sources") { assert(ResolvedDataSource.lookupDataSource("jdbc") === diff --git a/sql/hive-thriftserver/pom.xml b/sql/hive-thriftserver/pom.xml index 437f697d25bf3..20d3c7d4c5959 100644 --- a/sql/hive-thriftserver/pom.xml +++ b/sql/hive-thriftserver/pom.xml @@ -41,6 +41,13 @@ spark-hive_${scala.binary.version} ${project.version} + + org.apache.spark + spark-core_${scala.binary.version} + ${project.version} + test-jar + test + com.google.guava guava diff --git a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2.scala b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2.scala index 3458b04bfba0f..94687eeda4179 100644 --- a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2.scala +++ b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2.scala @@ -17,23 +17,23 @@ package org.apache.spark.sql.hive.thriftserver +import scala.collection.mutable +import scala.collection.mutable.ArrayBuffer + import org.apache.commons.logging.LogFactory import org.apache.hadoop.hive.conf.HiveConf import org.apache.hadoop.hive.conf.HiveConf.ConfVars import org.apache.hive.service.cli.thrift.{ThriftBinaryCLIService, ThriftHttpCLIService} import org.apache.hive.service.server.{HiveServer2, ServerOptionsProcessor} -import org.apache.spark.sql.SQLConf -import org.apache.spark.{SparkContext, SparkConf, Logging} import org.apache.spark.annotation.DeveloperApi -import org.apache.spark.sql.hive.HiveContext +import org.apache.spark.scheduler.{SparkListener, SparkListenerApplicationEnd, SparkListenerJobStart} +import org.apache.spark.sql.SQLConf import org.apache.spark.sql.hive.thriftserver.ReflectionUtils._ -import org.apache.spark.scheduler.{SparkListenerJobStart, SparkListenerApplicationEnd, SparkListener} import org.apache.spark.sql.hive.thriftserver.ui.ThriftServerTab +import org.apache.spark.sql.hive.{HiveContext, HiveShim} import org.apache.spark.util.Utils - -import scala.collection.mutable -import scala.collection.mutable.ArrayBuffer +import org.apache.spark.{Logging, SparkContext} /** * The main entry point for the Spark SQL port of HiveServer2. Starts up a `SparkSQLContext` and a @@ -51,6 +51,7 @@ object HiveThriftServer2 extends Logging { @DeveloperApi def startWithContext(sqlContext: HiveContext): Unit = { val server = new HiveThriftServer2(sqlContext) + sqlContext.setConf("spark.sql.hive.version", HiveShim.version) server.init(sqlContext.hiveconf) server.start() listener = new HiveThriftServer2Listener(server, sqlContext.conf) diff --git a/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/CliSuite.scala b/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/CliSuite.scala index cc07db827d359..3732af7870b93 100644 --- a/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/CliSuite.scala +++ b/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/CliSuite.scala @@ -25,16 +25,16 @@ import scala.concurrent.{Await, Promise} import scala.sys.process.{Process, ProcessLogger} import org.apache.hadoop.hive.conf.HiveConf.ConfVars -import org.scalatest.{BeforeAndAfter, FunSuite} +import org.scalatest.BeforeAndAfter -import org.apache.spark.Logging +import org.apache.spark.{Logging, SparkFunSuite} import org.apache.spark.util.Utils /** * A test suite for the `spark-sql` CLI tool. Note that all test cases share the same temporary * Hive metastore and warehouse. */ -class CliSuite extends FunSuite with BeforeAndAfter with Logging { +class CliSuite extends SparkFunSuite with BeforeAndAfter with Logging { val warehousePath = Utils.createTempDir() val metastorePath = Utils.createTempDir() diff --git a/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2Suites.scala b/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2Suites.scala index 1fadea97fd07f..a93a3dee43511 100644 --- a/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2Suites.scala +++ b/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2Suites.scala @@ -27,6 +27,8 @@ import scala.concurrent.{Await, Promise} import scala.sys.process.{Process, ProcessLogger} import scala.util.{Random, Try} +import com.google.common.base.Charsets.UTF_8 +import com.google.common.io.Files import org.apache.hadoop.hive.conf.HiveConf.ConfVars import org.apache.hive.jdbc.HiveDriver import org.apache.hive.service.auth.PlainSaslHelper @@ -35,9 +37,9 @@ import org.apache.hive.service.cli.thrift.TCLIService.Client import org.apache.hive.service.cli.thrift.ThriftCLIServiceClient import org.apache.thrift.protocol.TBinaryProtocol import org.apache.thrift.transport.TSocket -import org.scalatest.{BeforeAndAfterAll, FunSuite} +import org.scalatest.BeforeAndAfterAll -import org.apache.spark.Logging +import org.apache.spark.{Logging, SparkFunSuite} import org.apache.spark.sql.hive.HiveShim import org.apache.spark.util.Utils @@ -54,7 +56,7 @@ class HiveThriftBinaryServerSuite extends HiveThriftJdbcTest { override def mode: ServerMode.Value = ServerMode.binary private def withCLIServiceClient(f: ThriftCLIServiceClient => Unit): Unit = { - // Transport creation logics below mimics HiveConnection.createBinaryTransport + // Transport creation logic below mimics HiveConnection.createBinaryTransport val rawTransport = new TSocket("localhost", serverPort) val user = System.getProperty("user.name") val transport = PlainSaslHelper.getPlainTransport(user, "anonymous", rawTransport) @@ -391,10 +393,10 @@ abstract class HiveThriftJdbcTest extends HiveThriftServer2Test { val statements = connections.map(_.createStatement()) try { - statements.zip(fs).map { case (s, f) => f(s) } + statements.zip(fs).foreach { case (s, f) => f(s) } } finally { - statements.map(_.close()) - connections.map(_.close()) + statements.foreach(_.close()) + connections.foreach(_.close()) } } @@ -403,7 +405,7 @@ abstract class HiveThriftJdbcTest extends HiveThriftServer2Test { } } -abstract class HiveThriftServer2Test extends FunSuite with BeforeAndAfterAll with Logging { +abstract class HiveThriftServer2Test extends SparkFunSuite with BeforeAndAfterAll with Logging { def mode: ServerMode.Value private val CLASS_NAME = HiveThriftServer2.getClass.getCanonicalName.stripSuffix("$") @@ -433,15 +435,33 @@ abstract class HiveThriftServer2Test extends FunSuite with BeforeAndAfterAll wit ConfVars.HIVE_SERVER2_THRIFT_HTTP_PORT } + val driverClassPath = { + // Writes a temporary log4j.properties and prepend it to driver classpath, so that it + // overrides all other potential log4j configurations contained in other dependency jar files. + val tempLog4jConf = Utils.createTempDir().getCanonicalPath + + Files.write( + """log4j.rootCategory=INFO, console + |log4j.appender.console=org.apache.log4j.ConsoleAppender + |log4j.appender.console.target=System.err + |log4j.appender.console.layout=org.apache.log4j.PatternLayout + |log4j.appender.console.layout.ConversionPattern=%d{yy/MM/dd HH:mm:ss} %p %c{1}: %m%n + """.stripMargin, + new File(s"$tempLog4jConf/log4j.properties"), + UTF_8) + + tempLog4jConf + File.pathSeparator + sys.props("java.class.path") + } + s"""$startScript | --master local - | --hiveconf hive.root.logger=INFO,console | --hiveconf ${ConfVars.METASTORECONNECTURLKEY}=$metastoreJdbcUri | --hiveconf ${ConfVars.METASTOREWAREHOUSE}=$warehousePath | --hiveconf ${ConfVars.HIVE_SERVER2_THRIFT_BIND_HOST}=localhost | --hiveconf ${ConfVars.HIVE_SERVER2_TRANSPORT_MODE}=$mode | --hiveconf $portConf=$port - | --driver-class-path ${sys.props("java.class.path")} + | --driver-class-path $driverClassPath + | --driver-java-options -Dlog4j.debug | --conf spark.ui.enabled=false """.stripMargin.split("\\s+").toSeq } diff --git a/sql/hive/pom.xml b/sql/hive/pom.xml index 615b07e74d535..923ffabb9b99e 100644 --- a/sql/hive/pom.xml +++ b/sql/hive/pom.xml @@ -41,6 +41,13 @@ spark-core_${scala.binary.version} ${project.version} + + org.apache.spark + spark-core_${scala.binary.version} + ${project.version} + test-jar + test + org.apache.spark spark-sql_${scala.binary.version} diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala index 47b85731587d5..ca1f49b546bd7 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala @@ -596,7 +596,7 @@ private[hive] case class MetastoreRelation self: Product => - override def equals(other: scala.Any): Boolean = other match { + override def equals(other: Any): Boolean = other match { case relation: MetastoreRelation => databaseName == relation.databaseName && tableName == relation.tableName && diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala index 3915ee835685f..a5ca3613c5e00 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala @@ -57,7 +57,7 @@ private[hive] case object NativePlaceholder extends LogicalPlan { override def output: Seq[Attribute] = Seq.empty } -case class CreateTableAsSelect( +private[hive] case class CreateTableAsSelect( tableDesc: HiveTable, child: LogicalPlan, allowExisting: Boolean) extends UnaryNode with Command { @@ -1561,6 +1561,10 @@ https://cwiki.apache.org/confluence/display/Hive/Enhanced+Aggregation%2C+Cube%2C """.stripMargin) } + /* Case insensitive matches for Window Specification */ + val PRECEDING = "(?i)preceding".r + val FOLLOWING = "(?i)following".r + val CURRENT = "(?i)current".r def nodesToWindowSpecification(nodes: Seq[ASTNode]): WindowSpec = nodes match { case Token(windowName, Nil) :: Nil => // Refer to a window spec defined in the window clause. @@ -1614,11 +1618,19 @@ https://cwiki.apache.org/confluence/display/Hive/Enhanced+Aggregation%2C+Cube%2C } else { val frameType = rowFrame.map(_ => RowFrame).getOrElse(RangeFrame) def nodeToBoundary(node: Node): FrameBoundary = node match { - case Token("preceding", Token(count, Nil) :: Nil) => - if (count == "unbounded") UnboundedPreceding else ValuePreceding(count.toInt) - case Token("following", Token(count, Nil) :: Nil) => - if (count == "unbounded") UnboundedFollowing else ValueFollowing(count.toInt) - case Token("current", Nil) => CurrentRow + case Token(PRECEDING(), Token(count, Nil) :: Nil) => + if (count.toLowerCase() == "unbounded") { + UnboundedPreceding + } else { + ValuePreceding(count.toInt) + } + case Token(FOLLOWING(), Token(count, Nil) :: Nil) => + if (count.toLowerCase() == "unbounded") { + UnboundedFollowing + } else { + ValueFollowing(count.toInt) + } + case Token(CURRENT(), Nil) => CurrentRow case _ => throw new NotImplementedError( s"""No parse rules for the Window Frame Boundary based on Node ${node.getName} diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/IsolatedClientLoader.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/IsolatedClientLoader.scala index 196a3d836cab2..16851fdd71a98 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/IsolatedClientLoader.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/IsolatedClientLoader.scala @@ -90,14 +90,14 @@ private[hive] object IsolatedClientLoader { * `ClientInterface`, unless `isolationOn` is set to `false`. * * @param version The version of hive on the classpath. used to pick specific function signatures - * that are not compatibile accross versions. + * that are not compatible across versions. * @param execJars A collection of jar files that must include hive and hadoop. * @param config A set of options that will be added to the HiveConf of the constructed client. * @param isolationOn When true, custom versions of barrier classes will be constructed. Must be * true unless loading the version of hive that is on Sparks classloader. - * @param rootClassLoader The system root classloader. Must not know about hive classes. - * @param baseClassLoader The spark classloader that is used to load shared classes. - * + * @param rootClassLoader The system root classloader. + * @param baseClassLoader The spark classloader that is used to load shared classes. Must not know + * about Hive classes. */ private[hive] class IsolatedClientLoader( val version: HiveVersion, @@ -110,7 +110,7 @@ private[hive] class IsolatedClientLoader( val barrierPrefixes: Seq[String] = Seq.empty) extends Logging { - // Check to make sure that the root classloader does not know about Hive. + // Check to make sure that the base classloader does not know about Hive. assert(Try(baseClassLoader.loadClass("org.apache.hive.HiveConf")).isFailure) /** All jars used by the hive specific classloader. */ diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/ReflectionMagic.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/ReflectionMagic.scala index c600b158c5460..4d053ae42c2ea 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/ReflectionMagic.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/ReflectionMagic.scala @@ -30,7 +30,7 @@ private[client] object ReflectionException { /** * Provides implicit functions on any object for calling methods reflectively. */ -protected trait ReflectionMagic { +private[client] trait ReflectionMagic { /** code for InstanceMagic println( (1 to 22).map { n => diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/package.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/package.scala index 7db9200d47440..410d9881ac214 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/package.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/package.scala @@ -29,5 +29,5 @@ package object client { case object v13 extends HiveVersion("0.13.1", false) } // scalastyle:on - + } \ No newline at end of file diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/HiveTableScan.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/HiveTableScan.scala index 62dc4167b78dd..11ee5503146b9 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/HiveTableScan.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/HiveTableScan.scala @@ -63,7 +63,7 @@ case class HiveTableScan( BindReferences.bindReference(pred, relation.partitionKeys) } - // Create a local copy of hiveconf,so that scan specific modifications should not impact + // Create a local copy of hiveconf,so that scan specific modifications should not impact // other queries @transient private[this] val hiveExtraConf = new HiveConf(context.hiveconf) @@ -72,7 +72,7 @@ case class HiveTableScan( addColumnMetadataToConf(hiveExtraConf) @transient - private[this] val hadoopReader = + private[this] val hadoopReader = new HadoopTableReader(attributes, relation, context, hiveExtraConf) private[this] def castFromString(value: String, dataType: DataType) = { diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/ScriptTransformation.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/ScriptTransformation.scala index 6f27a8626fc1e..fd623370cc407 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/ScriptTransformation.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/ScriptTransformation.scala @@ -62,7 +62,7 @@ case class ScriptTransformation( val inputStream = proc.getInputStream val outputStream = proc.getOutputStream val reader = new BufferedReader(new InputStreamReader(inputStream)) - + val (outputSerde, outputSoi) = ioschema.initOutputSerDe(output) val iterator: Iterator[Row] = new Iterator[Row] with HiveInspectors { @@ -95,7 +95,7 @@ case class ScriptTransformation( val raw = outputSerde.deserialize(writable) val dataList = outputSoi.getStructFieldsDataAsList(raw) val fieldList = outputSoi.getAllStructFieldRefs() - + var i = 0 dataList.foreach( element => { if (element == null) { @@ -117,7 +117,7 @@ case class ScriptTransformation( if (!hasNext) { throw new NoSuchElementException } - + if (outputSerde == null) { val prevLine = curLine curLine = reader.readLine() @@ -192,7 +192,7 @@ case class HiveScriptIOSchema ( val inputRowFormatMap = inputRowFormat.toMap.withDefault((k) => defaultFormat(k)) val outputRowFormatMap = outputRowFormat.toMap.withDefault((k) => defaultFormat(k)) - + def initInputSerDe(input: Seq[Expression]): (AbstractSerDe, ObjectInspector) = { val (columns, columnTypes) = parseAttrs(input) val serde = initSerDe(inputSerdeClass, columns, columnTypes, inputSerdeProps) @@ -206,13 +206,13 @@ case class HiveScriptIOSchema ( } def parseAttrs(attrs: Seq[Expression]): (Seq[String], Seq[DataType]) = { - + val columns = attrs.map { case aref: AttributeReference => aref.name case e: NamedExpression => e.name case _ => null } - + val columnTypes = attrs.map { case aref: AttributeReference => aref.dataType case e: NamedExpression => e.dataType @@ -221,7 +221,7 @@ case class HiveScriptIOSchema ( (columns, columnTypes) } - + def initSerDe(serdeClassName: String, columns: Seq[String], columnTypes: Seq[DataType], serdeProps: Seq[(String, String)]): AbstractSerDe = { @@ -240,7 +240,7 @@ case class HiveScriptIOSchema ( (kv._1.split("'")(1), kv._2.split("'")(1)) }).toMap + (serdeConstants.LIST_COLUMNS -> columns.mkString(",")) propsMap = propsMap + (serdeConstants.LIST_COLUMN_TYPES -> columnTypesNames) - + val properties = new Properties() properties.putAll(propsMap) serde.initialize(null, properties) @@ -261,7 +261,7 @@ case class HiveScriptIOSchema ( null } } - + def initOutputputSoi(outputSerde: AbstractSerDe): StructObjectInspector = { if (outputSerde != null) { outputSerde.getObjectInspector().asInstanceOf[StructObjectInspector] diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUdfs.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUdfs.scala index bb116e3ab7de7..1658bb93b0b79 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUdfs.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUdfs.scala @@ -78,6 +78,8 @@ private[hive] case class HiveSimpleUdf(funcWrapper: HiveFunctionWrapper, childre type UDFType = UDF + override def deterministic: Boolean = isUDFDeterministic + override def nullable: Boolean = true @transient @@ -140,6 +142,8 @@ private[hive] case class HiveGenericUdf(funcWrapper: HiveFunctionWrapper, childr extends Expression with HiveInspectors with Logging { type UDFType = GenericUDF + override def deterministic: Boolean = isUDFDeterministic + override def nullable: Boolean = true @transient @@ -555,12 +559,12 @@ private[hive] case class HiveUdafFunction( } else { funcWrapper.createFunction[AbstractGenericUDAFResolver]() } - + private val inspectors = exprs.map(toInspector).toArray - - private val function = { + + private val function = { val parameterInfo = new SimpleGenericUDAFParameterInfo(inspectors, false, false) - resolver.getEvaluator(parameterInfo) + resolver.getEvaluator(parameterInfo) } private val returnInspector = function.init(GenericUDAFEvaluator.Mode.COMPLETE, inspectors) @@ -575,7 +579,7 @@ private[hive] case class HiveUdafFunction( @transient protected lazy val cached = new Array[AnyRef](exprs.length) - + def update(input: Row): Unit = { val inputs = inputProjection(input) function.iterate(buffer, wrap(inputs, inspectors, cached)) diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/CachedTableSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/CachedTableSuite.scala index 945596db80326..39d315aaeab57 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/CachedTableSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/CachedTableSuite.scala @@ -57,7 +57,7 @@ class CachedTableSuite extends QueryTest { checkAnswer( sql("SELECT * FROM src s"), preCacheResults) - + uncacheTable("src") assertCached(sql("SELECT * FROM src"), 0) } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveInspectorSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveInspectorSuite.scala index 80c2d32bf70d7..df137e7b2b333 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveInspectorSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveInspectorSuite.scala @@ -26,12 +26,12 @@ import org.apache.hadoop.hive.serde2.objectinspector.{ObjectInspector, ObjectIns import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorFactory.ObjectInspectorOptions import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorFactory import org.apache.hadoop.io.LongWritable -import org.scalatest.FunSuite +import org.apache.spark.SparkFunSuite import org.apache.spark.sql.catalyst.expressions.{Literal, Row} import org.apache.spark.sql.types._ -class HiveInspectorSuite extends FunSuite with HiveInspectors { +class HiveInspectorSuite extends SparkFunSuite with HiveInspectors { test("Test wrap SettableStructObjectInspector") { val udaf = new UDAFPercentile.PercentileLongEvaluator() udaf.init() diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveMetastoreCatalogSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveMetastoreCatalogSuite.scala index fa8e11ffec2b4..e9bb32667936c 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveMetastoreCatalogSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveMetastoreCatalogSuite.scala @@ -17,13 +17,13 @@ package org.apache.spark.sql.hive +import org.apache.spark.SparkFunSuite import org.apache.spark.sql.hive.test.TestHive -import org.scalatest.FunSuite import org.apache.spark.sql.test.ExamplePointUDT import org.apache.spark.sql.types.StructType -class HiveMetastoreCatalogSuite extends FunSuite { +class HiveMetastoreCatalogSuite extends SparkFunSuite { test("struct field should accept underscore in sub-column name") { val metastr = "struct" diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveQlSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveQlSuite.scala index 941a2941649b8..f765395e148af 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveQlSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveQlSuite.scala @@ -20,12 +20,13 @@ package org.apache.spark.sql.hive import org.apache.hadoop.hive.conf.HiveConf import org.apache.hadoop.hive.ql.session.SessionState import org.apache.hadoop.hive.serde.serdeConstants +import org.apache.spark.SparkFunSuite import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.hive.client.{ManagedTable, HiveColumn, ExternalTable, HiveTable} -import org.scalatest.{BeforeAndAfterAll, FunSuite} +import org.scalatest.BeforeAndAfterAll -class HiveQlSuite extends FunSuite with BeforeAndAfterAll { +class HiveQlSuite extends SparkFunSuite with BeforeAndAfterAll { override def beforeAll() { if (SessionState.get() == null) { SessionState.start(new HiveConf()) diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/InsertIntoHiveTableSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/InsertIntoHiveTableSuite.scala index 9cc4685499f19..aa5dbe2db6903 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/InsertIntoHiveTableSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/InsertIntoHiveTableSuite.scala @@ -240,7 +240,7 @@ class InsertIntoHiveTableSuite extends QueryTest with BeforeAndAfter { checkAnswer(sql("select key,value from table_with_partition where ds='1' "), testData.collect().toSeq ) - + // test difference type of field sql("ALTER TABLE table_with_partition CHANGE COLUMN key key BIGINT") checkAnswer(sql("select key,value from table_with_partition where ds='1' "), diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/SerializationSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/SerializationSuite.scala index 8afe5459d4f1b..a492ecf203d17 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/SerializationSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/SerializationSuite.scala @@ -17,13 +17,11 @@ package org.apache.spark.sql.hive -import org.scalatest.FunSuite - -import org.apache.spark.SparkConf +import org.apache.spark.{SparkConf, SparkFunSuite} import org.apache.spark.serializer.JavaSerializer import org.apache.spark.sql.hive.test.TestHive -class SerializationSuite extends FunSuite { +class SerializationSuite extends SparkFunSuite { test("[SPARK-5840] HiveContext should be serializable") { val hiveContext = TestHive diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/VersionsSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/VersionsSuite.scala index 321dc8d7322b8..7eb4842726665 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/VersionsSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/VersionsSuite.scala @@ -17,18 +17,17 @@ package org.apache.spark.sql.hive.client -import org.apache.spark.Logging +import org.apache.spark.{Logging, SparkFunSuite} import org.apache.spark.sql.catalyst.util.quietly import org.apache.spark.util.Utils -import org.scalatest.FunSuite /** - * A simple set of tests that call the methods of a hive ClientInterface, loading different version - * of hive from maven central. These tests are simple in that they are mostly just testing to make - * sure that reflective calls are not throwing NoSuchMethod error, but the actually functionallity + * A simple set of tests that call the methods of a hive ClientInterface, loading different version + * of hive from maven central. These tests are simple in that they are mostly just testing to make + * sure that reflective calls are not throwing NoSuchMethod error, but the actually functionality * is not fully tested. */ -class VersionsSuite extends FunSuite with Logging { +class VersionsSuite extends SparkFunSuite with Logging { private def buildConf() = { lazy val warehousePath = Utils.createTempDir() lazy val metastorePath = Utils.createTempDir() diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/ConcurrentHiveSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/ConcurrentHiveSuite.scala index 23ece7e7cf6e9..b0d3dd44daedc 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/ConcurrentHiveSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/ConcurrentHiveSuite.scala @@ -17,11 +17,11 @@ package org.apache.spark.sql.hive.execution -import org.apache.spark.{SparkConf, SparkContext} +import org.apache.spark.{SparkConf, SparkContext, SparkFunSuite} import org.apache.spark.sql.hive.test.TestHiveContext -import org.scalatest.{BeforeAndAfterAll, FunSuite} +import org.scalatest.BeforeAndAfterAll -class ConcurrentHiveSuite extends FunSuite with BeforeAndAfterAll { +class ConcurrentHiveSuite extends SparkFunSuite with BeforeAndAfterAll { ignore("multiple instances not supported") { test("Multiple Hive Instances") { (1 to 10).map { i => diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveComparisonTest.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveComparisonTest.scala index 55e5551b63818..c9dd4c0935a72 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveComparisonTest.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveComparisonTest.scala @@ -19,9 +19,9 @@ package org.apache.spark.sql.hive.execution import java.io._ -import org.scalatest.{BeforeAndAfterAll, FunSuite, GivenWhenThen} +import org.scalatest.{BeforeAndAfterAll, GivenWhenThen} -import org.apache.spark.Logging +import org.apache.spark.{Logging, SparkFunSuite} import org.apache.spark.sql.sources.DescribeCommand import org.apache.spark.sql.execution.{SetCommand, ExplainCommand} import org.apache.spark.sql.catalyst.planning.PhysicalOperation @@ -40,7 +40,7 @@ import org.apache.spark.sql.hive.test.TestHive * configured using system properties. */ abstract class HiveComparisonTest - extends FunSuite with BeforeAndAfterAll with GivenWhenThen with Logging { + extends SparkFunSuite with BeforeAndAfterAll with GivenWhenThen with Logging { /** * When set, any cache files that result in test failures will be deleted. Used when the test diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala index 4af31d482ce42..440b7c87b0da2 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala @@ -57,7 +57,7 @@ class HiveQuerySuite extends HiveComparisonTest with BeforeAndAfter { // https://cwiki.apache.org/confluence/display/Hive/DeveloperGuide+UDTF sql( """ - |CREATE TEMPORARY FUNCTION udtf_count2 + |CREATE TEMPORARY FUNCTION udtf_count2 |AS 'org.apache.spark.sql.hive.execution.GenericUDTFCount2' """.stripMargin) } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveTableScanSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveTableScanSuite.scala index 0ba4d11478211..2209fc2f30a3c 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveTableScanSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveTableScanSuite.scala @@ -61,7 +61,7 @@ class HiveTableScanSuite extends HiveComparisonTest { TestHive.sql("select KEY from tb where VALUE='just_for_test' limit 5").collect() TestHive.sql("drop table tb") } - + test("Spark-4077: timestamp query for null value") { TestHive.sql("DROP TABLE IF EXISTS timestamp_query_null") TestHive.sql( @@ -71,11 +71,11 @@ class HiveTableScanSuite extends HiveComparisonTest { FIELDS TERMINATED BY ',' LINES TERMINATED BY '\n' """.stripMargin) - val location = + val location = Utils.getSparkClassLoader.getResource("data/files/issue-4077-data.txt").getFile() - + TestHive.sql(s"LOAD DATA LOCAL INPATH '$location' INTO TABLE timestamp_query_null") - assert(TestHive.sql("SELECT time from timestamp_query_null limit 2").collect() + assert(TestHive.sql("SELECT time from timestamp_query_null limit 2").collect() === Array(Row(java.sql.Timestamp.valueOf("2014-12-11 00:00:00")), Row(null))) TestHive.sql("DROP TABLE timestamp_query_null") } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveUdfSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveUdfSuite.scala index 7f49eac490572..ce5985888f540 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveUdfSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveUdfSuite.scala @@ -101,7 +101,7 @@ class HiveUdfSuite extends QueryTest { sql("DROP TEMPORARY FUNCTION IF EXISTS test_avg") TestHive.reset() } - + test("SPARK-2693 udaf aggregates test") { checkAnswer(sql("SELECT percentile(key, 1) FROM src LIMIT 1"), sql("SELECT max(key) FROM src").collect().toSeq) diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala index 27863a60145d7..aba3becb1bce2 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala @@ -780,6 +780,42 @@ class SQLQuerySuite extends QueryTest { ).map(i => Row(i._1, i._2, i._3, i._4))) } + test("window function: multiple window expressions in a single expression") { + val nums = sparkContext.parallelize(1 to 10).map(x => (x, x % 2)).toDF("x", "y") + nums.registerTempTable("nums") + + val expected = + Row(1, 1, 1, 55, 1, 57) :: + Row(0, 2, 3, 55, 2, 60) :: + Row(1, 3, 6, 55, 4, 65) :: + Row(0, 4, 10, 55, 6, 71) :: + Row(1, 5, 15, 55, 9, 79) :: + Row(0, 6, 21, 55, 12, 88) :: + Row(1, 7, 28, 55, 16, 99) :: + Row(0, 8, 36, 55, 20, 111) :: + Row(1, 9, 45, 55, 25, 125) :: + Row(0, 10, 55, 55, 30, 140) :: Nil + + val actual = sql( + """ + |SELECT + | y, + | x, + | sum(x) OVER w1 AS running_sum, + | sum(x) OVER w2 AS total_sum, + | sum(x) OVER w3 AS running_sum_per_y, + | ((sum(x) OVER w1) + (sum(x) OVER w2) + (sum(x) OVER w3)) as combined2 + |FROM nums + |WINDOW w1 AS (ORDER BY x ROWS BETWEEN UnBOUNDED PRECEDiNG AND CuRRENT RoW), + | w2 AS (ORDER BY x ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOuNDED FoLLOWING), + | w3 AS (PARTITION BY y ORDER BY x ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW) + """.stripMargin) + + checkAnswer(actual, expected) + + dropTempTable("nums") + } + test("test case key when") { (1 to 5).map(i => (i, i.toString)).toDF("k", "v").registerTempTable("t") checkAnswer( diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcPartitionDiscoverySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcPartitionDiscoverySuite.scala index 88c99e35260d9..0e63d84e9824a 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcPartitionDiscoverySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcPartitionDiscoverySuite.scala @@ -19,13 +19,14 @@ package org.apache.spark.sql.hive.orc import java.io.File import org.apache.hadoop.hive.conf.HiveConf.ConfVars +import org.apache.spark.SparkFunSuite import org.apache.spark.sql._ import org.apache.spark.sql.catalyst.expressions.Row import org.apache.spark.sql.hive.test.TestHive import org.apache.spark.sql.hive.test.TestHive._ import org.apache.spark.sql.hive.test.TestHive.implicits._ import org.apache.spark.util.Utils -import org.scalatest.{BeforeAndAfterAll, FunSuiteLike} +import org.scalatest.BeforeAndAfterAll import scala.reflect.ClassTag import scala.reflect.runtime.universe.TypeTag @@ -38,7 +39,7 @@ case class OrcParData(intField: Int, stringField: String) case class OrcParDataWithKey(intField: Int, pi: Int, stringField: String, ps: String) // TODO This test suite duplicates ParquetPartitionDiscoverySuite a lot -class OrcPartitionDiscoverySuite extends QueryTest with FunSuiteLike with BeforeAndAfterAll { +class OrcPartitionDiscoverySuite extends QueryTest with BeforeAndAfterAll { val defaultPartitionName = ConfVars.DEFAULTPARTITIONNAME.defaultVal def withTempDir(f: File => Unit): Unit = { diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcQuerySuite.scala index cdd6e705f4a2c..57c23fe77f8b5 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcQuerySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcQuerySuite.scala @@ -21,8 +21,9 @@ import java.io.File import org.apache.hadoop.hive.conf.HiveConf.ConfVars import org.apache.hadoop.hive.ql.io.orc.CompressionKind -import org.scalatest.{BeforeAndAfterAll, FunSuiteLike} +import org.scalatest.BeforeAndAfterAll +import org.apache.spark.SparkFunSuite import org.apache.spark.sql._ import org.apache.spark.sql.catalyst.expressions.Row import org.apache.spark.sql.hive.test.TestHive @@ -50,7 +51,7 @@ case class Contact(name: String, phone: String) case class Person(name: String, age: Int, contacts: Seq[Contact]) -class OrcQuerySuite extends QueryTest with FunSuiteLike with BeforeAndAfterAll with OrcTest { +class OrcQuerySuite extends QueryTest with BeforeAndAfterAll with OrcTest { override val sqlContext = TestHive import TestHive.read diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/sources/hadoopFsRelationSuites.scala b/sql/hive/src/test/scala/org/apache/spark/sql/sources/hadoopFsRelationSuites.scala index cf5ae88dc4bee..74095426741e3 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/sources/hadoopFsRelationSuites.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/sources/hadoopFsRelationSuites.scala @@ -17,10 +17,12 @@ package org.apache.spark.sql.sources +import java.io.File + +import com.google.common.io.Files import org.apache.hadoop.fs.Path -import org.scalatest.FunSuite -import org.apache.spark.SparkException +import org.apache.spark.{SparkException, SparkFunSuite} import org.apache.spark.deploy.SparkHadoopUtil import org.apache.spark.sql._ import org.apache.spark.sql.hive.test.TestHive @@ -454,6 +456,20 @@ abstract class HadoopFsRelationTest extends QueryTest with SQLTestUtils { } } } + + test("SPARK-7616: adjust column name order accordingly when saving partitioned table") { + val df = (1 to 3).map(i => (i, s"val_$i", i * 2)).toDF("a", "b", "c") + + df.write + .format(dataSourceName) + .mode(SaveMode.Overwrite) + .partitionBy("c", "a") + .saveAsTable("t") + + withTable("t") { + checkAnswer(table("t"), df.select('b, 'c, 'a).collect()) + } + } } class SimpleTextHadoopFsRelationSuite extends HadoopFsRelationTest { @@ -485,7 +501,7 @@ class SimpleTextHadoopFsRelationSuite extends HadoopFsRelationTest { } } -class CommitFailureTestRelationSuite extends FunSuite with SQLTestUtils { +class CommitFailureTestRelationSuite extends SparkFunSuite with SQLTestUtils { import TestHive.implicits._ override val sqlContext = TestHive @@ -535,20 +551,6 @@ class ParquetHadoopFsRelationSuite extends HadoopFsRelationTest { } } - test("SPARK-7616: adjust column name order accordingly when saving partitioned table") { - val df = (1 to 3).map(i => (i, s"val_$i", i * 2)).toDF("a", "b", "c") - - df.write - .format("parquet") - .mode(SaveMode.Overwrite) - .partitionBy("c", "a") - .saveAsTable("t") - - withTable("t") { - checkAnswer(table("t"), df.select('b, 'c, 'a).collect()) - } - } - test("SPARK-7868: _temporary directories should be ignored") { withTempPath { dir => val df = Seq("a", "b", "c").zipWithIndex.toDF() @@ -564,4 +566,32 @@ class ParquetHadoopFsRelationSuite extends HadoopFsRelationTest { checkAnswer(read.format("parquet").load(dir.getCanonicalPath), df.collect()) } } + + test("SPARK-8014: Avoid scanning output directory when SaveMode isn't SaveMode.Append") { + withTempDir { dir => + val path = dir.getCanonicalPath + val df = Seq(1 -> "a").toDF() + + // Creates an arbitrary file. If this directory gets scanned, ParquetRelation2 will throw + // since it's not a valid Parquet file. + val emptyFile = new File(path, "empty") + Files.createParentDirs(emptyFile) + Files.touch(emptyFile) + + // This shouldn't throw anything. + df.write.format("parquet").mode(SaveMode.Ignore).save(path) + + // This should only complain that the destination directory already exists, rather than file + // "empty" is not a Parquet file. + assert { + intercept[RuntimeException] { + df.write.format("parquet").mode(SaveMode.ErrorIfExists).save(path) + }.getMessage.contains("already exists") + } + + // This shouldn't throw anything. + df.write.format("parquet").mode(SaveMode.Overwrite).save(path) + checkAnswer(read.format("parquet").load(path), df) + } + } } diff --git a/streaming/pom.xml b/streaming/pom.xml index 5ab7f4472c38b..49d035a1e9696 100644 --- a/streaming/pom.xml +++ b/streaming/pom.xml @@ -40,6 +40,13 @@ spark-core_${scala.binary.version} ${project.version} + + org.apache.spark + spark-core_${scala.binary.version} + ${project.version} + test-jar + test + diff --git a/streaming/src/main/scala/org/apache/spark/streaming/StreamingContext.scala b/streaming/src/main/scala/org/apache/spark/streaming/StreamingContext.scala index 25842d502543e..9cd9684d36404 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/StreamingContext.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/StreamingContext.scala @@ -23,6 +23,7 @@ import java.util.concurrent.atomic.{AtomicInteger, AtomicReference} import scala.collection.Map import scala.collection.mutable.Queue import scala.reflect.ClassTag +import scala.util.control.NonFatal import akka.actor.{Props, SupervisorStrategy} import org.apache.hadoop.conf.Configuration @@ -270,6 +271,8 @@ class StreamingContext private[streaming] ( * Create an input stream with any arbitrary user implemented receiver. * Find more details at: http://spark.apache.org/docs/latest/streaming-custom-receivers.html * @param receiver Custom implementation of Receiver + * + * @deprecated As of 1.0.0", replaced by `receiverStream`. */ @deprecated("Use receiverStream", "1.0.0") def networkStream[T: ClassTag](receiver: Receiver[T]): ReceiverInputDStream[T] = { @@ -576,18 +579,26 @@ class StreamingContext private[streaming] ( def start(): Unit = synchronized { state match { case INITIALIZED => - validate() startSite.set(DStream.getCreationSite()) sparkContext.setCallSite(startSite.get) StreamingContext.ACTIVATION_LOCK.synchronized { StreamingContext.assertNoOtherContextIsActive() - scheduler.start() - uiTab.foreach(_.attach()) - state = StreamingContextState.ACTIVE + try { + validate() + scheduler.start() + state = StreamingContextState.ACTIVE + } catch { + case NonFatal(e) => + logError("Error starting the context, marking it as stopped", e) + scheduler.stop(false) + state = StreamingContextState.STOPPED + throw e + } StreamingContext.setActiveContext(this) } shutdownHookRef = Utils.addShutdownHook( StreamingContext.SHUTDOWN_HOOK_PRIORITY)(stopOnShutdown) + uiTab.foreach(_.attach()) logInfo("StreamingContext started") case ACTIVE => logWarning("StreamingContext has already been started") @@ -608,6 +619,8 @@ class StreamingContext private[streaming] ( * Wait for the execution to stop. Any exceptions that occurs during the execution * will be thrown in this thread. * @param timeout time to wait in milliseconds + * + * @deprecated As of 1.3.0, replaced by `awaitTerminationOrTimeout(Long)`. */ @deprecated("Use awaitTerminationOrTimeout(Long) instead", "1.3.0") def awaitTermination(timeout: Long) { @@ -732,6 +745,10 @@ object StreamingContext extends Logging { } } + /** + * @deprecated As of 1.3.0, replaced by implicit functions in the DStream companion object. + * This is kept here only for backward compatibility. + */ @deprecated("Replaced by implicit functions in the DStream companion object. This is " + "kept here only for backward compatibility.", "1.3.0") def toPairDStreamFunctions[K, V](stream: DStream[(K, V)]) diff --git a/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaStreamingContext.scala b/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaStreamingContext.scala index b639b94d5ca47..989e3a729ebc2 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaStreamingContext.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaStreamingContext.scala @@ -148,6 +148,9 @@ class JavaStreamingContext(val ssc: StreamingContext) extends Closeable { /** The underlying SparkContext */ val sparkContext = new JavaSparkContext(ssc.sc) + /** + * @deprecated As of 0.9.0, replaced by `sparkContext` + */ @deprecated("use sparkContext", "0.9.0") val sc: JavaSparkContext = sparkContext @@ -619,6 +622,7 @@ class JavaStreamingContext(val ssc: StreamingContext) extends Closeable { * Wait for the execution to stop. Any exceptions that occurs during the execution * will be thrown in this thread. * @param timeout time to wait in milliseconds + * @deprecated As of 1.3.0, replaced by `awaitTerminationOrTimeout(Long)`. */ @deprecated("Use awaitTerminationOrTimeout(Long) instead", "1.3.0") def awaitTermination(timeout: Long): Unit = { @@ -677,6 +681,7 @@ object JavaStreamingContext { * * @param checkpointPath Checkpoint directory used in an earlier JavaStreamingContext program * @param factory JavaStreamingContextFactory object to create a new JavaStreamingContext + * @deprecated As of 1.4.0, replaced by `getOrCreate` without JavaStreamingContextFactor. */ @deprecated("use getOrCreate without JavaStreamingContextFactor", "1.4.0") def getOrCreate( @@ -699,6 +704,7 @@ object JavaStreamingContext { * @param factory JavaStreamingContextFactory object to create a new JavaStreamingContext * @param hadoopConf Hadoop configuration if necessary for reading from any HDFS compatible * file system + * @deprecated As of 1.4.0, replaced by `getOrCreate` without JavaStreamingContextFactor. */ @deprecated("use getOrCreate without JavaStreamingContextFactory", "1.4.0") def getOrCreate( @@ -724,6 +730,7 @@ object JavaStreamingContext { * file system * @param createOnError Whether to create a new JavaStreamingContext if there is an * error in reading checkpoint data. + * @deprecated As of 1.4.0, replaced by `getOrCreate` without JavaStreamingContextFactor. */ @deprecated("use getOrCreate without JavaStreamingContextFactory", "1.4.0") def getOrCreate( diff --git a/streaming/src/main/scala/org/apache/spark/streaming/dstream/DStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/dstream/DStream.scala index 6efcc193bfccc..192aa6a139bcb 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/dstream/DStream.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/dstream/DStream.scala @@ -603,6 +603,8 @@ abstract class DStream[T: ClassTag] ( /** * Apply a function to each RDD in this DStream. This is an output operator, so * 'this' DStream will be registered as an output stream and therefore materialized. + * + * @deprecated As of 0.9.0, replaced by `foreachRDD`. */ @deprecated("use foreachRDD", "0.9.0") def foreach(foreachFunc: RDD[T] => Unit): Unit = ssc.withScope { @@ -612,6 +614,8 @@ abstract class DStream[T: ClassTag] ( /** * Apply a function to each RDD in this DStream. This is an output operator, so * 'this' DStream will be registered as an output stream and therefore materialized. + * + * @deprecated As of 0.9.0, replaced by `foreachRDD`. */ @deprecated("use foreachRDD", "0.9.0") def foreach(foreachFunc: (RDD[T], Time) => Unit): Unit = ssc.withScope { diff --git a/streaming/src/main/scala/org/apache/spark/streaming/receiver/BlockGenerator.scala b/streaming/src/main/scala/org/apache/spark/streaming/receiver/BlockGenerator.scala index 0588517a2de39..8d73593ab6375 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/receiver/BlockGenerator.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/receiver/BlockGenerator.scala @@ -191,7 +191,7 @@ private[streaming] class BlockGenerator( logError(message, t) listener.onError(message, t) } - + private def pushBlock(block: Block) { listener.onPushBlock(block.id, block.buffer) logInfo("Pushed block " + block.id) diff --git a/streaming/src/main/scala/org/apache/spark/streaming/receiver/ReceivedBlockHandler.scala b/streaming/src/main/scala/org/apache/spark/streaming/receiver/ReceivedBlockHandler.scala index 651b534ac1900..207d64d9414ee 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/receiver/ReceivedBlockHandler.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/receiver/ReceivedBlockHandler.scala @@ -62,7 +62,7 @@ private[streaming] case class BlockManagerBasedStoreResult(blockId: StreamBlockI private[streaming] class BlockManagerBasedBlockHandler( blockManager: BlockManager, storageLevel: StorageLevel) extends ReceivedBlockHandler with Logging { - + def storeBlock(blockId: StreamBlockId, block: ReceivedBlock): ReceivedBlockStoreResult = { val putResult: Seq[(BlockId, BlockStatus)] = block match { case ArrayBufferBlock(arrayBuffer) => diff --git a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobScheduler.scala b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobScheduler.scala index 1d1ddaaccf217..4af9b6d3b56ab 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobScheduler.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobScheduler.scala @@ -126,6 +126,10 @@ class JobScheduler(val ssc: StreamingContext) extends Logging { eventLoop.post(ErrorReported(msg, e)) } + def isStarted(): Boolean = synchronized { + eventLoop != null + } + private def processEvent(event: JobSchedulerEvent) { try { event match { diff --git a/streaming/src/test/scala/org/apache/spark/streaming/DStreamClosureSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/DStreamClosureSuite.scala index 6a1dd6949b204..9b5e4dc819a2b 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/DStreamClosureSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/DStreamClosureSuite.scala @@ -19,9 +19,9 @@ package org.apache.spark.streaming import java.io.NotSerializableException -import org.scalatest.{BeforeAndAfterAll, FunSuite} +import org.scalatest.BeforeAndAfterAll -import org.apache.spark.{HashPartitioner, SparkContext, SparkException} +import org.apache.spark.{HashPartitioner, SparkContext, SparkException, SparkFunSuite} import org.apache.spark.rdd.RDD import org.apache.spark.streaming.dstream.DStream import org.apache.spark.util.ReturnStatementInClosureException @@ -29,7 +29,7 @@ import org.apache.spark.util.ReturnStatementInClosureException /** * Test that closures passed to DStream operations are actually cleaned. */ -class DStreamClosureSuite extends FunSuite with BeforeAndAfterAll { +class DStreamClosureSuite extends SparkFunSuite with BeforeAndAfterAll { private var ssc: StreamingContext = null override def beforeAll(): Unit = { diff --git a/streaming/src/test/scala/org/apache/spark/streaming/DStreamScopeSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/DStreamScopeSuite.scala index e3fb2ef130859..8844c9d74b933 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/DStreamScopeSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/DStreamScopeSuite.scala @@ -17,9 +17,9 @@ package org.apache.spark.streaming -import org.scalatest.{BeforeAndAfter, BeforeAndAfterAll, FunSuite} +import org.scalatest.{BeforeAndAfter, BeforeAndAfterAll} -import org.apache.spark.SparkContext +import org.apache.spark.{SparkContext, SparkFunSuite} import org.apache.spark.rdd.RDDOperationScope import org.apache.spark.streaming.dstream.DStream import org.apache.spark.streaming.ui.UIUtils @@ -27,7 +27,7 @@ import org.apache.spark.streaming.ui.UIUtils /** * Tests whether scope information is passed from DStream operations to RDDs correctly. */ -class DStreamScopeSuite extends FunSuite with BeforeAndAfter with BeforeAndAfterAll { +class DStreamScopeSuite extends SparkFunSuite with BeforeAndAfter with BeforeAndAfterAll { private var ssc: StreamingContext = null private val batchDuration: Duration = Seconds(1) diff --git a/streaming/src/test/scala/org/apache/spark/streaming/ReceivedBlockHandlerSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/ReceivedBlockHandlerSuite.scala index 23804237bda80..cca8cedb1d080 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/ReceivedBlockHandlerSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/ReceivedBlockHandlerSuite.scala @@ -25,7 +25,7 @@ import scala.concurrent.duration._ import scala.language.postfixOps import org.apache.hadoop.conf.Configuration -import org.scalatest.{BeforeAndAfter, FunSuite, Matchers} +import org.scalatest.{BeforeAndAfter, Matchers} import org.scalatest.concurrent.Eventually._ import org.apache.spark._ @@ -41,7 +41,11 @@ import org.apache.spark.util.{ManualClock, Utils} import WriteAheadLogBasedBlockHandler._ import WriteAheadLogSuite._ -class ReceivedBlockHandlerSuite extends FunSuite with BeforeAndAfter with Matchers with Logging { +class ReceivedBlockHandlerSuite + extends SparkFunSuite + with BeforeAndAfter + with Matchers + with Logging { val conf = new SparkConf().set("spark.streaming.receiver.writeAheadLog.rollingIntervalSecs", "1") val hadoopConf = new Configuration() diff --git a/streaming/src/test/scala/org/apache/spark/streaming/ReceivedBlockTrackerSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/ReceivedBlockTrackerSuite.scala index b1af8d5eaacfb..6f0ee774cb5cf 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/ReceivedBlockTrackerSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/ReceivedBlockTrackerSuite.scala @@ -25,10 +25,10 @@ import scala.language.{implicitConversions, postfixOps} import scala.util.Random import org.apache.hadoop.conf.Configuration -import org.scalatest.{BeforeAndAfter, FunSuite, Matchers} +import org.scalatest.{BeforeAndAfter, Matchers} import org.scalatest.concurrent.Eventually._ -import org.apache.spark.{Logging, SparkConf, SparkException} +import org.apache.spark.{Logging, SparkConf, SparkException, SparkFunSuite} import org.apache.spark.storage.StreamBlockId import org.apache.spark.streaming.receiver.BlockManagerBasedStoreResult import org.apache.spark.streaming.scheduler._ @@ -37,7 +37,7 @@ import org.apache.spark.streaming.util.WriteAheadLogSuite._ import org.apache.spark.util.{Clock, ManualClock, SystemClock, Utils} class ReceivedBlockTrackerSuite - extends FunSuite with BeforeAndAfter with Matchers with Logging { + extends SparkFunSuite with BeforeAndAfter with Matchers with Logging { val hadoopConf = new Configuration() val akkaTimeout = 10 seconds diff --git a/streaming/src/test/scala/org/apache/spark/streaming/StreamingContextSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/StreamingContextSuite.scala index e36c7914b130e..819dd2ccfe915 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/StreamingContextSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/StreamingContextSuite.scala @@ -25,16 +25,16 @@ import org.scalatest.concurrent.Eventually._ import org.scalatest.concurrent.Timeouts import org.scalatest.exceptions.TestFailedDueToTimeoutException import org.scalatest.time.SpanSugar._ -import org.scalatest.{Assertions, BeforeAndAfter, FunSuite} +import org.scalatest.{Assertions, BeforeAndAfter} import org.apache.spark.storage.StorageLevel import org.apache.spark.streaming.dstream.DStream import org.apache.spark.streaming.receiver.Receiver import org.apache.spark.util.Utils -import org.apache.spark.{Logging, SparkConf, SparkContext, SparkException} +import org.apache.spark.{Logging, SparkConf, SparkContext, SparkException, SparkFunSuite} -class StreamingContextSuite extends FunSuite with BeforeAndAfter with Timeouts with Logging { +class StreamingContextSuite extends SparkFunSuite with BeforeAndAfter with Timeouts with Logging { val master = "local[2]" val appName = this.getClass.getSimpleName @@ -151,6 +151,22 @@ class StreamingContextSuite extends FunSuite with BeforeAndAfter with Timeouts w assert(StreamingContext.getActive().isEmpty) } + test("start failure should stop internal components") { + ssc = new StreamingContext(conf, batchDuration) + val inputStream = addInputStream(ssc) + val updateFunc = (values: Seq[Int], state: Option[Int]) => { + Some(values.sum + state.getOrElse(0)) + } + inputStream.map(x => (x, 1)).updateStateByKey[Int](updateFunc) + // Require that the start fails because checkpoint directory was not set + intercept[Exception] { + ssc.start() + } + assert(ssc.getState() === StreamingContextState.STOPPED) + assert(ssc.scheduler.isStarted === false) + } + + test("start multiple times") { ssc = new StreamingContext(master, appName, batchDuration) addInputStream(ssc).register() diff --git a/streaming/src/test/scala/org/apache/spark/streaming/TestSuiteBase.scala b/streaming/src/test/scala/org/apache/spark/streaming/TestSuiteBase.scala index 554cd30223f44..31b1aebf6a8ec 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/TestSuiteBase.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/TestSuiteBase.scala @@ -24,12 +24,12 @@ import scala.collection.mutable.SynchronizedBuffer import scala.language.implicitConversions import scala.reflect.ClassTag -import org.scalatest.{BeforeAndAfter, FunSuite} +import org.scalatest.BeforeAndAfter import org.scalatest.time.{Span, Seconds => ScalaTestSeconds} import org.scalatest.concurrent.Eventually.timeout import org.scalatest.concurrent.PatienceConfiguration -import org.apache.spark.{SparkConf, Logging} +import org.apache.spark.{Logging, SparkConf, SparkFunSuite} import org.apache.spark.rdd.RDD import org.apache.spark.streaming.dstream.{DStream, InputDStream, ForEachDStream} import org.apache.spark.streaming.scheduler._ @@ -204,7 +204,7 @@ class BatchCounter(ssc: StreamingContext) { * This is the base trait for Spark Streaming testsuites. This provides basic functionality * to run user-defined set of input on user-defined stream operations, and verify the output. */ -trait TestSuiteBase extends FunSuite with BeforeAndAfter with Logging { +trait TestSuiteBase extends SparkFunSuite with BeforeAndAfter with Logging { // Name of the framework for Spark context def framework: String = this.getClass.getSimpleName diff --git a/streaming/src/test/scala/org/apache/spark/streaming/UISeleniumSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/UISeleniumSuite.scala index 441bbf95d0153..cbc24aee4fa1e 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/UISeleniumSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/UISeleniumSuite.scala @@ -28,14 +28,11 @@ import org.scalatest.time.SpanSugar._ import org.apache.spark._ - - - /** * Selenium tests for the Spark Web UI. */ class UISeleniumSuite - extends FunSuite with WebBrowser with Matchers with BeforeAndAfterAll with TestSuiteBase { + extends SparkFunSuite with WebBrowser with Matchers with BeforeAndAfterAll with TestSuiteBase { implicit var webDriver: WebDriver = _ @@ -197,4 +194,3 @@ class UISeleniumSuite } } } - diff --git a/streaming/src/test/scala/org/apache/spark/streaming/rdd/WriteAheadLogBackedBlockRDDSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/rdd/WriteAheadLogBackedBlockRDDSuite.scala index 6859b65c7165f..cb017b798b2a4 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/rdd/WriteAheadLogBackedBlockRDDSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/rdd/WriteAheadLogBackedBlockRDDSuite.scala @@ -21,15 +21,15 @@ import java.io.File import scala.util.Random import org.apache.hadoop.conf.Configuration -import org.scalatest.{BeforeAndAfterAll, BeforeAndAfterEach, FunSuite} +import org.scalatest.{BeforeAndAfterAll, BeforeAndAfterEach} import org.apache.spark.storage.{BlockId, BlockManager, StorageLevel, StreamBlockId} import org.apache.spark.streaming.util.{FileBasedWriteAheadLogSegment, FileBasedWriteAheadLogWriter} import org.apache.spark.util.Utils -import org.apache.spark.{SparkConf, SparkContext, SparkException} +import org.apache.spark.{SparkConf, SparkContext, SparkException, SparkFunSuite} class WriteAheadLogBackedBlockRDDSuite - extends FunSuite with BeforeAndAfterAll with BeforeAndAfterEach { + extends SparkFunSuite with BeforeAndAfterAll with BeforeAndAfterEach { val conf = new SparkConf() .setMaster("local[2]") diff --git a/streaming/src/test/scala/org/apache/spark/streaming/scheduler/InputInfoTrackerSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/scheduler/InputInfoTrackerSuite.scala index 5478b41845943..2e210397fe7c7 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/scheduler/InputInfoTrackerSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/scheduler/InputInfoTrackerSuite.scala @@ -17,12 +17,12 @@ package org.apache.spark.streaming.scheduler -import org.scalatest.{BeforeAndAfter, FunSuite} +import org.scalatest.BeforeAndAfter -import org.apache.spark.SparkConf +import org.apache.spark.{SparkConf, SparkFunSuite} import org.apache.spark.streaming.{Time, Duration, StreamingContext} -class InputInfoTrackerSuite extends FunSuite with BeforeAndAfter { +class InputInfoTrackerSuite extends SparkFunSuite with BeforeAndAfter { private var ssc: StreamingContext = _ diff --git a/streaming/src/test/scala/org/apache/spark/streaming/ui/UIUtilsSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/ui/UIUtilsSuite.scala index e9ab917ab845c..d3ca2b58f36c2 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/ui/UIUtilsSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/ui/UIUtilsSuite.scala @@ -20,10 +20,11 @@ package org.apache.spark.streaming.ui import java.util.TimeZone import java.util.concurrent.TimeUnit -import org.scalatest.FunSuite import org.scalatest.Matchers -class UIUtilsSuite extends FunSuite with Matchers{ +import org.apache.spark.SparkFunSuite + +class UIUtilsSuite extends SparkFunSuite with Matchers{ test("shortTimeUnitString") { assert("ns" === UIUtils.shortTimeUnitString(TimeUnit.NANOSECONDS)) diff --git a/streaming/src/test/scala/org/apache/spark/streaming/util/RateLimitedOutputStreamSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/util/RateLimitedOutputStreamSuite.scala index 9ebf7b484f421..78fc344b00177 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/util/RateLimitedOutputStreamSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/util/RateLimitedOutputStreamSuite.scala @@ -20,9 +20,9 @@ package org.apache.spark.streaming.util import java.io.ByteArrayOutputStream import java.util.concurrent.TimeUnit._ -import org.scalatest.FunSuite +import org.apache.spark.SparkFunSuite -class RateLimitedOutputStreamSuite extends FunSuite { +class RateLimitedOutputStreamSuite extends SparkFunSuite { private def benchmark[U](f: => U): Long = { val start = System.nanoTime diff --git a/streaming/src/test/scala/org/apache/spark/streaming/util/WriteAheadLogSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/util/WriteAheadLogSuite.scala index 79098bcf4861c..325ff7c74c39d 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/util/WriteAheadLogSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/util/WriteAheadLogSuite.scala @@ -28,15 +28,15 @@ import scala.reflect.ClassTag import org.apache.hadoop.conf.Configuration import org.apache.hadoop.fs.Path import org.scalatest.concurrent.Eventually._ -import org.scalatest.{BeforeAndAfter, FunSuite} +import org.scalatest.BeforeAndAfter import org.apache.spark.util.{ManualClock, Utils} -import org.apache.spark.{SparkConf, SparkException} +import org.apache.spark.{SparkConf, SparkException, SparkFunSuite} -class WriteAheadLogSuite extends FunSuite with BeforeAndAfter { +class WriteAheadLogSuite extends SparkFunSuite with BeforeAndAfter { import WriteAheadLogSuite._ - + val hadoopConf = new Configuration() var tempDir: File = null var testDir: String = null @@ -359,7 +359,7 @@ object WriteAheadLogSuite { ): FileBasedWriteAheadLog = { if (manualClock.getTimeMillis() < 100000) manualClock.setTime(10000) val wal = new FileBasedWriteAheadLog(new SparkConf(), logDirectory, hadoopConf, 1, 1) - + // Ensure that 500 does not get sorted after 2000, so put a high base value. data.foreach { item => manualClock.advance(500) diff --git a/yarn/pom.xml b/yarn/pom.xml index 00d219f836708..e207a46809684 100644 --- a/yarn/pom.xml +++ b/yarn/pom.xml @@ -39,6 +39,13 @@ spark-core_${scala.binary.version} ${project.version} + + org.apache.spark + spark-core_${scala.binary.version} + ${project.version} + test-jar + test + org.apache.hadoop hadoop-yarn-api diff --git a/yarn/src/main/scala/org/apache/spark/deploy/yarn/AMDelegationTokenRenewer.scala b/yarn/src/main/scala/org/apache/spark/deploy/yarn/AMDelegationTokenRenewer.scala index aaae6f9734a85..77af46c192cc2 100644 --- a/yarn/src/main/scala/org/apache/spark/deploy/yarn/AMDelegationTokenRenewer.scala +++ b/yarn/src/main/scala/org/apache/spark/deploy/yarn/AMDelegationTokenRenewer.scala @@ -60,8 +60,11 @@ private[yarn] class AMDelegationTokenRenewer( private val hadoopUtil = YarnSparkHadoopUtil.get - private val daysToKeepFiles = sparkConf.getInt("spark.yarn.credentials.file.retention.days", 5) - private val numFilesToKeep = sparkConf.getInt("spark.yarn.credentials.file.retention.count", 5) + private val credentialsFile = sparkConf.get("spark.yarn.credentials.file") + private val daysToKeepFiles = + sparkConf.getInt("spark.yarn.credentials.file.retention.days", 5) + private val numFilesToKeep = + sparkConf.getInt("spark.yarn.credentials.file.retention.count", 5) /** * Schedule a login from the keytab and principal set using the --principal and --keytab @@ -121,7 +124,7 @@ private[yarn] class AMDelegationTokenRenewer( import scala.concurrent.duration._ try { val remoteFs = FileSystem.get(hadoopConf) - val credentialsPath = new Path(sparkConf.get("spark.yarn.credentials.file")) + val credentialsPath = new Path(credentialsFile) val thresholdTime = System.currentTimeMillis() - (daysToKeepFiles days).toMillis hadoopUtil.listFilesSorted( remoteFs, credentialsPath.getParent, @@ -160,7 +163,7 @@ private[yarn] class AMDelegationTokenRenewer( val keytabLoggedInUGI = UserGroupInformation.loginUserFromKeytabAndReturnUGI(principal, keytab) logInfo("Successfully logged into KDC.") val tempCreds = keytabLoggedInUGI.getCredentials - val credentialsPath = new Path(sparkConf.get("spark.yarn.credentials.file")) + val credentialsPath = new Path(credentialsFile) val dst = credentialsPath.getParent keytabLoggedInUGI.doAs(new PrivilegedExceptionAction[Void] { // Get a copy of the credentials @@ -186,8 +189,7 @@ private[yarn] class AMDelegationTokenRenewer( } val nextSuffix = lastCredentialsFileSuffix + 1 val tokenPathStr = - sparkConf.get("spark.yarn.credentials.file") + - SparkHadoopUtil.SPARK_YARN_CREDS_COUNTER_DELIM + nextSuffix + credentialsFile + SparkHadoopUtil.SPARK_YARN_CREDS_COUNTER_DELIM + nextSuffix val tokenPath = new Path(tokenPathStr) val tempTokenPath = new Path(tokenPathStr + SparkHadoopUtil.SPARK_YARN_CREDS_TEMP_EXTENSION) logInfo("Writing out delegation tokens to " + tempTokenPath.toString) diff --git a/yarn/src/main/scala/org/apache/spark/deploy/yarn/ClientArguments.scala b/yarn/src/main/scala/org/apache/spark/deploy/yarn/ClientArguments.scala index 5653c9f14dc6d..9c7b1b3988082 100644 --- a/yarn/src/main/scala/org/apache/spark/deploy/yarn/ClientArguments.scala +++ b/yarn/src/main/scala/org/apache/spark/deploy/yarn/ClientArguments.scala @@ -98,6 +98,12 @@ private[spark] class ClientArguments(args: Array[String], sparkConf: SparkConf) numExecutors = initialNumExecutors } + principal = Option(principal) + .orElse(sparkConf.getOption("spark.yarn.principal")) + .orNull + keytab = Option(keytab) + .orElse(sparkConf.getOption("spark.yarn.keytab")) + .orNull } /** diff --git a/yarn/src/main/scala/org/apache/spark/deploy/yarn/ClientDistributedCacheManager.scala b/yarn/src/main/scala/org/apache/spark/deploy/yarn/ClientDistributedCacheManager.scala index 4ca6c903fcf12..3d3a966960e9f 100644 --- a/yarn/src/main/scala/org/apache/spark/deploy/yarn/ClientDistributedCacheManager.scala +++ b/yarn/src/main/scala/org/apache/spark/deploy/yarn/ClientDistributedCacheManager.scala @@ -43,22 +43,22 @@ private[spark] class ClientDistributedCacheManager() extends Logging { * Add a resource to the list of distributed cache resources. This list can * be sent to the ApplicationMaster and possibly the executors so that it can * be downloaded into the Hadoop distributed cache for use by this application. - * Adds the LocalResource to the localResources HashMap passed in and saves + * Adds the LocalResource to the localResources HashMap passed in and saves * the stats of the resources to they can be sent to the executors and verified. * * @param fs FileSystem * @param conf Configuration * @param destPath path to the resource * @param localResources localResource hashMap to insert the resource into - * @param resourceType LocalResourceType + * @param resourceType LocalResourceType * @param link link presented in the distributed cache to the destination - * @param statCache cache to store the file/directory stats + * @param statCache cache to store the file/directory stats * @param appMasterOnly Whether to only add the resource to the app master */ def addResource( fs: FileSystem, conf: Configuration, - destPath: Path, + destPath: Path, localResources: HashMap[String, LocalResource], resourceType: LocalResourceType, link: String, @@ -74,15 +74,15 @@ private[spark] class ClientDistributedCacheManager() extends Logging { amJarRsrc.setSize(destStatus.getLen()) if (link == null || link.isEmpty()) throw new Exception("You must specify a valid link name") localResources(link) = amJarRsrc - + if (!appMasterOnly) { val uri = destPath.toUri() val pathURI = new URI(uri.getScheme(), uri.getAuthority(), uri.getPath(), null, link) if (resourceType == LocalResourceType.FILE) { - distCacheFiles(pathURI.toString()) = (destStatus.getLen().toString(), + distCacheFiles(pathURI.toString()) = (destStatus.getLen().toString(), destStatus.getModificationTime().toString(), visibility.name()) } else { - distCacheArchives(pathURI.toString()) = (destStatus.getLen().toString(), + distCacheArchives(pathURI.toString()) = (destStatus.getLen().toString(), destStatus.getModificationTime().toString(), visibility.name()) } } @@ -96,11 +96,11 @@ private[spark] class ClientDistributedCacheManager() extends Logging { val (sizes, timeStamps, visibilities) = tupleValues.unzip3 if (keys.size > 0) { env("SPARK_YARN_CACHE_FILES") = keys.reduceLeft[String] { (acc, n) => acc + "," + n } - env("SPARK_YARN_CACHE_FILES_TIME_STAMPS") = + env("SPARK_YARN_CACHE_FILES_TIME_STAMPS") = timeStamps.reduceLeft[String] { (acc, n) => acc + "," + n } - env("SPARK_YARN_CACHE_FILES_FILE_SIZES") = + env("SPARK_YARN_CACHE_FILES_FILE_SIZES") = sizes.reduceLeft[String] { (acc, n) => acc + "," + n } - env("SPARK_YARN_CACHE_FILES_VISIBILITIES") = + env("SPARK_YARN_CACHE_FILES_VISIBILITIES") = visibilities.reduceLeft[String] { (acc, n) => acc + "," + n } } } @@ -113,11 +113,11 @@ private[spark] class ClientDistributedCacheManager() extends Logging { val (sizes, timeStamps, visibilities) = tupleValues.unzip3 if (keys.size > 0) { env("SPARK_YARN_CACHE_ARCHIVES") = keys.reduceLeft[String] { (acc, n) => acc + "," + n } - env("SPARK_YARN_CACHE_ARCHIVES_TIME_STAMPS") = + env("SPARK_YARN_CACHE_ARCHIVES_TIME_STAMPS") = timeStamps.reduceLeft[String] { (acc, n) => acc + "," + n } env("SPARK_YARN_CACHE_ARCHIVES_FILE_SIZES") = sizes.reduceLeft[String] { (acc, n) => acc + "," + n } - env("SPARK_YARN_CACHE_ARCHIVES_VISIBILITIES") = + env("SPARK_YARN_CACHE_ARCHIVES_VISIBILITIES") = visibilities.reduceLeft[String] { (acc, n) => acc + "," + n } } } @@ -197,7 +197,7 @@ private[spark] class ClientDistributedCacheManager() extends Logging { def getFileStatus(fs: FileSystem, uri: URI, statCache: Map[URI, FileStatus]): FileStatus = { val stat = statCache.get(uri) match { case Some(existstat) => existstat - case None => + case None => val newStat = fs.getFileStatus(new Path(uri)) statCache.put(uri, newStat) newStat diff --git a/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnSparkHadoopUtil.scala b/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnSparkHadoopUtil.scala index 5e6531895c7ba..68d01c17ef720 100644 --- a/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnSparkHadoopUtil.scala +++ b/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnSparkHadoopUtil.scala @@ -144,9 +144,9 @@ class YarnSparkHadoopUtil extends SparkHadoopUtil { } object YarnSparkHadoopUtil { - // Additional memory overhead + // Additional memory overhead // 10% was arrived at experimentally. In the interest of minimizing memory waste while covering - // the common cases. Memory overhead tends to grow with container size. + // the common cases. Memory overhead tends to grow with container size. val MEMORY_OVERHEAD_FACTOR = 0.10 val MEMORY_OVERHEAD_MIN = 384 diff --git a/yarn/src/test/scala/org/apache/spark/deploy/yarn/ClientDistributedCacheManagerSuite.scala b/yarn/src/test/scala/org/apache/spark/deploy/yarn/ClientDistributedCacheManagerSuite.scala index 80b57d1355a3a..804dfecde7867 100644 --- a/yarn/src/test/scala/org/apache/spark/deploy/yarn/ClientDistributedCacheManagerSuite.scala +++ b/yarn/src/test/scala/org/apache/spark/deploy/yarn/ClientDistributedCacheManagerSuite.scala @@ -19,7 +19,6 @@ package org.apache.spark.deploy.yarn import java.net.URI -import org.scalatest.FunSuite import org.scalatest.mock.MockitoSugar import org.mockito.Mockito.when @@ -36,16 +35,18 @@ import org.apache.hadoop.yarn.util.{Records, ConverterUtils} import scala.collection.mutable.HashMap import scala.collection.mutable.Map +import org.apache.spark.SparkFunSuite -class ClientDistributedCacheManagerSuite extends FunSuite with MockitoSugar { + +class ClientDistributedCacheManagerSuite extends SparkFunSuite with MockitoSugar { class MockClientDistributedCacheManager extends ClientDistributedCacheManager { - override def getVisibility(conf: Configuration, uri: URI, statCache: Map[URI, FileStatus]): + override def getVisibility(conf: Configuration, uri: URI, statCache: Map[URI, FileStatus]): LocalResourceVisibility = { LocalResourceVisibility.PRIVATE } } - + test("test getFileStatus empty") { val distMgr = new ClientDistributedCacheManager() val fs = mock[FileSystem] @@ -60,7 +61,7 @@ class ClientDistributedCacheManagerSuite extends FunSuite with MockitoSugar { val distMgr = new ClientDistributedCacheManager() val fs = mock[FileSystem] val uri = new URI("/tmp/testing") - val realFileStatus = new FileStatus(10, false, 1, 1024, 10, 10, null, "testOwner", + val realFileStatus = new FileStatus(10, false, 1, 1024, 10, 10, null, "testOwner", null, new Path("/tmp/testing")) when(fs.getFileStatus(new Path(uri))).thenReturn(new FileStatus()) val statCache: Map[URI, FileStatus] = HashMap[URI, FileStatus](uri -> realFileStatus) @@ -77,7 +78,7 @@ class ClientDistributedCacheManagerSuite extends FunSuite with MockitoSugar { val statCache: Map[URI, FileStatus] = HashMap[URI, FileStatus]() when(fs.getFileStatus(destPath)).thenReturn(new FileStatus()) - distMgr.addResource(fs, conf, destPath, localResources, LocalResourceType.FILE, "link", + distMgr.addResource(fs, conf, destPath, localResources, LocalResourceType.FILE, "link", statCache, false) val resource = localResources("link") assert(resource.getVisibility() === LocalResourceVisibility.PRIVATE) @@ -100,11 +101,11 @@ class ClientDistributedCacheManagerSuite extends FunSuite with MockitoSugar { assert(env.get("SPARK_YARN_CACHE_ARCHIVES_VISIBILITIES") === None) // add another one and verify both there and order correct - val realFileStatus = new FileStatus(20, false, 1, 1024, 10, 30, null, "testOwner", + val realFileStatus = new FileStatus(20, false, 1, 1024, 10, 30, null, "testOwner", null, new Path("/tmp/testing2")) val destPath2 = new Path("file:///foo.invalid.com:8080/tmp/testing2") when(fs.getFileStatus(destPath2)).thenReturn(realFileStatus) - distMgr.addResource(fs, conf, destPath2, localResources, LocalResourceType.FILE, "link2", + distMgr.addResource(fs, conf, destPath2, localResources, LocalResourceType.FILE, "link2", statCache, false) val resource2 = localResources("link2") assert(resource2.getVisibility() === LocalResourceVisibility.PRIVATE) @@ -116,7 +117,7 @@ class ClientDistributedCacheManagerSuite extends FunSuite with MockitoSugar { val env2 = new HashMap[String, String]() distMgr.setDistFilesEnv(env2) val timestamps = env2("SPARK_YARN_CACHE_FILES_TIME_STAMPS").split(',') - val files = env2("SPARK_YARN_CACHE_FILES").split(',') + val files = env2("SPARK_YARN_CACHE_FILES").split(',') val sizes = env2("SPARK_YARN_CACHE_FILES_FILE_SIZES").split(',') val visibilities = env2("SPARK_YARN_CACHE_FILES_VISIBILITIES") .split(',') assert(files(0) === "file:/foo.invalid.com:8080/tmp/testing#link") @@ -140,7 +141,7 @@ class ClientDistributedCacheManagerSuite extends FunSuite with MockitoSugar { when(fs.getFileStatus(destPath)).thenReturn(new FileStatus()) intercept[Exception] { - distMgr.addResource(fs, conf, destPath, localResources, LocalResourceType.FILE, null, + distMgr.addResource(fs, conf, destPath, localResources, LocalResourceType.FILE, null, statCache, false) } assert(localResources.get("link") === None) @@ -154,11 +155,11 @@ class ClientDistributedCacheManagerSuite extends FunSuite with MockitoSugar { val destPath = new Path("file:///foo.invalid.com:8080/tmp/testing") val localResources = HashMap[String, LocalResource]() val statCache: Map[URI, FileStatus] = HashMap[URI, FileStatus]() - val realFileStatus = new FileStatus(20, false, 1, 1024, 10, 30, null, "testOwner", + val realFileStatus = new FileStatus(20, false, 1, 1024, 10, 30, null, "testOwner", null, new Path("/tmp/testing")) when(fs.getFileStatus(destPath)).thenReturn(realFileStatus) - distMgr.addResource(fs, conf, destPath, localResources, LocalResourceType.ARCHIVE, "link", + distMgr.addResource(fs, conf, destPath, localResources, LocalResourceType.ARCHIVE, "link", statCache, true) val resource = localResources("link") assert(resource.getVisibility() === LocalResourceVisibility.PRIVATE) @@ -188,11 +189,11 @@ class ClientDistributedCacheManagerSuite extends FunSuite with MockitoSugar { val destPath = new Path("file:///foo.invalid.com:8080/tmp/testing") val localResources = HashMap[String, LocalResource]() val statCache: Map[URI, FileStatus] = HashMap[URI, FileStatus]() - val realFileStatus = new FileStatus(20, false, 1, 1024, 10, 30, null, "testOwner", + val realFileStatus = new FileStatus(20, false, 1, 1024, 10, 30, null, "testOwner", null, new Path("/tmp/testing")) when(fs.getFileStatus(destPath)).thenReturn(realFileStatus) - distMgr.addResource(fs, conf, destPath, localResources, LocalResourceType.ARCHIVE, "link", + distMgr.addResource(fs, conf, destPath, localResources, LocalResourceType.ARCHIVE, "link", statCache, false) val resource = localResources("link") assert(resource.getVisibility() === LocalResourceVisibility.PRIVATE) diff --git a/yarn/src/test/scala/org/apache/spark/deploy/yarn/ClientSuite.scala b/yarn/src/test/scala/org/apache/spark/deploy/yarn/ClientSuite.scala index 6da3e82acdb14..01d33c9ce9297 100644 --- a/yarn/src/test/scala/org/apache/spark/deploy/yarn/ClientSuite.scala +++ b/yarn/src/test/scala/org/apache/spark/deploy/yarn/ClientSuite.scala @@ -33,12 +33,12 @@ import org.apache.hadoop.yarn.api.records._ import org.apache.hadoop.yarn.conf.YarnConfiguration import org.mockito.Matchers._ import org.mockito.Mockito._ -import org.scalatest.{BeforeAndAfterAll, FunSuite, Matchers} +import org.scalatest.{BeforeAndAfterAll, Matchers} -import org.apache.spark.{SparkException, SparkConf} +import org.apache.spark.{SparkConf, SparkException, SparkFunSuite} import org.apache.spark.util.Utils -class ClientSuite extends FunSuite with Matchers with BeforeAndAfterAll { +class ClientSuite extends SparkFunSuite with Matchers with BeforeAndAfterAll { override def beforeAll(): Unit = { System.setProperty("SPARK_YARN_MODE", "true") diff --git a/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnAllocatorSuite.scala b/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnAllocatorSuite.scala index b343cbb0c7569..7509000771d94 100644 --- a/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnAllocatorSuite.scala +++ b/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnAllocatorSuite.scala @@ -26,13 +26,13 @@ import org.apache.hadoop.yarn.api.records._ import org.apache.hadoop.yarn.client.api.AMRMClient import org.apache.hadoop.yarn.client.api.AMRMClient.ContainerRequest -import org.apache.spark.SecurityManager +import org.apache.spark.{SecurityManager, SparkFunSuite} import org.apache.spark.SparkConf import org.apache.spark.deploy.yarn.YarnSparkHadoopUtil._ import org.apache.spark.deploy.yarn.YarnAllocator._ import org.apache.spark.scheduler.SplitInfo -import org.scalatest.{BeforeAndAfterEach, FunSuite, Matchers} +import org.scalatest.{BeforeAndAfterEach, Matchers} class MockResolver extends DNSToSwitchMapping { @@ -46,7 +46,7 @@ class MockResolver extends DNSToSwitchMapping { def reloadCachedMappings(names: JList[String]) {} } -class YarnAllocatorSuite extends FunSuite with Matchers with BeforeAndAfterEach { +class YarnAllocatorSuite extends SparkFunSuite with Matchers with BeforeAndAfterEach { val conf = new Configuration() conf.setClass( CommonConfigurationKeysPublic.NET_TOPOLOGY_NODE_SWITCH_MAPPING_IMPL_KEY, diff --git a/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnClusterSuite.scala b/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnClusterSuite.scala index dcaeb2e43ff41..d8bc2534c1a6a 100644 --- a/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnClusterSuite.scala +++ b/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnClusterSuite.scala @@ -30,9 +30,9 @@ import com.google.common.io.ByteStreams import com.google.common.io.Files import org.apache.hadoop.yarn.conf.YarnConfiguration import org.apache.hadoop.yarn.server.MiniYARNCluster -import org.scalatest.{BeforeAndAfterAll, FunSuite, Matchers} +import org.scalatest.{BeforeAndAfterAll, Matchers} -import org.apache.spark.{Logging, SparkConf, SparkContext, SparkException, TestUtils} +import org.apache.spark._ import org.apache.spark.scheduler.cluster.ExecutorInfo import org.apache.spark.scheduler.{SparkListener, SparkListenerApplicationStart, SparkListenerExecutorAdded} @@ -43,7 +43,7 @@ import org.apache.spark.util.Utils * applications, and require the Spark assembly to be built before they can be successfully * run. */ -class YarnClusterSuite extends FunSuite with BeforeAndAfterAll with Matchers with Logging { +class YarnClusterSuite extends SparkFunSuite with BeforeAndAfterAll with Matchers with Logging { // log4j configuration for the YARN containers, so that their output is collected // by YARN instead of trying to overwrite unit-tests.log. diff --git a/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnSparkHadoopUtilSuite.scala b/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnSparkHadoopUtilSuite.scala index e10b985c3c236..49bee0866dd43 100644 --- a/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnSparkHadoopUtilSuite.scala +++ b/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnSparkHadoopUtilSuite.scala @@ -25,15 +25,15 @@ import org.apache.hadoop.fs.Path import org.apache.hadoop.yarn.api.ApplicationConstants import org.apache.hadoop.yarn.api.ApplicationConstants.Environment import org.apache.hadoop.yarn.conf.YarnConfiguration -import org.scalatest.{FunSuite, Matchers} +import org.scalatest.Matchers import org.apache.hadoop.yarn.api.records.ApplicationAccessType -import org.apache.spark.{Logging, SecurityManager, SparkConf, SparkException} +import org.apache.spark.{Logging, SecurityManager, SparkConf, SparkException, SparkFunSuite} import org.apache.spark.util.Utils -class YarnSparkHadoopUtilSuite extends FunSuite with Matchers with Logging { +class YarnSparkHadoopUtilSuite extends SparkFunSuite with Matchers with Logging { val hasBash = try {