diff --git a/.circleci/config.yml b/.circleci/config.yml index 43f2d58acdf31..86636d6d20314 100644 --- a/.circleci/config.yml +++ b/.circleci/config.yml @@ -2,7 +2,7 @@ version: 2 defaults: &defaults docker: - - image: palantirtechnologies/circle-spark-base:0.1.0 + - image: palantirtechnologies/circle-spark-base:0.1.3 resource_class: xlarge environment: &defaults-environment TERM: dumb @@ -128,7 +128,7 @@ jobs: <<: *defaults # Some part of the maven setup fails if there's no R, so we need to use the R image here docker: - - image: palantirtechnologies/circle-spark-r:0.1.0 + - image: palantirtechnologies/circle-spark-r:0.1.3 steps: # Saves us from recompiling every time... - restore_cache: @@ -147,12 +147,7 @@ jobs: keys: - build-binaries-{{ checksum "build/mvn" }}-{{ checksum "build/sbt" }} - build-binaries- - - run: | - ./build/mvn -T1C -DskipTests -Phadoop-cloud -Phadoop-palantir -Pkinesis-asl -Pkubernetes -Pyarn -Psparkr install \ - | tee -a "/tmp/mvn-install.log" - - store_artifacts: - path: /tmp/mvn-install.log - destination: mvn-install.log + - run: ./build/mvn -DskipTests -Phadoop-cloud -Phadoop-palantir -Pkinesis-asl -Pkubernetes -Pyarn -Psparkr install # Get sbt to run trivially, ensures its launcher is downloaded under build/ - run: ./build/sbt -h || true - save_cache: @@ -300,7 +295,7 @@ jobs: # depends on build-sbt, but we only need the assembly jars <<: *defaults docker: - - image: palantirtechnologies/circle-spark-python:0.1.0 + - image: palantirtechnologies/circle-spark-python:0.1.3 parallelism: 2 steps: - *checkout-code @@ -325,7 +320,7 @@ jobs: # depends on build-sbt, but we only need the assembly jars <<: *defaults docker: - - image: palantirtechnologies/circle-spark-r:0.1.0 + - image: palantirtechnologies/circle-spark-r:0.1.3 steps: - *checkout-code - attach_workspace: @@ -438,7 +433,7 @@ jobs: <<: *defaults # Some part of the maven setup fails if there's no R, so we need to use the R image here docker: - - image: palantirtechnologies/circle-spark-r:0.1.0 + - image: palantirtechnologies/circle-spark-r:0.1.3 steps: - *checkout-code - restore_cache: @@ -458,7 +453,7 @@ jobs: deploy-gradle: <<: *defaults docker: - - image: palantirtechnologies/circle-spark-r:0.1.0 + - image: palantirtechnologies/circle-spark-r:0.1.3 steps: - *checkout-code - *restore-gradle-wrapper-cache @@ -470,7 +465,7 @@ jobs: <<: *defaults # Some part of the maven setup fails if there's no R, so we need to use the R image here docker: - - image: palantirtechnologies/circle-spark-r:0.1.0 + - image: palantirtechnologies/circle-spark-r:0.1.3 steps: # This cache contains the whole project after version was set and mvn package was called # Restoring first (and instead of checkout) as mvn versions:set mutates real source code... diff --git a/.gitignore b/.gitignore index 61f9349f42a76..7114705bfccf4 100644 --- a/.gitignore +++ b/.gitignore @@ -81,7 +81,6 @@ work/ .credentials dev/pr-deps docs/.jekyll-metadata -*.crc # For Hive TempStatsStore/ diff --git a/FORK.md b/FORK.md index 7efc0e7961237..d1262b890b67d 100644 --- a/FORK.md +++ b/FORK.md @@ -29,3 +29,5 @@ # Reverted * [SPARK-25908](https://issues.apache.org/jira/browse/SPARK-25908) - Removal of `monotonicall_increasing_id`, `toDegree`, `toRadians`, `approxCountDistinct`, `unionAll` * [SPARK-25862](https://issues.apache.org/jira/browse/SPARK-25862) - Removal of `unboundedPreceding`, `unboundedFollowing`, `currentRow` +* [SPARK-26127](https://issues.apache.org/jira/browse/SPARK-26127) - Removal of deprecated setters from tree regression and classification models +* [SPARK-25867](https://issues.apache.org/jira/browse/SPARK-25867) - Removal of KMeans computeCost diff --git a/R/WINDOWS.md b/R/WINDOWS.md index da668a69b8679..33a4c850cfdac 100644 --- a/R/WINDOWS.md +++ b/R/WINDOWS.md @@ -3,7 +3,7 @@ To build SparkR on Windows, the following steps are required 1. Install R (>= 3.1) and [Rtools](http://cran.r-project.org/bin/windows/Rtools/). Make sure to -include Rtools and R in `PATH`. +include Rtools and R in `PATH`. Note that support for R prior to version 3.4 is deprecated as of Spark 3.0.0. 2. Install [JDK8](http://www.oracle.com/technetwork/java/javase/downloads/jdk8-downloads-2133151.html) and set diff --git a/R/pkg/DESCRIPTION b/R/pkg/DESCRIPTION index 70f22bf895495..74cdbd185e570 100644 --- a/R/pkg/DESCRIPTION +++ b/R/pkg/DESCRIPTION @@ -15,7 +15,7 @@ URL: http://www.apache.org/ http://spark.apache.org/ BugReports: http://spark.apache.org/contributing.html SystemRequirements: Java (== 8) Depends: - R (>= 3.0), + R (>= 3.1), methods Suggests: knitr, diff --git a/R/pkg/R/DataFrame.R b/R/pkg/R/DataFrame.R index fb158e17fc19c..ad9cd845f696c 100644 --- a/R/pkg/R/DataFrame.R +++ b/R/pkg/R/DataFrame.R @@ -767,6 +767,14 @@ setMethod("repartition", #' using \code{spark.sql.shuffle.partitions} as number of partitions.} #'} #' +#' At least one partition-by expression must be specified. +#' When no explicit sort order is specified, "ascending nulls first" is assumed. +#' +#' Note that due to performance reasons this method uses sampling to estimate the ranges. +#' Hence, the output may not be consistent, since sampling can return different values. +#' The sample size can be controlled by the config +#' \code{spark.sql.execution.rangeExchange.sampleSizePerPartition}. +#' #' @param x a SparkDataFrame. #' @param numPartitions the number of partitions to use. #' @param col the column by which the range partitioning will be performed. diff --git a/R/pkg/R/functions.R b/R/pkg/R/functions.R index 9abb7fc1fadb4..f72645a257796 100644 --- a/R/pkg/R/functions.R +++ b/R/pkg/R/functions.R @@ -3370,7 +3370,7 @@ setMethod("flatten", #' #' @rdname column_collection_functions #' @aliases map_entries map_entries,Column-method -#' @note map_entries since 2.4.0 +#' @note map_entries since 3.0.0 setMethod("map_entries", signature(x = "Column"), function(x) { diff --git a/R/pkg/R/stats.R b/R/pkg/R/stats.R index 497f18c763048..7252351ebebb2 100644 --- a/R/pkg/R/stats.R +++ b/R/pkg/R/stats.R @@ -109,7 +109,7 @@ setMethod("corr", #' #' Finding frequent items for columns, possibly with false positives. #' Using the frequent element count algorithm described in -#' \url{http://dx.doi.org/10.1145/762471.762473}, proposed by Karp, Schenker, and Papadimitriou. +#' \url{https://doi.org/10.1145/762471.762473}, proposed by Karp, Schenker, and Papadimitriou. #' #' @param x A SparkDataFrame. #' @param cols A vector column names to search frequent items in. @@ -143,7 +143,7 @@ setMethod("freqItems", signature(x = "SparkDataFrame", cols = "character"), #' *exact* rank of x is close to (p * N). More precisely, #' floor((p - err) * N) <= rank(x) <= ceil((p + err) * N). #' This method implements a variation of the Greenwald-Khanna algorithm (with some speed -#' optimizations). The algorithm was first present in [[http://dx.doi.org/10.1145/375663.375670 +#' optimizations). The algorithm was first present in [[https://doi.org/10.1145/375663.375670 #' Space-efficient Online Computation of Quantile Summaries]] by Greenwald and Khanna. #' Note that NA values will be ignored in numerical columns before calculation. For #' columns only containing NA values, an empty list is returned. diff --git a/R/pkg/inst/profile/general.R b/R/pkg/inst/profile/general.R index 8c75c19ca7ac3..3efb460846fc2 100644 --- a/R/pkg/inst/profile/general.R +++ b/R/pkg/inst/profile/general.R @@ -16,6 +16,10 @@ # .First <- function() { + if (utils::compareVersion(paste0(R.version$major, ".", R.version$minor), "3.4.0") == -1) { + warning("Support for R prior to version 3.4 is deprecated since Spark 3.0.0") + } + packageDir <- Sys.getenv("SPARKR_PACKAGE_DIR") dirs <- strsplit(packageDir, ",")[[1]] .libPaths(c(dirs, .libPaths())) diff --git a/R/pkg/inst/profile/shell.R b/R/pkg/inst/profile/shell.R index 8a8111a8c5419..32eb3671b5941 100644 --- a/R/pkg/inst/profile/shell.R +++ b/R/pkg/inst/profile/shell.R @@ -16,6 +16,10 @@ # .First <- function() { + if (utils::compareVersion(paste0(R.version$major, ".", R.version$minor), "3.4.0") == -1) { + warning("Support for R prior to version 3.4 is deprecated since Spark 3.0.0") + } + home <- Sys.getenv("SPARK_HOME") .libPaths(c(file.path(home, "R", "lib"), .libPaths())) Sys.setenv(NOAWT = 1) diff --git a/R/pkg/tests/fulltests/test_sparkSQL.R b/R/pkg/tests/fulltests/test_sparkSQL.R index 4cd7fc857a2b2..77a29c9ecad86 100644 --- a/R/pkg/tests/fulltests/test_sparkSQL.R +++ b/R/pkg/tests/fulltests/test_sparkSQL.R @@ -1674,7 +1674,7 @@ test_that("column functions", { # check for unparseable df <- as.DataFrame(list(list("a" = ""))) - expect_equal(collect(select(df, from_json(df$a, schema)))[[1]][[1]], NA) + expect_equal(collect(select(df, from_json(df$a, schema)))[[1]][[1]]$a, NA) # check if array type in string is correctly supported. jsonArr <- "[{\"name\":\"Bob\"}, {\"name\":\"Alice\"}]" diff --git a/R/pkg/tests/fulltests/test_streaming.R b/R/pkg/tests/fulltests/test_streaming.R index bfb1a046490ec..6f0d2aefee886 100644 --- a/R/pkg/tests/fulltests/test_streaming.R +++ b/R/pkg/tests/fulltests/test_streaming.R @@ -127,6 +127,7 @@ test_that("Specify a schema by using a DDL-formatted string when reading", { expect_false(awaitTermination(q, 5 * 1000)) callJMethod(q@ssq, "processAllAvailable") expect_equal(head(sql("SELECT count(*) FROM people3"))[[1]], 3) + stopQuery(q) expect_error(read.stream(path = parquetPath, schema = "name stri"), "DataType stri is not supported.") diff --git a/R/pkg/vignettes/sparkr-vignettes.Rmd b/R/pkg/vignettes/sparkr-vignettes.Rmd index 7d924efa9d4bb..f80b45b4f36a8 100644 --- a/R/pkg/vignettes/sparkr-vignettes.Rmd +++ b/R/pkg/vignettes/sparkr-vignettes.Rmd @@ -57,6 +57,20 @@ First, let's load and attach the package. library(SparkR) ``` +```{r, include=FALSE} +# disable eval if java version not supported +override_eval <- tryCatch(!is.numeric(SparkR:::checkJavaVersion()), + error = function(e) { TRUE }, + warning = function(e) { TRUE }) + +if (override_eval) { + opts_hooks$set(eval = function(options) { + options$eval = FALSE + options + }) +} +``` + `SparkSession` is the entry point into SparkR which connects your R program to a Spark cluster. You can create a `SparkSession` using `sparkR.session` and pass in options such as the application name, any Spark packages depended on, etc. We use default settings in which it runs in local mode. It auto downloads Spark package in the background if no previous installation is found. For more details about setup, see [Spark Session](#SetupSparkSession). diff --git a/assembly/README b/assembly/README index d5dafab477410..1fd6d8858348c 100644 --- a/assembly/README +++ b/assembly/README @@ -9,4 +9,4 @@ This module is off by default. To activate it specify the profile in the command If you need to build an assembly for a different version of Hadoop the hadoop-version system property needs to be set as in this example: - -Dhadoop.version=2.7.3 + -Dhadoop.version=2.7.4 diff --git a/assembly/pom.xml b/assembly/pom.xml index 5cda79564104c..6f8153e847a47 100644 --- a/assembly/pom.xml +++ b/assembly/pom.xml @@ -20,12 +20,12 @@ 4.0.0 org.apache.spark - spark-parent_2.11 + spark-parent_2.12 3.0.0-SNAPSHOT ../pom.xml - spark-assembly_2.11 + spark-assembly_2.12 Spark Project Assembly http://spark.apache.org/ pom @@ -76,7 +76,7 @@ org.apache.spark - spark-avro_2.11 + spark-avro_${scala.binary.version} ${project.version} diff --git a/bin/docker-image-tool.sh b/bin/docker-image-tool.sh index aa5d847f4be2f..9f735f1148da4 100755 --- a/bin/docker-image-tool.sh +++ b/bin/docker-image-tool.sh @@ -29,6 +29,20 @@ if [ -z "${SPARK_HOME}" ]; then fi . "${SPARK_HOME}/bin/load-spark-env.sh" +CTX_DIR="$SPARK_HOME/target/tmp/docker" + +function is_dev_build { + [ ! -f "$SPARK_HOME/RELEASE" ] +} + +function cleanup_ctx_dir { + if is_dev_build; then + rm -rf "$CTX_DIR" + fi +} + +trap cleanup_ctx_dir EXIT + function image_ref { local image="$1" local add_repo="${2:-1}" @@ -41,94 +55,136 @@ function image_ref { echo "$image" } +function docker_push { + local image_name="$1" + if [ ! -z $(docker images -q "$(image_ref ${image_name})") ]; then + docker push "$(image_ref ${image_name})" + if [ $? -ne 0 ]; then + error "Failed to push $image_name Docker image." + fi + else + echo "$(image_ref ${image_name}) image not found. Skipping push for this image." + fi +} + +# Create a smaller build context for docker in dev builds to make the build faster. Docker +# uploads all of the current directory to the daemon, and it can get pretty big with dev +# builds that contain test log files and other artifacts. +# +# Three build contexts are created, one for each image: base, pyspark, and sparkr. For them +# to have the desired effect, the docker command needs to be executed inside the appropriate +# context directory. +# +# Note: docker does not support symlinks in the build context. +function create_dev_build_context {( + set -e + local BASE_CTX="$CTX_DIR/base" + mkdir -p "$BASE_CTX/kubernetes" + cp -r "resource-managers/kubernetes/docker/src/main/dockerfiles" \ + "$BASE_CTX/kubernetes/dockerfiles" + + cp -r "assembly/target/scala-$SPARK_SCALA_VERSION/jars" "$BASE_CTX/jars" + cp -r "resource-managers/kubernetes/integration-tests/tests" \ + "$BASE_CTX/kubernetes/tests" + + mkdir "$BASE_CTX/examples" + cp -r "examples/src" "$BASE_CTX/examples/src" + # Copy just needed examples jars instead of everything. + mkdir "$BASE_CTX/examples/jars" + for i in examples/target/scala-$SPARK_SCALA_VERSION/jars/*; do + if [ ! -f "$BASE_CTX/jars/$(basename $i)" ]; then + cp $i "$BASE_CTX/examples/jars" + fi + done + + for other in bin sbin data; do + cp -r "$other" "$BASE_CTX/$other" + done + + local PYSPARK_CTX="$CTX_DIR/pyspark" + mkdir -p "$PYSPARK_CTX/kubernetes" + cp -r "resource-managers/kubernetes/docker/src/main/dockerfiles" \ + "$PYSPARK_CTX/kubernetes/dockerfiles" + mkdir "$PYSPARK_CTX/python" + cp -r "python/lib" "$PYSPARK_CTX/python/lib" + + local R_CTX="$CTX_DIR/sparkr" + mkdir -p "$R_CTX/kubernetes" + cp -r "resource-managers/kubernetes/docker/src/main/dockerfiles" \ + "$R_CTX/kubernetes/dockerfiles" + cp -r "R" "$R_CTX/R" +)} + +function img_ctx_dir { + if is_dev_build; then + echo "$CTX_DIR/$1" + else + echo "$SPARK_HOME" + fi +} + function build { local BUILD_ARGS - local IMG_PATH - local JARS - - if [ ! -f "$SPARK_HOME/RELEASE" ]; then - # Set image build arguments accordingly if this is a source repo and not a distribution archive. - # - # Note that this will copy all of the example jars directory into the image, and that will - # contain a lot of duplicated jars with the main Spark directory. In a proper distribution, - # the examples directory is cleaned up before generating the distribution tarball, so this - # issue does not occur. - IMG_PATH=resource-managers/kubernetes/docker/src/main/dockerfiles - JARS=assembly/target/scala-$SPARK_SCALA_VERSION/jars - BUILD_ARGS=( - ${BUILD_PARAMS} - --build-arg - img_path=$IMG_PATH - --build-arg - spark_jars=$JARS - --build-arg - example_jars=examples/target/scala-$SPARK_SCALA_VERSION/jars - --build-arg - k8s_tests=resource-managers/kubernetes/integration-tests/tests - ) - else - # Not passed as arguments to docker, but used to validate the Spark directory. - IMG_PATH="kubernetes/dockerfiles" - JARS=jars - BUILD_ARGS=(${BUILD_PARAMS}) + local SPARK_ROOT="$SPARK_HOME" + + if is_dev_build; then + create_dev_build_context || error "Failed to create docker build context." + SPARK_ROOT="$CTX_DIR/base" fi # Verify that the Docker image content directory is present - if [ ! -d "$IMG_PATH" ]; then + if [ ! -d "$SPARK_ROOT/kubernetes/dockerfiles" ]; then error "Cannot find docker image. This script must be run from a runnable distribution of Apache Spark." fi # Verify that Spark has actually been built/is a runnable distribution # i.e. the Spark JARs that the Docker files will place into the image are present - local TOTAL_JARS=$(ls $JARS/spark-* | wc -l) + local TOTAL_JARS=$(ls $SPARK_ROOT/jars/spark-* | wc -l) TOTAL_JARS=$(( $TOTAL_JARS )) if [ "${TOTAL_JARS}" -eq 0 ]; then error "Cannot find Spark JARs. This script assumes that Apache Spark has first been built locally or this is a runnable distribution." fi + local BUILD_ARGS=(${BUILD_PARAMS}) local BINDING_BUILD_ARGS=( ${BUILD_PARAMS} --build-arg base_img=$(image_ref spark) ) - local BASEDOCKERFILE=${BASEDOCKERFILE:-"$IMG_PATH/spark/Dockerfile"} - local PYDOCKERFILE=${PYDOCKERFILE:-"$IMG_PATH/spark/bindings/python/Dockerfile"} - local RDOCKERFILE=${RDOCKERFILE:-"$IMG_PATH/spark/bindings/R/Dockerfile"} + local BASEDOCKERFILE=${BASEDOCKERFILE:-"kubernetes/dockerfiles/spark/Dockerfile"} + local PYDOCKERFILE=${PYDOCKERFILE:-false} + local RDOCKERFILE=${RDOCKERFILE:-false} - docker build $NOCACHEARG "${BUILD_ARGS[@]}" \ + (cd $(img_ctx_dir base) && docker build $NOCACHEARG "${BUILD_ARGS[@]}" \ -t $(image_ref spark) \ - -f "$BASEDOCKERFILE" . + -f "$BASEDOCKERFILE" .) if [ $? -ne 0 ]; then error "Failed to build Spark JVM Docker image, please refer to Docker build output for details." fi - docker build $NOCACHEARG "${BINDING_BUILD_ARGS[@]}" \ - -t $(image_ref spark-py) \ - -f "$PYDOCKERFILE" . + if [ "${PYDOCKERFILE}" != "false" ]; then + (cd $(img_ctx_dir pyspark) && docker build $NOCACHEARG "${BINDING_BUILD_ARGS[@]}" \ + -t $(image_ref spark-py) \ + -f "$PYDOCKERFILE" .) + if [ $? -ne 0 ]; then + error "Failed to build PySpark Docker image, please refer to Docker build output for details." + fi + fi + + if [ "${RDOCKERFILE}" != "false" ]; then + (cd $(img_ctx_dir sparkr) && docker build $NOCACHEARG "${BINDING_BUILD_ARGS[@]}" \ + -t $(image_ref spark-r) \ + -f "$RDOCKERFILE" .) if [ $? -ne 0 ]; then - error "Failed to build PySpark Docker image, please refer to Docker build output for details." + error "Failed to build SparkR Docker image, please refer to Docker build output for details." fi - docker build $NOCACHEARG "${BINDING_BUILD_ARGS[@]}" \ - -t $(image_ref spark-r) \ - -f "$RDOCKERFILE" . - if [ $? -ne 0 ]; then - error "Failed to build SparkR Docker image, please refer to Docker build output for details." fi } function push { - docker push "$(image_ref spark)" - if [ $? -ne 0 ]; then - error "Failed to push Spark JVM Docker image." - fi - docker push "$(image_ref spark-py)" - if [ $? -ne 0 ]; then - error "Failed to push PySpark Docker image." - fi - docker push "$(image_ref spark-r)" - if [ $? -ne 0 ]; then - error "Failed to push SparkR Docker image." - fi + docker_push "spark" + docker_push "spark-py" + docker_push "spark-r" } function usage { @@ -143,8 +199,10 @@ Commands: Options: -f file Dockerfile to build for JVM based Jobs. By default builds the Dockerfile shipped with Spark. - -p file Dockerfile to build for PySpark Jobs. Builds Python dependencies and ships with Spark. - -R file Dockerfile to build for SparkR Jobs. Builds R dependencies and ships with Spark. + -p file (Optional) Dockerfile to build for PySpark Jobs. Builds Python dependencies and ships with Spark. + Skips building PySpark docker image if not specified. + -R file (Optional) Dockerfile to build for SparkR Jobs. Builds R dependencies and ships with Spark. + Skips building SparkR docker image if not specified. -r repo Repository address. -t tag Tag to apply to the built image, or to identify the image to be pushed. -m Use minikube's Docker daemon. @@ -164,6 +222,9 @@ Examples: - Build image in minikube with tag "testing" $0 -m -t testing build + - Build PySpark docker image + $0 -r docker.io/myrepo -t v2.3.0 -p kubernetes/dockerfiles/spark/bindings/python/Dockerfile build + - Build and push image with tag "v2.3.0" to docker.io/myrepo $0 -r docker.io/myrepo -t v2.3.0 build $0 -r docker.io/myrepo -t v2.3.0 push diff --git a/bin/load-spark-env.sh b/bin/load-spark-env.sh index 0b5006dbd63ac..0ada5d8d0fc1d 100644 --- a/bin/load-spark-env.sh +++ b/bin/load-spark-env.sh @@ -26,15 +26,17 @@ if [ -z "${SPARK_HOME}" ]; then source "$(dirname "$0")"/find-spark-home fi +SPARK_ENV_SH="spark-env.sh" if [ -z "$SPARK_ENV_LOADED" ]; then export SPARK_ENV_LOADED=1 export SPARK_CONF_DIR="${SPARK_CONF_DIR:-"${SPARK_HOME}"/conf}" - if [ -f "${SPARK_CONF_DIR}/spark-env.sh" ]; then + SPARK_ENV_SH="${SPARK_CONF_DIR}/${SPARK_ENV_SH}" + if [[ -f "${SPARK_ENV_SH}" ]]; then # Promote all variable declarations to environment (exported) variables set -a - . "${SPARK_CONF_DIR}/spark-env.sh" + . ${SPARK_ENV_SH} set +a fi fi @@ -42,19 +44,22 @@ fi # Setting SPARK_SCALA_VERSION if not already set. if [ -z "$SPARK_SCALA_VERSION" ]; then + SCALA_VERSION_1=2.12 + SCALA_VERSION_2=2.11 - ASSEMBLY_DIR2="${SPARK_HOME}/assembly/target/scala-2.11" - ASSEMBLY_DIR1="${SPARK_HOME}/assembly/target/scala-2.12" - - if [[ -d "$ASSEMBLY_DIR2" && -d "$ASSEMBLY_DIR1" ]]; then - echo -e "Presence of build for multiple Scala versions detected." 1>&2 - echo -e 'Either clean one of them or, export SPARK_SCALA_VERSION in spark-env.sh.' 1>&2 + ASSEMBLY_DIR_1="${SPARK_HOME}/assembly/target/scala-${SCALA_VERSION_1}" + ASSEMBLY_DIR_2="${SPARK_HOME}/assembly/target/scala-${SCALA_VERSION_2}" + ENV_VARIABLE_DOC="https://spark.apache.org/docs/latest/configuration.html#environment-variables" + if [[ -d "$ASSEMBLY_DIR_1" && -d "$ASSEMBLY_DIR_2" ]]; then + echo "Presence of build for multiple Scala versions detected ($ASSEMBLY_DIR_1 and $ASSEMBLY_DIR_2)." 1>&2 + echo "Remove one of them or, export SPARK_SCALA_VERSION=$SCALA_VERSION_1 in ${SPARK_ENV_SH}." 1>&2 + echo "Visit ${ENV_VARIABLE_DOC} for more details about setting environment variables in spark-env.sh." 1>&2 exit 1 fi - if [ -d "$ASSEMBLY_DIR2" ]; then - export SPARK_SCALA_VERSION="2.11" + if [[ -d "$ASSEMBLY_DIR_1" ]]; then + export SPARK_SCALA_VERSION=${SCALA_VERSION_1} else - export SPARK_SCALA_VERSION="2.12" + export SPARK_SCALA_VERSION=${SCALA_VERSION_2} fi fi diff --git a/build/mvn b/build/mvn index 5b2b3c8351114..8931f4f9ffabe 100755 --- a/build/mvn +++ b/build/mvn @@ -112,7 +112,8 @@ install_zinc() { # the build/ folder install_scala() { # determine the Scala version used in Spark - local scala_version=`grep "scala.version" "${_DIR}/../pom.xml" | head -n1 | awk -F '[<>]' '{print $3}'` + local scala_binary_version=`grep "scala.binary.version" "${_DIR}/../pom.xml" | head -n1 | awk -F '[<>]' '{print $3}'` + local scala_version=`grep "scala.version" "${_DIR}/../pom.xml" | grep ${scala_binary_version} | head -n1 | awk -F '[<>]' '{print $3}'` local scala_bin="${_DIR}/scala-${scala_version}/bin/scala" local TYPESAFE_MIRROR=${TYPESAFE_MIRROR:-https://downloads.lightbend.com} diff --git a/common/kvstore/pom.xml b/common/kvstore/pom.xml index 23a0f49206909..f042a12fda3d2 100644 --- a/common/kvstore/pom.xml +++ b/common/kvstore/pom.xml @@ -21,12 +21,12 @@ 4.0.0 org.apache.spark - spark-parent_2.11 + spark-parent_2.12 3.0.0-SNAPSHOT ../../pom.xml - spark-kvstore_2.11 + spark-kvstore_2.12 jar Spark Project Local DB http://spark.apache.org/ diff --git a/common/kvstore/src/main/java/org/apache/spark/util/kvstore/LevelDBIterator.java b/common/kvstore/src/main/java/org/apache/spark/util/kvstore/LevelDBIterator.java index 448ba78bd23d0..55b4754a6c4ec 100644 --- a/common/kvstore/src/main/java/org/apache/spark/util/kvstore/LevelDBIterator.java +++ b/common/kvstore/src/main/java/org/apache/spark/util/kvstore/LevelDBIterator.java @@ -195,6 +195,7 @@ public synchronized void close() throws IOException { * when Scala wrappers are used, this makes sure that, hopefully, the JNI resources held by * the iterator will eventually be released. */ + @SuppressWarnings("deprecation") @Override protected void finalize() throws Throwable { db.closeIterator(this); diff --git a/common/network-common/pom.xml b/common/network-common/pom.xml index 41fcbf0589499..56d01fa0e8b3d 100644 --- a/common/network-common/pom.xml +++ b/common/network-common/pom.xml @@ -21,12 +21,12 @@ 4.0.0 org.apache.spark - spark-parent_2.11 + spark-parent_2.12 3.0.0-SNAPSHOT ../../pom.xml - spark-network-common_2.11 + spark-network-common_2.12 jar Spark Project Networking http://spark.apache.org/ diff --git a/common/network-common/src/main/java/org/apache/spark/network/protocol/ChunkFetchFailure.java b/common/network-common/src/main/java/org/apache/spark/network/protocol/ChunkFetchFailure.java index 7b28a9a969486..a7afbfa8621c8 100644 --- a/common/network-common/src/main/java/org/apache/spark/network/protocol/ChunkFetchFailure.java +++ b/common/network-common/src/main/java/org/apache/spark/network/protocol/ChunkFetchFailure.java @@ -33,7 +33,7 @@ public ChunkFetchFailure(StreamChunkId streamChunkId, String errorString) { } @Override - public Type type() { return Type.ChunkFetchFailure; } + public Message.Type type() { return Type.ChunkFetchFailure; } @Override public int encodedLength() { diff --git a/common/network-common/src/main/java/org/apache/spark/network/protocol/ChunkFetchRequest.java b/common/network-common/src/main/java/org/apache/spark/network/protocol/ChunkFetchRequest.java index 26d063feb5fe3..fe54fcc50dc86 100644 --- a/common/network-common/src/main/java/org/apache/spark/network/protocol/ChunkFetchRequest.java +++ b/common/network-common/src/main/java/org/apache/spark/network/protocol/ChunkFetchRequest.java @@ -32,7 +32,7 @@ public ChunkFetchRequest(StreamChunkId streamChunkId) { } @Override - public Type type() { return Type.ChunkFetchRequest; } + public Message.Type type() { return Type.ChunkFetchRequest; } @Override public int encodedLength() { diff --git a/common/network-common/src/main/java/org/apache/spark/network/protocol/ChunkFetchSuccess.java b/common/network-common/src/main/java/org/apache/spark/network/protocol/ChunkFetchSuccess.java index 94c2ac9b20e43..d5c9a9b3202fb 100644 --- a/common/network-common/src/main/java/org/apache/spark/network/protocol/ChunkFetchSuccess.java +++ b/common/network-common/src/main/java/org/apache/spark/network/protocol/ChunkFetchSuccess.java @@ -39,7 +39,7 @@ public ChunkFetchSuccess(StreamChunkId streamChunkId, ManagedBuffer buffer) { } @Override - public Type type() { return Type.ChunkFetchSuccess; } + public Message.Type type() { return Type.ChunkFetchSuccess; } @Override public int encodedLength() { diff --git a/common/network-common/src/main/java/org/apache/spark/network/protocol/OneWayMessage.java b/common/network-common/src/main/java/org/apache/spark/network/protocol/OneWayMessage.java index f7ffb1bd49bb6..1632fb9e03687 100644 --- a/common/network-common/src/main/java/org/apache/spark/network/protocol/OneWayMessage.java +++ b/common/network-common/src/main/java/org/apache/spark/network/protocol/OneWayMessage.java @@ -34,7 +34,7 @@ public OneWayMessage(ManagedBuffer body) { } @Override - public Type type() { return Type.OneWayMessage; } + public Message.Type type() { return Type.OneWayMessage; } @Override public int encodedLength() { diff --git a/common/network-common/src/main/java/org/apache/spark/network/protocol/RpcFailure.java b/common/network-common/src/main/java/org/apache/spark/network/protocol/RpcFailure.java index a76624ef5dc96..61061903de23f 100644 --- a/common/network-common/src/main/java/org/apache/spark/network/protocol/RpcFailure.java +++ b/common/network-common/src/main/java/org/apache/spark/network/protocol/RpcFailure.java @@ -31,7 +31,7 @@ public RpcFailure(long requestId, String errorString) { } @Override - public Type type() { return Type.RpcFailure; } + public Message.Type type() { return Type.RpcFailure; } @Override public int encodedLength() { diff --git a/common/network-common/src/main/java/org/apache/spark/network/protocol/RpcRequest.java b/common/network-common/src/main/java/org/apache/spark/network/protocol/RpcRequest.java index 2b30920f0598d..cc1bb95d2d566 100644 --- a/common/network-common/src/main/java/org/apache/spark/network/protocol/RpcRequest.java +++ b/common/network-common/src/main/java/org/apache/spark/network/protocol/RpcRequest.java @@ -38,7 +38,7 @@ public RpcRequest(long requestId, ManagedBuffer message) { } @Override - public Type type() { return Type.RpcRequest; } + public Message.Type type() { return Type.RpcRequest; } @Override public int encodedLength() { diff --git a/common/network-common/src/main/java/org/apache/spark/network/protocol/RpcResponse.java b/common/network-common/src/main/java/org/apache/spark/network/protocol/RpcResponse.java index d73014ecd8506..c03291e9c0b23 100644 --- a/common/network-common/src/main/java/org/apache/spark/network/protocol/RpcResponse.java +++ b/common/network-common/src/main/java/org/apache/spark/network/protocol/RpcResponse.java @@ -33,7 +33,7 @@ public RpcResponse(long requestId, ManagedBuffer message) { } @Override - public Type type() { return Type.RpcResponse; } + public Message.Type type() { return Type.RpcResponse; } @Override public int encodedLength() { diff --git a/common/network-common/src/main/java/org/apache/spark/network/protocol/StreamFailure.java b/common/network-common/src/main/java/org/apache/spark/network/protocol/StreamFailure.java index 258ef81c6783d..68fcfa7748611 100644 --- a/common/network-common/src/main/java/org/apache/spark/network/protocol/StreamFailure.java +++ b/common/network-common/src/main/java/org/apache/spark/network/protocol/StreamFailure.java @@ -33,7 +33,7 @@ public StreamFailure(String streamId, String error) { } @Override - public Type type() { return Type.StreamFailure; } + public Message.Type type() { return Type.StreamFailure; } @Override public int encodedLength() { diff --git a/common/network-common/src/main/java/org/apache/spark/network/protocol/StreamRequest.java b/common/network-common/src/main/java/org/apache/spark/network/protocol/StreamRequest.java index dc183c043ed9a..1b135af752bd8 100644 --- a/common/network-common/src/main/java/org/apache/spark/network/protocol/StreamRequest.java +++ b/common/network-common/src/main/java/org/apache/spark/network/protocol/StreamRequest.java @@ -34,7 +34,7 @@ public StreamRequest(String streamId) { } @Override - public Type type() { return Type.StreamRequest; } + public Message.Type type() { return Type.StreamRequest; } @Override public int encodedLength() { diff --git a/common/network-common/src/main/java/org/apache/spark/network/protocol/StreamResponse.java b/common/network-common/src/main/java/org/apache/spark/network/protocol/StreamResponse.java index 50b811604b84b..568108c4fe5e8 100644 --- a/common/network-common/src/main/java/org/apache/spark/network/protocol/StreamResponse.java +++ b/common/network-common/src/main/java/org/apache/spark/network/protocol/StreamResponse.java @@ -40,7 +40,7 @@ public StreamResponse(String streamId, long byteCount, ManagedBuffer buffer) { } @Override - public Type type() { return Type.StreamResponse; } + public Message.Type type() { return Type.StreamResponse; } @Override public int encodedLength() { diff --git a/common/network-common/src/main/java/org/apache/spark/network/protocol/UploadStream.java b/common/network-common/src/main/java/org/apache/spark/network/protocol/UploadStream.java index fa1d26e76b852..7d21151e01074 100644 --- a/common/network-common/src/main/java/org/apache/spark/network/protocol/UploadStream.java +++ b/common/network-common/src/main/java/org/apache/spark/network/protocol/UploadStream.java @@ -52,7 +52,7 @@ private UploadStream(long requestId, ManagedBuffer meta, long bodyByteCount) { } @Override - public Type type() { return Type.UploadStream; } + public Message.Type type() { return Type.UploadStream; } @Override public int encodedLength() { diff --git a/common/network-common/src/main/java/org/apache/spark/network/sasl/SaslMessage.java b/common/network-common/src/main/java/org/apache/spark/network/sasl/SaslMessage.java index 7331c2b481fb1..1b03300d948e2 100644 --- a/common/network-common/src/main/java/org/apache/spark/network/sasl/SaslMessage.java +++ b/common/network-common/src/main/java/org/apache/spark/network/sasl/SaslMessage.java @@ -23,6 +23,7 @@ import org.apache.spark.network.buffer.NettyManagedBuffer; import org.apache.spark.network.protocol.Encoders; import org.apache.spark.network.protocol.AbstractMessage; +import org.apache.spark.network.protocol.Message; /** * Encodes a Sasl-related message which is attempting to authenticate using some credentials tagged @@ -46,7 +47,7 @@ class SaslMessage extends AbstractMessage { } @Override - public Type type() { return Type.User; } + public Message.Type type() { return Type.User; } @Override public int encodedLength() { diff --git a/common/network-common/src/test/java/org/apache/spark/network/RpcIntegrationSuite.java b/common/network-common/src/test/java/org/apache/spark/network/RpcIntegrationSuite.java index 1f4d75c7e2ec5..1c0aa4da27ff9 100644 --- a/common/network-common/src/test/java/org/apache/spark/network/RpcIntegrationSuite.java +++ b/common/network-common/src/test/java/org/apache/spark/network/RpcIntegrationSuite.java @@ -371,23 +371,33 @@ private void assertErrorsContain(Set errors, Set contains) { private void assertErrorAndClosed(RpcResult result, String expectedError) { assertTrue("unexpected success: " + result.successMessages, result.successMessages.isEmpty()); - // we expect 1 additional error, which contains *either* "closed" or "Connection reset" Set errors = result.errorMessages; assertEquals("Expected 2 errors, got " + errors.size() + "errors: " + errors, 2, errors.size()); + // We expect 1 additional error due to closed connection and here are possible keywords in the + // error message. + Set possibleClosedErrors = Sets.newHashSet( + "closed", + "Connection reset", + "java.nio.channels.ClosedChannelException", + "java.io.IOException: Broken pipe" + ); Set containsAndClosed = Sets.newHashSet(expectedError); - containsAndClosed.add("closed"); - containsAndClosed.add("Connection reset"); + containsAndClosed.addAll(possibleClosedErrors); Pair, Set> r = checkErrorsContain(errors, containsAndClosed); - Set errorsNotFound = r.getRight(); - assertEquals(1, errorsNotFound.size()); - String err = errorsNotFound.iterator().next(); - assertTrue(err.equals("closed") || err.equals("Connection reset")); + assertTrue("Got a non-empty set " + r.getLeft(), r.getLeft().isEmpty()); - assertTrue(r.getLeft().isEmpty()); + Set errorsNotFound = r.getRight(); + assertEquals( + "The size of " + errorsNotFound + " was not " + (possibleClosedErrors.size() - 1), + possibleClosedErrors.size() - 1, + errorsNotFound.size()); + for (String err: errorsNotFound) { + assertTrue("Found a wrong error " + err, containsAndClosed.contains(err)); + } } private Pair, Set> checkErrorsContain( diff --git a/common/network-shuffle/pom.xml b/common/network-shuffle/pom.xml index ff717057bb25d..a6d99813a8501 100644 --- a/common/network-shuffle/pom.xml +++ b/common/network-shuffle/pom.xml @@ -21,12 +21,12 @@ 4.0.0 org.apache.spark - spark-parent_2.11 + spark-parent_2.12 3.0.0-SNAPSHOT ../../pom.xml - spark-network-shuffle_2.11 + spark-network-shuffle_2.12 jar Spark Project Shuffle Streaming Service http://spark.apache.org/ diff --git a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/RetryingBlockFetcher.java b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/RetryingBlockFetcher.java index f309dda8afca6..6bf3da94030d4 100644 --- a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/RetryingBlockFetcher.java +++ b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/RetryingBlockFetcher.java @@ -101,7 +101,7 @@ void createAndStart(String[] blockIds, BlockFetchingListener listener) public RetryingBlockFetcher( TransportConf conf, - BlockFetchStarter fetchStarter, + RetryingBlockFetcher.BlockFetchStarter fetchStarter, String[] blockIds, BlockFetchingListener listener) { this.fetchStarter = fetchStarter; diff --git a/common/network-yarn/pom.xml b/common/network-yarn/pom.xml index a1cf761d12d8b..55cdc3140aa08 100644 --- a/common/network-yarn/pom.xml +++ b/common/network-yarn/pom.xml @@ -21,12 +21,12 @@ 4.0.0 org.apache.spark - spark-parent_2.11 + spark-parent_2.12 3.0.0-SNAPSHOT ../../pom.xml - spark-network-yarn_2.11 + spark-network-yarn_2.12 jar Spark Project YARN Shuffle Service http://spark.apache.org/ diff --git a/common/sketch/pom.xml b/common/sketch/pom.xml index adbbcb1cb3040..3c3c0d2d96a1c 100644 --- a/common/sketch/pom.xml +++ b/common/sketch/pom.xml @@ -21,12 +21,12 @@ 4.0.0 org.apache.spark - spark-parent_2.11 + spark-parent_2.12 3.0.0-SNAPSHOT ../../pom.xml - spark-sketch_2.11 + spark-sketch_2.12 jar Spark Project Sketch http://spark.apache.org/ diff --git a/common/tags/pom.xml b/common/tags/pom.xml index f6627beabe84b..883b73a69c9de 100644 --- a/common/tags/pom.xml +++ b/common/tags/pom.xml @@ -21,12 +21,12 @@ 4.0.0 org.apache.spark - spark-parent_2.11 + spark-parent_2.12 3.0.0-SNAPSHOT ../../pom.xml - spark-tags_2.11 + spark-tags_2.12 jar Spark Project Tags http://spark.apache.org/ diff --git a/common/tags/src/main/java/org/apache/spark/annotation/DeveloperApi.java b/common/tags/src/main/java/org/apache/spark/annotation/DeveloperApi.java index 0ecef6db0e039..890f2faca28b0 100644 --- a/common/tags/src/main/java/org/apache/spark/annotation/DeveloperApi.java +++ b/common/tags/src/main/java/org/apache/spark/annotation/DeveloperApi.java @@ -29,6 +29,7 @@ * of the known issue that Scaladoc displays only either the annotation or the comment, whichever * comes first. */ +@Documented @Retention(RetentionPolicy.RUNTIME) @Target({ElementType.TYPE, ElementType.FIELD, ElementType.METHOD, ElementType.PARAMETER, ElementType.CONSTRUCTOR, ElementType.LOCAL_VARIABLE, ElementType.PACKAGE}) diff --git a/common/tags/src/main/java/org/apache/spark/annotation/Evolving.java b/common/tags/src/main/java/org/apache/spark/annotation/Evolving.java new file mode 100644 index 0000000000000..87e8948f204ff --- /dev/null +++ b/common/tags/src/main/java/org/apache/spark/annotation/Evolving.java @@ -0,0 +1,30 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.annotation; + +import java.lang.annotation.*; + +/** + * APIs that are meant to evolve towards becoming stable APIs, but are not stable APIs yet. + * Evolving interfaces can change from one feature release to another release (i.e. 2.1 to 2.2). + */ +@Documented +@Retention(RetentionPolicy.RUNTIME) +@Target({ElementType.TYPE, ElementType.FIELD, ElementType.METHOD, ElementType.PARAMETER, + ElementType.CONSTRUCTOR, ElementType.LOCAL_VARIABLE, ElementType.PACKAGE}) +public @interface Evolving {} diff --git a/common/tags/src/main/java/org/apache/spark/annotation/Experimental.java b/common/tags/src/main/java/org/apache/spark/annotation/Experimental.java index ff8120291455f..96875920cd9c3 100644 --- a/common/tags/src/main/java/org/apache/spark/annotation/Experimental.java +++ b/common/tags/src/main/java/org/apache/spark/annotation/Experimental.java @@ -30,6 +30,7 @@ * of the known issue that Scaladoc displays only either the annotation or the comment, whichever * comes first. */ +@Documented @Retention(RetentionPolicy.RUNTIME) @Target({ElementType.TYPE, ElementType.FIELD, ElementType.METHOD, ElementType.PARAMETER, ElementType.CONSTRUCTOR, ElementType.LOCAL_VARIABLE, ElementType.PACKAGE}) diff --git a/common/tags/src/main/java/org/apache/spark/annotation/InterfaceStability.java b/common/tags/src/main/java/org/apache/spark/annotation/InterfaceStability.java deleted file mode 100644 index 323098f69c6e1..0000000000000 --- a/common/tags/src/main/java/org/apache/spark/annotation/InterfaceStability.java +++ /dev/null @@ -1,49 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.annotation; - -import java.lang.annotation.Documented; - -/** - * Annotation to inform users of how much to rely on a particular package, - * class or method not changing over time. - */ -public class InterfaceStability { - - /** - * Stable APIs that retain source and binary compatibility within a major release. - * These interfaces can change from one major release to another major release - * (e.g. from 1.0 to 2.0). - */ - @Documented - public @interface Stable {}; - - /** - * APIs that are meant to evolve towards becoming stable APIs, but are not stable APIs yet. - * Evolving interfaces can change from one feature release to another release (i.e. 2.1 to 2.2). - */ - @Documented - public @interface Evolving {}; - - /** - * Unstable APIs, with no guarantee on stability. - * Classes that are unannotated are considered Unstable. - */ - @Documented - public @interface Unstable {}; -} diff --git a/common/tags/src/main/java/org/apache/spark/annotation/Private.java b/common/tags/src/main/java/org/apache/spark/annotation/Private.java index 9082fcf0c84bc..a460d608ae16b 100644 --- a/common/tags/src/main/java/org/apache/spark/annotation/Private.java +++ b/common/tags/src/main/java/org/apache/spark/annotation/Private.java @@ -17,10 +17,7 @@ package org.apache.spark.annotation; -import java.lang.annotation.ElementType; -import java.lang.annotation.Retention; -import java.lang.annotation.RetentionPolicy; -import java.lang.annotation.Target; +import java.lang.annotation.*; /** * A class that is considered private to the internals of Spark -- there is a high-likelihood @@ -35,6 +32,7 @@ * of the known issue that Scaladoc displays only either the annotation or the comment, whichever * comes first. */ +@Documented @Retention(RetentionPolicy.RUNTIME) @Target({ElementType.TYPE, ElementType.FIELD, ElementType.METHOD, ElementType.PARAMETER, ElementType.CONSTRUCTOR, ElementType.LOCAL_VARIABLE, ElementType.PACKAGE}) diff --git a/common/tags/src/main/java/org/apache/spark/annotation/Stable.java b/common/tags/src/main/java/org/apache/spark/annotation/Stable.java new file mode 100644 index 0000000000000..b198bfbe91e10 --- /dev/null +++ b/common/tags/src/main/java/org/apache/spark/annotation/Stable.java @@ -0,0 +1,31 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.annotation; + +import java.lang.annotation.*; + +/** + * Stable APIs that retain source and binary compatibility within a major release. + * These interfaces can change from one major release to another major release + * (e.g. from 1.0 to 2.0). + */ +@Documented +@Retention(RetentionPolicy.RUNTIME) +@Target({ElementType.TYPE, ElementType.FIELD, ElementType.METHOD, ElementType.PARAMETER, + ElementType.CONSTRUCTOR, ElementType.LOCAL_VARIABLE, ElementType.PACKAGE}) +public @interface Stable {} diff --git a/common/tags/src/main/java/org/apache/spark/annotation/Unstable.java b/common/tags/src/main/java/org/apache/spark/annotation/Unstable.java new file mode 100644 index 0000000000000..88ee72125b23f --- /dev/null +++ b/common/tags/src/main/java/org/apache/spark/annotation/Unstable.java @@ -0,0 +1,30 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.annotation; + +import java.lang.annotation.*; + +/** + * Unstable APIs, with no guarantee on stability. + * Classes that are unannotated are considered Unstable. + */ +@Documented +@Retention(RetentionPolicy.RUNTIME) +@Target({ElementType.TYPE, ElementType.FIELD, ElementType.METHOD, ElementType.PARAMETER, + ElementType.CONSTRUCTOR, ElementType.LOCAL_VARIABLE, ElementType.PACKAGE}) +public @interface Unstable {} diff --git a/common/unsafe/pom.xml b/common/unsafe/pom.xml index 62c493a5e1ed8..93a4f67fd23f2 100644 --- a/common/unsafe/pom.xml +++ b/common/unsafe/pom.xml @@ -21,12 +21,12 @@ 4.0.0 org.apache.spark - spark-parent_2.11 + spark-parent_2.12 3.0.0-SNAPSHOT ../../pom.xml - spark-unsafe_2.11 + spark-unsafe_2.12 jar Spark Project Unsafe http://spark.apache.org/ @@ -89,6 +89,11 @@ commons-lang3 test + + org.apache.commons + commons-text + test + target/scala-${scala.binary.version}/classes diff --git a/common/unsafe/src/main/java/org/apache/spark/unsafe/Platform.java b/common/unsafe/src/main/java/org/apache/spark/unsafe/Platform.java index aca6fca00c48b..4563efcfcf474 100644 --- a/common/unsafe/src/main/java/org/apache/spark/unsafe/Platform.java +++ b/common/unsafe/src/main/java/org/apache/spark/unsafe/Platform.java @@ -19,10 +19,10 @@ import java.lang.reflect.Constructor; import java.lang.reflect.Field; +import java.lang.reflect.InvocationTargetException; import java.lang.reflect.Method; import java.nio.ByteBuffer; -import sun.misc.Cleaner; import sun.misc.Unsafe; public final class Platform { @@ -67,6 +67,60 @@ public final class Platform { unaligned = _unaligned; } + // Access fields and constructors once and store them, for performance: + + private static final Constructor DBB_CONSTRUCTOR; + private static final Field DBB_CLEANER_FIELD; + static { + try { + Class cls = Class.forName("java.nio.DirectByteBuffer"); + Constructor constructor = cls.getDeclaredConstructor(Long.TYPE, Integer.TYPE); + constructor.setAccessible(true); + Field cleanerField = cls.getDeclaredField("cleaner"); + cleanerField.setAccessible(true); + DBB_CONSTRUCTOR = constructor; + DBB_CLEANER_FIELD = cleanerField; + } catch (ClassNotFoundException | NoSuchMethodException | NoSuchFieldException e) { + throw new IllegalStateException(e); + } + } + + private static final Method CLEANER_CREATE_METHOD; + static { + // The implementation of Cleaner changed from JDK 8 to 9 + // Split java.version on non-digit chars: + int majorVersion = Integer.parseInt(System.getProperty("java.version").split("\\D+")[0]); + String cleanerClassName; + if (majorVersion < 9) { + cleanerClassName = "sun.misc.Cleaner"; + } else { + cleanerClassName = "jdk.internal.ref.Cleaner"; + } + try { + Class cleanerClass = Class.forName(cleanerClassName); + Method createMethod = cleanerClass.getMethod("create", Object.class, Runnable.class); + // Accessing jdk.internal.ref.Cleaner should actually fail by default in JDK 9+, + // unfortunately, unless the user has allowed access with something like + // --add-opens java.base/java.lang=ALL-UNNAMED If not, we can't really use the Cleaner + // hack below. It doesn't break, just means the user might run into the default JVM limit + // on off-heap memory and increase it or set the flag above. This tests whether it's + // available: + try { + createMethod.invoke(null, null, null); + } catch (IllegalAccessException e) { + // Don't throw an exception, but can't log here? + createMethod = null; + } catch (InvocationTargetException ite) { + // shouldn't happen; report it + throw new IllegalStateException(ite); + } + CLEANER_CREATE_METHOD = createMethod; + } catch (ClassNotFoundException | NoSuchMethodException e) { + throw new IllegalStateException(e); + } + + } + /** * @return true when running JVM is having sun's Unsafe package available in it and underlying * system having unaligned-access capability. @@ -120,6 +174,11 @@ public static float getFloat(Object object, long offset) { } public static void putFloat(Object object, long offset, float value) { + if (Float.isNaN(value)) { + value = Float.NaN; + } else if (value == -0.0f) { + value = 0.0f; + } _UNSAFE.putFloat(object, offset, value); } @@ -128,6 +187,11 @@ public static double getDouble(Object object, long offset) { } public static void putDouble(Object object, long offset, double value) { + if (Double.isNaN(value)) { + value = Double.NaN; + } else if (value == -0.0d) { + value = 0.0d; + } _UNSAFE.putDouble(object, offset, value); } @@ -159,18 +223,18 @@ public static long reallocateMemory(long address, long oldSize, long newSize) { * MaxDirectMemorySize limit (the default limit is too low and we do not want to require users * to increase it). */ - @SuppressWarnings("unchecked") public static ByteBuffer allocateDirectBuffer(int size) { try { - Class cls = Class.forName("java.nio.DirectByteBuffer"); - Constructor constructor = cls.getDeclaredConstructor(Long.TYPE, Integer.TYPE); - constructor.setAccessible(true); - Field cleanerField = cls.getDeclaredField("cleaner"); - cleanerField.setAccessible(true); long memory = allocateMemory(size); - ByteBuffer buffer = (ByteBuffer) constructor.newInstance(memory, size); - Cleaner cleaner = Cleaner.create(buffer, () -> freeMemory(memory)); - cleanerField.set(buffer, cleaner); + ByteBuffer buffer = (ByteBuffer) DBB_CONSTRUCTOR.newInstance(memory, size); + if (CLEANER_CREATE_METHOD != null) { + try { + DBB_CLEANER_FIELD.set(buffer, + CLEANER_CREATE_METHOD.invoke(null, buffer, (Runnable) () -> freeMemory(memory))); + } catch (IllegalAccessException | InvocationTargetException e) { + throw new IllegalStateException(e); + } + } return buffer; } catch (Exception e) { throwException(e); diff --git a/common/unsafe/src/main/java/org/apache/spark/unsafe/UnsafeAlignedOffset.java b/common/unsafe/src/main/java/org/apache/spark/unsafe/UnsafeAlignedOffset.java index be62e40412f83..546e8780a6606 100644 --- a/common/unsafe/src/main/java/org/apache/spark/unsafe/UnsafeAlignedOffset.java +++ b/common/unsafe/src/main/java/org/apache/spark/unsafe/UnsafeAlignedOffset.java @@ -39,7 +39,9 @@ public static int getSize(Object object, long offset) { case 8: return (int)Platform.getLong(object, offset); default: + // checkstyle.off: RegexpSinglelineJava throw new AssertionError("Illegal UAO_SIZE"); + // checkstyle.on: RegexpSinglelineJava } } @@ -52,7 +54,9 @@ public static void putSize(Object object, long offset, int value) { Platform.putLong(object, offset, value); break; default: + // checkstyle.off: RegexpSinglelineJava throw new AssertionError("Illegal UAO_SIZE"); + // checkstyle.on: RegexpSinglelineJava } } } diff --git a/common/unsafe/src/test/java/org/apache/spark/unsafe/PlatformUtilSuite.java b/common/unsafe/src/test/java/org/apache/spark/unsafe/PlatformUtilSuite.java index 3ad9ac7b4de9c..2474081dad5c9 100644 --- a/common/unsafe/src/test/java/org/apache/spark/unsafe/PlatformUtilSuite.java +++ b/common/unsafe/src/test/java/org/apache/spark/unsafe/PlatformUtilSuite.java @@ -157,4 +157,22 @@ public void heapMemoryReuse() { Assert.assertEquals(onheap4.size(), 1024 * 1024 + 7); Assert.assertEquals(obj3, onheap4.getBaseObject()); } + + @Test + // SPARK-26021 + public void writeMinusZeroIsReplacedWithZero() { + byte[] doubleBytes = new byte[Double.BYTES]; + byte[] floatBytes = new byte[Float.BYTES]; + Platform.putDouble(doubleBytes, Platform.BYTE_ARRAY_OFFSET, -0.0d); + Platform.putFloat(floatBytes, Platform.BYTE_ARRAY_OFFSET, -0.0f); + + byte[] doubleBytes2 = new byte[Double.BYTES]; + byte[] floatBytes2 = new byte[Float.BYTES]; + Platform.putDouble(doubleBytes, Platform.BYTE_ARRAY_OFFSET, 0.0d); + Platform.putFloat(floatBytes, Platform.BYTE_ARRAY_OFFSET, 0.0f); + + // Make sure the bytes we write from 0.0 and -0.0 are same. + Assert.assertArrayEquals(doubleBytes, doubleBytes2); + Assert.assertArrayEquals(floatBytes, floatBytes2); + } } diff --git a/common/unsafe/src/test/scala/org/apache/spark/unsafe/types/UTF8StringPropertyCheckSuite.scala b/common/unsafe/src/test/scala/org/apache/spark/unsafe/types/UTF8StringPropertyCheckSuite.scala index 9656951810daf..fdb81a06d41c9 100644 --- a/common/unsafe/src/test/scala/org/apache/spark/unsafe/types/UTF8StringPropertyCheckSuite.scala +++ b/common/unsafe/src/test/scala/org/apache/spark/unsafe/types/UTF8StringPropertyCheckSuite.scala @@ -17,7 +17,7 @@ package org.apache.spark.unsafe.types -import org.apache.commons.lang3.StringUtils +import org.apache.commons.text.similarity.LevenshteinDistance import org.scalacheck.{Arbitrary, Gen} import org.scalatest.prop.GeneratorDrivenPropertyChecks // scalastyle:off @@ -232,7 +232,7 @@ class UTF8StringPropertyCheckSuite extends FunSuite with GeneratorDrivenProperty test("levenshteinDistance") { forAll { (one: String, another: String) => assert(toUTF8(one).levenshteinDistance(toUTF8(another)) === - StringUtils.getLevenshteinDistance(one, another)) + LevenshteinDistance.getDefaultInstance.apply(one, another)) } } diff --git a/core/benchmarks/KryoSerializerBenchmark-results.txt b/core/benchmarks/KryoSerializerBenchmark-results.txt new file mode 100644 index 0000000000000..c3ce336d93241 --- /dev/null +++ b/core/benchmarks/KryoSerializerBenchmark-results.txt @@ -0,0 +1,12 @@ +================================================================================================ +Benchmark KryoPool vs "pool of 1" +================================================================================================ + +Java HotSpot(TM) 64-Bit Server VM 1.8.0_131-b11 on Mac OS X 10.14 +Intel(R) Core(TM) i7-4770HQ CPU @ 2.20GHz +Benchmark KryoPool vs "pool of 1": Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative +------------------------------------------------------------------------------------------------ +KryoPool:true 2682 / 3425 0.0 5364627.9 1.0X +KryoPool:false 8176 / 9292 0.0 16351252.2 0.3X + + diff --git a/core/pom.xml b/core/pom.xml index 1e9f428d86860..c374862a5d64e 100644 --- a/core/pom.xml +++ b/core/pom.xml @@ -20,12 +20,12 @@ 4.0.0 org.apache.spark - spark-parent_2.11 + spark-parent_2.12 3.0.0-SNAPSHOT ../pom.xml - spark-core_2.11 + spark-core_2.12 core diff --git a/core/src/main/java/org/apache/spark/io/NioBufferedFileInputStream.java b/core/src/main/java/org/apache/spark/io/NioBufferedFileInputStream.java index f6d1288cb263d..92bf0ecc1b5cb 100644 --- a/core/src/main/java/org/apache/spark/io/NioBufferedFileInputStream.java +++ b/core/src/main/java/org/apache/spark/io/NioBufferedFileInputStream.java @@ -27,7 +27,7 @@ * to read a file to avoid extra copy of data between Java and * native memory which happens when using {@link java.io.BufferedInputStream}. * Unfortunately, this is not something already available in JDK, - * {@link sun.nio.ch.ChannelInputStream} supports reading a file using nio, + * {@code sun.nio.ch.ChannelInputStream} supports reading a file using nio, * but does not support buffering. */ public final class NioBufferedFileInputStream extends InputStream { @@ -130,6 +130,7 @@ public synchronized void close() throws IOException { StorageUtils.dispose(byteBuffer); } + @SuppressWarnings("deprecation") @Override protected void finalize() throws IOException { close(); diff --git a/core/src/main/java/org/apache/spark/memory/MemoryConsumer.java b/core/src/main/java/org/apache/spark/memory/MemoryConsumer.java index 115e1fbb79a2e..4bfd2d358f36f 100644 --- a/core/src/main/java/org/apache/spark/memory/MemoryConsumer.java +++ b/core/src/main/java/org/apache/spark/memory/MemoryConsumer.java @@ -83,10 +83,10 @@ public void spill() throws IOException { public abstract long spill(long size, MemoryConsumer trigger) throws IOException; /** - * Allocates a LongArray of `size`. Note that this method may throw `OutOfMemoryError` if Spark - * doesn't have enough memory for this allocation, or throw `TooLargePageException` if this - * `LongArray` is too large to fit in a single page. The caller side should take care of these - * two exceptions, or make sure the `size` is small enough that won't trigger exceptions. + * Allocates a LongArray of `size`. Note that this method may throw `SparkOutOfMemoryError` + * if Spark doesn't have enough memory for this allocation, or throw `TooLargePageException` + * if this `LongArray` is too large to fit in a single page. The caller side should take care of + * these two exceptions, or make sure the `size` is small enough that won't trigger exceptions. * * @throws SparkOutOfMemoryError * @throws TooLargePageException @@ -111,7 +111,7 @@ public void freeArray(LongArray array) { /** * Allocate a memory block with at least `required` bytes. * - * @throws OutOfMemoryError + * @throws SparkOutOfMemoryError */ protected MemoryBlock allocatePage(long required) { MemoryBlock page = taskMemoryManager.allocatePage(Math.max(pageSize, required), this); @@ -154,7 +154,9 @@ private void throwOom(final MemoryBlock page, final long required) { taskMemoryManager.freePage(page, this); } taskMemoryManager.showMemoryUsage(); + // checkstyle.off: RegexpSinglelineJava throw new SparkOutOfMemoryError("Unable to acquire " + required + " bytes of memory, got " + got); + // checkstyle.on: RegexpSinglelineJava } } diff --git a/core/src/main/java/org/apache/spark/memory/TaskMemoryManager.java b/core/src/main/java/org/apache/spark/memory/TaskMemoryManager.java index d07faf1da1248..28b646ba3c951 100644 --- a/core/src/main/java/org/apache/spark/memory/TaskMemoryManager.java +++ b/core/src/main/java/org/apache/spark/memory/TaskMemoryManager.java @@ -194,8 +194,10 @@ public long acquireExecutionMemory(long required, MemoryConsumer consumer) { throw new RuntimeException(e.getMessage()); } catch (IOException e) { logger.error("error while calling spill() on " + c, e); + // checkstyle.off: RegexpSinglelineJava throw new SparkOutOfMemoryError("error while calling spill() on " + c + " : " + e.getMessage()); + // checkstyle.on: RegexpSinglelineJava } } } @@ -215,8 +217,10 @@ public long acquireExecutionMemory(long required, MemoryConsumer consumer) { throw new RuntimeException(e.getMessage()); } catch (IOException e) { logger.error("error while calling spill() on " + consumer, e); + // checkstyle.off: RegexpSinglelineJava throw new SparkOutOfMemoryError("error while calling spill() on " + consumer + " : " + e.getMessage()); + // checkstyle.on: RegexpSinglelineJava } } diff --git a/core/src/main/java/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriter.java b/core/src/main/java/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriter.java index b020a6d99247b..fda33cd8293d5 100644 --- a/core/src/main/java/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriter.java +++ b/core/src/main/java/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriter.java @@ -37,12 +37,11 @@ import org.apache.spark.Partitioner; import org.apache.spark.ShuffleDependency; import org.apache.spark.SparkConf; -import org.apache.spark.TaskContext; -import org.apache.spark.executor.ShuffleWriteMetrics; import org.apache.spark.scheduler.MapStatus; import org.apache.spark.scheduler.MapStatus$; import org.apache.spark.serializer.Serializer; import org.apache.spark.serializer.SerializerInstance; +import org.apache.spark.shuffle.ShuffleWriteMetricsReporter; import org.apache.spark.shuffle.IndexShuffleBlockResolver; import org.apache.spark.shuffle.ShuffleWriter; import org.apache.spark.storage.*; @@ -79,7 +78,7 @@ final class BypassMergeSortShuffleWriter extends ShuffleWriter { private final int numPartitions; private final BlockManager blockManager; private final Partitioner partitioner; - private final ShuffleWriteMetrics writeMetrics; + private final ShuffleWriteMetricsReporter writeMetrics; private final int shuffleId; private final int mapId; private final Serializer serializer; @@ -103,8 +102,8 @@ final class BypassMergeSortShuffleWriter extends ShuffleWriter { IndexShuffleBlockResolver shuffleBlockResolver, BypassMergeSortShuffleHandle handle, int mapId, - TaskContext taskContext, - SparkConf conf) { + SparkConf conf, + ShuffleWriteMetricsReporter writeMetrics) { // Use getSizeAsKb (not bytes) to maintain backwards compatibility if no units are provided this.fileBufferSize = (int) conf.getSizeAsKb("spark.shuffle.file.buffer", "32k") * 1024; this.transferToEnabled = conf.getBoolean("spark.file.transferTo", true); @@ -114,7 +113,7 @@ final class BypassMergeSortShuffleWriter extends ShuffleWriter { this.shuffleId = dep.shuffleId(); this.partitioner = dep.partitioner(); this.numPartitions = partitioner.numPartitions(); - this.writeMetrics = taskContext.taskMetrics().shuffleWriteMetrics(); + this.writeMetrics = writeMetrics; this.serializer = dep.serializer(); this.shuffleBlockResolver = shuffleBlockResolver; } diff --git a/core/src/main/java/org/apache/spark/shuffle/sort/ShuffleExternalSorter.java b/core/src/main/java/org/apache/spark/shuffle/sort/ShuffleExternalSorter.java index 1c0d664afb138..6ee9d5f0eec3b 100644 --- a/core/src/main/java/org/apache/spark/shuffle/sort/ShuffleExternalSorter.java +++ b/core/src/main/java/org/apache/spark/shuffle/sort/ShuffleExternalSorter.java @@ -38,6 +38,7 @@ import org.apache.spark.memory.TooLargePageException; import org.apache.spark.serializer.DummySerializerInstance; import org.apache.spark.serializer.SerializerInstance; +import org.apache.spark.shuffle.ShuffleWriteMetricsReporter; import org.apache.spark.storage.BlockManager; import org.apache.spark.storage.DiskBlockObjectWriter; import org.apache.spark.storage.FileSegment; @@ -75,7 +76,7 @@ final class ShuffleExternalSorter extends MemoryConsumer { private final TaskMemoryManager taskMemoryManager; private final BlockManager blockManager; private final TaskContext taskContext; - private final ShuffleWriteMetrics writeMetrics; + private final ShuffleWriteMetricsReporter writeMetrics; /** * Force this sorter to spill when there are this many elements in memory. @@ -113,7 +114,7 @@ final class ShuffleExternalSorter extends MemoryConsumer { int initialSize, int numPartitions, SparkConf conf, - ShuffleWriteMetrics writeMetrics) { + ShuffleWriteMetricsReporter writeMetrics) { super(memoryManager, (int) Math.min(PackedRecordPointer.MAXIMUM_PAGE_SIZE_BYTES, memoryManager.pageSizeBytes()), memoryManager.getTungstenMemoryMode()); @@ -144,7 +145,7 @@ final class ShuffleExternalSorter extends MemoryConsumer { */ private void writeSortedFile(boolean isLastFile) { - final ShuffleWriteMetrics writeMetricsToUse; + final ShuffleWriteMetricsReporter writeMetricsToUse; if (isLastFile) { // We're writing the final non-spill file, so we _do_ want to count this as shuffle bytes. @@ -241,9 +242,14 @@ private void writeSortedFile(boolean isLastFile) { // // Note that we intentionally ignore the value of `writeMetricsToUse.shuffleWriteTime()`. // Consistent with ExternalSorter, we do not count this IO towards shuffle write time. - // This means that this IO time is not accounted for anywhere; SPARK-3577 will fix this. - writeMetrics.incRecordsWritten(writeMetricsToUse.recordsWritten()); - taskContext.taskMetrics().incDiskBytesSpilled(writeMetricsToUse.bytesWritten()); + // SPARK-3577 tracks the spill time separately. + + // This is guaranteed to be a ShuffleWriteMetrics based on the if check in the beginning + // of this method. + writeMetrics.incRecordsWritten( + ((ShuffleWriteMetrics)writeMetricsToUse).recordsWritten()); + taskContext.taskMetrics().incDiskBytesSpilled( + ((ShuffleWriteMetrics)writeMetricsToUse).bytesWritten()); } } diff --git a/core/src/main/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriter.java b/core/src/main/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriter.java index 4839d04522f10..4b0c74341551e 100644 --- a/core/src/main/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriter.java +++ b/core/src/main/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriter.java @@ -37,7 +37,6 @@ import org.apache.spark.*; import org.apache.spark.annotation.Private; -import org.apache.spark.executor.ShuffleWriteMetrics; import org.apache.spark.io.CompressionCodec; import org.apache.spark.io.CompressionCodec$; import org.apache.spark.io.NioBufferedFileInputStream; @@ -47,6 +46,7 @@ import org.apache.spark.network.util.LimitedInputStream; import org.apache.spark.scheduler.MapStatus; import org.apache.spark.scheduler.MapStatus$; +import org.apache.spark.shuffle.ShuffleWriteMetricsReporter; import org.apache.spark.serializer.SerializationStream; import org.apache.spark.serializer.SerializerInstance; import org.apache.spark.shuffle.IndexShuffleBlockResolver; @@ -73,7 +73,7 @@ public class UnsafeShuffleWriter extends ShuffleWriter { private final TaskMemoryManager memoryManager; private final SerializerInstance serializer; private final Partitioner partitioner; - private final ShuffleWriteMetrics writeMetrics; + private final ShuffleWriteMetricsReporter writeMetrics; private final int shuffleId; private final int mapId; private final TaskContext taskContext; @@ -122,7 +122,8 @@ public UnsafeShuffleWriter( SerializedShuffleHandle handle, int mapId, TaskContext taskContext, - SparkConf sparkConf) throws IOException { + SparkConf sparkConf, + ShuffleWriteMetricsReporter writeMetrics) throws IOException { final int numPartitions = handle.dependency().partitioner().numPartitions(); if (numPartitions > SortShuffleManager.MAX_SHUFFLE_OUTPUT_PARTITIONS_FOR_SERIALIZED_MODE()) { throw new IllegalArgumentException( @@ -138,7 +139,7 @@ public UnsafeShuffleWriter( this.shuffleId = dep.shuffleId(); this.serializer = dep.serializer().newInstance(); this.partitioner = dep.partitioner(); - this.writeMetrics = taskContext.taskMetrics().shuffleWriteMetrics(); + this.writeMetrics = writeMetrics; this.taskContext = taskContext; this.sparkConf = sparkConf; this.transferToEnabled = sparkConf.getBoolean("spark.file.transferTo", true); diff --git a/core/src/main/java/org/apache/spark/storage/TimeTrackingOutputStream.java b/core/src/main/java/org/apache/spark/storage/TimeTrackingOutputStream.java index 5d0555a8c28e1..fcba3b73445c9 100644 --- a/core/src/main/java/org/apache/spark/storage/TimeTrackingOutputStream.java +++ b/core/src/main/java/org/apache/spark/storage/TimeTrackingOutputStream.java @@ -21,7 +21,7 @@ import java.io.OutputStream; import org.apache.spark.annotation.Private; -import org.apache.spark.executor.ShuffleWriteMetrics; +import org.apache.spark.shuffle.ShuffleWriteMetricsReporter; /** * Intercepts write calls and tracks total time spent writing in order to update shuffle write @@ -30,10 +30,11 @@ @Private public final class TimeTrackingOutputStream extends OutputStream { - private final ShuffleWriteMetrics writeMetrics; + private final ShuffleWriteMetricsReporter writeMetrics; private final OutputStream outputStream; - public TimeTrackingOutputStream(ShuffleWriteMetrics writeMetrics, OutputStream outputStream) { + public TimeTrackingOutputStream( + ShuffleWriteMetricsReporter writeMetrics, OutputStream outputStream) { this.writeMetrics = writeMetrics; this.outputStream = outputStream; } diff --git a/core/src/main/java/org/apache/spark/unsafe/map/BytesToBytesMap.java b/core/src/main/java/org/apache/spark/unsafe/map/BytesToBytesMap.java index 9b6cbab38cbcc..a4e88598f7607 100644 --- a/core/src/main/java/org/apache/spark/unsafe/map/BytesToBytesMap.java +++ b/core/src/main/java/org/apache/spark/unsafe/map/BytesToBytesMap.java @@ -31,6 +31,7 @@ import org.apache.spark.SparkEnv; import org.apache.spark.executor.ShuffleWriteMetrics; import org.apache.spark.memory.MemoryConsumer; +import org.apache.spark.memory.SparkOutOfMemoryError; import org.apache.spark.memory.TaskMemoryManager; import org.apache.spark.serializer.SerializerManager; import org.apache.spark.storage.BlockManager; @@ -741,7 +742,7 @@ public boolean append(Object kbase, long koff, int klen, Object vbase, long voff if (numKeys >= growthThreshold && longArray.size() < MAX_CAPACITY) { try { growAndRehash(); - } catch (OutOfMemoryError oom) { + } catch (SparkOutOfMemoryError oom) { canGrowArray = false; } } @@ -757,7 +758,7 @@ public boolean append(Object kbase, long koff, int klen, Object vbase, long voff private boolean acquireNewPage(long required) { try { currentPage = allocatePage(required); - } catch (OutOfMemoryError e) { + } catch (SparkOutOfMemoryError e) { return false; } dataPages.add(currentPage); diff --git a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorter.java b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorter.java index 75690ae264838..1a9453a8b3e80 100644 --- a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorter.java +++ b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorter.java @@ -214,7 +214,9 @@ public boolean hasSpaceForAnotherRecord() { public void expandPointerArray(LongArray newArray) { if (newArray.size() < array.size()) { + // checkstyle.off: RegexpSinglelineJava throw new SparkOutOfMemoryError("Not enough memory to grow pointer array"); + // checkstyle.on: RegexpSinglelineJava } Platform.copyMemory( array.getBaseObject(), diff --git a/core/src/main/resources/org/apache/spark/ui/static/executorspage-template.html b/core/src/main/resources/org/apache/spark/ui/static/executorspage-template.html index 5c91304e49fd7..f2c17aef097a4 100644 --- a/core/src/main/resources/org/apache/spark/ui/static/executorspage-template.html +++ b/core/src/main/resources/org/apache/spark/ui/static/executorspage-template.html @@ -16,10 +16,10 @@ --> diff --git a/core/src/main/resources/org/apache/spark/ui/static/utils.js b/core/src/main/resources/org/apache/spark/ui/static/utils.js index 4f63f6413d6de..deeafad4eb5f5 100644 --- a/core/src/main/resources/org/apache/spark/ui/static/utils.js +++ b/core/src/main/resources/org/apache/spark/ui/static/utils.js @@ -18,7 +18,7 @@ // this function works exactly the same as UIUtils.formatDuration function formatDuration(milliseconds) { if (milliseconds < 100) { - return milliseconds + " ms"; + return parseInt(milliseconds).toFixed(1) + " ms"; } var seconds = milliseconds * 1.0 / 1000; if (seconds < 1) { @@ -74,3 +74,114 @@ function getTimeZone() { return new Date().toString().match(/\((.*)\)/)[1]; } } + +function formatLogsCells(execLogs, type) { + if (type !== 'display') return Object.keys(execLogs); + if (!execLogs) return; + var result = ''; + $.each(execLogs, function (logName, logUrl) { + result += '' + }); + return result; +} + +function getStandAloneAppId(cb) { + var words = document.baseURI.split('/'); + var ind = words.indexOf("proxy"); + if (ind > 0) { + var appId = words[ind + 1]; + cb(appId); + return; + } + ind = words.indexOf("history"); + if (ind > 0) { + var appId = words[ind + 1]; + cb(appId); + return; + } + // Looks like Web UI is running in standalone mode + // Let's get application-id using REST End Point + $.getJSON(location.origin + "/api/v1/applications", function(response, status, jqXHR) { + if (response && response.length > 0) { + var appId = response[0].id; + cb(appId); + return; + } + }); +} + +// This function is a helper function for sorting in datatable. +// When the data is in duration (e.g. 12ms 2s 2min 2h ) +// It will convert the string into integer for correct ordering +function ConvertDurationString(data) { + data = data.toString(); + var units = data.replace(/[\d\.]/g, '' ) + .replace(' ', '') + .toLowerCase(); + var multiplier = 1; + + switch(units) { + case 's': + multiplier = 1000; + break; + case 'min': + multiplier = 600000; + break; + case 'h': + multiplier = 3600000; + break; + default: + break; + } + return parseFloat(data) * multiplier; +} + +function createTemplateURI(appId, templateName) { + var words = document.baseURI.split('/'); + var ind = words.indexOf("proxy"); + if (ind > 0) { + var baseURI = words.slice(0, ind + 1).join('/') + '/' + appId + '/static/' + templateName + '-template.html'; + return baseURI; + } + ind = words.indexOf("history"); + if(ind > 0) { + var baseURI = words.slice(0, ind).join('/') + '/static/' + templateName + '-template.html'; + return baseURI; + } + return location.origin + "/static/" + templateName + "-template.html"; +} + +function setDataTableDefaults() { + $.extend($.fn.dataTable.defaults, { + stateSave: true, + lengthMenu: [[20, 40, 60, 100, -1], [20, 40, 60, 100, "All"]], + pageLength: 20 + }); +} + +function formatDate(date) { + if (date <= 0) return "-"; + else return date.split(".")[0].replace("T", " "); +} + +function createRESTEndPointForExecutorsPage(appId) { + var words = document.baseURI.split('/'); + var ind = words.indexOf("proxy"); + if (ind > 0) { + var appId = words[ind + 1]; + var newBaseURI = words.slice(0, ind + 2).join('/'); + return newBaseURI + "/api/v1/applications/" + appId + "/allexecutors" + } + ind = words.indexOf("history"); + if (ind > 0) { + var appId = words[ind + 1]; + var attemptId = words[ind + 2]; + var newBaseURI = words.slice(0, ind).join('/'); + if (isNaN(attemptId)) { + return newBaseURI + "/api/v1/applications/" + appId + "/allexecutors"; + } else { + return newBaseURI + "/api/v1/applications/" + appId + "/" + attemptId + "/allexecutors"; + } + } + return location.origin + "/api/v1/applications/" + appId + "/allexecutors"; +} diff --git a/core/src/main/resources/org/apache/spark/ui/static/webui-dataTables.css b/core/src/main/resources/org/apache/spark/ui/static/webui-dataTables.css new file mode 100644 index 0000000000000..f6b4abed21e0d --- /dev/null +++ b/core/src/main/resources/org/apache/spark/ui/static/webui-dataTables.css @@ -0,0 +1,20 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +table.dataTable thead .sorting_asc { background: url('images/sort_asc.png') no-repeat bottom right; } + +table.dataTable thead .sorting_desc { background: url('images/sort_desc.png') no-repeat bottom right; } \ No newline at end of file diff --git a/core/src/main/resources/org/apache/spark/ui/static/webui.css b/core/src/main/resources/org/apache/spark/ui/static/webui.css index 266eeec55576e..fe5bb25687af1 100644 --- a/core/src/main/resources/org/apache/spark/ui/static/webui.css +++ b/core/src/main/resources/org/apache/spark/ui/static/webui.css @@ -260,4 +260,105 @@ a.expandbutton { .paginate_button.active > a { color: #999999; text-decoration: underline; +} + +.title-table { + clear: left; + display: inline-block; +} + +.table-dataTable { + width: 100%; +} + +.container-fluid-div { + width: 200px; +} + +.scheduler-delay-checkbox-div { + width: 120px; +} + +.task-deserialization-time-checkbox-div { + width: 175px; +} + +.shuffle-read-blocked-time-checkbox-div { + width: 187px; +} + +.shuffle-remote-reads-checkbox-div { + width: 157px; +} + +.result-serialization-time-checkbox-div { + width: 171px; +} + +.getting-result-time-checkbox-div { + width: 141px; +} + +.peak-execution-memory-checkbox-div { + width: 170px; +} + +#active-tasks-table th { + border-top: 1px solid #dddddd; + border-bottom: 1px solid #dddddd; + border-right: 1px solid #dddddd; +} + +#active-tasks-table th:first-child { + border-left: 1px solid #dddddd; +} + +#accumulator-table th { + border-top: 1px solid #dddddd; + border-bottom: 1px solid #dddddd; + border-right: 1px solid #dddddd; +} + +#accumulator-table th:first-child { + border-left: 1px solid #dddddd; +} + +#summary-executor-table th { + border-top: 1px solid #dddddd; + border-bottom: 1px solid #dddddd; + border-right: 1px solid #dddddd; +} + +#summary-executor-table th:first-child { + border-left: 1px solid #dddddd; +} + +#summary-metrics-table th { + border-top: 1px solid #dddddd; + border-bottom: 1px solid #dddddd; + border-right: 1px solid #dddddd; +} + +#summary-metrics-table th:first-child { + border-left: 1px solid #dddddd; +} + +#summary-execs-table th { + border-top: 1px solid #dddddd; + border-bottom: 1px solid #dddddd; + border-right: 1px solid #dddddd; +} + +#summary-execs-table th:first-child { + border-left: 1px solid #dddddd; +} + +#active-executors-table th { + border-top: 1px solid #dddddd; + border-bottom: 1px solid #dddddd; + border-right: 1px solid #dddddd; +} + +#active-executors-table th:first-child { + border-left: 1px solid #dddddd; } \ No newline at end of file diff --git a/core/src/main/scala/org/apache/spark/SparkContext.scala b/core/src/main/scala/org/apache/spark/SparkContext.scala index e50020b07b0b9..c609161146b22 100644 --- a/core/src/main/scala/org/apache/spark/SparkContext.scala +++ b/core/src/main/scala/org/apache/spark/SparkContext.scala @@ -61,6 +61,7 @@ import org.apache.spark.storage._ import org.apache.spark.storage.BlockManagerMessages.TriggerThreadDump import org.apache.spark.ui.{ConsoleProgressBar, SparkUI} import org.apache.spark.util._ +import org.apache.spark.util.logging.DriverLogger /** * Main entry point for Spark functionality. A SparkContext represents the connection to a Spark @@ -208,6 +209,7 @@ class SparkContext(config: SparkConf) extends SafeLogging { private var _applicationId: String = _ private var _applicationAttemptId: Option[String] = None private var _eventLogger: Option[EventLoggingListener] = None + private var _driverLogger: Option[DriverLogger] = None private var _executorAllocationManager: Option[ExecutorAllocationManager] = None private var _cleaner: Option[ContextCleaner] = None private var _listenerBusStarted: Boolean = false @@ -377,6 +379,8 @@ class SparkContext(config: SparkConf) extends SafeLogging { throw new SparkException("An application name must be set in your configuration") } + _driverLogger = DriverLogger(_conf) + // log out spark.app.name in the Spark driver logs safeLogInfo("Submitted application", SafeArg.of("appName", appName)) @@ -504,7 +508,9 @@ class SparkContext(config: SparkConf) extends SafeLogging { _heartbeatReceiver.ask[Boolean](TaskSchedulerIsSet) // create and start the heartbeater for collecting memory metrics - _heartbeater = new Heartbeater(env.memoryManager, reportHeartBeat, "driver-heartbeater", + _heartbeater = new Heartbeater(env.memoryManager, + () => SparkContext.this.reportHeartBeat(), + "driver-heartbeater", conf.get(EXECUTOR_HEARTBEAT_INTERVAL)) _heartbeater.start() @@ -1911,6 +1917,9 @@ class SparkContext(config: SparkConf) extends SafeLogging { Utils.tryLogNonFatalError { postApplicationEnd() } + Utils.tryLogNonFatalError { + _driverLogger.foreach(_.stop()) + } Utils.tryLogNonFatalError { _ui.foreach(_.stop()) } @@ -2401,6 +2410,7 @@ class SparkContext(config: SparkConf) extends SafeLogging { // the cluster manager to get an application ID (in case the cluster manager provides one). listenerBus.post(SparkListenerApplicationStart(appName, Some(applicationId), startTime, sparkUser, applicationAttemptId, schedulerBackend.getDriverLogUrls)) + _driverLogger.foreach(_.startSync(_hadoopConfiguration)) } /** Post the application end event */ diff --git a/core/src/main/scala/org/apache/spark/SparkEnv.scala b/core/src/main/scala/org/apache/spark/SparkEnv.scala index dd087071113fe..c41d958347f0e 100644 --- a/core/src/main/scala/org/apache/spark/SparkEnv.scala +++ b/core/src/main/scala/org/apache/spark/SparkEnv.scala @@ -299,7 +299,7 @@ object SparkEnv extends Logging { // SparkConf, then one taking no arguments try { cls.getConstructor(classOf[SparkConf], java.lang.Boolean.TYPE) - .newInstance(conf, new java.lang.Boolean(isDriver)) + .newInstance(conf, java.lang.Boolean.valueOf(isDriver)) .asInstanceOf[T] } catch { case _: NoSuchMethodException => diff --git a/core/src/main/scala/org/apache/spark/api/java/JavaPairRDD.scala b/core/src/main/scala/org/apache/spark/api/java/JavaPairRDD.scala index 80a4f84087466..50ed8d9bd3f68 100644 --- a/core/src/main/scala/org/apache/spark/api/java/JavaPairRDD.scala +++ b/core/src/main/scala/org/apache/spark/api/java/JavaPairRDD.scala @@ -952,7 +952,7 @@ class JavaPairRDD[K, V](val rdd: RDD[(K, V)]) * * The algorithm used is based on streamlib's implementation of "HyperLogLog in Practice: * Algorithmic Engineering of a State of The Art Cardinality Estimation Algorithm", available - * here. + * here. * * @param relativeSD Relative accuracy. Smaller values create counters that require more space. * It must be greater than 0.000017. @@ -969,7 +969,7 @@ class JavaPairRDD[K, V](val rdd: RDD[(K, V)]) * * The algorithm used is based on streamlib's implementation of "HyperLogLog in Practice: * Algorithmic Engineering of a State of The Art Cardinality Estimation Algorithm", available - * here. + * here. * * @param relativeSD Relative accuracy. Smaller values create counters that require more space. * It must be greater than 0.000017. @@ -985,7 +985,7 @@ class JavaPairRDD[K, V](val rdd: RDD[(K, V)]) * * The algorithm used is based on streamlib's implementation of "HyperLogLog in Practice: * Algorithmic Engineering of a State of The Art Cardinality Estimation Algorithm", available - * here. + * here. * * @param relativeSD Relative accuracy. Smaller values create counters that require more space. * It must be greater than 0.000017. diff --git a/core/src/main/scala/org/apache/spark/api/java/JavaRDDLike.scala b/core/src/main/scala/org/apache/spark/api/java/JavaRDDLike.scala index 91ae1002abd21..5ba821935ac69 100644 --- a/core/src/main/scala/org/apache/spark/api/java/JavaRDDLike.scala +++ b/core/src/main/scala/org/apache/spark/api/java/JavaRDDLike.scala @@ -685,7 +685,7 @@ trait JavaRDDLike[T, This <: JavaRDDLike[T, This]] extends Serializable { * * The algorithm used is based on streamlib's implementation of "HyperLogLog in Practice: * Algorithmic Engineering of a State of The Art Cardinality Estimation Algorithm", available - * here. + * here. * * @param relativeSD Relative accuracy. Smaller values create counters that require more space. * It must be greater than 0.000017. diff --git a/core/src/main/scala/org/apache/spark/api/python/PythonHadoopUtil.scala b/core/src/main/scala/org/apache/spark/api/python/PythonHadoopUtil.scala index 6259bead3ea88..2ab8add63efae 100644 --- a/core/src/main/scala/org/apache/spark/api/python/PythonHadoopUtil.scala +++ b/core/src/main/scala/org/apache/spark/api/python/PythonHadoopUtil.scala @@ -43,7 +43,8 @@ private[python] object Converter extends Logging { defaultConverter: Converter[Any, Any]): Converter[Any, Any] = { converterClass.map { cc => Try { - val c = Utils.classForName(cc).newInstance().asInstanceOf[Converter[Any, Any]] + val c = Utils.classForName(cc).getConstructor(). + newInstance().asInstanceOf[Converter[Any, Any]] logInfo(s"Loaded converter: $cc") c } match { diff --git a/core/src/main/scala/org/apache/spark/api/r/RBackend.scala b/core/src/main/scala/org/apache/spark/api/r/RBackend.scala index 7ce2581555014..50c8fdf5316d6 100644 --- a/core/src/main/scala/org/apache/spark/api/r/RBackend.scala +++ b/core/src/main/scala/org/apache/spark/api/r/RBackend.scala @@ -17,7 +17,7 @@ package org.apache.spark.api.r -import java.io.{DataInputStream, DataOutputStream, File, FileOutputStream, IOException} +import java.io.{DataOutputStream, File, FileOutputStream, IOException} import java.net.{InetAddress, InetSocketAddress, ServerSocket, Socket} import java.util.concurrent.TimeUnit @@ -32,8 +32,6 @@ import io.netty.handler.timeout.ReadTimeoutHandler import org.apache.spark.SparkConf import org.apache.spark.internal.Logging -import org.apache.spark.network.util.JavaUtils -import org.apache.spark.util.Utils /** * Netty-based backend server that is used to communicate between R and Java. @@ -99,7 +97,7 @@ private[spark] class RBackend { if (bootstrap != null && bootstrap.config().group() != null) { bootstrap.config().group().shutdownGracefully() } - if (bootstrap != null && bootstrap.childGroup() != null) { + if (bootstrap != null && bootstrap.config().childGroup() != null) { bootstrap.config().childGroup().shutdownGracefully() } bootstrap = null @@ -147,7 +145,7 @@ private[spark] object RBackend extends Logging { new Thread("wait for socket to close") { setDaemon(true) override def run(): Unit = { - // any un-catched exception will also shutdown JVM + // any uncaught exception will also shutdown JVM val buf = new Array[Byte](1024) // shutdown JVM if R does not connect back in 10 seconds serverSocket.setSoTimeout(10000) diff --git a/core/src/main/scala/org/apache/spark/api/r/SerDe.scala b/core/src/main/scala/org/apache/spark/api/r/SerDe.scala index 537ab57f9664d..6e0a3f63988d4 100644 --- a/core/src/main/scala/org/apache/spark/api/r/SerDe.scala +++ b/core/src/main/scala/org/apache/spark/api/r/SerDe.scala @@ -74,9 +74,9 @@ private[spark] object SerDe { jvmObjectTracker: JVMObjectTracker): Object = { dataType match { case 'n' => null - case 'i' => new java.lang.Integer(readInt(dis)) - case 'd' => new java.lang.Double(readDouble(dis)) - case 'b' => new java.lang.Boolean(readBoolean(dis)) + case 'i' => java.lang.Integer.valueOf(readInt(dis)) + case 'd' => java.lang.Double.valueOf(readDouble(dis)) + case 'b' => java.lang.Boolean.valueOf(readBoolean(dis)) case 'c' => readString(dis) case 'e' => readMap(dis, jvmObjectTracker) case 'r' => readBytes(dis) diff --git a/core/src/main/scala/org/apache/spark/deploy/DependencyUtils.scala b/core/src/main/scala/org/apache/spark/deploy/DependencyUtils.scala index 178bdcfccb603..5a17a6b6e169c 100644 --- a/core/src/main/scala/org/apache/spark/deploy/DependencyUtils.scala +++ b/core/src/main/scala/org/apache/spark/deploy/DependencyUtils.scala @@ -61,11 +61,12 @@ private[deploy] object DependencyUtils extends Logging { hadoopConf: Configuration, secMgr: SecurityManager): String = { val targetDir = Utils.createTempDir() + val userJarName = userJar.split(File.separatorChar).last Option(jars) .map { resolveGlobPaths(_, hadoopConf) .split(",") - .filterNot(_.contains(userJar.split("/").last)) + .filterNot(_.contains(userJarName)) .mkString(",") } .filterNot(_ == "") diff --git a/core/src/main/scala/org/apache/spark/deploy/SparkHadoopUtil.scala b/core/src/main/scala/org/apache/spark/deploy/SparkHadoopUtil.scala index 5979151345415..7bb2a419107d6 100644 --- a/core/src/main/scala/org/apache/spark/deploy/SparkHadoopUtil.scala +++ b/core/src/main/scala/org/apache/spark/deploy/SparkHadoopUtil.scala @@ -18,7 +18,6 @@ package org.apache.spark.deploy import java.io.{ByteArrayInputStream, ByteArrayOutputStream, DataInputStream, DataOutputStream, File, IOException} -import java.lang.reflect.Method import java.security.PrivilegedExceptionAction import java.text.DateFormat import java.util.{Arrays, Comparator, Date, Locale} @@ -38,17 +37,13 @@ import org.apache.hadoop.security.token.{Token, TokenIdentifier} import org.apache.hadoop.security.token.delegation.AbstractDelegationTokenIdentifier import org.apache.spark.{SparkConf, SparkException} -import org.apache.spark.annotation.DeveloperApi import org.apache.spark.internal.Logging -import org.apache.spark.internal.config._ import org.apache.spark.util.Utils /** - * :: DeveloperApi :: * Contains util methods to interact with Hadoop from Spark. */ -@DeveloperApi -class SparkHadoopUtil extends Logging { +private[spark] class SparkHadoopUtil extends Logging { private val sparkConf = new SparkConf(false).loadFromSystemProperties(true) val conf: Configuration = newConfiguration(sparkConf) UserGroupInformation.setConfiguration(conf) @@ -388,7 +383,7 @@ class SparkHadoopUtil extends Logging { } -object SparkHadoopUtil { +private[spark] object SparkHadoopUtil { private lazy val instance = new SparkHadoopUtil diff --git a/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala b/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala index 88df7324a354a..324f6f8894d34 100644 --- a/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala +++ b/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala @@ -318,7 +318,7 @@ private[spark] class SparkSubmit extends Logging { if (!StringUtils.isBlank(resolvedMavenCoordinates)) { args.jars = mergeFileLists(args.jars, resolvedMavenCoordinates) - if (args.isPython) { + if (args.isPython || isInternal(args.primaryResource)) { args.pyFiles = mergeFileLists(args.pyFiles, resolvedMavenCoordinates) } } @@ -829,7 +829,7 @@ private[spark] class SparkSubmit extends Logging { } val app: SparkApplication = if (classOf[SparkApplication].isAssignableFrom(mainClass)) { - mainClass.newInstance().asInstanceOf[SparkApplication] + mainClass.getConstructor().newInstance().asInstanceOf[SparkApplication] } else { // SPARK-4170 if (classOf[scala.App].isAssignableFrom(mainClass)) { diff --git a/core/src/main/scala/org/apache/spark/deploy/history/FsHistoryProvider.scala b/core/src/main/scala/org/apache/spark/deploy/history/FsHistoryProvider.scala index 2230bc8d6c641..4f4a00b7d831d 100644 --- a/core/src/main/scala/org/apache/spark/deploy/history/FsHistoryProvider.scala +++ b/core/src/main/scala/org/apache/spark/deploy/history/FsHistoryProvider.scala @@ -42,6 +42,7 @@ import org.fusesource.leveldbjni.internal.NativeDB import org.apache.spark.{SecurityManager, SparkConf, SparkException} import org.apache.spark.deploy.SparkHadoopUtil import org.apache.spark.internal.Logging +import org.apache.spark.internal.config.DRIVER_LOG_DFS_DIR import org.apache.spark.internal.config.History._ import org.apache.spark.internal.config.Status._ import org.apache.spark.io.CompressionCodec @@ -97,7 +98,7 @@ private[history] class FsHistoryProvider(conf: SparkConf, clock: Clock) private val UPDATE_INTERVAL_S = conf.getTimeAsSeconds("spark.history.fs.update.interval", "10s") // Interval between each cleaner checks for event logs to delete - private val CLEAN_INTERVAL_S = conf.getTimeAsSeconds("spark.history.fs.cleaner.interval", "1d") + private val CLEAN_INTERVAL_S = conf.get(CLEANER_INTERVAL_S) // Number of threads used to replay event logs. private val NUM_PROCESSING_THREADS = conf.getInt(SPARK_HISTORY_FS_NUM_REPLAY_THREADS, @@ -274,11 +275,18 @@ private[history] class FsHistoryProvider(conf: SparkConf, clock: Clock) pool.scheduleWithFixedDelay( getRunner(() => checkForLogs()), 0, UPDATE_INTERVAL_S, TimeUnit.SECONDS) - if (conf.getBoolean("spark.history.fs.cleaner.enabled", false)) { + if (conf.get(CLEANER_ENABLED)) { // A task that periodically cleans event logs on disk. pool.scheduleWithFixedDelay( getRunner(() => cleanLogs()), 0, CLEAN_INTERVAL_S, TimeUnit.SECONDS) } + + if (conf.contains(DRIVER_LOG_DFS_DIR) && conf.get(DRIVER_LOG_CLEANER_ENABLED)) { + pool.scheduleWithFixedDelay(getRunner(() => cleanDriverLogs()), + 0, + conf.get(DRIVER_LOG_CLEANER_INTERVAL), + TimeUnit.SECONDS) + } } else { logDebug("Background update thread disabled for testing") } @@ -466,8 +474,8 @@ private[history] class FsHistoryProvider(conf: SparkConf, clock: Clock) // If the file is currently not being tracked by the SHS, add an entry for it and try // to parse it. This will allow the cleaner code to detect the file as stale later on // if it was not possible to parse it. - listing.write(LogInfo(entry.getPath().toString(), newLastScanTime, None, None, - entry.getLen())) + listing.write(LogInfo(entry.getPath().toString(), newLastScanTime, LogType.EventLogs, + None, None, entry.getLen())) entry.getLen() > 0 } } @@ -724,7 +732,7 @@ private[history] class FsHistoryProvider(conf: SparkConf, clock: Clock) // listing data is good. invalidateUI(app.info.id, app.attempts.head.info.attemptId) addListing(app) - listing.write(LogInfo(logPath.toString(), scanTime, Some(app.info.id), + listing.write(LogInfo(logPath.toString(), scanTime, LogType.EventLogs, Some(app.info.id), app.attempts.head.info.attemptId, fileStatus.getLen())) // For a finished log, remove the corresponding "in progress" entry from the listing DB if @@ -753,7 +761,8 @@ private[history] class FsHistoryProvider(conf: SparkConf, clock: Clock) // If the app hasn't written down its app ID to the logs, still record the entry in the // listing db, with an empty ID. This will make the log eligible for deletion if the app // does not make progress after the configured max log age. - listing.write(LogInfo(logPath.toString(), scanTime, None, None, fileStatus.getLen())) + listing.write( + LogInfo(logPath.toString(), scanTime, LogType.EventLogs, None, None, fileStatus.getLen())) } } @@ -798,7 +807,7 @@ private[history] class FsHistoryProvider(conf: SparkConf, clock: Clock) val logPath = new Path(logDir, attempt.logPath) listing.delete(classOf[LogInfo], logPath.toString()) cleanAppData(app.id, attempt.info.attemptId, logPath.toString()) - deleteLog(logPath) + deleteLog(fs, logPath) } if (remaining.isEmpty) { @@ -812,11 +821,12 @@ private[history] class FsHistoryProvider(conf: SparkConf, clock: Clock) .reverse() .first(maxTime) .asScala + .filter { l => l.logType == null || l.logType == LogType.EventLogs } .toList stale.foreach { log => if (log.appId.isEmpty) { logInfo(s"Deleting invalid / corrupt event log ${log.logPath}") - deleteLog(new Path(log.logPath)) + deleteLog(fs, new Path(log.logPath)) listing.delete(classOf[LogInfo], log.logPath) } } @@ -824,6 +834,61 @@ private[history] class FsHistoryProvider(conf: SparkConf, clock: Clock) clearBlacklist(CLEAN_INTERVAL_S) } + /** + * Delete driver logs from the configured spark dfs dir that exceed the configured max age + */ + private[history] def cleanDriverLogs(): Unit = Utils.tryLog { + val driverLogDir = conf.get(DRIVER_LOG_DFS_DIR).get + val driverLogFs = new Path(driverLogDir).getFileSystem(hadoopConf) + val currentTime = clock.getTimeMillis() + val maxTime = currentTime - conf.get(MAX_DRIVER_LOG_AGE_S) * 1000 + val logFiles = driverLogFs.listLocatedStatus(new Path(driverLogDir)) + while (logFiles.hasNext()) { + val f = logFiles.next() + // Do not rely on 'modtime' as it is not updated for all filesystems when files are written to + val deleteFile = + try { + val info = listing.read(classOf[LogInfo], f.getPath().toString()) + // Update the lastprocessedtime of file if it's length or modification time has changed + if (info.fileSize < f.getLen() || info.lastProcessed < f.getModificationTime()) { + listing.write( + info.copy(lastProcessed = currentTime, fileSize = f.getLen())) + false + } else if (info.lastProcessed > maxTime) { + false + } else { + true + } + } catch { + case e: NoSuchElementException => + // For every new driver log file discovered, create a new entry in listing + listing.write(LogInfo(f.getPath().toString(), currentTime, LogType.DriverLogs, None, + None, f.getLen())) + false + } + if (deleteFile) { + logInfo(s"Deleting expired driver log for: ${f.getPath().getName()}") + listing.delete(classOf[LogInfo], f.getPath().toString()) + deleteLog(driverLogFs, f.getPath()) + } + } + + // Delete driver log file entries that exceed the configured max age and + // may have been deleted on filesystem externally. + val stale = listing.view(classOf[LogInfo]) + .index("lastProcessed") + .reverse() + .first(maxTime) + .asScala + .filter { l => l.logType != null && l.logType == LogType.DriverLogs } + .toList + stale.foreach { log => + logInfo(s"Deleting invalid driver log ${log.logPath}") + listing.delete(classOf[LogInfo], log.logPath) + deleteLog(driverLogFs, new Path(log.logPath)) + } + } + /** * Rebuilds the application state store from its event log. */ @@ -980,7 +1045,7 @@ private[history] class FsHistoryProvider(conf: SparkConf, clock: Clock) throw new NoSuchElementException(s"Cannot find attempt $attemptId of $appId.")) } - private def deleteLog(log: Path): Unit = { + private def deleteLog(fs: FileSystem, log: Path): Unit = { if (isBlacklisted(log)) { logDebug(s"Skipping deleting $log as we don't have permissions on it.") } else { @@ -1025,6 +1090,10 @@ private[history] case class FsHistoryProviderMetadata( uiVersion: Long, logDir: String) +private[history] object LogType extends Enumeration { + val DriverLogs, EventLogs = Value +} + /** * Tracking info for event logs detected in the configured log directory. Tracks both valid and * invalid logs (e.g. unparseable logs, recorded as logs with no app ID) so that the cleaner @@ -1033,6 +1102,7 @@ private[history] case class FsHistoryProviderMetadata( private[history] case class LogInfo( @KVIndexParam logPath: String, @KVIndexParam("lastProcessed") lastProcessed: Long, + logType: LogType.Value, appId: Option[String], attemptId: Option[String], fileSize: Long) diff --git a/core/src/main/scala/org/apache/spark/deploy/history/HistoryPage.scala b/core/src/main/scala/org/apache/spark/deploy/history/HistoryPage.scala index 32667ddf5c7ea..00ca4efa4d266 100644 --- a/core/src/main/scala/org/apache/spark/deploy/history/HistoryPage.scala +++ b/core/src/main/scala/org/apache/spark/deploy/history/HistoryPage.scala @@ -31,8 +31,8 @@ private[history] class HistoryPage(parent: HistoryServer) extends WebUIPage("") val requestedIncomplete = Option(UIUtils.stripXSS(request.getParameter("showIncomplete"))).getOrElse("false").toBoolean - val allAppsSize = parent.getApplicationList() - .count(isApplicationCompleted(_) != requestedIncomplete) + val displayApplications = parent.getApplicationList() + .exists(isApplicationCompleted(_) != requestedIncomplete) val eventLogsUnderProcessCount = parent.getEventLogsUnderProcess() val lastUpdatedTime = parent.getLastUpdatedTime() val providerConfig = parent.getProviderConfig() @@ -63,9 +63,9 @@ private[history] class HistoryPage(parent: HistoryServer) extends WebUIPage("") } { - if (allAppsSize > 0) { + if (displayApplications) { ++ + request, "/static/dataTables.rowsGroup.js")}> ++
++ ++ diff --git a/core/src/main/scala/org/apache/spark/deploy/rest/RestSubmissionClient.scala b/core/src/main/scala/org/apache/spark/deploy/rest/RestSubmissionClient.scala index 31a8e3e60c067..afa413fe165df 100644 --- a/core/src/main/scala/org/apache/spark/deploy/rest/RestSubmissionClient.scala +++ b/core/src/main/scala/org/apache/spark/deploy/rest/RestSubmissionClient.scala @@ -408,6 +408,10 @@ private[spark] class RestSubmissionClient(master: String) extends Logging { } private[spark] object RestSubmissionClient { + + // SPARK_HOME and SPARK_CONF_DIR are filtered out because they are usually wrong + // on the remote machine (SPARK-12345) (SPARK-25934) + private val BLACKLISTED_SPARK_ENV_VARS = Set("SPARK_ENV_LOADED", "SPARK_HOME", "SPARK_CONF_DIR") private val REPORT_DRIVER_STATUS_INTERVAL = 1000 private val REPORT_DRIVER_STATUS_MAX_TRIES = 10 val PROTOCOL_VERSION = "v1" @@ -417,9 +421,7 @@ private[spark] object RestSubmissionClient { */ private[rest] def filterSystemEnvironment(env: Map[String, String]): Map[String, String] = { env.filterKeys { k => - // SPARK_HOME is filtered out because it is usually wrong on the remote machine (SPARK-12345) - (k.startsWith("SPARK_") && k != "SPARK_ENV_LOADED" && k != "SPARK_HOME") || - k.startsWith("MESOS_") + (k.startsWith("SPARK_") && !BLACKLISTED_SPARK_ENV_VARS.contains(k)) || k.startsWith("MESOS_") } } } diff --git a/core/src/main/scala/org/apache/spark/deploy/security/HadoopDelegationTokenManager.scala b/core/src/main/scala/org/apache/spark/deploy/security/HadoopDelegationTokenManager.scala index 10cd8742f2b49..1169b2878e993 100644 --- a/core/src/main/scala/org/apache/spark/deploy/security/HadoopDelegationTokenManager.scala +++ b/core/src/main/scala/org/apache/spark/deploy/security/HadoopDelegationTokenManager.scala @@ -270,7 +270,9 @@ private[spark] class HadoopDelegationTokenManager( } private def loadProviders(): Map[String, HadoopDelegationTokenProvider] = { - val providers = Seq(new HadoopFSDelegationTokenProvider(fileSystemsToAccess)) ++ + val providers = Seq( + new HadoopFSDelegationTokenProvider( + () => HadoopDelegationTokenManager.this.fileSystemsToAccess())) ++ safeCreateProvider(new HiveDelegationTokenProvider) ++ safeCreateProvider(new HBaseDelegationTokenProvider) diff --git a/core/src/main/scala/org/apache/spark/executor/Executor.scala b/core/src/main/scala/org/apache/spark/executor/Executor.scala index 78556f3f5f7c5..0ea3647466734 100644 --- a/core/src/main/scala/org/apache/spark/executor/Executor.scala +++ b/core/src/main/scala/org/apache/spark/executor/Executor.scala @@ -195,8 +195,11 @@ private[spark] class Executor( private val HEARTBEAT_INTERVAL_MS = conf.get(EXECUTOR_HEARTBEAT_INTERVAL) // Executor for the heartbeat task. - private val heartbeater = new Heartbeater(env.memoryManager, reportHeartBeat, - "executor-heartbeater", HEARTBEAT_INTERVAL_MS) + private val heartbeater = new Heartbeater( + env.memoryManager, + () => Executor.this.reportHeartBeat(), + "executor-heartbeater", + HEARTBEAT_INTERVAL_MS) // must be initialized before running startDriverHeartbeat() private val heartbeatReceiverRef = diff --git a/core/src/main/scala/org/apache/spark/executor/ShuffleReadMetrics.scala b/core/src/main/scala/org/apache/spark/executor/ShuffleReadMetrics.scala index 4be395c8358b2..12c4b8f67f71c 100644 --- a/core/src/main/scala/org/apache/spark/executor/ShuffleReadMetrics.scala +++ b/core/src/main/scala/org/apache/spark/executor/ShuffleReadMetrics.scala @@ -18,6 +18,7 @@ package org.apache.spark.executor import org.apache.spark.annotation.DeveloperApi +import org.apache.spark.shuffle.ShuffleReadMetricsReporter import org.apache.spark.util.LongAccumulator @@ -123,12 +124,13 @@ class ShuffleReadMetrics private[spark] () extends Serializable { } } + /** * A temporary shuffle read metrics holder that is used to collect shuffle read metrics for each * shuffle dependency, and all temporary metrics will be merged into the [[ShuffleReadMetrics]] at * last. */ -private[spark] class TempShuffleReadMetrics { +private[spark] class TempShuffleReadMetrics extends ShuffleReadMetricsReporter { private[this] var _remoteBlocksFetched = 0L private[this] var _localBlocksFetched = 0L private[this] var _remoteBytesRead = 0L @@ -137,13 +139,13 @@ private[spark] class TempShuffleReadMetrics { private[this] var _fetchWaitTime = 0L private[this] var _recordsRead = 0L - def incRemoteBlocksFetched(v: Long): Unit = _remoteBlocksFetched += v - def incLocalBlocksFetched(v: Long): Unit = _localBlocksFetched += v - def incRemoteBytesRead(v: Long): Unit = _remoteBytesRead += v - def incRemoteBytesReadToDisk(v: Long): Unit = _remoteBytesReadToDisk += v - def incLocalBytesRead(v: Long): Unit = _localBytesRead += v - def incFetchWaitTime(v: Long): Unit = _fetchWaitTime += v - def incRecordsRead(v: Long): Unit = _recordsRead += v + override def incRemoteBlocksFetched(v: Long): Unit = _remoteBlocksFetched += v + override def incLocalBlocksFetched(v: Long): Unit = _localBlocksFetched += v + override def incRemoteBytesRead(v: Long): Unit = _remoteBytesRead += v + override def incRemoteBytesReadToDisk(v: Long): Unit = _remoteBytesReadToDisk += v + override def incLocalBytesRead(v: Long): Unit = _localBytesRead += v + override def incFetchWaitTime(v: Long): Unit = _fetchWaitTime += v + override def incRecordsRead(v: Long): Unit = _recordsRead += v def remoteBlocksFetched: Long = _remoteBlocksFetched def localBlocksFetched: Long = _localBlocksFetched diff --git a/core/src/main/scala/org/apache/spark/executor/ShuffleWriteMetrics.scala b/core/src/main/scala/org/apache/spark/executor/ShuffleWriteMetrics.scala index 0c9da657c2b60..d0b0e7da079c9 100644 --- a/core/src/main/scala/org/apache/spark/executor/ShuffleWriteMetrics.scala +++ b/core/src/main/scala/org/apache/spark/executor/ShuffleWriteMetrics.scala @@ -18,6 +18,7 @@ package org.apache.spark.executor import org.apache.spark.annotation.DeveloperApi +import org.apache.spark.shuffle.ShuffleWriteMetricsReporter import org.apache.spark.util.LongAccumulator @@ -27,7 +28,7 @@ import org.apache.spark.util.LongAccumulator * Operations are not thread-safe. */ @DeveloperApi -class ShuffleWriteMetrics private[spark] () extends Serializable { +class ShuffleWriteMetrics private[spark] () extends ShuffleWriteMetricsReporter with Serializable { private[executor] val _bytesWritten = new LongAccumulator private[executor] val _recordsWritten = new LongAccumulator private[executor] val _writeTime = new LongAccumulator @@ -47,13 +48,13 @@ class ShuffleWriteMetrics private[spark] () extends Serializable { */ def writeTime: Long = _writeTime.sum - private[spark] def incBytesWritten(v: Long): Unit = _bytesWritten.add(v) - private[spark] def incRecordsWritten(v: Long): Unit = _recordsWritten.add(v) - private[spark] def incWriteTime(v: Long): Unit = _writeTime.add(v) - private[spark] def decBytesWritten(v: Long): Unit = { + private[spark] override def incBytesWritten(v: Long): Unit = _bytesWritten.add(v) + private[spark] override def incRecordsWritten(v: Long): Unit = _recordsWritten.add(v) + private[spark] override def incWriteTime(v: Long): Unit = _writeTime.add(v) + private[spark] override def decBytesWritten(v: Long): Unit = { _bytesWritten.setValue(bytesWritten - v) } - private[spark] def decRecordsWritten(v: Long): Unit = { + private[spark] override def decRecordsWritten(v: Long): Unit = { _recordsWritten.setValue(recordsWritten - v) } } diff --git a/core/src/main/scala/org/apache/spark/internal/Logging.scala b/core/src/main/scala/org/apache/spark/internal/Logging.scala index c0d709ad25f29..00db9af846ab9 100644 --- a/core/src/main/scala/org/apache/spark/internal/Logging.scala +++ b/core/src/main/scala/org/apache/spark/internal/Logging.scala @@ -17,7 +17,11 @@ package org.apache.spark.internal -import org.apache.log4j.{Level, LogManager, PropertyConfigurator} +import java.util.concurrent.ConcurrentHashMap + +import scala.collection.JavaConverters._ + +import org.apache.log4j._ import org.slf4j.{Logger, LoggerFactory} import org.slf4j.impl.StaticLoggerBinder @@ -143,13 +147,25 @@ trait Logging { // overriding the root logger's config if they're different. val replLogger = LogManager.getLogger(logName) val replLevel = Option(replLogger.getLevel()).getOrElse(Level.WARN) + // Update the consoleAppender threshold to replLevel if (replLevel != rootLogger.getEffectiveLevel()) { if (!silent) { System.err.printf("Setting default log level to \"%s\".\n", replLevel) System.err.println("To adjust logging level use sc.setLogLevel(newLevel). " + "For SparkR, use setLogLevel(newLevel).") } - rootLogger.setLevel(replLevel) + rootLogger.getAllAppenders().asScala.foreach { + case ca: ConsoleAppender => + Option(ca.getThreshold()) match { + case Some(t) => + Logging.consoleAppenderToThreshold.put(ca, t) + if (!t.isGreaterOrEqual(replLevel)) { + ca.setThreshold(replLevel) + } + case None => ca.setThreshold(replLevel) + } + case _ => // no-op + } } } // scalastyle:on println @@ -166,6 +182,7 @@ private[spark] object Logging { @volatile private var initialized = false @volatile private var defaultRootLevel: Level = null @volatile private var defaultSparkLog4jConfig = false + private val consoleAppenderToThreshold = new ConcurrentHashMap[ConsoleAppender, Priority]() val initLock = new Object() try { @@ -192,7 +209,13 @@ private[spark] object Logging { defaultSparkLog4jConfig = false LogManager.resetConfiguration() } else { - LogManager.getRootLogger().setLevel(defaultRootLevel) + val rootLogger = LogManager.getRootLogger() + rootLogger.setLevel(defaultRootLevel) + rootLogger.getAllAppenders().asScala.foreach { + case ca: ConsoleAppender => + ca.setThreshold(consoleAppenderToThreshold.get(ca)) + case _ => // no-op + } } } this.initialized = false diff --git a/core/src/main/scala/org/apache/spark/internal/config/History.scala b/core/src/main/scala/org/apache/spark/internal/config/History.scala index 3f74eb3d13d4b..b7d8061d26d21 100644 --- a/core/src/main/scala/org/apache/spark/internal/config/History.scala +++ b/core/src/main/scala/org/apache/spark/internal/config/History.scala @@ -29,6 +29,14 @@ private[spark] object History { .stringConf .createWithDefault(DEFAULT_LOG_DIR) + val CLEANER_ENABLED = ConfigBuilder("spark.history.fs.cleaner.enabled") + .booleanConf + .createWithDefault(false) + + val CLEANER_INTERVAL_S = ConfigBuilder("spark.history.fs.cleaner.interval") + .timeConf(TimeUnit.SECONDS) + .createWithDefaultString("1d") + val MAX_LOG_AGE_S = ConfigBuilder("spark.history.fs.cleaner.maxAge") .timeConf(TimeUnit.SECONDS) .createWithDefaultString("7d") @@ -62,4 +70,13 @@ private[spark] object History { "parts of event log files. It can be disabled by setting this config to 0.") .bytesConf(ByteUnit.BYTE) .createWithDefaultString("1m") + + val DRIVER_LOG_CLEANER_ENABLED = ConfigBuilder("spark.history.fs.driverlog.cleaner.enabled") + .fallbackConf(CLEANER_ENABLED) + + val DRIVER_LOG_CLEANER_INTERVAL = ConfigBuilder("spark.history.fs.driverlog.cleaner.interval") + .fallbackConf(CLEANER_INTERVAL_S) + + val MAX_DRIVER_LOG_AGE_S = ConfigBuilder("spark.history.fs.driverlog.cleaner.maxAge") + .fallbackConf(MAX_LOG_AGE_S) } diff --git a/core/src/main/scala/org/apache/spark/internal/config/package.scala b/core/src/main/scala/org/apache/spark/internal/config/package.scala index a8a593278fd55..8597e9848f9ff 100644 --- a/core/src/main/scala/org/apache/spark/internal/config/package.scala +++ b/core/src/main/scala/org/apache/spark/internal/config/package.scala @@ -49,6 +49,19 @@ package object config { .bytesConf(ByteUnit.MiB) .createOptional + private[spark] val DRIVER_LOG_DFS_DIR = + ConfigBuilder("spark.driver.log.dfsDir").stringConf.createOptional + + private[spark] val DRIVER_LOG_LAYOUT = + ConfigBuilder("spark.driver.log.layout") + .stringConf + .createOptional + + private[spark] val DRIVER_LOG_PERSISTTODFS = + ConfigBuilder("spark.driver.log.persistToDfs.enabled") + .booleanConf + .createWithDefault(false) + private[spark] val EVENT_LOG_COMPRESS = ConfigBuilder("spark.eventLog.compress") .booleanConf @@ -594,6 +607,12 @@ package object config { .stringConf .createOptional + private[spark] val UI_REQUEST_HEADER_SIZE = + ConfigBuilder("spark.ui.requestHeaderSize") + .doc("Value for HTTP request header size in bytes.") + .bytesConf(ByteUnit.BYTE) + .createWithDefaultString("8k") + private[spark] val EXTRA_LISTENERS = ConfigBuilder("spark.extraListeners") .doc("Class names of listeners to add to SparkContext during initialization.") .stringConf diff --git a/core/src/main/scala/org/apache/spark/internal/io/HadoopMapReduceCommitProtocol.scala b/core/src/main/scala/org/apache/spark/internal/io/HadoopMapReduceCommitProtocol.scala index 3e60c50ada59b..7477e03bfaa76 100644 --- a/core/src/main/scala/org/apache/spark/internal/io/HadoopMapReduceCommitProtocol.scala +++ b/core/src/main/scala/org/apache/spark/internal/io/HadoopMapReduceCommitProtocol.scala @@ -91,7 +91,7 @@ class HadoopMapReduceCommitProtocol( private def stagingDir = new Path(path, ".spark-staging-" + jobId) protected def setupCommitter(context: TaskAttemptContext): OutputCommitter = { - val format = context.getOutputFormatClass.newInstance() + val format = context.getOutputFormatClass.getConstructor().newInstance() // If OutputFormat is Configurable, we should set conf to it. format match { case c: Configurable => c.setConf(context.getConfiguration) diff --git a/core/src/main/scala/org/apache/spark/internal/io/SparkHadoopWriter.scala b/core/src/main/scala/org/apache/spark/internal/io/SparkHadoopWriter.scala index 9ebd0aa301592..3a58ea816937b 100644 --- a/core/src/main/scala/org/apache/spark/internal/io/SparkHadoopWriter.scala +++ b/core/src/main/scala/org/apache/spark/internal/io/SparkHadoopWriter.scala @@ -256,7 +256,7 @@ class HadoopMapRedWriteConfigUtil[K, V: ClassTag](conf: SerializableJobConf) private def getOutputFormat(): OutputFormat[K, V] = { require(outputFormat != null, "Must call initOutputFormat first.") - outputFormat.newInstance() + outputFormat.getConstructor().newInstance() } // -------------------------------------------------------------------------- @@ -379,7 +379,7 @@ class HadoopMapReduceWriteConfigUtil[K, V: ClassTag](conf: SerializableConfigura private def getOutputFormat(): NewOutputFormat[K, V] = { require(outputFormat != null, "Must call initOutputFormat first.") - outputFormat.newInstance() + outputFormat.getConstructor().newInstance() } // -------------------------------------------------------------------------- diff --git a/core/src/main/scala/org/apache/spark/metrics/MetricsSystem.scala b/core/src/main/scala/org/apache/spark/metrics/MetricsSystem.scala index fd2d6006a467e..6b84299859381 100644 --- a/core/src/main/scala/org/apache/spark/metrics/MetricsSystem.scala +++ b/core/src/main/scala/org/apache/spark/metrics/MetricsSystem.scala @@ -194,7 +194,7 @@ private[spark] class MetricsSystem private ( sourceConfigs.foreach { kv => val classPath = kv._2.getProperty("class") try { - val source = Utils.classForName(classPath).newInstance() + val source = Utils.classForName(classPath).getConstructor().newInstance() registerSource(source.asInstanceOf[Source]) } catch { case e: Exception => logError("Source class " + classPath + " cannot be instantiated", e) diff --git a/core/src/main/scala/org/apache/spark/package.scala b/core/src/main/scala/org/apache/spark/package.scala index 8058a4d5dbdea..5d0639e92c36a 100644 --- a/core/src/main/scala/org/apache/spark/package.scala +++ b/core/src/main/scala/org/apache/spark/package.scala @@ -19,6 +19,8 @@ package org.apache import java.util.Properties +import org.apache.spark.util.VersionUtils + /** * Core Spark functionality. [[org.apache.spark.SparkContext]] serves as the main entry point to * Spark, while [[org.apache.spark.rdd.RDD]] is the data type representing a distributed collection, @@ -89,6 +91,7 @@ package object spark { } val SPARK_VERSION = SparkBuildInfo.spark_version + val SPARK_VERSION_SHORT = VersionUtils.shortVersion(SparkBuildInfo.spark_version) val SPARK_BRANCH = SparkBuildInfo.spark_branch val SPARK_REVISION = SparkBuildInfo.spark_revision val SPARK_BUILD_USER = SparkBuildInfo.spark_build_user diff --git a/core/src/main/scala/org/apache/spark/rdd/BinaryFileRDD.scala b/core/src/main/scala/org/apache/spark/rdd/BinaryFileRDD.scala index a14bad47dfe10..039dbcbd5e035 100644 --- a/core/src/main/scala/org/apache/spark/rdd/BinaryFileRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/BinaryFileRDD.scala @@ -41,7 +41,7 @@ private[spark] class BinaryFileRDD[T]( // traversing a large number of directories and files. Parallelize it. conf.setIfUnset(FileInputFormat.LIST_STATUS_NUM_THREADS, Runtime.getRuntime.availableProcessors().toString) - val inputFormat = inputFormatClass.newInstance + val inputFormat = inputFormatClass.getConstructor().newInstance() inputFormat match { case configurable: Configurable => configurable.setConf(conf) diff --git a/core/src/main/scala/org/apache/spark/rdd/CoGroupedRDD.scala b/core/src/main/scala/org/apache/spark/rdd/CoGroupedRDD.scala index 4574c3724962e..7e76731f5e454 100644 --- a/core/src/main/scala/org/apache/spark/rdd/CoGroupedRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/CoGroupedRDD.scala @@ -143,8 +143,10 @@ class CoGroupedRDD[K: ClassTag]( case shuffleDependency: ShuffleDependency[_, _, _] => // Read map outputs of shuffle + val metrics = context.taskMetrics().createTempShuffleReadMetrics() val it = SparkEnv.get.shuffleManager - .getReader(shuffleDependency.shuffleHandle, split.index, split.index + 1, context) + .getReader( + shuffleDependency.shuffleHandle, split.index, split.index + 1, context, metrics) .read() rddIterators += ((it, depNum)) } diff --git a/core/src/main/scala/org/apache/spark/rdd/NewHadoopRDD.scala b/core/src/main/scala/org/apache/spark/rdd/NewHadoopRDD.scala index 2d66d25ba39fa..483de28d92ab7 100644 --- a/core/src/main/scala/org/apache/spark/rdd/NewHadoopRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/NewHadoopRDD.scala @@ -120,7 +120,7 @@ class NewHadoopRDD[K, V]( } override def getPartitions: Array[Partition] = { - val inputFormat = inputFormatClass.newInstance + val inputFormat = inputFormatClass.getConstructor().newInstance() inputFormat match { case configurable: Configurable => configurable.setConf(_conf) @@ -183,7 +183,7 @@ class NewHadoopRDD[K, V]( } } - private val format = inputFormatClass.newInstance + private val format = inputFormatClass.getConstructor().newInstance() format match { case configurable: Configurable => configurable.setConf(conf) diff --git a/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala b/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala index e68c6b1366c7f..4bf4f082d0382 100644 --- a/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala +++ b/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala @@ -394,7 +394,7 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)]) * * The algorithm used is based on streamlib's implementation of "HyperLogLog in Practice: * Algorithmic Engineering of a State of The Art Cardinality Estimation Algorithm", available - * here. + * here. * * The relative accuracy is approximately `1.054 / sqrt(2^p)`. Setting a nonzero (`sp` is * greater than `p`) would trigger sparse representation of registers, which may reduce the @@ -436,7 +436,7 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)]) * * The algorithm used is based on streamlib's implementation of "HyperLogLog in Practice: * Algorithmic Engineering of a State of The Art Cardinality Estimation Algorithm", available - * here. + * here. * * @param relativeSD Relative accuracy. Smaller values create counters that require more space. * It must be greater than 0.000017. @@ -456,7 +456,7 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)]) * * The algorithm used is based on streamlib's implementation of "HyperLogLog in Practice: * Algorithmic Engineering of a State of The Art Cardinality Estimation Algorithm", available - * here. + * here. * * @param relativeSD Relative accuracy. Smaller values create counters that require more space. * It must be greater than 0.000017. @@ -473,7 +473,7 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)]) * * The algorithm used is based on streamlib's implementation of "HyperLogLog in Practice: * Algorithmic Engineering of a State of The Art Cardinality Estimation Algorithm", available - * here. + * here. * * @param relativeSD Relative accuracy. Smaller values create counters that require more space. * It must be greater than 0.000017. diff --git a/core/src/main/scala/org/apache/spark/rdd/RDD.scala b/core/src/main/scala/org/apache/spark/rdd/RDD.scala index 743e3441eea55..6a25ee20b2c68 100644 --- a/core/src/main/scala/org/apache/spark/rdd/RDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/RDD.scala @@ -1258,7 +1258,7 @@ abstract class RDD[T: ClassTag]( * * The algorithm used is based on streamlib's implementation of "HyperLogLog in Practice: * Algorithmic Engineering of a State of The Art Cardinality Estimation Algorithm", available - * here. + * here. * * The relative accuracy is approximately `1.054 / sqrt(2^p)`. Setting a nonzero (`sp` is greater * than `p`) would trigger sparse representation of registers, which may reduce the memory @@ -1290,7 +1290,7 @@ abstract class RDD[T: ClassTag]( * * The algorithm used is based on streamlib's implementation of "HyperLogLog in Practice: * Algorithmic Engineering of a State of The Art Cardinality Estimation Algorithm", available - * here. + * here. * * @param relativeSD Relative accuracy. Smaller values create counters that require more space. * It must be greater than 0.000017. diff --git a/core/src/main/scala/org/apache/spark/rdd/ShuffledRDD.scala b/core/src/main/scala/org/apache/spark/rdd/ShuffledRDD.scala index e8f9b27b7eb55..5ec99b7f4f3ab 100644 --- a/core/src/main/scala/org/apache/spark/rdd/ShuffledRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/ShuffledRDD.scala @@ -101,7 +101,9 @@ class ShuffledRDD[K: ClassTag, V: ClassTag, C: ClassTag]( override def compute(split: Partition, context: TaskContext): Iterator[(K, C)] = { val dep = dependencies.head.asInstanceOf[ShuffleDependency[K, V, C]] - SparkEnv.get.shuffleManager.getReader(dep.shuffleHandle, split.index, split.index + 1, context) + val metrics = context.taskMetrics().createTempShuffleReadMetrics() + SparkEnv.get.shuffleManager.getReader( + dep.shuffleHandle, split.index, split.index + 1, context, metrics) .read() .asInstanceOf[Iterator[(K, C)]] } diff --git a/core/src/main/scala/org/apache/spark/rdd/SubtractedRDD.scala b/core/src/main/scala/org/apache/spark/rdd/SubtractedRDD.scala index a733eaa5d7e53..42d190377f104 100644 --- a/core/src/main/scala/org/apache/spark/rdd/SubtractedRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/SubtractedRDD.scala @@ -107,9 +107,14 @@ private[spark] class SubtractedRDD[K: ClassTag, V: ClassTag, W: ClassTag]( .asInstanceOf[Iterator[Product2[K, V]]].foreach(op) case shuffleDependency: ShuffleDependency[_, _, _] => + val metrics = context.taskMetrics().createTempShuffleReadMetrics() val iter = SparkEnv.get.shuffleManager .getReader( - shuffleDependency.shuffleHandle, partition.index, partition.index + 1, context) + shuffleDependency.shuffleHandle, + partition.index, + partition.index + 1, + context, + metrics) .read() iter.foreach(op) } diff --git a/core/src/main/scala/org/apache/spark/rdd/WholeTextFileRDD.scala b/core/src/main/scala/org/apache/spark/rdd/WholeTextFileRDD.scala index 9f3d0745c33c9..eada762b99c8e 100644 --- a/core/src/main/scala/org/apache/spark/rdd/WholeTextFileRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/WholeTextFileRDD.scala @@ -44,7 +44,7 @@ private[spark] class WholeTextFileRDD( // traversing a large number of directories and files. Parallelize it. conf.setIfUnset(FileInputFormat.LIST_STATUS_NUM_THREADS, Runtime.getRuntime.availableProcessors().toString) - val inputFormat = inputFormatClass.newInstance + val inputFormat = inputFormatClass.getConstructor().newInstance() inputFormat match { case configurable: Configurable => configurable.setConf(conf) diff --git a/core/src/main/scala/org/apache/spark/scheduler/EventLoggingListener.scala b/core/src/main/scala/org/apache/spark/scheduler/EventLoggingListener.scala index 788b23d1bfb03..5f697fe99258d 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/EventLoggingListener.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/EventLoggingListener.scala @@ -382,19 +382,15 @@ private[spark] object EventLoggingListener extends Logging { appId: String, appAttemptId: Option[String], compressionCodecName: Option[String] = None): String = { - val base = new Path(logBaseDir).toString.stripSuffix("/") + "/" + sanitize(appId) + val base = new Path(logBaseDir).toString.stripSuffix("/") + "/" + Utils.sanitizeDirName(appId) val codec = compressionCodecName.map("." + _).getOrElse("") if (appAttemptId.isDefined) { - base + "_" + sanitize(appAttemptId.get) + codec + base + "_" + Utils.sanitizeDirName(appAttemptId.get) + codec } else { base + codec } } - private def sanitize(str: String): String = { - str.replaceAll("[ :/]", "-").replaceAll("[.${}'\"]", "_").toLowerCase(Locale.ROOT) - } - /** * Opens an event log file and returns an input stream that contains the event data. * diff --git a/core/src/main/scala/org/apache/spark/scheduler/ShuffleMapTask.scala b/core/src/main/scala/org/apache/spark/scheduler/ShuffleMapTask.scala index f2cd65fd523ab..5412717d61988 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/ShuffleMapTask.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/ShuffleMapTask.scala @@ -95,7 +95,8 @@ private[spark] class ShuffleMapTask( var writer: ShuffleWriter[Any, Any] = null try { val manager = SparkEnv.get.shuffleManager - writer = manager.getWriter[Any, Any](dep.shuffleHandle, partitionId, context) + writer = manager.getWriter[Any, Any]( + dep.shuffleHandle, partitionId, context, context.taskMetrics().shuffleWriteMetrics) writer.write(rdd.iterator(partition, context).asInstanceOf[Iterator[_ <: Product2[Any, Any]]]) writer.stop(success = true).get } catch { diff --git a/core/src/main/scala/org/apache/spark/scheduler/StageInfo.scala b/core/src/main/scala/org/apache/spark/scheduler/StageInfo.scala index 903e25b7986f2..33a68f24bd53a 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/StageInfo.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/StageInfo.scala @@ -30,7 +30,7 @@ import org.apache.spark.storage.RDDInfo @DeveloperApi class StageInfo( val stageId: Int, - @deprecated("Use attemptNumber instead", "2.3.0") val attemptId: Int, + private val attemptId: Int, val name: String, val numTasks: Int, val rddInfos: Seq[RDDInfo], @@ -56,6 +56,8 @@ class StageInfo( completionTime = Some(System.currentTimeMillis) } + // This would just be the second constructor arg, except we need to maintain this method + // with parentheses for compatibility def attemptNumber(): Int = attemptId private[spark] def getStatusString: String = { diff --git a/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala b/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala index 72427dd6ce4d4..1e1c27c477877 100644 --- a/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala +++ b/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala @@ -30,6 +30,7 @@ import scala.util.control.NonFatal import com.esotericsoftware.kryo.{Kryo, KryoException, Serializer => KryoClassSerializer} import com.esotericsoftware.kryo.io.{Input => KryoInput, Output => KryoOutput} import com.esotericsoftware.kryo.io.{UnsafeInput => KryoUnsafeInput, UnsafeOutput => KryoUnsafeOutput} +import com.esotericsoftware.kryo.pool.{KryoCallback, KryoFactory, KryoPool} import com.esotericsoftware.kryo.serializers.{JavaSerializer => KryoJavaSerializer} import com.twitter.chill.{AllScalaRegistrar, EmptyScalaKryoInstantiator} import org.apache.avro.generic.{GenericData, GenericRecord} @@ -41,7 +42,7 @@ import org.apache.spark.internal.Logging import org.apache.spark.network.util.ByteUnit import org.apache.spark.scheduler.{CompressedMapStatus, HighlyCompressedMapStatus} import org.apache.spark.storage._ -import org.apache.spark.util.{BoundedPriorityQueue, SerializableConfiguration, SerializableJobConf, Utils} +import org.apache.spark.util.{BoundedPriorityQueue, ByteBufferInputStream, SerializableConfiguration, SerializableJobConf, Utils} import org.apache.spark.util.collection.CompactBuffer /** @@ -84,6 +85,7 @@ class KryoSerializer(conf: SparkConf) private val avroSchemas = conf.getAvroSchema // whether to use unsafe based IO for serialization private val useUnsafe = conf.getBoolean("spark.kryo.unsafe", false) + private val usePool = conf.getBoolean("spark.kryo.pool", true) def newKryoOutput(): KryoOutput = if (useUnsafe) { @@ -92,6 +94,36 @@ class KryoSerializer(conf: SparkConf) new KryoOutput(bufferSize, math.max(bufferSize, maxBufferSize)) } + @transient + private lazy val factory: KryoFactory = new KryoFactory() { + override def create: Kryo = { + newKryo() + } + } + + private class PoolWrapper extends KryoPool { + private var pool: KryoPool = getPool + + override def borrow(): Kryo = pool.borrow() + + override def release(kryo: Kryo): Unit = pool.release(kryo) + + override def run[T](kryoCallback: KryoCallback[T]): T = pool.run(kryoCallback) + + def reset(): Unit = { + pool = getPool + } + + private def getPool: KryoPool = { + new KryoPool.Builder(factory).softReferences.build + } + } + + @transient + private lazy val internalPool = new PoolWrapper + + def pool: KryoPool = internalPool + def newKryo(): Kryo = { val instantiator = new EmptyScalaKryoInstantiator val kryo = instantiator.newKryo() @@ -132,7 +164,8 @@ class KryoSerializer(conf: SparkConf) .foreach { className => kryo.register(Class.forName(className, true, classLoader)) } // Allow the user to register their own classes by setting spark.kryo.registrator. userRegistrators - .map(Class.forName(_, true, classLoader).newInstance().asInstanceOf[KryoRegistrator]) + .map(Class.forName(_, true, classLoader).getConstructor(). + newInstance().asInstanceOf[KryoRegistrator]) .foreach { reg => reg.registerClasses(kryo) } // scalastyle:on classforname } catch { @@ -182,6 +215,12 @@ class KryoSerializer(conf: SparkConf) // We can't load those class directly in order to avoid unnecessary jar dependencies. // We load them safely, ignore it if the class not found. Seq( + "org.apache.spark.ml.attribute.Attribute", + "org.apache.spark.ml.attribute.AttributeGroup", + "org.apache.spark.ml.attribute.BinaryAttribute", + "org.apache.spark.ml.attribute.NominalAttribute", + "org.apache.spark.ml.attribute.NumericAttribute", + "org.apache.spark.ml.feature.Instance", "org.apache.spark.ml.feature.LabeledPoint", "org.apache.spark.ml.feature.OffsetInstance", @@ -191,6 +230,7 @@ class KryoSerializer(conf: SparkConf) "org.apache.spark.ml.linalg.SparseMatrix", "org.apache.spark.ml.linalg.SparseVector", "org.apache.spark.ml.linalg.Vector", + "org.apache.spark.ml.stat.distribution.MultivariateGaussian", "org.apache.spark.ml.tree.impl.TreePoint", "org.apache.spark.mllib.clustering.VectorWithNorm", "org.apache.spark.mllib.linalg.DenseMatrix", @@ -199,7 +239,8 @@ class KryoSerializer(conf: SparkConf) "org.apache.spark.mllib.linalg.SparseMatrix", "org.apache.spark.mllib.linalg.SparseVector", "org.apache.spark.mllib.linalg.Vector", - "org.apache.spark.mllib.regression.LabeledPoint" + "org.apache.spark.mllib.regression.LabeledPoint", + "org.apache.spark.mllib.stat.distribution.MultivariateGaussian" ).foreach { name => try { val clazz = Utils.classForName(name) @@ -214,8 +255,14 @@ class KryoSerializer(conf: SparkConf) kryo } + override def setDefaultClassLoader(classLoader: ClassLoader): Serializer = { + super.setDefaultClassLoader(classLoader) + internalPool.reset() + this + } + override def newInstance(): SerializerInstance = { - new KryoSerializerInstance(this, useUnsafe) + new KryoSerializerInstance(this, useUnsafe, usePool) } private[spark] override lazy val supportsRelocationOfSerializedObjects: Boolean = { @@ -298,7 +345,8 @@ class KryoDeserializationStream( } } -private[spark] class KryoSerializerInstance(ks: KryoSerializer, useUnsafe: Boolean) +private[spark] class KryoSerializerInstance( + ks: KryoSerializer, useUnsafe: Boolean, usePool: Boolean) extends SerializerInstance { /** * A re-used [[Kryo]] instance. Methods will borrow this instance by calling `borrowKryo()`, do @@ -306,22 +354,29 @@ private[spark] class KryoSerializerInstance(ks: KryoSerializer, useUnsafe: Boole * pool of size one. SerializerInstances are not thread-safe, hence accesses to this field are * not synchronized. */ - @Nullable private[this] var cachedKryo: Kryo = borrowKryo() + @Nullable private[this] var cachedKryo: Kryo = if (usePool) null else borrowKryo() /** * Borrows a [[Kryo]] instance. If possible, this tries to re-use a cached Kryo instance; * otherwise, it allocates a new instance. */ private[serializer] def borrowKryo(): Kryo = { - if (cachedKryo != null) { - val kryo = cachedKryo - // As a defensive measure, call reset() to clear any Kryo state that might have been modified - // by the last operation to borrow this instance (see SPARK-7766 for discussion of this issue) + if (usePool) { + val kryo = ks.pool.borrow() kryo.reset() - cachedKryo = null kryo } else { - ks.newKryo() + if (cachedKryo != null) { + val kryo = cachedKryo + // As a defensive measure, call reset() to clear any Kryo state that might have + // been modified by the last operation to borrow this instance + // (see SPARK-7766 for discussion of this issue) + kryo.reset() + cachedKryo = null + kryo + } else { + ks.newKryo() + } } } @@ -331,8 +386,12 @@ private[spark] class KryoSerializerInstance(ks: KryoSerializer, useUnsafe: Boole * re-use. */ private[serializer] def releaseKryo(kryo: Kryo): Unit = { - if (cachedKryo == null) { - cachedKryo = kryo + if (usePool) { + ks.pool.release(kryo) + } else { + if (cachedKryo == null) { + cachedKryo = kryo + } } } @@ -358,7 +417,12 @@ private[spark] class KryoSerializerInstance(ks: KryoSerializer, useUnsafe: Boole override def deserialize[T: ClassTag](bytes: ByteBuffer): T = { val kryo = borrowKryo() try { - input.setBuffer(bytes.array(), bytes.arrayOffset() + bytes.position(), bytes.remaining()) + if (bytes.hasArray) { + input.setBuffer(bytes.array(), bytes.arrayOffset() + bytes.position(), bytes.remaining()) + } else { + input.setBuffer(new Array[Byte](4096)) + input.setInputStream(new ByteBufferInputStream(bytes)) + } kryo.readClassAndObject(input).asInstanceOf[T] } finally { releaseKryo(kryo) @@ -370,7 +434,12 @@ private[spark] class KryoSerializerInstance(ks: KryoSerializer, useUnsafe: Boole val oldClassLoader = kryo.getClassLoader try { kryo.setClassLoader(loader) - input.setBuffer(bytes.array(), bytes.arrayOffset() + bytes.position(), bytes.remaining()) + if (bytes.hasArray) { + input.setBuffer(bytes.array(), bytes.arrayOffset() + bytes.position(), bytes.remaining()) + } else { + input.setBuffer(new Array[Byte](4096)) + input.setInputStream(new ByteBufferInputStream(bytes)) + } kryo.readClassAndObject(input).asInstanceOf[T] } finally { kryo.setClassLoader(oldClassLoader) diff --git a/core/src/main/scala/org/apache/spark/shuffle/BlockStoreShuffleReader.scala b/core/src/main/scala/org/apache/spark/shuffle/BlockStoreShuffleReader.scala index 74b0e0b3a741a..27e2f98c58f0c 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/BlockStoreShuffleReader.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/BlockStoreShuffleReader.scala @@ -33,6 +33,7 @@ private[spark] class BlockStoreShuffleReader[K, C]( startPartition: Int, endPartition: Int, context: TaskContext, + readMetrics: ShuffleReadMetricsReporter, serializerManager: SerializerManager = SparkEnv.get.serializerManager, blockManager: BlockManager = SparkEnv.get.blockManager, mapOutputTracker: MapOutputTracker = SparkEnv.get.mapOutputTracker) @@ -53,7 +54,8 @@ private[spark] class BlockStoreShuffleReader[K, C]( SparkEnv.get.conf.getInt("spark.reducer.maxReqsInFlight", Int.MaxValue), SparkEnv.get.conf.get(config.REDUCER_MAX_BLOCKS_IN_FLIGHT_PER_ADDRESS), SparkEnv.get.conf.get(config.MAX_REMOTE_BLOCK_SIZE_FETCH_TO_MEM), - SparkEnv.get.conf.getBoolean("spark.shuffle.detectCorrupt", true)) + SparkEnv.get.conf.getBoolean("spark.shuffle.detectCorrupt", true), + readMetrics) val serializerInstance = dep.serializer.newInstance() @@ -66,7 +68,6 @@ private[spark] class BlockStoreShuffleReader[K, C]( } // Update the context task metrics for each record read. - val readMetrics = context.taskMetrics.createTempShuffleReadMetrics() val metricIter = CompletionIterator[(Any, Any), Iterator[(Any, Any)]]( recordIter.map { record => readMetrics.incRecordsRead(1) diff --git a/core/src/main/scala/org/apache/spark/shuffle/ShuffleManager.scala b/core/src/main/scala/org/apache/spark/shuffle/ShuffleManager.scala index 4ea8a7120a9cc..18a743fbfa6fc 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/ShuffleManager.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/ShuffleManager.scala @@ -38,7 +38,11 @@ private[spark] trait ShuffleManager { dependency: ShuffleDependency[K, V, C]): ShuffleHandle /** Get a writer for a given partition. Called on executors by map tasks. */ - def getWriter[K, V](handle: ShuffleHandle, mapId: Int, context: TaskContext): ShuffleWriter[K, V] + def getWriter[K, V]( + handle: ShuffleHandle, + mapId: Int, + context: TaskContext, + metrics: ShuffleWriteMetricsReporter): ShuffleWriter[K, V] /** * Get a reader for a range of reduce partitions (startPartition to endPartition-1, inclusive). @@ -48,7 +52,8 @@ private[spark] trait ShuffleManager { handle: ShuffleHandle, startPartition: Int, endPartition: Int, - context: TaskContext): ShuffleReader[K, C] + context: TaskContext, + metrics: ShuffleReadMetricsReporter): ShuffleReader[K, C] /** * Remove a shuffle's metadata from the ShuffleManager. diff --git a/core/src/main/scala/org/apache/spark/shuffle/metrics.scala b/core/src/main/scala/org/apache/spark/shuffle/metrics.scala new file mode 100644 index 0000000000000..33be677bc90cb --- /dev/null +++ b/core/src/main/scala/org/apache/spark/shuffle/metrics.scala @@ -0,0 +1,52 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.shuffle + +/** + * An interface for reporting shuffle read metrics, for each shuffle. This interface assumes + * all the methods are called on a single-threaded, i.e. concrete implementations would not need + * to synchronize. + * + * All methods have additional Spark visibility modifier to allow public, concrete implementations + * that still have these methods marked as private[spark]. + */ +private[spark] trait ShuffleReadMetricsReporter { + private[spark] def incRemoteBlocksFetched(v: Long): Unit + private[spark] def incLocalBlocksFetched(v: Long): Unit + private[spark] def incRemoteBytesRead(v: Long): Unit + private[spark] def incRemoteBytesReadToDisk(v: Long): Unit + private[spark] def incLocalBytesRead(v: Long): Unit + private[spark] def incFetchWaitTime(v: Long): Unit + private[spark] def incRecordsRead(v: Long): Unit +} + + +/** + * An interface for reporting shuffle write metrics. This interface assumes all the methods are + * called on a single-threaded, i.e. concrete implementations would not need to synchronize. + * + * All methods have additional Spark visibility modifier to allow public, concrete implementations + * that still have these methods marked as private[spark]. + */ +private[spark] trait ShuffleWriteMetricsReporter { + private[spark] def incBytesWritten(v: Long): Unit + private[spark] def incRecordsWritten(v: Long): Unit + private[spark] def incWriteTime(v: Long): Unit + private[spark] def decBytesWritten(v: Long): Unit + private[spark] def decRecordsWritten(v: Long): Unit +} diff --git a/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleManager.scala b/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleManager.scala index 0caf84c6050a8..b51a843a31c31 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleManager.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleManager.scala @@ -114,16 +114,19 @@ private[spark] class SortShuffleManager(conf: SparkConf) extends ShuffleManager handle: ShuffleHandle, startPartition: Int, endPartition: Int, - context: TaskContext): ShuffleReader[K, C] = { + context: TaskContext, + metrics: ShuffleReadMetricsReporter): ShuffleReader[K, C] = { new BlockStoreShuffleReader( - handle.asInstanceOf[BaseShuffleHandle[K, _, C]], startPartition, endPartition, context) + handle.asInstanceOf[BaseShuffleHandle[K, _, C]], + startPartition, endPartition, context, metrics) } /** Get a writer for a given partition. Called on executors by map tasks. */ override def getWriter[K, V]( handle: ShuffleHandle, mapId: Int, - context: TaskContext): ShuffleWriter[K, V] = { + context: TaskContext, + metrics: ShuffleWriteMetricsReporter): ShuffleWriter[K, V] = { numMapsForShuffle.putIfAbsent( handle.shuffleId, handle.asInstanceOf[BaseShuffleHandle[_, _, _]].numMaps) val env = SparkEnv.get @@ -136,15 +139,16 @@ private[spark] class SortShuffleManager(conf: SparkConf) extends ShuffleManager unsafeShuffleHandle, mapId, context, - env.conf) + env.conf, + metrics) case bypassMergeSortHandle: BypassMergeSortShuffleHandle[K @unchecked, V @unchecked] => new BypassMergeSortShuffleWriter( env.blockManager, shuffleBlockResolver.asInstanceOf[IndexShuffleBlockResolver], bypassMergeSortHandle, mapId, - context, - env.conf) + env.conf, + metrics) case other: BaseShuffleHandle[K @unchecked, V @unchecked, _] => new SortShuffleWriter(shuffleBlockResolver, other, mapId, context) } diff --git a/core/src/main/scala/org/apache/spark/status/AppStatusListener.scala b/core/src/main/scala/org/apache/spark/status/AppStatusListener.scala index 81d39e0407fed..8e845573a903d 100644 --- a/core/src/main/scala/org/apache/spark/status/AppStatusListener.scala +++ b/core/src/main/scala/org/apache/spark/status/AppStatusListener.scala @@ -473,6 +473,7 @@ private[spark] class AppStatusListener( val locality = event.taskInfo.taskLocality.toString() val count = stage.localitySummary.getOrElse(locality, 0L) + 1L stage.localitySummary = stage.localitySummary ++ Map(locality -> count) + stage.activeTasksPerExecutor(event.taskInfo.executorId) += 1 maybeUpdate(stage, now) stage.jobs.foreach { job => @@ -558,6 +559,7 @@ private[spark] class AppStatusListener( if (killedDelta > 0) { stage.killedSummary = killedTasksSummary(event.reason, stage.killedSummary) } + stage.activeTasksPerExecutor(event.taskInfo.executorId) -= 1 // [SPARK-24415] Wait for all tasks to finish before removing stage from live list val removeStage = stage.activeTasks == 0 && @@ -582,7 +584,11 @@ private[spark] class AppStatusListener( if (killedDelta > 0) { job.killedSummary = killedTasksSummary(event.reason, job.killedSummary) } - conditionalLiveUpdate(job, now, removeStage) + if (removeStage) { + update(job, now) + } else { + maybeUpdate(job, now) + } } val esummary = stage.executorSummary(event.taskInfo.executorId) @@ -593,7 +599,16 @@ private[spark] class AppStatusListener( if (metricsDelta != null) { esummary.metrics = LiveEntityHelpers.addMetrics(esummary.metrics, metricsDelta) } - conditionalLiveUpdate(esummary, now, removeStage) + + val isLastTask = stage.activeTasksPerExecutor(event.taskInfo.executorId) == 0 + + // If the last task of the executor finished, then update the esummary + // for both live and history events. + if (isLastTask) { + update(esummary, now) + } else { + maybeUpdate(esummary, now) + } if (!stage.cleaning && stage.savedTasks.get() > maxTasksPerStage) { stage.cleaning = true diff --git a/core/src/main/scala/org/apache/spark/status/AppStatusStore.scala b/core/src/main/scala/org/apache/spark/status/AppStatusStore.scala index 63b9d8988499d..5c0ed4d5d8f4c 100644 --- a/core/src/main/scala/org/apache/spark/status/AppStatusStore.scala +++ b/core/src/main/scala/org/apache/spark/status/AppStatusStore.scala @@ -351,7 +351,9 @@ private[spark] class AppStatusStore( def taskList(stageId: Int, stageAttemptId: Int, maxTasks: Int): Seq[v1.TaskData] = { val stageKey = Array(stageId, stageAttemptId) store.view(classOf[TaskDataWrapper]).index("stage").first(stageKey).last(stageKey).reverse() - .max(maxTasks).asScala.map(_.toApi).toSeq.reverse + .max(maxTasks).asScala.map { taskDataWrapper => + constructTaskData(taskDataWrapper) + }.toSeq.reverse } def taskList( @@ -390,7 +392,9 @@ private[spark] class AppStatusStore( } val ordered = if (ascending) indexed else indexed.reverse() - ordered.skip(offset).max(length).asScala.map(_.toApi).toSeq + ordered.skip(offset).max(length).asScala.map { taskDataWrapper => + constructTaskData(taskDataWrapper) + }.toSeq } def executorSummary(stageId: Int, attemptId: Int): Map[String, v1.ExecutorStageSummary] = { @@ -496,6 +500,24 @@ private[spark] class AppStatusStore( store.close() } + def constructTaskData(taskDataWrapper: TaskDataWrapper) : v1.TaskData = { + val taskDataOld: v1.TaskData = taskDataWrapper.toApi + val executorLogs: Option[Map[String, String]] = try { + Some(executorSummary(taskDataOld.executorId).executorLogs) + } catch { + case e: NoSuchElementException => e.getMessage + None + } + new v1.TaskData(taskDataOld.taskId, taskDataOld.index, + taskDataOld.attempt, taskDataOld.launchTime, taskDataOld.resultFetchStart, + taskDataOld.duration, taskDataOld.executorId, taskDataOld.host, taskDataOld.status, + taskDataOld.taskLocality, taskDataOld.speculative, taskDataOld.accumulatorUpdates, + taskDataOld.errorMessage, taskDataOld.taskMetrics, + executorLogs.getOrElse(Map[String, String]()), + AppStatusUtils.schedulerDelay(taskDataOld), + AppStatusUtils.gettingResultTime(taskDataOld)) + } + } private[spark] object AppStatusStore { diff --git a/core/src/main/scala/org/apache/spark/status/LiveEntity.scala b/core/src/main/scala/org/apache/spark/status/LiveEntity.scala index 80663318c1ba1..47e45a66ecccb 100644 --- a/core/src/main/scala/org/apache/spark/status/LiveEntity.scala +++ b/core/src/main/scala/org/apache/spark/status/LiveEntity.scala @@ -376,6 +376,8 @@ private class LiveStage extends LiveEntity { val executorSummaries = new HashMap[String, LiveExecutorStageSummary]() + val activeTasksPerExecutor = new HashMap[String, Int]().withDefaultValue(0) + var blackListedExecutors = new HashSet[String]() // Used for cleanup of tasks after they reach the configured limit. Not written to the store. diff --git a/core/src/main/scala/org/apache/spark/status/api/v1/StagesResource.scala b/core/src/main/scala/org/apache/spark/status/api/v1/StagesResource.scala index 30d52b97833e6..f81892734c2de 100644 --- a/core/src/main/scala/org/apache/spark/status/api/v1/StagesResource.scala +++ b/core/src/main/scala/org/apache/spark/status/api/v1/StagesResource.scala @@ -16,15 +16,16 @@ */ package org.apache.spark.status.api.v1 -import java.util.{List => JList} +import java.util.{HashMap, List => JList, Locale} import javax.ws.rs._ -import javax.ws.rs.core.MediaType +import javax.ws.rs.core.{Context, MediaType, MultivaluedMap, UriInfo} import org.apache.spark.SparkException import org.apache.spark.scheduler.StageInfo import org.apache.spark.status.api.v1.StageStatus._ import org.apache.spark.status.api.v1.TaskSorting._ import org.apache.spark.ui.SparkUI +import org.apache.spark.ui.jobs.ApiHelper._ @Produces(Array(MediaType.APPLICATION_JSON)) private[v1] class StagesResource extends BaseAppResource { @@ -102,4 +103,120 @@ private[v1] class StagesResource extends BaseAppResource { withUI(_.store.taskList(stageId, stageAttemptId, offset, length, sortBy)) } + // This api needs to stay formatted exactly as it is below, since, it is being used by the + // datatables for the stages page. + @GET + @Path("{stageId: \\d+}/{stageAttemptId: \\d+}/taskTable") + def taskTable( + @PathParam("stageId") stageId: Int, + @PathParam("stageAttemptId") stageAttemptId: Int, + @QueryParam("details") @DefaultValue("true") details: Boolean, + @Context uriInfo: UriInfo): + HashMap[String, Object] = { + withUI { ui => + val uriQueryParameters = uriInfo.getQueryParameters(true) + val totalRecords = uriQueryParameters.getFirst("numTasks") + var isSearch = false + var searchValue: String = null + var filteredRecords = totalRecords + // The datatables client API sends a list of query parameters to the server which contain + // information like the columns to be sorted, search value typed by the user in the search + // box, pagination index etc. For more information on these query parameters, + // refer https://datatables.net/manual/server-side. + if (uriQueryParameters.getFirst("search[value]") != null && + uriQueryParameters.getFirst("search[value]").length > 0) { + isSearch = true + searchValue = uriQueryParameters.getFirst("search[value]") + } + val _tasksToShow: Seq[TaskData] = doPagination(uriQueryParameters, stageId, stageAttemptId, + isSearch, totalRecords.toInt) + val ret = new HashMap[String, Object]() + if (_tasksToShow.nonEmpty) { + // Performs server-side search based on input from user + if (isSearch) { + val filteredTaskList = filterTaskList(_tasksToShow, searchValue) + filteredRecords = filteredTaskList.length.toString + if (filteredTaskList.length > 0) { + val pageStartIndex = uriQueryParameters.getFirst("start").toInt + val pageLength = uriQueryParameters.getFirst("length").toInt + ret.put("aaData", filteredTaskList.slice( + pageStartIndex, pageStartIndex + pageLength)) + } else { + ret.put("aaData", filteredTaskList) + } + } else { + ret.put("aaData", _tasksToShow) + } + } else { + ret.put("aaData", _tasksToShow) + } + ret.put("recordsTotal", totalRecords) + ret.put("recordsFiltered", filteredRecords) + ret + } + } + + // Performs pagination on the server side + def doPagination(queryParameters: MultivaluedMap[String, String], stageId: Int, + stageAttemptId: Int, isSearch: Boolean, totalRecords: Int): Seq[TaskData] = { + var columnNameToSort = queryParameters.getFirst("columnNameToSort") + // Sorting on Logs column will default to Index column sort + if (columnNameToSort.equalsIgnoreCase("Logs")) { + columnNameToSort = "Index" + } + val isAscendingStr = queryParameters.getFirst("order[0][dir]") + var pageStartIndex = 0 + var pageLength = totalRecords + // We fetch only the desired rows upto the specified page length for all cases except when a + // search query is present, in that case, we need to fetch all the rows to perform the search + // on the entire table + if (!isSearch) { + pageStartIndex = queryParameters.getFirst("start").toInt + pageLength = queryParameters.getFirst("length").toInt + } + withUI(_.store.taskList(stageId, stageAttemptId, pageStartIndex, pageLength, + indexName(columnNameToSort), isAscendingStr.equalsIgnoreCase("asc"))) + } + + // Filters task list based on search parameter + def filterTaskList( + taskDataList: Seq[TaskData], + searchValue: String): Seq[TaskData] = { + val defaultOptionString: String = "d" + val searchValueLowerCase = searchValue.toLowerCase(Locale.ROOT) + val containsValue = (taskDataParams: Any) => taskDataParams.toString.toLowerCase( + Locale.ROOT).contains(searchValueLowerCase) + val taskMetricsContainsValue = (task: TaskData) => task.taskMetrics match { + case None => false + case Some(metrics) => + (containsValue(task.taskMetrics.get.executorDeserializeTime) + || containsValue(task.taskMetrics.get.executorRunTime) + || containsValue(task.taskMetrics.get.jvmGcTime) + || containsValue(task.taskMetrics.get.resultSerializationTime) + || containsValue(task.taskMetrics.get.memoryBytesSpilled) + || containsValue(task.taskMetrics.get.diskBytesSpilled) + || containsValue(task.taskMetrics.get.peakExecutionMemory) + || containsValue(task.taskMetrics.get.inputMetrics.bytesRead) + || containsValue(task.taskMetrics.get.inputMetrics.recordsRead) + || containsValue(task.taskMetrics.get.outputMetrics.bytesWritten) + || containsValue(task.taskMetrics.get.outputMetrics.recordsWritten) + || containsValue(task.taskMetrics.get.shuffleReadMetrics.fetchWaitTime) + || containsValue(task.taskMetrics.get.shuffleReadMetrics.recordsRead) + || containsValue(task.taskMetrics.get.shuffleWriteMetrics.bytesWritten) + || containsValue(task.taskMetrics.get.shuffleWriteMetrics.recordsWritten) + || containsValue(task.taskMetrics.get.shuffleWriteMetrics.writeTime)) + } + val filteredTaskDataSequence: Seq[TaskData] = taskDataList.filter(f => + (containsValue(f.taskId) || containsValue(f.index) || containsValue(f.attempt) + || containsValue(f.launchTime) + || containsValue(f.resultFetchStart.getOrElse(defaultOptionString)) + || containsValue(f.duration.getOrElse(defaultOptionString)) + || containsValue(f.executorId) || containsValue(f.host) || containsValue(f.status) + || containsValue(f.taskLocality) || containsValue(f.speculative) + || containsValue(f.errorMessage.getOrElse(defaultOptionString)) + || taskMetricsContainsValue(f) + || containsValue(f.schedulerDelay) || containsValue(f.gettingResultTime))) + filteredTaskDataSequence + } + } diff --git a/core/src/main/scala/org/apache/spark/status/api/v1/api.scala b/core/src/main/scala/org/apache/spark/status/api/v1/api.scala index 30afd8b769720..aa21da2b66ab2 100644 --- a/core/src/main/scala/org/apache/spark/status/api/v1/api.scala +++ b/core/src/main/scala/org/apache/spark/status/api/v1/api.scala @@ -253,7 +253,10 @@ class TaskData private[spark]( val speculative: Boolean, val accumulatorUpdates: Seq[AccumulableInfo], val errorMessage: Option[String] = None, - val taskMetrics: Option[TaskMetrics] = None) + val taskMetrics: Option[TaskMetrics] = None, + val executorLogs: Map[String, String], + val schedulerDelay: Long, + val gettingResultTime: Long) class TaskMetrics private[spark]( val executorDeserializeTime: Long, diff --git a/core/src/main/scala/org/apache/spark/status/storeTypes.scala b/core/src/main/scala/org/apache/spark/status/storeTypes.scala index 646cf25880e37..ef19e86f3135f 100644 --- a/core/src/main/scala/org/apache/spark/status/storeTypes.scala +++ b/core/src/main/scala/org/apache/spark/status/storeTypes.scala @@ -283,7 +283,10 @@ private[spark] class TaskDataWrapper( speculative, accumulatorUpdates, errorMessage, - metrics) + metrics, + executorLogs = null, + schedulerDelay = 0L, + gettingResultTime = 0L) } @JsonIgnore @KVIndex(TaskIndexNames.STAGE) diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala index 2361dd9aafea7..69258faa6c409 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala @@ -33,10 +33,9 @@ import scala.util.Random import scala.util.control.NonFatal import com.codahale.metrics.{MetricRegistry, MetricSet} -import com.google.common.io.CountingOutputStream import org.apache.spark._ -import org.apache.spark.executor.{DataReadMethod, ShuffleWriteMetrics} +import org.apache.spark.executor.DataReadMethod import org.apache.spark.internal.{config, Logging} import org.apache.spark.memory.{MemoryManager, MemoryMode} import org.apache.spark.metrics.source.Source @@ -50,7 +49,7 @@ import org.apache.spark.network.util.TransportConf import org.apache.spark.rpc.RpcEnv import org.apache.spark.scheduler.ExecutorCacheTaskLocation import org.apache.spark.serializer.{SerializerInstance, SerializerManager} -import org.apache.spark.shuffle.ShuffleManager +import org.apache.spark.shuffle.{ShuffleManager, ShuffleWriteMetricsReporter} import org.apache.spark.storage.memory._ import org.apache.spark.unsafe.Platform import org.apache.spark.util._ @@ -237,7 +236,7 @@ private[spark] class BlockManager( val priorityClass = conf.get( "spark.storage.replication.policy", classOf[RandomBlockReplicationPolicy].getName) val clazz = Utils.classForName(priorityClass) - val ret = clazz.newInstance.asInstanceOf[BlockReplicationPolicy] + val ret = clazz.getConstructor().newInstance().asInstanceOf[BlockReplicationPolicy] logInfo(s"Using $priorityClass for block replication policy") ret } @@ -938,7 +937,7 @@ private[spark] class BlockManager( file: File, serializerInstance: SerializerInstance, bufferSize: Int, - writeMetrics: ShuffleWriteMetrics): DiskBlockObjectWriter = { + writeMetrics: ShuffleWriteMetricsReporter): DiskBlockObjectWriter = { val syncWrites = conf.getBoolean("spark.shuffle.sync", false) new DiskBlockObjectWriter(file, serializerManager, serializerInstance, bufferSize, syncWrites, writeMetrics, blockId) diff --git a/core/src/main/scala/org/apache/spark/storage/DiskBlockObjectWriter.scala b/core/src/main/scala/org/apache/spark/storage/DiskBlockObjectWriter.scala index a024c83d8d8b7..17390f9c60e79 100644 --- a/core/src/main/scala/org/apache/spark/storage/DiskBlockObjectWriter.scala +++ b/core/src/main/scala/org/apache/spark/storage/DiskBlockObjectWriter.scala @@ -20,9 +20,9 @@ package org.apache.spark.storage import java.io.{BufferedOutputStream, File, FileOutputStream, OutputStream} import java.nio.channels.FileChannel -import org.apache.spark.executor.ShuffleWriteMetrics import org.apache.spark.internal.Logging import org.apache.spark.serializer.{SerializationStream, SerializerInstance, SerializerManager} +import org.apache.spark.shuffle.ShuffleWriteMetricsReporter import org.apache.spark.util.Utils /** @@ -43,7 +43,7 @@ private[spark] class DiskBlockObjectWriter( syncWrites: Boolean, // These write metrics concurrently shared with other active DiskBlockObjectWriters who // are themselves performing writes. All updates must be relative. - writeMetrics: ShuffleWriteMetrics, + writeMetrics: ShuffleWriteMetricsReporter, val blockId: BlockId = null) extends OutputStream with Logging { diff --git a/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala b/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala index aecc2284a9588..86f7c08eddcb5 100644 --- a/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala +++ b/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala @@ -30,7 +30,7 @@ import org.apache.spark.internal.Logging import org.apache.spark.network.buffer.{FileSegmentManagedBuffer, ManagedBuffer} import org.apache.spark.network.shuffle._ import org.apache.spark.network.util.TransportConf -import org.apache.spark.shuffle.FetchFailedException +import org.apache.spark.shuffle.{FetchFailedException, ShuffleReadMetricsReporter} import org.apache.spark.util.Utils import org.apache.spark.util.io.ChunkedByteBufferOutputStream @@ -51,7 +51,7 @@ import org.apache.spark.util.io.ChunkedByteBufferOutputStream * For each block we also require the size (in bytes as a long field) in * order to throttle the memory usage. Note that zero-sized blocks are * already excluded, which happened in - * [[MapOutputTracker.convertMapStatuses]]. + * [[org.apache.spark.MapOutputTracker.convertMapStatuses]]. * @param streamWrapper A function to wrap the returned input stream. * @param maxBytesInFlight max size (in bytes) of remote blocks to fetch at any given point. * @param maxReqsInFlight max number of remote requests to fetch blocks at any given point. @@ -59,6 +59,7 @@ import org.apache.spark.util.io.ChunkedByteBufferOutputStream * for a given remote host:port. * @param maxReqSizeShuffleToMem max size (in bytes) of a request that can be shuffled to memory. * @param detectCorrupt whether to detect any corruption in fetched blocks. + * @param shuffleMetrics used to report shuffle metrics. */ private[spark] final class ShuffleBlockFetcherIterator( @@ -71,7 +72,8 @@ final class ShuffleBlockFetcherIterator( maxReqsInFlight: Int, maxBlocksInFlightPerAddress: Int, maxReqSizeShuffleToMem: Long, - detectCorrupt: Boolean) + detectCorrupt: Boolean, + shuffleMetrics: ShuffleReadMetricsReporter) extends Iterator[(BlockId, InputStream)] with DownloadFileManager with Logging { import ShuffleBlockFetcherIterator._ @@ -137,8 +139,6 @@ final class ShuffleBlockFetcherIterator( */ private[this] val corruptedBlocks = mutable.HashSet[BlockId]() - private[this] val shuffleMetrics = context.taskMetrics().createTempShuffleReadMetrics() - /** * Whether the iterator is still active. If isZombie is true, the callback interface will no * longer place fetched blocks into [[results]]. diff --git a/core/src/main/scala/org/apache/spark/storage/StorageUtils.scala b/core/src/main/scala/org/apache/spark/storage/StorageUtils.scala index adc406bb1c441..1c9ea1dba97d7 100644 --- a/core/src/main/scala/org/apache/spark/storage/StorageUtils.scala +++ b/core/src/main/scala/org/apache/spark/storage/StorageUtils.scala @@ -22,9 +22,12 @@ import java.nio.{ByteBuffer, MappedByteBuffer} import scala.collection.Map import scala.collection.mutable +import org.apache.commons.lang3.{JavaVersion, SystemUtils} +import sun.misc.Unsafe import sun.nio.ch.DirectBuffer import org.apache.spark.internal.Logging +import org.apache.spark.util.Utils /** * Storage information for each BlockManager. @@ -193,6 +196,31 @@ private[spark] class StorageStatus( /** Helper methods for storage-related objects. */ private[spark] object StorageUtils extends Logging { + + // In Java 8, the type of DirectBuffer.cleaner() was sun.misc.Cleaner, and it was possible + // to access the method sun.misc.Cleaner.clean() to invoke it. The type changed to + // jdk.internal.ref.Cleaner in later JDKs, and the .clean() method is not accessible even with + // reflection. However sun.misc.Unsafe added a invokeCleaner() method in JDK 9+ and this is + // still accessible with reflection. + private val bufferCleaner: DirectBuffer => Unit = + if (SystemUtils.isJavaVersionAtLeast(JavaVersion.JAVA_9)) { + val cleanerMethod = + Utils.classForName("sun.misc.Unsafe").getMethod("invokeCleaner", classOf[ByteBuffer]) + val unsafeField = classOf[Unsafe].getDeclaredField("theUnsafe") + unsafeField.setAccessible(true) + val unsafe = unsafeField.get(null).asInstanceOf[Unsafe] + buffer: DirectBuffer => cleanerMethod.invoke(unsafe, buffer) + } else { + val cleanerMethod = Utils.classForName("sun.misc.Cleaner").getMethod("clean") + buffer: DirectBuffer => { + // Careful to avoid the return type of .cleaner(), which changes with JDK + val cleaner: AnyRef = buffer.cleaner() + if (cleaner != null) { + cleanerMethod.invoke(cleaner) + } + } + } + /** * Attempt to clean up a ByteBuffer if it is direct or memory-mapped. This uses an *unsafe* Sun * API that will cause errors if one attempts to read from the disposed buffer. However, neither @@ -204,14 +232,8 @@ private[spark] object StorageUtils extends Logging { def dispose(buffer: ByteBuffer): Unit = { if (buffer != null && buffer.isInstanceOf[MappedByteBuffer]) { logTrace(s"Disposing of $buffer") - cleanDirectBuffer(buffer.asInstanceOf[DirectBuffer]) + bufferCleaner(buffer.asInstanceOf[DirectBuffer]) } } - private def cleanDirectBuffer(buffer: DirectBuffer) = { - val cleaner = buffer.cleaner() - if (cleaner != null) { - cleaner.clean() - } - } } diff --git a/core/src/main/scala/org/apache/spark/ui/JettyUtils.scala b/core/src/main/scala/org/apache/spark/ui/JettyUtils.scala index 52a955111231a..316af9b79d286 100644 --- a/core/src/main/scala/org/apache/spark/ui/JettyUtils.scala +++ b/core/src/main/scala/org/apache/spark/ui/JettyUtils.scala @@ -356,13 +356,15 @@ private[spark] object JettyUtils extends Logging { (connector, connector.getLocalPort()) } + val httpConfig = new HttpConfiguration() + httpConfig.setRequestHeaderSize(conf.get(UI_REQUEST_HEADER_SIZE).toInt) // If SSL is configured, create the secure connector first. val securePort = sslOptions.createJettySslContextFactory().map { factory => val securePort = sslOptions.port.getOrElse(if (port > 0) Utils.userPort(port, 400) else 0) val secureServerName = if (serverName.nonEmpty) s"$serverName (HTTPS)" else serverName val connectionFactories = AbstractConnectionFactory.getFactories(factory, - new HttpConnectionFactory()) + new HttpConnectionFactory(httpConfig)) def sslConnect(currentPort: Int): (ServerConnector, Int) = { newConnector(connectionFactories, currentPort) @@ -377,7 +379,7 @@ private[spark] object JettyUtils extends Logging { // Bind the HTTP port. def httpConnect(currentPort: Int): (ServerConnector, Int) = { - newConnector(Array(new HttpConnectionFactory()), currentPort) + newConnector(Array(new HttpConnectionFactory(httpConfig)), currentPort) } val (httpConnector, httpPort) = Utils.startServiceOnPort[ServerConnector](port, httpConnect, diff --git a/core/src/main/scala/org/apache/spark/ui/UIUtils.scala b/core/src/main/scala/org/apache/spark/ui/UIUtils.scala index 3aed4647a96f0..60a929375baae 100644 --- a/core/src/main/scala/org/apache/spark/ui/UIUtils.scala +++ b/core/src/main/scala/org/apache/spark/ui/UIUtils.scala @@ -204,6 +204,8 @@ private[spark] object UIUtils extends Logging { href={prependBaseUri(request, "/static/dataTables.bootstrap.css")} type="text/css"/> + diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/ExecutorTable.scala b/core/src/main/scala/org/apache/spark/ui/jobs/ExecutorTable.scala deleted file mode 100644 index 0ff64f053f371..0000000000000 --- a/core/src/main/scala/org/apache/spark/ui/jobs/ExecutorTable.scala +++ /dev/null @@ -1,152 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.ui.jobs - -import scala.xml.{Node, Unparsed} - -import org.apache.spark.status.AppStatusStore -import org.apache.spark.status.api.v1.StageData -import org.apache.spark.ui.{ToolTips, UIUtils} -import org.apache.spark.util.Utils - -/** Stage summary grouped by executors. */ -private[ui] class ExecutorTable(stage: StageData, store: AppStatusStore) { - - import ApiHelper._ - - def toNodeSeq: Seq[Node] = { - - - - - - - - - - {if (hasInput(stage)) { - - }} - {if (hasOutput(stage)) { - - }} - {if (hasShuffleRead(stage)) { - - }} - {if (hasShuffleWrite(stage)) { - - }} - {if (hasBytesSpilled(stage)) { - - - }} - - - - {createExecutorTable(stage)} - -
Executor IDAddressTask TimeTotal TasksFailed TasksKilled TasksSucceeded Tasks - Input Size / Records - - Output Size / Records - - - Shuffle Read Size / Records - - - Shuffle Write Size / Records - Shuffle Spill (Memory)Shuffle Spill (Disk) - - Blacklisted - -
- - } - - private def createExecutorTable(stage: StageData) : Seq[Node] = { - val executorSummary = store.executorSummary(stage.stageId, stage.attemptId) - - executorSummary.toSeq.sortBy(_._1).map { case (k, v) => - val executor = store.asOption(store.executorSummary(k)) - - -
{k}
-
- { - executor.map(_.executorLogs).getOrElse(Map.empty).map { - case (logName, logUrl) => - } - } -
- - {executor.map { e => e.hostPort }.getOrElse("CANNOT FIND ADDRESS")} - {UIUtils.formatDuration(v.taskTime)} - {v.failedTasks + v.succeededTasks + v.killedTasks} - {v.failedTasks} - {v.killedTasks} - {v.succeededTasks} - {if (hasInput(stage)) { - - {s"${Utils.bytesToString(v.inputBytes)} / ${v.inputRecords}"} - - }} - {if (hasOutput(stage)) { - - {s"${Utils.bytesToString(v.outputBytes)} / ${v.outputRecords}"} - - }} - {if (hasShuffleRead(stage)) { - - {s"${Utils.bytesToString(v.shuffleRead)} / ${v.shuffleReadRecords}"} - - }} - {if (hasShuffleWrite(stage)) { - - {s"${Utils.bytesToString(v.shuffleWrite)} / ${v.shuffleWriteRecords}"} - - }} - {if (hasBytesSpilled(stage)) { - - {Utils.bytesToString(v.memoryBytesSpilled)} - - - {Utils.bytesToString(v.diskBytesSpilled)} - - }} - { - if (executor.map(_.isBlacklisted).getOrElse(false)) { - for application - } else if (v.isBlacklistedForStage) { - for stage - } else { - false - } - } - - } - } - -} diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala b/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala index 477b9ce7f7848..a213b764abea7 100644 --- a/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala +++ b/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala @@ -92,6 +92,14 @@ private[ui] class StagePage(parent: StagesTab, store: AppStatusStore) extends We val parameterTaskSortDesc = UIUtils.stripXSS(request.getParameter("task.desc")) val parameterTaskPageSize = UIUtils.stripXSS(request.getParameter("task.pageSize")) + val eventTimelineParameterTaskPage = UIUtils.stripXSS( + request.getParameter("task.eventTimelinePageNumber")) + val eventTimelineParameterTaskPageSize = UIUtils.stripXSS( + request.getParameter("task.eventTimelinePageSize")) + var eventTimelineTaskPage = Option(eventTimelineParameterTaskPage).map(_.toInt).getOrElse(1) + var eventTimelineTaskPageSize = Option( + eventTimelineParameterTaskPageSize).map(_.toInt).getOrElse(100) + val taskPage = Option(parameterTaskPage).map(_.toInt).getOrElse(1) val taskSortColumn = Option(parameterTaskSortColumn).map { sortColumn => UIUtils.decodeURLParameter(sortColumn) @@ -132,6 +140,14 @@ private[ui] class StagePage(parent: StagesTab, store: AppStatusStore) extends We } else { s"$totalTasks, showing $storedTasks" } + if (eventTimelineTaskPageSize < 1 || eventTimelineTaskPageSize > totalTasks) { + eventTimelineTaskPageSize = totalTasks + } + val eventTimelineTotalPages = + (totalTasks + eventTimelineTaskPageSize - 1) / eventTimelineTaskPageSize + if (eventTimelineTaskPage < 1 || eventTimelineTaskPage > eventTimelineTotalPages) { + eventTimelineTaskPage = 1 + } val summary =
@@ -152,20 +168,20 @@ private[ui] class StagePage(parent: StagesTab, store: AppStatusStore) extends We }} {if (hasOutput(stageData)) {
  • - Output: + Output Size / Records: {s"${Utils.bytesToString(stageData.outputBytes)} / ${stageData.outputRecords}"}
  • }} {if (hasShuffleRead(stageData)) {
  • - Shuffle Read: + Shuffle Read Size / Records: {s"${Utils.bytesToString(stageData.shuffleReadBytes)} / " + s"${stageData.shuffleReadRecords}"}
  • }} {if (hasShuffleWrite(stageData)) {
  • - Shuffle Write: + Shuffle Write Size / Records: {s"${Utils.bytesToString(stageData.shuffleWriteBytes)} / " + s"${stageData.shuffleWriteRecords}"}
  • @@ -183,82 +199,16 @@ private[ui] class StagePage(parent: StagesTab, store: AppStatusStore) extends We {if (!stageJobIds.isEmpty) {
  • Associated Job Ids: - {stageJobIds.map(jobId => {val detailUrl = "%s/jobs/job/?id=%s".format( - UIUtils.prependBaseUri(request, parent.basePath), jobId) - {s"${jobId}"}    - })} + {stageJobIds.sorted.map { jobId => + val jobURL = "%s/jobs/job/?id=%s" + .format(UIUtils.prependBaseUri(request, parent.basePath), jobId) + {jobId.toString}  + }}
  • }}
    - val showAdditionalMetrics = -
    - - - Show Additional Metrics - - -
    - val stageGraph = parent.store.asOption(parent.store.operationGraphForStage(stageId)) val dagViz = UIUtils.showDagVizForStage(stageId, stageGraph) @@ -276,7 +226,7 @@ private[ui] class StagePage(parent: StagesTab, store: AppStatusStore) extends We stageData.accumulatorUpdates.toSeq) val currentTime = System.currentTimeMillis() - val (taskTable, taskTableHTML) = try { + val taskTable = try { val _taskTable = new TaskPagedTable( stageData, UIUtils.prependBaseUri(request, parent.basePath) + @@ -287,17 +237,10 @@ private[ui] class StagePage(parent: StagesTab, store: AppStatusStore) extends We desc = taskSortDesc, store = parent.store ) - (_taskTable, _taskTable.table(taskPage)) + _taskTable } catch { case e @ (_ : IllegalArgumentException | _ : IndexOutOfBoundsException) => - val errorMessage = -
    -

    Error while rendering stage table:

    -
    -              {Utils.exceptionString(e)}
    -            
    -
    - (null, errorMessage) + null } val jsForScrollingDownToTaskTable = @@ -315,190 +258,36 @@ private[ui] class StagePage(parent: StagesTab, store: AppStatusStore) extends We } - val metricsSummary = store.taskSummary(stageData.stageId, stageData.attemptId, - Array(0, 0.25, 0.5, 0.75, 1.0)) - - val summaryTable = metricsSummary.map { metrics => - def timeQuantiles(data: IndexedSeq[Double]): Seq[Node] = { - data.map { millis => - {UIUtils.formatDuration(millis.toLong)} - } - } - - def sizeQuantiles(data: IndexedSeq[Double]): Seq[Node] = { - data.map { size => - {Utils.bytesToString(size.toLong)} - } - } - - def sizeQuantilesWithRecords( - data: IndexedSeq[Double], - records: IndexedSeq[Double]) : Seq[Node] = { - data.zip(records).map { case (d, r) => - {s"${Utils.bytesToString(d.toLong)} / ${r.toLong}"} - } - } - - def titleCell(title: String, tooltip: String): Seq[Node] = { - - - {title} - - - } - - def simpleTitleCell(title: String): Seq[Node] = {title} - - val deserializationQuantiles = titleCell("Task Deserialization Time", - ToolTips.TASK_DESERIALIZATION_TIME) ++ timeQuantiles(metrics.executorDeserializeTime) - - val serviceQuantiles = simpleTitleCell("Duration") ++ timeQuantiles(metrics.executorRunTime) - - val gcQuantiles = titleCell("GC Time", ToolTips.GC_TIME) ++ timeQuantiles(metrics.jvmGcTime) - - val serializationQuantiles = titleCell("Result Serialization Time", - ToolTips.RESULT_SERIALIZATION_TIME) ++ timeQuantiles(metrics.resultSerializationTime) - - val gettingResultQuantiles = titleCell("Getting Result Time", ToolTips.GETTING_RESULT_TIME) ++ - timeQuantiles(metrics.gettingResultTime) - - val peakExecutionMemoryQuantiles = titleCell("Peak Execution Memory", - ToolTips.PEAK_EXECUTION_MEMORY) ++ sizeQuantiles(metrics.peakExecutionMemory) - - // The scheduler delay includes the network delay to send the task to the worker - // machine and to send back the result (but not the time to fetch the task result, - // if it needed to be fetched from the block manager on the worker). - val schedulerDelayQuantiles = titleCell("Scheduler Delay", ToolTips.SCHEDULER_DELAY) ++ - timeQuantiles(metrics.schedulerDelay) - - def inputQuantiles: Seq[Node] = { - simpleTitleCell("Input Size / Records") ++ - sizeQuantilesWithRecords(metrics.inputMetrics.bytesRead, metrics.inputMetrics.recordsRead) - } - - def outputQuantiles: Seq[Node] = { - simpleTitleCell("Output Size / Records") ++ - sizeQuantilesWithRecords(metrics.outputMetrics.bytesWritten, - metrics.outputMetrics.recordsWritten) - } - - def shuffleReadBlockedQuantiles: Seq[Node] = { - titleCell("Shuffle Read Blocked Time", ToolTips.SHUFFLE_READ_BLOCKED_TIME) ++ - timeQuantiles(metrics.shuffleReadMetrics.fetchWaitTime) - } - - def shuffleReadTotalQuantiles: Seq[Node] = { - titleCell("Shuffle Read Size / Records", ToolTips.SHUFFLE_READ) ++ - sizeQuantilesWithRecords(metrics.shuffleReadMetrics.readBytes, - metrics.shuffleReadMetrics.readRecords) - } - - def shuffleReadRemoteQuantiles: Seq[Node] = { - titleCell("Shuffle Remote Reads", ToolTips.SHUFFLE_READ_REMOTE_SIZE) ++ - sizeQuantiles(metrics.shuffleReadMetrics.remoteBytesRead) - } - - def shuffleWriteQuantiles: Seq[Node] = { - simpleTitleCell("Shuffle Write Size / Records") ++ - sizeQuantilesWithRecords(metrics.shuffleWriteMetrics.writeBytes, - metrics.shuffleWriteMetrics.writeRecords) - } - - def memoryBytesSpilledQuantiles: Seq[Node] = { - simpleTitleCell("Shuffle spill (memory)") ++ sizeQuantiles(metrics.memoryBytesSpilled) - } - - def diskBytesSpilledQuantiles: Seq[Node] = { - simpleTitleCell("Shuffle spill (disk)") ++ sizeQuantiles(metrics.diskBytesSpilled) - } - - val listings: Seq[Seq[Node]] = Seq( - {serviceQuantiles}, - {schedulerDelayQuantiles}, - - {deserializationQuantiles} - - {gcQuantiles}, - - {serializationQuantiles} - , - {gettingResultQuantiles}, - - {peakExecutionMemoryQuantiles} - , - if (hasInput(stageData)) {inputQuantiles} else Nil, - if (hasOutput(stageData)) {outputQuantiles} else Nil, - if (hasShuffleRead(stageData)) { - - {shuffleReadBlockedQuantiles} - - {shuffleReadTotalQuantiles} - - {shuffleReadRemoteQuantiles} - - } else { - Nil - }, - if (hasShuffleWrite(stageData)) {shuffleWriteQuantiles} else Nil, - if (hasBytesSpilled(stageData)) {memoryBytesSpilledQuantiles} else Nil, - if (hasBytesSpilled(stageData)) {diskBytesSpilledQuantiles} else Nil) - - val quantileHeaders = Seq("Metric", "Min", "25th percentile", "Median", "75th percentile", - "Max") - // The summary table does not use CSS to stripe rows, which doesn't work with hidden - // rows (instead, JavaScript in table.js is used to stripe the non-hidden rows). - UIUtils.listingTable( - quantileHeaders, - identity[Seq[Node]], - listings, - fixedWidth = true, - id = Some("task-summary-table"), - stripeRowsWithCss = false) - } - - val executorTable = new ExecutorTable(stageData, parent.store) - - val maybeAccumulableTable: Seq[Node] = - if (hasAccumulators(stageData)) {

    Accumulators

    ++ accumulableTable } else Seq() - - val aggMetrics = - -

    - - Aggregated Metrics by Executor -

    -
    -
    - {executorTable.toNodeSeq} -
    - val content = summary ++ - dagViz ++ - showAdditionalMetrics ++ + dagViz ++
    ++ makeTimeline( // Only show the tasks in the table - Option(taskTable).map(_.dataSource.tasks).getOrElse(Nil), - currentTime) ++ -

    Summary Metrics for {numCompleted} Completed Tasks

    ++ -
    {summaryTable.getOrElse("No tasks have reported metrics yet.")}
    ++ - aggMetrics ++ - maybeAccumulableTable ++ - -

    - - Tasks ({totalTasksNumStr}) -

    -
    ++ -
    - {taskTableHTML ++ jsForScrollingDownToTaskTable} -
    - UIUtils.headerSparkPage(request, stageHeader, content, parent, showVisualization = true) + Option(taskTable).map({ taskPagedTable => + val from = (eventTimelineTaskPage - 1) * eventTimelineTaskPageSize + val to = taskPagedTable.dataSource.dataSize.min( + eventTimelineTaskPage * eventTimelineTaskPageSize) + taskPagedTable.dataSource.sliceData(from, to)}).getOrElse(Nil), currentTime, + eventTimelineTaskPage, eventTimelineTaskPageSize, eventTimelineTotalPages, stageId, + stageAttemptId, totalTasks) ++ +
    + + +
    + UIUtils.headerSparkPage(request, stageHeader, content, parent, showVisualization = true, + useDataTables = true) + } - def makeTimeline(tasks: Seq[TaskData], currentTime: Long): Seq[Node] = { + def makeTimeline( + tasks: Seq[TaskData], + currentTime: Long, + page: Int, + pageSize: Int, + totalPages: Int, + stageId: Int, + stageAttemptId: Int, + totalTasks: Int): Seq[Node] = { val executorsSet = new HashSet[(String, String)] var minLaunchTime = Long.MaxValue var maxFinishTime = Long.MinValue @@ -657,6 +446,31 @@ private[ui] class StagePage(parent: StagesTab, store: AppStatusStore) extends We Enable zooming +
    +
    + + + + + + + + + + +
    +
    {TIMELINE_LEGEND} ++ @@ -843,7 +657,7 @@ private[ui] class TaskPagedTable( {UIUtils.formatDate(task.launchTime)} - {formatDuration(task.duration)} + {formatDuration(task.taskMetrics.map(_.executorRunTime))} {UIUtils.formatDuration(AppStatusUtils.schedulerDelay(task))} @@ -958,7 +772,7 @@ private[ui] class TaskPagedTable( } } -private[ui] object ApiHelper { +private[spark] object ApiHelper { val HEADER_ID = "ID" val HEADER_TASK_INDEX = "Index" @@ -996,7 +810,9 @@ private[ui] object ApiHelper { HEADER_EXECUTOR -> TaskIndexNames.EXECUTOR, HEADER_HOST -> TaskIndexNames.HOST, HEADER_LAUNCH_TIME -> TaskIndexNames.LAUNCH_TIME, - HEADER_DURATION -> TaskIndexNames.DURATION, + // SPARK-26109: Duration of task as executorRunTime to make it consistent with the + // aggregated tasks summary metrics table and the previous versions of Spark. + HEADER_DURATION -> TaskIndexNames.EXEC_RUN_TIME, HEADER_SCHEDULER_DELAY -> TaskIndexNames.SCHEDULER_DELAY, HEADER_DESER_TIME -> TaskIndexNames.DESER_TIME, HEADER_GC_TIME -> TaskIndexNames.GC_TIME, diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/StageTable.scala b/core/src/main/scala/org/apache/spark/ui/jobs/StageTable.scala index b9abd39b4705d..766efc15e26ba 100644 --- a/core/src/main/scala/org/apache/spark/ui/jobs/StageTable.scala +++ b/core/src/main/scala/org/apache/spark/ui/jobs/StageTable.scala @@ -368,7 +368,7 @@ private[ui] class StagePagedTable( {if (cachedRddInfos.nonEmpty) { Text("RDD: ") ++ cachedRddInfos.map { i => - {i.name} + {i.name} } }}
    {s.details}
    diff --git a/core/src/main/scala/org/apache/spark/ui/storage/StoragePage.scala b/core/src/main/scala/org/apache/spark/ui/storage/StoragePage.scala index 3eb546e336e99..2488197814ffd 100644 --- a/core/src/main/scala/org/apache/spark/ui/storage/StoragePage.scala +++ b/core/src/main/scala/org/apache/spark/ui/storage/StoragePage.scala @@ -78,7 +78,7 @@ private[ui] class StoragePage(parent: SparkUITab, store: AppStatusStore) extends {rdd.id} - {rdd.name} diff --git a/core/src/main/scala/org/apache/spark/util/CompletionIterator.scala b/core/src/main/scala/org/apache/spark/util/CompletionIterator.scala index 21acaa95c5645..f4d6c7a28d2e4 100644 --- a/core/src/main/scala/org/apache/spark/util/CompletionIterator.scala +++ b/core/src/main/scala/org/apache/spark/util/CompletionIterator.scala @@ -25,11 +25,14 @@ private[spark] abstract class CompletionIterator[ +A, +I <: Iterator[A]](sub: I) extends Iterator[A] { private[this] var completed = false - def next(): A = sub.next() + private[this] var iter = sub + def next(): A = iter.next() def hasNext: Boolean = { - val r = sub.hasNext + val r = iter.hasNext if (!r && !completed) { completed = true + // reassign to release resources of highly resource consuming iterators early + iter = Iterator.empty.asInstanceOf[I] completion() } r diff --git a/core/src/main/scala/org/apache/spark/util/Utils.scala b/core/src/main/scala/org/apache/spark/util/Utils.scala index 93b5826f8a74b..227c9e734f0af 100644 --- a/core/src/main/scala/org/apache/spark/util/Utils.scala +++ b/core/src/main/scala/org/apache/spark/util/Utils.scala @@ -31,7 +31,6 @@ import java.security.SecureRandom import java.util.{Locale, Properties, Random, UUID} import java.util.concurrent._ import java.util.concurrent.TimeUnit.NANOSECONDS -import java.util.concurrent.atomic.AtomicBoolean import java.util.zip.GZIPInputStream import scala.annotation.tailrec @@ -93,53 +92,6 @@ private[spark] object Utils extends Logging { private val MAX_DIR_CREATION_ATTEMPTS: Int = 10 @volatile private var localRootDirs: Array[String] = null - /** - * The performance overhead of creating and logging strings for wide schemas can be large. To - * limit the impact, we bound the number of fields to include by default. This can be overridden - * by setting the 'spark.debug.maxToStringFields' conf in SparkEnv. - */ - val DEFAULT_MAX_TO_STRING_FIELDS = 25 - - private[spark] def maxNumToStringFields = { - if (SparkEnv.get != null) { - SparkEnv.get.conf.getInt("spark.debug.maxToStringFields", DEFAULT_MAX_TO_STRING_FIELDS) - } else { - DEFAULT_MAX_TO_STRING_FIELDS - } - } - - /** Whether we have warned about plan string truncation yet. */ - private val truncationWarningPrinted = new AtomicBoolean(false) - - /** - * Format a sequence with semantics similar to calling .mkString(). Any elements beyond - * maxNumToStringFields will be dropped and replaced by a "... N more fields" placeholder. - * - * @return the trimmed and formatted string. - */ - def truncatedString[T]( - seq: Seq[T], - start: String, - sep: String, - end: String, - maxNumFields: Int = maxNumToStringFields): String = { - if (seq.length > maxNumFields) { - if (truncationWarningPrinted.compareAndSet(false, true)) { - logWarning( - "Truncated the string representation of a plan since it was too large. This " + - "behavior can be adjusted by setting 'spark.debug.maxToStringFields' in SparkEnv.conf.") - } - val numFields = math.max(0, maxNumFields - 1) - seq.take(numFields).mkString( - start, sep, sep + "... " + (seq.length - numFields) + " more fields" + end) - } else { - seq.mkString(start, sep, end) - } - } - - /** Shorthand for calling truncatedString() without start or end strings. */ - def truncatedString[T](seq: Seq[T], sep: String): String = truncatedString(seq, "", sep, "") - /** Serialize an object using Java serialization */ def serialize[T](o: T): Array[Byte] = { val bos = new ByteArrayOutputStream() @@ -2328,7 +2280,12 @@ private[spark] object Utils extends Logging { * configure a new log4j level */ def setLogLevel(l: org.apache.log4j.Level) { - org.apache.log4j.Logger.getRootLogger().setLevel(l) + val rootLogger = org.apache.log4j.Logger.getRootLogger() + rootLogger.setLevel(l) + rootLogger.getAllAppenders().asScala.foreach { + case ca: org.apache.log4j.ConsoleAppender => ca.setThreshold(l) + case _ => // no-op + } } /** @@ -2430,7 +2387,8 @@ private[spark] object Utils extends Logging { "org.apache.spark.security.ShellBasedGroupsMappingProvider") if (groupProviderClassName != "") { try { - val groupMappingServiceProvider = classForName(groupProviderClassName).newInstance. + val groupMappingServiceProvider = classForName(groupProviderClassName). + getConstructor().newInstance(). asInstanceOf[org.apache.spark.security.GroupMappingServiceProvider] val currentUserGroups = groupMappingServiceProvider.getGroups(username) return currentUserGroups @@ -2863,6 +2821,14 @@ private[spark] object Utils extends Logging { def stringHalfWidth(str: String): Int = { if (str == null) 0 else str.length + fullWidthRegex.findAllIn(str).size } + + def sanitizeDirName(str: String): String = { + str.replaceAll("[ :/]", "-").replaceAll("[.${}'\"]", "_").toLowerCase(Locale.ROOT) + } + + def isClientMode(conf: SparkConf): Boolean = { + "client".equals(conf.get(SparkLauncher.DEPLOY_MODE, "client")) + } } private[util] object CallerContext extends Logging { diff --git a/core/src/main/scala/org/apache/spark/util/VersionUtils.scala b/core/src/main/scala/org/apache/spark/util/VersionUtils.scala index 828153b868420..c0f8866dd58dc 100644 --- a/core/src/main/scala/org/apache/spark/util/VersionUtils.scala +++ b/core/src/main/scala/org/apache/spark/util/VersionUtils.scala @@ -23,6 +23,7 @@ package org.apache.spark.util private[spark] object VersionUtils { private val majorMinorRegex = """^(\d+)\.(\d+)(\..*)?$""".r + private val shortVersionRegex = """^(\d+\.\d+\.\d+)(.*)?$""".r /** * Given a Spark version string, return the major version number. @@ -36,6 +37,19 @@ private[spark] object VersionUtils { */ def minorVersion(sparkVersion: String): Int = majorMinorVersion(sparkVersion)._2 + /** + * Given a Spark version string, return the short version string. + * E.g., for 3.0.0-SNAPSHOT, return '3.0.0'. + */ + def shortVersion(sparkVersion: String): String = { + shortVersionRegex.findFirstMatchIn(sparkVersion) match { + case Some(m) => m.group(1) + case None => + throw new IllegalArgumentException(s"Spark tried to parse '$sparkVersion' as a Spark" + + s" version string, but it could not find the major/minor/maintenance version numbers.") + } + } + /** * Given a Spark version string, return the (major version number, minor version number). * E.g., for 2.0.1-SNAPSHOT, return (2, 0). diff --git a/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala b/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala index b159200d79222..46279e79d78db 100644 --- a/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala +++ b/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala @@ -727,9 +727,10 @@ private[spark] class ExternalSorter[K, V, C]( spills.clear() forceSpillFiles.foreach(s => s.file.delete()) forceSpillFiles.clear() - if (map != null || buffer != null) { + if (map != null || buffer != null || readingIterator != null) { map = null // So that the memory can be garbage-collected buffer = null // So that the memory can be garbage-collected + readingIterator = null // So that the memory can be garbage-collected releaseMemory() } } @@ -793,8 +794,8 @@ private[spark] class ExternalSorter[K, V, C]( def nextPartition(): Int = cur._1._1 } - logInfo(s"Task ${context.taskAttemptId} force spilling in-memory map to disk and " + - s" it will release ${org.apache.spark.util.Utils.bytesToString(getUsed())} memory") + logInfo(s"Task ${TaskContext.get().taskAttemptId} force spilling in-memory map to disk " + + s"and it will release ${org.apache.spark.util.Utils.bytesToString(getUsed())} memory") val spillFile = spillMemoryIteratorToDisk(inMemoryIterator) forceSpillFiles += spillFile val spillReader = new SpillReader(spillFile) diff --git a/core/src/main/scala/org/apache/spark/util/io/ChunkedByteBuffer.scala b/core/src/main/scala/org/apache/spark/util/io/ChunkedByteBuffer.scala index 870830fff4c3e..128d6ff8cd746 100644 --- a/core/src/main/scala/org/apache/spark/util/io/ChunkedByteBuffer.scala +++ b/core/src/main/scala/org/apache/spark/util/io/ChunkedByteBuffer.scala @@ -222,7 +222,8 @@ private[spark] class ChunkedByteBufferInputStream( dispose: Boolean) extends InputStream { - private[this] var chunks = chunkedByteBuffer.getChunks().iterator + // Filter out empty chunks since `read()` assumes all chunks are non-empty. + private[this] var chunks = chunkedByteBuffer.getChunks().filter(_.hasRemaining).iterator private[this] var currentChunk: ByteBuffer = { if (chunks.hasNext) { chunks.next() diff --git a/core/src/main/scala/org/apache/spark/util/logging/DriverLogger.scala b/core/src/main/scala/org/apache/spark/util/logging/DriverLogger.scala new file mode 100644 index 0000000000000..bea18a3df4783 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/util/logging/DriverLogger.scala @@ -0,0 +1,205 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.util.logging + +import java.io._ +import java.util.concurrent.{ScheduledExecutorService, TimeUnit} + +import org.apache.commons.io.FileUtils +import org.apache.hadoop.conf.Configuration +import org.apache.hadoop.fs.{FileSystem, FSDataOutputStream, Path} +import org.apache.hadoop.fs.permission.FsPermission +import org.apache.log4j.{FileAppender => Log4jFileAppender, _} + +import org.apache.spark.SparkConf +import org.apache.spark.internal.Logging +import org.apache.spark.internal.config._ +import org.apache.spark.network.util.JavaUtils +import org.apache.spark.util.{ThreadUtils, Utils} + +private[spark] class DriverLogger(conf: SparkConf) extends Logging { + + private val UPLOAD_CHUNK_SIZE = 1024 * 1024 + private val UPLOAD_INTERVAL_IN_SECS = 5 + private val DEFAULT_LAYOUT = "%d{yy/MM/dd HH:mm:ss.SSS} %t %p %c{1}: %m%n" + private val LOG_FILE_PERMISSIONS = new FsPermission(Integer.parseInt("770", 8).toShort) + + private var localLogFile: String = FileUtils.getFile( + Utils.getLocalDir(conf), + DriverLogger.DRIVER_LOG_DIR, + DriverLogger.DRIVER_LOG_FILE).getAbsolutePath() + private var writer: Option[DfsAsyncWriter] = None + + addLogAppender() + + private def addLogAppender(): Unit = { + val appenders = LogManager.getRootLogger().getAllAppenders() + val layout = if (conf.contains(DRIVER_LOG_LAYOUT)) { + new PatternLayout(conf.get(DRIVER_LOG_LAYOUT).get) + } else if (appenders.hasMoreElements()) { + appenders.nextElement().asInstanceOf[Appender].getLayout() + } else { + new PatternLayout(DEFAULT_LAYOUT) + } + val fa = new Log4jFileAppender(layout, localLogFile) + fa.setName(DriverLogger.APPENDER_NAME) + LogManager.getRootLogger().addAppender(fa) + logInfo(s"Added a local log appender at: ${localLogFile}") + } + + def startSync(hadoopConf: Configuration): Unit = { + try { + // Setup a writer which moves the local file to hdfs continuously + val appId = Utils.sanitizeDirName(conf.getAppId) + writer = Some(new DfsAsyncWriter(appId, hadoopConf)) + } catch { + case e: Exception => + logError(s"Could not persist driver logs to dfs", e) + } + } + + def stop(): Unit = { + try { + val fa = LogManager.getRootLogger.getAppender(DriverLogger.APPENDER_NAME) + LogManager.getRootLogger().removeAppender(DriverLogger.APPENDER_NAME) + Utils.tryLogNonFatalError(fa.close()) + writer.foreach(_.closeWriter()) + } catch { + case e: Exception => + logError(s"Error in persisting driver logs", e) + } finally { + Utils.tryLogNonFatalError { + JavaUtils.deleteRecursively(FileUtils.getFile(localLogFile).getParentFile()) + } + } + } + + // Visible for testing + private[spark] class DfsAsyncWriter(appId: String, hadoopConf: Configuration) extends Runnable + with Logging { + + private var streamClosed = false + private var inStream: InputStream = null + private var outputStream: FSDataOutputStream = null + private val tmpBuffer = new Array[Byte](UPLOAD_CHUNK_SIZE) + private var threadpool: ScheduledExecutorService = _ + init() + + private def init(): Unit = { + val rootDir = conf.get(DRIVER_LOG_DFS_DIR).get + val fileSystem: FileSystem = new Path(rootDir).getFileSystem(hadoopConf) + if (!fileSystem.exists(new Path(rootDir))) { + throw new RuntimeException(s"${rootDir} does not exist." + + s" Please create this dir in order to persist driver logs") + } + val dfsLogFile: String = FileUtils.getFile(rootDir, appId + + DriverLogger.DRIVER_LOG_FILE_SUFFIX).getAbsolutePath() + try { + inStream = new BufferedInputStream(new FileInputStream(localLogFile)) + outputStream = fileSystem.create(new Path(dfsLogFile), true) + fileSystem.setPermission(new Path(dfsLogFile), LOG_FILE_PERMISSIONS) + } catch { + case e: Exception => + JavaUtils.closeQuietly(inStream) + JavaUtils.closeQuietly(outputStream) + throw e + } + threadpool = ThreadUtils.newDaemonSingleThreadScheduledExecutor("dfsSyncThread") + threadpool.scheduleWithFixedDelay(this, UPLOAD_INTERVAL_IN_SECS, UPLOAD_INTERVAL_IN_SECS, + TimeUnit.SECONDS) + logInfo(s"Started driver log file sync to: ${dfsLogFile}") + } + + def run(): Unit = { + if (streamClosed) { + return + } + try { + var remaining = inStream.available() + while (remaining > 0) { + val read = inStream.read(tmpBuffer, 0, math.min(remaining, UPLOAD_CHUNK_SIZE)) + outputStream.write(tmpBuffer, 0, read) + remaining -= read + } + outputStream.hflush() + } catch { + case e: Exception => logError("Failed writing driver logs to dfs", e) + } + } + + private def close(): Unit = { + if (streamClosed) { + return + } + try { + // Write all remaining bytes + run() + } finally { + try { + streamClosed = true + inStream.close() + outputStream.close() + } catch { + case e: Exception => + logError("Error in closing driver log input/output stream", e) + } + } + } + + def closeWriter(): Unit = { + try { + threadpool.execute(new Runnable() { + override def run(): Unit = DfsAsyncWriter.this.close() + }) + threadpool.shutdown() + threadpool.awaitTermination(1, TimeUnit.MINUTES) + } catch { + case e: Exception => + logError("Error in shutting down threadpool", e) + } + } + } + +} + +private[spark] object DriverLogger extends Logging { + val DRIVER_LOG_DIR = "__driver_logs__" + val DRIVER_LOG_FILE = "driver.log" + val DRIVER_LOG_FILE_SUFFIX = "_" + DRIVER_LOG_FILE + val APPENDER_NAME = "_DriverLogAppender" + + def apply(conf: SparkConf): Option[DriverLogger] = { + if (conf.get(DRIVER_LOG_PERSISTTODFS) && Utils.isClientMode(conf)) { + if (conf.contains(DRIVER_LOG_DFS_DIR)) { + try { + Some(new DriverLogger(conf)) + } catch { + case e: Exception => + logError("Could not add driver logger", e) + None + } + } else { + logWarning(s"Driver logs are not persisted because" + + s" ${DRIVER_LOG_DFS_DIR.key} is not configured") + None + } + } else { + None + } + } +} diff --git a/core/src/main/scala/org/apache/spark/util/random/RandomSampler.scala b/core/src/main/scala/org/apache/spark/util/random/RandomSampler.scala index ea99a7e5b4847..70554f1d03067 100644 --- a/core/src/main/scala/org/apache/spark/util/random/RandomSampler.scala +++ b/core/src/main/scala/org/apache/spark/util/random/RandomSampler.scala @@ -49,7 +49,7 @@ trait RandomSampler[T, U] extends Pseudorandom with Cloneable with Serializable /** return a copy of the RandomSampler object */ override def clone: RandomSampler[T, U] = - throw new NotImplementedError("clone() is not implemented.") + throw new UnsupportedOperationException("clone() is not implemented.") } private[spark] diff --git a/core/src/test/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriterSuite.java b/core/src/test/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriterSuite.java index a07d0e84ea854..30ad3f5575545 100644 --- a/core/src/test/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriterSuite.java +++ b/core/src/test/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriterSuite.java @@ -162,7 +162,8 @@ private UnsafeShuffleWriter createWriter( new SerializedShuffleHandle<>(0, 1, shuffleDep), 0, // map id taskContext, - conf + conf, + taskContext.taskMetrics().shuffleWriteMetrics() ); } @@ -521,7 +522,8 @@ public void testPeakMemoryUsed() throws Exception { new SerializedShuffleHandle<>(0, 1, shuffleDep), 0, // map id taskContext, - conf); + conf, + taskContext.taskMetrics().shuffleWriteMetrics()); // Peak memory should be monotonically increasing. More specifically, every time // we allocate a new page it should increase by exactly the size of the page. diff --git a/core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorterSuite.java b/core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorterSuite.java index 411cd5cb57331..d1b29d90ad913 100644 --- a/core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorterSuite.java +++ b/core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorterSuite.java @@ -38,6 +38,7 @@ import org.apache.spark.executor.TaskMetrics; import org.apache.spark.internal.config.package$; import org.apache.spark.memory.TestMemoryManager; +import org.apache.spark.memory.SparkOutOfMemoryError; import org.apache.spark.memory.TaskMemoryManager; import org.apache.spark.serializer.JavaSerializer; import org.apache.spark.serializer.SerializerInstance; @@ -534,10 +535,10 @@ public void testOOMDuringSpill() throws Exception { insertNumber(sorter, 1024); fail("expected OutOfMmoryError but it seems operation surprisingly succeeded"); } - // we expect an OutOfMemoryError here, anything else (i.e the original NPE is a failure) - catch (OutOfMemoryError oom){ + // we expect an SparkOutOfMemoryError here, anything else (i.e the original NPE is a failure) + catch (SparkOutOfMemoryError oom){ String oomStackTrace = Utils.exceptionString(oom); - assertThat("expected OutOfMemoryError in " + + assertThat("expected SparkOutOfMemoryError in " + "org.apache.spark.util.collection.unsafe.sort.UnsafeInMemorySorter.reset", oomStackTrace, Matchers.containsString( diff --git a/core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorterSuite.java b/core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorterSuite.java index 85ffdca436e14..b0d485f0c953f 100644 --- a/core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorterSuite.java +++ b/core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorterSuite.java @@ -27,6 +27,7 @@ import org.apache.spark.SparkConf; import org.apache.spark.memory.TestMemoryConsumer; import org.apache.spark.memory.TestMemoryManager; +import org.apache.spark.memory.SparkOutOfMemoryError; import org.apache.spark.memory.TaskMemoryManager; import org.apache.spark.unsafe.Platform; import org.apache.spark.unsafe.memory.MemoryBlock; @@ -178,8 +179,8 @@ public int compare( testMemoryManager.markExecutionAsOutOfMemoryOnce(); try { sorter.reset(); - fail("expected OutOfMmoryError but it seems operation surprisingly succeeded"); - } catch (OutOfMemoryError oom) { + fail("expected SparkOutOfMemoryError but it seems operation surprisingly succeeded"); + } catch (SparkOutOfMemoryError oom) { // as expected } // [SPARK-21907] this failed on NPE at diff --git a/core/src/test/resources/HistoryServerExpectations/blacklisting_for_stage_expectation.json b/core/src/test/resources/HistoryServerExpectations/blacklisting_for_stage_expectation.json index 5e9e8230e2745..62e5c123fd3d4 100644 --- a/core/src/test/resources/HistoryServerExpectations/blacklisting_for_stage_expectation.json +++ b/core/src/test/resources/HistoryServerExpectations/blacklisting_for_stage_expectation.json @@ -1,639 +1,708 @@ { - "status": "COMPLETE", - "stageId": 0, - "attemptId": 0, - "numTasks": 10, - "numActiveTasks": 0, - "numCompleteTasks": 10, - "numFailedTasks": 2, - "numKilledTasks": 0, - "numCompletedIndices": 10, - "executorRunTime": 761, - "executorCpuTime": 269916000, - "submissionTime": "2018-01-09T10:21:18.152GMT", - "firstTaskLaunchedTime": "2018-01-09T10:21:18.347GMT", - "completionTime": "2018-01-09T10:21:19.062GMT", - "inputBytes": 0, - "inputRecords": 0, - "outputBytes": 0, - "outputRecords": 0, - "shuffleReadBytes": 0, - "shuffleReadRecords": 0, - "shuffleWriteBytes": 460, - "shuffleWriteRecords": 10, - "memoryBytesSpilled": 0, - "diskBytesSpilled": 0, - "name": "map at :26", - "details": "org.apache.spark.rdd.RDD.map(RDD.scala:370)\n$line17.$read$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw.(:26)\n$line17.$read$$iw$$iw$$iw$$iw$$iw$$iw$$iw.(:34)\n$line17.$read$$iw$$iw$$iw$$iw$$iw$$iw.(:36)\n$line17.$read$$iw$$iw$$iw$$iw$$iw.(:38)\n$line17.$read$$iw$$iw$$iw$$iw.(:40)\n$line17.$read$$iw$$iw$$iw.(:42)\n$line17.$read$$iw$$iw.(:44)\n$line17.$read$$iw.(:46)\n$line17.$read.(:48)\n$line17.$read$.(:52)\n$line17.$read$.()\n$line17.$eval$.$print$lzycompute(:7)\n$line17.$eval$.$print(:6)\n$line17.$eval.$print()\nsun.reflect.NativeMethodAccessorImpl.invoke0(Native Method)\nsun.reflect.NativeMethodAccessorImpl.invoke(NativeMethodAccessorImpl.java:62)\nsun.reflect.DelegatingMethodAccessorImpl.invoke(DelegatingMethodAccessorImpl.java:43)\njava.lang.reflect.Method.invoke(Method.java:498)\nscala.tools.nsc.interpreter.IMain$ReadEvalPrint.call(IMain.scala:786)", - "schedulingPool": "default", - "rddIds": [ - 1, - 0 - ], - "accumulatorUpdates": [], - "tasks": { - "0": { - "taskId": 0, - "index": 0, - "attempt": 0, - "launchTime": "2018-01-09T10:21:18.347GMT", - "duration": 562, - "executorId": "0", - "host": "172.30.65.138", - "status": "FAILED", - "taskLocality": "PROCESS_LOCAL", - "speculative": false, - "accumulatorUpdates": [], - "errorMessage": "java.lang.RuntimeException: Bad executor\n\tat $line17.$read$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$anonfun$2.apply(:27)\n\tat $line17.$read$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$anonfun$2.apply(:26)\n\tat scala.collection.Iterator$$anon$11.next(Iterator.scala:409)\n\tat org.apache.spark.util.collection.ExternalSorter.insertAll(ExternalSorter.scala:193)\n\tat org.apache.spark.shuffle.sort.SortShuffleWriter.write(SortShuffleWriter.scala:63)\n\tat org.apache.spark.scheduler.ShuffleMapTask.runTask(ShuffleMapTask.scala:96)\n\tat org.apache.spark.scheduler.ShuffleMapTask.runTask(ShuffleMapTask.scala:53)\n\tat org.apache.spark.scheduler.Task.run(Task.scala:109)\n\tat org.apache.spark.executor.Executor$TaskRunner.run(Executor.scala:345)\n\tat java.util.concurrent.ThreadPoolExecutor.runWorker(ThreadPoolExecutor.java:1149)\n\tat java.util.concurrent.ThreadPoolExecutor$Worker.run(ThreadPoolExecutor.java:624)\n\tat java.lang.Thread.run(Thread.java:748)\n", - "taskMetrics": { - "executorDeserializeTime": 0, - "executorDeserializeCpuTime": 0, - "executorRunTime": 460, - "executorCpuTime": 0, - "resultSize": 0, - "jvmGcTime": 14, - "resultSerializationTime": 0, - "memoryBytesSpilled": 0, - "diskBytesSpilled": 0, - "peakExecutionMemory": 0, - "inputMetrics": { - "bytesRead": 0, - "recordsRead": 0 - }, - "outputMetrics": { - "bytesWritten": 0, - "recordsWritten": 0 - }, - "shuffleReadMetrics": { - "remoteBlocksFetched": 0, - "localBlocksFetched": 0, - "fetchWaitTime": 0, - "remoteBytesRead": 0, - "remoteBytesReadToDisk": 0, - "localBytesRead": 0, - "recordsRead": 0 - }, - "shuffleWriteMetrics": { - "bytesWritten": 0, - "writeTime": 3873006, - "recordsWritten": 0 + "status" : "COMPLETE", + "stageId" : 0, + "attemptId" : 0, + "numTasks" : 10, + "numActiveTasks" : 0, + "numCompleteTasks" : 10, + "numFailedTasks" : 2, + "numKilledTasks" : 0, + "numCompletedIndices" : 10, + "executorRunTime" : 761, + "executorCpuTime" : 269916000, + "submissionTime" : "2018-01-09T10:21:18.152GMT", + "firstTaskLaunchedTime" : "2018-01-09T10:21:18.347GMT", + "completionTime" : "2018-01-09T10:21:19.062GMT", + "inputBytes" : 0, + "inputRecords" : 0, + "outputBytes" : 0, + "outputRecords" : 0, + "shuffleReadBytes" : 0, + "shuffleReadRecords" : 0, + "shuffleWriteBytes" : 460, + "shuffleWriteRecords" : 10, + "memoryBytesSpilled" : 0, + "diskBytesSpilled" : 0, + "name" : "map at :26", + "details" : "org.apache.spark.rdd.RDD.map(RDD.scala:370)\n$line17.$read$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw.(:26)\n$line17.$read$$iw$$iw$$iw$$iw$$iw$$iw$$iw.(:34)\n$line17.$read$$iw$$iw$$iw$$iw$$iw$$iw.(:36)\n$line17.$read$$iw$$iw$$iw$$iw$$iw.(:38)\n$line17.$read$$iw$$iw$$iw$$iw.(:40)\n$line17.$read$$iw$$iw$$iw.(:42)\n$line17.$read$$iw$$iw.(:44)\n$line17.$read$$iw.(:46)\n$line17.$read.(:48)\n$line17.$read$.(:52)\n$line17.$read$.()\n$line17.$eval$.$print$lzycompute(:7)\n$line17.$eval$.$print(:6)\n$line17.$eval.$print()\nsun.reflect.NativeMethodAccessorImpl.invoke0(Native Method)\nsun.reflect.NativeMethodAccessorImpl.invoke(NativeMethodAccessorImpl.java:62)\nsun.reflect.DelegatingMethodAccessorImpl.invoke(DelegatingMethodAccessorImpl.java:43)\njava.lang.reflect.Method.invoke(Method.java:498)\nscala.tools.nsc.interpreter.IMain$ReadEvalPrint.call(IMain.scala:786)", + "schedulingPool" : "default", + "rddIds" : [ 1, 0 ], + "accumulatorUpdates" : [ ], + "tasks" : { + "0" : { + "taskId" : 0, + "index" : 0, + "attempt" : 0, + "launchTime" : "2018-01-09T10:21:18.347GMT", + "duration" : 562, + "executorId" : "0", + "host" : "172.30.65.138", + "status" : "FAILED", + "taskLocality" : "PROCESS_LOCAL", + "speculative" : false, + "accumulatorUpdates" : [ ], + "errorMessage" : "java.lang.RuntimeException: Bad executor\n\tat $line17.$read$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$anonfun$2.apply(:27)\n\tat $line17.$read$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$anonfun$2.apply(:26)\n\tat scala.collection.Iterator$$anon$11.next(Iterator.scala:409)\n\tat org.apache.spark.util.collection.ExternalSorter.insertAll(ExternalSorter.scala:193)\n\tat org.apache.spark.shuffle.sort.SortShuffleWriter.write(SortShuffleWriter.scala:63)\n\tat org.apache.spark.scheduler.ShuffleMapTask.runTask(ShuffleMapTask.scala:96)\n\tat org.apache.spark.scheduler.ShuffleMapTask.runTask(ShuffleMapTask.scala:53)\n\tat org.apache.spark.scheduler.Task.run(Task.scala:109)\n\tat org.apache.spark.executor.Executor$TaskRunner.run(Executor.scala:345)\n\tat java.util.concurrent.ThreadPoolExecutor.runWorker(ThreadPoolExecutor.java:1149)\n\tat java.util.concurrent.ThreadPoolExecutor$Worker.run(ThreadPoolExecutor.java:624)\n\tat java.lang.Thread.run(Thread.java:748)\n", + "taskMetrics" : { + "executorDeserializeTime" : 0, + "executorDeserializeCpuTime" : 0, + "executorRunTime" : 460, + "executorCpuTime" : 0, + "resultSize" : 0, + "jvmGcTime" : 14, + "resultSerializationTime" : 0, + "memoryBytesSpilled" : 0, + "diskBytesSpilled" : 0, + "peakExecutionMemory" : 0, + "inputMetrics" : { + "bytesRead" : 0, + "recordsRead" : 0 + }, + "outputMetrics" : { + "bytesWritten" : 0, + "recordsWritten" : 0 + }, + "shuffleReadMetrics" : { + "remoteBlocksFetched" : 0, + "localBlocksFetched" : 0, + "fetchWaitTime" : 0, + "remoteBytesRead" : 0, + "remoteBytesReadToDisk" : 0, + "localBytesRead" : 0, + "recordsRead" : 0 + }, + "shuffleWriteMetrics" : { + "bytesWritten" : 0, + "writeTime" : 3873006, + "recordsWritten" : 0 } - } + }, + "executorLogs" : { + "stdout" : "http://172.30.65.138:64279/logPage/?appId=app-20180109111548-0000&executorId=0&logType=stdout", + "stderr" : "http://172.30.65.138:64279/logPage/?appId=app-20180109111548-0000&executorId=0&logType=stderr" + }, + "schedulerDelay" : 102, + "gettingResultTime" : 0 }, - "5": { - "taskId": 5, - "index": 3, - "attempt": 0, - "launchTime": "2018-01-09T10:21:18.958GMT", - "duration": 22, - "executorId": "1", - "host": "172.30.65.138", - "status": "SUCCESS", - "taskLocality": "PROCESS_LOCAL", - "speculative": false, - "accumulatorUpdates": [], - "taskMetrics": { - "executorDeserializeTime": 3, - "executorDeserializeCpuTime": 2586000, - "executorRunTime": 9, - "executorCpuTime": 9635000, - "resultSize": 1029, - "jvmGcTime": 0, - "resultSerializationTime": 0, - "memoryBytesSpilled": 0, - "diskBytesSpilled": 0, - "peakExecutionMemory": 0, - "inputMetrics": { - "bytesRead": 0, - "recordsRead": 0 - }, - "outputMetrics": { - "bytesWritten": 0, - "recordsWritten": 0 - }, - "shuffleReadMetrics": { - "remoteBlocksFetched": 0, - "localBlocksFetched": 0, - "fetchWaitTime": 0, - "remoteBytesRead": 0, - "remoteBytesReadToDisk": 0, - "localBytesRead": 0, - "recordsRead": 0 - }, - "shuffleWriteMetrics": { - "bytesWritten": 46, - "writeTime": 262919, - "recordsWritten": 1 + "5" : { + "taskId" : 5, + "index" : 3, + "attempt" : 0, + "launchTime" : "2018-01-09T10:21:18.958GMT", + "duration" : 22, + "executorId" : "1", + "host" : "172.30.65.138", + "status" : "SUCCESS", + "taskLocality" : "PROCESS_LOCAL", + "speculative" : false, + "accumulatorUpdates" : [ ], + "taskMetrics" : { + "executorDeserializeTime" : 3, + "executorDeserializeCpuTime" : 2586000, + "executorRunTime" : 9, + "executorCpuTime" : 9635000, + "resultSize" : 1029, + "jvmGcTime" : 0, + "resultSerializationTime" : 0, + "memoryBytesSpilled" : 0, + "diskBytesSpilled" : 0, + "peakExecutionMemory" : 0, + "inputMetrics" : { + "bytesRead" : 0, + "recordsRead" : 0 + }, + "outputMetrics" : { + "bytesWritten" : 0, + "recordsWritten" : 0 + }, + "shuffleReadMetrics" : { + "remoteBlocksFetched" : 0, + "localBlocksFetched" : 0, + "fetchWaitTime" : 0, + "remoteBytesRead" : 0, + "remoteBytesReadToDisk" : 0, + "localBytesRead" : 0, + "recordsRead" : 0 + }, + "shuffleWriteMetrics" : { + "bytesWritten" : 46, + "writeTime" : 262919, + "recordsWritten" : 1 } - } + }, + "executorLogs" : { + "stdout" : "http://172.30.65.138:64278/logPage/?appId=app-20180109111548-0000&executorId=1&logType=stdout", + "stderr" : "http://172.30.65.138:64278/logPage/?appId=app-20180109111548-0000&executorId=1&logType=stderr" + }, + "schedulerDelay" : 10, + "gettingResultTime" : 0 }, - "10": { - "taskId": 10, - "index": 8, - "attempt": 0, - "launchTime": "2018-01-09T10:21:19.034GMT", - "duration": 12, - "executorId": "1", - "host": "172.30.65.138", - "status": "SUCCESS", - "taskLocality": "PROCESS_LOCAL", - "speculative": false, - "accumulatorUpdates": [], - "taskMetrics": { - "executorDeserializeTime": 2, - "executorDeserializeCpuTime": 1803000, - "executorRunTime": 6, - "executorCpuTime": 6157000, - "resultSize": 1029, - "jvmGcTime": 0, - "resultSerializationTime": 0, - "memoryBytesSpilled": 0, - "diskBytesSpilled": 0, - "peakExecutionMemory": 0, - "inputMetrics": { - "bytesRead": 0, - "recordsRead": 0 - }, - "outputMetrics": { - "bytesWritten": 0, - "recordsWritten": 0 - }, - "shuffleReadMetrics": { - "remoteBlocksFetched": 0, - "localBlocksFetched": 0, - "fetchWaitTime": 0, - "remoteBytesRead": 0, - "remoteBytesReadToDisk": 0, - "localBytesRead": 0, - "recordsRead": 0 - }, - "shuffleWriteMetrics": { - "bytesWritten": 46, - "writeTime": 243647, - "recordsWritten": 1 + "10" : { + "taskId" : 10, + "index" : 8, + "attempt" : 0, + "launchTime" : "2018-01-09T10:21:19.034GMT", + "duration" : 12, + "executorId" : "1", + "host" : "172.30.65.138", + "status" : "SUCCESS", + "taskLocality" : "PROCESS_LOCAL", + "speculative" : false, + "accumulatorUpdates" : [ ], + "taskMetrics" : { + "executorDeserializeTime" : 2, + "executorDeserializeCpuTime" : 1803000, + "executorRunTime" : 6, + "executorCpuTime" : 6157000, + "resultSize" : 1029, + "jvmGcTime" : 0, + "resultSerializationTime" : 0, + "memoryBytesSpilled" : 0, + "diskBytesSpilled" : 0, + "peakExecutionMemory" : 0, + "inputMetrics" : { + "bytesRead" : 0, + "recordsRead" : 0 + }, + "outputMetrics" : { + "bytesWritten" : 0, + "recordsWritten" : 0 + }, + "shuffleReadMetrics" : { + "remoteBlocksFetched" : 0, + "localBlocksFetched" : 0, + "fetchWaitTime" : 0, + "remoteBytesRead" : 0, + "remoteBytesReadToDisk" : 0, + "localBytesRead" : 0, + "recordsRead" : 0 + }, + "shuffleWriteMetrics" : { + "bytesWritten" : 46, + "writeTime" : 243647, + "recordsWritten" : 1 } - } + }, + "executorLogs" : { + "stdout" : "http://172.30.65.138:64278/logPage/?appId=app-20180109111548-0000&executorId=1&logType=stdout", + "stderr" : "http://172.30.65.138:64278/logPage/?appId=app-20180109111548-0000&executorId=1&logType=stderr" + }, + "schedulerDelay" : 4, + "gettingResultTime" : 0 }, - "1": { - "taskId": 1, - "index": 1, - "attempt": 0, - "launchTime": "2018-01-09T10:21:18.364GMT", - "duration": 565, - "executorId": "1", - "host": "172.30.65.138", - "status": "SUCCESS", - "taskLocality": "PROCESS_LOCAL", - "speculative": false, - "accumulatorUpdates": [], - "taskMetrics": { - "executorDeserializeTime": 301, - "executorDeserializeCpuTime": 200029000, - "executorRunTime": 212, - "executorCpuTime": 198479000, - "resultSize": 1115, - "jvmGcTime": 13, - "resultSerializationTime": 1, - "memoryBytesSpilled": 0, - "diskBytesSpilled": 0, - "peakExecutionMemory": 0, - "inputMetrics": { - "bytesRead": 0, - "recordsRead": 0 - }, - "outputMetrics": { - "bytesWritten": 0, - "recordsWritten": 0 - }, - "shuffleReadMetrics": { - "remoteBlocksFetched": 0, - "localBlocksFetched": 0, - "fetchWaitTime": 0, - "remoteBytesRead": 0, - "remoteBytesReadToDisk": 0, - "localBytesRead": 0, - "recordsRead": 0 - }, - "shuffleWriteMetrics": { - "bytesWritten": 46, - "writeTime": 2409488, - "recordsWritten": 1 + "1" : { + "taskId" : 1, + "index" : 1, + "attempt" : 0, + "launchTime" : "2018-01-09T10:21:18.364GMT", + "duration" : 565, + "executorId" : "1", + "host" : "172.30.65.138", + "status" : "SUCCESS", + "taskLocality" : "PROCESS_LOCAL", + "speculative" : false, + "accumulatorUpdates" : [ ], + "taskMetrics" : { + "executorDeserializeTime" : 301, + "executorDeserializeCpuTime" : 200029000, + "executorRunTime" : 212, + "executorCpuTime" : 198479000, + "resultSize" : 1115, + "jvmGcTime" : 13, + "resultSerializationTime" : 1, + "memoryBytesSpilled" : 0, + "diskBytesSpilled" : 0, + "peakExecutionMemory" : 0, + "inputMetrics" : { + "bytesRead" : 0, + "recordsRead" : 0 + }, + "outputMetrics" : { + "bytesWritten" : 0, + "recordsWritten" : 0 + }, + "shuffleReadMetrics" : { + "remoteBlocksFetched" : 0, + "localBlocksFetched" : 0, + "fetchWaitTime" : 0, + "remoteBytesRead" : 0, + "remoteBytesReadToDisk" : 0, + "localBytesRead" : 0, + "recordsRead" : 0 + }, + "shuffleWriteMetrics" : { + "bytesWritten" : 46, + "writeTime" : 2409488, + "recordsWritten" : 1 } - } + }, + "executorLogs" : { + "stdout" : "http://172.30.65.138:64278/logPage/?appId=app-20180109111548-0000&executorId=1&logType=stdout", + "stderr" : "http://172.30.65.138:64278/logPage/?appId=app-20180109111548-0000&executorId=1&logType=stderr" + }, + "schedulerDelay" : 51, + "gettingResultTime" : 0 }, - "6": { - "taskId": 6, - "index": 4, - "attempt": 0, - "launchTime": "2018-01-09T10:21:18.980GMT", - "duration": 16, - "executorId": "1", - "host": "172.30.65.138", - "status": "SUCCESS", - "taskLocality": "PROCESS_LOCAL", - "speculative": false, - "accumulatorUpdates": [], - "taskMetrics": { - "executorDeserializeTime": 3, - "executorDeserializeCpuTime": 2610000, - "executorRunTime": 10, - "executorCpuTime": 9622000, - "resultSize": 1029, - "jvmGcTime": 0, - "resultSerializationTime": 0, - "memoryBytesSpilled": 0, - "diskBytesSpilled": 0, - "peakExecutionMemory": 0, - "inputMetrics": { - "bytesRead": 0, - "recordsRead": 0 - }, - "outputMetrics": { - "bytesWritten": 0, - "recordsWritten": 0 - }, - "shuffleReadMetrics": { - "remoteBlocksFetched": 0, - "localBlocksFetched": 0, - "fetchWaitTime": 0, - "remoteBytesRead": 0, - "remoteBytesReadToDisk": 0, - "localBytesRead": 0, - "recordsRead": 0 - }, - "shuffleWriteMetrics": { - "bytesWritten": 46, - "writeTime": 385110, - "recordsWritten": 1 + "6" : { + "taskId" : 6, + "index" : 4, + "attempt" : 0, + "launchTime" : "2018-01-09T10:21:18.980GMT", + "duration" : 16, + "executorId" : "1", + "host" : "172.30.65.138", + "status" : "SUCCESS", + "taskLocality" : "PROCESS_LOCAL", + "speculative" : false, + "accumulatorUpdates" : [ ], + "taskMetrics" : { + "executorDeserializeTime" : 3, + "executorDeserializeCpuTime" : 2610000, + "executorRunTime" : 10, + "executorCpuTime" : 9622000, + "resultSize" : 1029, + "jvmGcTime" : 0, + "resultSerializationTime" : 0, + "memoryBytesSpilled" : 0, + "diskBytesSpilled" : 0, + "peakExecutionMemory" : 0, + "inputMetrics" : { + "bytesRead" : 0, + "recordsRead" : 0 + }, + "outputMetrics" : { + "bytesWritten" : 0, + "recordsWritten" : 0 + }, + "shuffleReadMetrics" : { + "remoteBlocksFetched" : 0, + "localBlocksFetched" : 0, + "fetchWaitTime" : 0, + "remoteBytesRead" : 0, + "remoteBytesReadToDisk" : 0, + "localBytesRead" : 0, + "recordsRead" : 0 + }, + "shuffleWriteMetrics" : { + "bytesWritten" : 46, + "writeTime" : 385110, + "recordsWritten" : 1 } - } + }, + "executorLogs" : { + "stdout" : "http://172.30.65.138:64278/logPage/?appId=app-20180109111548-0000&executorId=1&logType=stdout", + "stderr" : "http://172.30.65.138:64278/logPage/?appId=app-20180109111548-0000&executorId=1&logType=stderr" + }, + "schedulerDelay" : 3, + "gettingResultTime" : 0 }, - "9": { - "taskId": 9, - "index": 7, - "attempt": 0, - "launchTime": "2018-01-09T10:21:19.022GMT", - "duration": 12, - "executorId": "1", - "host": "172.30.65.138", - "status": "SUCCESS", - "taskLocality": "PROCESS_LOCAL", - "speculative": false, - "accumulatorUpdates": [], - "taskMetrics": { - "executorDeserializeTime": 2, - "executorDeserializeCpuTime": 1981000, - "executorRunTime": 7, - "executorCpuTime": 6335000, - "resultSize": 1029, - "jvmGcTime": 0, - "resultSerializationTime": 0, - "memoryBytesSpilled": 0, - "diskBytesSpilled": 0, - "peakExecutionMemory": 0, - "inputMetrics": { - "bytesRead": 0, - "recordsRead": 0 - }, - "outputMetrics": { - "bytesWritten": 0, - "recordsWritten": 0 - }, - "shuffleReadMetrics": { - "remoteBlocksFetched": 0, - "localBlocksFetched": 0, - "fetchWaitTime": 0, - "remoteBytesRead": 0, - "remoteBytesReadToDisk": 0, - "localBytesRead": 0, - "recordsRead": 0 - }, - "shuffleWriteMetrics": { - "bytesWritten": 46, - "writeTime": 259354, - "recordsWritten": 1 + "9" : { + "taskId" : 9, + "index" : 7, + "attempt" : 0, + "launchTime" : "2018-01-09T10:21:19.022GMT", + "duration" : 12, + "executorId" : "1", + "host" : "172.30.65.138", + "status" : "SUCCESS", + "taskLocality" : "PROCESS_LOCAL", + "speculative" : false, + "accumulatorUpdates" : [ ], + "taskMetrics" : { + "executorDeserializeTime" : 2, + "executorDeserializeCpuTime" : 1981000, + "executorRunTime" : 7, + "executorCpuTime" : 6335000, + "resultSize" : 1029, + "jvmGcTime" : 0, + "resultSerializationTime" : 0, + "memoryBytesSpilled" : 0, + "diskBytesSpilled" : 0, + "peakExecutionMemory" : 0, + "inputMetrics" : { + "bytesRead" : 0, + "recordsRead" : 0 + }, + "outputMetrics" : { + "bytesWritten" : 0, + "recordsWritten" : 0 + }, + "shuffleReadMetrics" : { + "remoteBlocksFetched" : 0, + "localBlocksFetched" : 0, + "fetchWaitTime" : 0, + "remoteBytesRead" : 0, + "remoteBytesReadToDisk" : 0, + "localBytesRead" : 0, + "recordsRead" : 0 + }, + "shuffleWriteMetrics" : { + "bytesWritten" : 46, + "writeTime" : 259354, + "recordsWritten" : 1 } - } + }, + "executorLogs" : { + "stdout" : "http://172.30.65.138:64278/logPage/?appId=app-20180109111548-0000&executorId=1&logType=stdout", + "stderr" : "http://172.30.65.138:64278/logPage/?appId=app-20180109111548-0000&executorId=1&logType=stderr" + }, + "schedulerDelay" : 3, + "gettingResultTime" : 0 }, - "2": { - "taskId": 2, - "index": 2, - "attempt": 0, - "launchTime": "2018-01-09T10:21:18.899GMT", - "duration": 27, - "executorId": "0", - "host": "172.30.65.138", - "status": "FAILED", - "taskLocality": "PROCESS_LOCAL", - "speculative": false, - "accumulatorUpdates": [], - "errorMessage": "java.lang.RuntimeException: Bad executor\n\tat $line17.$read$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$anonfun$2.apply(:27)\n\tat $line17.$read$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$anonfun$2.apply(:26)\n\tat scala.collection.Iterator$$anon$11.next(Iterator.scala:409)\n\tat org.apache.spark.util.collection.ExternalSorter.insertAll(ExternalSorter.scala:193)\n\tat org.apache.spark.shuffle.sort.SortShuffleWriter.write(SortShuffleWriter.scala:63)\n\tat org.apache.spark.scheduler.ShuffleMapTask.runTask(ShuffleMapTask.scala:96)\n\tat org.apache.spark.scheduler.ShuffleMapTask.runTask(ShuffleMapTask.scala:53)\n\tat org.apache.spark.scheduler.Task.run(Task.scala:109)\n\tat org.apache.spark.executor.Executor$TaskRunner.run(Executor.scala:345)\n\tat java.util.concurrent.ThreadPoolExecutor.runWorker(ThreadPoolExecutor.java:1149)\n\tat java.util.concurrent.ThreadPoolExecutor$Worker.run(ThreadPoolExecutor.java:624)\n\tat java.lang.Thread.run(Thread.java:748)\n", - "taskMetrics": { - "executorDeserializeTime": 0, - "executorDeserializeCpuTime": 0, - "executorRunTime": 16, - "executorCpuTime": 0, - "resultSize": 0, - "jvmGcTime": 0, - "resultSerializationTime": 0, - "memoryBytesSpilled": 0, - "diskBytesSpilled": 0, - "peakExecutionMemory": 0, - "inputMetrics": { - "bytesRead": 0, - "recordsRead": 0 - }, - "outputMetrics": { - "bytesWritten": 0, - "recordsWritten": 0 - }, - "shuffleReadMetrics": { - "remoteBlocksFetched": 0, - "localBlocksFetched": 0, - "fetchWaitTime": 0, - "remoteBytesRead": 0, - "remoteBytesReadToDisk": 0, - "localBytesRead": 0, - "recordsRead": 0 - }, - "shuffleWriteMetrics": { - "bytesWritten": 0, - "writeTime": 126128, - "recordsWritten": 0 + "2" : { + "taskId" : 2, + "index" : 2, + "attempt" : 0, + "launchTime" : "2018-01-09T10:21:18.899GMT", + "duration" : 27, + "executorId" : "0", + "host" : "172.30.65.138", + "status" : "FAILED", + "taskLocality" : "PROCESS_LOCAL", + "speculative" : false, + "accumulatorUpdates" : [ ], + "errorMessage" : "java.lang.RuntimeException: Bad executor\n\tat $line17.$read$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$anonfun$2.apply(:27)\n\tat $line17.$read$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$anonfun$2.apply(:26)\n\tat scala.collection.Iterator$$anon$11.next(Iterator.scala:409)\n\tat org.apache.spark.util.collection.ExternalSorter.insertAll(ExternalSorter.scala:193)\n\tat org.apache.spark.shuffle.sort.SortShuffleWriter.write(SortShuffleWriter.scala:63)\n\tat org.apache.spark.scheduler.ShuffleMapTask.runTask(ShuffleMapTask.scala:96)\n\tat org.apache.spark.scheduler.ShuffleMapTask.runTask(ShuffleMapTask.scala:53)\n\tat org.apache.spark.scheduler.Task.run(Task.scala:109)\n\tat org.apache.spark.executor.Executor$TaskRunner.run(Executor.scala:345)\n\tat java.util.concurrent.ThreadPoolExecutor.runWorker(ThreadPoolExecutor.java:1149)\n\tat java.util.concurrent.ThreadPoolExecutor$Worker.run(ThreadPoolExecutor.java:624)\n\tat java.lang.Thread.run(Thread.java:748)\n", + "taskMetrics" : { + "executorDeserializeTime" : 0, + "executorDeserializeCpuTime" : 0, + "executorRunTime" : 16, + "executorCpuTime" : 0, + "resultSize" : 0, + "jvmGcTime" : 0, + "resultSerializationTime" : 0, + "memoryBytesSpilled" : 0, + "diskBytesSpilled" : 0, + "peakExecutionMemory" : 0, + "inputMetrics" : { + "bytesRead" : 0, + "recordsRead" : 0 + }, + "outputMetrics" : { + "bytesWritten" : 0, + "recordsWritten" : 0 + }, + "shuffleReadMetrics" : { + "remoteBlocksFetched" : 0, + "localBlocksFetched" : 0, + "fetchWaitTime" : 0, + "remoteBytesRead" : 0, + "remoteBytesReadToDisk" : 0, + "localBytesRead" : 0, + "recordsRead" : 0 + }, + "shuffleWriteMetrics" : { + "bytesWritten" : 0, + "writeTime" : 126128, + "recordsWritten" : 0 } - } + }, + "executorLogs" : { + "stdout" : "http://172.30.65.138:64279/logPage/?appId=app-20180109111548-0000&executorId=0&logType=stdout", + "stderr" : "http://172.30.65.138:64279/logPage/?appId=app-20180109111548-0000&executorId=0&logType=stderr" + }, + "schedulerDelay" : 11, + "gettingResultTime" : 0 }, - "7": { - "taskId": 7, - "index": 5, - "attempt": 0, - "launchTime": "2018-01-09T10:21:18.996GMT", - "duration": 15, - "executorId": "1", - "host": "172.30.65.138", - "status": "SUCCESS", - "taskLocality": "PROCESS_LOCAL", - "speculative": false, - "accumulatorUpdates": [], - "taskMetrics": { - "executorDeserializeTime": 2, - "executorDeserializeCpuTime": 2231000, - "executorRunTime": 9, - "executorCpuTime": 8407000, - "resultSize": 1029, - "jvmGcTime": 0, - "resultSerializationTime": 0, - "memoryBytesSpilled": 0, - "diskBytesSpilled": 0, - "peakExecutionMemory": 0, - "inputMetrics": { - "bytesRead": 0, - "recordsRead": 0 - }, - "outputMetrics": { - "bytesWritten": 0, - "recordsWritten": 0 - }, - "shuffleReadMetrics": { - "remoteBlocksFetched": 0, - "localBlocksFetched": 0, - "fetchWaitTime": 0, - "remoteBytesRead": 0, - "remoteBytesReadToDisk": 0, - "localBytesRead": 0, - "recordsRead": 0 - }, - "shuffleWriteMetrics": { - "bytesWritten": 46, - "writeTime": 205520, - "recordsWritten": 1 + "7" : { + "taskId" : 7, + "index" : 5, + "attempt" : 0, + "launchTime" : "2018-01-09T10:21:18.996GMT", + "duration" : 15, + "executorId" : "1", + "host" : "172.30.65.138", + "status" : "SUCCESS", + "taskLocality" : "PROCESS_LOCAL", + "speculative" : false, + "accumulatorUpdates" : [ ], + "taskMetrics" : { + "executorDeserializeTime" : 2, + "executorDeserializeCpuTime" : 2231000, + "executorRunTime" : 9, + "executorCpuTime" : 8407000, + "resultSize" : 1029, + "jvmGcTime" : 0, + "resultSerializationTime" : 0, + "memoryBytesSpilled" : 0, + "diskBytesSpilled" : 0, + "peakExecutionMemory" : 0, + "inputMetrics" : { + "bytesRead" : 0, + "recordsRead" : 0 + }, + "outputMetrics" : { + "bytesWritten" : 0, + "recordsWritten" : 0 + }, + "shuffleReadMetrics" : { + "remoteBlocksFetched" : 0, + "localBlocksFetched" : 0, + "fetchWaitTime" : 0, + "remoteBytesRead" : 0, + "remoteBytesReadToDisk" : 0, + "localBytesRead" : 0, + "recordsRead" : 0 + }, + "shuffleWriteMetrics" : { + "bytesWritten" : 46, + "writeTime" : 205520, + "recordsWritten" : 1 } - } + }, + "executorLogs" : { + "stdout" : "http://172.30.65.138:64278/logPage/?appId=app-20180109111548-0000&executorId=1&logType=stdout", + "stderr" : "http://172.30.65.138:64278/logPage/?appId=app-20180109111548-0000&executorId=1&logType=stderr" + }, + "schedulerDelay" : 4, + "gettingResultTime" : 0 }, - "3": { - "taskId": 3, - "index": 0, - "attempt": 1, - "launchTime": "2018-01-09T10:21:18.919GMT", - "duration": 24, - "executorId": "1", - "host": "172.30.65.138", - "status": "SUCCESS", - "taskLocality": "PROCESS_LOCAL", - "speculative": false, - "accumulatorUpdates": [], - "taskMetrics": { - "executorDeserializeTime": 8, - "executorDeserializeCpuTime": 8878000, - "executorRunTime": 10, - "executorCpuTime": 9364000, - "resultSize": 1029, - "jvmGcTime": 0, - "resultSerializationTime": 0, - "memoryBytesSpilled": 0, - "diskBytesSpilled": 0, - "peakExecutionMemory": 0, - "inputMetrics": { - "bytesRead": 0, - "recordsRead": 0 - }, - "outputMetrics": { - "bytesWritten": 0, - "recordsWritten": 0 - }, - "shuffleReadMetrics": { - "remoteBlocksFetched": 0, - "localBlocksFetched": 0, - "fetchWaitTime": 0, - "remoteBytesRead": 0, - "remoteBytesReadToDisk": 0, - "localBytesRead": 0, - "recordsRead": 0 - }, - "shuffleWriteMetrics": { - "bytesWritten": 46, - "writeTime": 207014, - "recordsWritten": 1 + "3" : { + "taskId" : 3, + "index" : 0, + "attempt" : 1, + "launchTime" : "2018-01-09T10:21:18.919GMT", + "duration" : 24, + "executorId" : "1", + "host" : "172.30.65.138", + "status" : "SUCCESS", + "taskLocality" : "PROCESS_LOCAL", + "speculative" : false, + "accumulatorUpdates" : [ ], + "taskMetrics" : { + "executorDeserializeTime" : 8, + "executorDeserializeCpuTime" : 8878000, + "executorRunTime" : 10, + "executorCpuTime" : 9364000, + "resultSize" : 1029, + "jvmGcTime" : 0, + "resultSerializationTime" : 0, + "memoryBytesSpilled" : 0, + "diskBytesSpilled" : 0, + "peakExecutionMemory" : 0, + "inputMetrics" : { + "bytesRead" : 0, + "recordsRead" : 0 + }, + "outputMetrics" : { + "bytesWritten" : 0, + "recordsWritten" : 0 + }, + "shuffleReadMetrics" : { + "remoteBlocksFetched" : 0, + "localBlocksFetched" : 0, + "fetchWaitTime" : 0, + "remoteBytesRead" : 0, + "remoteBytesReadToDisk" : 0, + "localBytesRead" : 0, + "recordsRead" : 0 + }, + "shuffleWriteMetrics" : { + "bytesWritten" : 46, + "writeTime" : 207014, + "recordsWritten" : 1 } - } + }, + "executorLogs" : { + "stdout" : "http://172.30.65.138:64278/logPage/?appId=app-20180109111548-0000&executorId=1&logType=stdout", + "stderr" : "http://172.30.65.138:64278/logPage/?appId=app-20180109111548-0000&executorId=1&logType=stderr" + }, + "schedulerDelay" : 6, + "gettingResultTime" : 0 }, - "11": { - "taskId": 11, - "index": 9, - "attempt": 0, - "launchTime": "2018-01-09T10:21:19.045GMT", - "duration": 15, - "executorId": "1", - "host": "172.30.65.138", - "status": "SUCCESS", - "taskLocality": "PROCESS_LOCAL", - "speculative": false, - "accumulatorUpdates": [], - "taskMetrics": { - "executorDeserializeTime": 3, - "executorDeserializeCpuTime": 2017000, - "executorRunTime": 6, - "executorCpuTime": 6676000, - "resultSize": 1029, - "jvmGcTime": 0, - "resultSerializationTime": 0, - "memoryBytesSpilled": 0, - "diskBytesSpilled": 0, - "peakExecutionMemory": 0, - "inputMetrics": { - "bytesRead": 0, - "recordsRead": 0 - }, - "outputMetrics": { - "bytesWritten": 0, - "recordsWritten": 0 - }, - "shuffleReadMetrics": { - "remoteBlocksFetched": 0, - "localBlocksFetched": 0, - "fetchWaitTime": 0, - "remoteBytesRead": 0, - "remoteBytesReadToDisk": 0, - "localBytesRead": 0, - "recordsRead": 0 - }, - "shuffleWriteMetrics": { - "bytesWritten": 46, - "writeTime": 233652, - "recordsWritten": 1 + "11" : { + "taskId" : 11, + "index" : 9, + "attempt" : 0, + "launchTime" : "2018-01-09T10:21:19.045GMT", + "duration" : 15, + "executorId" : "1", + "host" : "172.30.65.138", + "status" : "SUCCESS", + "taskLocality" : "PROCESS_LOCAL", + "speculative" : false, + "accumulatorUpdates" : [ ], + "taskMetrics" : { + "executorDeserializeTime" : 3, + "executorDeserializeCpuTime" : 2017000, + "executorRunTime" : 6, + "executorCpuTime" : 6676000, + "resultSize" : 1029, + "jvmGcTime" : 0, + "resultSerializationTime" : 0, + "memoryBytesSpilled" : 0, + "diskBytesSpilled" : 0, + "peakExecutionMemory" : 0, + "inputMetrics" : { + "bytesRead" : 0, + "recordsRead" : 0 + }, + "outputMetrics" : { + "bytesWritten" : 0, + "recordsWritten" : 0 + }, + "shuffleReadMetrics" : { + "remoteBlocksFetched" : 0, + "localBlocksFetched" : 0, + "fetchWaitTime" : 0, + "remoteBytesRead" : 0, + "remoteBytesReadToDisk" : 0, + "localBytesRead" : 0, + "recordsRead" : 0 + }, + "shuffleWriteMetrics" : { + "bytesWritten" : 46, + "writeTime" : 233652, + "recordsWritten" : 1 } - } + }, + "executorLogs" : { + "stdout" : "http://172.30.65.138:64278/logPage/?appId=app-20180109111548-0000&executorId=1&logType=stdout", + "stderr" : "http://172.30.65.138:64278/logPage/?appId=app-20180109111548-0000&executorId=1&logType=stderr" + }, + "schedulerDelay" : 6, + "gettingResultTime" : 0 }, - "8": { - "taskId": 8, - "index": 6, - "attempt": 0, - "launchTime": "2018-01-09T10:21:19.011GMT", - "duration": 11, - "executorId": "1", - "host": "172.30.65.138", - "status": "SUCCESS", - "taskLocality": "PROCESS_LOCAL", - "speculative": false, - "accumulatorUpdates": [], - "taskMetrics": { - "executorDeserializeTime": 1, - "executorDeserializeCpuTime": 1554000, - "executorRunTime": 7, - "executorCpuTime": 6034000, - "resultSize": 1029, - "jvmGcTime": 0, - "resultSerializationTime": 0, - "memoryBytesSpilled": 0, - "diskBytesSpilled": 0, - "peakExecutionMemory": 0, - "inputMetrics": { - "bytesRead": 0, - "recordsRead": 0 - }, - "outputMetrics": { - "bytesWritten": 0, - "recordsWritten": 0 - }, - "shuffleReadMetrics": { - "remoteBlocksFetched": 0, - "localBlocksFetched": 0, - "fetchWaitTime": 0, - "remoteBytesRead": 0, - "remoteBytesReadToDisk": 0, - "localBytesRead": 0, - "recordsRead": 0 - }, - "shuffleWriteMetrics": { - "bytesWritten": 46, - "writeTime": 213296, - "recordsWritten": 1 + "8" : { + "taskId" : 8, + "index" : 6, + "attempt" : 0, + "launchTime" : "2018-01-09T10:21:19.011GMT", + "duration" : 11, + "executorId" : "1", + "host" : "172.30.65.138", + "status" : "SUCCESS", + "taskLocality" : "PROCESS_LOCAL", + "speculative" : false, + "accumulatorUpdates" : [ ], + "taskMetrics" : { + "executorDeserializeTime" : 1, + "executorDeserializeCpuTime" : 1554000, + "executorRunTime" : 7, + "executorCpuTime" : 6034000, + "resultSize" : 1029, + "jvmGcTime" : 0, + "resultSerializationTime" : 0, + "memoryBytesSpilled" : 0, + "diskBytesSpilled" : 0, + "peakExecutionMemory" : 0, + "inputMetrics" : { + "bytesRead" : 0, + "recordsRead" : 0 + }, + "outputMetrics" : { + "bytesWritten" : 0, + "recordsWritten" : 0 + }, + "shuffleReadMetrics" : { + "remoteBlocksFetched" : 0, + "localBlocksFetched" : 0, + "fetchWaitTime" : 0, + "remoteBytesRead" : 0, + "remoteBytesReadToDisk" : 0, + "localBytesRead" : 0, + "recordsRead" : 0 + }, + "shuffleWriteMetrics" : { + "bytesWritten" : 46, + "writeTime" : 213296, + "recordsWritten" : 1 } - } + }, + "executorLogs" : { + "stdout" : "http://172.30.65.138:64278/logPage/?appId=app-20180109111548-0000&executorId=1&logType=stdout", + "stderr" : "http://172.30.65.138:64278/logPage/?appId=app-20180109111548-0000&executorId=1&logType=stderr" + }, + "schedulerDelay" : 3, + "gettingResultTime" : 0 }, - "4": { - "taskId": 4, - "index": 2, - "attempt": 1, - "launchTime": "2018-01-09T10:21:18.943GMT", - "duration": 16, - "executorId": "1", - "host": "172.30.65.138", - "status": "SUCCESS", - "taskLocality": "PROCESS_LOCAL", - "speculative": false, - "accumulatorUpdates": [], - "taskMetrics": { - "executorDeserializeTime": 2, - "executorDeserializeCpuTime": 2211000, - "executorRunTime": 9, - "executorCpuTime": 9207000, - "resultSize": 1029, - "jvmGcTime": 0, - "resultSerializationTime": 0, - "memoryBytesSpilled": 0, - "diskBytesSpilled": 0, - "peakExecutionMemory": 0, - "inputMetrics": { - "bytesRead": 0, - "recordsRead": 0 - }, - "outputMetrics": { - "bytesWritten": 0, - "recordsWritten": 0 - }, - "shuffleReadMetrics": { - "remoteBlocksFetched": 0, - "localBlocksFetched": 0, - "fetchWaitTime": 0, - "remoteBytesRead": 0, - "remoteBytesReadToDisk": 0, - "localBytesRead": 0, - "recordsRead": 0 - }, - "shuffleWriteMetrics": { - "bytesWritten": 46, - "writeTime": 292381, - "recordsWritten": 1 + "4" : { + "taskId" : 4, + "index" : 2, + "attempt" : 1, + "launchTime" : "2018-01-09T10:21:18.943GMT", + "duration" : 16, + "executorId" : "1", + "host" : "172.30.65.138", + "status" : "SUCCESS", + "taskLocality" : "PROCESS_LOCAL", + "speculative" : false, + "accumulatorUpdates" : [ ], + "taskMetrics" : { + "executorDeserializeTime" : 2, + "executorDeserializeCpuTime" : 2211000, + "executorRunTime" : 9, + "executorCpuTime" : 9207000, + "resultSize" : 1029, + "jvmGcTime" : 0, + "resultSerializationTime" : 0, + "memoryBytesSpilled" : 0, + "diskBytesSpilled" : 0, + "peakExecutionMemory" : 0, + "inputMetrics" : { + "bytesRead" : 0, + "recordsRead" : 0 + }, + "outputMetrics" : { + "bytesWritten" : 0, + "recordsWritten" : 0 + }, + "shuffleReadMetrics" : { + "remoteBlocksFetched" : 0, + "localBlocksFetched" : 0, + "fetchWaitTime" : 0, + "remoteBytesRead" : 0, + "remoteBytesReadToDisk" : 0, + "localBytesRead" : 0, + "recordsRead" : 0 + }, + "shuffleWriteMetrics" : { + "bytesWritten" : 46, + "writeTime" : 292381, + "recordsWritten" : 1 } - } + }, + "executorLogs" : { + "stdout" : "http://172.30.65.138:64278/logPage/?appId=app-20180109111548-0000&executorId=1&logType=stdout", + "stderr" : "http://172.30.65.138:64278/logPage/?appId=app-20180109111548-0000&executorId=1&logType=stderr" + }, + "schedulerDelay" : 5, + "gettingResultTime" : 0 } }, - "executorSummary": { - "0": { - "taskTime": 589, - "failedTasks": 2, - "succeededTasks": 0, - "killedTasks": 0, - "inputBytes": 0, - "inputRecords": 0, - "outputBytes": 0, - "outputRecords": 0, - "shuffleRead": 0, - "shuffleReadRecords": 0, - "shuffleWrite": 0, - "shuffleWriteRecords": 0, - "memoryBytesSpilled": 0, - "diskBytesSpilled": 0, - "isBlacklistedForStage": true + "executorSummary" : { + "0" : { + "taskTime" : 589, + "failedTasks" : 2, + "succeededTasks" : 0, + "killedTasks" : 0, + "inputBytes" : 0, + "inputRecords" : 0, + "outputBytes" : 0, + "outputRecords" : 0, + "shuffleRead" : 0, + "shuffleReadRecords" : 0, + "shuffleWrite" : 0, + "shuffleWriteRecords" : 0, + "memoryBytesSpilled" : 0, + "diskBytesSpilled" : 0, + "isBlacklistedForStage" : true }, - "1": { - "taskTime": 708, - "failedTasks": 0, - "succeededTasks": 10, - "killedTasks": 0, - "inputBytes": 0, - "inputRecords": 0, - "outputBytes": 0, - "outputRecords": 0, - "shuffleRead": 0, - "shuffleReadRecords": 0, - "shuffleWrite": 460, - "shuffleWriteRecords": 10, - "memoryBytesSpilled": 0, - "diskBytesSpilled": 0, - "isBlacklistedForStage": false + "1" : { + "taskTime" : 708, + "failedTasks" : 0, + "succeededTasks" : 10, + "killedTasks" : 0, + "inputBytes" : 0, + "inputRecords" : 0, + "outputBytes" : 0, + "outputRecords" : 0, + "shuffleRead" : 0, + "shuffleReadRecords" : 0, + "shuffleWrite" : 460, + "shuffleWriteRecords" : 10, + "memoryBytesSpilled" : 0, + "diskBytesSpilled" : 0, + "isBlacklistedForStage" : false } }, - "killedTasksSummary": {} + "killedTasksSummary" : { } } diff --git a/core/src/test/resources/HistoryServerExpectations/blacklisting_node_for_stage_expectation.json b/core/src/test/resources/HistoryServerExpectations/blacklisting_node_for_stage_expectation.json index acd4cc53de6cd..6e46c881b2a21 100644 --- a/core/src/test/resources/HistoryServerExpectations/blacklisting_node_for_stage_expectation.json +++ b/core/src/test/resources/HistoryServerExpectations/blacklisting_node_for_stage_expectation.json @@ -74,7 +74,13 @@ "writeTime" : 3662221, "recordsWritten" : 3 } - } + }, + "executorLogs" : { + "stdout" : "http://apiros-3.gce.test.com:8042/node/containerlogs/container_1516285256255_0012_01_000002/attilapiros/stdout?start=-4096", + "stderr" : "http://apiros-3.gce.test.com:8042/node/containerlogs/container_1516285256255_0012_01_000002/attilapiros/stderr?start=-4096" + }, + "schedulerDelay" : 68, + "gettingResultTime" : 0 }, "5" : { "taskId" : 5, @@ -122,7 +128,13 @@ "writeTime" : 191901, "recordsWritten" : 0 } - } + }, + "executorLogs" : { + "stdout" : "http://apiros-2.gce.test.com:8042/node/containerlogs/container_1516285256255_0012_01_000007/attilapiros/stdout?start=-4096", + "stderr" : "http://apiros-2.gce.test.com:8042/node/containerlogs/container_1516285256255_0012_01_000007/attilapiros/stderr?start=-4096" + }, + "schedulerDelay" : 46, + "gettingResultTime" : 0 }, "10" : { "taskId" : 10, @@ -169,7 +181,13 @@ "writeTime" : 301705, "recordsWritten" : 3 } - } + }, + "executorLogs" : { + "stdout" : "http://apiros-3.gce.test.com:8042/node/containerlogs/container_1516285256255_0012_01_000003/attilapiros/stdout?start=-4096", + "stderr" : "http://apiros-3.gce.test.com:8042/node/containerlogs/container_1516285256255_0012_01_000003/attilapiros/stderr?start=-4096" + }, + "schedulerDelay" : 50, + "gettingResultTime" : 0 }, "1" : { "taskId" : 1, @@ -217,7 +235,13 @@ "writeTime" : 3075188, "recordsWritten" : 0 } - } + }, + "executorLogs" : { + "stdout" : "http://apiros-2.gce.test.com:8042/node/containerlogs/container_1516285256255_0012_01_000007/attilapiros/stdout?start=-4096", + "stderr" : "http://apiros-2.gce.test.com:8042/node/containerlogs/container_1516285256255_0012_01_000007/attilapiros/stderr?start=-4096" + }, + "schedulerDelay" : 174, + "gettingResultTime" : 0 }, "6" : { "taskId" : 6, @@ -265,7 +289,13 @@ "writeTime" : 183718, "recordsWritten" : 0 } - } + }, + "executorLogs" : { + "stdout" : "http://apiros-2.gce.test.com:8042/node/containerlogs/container_1516285256255_0012_01_000005/attilapiros/stdout?start=-4096", + "stderr" : "http://apiros-2.gce.test.com:8042/node/containerlogs/container_1516285256255_0012_01_000005/attilapiros/stderr?start=-4096" + }, + "schedulerDelay" : 16, + "gettingResultTime" : 0 }, "9" : { "taskId" : 9, @@ -312,7 +342,13 @@ "writeTime" : 366050, "recordsWritten" : 3 } - } + }, + "executorLogs" : { + "stdout" : "http://apiros-3.gce.test.com:8042/node/containerlogs/container_1516285256255_0012_01_000003/attilapiros/stdout?start=-4096", + "stderr" : "http://apiros-3.gce.test.com:8042/node/containerlogs/container_1516285256255_0012_01_000003/attilapiros/stderr?start=-4096" + }, + "schedulerDelay" : 42, + "gettingResultTime" : 0 }, "13" : { "taskId" : 13, @@ -359,7 +395,13 @@ "writeTime" : 369513, "recordsWritten" : 3 } - } + }, + "executorLogs" : { + "stdout" : "http://apiros-3.gce.test.com:8042/node/containerlogs/container_1516285256255_0012_01_000003/attilapiros/stdout?start=-4096", + "stderr" : "http://apiros-3.gce.test.com:8042/node/containerlogs/container_1516285256255_0012_01_000003/attilapiros/stderr?start=-4096" + }, + "schedulerDelay" : 26, + "gettingResultTime" : 0 }, "2" : { "taskId" : 2, @@ -406,7 +448,13 @@ "writeTime" : 3322956, "recordsWritten" : 3 } - } + }, + "executorLogs" : { + "stdout" : "http://apiros-2.gce.test.com:8042/node/containerlogs/container_1516285256255_0012_01_000004/attilapiros/stdout?start=-4096", + "stderr" : "http://apiros-2.gce.test.com:8042/node/containerlogs/container_1516285256255_0012_01_000004/attilapiros/stderr?start=-4096" + }, + "schedulerDelay" : 74, + "gettingResultTime" : 0 }, "12" : { "taskId" : 12, @@ -453,7 +501,13 @@ "writeTime" : 319101, "recordsWritten" : 3 } - } + }, + "executorLogs" : { + "stdout" : "http://apiros-3.gce.test.com:8042/node/containerlogs/container_1516285256255_0012_01_000002/attilapiros/stdout?start=-4096", + "stderr" : "http://apiros-3.gce.test.com:8042/node/containerlogs/container_1516285256255_0012_01_000002/attilapiros/stderr?start=-4096" + }, + "schedulerDelay" : 22, + "gettingResultTime" : 0 }, "7" : { "taskId" : 7, @@ -500,7 +554,13 @@ "writeTime" : 377601, "recordsWritten" : 3 } - } + }, + "executorLogs" : { + "stdout" : "http://apiros-3.gce.test.com:8042/node/containerlogs/container_1516285256255_0012_01_000003/attilapiros/stdout?start=-4096", + "stderr" : "http://apiros-3.gce.test.com:8042/node/containerlogs/container_1516285256255_0012_01_000003/attilapiros/stderr?start=-4096" + }, + "schedulerDelay" : 20, + "gettingResultTime" : 0 }, "3" : { "taskId" : 3, @@ -547,7 +607,13 @@ "writeTime" : 3587839, "recordsWritten" : 3 } - } + }, + "executorLogs" : { + "stdout" : "http://apiros-3.gce.test.com:8042/node/containerlogs/container_1516285256255_0012_01_000003/attilapiros/stdout?start=-4096", + "stderr" : "http://apiros-3.gce.test.com:8042/node/containerlogs/container_1516285256255_0012_01_000003/attilapiros/stderr?start=-4096" + }, + "schedulerDelay" : 63, + "gettingResultTime" : 0 }, "11" : { "taskId" : 11, @@ -594,7 +660,13 @@ "writeTime" : 323898, "recordsWritten" : 3 } - } + }, + "executorLogs" : { + "stdout" : "http://apiros-3.gce.test.com:8042/node/containerlogs/container_1516285256255_0012_01_000002/attilapiros/stdout?start=-4096", + "stderr" : "http://apiros-3.gce.test.com:8042/node/containerlogs/container_1516285256255_0012_01_000002/attilapiros/stderr?start=-4096" + }, + "schedulerDelay" : 12, + "gettingResultTime" : 0 }, "8" : { "taskId" : 8, @@ -641,7 +713,13 @@ "writeTime" : 311940, "recordsWritten" : 3 } - } + }, + "executorLogs" : { + "stdout" : "http://apiros-3.gce.test.com:8042/node/containerlogs/container_1516285256255_0012_01_000002/attilapiros/stdout?start=-4096", + "stderr" : "http://apiros-3.gce.test.com:8042/node/containerlogs/container_1516285256255_0012_01_000002/attilapiros/stderr?start=-4096" + }, + "schedulerDelay" : 84, + "gettingResultTime" : 0 }, "4" : { "taskId" : 4, @@ -689,7 +767,13 @@ "writeTime" : 16858066, "recordsWritten" : 0 } - } + }, + "executorLogs" : { + "stdout" : "http://apiros-2.gce.test.com:8042/node/containerlogs/container_1516285256255_0012_01_000005/attilapiros/stdout?start=-4096", + "stderr" : "http://apiros-2.gce.test.com:8042/node/containerlogs/container_1516285256255_0012_01_000005/attilapiros/stderr?start=-4096" + }, + "schedulerDelay" : 338, + "gettingResultTime" : 0 } }, "executorSummary" : { diff --git a/core/src/test/resources/HistoryServerExpectations/one_stage_attempt_json_expectation.json b/core/src/test/resources/HistoryServerExpectations/one_stage_attempt_json_expectation.json index 03f886afa5413..aa9471301fe3e 100644 --- a/core/src/test/resources/HistoryServerExpectations/one_stage_attempt_json_expectation.json +++ b/core/src/test/resources/HistoryServerExpectations/one_stage_attempt_json_expectation.json @@ -74,7 +74,10 @@ "writeTime" : 76000, "recordsWritten" : 0 } - } + }, + "executorLogs" : { }, + "schedulerDelay" : 19, + "gettingResultTime" : 0 }, "14" : { "taskId" : 14, @@ -121,7 +124,10 @@ "writeTime" : 88000, "recordsWritten" : 0 } - } + }, + "executorLogs" : { }, + "schedulerDelay" : 13, + "gettingResultTime" : 0 }, "9" : { "taskId" : 9, @@ -168,7 +174,10 @@ "writeTime" : 98000, "recordsWritten" : 0 } - } + }, + "executorLogs" : { }, + "schedulerDelay" : 17, + "gettingResultTime" : 0 }, "13" : { "taskId" : 13, @@ -215,7 +224,10 @@ "writeTime" : 73000, "recordsWritten" : 0 } - } + }, + "executorLogs" : { }, + "schedulerDelay" : 14, + "gettingResultTime" : 0 }, "12" : { "taskId" : 12, @@ -262,7 +274,10 @@ "writeTime" : 101000, "recordsWritten" : 0 } - } + }, + "executorLogs" : { }, + "schedulerDelay" : 17, + "gettingResultTime" : 0 }, "11" : { "taskId" : 11, @@ -309,7 +324,10 @@ "writeTime" : 83000, "recordsWritten" : 0 } - } + }, + "executorLogs" : { }, + "schedulerDelay" : 17, + "gettingResultTime" : 0 }, "8" : { "taskId" : 8, @@ -356,7 +374,10 @@ "writeTime" : 94000, "recordsWritten" : 0 } - } + }, + "executorLogs" : { }, + "schedulerDelay" : 16, + "gettingResultTime" : 0 }, "15" : { "taskId" : 15, @@ -403,7 +424,10 @@ "writeTime" : 79000, "recordsWritten" : 0 } - } + }, + "executorLogs" : { }, + "schedulerDelay" : 13, + "gettingResultTime" : 0 } }, "executorSummary" : { diff --git a/core/src/test/resources/HistoryServerExpectations/one_stage_json_expectation.json b/core/src/test/resources/HistoryServerExpectations/one_stage_json_expectation.json index 947c89906955d..584803b5e8631 100644 --- a/core/src/test/resources/HistoryServerExpectations/one_stage_json_expectation.json +++ b/core/src/test/resources/HistoryServerExpectations/one_stage_json_expectation.json @@ -74,7 +74,10 @@ "writeTime" : 76000, "recordsWritten" : 0 } - } + }, + "executorLogs" : { }, + "schedulerDelay" : 19, + "gettingResultTime" : 0 }, "14" : { "taskId" : 14, @@ -121,7 +124,10 @@ "writeTime" : 88000, "recordsWritten" : 0 } - } + }, + "executorLogs" : { }, + "schedulerDelay" : 13, + "gettingResultTime" : 0 }, "9" : { "taskId" : 9, @@ -168,7 +174,10 @@ "writeTime" : 98000, "recordsWritten" : 0 } - } + }, + "executorLogs" : { }, + "schedulerDelay" : 17, + "gettingResultTime" : 0 }, "13" : { "taskId" : 13, @@ -215,7 +224,10 @@ "writeTime" : 73000, "recordsWritten" : 0 } - } + }, + "executorLogs" : { }, + "schedulerDelay" : 14, + "gettingResultTime" : 0 }, "12" : { "taskId" : 12, @@ -262,7 +274,10 @@ "writeTime" : 101000, "recordsWritten" : 0 } - } + }, + "executorLogs" : { }, + "schedulerDelay" : 17, + "gettingResultTime" : 0 }, "11" : { "taskId" : 11, @@ -309,7 +324,10 @@ "writeTime" : 83000, "recordsWritten" : 0 } - } + }, + "executorLogs" : { }, + "schedulerDelay" : 17, + "gettingResultTime" : 0 }, "8" : { "taskId" : 8, @@ -356,7 +374,10 @@ "writeTime" : 94000, "recordsWritten" : 0 } - } + }, + "executorLogs" : { }, + "schedulerDelay" : 16, + "gettingResultTime" : 0 }, "15" : { "taskId" : 15, @@ -403,7 +424,10 @@ "writeTime" : 79000, "recordsWritten" : 0 } - } + }, + "executorLogs" : { }, + "schedulerDelay" : 13, + "gettingResultTime" : 0 } }, "executorSummary" : { diff --git a/core/src/test/resources/HistoryServerExpectations/stage_task_list_expectation.json b/core/src/test/resources/HistoryServerExpectations/stage_task_list_expectation.json index a15ee23523365..f859ab6fff240 100644 --- a/core/src/test/resources/HistoryServerExpectations/stage_task_list_expectation.json +++ b/core/src/test/resources/HistoryServerExpectations/stage_task_list_expectation.json @@ -43,7 +43,10 @@ "writeTime" : 3842811, "recordsWritten" : 10 } - } + }, + "executorLogs" : { }, + "schedulerDelay" : 53, + "gettingResultTime" : 0 }, { "taskId" : 1, "index" : 1, @@ -89,7 +92,10 @@ "writeTime" : 3934399, "recordsWritten" : 10 } - } + }, + "executorLogs" : { }, + "schedulerDelay" : 40, + "gettingResultTime" : 0 }, { "taskId" : 2, "index" : 2, @@ -135,7 +141,10 @@ "writeTime" : 89885, "recordsWritten" : 10 } - } + }, + "executorLogs" : { }, + "schedulerDelay" : 37, + "gettingResultTime" : 0 }, { "taskId" : 3, "index" : 3, @@ -181,7 +190,10 @@ "writeTime" : 1311694, "recordsWritten" : 10 } - } + }, + "executorLogs" : { }, + "schedulerDelay" : 41, + "gettingResultTime" : 0 }, { "taskId" : 4, "index" : 4, @@ -227,7 +239,10 @@ "writeTime" : 83022, "recordsWritten" : 10 } - } + }, + "executorLogs" : { }, + "schedulerDelay" : 38, + "gettingResultTime" : 0 }, { "taskId" : 5, "index" : 5, @@ -273,7 +288,10 @@ "writeTime" : 3675510, "recordsWritten" : 10 } - } + }, + "executorLogs" : { }, + "schedulerDelay" : 33, + "gettingResultTime" : 0 }, { "taskId" : 6, "index" : 6, @@ -319,7 +337,10 @@ "writeTime" : 4016617, "recordsWritten" : 10 } - } + }, + "executorLogs" : { }, + "schedulerDelay" : 38, + "gettingResultTime" : 0 }, { "taskId" : 7, "index" : 7, @@ -365,7 +386,10 @@ "writeTime" : 2579051, "recordsWritten" : 10 } - } + }, + "executorLogs" : { }, + "schedulerDelay" : 43, + "gettingResultTime" : 0 }, { "taskId" : 8, "index" : 8, @@ -411,7 +435,10 @@ "writeTime" : 121551, "recordsWritten" : 10 } - } + }, + "executorLogs" : { }, + "schedulerDelay" : 5, + "gettingResultTime" : 0 }, { "taskId" : 9, "index" : 9, @@ -457,7 +484,10 @@ "writeTime" : 101664, "recordsWritten" : 10 } - } + }, + "executorLogs" : { }, + "schedulerDelay" : 8, + "gettingResultTime" : 0 }, { "taskId" : 10, "index" : 10, @@ -503,7 +533,10 @@ "writeTime" : 94709, "recordsWritten" : 10 } - } + }, + "executorLogs" : { }, + "schedulerDelay" : 18, + "gettingResultTime" : 0 }, { "taskId" : 11, "index" : 11, @@ -549,7 +582,10 @@ "writeTime" : 94507, "recordsWritten" : 10 } - } + }, + "executorLogs" : { }, + "schedulerDelay" : 8, + "gettingResultTime" : 0 }, { "taskId" : 12, "index" : 12, @@ -595,7 +631,10 @@ "writeTime" : 102476, "recordsWritten" : 10 } - } + }, + "executorLogs" : { }, + "schedulerDelay" : 7, + "gettingResultTime" : 0 }, { "taskId" : 13, "index" : 13, @@ -641,7 +680,10 @@ "writeTime" : 95004, "recordsWritten" : 10 } - } + }, + "executorLogs" : { }, + "schedulerDelay" : 53, + "gettingResultTime" : 0 }, { "taskId" : 14, "index" : 14, @@ -687,7 +729,10 @@ "writeTime" : 95646, "recordsWritten" : 10 } - } + }, + "executorLogs" : { }, + "schedulerDelay" : 5, + "gettingResultTime" : 0 }, { "taskId" : 15, "index" : 15, @@ -733,7 +778,10 @@ "writeTime" : 602780, "recordsWritten" : 10 } - } + }, + "executorLogs" : { }, + "schedulerDelay" : 4, + "gettingResultTime" : 0 }, { "taskId" : 16, "index" : 16, @@ -779,7 +827,10 @@ "writeTime" : 108320, "recordsWritten" : 10 } - } + }, + "executorLogs" : { }, + "schedulerDelay" : 4, + "gettingResultTime" : 0 }, { "taskId" : 17, "index" : 17, @@ -825,7 +876,10 @@ "writeTime" : 99944, "recordsWritten" : 10 } - } + }, + "executorLogs" : { }, + "schedulerDelay" : 20, + "gettingResultTime" : 0 }, { "taskId" : 18, "index" : 18, @@ -871,7 +925,10 @@ "writeTime" : 100836, "recordsWritten" : 10 } - } + }, + "executorLogs" : { }, + "schedulerDelay" : 3, + "gettingResultTime" : 0 }, { "taskId" : 19, "index" : 19, @@ -917,5 +974,8 @@ "writeTime" : 95788, "recordsWritten" : 10 } - } + }, + "executorLogs" : { }, + "schedulerDelay" : 5, + "gettingResultTime" : 0 } ] diff --git a/core/src/test/resources/HistoryServerExpectations/stage_task_list_from_multi_attempt_app_json_1__expectation.json b/core/src/test/resources/HistoryServerExpectations/stage_task_list_from_multi_attempt_app_json_1__expectation.json index f9182b1658334..ea88ca116707a 100644 --- a/core/src/test/resources/HistoryServerExpectations/stage_task_list_from_multi_attempt_app_json_1__expectation.json +++ b/core/src/test/resources/HistoryServerExpectations/stage_task_list_from_multi_attempt_app_json_1__expectation.json @@ -48,7 +48,10 @@ "writeTime" : 0, "recordsWritten" : 0 } - } + }, + "executorLogs" : { }, + "schedulerDelay" : 30, + "gettingResultTime" : 0 }, { "taskId" : 1, "index" : 1, @@ -99,7 +102,10 @@ "writeTime" : 0, "recordsWritten" : 0 } - } + }, + "executorLogs" : { }, + "schedulerDelay" : 22, + "gettingResultTime" : 0 }, { "taskId" : 2, "index" : 2, @@ -150,7 +156,10 @@ "writeTime" : 0, "recordsWritten" : 0 } - } + }, + "executorLogs" : { }, + "schedulerDelay" : 18, + "gettingResultTime" : 0 }, { "taskId" : 3, "index" : 3, @@ -201,7 +210,10 @@ "writeTime" : 0, "recordsWritten" : 0 } - } + }, + "executorLogs" : { }, + "schedulerDelay" : 20, + "gettingResultTime" : 0 }, { "taskId" : 4, "index" : 4, @@ -252,7 +264,10 @@ "writeTime" : 0, "recordsWritten" : 0 } - } + }, + "executorLogs" : { }, + "schedulerDelay" : 24, + "gettingResultTime" : 0 }, { "taskId" : 5, "index" : 5, @@ -303,7 +318,10 @@ "writeTime" : 0, "recordsWritten" : 0 } - } + }, + "executorLogs" : { }, + "schedulerDelay" : 23, + "gettingResultTime" : 0 }, { "taskId" : 6, "index" : 6, @@ -354,7 +372,10 @@ "writeTime" : 0, "recordsWritten" : 0 } - } + }, + "executorLogs" : { }, + "schedulerDelay" : 22, + "gettingResultTime" : 0 }, { "taskId" : 7, "index" : 7, @@ -405,5 +426,8 @@ "writeTime" : 0, "recordsWritten" : 0 } - } + }, + "executorLogs" : { }, + "schedulerDelay" : 22, + "gettingResultTime" : 0 } ] diff --git a/core/src/test/resources/HistoryServerExpectations/stage_task_list_from_multi_attempt_app_json_2__expectation.json b/core/src/test/resources/HistoryServerExpectations/stage_task_list_from_multi_attempt_app_json_2__expectation.json index 76dd2f710b90f..efd0a45bf01d0 100644 --- a/core/src/test/resources/HistoryServerExpectations/stage_task_list_from_multi_attempt_app_json_2__expectation.json +++ b/core/src/test/resources/HistoryServerExpectations/stage_task_list_from_multi_attempt_app_json_2__expectation.json @@ -48,7 +48,10 @@ "writeTime" : 0, "recordsWritten" : 0 } - } + }, + "executorLogs" : { }, + "schedulerDelay" : 30, + "gettingResultTime" : 0 }, { "taskId" : 1, "index" : 1, @@ -99,7 +102,10 @@ "writeTime" : 0, "recordsWritten" : 0 } - } + }, + "executorLogs" : { }, + "schedulerDelay" : 22, + "gettingResultTime" : 0 }, { "taskId" : 2, "index" : 2, @@ -150,7 +156,10 @@ "writeTime" : 0, "recordsWritten" : 0 } - } + }, + "executorLogs" : { }, + "schedulerDelay" : 18, + "gettingResultTime" : 0 }, { "taskId" : 3, "index" : 3, @@ -201,7 +210,10 @@ "writeTime" : 0, "recordsWritten" : 0 } - } + }, + "executorLogs" : { }, + "schedulerDelay" : 20, + "gettingResultTime" : 0 }, { "taskId" : 4, "index" : 4, @@ -252,7 +264,10 @@ "writeTime" : 0, "recordsWritten" : 0 } - } + }, + "executorLogs" : { }, + "schedulerDelay" : 24, + "gettingResultTime" : 0 }, { "taskId" : 5, "index" : 5, @@ -303,7 +318,10 @@ "writeTime" : 0, "recordsWritten" : 0 } - } + }, + "executorLogs" : { }, + "schedulerDelay" : 23, + "gettingResultTime" : 0 }, { "taskId" : 6, "index" : 6, @@ -354,7 +372,10 @@ "writeTime" : 0, "recordsWritten" : 0 } - } + }, + "executorLogs" : { }, + "schedulerDelay" : 22, + "gettingResultTime" : 0 }, { "taskId" : 7, "index" : 7, @@ -405,5 +426,8 @@ "writeTime" : 0, "recordsWritten" : 0 } - } + }, + "executorLogs" : { }, + "schedulerDelay" : 22, + "gettingResultTime" : 0 } ] diff --git a/core/src/test/resources/HistoryServerExpectations/stage_task_list_w__offset___length_expectation.json b/core/src/test/resources/HistoryServerExpectations/stage_task_list_w__offset___length_expectation.json index 6bdc10465d89e..d83528d84972c 100644 --- a/core/src/test/resources/HistoryServerExpectations/stage_task_list_w__offset___length_expectation.json +++ b/core/src/test/resources/HistoryServerExpectations/stage_task_list_w__offset___length_expectation.json @@ -43,7 +43,10 @@ "writeTime" : 94709, "recordsWritten" : 10 } - } + }, + "executorLogs" : { }, + "schedulerDelay" : 18, + "gettingResultTime" : 0 }, { "taskId" : 11, "index" : 11, @@ -89,7 +92,10 @@ "writeTime" : 94507, "recordsWritten" : 10 } - } + }, + "executorLogs" : { }, + "schedulerDelay" : 8, + "gettingResultTime" : 0 }, { "taskId" : 12, "index" : 12, @@ -135,7 +141,10 @@ "writeTime" : 102476, "recordsWritten" : 10 } - } + }, + "executorLogs" : { }, + "schedulerDelay" : 7, + "gettingResultTime" : 0 }, { "taskId" : 13, "index" : 13, @@ -181,7 +190,10 @@ "writeTime" : 95004, "recordsWritten" : 10 } - } + }, + "executorLogs" : { }, + "schedulerDelay" : 53, + "gettingResultTime" : 0 }, { "taskId" : 14, "index" : 14, @@ -227,7 +239,10 @@ "writeTime" : 95646, "recordsWritten" : 10 } - } + }, + "executorLogs" : { }, + "schedulerDelay" : 5, + "gettingResultTime" : 0 }, { "taskId" : 15, "index" : 15, @@ -273,7 +288,10 @@ "writeTime" : 602780, "recordsWritten" : 10 } - } + }, + "executorLogs" : { }, + "schedulerDelay" : 4, + "gettingResultTime" : 0 }, { "taskId" : 16, "index" : 16, @@ -319,7 +337,10 @@ "writeTime" : 108320, "recordsWritten" : 10 } - } + }, + "executorLogs" : { }, + "schedulerDelay" : 4, + "gettingResultTime" : 0 }, { "taskId" : 17, "index" : 17, @@ -365,7 +386,10 @@ "writeTime" : 99944, "recordsWritten" : 10 } - } + }, + "executorLogs" : { }, + "schedulerDelay" : 20, + "gettingResultTime" : 0 }, { "taskId" : 18, "index" : 18, @@ -411,7 +435,10 @@ "writeTime" : 100836, "recordsWritten" : 10 } - } + }, + "executorLogs" : { }, + "schedulerDelay" : 3, + "gettingResultTime" : 0 }, { "taskId" : 19, "index" : 19, @@ -457,7 +484,10 @@ "writeTime" : 95788, "recordsWritten" : 10 } - } + }, + "executorLogs" : { }, + "schedulerDelay" : 5, + "gettingResultTime" : 0 }, { "taskId" : 20, "index" : 20, @@ -503,7 +533,10 @@ "writeTime" : 97716, "recordsWritten" : 10 } - } + }, + "executorLogs" : { }, + "schedulerDelay" : 4, + "gettingResultTime" : 0 }, { "taskId" : 21, "index" : 21, @@ -549,7 +582,10 @@ "writeTime" : 100270, "recordsWritten" : 10 } - } + }, + "executorLogs" : { }, + "schedulerDelay" : 3, + "gettingResultTime" : 0 }, { "taskId" : 22, "index" : 22, @@ -595,7 +631,10 @@ "writeTime" : 143427, "recordsWritten" : 10 } - } + }, + "executorLogs" : { }, + "schedulerDelay" : 3, + "gettingResultTime" : 0 }, { "taskId" : 23, "index" : 23, @@ -641,7 +680,10 @@ "writeTime" : 91844, "recordsWritten" : 10 } - } + }, + "executorLogs" : { }, + "schedulerDelay" : 16, + "gettingResultTime" : 0 }, { "taskId" : 24, "index" : 24, @@ -687,7 +729,10 @@ "writeTime" : 157194, "recordsWritten" : 10 } - } + }, + "executorLogs" : { }, + "schedulerDelay" : 5, + "gettingResultTime" : 0 }, { "taskId" : 25, "index" : 25, @@ -733,7 +778,10 @@ "writeTime" : 94134, "recordsWritten" : 10 } - } + }, + "executorLogs" : { }, + "schedulerDelay" : 9, + "gettingResultTime" : 0 }, { "taskId" : 26, "index" : 26, @@ -779,7 +827,10 @@ "writeTime" : 108213, "recordsWritten" : 10 } - } + }, + "executorLogs" : { }, + "schedulerDelay" : 8, + "gettingResultTime" : 0 }, { "taskId" : 27, "index" : 27, @@ -825,7 +876,10 @@ "writeTime" : 102019, "recordsWritten" : 10 } - } + }, + "executorLogs" : { }, + "schedulerDelay" : 6, + "gettingResultTime" : 0 }, { "taskId" : 28, "index" : 28, @@ -871,7 +925,10 @@ "writeTime" : 104299, "recordsWritten" : 10 } - } + }, + "executorLogs" : { }, + "schedulerDelay" : 17, + "gettingResultTime" : 0 }, { "taskId" : 29, "index" : 29, @@ -917,7 +974,10 @@ "writeTime" : 114938, "recordsWritten" : 10 } - } + }, + "executorLogs" : { }, + "schedulerDelay" : 11, + "gettingResultTime" : 0 }, { "taskId" : 30, "index" : 30, @@ -963,7 +1023,10 @@ "writeTime" : 119770, "recordsWritten" : 10 } - } + }, + "executorLogs" : { }, + "schedulerDelay" : 24, + "gettingResultTime" : 0 }, { "taskId" : 31, "index" : 31, @@ -1009,7 +1072,10 @@ "writeTime" : 92619, "recordsWritten" : 10 } - } + }, + "executorLogs" : { }, + "schedulerDelay" : 14, + "gettingResultTime" : 0 }, { "taskId" : 32, "index" : 32, @@ -1055,7 +1121,10 @@ "writeTime" : 89603, "recordsWritten" : 10 } - } + }, + "executorLogs" : { }, + "schedulerDelay" : 13, + "gettingResultTime" : 0 }, { "taskId" : 33, "index" : 33, @@ -1101,7 +1170,10 @@ "writeTime" : 118329, "recordsWritten" : 10 } - } + }, + "executorLogs" : { }, + "schedulerDelay" : 10, + "gettingResultTime" : 0 }, { "taskId" : 34, "index" : 34, @@ -1147,7 +1219,10 @@ "writeTime" : 127746, "recordsWritten" : 10 } - } + }, + "executorLogs" : { }, + "schedulerDelay" : 9, + "gettingResultTime" : 0 }, { "taskId" : 35, "index" : 35, @@ -1193,7 +1268,10 @@ "writeTime" : 160963, "recordsWritten" : 10 } - } + }, + "executorLogs" : { }, + "schedulerDelay" : 10, + "gettingResultTime" : 0 }, { "taskId" : 36, "index" : 36, @@ -1239,7 +1317,10 @@ "writeTime" : 123855, "recordsWritten" : 10 } - } + }, + "executorLogs" : { }, + "schedulerDelay" : 6, + "gettingResultTime" : 0 }, { "taskId" : 37, "index" : 37, @@ -1285,7 +1366,10 @@ "writeTime" : 111869, "recordsWritten" : 10 } - } + }, + "executorLogs" : { }, + "schedulerDelay" : 6, + "gettingResultTime" : 0 }, { "taskId" : 38, "index" : 38, @@ -1331,7 +1415,10 @@ "writeTime" : 131158, "recordsWritten" : 10 } - } + }, + "executorLogs" : { }, + "schedulerDelay" : 13, + "gettingResultTime" : 0 }, { "taskId" : 39, "index" : 39, @@ -1377,7 +1464,10 @@ "writeTime" : 98748, "recordsWritten" : 10 } - } + }, + "executorLogs" : { }, + "schedulerDelay" : 13, + "gettingResultTime" : 0 }, { "taskId" : 40, "index" : 40, @@ -1423,7 +1513,10 @@ "writeTime" : 94792, "recordsWritten" : 10 } - } + }, + "executorLogs" : { }, + "schedulerDelay" : 6, + "gettingResultTime" : 0 }, { "taskId" : 41, "index" : 41, @@ -1469,7 +1562,10 @@ "writeTime" : 90765, "recordsWritten" : 10 } - } + }, + "executorLogs" : { }, + "schedulerDelay" : 6, + "gettingResultTime" : 0 }, { "taskId" : 42, "index" : 42, @@ -1515,7 +1611,10 @@ "writeTime" : 103713, "recordsWritten" : 10 } - } + }, + "executorLogs" : { }, + "schedulerDelay" : 15, + "gettingResultTime" : 0 }, { "taskId" : 43, "index" : 43, @@ -1561,7 +1660,10 @@ "writeTime" : 171516, "recordsWritten" : 10 } - } + }, + "executorLogs" : { }, + "schedulerDelay" : 21, + "gettingResultTime" : 0 }, { "taskId" : 44, "index" : 44, @@ -1607,7 +1709,10 @@ "writeTime" : 98293, "recordsWritten" : 10 } - } + }, + "executorLogs" : { }, + "schedulerDelay" : 16, + "gettingResultTime" : 0 }, { "taskId" : 45, "index" : 45, @@ -1653,7 +1758,10 @@ "writeTime" : 92985, "recordsWritten" : 10 } - } + }, + "executorLogs" : { }, + "schedulerDelay" : 16, + "gettingResultTime" : 0 }, { "taskId" : 46, "index" : 46, @@ -1699,7 +1807,10 @@ "writeTime" : 113322, "recordsWritten" : 10 } - } + }, + "executorLogs" : { }, + "schedulerDelay" : 11, + "gettingResultTime" : 0 }, { "taskId" : 47, "index" : 47, @@ -1745,7 +1856,10 @@ "writeTime" : 103015, "recordsWritten" : 10 } - } + }, + "executorLogs" : { }, + "schedulerDelay" : 13, + "gettingResultTime" : 0 }, { "taskId" : 48, "index" : 48, @@ -1791,7 +1905,10 @@ "writeTime" : 139844, "recordsWritten" : 10 } - } + }, + "executorLogs" : { }, + "schedulerDelay" : 3, + "gettingResultTime" : 0 }, { "taskId" : 49, "index" : 49, @@ -1837,7 +1954,10 @@ "writeTime" : 94984, "recordsWritten" : 10 } - } + }, + "executorLogs" : { }, + "schedulerDelay" : 4, + "gettingResultTime" : 0 }, { "taskId" : 50, "index" : 50, @@ -1883,7 +2003,10 @@ "writeTime" : 90836, "recordsWritten" : 10 } - } + }, + "executorLogs" : { }, + "schedulerDelay" : 4, + "gettingResultTime" : 0 }, { "taskId" : 51, "index" : 51, @@ -1929,7 +2052,10 @@ "writeTime" : 96013, "recordsWritten" : 10 } - } + }, + "executorLogs" : { }, + "schedulerDelay" : 2, + "gettingResultTime" : 0 }, { "taskId" : 52, "index" : 52, @@ -1975,7 +2101,10 @@ "writeTime" : 89664, "recordsWritten" : 10 } - } + }, + "executorLogs" : { }, + "schedulerDelay" : 5, + "gettingResultTime" : 0 }, { "taskId" : 53, "index" : 53, @@ -2021,7 +2150,10 @@ "writeTime" : 92835, "recordsWritten" : 10 } - } + }, + "executorLogs" : { }, + "schedulerDelay" : 5, + "gettingResultTime" : 0 }, { "taskId" : 54, "index" : 54, @@ -2067,7 +2199,10 @@ "writeTime" : 90506, "recordsWritten" : 10 } - } + }, + "executorLogs" : { }, + "schedulerDelay" : 38, + "gettingResultTime" : 0 }, { "taskId" : 55, "index" : 55, @@ -2113,7 +2248,10 @@ "writeTime" : 108309, "recordsWritten" : 10 } - } + }, + "executorLogs" : { }, + "schedulerDelay" : 5, + "gettingResultTime" : 0 }, { "taskId" : 56, "index" : 56, @@ -2159,7 +2297,10 @@ "writeTime" : 90329, "recordsWritten" : 10 } - } + }, + "executorLogs" : { }, + "schedulerDelay" : 6, + "gettingResultTime" : 0 }, { "taskId" : 57, "index" : 57, @@ -2205,7 +2346,10 @@ "writeTime" : 96849, "recordsWritten" : 10 } - } + }, + "executorLogs" : { }, + "schedulerDelay" : 2, + "gettingResultTime" : 0 }, { "taskId" : 58, "index" : 58, @@ -2251,7 +2395,10 @@ "writeTime" : 97521, "recordsWritten" : 10 } - } + }, + "executorLogs" : { }, + "schedulerDelay" : 4, + "gettingResultTime" : 0 }, { "taskId" : 59, "index" : 59, @@ -2297,5 +2444,8 @@ "writeTime" : 100753, "recordsWritten" : 10 } - } + }, + "executorLogs" : { }, + "schedulerDelay" : 3, + "gettingResultTime" : 0 } ] diff --git a/core/src/test/resources/HistoryServerExpectations/stage_task_list_w__sortBy_expectation.json b/core/src/test/resources/HistoryServerExpectations/stage_task_list_w__sortBy_expectation.json index bc1cd49909d31..82e339c8f56dd 100644 --- a/core/src/test/resources/HistoryServerExpectations/stage_task_list_w__sortBy_expectation.json +++ b/core/src/test/resources/HistoryServerExpectations/stage_task_list_w__sortBy_expectation.json @@ -43,7 +43,10 @@ "writeTime" : 4016617, "recordsWritten" : 10 } - } + }, + "executorLogs" : { }, + "schedulerDelay" : 38, + "gettingResultTime" : 0 }, { "taskId" : 5, "index" : 5, @@ -89,7 +92,10 @@ "writeTime" : 3675510, "recordsWritten" : 10 } - } + }, + "executorLogs" : { }, + "schedulerDelay" : 33, + "gettingResultTime" : 0 }, { "taskId" : 1, "index" : 1, @@ -135,7 +141,10 @@ "writeTime" : 3934399, "recordsWritten" : 10 } - } + }, + "executorLogs" : { }, + "schedulerDelay" : 40, + "gettingResultTime" : 0 }, { "taskId" : 7, "index" : 7, @@ -181,7 +190,10 @@ "writeTime" : 2579051, "recordsWritten" : 10 } - } + }, + "executorLogs" : { }, + "schedulerDelay" : 43, + "gettingResultTime" : 0 }, { "taskId" : 4, "index" : 4, @@ -227,7 +239,10 @@ "writeTime" : 83022, "recordsWritten" : 10 } - } + }, + "executorLogs" : { }, + "schedulerDelay" : 38, + "gettingResultTime" : 0 }, { "taskId" : 3, "index" : 3, @@ -273,7 +288,10 @@ "writeTime" : 1311694, "recordsWritten" : 10 } - } + }, + "executorLogs" : { }, + "schedulerDelay" : 41, + "gettingResultTime" : 0 }, { "taskId" : 0, "index" : 0, @@ -319,7 +337,10 @@ "writeTime" : 3842811, "recordsWritten" : 10 } - } + }, + "executorLogs" : { }, + "schedulerDelay" : 53, + "gettingResultTime" : 0 }, { "taskId" : 2, "index" : 2, @@ -365,7 +386,10 @@ "writeTime" : 89885, "recordsWritten" : 10 } - } + }, + "executorLogs" : { }, + "schedulerDelay" : 37, + "gettingResultTime" : 0 }, { "taskId" : 22, "index" : 22, @@ -411,7 +435,10 @@ "writeTime" : 143427, "recordsWritten" : 10 } - } + }, + "executorLogs" : { }, + "schedulerDelay" : 3, + "gettingResultTime" : 0 }, { "taskId" : 18, "index" : 18, @@ -457,7 +484,10 @@ "writeTime" : 100836, "recordsWritten" : 10 } - } + }, + "executorLogs" : { }, + "schedulerDelay" : 3, + "gettingResultTime" : 0 }, { "taskId" : 17, "index" : 17, @@ -503,7 +533,10 @@ "writeTime" : 99944, "recordsWritten" : 10 } - } + }, + "executorLogs" : { }, + "schedulerDelay" : 20, + "gettingResultTime" : 0 }, { "taskId" : 21, "index" : 21, @@ -549,7 +582,10 @@ "writeTime" : 100270, "recordsWritten" : 10 } - } + }, + "executorLogs" : { }, + "schedulerDelay" : 3, + "gettingResultTime" : 0 }, { "taskId" : 19, "index" : 19, @@ -595,7 +631,10 @@ "writeTime" : 95788, "recordsWritten" : 10 } - } + }, + "executorLogs" : { }, + "schedulerDelay" : 5, + "gettingResultTime" : 0 }, { "taskId" : 16, "index" : 16, @@ -641,7 +680,10 @@ "writeTime" : 108320, "recordsWritten" : 10 } - } + }, + "executorLogs" : { }, + "schedulerDelay" : 4, + "gettingResultTime" : 0 }, { "taskId" : 9, "index" : 9, @@ -687,7 +729,10 @@ "writeTime" : 101664, "recordsWritten" : 10 } - } + }, + "executorLogs" : { }, + "schedulerDelay" : 8, + "gettingResultTime" : 0 }, { "taskId" : 20, "index" : 20, @@ -733,7 +778,10 @@ "writeTime" : 97716, "recordsWritten" : 10 } - } + }, + "executorLogs" : { }, + "schedulerDelay" : 4, + "gettingResultTime" : 0 }, { "taskId" : 14, "index" : 14, @@ -779,7 +827,10 @@ "writeTime" : 95646, "recordsWritten" : 10 } - } + }, + "executorLogs" : { }, + "schedulerDelay" : 5, + "gettingResultTime" : 0 }, { "taskId" : 8, "index" : 8, @@ -825,7 +876,10 @@ "writeTime" : 121551, "recordsWritten" : 10 } - } + }, + "executorLogs" : { }, + "schedulerDelay" : 5, + "gettingResultTime" : 0 }, { "taskId" : 12, "index" : 12, @@ -871,7 +925,10 @@ "writeTime" : 102476, "recordsWritten" : 10 } - } + }, + "executorLogs" : { }, + "schedulerDelay" : 7, + "gettingResultTime" : 0 }, { "taskId" : 15, "index" : 15, @@ -917,5 +974,8 @@ "writeTime" : 602780, "recordsWritten" : 10 } - } + }, + "executorLogs" : { }, + "schedulerDelay" : 4, + "gettingResultTime" : 0 } ] diff --git a/core/src/test/resources/HistoryServerExpectations/stage_task_list_w__sortBy_short_names___runtime_expectation.json b/core/src/test/resources/HistoryServerExpectations/stage_task_list_w__sortBy_short_names___runtime_expectation.json index bc1cd49909d31..82e339c8f56dd 100644 --- a/core/src/test/resources/HistoryServerExpectations/stage_task_list_w__sortBy_short_names___runtime_expectation.json +++ b/core/src/test/resources/HistoryServerExpectations/stage_task_list_w__sortBy_short_names___runtime_expectation.json @@ -43,7 +43,10 @@ "writeTime" : 4016617, "recordsWritten" : 10 } - } + }, + "executorLogs" : { }, + "schedulerDelay" : 38, + "gettingResultTime" : 0 }, { "taskId" : 5, "index" : 5, @@ -89,7 +92,10 @@ "writeTime" : 3675510, "recordsWritten" : 10 } - } + }, + "executorLogs" : { }, + "schedulerDelay" : 33, + "gettingResultTime" : 0 }, { "taskId" : 1, "index" : 1, @@ -135,7 +141,10 @@ "writeTime" : 3934399, "recordsWritten" : 10 } - } + }, + "executorLogs" : { }, + "schedulerDelay" : 40, + "gettingResultTime" : 0 }, { "taskId" : 7, "index" : 7, @@ -181,7 +190,10 @@ "writeTime" : 2579051, "recordsWritten" : 10 } - } + }, + "executorLogs" : { }, + "schedulerDelay" : 43, + "gettingResultTime" : 0 }, { "taskId" : 4, "index" : 4, @@ -227,7 +239,10 @@ "writeTime" : 83022, "recordsWritten" : 10 } - } + }, + "executorLogs" : { }, + "schedulerDelay" : 38, + "gettingResultTime" : 0 }, { "taskId" : 3, "index" : 3, @@ -273,7 +288,10 @@ "writeTime" : 1311694, "recordsWritten" : 10 } - } + }, + "executorLogs" : { }, + "schedulerDelay" : 41, + "gettingResultTime" : 0 }, { "taskId" : 0, "index" : 0, @@ -319,7 +337,10 @@ "writeTime" : 3842811, "recordsWritten" : 10 } - } + }, + "executorLogs" : { }, + "schedulerDelay" : 53, + "gettingResultTime" : 0 }, { "taskId" : 2, "index" : 2, @@ -365,7 +386,10 @@ "writeTime" : 89885, "recordsWritten" : 10 } - } + }, + "executorLogs" : { }, + "schedulerDelay" : 37, + "gettingResultTime" : 0 }, { "taskId" : 22, "index" : 22, @@ -411,7 +435,10 @@ "writeTime" : 143427, "recordsWritten" : 10 } - } + }, + "executorLogs" : { }, + "schedulerDelay" : 3, + "gettingResultTime" : 0 }, { "taskId" : 18, "index" : 18, @@ -457,7 +484,10 @@ "writeTime" : 100836, "recordsWritten" : 10 } - } + }, + "executorLogs" : { }, + "schedulerDelay" : 3, + "gettingResultTime" : 0 }, { "taskId" : 17, "index" : 17, @@ -503,7 +533,10 @@ "writeTime" : 99944, "recordsWritten" : 10 } - } + }, + "executorLogs" : { }, + "schedulerDelay" : 20, + "gettingResultTime" : 0 }, { "taskId" : 21, "index" : 21, @@ -549,7 +582,10 @@ "writeTime" : 100270, "recordsWritten" : 10 } - } + }, + "executorLogs" : { }, + "schedulerDelay" : 3, + "gettingResultTime" : 0 }, { "taskId" : 19, "index" : 19, @@ -595,7 +631,10 @@ "writeTime" : 95788, "recordsWritten" : 10 } - } + }, + "executorLogs" : { }, + "schedulerDelay" : 5, + "gettingResultTime" : 0 }, { "taskId" : 16, "index" : 16, @@ -641,7 +680,10 @@ "writeTime" : 108320, "recordsWritten" : 10 } - } + }, + "executorLogs" : { }, + "schedulerDelay" : 4, + "gettingResultTime" : 0 }, { "taskId" : 9, "index" : 9, @@ -687,7 +729,10 @@ "writeTime" : 101664, "recordsWritten" : 10 } - } + }, + "executorLogs" : { }, + "schedulerDelay" : 8, + "gettingResultTime" : 0 }, { "taskId" : 20, "index" : 20, @@ -733,7 +778,10 @@ "writeTime" : 97716, "recordsWritten" : 10 } - } + }, + "executorLogs" : { }, + "schedulerDelay" : 4, + "gettingResultTime" : 0 }, { "taskId" : 14, "index" : 14, @@ -779,7 +827,10 @@ "writeTime" : 95646, "recordsWritten" : 10 } - } + }, + "executorLogs" : { }, + "schedulerDelay" : 5, + "gettingResultTime" : 0 }, { "taskId" : 8, "index" : 8, @@ -825,7 +876,10 @@ "writeTime" : 121551, "recordsWritten" : 10 } - } + }, + "executorLogs" : { }, + "schedulerDelay" : 5, + "gettingResultTime" : 0 }, { "taskId" : 12, "index" : 12, @@ -871,7 +925,10 @@ "writeTime" : 102476, "recordsWritten" : 10 } - } + }, + "executorLogs" : { }, + "schedulerDelay" : 7, + "gettingResultTime" : 0 }, { "taskId" : 15, "index" : 15, @@ -917,5 +974,8 @@ "writeTime" : 602780, "recordsWritten" : 10 } - } + }, + "executorLogs" : { }, + "schedulerDelay" : 4, + "gettingResultTime" : 0 } ] diff --git a/core/src/test/resources/HistoryServerExpectations/stage_task_list_w__sortBy_short_names__runtime_expectation.json b/core/src/test/resources/HistoryServerExpectations/stage_task_list_w__sortBy_short_names__runtime_expectation.json index 09857cb401acd..01eef1b565bf6 100644 --- a/core/src/test/resources/HistoryServerExpectations/stage_task_list_w__sortBy_short_names__runtime_expectation.json +++ b/core/src/test/resources/HistoryServerExpectations/stage_task_list_w__sortBy_short_names__runtime_expectation.json @@ -43,7 +43,10 @@ "writeTime" : 94792, "recordsWritten" : 10 } - } + }, + "executorLogs" : { }, + "schedulerDelay" : 6, + "gettingResultTime" : 0 }, { "taskId" : 41, "index" : 41, @@ -89,7 +92,10 @@ "writeTime" : 90765, "recordsWritten" : 10 } - } + }, + "executorLogs" : { }, + "schedulerDelay" : 6, + "gettingResultTime" : 0 }, { "taskId" : 43, "index" : 43, @@ -135,7 +141,10 @@ "writeTime" : 171516, "recordsWritten" : 10 } - } + }, + "executorLogs" : { }, + "schedulerDelay" : 21, + "gettingResultTime" : 0 }, { "taskId" : 57, "index" : 57, @@ -181,7 +190,10 @@ "writeTime" : 96849, "recordsWritten" : 10 } - } + }, + "executorLogs" : { }, + "schedulerDelay" : 2, + "gettingResultTime" : 0 }, { "taskId" : 58, "index" : 58, @@ -227,7 +239,10 @@ "writeTime" : 97521, "recordsWritten" : 10 } - } + }, + "executorLogs" : { }, + "schedulerDelay" : 4, + "gettingResultTime" : 0 }, { "taskId" : 68, "index" : 68, @@ -273,7 +288,10 @@ "writeTime" : 101750, "recordsWritten" : 10 } - } + }, + "executorLogs" : { }, + "schedulerDelay" : 4, + "gettingResultTime" : 0 }, { "taskId" : 86, "index" : 86, @@ -319,7 +337,10 @@ "writeTime" : 95848, "recordsWritten" : 10 } - } + }, + "executorLogs" : { }, + "schedulerDelay" : 8, + "gettingResultTime" : 0 }, { "taskId" : 32, "index" : 32, @@ -365,7 +386,10 @@ "writeTime" : 89603, "recordsWritten" : 10 } - } + }, + "executorLogs" : { }, + "schedulerDelay" : 13, + "gettingResultTime" : 0 }, { "taskId" : 39, "index" : 39, @@ -411,7 +435,10 @@ "writeTime" : 98748, "recordsWritten" : 10 } - } + }, + "executorLogs" : { }, + "schedulerDelay" : 13, + "gettingResultTime" : 0 }, { "taskId" : 42, "index" : 42, @@ -457,7 +484,10 @@ "writeTime" : 103713, "recordsWritten" : 10 } - } + }, + "executorLogs" : { }, + "schedulerDelay" : 15, + "gettingResultTime" : 0 }, { "taskId" : 51, "index" : 51, @@ -503,7 +533,10 @@ "writeTime" : 96013, "recordsWritten" : 10 } - } + }, + "executorLogs" : { }, + "schedulerDelay" : 2, + "gettingResultTime" : 0 }, { "taskId" : 59, "index" : 59, @@ -549,7 +582,10 @@ "writeTime" : 100753, "recordsWritten" : 10 } - } + }, + "executorLogs" : { }, + "schedulerDelay" : 3, + "gettingResultTime" : 0 }, { "taskId" : 63, "index" : 63, @@ -595,7 +631,10 @@ "writeTime" : 102779, "recordsWritten" : 10 } - } + }, + "executorLogs" : { }, + "schedulerDelay" : 3, + "gettingResultTime" : 0 }, { "taskId" : 87, "index" : 87, @@ -641,7 +680,10 @@ "writeTime" : 102159, "recordsWritten" : 10 } - } + }, + "executorLogs" : { }, + "schedulerDelay" : 7, + "gettingResultTime" : 0 }, { "taskId" : 90, "index" : 90, @@ -687,7 +729,10 @@ "writeTime" : 98472, "recordsWritten" : 10 } - } + }, + "executorLogs" : { }, + "schedulerDelay" : 4, + "gettingResultTime" : 0 }, { "taskId" : 99, "index" : 99, @@ -733,7 +778,10 @@ "writeTime" : 133964, "recordsWritten" : 10 } - } + }, + "executorLogs" : { }, + "schedulerDelay" : 3, + "gettingResultTime" : 0 }, { "taskId" : 44, "index" : 44, @@ -779,7 +827,10 @@ "writeTime" : 98293, "recordsWritten" : 10 } - } + }, + "executorLogs" : { }, + "schedulerDelay" : 16, + "gettingResultTime" : 0 }, { "taskId" : 47, "index" : 47, @@ -825,7 +876,10 @@ "writeTime" : 103015, "recordsWritten" : 10 } - } + }, + "executorLogs" : { }, + "schedulerDelay" : 13, + "gettingResultTime" : 0 }, { "taskId" : 50, "index" : 50, @@ -871,7 +925,10 @@ "writeTime" : 90836, "recordsWritten" : 10 } - } + }, + "executorLogs" : { }, + "schedulerDelay" : 4, + "gettingResultTime" : 0 }, { "taskId" : 52, "index" : 52, @@ -917,5 +974,8 @@ "writeTime" : 89664, "recordsWritten" : 10 } - } + }, + "executorLogs" : { }, + "schedulerDelay" : 5, + "gettingResultTime" : 0 } ] diff --git a/core/src/test/resources/HistoryServerExpectations/stage_with_accumulable_json_expectation.json b/core/src/test/resources/HistoryServerExpectations/stage_with_accumulable_json_expectation.json index 963f010968b62..a8e1fd303a42a 100644 --- a/core/src/test/resources/HistoryServerExpectations/stage_with_accumulable_json_expectation.json +++ b/core/src/test/resources/HistoryServerExpectations/stage_with_accumulable_json_expectation.json @@ -83,14 +83,17 @@ "writeTime" : 0, "recordsWritten" : 0 } - } + }, + "executorLogs" : { }, + "schedulerDelay" : 30, + "gettingResultTime" : 0 }, - "1" : { - "taskId" : 1, - "index" : 1, + "5" : { + "taskId" : 5, + "index" : 5, "attempt" : 0, - "launchTime" : "2015-03-16T19:25:36.521GMT", - "duration" : 53, + "launchTime" : "2015-03-16T19:25:36.523GMT", + "duration" : 52, "executorId" : "", "host" : "localhost", "status" : "SUCCESS", @@ -99,11 +102,11 @@ "accumulatorUpdates" : [ { "id" : 1, "name" : "my counter", - "update" : "247", - "value" : "2175" + "update" : "897", + "value" : "3750" } ], "taskMetrics" : { - "executorDeserializeTime" : 14, + "executorDeserializeTime" : 12, "executorDeserializeCpuTime" : 0, "executorRunTime" : 15, "executorCpuTime" : 0, @@ -135,14 +138,17 @@ "writeTime" : 0, "recordsWritten" : 0 } - } + }, + "executorLogs" : { }, + "schedulerDelay" : 23, + "gettingResultTime" : 0 }, - "2" : { - "taskId" : 2, - "index" : 2, + "1" : { + "taskId" : 1, + "index" : 1, "attempt" : 0, - "launchTime" : "2015-03-16T19:25:36.522GMT", - "duration" : 48, + "launchTime" : "2015-03-16T19:25:36.521GMT", + "duration" : 53, "executorId" : "", "host" : "localhost", "status" : "SUCCESS", @@ -151,11 +157,11 @@ "accumulatorUpdates" : [ { "id" : 1, "name" : "my counter", - "update" : "378", - "value" : "378" + "update" : "247", + "value" : "2175" } ], "taskMetrics" : { - "executorDeserializeTime" : 13, + "executorDeserializeTime" : 14, "executorDeserializeCpuTime" : 0, "executorRunTime" : 15, "executorCpuTime" : 0, @@ -187,14 +193,17 @@ "writeTime" : 0, "recordsWritten" : 0 } - } + }, + "executorLogs" : { }, + "schedulerDelay" : 22, + "gettingResultTime" : 0 }, - "3" : { - "taskId" : 3, - "index" : 3, + "6" : { + "taskId" : 6, + "index" : 6, "attempt" : 0, - "launchTime" : "2015-03-16T19:25:36.522GMT", - "duration" : 50, + "launchTime" : "2015-03-16T19:25:36.523GMT", + "duration" : 51, "executorId" : "", "host" : "localhost", "status" : "SUCCESS", @@ -203,11 +212,11 @@ "accumulatorUpdates" : [ { "id" : 1, "name" : "my counter", - "update" : "572", - "value" : "950" + "update" : "978", + "value" : "1928" } ], "taskMetrics" : { - "executorDeserializeTime" : 13, + "executorDeserializeTime" : 12, "executorDeserializeCpuTime" : 0, "executorRunTime" : 15, "executorCpuTime" : 0, @@ -239,14 +248,17 @@ "writeTime" : 0, "recordsWritten" : 0 } - } + }, + "executorLogs" : { }, + "schedulerDelay" : 22, + "gettingResultTime" : 0 }, - "4" : { - "taskId" : 4, - "index" : 4, + "2" : { + "taskId" : 2, + "index" : 2, "attempt" : 0, "launchTime" : "2015-03-16T19:25:36.522GMT", - "duration" : 52, + "duration" : 48, "executorId" : "", "host" : "localhost", "status" : "SUCCESS", @@ -255,17 +267,17 @@ "accumulatorUpdates" : [ { "id" : 1, "name" : "my counter", - "update" : "678", - "value" : "2853" + "update" : "378", + "value" : "378" } ], "taskMetrics" : { - "executorDeserializeTime" : 12, + "executorDeserializeTime" : 13, "executorDeserializeCpuTime" : 0, "executorRunTime" : 15, "executorCpuTime" : 0, "resultSize" : 697, "jvmGcTime" : 0, - "resultSerializationTime" : 1, + "resultSerializationTime" : 2, "memoryBytesSpilled" : 0, "diskBytesSpilled" : 0, "peakExecutionMemory" : 0, @@ -291,14 +303,17 @@ "writeTime" : 0, "recordsWritten" : 0 } - } + }, + "executorLogs" : { }, + "schedulerDelay" : 18, + "gettingResultTime" : 0 }, - "5" : { - "taskId" : 5, - "index" : 5, + "7" : { + "taskId" : 7, + "index" : 7, "attempt" : 0, - "launchTime" : "2015-03-16T19:25:36.523GMT", - "duration" : 52, + "launchTime" : "2015-03-16T19:25:36.524GMT", + "duration" : 51, "executorId" : "", "host" : "localhost", "status" : "SUCCESS", @@ -307,8 +322,8 @@ "accumulatorUpdates" : [ { "id" : 1, "name" : "my counter", - "update" : "897", - "value" : "3750" + "update" : "1222", + "value" : "4972" } ], "taskMetrics" : { "executorDeserializeTime" : 12, @@ -343,14 +358,17 @@ "writeTime" : 0, "recordsWritten" : 0 } - } + }, + "executorLogs" : { }, + "schedulerDelay" : 22, + "gettingResultTime" : 0 }, - "6" : { - "taskId" : 6, - "index" : 6, + "3" : { + "taskId" : 3, + "index" : 3, "attempt" : 0, - "launchTime" : "2015-03-16T19:25:36.523GMT", - "duration" : 51, + "launchTime" : "2015-03-16T19:25:36.522GMT", + "duration" : 50, "executorId" : "", "host" : "localhost", "status" : "SUCCESS", @@ -359,11 +377,11 @@ "accumulatorUpdates" : [ { "id" : 1, "name" : "my counter", - "update" : "978", - "value" : "1928" + "update" : "572", + "value" : "950" } ], "taskMetrics" : { - "executorDeserializeTime" : 12, + "executorDeserializeTime" : 13, "executorDeserializeCpuTime" : 0, "executorRunTime" : 15, "executorCpuTime" : 0, @@ -395,14 +413,17 @@ "writeTime" : 0, "recordsWritten" : 0 } - } + }, + "executorLogs" : { }, + "schedulerDelay" : 20, + "gettingResultTime" : 0 }, - "7" : { - "taskId" : 7, - "index" : 7, + "4" : { + "taskId" : 4, + "index" : 4, "attempt" : 0, - "launchTime" : "2015-03-16T19:25:36.524GMT", - "duration" : 51, + "launchTime" : "2015-03-16T19:25:36.522GMT", + "duration" : 52, "executorId" : "", "host" : "localhost", "status" : "SUCCESS", @@ -411,8 +432,8 @@ "accumulatorUpdates" : [ { "id" : 1, "name" : "my counter", - "update" : "1222", - "value" : "4972" + "update" : "678", + "value" : "2853" } ], "taskMetrics" : { "executorDeserializeTime" : 12, @@ -421,7 +442,7 @@ "executorCpuTime" : 0, "resultSize" : 697, "jvmGcTime" : 0, - "resultSerializationTime" : 2, + "resultSerializationTime" : 1, "memoryBytesSpilled" : 0, "diskBytesSpilled" : 0, "peakExecutionMemory" : 0, @@ -447,7 +468,10 @@ "writeTime" : 0, "recordsWritten" : 0 } - } + }, + "executorLogs" : { }, + "schedulerDelay" : 24, + "gettingResultTime" : 0 } }, "executorSummary" : { diff --git a/core/src/test/scala/org/apache/spark/FailureSuite.scala b/core/src/test/scala/org/apache/spark/FailureSuite.scala index d805c67714ff8..f2d97d452ddb0 100644 --- a/core/src/test/scala/org/apache/spark/FailureSuite.scala +++ b/core/src/test/scala/org/apache/spark/FailureSuite.scala @@ -257,7 +257,9 @@ class FailureSuite extends SparkFunSuite with LocalSparkContext { sc = new SparkContext("local[1,2]", "test") intercept[SparkException] { sc.parallelize(1 to 2).foreach { i => + // scalastyle:off throwerror throw new LinkageError() + // scalastyle:on throwerror } } } diff --git a/core/src/test/scala/org/apache/spark/FileSuite.scala b/core/src/test/scala/org/apache/spark/FileSuite.scala index 34efcdf4bc886..df04a5ea1d99e 100644 --- a/core/src/test/scala/org/apache/spark/FileSuite.scala +++ b/core/src/test/scala/org/apache/spark/FileSuite.scala @@ -202,7 +202,7 @@ class FileSuite extends SparkFunSuite with LocalSparkContext { sc = new SparkContext("local", "test") val objs = sc.makeRDD(1 to 3).map { x => val loader = Thread.currentThread().getContextClassLoader - Class.forName(className, true, loader).newInstance() + Class.forName(className, true, loader).getConstructor().newInstance() } val outputDir = new File(tempDir, "output").getAbsolutePath objs.saveAsObjectFile(outputDir) diff --git a/core/src/test/scala/org/apache/spark/ShuffleSuite.scala b/core/src/test/scala/org/apache/spark/ShuffleSuite.scala index b917469e48747..35f728cd57fe2 100644 --- a/core/src/test/scala/org/apache/spark/ShuffleSuite.scala +++ b/core/src/test/scala/org/apache/spark/ShuffleSuite.scala @@ -362,15 +362,19 @@ abstract class ShuffleSuite extends SparkFunSuite with Matchers with LocalSparkC mapTrackerMaster.registerShuffle(0, 1) // first attempt -- its successful - val writer1 = manager.getWriter[Int, Int](shuffleHandle, 0, - new TaskContextImpl(0, 0, 0, 0L, 0, taskMemoryManager, new Properties, metricsSystem)) + val context1 = + new TaskContextImpl(0, 0, 0, 0L, 0, taskMemoryManager, new Properties, metricsSystem) + val writer1 = manager.getWriter[Int, Int]( + shuffleHandle, 0, context1, context1.taskMetrics.shuffleWriteMetrics) val data1 = (1 to 10).map { x => x -> x} // second attempt -- also successful. We'll write out different data, // just to simulate the fact that the records may get written differently // depending on what gets spilled, what gets combined, etc. - val writer2 = manager.getWriter[Int, Int](shuffleHandle, 0, - new TaskContextImpl(0, 0, 0, 1L, 0, taskMemoryManager, new Properties, metricsSystem)) + val context2 = + new TaskContextImpl(0, 0, 0, 1L, 0, taskMemoryManager, new Properties, metricsSystem) + val writer2 = manager.getWriter[Int, Int]( + shuffleHandle, 0, context2, context2.taskMetrics.shuffleWriteMetrics) val data2 = (11 to 20).map { x => x -> x} // interleave writes of both attempts -- we want to test that both attempts can occur @@ -397,8 +401,10 @@ abstract class ShuffleSuite extends SparkFunSuite with Matchers with LocalSparkC mapTrackerMaster.registerMapOutput(0, 0, mapStatus) } - val reader = manager.getReader[Int, Int](shuffleHandle, 0, 1, - new TaskContextImpl(1, 0, 0, 2L, 0, taskMemoryManager, new Properties, metricsSystem)) + val taskContext = new TaskContextImpl( + 1, 0, 0, 2L, 0, taskMemoryManager, new Properties, metricsSystem) + val metrics = taskContext.taskMetrics.createTempShuffleReadMetrics() + val reader = manager.getReader[Int, Int](shuffleHandle, 0, 1, taskContext, metrics) val readData = reader.read().toIndexedSeq assert(readData === data1.toIndexedSeq || readData === data2.toIndexedSeq) diff --git a/core/src/test/scala/org/apache/spark/benchmark/BenchmarkBase.scala b/core/src/test/scala/org/apache/spark/benchmark/BenchmarkBase.scala index 24e596e1ecdaf..a6666db4e95c3 100644 --- a/core/src/test/scala/org/apache/spark/benchmark/BenchmarkBase.scala +++ b/core/src/test/scala/org/apache/spark/benchmark/BenchmarkBase.scala @@ -58,5 +58,12 @@ abstract class BenchmarkBase { o.close() } } + + afterAll() } + + /** + * Any shutdown code to ensure a clean shutdown + */ + def afterAll(): Unit = {} } diff --git a/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala b/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala index 652c36ffa6e71..c093789244bfe 100644 --- a/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala +++ b/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala @@ -962,6 +962,25 @@ class SparkSubmitSuite } } + test("remove copies of application jar from classpath") { + val fs = File.separator + val sparkConf = new SparkConf(false) + val hadoopConf = new Configuration() + val secMgr = new SecurityManager(sparkConf) + + val appJarName = "myApp.jar" + val jar1Name = "myJar1.jar" + val jar2Name = "myJar2.jar" + val userJar = s"file:/path${fs}to${fs}app${fs}jar$fs$appJarName" + val jars = s"file:/$jar1Name,file:/$appJarName,file:/$jar2Name" + + val resolvedJars = DependencyUtils + .resolveAndDownloadJars(jars, userJar, sparkConf, hadoopConf, secMgr) + + assert(!resolvedJars.contains(appJarName)) + assert(resolvedJars.contains(jar1Name) && resolvedJars.contains(jar2Name)) + } + test("Avoid re-upload remote resources in yarn client mode") { val hadoopConf = new Configuration() updateConfWithFakeS3Fs(hadoopConf) diff --git a/core/src/test/scala/org/apache/spark/deploy/history/FsHistoryProviderSuite.scala b/core/src/test/scala/org/apache/spark/deploy/history/FsHistoryProviderSuite.scala index 87d585278a747..b0ced46c9c7b8 100644 --- a/core/src/test/scala/org/apache/spark/deploy/history/FsHistoryProviderSuite.scala +++ b/core/src/test/scala/org/apache/spark/deploy/history/FsHistoryProviderSuite.scala @@ -23,10 +23,12 @@ import java.util.{Date, Locale} import java.util.concurrent.TimeUnit import java.util.zip.{ZipInputStream, ZipOutputStream} +import scala.collection.JavaConverters._ import scala.concurrent.duration._ import scala.language.postfixOps import com.google.common.io.{ByteStreams, Files} +import org.apache.commons.io.FileUtils import org.apache.hadoop.fs.{FileStatus, FileSystem, FSDataInputStream, Path} import org.apache.hadoop.hdfs.{DFSInputStream, DistributedFileSystem} import org.apache.hadoop.security.AccessControlException @@ -40,6 +42,7 @@ import org.scalatest.concurrent.Eventually._ import org.apache.spark.{SecurityManager, SparkConf, SparkFunSuite} import org.apache.spark.internal.Logging +import org.apache.spark.internal.config.DRIVER_LOG_DFS_DIR import org.apache.spark.internal.config.History._ import org.apache.spark.io._ import org.apache.spark.scheduler._ @@ -47,6 +50,7 @@ import org.apache.spark.security.GroupMappingServiceProvider import org.apache.spark.status.AppStatusStore import org.apache.spark.status.api.v1.{ApplicationAttemptInfo, ApplicationInfo} import org.apache.spark.util.{Clock, JsonProtocol, ManualClock, Utils} +import org.apache.spark.util.logging.DriverLogger class FsHistoryProviderSuite extends SparkFunSuite with BeforeAndAfter with Matchers with Logging { @@ -413,6 +417,63 @@ class FsHistoryProviderSuite extends SparkFunSuite with BeforeAndAfter with Matc } } + test("driver log cleaner") { + val firstFileModifiedTime = TimeUnit.SECONDS.toMillis(10) + val secondFileModifiedTime = TimeUnit.SECONDS.toMillis(20) + val maxAge = TimeUnit.SECONDS.toSeconds(40) + val clock = new ManualClock(0) + val testConf = new SparkConf() + testConf.set("spark.history.fs.logDirectory", + Utils.createTempDir(namePrefix = "eventLog").getAbsolutePath()) + testConf.set(DRIVER_LOG_DFS_DIR, testDir.getAbsolutePath()) + testConf.set(DRIVER_LOG_CLEANER_ENABLED, true) + testConf.set(DRIVER_LOG_CLEANER_INTERVAL, maxAge / 4) + testConf.set(MAX_DRIVER_LOG_AGE_S, maxAge) + val provider = new FsHistoryProvider(testConf, clock) + + val log1 = FileUtils.getFile(testDir, "1" + DriverLogger.DRIVER_LOG_FILE_SUFFIX) + createEmptyFile(log1) + clock.setTime(firstFileModifiedTime) + log1.setLastModified(clock.getTimeMillis()) + provider.cleanDriverLogs() + + val log2 = FileUtils.getFile(testDir, "2" + DriverLogger.DRIVER_LOG_FILE_SUFFIX) + createEmptyFile(log2) + val log3 = FileUtils.getFile(testDir, "3" + DriverLogger.DRIVER_LOG_FILE_SUFFIX) + createEmptyFile(log3) + clock.setTime(secondFileModifiedTime) + log2.setLastModified(clock.getTimeMillis()) + log3.setLastModified(clock.getTimeMillis()) + // This should not trigger any cleanup + provider.cleanDriverLogs() + provider.listing.view(classOf[LogInfo]).iterator().asScala.toSeq.size should be(3) + + // Should trigger cleanup for first file but not second one + clock.setTime(firstFileModifiedTime + TimeUnit.SECONDS.toMillis(maxAge) + 1) + provider.cleanDriverLogs() + provider.listing.view(classOf[LogInfo]).iterator().asScala.toSeq.size should be(2) + assert(!log1.exists()) + assert(log2.exists()) + assert(log3.exists()) + + // Update the third file length while keeping the original modified time + Files.write("Add logs to file".getBytes(), log3) + log3.setLastModified(secondFileModifiedTime) + // Should cleanup the second file but not the third file, as filelength changed. + clock.setTime(secondFileModifiedTime + TimeUnit.SECONDS.toMillis(maxAge) + 1) + provider.cleanDriverLogs() + provider.listing.view(classOf[LogInfo]).iterator().asScala.toSeq.size should be(1) + assert(!log1.exists()) + assert(!log2.exists()) + assert(log3.exists()) + + // Should cleanup the third file as well. + clock.setTime(secondFileModifiedTime + 2 * TimeUnit.SECONDS.toMillis(maxAge) + 2) + provider.cleanDriverLogs() + provider.listing.view(classOf[LogInfo]).iterator().asScala.toSeq.size should be(0) + assert(!log3.exists()) + } + test("SPARK-8372: new logs with no app ID are ignored") { val provider = new FsHistoryProvider(createTestConf()) @@ -869,16 +930,19 @@ class FsHistoryProviderSuite extends SparkFunSuite with BeforeAndAfter with Matc when(dfsIn.getFileLength).thenReturn(200) // FileStatus.getLen is more than logInfo fileSize var fileStatus = new FileStatus(200, false, 0, 0, 0, path) - var logInfo = new LogInfo(path.toString, 0, Some("appId"), Some("attemptId"), 100) + var logInfo = new LogInfo(path.toString, 0, LogType.EventLogs, Some("appId"), + Some("attemptId"), 100) assert(mockedProvider.shouldReloadLog(logInfo, fileStatus)) fileStatus = new FileStatus() fileStatus.setPath(path) // DFSInputStream.getFileLength is more than logInfo fileSize - logInfo = new LogInfo(path.toString, 0, Some("appId"), Some("attemptId"), 100) + logInfo = new LogInfo(path.toString, 0, LogType.EventLogs, Some("appId"), + Some("attemptId"), 100) assert(mockedProvider.shouldReloadLog(logInfo, fileStatus)) // DFSInputStream.getFileLength is equal to logInfo fileSize - logInfo = new LogInfo(path.toString, 0, Some("appId"), Some("attemptId"), 200) + logInfo = new LogInfo(path.toString, 0, LogType.EventLogs, Some("appId"), + Some("attemptId"), 200) assert(!mockedProvider.shouldReloadLog(logInfo, fileStatus)) // in.getWrappedStream returns other than DFSInputStream val bin = mock(classOf[BufferedInputStream]) diff --git a/core/src/test/scala/org/apache/spark/deploy/rest/StandaloneRestSubmitSuite.scala b/core/src/test/scala/org/apache/spark/deploy/rest/StandaloneRestSubmitSuite.scala index 4839c842cc785..89b8bb4ff7d03 100644 --- a/core/src/test/scala/org/apache/spark/deploy/rest/StandaloneRestSubmitSuite.scala +++ b/core/src/test/scala/org/apache/spark/deploy/rest/StandaloneRestSubmitSuite.scala @@ -396,6 +396,18 @@ class StandaloneRestSubmitSuite extends SparkFunSuite with BeforeAndAfterEach { assert(filteredVariables == Map("SPARK_VAR" -> "1")) } + test("client does not send 'SPARK_HOME' env var by default") { + val environmentVariables = Map("SPARK_VAR" -> "1", "SPARK_HOME" -> "1") + val filteredVariables = RestSubmissionClient.filterSystemEnvironment(environmentVariables) + assert(filteredVariables == Map("SPARK_VAR" -> "1")) + } + + test("client does not send 'SPARK_CONF_DIR' env var by default") { + val environmentVariables = Map("SPARK_VAR" -> "1", "SPARK_CONF_DIR" -> "1") + val filteredVariables = RestSubmissionClient.filterSystemEnvironment(environmentVariables) + assert(filteredVariables == Map("SPARK_VAR" -> "1")) + } + test("client includes mesos env vars") { val environmentVariables = Map("SPARK_VAR" -> "1", "MESOS_VAR" -> "1", "OTHER_VAR" -> "1") val filteredVariables = RestSubmissionClient.filterSystemEnvironment(environmentVariables) diff --git a/core/src/test/scala/org/apache/spark/executor/ExecutorSuite.scala b/core/src/test/scala/org/apache/spark/executor/ExecutorSuite.scala index 1f8a65707b2f7..32a94e60484e3 100644 --- a/core/src/test/scala/org/apache/spark/executor/ExecutorSuite.scala +++ b/core/src/test/scala/org/apache/spark/executor/ExecutorSuite.scala @@ -467,7 +467,9 @@ class FetchFailureHidingRDD( } catch { case t: Throwable => if (throwOOM) { + // scalastyle:off throwerror throw new OutOfMemoryError("OOM while handling another exception") + // scalastyle:on throwerror } else if (interrupt) { // make sure our test is setup correctly assert(TaskContext.get().asInstanceOf[TaskContextImpl].fetchFailed.isDefined) diff --git a/core/src/test/scala/org/apache/spark/io/ChunkedByteBufferSuite.scala b/core/src/test/scala/org/apache/spark/io/ChunkedByteBufferSuite.scala index ff117b1c21cb1..083c5e696b753 100644 --- a/core/src/test/scala/org/apache/spark/io/ChunkedByteBufferSuite.scala +++ b/core/src/test/scala/org/apache/spark/io/ChunkedByteBufferSuite.scala @@ -90,7 +90,7 @@ class ChunkedByteBufferSuite extends SparkFunSuite with SharedSparkContext { val empty = ByteBuffer.wrap(Array.empty[Byte]) val bytes1 = ByteBuffer.wrap(Array.tabulate(256)(_.toByte)) val bytes2 = ByteBuffer.wrap(Array.tabulate(128)(_.toByte)) - val chunkedByteBuffer = new ChunkedByteBuffer(Array(empty, bytes1, bytes2)) + val chunkedByteBuffer = new ChunkedByteBuffer(Array(empty, bytes1, empty, bytes2)) assert(chunkedByteBuffer.size === bytes1.limit() + bytes2.limit()) val inputStream = chunkedByteBuffer.toInputStream(dispose = false) diff --git a/core/src/test/scala/org/apache/spark/rdd/PairRDDFunctionsSuite.scala b/core/src/test/scala/org/apache/spark/rdd/PairRDDFunctionsSuite.scala index 47af5c3320dd9..0ec359d1c94f3 100644 --- a/core/src/test/scala/org/apache/spark/rdd/PairRDDFunctionsSuite.scala +++ b/core/src/test/scala/org/apache/spark/rdd/PairRDDFunctionsSuite.scala @@ -574,7 +574,7 @@ class PairRDDFunctionsSuite extends SparkFunSuite with SharedSparkContext { } test("saveNewAPIHadoopFile should call setConf if format is configurable") { - val pairs = sc.parallelize(Array((new Integer(1), new Integer(1)))) + val pairs = sc.parallelize(Array((Integer.valueOf(1), Integer.valueOf(1)))) // No error, non-configurable formats still work pairs.saveAsNewAPIHadoopFile[NewFakeFormat]("ignored") @@ -591,14 +591,14 @@ class PairRDDFunctionsSuite extends SparkFunSuite with SharedSparkContext { test("The JobId on the driver and executors should be the same during the commit") { // Create more than one rdd to mimic stageId not equal to rddId val pairs = sc.parallelize(Array((1, 2), (2, 3)), 2) - .map { p => (new Integer(p._1 + 1), new Integer(p._2 + 1)) } + .map { p => (Integer.valueOf(p._1 + 1), Integer.valueOf(p._2 + 1)) } .filter { p => p._1 > 0 } pairs.saveAsNewAPIHadoopFile[YetAnotherFakeFormat]("ignored") assert(JobID.jobid != -1) } test("saveAsHadoopFile should respect configured output committers") { - val pairs = sc.parallelize(Array((new Integer(1), new Integer(1)))) + val pairs = sc.parallelize(Array((Integer.valueOf(1), Integer.valueOf(1)))) val conf = new JobConf() conf.setOutputCommitter(classOf[FakeOutputCommitter]) @@ -610,7 +610,7 @@ class PairRDDFunctionsSuite extends SparkFunSuite with SharedSparkContext { } test("failure callbacks should be called before calling writer.close() in saveNewAPIHadoopFile") { - val pairs = sc.parallelize(Array((new Integer(1), new Integer(2))), 1) + val pairs = sc.parallelize(Array((Integer.valueOf(1), Integer.valueOf(2))), 1) FakeWriterWithCallback.calledBy = "" FakeWriterWithCallback.exception = null @@ -625,7 +625,7 @@ class PairRDDFunctionsSuite extends SparkFunSuite with SharedSparkContext { } test("failure callbacks should be called before calling writer.close() in saveAsHadoopFile") { - val pairs = sc.parallelize(Array((new Integer(1), new Integer(2))), 1) + val pairs = sc.parallelize(Array((Integer.valueOf(1), Integer.valueOf(2))), 1) val conf = new JobConf() FakeWriterWithCallback.calledBy = "" @@ -643,7 +643,7 @@ class PairRDDFunctionsSuite extends SparkFunSuite with SharedSparkContext { test("saveAsNewAPIHadoopDataset should support invalid output paths when " + "there are no files to be committed to an absolute output location") { - val pairs = sc.parallelize(Array((new Integer(1), new Integer(2))), 1) + val pairs = sc.parallelize(Array((Integer.valueOf(1), Integer.valueOf(2))), 1) def saveRddWithPath(path: String): Unit = { val job = NewJob.getInstance(new Configuration(sc.hadoopConfiguration)) @@ -671,7 +671,7 @@ class PairRDDFunctionsSuite extends SparkFunSuite with SharedSparkContext { // for non-null invalid paths. test("saveAsHadoopDataset should respect empty output directory when " + "there are no files to be committed to an absolute output location") { - val pairs = sc.parallelize(Array((new Integer(1), new Integer(2))), 1) + val pairs = sc.parallelize(Array((Integer.valueOf(1), Integer.valueOf(2))), 1) val conf = new JobConf() conf.setOutputKeyClass(classOf[Integer]) diff --git a/core/src/test/scala/org/apache/spark/rdd/ParallelCollectionSplitSuite.scala b/core/src/test/scala/org/apache/spark/rdd/ParallelCollectionSplitSuite.scala index 31ce9483cf20a..424d9f825c465 100644 --- a/core/src/test/scala/org/apache/spark/rdd/ParallelCollectionSplitSuite.scala +++ b/core/src/test/scala/org/apache/spark/rdd/ParallelCollectionSplitSuite.scala @@ -215,7 +215,7 @@ class ParallelCollectionSplitSuite extends SparkFunSuite with Checkers { } test("exclusive ranges of doubles") { - val data = 1.0 until 100.0 by 1.0 + val data = Range.BigDecimal(1, 100, 1) val slices = ParallelCollectionRDD.slice(data, 3) assert(slices.size === 3) assert(slices.map(_.size).sum === 99) @@ -223,7 +223,7 @@ class ParallelCollectionSplitSuite extends SparkFunSuite with Checkers { } test("inclusive ranges of doubles") { - val data = 1.0 to 100.0 by 1.0 + val data = Range.BigDecimal.inclusive(1, 100, 1) val slices = ParallelCollectionRDD.slice(data, 3) assert(slices.size === 3) assert(slices.map(_.size).sum === 100) diff --git a/core/src/test/scala/org/apache/spark/scheduler/CustomShuffledRDD.scala b/core/src/test/scala/org/apache/spark/scheduler/CustomShuffledRDD.scala index 838686923767e..1be2e2a067115 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/CustomShuffledRDD.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/CustomShuffledRDD.scala @@ -104,8 +104,9 @@ class CustomShuffledRDD[K, V, C]( override def compute(p: Partition, context: TaskContext): Iterator[(K, C)] = { val part = p.asInstanceOf[CustomShuffledRDDPartition] + val metrics = context.taskMetrics().createTempShuffleReadMetrics() SparkEnv.get.shuffleManager.getReader( - dependency.shuffleHandle, part.startIndexInParent, part.endIndexInParent, context) + dependency.shuffleHandle, part.startIndexInParent, part.endIndexInParent, context, metrics) .read() .asInstanceOf[Iterator[(K, C)]] } diff --git a/core/src/test/scala/org/apache/spark/scheduler/TaskResultGetterSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/TaskResultGetterSuite.scala index 1bddba8f6c82b..efb8b15cf6b4d 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/TaskResultGetterSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/TaskResultGetterSuite.scala @@ -194,7 +194,7 @@ class TaskResultGetterSuite extends SparkFunSuite with BeforeAndAfter with Local // jar. sc = new SparkContext("local", "test", conf) val rdd = sc.parallelize(Seq(1), 1).map { _ => - val exc = excClass.newInstance().asInstanceOf[Exception] + val exc = excClass.getConstructor().newInstance().asInstanceOf[Exception] throw exc } @@ -265,7 +265,9 @@ class TaskResultGetterSuite extends SparkFunSuite with BeforeAndAfter with Local private class UndeserializableException extends Exception { private def readObject(in: ObjectInputStream): Unit = { + // scalastyle:off throwerror throw new NoClassDefFoundError() + // scalastyle:on throwerror } } diff --git a/core/src/test/scala/org/apache/spark/serializer/KryoSerializerBenchmark.scala b/core/src/test/scala/org/apache/spark/serializer/KryoSerializerBenchmark.scala new file mode 100644 index 0000000000000..2a15c6f6a2d96 --- /dev/null +++ b/core/src/test/scala/org/apache/spark/serializer/KryoSerializerBenchmark.scala @@ -0,0 +1,90 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.serializer + +import scala.concurrent._ +import scala.concurrent.ExecutionContext.Implicits.global +import scala.concurrent.duration._ + +import org.apache.spark.{SparkConf, SparkContext} +import org.apache.spark.benchmark.{Benchmark, BenchmarkBase} +import org.apache.spark.serializer.KryoTest._ +import org.apache.spark.util.ThreadUtils + +/** + * Benchmark for KryoPool vs old "pool of 1". + * To run this benchmark: + * {{{ + * 1. without sbt: + * bin/spark-submit --class --jars + * 2. build/sbt "core/test:runMain " + * 3. generate result: + * SPARK_GENERATE_BENCHMARK_FILES=1 build/sbt "core/test:runMain " + * Results will be written to "benchmarks/KryoSerializerBenchmark-results.txt". + * }}} + */ +object KryoSerializerBenchmark extends BenchmarkBase { + + var sc: SparkContext = null + val N = 500 + override def runBenchmarkSuite(mainArgs: Array[String]): Unit = { + val name = "Benchmark KryoPool vs old\"pool of 1\" implementation" + runBenchmark(name) { + val benchmark = new Benchmark(name, N, 10, output = output) + Seq(true, false).foreach(usePool => run(usePool, benchmark)) + benchmark.run() + } + } + + private def run(usePool: Boolean, benchmark: Benchmark): Unit = { + lazy val sc = createSparkContext(usePool) + + benchmark.addCase(s"KryoPool:$usePool") { _ => + val futures = for (_ <- 0 until N) yield { + Future { + sc.parallelize(0 until 10).map(i => i + 1).count() + } + } + + val future = Future.sequence(futures) + + ThreadUtils.awaitResult(future, 10.minutes) + } + } + + def createSparkContext(usePool: Boolean): SparkContext = { + val conf = new SparkConf() + conf.set("spark.serializer", "org.apache.spark.serializer.KryoSerializer") + conf.set("spark.kryo.registrator", classOf[MyRegistrator].getName) + conf.set("spark.kryo.pool", usePool.toString) + + if (sc != null) { + sc.stop() + } + + sc = new SparkContext("local-cluster[4,1,1024]", "test", conf) + sc + } + + override def afterAll(): Unit = { + if (sc != null) { + sc.stop() + } + } + +} diff --git a/core/src/test/scala/org/apache/spark/serializer/KryoSerializerSuite.scala b/core/src/test/scala/org/apache/spark/serializer/KryoSerializerSuite.scala index ac25bcef54349..a7eed4b6a8b88 100644 --- a/core/src/test/scala/org/apache/spark/serializer/KryoSerializerSuite.scala +++ b/core/src/test/scala/org/apache/spark/serializer/KryoSerializerSuite.scala @@ -18,9 +18,13 @@ package org.apache.spark.serializer import java.io.{ByteArrayInputStream, ByteArrayOutputStream, FileInputStream, FileOutputStream} +import java.nio.ByteBuffer +import java.util.concurrent.Executors import scala.collection.JavaConverters._ import scala.collection.mutable +import scala.concurrent.{ExecutionContext, Future} +import scala.concurrent.duration._ import scala.reflect.ClassTag import com.esotericsoftware.kryo.{Kryo, KryoException} @@ -31,7 +35,7 @@ import org.apache.spark.{SharedSparkContext, SparkConf, SparkFunSuite} import org.apache.spark.scheduler.HighlyCompressedMapStatus import org.apache.spark.serializer.KryoTest._ import org.apache.spark.storage.BlockManagerId -import org.apache.spark.util.Utils +import org.apache.spark.util.{ThreadUtils, Utils} class KryoSerializerSuite extends SparkFunSuite with SharedSparkContext { conf.set("spark.serializer", "org.apache.spark.serializer.KryoSerializer") @@ -199,7 +203,7 @@ class KryoSerializerSuite extends SparkFunSuite with SharedSparkContext { def check[T: ClassTag](t: T) { assert(ser.deserialize[T](ser.serialize(t)) === t) // Check that very long ranges don't get written one element at a time - assert(ser.serialize(t).limit() < 100) + assert(ser.serialize(t).limit() < 200) } check(1 to 1000000) check(1 to 1000000 by 2) @@ -209,10 +213,10 @@ class KryoSerializerSuite extends SparkFunSuite with SharedSparkContext { check(1L to 1000000L by 2L) check(1L until 1000000L) check(1L until 1000000L by 2L) - check(1.0 to 1000000.0 by 1.0) - check(1.0 to 1000000.0 by 2.0) - check(1.0 until 1000000.0 by 1.0) - check(1.0 until 1000000.0 by 2.0) + check(Range.BigDecimal.inclusive(1, 1000000, 1)) + check(Range.BigDecimal.inclusive(1, 1000000, 2)) + check(Range.BigDecimal(1, 1000000, 1)) + check(Range.BigDecimal(1, 1000000, 2)) } test("asJavaIterable") { @@ -308,7 +312,7 @@ class KryoSerializerSuite extends SparkFunSuite with SharedSparkContext { val conf = new SparkConf(false) conf.set("spark.kryo.registrator", "this.class.does.not.exist") - val thrown = intercept[SparkException](new KryoSerializer(conf).newInstance()) + val thrown = intercept[SparkException](new KryoSerializer(conf).newInstance().serialize(1)) assert(thrown.getMessage.contains("Failed to register classes with Kryo")) } @@ -431,9 +435,11 @@ class KryoSerializerSuite extends SparkFunSuite with SharedSparkContext { ser.deserialize[HashMap[Int, List[String]]](serializedMap) } - private def testSerializerInstanceReuse(autoReset: Boolean, referenceTracking: Boolean): Unit = { + private def testSerializerInstanceReuse( + autoReset: Boolean, referenceTracking: Boolean, usePool: Boolean): Unit = { val conf = new SparkConf(loadDefaults = false) .set("spark.kryo.referenceTracking", referenceTracking.toString) + .set("spark.kryo.pool", usePool.toString) if (!autoReset) { conf.set("spark.kryo.registrator", classOf[RegistratorWithoutAutoReset].getName) } @@ -456,9 +462,58 @@ class KryoSerializerSuite extends SparkFunSuite with SharedSparkContext { // Regression test for SPARK-7766, an issue where disabling auto-reset and enabling // reference-tracking would lead to corrupted output when serializer instances are re-used - for (referenceTracking <- Set(true, false); autoReset <- Set(true, false)) { - test(s"instance reuse with autoReset = $autoReset, referenceTracking = $referenceTracking") { - testSerializerInstanceReuse(autoReset = autoReset, referenceTracking = referenceTracking) + for { + referenceTracking <- Seq(true, false) + autoReset <- Seq(true, false) + usePool <- Seq(true, false) + } { + test(s"instance reuse with autoReset = $autoReset, referenceTracking = $referenceTracking" + + s", usePool = $usePool") { + testSerializerInstanceReuse( + autoReset, referenceTracking, usePool) + } + } + + test("SPARK-25839 KryoPool implementation works correctly in multi-threaded environment") { + implicit val executionContext: ExecutionContext = ExecutionContext.fromExecutor( + Executors.newFixedThreadPool(4)) + + val ser = new KryoSerializer(conf.clone.set("spark.kryo.pool", "true")) + + val tests = mutable.ListBuffer[Future[Boolean]]() + + def check[T: ClassTag](t: T) { + tests += Future { + val serializerInstance = ser.newInstance() + serializerInstance.deserialize[T](serializerInstance.serialize(t)) === t + } + } + + check((1, 3)) + check(Array((1, 3))) + check(List((1, 3))) + check(List[Int]()) + check(List[Int](1, 2, 3)) + check(List[String]()) + check(List[String]("x", "y", "z")) + check(None) + check(Some(1)) + check(Some("hi")) + check(1 -> 1) + check(mutable.ArrayBuffer(1, 2, 3)) + check(mutable.ArrayBuffer("1", "2", "3")) + check(mutable.Map()) + check(mutable.Map(1 -> "one", 2 -> "two")) + check(mutable.Map("one" -> 1, "two" -> 2)) + check(mutable.HashMap(1 -> "one", 2 -> "two")) + check(mutable.HashMap("one" -> 1, "two" -> 2)) + check(List(Some(mutable.HashMap(1 -> 1, 2 -> 2)), None, Some(mutable.HashMap(3 -> 4)))) + check(List( + mutable.HashMap("one" -> 1, "two" -> 2), + mutable.HashMap(1 -> "one", 2 -> "two", 3 -> "three"))) + + tests.foreach { f => + assert(ThreadUtils.awaitResult(f, 10.seconds)) } } } @@ -497,6 +552,17 @@ class KryoSerializerAutoResetDisabledSuite extends SparkFunSuite with SharedSpar deserializationStream.close() assert(serInstance.deserialize[Any](helloHello) === ((hello, hello))) } + + test("SPARK-25786: ByteBuffer.array -- UnsupportedOperationException") { + val serInstance = new KryoSerializer(conf).newInstance().asInstanceOf[KryoSerializerInstance] + val obj = "UnsupportedOperationException" + val serObj = serInstance.serialize(obj) + val byteBuffer = ByteBuffer.allocateDirect(serObj.array().length) + byteBuffer.put(serObj.array()) + byteBuffer.flip() + assert(serInstance.deserialize[Any](serObj) === (obj)) + assert(serInstance.deserialize[Any](byteBuffer) === (obj)) + } } class ClassLoaderTestingObject diff --git a/core/src/test/scala/org/apache/spark/shuffle/BlockStoreShuffleReaderSuite.scala b/core/src/test/scala/org/apache/spark/shuffle/BlockStoreShuffleReaderSuite.scala index 2d8a83c6fabed..eb97d5a1e5074 100644 --- a/core/src/test/scala/org/apache/spark/shuffle/BlockStoreShuffleReaderSuite.scala +++ b/core/src/test/scala/org/apache/spark/shuffle/BlockStoreShuffleReaderSuite.scala @@ -126,11 +126,14 @@ class BlockStoreShuffleReaderSuite extends SparkFunSuite with LocalSparkContext .set("spark.shuffle.compress", "false") .set("spark.shuffle.spill.compress", "false")) + val taskContext = TaskContext.empty() + val metrics = taskContext.taskMetrics.createTempShuffleReadMetrics() val shuffleReader = new BlockStoreShuffleReader( shuffleHandle, reduceId, reduceId + 1, - TaskContext.empty(), + taskContext, + metrics, serializerManager, blockManager, mapOutputTracker) diff --git a/core/src/test/scala/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriterSuite.scala b/core/src/test/scala/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriterSuite.scala index 85ccb33471048..4467c3241a947 100644 --- a/core/src/test/scala/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriterSuite.scala +++ b/core/src/test/scala/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriterSuite.scala @@ -136,8 +136,8 @@ class BypassMergeSortShuffleWriterSuite extends SparkFunSuite with BeforeAndAfte blockResolver, shuffleHandle, 0, // MapId - taskContext, - conf + conf, + taskContext.taskMetrics().shuffleWriteMetrics ) writer.write(Iterator.empty) writer.stop( /* success = */ true) @@ -160,8 +160,8 @@ class BypassMergeSortShuffleWriterSuite extends SparkFunSuite with BeforeAndAfte blockResolver, shuffleHandle, 0, // MapId - taskContext, - conf + conf, + taskContext.taskMetrics().shuffleWriteMetrics ) writer.write(records) writer.stop( /* success = */ true) @@ -195,8 +195,8 @@ class BypassMergeSortShuffleWriterSuite extends SparkFunSuite with BeforeAndAfte blockResolver, shuffleHandle, 0, // MapId - taskContext, - conf + conf, + taskContext.taskMetrics().shuffleWriteMetrics ) intercept[SparkException] { @@ -217,8 +217,8 @@ class BypassMergeSortShuffleWriterSuite extends SparkFunSuite with BeforeAndAfte blockResolver, shuffleHandle, 0, // MapId - taskContext, - conf + conf, + taskContext.taskMetrics().shuffleWriteMetrics ) intercept[SparkException] { writer.write((0 until 100000).iterator.map(i => { diff --git a/core/src/test/scala/org/apache/spark/status/AppStatusListenerSuite.scala b/core/src/test/scala/org/apache/spark/status/AppStatusListenerSuite.scala index bfd73069fbff8..7860a0df4bb2d 100644 --- a/core/src/test/scala/org/apache/spark/status/AppStatusListenerSuite.scala +++ b/core/src/test/scala/org/apache/spark/status/AppStatusListenerSuite.scala @@ -18,8 +18,7 @@ package org.apache.spark.status import java.io.File -import java.lang.{Integer => JInteger, Long => JLong} -import java.util.{Arrays, Date, Properties} +import java.util.{Date, Properties} import scala.collection.JavaConverters._ import scala.collection.immutable.Map @@ -1171,12 +1170,12 @@ class AppStatusListenerSuite extends SparkFunSuite with BeforeAndAfter { // Stop task 2 before task 1 time += 1 tasks(1).markFinished(TaskState.FINISHED, time) - listener.onTaskEnd( - SparkListenerTaskEnd(stage1.stageId, stage1.attemptId, "taskType", Success, tasks(1), null)) + listener.onTaskEnd(SparkListenerTaskEnd( + stage1.stageId, stage1.attemptNumber, "taskType", Success, tasks(1), null)) time += 1 tasks(0).markFinished(TaskState.FINISHED, time) - listener.onTaskEnd( - SparkListenerTaskEnd(stage1.stageId, stage1.attemptId, "taskType", Success, tasks(0), null)) + listener.onTaskEnd(SparkListenerTaskEnd( + stage1.stageId, stage1.attemptNumber, "taskType", Success, tasks(0), null)) // Start task 3 and task 2 should be evicted. listener.onTaskStart(SparkListenerTaskStart(stage1.stageId, stage1.attemptNumber, tasks(2))) @@ -1241,8 +1240,8 @@ class AppStatusListenerSuite extends SparkFunSuite with BeforeAndAfter { // Task 1 Finished time += 1 tasks(0).markFinished(TaskState.FINISHED, time) - listener.onTaskEnd( - SparkListenerTaskEnd(stage1.stageId, stage1.attemptId, "taskType", Success, tasks(0), null)) + listener.onTaskEnd(SparkListenerTaskEnd( + stage1.stageId, stage1.attemptNumber, "taskType", Success, tasks(0), null)) // Stage 1 Completed stage1.failureReason = Some("Failed") @@ -1256,7 +1255,7 @@ class AppStatusListenerSuite extends SparkFunSuite with BeforeAndAfter { time += 1 tasks(1).markFinished(TaskState.FINISHED, time) listener.onTaskEnd( - SparkListenerTaskEnd(stage1.stageId, stage1.attemptId, "taskType", + SparkListenerTaskEnd(stage1.stageId, stage1.attemptNumber, "taskType", TaskKilled(reason = "Killed"), tasks(1), null)) // Ensure killed task metrics are updated @@ -1274,6 +1273,51 @@ class AppStatusListenerSuite extends SparkFunSuite with BeforeAndAfter { assert(allJobs.head.numFailedStages == 1) } + test("SPARK-25451: total tasks in the executor summary should match total stage tasks") { + val testConf = conf.clone.set(LIVE_ENTITY_UPDATE_PERIOD, Long.MaxValue) + + val listener = new AppStatusListener(store, testConf, true) + + val stage = new StageInfo(1, 0, "stage", 4, Nil, Nil, "details") + listener.onJobStart(SparkListenerJobStart(1, time, Seq(stage), null)) + listener.onStageSubmitted(SparkListenerStageSubmitted(stage, new Properties())) + + val tasks = createTasks(4, Array("1", "2")) + tasks.foreach { task => + listener.onTaskStart(SparkListenerTaskStart(stage.stageId, stage.attemptNumber, task)) + } + + time += 1 + tasks(0).markFinished(TaskState.FINISHED, time) + listener.onTaskEnd(SparkListenerTaskEnd(stage.stageId, stage.attemptNumber, "taskType", + Success, tasks(0), null)) + time += 1 + tasks(1).markFinished(TaskState.FINISHED, time) + listener.onTaskEnd(SparkListenerTaskEnd(stage.stageId, stage.attemptNumber, "taskType", + Success, tasks(1), null)) + + stage.failureReason = Some("Failed") + listener.onStageCompleted(SparkListenerStageCompleted(stage)) + time += 1 + listener.onJobEnd(SparkListenerJobEnd(1, time, JobFailed(new RuntimeException("Bad Executor")))) + + time += 1 + tasks(2).markFinished(TaskState.FAILED, time) + listener.onTaskEnd(SparkListenerTaskEnd(stage.stageId, stage.attemptNumber, "taskType", + ExecutorLostFailure("1", true, Some("Lost executor")), tasks(2), null)) + time += 1 + tasks(3).markFinished(TaskState.FAILED, time) + listener.onTaskEnd(SparkListenerTaskEnd(stage.stageId, stage.attemptNumber, "taskType", + ExecutorLostFailure("2", true, Some("Lost executor")), tasks(3), null)) + + val esummary = store.view(classOf[ExecutorStageSummaryWrapper]).asScala.map(_.info) + esummary.foreach { execSummary => + assert(execSummary.failedTasks === 1) + assert(execSummary.succeededTasks === 1) + assert(execSummary.killedTasks === 0) + } + } + test("driver logs") { val listener = new AppStatusListener(store, conf, true) diff --git a/core/src/test/scala/org/apache/spark/status/AppStatusUtilsSuite.scala b/core/src/test/scala/org/apache/spark/status/AppStatusUtilsSuite.scala index 9e74e86ad54b9..a01b24d323d28 100644 --- a/core/src/test/scala/org/apache/spark/status/AppStatusUtilsSuite.scala +++ b/core/src/test/scala/org/apache/spark/status/AppStatusUtilsSuite.scala @@ -52,7 +52,10 @@ class AppStatusUtilsSuite extends SparkFunSuite { inputMetrics = null, outputMetrics = null, shuffleReadMetrics = null, - shuffleWriteMetrics = null))) + shuffleWriteMetrics = null)), + executorLogs = null, + schedulerDelay = 0L, + gettingResultTime = 0L) assert(AppStatusUtils.schedulerDelay(runningTask) === 0L) val finishedTask = new TaskData( @@ -83,7 +86,10 @@ class AppStatusUtilsSuite extends SparkFunSuite { inputMetrics = null, outputMetrics = null, shuffleReadMetrics = null, - shuffleWriteMetrics = null))) + shuffleWriteMetrics = null)), + executorLogs = null, + schedulerDelay = 0L, + gettingResultTime = 0L) assert(AppStatusUtils.schedulerDelay(finishedTask) === 3L) } } diff --git a/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala b/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala index 32d6e8b94e1a2..cf00c1c3aad39 100644 --- a/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala +++ b/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala @@ -574,7 +574,7 @@ class BlockManagerSuite extends SparkFunSuite with Matchers with BeforeAndAfterE "list1", StorageLevel.MEMORY_ONLY, ClassTag.Any, - () => throw new AssertionError("attempted to compute locally")).isLeft) + () => fail("attempted to compute locally")).isLeft) } test("in-memory LRU storage") { diff --git a/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala b/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala index b268195e09a5b..01ee9ef0825f8 100644 --- a/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala +++ b/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala @@ -102,8 +102,10 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT (remoteBmId, remoteBlocks.keys.map(blockId => (blockId, 1.asInstanceOf[Long])).toSeq) ).toIterator + val taskContext = TaskContext.empty() + val metrics = taskContext.taskMetrics.createTempShuffleReadMetrics() val iterator = new ShuffleBlockFetcherIterator( - TaskContext.empty(), + taskContext, transfer, blockManager, blocksByAddress, @@ -112,7 +114,8 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT Int.MaxValue, Int.MaxValue, Int.MaxValue, - true) + true, + metrics) // 3 local blocks fetched in initialization verify(blockManager, times(3)).getBlockData(any()) @@ -190,7 +193,8 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT Int.MaxValue, Int.MaxValue, Int.MaxValue, - true) + true, + taskContext.taskMetrics.createTempShuffleReadMetrics()) verify(blocks(ShuffleBlockId(0, 0, 0)), times(0)).release() iterator.next()._2.close() // close() first block's input stream @@ -258,7 +262,8 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT Int.MaxValue, Int.MaxValue, Int.MaxValue, - true) + true, + taskContext.taskMetrics.createTempShuffleReadMetrics()) // Continue only after the mock calls onBlockFetchFailure sem.acquire() @@ -328,7 +333,8 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT Int.MaxValue, Int.MaxValue, Int.MaxValue, - true) + true, + taskContext.taskMetrics.createTempShuffleReadMetrics()) // Continue only after the mock calls onBlockFetchFailure sem.acquire() @@ -392,7 +398,8 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT Int.MaxValue, Int.MaxValue, Int.MaxValue, - true) + true, + taskContext.taskMetrics.createTempShuffleReadMetrics()) // Blocks should be returned without exceptions. assert(Set(iterator.next()._1, iterator.next()._1) === Set(ShuffleBlockId(0, 0, 0), ShuffleBlockId(0, 1, 0))) @@ -446,7 +453,8 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT Int.MaxValue, Int.MaxValue, Int.MaxValue, - false) + false, + taskContext.taskMetrics.createTempShuffleReadMetrics()) // Continue only after the mock calls onBlockFetchFailure sem.acquire() @@ -496,8 +504,9 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT // Set `maxBytesInFlight` and `maxReqsInFlight` to `Int.MaxValue`, so that during the // construction of `ShuffleBlockFetcherIterator`, all requests to fetch remote shuffle blocks // are issued. The `maxReqSizeShuffleToMem` is hard-coded as 200 here. + val taskContext = TaskContext.empty() new ShuffleBlockFetcherIterator( - TaskContext.empty(), + taskContext, transfer, blockManager, blocksByAddress, @@ -506,7 +515,8 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT maxReqsInFlight = Int.MaxValue, maxBlocksInFlightPerAddress = Int.MaxValue, maxReqSizeShuffleToMem = 200, - detectCorrupt = true) + detectCorrupt = true, + taskContext.taskMetrics.createTempShuffleReadMetrics()) } val blocksByAddress1 = Seq[(BlockManagerId, Seq[(BlockId, Long)])]( @@ -552,7 +562,8 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT Int.MaxValue, Int.MaxValue, Int.MaxValue, - true) + true, + taskContext.taskMetrics.createTempShuffleReadMetrics()) // All blocks fetched return zero length and should trigger a receive-side error: val e = intercept[FetchFailedException] { iterator.next() } diff --git a/core/src/test/scala/org/apache/spark/ui/StagePageSuite.scala b/core/src/test/scala/org/apache/spark/ui/StagePageSuite.scala index 2945c3ee0a9d9..5e976ae4e91da 100644 --- a/core/src/test/scala/org/apache/spark/ui/StagePageSuite.scala +++ b/core/src/test/scala/org/apache/spark/ui/StagePageSuite.scala @@ -96,18 +96,6 @@ class StagePageSuite extends SparkFunSuite with LocalSparkContext { } } - test("peak execution memory should displayed") { - val html = renderStagePage().toString().toLowerCase(Locale.ROOT) - val targetString = "peak execution memory" - assert(html.contains(targetString)) - } - - test("SPARK-10543: peak execution memory should be per-task rather than cumulative") { - val html = renderStagePage().toString().toLowerCase(Locale.ROOT) - // verify min/25/50/75/max show task value not cumulative values - assert(html.contains(s"$peakExecutionMemory.0 b" * 5)) - } - /** * Render a stage page started with the given conf and return the HTML. * This also runs a dummy stage to populate the page with useful content. diff --git a/core/src/test/scala/org/apache/spark/ui/storage/StoragePageSuite.scala b/core/src/test/scala/org/apache/spark/ui/storage/StoragePageSuite.scala index cdc7f541b9552..06f01a60868f9 100644 --- a/core/src/test/scala/org/apache/spark/ui/storage/StoragePageSuite.scala +++ b/core/src/test/scala/org/apache/spark/ui/storage/StoragePageSuite.scala @@ -81,19 +81,19 @@ class StoragePageSuite extends SparkFunSuite { Seq("1", "rdd1", "Memory Deserialized 1x Replicated", "10", "100%", "100.0 B", "0.0 B")) // Check the url assert(((xmlNodes \\ "tr")(0) \\ "td" \ "a")(0).attribute("href").map(_.text) === - Some("http://localhost:4040/storage/rdd?id=1")) + Some("http://localhost:4040/storage/rdd/?id=1")) assert(((xmlNodes \\ "tr")(1) \\ "td").map(_.text.trim) === Seq("2", "rdd2", "Disk Serialized 1x Replicated", "5", "50%", "0.0 B", "200.0 B")) // Check the url assert(((xmlNodes \\ "tr")(1) \\ "td" \ "a")(0).attribute("href").map(_.text) === - Some("http://localhost:4040/storage/rdd?id=2")) + Some("http://localhost:4040/storage/rdd/?id=2")) assert(((xmlNodes \\ "tr")(2) \\ "td").map(_.text.trim) === Seq("3", "rdd3", "Disk Memory Serialized 1x Replicated", "10", "100%", "400.0 B", "500.0 B")) // Check the url assert(((xmlNodes \\ "tr")(2) \\ "td" \ "a")(0).attribute("href").map(_.text) === - Some("http://localhost:4040/storage/rdd?id=3")) + Some("http://localhost:4040/storage/rdd/?id=3")) } test("empty rddTable") { diff --git a/core/src/test/scala/org/apache/spark/util/AccumulatorV2Suite.scala b/core/src/test/scala/org/apache/spark/util/AccumulatorV2Suite.scala index 621399af731f7..172bebbfec61d 100644 --- a/core/src/test/scala/org/apache/spark/util/AccumulatorV2Suite.scala +++ b/core/src/test/scala/org/apache/spark/util/AccumulatorV2Suite.scala @@ -40,7 +40,7 @@ class AccumulatorV2Suite extends SparkFunSuite { assert(acc.avg == 0.5) // Also test add using non-specialized add function - acc.add(new java.lang.Long(2)) + acc.add(java.lang.Long.valueOf(2)) assert(acc.count == 3) assert(acc.sum == 3) assert(acc.avg == 1.0) @@ -73,7 +73,7 @@ class AccumulatorV2Suite extends SparkFunSuite { assert(acc.avg == 0.5) // Also test add using non-specialized add function - acc.add(new java.lang.Double(2.0)) + acc.add(java.lang.Double.valueOf(2.0)) assert(acc.count == 3) assert(acc.sum == 3.0) assert(acc.avg == 1.0) @@ -96,7 +96,7 @@ class AccumulatorV2Suite extends SparkFunSuite { assert(acc.value.contains(0.0)) assert(!acc.isZero) - acc.add(new java.lang.Double(1.0)) + acc.add(java.lang.Double.valueOf(1.0)) val acc2 = acc.copyAndReset() assert(acc2.value.isEmpty) diff --git a/core/src/test/scala/org/apache/spark/util/CompletionIteratorSuite.scala b/core/src/test/scala/org/apache/spark/util/CompletionIteratorSuite.scala index 688fcd9f9aaba..29421f7aa9e36 100644 --- a/core/src/test/scala/org/apache/spark/util/CompletionIteratorSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/CompletionIteratorSuite.scala @@ -17,6 +17,9 @@ package org.apache.spark.util +import java.lang.ref.PhantomReference +import java.lang.ref.ReferenceQueue + import org.apache.spark.SparkFunSuite class CompletionIteratorSuite extends SparkFunSuite { @@ -44,4 +47,23 @@ class CompletionIteratorSuite extends SparkFunSuite { assert(!completionIter.hasNext) assert(numTimesCompleted === 1) } + test("reference to sub iterator should not be available after completion") { + var sub = Iterator(1, 2, 3) + + val refQueue = new ReferenceQueue[Iterator[Int]] + val ref = new PhantomReference[Iterator[Int]](sub, refQueue) + + val iter = CompletionIterator[Int, Iterator[Int]](sub, {}) + sub = null + iter.toArray + + for (_ <- 1 to 100 if !ref.isEnqueued) { + System.gc() + if (!ref.isEnqueued) { + Thread.sleep(10) + } + } + assert(ref.isEnqueued) + assert(refQueue.poll() === ref) + } } diff --git a/core/src/test/scala/org/apache/spark/util/MutableURLClassLoaderSuite.scala b/core/src/test/scala/org/apache/spark/util/MutableURLClassLoaderSuite.scala index f6ac89fc2742a..8d844bd08771c 100644 --- a/core/src/test/scala/org/apache/spark/util/MutableURLClassLoaderSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/MutableURLClassLoaderSuite.scala @@ -22,7 +22,6 @@ import java.net.URLClassLoader import scala.collection.JavaConverters._ import org.scalatest.Matchers -import org.scalatest.Matchers._ import org.apache.spark.{SparkContext, SparkException, SparkFunSuite, TestUtils} @@ -46,10 +45,10 @@ class MutableURLClassLoaderSuite extends SparkFunSuite with Matchers { test("child first") { val parentLoader = new URLClassLoader(urls2, null) val classLoader = new ChildFirstURLClassLoader(urls, parentLoader) - val fakeClass = classLoader.loadClass("FakeClass2").newInstance() + val fakeClass = classLoader.loadClass("FakeClass2").getConstructor().newInstance() val fakeClassVersion = fakeClass.toString assert(fakeClassVersion === "1") - val fakeClass2 = classLoader.loadClass("FakeClass2").newInstance() + val fakeClass2 = classLoader.loadClass("FakeClass2").getConstructor().newInstance() assert(fakeClass.getClass === fakeClass2.getClass) classLoader.close() parentLoader.close() @@ -58,10 +57,10 @@ class MutableURLClassLoaderSuite extends SparkFunSuite with Matchers { test("parent first") { val parentLoader = new URLClassLoader(urls2, null) val classLoader = new MutableURLClassLoader(urls, parentLoader) - val fakeClass = classLoader.loadClass("FakeClass1").newInstance() + val fakeClass = classLoader.loadClass("FakeClass1").getConstructor().newInstance() val fakeClassVersion = fakeClass.toString assert(fakeClassVersion === "2") - val fakeClass2 = classLoader.loadClass("FakeClass1").newInstance() + val fakeClass2 = classLoader.loadClass("FakeClass1").getConstructor().newInstance() assert(fakeClass.getClass === fakeClass2.getClass) classLoader.close() parentLoader.close() @@ -70,7 +69,7 @@ class MutableURLClassLoaderSuite extends SparkFunSuite with Matchers { test("child first can fall back") { val parentLoader = new URLClassLoader(urls2, null) val classLoader = new ChildFirstURLClassLoader(urls, parentLoader) - val fakeClass = classLoader.loadClass("FakeClass3").newInstance() + val fakeClass = classLoader.loadClass("FakeClass3").getConstructor().newInstance() val fakeClassVersion = fakeClass.toString assert(fakeClassVersion === "2") classLoader.close() @@ -81,7 +80,7 @@ class MutableURLClassLoaderSuite extends SparkFunSuite with Matchers { val parentLoader = new URLClassLoader(urls2, null) val classLoader = new ChildFirstURLClassLoader(urls, parentLoader) intercept[java.lang.ClassNotFoundException] { - classLoader.loadClass("FakeClassDoesNotExist").newInstance() + classLoader.loadClass("FakeClassDoesNotExist").getConstructor().newInstance() } classLoader.close() parentLoader.close() @@ -137,7 +136,7 @@ class MutableURLClassLoaderSuite extends SparkFunSuite with Matchers { sc.makeRDD(1 to 5, 2).mapPartitions { x => val loader = Thread.currentThread().getContextClassLoader // scalastyle:off classforname - Class.forName(className, true, loader).newInstance() + Class.forName(className, true, loader).getConstructor().newInstance() // scalastyle:on classforname Seq().iterator }.count() diff --git a/core/src/test/scala/org/apache/spark/util/SizeEstimatorSuite.scala b/core/src/test/scala/org/apache/spark/util/SizeEstimatorSuite.scala index 2695295d451d5..63f9f82adf3e0 100644 --- a/core/src/test/scala/org/apache/spark/util/SizeEstimatorSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/SizeEstimatorSuite.scala @@ -92,14 +92,14 @@ class SizeEstimatorSuite } test("primitive wrapper objects") { - assertResult(16)(SizeEstimator.estimate(new java.lang.Boolean(true))) - assertResult(16)(SizeEstimator.estimate(new java.lang.Byte("1"))) - assertResult(16)(SizeEstimator.estimate(new java.lang.Character('1'))) - assertResult(16)(SizeEstimator.estimate(new java.lang.Short("1"))) - assertResult(16)(SizeEstimator.estimate(new java.lang.Integer(1))) - assertResult(24)(SizeEstimator.estimate(new java.lang.Long(1))) - assertResult(16)(SizeEstimator.estimate(new java.lang.Float(1.0))) - assertResult(24)(SizeEstimator.estimate(new java.lang.Double(1.0d))) + assertResult(16)(SizeEstimator.estimate(java.lang.Boolean.TRUE)) + assertResult(16)(SizeEstimator.estimate(java.lang.Byte.valueOf("1"))) + assertResult(16)(SizeEstimator.estimate(java.lang.Character.valueOf('1'))) + assertResult(16)(SizeEstimator.estimate(java.lang.Short.valueOf("1"))) + assertResult(16)(SizeEstimator.estimate(java.lang.Integer.valueOf(1))) + assertResult(24)(SizeEstimator.estimate(java.lang.Long.valueOf(1))) + assertResult(16)(SizeEstimator.estimate(java.lang.Float.valueOf(1.0f))) + assertResult(24)(SizeEstimator.estimate(java.lang.Double.valueOf(1.0))) } test("class field blocks rounding") { @@ -202,14 +202,14 @@ class SizeEstimatorSuite assertResult(72)(SizeEstimator.estimate(DummyString("abcdefgh"))) // primitive wrapper classes - assertResult(24)(SizeEstimator.estimate(new java.lang.Boolean(true))) - assertResult(24)(SizeEstimator.estimate(new java.lang.Byte("1"))) - assertResult(24)(SizeEstimator.estimate(new java.lang.Character('1'))) - assertResult(24)(SizeEstimator.estimate(new java.lang.Short("1"))) - assertResult(24)(SizeEstimator.estimate(new java.lang.Integer(1))) - assertResult(24)(SizeEstimator.estimate(new java.lang.Long(1))) - assertResult(24)(SizeEstimator.estimate(new java.lang.Float(1.0))) - assertResult(24)(SizeEstimator.estimate(new java.lang.Double(1.0d))) + assertResult(24)(SizeEstimator.estimate(java.lang.Boolean.TRUE)) + assertResult(24)(SizeEstimator.estimate(java.lang.Byte.valueOf("1"))) + assertResult(24)(SizeEstimator.estimate(java.lang.Character.valueOf('1'))) + assertResult(24)(SizeEstimator.estimate(java.lang.Short.valueOf("1"))) + assertResult(24)(SizeEstimator.estimate(java.lang.Integer.valueOf(1))) + assertResult(24)(SizeEstimator.estimate(java.lang.Long.valueOf(1))) + assertResult(24)(SizeEstimator.estimate(java.lang.Float.valueOf(1.0f))) + assertResult(24)(SizeEstimator.estimate(java.lang.Double.valueOf(1.0))) } test("class field blocks rounding on 64-bit VM without useCompressedOops") { diff --git a/core/src/test/scala/org/apache/spark/util/UtilsSuite.scala b/core/src/test/scala/org/apache/spark/util/UtilsSuite.scala index 39f4fba78583f..f5e912b50d1ab 100644 --- a/core/src/test/scala/org/apache/spark/util/UtilsSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/UtilsSuite.scala @@ -33,7 +33,7 @@ import scala.util.Random import com.google.common.io.Files import org.apache.commons.io.IOUtils -import org.apache.commons.lang3.SystemUtils +import org.apache.commons.lang3.{JavaVersion, SystemUtils} import org.apache.commons.math3.stat.inference.ChiSquareTest import org.apache.hadoop.conf.Configuration import org.apache.hadoop.fs.Path @@ -45,14 +45,6 @@ import org.apache.spark.scheduler.SparkListener class UtilsSuite extends SparkFunSuite with ResetSystemProperties with Logging { - test("truncatedString") { - assert(Utils.truncatedString(Nil, "[", ", ", "]", 2) == "[]") - assert(Utils.truncatedString(Seq(1, 2), "[", ", ", "]", 2) == "[1, 2]") - assert(Utils.truncatedString(Seq(1, 2, 3), "[", ", ", "]", 2) == "[1, ... 2 more fields]") - assert(Utils.truncatedString(Seq(1, 2, 3), "[", ", ", "]", -5) == "[, ... 3 more fields]") - assert(Utils.truncatedString(Seq(1, 2, 3), ", ") == "1, 2, 3") - } - test("timeConversion") { // Test -1 assert(Utils.timeStringAsSeconds("-1") === -1) @@ -932,10 +924,7 @@ class UtilsSuite extends SparkFunSuite with ResetSystemProperties with Logging { signal(pid, "SIGKILL") } - val versionParts = System.getProperty("java.version").split("[+.\\-]+", 3) - var majorVersion = versionParts(0).toInt - if (majorVersion == 1) majorVersion = versionParts(1).toInt - if (majorVersion >= 8) { + if (SystemUtils.isJavaVersionAtLeast(JavaVersion.JAVA_1_8)) { // We'll make sure that forcibly terminating a process works by // creating a very misbehaving process. It ignores SIGTERM and has been SIGSTOPed. On // older versions of java, this will *not* terminate. diff --git a/core/src/test/scala/org/apache/spark/util/VersionUtilsSuite.scala b/core/src/test/scala/org/apache/spark/util/VersionUtilsSuite.scala index b36d6be231d39..56623ebea1651 100644 --- a/core/src/test/scala/org/apache/spark/util/VersionUtilsSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/VersionUtilsSuite.scala @@ -73,4 +73,29 @@ class VersionUtilsSuite extends SparkFunSuite { } } } + + test("Return short version number") { + assert(shortVersion("3.0.0") === "3.0.0") + assert(shortVersion("3.0.0-SNAPSHOT") === "3.0.0") + withClue("shortVersion parsing should fail for missing maintenance version number") { + intercept[IllegalArgumentException] { + shortVersion("3.0") + } + } + withClue("shortVersion parsing should fail for invalid major version number") { + intercept[IllegalArgumentException] { + shortVersion("x.0.0") + } + } + withClue("shortVersion parsing should fail for invalid minor version number") { + intercept[IllegalArgumentException] { + shortVersion("3.x.0") + } + } + withClue("shortVersion parsing should fail for invalid maintenance version number") { + intercept[IllegalArgumentException] { + shortVersion("3.0.x") + } + } + } } diff --git a/core/src/test/scala/org/apache/spark/util/collection/ExternalAppendOnlyMapSuite.scala b/core/src/test/scala/org/apache/spark/util/collection/ExternalAppendOnlyMapSuite.scala index cd25265784136..35fba1a3b73c6 100644 --- a/core/src/test/scala/org/apache/spark/util/collection/ExternalAppendOnlyMapSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/collection/ExternalAppendOnlyMapSuite.scala @@ -19,6 +19,7 @@ package org.apache.spark.util.collection import scala.collection.mutable.ArrayBuffer import scala.concurrent.duration._ +import scala.language.postfixOps import scala.ref.WeakReference import org.scalatest.Matchers diff --git a/core/src/test/scala/org/apache/spark/util/collection/SorterSuite.scala b/core/src/test/scala/org/apache/spark/util/collection/SorterSuite.scala index 65bf857e22c02..46a05e2ba798b 100644 --- a/core/src/test/scala/org/apache/spark/util/collection/SorterSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/collection/SorterSuite.scala @@ -17,7 +17,7 @@ package org.apache.spark.util.collection -import java.lang.{Float => JFloat, Integer => JInteger} +import java.lang.{Float => JFloat} import java.util.{Arrays, Comparator} import org.apache.spark.SparkFunSuite @@ -48,7 +48,7 @@ class SorterSuite extends SparkFunSuite with Logging { // alternate. Keys are random doubles, values are ordinals from 0 to length. val keys = Array.tabulate[Double](5000) { i => rand.nextDouble() } val keyValueArray = Array.tabulate[Number](10000) { i => - if (i % 2 == 0) keys(i / 2) else new Integer(i / 2) + if (i % 2 == 0) keys(i / 2) else Integer.valueOf(i / 2) } // Map from generated keys to values, to verify correctness later @@ -112,7 +112,7 @@ class SorterSuite extends SparkFunSuite with Logging { // Test our key-value pairs where each element is a Tuple2[Float, Integer]. val kvTuples = Array.tabulate(numElements) { i => - (new JFloat(rand.nextFloat()), new JInteger(i)) + (JFloat.valueOf(rand.nextFloat()), Integer.valueOf(i)) } val kvTupleArray = new Array[AnyRef](numElements) @@ -167,23 +167,23 @@ class SorterSuite extends SparkFunSuite with Logging { val ints = Array.fill(numElements)(rand.nextInt()) val intObjects = { - val data = new Array[JInteger](numElements) + val data = new Array[Integer](numElements) var i = 0 while (i < numElements) { - data(i) = new JInteger(ints(i)) + data(i) = Integer.valueOf(ints(i)) i += 1 } data } - val intObjectArray = new Array[JInteger](numElements) + val intObjectArray = new Array[Integer](numElements) val prepareIntObjectArray = () => { System.arraycopy(intObjects, 0, intObjectArray, 0, numElements) } runExperiment("Java Arrays.sort() on non-primitive int array")({ - Arrays.sort(intObjectArray, new Comparator[JInteger] { - override def compare(x: JInteger, y: JInteger): Int = x.compareTo(y) + Arrays.sort(intObjectArray, new Comparator[Integer] { + override def compare(x: Integer, y: Integer): Int = x.compareTo(y) }) }, prepareIntObjectArray) diff --git a/core/src/test/scala/org/apache/spark/util/collection/unsafe/sort/RadixSortSuite.scala b/core/src/test/scala/org/apache/spark/util/collection/unsafe/sort/RadixSortSuite.scala index d5956ea32096a..d570630c1a095 100644 --- a/core/src/test/scala/org/apache/spark/util/collection/unsafe/sort/RadixSortSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/collection/unsafe/sort/RadixSortSuite.scala @@ -78,7 +78,7 @@ class RadixSortSuite extends SparkFunSuite with Logging { private def generateTestData(size: Long, rand: => Long): (Array[JLong], LongArray) = { val ref = Array.tabulate[Long](Ints.checkedCast(size)) { i => rand } val extended = ref ++ Array.fill[Long](Ints.checkedCast(size))(0) - (ref.map(i => new JLong(i)), new LongArray(MemoryBlock.fromLongArray(extended))) + (ref.map(i => JLong.valueOf(i)), new LongArray(MemoryBlock.fromLongArray(extended))) } private def generateKeyPrefixTestData(size: Long, rand: => Long): (LongArray, LongArray) = { diff --git a/core/src/test/scala/org/apache/spark/util/logging/DriverLoggerSuite.scala b/core/src/test/scala/org/apache/spark/util/logging/DriverLoggerSuite.scala new file mode 100644 index 0000000000000..973f71cdeb755 --- /dev/null +++ b/core/src/test/scala/org/apache/spark/util/logging/DriverLoggerSuite.scala @@ -0,0 +1,79 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.util.logging + +import java.io.{BufferedInputStream, File, FileInputStream} + +import org.apache.commons.io.FileUtils + +import org.apache.spark._ +import org.apache.spark.{SparkContext, SparkFunSuite} +import org.apache.spark.internal.config._ +import org.apache.spark.launcher.SparkLauncher +import org.apache.spark.network.util.JavaUtils +import org.apache.spark.util.Utils + +class DriverLoggerSuite extends SparkFunSuite with LocalSparkContext { + + private var rootDfsDir : File = _ + + override def beforeAll(): Unit = { + super.beforeAll() + rootDfsDir = Utils.createTempDir(namePrefix = "dfs_logs") + } + + override def afterAll(): Unit = { + super.afterAll() + JavaUtils.deleteRecursively(rootDfsDir) + } + + test("driver logs are persisted locally and synced to dfs") { + val sc = getSparkContext() + + val app_id = sc.applicationId + // Run a simple spark application + sc.parallelize(1 to 1000).count() + + // Assert driver log file exists + val rootDir = Utils.getLocalDir(sc.getConf) + val driverLogsDir = FileUtils.getFile(rootDir, DriverLogger.DRIVER_LOG_DIR) + assert(driverLogsDir.exists()) + val files = driverLogsDir.listFiles() + assert(files.length === 1) + assert(files(0).getName.equals(DriverLogger.DRIVER_LOG_FILE)) + + sc.stop() + assert(!driverLogsDir.exists()) + val dfsFile = FileUtils.getFile(sc.getConf.get(DRIVER_LOG_DFS_DIR).get, + app_id + DriverLogger.DRIVER_LOG_FILE_SUFFIX) + assert(dfsFile.exists()) + assert(dfsFile.length() > 0) + } + + private def getSparkContext(): SparkContext = { + val conf = new SparkConf() + conf.set(DRIVER_LOG_DFS_DIR, rootDfsDir.getAbsolutePath()) + conf.set(DRIVER_LOG_PERSISTTODFS, true) + conf.set(SparkLauncher.SPARK_MASTER, "local") + conf.set(SparkLauncher.DEPLOY_MODE, "client") + sc = new SparkContext("local", "DriverLogTest", conf) + sc + } + +} + diff --git a/dev/appveyor-install-dependencies.ps1 b/dev/appveyor-install-dependencies.ps1 index 06d9d70af311a..cc68ffb90d875 100644 --- a/dev/appveyor-install-dependencies.ps1 +++ b/dev/appveyor-install-dependencies.ps1 @@ -116,7 +116,7 @@ Pop-Location # ========================== R $rVer = "3.5.1" -$rToolsVer = "3.4.0" +$rToolsVer = "3.5.1" InstallR InstallRtools diff --git a/dev/checkstyle.xml b/dev/checkstyle.xml index 53c284888ebb0..e8859c01f2bd8 100644 --- a/dev/checkstyle.xml +++ b/dev/checkstyle.xml @@ -71,13 +71,13 @@ If you wish to turn off checking for a section of code, you can put a comment in the source before and after the section, with the following syntax: - // checkstyle:off no.XXX (such as checkstyle.off: NoFinalizer) + // checkstyle.off: XXX (such as checkstyle.off: NoFinalizer) ... // stuff that breaks the styles - // checkstyle:on + // checkstyle.on: XXX (such as checkstyle.on: NoFinalizer) --> - - + + @@ -180,5 +180,10 @@ + + + + + diff --git a/dev/deps/spark-deps-hadoop-palantir b/dev/deps/spark-deps-hadoop-palantir index 7253e4920b110..0d881c6cdab73 100644 --- a/dev/deps/spark-deps-hadoop-palantir +++ b/dev/deps/spark-deps-hadoop-palantir @@ -2,7 +2,7 @@ HikariCP-java7-2.4.12.jar RoaringBitmap-0.7.16.jar activation-1.1.1.jar aircompressor-0.10.jar -antlr4-runtime-4.7.jar +antlr4-runtime-4.7.1.jar aopalliance-1.0.jar aopalliance-repackaged-2.5.0-b32.jar apacheds-i18n-2.0.0-M15.jar @@ -21,10 +21,10 @@ avro-mapred-1.8.2-hadoop2.jar aws-java-sdk-bundle-1.11.201.jar azure-keyvault-core-0.8.0.jar azure-storage-5.4.0.jar -breeze-macros_2.11-0.13.2.jar -breeze_2.11-0.13.2.jar +breeze-macros_2.12-0.13.2.jar +breeze_2.12-0.13.2.jar chill-java-0.9.3.jar -chill_2.11-0.9.3.jar +chill_2.12-0.9.3.jar classmate-1.1.0.jar commons-beanutils-1.9.3.jar commons-beanutils-core-1.8.0.jar @@ -32,14 +32,14 @@ commons-cli-1.2.jar commons-codec-1.11.jar commons-collections-3.2.2.jar commons-compiler-3.0.10.jar -commons-compress-1.8.1.jar +commons-compress-1.18.jar commons-configuration-1.6.jar commons-crypto-1.0.0.jar commons-digester-1.8.jar commons-httpclient-3.1.jar commons-io-2.6.jar commons-lang-2.6.jar -commons-lang3-3.8.jar +commons-lang3-3.8.1.jar commons-logging-1.2.jar commons-math3-3.6.1.jar commons-net-3.6.jar @@ -105,7 +105,7 @@ jackson-mapper-asl-1.9.13.jar jackson-module-afterburner-2.9.7.jar jackson-module-jaxb-annotations-2.9.7.jar jackson-module-paranamer-2.9.7.jar -jackson-module-scala_2.11-2.9.7.jar +jackson-module-scala_2.12-2.9.7.jar jackson-xc-1.9.13.jar janino-3.0.10.jar javassist-3.20.0-GA.jar @@ -132,10 +132,10 @@ jetty-util-6.1.26.jar jetty-util-9.4.12.v20180830.jar joda-time-2.10.jar json-smart-1.3.1.jar -json4s-ast_2.11-3.5.3.jar -json4s-core_2.11-3.5.3.jar -json4s-jackson_2.11-3.5.3.jar -json4s-scalap_2.11-3.5.3.jar +json4s-ast_2.12-3.5.3.jar +json4s-core_2.12-3.5.3.jar +json4s-jackson_2.12-3.5.3.jar +json4s-scalap_2.12-3.5.3.jar jsp-api-2.1.jar jsr305-3.0.2.jar jtransforms-2.4.0.jar @@ -148,8 +148,8 @@ leveldbjni-all-1.8.jar log4j-1.2.17.jar logging-interceptor-3.11.0.jar lz4-java-1.5.0.jar -machinist_2.11-0.6.1.jar -macro-compat_2.11-1.1.1.jar +machinist_2.12-0.6.1.jar +macro-compat_2.12-1.1.1.jar metrics-core-3.2.6.jar metrics-graphite-3.2.6.jar metrics-influxdb-1.2.2.jar @@ -181,18 +181,18 @@ protobuf-java-2.5.0.jar py4j-0.10.8.1.jar pyrolite-4.13.jar safe-logging-1.5.1.jar -scala-compiler-2.11.12.jar -scala-library-2.11.12.jar -scala-parser-combinators_2.11-1.1.0.jar -scala-reflect-2.11.12.jar -scala-xml_2.11-1.0.5.jar -shapeless_2.11-2.3.2.jar +scala-compiler-2.12.7.jar +scala-library-2.12.7.jar +scala-parser-combinators_2.12-1.1.0.jar +scala-reflect-2.12.7.jar +scala-xml_2.12-1.0.5.jar +shapeless_2.12-2.3.2.jar slf4j-api-1.7.25.jar slf4j-log4j12-1.7.25.jar snakeyaml-1.23.jar snappy-java-1.1.7.2.jar -spire-macros_2.11-0.13.0.jar -spire_2.11-0.13.0.jar +spire-macros_2.12-0.13.0.jar +spire_2.12-0.13.0.jar stax-api-1.0-2.jar stax2-api-3.1.4.jar stream-2.9.6.jar @@ -203,5 +203,5 @@ xbean-asm7-shaded-4.12.jar xmlenc-0.52.jar xz-1.5.jar zjsonpatch-0.3.0.jar -zookeeper-3.4.6.jar +zookeeper-3.4.7.jar zstd-jni-1.3.5-3.jar diff --git a/dev/docker-images/Makefile b/dev/docker-images/Makefile index ed3e3a5ee7687..168d386770655 100644 --- a/dev/docker-images/Makefile +++ b/dev/docker-images/Makefile @@ -17,9 +17,9 @@ .PHONY: all publish base python r -BASE_IMAGE_NAME = palantirtechnologies/circle-spark-base:0.1.0 -PYTHON_IMAGE_NAME = palantirtechnologies/circle-spark-python:0.1.0 -R_IMAGE_NAME = palantirtechnologies/circle-spark-r:0.1.0 +BASE_IMAGE_NAME = palantirtechnologies/circle-spark-base:0.1.3 +PYTHON_IMAGE_NAME = palantirtechnologies/circle-spark-python:0.1.3 +R_IMAGE_NAME = palantirtechnologies/circle-spark-r:0.1.3 all: base python r diff --git a/dev/docker-images/base/Dockerfile b/dev/docker-images/base/Dockerfile index 0e84ec665fcd7..61293349a0db5 100644 --- a/dev/docker-images/base/Dockerfile +++ b/dev/docker-images/base/Dockerfile @@ -31,7 +31,7 @@ RUN mkdir -p /usr/share/man/man1 \ git \ locales sudo openssh-client ca-certificates tar gzip parallel \ net-tools netcat unzip zip bzip2 gnupg curl wget \ - openjdk-8-jdk rsync pandoc pandoc-citeproc flake8 \ + openjdk-8-jdk rsync pandoc pandoc-citeproc flake8 tzdata \ && rm -rf /var/lib/apt/lists/* # If you update java, make sure this aligns @@ -42,8 +42,9 @@ ENV JAVA_HOME=/usr/lib/jvm/java-8-openjdk-amd64 RUN ln -sf /usr/share/zoneinfo/Etc/UTC /etc/localtime # Use unicode -RUN locale-gen C.UTF-8 || true -ENV LANG=C.UTF-8 +RUN locale-gen en_US.UTF-8 +ENV LANG=en_US.UTF-8 +ENV TZ=UTC # install jq RUN JQ_URL="https://circle-downloads.s3.amazonaws.com/circleci-images/cache/linux-amd64/jq-latest" \ diff --git a/dev/docker-images/python/Dockerfile b/dev/docker-images/python/Dockerfile index cb44b373617da..bb039b7c5c8b6 100644 --- a/dev/docker-images/python/Dockerfile +++ b/dev/docker-images/python/Dockerfile @@ -34,12 +34,8 @@ RUN mkdir -p $(pyenv root)/versions \ RUN pyenv global our-miniconda/envs/python2 our-miniconda/envs/python3 \ && pyenv rehash -RUN $CONDA_BIN install -y -n python2 -c anaconda -c conda-forge xmlrunner \ - && $CONDA_BIN install -y -n python3 -c anaconda -c conda-forge xmlrunner \ - && $CONDA_BIN clean --all - # Expose pyenv globally ENV PATH=$CIRCLE_HOME/.pyenv/shims:$PATH -RUN PYENV_VERSION=our-miniconda/envs/python2 $CIRCLE_HOME/.pyenv/shims/pip install unishark \ - && PYENV_VERSION=our-miniconda/envs/python3 $CIRCLE_HOME/.pyenv/shims/pip install unishark +RUN PYENV_VERSION=our-miniconda/envs/python2 $CIRCLE_HOME/.pyenv/shims/pip install unishark unittest-xml-reporting \ + && PYENV_VERSION=our-miniconda/envs/python3 $CIRCLE_HOME/.pyenv/shims/pip install unishark unittest-xml-reporting diff --git a/dev/lint-python b/dev/lint-python index 27d87f6b56680..06816932e754a 100755 --- a/dev/lint-python +++ b/dev/lint-python @@ -1,5 +1,4 @@ #!/usr/bin/env bash - # # Licensed to the Apache Software Foundation (ASF) under one or more # contributor license agreements. See the NOTICE file distributed with @@ -16,160 +15,242 @@ # See the License for the specific language governing permissions and # limitations under the License. # +# define test binaries + versions +PYDOCSTYLE_BUILD="pydocstyle" +MINIMUM_PYDOCSTYLE="3.0.0" -SCRIPT_DIR="$( cd "$( dirname "$0" )" && pwd )" -SPARK_ROOT_DIR="$(dirname "$SCRIPT_DIR")" -# Exclude auto-generated configuration file. -PATHS_TO_CHECK="$( cd "$SPARK_ROOT_DIR" && find . -name "*.py" )" -DOC_PATHS_TO_CHECK="$( cd "$SPARK_ROOT_DIR" && find . -name "*.py" | grep -vF 'functions.py' )" -PYCODESTYLE_REPORT_PATH="$SPARK_ROOT_DIR/dev/pycodestyle-report.txt" -PYDOCSTYLE_REPORT_PATH="$SPARK_ROOT_DIR/dev/pydocstyle-report.txt" -PYLINT_REPORT_PATH="$SPARK_ROOT_DIR/dev/pylint-report.txt" -PYLINT_INSTALL_INFO="$SPARK_ROOT_DIR/dev/pylint-info.txt" - -PYDOCSTYLEBUILD="pydocstyle" -MINIMUM_PYDOCSTYLEVERSION="3.0.0" - -FLAKE8BUILD="flake8" +FLAKE8_BUILD="flake8" MINIMUM_FLAKE8="3.5.0" -SPHINXBUILD=${SPHINXBUILD:=sphinx-build} -SPHINX_REPORT_PATH="$SPARK_ROOT_DIR/dev/sphinx-report.txt" +PYCODESTYLE_BUILD="pycodestyle" +MINIMUM_PYCODESTYLE="2.4.0" -cd "$SPARK_ROOT_DIR" +SPHINX_BUILD="sphinx-build" -# compileall: https://docs.python.org/2/library/compileall.html -python -B -m compileall -q -l $PATHS_TO_CHECK > "$PYCODESTYLE_REPORT_PATH" -compile_status="${PIPESTATUS[0]}" +function compile_python_test { + local COMPILE_STATUS= + local COMPILE_REPORT= + + if [[ ! "$1" ]]; then + echo "No python files found! Something is very wrong -- exiting." + exit 1; + fi -# Get pycodestyle at runtime so that we don't rely on it being installed on the build server. -# See: https://github.com/apache/spark/pull/1744#issuecomment-50982162 -# Updated to the latest official version of pep8. pep8 is formally renamed to pycodestyle. -PYCODESTYLE_VERSION="2.4.0" -PYCODESTYLE_SCRIPT_PATH="$SPARK_ROOT_DIR/dev/pycodestyle-$PYCODESTYLE_VERSION.py" -PYCODESTYLE_SCRIPT_REMOTE_PATH="https://raw.githubusercontent.com/PyCQA/pycodestyle/$PYCODESTYLE_VERSION/pycodestyle.py" + # compileall: https://docs.python.org/2/library/compileall.html + echo "starting python compilation test..." + COMPILE_REPORT=$( (python -B -mcompileall -q -l $1) 2>&1) + COMPILE_STATUS=$? + + if [ $COMPILE_STATUS -ne 0 ]; then + echo "Python compilation failed with the following errors:" + echo "$COMPILE_REPORT" + echo "$COMPILE_STATUS" + exit "$COMPILE_STATUS" + else + echo "python compilation succeeded." + echo + fi +} -if [ ! -e "$PYCODESTYLE_SCRIPT_PATH" ]; then - curl --silent -o "$PYCODESTYLE_SCRIPT_PATH" "$PYCODESTYLE_SCRIPT_REMOTE_PATH" - curl_status="$?" +function pycodestyle_test { + local PYCODESTYLE_STATUS= + local PYCODESTYLE_REPORT= + local RUN_LOCAL_PYCODESTYLE= + local VERSION= + local EXPECTED_PYCODESTYLE= + local PYCODESTYLE_SCRIPT_PATH="$SPARK_ROOT_DIR/dev/pycodestyle-$MINIMUM_PYCODESTYLE.py" + local PYCODESTYLE_SCRIPT_REMOTE_PATH="https://raw.githubusercontent.com/PyCQA/pycodestyle/$MINIMUM_PYCODESTYLE/pycodestyle.py" - if [ "$curl_status" -ne 0 ]; then - echo "Failed to download pycodestyle.py from \"$PYCODESTYLE_SCRIPT_REMOTE_PATH\"." - exit "$curl_status" + if [[ ! "$1" ]]; then + echo "No python files found! Something is very wrong -- exiting." + exit 1; fi -fi - -# Easy install pylint in /dev/pylint. To easy_install into a directory, the PYTHONPATH should -# be set to the directory. -# dev/pylint should be appended to the PATH variable as well. -# Jenkins by default installs the pylint3 version, so for now this just checks the code quality -# of python3. -export "PYTHONPATH=$SPARK_ROOT_DIR/dev/pylint" -export "PYLINT_HOME=$PYTHONPATH" -export "PATH=$PYTHONPATH:$PATH" - -# There is no need to write this output to a file -# first, but we do so so that the check status can -# be output before the report, like with the -# scalastyle and RAT checks. -python "$PYCODESTYLE_SCRIPT_PATH" --config=dev/tox.ini $PATHS_TO_CHECK >> "$PYCODESTYLE_REPORT_PATH" -pycodestyle_status="${PIPESTATUS[0]}" - -if [ "$compile_status" -eq 0 -a "$pycodestyle_status" -eq 0 ]; then - lint_status=0 -else - lint_status=1 -fi - -if [ "$lint_status" -ne 0 ]; then - echo "pycodestyle checks failed." - cat "$PYCODESTYLE_REPORT_PATH" - rm "$PYCODESTYLE_REPORT_PATH" - exit "$lint_status" -else - echo "pycodestyle checks passed." - rm "$PYCODESTYLE_REPORT_PATH" -fi - -# Check by flake8 -if hash "$FLAKE8BUILD" 2> /dev/null; then - FLAKE8VERSION="$( $FLAKE8BUILD --version 2> /dev/null )" - VERSION=($FLAKE8VERSION) - IS_EXPECTED_FLAKE8=$(python -c 'from distutils.version import LooseVersion; \ -print(LooseVersion("""'${VERSION[0]}'""") >= LooseVersion("""'$MINIMUM_FLAKE8'"""))' 2> /dev/null) - if [[ "$IS_EXPECTED_FLAKE8" == "True" ]]; then - # stop the build if there are Python syntax errors or undefined names - $FLAKE8BUILD . --count --select=E901,E999,F821,F822,F823 --max-line-length=100 --show-source --statistics - flake8_status="${PIPESTATUS[0]}" - - if [ "$flake8_status" -eq 0 ]; then - lint_status=0 - else - lint_status=1 + + # check for locally installed pycodestyle & version + RUN_LOCAL_PYCODESTYLE="False" + if hash "$PYCODESTYLE_BUILD" 2> /dev/null; then + VERSION=$( $PYCODESTYLE_BUILD --version 2> /dev/null) + EXPECTED_PYCODESTYLE=$( (python -c 'from distutils.version import LooseVersion; + print(LooseVersion("""'${VERSION[0]}'""") >= LooseVersion("""'$MINIMUM_PYCODESTYLE'"""))')\ + 2> /dev/null) + + if [ "$EXPECTED_PYCODESTYLE" == "True" ]; then + RUN_LOCAL_PYCODESTYLE="True" fi + fi - if [ "$lint_status" -ne 0 ]; then - echo "flake8 checks failed." - exit "$lint_status" - else - echo "flake8 checks passed." + # download the right version or run locally + if [ $RUN_LOCAL_PYCODESTYLE == "False" ]; then + # Get pycodestyle at runtime so that we don't rely on it being installed on the build server. + # See: https://github.com/apache/spark/pull/1744#issuecomment-50982162 + # Updated to the latest official version of pep8. pep8 is formally renamed to pycodestyle. + echo "downloading pycodestyle from $PYCODESTYLE_SCRIPT_REMOTE_PATH..." + if [ ! -e "$PYCODESTYLE_SCRIPT_PATH" ]; then + curl --silent -o "$PYCODESTYLE_SCRIPT_PATH" "$PYCODESTYLE_SCRIPT_REMOTE_PATH" + local curl_status="$?" + + if [ "$curl_status" -ne 0 ]; then + echo "Failed to download pycodestyle.py from $PYCODESTYLE_SCRIPT_REMOTE_PATH" + exit "$curl_status" + fi fi + + echo "starting pycodestyle test..." + PYCODESTYLE_REPORT=$( (python "$PYCODESTYLE_SCRIPT_PATH" --config=dev/tox.ini $1) 2>&1) + PYCODESTYLE_STATUS=$? + else + # we have the right version installed, so run locally + echo "starting pycodestyle test..." + PYCODESTYLE_REPORT=$( ($PYCODESTYLE_BUILD --config=dev/tox.ini $1) 2>&1) + PYCODESTYLE_STATUS=$? + fi + + if [ $PYCODESTYLE_STATUS -ne 0 ]; then + echo "pycodestyle checks failed:" + echo "$PYCODESTYLE_REPORT" + exit "$PYCODESTYLE_STATUS" else - echo "The flake8 version needs to be "$MINIMUM_FLAKE8" at latest. Your current version is '"$FLAKE8VERSION"'." + echo "pycodestyle checks passed." + echo + fi +} + +function flake8_test { + local FLAKE8_VERSION= + local VERSION= + local EXPECTED_FLAKE8= + local FLAKE8_REPORT= + local FLAKE8_STATUS= + + if ! hash "$FLAKE8_BUILD" 2> /dev/null; then + echo "The flake8 command was not found." echo "flake8 checks failed." exit 1 fi -else - echo >&2 "The flake8 command was not found." - echo "flake8 checks failed." - exit 1 -fi - -# Check python document style, skip check if pydocstyle is not installed. -if hash "$PYDOCSTYLEBUILD" 2> /dev/null; then - PYDOCSTYLEVERSION="$( $PYDOCSTYLEBUILD --version 2> /dev/null )" - IS_EXPECTED_PYDOCSTYLEVERSION=$(python -c 'from distutils.version import LooseVersion; \ -print(LooseVersion("""'$PYDOCSTYLEVERSION'""") >= LooseVersion("""'$MINIMUM_PYDOCSTYLEVERSION'"""))') - if [[ "$IS_EXPECTED_PYDOCSTYLEVERSION" == "True" ]]; then - $PYDOCSTYLEBUILD --config=dev/tox.ini $DOC_PATHS_TO_CHECK >> "$PYDOCSTYLE_REPORT_PATH" - pydocstyle_status="${PIPESTATUS[0]}" - - if [ "$compile_status" -eq 0 -a "$pydocstyle_status" -eq 0 ]; then - echo "pydocstyle checks passed." - rm "$PYDOCSTYLE_REPORT_PATH" - else - echo "pydocstyle checks failed." - cat "$PYDOCSTYLE_REPORT_PATH" - rm "$PYDOCSTYLE_REPORT_PATH" - exit 1 - fi + FLAKE8_VERSION="$($FLAKE8_BUILD --version 2> /dev/null)" + VERSION=($FLAKE8_VERSION) + EXPECTED_FLAKE8=$( (python -c 'from distutils.version import LooseVersion; + print(LooseVersion("""'${VERSION[0]}'""") >= LooseVersion("""'$MINIMUM_FLAKE8'"""))') \ + 2> /dev/null) + + if [[ "$EXPECTED_FLAKE8" == "False" ]]; then + echo "\ +The minimum flake8 version needs to be $MINIMUM_FLAKE8. Your current version is $FLAKE8_VERSION + +flake8 checks failed." + exit 1 + fi + + echo "starting $FLAKE8_BUILD test..." + FLAKE8_REPORT=$( ($FLAKE8_BUILD . --count --select=E901,E999,F821,F822,F823 \ + --max-line-length=100 --show-source --statistics) 2>&1) + FLAKE8_STATUS=$? + + if [ "$FLAKE8_STATUS" -ne 0 ]; then + echo "flake8 checks failed:" + echo "$FLAKE8_REPORT" + echo "$FLAKE8_STATUS" + exit "$FLAKE8_STATUS" else - echo "The pydocstyle version needs to be "$MINIMUM_PYDOCSTYLEVERSION" at latest. Your current version is "$PYDOCSTYLEVERSION". Skipping pydoc checks for now." + echo "flake8 checks passed." + echo fi -else - echo >&2 "The pydocstyle command was not found. Skipping pydoc checks for now" -fi - -# Check that the documentation builds acceptably, skip check if sphinx is not installed. -if hash "$SPHINXBUILD" 2> /dev/null; then - cd python/docs - make clean - # Treat warnings as errors so we stop correctly - SPHINXOPTS="-a -W" make html &> "$SPHINX_REPORT_PATH" || lint_status=1 - if [ "$lint_status" -ne 0 ]; then - echo "pydoc checks failed." - cat "$SPHINX_REPORT_PATH" - echo "re-running make html to print full warning list" - make clean - SPHINXOPTS="-a" make html - rm "$SPHINX_REPORT_PATH" - exit "$lint_status" - else - echo "pydoc checks passed." - rm "$SPHINX_REPORT_PATH" - fi - cd ../.. -else - echo >&2 "The $SPHINXBUILD command was not found. Skipping pydoc checks for now" -fi +} + +function pydocstyle_test { + local PYDOCSTYLE_REPORT= + local PYDOCSTYLE_STATUS= + local PYDOCSTYLE_VERSION= + local EXPECTED_PYDOCSTYLE= + + # Exclude auto-generated configuration file. + local DOC_PATHS_TO_CHECK="$( cd "${SPARK_ROOT_DIR}" && find . -name "*.py" | grep -vF 'functions.py' )" + + # Check python document style, skip check if pydocstyle is not installed. + if ! hash "$PYDOCSTYLE_BUILD" 2> /dev/null; then + echo "The pydocstyle command was not found. Skipping pydocstyle checks for now." + echo + return + fi + + PYDOCSTYLE_VERSION="$($PYDOCSTYLEBUILD --version 2> /dev/null)" + EXPECTED_PYDOCSTYLE=$(python -c 'from distutils.version import LooseVersion; \ + print(LooseVersion("""'$PYDOCSTYLE_VERSION'""") >= LooseVersion("""'$MINIMUM_PYDOCSTYLE'"""))' \ + 2> /dev/null) + + if [[ "$EXPECTED_PYDOCSTYLE" == "False" ]]; then + echo "\ +The minimum version of pydocstyle needs to be $MINIMUM_PYDOCSTYLE. +Your current version is $PYDOCSTYLE_VERSION. +Skipping pydocstyle checks for now." + echo + return + fi + + echo "starting $PYDOCSTYLE_BUILD test..." + PYDOCSTYLE_REPORT=$( ($PYDOCSTYLE_BUILD --config=dev/tox.ini $DOC_PATHS_TO_CHECK) 2>&1) + PYDOCSTYLE_STATUS=$? + + if [ "$PYDOCSTYLE_STATUS" -ne 0 ]; then + echo "pydocstyle checks failed:" + echo "$PYDOCSTYLE_REPORT" + exit "$PYDOCSTYLE_STATUS" + else + echo "pydocstyle checks passed." + echo + fi +} + +function sphinx_test { + local SPHINX_REPORT= + local SPHINX_STATUS= + + # Check that the documentation builds acceptably, skip check if sphinx is not installed. + if ! hash "$SPHINX_BUILD" 2> /dev/null; then + echo "The $SPHINX_BUILD command was not found. Skipping pydoc checks for now." + echo + return + fi + + echo "starting $SPHINX_BUILD tests..." + pushd python/docs &> /dev/null + make clean &> /dev/null + # Treat warnings as errors so we stop correctly + SPHINX_REPORT=$( (SPHINXOPTS="-a -W" make html) 2>&1) + SPHINX_STATUS=$? + + if [ "$SPHINX_STATUS" -ne 0 ]; then + echo "$SPHINX_BUILD checks failed:" + echo "$SPHINX_REPORT" + echo + echo "re-running make html to print full warning list:" + make clean &> /dev/null + SPHINX_REPORT=$( (SPHINXOPTS="-a" make html) 2>&1) + echo "$SPHINX_REPORT" + exit "$SPHINX_STATUS" + else + echo "$SPHINX_BUILD checks passed." + echo + fi + + popd &> /dev/null +} + +SCRIPT_DIR="$( cd "$( dirname "$0" )" && pwd )" +SPARK_ROOT_DIR="$(dirname "${SCRIPT_DIR}")" + +pushd "$SPARK_ROOT_DIR" &> /dev/null + +PYTHON_SOURCE="$(find . -name "*.py")" + +compile_python_test "$PYTHON_SOURCE" +pycodestyle_test "$PYTHON_SOURCE" +flake8_test +pydocstyle_test +sphinx_test + +echo +echo "all lint-python tests passed!" + +popd &> /dev/null diff --git a/dev/make-distribution.sh b/dev/make-distribution.sh index e4405d89a1c17..1bde61b998ca5 100755 --- a/dev/make-distribution.sh +++ b/dev/make-distribution.sh @@ -173,7 +173,7 @@ fi # Normal quoting tricks don't work. # See: http://mywiki.wooledge.org/BashFAQ/050 if [[ -z "$DONT_BUILD" ]]; then - BUILD_COMMAND=("$MVN" -T 1C $MAYBE_CLEAN package -DskipTests $@) + BUILD_COMMAND=("$MVN" $MAYBE_CLEAN package -DskipTests $@) # Actually build the jar echo -e "\nBuilding with..." diff --git a/dev/sparktestsupport/modules.py b/dev/sparktestsupport/modules.py index 5db8efda90774..db44c3bff9083 100644 --- a/dev/sparktestsupport/modules.py +++ b/dev/sparktestsupport/modules.py @@ -310,6 +310,7 @@ def __hash__(self): "python/(?!pyspark/(ml|mllib|sql|streaming))" ], python_test_goals=[ + # doctests "pyspark.rdd", "pyspark.context", "pyspark.conf", @@ -318,10 +319,22 @@ def __hash__(self): "pyspark.serializers", "pyspark.profiler", "pyspark.shuffle", - "pyspark.tests", - "pyspark.test_broadcast", - "pyspark.test_serializers", "pyspark.util", + # unittests + "pyspark.tests.test_appsubmit", + "pyspark.tests.test_broadcast", + "pyspark.tests.test_conf", + "pyspark.tests.test_context", + "pyspark.tests.test_daemon", + "pyspark.tests.test_join", + "pyspark.tests.test_profiler", + "pyspark.tests.test_rdd", + "pyspark.tests.test_readwrite", + "pyspark.tests.test_serializers", + "pyspark.tests.test_shuffle", + "pyspark.tests.test_taskcontext", + "pyspark.tests.test_util", + "pyspark.tests.test_worker", ] ) @@ -333,6 +346,7 @@ def __hash__(self): "python/pyspark/sql" ], python_test_goals=[ + # doctests "pyspark.sql.types", "pyspark.sql.context", "pyspark.sql.session", @@ -346,7 +360,29 @@ def __hash__(self): "pyspark.sql.streaming", "pyspark.sql.udf", "pyspark.sql.window", - "pyspark.sql.tests", + # unittests + "pyspark.sql.tests.test_appsubmit", + "pyspark.sql.tests.test_arrow", + "pyspark.sql.tests.test_catalog", + "pyspark.sql.tests.test_column", + "pyspark.sql.tests.test_conf", + "pyspark.sql.tests.test_context", + "pyspark.sql.tests.test_dataframe", + "pyspark.sql.tests.test_datasources", + "pyspark.sql.tests.test_functions", + "pyspark.sql.tests.test_group", + "pyspark.sql.tests.test_pandas_udf", + "pyspark.sql.tests.test_pandas_udf_grouped_agg", + "pyspark.sql.tests.test_pandas_udf_grouped_map", + "pyspark.sql.tests.test_pandas_udf_scalar", + "pyspark.sql.tests.test_pandas_udf_window", + "pyspark.sql.tests.test_readwriter", + "pyspark.sql.tests.test_serde", + "pyspark.sql.tests.test_session", + "pyspark.sql.tests.test_streaming", + "pyspark.sql.tests.test_types", + "pyspark.sql.tests.test_udf", + "pyspark.sql.tests.test_utils", ] ) @@ -362,8 +398,13 @@ def __hash__(self): "python/pyspark/streaming" ], python_test_goals=[ + # doctests "pyspark.streaming.util", - "pyspark.streaming.tests", + # unittests + "pyspark.streaming.tests.test_context", + "pyspark.streaming.tests.test_dstream", + "pyspark.streaming.tests.test_kinesis", + "pyspark.streaming.tests.test_listener", ] ) @@ -375,6 +416,7 @@ def __hash__(self): "python/pyspark/mllib" ], python_test_goals=[ + # doctests "pyspark.mllib.classification", "pyspark.mllib.clustering", "pyspark.mllib.evaluation", @@ -389,7 +431,13 @@ def __hash__(self): "pyspark.mllib.stat.KernelDensity", "pyspark.mllib.tree", "pyspark.mllib.util", - "pyspark.mllib.tests", + # unittests + "pyspark.mllib.tests.test_algorithms", + "pyspark.mllib.tests.test_feature", + "pyspark.mllib.tests.test_linalg", + "pyspark.mllib.tests.test_stat", + "pyspark.mllib.tests.test_streaming_algorithms", + "pyspark.mllib.tests.test_util", ], blacklisted_python_implementations=[ "PyPy" # Skip these tests under PyPy since they require numpy and it isn't available there @@ -404,6 +452,7 @@ def __hash__(self): "python/pyspark/ml/" ], python_test_goals=[ + # doctests "pyspark.ml.classification", "pyspark.ml.clustering", "pyspark.ml.evaluation", @@ -415,7 +464,20 @@ def __hash__(self): "pyspark.ml.regression", "pyspark.ml.stat", "pyspark.ml.tuning", - "pyspark.ml.tests", + # unittests + "pyspark.ml.tests.test_algorithms", + "pyspark.ml.tests.test_base", + "pyspark.ml.tests.test_evaluation", + "pyspark.ml.tests.test_feature", + "pyspark.ml.tests.test_image", + "pyspark.ml.tests.test_linalg", + "pyspark.ml.tests.test_param", + "pyspark.ml.tests.test_persistence", + "pyspark.ml.tests.test_pipeline", + "pyspark.ml.tests.test_stat", + "pyspark.ml.tests.test_training_summary", + "pyspark.ml.tests.test_tuning", + "pyspark.ml.tests.test_wrapper", ], blacklisted_python_implementations=[ "PyPy" # Skip these tests under PyPy since they require numpy and it isn't available there diff --git a/dev/test_functions.py b/dev/test_functions.py index 3c9c7ccdefd25..14b038176e1c5 100755 --- a/dev/test_functions.py +++ b/dev/test_functions.py @@ -481,7 +481,7 @@ def parse_opts(): prog="run-tests" ) parser.add_option( - "-p", "--parallelism", type="int", default=4, + "-p", "--parallelism", type="int", default=8, help="The number of suites to test in parallel (default %default)" ) diff --git a/dists/hadoop-palantir-bom/pom.xml b/dists/hadoop-palantir-bom/pom.xml index 4dcbb5149f750..523997699e9e6 100644 --- a/dists/hadoop-palantir-bom/pom.xml +++ b/dists/hadoop-palantir-bom/pom.xml @@ -20,12 +20,12 @@ 4.0.0 org.apache.spark - spark-parent_2.11 + spark-parent_2.12 3.0.0-SNAPSHOT ../../pom.xml - spark-dist_2.11-hadoop-palantir-bom + spark-dist_2.12-hadoop-palantir-bom Spark Project Dist Palantir Hadoop (BOM) http://spark.apache.org/ pom @@ -107,7 +107,7 @@ org.apache.spark - spark-avro_2.11 + spark-avro_${scala.binary.version} ${project.version} diff --git a/dists/hadoop-palantir/pom.xml b/dists/hadoop-palantir/pom.xml index d3158b1bbef10..a3bb85e70f844 100644 --- a/dists/hadoop-palantir/pom.xml +++ b/dists/hadoop-palantir/pom.xml @@ -20,12 +20,12 @@ 4.0.0 org.apache.spark - spark-dist_2.11-hadoop-palantir-bom + spark-dist_2.12-hadoop-palantir-bom 3.0.0-SNAPSHOT ../hadoop-palantir-bom/pom.xml - spark-dist_2.11-hadoop-palantir + spark-dist_2.12-hadoop-palantir Spark Project Dist Palantir Hadoop http://spark.apache.org/ pom @@ -57,7 +57,7 @@ org.apache.spark - spark-avro_2.11 + spark-avro_${scala.binary.version} org.apache.spark diff --git a/docs/_config.yml b/docs/_config.yml index c3ef98575fa62..649d18bf72b57 100644 --- a/docs/_config.yml +++ b/docs/_config.yml @@ -16,8 +16,8 @@ include: # of Spark, Scala, and Mesos. SPARK_VERSION: 3.0.0-SNAPSHOT SPARK_VERSION_SHORT: 3.0.0 -SCALA_BINARY_VERSION: "2.11" -SCALA_VERSION: "2.11.12" +SCALA_BINARY_VERSION: "2.12" +SCALA_VERSION: "2.12.7" MESOS_VERSION: 1.0.0 SPARK_ISSUE_TRACKER_URL: https://issues.apache.org/jira/browse/SPARK SPARK_GITHUB_URL: https://github.com/apache/spark diff --git a/docs/_plugins/copy_api_dirs.rb b/docs/_plugins/copy_api_dirs.rb index 4d0d043a349bb..2d1a9547e3731 100644 --- a/docs/_plugins/copy_api_dirs.rb +++ b/docs/_plugins/copy_api_dirs.rb @@ -37,7 +37,7 @@ # Copy over the unified ScalaDoc for all projects to api/scala. # This directory will be copied over to _site when `jekyll` command is run. - source = "../target/scala-2.11/unidoc" + source = "../target/scala-2.12/unidoc" dest = "api/scala" puts "Making directory " + dest diff --git a/docs/building-spark.md b/docs/building-spark.md index 8af90db9a19dd..dfcd53c48e85c 100644 --- a/docs/building-spark.md +++ b/docs/building-spark.md @@ -96,9 +96,9 @@ It's possible to build Spark submodules using the `mvn -pl` option. For instance, you can build the Spark Streaming module using: - ./build/mvn -pl :spark-streaming_2.11 clean install + ./build/mvn -pl :spark-streaming_{{site.SCALA_BINARY_VERSION}} clean install -where `spark-streaming_2.11` is the `artifactId` as defined in `streaming/pom.xml` file. +where `spark-streaming_{{site.SCALA_BINARY_VERSION}}` is the `artifactId` as defined in `streaming/pom.xml` file. ## Continuous Compilation @@ -230,7 +230,7 @@ Once installed, the `docker` service needs to be started, if not already running On Linux, this can be done by `sudo service docker start`. ./build/mvn install -DskipTests - ./build/mvn test -Pdocker-integration-tests -pl :spark-docker-integration-tests_2.11 + ./build/mvn test -Pdocker-integration-tests -pl :spark-docker-integration-tests_{{site.SCALA_BINARY_VERSION}} or @@ -238,17 +238,17 @@ or ## Change Scala Version -To build Spark using another supported Scala version, please change the major Scala version using (e.g. 2.12): +To build Spark using another supported Scala version, please change the major Scala version using (e.g. 2.11): - ./dev/change-scala-version.sh 2.12 + ./dev/change-scala-version.sh 2.11 -For Maven, please enable the profile (e.g. 2.12): +For Maven, please enable the profile (e.g. 2.11): - ./build/mvn -Pscala-2.12 compile + ./build/mvn -Pscala-2.11 compile -For SBT, specify a complete scala version using (e.g. 2.12.6): +For SBT, specify a complete scala version using (e.g. 2.11.12): - ./build/sbt -Dscala.version=2.12.6 + ./build/sbt -Dscala.version=2.11.12 Otherwise, the sbt-pom-reader plugin will use the `scala.version` specified in the spark-parent pom. diff --git a/docs/cloud-integration.md b/docs/cloud-integration.md index fc24fa3ec4871..ccbbaaf886886 100644 --- a/docs/cloud-integration.md +++ b/docs/cloud-integration.md @@ -85,7 +85,7 @@ is set to the chosen version of Spark: ... org.apache.spark - spark-hadoop-cloud_2.11 + hadoop-cloud_{{site.SCALA_BINARY_VERSION}} ${spark.version} ... diff --git a/docs/configuration.md b/docs/configuration.md index f8937b0bc61a0..04210d855b110 100644 --- a/docs/configuration.md +++ b/docs/configuration.md @@ -266,6 +266,39 @@ of the most common options to set are: Only has effect in Spark standalone mode or Mesos cluster deploy mode. + + spark.driver.log.dfsDir + (none) + + Base directory in which Spark driver logs are synced, if spark.driver.log.persistToDfs.enabled + is true. Within this base directory, each application logs the driver logs to an application specific file. + Users may want to set this to a unified location like an HDFS directory so driver log files can be persisted + for later usage. This directory should allow any Spark user to read/write files and the Spark History Server + user to delete files. Additionally, older logs from this directory are cleaned by the + Spark History Server if + spark.history.fs.driverlog.cleaner.enabled is true and, if they are older than max age configured + by setting spark.history.fs.driverlog.cleaner.maxAge. + + + + spark.driver.log.persistToDfs.enabled + false + + If true, spark application running in client mode will write driver logs to a persistent storage, configured + in spark.driver.log.dfsDir. If spark.driver.log.dfsDir is not configured, driver logs + will not be persisted. Additionally, enable the cleaner by setting spark.history.fs.driverlog.cleaner.enabled + to true in Spark History Server. + + + + spark.driver.log.layout + %d{yy/MM/dd HH:mm:ss.SSS} %t %p %c{1}: %m%n + + The layout for the driver logs that are synced to spark.driver.log.dfsDir. If this is not configured, + it uses the layout for the first appender defined in log4j.properties. If that is also not configured, driver logs + use the default layout. + + Apart from these, the following properties are also available, and may be useful in some situations: @@ -940,6 +973,14 @@ Apart from these, the following properties are also available, and may be useful
    spark.com.test.filter1.param.name2=bar + + spark.ui.requestHeaderSize + 8k + + The maximum allowed size for a HTTP request header, in bytes unless otherwise specified. + This setting applies for the Spark History Server too. + + ### Compression and Serialization diff --git a/docs/index.md b/docs/index.md index b9996cc8645d9..b842fca6245ae 100644 --- a/docs/index.md +++ b/docs/index.md @@ -31,7 +31,8 @@ Spark runs on both Windows and UNIX-like systems (e.g. Linux, Mac OS). It's easy locally on one machine --- all you need is to have `java` installed on your system `PATH`, or the `JAVA_HOME` environment variable pointing to a Java installation. -Spark runs on Java 8+, Python 2.7+/3.4+ and R 3.1+. For the Scala API, Spark {{site.SPARK_VERSION}} +Spark runs on Java 8+, Python 2.7+/3.4+ and R 3.1+. R prior to version 3.4 support is deprecated as of Spark 3.0.0. +For the Scala API, Spark {{site.SPARK_VERSION}} uses Scala {{site.SCALA_BINARY_VERSION}}. You will need to use a compatible Scala version ({{site.SCALA_BINARY_VERSION}}.x). diff --git a/docs/ml-classification-regression.md b/docs/ml-classification-regression.md index b3d109039da4d..42912a2e2bc31 100644 --- a/docs/ml-classification-regression.md +++ b/docs/ml-classification-regression.md @@ -941,9 +941,9 @@ Essentially isotonic regression is a best fitting the original data points. We implement a -[pool adjacent violators algorithm](http://doi.org/10.1198/TECH.2010.10111) +[pool adjacent violators algorithm](https://doi.org/10.1198/TECH.2010.10111) which uses an approach to -[parallelizing isotonic regression](http://doi.org/10.1007/978-3-642-99789-1_10). +[parallelizing isotonic regression](https://doi.org/10.1007/978-3-642-99789-1_10). The training input is a DataFrame which contains three columns label, features and weight. Additionally, IsotonicRegression algorithm has one optional parameter called $isotonic$ defaulting to true. diff --git a/docs/ml-collaborative-filtering.md b/docs/ml-collaborative-filtering.md index 8b0f287dc39ad..58646642bfbcc 100644 --- a/docs/ml-collaborative-filtering.md +++ b/docs/ml-collaborative-filtering.md @@ -41,7 +41,7 @@ for example, users giving ratings to movies. It is common in many real-world use cases to only have access to *implicit feedback* (e.g. views, clicks, purchases, likes, shares etc.). The approach used in `spark.ml` to deal with such data is taken -from [Collaborative Filtering for Implicit Feedback Datasets](http://dx.doi.org/10.1109/ICDM.2008.22). +from [Collaborative Filtering for Implicit Feedback Datasets](https://doi.org/10.1109/ICDM.2008.22). Essentially, instead of trying to model the matrix of ratings directly, this approach treats the data as numbers representing the *strength* in observations of user actions (such as the number of clicks, or the cumulative duration someone spent viewing a movie). Those numbers are then related to the level of @@ -55,7 +55,7 @@ We scale the regularization parameter `regParam` in solving each least squares p the number of ratings the user generated in updating user factors, or the number of ratings the product received in updating product factors. This approach is named "ALS-WR" and discussed in the paper -"[Large-Scale Parallel Collaborative Filtering for the Netflix Prize](http://dx.doi.org/10.1007/978-3-540-68880-8_32)". +"[Large-Scale Parallel Collaborative Filtering for the Netflix Prize](https://doi.org/10.1007/978-3-540-68880-8_32)". It makes `regParam` less dependent on the scale of the dataset, so we can apply the best parameter learned from a sampled subset to the full dataset and expect similar performance. diff --git a/docs/ml-frequent-pattern-mining.md b/docs/ml-frequent-pattern-mining.md index c2043d495c149..f613664271ec6 100644 --- a/docs/ml-frequent-pattern-mining.md +++ b/docs/ml-frequent-pattern-mining.md @@ -18,7 +18,7 @@ for more information. ## FP-Growth The FP-growth algorithm is described in the paper -[Han et al., Mining frequent patterns without candidate generation](http://dx.doi.org/10.1145/335191.335372), +[Han et al., Mining frequent patterns without candidate generation](https://doi.org/10.1145/335191.335372), where "FP" stands for frequent pattern. Given a dataset of transactions, the first step of FP-growth is to calculate item frequencies and identify frequent items. Different from [Apriori-like](http://en.wikipedia.org/wiki/Apriori_algorithm) algorithms designed for the same purpose, @@ -26,7 +26,7 @@ the second step of FP-growth uses a suffix tree (FP-tree) structure to encode tr explicitly, which are usually expensive to generate. After the second step, the frequent itemsets can be extracted from the FP-tree. In `spark.mllib`, we implemented a parallel version of FP-growth called PFP, -as described in [Li et al., PFP: Parallel FP-growth for query recommendation](http://dx.doi.org/10.1145/1454008.1454027). +as described in [Li et al., PFP: Parallel FP-growth for query recommendation](https://doi.org/10.1145/1454008.1454027). PFP distributes the work of growing FP-trees based on the suffixes of transactions, and hence is more scalable than a single-machine implementation. We refer users to the papers for more details. @@ -90,7 +90,7 @@ Refer to the [R API docs](api/R/spark.fpGrowth.html) for more details. PrefixSpan is a sequential pattern mining algorithm described in [Pei et al., Mining Sequential Patterns by Pattern-Growth: The -PrefixSpan Approach](http://dx.doi.org/10.1109%2FTKDE.2004.77). We refer +PrefixSpan Approach](https://doi.org/10.1109%2FTKDE.2004.77). We refer the reader to the referenced paper for formalizing the sequential pattern mining problem. @@ -137,4 +137,4 @@ Refer to the [R API docs](api/R/spark.prefixSpan.html) for more details. {% include_example r/ml/prefixSpan.R %} - \ No newline at end of file + diff --git a/docs/mllib-collaborative-filtering.md b/docs/mllib-collaborative-filtering.md index b2300028e151b..aeebb26bb45f3 100644 --- a/docs/mllib-collaborative-filtering.md +++ b/docs/mllib-collaborative-filtering.md @@ -37,7 +37,7 @@ for example, users giving ratings to movies. It is common in many real-world use cases to only have access to *implicit feedback* (e.g. views, clicks, purchases, likes, shares etc.). The approach used in `spark.mllib` to deal with such data is taken -from [Collaborative Filtering for Implicit Feedback Datasets](http://dx.doi.org/10.1109/ICDM.2008.22). +from [Collaborative Filtering for Implicit Feedback Datasets](https://doi.org/10.1109/ICDM.2008.22). Essentially, instead of trying to model the matrix of ratings directly, this approach treats the data as numbers representing the *strength* in observations of user actions (such as the number of clicks, or the cumulative duration someone spent viewing a movie). Those numbers are then related to the level of @@ -51,7 +51,7 @@ Since v1.1, we scale the regularization parameter `lambda` in solving each least the number of ratings the user generated in updating user factors, or the number of ratings the product received in updating product factors. This approach is named "ALS-WR" and discussed in the paper -"[Large-Scale Parallel Collaborative Filtering for the Netflix Prize](http://dx.doi.org/10.1007/978-3-540-68880-8_32)". +"[Large-Scale Parallel Collaborative Filtering for the Netflix Prize](https://doi.org/10.1007/978-3-540-68880-8_32)". It makes `lambda` less dependent on the scale of the dataset, so we can apply the best parameter learned from a sampled subset to the full dataset and expect similar performance. diff --git a/docs/mllib-frequent-pattern-mining.md b/docs/mllib-frequent-pattern-mining.md index 0d3192c6b1d9c..8e4505756b275 100644 --- a/docs/mllib-frequent-pattern-mining.md +++ b/docs/mllib-frequent-pattern-mining.md @@ -15,7 +15,7 @@ a popular algorithm to mining frequent itemsets. ## FP-growth The FP-growth algorithm is described in the paper -[Han et al., Mining frequent patterns without candidate generation](http://dx.doi.org/10.1145/335191.335372), +[Han et al., Mining frequent patterns without candidate generation](https://doi.org/10.1145/335191.335372), where "FP" stands for frequent pattern. Given a dataset of transactions, the first step of FP-growth is to calculate item frequencies and identify frequent items. Different from [Apriori-like](http://en.wikipedia.org/wiki/Apriori_algorithm) algorithms designed for the same purpose, @@ -23,7 +23,7 @@ the second step of FP-growth uses a suffix tree (FP-tree) structure to encode tr explicitly, which are usually expensive to generate. After the second step, the frequent itemsets can be extracted from the FP-tree. In `spark.mllib`, we implemented a parallel version of FP-growth called PFP, -as described in [Li et al., PFP: Parallel FP-growth for query recommendation](http://dx.doi.org/10.1145/1454008.1454027). +as described in [Li et al., PFP: Parallel FP-growth for query recommendation](https://doi.org/10.1145/1454008.1454027). PFP distributes the work of growing FP-trees based on the suffixes of transactions, and hence more scalable than a single-machine implementation. We refer users to the papers for more details. @@ -122,7 +122,7 @@ Refer to the [`AssociationRules` Java docs](api/java/org/apache/spark/mllib/fpm/ PrefixSpan is a sequential pattern mining algorithm described in [Pei et al., Mining Sequential Patterns by Pattern-Growth: The -PrefixSpan Approach](http://dx.doi.org/10.1109%2FTKDE.2004.77). We refer +PrefixSpan Approach](https://doi.org/10.1109%2FTKDE.2004.77). We refer the reader to the referenced paper for formalizing the sequential pattern mining problem. diff --git a/docs/mllib-isotonic-regression.md b/docs/mllib-isotonic-regression.md index 99cab98c690c6..9964fce3273be 100644 --- a/docs/mllib-isotonic-regression.md +++ b/docs/mllib-isotonic-regression.md @@ -24,9 +24,9 @@ Essentially isotonic regression is a best fitting the original data points. `spark.mllib` supports a -[pool adjacent violators algorithm](http://doi.org/10.1198/TECH.2010.10111) +[pool adjacent violators algorithm](https://doi.org/10.1198/TECH.2010.10111) which uses an approach to -[parallelizing isotonic regression](http://doi.org/10.1007/978-3-642-99789-1_10). +[parallelizing isotonic regression](https://doi.org/10.1007/978-3-642-99789-1_10). The training input is an RDD of tuples of three double values that represent label, feature and weight in this order. Additionally, IsotonicRegression algorithm has one optional parameter called $isotonic$ defaulting to true. diff --git a/docs/monitoring.md b/docs/monitoring.md index 69bf3082f0f27..6bb620a2e5f69 100644 --- a/docs/monitoring.md +++ b/docs/monitoring.md @@ -202,6 +202,28 @@ Security options for the Spark History Server are covered more detail in the applications that fail to rename their event logs listed as in-progress. + + spark.history.fs.driverlog.cleaner.enabled + spark.history.fs.cleaner.enabled + + Specifies whether the History Server should periodically clean up driver logs from storage. + + + + spark.history.fs.driverlog.cleaner.interval + spark.history.fs.cleaner.interval + + How often the filesystem driver log cleaner checks for files to delete. + Files are only deleted if they are older than spark.history.fs.driverlog.cleaner.maxAge + + + + spark.history.fs.driverlog.cleaner.maxAge + spark.history.fs.cleaner.maxAge + + Driver log files older than this will be deleted when the driver log cleaner runs. + + spark.history.fs.numReplayThreads 25% of available cores diff --git a/docs/running-on-kubernetes.md b/docs/running-on-kubernetes.md index 29a4d86fecb44..ec9514ea4134a 100644 --- a/docs/running-on-kubernetes.md +++ b/docs/running-on-kubernetes.md @@ -15,7 +15,19 @@ container images and entrypoints.** # Security Security in Spark is OFF by default. This could mean you are vulnerable to attack by default. -Please see [Spark Security](security.html) and the specific security sections in this doc before running Spark. +Please see [Spark Security](security.html) and the specific advice below before running Spark. + +## User Identity + +Images built from the project provided Dockerfiles do not contain any [`USER`](https://docs.docker.com/engine/reference/builder/#user) directives. This means that the resulting images will be running the Spark processes as `root` inside the container. On unsecured clusters this may provide an attack vector for privilege escalation and container breakout. Therefore security conscious deployments should consider providing custom images with `USER` directives specifying an unprivileged UID and GID. + +Alternatively the [Pod Template](#pod-template) feature can be used to add a [Security Context](https://kubernetes.io/docs/tasks/configure-pod-container/security-context/#volumes-and-file-systems) with a `runAsUser` to the pods that Spark submits. Please bear in mind that this requires cooperation from your users and as such may not be a suitable solution for shared environments. Cluster administrators should use [Pod Security Policies](https://kubernetes.io/docs/concepts/policy/pod-security-policy/#users-and-groups) if they wish to limit the users that pods may run as. + +## Volume Mounts + +As described later in this document under [Using Kubernetes Volumes](#using-kubernetes-volumes) Spark on K8S provides configuration options that allow for mounting certain volume types into the driver and executor pods. In particular it allows for [`hostPath`](https://kubernetes.io/docs/concepts/storage/volumes/#hostpath) volumes which as described in the Kubernetes documentation have known security vulnerabilities. + +Cluster administrators should use [Pod Security Policies](https://kubernetes.io/docs/concepts/policy/pod-security-policy/) to limit the ability to mount `hostPath` volumes appropriately for their environments. # Prerequisites @@ -76,6 +88,18 @@ $ ./bin/docker-image-tool.sh -r -t my-tag build $ ./bin/docker-image-tool.sh -r -t my-tag push ``` +By default `bin/docker-image-tool.sh` builds docker image for running JVM jobs. You need to opt-in to build additional +language binding docker images. + +Example usage is +```bash +# To build additional PySpark docker image +$ ./bin/docker-image-tool.sh -r -t my-tag -p ./kubernetes/dockerfiles/spark/bindings/python/Dockerfile build + +# To build additional SparkR docker image +$ ./bin/docker-image-tool.sh -r -t my-tag -R ./kubernetes/dockerfiles/spark/bindings/R/Dockerfile build +``` + ## Cluster Mode To launch Spark Pi in cluster mode, @@ -142,7 +166,7 @@ hostname via `spark.driver.host` and your spark driver's port to `spark.driver.p ### Client Mode Executor Pod Garbage Collection -If you run your Spark driver in a pod, it is highly recommended to set `spark.driver.pod.name` to the name of that pod. +If you run your Spark driver in a pod, it is highly recommended to set `spark.kubernetes.driver.pod.name` to the name of that pod. When this property is set, the Spark scheduler will deploy the executor pods with an [OwnerReference](https://kubernetes.io/docs/concepts/workloads/controllers/garbage-collection/), which in turn will ensure that once the driver pod is deleted from the cluster, all of the application's executor pods will also be deleted. @@ -151,7 +175,7 @@ an OwnerReference pointing to that pod will be added to each executor pod's Owne setting the OwnerReference to a pod that is not actually that driver pod, or else the executors may be terminated prematurely when the wrong pod is deleted. -If your application is not running inside a pod, or if `spark.driver.pod.name` is not set when your application is +If your application is not running inside a pod, or if `spark.kubernetes.driver.pod.name` is not set when your application is actually running in a pod, keep in mind that the executor pods may not be properly deleted from the cluster when the application exits. The Spark scheduler attempts to delete these pods, but if the network request to the API server fails for any reason, these pods will remain in the cluster. The executor processes should exit when they cannot reach the @@ -214,11 +238,14 @@ Starting with Spark 2.4.0, users can mount the following types of Kubernetes [vo * [emptyDir](https://kubernetes.io/docs/concepts/storage/volumes/#emptydir): an initially empty volume created when a pod is assigned to a node. * [persistentVolumeClaim](https://kubernetes.io/docs/concepts/storage/volumes/#persistentvolumeclaim): used to mount a `PersistentVolume` into a pod. +**NB:** Please see the [Security](#security) section of this document for security issues related to volume mounts. + To mount a volume of any of the types above into the driver pod, use the following configuration property: ``` --conf spark.kubernetes.driver.volumes.[VolumeType].[VolumeName].mount.path= --conf spark.kubernetes.driver.volumes.[VolumeType].[VolumeName].mount.readOnly= +--conf spark.kubernetes.driver.volumes.[VolumeType].[VolumeName].mount.subPath= ``` Specifically, `VolumeType` can be one of the following values: `hostPath`, `emptyDir`, and `persistentVolumeClaim`. `VolumeName` is the name you want to use for the volume under the `volumes` field in the pod specification. @@ -780,6 +807,14 @@ specific to Spark on Kubernetes. spark.kubernetes.driver.volumes.persistentVolumeClaim.checkpointpvc.mount.path=/checkpoint. + + spark.kubernetes.driver.volumes.[VolumeType].[VolumeName].mount.subPath + (none) + + Specifies a subpath to be mounted from the volume into the driver pod. + spark.kubernetes.driver.volumes.persistentVolumeClaim.checkpointpvc.mount.subPath=checkpoint. + + spark.kubernetes.driver.volumes.[VolumeType].[VolumeName].mount.readOnly (none) @@ -804,6 +839,14 @@ specific to Spark on Kubernetes. spark.kubernetes.executor.volumes.persistentVolumeClaim.checkpointpvc.mount.path=/checkpoint. + + spark.kubernetes.executor.volumes.[VolumeType].[VolumeName].mount.subPath + (none) + + Specifies a subpath to be mounted from the volume into the executor pod. + spark.kubernetes.executor.volumes.persistentVolumeClaim.checkpointpvc.mount.subPath=checkpoint. + + spark.kubernetes.executor.volumes.[VolumeType].[VolumeName].mount.readOnly false diff --git a/docs/security.md b/docs/security.md index 2f7fa9c6179f4..02d581c6dad91 100644 --- a/docs/security.md +++ b/docs/security.md @@ -821,3 +821,14 @@ should correspond to the super user who is running the Spark History Server. This will allow all users to write to the directory but will prevent unprivileged users from reading, removing or renaming a file unless they own it. The event log files will be created by Spark with permissions such that only the user and group have read and write access. + +# Persisting driver logs in client mode + +If your applications persist driver logs in client mode by enabling `spark.driver.log.persistToDfs.enabled`, +the directory where the driver logs go (`spark.driver.log.dfsDir`) should be manually created with proper +permissions. To secure the log files, the directory permissions should be set to `drwxrwxrwxt`. The owner +and group of the directory should correspond to the super user who is running the Spark History Server. + +This will allow all users to write to the directory but will prevent unprivileged users from +reading, removing or renaming a file unless they own it. The driver log files will be created by +Spark with permissions such that only the user and group have read and write access. diff --git a/docs/sparkr.md b/docs/sparkr.md index f84ec504b595a..5972435a0e409 100644 --- a/docs/sparkr.md +++ b/docs/sparkr.md @@ -133,7 +133,7 @@ specifying `--packages` with `spark-submit` or `sparkR` commands, or if initiali
    {% highlight r %} -sparkR.session(sparkPackages = "com.databricks:spark-avro_2.11:3.0.0") +sparkR.session(sparkPackages = "org.apache.spark:spark-avro_{{site.SCALA_BINARY_VERSION}}:{{site.SPARK_VERSION}}") {% endhighlight %}
    diff --git a/docs/sql-data-sources-hive-tables.md b/docs/sql-data-sources-hive-tables.md index 687e6f8e0a7cc..28e1a39626666 100644 --- a/docs/sql-data-sources-hive-tables.md +++ b/docs/sql-data-sources-hive-tables.md @@ -115,7 +115,7 @@ The following options can be used to configure the version of Hive that is used 1.2.1 Version of the Hive metastore. Available - options are 0.12.0 through 2.3.3. + options are 0.12.0 through 2.3.4. diff --git a/docs/sql-migration-guide-hive-compatibility.md b/docs/sql-migration-guide-hive-compatibility.md index 94849418030ef..dd7b06225714f 100644 --- a/docs/sql-migration-guide-hive-compatibility.md +++ b/docs/sql-migration-guide-hive-compatibility.md @@ -10,7 +10,7 @@ displayTitle: Compatibility with Apache Hive Spark SQL is designed to be compatible with the Hive Metastore, SerDes and UDFs. Currently, Hive SerDes and UDFs are based on Hive 1.2.1, and Spark SQL can be connected to different versions of Hive Metastore -(from 0.12.0 to 2.3.3. Also see [Interacting with Different Versions of Hive Metastore](sql-data-sources-hive-tables.html#interacting-with-different-versions-of-hive-metastore)). +(from 0.12.0 to 2.3.4. Also see [Interacting with Different Versions of Hive Metastore](sql-data-sources-hive-tables.html#interacting-with-different-versions-of-hive-metastore)). #### Deploying in Existing Hive Warehouses diff --git a/docs/sql-migration-guide-upgrade.md b/docs/sql-migration-guide-upgrade.md index 80f113de62ccc..25cd541190919 100644 --- a/docs/sql-migration-guide-upgrade.md +++ b/docs/sql-migration-guide-upgrade.md @@ -17,8 +17,16 @@ displayTitle: Spark SQL Upgrading Guide - Since Spark 3.0, the `from_json` functions supports two modes - `PERMISSIVE` and `FAILFAST`. The modes can be set via the `mode` option. The default mode became `PERMISSIVE`. In previous versions, behavior of `from_json` did not conform to either `PERMISSIVE` nor `FAILFAST`, especially in processing of malformed JSON records. For example, the JSON string `{"a" 1}` with the schema `a INT` is converted to `null` by previous versions but Spark 3.0 converts it to `Row(null)`. + - In Spark version 2.4 and earlier, the `from_json` function produces `null`s for JSON strings and JSON datasource skips the same independetly of its mode if there is no valid root JSON token in its input (` ` for example). Since Spark 3.0, such input is treated as a bad record and handled according to specified mode. For example, in the `PERMISSIVE` mode the ` ` input is converted to `Row(null, null)` if specified schema is `key STRING, value INT`. + - The `ADD JAR` command previously returned a result set with the single value 0. It now returns an empty result set. + - In Spark version 2.4 and earlier, users can create map values with map type key via built-in function like `CreateMap`, `MapFromArrays`, etc. Since Spark 3.0, it's not allowed to create map values with map type key with these built-in functions. Users can still read map values with map type key from data source or Java/Scala collections, though they are not very useful. + + - In Spark version 2.4 and earlier, `Dataset.groupByKey` results to a grouped dataset with key attribute wrongly named as "value", if the key is non-struct type, e.g. int, string, array, etc. This is counterintuitive and makes the schema of aggregation queries weird. For example, the schema of `ds.groupByKey(...).count()` is `(value, count)`. Since Spark 3.0, we name the grouping attribute to "key". The old behaviour is preserved under a newly added configuration `spark.sql.legacy.dataset.nameNonStructGroupingKeyAsValue` with a default value of `false`. + + - In Spark version 2.4 and earlier, float/double -0.0 is semantically equal to 0.0, but users can still distinguish them via `Dataset.show`, `Dataset.collect` etc. Since Spark 3.0, float/double -0.0 is replaced by 0.0 internally, and users can't distinguish them any more. + ## Upgrading From Spark SQL 2.3 to 2.4 - In Spark version 2.3 and earlier, the second parameter to array_contains function is implicitly promoted to the element type of first array type parameter. This type promotion can be lossy and may cause `array_contains` function to return wrong result. This problem has been addressed in 2.4 by employing a safer type promotion mechanism. This can cause some change in behavior and are illustrated in the table below. @@ -119,7 +127,7 @@ displayTitle: Spark SQL Upgrading Guide - Since Spark 2.4, Metadata files (e.g. Parquet summary files) and temporary files are not counted as data files when calculating table size during Statistics computation. - - Since Spark 2.4, empty strings are saved as quoted empty strings `""`. In version 2.3 and earlier, empty strings are equal to `null` values and do not reflect to any characters in saved CSV files. For example, the row of `"a", null, "", 1` was written as `a,,,1`. Since Spark 2.4, the same row is saved as `a,,"",1`. To restore the previous behavior, set the CSV option `emptyValue` to empty (not quoted) string. + - Since Spark 2.4, empty strings are saved as quoted empty strings `""`. In version 2.3 and earlier, empty strings are equal to `null` values and do not reflect to any characters in saved CSV files. For example, the row of `"a", null, "", 1` was written as `a,,,1`. Since Spark 2.4, the same row is saved as `a,,"",1`. To restore the previous behavior, set the CSV option `emptyValue` to empty (not quoted) string. - Since Spark 2.4, The LOAD DATA command supports wildcard `?` and `*`, which match any one character, and zero or more characters, respectively. Example: `LOAD DATA INPATH '/tmp/folder*/'` or `LOAD DATA INPATH '/tmp/part-?'`. Special Characters like `space` also now work in paths. Example: `LOAD DATA INPATH '/tmp/folder name/'`. @@ -305,7 +313,7 @@ displayTitle: Spark SQL Upgrading Guide ## Upgrading From Spark SQL 2.1 to 2.2 - Spark 2.1.1 introduced a new configuration key: `spark.sql.hive.caseSensitiveInferenceMode`. It had a default setting of `NEVER_INFER`, which kept behavior identical to 2.1.0. However, Spark 2.2.0 changes this setting's default value to `INFER_AND_SAVE` to restore compatibility with reading Hive metastore tables whose underlying file schema have mixed-case column names. With the `INFER_AND_SAVE` configuration value, on first access Spark will perform schema inference on any Hive metastore table for which it has not already saved an inferred schema. Note that schema inference can be a very time-consuming operation for tables with thousands of partitions. If compatibility with mixed-case column names is not a concern, you can safely set `spark.sql.hive.caseSensitiveInferenceMode` to `NEVER_INFER` to avoid the initial overhead of schema inference. Note that with the new default `INFER_AND_SAVE` setting, the results of the schema inference are saved as a metastore key for future use. Therefore, the initial schema inference occurs only at a table's first access. - + - Since Spark 2.2.1 and 2.3.0, the schema is always inferred at runtime when the data source tables have the columns that exist in both partition schema and data schema. The inferred schema does not have the partitioned columns. When reading the table, Spark respects the partition values of these overlapping columns instead of the values stored in the data source files. In 2.2.0 and 2.1.x release, the inferred schema is partitioned but the data of the table is invisible to users (i.e., the result set is empty). - Since Spark 2.2, view definitions are stored in a different way from prior versions. This may cause Spark unable to read views created by prior versions. In such cases, you need to recreate the views using `ALTER VIEW AS` or `CREATE OR REPLACE VIEW AS` with newer Spark versions. diff --git a/docs/structured-streaming-kafka-integration.md b/docs/structured-streaming-kafka-integration.md index 71fd5b10cc407..a549ce2a6a05f 100644 --- a/docs/structured-streaming-kafka-integration.md +++ b/docs/structured-streaming-kafka-integration.md @@ -123,7 +123,7 @@ df.selectExpr("CAST(key AS STRING)", "CAST(value AS STRING)") -### Creating a Kafka Source for Batch Queries +### Creating a Kafka Source for Batch Queries If you have a use case that is better suited to batch processing, you can create a Dataset/DataFrame for a defined range of offsets. @@ -374,17 +374,24 @@ The following configurations are optional: streaming and batch Rate limit on maximum number of offsets processed per trigger interval. The specified total number of offsets will be proportionally split across topicPartitions of different volume. + + groupIdPrefix + string + spark-kafka-source + streaming and batch + Prefix of consumer group identifiers (`group.id`) that are generated by structured streaming queries + ## Writing Data to Kafka -Here, we describe the support for writing Streaming Queries and Batch Queries to Apache Kafka. Take note that +Here, we describe the support for writing Streaming Queries and Batch Queries to Apache Kafka. Take note that Apache Kafka only supports at least once write semantics. Consequently, when writing---either Streaming Queries or Batch Queries---to Kafka, some records may be duplicated; this can happen, for example, if Kafka needs to retry a message that was not acknowledged by a Broker, even though that Broker received and wrote the message record. -Structured Streaming cannot prevent such duplicates from occurring due to these Kafka write semantics. However, +Structured Streaming cannot prevent such duplicates from occurring due to these Kafka write semantics. However, if writing the query is successful, then you can assume that the query output was written at least once. A possible -solution to remove duplicates when reading the written data could be to introduce a primary (unique) key +solution to remove duplicates when reading the written data could be to introduce a primary (unique) key that can be used to perform de-duplication when reading. The Dataframe being written to Kafka should have the following columns in schema: @@ -405,8 +412,8 @@ The Dataframe being written to Kafka should have the following columns in schema \* The topic column is required if the "topic" configuration option is not specified.
    -The value column is the only required option. If a key column is not specified then -a ```null``` valued key column will be automatically added (see Kafka semantics on +The value column is the only required option. If a key column is not specified then +a ```null``` valued key column will be automatically added (see Kafka semantics on how ```null``` valued key values are handled). If a topic column exists then its value is used as the topic when writing the given row to Kafka, unless the "topic" configuration option is set i.e., the "topic" configuration option overrides the topic column. @@ -568,7 +575,7 @@ df.selectExpr("topic", "CAST(key AS STRING)", "CAST(value AS STRING)") \ .format("kafka") \ .option("kafka.bootstrap.servers", "host1:port1,host2:port2") \ .save() - + {% endhighlight %} @@ -576,23 +583,25 @@ df.selectExpr("topic", "CAST(key AS STRING)", "CAST(value AS STRING)") \ ## Kafka Specific Configurations -Kafka's own configurations can be set via `DataStreamReader.option` with `kafka.` prefix, e.g, -`stream.option("kafka.bootstrap.servers", "host:port")`. For possible kafka parameters, see +Kafka's own configurations can be set via `DataStreamReader.option` with `kafka.` prefix, e.g, +`stream.option("kafka.bootstrap.servers", "host:port")`. For possible kafka parameters, see [Kafka consumer config docs](http://kafka.apache.org/documentation.html#newconsumerconfigs) for parameters related to reading data, and [Kafka producer config docs](http://kafka.apache.org/documentation/#producerconfigs) for parameters related to writing data. Note that the following Kafka params cannot be set and the Kafka source or sink will throw an exception: -- **group.id**: Kafka source will create a unique group id for each query automatically. +- **group.id**: Kafka source will create a unique group id for each query automatically. The user can +set the prefix of the automatically generated group.id's via the optional source option `groupIdPrefix`, default value +is "spark-kafka-source". - **auto.offset.reset**: Set the source option `startingOffsets` to specify - where to start instead. Structured Streaming manages which offsets are consumed internally, rather - than rely on the kafka Consumer to do it. This will ensure that no data is missed when new + where to start instead. Structured Streaming manages which offsets are consumed internally, rather + than rely on the kafka Consumer to do it. This will ensure that no data is missed when new topics/partitions are dynamically subscribed. Note that `startingOffsets` only applies when a new streaming query is started, and that resuming will always pick up from where the query left off. -- **key.deserializer**: Keys are always deserialized as byte arrays with ByteArrayDeserializer. Use +- **key.deserializer**: Keys are always deserialized as byte arrays with ByteArrayDeserializer. Use DataFrame operations to explicitly deserialize the keys. -- **value.deserializer**: Values are always deserialized as byte arrays with ByteArrayDeserializer. +- **value.deserializer**: Values are always deserialized as byte arrays with ByteArrayDeserializer. Use DataFrame operations to explicitly deserialize the values. - **key.serializer**: Keys are always serialized with ByteArraySerializer or StringSerializer. Use DataFrame operations to explicitly serialize the keys into either strings or byte arrays. diff --git a/examples/pom.xml b/examples/pom.xml index e9b1d4be33974..fbc3ec40359f7 100644 --- a/examples/pom.xml +++ b/examples/pom.xml @@ -20,12 +20,12 @@ 4.0.0 org.apache.spark - spark-parent_2.11 + spark-parent_2.12 3.0.0-SNAPSHOT ../pom.xml - spark-examples_2.11 + spark-examples_2.12 jar Spark Project Examples http://spark.apache.org/ diff --git a/examples/src/main/scala/org/apache/spark/examples/LogQuery.scala b/examples/src/main/scala/org/apache/spark/examples/LogQuery.scala index c55b68e033964..03187aee044e4 100644 --- a/examples/src/main/scala/org/apache/spark/examples/LogQuery.scala +++ b/examples/src/main/scala/org/apache/spark/examples/LogQuery.scala @@ -32,13 +32,13 @@ object LogQuery { | GTB7.4; .NET CLR 2.0.50727; .NET CLR 3.0.04506.30; .NET CLR 3.0.04506.648; .NET CLR | 3.5.21022; .NET CLR 3.0.4506.2152; .NET CLR 1.0.3705; .NET CLR 1.1.4322; .NET CLR | 3.5.30729; Release=ARP)" "UD-1" - "image/jpeg" "whatever" 0.350 "-" - "" 265 923 934 "" - | 62.24.11.25 images.com 1358492167 - Whatup""".stripMargin.lines.mkString, + | 62.24.11.25 images.com 1358492167 - Whatup""".stripMargin.split('\n').mkString, """10.10.10.10 - "FRED" [18/Jan/2013:18:02:37 +1100] "GET http://images.com/2013/Generic.jpg | HTTP/1.1" 304 306 "http:/referall.com" "Mozilla/4.0 (compatible; MSIE 7.0; Windows NT 5.1; | GTB7.4; .NET CLR 2.0.50727; .NET CLR 3.0.04506.30; .NET CLR 3.0.04506.648; .NET CLR | 3.5.21022; .NET CLR 3.0.4506.2152; .NET CLR 1.0.3705; .NET CLR 1.1.4322; .NET CLR | 3.5.30729; Release=ARP)" "UD-1" - "image/jpeg" "whatever" 0.352 "-" - "" 256 977 988 "" - | 0 73.23.2.15 images.com 1358492557 - Whatup""".stripMargin.lines.mkString + | 0 73.23.2.15 images.com 1358492557 - Whatup""".stripMargin.split('\n').mkString ) def main(args: Array[String]) { diff --git a/external/avro/benchmarks/AvroReadBenchmark-results.txt b/external/avro/benchmarks/AvroReadBenchmark-results.txt new file mode 100644 index 0000000000000..7900fea453b10 --- /dev/null +++ b/external/avro/benchmarks/AvroReadBenchmark-results.txt @@ -0,0 +1,122 @@ +================================================================================================ +SQL Single Numeric Column Scan +================================================================================================ + +OpenJDK 64-Bit Server VM 1.8.0_191-b12 on Linux 3.10.0-862.3.2.el7.x86_64 +Intel(R) Xeon(R) CPU E5-2670 v2 @ 2.50GHz +SQL Single TINYINT Column Scan: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative +------------------------------------------------------------------------------------------------ +Sum 2774 / 2815 5.7 176.4 1.0X + +OpenJDK 64-Bit Server VM 1.8.0_191-b12 on Linux 3.10.0-862.3.2.el7.x86_64 +Intel(R) Xeon(R) CPU E5-2670 v2 @ 2.50GHz +SQL Single SMALLINT Column Scan: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative +------------------------------------------------------------------------------------------------ +Sum 2761 / 2777 5.7 175.5 1.0X + +OpenJDK 64-Bit Server VM 1.8.0_191-b12 on Linux 3.10.0-862.3.2.el7.x86_64 +Intel(R) Xeon(R) CPU E5-2670 v2 @ 2.50GHz +SQL Single INT Column Scan: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative +------------------------------------------------------------------------------------------------ +Sum 2783 / 2870 5.7 176.9 1.0X + +OpenJDK 64-Bit Server VM 1.8.0_191-b12 on Linux 3.10.0-862.3.2.el7.x86_64 +Intel(R) Xeon(R) CPU E5-2670 v2 @ 2.50GHz +SQL Single BIGINT Column Scan: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative +------------------------------------------------------------------------------------------------ +Sum 3256 / 3266 4.8 207.0 1.0X + +OpenJDK 64-Bit Server VM 1.8.0_191-b12 on Linux 3.10.0-862.3.2.el7.x86_64 +Intel(R) Xeon(R) CPU E5-2670 v2 @ 2.50GHz +SQL Single FLOAT Column Scan: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative +------------------------------------------------------------------------------------------------ +Sum 2841 / 2867 5.5 180.6 1.0X + +OpenJDK 64-Bit Server VM 1.8.0_191-b12 on Linux 3.10.0-862.3.2.el7.x86_64 +Intel(R) Xeon(R) CPU E5-2670 v2 @ 2.50GHz +SQL Single DOUBLE Column Scan: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative +------------------------------------------------------------------------------------------------ +Sum 2981 / 2996 5.3 189.5 1.0X + + +================================================================================================ +Int and String Scan +================================================================================================ + +OpenJDK 64-Bit Server VM 1.8.0_191-b12 on Linux 3.10.0-862.3.2.el7.x86_64 +Intel(R) Xeon(R) CPU E5-2670 v2 @ 2.50GHz +Int and String Scan: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative +------------------------------------------------------------------------------------------------ +Sum of columns 4781 / 4783 2.2 456.0 1.0X + + +================================================================================================ +Partitioned Table Scan +================================================================================================ + +OpenJDK 64-Bit Server VM 1.8.0_191-b12 on Linux 3.10.0-862.3.2.el7.x86_64 +Intel(R) Xeon(R) CPU E5-2670 v2 @ 2.50GHz +Partitioned Table: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative +------------------------------------------------------------------------------------------------ +Data column 3372 / 3386 4.7 214.4 1.0X +Partition column 3035 / 3064 5.2 193.0 1.1X +Both columns 3445 / 3461 4.6 219.1 1.0X + + +================================================================================================ +Repeated String Scan +================================================================================================ + +OpenJDK 64-Bit Server VM 1.8.0_191-b12 on Linux 3.10.0-862.3.2.el7.x86_64 +Intel(R) Xeon(R) CPU E5-2670 v2 @ 2.50GHz +Repeated String: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative +------------------------------------------------------------------------------------------------ +Sum of string length 3395 / 3401 3.1 323.8 1.0X + + +================================================================================================ +String with Nulls Scan +================================================================================================ + +OpenJDK 64-Bit Server VM 1.8.0_191-b12 on Linux 3.10.0-862.3.2.el7.x86_64 +Intel(R) Xeon(R) CPU E5-2670 v2 @ 2.50GHz +String with Nulls Scan (0.0%): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative +------------------------------------------------------------------------------------------------ +Sum of string length 5580 / 5624 1.9 532.2 1.0X + +OpenJDK 64-Bit Server VM 1.8.0_191-b12 on Linux 3.10.0-862.3.2.el7.x86_64 +Intel(R) Xeon(R) CPU E5-2670 v2 @ 2.50GHz +String with Nulls Scan (50.0%): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative +------------------------------------------------------------------------------------------------ +Sum of string length 4622 / 4623 2.3 440.8 1.0X + +OpenJDK 64-Bit Server VM 1.8.0_191-b12 on Linux 3.10.0-862.3.2.el7.x86_64 +Intel(R) Xeon(R) CPU E5-2670 v2 @ 2.50GHz +String with Nulls Scan (95.0%): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative +------------------------------------------------------------------------------------------------ +Sum of string length 3238 / 3241 3.2 308.8 1.0X + + +================================================================================================ +Single Column Scan From Wide Columns +================================================================================================ + +OpenJDK 64-Bit Server VM 1.8.0_191-b12 on Linux 3.10.0-862.3.2.el7.x86_64 +Intel(R) Xeon(R) CPU E5-2670 v2 @ 2.50GHz +Single Column Scan from 100 columns: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative +------------------------------------------------------------------------------------------------ +Sum of single column 5472 / 5484 0.2 5218.8 1.0X + +OpenJDK 64-Bit Server VM 1.8.0_191-b12 on Linux 3.10.0-862.3.2.el7.x86_64 +Intel(R) Xeon(R) CPU E5-2670 v2 @ 2.50GHz +Single Column Scan from 200 columns: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative +------------------------------------------------------------------------------------------------ +Sum of single column 10680 / 10701 0.1 10185.1 1.0X + +OpenJDK 64-Bit Server VM 1.8.0_191-b12 on Linux 3.10.0-862.3.2.el7.x86_64 +Intel(R) Xeon(R) CPU E5-2670 v2 @ 2.50GHz +Single Column Scan from 300 columns: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative +------------------------------------------------------------------------------------------------ +Sum of single column 16143 / 16238 0.1 15394.9 1.0X + + diff --git a/external/avro/pom.xml b/external/avro/pom.xml index 9d8f319cc9396..ba6f20bfdbf58 100644 --- a/external/avro/pom.xml +++ b/external/avro/pom.xml @@ -20,12 +20,12 @@ 4.0.0 org.apache.spark - spark-parent_2.11 + spark-parent_2.12 3.0.0-SNAPSHOT ../../pom.xml - spark-avro_2.11 + spark-avro_2.12 avro diff --git a/external/avro/src/test/scala/org/apache/spark/sql/avro/AvroSuite.scala b/external/avro/src/test/scala/org/apache/spark/sql/avro/AvroSuite.scala index 4fea2cb969446..207c54ce75f4c 100644 --- a/external/avro/src/test/scala/org/apache/spark/sql/avro/AvroSuite.scala +++ b/external/avro/src/test/scala/org/apache/spark/sql/avro/AvroSuite.scala @@ -138,7 +138,7 @@ class AvroSuite extends QueryTest with SharedSQLContext with SQLTestUtils { test("test NULL avro type") { withTempPath { dir => val fields = - Seq(new Field("null", Schema.create(Type.NULL), "doc", null)).asJava + Seq(new Field("null", Schema.create(Type.NULL), "doc", null.asInstanceOf[AnyVal])).asJava val schema = Schema.createRecord("name", "docs", "namespace", false) schema.setFields(fields) val datumWriter = new GenericDatumWriter[GenericRecord](schema) @@ -161,7 +161,7 @@ class AvroSuite extends QueryTest with SharedSQLContext with SQLTestUtils { val avroSchema: Schema = { val union = Schema.createUnion(List(Schema.create(Type.INT), Schema.create(Type.LONG)).asJava) - val fields = Seq(new Field("field1", union, "doc", null)).asJava + val fields = Seq(new Field("field1", union, "doc", null.asInstanceOf[AnyVal])).asJava val schema = Schema.createRecord("name", "docs", "namespace", false) schema.setFields(fields) schema @@ -189,7 +189,7 @@ class AvroSuite extends QueryTest with SharedSQLContext with SQLTestUtils { val avroSchema: Schema = { val union = Schema.createUnion(List(Schema.create(Type.FLOAT), Schema.create(Type.DOUBLE)).asJava) - val fields = Seq(new Field("field1", union, "doc", null)).asJava + val fields = Seq(new Field("field1", union, "doc", null.asInstanceOf[AnyVal])).asJava val schema = Schema.createRecord("name", "docs", "namespace", false) schema.setFields(fields) schema @@ -221,7 +221,7 @@ class AvroSuite extends QueryTest with SharedSQLContext with SQLTestUtils { Schema.create(Type.NULL) ).asJava ) - val fields = Seq(new Field("field1", union, "doc", null)).asJava + val fields = Seq(new Field("field1", union, "doc", null.asInstanceOf[AnyVal])).asJava val schema = Schema.createRecord("name", "docs", "namespace", false) schema.setFields(fields) schema @@ -247,7 +247,7 @@ class AvroSuite extends QueryTest with SharedSQLContext with SQLTestUtils { test("Union of a single type") { withTempPath { dir => val UnionOfOne = Schema.createUnion(List(Schema.create(Type.INT)).asJava) - val fields = Seq(new Field("field1", UnionOfOne, "doc", null)).asJava + val fields = Seq(new Field("field1", UnionOfOne, "doc", null.asInstanceOf[AnyVal])).asJava val schema = Schema.createRecord("name", "docs", "namespace", false) schema.setFields(fields) @@ -274,10 +274,10 @@ class AvroSuite extends QueryTest with SharedSQLContext with SQLTestUtils { val complexUnionType = Schema.createUnion( List(Schema.create(Type.INT), Schema.create(Type.STRING), fixedSchema, enumSchema).asJava) val fields = Seq( - new Field("field1", complexUnionType, "doc", null), - new Field("field2", complexUnionType, "doc", null), - new Field("field3", complexUnionType, "doc", null), - new Field("field4", complexUnionType, "doc", null) + new Field("field1", complexUnionType, "doc", null.asInstanceOf[AnyVal]), + new Field("field2", complexUnionType, "doc", null.asInstanceOf[AnyVal]), + new Field("field3", complexUnionType, "doc", null.asInstanceOf[AnyVal]), + new Field("field4", complexUnionType, "doc", null.asInstanceOf[AnyVal]) ).asJava val schema = Schema.createRecord("name", "docs", "namespace", false) schema.setFields(fields) @@ -508,7 +508,7 @@ class AvroSuite extends QueryTest with SharedSQLContext with SQLTestUtils { val union2 = spark.read.format("avro").load(testAvro).select("union_float_double").collect() assert( union2 - .map(x => new java.lang.Double(x(0).toString)) + .map(x => java.lang.Double.valueOf(x(0).toString)) .exists(p => Math.abs(p - Math.PI) < 0.001)) val fixed = spark.read.format("avro").load(testAvro).select("fixed3").collect() @@ -941,7 +941,7 @@ class AvroSuite extends QueryTest with SharedSQLContext with SQLTestUtils { val avroArrayType = resolveNullable(Schema.createArray(avroType), nullable) val avroMapType = resolveNullable(Schema.createMap(avroType), nullable) val name = "foo" - val avroField = new Field(name, avroType, "", null) + val avroField = new Field(name, avroType, "", null.asInstanceOf[AnyVal]) val recordSchema = Schema.createRecord("name", "doc", "space", true, Seq(avroField).asJava) val avroRecordType = resolveNullable(recordSchema, nullable) diff --git a/external/avro/src/test/scala/org/apache/spark/sql/execution/benchmark/AvroReadBenchmark.scala b/external/avro/src/test/scala/org/apache/spark/sql/execution/benchmark/AvroReadBenchmark.scala new file mode 100644 index 0000000000000..f2f7d650066fb --- /dev/null +++ b/external/avro/src/test/scala/org/apache/spark/sql/execution/benchmark/AvroReadBenchmark.scala @@ -0,0 +1,216 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql.execution.benchmark + +import java.io.File + +import scala.util.Random + +import org.apache.spark.benchmark.Benchmark +import org.apache.spark.sql.DataFrame +import org.apache.spark.sql.catalyst.plans.SQLHelper +import org.apache.spark.sql.types._ + +/** + * Benchmark to measure Avro read performance. + * {{{ + * To run this benchmark: + * 1. without sbt: bin/spark-submit --class + * --jars ,, + * 2. build/sbt "avro/test:runMain " + * 3. generate result: SPARK_GENERATE_BENCHMARK_FILES=1 build/sbt "avro/test:runMain " + * Results will be written to "benchmarks/AvroReadBenchmark-results.txt". + * }}} + */ +object AvroReadBenchmark extends SqlBasedBenchmark with SQLHelper { + def withTempTable(tableNames: String*)(f: => Unit): Unit = { + try f finally tableNames.foreach(spark.catalog.dropTempView) + } + + private def prepareTable(dir: File, df: DataFrame, partition: Option[String] = None): Unit = { + val dirAvro = dir.getCanonicalPath + + if (partition.isDefined) { + df.write.partitionBy(partition.get).format("avro").save(dirAvro) + } else { + df.write.format("avro").save(dirAvro) + } + + spark.read.format("avro").load(dirAvro).createOrReplaceTempView("avroTable") + } + + def numericScanBenchmark(values: Int, dataType: DataType): Unit = { + val benchmark = + new Benchmark(s"SQL Single ${dataType.sql} Column Scan", values, output = output) + + withTempPath { dir => + withTempTable("t1", "avroTable") { + import spark.implicits._ + spark.range(values).map(_ => Random.nextLong).createOrReplaceTempView("t1") + + prepareTable(dir, spark.sql(s"SELECT CAST(value as ${dataType.sql}) id FROM t1")) + + benchmark.addCase("Sum") { _ => + spark.sql("SELECT sum(id) FROM avroTable").collect() + } + + benchmark.run() + } + } + } + + def intStringScanBenchmark(values: Int): Unit = { + val benchmark = new Benchmark("Int and String Scan", values, output = output) + + withTempPath { dir => + withTempTable("t1", "avroTable") { + import spark.implicits._ + spark.range(values).map(_ => Random.nextLong).createOrReplaceTempView("t1") + + prepareTable( + dir, + spark.sql("SELECT CAST(value AS INT) AS c1, CAST(value as STRING) AS c2 FROM t1")) + + benchmark.addCase("Sum of columns") { _ => + spark.sql("SELECT sum(c1), sum(length(c2)) FROM avroTable").collect() + } + + benchmark.run() + } + } + } + + def partitionTableScanBenchmark(values: Int): Unit = { + val benchmark = new Benchmark("Partitioned Table", values, output = output) + + withTempPath { dir => + withTempTable("t1", "avroTable") { + import spark.implicits._ + spark.range(values).map(_ => Random.nextLong).createOrReplaceTempView("t1") + + prepareTable(dir, spark.sql("SELECT value % 2 AS p, value AS id FROM t1"), Some("p")) + + benchmark.addCase("Data column") { _ => + spark.sql("SELECT sum(id) FROM avroTable").collect() + } + + benchmark.addCase("Partition column") { _ => + spark.sql("SELECT sum(p) FROM avroTable").collect() + } + + benchmark.addCase("Both columns") { _ => + spark.sql("SELECT sum(p), sum(id) FROM avroTable").collect() + } + + benchmark.run() + } + } + } + + def repeatedStringScanBenchmark(values: Int): Unit = { + val benchmark = new Benchmark("Repeated String", values, output = output) + + withTempPath { dir => + withTempTable("t1", "avroTable") { + spark.range(values).createOrReplaceTempView("t1") + + prepareTable(dir, spark.sql("SELECT CAST((id % 200) + 10000 as STRING) AS c1 FROM t1")) + + benchmark.addCase("Sum of string length") { _ => + spark.sql("SELECT sum(length(c1)) FROM avroTable").collect() + } + + benchmark.run() + } + } + } + + def stringWithNullsScanBenchmark(values: Int, fractionOfNulls: Double): Unit = { + withTempPath { dir => + withTempTable("t1", "avroTable") { + spark.range(values).createOrReplaceTempView("t1") + + prepareTable( + dir, + spark.sql( + s"SELECT IF(RAND(1) < $fractionOfNulls, NULL, CAST(id as STRING)) AS c1, " + + s"IF(RAND(2) < $fractionOfNulls, NULL, CAST(id as STRING)) AS c2 FROM t1")) + + val percentageOfNulls = fractionOfNulls * 100 + val benchmark = + new Benchmark(s"String with Nulls Scan ($percentageOfNulls%)", values, output = output) + + benchmark.addCase("Sum of string length") { _ => + spark.sql("SELECT SUM(LENGTH(c2)) FROM avroTable " + + "WHERE c1 IS NOT NULL AND c2 IS NOT NULL").collect() + } + + benchmark.run() + } + } + } + + def columnsBenchmark(values: Int, width: Int): Unit = { + val benchmark = + new Benchmark(s"Single Column Scan from $width columns", values, output = output) + + withTempPath { dir => + withTempTable("t1", "avroTable") { + import spark.implicits._ + val middle = width / 2 + val selectExpr = (1 to width).map(i => s"value as c$i") + spark.range(values).map(_ => Random.nextLong).toDF() + .selectExpr(selectExpr: _*).createOrReplaceTempView("t1") + + prepareTable(dir, spark.sql("SELECT * FROM t1")) + + benchmark.addCase("Sum of single column") { _ => + spark.sql(s"SELECT sum(c$middle) FROM avroTable").collect() + } + + benchmark.run() + } + } + } + + override def runBenchmarkSuite(mainArgs: Array[String]): Unit = { + runBenchmark("SQL Single Numeric Column Scan") { + Seq(ByteType, ShortType, IntegerType, LongType, FloatType, DoubleType).foreach { dataType => + numericScanBenchmark(1024 * 1024 * 15, dataType) + } + } + runBenchmark("Int and String Scan") { + intStringScanBenchmark(1024 * 1024 * 10) + } + runBenchmark("Partitioned Table Scan") { + partitionTableScanBenchmark(1024 * 1024 * 15) + } + runBenchmark("Repeated String Scan") { + repeatedStringScanBenchmark(1024 * 1024 * 10) + } + runBenchmark("String with Nulls Scan") { + for (fractionOfNulls <- List(0.0, 0.50, 0.95)) { + stringWithNullsScanBenchmark(1024 * 1024 * 10, fractionOfNulls) + } + } + runBenchmark("Single Column Scan From Wide Columns") { + columnsBenchmark(1024 * 1024 * 1, 100) + columnsBenchmark(1024 * 1024 * 1, 200) + columnsBenchmark(1024 * 1024 * 1, 300) + } + } +} diff --git a/external/docker-integration-tests/pom.xml b/external/docker-integration-tests/pom.xml index f24254b698080..b39db7540b7d2 100644 --- a/external/docker-integration-tests/pom.xml +++ b/external/docker-integration-tests/pom.xml @@ -21,12 +21,12 @@ 4.0.0 org.apache.spark - spark-parent_2.11 + spark-parent_2.12 3.0.0-SNAPSHOT ../../pom.xml - spark-docker-integration-tests_2.11 + spark-docker-integration-tests_2.12 jar Spark Project Docker Integration Tests http://spark.apache.org/ diff --git a/external/kafka-0-10-assembly/pom.xml b/external/kafka-0-10-assembly/pom.xml index 4f9c3163b2408..f2dcf5d217a89 100644 --- a/external/kafka-0-10-assembly/pom.xml +++ b/external/kafka-0-10-assembly/pom.xml @@ -20,12 +20,12 @@ 4.0.0 org.apache.spark - spark-parent_2.11 + spark-parent_2.12 3.0.0-SNAPSHOT ../../pom.xml - spark-streaming-kafka-0-10-assembly_2.11 + spark-streaming-kafka-0-10-assembly_2.12 jar Spark Integration for Kafka 0.10 Assembly http://spark.apache.org/ diff --git a/external/kafka-0-10-sql/pom.xml b/external/kafka-0-10-sql/pom.xml index efd0862fb58ee..1af407167597b 100644 --- a/external/kafka-0-10-sql/pom.xml +++ b/external/kafka-0-10-sql/pom.xml @@ -20,17 +20,17 @@ 4.0.0 org.apache.spark - spark-parent_2.11 + spark-parent_2.12 3.0.0-SNAPSHOT ../../pom.xml org.apache.spark - spark-sql-kafka-0-10_2.11 + spark-sql-kafka-0-10_2.12 sql-kafka-0-10 - 2.0.0 + 2.1.0 jar Kafka 0.10+ Source for Structured Streaming @@ -89,6 +89,13 @@
    + + + org.apache.zookeeper + zookeeper + 3.4.7 + test + net.sf.jopt-simple jopt-simple diff --git a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSourceProvider.scala b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSourceProvider.scala index 28c9853bfea9c..f770f0c2a04c2 100644 --- a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSourceProvider.scala +++ b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSourceProvider.scala @@ -77,7 +77,7 @@ private[kafka010] class KafkaSourceProvider extends DataSourceRegister // Each running query should use its own group id. Otherwise, the query may be only assigned // partial data since Kafka will assign partitions to multiple consumers having the same group // id. Hence, we should generate a unique id for each query. - val uniqueGroupId = s"spark-kafka-source-${UUID.randomUUID}-${metadataPath.hashCode}" + val uniqueGroupId = streamingUniqueGroupId(parameters, metadataPath) val caseInsensitiveParams = parameters.map { case (k, v) => (k.toLowerCase(Locale.ROOT), v) } val specifiedKafkaParams = @@ -119,7 +119,7 @@ private[kafka010] class KafkaSourceProvider extends DataSourceRegister // Each running query should use its own group id. Otherwise, the query may be only assigned // partial data since Kafka will assign partitions to multiple consumers having the same group // id. Hence, we should generate a unique id for each query. - val uniqueGroupId = s"spark-kafka-source-${UUID.randomUUID}-${metadataPath.hashCode}" + val uniqueGroupId = streamingUniqueGroupId(parameters, metadataPath) val caseInsensitiveParams = parameters.map { case (k, v) => (k.toLowerCase(Locale.ROOT), v) } val specifiedKafkaParams = @@ -159,7 +159,7 @@ private[kafka010] class KafkaSourceProvider extends DataSourceRegister // Each running query should use its own group id. Otherwise, the query may be only assigned // partial data since Kafka will assign partitions to multiple consumers having the same group // id. Hence, we should generate a unique id for each query. - val uniqueGroupId = s"spark-kafka-source-${UUID.randomUUID}-${metadataPath.hashCode}" + val uniqueGroupId = streamingUniqueGroupId(parameters, metadataPath) val caseInsensitiveParams = parameters.map { case (k, v) => (k.toLowerCase(Locale.ROOT), v) } val specifiedKafkaParams = @@ -510,7 +510,7 @@ private[kafka010] object KafkaSourceProvider extends Logging { .set(ConsumerConfig.ENABLE_AUTO_COMMIT_CONFIG, "false") // So that the driver does not pull too much data - .set(ConsumerConfig.MAX_POLL_RECORDS_CONFIG, new java.lang.Integer(1)) + .set(ConsumerConfig.MAX_POLL_RECORDS_CONFIG, java.lang.Integer.valueOf(1)) // If buffer config is not set, set it to reasonable value to work around // buffer issues (see KAFKA-3135) @@ -538,6 +538,18 @@ private[kafka010] object KafkaSourceProvider extends Logging { .setIfUnset(ConsumerConfig.RECEIVE_BUFFER_CONFIG, 65536: java.lang.Integer) .build() + /** + * Returns a unique consumer group (group.id), allowing the user to set the prefix of + * the consumer group + */ + private def streamingUniqueGroupId( + parameters: Map[String, String], + metadataPath: String): String = { + val groupIdPrefix = parameters + .getOrElse("groupIdPrefix", "spark-kafka-source") + s"${groupIdPrefix}-${UUID.randomUUID}-${metadataPath.hashCode}" + } + /** Class to conveniently update Kafka config params, while logging the changes */ private case class ConfigUpdater(module: String, kafkaParams: Map[String, String]) { private val map = new ju.HashMap[String, Object](kafkaParams.asJava) diff --git a/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaContinuousTest.scala b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaContinuousTest.scala index fa6bdc20bd4f9..aa21f1271b817 100644 --- a/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaContinuousTest.scala +++ b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaContinuousTest.scala @@ -56,7 +56,7 @@ trait KafkaContinuousTest extends KafkaSourceTest { } // Continuous processing tasks end asynchronously, so test that they actually end. - private val tasksEndedListener = new SparkListener() { + private class TasksEndedListener extends SparkListener { val activeTaskIdCount = new AtomicInteger(0) override def onTaskStart(start: SparkListenerTaskStart): Unit = { @@ -68,6 +68,8 @@ trait KafkaContinuousTest extends KafkaSourceTest { } } + private val tasksEndedListener = new TasksEndedListener() + override def beforeEach(): Unit = { super.beforeEach() spark.sparkContext.addSparkListener(tasksEndedListener) diff --git a/external/kafka-0-10/pom.xml b/external/kafka-0-10/pom.xml index f59f07265a0f4..ea18b7e035915 100644 --- a/external/kafka-0-10/pom.xml +++ b/external/kafka-0-10/pom.xml @@ -20,16 +20,16 @@ 4.0.0 org.apache.spark - spark-parent_2.11 + spark-parent_2.12 3.0.0-SNAPSHOT ../../pom.xml - spark-streaming-kafka-0-10_2.11 + spark-streaming-kafka-0-10_2.12 streaming-kafka-0-10 - 2.0.0 + 2.1.0 jar Spark Integration for Kafka 0.10 @@ -74,6 +74,13 @@ + + + org.apache.zookeeper + zookeeper + 3.4.7 + test + net.sf.jopt-simple jopt-simple diff --git a/external/kafka-0-10/src/main/scala/org/apache/spark/streaming/kafka010/ConsumerStrategy.scala b/external/kafka-0-10/src/main/scala/org/apache/spark/streaming/kafka010/ConsumerStrategy.scala index cf283a5c3e11e..07960d14b0bfc 100644 --- a/external/kafka-0-10/src/main/scala/org/apache/spark/streaming/kafka010/ConsumerStrategy.scala +++ b/external/kafka-0-10/src/main/scala/org/apache/spark/streaming/kafka010/ConsumerStrategy.scala @@ -228,7 +228,7 @@ object ConsumerStrategies { new Subscribe[K, V]( new ju.ArrayList(topics.asJavaCollection), new ju.HashMap[String, Object](kafkaParams.asJava), - new ju.HashMap[TopicPartition, jl.Long](offsets.mapValues(l => new jl.Long(l)).asJava)) + new ju.HashMap[TopicPartition, jl.Long](offsets.mapValues(jl.Long.valueOf).asJava)) } /** @@ -307,7 +307,7 @@ object ConsumerStrategies { new SubscribePattern[K, V]( pattern, new ju.HashMap[String, Object](kafkaParams.asJava), - new ju.HashMap[TopicPartition, jl.Long](offsets.mapValues(l => new jl.Long(l)).asJava)) + new ju.HashMap[TopicPartition, jl.Long](offsets.mapValues(jl.Long.valueOf).asJava)) } /** @@ -391,7 +391,7 @@ object ConsumerStrategies { new Assign[K, V]( new ju.ArrayList(topicPartitions.asJavaCollection), new ju.HashMap[String, Object](kafkaParams.asJava), - new ju.HashMap[TopicPartition, jl.Long](offsets.mapValues(l => new jl.Long(l)).asJava)) + new ju.HashMap[TopicPartition, jl.Long](offsets.mapValues(jl.Long.valueOf).asJava)) } /** diff --git a/external/kafka-0-10/src/main/scala/org/apache/spark/streaming/kafka010/DirectKafkaInputDStream.scala b/external/kafka-0-10/src/main/scala/org/apache/spark/streaming/kafka010/DirectKafkaInputDStream.scala index ba4009ef08856..224f41a683955 100644 --- a/external/kafka-0-10/src/main/scala/org/apache/spark/streaming/kafka010/DirectKafkaInputDStream.scala +++ b/external/kafka-0-10/src/main/scala/org/apache/spark/streaming/kafka010/DirectKafkaInputDStream.scala @@ -70,7 +70,7 @@ private[spark] class DirectKafkaInputDStream[K, V]( @transient private var kc: Consumer[K, V] = null def consumer(): Consumer[K, V] = this.synchronized { if (null == kc) { - kc = consumerStrategy.onStart(currentOffsets.mapValues(l => new java.lang.Long(l)).asJava) + kc = consumerStrategy.onStart(currentOffsets.mapValues(l => java.lang.Long.valueOf(l)).asJava) } kc } diff --git a/external/kafka-0-10/src/main/scala/org/apache/spark/streaming/kafka010/KafkaUtils.scala b/external/kafka-0-10/src/main/scala/org/apache/spark/streaming/kafka010/KafkaUtils.scala index 64b6ef6c53b6d..2516b948f6650 100644 --- a/external/kafka-0-10/src/main/scala/org/apache/spark/streaming/kafka010/KafkaUtils.scala +++ b/external/kafka-0-10/src/main/scala/org/apache/spark/streaming/kafka010/KafkaUtils.scala @@ -56,7 +56,7 @@ object KafkaUtils extends Logging { ): RDD[ConsumerRecord[K, V]] = { val preferredHosts = locationStrategy match { case PreferBrokers => - throw new AssertionError( + throw new IllegalArgumentException( "If you want to prefer brokers, you must provide a mapping using PreferFixed " + "A single KafkaRDD does not have a driver consumer and cannot look up brokers for you.") case PreferConsistent => ju.Collections.emptyMap[TopicPartition, String]() diff --git a/external/kinesis-asl-assembly/pom.xml b/external/kinesis-asl-assembly/pom.xml index 0bf4c265939e7..0ce922349ea66 100644 --- a/external/kinesis-asl-assembly/pom.xml +++ b/external/kinesis-asl-assembly/pom.xml @@ -20,12 +20,12 @@ 4.0.0 org.apache.spark - spark-parent_2.11 + spark-parent_2.12 3.0.0-SNAPSHOT ../../pom.xml - spark-streaming-kinesis-asl-assembly_2.11 + spark-streaming-kinesis-asl-assembly_2.12 jar Spark Project Kinesis Assembly http://spark.apache.org/ diff --git a/external/kinesis-asl/pom.xml b/external/kinesis-asl/pom.xml index 0aef25329db99..7d69764b77de7 100644 --- a/external/kinesis-asl/pom.xml +++ b/external/kinesis-asl/pom.xml @@ -19,13 +19,13 @@ 4.0.0 org.apache.spark - spark-parent_2.11 + spark-parent_2.12 3.0.0-SNAPSHOT ../../pom.xml - spark-streaming-kinesis-asl_2.11 + spark-streaming-kinesis-asl_2.12 jar Spark Kinesis Integration diff --git a/external/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisInputDStream.scala b/external/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisInputDStream.scala index 1ffec01df9f00..d4a428f45c110 100644 --- a/external/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisInputDStream.scala +++ b/external/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisInputDStream.scala @@ -22,7 +22,7 @@ import scala.reflect.ClassTag import com.amazonaws.services.kinesis.clientlibrary.lib.worker.InitialPositionInStream import com.amazonaws.services.kinesis.model.Record -import org.apache.spark.annotation.InterfaceStability +import org.apache.spark.annotation.Evolving import org.apache.spark.rdd.RDD import org.apache.spark.storage.{BlockId, StorageLevel} import org.apache.spark.streaming.{Duration, StreamingContext, Time} @@ -84,14 +84,14 @@ private[kinesis] class KinesisInputDStream[T: ClassTag]( } } -@InterfaceStability.Evolving +@Evolving object KinesisInputDStream { /** * Builder for [[KinesisInputDStream]] instances. * * @since 2.2.0 */ - @InterfaceStability.Evolving + @Evolving class Builder { // Required params private var streamingContext: Option[StreamingContext] = None diff --git a/external/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/SparkAWSCredentials.scala b/external/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/SparkAWSCredentials.scala index 9facfe8ff2b0f..dcb60b21d9851 100644 --- a/external/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/SparkAWSCredentials.scala +++ b/external/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/SparkAWSCredentials.scala @@ -14,13 +14,12 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package org.apache.spark.streaming.kinesis -import scala.collection.JavaConverters._ +package org.apache.spark.streaming.kinesis import com.amazonaws.auth._ -import org.apache.spark.annotation.InterfaceStability +import org.apache.spark.annotation.Evolving import org.apache.spark.internal.Logging /** @@ -84,14 +83,14 @@ private[kinesis] final case class STSCredentials( } } -@InterfaceStability.Evolving +@Evolving object SparkAWSCredentials { /** * Builder for [[SparkAWSCredentials]] instances. * * @since 2.2.0 */ - @InterfaceStability.Evolving + @Evolving class Builder { private var basicCreds: Option[BasicCredentials] = None private var stsCreds: Option[STSCredentials] = None diff --git a/external/spark-ganglia-lgpl/pom.xml b/external/spark-ganglia-lgpl/pom.xml index 35a55b70baf33..a23d255f9187c 100644 --- a/external/spark-ganglia-lgpl/pom.xml +++ b/external/spark-ganglia-lgpl/pom.xml @@ -19,13 +19,13 @@ 4.0.0 org.apache.spark - spark-parent_2.11 + spark-parent_2.12 3.0.0-SNAPSHOT ../../pom.xml - spark-ganglia-lgpl_2.11 + spark-ganglia-lgpl_2.12 jar Spark Ganglia Integration diff --git a/graphx/pom.xml b/graphx/pom.xml index 22bc148e068a5..444568a03d6c7 100644 --- a/graphx/pom.xml +++ b/graphx/pom.xml @@ -20,12 +20,12 @@ 4.0.0 org.apache.spark - spark-parent_2.11 + spark-parent_2.12 3.0.0-SNAPSHOT ../pom.xml - spark-graphx_2.11 + spark-graphx_2.12 graphx diff --git a/graphx/src/main/scala/org/apache/spark/graphx/impl/ShippableVertexPartition.scala b/graphx/src/main/scala/org/apache/spark/graphx/impl/ShippableVertexPartition.scala index a4e293d74a012..184b96426fa9b 100644 --- a/graphx/src/main/scala/org/apache/spark/graphx/impl/ShippableVertexPartition.scala +++ b/graphx/src/main/scala/org/apache/spark/graphx/impl/ShippableVertexPartition.scala @@ -117,13 +117,11 @@ class ShippableVertexPartition[VD: ClassTag]( val initialSize = if (shipSrc && shipDst) routingTable.partitionSize(pid) else 64 val vids = new PrimitiveVector[VertexId](initialSize) val attrs = new PrimitiveVector[VD](initialSize) - var i = 0 routingTable.foreachWithinEdgePartition(pid, shipSrc, shipDst) { vid => if (isDefined(vid)) { vids += vid attrs += this(vid) } - i += 1 } (pid, new VertexAttributeBlock(vids.trim().array, attrs.trim().array)) } @@ -137,12 +135,10 @@ class ShippableVertexPartition[VD: ClassTag]( def shipVertexIds(): Iterator[(PartitionID, Array[VertexId])] = { Iterator.tabulate(routingTable.numEdgePartitions) { pid => val vids = new PrimitiveVector[VertexId](routingTable.partitionSize(pid)) - var i = 0 routingTable.foreachWithinEdgePartition(pid, true, true) { vid => if (isDefined(vid)) { vids += vid } - i += 1 } (pid, vids.trim().array) } diff --git a/hadoop-cloud/pom.xml b/hadoop-cloud/pom.xml index 5d08d8d92e577..fc98f2e0023ce 100644 --- a/hadoop-cloud/pom.xml +++ b/hadoop-cloud/pom.xml @@ -21,12 +21,12 @@ 4.0.0 org.apache.spark - spark-parent_2.11 + spark-parent_2.12 3.0.0-SNAPSHOT ../pom.xml - spark-hadoop-cloud_2.11 + spark-hadoop-cloud_2.12 jar Spark Project Cloud Integration diff --git a/launcher/pom.xml b/launcher/pom.xml index a833a35399918..130519d6c3b08 100644 --- a/launcher/pom.xml +++ b/launcher/pom.xml @@ -21,12 +21,12 @@ 4.0.0 org.apache.spark - spark-parent_2.11 + spark-parent_2.12 3.0.0-SNAPSHOT ../pom.xml - spark-launcher_2.11 + spark-launcher_2.12 jar Spark Project Launcher http://spark.apache.org/ diff --git a/launcher/src/main/java/org/apache/spark/launcher/AbstractAppHandle.java b/launcher/src/main/java/org/apache/spark/launcher/AbstractAppHandle.java index 9cbebdaeb33d3..0999cbd216871 100644 --- a/launcher/src/main/java/org/apache/spark/launcher/AbstractAppHandle.java +++ b/launcher/src/main/java/org/apache/spark/launcher/AbstractAppHandle.java @@ -31,8 +31,8 @@ abstract class AbstractAppHandle implements SparkAppHandle { private final LauncherServer server; private LauncherServer.ServerConnection connection; - private List listeners; - private AtomicReference state; + private List listeners; + private AtomicReference state; private volatile String appId; private volatile boolean disposed; @@ -42,7 +42,7 @@ protected AbstractAppHandle(LauncherServer server) { } @Override - public synchronized void addListener(Listener l) { + public synchronized void addListener(SparkAppHandle.Listener l) { if (listeners == null) { listeners = new CopyOnWriteArrayList<>(); } @@ -50,7 +50,7 @@ public synchronized void addListener(Listener l) { } @Override - public State getState() { + public SparkAppHandle.State getState() { return state.get(); } @@ -120,11 +120,11 @@ synchronized void dispose() { } } - void setState(State s) { + void setState(SparkAppHandle.State s) { setState(s, false); } - void setState(State s, boolean force) { + void setState(SparkAppHandle.State s, boolean force) { if (force) { state.set(s); fireEvent(false); diff --git a/mllib-local/pom.xml b/mllib-local/pom.xml index ec5f9b0e92c8f..2eab868ac0dc8 100644 --- a/mllib-local/pom.xml +++ b/mllib-local/pom.xml @@ -20,12 +20,12 @@ 4.0.0 org.apache.spark - spark-parent_2.11 + spark-parent_2.12 3.0.0-SNAPSHOT ../pom.xml - spark-mllib-local_2.11 + spark-mllib-local_2.12 mllib-local diff --git a/mllib-local/src/main/scala/org/apache/spark/ml/linalg/Vectors.scala b/mllib-local/src/main/scala/org/apache/spark/ml/linalg/Vectors.scala index 5824e463ca1aa..6e950f968a65d 100644 --- a/mllib-local/src/main/scala/org/apache/spark/ml/linalg/Vectors.scala +++ b/mllib-local/src/main/scala/org/apache/spark/ml/linalg/Vectors.scala @@ -106,7 +106,7 @@ sealed trait Vector extends Serializable { */ @Since("2.0.0") def copy: Vector = { - throw new NotImplementedError(s"copy is not implemented for ${this.getClass}.") + throw new UnsupportedOperationException(s"copy is not implemented for ${this.getClass}.") } /** diff --git a/mllib-local/src/main/scala/org/apache/spark/ml/stat/distribution/MultivariateGaussian.scala b/mllib-local/src/main/scala/org/apache/spark/ml/stat/distribution/MultivariateGaussian.scala index 3167e0c286d47..e7f7a8e07d7f2 100644 --- a/mllib-local/src/main/scala/org/apache/spark/ml/stat/distribution/MultivariateGaussian.scala +++ b/mllib-local/src/main/scala/org/apache/spark/ml/stat/distribution/MultivariateGaussian.scala @@ -48,14 +48,14 @@ class MultivariateGaussian @Since("2.0.0") ( this(Vectors.fromBreeze(mean), Matrices.fromBreeze(cov)) } - private val breezeMu = mean.asBreeze.toDenseVector + @transient private lazy val breezeMu = mean.asBreeze.toDenseVector /** * Compute distribution dependent constants: * rootSigmaInv = D^(-1/2)^ * U.t, where sigma = U * D * U.t * u = log((2*pi)^(-k/2)^ * det(sigma)^(-1/2)^) */ - private val (rootSigmaInv: BDM[Double], u: Double) = calculateCovarianceConstants + @transient private lazy val (rootSigmaInv: BDM[Double], u: Double) = calculateCovarianceConstants /** * Returns density of this multivariate Gaussian at given point, x diff --git a/mllib-local/src/test/scala/org/apache/spark/ml/linalg/MatricesSuite.scala b/mllib-local/src/test/scala/org/apache/spark/ml/linalg/MatricesSuite.scala index ace44165b1067..332734bd28341 100644 --- a/mllib-local/src/test/scala/org/apache/spark/ml/linalg/MatricesSuite.scala +++ b/mllib-local/src/test/scala/org/apache/spark/ml/linalg/MatricesSuite.scala @@ -862,10 +862,10 @@ class MatricesSuite extends SparkMLFunSuite { mat.toString(0, 0) mat.toString(Int.MinValue, Int.MinValue) mat.toString(Int.MaxValue, Int.MaxValue) - var lines = mat.toString(6, 50).lines.toArray + var lines = mat.toString(6, 50).split('\n') assert(lines.size == 5 && lines.forall(_.size <= 50)) - lines = mat.toString(5, 100).lines.toArray + lines = mat.toString(5, 100).split('\n') assert(lines.size == 5 && lines.forall(_.size <= 100)) } diff --git a/mllib/pom.xml b/mllib/pom.xml index 17ddb87c4d86a..0b17345064a71 100644 --- a/mllib/pom.xml +++ b/mllib/pom.xml @@ -20,12 +20,12 @@ 4.0.0 org.apache.spark - spark-parent_2.11 + spark-parent_2.12 3.0.0-SNAPSHOT ../pom.xml - spark-mllib_2.11 + spark-mllib_2.12 mllib diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala index 62cfa39746ff0..2c4186a13d8f4 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala @@ -180,7 +180,6 @@ class GBTClassifier @Since("1.4.0") ( (convert2LabeledPoint(dataset), null) } - val numFeatures = trainDataset.first().features.size val boostingStrategy = super.getOldBoostingStrategy(categoricalFeatures, OldAlgo.Classification) val numClasses = 2 @@ -196,7 +195,6 @@ class GBTClassifier @Since("1.4.0") ( maxDepth, maxBins, maxIter, maxMemoryInMB, minInfoGain, minInstancesPerNode, seed, stepSize, subsamplingRate, cacheNodeIds, checkpointInterval, featureSubsetStrategy, validationIndicatorCol) - instr.logNumFeatures(numFeatures) instr.logNumClasses(numClasses) val (baseLearners, learnerWeights) = if (withValidation) { @@ -206,6 +204,9 @@ class GBTClassifier @Since("1.4.0") ( GradientBoostedTrees.run(trainDataset, boostingStrategy, $(seed), $(featureSubsetStrategy)) } + val numFeatures = baseLearners.head.numFeatures + instr.logNumFeatures(numFeatures) + new GBTClassificationModel(uid, baseLearners, learnerWeights, numFeatures) } @@ -427,7 +428,9 @@ object GBTClassificationModel extends MLReadable[GBTClassificationModel] { s" trees based on metadata but found ${trees.length} trees.") val model = new GBTClassificationModel(metadata.uid, trees, treeWeights, numFeatures) - metadata.getAndSetParams(model) + // We ignore the impurity while loading models because in previous models it was wrongly + // set to gini (see SPARK-25959). + metadata.getAndSetParams(model, Some(List("impurity"))) model } } diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/NaiveBayes.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/NaiveBayes.scala index 51495c1a74e69..1a7a5e7a52344 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/NaiveBayes.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/NaiveBayes.scala @@ -146,7 +146,7 @@ class NaiveBayes @Since("1.5.0") ( requireZeroOneBernoulliValues case _ => // This should never happen. - throw new UnknownError(s"Invalid modelType: ${$(modelType)}.") + throw new IllegalArgumentException(s"Invalid modelType: ${$(modelType)}.") } } @@ -196,7 +196,7 @@ class NaiveBayes @Since("1.5.0") ( case Bernoulli => math.log(n + 2.0 * lambda) case _ => // This should never happen. - throw new UnknownError(s"Invalid modelType: ${$(modelType)}.") + throw new IllegalArgumentException(s"Invalid modelType: ${$(modelType)}.") } var j = 0 while (j < numFeatures) { @@ -295,7 +295,7 @@ class NaiveBayesModel private[ml] ( (Option(thetaMinusNegTheta), Option(negTheta.multiply(ones))) case _ => // This should never happen. - throw new UnknownError(s"Invalid modelType: ${$(modelType)}.") + throw new IllegalArgumentException(s"Invalid modelType: ${$(modelType)}.") } @Since("1.6.0") @@ -329,7 +329,7 @@ class NaiveBayesModel private[ml] ( bernoulliCalculation(features) case _ => // This should never happen. - throw new UnknownError(s"Invalid modelType: ${$(modelType)}.") + throw new IllegalArgumentException(s"Invalid modelType: ${$(modelType)}.") } } diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala index 57132381b6474..7598a28b6f89d 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala @@ -142,7 +142,7 @@ class RandomForestClassifier @Since("1.4.0") ( .run(oldDataset, strategy, getNumTrees, getFeatureSubsetStrategy, getSeed, Some(instr)) .map(_.asInstanceOf[DecisionTreeClassificationModel]) - val numFeatures = oldDataset.first().features.size + val numFeatures = trees.head.numFeatures instr.logNumClasses(numClasses) instr.logNumFeatures(numFeatures) new RandomForestClassificationModel(uid, trees, numFeatures, numClasses) diff --git a/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala b/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala index 498310d6644e1..5d02305aafdda 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala @@ -279,7 +279,7 @@ object KMeansModel extends MLReadable[KMeansModel] { /** * K-means clustering with support for k-means|| initialization proposed by Bahmani et al. * - * @see Bahmani et al., Scalable k-means++. + * @see Bahmani et al., Scalable k-means++. */ @Since("1.5.0") class KMeans @Since("1.5.0") ( diff --git a/mllib/src/main/scala/org/apache/spark/ml/evaluation/MulticlassClassificationEvaluator.scala b/mllib/src/main/scala/org/apache/spark/ml/evaluation/MulticlassClassificationEvaluator.scala index 794b1e7d9d881..f1602c1bc5333 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/evaluation/MulticlassClassificationEvaluator.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/evaluation/MulticlassClassificationEvaluator.scala @@ -19,7 +19,7 @@ package org.apache.spark.ml.evaluation import org.apache.spark.annotation.{Experimental, Since} import org.apache.spark.ml.param.{Param, ParamMap, ParamValidators} -import org.apache.spark.ml.param.shared.{HasLabelCol, HasPredictionCol} +import org.apache.spark.ml.param.shared.{HasLabelCol, HasPredictionCol, HasWeightCol} import org.apache.spark.ml.util.{DefaultParamsReadable, DefaultParamsWritable, Identifiable, SchemaUtils} import org.apache.spark.mllib.evaluation.MulticlassMetrics import org.apache.spark.sql.{Dataset, Row} @@ -33,7 +33,8 @@ import org.apache.spark.sql.types.DoubleType @Since("1.5.0") @Experimental class MulticlassClassificationEvaluator @Since("1.5.0") (@Since("1.5.0") override val uid: String) - extends Evaluator with HasPredictionCol with HasLabelCol with DefaultParamsWritable { + extends Evaluator with HasPredictionCol with HasLabelCol + with HasWeightCol with DefaultParamsWritable { @Since("1.5.0") def this() = this(Identifiable.randomUID("mcEval")) @@ -67,6 +68,10 @@ class MulticlassClassificationEvaluator @Since("1.5.0") (@Since("1.5.0") overrid @Since("1.5.0") def setLabelCol(value: String): this.type = set(labelCol, value) + /** @group setParam */ + @Since("3.0.0") + def setWeightCol(value: String): this.type = set(weightCol, value) + setDefault(metricName -> "f1") @Since("2.0.0") @@ -75,11 +80,13 @@ class MulticlassClassificationEvaluator @Since("1.5.0") (@Since("1.5.0") overrid SchemaUtils.checkColumnType(schema, $(predictionCol), DoubleType) SchemaUtils.checkNumericType(schema, $(labelCol)) - val predictionAndLabels = - dataset.select(col($(predictionCol)), col($(labelCol)).cast(DoubleType)).rdd.map { - case Row(prediction: Double, label: Double) => (prediction, label) + val predictionAndLabelsWithWeights = + dataset.select(col($(predictionCol)), col($(labelCol)).cast(DoubleType), + if (!isDefined(weightCol) || $(weightCol).isEmpty) lit(1.0) else col($(weightCol))) + .rdd.map { + case Row(prediction: Double, label: Double, weight: Double) => (prediction, label, weight) } - val metrics = new MulticlassMetrics(predictionAndLabels) + val metrics = new MulticlassMetrics(predictionAndLabelsWithWeights) val metric = $(metricName) match { case "f1" => metrics.weightedFMeasure case "weightedPrecision" => metrics.weightedPrecision diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/Bucketizer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/Bucketizer.scala index f99649f7fa164..0b989b0d7d253 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/Bucketizer.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/Bucketizer.scala @@ -89,7 +89,8 @@ final class Bucketizer @Since("1.4.0") (@Since("1.4.0") override val uid: String def setOutputCol(value: String): this.type = set(outputCol, value) /** - * Param for how to handle invalid entries. Options are 'skip' (filter out rows with + * Param for how to handle invalid entries containing NaN values. Values outside the splits + * will always be treated as errors. Options are 'skip' (filter out rows with * invalid values), 'error' (throw an error), or 'keep' (keep invalid values in a special * additional bucket). Note that in the multiple column case, the invalid handling is applied * to all columns. That said for 'error' it will throw an error if any invalids are found in @@ -99,7 +100,8 @@ final class Bucketizer @Since("1.4.0") (@Since("1.4.0") override val uid: String */ @Since("2.1.0") override val handleInvalid: Param[String] = new Param[String](this, "handleInvalid", - "how to handle invalid entries. Options are skip (filter out rows with invalid values), " + + "how to handle invalid entries containing NaN values. Values outside the splits will always " + + "be treated as errorsOptions are skip (filter out rows with invalid values), " + "error (throw an error), or keep (keep invalid values in a special additional bucket).", ParamValidators.inArray(Bucketizer.supportedHandleInvalids)) diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/LabeledPoint.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/LabeledPoint.scala index c5d0ec1a8d350..412954f7b2d5a 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/LabeledPoint.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/LabeledPoint.scala @@ -17,8 +17,6 @@ package org.apache.spark.ml.feature -import scala.beans.BeanInfo - import org.apache.spark.annotation.Since import org.apache.spark.ml.linalg.Vector @@ -30,8 +28,12 @@ import org.apache.spark.ml.linalg.Vector * @param features List of features for this data point. */ @Since("2.0.0") -@BeanInfo case class LabeledPoint(@Since("2.0.0") label: Double, @Since("2.0.0") features: Vector) { + + def getLabel: Double = label + + def getFeatures: Vector = features + override def toString: String = { s"($label,$features)" } diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/QuantileDiscretizer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/QuantileDiscretizer.scala index 56e2c543d100a..5bfaa3b7f3f52 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/QuantileDiscretizer.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/QuantileDiscretizer.scala @@ -17,10 +17,6 @@ package org.apache.spark.ml.feature -import org.json4s.JsonDSL._ -import org.json4s.JValue -import org.json4s.jackson.JsonMethods._ - import org.apache.spark.annotation.Since import org.apache.spark.internal.Logging import org.apache.spark.ml._ @@ -209,7 +205,7 @@ final class QuantileDiscretizer @Since("1.6.0") (@Since("1.6.0") override val ui if (isSet(inputCols)) { val splitsArray = if (isSet(numBucketsArray)) { val probArrayPerCol = $(numBucketsArray).map { numOfBuckets => - (0.0 to 1.0 by 1.0 / numOfBuckets).toArray + (0 to numOfBuckets).map(_.toDouble / numOfBuckets).toArray } val probabilityArray = probArrayPerCol.flatten.sorted.distinct @@ -229,12 +225,12 @@ final class QuantileDiscretizer @Since("1.6.0") (@Since("1.6.0") override val ui } } else { dataset.stat.approxQuantile($(inputCols), - (0.0 to 1.0 by 1.0 / $(numBuckets)).toArray, $(relativeError)) + (0 to $(numBuckets)).map(_.toDouble / $(numBuckets)).toArray, $(relativeError)) } bucketizer.setSplitsArray(splitsArray.map(getDistinctSplits)) } else { val splits = dataset.stat.approxQuantile($(inputCol), - (0.0 to 1.0 by 1.0 / $(numBuckets)).toArray, $(relativeError)) + (0 to $(numBuckets)).map(_.toDouble / $(numBuckets)).toArray, $(relativeError)) bucketizer.setSplits(getDistinctSplits(splits)) } copyValues(bucketizer.setParent(this)) diff --git a/mllib/src/main/scala/org/apache/spark/ml/fpm/FPGrowth.scala b/mllib/src/main/scala/org/apache/spark/ml/fpm/FPGrowth.scala index 840a89b76d26b..7322815c12ab8 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/fpm/FPGrowth.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/fpm/FPGrowth.scala @@ -118,10 +118,10 @@ private[fpm] trait FPGrowthParams extends Params with HasPredictionCol { /** * :: Experimental :: * A parallel FP-growth algorithm to mine frequent itemsets. The algorithm is described in - * Li et al., PFP: Parallel FP-Growth for Query + * Li et al., PFP: Parallel FP-Growth for Query * Recommendation. PFP distributes computation in such a way that each worker executes an * independent group of mining tasks. The FP-Growth algorithm is described in - * Han et al., Mining frequent patterns without + * Han et al., Mining frequent patterns without * candidate generation. Note null values in the itemsCol column are ignored during fit(). * * @see diff --git a/mllib/src/main/scala/org/apache/spark/ml/fpm/PrefixSpan.scala b/mllib/src/main/scala/org/apache/spark/ml/fpm/PrefixSpan.scala index bd1c1a8885201..2a3413553a6af 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/fpm/PrefixSpan.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/fpm/PrefixSpan.scala @@ -30,7 +30,7 @@ import org.apache.spark.sql.types.{ArrayType, LongType, StructField, StructType} * A parallel PrefixSpan algorithm to mine frequent sequential patterns. * The PrefixSpan algorithm is described in J. Pei, et al., PrefixSpan: Mining Sequential Patterns * Efficiently by Prefix-Projected Pattern Growth - * (see here). + * (see here). * This class is not yet an Estimator/Transformer, use `findFrequentSequentialPatterns` method to * run the PrefixSpan algorithm. * diff --git a/mllib/src/main/scala/org/apache/spark/ml/param/params.scala b/mllib/src/main/scala/org/apache/spark/ml/param/params.scala index e6c347ed17c15..4c50f1e3292bc 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/param/params.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/param/params.scala @@ -97,7 +97,7 @@ class Param[T](val parent: String, val name: String, val doc: String, val isVali case m: Matrix => JsonMatrixConverter.toJson(m) case _ => - throw new NotImplementedError( + throw new UnsupportedOperationException( "The default jsonEncode only supports string, vector and matrix. " + s"${this.getClass.getName} must override jsonEncode for ${value.getClass.getName}.") } @@ -151,7 +151,7 @@ private[ml] object Param { } case _ => - throw new NotImplementedError( + throw new UnsupportedOperationException( "The default jsonDecode only supports string, vector and matrix. " + s"${this.getClass.getName} must override jsonDecode to support its value type.") } diff --git a/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala b/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala index ffe592789b3cc..50ef4330ddc80 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala @@ -557,7 +557,7 @@ object ALSModel extends MLReadable[ALSModel] { * * For implicit preference data, the algorithm used is based on * "Collaborative Filtering for Implicit Feedback Datasets", available at - * http://dx.doi.org/10.1109/ICDM.2008.22, adapted for the blocked approach used here. + * https://doi.org/10.1109/ICDM.2008.22, adapted for the blocked approach used here. * * Essentially instead of finding the low-rank approximations to the rating matrix `R`, * this finds the approximations for a preference matrix `P` where the elements of `P` are 1 if diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/DecisionTreeRegressor.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/DecisionTreeRegressor.scala index 6fa656275c1fd..c9de85de42fa5 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/DecisionTreeRegressor.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/DecisionTreeRegressor.scala @@ -145,7 +145,7 @@ class DecisionTreeRegressor @Since("1.4.0") (@Since("1.4.0") override val uid: S @Since("1.4.0") object DecisionTreeRegressor extends DefaultParamsReadable[DecisionTreeRegressor] { /** Accessor for supported impurities: variance */ - final val supportedImpurities: Array[String] = TreeRegressorParams.supportedImpurities + final val supportedImpurities: Array[String] = HasVarianceImpurity.supportedImpurities @Since("2.0.0") override def load(path: String): DecisionTreeRegressor = super.load(path) diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala index 07f88d8d5f84d..88dee2507bf7e 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala @@ -166,7 +166,6 @@ class GBTRegressor @Since("1.4.0") (@Since("1.4.0") override val uid: String) } else { (extractLabeledPoints(dataset), null) } - val numFeatures = trainDataset.first().features.size val boostingStrategy = super.getOldBoostingStrategy(categoricalFeatures, OldAlgo.Regression) instr.logPipelineStage(this) @@ -174,7 +173,6 @@ class GBTRegressor @Since("1.4.0") (@Since("1.4.0") override val uid: String) instr.logParams(this, labelCol, featuresCol, predictionCol, impurity, lossType, maxDepth, maxBins, maxIter, maxMemoryInMB, minInfoGain, minInstancesPerNode, seed, stepSize, subsamplingRate, cacheNodeIds, checkpointInterval, featureSubsetStrategy) - instr.logNumFeatures(numFeatures) val (baseLearners, learnerWeights) = if (withValidation) { GradientBoostedTrees.runWithValidation(trainDataset, validationDataset, boostingStrategy, @@ -183,6 +181,10 @@ class GBTRegressor @Since("1.4.0") (@Since("1.4.0") override val uid: String) GradientBoostedTrees.run(trainDataset, boostingStrategy, $(seed), $(featureSubsetStrategy)) } + + val numFeatures = baseLearners.head.numFeatures + instr.logNumFeatures(numFeatures) + new GBTRegressionModel(uid, baseLearners, learnerWeights, numFeatures) } diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/RandomForestRegressor.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/RandomForestRegressor.scala index 82bf66ff66d8a..a548ec537bb44 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/RandomForestRegressor.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/RandomForestRegressor.scala @@ -133,7 +133,7 @@ class RandomForestRegressor @Since("1.4.0") (@Since("1.4.0") override val uid: S .run(oldDataset, strategy, getNumTrees, getFeatureSubsetStrategy, getSeed, Some(instr)) .map(_.asInstanceOf[DecisionTreeRegressionModel]) - val numFeatures = oldDataset.first().features.size + val numFeatures = trees.head.numFeatures instr.logNamedValue(Instrumentation.loggerTags.numFeatures, numFeatures) new RandomForestRegressionModel(uid, trees, numFeatures) } @@ -146,7 +146,7 @@ class RandomForestRegressor @Since("1.4.0") (@Since("1.4.0") override val uid: S object RandomForestRegressor extends DefaultParamsReadable[RandomForestRegressor]{ /** Accessor for supported impurity settings: variance */ @Since("1.4.0") - final val supportedImpurities: Array[String] = TreeRegressorParams.supportedImpurities + final val supportedImpurities: Array[String] = HasVarianceImpurity.supportedImpurities /** Accessor for supported featureSubsetStrategy settings: auto, all, onethird, sqrt, log2 */ @Since("1.4.0") diff --git a/mllib/src/main/scala/org/apache/spark/ml/tree/treeParams.scala b/mllib/src/main/scala/org/apache/spark/ml/tree/treeParams.scala index 00157fe63af41..f1e3836ebe476 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/tree/treeParams.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/tree/treeParams.scala @@ -258,11 +258,7 @@ private[ml] object TreeClassifierParams { private[ml] trait DecisionTreeClassifierParams extends DecisionTreeParams with TreeClassifierParams -/** - * Parameters for Decision Tree-based regression algorithms. - */ -private[ml] trait TreeRegressorParams extends Params { - +private[ml] trait HasVarianceImpurity extends Params { /** * Criterion used for information gain calculation (case-insensitive). * Supported: "variance". @@ -271,9 +267,9 @@ private[ml] trait TreeRegressorParams extends Params { */ final val impurity: Param[String] = new Param[String](this, "impurity", "Criterion used for" + " information gain calculation (case-insensitive). Supported options:" + - s" ${TreeRegressorParams.supportedImpurities.mkString(", ")}", + s" ${HasVarianceImpurity.supportedImpurities.mkString(", ")}", (value: String) => - TreeRegressorParams.supportedImpurities.contains(value.toLowerCase(Locale.ROOT))) + HasVarianceImpurity.supportedImpurities.contains(value.toLowerCase(Locale.ROOT))) setDefault(impurity -> "variance") @@ -299,12 +295,17 @@ private[ml] trait TreeRegressorParams extends Params { } } -private[ml] object TreeRegressorParams { +private[ml] object HasVarianceImpurity { // These options should be lowercase. final val supportedImpurities: Array[String] = Array("variance").map(_.toLowerCase(Locale.ROOT)) } +/** + * Parameters for Decision Tree-based regression algorithms. + */ +private[ml] trait TreeRegressorParams extends HasVarianceImpurity + private[ml] trait DecisionTreeRegressorParams extends DecisionTreeParams with TreeRegressorParams with HasVarianceCol { @@ -538,7 +539,7 @@ private[ml] object GBTClassifierParams { Array("logistic").map(_.toLowerCase(Locale.ROOT)) } -private[ml] trait GBTClassifierParams extends GBTParams with TreeClassifierParams { +private[ml] trait GBTClassifierParams extends GBTParams with HasVarianceImpurity { /** * Loss function which GBT tries to minimize. (case-insensitive) diff --git a/mllib/src/main/scala/org/apache/spark/ml/tuning/ValidatorParams.scala b/mllib/src/main/scala/org/apache/spark/ml/tuning/ValidatorParams.scala index 135828815504a..6d46ea0adcc9a 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/tuning/ValidatorParams.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/tuning/ValidatorParams.scala @@ -140,8 +140,8 @@ private[ml] object ValidatorParams { "value" -> compact(render(JString(relativePath))), "isJson" -> compact(render(JBool(false)))) case _: MLWritable => - throw new NotImplementedError("ValidatorParams.saveImpl does not handle parameters " + - "of type: MLWritable that are not DefaultParamsWritable") + throw new UnsupportedOperationException("ValidatorParams.saveImpl does not handle" + + " parameters of type: MLWritable that are not DefaultParamsWritable") case _ => Map("parent" -> p.parent, "name" -> p.name, "value" -> p.jsonEncode(v), "isJson" -> compact(render(JBool(true)))) diff --git a/mllib/src/main/scala/org/apache/spark/ml/util/ReadWrite.scala b/mllib/src/main/scala/org/apache/spark/ml/util/ReadWrite.scala index a0ac26a34d8c8..fbc7be25a5640 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/util/ReadWrite.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/util/ReadWrite.scala @@ -31,7 +31,7 @@ import org.json4s.JsonDSL._ import org.json4s.jackson.JsonMethods._ import org.apache.spark.{SparkContext, SparkException} -import org.apache.spark.annotation.{DeveloperApi, InterfaceStability, Since} +import org.apache.spark.annotation.{DeveloperApi, Since, Unstable} import org.apache.spark.internal.Logging import org.apache.spark.ml._ import org.apache.spark.ml.classification.{OneVsRest, OneVsRestModel} @@ -84,7 +84,7 @@ private[util] sealed trait BaseReadWrite { * * @since 2.4.0 */ -@InterfaceStability.Unstable +@Unstable @Since("2.4.0") trait MLWriterFormat { /** @@ -108,7 +108,7 @@ trait MLWriterFormat { * * @since 2.4.0 */ -@InterfaceStability.Unstable +@Unstable @Since("2.4.0") trait MLFormatRegister extends MLWriterFormat { /** @@ -208,7 +208,7 @@ abstract class MLWriter extends BaseReadWrite with Logging { /** * A ML Writer which delegates based on the requested format. */ -@InterfaceStability.Unstable +@Unstable @Since("2.4.0") class GeneralMLWriter(stage: PipelineStage) extends MLWriter with Logging { private var source: String = "internal" @@ -256,7 +256,7 @@ class GeneralMLWriter(stage: PipelineStage) extends MLWriter with Logging { s"Multiple writers found for $source+$stageName, try using the class name of the writer") } if (classOf[MLWriterFormat].isAssignableFrom(writerCls)) { - val writer = writerCls.newInstance().asInstanceOf[MLWriterFormat] + val writer = writerCls.getConstructor().newInstance().asInstanceOf[MLWriterFormat] writer.write(path, sparkSession, optionMap, stage) } else { throw new SparkException(s"ML source $source is not a valid MLWriterFormat") @@ -291,7 +291,7 @@ trait MLWritable { * Trait for classes that provide `GeneralMLWriter`. */ @Since("2.4.0") -@InterfaceStability.Unstable +@Unstable trait GeneralMLWritable extends MLWritable { /** * Returns an `MLWriter` instance for this ML instance. diff --git a/mllib/src/main/scala/org/apache/spark/mllib/classification/NaiveBayes.scala b/mllib/src/main/scala/org/apache/spark/mllib/classification/NaiveBayes.scala index 9e8774732efe6..16ba6cabdc823 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/classification/NaiveBayes.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/classification/NaiveBayes.scala @@ -83,7 +83,7 @@ class NaiveBayesModel private[spark] ( (Option(thetaMinusNegTheta), Option(negTheta.multiply(ones))) case _ => // This should never happen. - throw new UnknownError(s"Invalid modelType: $modelType.") + throw new IllegalArgumentException(s"Invalid modelType: $modelType.") } @Since("1.0.0") diff --git a/mllib/src/main/scala/org/apache/spark/mllib/evaluation/MulticlassMetrics.scala b/mllib/src/main/scala/org/apache/spark/mllib/evaluation/MulticlassMetrics.scala index 980e0c92531a2..ad83c24ede964 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/evaluation/MulticlassMetrics.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/evaluation/MulticlassMetrics.scala @@ -27,10 +27,19 @@ import org.apache.spark.sql.DataFrame /** * Evaluator for multiclass classification. * - * @param predictionAndLabels an RDD of (prediction, label) pairs. + * @param predAndLabelsWithOptWeight an RDD of (prediction, label, weight) or + * (prediction, label) pairs. */ @Since("1.1.0") -class MulticlassMetrics @Since("1.1.0") (predictionAndLabels: RDD[(Double, Double)]) { +class MulticlassMetrics @Since("1.1.0") (predAndLabelsWithOptWeight: RDD[_ <: Product]) { + val predLabelsWeight: RDD[(Double, Double, Double)] = predAndLabelsWithOptWeight.map { + case (prediction: Double, label: Double, weight: Double) => + (prediction, label, weight) + case (prediction: Double, label: Double) => + (prediction, label, 1.0) + case other => + throw new IllegalArgumentException(s"Expected tuples, got $other") + } /** * An auxiliary constructor taking a DataFrame. @@ -39,21 +48,29 @@ class MulticlassMetrics @Since("1.1.0") (predictionAndLabels: RDD[(Double, Doubl private[mllib] def this(predictionAndLabels: DataFrame) = this(predictionAndLabels.rdd.map(r => (r.getDouble(0), r.getDouble(1)))) - private lazy val labelCountByClass: Map[Double, Long] = predictionAndLabels.values.countByValue() - private lazy val labelCount: Long = labelCountByClass.values.sum - private lazy val tpByClass: Map[Double, Int] = predictionAndLabels - .map { case (prediction, label) => - (label, if (label == prediction) 1 else 0) + private lazy val labelCountByClass: Map[Double, Double] = + predLabelsWeight.map { + case (_: Double, label: Double, weight: Double) => + (label, weight) + }.reduceByKey(_ + _) + .collectAsMap() + private lazy val labelCount: Double = labelCountByClass.values.sum + private lazy val tpByClass: Map[Double, Double] = predLabelsWeight + .map { + case (prediction: Double, label: Double, weight: Double) => + (label, if (label == prediction) weight else 0.0) }.reduceByKey(_ + _) .collectAsMap() - private lazy val fpByClass: Map[Double, Int] = predictionAndLabels - .map { case (prediction, label) => - (prediction, if (prediction != label) 1 else 0) + private lazy val fpByClass: Map[Double, Double] = predLabelsWeight + .map { + case (prediction: Double, label: Double, weight: Double) => + (prediction, if (prediction != label) weight else 0.0) }.reduceByKey(_ + _) .collectAsMap() - private lazy val confusions = predictionAndLabels - .map { case (prediction, label) => - ((label, prediction), 1) + private lazy val confusions = predLabelsWeight + .map { + case (prediction: Double, label: Double, weight: Double) => + ((label, prediction), weight) }.reduceByKey(_ + _) .collectAsMap() @@ -71,7 +88,7 @@ class MulticlassMetrics @Since("1.1.0") (predictionAndLabels: RDD[(Double, Doubl while (i < n) { var j = 0 while (j < n) { - values(i + j * n) = confusions.getOrElse((labels(i), labels(j)), 0).toDouble + values(i + j * n) = confusions.getOrElse((labels(i), labels(j)), 0.0) j += 1 } i += 1 @@ -92,8 +109,8 @@ class MulticlassMetrics @Since("1.1.0") (predictionAndLabels: RDD[(Double, Doubl */ @Since("1.1.0") def falsePositiveRate(label: Double): Double = { - val fp = fpByClass.getOrElse(label, 0) - fp.toDouble / (labelCount - labelCountByClass(label)) + val fp = fpByClass.getOrElse(label, 0.0) + fp / (labelCount - labelCountByClass(label)) } /** @@ -103,7 +120,7 @@ class MulticlassMetrics @Since("1.1.0") (predictionAndLabels: RDD[(Double, Doubl @Since("1.1.0") def precision(label: Double): Double = { val tp = tpByClass(label) - val fp = fpByClass.getOrElse(label, 0) + val fp = fpByClass.getOrElse(label, 0.0) if (tp + fp == 0) 0 else tp.toDouble / (tp + fp) } @@ -112,7 +129,7 @@ class MulticlassMetrics @Since("1.1.0") (predictionAndLabels: RDD[(Double, Doubl * @param label the label. */ @Since("1.1.0") - def recall(label: Double): Double = tpByClass(label).toDouble / labelCountByClass(label) + def recall(label: Double): Double = tpByClass(label) / labelCountByClass(label) /** * Returns f-measure for a given label (category) @@ -140,7 +157,7 @@ class MulticlassMetrics @Since("1.1.0") (predictionAndLabels: RDD[(Double, Doubl * out of the total number of instances.) */ @Since("2.0.0") - lazy val accuracy: Double = tpByClass.values.sum.toDouble / labelCount + lazy val accuracy: Double = tpByClass.values.sum / labelCount /** * Returns weighted true positive rate diff --git a/mllib/src/main/scala/org/apache/spark/mllib/fpm/FPGrowth.scala b/mllib/src/main/scala/org/apache/spark/mllib/fpm/FPGrowth.scala index 3a1bc35186dc3..519c1ea47c1db 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/fpm/FPGrowth.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/fpm/FPGrowth.scala @@ -152,10 +152,10 @@ object FPGrowthModel extends Loader[FPGrowthModel[_]] { /** * A parallel FP-growth algorithm to mine frequent itemsets. The algorithm is described in - * Li et al., PFP: Parallel FP-Growth for Query + * Li et al., PFP: Parallel FP-Growth for Query * Recommendation. PFP distributes computation in such a way that each worker executes an * independent group of mining tasks. The FP-Growth algorithm is described in - * Han et al., Mining frequent patterns without + * Han et al., Mining frequent patterns without * candidate generation. * * @param minSupport the minimal support level of the frequent pattern, any pattern that appears diff --git a/mllib/src/main/scala/org/apache/spark/mllib/fpm/PrefixSpan.scala b/mllib/src/main/scala/org/apache/spark/mllib/fpm/PrefixSpan.scala index 7aed2f3bd8a61..b2c09b408b40b 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/fpm/PrefixSpan.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/fpm/PrefixSpan.scala @@ -45,7 +45,7 @@ import org.apache.spark.storage.StorageLevel * A parallel PrefixSpan algorithm to mine frequent sequential patterns. * The PrefixSpan algorithm is described in J. Pei, et al., PrefixSpan: Mining Sequential Patterns * Efficiently by Prefix-Projected Pattern Growth - * (see here). + * (see here). * * @param minSupport the minimal support level of the sequential pattern, any pattern that appears * more than (minSupport * size-of-the-dataset) times will be output @@ -174,6 +174,13 @@ class PrefixSpan private ( val freqSequences = results.map { case (seq: Array[Int], count: Long) => new FreqSequence(toPublicRepr(seq), count) } + // Cache the final RDD to the same storage level as input + if (data.getStorageLevel != StorageLevel.NONE) { + freqSequences.persist(data.getStorageLevel) + freqSequences.count() + } + dataInternalRepr.unpersist(false) + new PrefixSpanModel(freqSequences) } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala b/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala index 6e68d9684a672..9cdf1944329b8 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala @@ -117,7 +117,7 @@ sealed trait Vector extends Serializable { */ @Since("1.1.0") def copy: Vector = { - throw new NotImplementedError(s"copy is not implemented for ${this.getClass}.") + throw new UnsupportedOperationException(s"copy is not implemented for ${this.getClass}.") } /** diff --git a/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/BlockMatrix.scala b/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/BlockMatrix.scala index 7caacd13b3459..e58860fea97d0 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/BlockMatrix.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/BlockMatrix.scala @@ -17,10 +17,9 @@ package org.apache.spark.mllib.linalg.distributed +import breeze.linalg.{DenseMatrix => BDM, DenseVector => BDV, Matrix => BM} import scala.collection.mutable.ArrayBuffer -import breeze.linalg.{DenseMatrix => BDM, DenseVector => BDV, Matrix => BM, SparseVector => BSV, Vector => BV} - import org.apache.spark.{Partitioner, SparkException} import org.apache.spark.annotation.Since import org.apache.spark.internal.Logging @@ -28,6 +27,7 @@ import org.apache.spark.mllib.linalg._ import org.apache.spark.rdd.RDD import org.apache.spark.storage.StorageLevel + /** * A grid partitioner, which uses a regular grid to partition coordinates. * @@ -273,24 +273,37 @@ class BlockMatrix @Since("1.3.0") ( require(cols < Int.MaxValue, s"The number of columns should be less than Int.MaxValue ($cols).") val rows = blocks.flatMap { case ((blockRowIdx, blockColIdx), mat) => - mat.rowIter.zipWithIndex.map { + mat.rowIter.zipWithIndex.filter(_._1.size > 0).map { case (vector, rowIdx) => - blockRowIdx * rowsPerBlock + rowIdx -> ((blockColIdx, vector.asBreeze)) + blockRowIdx * rowsPerBlock + rowIdx -> ((blockColIdx, vector)) } }.groupByKey().map { case (rowIdx, vectors) => - val numberNonZeroPerRow = vectors.map(_._2.activeSize).sum.toDouble / cols.toDouble - - val wholeVector = if (numberNonZeroPerRow <= 0.1) { // Sparse at 1/10th nnz - BSV.zeros[Double](cols) - } else { - BDV.zeros[Double](cols) - } + val numberNonZero = vectors.map(_._2.numActives).sum + val numberNonZeroPerRow = numberNonZero.toDouble / cols.toDouble + + val wholeVector = + if (numberNonZeroPerRow <= 0.1) { // Sparse at 1/10th nnz + val arrBufferIndices = new ArrayBuffer[Int](numberNonZero) + val arrBufferValues = new ArrayBuffer[Double](numberNonZero) + + vectors.foreach { case (blockColIdx: Int, vec: Vector) => + val offset = colsPerBlock * blockColIdx + vec.foreachActive { case (colIdx: Int, value: Double) => + arrBufferIndices += offset + colIdx + arrBufferValues += value + } + } + Vectors.sparse(cols, arrBufferIndices.toArray, arrBufferValues.toArray) + } else { + val wholeVectorBuf = BDV.zeros[Double](cols) + vectors.foreach { case (blockColIdx: Int, vec: Vector) => + val offset = colsPerBlock * blockColIdx + wholeVectorBuf(offset until Math.min(cols, offset + colsPerBlock)) := vec.asBreeze + } + Vectors.fromBreeze(wholeVectorBuf) + } - vectors.foreach { case (blockColIdx: Int, vec: BV[_]) => - val offset = colsPerBlock * blockColIdx - wholeVector(offset until Math.min(cols, offset + colsPerBlock)) := vec - } - new IndexedRow(rowIdx, Vectors.fromBreeze(wholeVector)) + IndexedRow(rowIdx, wholeVector) } new IndexedRowMatrix(rows) } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/RowMatrix.scala b/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/RowMatrix.scala index 82ab716ed96a8..c12b751bfb8e4 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/RowMatrix.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/RowMatrix.scala @@ -540,7 +540,7 @@ class RowMatrix @Since("1.0.0") ( * decomposition (factorization) for the [[RowMatrix]] of a tall and skinny shape. * Reference: * Paul G. Constantine, David F. Gleich. "Tall and skinny QR factorizations in MapReduce - * architectures" (see here) + * architectures" (see here) * * @param computeQ whether to computeQ * @return QRDecomposition(Q, R), Q = null if computeQ = false. diff --git a/mllib/src/main/scala/org/apache/spark/mllib/recommendation/ALS.scala b/mllib/src/main/scala/org/apache/spark/mllib/recommendation/ALS.scala index 14288221b6945..12870f819b147 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/recommendation/ALS.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/recommendation/ALS.scala @@ -54,7 +54,7 @@ case class Rating @Since("0.8.0") ( * * For implicit preference data, the algorithm used is based on * "Collaborative Filtering for Implicit Feedback Datasets", available at - * here, adapted for the blocked approach + * here, adapted for the blocked approach * used here. * * Essentially instead of finding the low-rank approximations to the rating matrix `R`, diff --git a/mllib/src/main/scala/org/apache/spark/mllib/regression/LabeledPoint.scala b/mllib/src/main/scala/org/apache/spark/mllib/regression/LabeledPoint.scala index 4381d6ab20cc0..b320057b25276 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/regression/LabeledPoint.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/regression/LabeledPoint.scala @@ -17,8 +17,6 @@ package org.apache.spark.mllib.regression -import scala.beans.BeanInfo - import org.apache.spark.SparkException import org.apache.spark.annotation.Since import org.apache.spark.ml.feature.{LabeledPoint => NewLabeledPoint} @@ -32,10 +30,14 @@ import org.apache.spark.mllib.util.NumericParser * @param features List of features for this data point. */ @Since("0.8.0") -@BeanInfo case class LabeledPoint @Since("1.0.0") ( @Since("0.8.0") label: Double, @Since("1.0.0") features: Vector) { + + def getLabel: Double = label + + def getFeatures: Vector = features + override def toString: String = { s"($label,$features)" } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/stat/distribution/MultivariateGaussian.scala b/mllib/src/main/scala/org/apache/spark/mllib/stat/distribution/MultivariateGaussian.scala index 4cf662e036346..9a746dcf35556 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/stat/distribution/MultivariateGaussian.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/stat/distribution/MultivariateGaussian.scala @@ -43,7 +43,7 @@ class MultivariateGaussian @Since("1.3.0") ( require(sigma.numCols == sigma.numRows, "Covariance matrix must be square") require(mu.size == sigma.numCols, "Mean vector length must match covariance matrix size") - private val breezeMu = mu.asBreeze.toDenseVector + @transient private lazy val breezeMu = mu.asBreeze.toDenseVector /** * private[mllib] constructor @@ -60,7 +60,7 @@ class MultivariateGaussian @Since("1.3.0") ( * rootSigmaInv = D^(-1/2)^ * U.t, where sigma = U * D * U.t * u = log((2*pi)^(-k/2)^ * det(sigma)^(-1/2)^) */ - private val (rootSigmaInv: DBM[Double], u: Double) = calculateCovarianceConstants + @transient private lazy val (rootSigmaInv: DBM[Double], u: Double) = calculateCovarianceConstants /** * Returns density of this multivariate Gaussian at given point, x diff --git a/mllib/src/main/scala/org/apache/spark/mllib/stat/test/StreamingTest.scala b/mllib/src/main/scala/org/apache/spark/mllib/stat/test/StreamingTest.scala index 80c6ef0ea1aa1..85ed11d6553d9 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/stat/test/StreamingTest.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/stat/test/StreamingTest.scala @@ -17,8 +17,6 @@ package org.apache.spark.mllib.stat.test -import scala.beans.BeanInfo - import org.apache.spark.annotation.Since import org.apache.spark.internal.Logging import org.apache.spark.streaming.api.java.JavaDStream @@ -32,10 +30,11 @@ import org.apache.spark.util.StatCounter * @param value numeric value of the observation. */ @Since("1.6.0") -@BeanInfo case class BinarySample @Since("1.6.0") ( @Since("1.6.0") isExperiment: Boolean, @Since("1.6.0") value: Double) { + def getIsExperiment: Boolean = isExperiment + def getValue: Double = value override def toString: String = { s"($isExperiment, $value)" } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/util/MLUtils.scala b/mllib/src/main/scala/org/apache/spark/mllib/util/MLUtils.scala index 14af8b5c73870..6d15a6bb01e4e 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/util/MLUtils.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/util/MLUtils.scala @@ -506,8 +506,6 @@ object MLUtils extends Logging { val n = v1.size require(v2.size == n) require(norm1 >= 0.0 && norm2 >= 0.0) - val sumSquaredNorm = norm1 * norm1 + norm2 * norm2 - val normDiff = norm1 - norm2 var sqDist = 0.0 /* * The relative error is @@ -521,19 +519,23 @@ object MLUtils extends Logging { * The bound doesn't need the inner product, so we can use it as a sufficient condition to * check quickly whether the inner product approach is accurate. */ - val precisionBound1 = 2.0 * EPSILON * sumSquaredNorm / (normDiff * normDiff + EPSILON) - if (precisionBound1 < precision) { - sqDist = sumSquaredNorm - 2.0 * dot(v1, v2) - } else if (v1.isInstanceOf[SparseVector] || v2.isInstanceOf[SparseVector]) { - val dotValue = dot(v1, v2) - sqDist = math.max(sumSquaredNorm - 2.0 * dotValue, 0.0) - val precisionBound2 = EPSILON * (sumSquaredNorm + 2.0 * math.abs(dotValue)) / - (sqDist + EPSILON) - if (precisionBound2 > precision) { - sqDist = Vectors.sqdist(v1, v2) - } - } else { + if (v1.isInstanceOf[DenseVector] && v2.isInstanceOf[DenseVector]) { sqDist = Vectors.sqdist(v1, v2) + } else { + val sumSquaredNorm = norm1 * norm1 + norm2 * norm2 + val normDiff = norm1 - norm2 + val precisionBound1 = 2.0 * EPSILON * sumSquaredNorm / (normDiff * normDiff + EPSILON) + if (precisionBound1 < precision) { + sqDist = sumSquaredNorm - 2.0 * dot(v1, v2) + } else { + val dotValue = dot(v1, v2) + sqDist = math.max(sumSquaredNorm - 2.0 * dotValue, 0.0) + val precisionBound2 = EPSILON * (sumSquaredNorm + 2.0 * math.abs(dotValue)) / + (sqDist + EPSILON) + if (precisionBound2 > precision) { + sqDist = Vectors.sqdist(v1, v2) + } + } } sqDist } diff --git a/mllib/src/test/scala/org/apache/spark/ml/PredictorSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/PredictorSuite.scala index ec45e32d412a9..dff00eade620f 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/PredictorSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/PredictorSuite.scala @@ -73,7 +73,7 @@ object PredictorSuite { } override def copy(extra: ParamMap): MockPredictor = - throw new NotImplementedError() + throw new UnsupportedOperationException() } class MockPredictionModel(override val uid: String) @@ -82,9 +82,9 @@ object PredictorSuite { def this() = this(Identifiable.randomUID("mockpredictormodel")) override def predict(features: Vector): Double = - throw new NotImplementedError() + throw new UnsupportedOperationException() override def copy(extra: ParamMap): MockPredictionModel = - throw new NotImplementedError() + throw new UnsupportedOperationException() } } diff --git a/mllib/src/test/scala/org/apache/spark/ml/attribute/AttributeSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/attribute/AttributeSuite.scala index 6355e0f179496..eb5f3ca45940d 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/attribute/AttributeSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/attribute/AttributeSuite.scala @@ -17,7 +17,8 @@ package org.apache.spark.ml.attribute -import org.apache.spark.SparkFunSuite +import org.apache.spark.{SparkConf, SparkFunSuite} +import org.apache.spark.serializer.KryoSerializer import org.apache.spark.sql.types._ class AttributeSuite extends SparkFunSuite { @@ -221,4 +222,20 @@ class AttributeSuite extends SparkFunSuite { val decimalFldWithMeta = new StructField("x", DecimalType(38, 18), false, metadata) assert(Attribute.fromStructField(decimalFldWithMeta).isNumeric) } + + test("Kryo class register") { + val conf = new SparkConf(false) + conf.set("spark.kryo.registrationRequired", "true") + + val ser = new KryoSerializer(conf).newInstance() + + val numericAttr = new NumericAttribute(Some("numeric"), Some(1), Some(1.0), Some(2.0)) + val nominalAttr = new NominalAttribute(Some("nominal"), Some(2), Some(false)) + val binaryAttr = new BinaryAttribute(Some("binary"), Some(3), Some(Array("i", "j"))) + + Seq(numericAttr, nominalAttr, binaryAttr).foreach { i => + val i2 = ser.deserialize[Attribute](ser.serialize(i)) + assert(i === i2) + } + } } diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/ClassifierSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/ClassifierSuite.scala index 87bf2be06c2be..be52d99e54d3b 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/classification/ClassifierSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/classification/ClassifierSuite.scala @@ -117,10 +117,10 @@ object ClassifierSuite { def this() = this(Identifiable.randomUID("mockclassifier")) - override def copy(extra: ParamMap): MockClassifier = throw new NotImplementedError() + override def copy(extra: ParamMap): MockClassifier = throw new UnsupportedOperationException() override def train(dataset: Dataset[_]): MockClassificationModel = - throw new NotImplementedError() + throw new UnsupportedOperationException() // Make methods public override def extractLabeledPoints(dataset: Dataset[_], numClasses: Int): RDD[LabeledPoint] = @@ -133,11 +133,12 @@ object ClassifierSuite { def this() = this(Identifiable.randomUID("mockclassificationmodel")) - protected def predictRaw(features: Vector): Vector = throw new NotImplementedError() + protected def predictRaw(features: Vector): Vector = throw new UnsupportedOperationException() - override def copy(extra: ParamMap): MockClassificationModel = throw new NotImplementedError() + override def copy(extra: ParamMap): MockClassificationModel = + throw new UnsupportedOperationException() - override def numClasses: Int = throw new NotImplementedError() + override def numClasses: Int = throw new UnsupportedOperationException() } } diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/GBTClassifierSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/GBTClassifierSuite.scala index 304977634189c..cedbaf1858ef4 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/classification/GBTClassifierSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/classification/GBTClassifierSuite.scala @@ -448,6 +448,7 @@ class GBTClassifierSuite extends MLTest with DefaultReadWriteTest { model2: GBTClassificationModel): Unit = { TreeTests.checkEqual(model, model2) assert(model.numFeatures === model2.numFeatures) + assert(model.featureImportances == model2.featureImportances) } val gbt = new GBTClassifier() diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/NaiveBayesSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/NaiveBayesSuite.scala index 5f9ab98a2c3ce..a8c4f091b2aed 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/classification/NaiveBayesSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/classification/NaiveBayesSuite.scala @@ -103,7 +103,7 @@ class NaiveBayesSuite extends MLTest with DefaultReadWriteTest { case Bernoulli => expectedBernoulliProbabilities(model, features) case _ => - throw new UnknownError(s"Invalid modelType: $modelType.") + throw new IllegalArgumentException(s"Invalid modelType: $modelType.") } assert(probability ~== expected relTol 1.0e-10) } @@ -378,7 +378,7 @@ object NaiveBayesSuite { counts.toArray.sortBy(_._1).map(_._2) case _ => // This should never happen. - throw new UnknownError(s"Invalid modelType: $modelType.") + throw new IllegalArgumentException(s"Invalid modelType: $modelType.") } LabeledPoint(y, Vectors.dense(xi)) diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/OneVsRestSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/OneVsRestSuite.scala index 4fec5bf1eb812..259e7b62f8a66 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/classification/OneVsRestSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/classification/OneVsRestSuite.scala @@ -134,8 +134,8 @@ class OneVsRestSuite extends MLTest with DefaultReadWriteTest { assert(lrModel1.coefficients ~== lrModel2.coefficients relTol 1E-3) assert(lrModel1.intercept ~== lrModel2.intercept relTol 1E-3) case other => - throw new AssertionError(s"Loaded OneVsRestModel expected model of type" + - s" LogisticRegressionModel but found ${other.getClass.getName}") + fail("Loaded OneVsRestModel expected model of type LogisticRegressionModel " + + s"but found ${other.getClass.getName}") } } @@ -247,8 +247,8 @@ class OneVsRestSuite extends MLTest with DefaultReadWriteTest { assert(lr.getMaxIter === lr2.getMaxIter) assert(lr.getRegParam === lr2.getRegParam) case other => - throw new AssertionError(s"Loaded OneVsRest expected classifier of type" + - s" LogisticRegression but found ${other.getClass.getName}") + fail("Loaded OneVsRest expected classifier of type LogisticRegression" + + s" but found ${other.getClass.getName}") } } @@ -267,8 +267,8 @@ class OneVsRestSuite extends MLTest with DefaultReadWriteTest { assert(classifier.getMaxIter === lr2.getMaxIter) assert(classifier.getRegParam === lr2.getRegParam) case other => - throw new AssertionError(s"Loaded OneVsRestModel expected classifier of type" + - s" LogisticRegression but found ${other.getClass.getName}") + fail("Loaded OneVsRestModel expected classifier of type LogisticRegression" + + s" but found ${other.getClass.getName}") } assert(model.labelMetadata === model2.labelMetadata) @@ -278,8 +278,8 @@ class OneVsRestSuite extends MLTest with DefaultReadWriteTest { assert(lrModel1.coefficients === lrModel2.coefficients) assert(lrModel1.intercept === lrModel2.intercept) case other => - throw new AssertionError(s"Loaded OneVsRestModel expected model of type" + - s" LogisticRegressionModel but found ${other.getClass.getName}") + fail(s"Loaded OneVsRestModel expected model of type LogisticRegressionModel" + + s" but found ${other.getClass.getName}") } } diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/DCTSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/DCTSuite.scala index 6734336aac39c..985e396000d05 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/DCTSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/DCTSuite.scala @@ -17,16 +17,16 @@ package org.apache.spark.ml.feature -import scala.beans.BeanInfo - import edu.emory.mathcs.jtransforms.dct.DoubleDCT_1D import org.apache.spark.ml.linalg.{Vector, Vectors} import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTest} import org.apache.spark.sql.Row -@BeanInfo -case class DCTTestData(vec: Vector, wantedVec: Vector) +case class DCTTestData(vec: Vector, wantedVec: Vector) { + def getVec: Vector = vec + def getWantedVec: Vector = wantedVec +} class DCTSuite extends MLTest with DefaultReadWriteTest { diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/NGramSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/NGramSuite.scala index 201a335e0d7be..1483d5df4d224 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/NGramSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/NGramSuite.scala @@ -17,14 +17,13 @@ package org.apache.spark.ml.feature -import scala.beans.BeanInfo - import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTest} import org.apache.spark.sql.{DataFrame, Row} - -@BeanInfo -case class NGramTestData(inputTokens: Array[String], wantedNGrams: Array[String]) +case class NGramTestData(inputTokens: Array[String], wantedNGrams: Array[String]) { + def getInputTokens: Array[String] = inputTokens + def getWantedNGrams: Array[String] = wantedNGrams +} class NGramSuite extends MLTest with DefaultReadWriteTest { diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/QuantileDiscretizerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/QuantileDiscretizerSuite.scala index b009038bbd833..82af05039653e 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/QuantileDiscretizerSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/QuantileDiscretizerSuite.scala @@ -31,7 +31,7 @@ class QuantileDiscretizerSuite extends MLTest with DefaultReadWriteTest { val datasetSize = 100000 val numBuckets = 5 - val df = sc.parallelize(1.0 to datasetSize by 1.0).map(Tuple1.apply).toDF("input") + val df = sc.parallelize(1 to datasetSize).map(_.toDouble).map(Tuple1.apply).toDF("input") val discretizer = new QuantileDiscretizer() .setInputCol("input") .setOutputCol("result") @@ -114,8 +114,8 @@ class QuantileDiscretizerSuite extends MLTest with DefaultReadWriteTest { val spark = this.spark import spark.implicits._ - val trainDF = sc.parallelize(1.0 to 100.0 by 1.0).map(Tuple1.apply).toDF("input") - val testDF = sc.parallelize(-10.0 to 110.0 by 1.0).map(Tuple1.apply).toDF("input") + val trainDF = sc.parallelize((1 to 100).map(_.toDouble)).map(Tuple1.apply).toDF("input") + val testDF = sc.parallelize((-10 to 110).map(_.toDouble)).map(Tuple1.apply).toDF("input") val discretizer = new QuantileDiscretizer() .setInputCol("input") .setOutputCol("result") @@ -276,10 +276,10 @@ class QuantileDiscretizerSuite extends MLTest with DefaultReadWriteTest { 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0) val data2 = Array.range(1, 40, 2).map(_.toDouble) val expected2 = Array (0.0, 0.0, 0.0, 1.0, 1.0, 1.0, 1.0, 2.0, 2.0, 2.0, - 2.0, 2.0, 3.0, 3.0, 3.0, 4.0, 4.0, 4.0, 4.0, 4.0) + 2.0, 3.0, 3.0, 3.0, 3.0, 4.0, 4.0, 4.0, 4.0, 4.0) val data3 = Array.range(1, 60, 3).map(_.toDouble) - val expected3 = Array (0.0, 1.0, 1.0, 2.0, 2.0, 2.0, 3.0, 4.0, 4.0, 5.0, - 5.0, 5.0, 6.0, 6.0, 7.0, 8.0, 8.0, 9.0, 9.0, 9.0) + val expected3 = Array (0.0, 1.0, 1.0, 2.0, 2.0, 3.0, 3.0, 4.0, 4.0, 5.0, + 5.0, 6.0, 6.0, 7.0, 7.0, 8.0, 8.0, 9.0, 9.0, 9.0) val data = (0 until 20).map { idx => (data1(idx), data2(idx), data3(idx), expected1(idx), expected2(idx), expected3(idx)) } diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/TokenizerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/TokenizerSuite.scala index be59b0af2c78e..ba8e79f14de95 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/TokenizerSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/TokenizerSuite.scala @@ -17,14 +17,14 @@ package org.apache.spark.ml.feature -import scala.beans.BeanInfo - import org.apache.spark.ml.param.ParamsSuite import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTest} import org.apache.spark.sql.{DataFrame, Row} -@BeanInfo -case class TokenizerTestData(rawText: String, wantedTokens: Array[String]) +case class TokenizerTestData(rawText: String, wantedTokens: Array[String]) { + def getRawText: String = rawText + def getWantedTokens: Array[String] = wantedTokens +} class TokenizerSuite extends MLTest with DefaultReadWriteTest { diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/VectorIndexerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/VectorIndexerSuite.scala index e5675e31bbecf..44b0f8f8ae7d8 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/VectorIndexerSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/VectorIndexerSuite.scala @@ -17,8 +17,6 @@ package org.apache.spark.ml.feature -import scala.beans.{BeanInfo, BeanProperty} - import org.apache.spark.SparkException import org.apache.spark.internal.Logging import org.apache.spark.ml.attribute._ @@ -26,7 +24,7 @@ import org.apache.spark.ml.linalg.{SparseVector, Vector, Vectors} import org.apache.spark.ml.param.ParamsSuite import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTest, MLTestingUtils} import org.apache.spark.rdd.RDD -import org.apache.spark.sql.{DataFrame, Row} +import org.apache.spark.sql.DataFrame class VectorIndexerSuite extends MLTest with DefaultReadWriteTest with Logging { @@ -283,7 +281,9 @@ class VectorIndexerSuite extends MLTest with DefaultReadWriteTest with Logging { points.zip(rows.map(_(0))).foreach { case (orig: SparseVector, indexed: SparseVector) => assert(orig.indices.length == indexed.indices.length) - case _ => throw new UnknownError("Unit test has a bug in it.") // should never happen + case _ => + // should never happen + fail("Unit test has a bug in it.") } } } @@ -337,6 +337,7 @@ class VectorIndexerSuite extends MLTest with DefaultReadWriteTest with Logging { } private[feature] object VectorIndexerSuite { - @BeanInfo - case class FeatureData(@BeanProperty features: Vector) + case class FeatureData(features: Vector) { + def getFeatures: Vector = features + } } diff --git a/mllib/src/test/scala/org/apache/spark/ml/linalg/MatrixUDTSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/linalg/MatrixUDTSuite.scala index bdceba7887cac..8371c33a209dc 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/linalg/MatrixUDTSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/linalg/MatrixUDTSuite.scala @@ -31,7 +31,7 @@ class MatrixUDTSuite extends SparkFunSuite { val sm3 = dm3.toSparse for (m <- Seq(dm1, dm2, dm3, sm1, sm2, sm3)) { - val udt = UDTRegistration.getUDTFor(m.getClass.getName).get.newInstance() + val udt = UDTRegistration.getUDTFor(m.getClass.getName).get.getConstructor().newInstance() .asInstanceOf[MatrixUDT] assert(m === udt.deserialize(udt.serialize(m))) assert(udt.typeName == "matrix") diff --git a/mllib/src/test/scala/org/apache/spark/ml/linalg/VectorUDTSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/linalg/VectorUDTSuite.scala index 6ddb12cb76aac..67c64f762b25e 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/linalg/VectorUDTSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/linalg/VectorUDTSuite.scala @@ -31,7 +31,7 @@ class VectorUDTSuite extends SparkFunSuite { val sv2 = Vectors.sparse(2, Array(1), Array(2.0)) for (v <- Seq(dv1, dv2, sv1, sv2)) { - val udt = UDTRegistration.getUDTFor(v.getClass.getName).get.newInstance() + val udt = UDTRegistration.getUDTFor(v.getClass.getName).get.getConstructor().newInstance() .asInstanceOf[VectorUDT] assert(v === udt.deserialize(udt.serialize(v))) assert(udt.typeName == "vector") diff --git a/mllib/src/test/scala/org/apache/spark/ml/recommendation/ALSSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/recommendation/ALSSuite.scala index 9a59c41740daf..2fc9754ecfe1e 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/recommendation/ALSSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/recommendation/ALSSuite.scala @@ -601,7 +601,7 @@ class ALSSuite extends MLTest with DefaultReadWriteTest with Logging { val df = maybeDf.get._2 val expected = estimator.fit(df) - val actuals = dfs.filter(_ != baseType).map(t => (t, estimator.fit(t._2))) + val actuals = dfs.map(t => (t, estimator.fit(t._2))) actuals.foreach { case (_, actual) => check(expected, actual) } actuals.foreach { case (t, actual) => check2(expected, actual, t._2, t._1.encoder) } diff --git a/mllib/src/test/scala/org/apache/spark/ml/tree/impl/RandomForestSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/tree/impl/RandomForestSuite.scala index 743dacf146fe7..5caa5117d5752 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/tree/impl/RandomForestSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/tree/impl/RandomForestSuite.scala @@ -417,9 +417,9 @@ class RandomForestSuite extends SparkFunSuite with MLlibTestSparkContext { case n: InternalNode => n.split match { case s: CategoricalSplit => assert(s.leftCategories === Array(1.0)) - case _ => throw new AssertionError("model.rootNode.split was not a CategoricalSplit") + case _ => fail("model.rootNode.split was not a CategoricalSplit") } - case _ => throw new AssertionError("model.rootNode was not an InternalNode") + case _ => fail("model.rootNode was not an InternalNode") } } @@ -444,7 +444,7 @@ class RandomForestSuite extends SparkFunSuite with MLlibTestSparkContext { assert(n.leftChild.isInstanceOf[InternalNode]) assert(n.rightChild.isInstanceOf[InternalNode]) Array(n.leftChild.asInstanceOf[InternalNode], n.rightChild.asInstanceOf[InternalNode]) - case _ => throw new AssertionError("rootNode was not an InternalNode") + case _ => fail("rootNode was not an InternalNode") } // Single group second level tree construction. diff --git a/mllib/src/test/scala/org/apache/spark/ml/tree/impl/TreeTests.scala b/mllib/src/test/scala/org/apache/spark/ml/tree/impl/TreeTests.scala index b6894b30b0c2b..ae9794b87b08d 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/tree/impl/TreeTests.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/tree/impl/TreeTests.scala @@ -112,7 +112,7 @@ private[ml] object TreeTests extends SparkFunSuite { checkEqual(a.rootNode, b.rootNode) } catch { case ex: Exception => - throw new AssertionError("checkEqual failed since the two trees were not identical.\n" + + fail("checkEqual failed since the two trees were not identical.\n" + "TREE A:\n" + a.toDebugString + "\n" + "TREE B:\n" + b.toDebugString + "\n", ex) } @@ -133,7 +133,7 @@ private[ml] object TreeTests extends SparkFunSuite { checkEqual(aye.rightChild, bee.rightChild) case (aye: LeafNode, bee: LeafNode) => // do nothing case _ => - throw new AssertionError("Found mismatched nodes") + fail("Found mismatched nodes") } } @@ -148,7 +148,7 @@ private[ml] object TreeTests extends SparkFunSuite { } assert(a.treeWeights === b.treeWeights) } catch { - case ex: Exception => throw new AssertionError( + case ex: Exception => fail( "checkEqual failed since the two tree ensembles were not identical") } } diff --git a/mllib/src/test/scala/org/apache/spark/ml/tuning/CrossValidatorSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/tuning/CrossValidatorSuite.scala index e6ee7220d2279..a30428ec2d283 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/tuning/CrossValidatorSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/tuning/CrossValidatorSuite.scala @@ -190,8 +190,8 @@ class CrossValidatorSuite assert(lr.uid === lr2.uid) assert(lr.getMaxIter === lr2.getMaxIter) case other => - throw new AssertionError(s"Loaded CrossValidator expected estimator of type" + - s" LogisticRegression but found ${other.getClass.getName}") + fail("Loaded CrossValidator expected estimator of type LogisticRegression" + + s" but found ${other.getClass.getName}") } ValidatorParamsSuiteHelpers @@ -281,13 +281,13 @@ class CrossValidatorSuite assert(ova.getClassifier.asInstanceOf[LogisticRegression].getMaxIter === lr.getMaxIter) case other => - throw new AssertionError(s"Loaded CrossValidator expected estimator of type" + - s" LogisticRegression but found ${other.getClass.getName}") + fail("Loaded CrossValidator expected estimator of type LogisticRegression" + + s" but found ${other.getClass.getName}") } case other => - throw new AssertionError(s"Loaded CrossValidator expected estimator of type" + - s" OneVsRest but found ${other.getClass.getName}") + fail("Loaded CrossValidator expected estimator of type OneVsRest but " + + s"found ${other.getClass.getName}") } ValidatorParamsSuiteHelpers @@ -364,8 +364,8 @@ class CrossValidatorSuite assert(lr.uid === lr2.uid) assert(lr.getMaxIter === lr2.getMaxIter) case other => - throw new AssertionError(s"Loaded internal CrossValidator expected to be" + - s" LogisticRegression but found type ${other.getClass.getName}") + fail("Loaded internal CrossValidator expected to be LogisticRegression" + + s" but found type ${other.getClass.getName}") } assert(lrcv.uid === lrcv2.uid) assert(lrcv2.getEvaluator.isInstanceOf[BinaryClassificationEvaluator]) @@ -373,12 +373,12 @@ class CrossValidatorSuite ValidatorParamsSuiteHelpers .compareParamMaps(lrParamMaps, lrcv2.getEstimatorParamMaps) case other => - throw new AssertionError("Loaded Pipeline expected stages (HashingTF, CrossValidator)" + - " but found: " + other.map(_.getClass.getName).mkString(", ")) + fail("Loaded Pipeline expected stages (HashingTF, CrossValidator) but found: " + + other.map(_.getClass.getName).mkString(", ")) } case other => - throw new AssertionError(s"Loaded CrossValidator expected estimator of type" + - s" CrossValidator but found ${other.getClass.getName}") + fail("Loaded CrossValidator expected estimator of type CrossValidator but found" + + s" ${other.getClass.getName}") } } @@ -433,8 +433,8 @@ class CrossValidatorSuite assert(lr.uid === lr2.uid) assert(lr.getThreshold === lr2.getThreshold) case other => - throw new AssertionError(s"Loaded CrossValidator expected estimator of type" + - s" LogisticRegression but found ${other.getClass.getName}") + fail("Loaded CrossValidator expected estimator of type LogisticRegression" + + s" but found ${other.getClass.getName}") } ValidatorParamsSuiteHelpers @@ -447,8 +447,8 @@ class CrossValidatorSuite assert(lrModel.coefficients === lrModel2.coefficients) assert(lrModel.intercept === lrModel2.intercept) case other => - throw new AssertionError(s"Loaded CrossValidator expected bestModel of type" + - s" LogisticRegressionModel but found ${other.getClass.getName}") + fail("Loaded CrossValidator expected bestModel of type LogisticRegressionModel" + + s" but found ${other.getClass.getName}") } assert(cv.avgMetrics === cv2.avgMetrics) } diff --git a/mllib/src/test/scala/org/apache/spark/ml/tuning/TrainValidationSplitSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/tuning/TrainValidationSplitSuite.scala index cd76acf9c67bc..289db336eca5d 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/tuning/TrainValidationSplitSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/tuning/TrainValidationSplitSuite.scala @@ -187,8 +187,8 @@ class TrainValidationSplitSuite assert(lr.uid === lr2.uid) assert(lr.getMaxIter === lr2.getMaxIter) case other => - throw new AssertionError(s"Loaded TrainValidationSplit expected estimator of type" + - s" LogisticRegression but found ${other.getClass.getName}") + fail("Loaded TrainValidationSplit expected estimator of type LogisticRegression" + + s" but found ${other.getClass.getName}") } } @@ -264,13 +264,13 @@ class TrainValidationSplitSuite assert(ova.getClassifier.asInstanceOf[LogisticRegression].getMaxIter === lr.getMaxIter) case other => - throw new AssertionError(s"Loaded TrainValidationSplit expected estimator of type" + - s" LogisticRegression but found ${other.getClass.getName}") + fail(s"Loaded TrainValidationSplit expected estimator of type LogisticRegression" + + s" but found ${other.getClass.getName}") } case other => - throw new AssertionError(s"Loaded TrainValidationSplit expected estimator of type" + - s" OneVsRest but found ${other.getClass.getName}") + fail(s"Loaded TrainValidationSplit expected estimator of type OneVsRest" + + s" but found ${other.getClass.getName}") } ValidatorParamsSuiteHelpers diff --git a/mllib/src/test/scala/org/apache/spark/ml/tuning/ValidatorParamsSuiteHelpers.scala b/mllib/src/test/scala/org/apache/spark/ml/tuning/ValidatorParamsSuiteHelpers.scala index eae1f5adc8842..cea2f50d3470c 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/tuning/ValidatorParamsSuiteHelpers.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/tuning/ValidatorParamsSuiteHelpers.scala @@ -47,8 +47,7 @@ object ValidatorParamsSuiteHelpers extends Assertions { val estimatorParamMap2 = Array(estimator2.extractParamMap()) compareParamMaps(estimatorParamMap, estimatorParamMap2) case other => - throw new AssertionError(s"Expected parameter of type Params but" + - s" found ${otherParam.getClass.getName}") + fail(s"Expected parameter of type Params but found ${otherParam.getClass.getName}") } case _ => assert(otherParam === v) diff --git a/mllib/src/test/scala/org/apache/spark/mllib/classification/NaiveBayesSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/classification/NaiveBayesSuite.scala index 5ec4c15387e94..8c7d583923b32 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/classification/NaiveBayesSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/classification/NaiveBayesSuite.scala @@ -71,7 +71,7 @@ object NaiveBayesSuite { counts.toArray.sortBy(_._1).map(_._2) case _ => // This should never happen. - throw new UnknownError(s"Invalid modelType: $modelType.") + throw new IllegalArgumentException(s"Invalid modelType: $modelType.") } LabeledPoint(y, Vectors.dense(xi)) diff --git a/mllib/src/test/scala/org/apache/spark/mllib/clustering/KMeansSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/clustering/KMeansSuite.scala index 1b98250061c7a..d18cef7e264db 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/clustering/KMeansSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/clustering/KMeansSuite.scala @@ -349,7 +349,7 @@ object KMeansSuite extends SparkFunSuite { case (ca: DenseVector, cb: DenseVector) => assert(ca === cb) case _ => - throw new AssertionError("checkEqual failed since the two clusters were not identical.\n") + fail("checkEqual failed since the two clusters were not identical.\n") } } } diff --git a/mllib/src/test/scala/org/apache/spark/mllib/evaluation/MulticlassMetricsSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/evaluation/MulticlassMetricsSuite.scala index 5394baab94bcf..8779de590a256 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/evaluation/MulticlassMetricsSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/evaluation/MulticlassMetricsSuite.scala @@ -18,10 +18,14 @@ package org.apache.spark.mllib.evaluation import org.apache.spark.SparkFunSuite -import org.apache.spark.mllib.linalg.Matrices +import org.apache.spark.ml.linalg.Matrices +import org.apache.spark.ml.util.TestingUtils._ import org.apache.spark.mllib.util.MLlibTestSparkContext class MulticlassMetricsSuite extends SparkFunSuite with MLlibTestSparkContext { + + private val delta = 1e-7 + test("Multiclass evaluation metrics") { /* * Confusion matrix for 3-class classification with total 9 instances: @@ -35,7 +39,6 @@ class MulticlassMetricsSuite extends SparkFunSuite with MLlibTestSparkContext { Seq((0.0, 0.0), (0.0, 1.0), (0.0, 0.0), (1.0, 0.0), (1.0, 1.0), (1.0, 1.0), (1.0, 1.0), (2.0, 2.0), (2.0, 0.0)), 2) val metrics = new MulticlassMetrics(predictionAndLabels) - val delta = 0.0000001 val tpRate0 = 2.0 / (2 + 2) val tpRate1 = 3.0 / (3 + 1) val tpRate2 = 1.0 / (1 + 0) @@ -55,41 +58,122 @@ class MulticlassMetricsSuite extends SparkFunSuite with MLlibTestSparkContext { val f2measure1 = (1 + 2 * 2) * precision1 * recall1 / (2 * 2 * precision1 + recall1) val f2measure2 = (1 + 2 * 2) * precision2 * recall2 / (2 * 2 * precision2 + recall2) - assert(metrics.confusionMatrix.toArray.sameElements(confusionMatrix.toArray)) - assert(math.abs(metrics.truePositiveRate(0.0) - tpRate0) < delta) - assert(math.abs(metrics.truePositiveRate(1.0) - tpRate1) < delta) - assert(math.abs(metrics.truePositiveRate(2.0) - tpRate2) < delta) - assert(math.abs(metrics.falsePositiveRate(0.0) - fpRate0) < delta) - assert(math.abs(metrics.falsePositiveRate(1.0) - fpRate1) < delta) - assert(math.abs(metrics.falsePositiveRate(2.0) - fpRate2) < delta) - assert(math.abs(metrics.precision(0.0) - precision0) < delta) - assert(math.abs(metrics.precision(1.0) - precision1) < delta) - assert(math.abs(metrics.precision(2.0) - precision2) < delta) - assert(math.abs(metrics.recall(0.0) - recall0) < delta) - assert(math.abs(metrics.recall(1.0) - recall1) < delta) - assert(math.abs(metrics.recall(2.0) - recall2) < delta) - assert(math.abs(metrics.fMeasure(0.0) - f1measure0) < delta) - assert(math.abs(metrics.fMeasure(1.0) - f1measure1) < delta) - assert(math.abs(metrics.fMeasure(2.0) - f1measure2) < delta) - assert(math.abs(metrics.fMeasure(0.0, 2.0) - f2measure0) < delta) - assert(math.abs(metrics.fMeasure(1.0, 2.0) - f2measure1) < delta) - assert(math.abs(metrics.fMeasure(2.0, 2.0) - f2measure2) < delta) + assert(metrics.confusionMatrix.asML ~== confusionMatrix relTol delta) + assert(metrics.truePositiveRate(0.0) ~== tpRate0 relTol delta) + assert(metrics.truePositiveRate(1.0) ~== tpRate1 relTol delta) + assert(metrics.truePositiveRate(2.0) ~== tpRate2 relTol delta) + assert(metrics.falsePositiveRate(0.0) ~== fpRate0 relTol delta) + assert(metrics.falsePositiveRate(1.0) ~== fpRate1 relTol delta) + assert(metrics.falsePositiveRate(2.0) ~== fpRate2 relTol delta) + assert(metrics.precision(0.0) ~== precision0 relTol delta) + assert(metrics.precision(1.0) ~== precision1 relTol delta) + assert(metrics.precision(2.0) ~== precision2 relTol delta) + assert(metrics.recall(0.0) ~== recall0 relTol delta) + assert(metrics.recall(1.0) ~== recall1 relTol delta) + assert(metrics.recall(2.0) ~== recall2 relTol delta) + assert(metrics.fMeasure(0.0) ~== f1measure0 relTol delta) + assert(metrics.fMeasure(1.0) ~== f1measure1 relTol delta) + assert(metrics.fMeasure(2.0) ~== f1measure2 relTol delta) + assert(metrics.fMeasure(0.0, 2.0) ~== f2measure0 relTol delta) + assert(metrics.fMeasure(1.0, 2.0) ~== f2measure1 relTol delta) + assert(metrics.fMeasure(2.0, 2.0) ~== f2measure2 relTol delta) + + assert(metrics.accuracy ~== + (2.0 + 3.0 + 1.0) / ((2 + 3 + 1) + (1 + 1 + 1)) relTol delta) + assert(metrics.accuracy ~== metrics.weightedRecall relTol delta) + val weight0 = 4.0 / 9 + val weight1 = 4.0 / 9 + val weight2 = 1.0 / 9 + assert(metrics.weightedTruePositiveRate ~== + (weight0 * tpRate0 + weight1 * tpRate1 + weight2 * tpRate2) relTol delta) + assert(metrics.weightedFalsePositiveRate ~== + (weight0 * fpRate0 + weight1 * fpRate1 + weight2 * fpRate2) relTol delta) + assert(metrics.weightedPrecision ~== + (weight0 * precision0 + weight1 * precision1 + weight2 * precision2) relTol delta) + assert(metrics.weightedRecall ~== + (weight0 * recall0 + weight1 * recall1 + weight2 * recall2) relTol delta) + assert(metrics.weightedFMeasure ~== + (weight0 * f1measure0 + weight1 * f1measure1 + weight2 * f1measure2) relTol delta) + assert(metrics.weightedFMeasure(2.0) ~== + (weight0 * f2measure0 + weight1 * f2measure1 + weight2 * f2measure2) relTol delta) + assert(metrics.labels === labels) + } + + test("Multiclass evaluation metrics with weights") { + /* + * Confusion matrix for 3-class classification with total 9 instances with 2 weights: + * |2 * w1|1 * w2 |1 * w1| true class0 (4 instances) + * |1 * w2|2 * w1 + 1 * w2|0 | true class1 (4 instances) + * |0 |0 |1 * w2| true class2 (1 instance) + */ + val w1 = 2.2 + val w2 = 1.5 + val tw = 2.0 * w1 + 1.0 * w2 + 1.0 * w1 + 1.0 * w2 + 2.0 * w1 + 1.0 * w2 + 1.0 * w2 + val confusionMatrix = Matrices.dense(3, 3, + Array(2 * w1, 1 * w2, 0, 1 * w2, 2 * w1 + 1 * w2, 0, 1 * w1, 0, 1 * w2)) + val labels = Array(0.0, 1.0, 2.0) + val predictionAndLabelsWithWeights = sc.parallelize( + Seq((0.0, 0.0, w1), (0.0, 1.0, w2), (0.0, 0.0, w1), (1.0, 0.0, w2), + (1.0, 1.0, w1), (1.0, 1.0, w2), (1.0, 1.0, w1), (2.0, 2.0, w2), + (2.0, 0.0, w1)), 2) + val metrics = new MulticlassMetrics(predictionAndLabelsWithWeights) + val tpRate0 = (2.0 * w1) / (2.0 * w1 + 1.0 * w2 + 1.0 * w1) + val tpRate1 = (2.0 * w1 + 1.0 * w2) / (2.0 * w1 + 1.0 * w2 + 1.0 * w2) + val tpRate2 = (1.0 * w2) / (1.0 * w2 + 0) + val fpRate0 = (1.0 * w2) / (tw - (2.0 * w1 + 1.0 * w2 + 1.0 * w1)) + val fpRate1 = (1.0 * w2) / (tw - (1.0 * w2 + 2.0 * w1 + 1.0 * w2)) + val fpRate2 = (1.0 * w1) / (tw - (1.0 * w2)) + val precision0 = (2.0 * w1) / (2 * w1 + 1 * w2) + val precision1 = (2.0 * w1 + 1.0 * w2) / (2.0 * w1 + 1.0 * w2 + 1.0 * w2) + val precision2 = (1.0 * w2) / (1 * w1 + 1 * w2) + val recall0 = (2.0 * w1) / (2.0 * w1 + 1.0 * w2 + 1.0 * w1) + val recall1 = (2.0 * w1 + 1.0 * w2) / (2.0 * w1 + 1.0 * w2 + 1.0 * w2) + val recall2 = (1.0 * w2) / (1.0 * w2 + 0) + val f1measure0 = 2 * precision0 * recall0 / (precision0 + recall0) + val f1measure1 = 2 * precision1 * recall1 / (precision1 + recall1) + val f1measure2 = 2 * precision2 * recall2 / (precision2 + recall2) + val f2measure0 = (1 + 2 * 2) * precision0 * recall0 / (2 * 2 * precision0 + recall0) + val f2measure1 = (1 + 2 * 2) * precision1 * recall1 / (2 * 2 * precision1 + recall1) + val f2measure2 = (1 + 2 * 2) * precision2 * recall2 / (2 * 2 * precision2 + recall2) + + assert(metrics.confusionMatrix.asML ~== confusionMatrix relTol delta) + assert(metrics.truePositiveRate(0.0) ~== tpRate0 relTol delta) + assert(metrics.truePositiveRate(1.0) ~== tpRate1 relTol delta) + assert(metrics.truePositiveRate(2.0) ~== tpRate2 relTol delta) + assert(metrics.falsePositiveRate(0.0) ~== fpRate0 relTol delta) + assert(metrics.falsePositiveRate(1.0) ~== fpRate1 relTol delta) + assert(metrics.falsePositiveRate(2.0) ~== fpRate2 relTol delta) + assert(metrics.precision(0.0) ~== precision0 relTol delta) + assert(metrics.precision(1.0) ~== precision1 relTol delta) + assert(metrics.precision(2.0) ~== precision2 relTol delta) + assert(metrics.recall(0.0) ~== recall0 relTol delta) + assert(metrics.recall(1.0) ~== recall1 relTol delta) + assert(metrics.recall(2.0) ~== recall2 relTol delta) + assert(metrics.fMeasure(0.0) ~== f1measure0 relTol delta) + assert(metrics.fMeasure(1.0) ~== f1measure1 relTol delta) + assert(metrics.fMeasure(2.0) ~== f1measure2 relTol delta) + assert(metrics.fMeasure(0.0, 2.0) ~== f2measure0 relTol delta) + assert(metrics.fMeasure(1.0, 2.0) ~== f2measure1 relTol delta) + assert(metrics.fMeasure(2.0, 2.0) ~== f2measure2 relTol delta) - assert(math.abs(metrics.accuracy - - (2.0 + 3.0 + 1.0) / ((2 + 3 + 1) + (1 + 1 + 1))) < delta) - assert(math.abs(metrics.accuracy - metrics.weightedRecall) < delta) - assert(math.abs(metrics.weightedTruePositiveRate - - ((4.0 / 9) * tpRate0 + (4.0 / 9) * tpRate1 + (1.0 / 9) * tpRate2)) < delta) - assert(math.abs(metrics.weightedFalsePositiveRate - - ((4.0 / 9) * fpRate0 + (4.0 / 9) * fpRate1 + (1.0 / 9) * fpRate2)) < delta) - assert(math.abs(metrics.weightedPrecision - - ((4.0 / 9) * precision0 + (4.0 / 9) * precision1 + (1.0 / 9) * precision2)) < delta) - assert(math.abs(metrics.weightedRecall - - ((4.0 / 9) * recall0 + (4.0 / 9) * recall1 + (1.0 / 9) * recall2)) < delta) - assert(math.abs(metrics.weightedFMeasure - - ((4.0 / 9) * f1measure0 + (4.0 / 9) * f1measure1 + (1.0 / 9) * f1measure2)) < delta) - assert(math.abs(metrics.weightedFMeasure(2.0) - - ((4.0 / 9) * f2measure0 + (4.0 / 9) * f2measure1 + (1.0 / 9) * f2measure2)) < delta) - assert(metrics.labels.sameElements(labels)) + assert(metrics.accuracy ~== + (2.0 * w1 + 2.0 * w1 + 1.0 * w2 + 1.0 * w2) / tw relTol delta) + assert(metrics.accuracy ~== metrics.weightedRecall relTol delta) + val weight0 = (2 * w1 + 1 * w2 + 1 * w1) / tw + val weight1 = (1 * w2 + 2 * w1 + 1 * w2) / tw + val weight2 = 1 * w2 / tw + assert(metrics.weightedTruePositiveRate ~== + (weight0 * tpRate0 + weight1 * tpRate1 + weight2 * tpRate2) relTol delta) + assert(metrics.weightedFalsePositiveRate ~== + (weight0 * fpRate0 + weight1 * fpRate1 + weight2 * fpRate2) relTol delta) + assert(metrics.weightedPrecision ~== + (weight0 * precision0 + weight1 * precision1 + weight2 * precision2) relTol delta) + assert(metrics.weightedRecall ~== + (weight0 * recall0 + weight1 * recall1 + weight2 * recall2) relTol delta) + assert(metrics.weightedFMeasure ~== + (weight0 * f1measure0 + weight1 * f1measure1 + weight2 * f1measure2) relTol delta) + assert(metrics.weightedFMeasure(2.0) ~== + (weight0 * f2measure0 + weight1 * f2measure1 + weight2 * f2measure2) relTol delta) + assert(metrics.labels === labels) } } diff --git a/mllib/src/test/scala/org/apache/spark/mllib/linalg/MatricesSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/linalg/MatricesSuite.scala index d76edb940b2bd..2c3f84617cfa5 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/linalg/MatricesSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/linalg/MatricesSuite.scala @@ -511,10 +511,10 @@ class MatricesSuite extends SparkFunSuite { mat.toString(0, 0) mat.toString(Int.MinValue, Int.MinValue) mat.toString(Int.MaxValue, Int.MaxValue) - var lines = mat.toString(6, 50).lines.toArray + var lines = mat.toString(6, 50).split('\n') assert(lines.size == 5 && lines.forall(_.size <= 50)) - lines = mat.toString(5, 100).lines.toArray + lines = mat.toString(5, 100).split('\n') assert(lines.size == 5 && lines.forall(_.size <= 100)) } diff --git a/mllib/src/test/scala/org/apache/spark/mllib/stat/distribution/MultivariateGaussianSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/stat/distribution/MultivariateGaussianSuite.scala index 669d44223d713..5b4a2607f0b25 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/stat/distribution/MultivariateGaussianSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/stat/distribution/MultivariateGaussianSuite.scala @@ -17,10 +17,11 @@ package org.apache.spark.mllib.stat.distribution -import org.apache.spark.SparkFunSuite +import org.apache.spark.{SparkConf, SparkFunSuite} import org.apache.spark.mllib.linalg.{Matrices, Vectors} import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.mllib.util.TestingUtils._ +import org.apache.spark.serializer.KryoSerializer class MultivariateGaussianSuite extends SparkFunSuite with MLlibTestSparkContext { test("univariate") { @@ -80,4 +81,23 @@ class MultivariateGaussianSuite extends SparkFunSuite with MLlibTestSparkContext assert(dist.pdf(x) ~== 7.154782224045512E-5 absTol 1E-9) } + test("Kryo class register") { + val conf = new SparkConf(false) + conf.set("spark.kryo.registrationRequired", "true") + + val ser = new KryoSerializer(conf).newInstance() + + val mu = Vectors.dense(0.0, 0.0) + val sigma1 = Matrices.dense(2, 2, Array(1.0, 0.0, 0.0, 1.0)) + val dist1 = new MultivariateGaussian(mu, sigma1) + + val sigma2 = Matrices.dense(2, 2, Array(4.0, -1.0, -1.0, 2.0)) + val dist2 = new MultivariateGaussian(mu, sigma2) + + Seq(dist1, dist2).foreach { i => + val i2 = ser.deserialize[MultivariateGaussian](ser.serialize(i)) + assert(i.sigma === i2.sigma) + assert(i.mu === i2.mu) + } + } } diff --git a/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala index bc59f3f4125fb..34bc303ac6079 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala @@ -607,7 +607,7 @@ object DecisionTreeSuite extends SparkFunSuite { checkEqual(a.topNode, b.topNode) } catch { case ex: Exception => - throw new AssertionError("checkEqual failed since the two trees were not identical.\n" + + fail("checkEqual failed since the two trees were not identical.\n" + "TREE A:\n" + a.toDebugString + "\n" + "TREE B:\n" + b.toDebugString + "\n", ex) } @@ -628,20 +628,21 @@ object DecisionTreeSuite extends SparkFunSuite { // TODO: Check other fields besides the information gain. case (Some(aStats), Some(bStats)) => assert(aStats.gain === bStats.gain) case (None, None) => - case _ => throw new AssertionError( - s"Only one instance has stats defined. (a.stats: ${a.stats}, b.stats: ${b.stats})") + case _ => fail(s"Only one instance has stats defined. (a.stats: ${a.stats}, " + + s"b.stats: ${b.stats})") } (a.leftNode, b.leftNode) match { case (Some(aNode), Some(bNode)) => checkEqual(aNode, bNode) case (None, None) => - case _ => throw new AssertionError("Only one instance has leftNode defined. " + - s"(a.leftNode: ${a.leftNode}, b.leftNode: ${b.leftNode})") + case _ => + fail("Only one instance has leftNode defined. (a.leftNode: ${a.leftNode}," + + " b.leftNode: ${b.leftNode})") } (a.rightNode, b.rightNode) match { case (Some(aNode: Node), Some(bNode: Node)) => checkEqual(aNode, bNode) case (None, None) => - case _ => throw new AssertionError("Only one instance has rightNode defined. " + - s"(a.rightNode: ${a.rightNode}, b.rightNode: ${b.rightNode})") + case _ => fail("Only one instance has rightNode defined. (a.rightNode: ${a.rightNode}, " + + "b.rightNode: ${b.rightNode})") } } } diff --git a/pom.xml b/pom.xml index 247b77c71928b..0654831b70582 100644 --- a/pom.xml +++ b/pom.xml @@ -25,7 +25,7 @@ 18 org.apache.spark - spark-parent_2.11 + spark-parent_2.12 3.0.0-SNAPSHOT pom Spark Project Parent POM @@ -116,14 +116,14 @@ 1.8 ${java.version} ${java.version} - 3.5.4 + 3.6.0 spark 1.7.25 1.2.17 - 2.7.3 + 2.7.4 2.5.0 ${hadoop.version} - 3.4.6 + 3.4.7 2.7.1 org.spark-project.hive @@ -164,8 +164,8 @@ 3.6.1 3.2.2 - 2.11.12 - 2.11 + 2.12.7 + 2.12 1.9.13 2.9.7 1.1.7.2 @@ -176,8 +176,8 @@ 2.6 - 3.8 - 1.8.1 + 3.8.1 + 1.18 1.6 3.2.10 1.1.1 @@ -188,7 +188,7 @@ 3.5.2 3.0.2 0.9.3 - 4.7 + 4.7.1 3.4 1.1 2.52.0 @@ -482,6 +482,11 @@ commons-lang3 ${commons-lang3.version} + + org.apache.commons + commons-text + 1.6 + commons-lang commons-lang @@ -2264,7 +2269,7 @@ org.apache.maven.plugins maven-enforcer-plugin - 3.0.0-M1 + 3.0.0-M2 enforce-versions @@ -2290,6 +2295,7 @@ --> org.jboss.netty org.codehaus.groovy + *:*_2.11 *:*_2.10 true @@ -2307,8 +2313,7 @@ net.alchim31.maven scala-maven-plugin - - 3.2.2 + 3.4.4 eclipse-add-source @@ -2328,6 +2333,13 @@ testCompile + + attach-scaladocs + verify + + doc-jar + + ${scala.version} @@ -2357,7 +2369,7 @@ org.apache.maven.plugins maven-compiler-plugin - 3.7.0 + 3.8.0 ${java.version} ${java.version} @@ -2374,7 +2386,7 @@ org.apache.maven.plugins maven-surefire-plugin - 2.22.0 + 3.0.0-M1 @@ -2428,7 +2440,7 @@ org.scalatest scalatest-maven-plugin - 1.0 + 2.0.0 ${project.build.directory}/surefire-reports @@ -2475,7 +2487,7 @@ org.apache.maven.plugins maven-jar-plugin - 3.0.2 + 3.1.0 org.apache.maven.plugins @@ -2502,7 +2514,7 @@ org.apache.maven.plugins maven-clean-plugin - 3.0.0 + 3.1.0 @@ -2520,9 +2532,12 @@ org.apache.maven.plugins maven-javadoc-plugin - 3.0.0-M1 + 3.0.1 - -Xdoclint:all -Xdoclint:-missing + + -Xdoclint:all + -Xdoclint:-missing + example @@ -2573,22 +2588,34 @@ org.apache.maven.plugins maven-shade-plugin - 3.1.0 + 3.2.1 + + + org.ow2.asm + asm + 7.0 + + + org.ow2.asm + asm-commons + 7.0 + + org.apache.maven.plugins maven-install-plugin - 2.5.2 + 3.0.0-M1 org.apache.maven.plugins maven-deploy-plugin - 2.8.2 + 3.0.0-M1 org.apache.maven.plugins maven-dependency-plugin - 3.0.2 + 3.1.1 default-cli @@ -2629,7 +2656,7 @@ org.apache.maven.plugins maven-jar-plugin - [2.6,) + 3.1.0 test-jar @@ -2786,12 +2813,17 @@ org.apache.maven.plugins maven-checkstyle-plugin - 2.17 + 3.0.0 false true - ${basedir}/src/main/java,${basedir}/src/main/scala - ${basedir}/src/test/java + + ${basedir}/src/main/java + ${basedir}/src/main/scala + + + ${basedir}/src/test/java + dev/checkstyle.xml ${basedir}/target/checkstyle-output.xml ${project.build.sourceEncoding} @@ -2801,7 +2833,7 @@ com.puppycrawl.tools checkstyle - 8.2 + 8.14 @@ -3027,14 +3059,14 @@ - scala-2.11 + scala-2.12 - scala-2.12 + scala-2.11 - 2.12.7 - 2.12 + 2.11.12 + 2.11 @@ -3050,8 +3082,9 @@ - - *:*_2.11 + + org.jboss.netty + org.codehaus.groovy *:*_2.10 diff --git a/project/MimaBuild.scala b/project/MimaBuild.scala index adde213e361f0..10c02103aeddb 100644 --- a/project/MimaBuild.scala +++ b/project/MimaBuild.scala @@ -88,9 +88,9 @@ object MimaBuild { def mimaSettings(sparkHome: File, projectRef: ProjectRef) = { val organization = "org.apache.spark" - val previousSparkVersion = "2.2.0" + val previousSparkVersion = "2.4.0" val project = projectRef.project - val fullId = "spark-" + project + "_2.11" + val fullId = "spark-" + project + "_2.12" mimaDefaultSettings ++ Seq(mimaPreviousArtifacts := Set(organization % fullId % previousSparkVersion), mimaBinaryIssueFilters ++= ignoredABIProblems(sparkHome, version.value)) diff --git a/project/MimaExcludes.scala b/project/MimaExcludes.scala index f2e943cae6117..842730e7deb13 100644 --- a/project/MimaExcludes.scala +++ b/project/MimaExcludes.scala @@ -36,6 +36,32 @@ object MimaExcludes { // Exclude rules for 3.0.x lazy val v30excludes = v24excludes ++ Seq( + // [SPARK-26124] Update plugins, including MiMa + ProblemFilters.exclude[InheritedNewAbstractMethodProblem]("org.apache.spark.sql.sources.v2.reader.SupportsPushDownRequiredColumns.build"), + ProblemFilters.exclude[InheritedNewAbstractMethodProblem]("org.apache.spark.sql.sources.v2.reader.SupportsReportStatistics.fullSchema"), + ProblemFilters.exclude[InheritedNewAbstractMethodProblem]("org.apache.spark.sql.sources.v2.reader.SupportsReportStatistics.planInputPartitions"), + ProblemFilters.exclude[InheritedNewAbstractMethodProblem]("org.apache.spark.sql.sources.v2.reader.SupportsReportPartitioning.fullSchema"), + ProblemFilters.exclude[InheritedNewAbstractMethodProblem]("org.apache.spark.sql.sources.v2.reader.SupportsReportPartitioning.planInputPartitions"), + ProblemFilters.exclude[InheritedNewAbstractMethodProblem]("org.apache.spark.sql.sources.v2.reader.SupportsPushDownFilters.build"), + + // [SPARK-26090] Resolve most miscellaneous deprecation and build warnings for Spark 3 + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.mllib.stat.test.BinarySampleBeanInfo"), + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.mllib.regression.LabeledPointBeanInfo"), + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.ml.feature.LabeledPointBeanInfo"), + + // [SPARK-25959] GBTClassifier picks wrong impurity stats on loading + ProblemFilters.exclude[IncompatibleResultTypeProblem]("org.apache.spark.ml.classification.GBTClassificationModel.setImpurity"), + ProblemFilters.exclude[InheritedNewAbstractMethodProblem]("org.apache.spark.ml.tree.HasVarianceImpurity.org$apache$spark$ml$tree$HasVarianceImpurity$_setter_$impurity_="), + ProblemFilters.exclude[InheritedNewAbstractMethodProblem]("org.apache.spark.ml.tree.HasVarianceImpurity.org$apache$spark$ml$tree$HasVarianceImpurity$_setter_$impurity_="), + ProblemFilters.exclude[InheritedNewAbstractMethodProblem]("org.apache.spark.ml.tree.HasVarianceImpurity.org$apache$spark$ml$tree$HasVarianceImpurity$_setter_$impurity_="), + ProblemFilters.exclude[InheritedNewAbstractMethodProblem]("org.apache.spark.ml.tree.HasVarianceImpurity.org$apache$spark$ml$tree$HasVarianceImpurity$_setter_$impurity_="), + ProblemFilters.exclude[InheritedNewAbstractMethodProblem]("org.apache.spark.ml.tree.HasVarianceImpurity.org$apache$spark$ml$tree$HasVarianceImpurity$_setter_$impurity_="), + ProblemFilters.exclude[IncompatibleResultTypeProblem]("org.apache.spark.ml.regression.DecisionTreeRegressionModel.setImpurity"), + ProblemFilters.exclude[IncompatibleResultTypeProblem]("org.apache.spark.ml.regression.GBTRegressionModel.setImpurity"), + ProblemFilters.exclude[IncompatibleResultTypeProblem]("org.apache.spark.ml.regression.RandomForestRegressionModel.setImpurity"), + + // [SPARK-25908][CORE][SQL] Remove old deprecated items in Spark 3 + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.BarrierTaskContext.isRunningLocally"), ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.TaskContext.isRunningLocally"), ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.executor.ShuffleWriteMetrics.shuffleBytesWritten"), ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.executor.ShuffleWriteMetrics.shuffleWriteTime"), @@ -50,10 +76,13 @@ object MimaExcludes { ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.mllib.evaluation.MulticlassMetrics.precision"), ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.util.MLWriter.context"), ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.util.MLReader.context"), + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.util.GeneralMLWriter.context"), + // [SPARK-25737] Remove JavaSparkContextVarargsWorkaround ProblemFilters.exclude[MissingTypesProblem]("org.apache.spark.api.java.JavaSparkContext"), ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.api.java.JavaSparkContext.union"), ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.streaming.api.java.JavaStreamingContext.union"), + // [SPARK-16775] Remove deprecated accumulator v1 APIs ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.Accumulable"), ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.AccumulatorParam"), @@ -73,14 +102,61 @@ object MimaExcludes { ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.api.java.JavaSparkContext.accumulable"), ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.api.java.JavaSparkContext.doubleAccumulator"), ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.api.java.JavaSparkContext.accumulator"), + // [SPARK-24109] Remove class SnappyOutputStreamWrapper ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.io.SnappyCompressionCodec.version"), + // [SPARK-19287] JavaPairRDD flatMapValues requires function returning Iterable, not Iterator ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.api.java.JavaPairRDD.flatMapValues"), ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.streaming.api.java.JavaPairDStream.flatMapValues"), + // [SPARK-25680] SQL execution listener shouldn't happen on execution thread ProblemFilters.exclude[IncompatibleResultTypeProblem]("org.apache.spark.sql.util.ExecutionListenerManager.clone"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.sql.util.ExecutionListenerManager.this") + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.sql.util.ExecutionListenerManager.this"), + + // [SPARK-23781][CORE] Merge token renewer functionality into HadoopDelegationTokenManager + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.deploy.SparkHadoopUtil.nextCredentialRenewalTime"), + + // Data Source V2 API changes + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.v2.ContinuousReadSupport"), + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.v2.ReadSupport"), + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.v2.WriteSupport"), + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.v2.StreamWriteSupport"), + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.v2.MicroBatchReadSupport"), + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.v2.reader.SupportsScanColumnarBatch"), + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.v2.reader.DataSourceReader"), + ProblemFilters.exclude[MissingTypesProblem]("org.apache.spark.sql.sources.v2.reader.SupportsPushDownRequiredColumns"), + ProblemFilters.exclude[InheritedNewAbstractMethodProblem]("org.apache.spark.sql.sources.v2.reader.ScanConfigBuilder.build"), + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.sql.sources.v2.reader.InputPartition.createPartitionReader"), + ProblemFilters.exclude[MissingTypesProblem]("org.apache.spark.sql.sources.v2.reader.SupportsReportStatistics"), + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.sql.sources.v2.reader.SupportsReportStatistics.estimateStatistics"), + ProblemFilters.exclude[InheritedNewAbstractMethodProblem]("org.apache.spark.sql.sources.v2.reader.ReadSupport.fullSchema"), + ProblemFilters.exclude[InheritedNewAbstractMethodProblem]("org.apache.spark.sql.sources.v2.reader.ReadSupport.planInputPartitions"), + ProblemFilters.exclude[MissingTypesProblem]("org.apache.spark.sql.sources.v2.reader.SupportsReportPartitioning"), + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.sql.sources.v2.reader.SupportsReportPartitioning.outputPartitioning"), + ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.sql.sources.v2.reader.SupportsReportPartitioning.outputPartitioning"), + ProblemFilters.exclude[InheritedNewAbstractMethodProblem]("org.apache.spark.sql.sources.v2.reader.ReadSupport.fullSchema"), + ProblemFilters.exclude[InheritedNewAbstractMethodProblem]("org.apache.spark.sql.sources.v2.reader.ReadSupport.planInputPartitions"), + ProblemFilters.exclude[MissingTypesProblem]("org.apache.spark.sql.sources.v2.reader.SupportsPushDownFilters"), + ProblemFilters.exclude[InheritedNewAbstractMethodProblem]("org.apache.spark.sql.sources.v2.reader.ScanConfigBuilder.build"), + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.v2.reader.ContinuousInputPartition"), + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.v2.reader.InputPartitionReader"), + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.v2.reader.streaming.ContinuousInputPartitionReader"), + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.v2.reader.streaming.ContinuousReader"), + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.v2.reader.streaming.MicroBatchReader"), + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.v2.writer.DataSourceWriter"), + ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.sql.sources.v2.writer.DataWriterFactory.createWriter"), + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.v2.writer.streaming.StreamWriter"), + + // [SPARK-26141] Enable custom metrics implementation in shuffle write + // Following are Java private classes + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.shuffle.sort.UnsafeShuffleWriter.this"), + ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.storage.TimeTrackingOutputStream.this"), + + // SafeLogging after MimaUpgrade + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.SparkContext.initializeLogIfNecessary$default$2"), + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.SparkContext.initializeLogIfNecessary$default$2"), + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.broadcast.Broadcast.initializeLogIfNecessary$default$2") ) // Exclude rules for 2.4.x @@ -214,7 +290,50 @@ object MimaExcludes { ProblemFilters.exclude[InheritedNewAbstractMethodProblem]("org.apache.spark.ml.param.shared.HasValidationIndicatorCol.validationIndicatorCol"), // [SPARK-23042] Use OneHotEncoderModel to encode labels in MultilayerPerceptronClassifier - ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.ml.classification.LabelConverter") + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.ml.classification.LabelConverter"), + + // [SPARK-21842][MESOS] Support Kerberos ticket renewal and creation in Mesos + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.deploy.SparkHadoopUtil.getDateOfNextUpdate"), + + // [SPARK-23366] Improve hot reading path in ReadAheadInputStream + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.io.ReadAheadInputStream.this"), + + // [SPARK-22941][CORE] Do not exit JVM when submit fails with in-process launcher. + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.deploy.SparkSubmit.addJarToClasspath"), + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.deploy.SparkSubmit.mergeFileLists"), + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.deploy.SparkSubmit.prepareSubmitEnvironment$default$2"), + + // Data Source V2 API changes + // TODO: they are unstable APIs and should not be tracked by mima. + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.v2.ReadSupportWithSchema"), + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.sql.sources.v2.reader.SupportsScanColumnarBatch.createDataReaderFactories"), + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.sql.sources.v2.reader.SupportsScanColumnarBatch.createBatchDataReaderFactories"), + ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.sql.sources.v2.reader.SupportsScanColumnarBatch.planBatchInputPartitions"), + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.v2.reader.SupportsScanUnsafeRow"), + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.sql.sources.v2.reader.DataSourceReader.createDataReaderFactories"), + ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.sql.sources.v2.reader.DataSourceReader.planInputPartitions"), + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.v2.reader.SupportsPushDownCatalystFilters"), + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.v2.reader.DataReader"), + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.sql.sources.v2.reader.SupportsReportStatistics.getStatistics"), + ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.sql.sources.v2.reader.SupportsReportStatistics.estimateStatistics"), + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.v2.reader.DataReaderFactory"), + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.v2.reader.streaming.ContinuousDataReader"), + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.v2.writer.SupportsWriteInternalRow"), + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.sql.sources.v2.writer.DataWriterFactory.createDataWriter"), + ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.sql.sources.v2.writer.DataWriterFactory.createDataWriter"), + + // Changes to HasRawPredictionCol. + ProblemFilters.exclude[InheritedNewAbstractMethodProblem]("org.apache.spark.ml.param.shared.HasRawPredictionCol.rawPredictionCol"), + ProblemFilters.exclude[InheritedNewAbstractMethodProblem]("org.apache.spark.ml.param.shared.HasRawPredictionCol.org$apache$spark$ml$param$shared$HasRawPredictionCol$_setter_$rawPredictionCol_="), + ProblemFilters.exclude[InheritedNewAbstractMethodProblem]("org.apache.spark.ml.param.shared.HasRawPredictionCol.getRawPredictionCol"), + + // [SPARK-15526][ML][FOLLOWUP] Make JPMML provided scope to avoid including unshaded JARs + (problem: Problem) => problem match { + case MissingClassProblem(cls) => + !cls.fullName.startsWith("org.spark_project.jpmml") && + !cls.fullName.startsWith("org.spark_project.dmg.pmml") + case _ => true + } ) // Exclude rules for 2.3.x diff --git a/project/SparkBuild.scala b/project/SparkBuild.scala index 2f11c5deda217..639d642b78d2a 100644 --- a/project/SparkBuild.scala +++ b/project/SparkBuild.scala @@ -17,6 +17,7 @@ import java.io._ import java.nio.file.Files +import java.util.Locale import scala.io.Source import scala.util.Properties @@ -95,15 +96,15 @@ object SparkBuild extends PomBuild { } Option(System.getProperty("scala.version")) - .filter(_.startsWith("2.12")) + .filter(_.startsWith("2.11")) .foreach { versionString => - System.setProperty("scala-2.12", "true") + System.setProperty("scala-2.11", "true") } - if (System.getProperty("scala-2.12") == "") { + if (System.getProperty("scala-2.11") == "") { // To activate scala-2.10 profile, replace empty property value to non-empty value // in the same way as Maven which handles -Dname as -Dname=true before executes build process. // see: https://github.com/apache/maven/blob/maven-3.0.4/maven-embedder/src/main/java/org/apache/maven/cli/MavenCli.java#L1082 - System.setProperty("scala-2.12", "true") + System.setProperty("scala-2.11", "true") } profiles } @@ -522,7 +523,8 @@ object KubernetesIntegrationTests { s"-Dspark.kubernetes.test.unpackSparkDir=$sparkHome" ), // Force packaging before building images, so that the latest code is tested. - dockerBuild := dockerBuild.dependsOn(packageBin in Compile in assembly).value + dockerBuild := dockerBuild.dependsOn(packageBin in Compile in assembly) + .dependsOn(packageBin in Compile in examples).value ) } @@ -657,10 +659,13 @@ object Assembly { }, jarName in (Test, assembly) := s"${moduleName.value}-test-${version.value}.jar", mergeStrategy in assembly := { - case m if m.toLowerCase.endsWith("manifest.mf") => MergeStrategy.discard - case m if m.toLowerCase.matches("meta-inf.*\\.sf$") => MergeStrategy.discard + case m if m.toLowerCase(Locale.ROOT).endsWith("manifest.mf") + => MergeStrategy.discard + case m if m.toLowerCase(Locale.ROOT).matches("meta-inf.*\\.sf$") + => MergeStrategy.discard case "log4j.properties" => MergeStrategy.discard - case m if m.toLowerCase.startsWith("meta-inf/services/") => MergeStrategy.filterDistinctLines + case m if m.toLowerCase(Locale.ROOT).startsWith("meta-inf/services/") + => MergeStrategy.filterDistinctLines case "reference.conf" => MergeStrategy.concat case _ => MergeStrategy.first } @@ -852,10 +857,10 @@ object TestSettings { import BuildCommons._ private val scalaBinaryVersion = - if (System.getProperty("scala-2.12") == "true") { - "2.12" - } else { + if (System.getProperty("scala-2.11") == "true") { "2.11" + } else { + "2.12" } lazy val settings = Seq ( // Fork new JVMs for tests and set Java options for those diff --git a/project/plugins.sbt b/project/plugins.sbt index da4fd524d6682..e2cfbe7bd6d4b 100644 --- a/project/plugins.sbt +++ b/project/plugins.sbt @@ -3,7 +3,7 @@ addSbtPlugin("com.typesafe.sbt" % "sbt-git" % "0.8.5") addSbtPlugin("com.etsy" % "sbt-checkstyle-plugin" % "3.1.1") // sbt-checkstyle-plugin uses an old version of checkstyle. Match it to Maven's. -libraryDependencies += "com.puppycrawl.tools" % "checkstyle" % "8.2" +libraryDependencies += "com.puppycrawl.tools" % "checkstyle" % "8.14" // checkstyle uses guava 23.0. libraryDependencies += "com.google.guava" % "guava" % "23.0" @@ -11,13 +11,13 @@ libraryDependencies += "com.google.guava" % "guava" % "23.0" // need to make changes to uptake sbt 1.0 support in "com.eed3si9n" % "sbt-assembly" % "1.14.5" addSbtPlugin("com.eed3si9n" % "sbt-assembly" % "0.11.2") -addSbtPlugin("com.typesafe.sbteclipse" % "sbteclipse-plugin" % "5.2.3") +addSbtPlugin("com.typesafe.sbteclipse" % "sbteclipse-plugin" % "5.2.4") -addSbtPlugin("net.virtual-void" % "sbt-dependency-graph" % "0.9.0") +addSbtPlugin("net.virtual-void" % "sbt-dependency-graph" % "0.9.2") addSbtPlugin("org.scalastyle" %% "scalastyle-sbt-plugin" % "1.0.0") -addSbtPlugin("com.typesafe" % "sbt-mima-plugin" % "0.1.17") +addSbtPlugin("com.typesafe" % "sbt-mima-plugin" % "0.3.0") // sbt 1.0.0 support: https://github.com/AlpineNow/junit_xml_listener/issues/6 addSbtPlugin("com.alpinenow" % "junit_xml_listener" % "0.5.1") @@ -30,12 +30,12 @@ addSbtPlugin("com.cavorite" % "sbt-avro" % "0.3.2") addSbtPlugin("io.spray" % "sbt-revolver" % "0.9.1") -libraryDependencies += "org.ow2.asm" % "asm" % "5.1" +libraryDependencies += "org.ow2.asm" % "asm" % "7.0" -libraryDependencies += "org.ow2.asm" % "asm-commons" % "5.1" +libraryDependencies += "org.ow2.asm" % "asm-commons" % "7.0" // sbt 1.0.0 support: https://github.com/ihji/sbt-antlr4/issues/14 -addSbtPlugin("com.simplytyped" % "sbt-antlr4" % "0.7.11") +addSbtPlugin("com.simplytyped" % "sbt-antlr4" % "0.7.12") // Spark uses a custom fork of the sbt-pom-reader plugin which contains a patch to fix issues // related to test-jar dependencies (https://github.com/sbt/sbt-pom-reader/pull/14). The source for diff --git a/python/pyspark/ml/clustering.py b/python/pyspark/ml/clustering.py index b37129428f491..aaeeeb82d3d86 100644 --- a/python/pyspark/ml/clustering.py +++ b/python/pyspark/ml/clustering.py @@ -388,7 +388,7 @@ class KMeans(JavaEstimator, HasDistanceMeasure, HasFeaturesCol, HasPredictionCol >>> len(centers) 2 >>> model.computeCost(df) - 2.000... + 2.0 >>> transformed = model.transform(df).select("features", "prediction") >>> rows = transformed.collect() >>> rows[0].prediction == rows[1].prediction @@ -403,7 +403,7 @@ class KMeans(JavaEstimator, HasDistanceMeasure, HasFeaturesCol, HasPredictionCol >>> summary.clusterSizes [2, 2] >>> summary.trainingCost - 2.000... + 2.0 >>> kmeans_path = temp_path + "/kmeans" >>> kmeans.save(kmeans_path) >>> kmeans2 = KMeans.load(kmeans_path) @@ -595,7 +595,7 @@ class BisectingKMeans(JavaEstimator, HasDistanceMeasure, HasFeaturesCol, HasPred >>> len(centers) 2 >>> model.computeCost(df) - 2.000... + 2.0 >>> model.hasSummary True >>> summary = model.summary diff --git a/python/pyspark/ml/feature.py b/python/pyspark/ml/feature.py index eccb7acae5b98..3d23700242594 100755 --- a/python/pyspark/ml/feature.py +++ b/python/pyspark/ml/feature.py @@ -361,8 +361,9 @@ class Bucketizer(JavaTransformer, HasInputCol, HasOutputCol, HasHandleInvalid, "splits specified will be treated as errors.", typeConverter=TypeConverters.toListFloat) - handleInvalid = Param(Params._dummy(), "handleInvalid", "how to handle invalid entries. " + - "Options are 'skip' (filter out rows with invalid values), " + + handleInvalid = Param(Params._dummy(), "handleInvalid", "how to handle invalid entries " + "containing NaN values. Values outside the splits will always be treated " + "as errors. Options are 'skip' (filter out rows with invalid values), " + "'error' (throw an error), or 'keep' (keep invalid values in a special " + "additional bucket).", typeConverter=TypeConverters.toString) diff --git a/python/pyspark/ml/fpm.py b/python/pyspark/ml/fpm.py index 886ad8409ca66..734763ebd3fa6 100644 --- a/python/pyspark/ml/fpm.py +++ b/python/pyspark/ml/fpm.py @@ -167,8 +167,8 @@ class FPGrowth(JavaEstimator, HasItemsCol, HasPredictionCol, independent group of mining tasks. The FP-Growth algorithm is described in Han et al., Mining frequent patterns without candidate generation [HAN2000]_ - .. [LI2008] http://dx.doi.org/10.1145/1454008.1454027 - .. [HAN2000] http://dx.doi.org/10.1145/335191.335372 + .. [LI2008] https://doi.org/10.1145/1454008.1454027 + .. [HAN2000] https://doi.org/10.1145/335191.335372 .. note:: null values in the feature column are ignored during fit(). .. note:: Internally `transform` `collects` and `broadcasts` association rules. @@ -254,7 +254,7 @@ class PrefixSpan(JavaParams): A parallel PrefixSpan algorithm to mine frequent sequential patterns. The PrefixSpan algorithm is described in J. Pei, et al., PrefixSpan: Mining Sequential Patterns Efficiently by Prefix-Projected Pattern Growth - (see here). + (see here). This class is not yet an Estimator/Transformer, use :py:func:`findFrequentSequentialPatterns` method to run the PrefixSpan algorithm. diff --git a/python/pyspark/ml/recommendation.py b/python/pyspark/ml/recommendation.py index a8eae9bd268d3..520d7912c1a10 100644 --- a/python/pyspark/ml/recommendation.py +++ b/python/pyspark/ml/recommendation.py @@ -57,7 +57,7 @@ class ALS(JavaEstimator, HasCheckpointInterval, HasMaxIter, HasPredictionCol, Ha For implicit preference data, the algorithm used is based on `"Collaborative Filtering for Implicit Feedback Datasets", - `_, adapted for the blocked + `_, adapted for the blocked approach used here. Essentially instead of finding the low-rank approximations to the diff --git a/python/pyspark/ml/tests.py b/python/pyspark/ml/tests.py deleted file mode 100755 index 8c4f02dd724b4..0000000000000 --- a/python/pyspark/ml/tests.py +++ /dev/null @@ -1,2761 +0,0 @@ -# -*- coding: utf-8 -*- -# -# Licensed to the Apache Software Foundation (ASF) under one or more -# contributor license agreements. See the NOTICE file distributed with -# this work for additional information regarding copyright ownership. -# The ASF licenses this file to You under the Apache License, Version 2.0 -# (the "License"); you may not use this file except in compliance with -# the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# - -""" -Unit tests for MLlib Python DataFrame-based APIs. -""" -import sys - -import unishark - -if sys.version > '3': - xrange = range - basestring = str - -if sys.version_info[:2] <= (2, 6): - try: - import unittest2 as unittest - except ImportError: - sys.stderr.write('Please install unittest2 to test with Python 2.6 or earlier') - sys.exit(1) -else: - import unittest - -from shutil import rmtree -import tempfile -import array as pyarray -import numpy as np -from numpy import abs, all, arange, array, array_equal, inf, ones, tile, zeros -import inspect -import py4j - -from pyspark import keyword_only, SparkContext -from pyspark.ml import Estimator, Model, Pipeline, PipelineModel, Transformer, UnaryTransformer -from pyspark.ml.classification import * -from pyspark.ml.clustering import * -from pyspark.ml.common import _java2py, _py2java -from pyspark.ml.evaluation import BinaryClassificationEvaluator, ClusteringEvaluator, \ - MulticlassClassificationEvaluator, RegressionEvaluator -from pyspark.ml.feature import * -from pyspark.ml.fpm import FPGrowth, FPGrowthModel -from pyspark.ml.image import ImageSchema -from pyspark.ml.linalg import DenseMatrix, DenseMatrix, DenseVector, Matrices, MatrixUDT, \ - SparseMatrix, SparseVector, Vector, VectorUDT, Vectors -from pyspark.ml.param import Param, Params, TypeConverters -from pyspark.ml.param.shared import HasInputCol, HasMaxIter, HasSeed -from pyspark.ml.recommendation import ALS -from pyspark.ml.regression import DecisionTreeRegressor, GeneralizedLinearRegression, \ - LinearRegression -from pyspark.ml.stat import ChiSquareTest -from pyspark.ml.tuning import * -from pyspark.ml.util import * -from pyspark.ml.wrapper import JavaParams, JavaWrapper -from pyspark.serializers import PickleSerializer -from pyspark.sql import DataFrame, Row, SparkSession, HiveContext -from pyspark.sql.functions import rand -from pyspark.sql.types import DoubleType, IntegerType -from pyspark.storagelevel import * -from pyspark.tests import QuietTest, ReusedPySparkTestCase as PySparkTestCase - -ser = PickleSerializer() - - -class MLlibTestCase(unittest.TestCase): - def setUp(self): - self.sc = SparkContext('local[4]', "MLlib tests") - self.spark = SparkSession(self.sc) - - def tearDown(self): - self.spark.stop() - - -class SparkSessionTestCase(PySparkTestCase): - @classmethod - def setUpClass(cls): - PySparkTestCase.setUpClass() - cls.spark = SparkSession(cls.sc) - - @classmethod - def tearDownClass(cls): - PySparkTestCase.tearDownClass() - cls.spark.stop() - - -class MockDataset(DataFrame): - - def __init__(self): - self.index = 0 - - -class HasFake(Params): - - def __init__(self): - super(HasFake, self).__init__() - self.fake = Param(self, "fake", "fake param") - - def getFake(self): - return self.getOrDefault(self.fake) - - -class MockTransformer(Transformer, HasFake): - - def __init__(self): - super(MockTransformer, self).__init__() - self.dataset_index = None - - def _transform(self, dataset): - self.dataset_index = dataset.index - dataset.index += 1 - return dataset - - -class MockUnaryTransformer(UnaryTransformer, DefaultParamsReadable, DefaultParamsWritable): - - shift = Param(Params._dummy(), "shift", "The amount by which to shift " + - "data in a DataFrame", - typeConverter=TypeConverters.toFloat) - - def __init__(self, shiftVal=1): - super(MockUnaryTransformer, self).__init__() - self._setDefault(shift=1) - self._set(shift=shiftVal) - - def getShift(self): - return self.getOrDefault(self.shift) - - def setShift(self, shift): - self._set(shift=shift) - - def createTransformFunc(self): - shiftVal = self.getShift() - return lambda x: x + shiftVal - - def outputDataType(self): - return DoubleType() - - def validateInputType(self, inputType): - if inputType != DoubleType(): - raise TypeError("Bad input type: {}. ".format(inputType) + - "Requires Double.") - - -class MockEstimator(Estimator, HasFake): - - def __init__(self): - super(MockEstimator, self).__init__() - self.dataset_index = None - - def _fit(self, dataset): - self.dataset_index = dataset.index - model = MockModel() - self._copyValues(model) - return model - - -class MockModel(MockTransformer, Model, HasFake): - pass - - -class JavaWrapperMemoryTests(SparkSessionTestCase): - - def test_java_object_gets_detached(self): - df = self.spark.createDataFrame([(1.0, 2.0, Vectors.dense(1.0)), - (0.0, 2.0, Vectors.sparse(1, [], []))], - ["label", "weight", "features"]) - lr = LinearRegression(maxIter=1, regParam=0.0, solver="normal", weightCol="weight", - fitIntercept=False) - - model = lr.fit(df) - summary = model.summary - - self.assertIsInstance(model, JavaWrapper) - self.assertIsInstance(summary, JavaWrapper) - self.assertIsInstance(model, JavaParams) - self.assertNotIsInstance(summary, JavaParams) - - error_no_object = 'Target Object ID does not exist for this gateway' - - self.assertIn("LinearRegression_", model._java_obj.toString()) - self.assertIn("LinearRegressionTrainingSummary", summary._java_obj.toString()) - - model.__del__() - - with self.assertRaisesRegexp(py4j.protocol.Py4JError, error_no_object): - model._java_obj.toString() - self.assertIn("LinearRegressionTrainingSummary", summary._java_obj.toString()) - - try: - summary.__del__() - except: - pass - - with self.assertRaisesRegexp(py4j.protocol.Py4JError, error_no_object): - model._java_obj.toString() - with self.assertRaisesRegexp(py4j.protocol.Py4JError, error_no_object): - summary._java_obj.toString() - - -class ParamTypeConversionTests(PySparkTestCase): - """ - Test that param type conversion happens. - """ - - def test_int(self): - lr = LogisticRegression(maxIter=5.0) - self.assertEqual(lr.getMaxIter(), 5) - self.assertTrue(type(lr.getMaxIter()) == int) - self.assertRaises(TypeError, lambda: LogisticRegression(maxIter="notAnInt")) - self.assertRaises(TypeError, lambda: LogisticRegression(maxIter=5.1)) - - def test_float(self): - lr = LogisticRegression(tol=1) - self.assertEqual(lr.getTol(), 1.0) - self.assertTrue(type(lr.getTol()) == float) - self.assertRaises(TypeError, lambda: LogisticRegression(tol="notAFloat")) - - def test_vector(self): - ewp = ElementwiseProduct(scalingVec=[1, 3]) - self.assertEqual(ewp.getScalingVec(), DenseVector([1.0, 3.0])) - ewp = ElementwiseProduct(scalingVec=np.array([1.2, 3.4])) - self.assertEqual(ewp.getScalingVec(), DenseVector([1.2, 3.4])) - self.assertRaises(TypeError, lambda: ElementwiseProduct(scalingVec=["a", "b"])) - - def test_list(self): - l = [0, 1] - for lst_like in [l, np.array(l), DenseVector(l), SparseVector(len(l), - range(len(l)), l), pyarray.array('l', l), xrange(2), tuple(l)]: - converted = TypeConverters.toList(lst_like) - self.assertEqual(type(converted), list) - self.assertListEqual(converted, l) - - def test_list_int(self): - for indices in [[1.0, 2.0], np.array([1.0, 2.0]), DenseVector([1.0, 2.0]), - SparseVector(2, {0: 1.0, 1: 2.0}), xrange(1, 3), (1.0, 2.0), - pyarray.array('d', [1.0, 2.0])]: - vs = VectorSlicer(indices=indices) - self.assertListEqual(vs.getIndices(), [1, 2]) - self.assertTrue(all([type(v) == int for v in vs.getIndices()])) - self.assertRaises(TypeError, lambda: VectorSlicer(indices=["a", "b"])) - - def test_list_float(self): - b = Bucketizer(splits=[1, 4]) - self.assertEqual(b.getSplits(), [1.0, 4.0]) - self.assertTrue(all([type(v) == float for v in b.getSplits()])) - self.assertRaises(TypeError, lambda: Bucketizer(splits=["a", 1.0])) - - def test_list_string(self): - for labels in [np.array(['a', u'b']), ['a', u'b'], np.array(['a', 'b'])]: - idx_to_string = IndexToString(labels=labels) - self.assertListEqual(idx_to_string.getLabels(), ['a', 'b']) - self.assertRaises(TypeError, lambda: IndexToString(labels=['a', 2])) - - def test_string(self): - lr = LogisticRegression() - for col in ['features', u'features', np.str_('features')]: - lr.setFeaturesCol(col) - self.assertEqual(lr.getFeaturesCol(), 'features') - self.assertRaises(TypeError, lambda: LogisticRegression(featuresCol=2.3)) - - def test_bool(self): - self.assertRaises(TypeError, lambda: LogisticRegression(fitIntercept=1)) - self.assertRaises(TypeError, lambda: LogisticRegression(fitIntercept="false")) - - -class PipelineTests(PySparkTestCase): - - def test_pipeline(self): - dataset = MockDataset() - estimator0 = MockEstimator() - transformer1 = MockTransformer() - estimator2 = MockEstimator() - transformer3 = MockTransformer() - pipeline = Pipeline(stages=[estimator0, transformer1, estimator2, transformer3]) - pipeline_model = pipeline.fit(dataset, {estimator0.fake: 0, transformer1.fake: 1}) - model0, transformer1, model2, transformer3 = pipeline_model.stages - self.assertEqual(0, model0.dataset_index) - self.assertEqual(0, model0.getFake()) - self.assertEqual(1, transformer1.dataset_index) - self.assertEqual(1, transformer1.getFake()) - self.assertEqual(2, dataset.index) - self.assertIsNone(model2.dataset_index, "The last model shouldn't be called in fit.") - self.assertIsNone(transformer3.dataset_index, - "The last transformer shouldn't be called in fit.") - dataset = pipeline_model.transform(dataset) - self.assertEqual(2, model0.dataset_index) - self.assertEqual(3, transformer1.dataset_index) - self.assertEqual(4, model2.dataset_index) - self.assertEqual(5, transformer3.dataset_index) - self.assertEqual(6, dataset.index) - - def test_identity_pipeline(self): - dataset = MockDataset() - - def doTransform(pipeline): - pipeline_model = pipeline.fit(dataset) - return pipeline_model.transform(dataset) - # check that empty pipeline did not perform any transformation - self.assertEqual(dataset.index, doTransform(Pipeline(stages=[])).index) - # check that failure to set stages param will raise KeyError for missing param - self.assertRaises(KeyError, lambda: doTransform(Pipeline())) - - -class TestParams(HasMaxIter, HasInputCol, HasSeed): - """ - A subclass of Params mixed with HasMaxIter, HasInputCol and HasSeed. - """ - @keyword_only - def __init__(self, seed=None): - super(TestParams, self).__init__() - self._setDefault(maxIter=10) - kwargs = self._input_kwargs - self.setParams(**kwargs) - - @keyword_only - def setParams(self, seed=None): - """ - setParams(self, seed=None) - Sets params for this test. - """ - kwargs = self._input_kwargs - return self._set(**kwargs) - - -class OtherTestParams(HasMaxIter, HasInputCol, HasSeed): - """ - A subclass of Params mixed with HasMaxIter, HasInputCol and HasSeed. - """ - @keyword_only - def __init__(self, seed=None): - super(OtherTestParams, self).__init__() - self._setDefault(maxIter=10) - kwargs = self._input_kwargs - self.setParams(**kwargs) - - @keyword_only - def setParams(self, seed=None): - """ - setParams(self, seed=None) - Sets params for this test. - """ - kwargs = self._input_kwargs - return self._set(**kwargs) - - -class HasThrowableProperty(Params): - - def __init__(self): - super(HasThrowableProperty, self).__init__() - self.p = Param(self, "none", "empty param") - - @property - def test_property(self): - raise RuntimeError("Test property to raise error when invoked") - - -class ParamTests(SparkSessionTestCase): - - def test_copy_new_parent(self): - testParams = TestParams() - # Copying an instantiated param should fail - with self.assertRaises(ValueError): - testParams.maxIter._copy_new_parent(testParams) - # Copying a dummy param should succeed - TestParams.maxIter._copy_new_parent(testParams) - maxIter = testParams.maxIter - self.assertEqual(maxIter.name, "maxIter") - self.assertEqual(maxIter.doc, "max number of iterations (>= 0).") - self.assertTrue(maxIter.parent == testParams.uid) - - def test_param(self): - testParams = TestParams() - maxIter = testParams.maxIter - self.assertEqual(maxIter.name, "maxIter") - self.assertEqual(maxIter.doc, "max number of iterations (>= 0).") - self.assertTrue(maxIter.parent == testParams.uid) - - def test_hasparam(self): - testParams = TestParams() - self.assertTrue(all([testParams.hasParam(p.name) for p in testParams.params])) - self.assertFalse(testParams.hasParam("notAParameter")) - self.assertTrue(testParams.hasParam(u"maxIter")) - - def test_resolveparam(self): - testParams = TestParams() - self.assertEqual(testParams._resolveParam(testParams.maxIter), testParams.maxIter) - self.assertEqual(testParams._resolveParam("maxIter"), testParams.maxIter) - - self.assertEqual(testParams._resolveParam(u"maxIter"), testParams.maxIter) - if sys.version_info[0] >= 3: - # In Python 3, it is allowed to get/set attributes with non-ascii characters. - e_cls = AttributeError - else: - e_cls = UnicodeEncodeError - self.assertRaises(e_cls, lambda: testParams._resolveParam(u"아")) - - def test_params(self): - testParams = TestParams() - maxIter = testParams.maxIter - inputCol = testParams.inputCol - seed = testParams.seed - - params = testParams.params - self.assertEqual(params, [inputCol, maxIter, seed]) - - self.assertTrue(testParams.hasParam(maxIter.name)) - self.assertTrue(testParams.hasDefault(maxIter)) - self.assertFalse(testParams.isSet(maxIter)) - self.assertTrue(testParams.isDefined(maxIter)) - self.assertEqual(testParams.getMaxIter(), 10) - testParams.setMaxIter(100) - self.assertTrue(testParams.isSet(maxIter)) - self.assertEqual(testParams.getMaxIter(), 100) - - self.assertTrue(testParams.hasParam(inputCol.name)) - self.assertFalse(testParams.hasDefault(inputCol)) - self.assertFalse(testParams.isSet(inputCol)) - self.assertFalse(testParams.isDefined(inputCol)) - with self.assertRaises(KeyError): - testParams.getInputCol() - - otherParam = Param(Params._dummy(), "otherParam", "Parameter used to test that " + - "set raises an error for a non-member parameter.", - typeConverter=TypeConverters.toString) - with self.assertRaises(ValueError): - testParams.set(otherParam, "value") - - # Since the default is normally random, set it to a known number for debug str - testParams._setDefault(seed=41) - testParams.setSeed(43) - - self.assertEqual( - testParams.explainParams(), - "\n".join(["inputCol: input column name. (undefined)", - "maxIter: max number of iterations (>= 0). (default: 10, current: 100)", - "seed: random seed. (default: 41, current: 43)"])) - - def test_kmeans_param(self): - algo = KMeans() - self.assertEqual(algo.getInitMode(), "k-means||") - algo.setK(10) - self.assertEqual(algo.getK(), 10) - algo.setInitSteps(10) - self.assertEqual(algo.getInitSteps(), 10) - self.assertEqual(algo.getDistanceMeasure(), "euclidean") - algo.setDistanceMeasure("cosine") - self.assertEqual(algo.getDistanceMeasure(), "cosine") - - def test_hasseed(self): - noSeedSpecd = TestParams() - withSeedSpecd = TestParams(seed=42) - other = OtherTestParams() - # Check that we no longer use 42 as the magic number - self.assertNotEqual(noSeedSpecd.getSeed(), 42) - origSeed = noSeedSpecd.getSeed() - # Check that we only compute the seed once - self.assertEqual(noSeedSpecd.getSeed(), origSeed) - # Check that a specified seed is honored - self.assertEqual(withSeedSpecd.getSeed(), 42) - # Check that a different class has a different seed - self.assertNotEqual(other.getSeed(), noSeedSpecd.getSeed()) - - def test_param_property_error(self): - param_store = HasThrowableProperty() - self.assertRaises(RuntimeError, lambda: param_store.test_property) - params = param_store.params # should not invoke the property 'test_property' - self.assertEqual(len(params), 1) - - def test_word2vec_param(self): - model = Word2Vec().setWindowSize(6) - # Check windowSize is set properly - self.assertEqual(model.getWindowSize(), 6) - - def test_copy_param_extras(self): - tp = TestParams(seed=42) - extra = {tp.getParam(TestParams.inputCol.name): "copy_input"} - tp_copy = tp.copy(extra=extra) - self.assertEqual(tp.uid, tp_copy.uid) - self.assertEqual(tp.params, tp_copy.params) - for k, v in extra.items(): - self.assertTrue(tp_copy.isDefined(k)) - self.assertEqual(tp_copy.getOrDefault(k), v) - copied_no_extra = {} - for k, v in tp_copy._paramMap.items(): - if k not in extra: - copied_no_extra[k] = v - self.assertEqual(tp._paramMap, copied_no_extra) - self.assertEqual(tp._defaultParamMap, tp_copy._defaultParamMap) - - def test_logistic_regression_check_thresholds(self): - self.assertIsInstance( - LogisticRegression(threshold=0.5, thresholds=[0.5, 0.5]), - LogisticRegression - ) - - self.assertRaisesRegexp( - ValueError, - "Logistic Regression getThreshold found inconsistent.*$", - LogisticRegression, threshold=0.42, thresholds=[0.5, 0.5] - ) - - def test_preserve_set_state(self): - dataset = self.spark.createDataFrame([(0.5,)], ["data"]) - binarizer = Binarizer(inputCol="data") - self.assertFalse(binarizer.isSet("threshold")) - binarizer.transform(dataset) - binarizer._transfer_params_from_java() - self.assertFalse(binarizer.isSet("threshold"), - "Params not explicitly set should remain unset after transform") - - def test_default_params_transferred(self): - dataset = self.spark.createDataFrame([(0.5,)], ["data"]) - binarizer = Binarizer(inputCol="data") - # intentionally change the pyspark default, but don't set it - binarizer._defaultParamMap[binarizer.outputCol] = "my_default" - result = binarizer.transform(dataset).select("my_default").collect() - self.assertFalse(binarizer.isSet(binarizer.outputCol)) - self.assertEqual(result[0][0], 1.0) - - @staticmethod - def check_params(test_self, py_stage, check_params_exist=True): - """ - Checks common requirements for Params.params: - - set of params exist in Java and Python and are ordered by names - - param parent has the same UID as the object's UID - - default param value from Java matches value in Python - - optionally check if all params from Java also exist in Python - """ - py_stage_str = "%s %s" % (type(py_stage), py_stage) - if not hasattr(py_stage, "_to_java"): - return - java_stage = py_stage._to_java() - if java_stage is None: - return - test_self.assertEqual(py_stage.uid, java_stage.uid(), msg=py_stage_str) - if check_params_exist: - param_names = [p.name for p in py_stage.params] - java_params = list(java_stage.params()) - java_param_names = [jp.name() for jp in java_params] - test_self.assertEqual( - param_names, sorted(java_param_names), - "Param list in Python does not match Java for %s:\nJava = %s\nPython = %s" - % (py_stage_str, java_param_names, param_names)) - for p in py_stage.params: - test_self.assertEqual(p.parent, py_stage.uid) - java_param = java_stage.getParam(p.name) - py_has_default = py_stage.hasDefault(p) - java_has_default = java_stage.hasDefault(java_param) - test_self.assertEqual(py_has_default, java_has_default, - "Default value mismatch of param %s for Params %s" - % (p.name, str(py_stage))) - if py_has_default: - if p.name == "seed": - continue # Random seeds between Spark and PySpark are different - java_default = _java2py(test_self.sc, - java_stage.clear(java_param).getOrDefault(java_param)) - py_stage._clear(p) - py_default = py_stage.getOrDefault(p) - # equality test for NaN is always False - if isinstance(java_default, float) and np.isnan(java_default): - java_default = "NaN" - py_default = "NaN" if np.isnan(py_default) else "not NaN" - test_self.assertEqual( - java_default, py_default, - "Java default %s != python default %s of param %s for Params %s" - % (str(java_default), str(py_default), p.name, str(py_stage))) - - -class EvaluatorTests(SparkSessionTestCase): - - def test_java_params(self): - """ - This tests a bug fixed by SPARK-18274 which causes multiple copies - of a Params instance in Python to be linked to the same Java instance. - """ - evaluator = RegressionEvaluator(metricName="r2") - df = self.spark.createDataFrame([Row(label=1.0, prediction=1.1)]) - evaluator.evaluate(df) - self.assertEqual(evaluator._java_obj.getMetricName(), "r2") - evaluatorCopy = evaluator.copy({evaluator.metricName: "mae"}) - evaluator.evaluate(df) - evaluatorCopy.evaluate(df) - self.assertEqual(evaluator._java_obj.getMetricName(), "r2") - self.assertEqual(evaluatorCopy._java_obj.getMetricName(), "mae") - - def test_clustering_evaluator_with_cosine_distance(self): - featureAndPredictions = map(lambda x: (Vectors.dense(x[0]), x[1]), - [([1.0, 1.0], 1.0), ([10.0, 10.0], 1.0), ([1.0, 0.5], 2.0), - ([10.0, 4.4], 2.0), ([-1.0, 1.0], 3.0), ([-100.0, 90.0], 3.0)]) - dataset = self.spark.createDataFrame(featureAndPredictions, ["features", "prediction"]) - evaluator = ClusteringEvaluator(predictionCol="prediction", distanceMeasure="cosine") - self.assertEqual(evaluator.getDistanceMeasure(), "cosine") - self.assertTrue(np.isclose(evaluator.evaluate(dataset), 0.992671213, atol=1e-5)) - - -class FeatureTests(SparkSessionTestCase): - - def test_binarizer(self): - b0 = Binarizer() - self.assertListEqual(b0.params, [b0.inputCol, b0.outputCol, b0.threshold]) - self.assertTrue(all([~b0.isSet(p) for p in b0.params])) - self.assertTrue(b0.hasDefault(b0.threshold)) - self.assertEqual(b0.getThreshold(), 0.0) - b0.setParams(inputCol="input", outputCol="output").setThreshold(1.0) - self.assertTrue(all([b0.isSet(p) for p in b0.params])) - self.assertEqual(b0.getThreshold(), 1.0) - self.assertEqual(b0.getInputCol(), "input") - self.assertEqual(b0.getOutputCol(), "output") - - b0c = b0.copy({b0.threshold: 2.0}) - self.assertEqual(b0c.uid, b0.uid) - self.assertListEqual(b0c.params, b0.params) - self.assertEqual(b0c.getThreshold(), 2.0) - - b1 = Binarizer(threshold=2.0, inputCol="input", outputCol="output") - self.assertNotEqual(b1.uid, b0.uid) - self.assertEqual(b1.getThreshold(), 2.0) - self.assertEqual(b1.getInputCol(), "input") - self.assertEqual(b1.getOutputCol(), "output") - - def test_idf(self): - dataset = self.spark.createDataFrame([ - (DenseVector([1.0, 2.0]),), - (DenseVector([0.0, 1.0]),), - (DenseVector([3.0, 0.2]),)], ["tf"]) - idf0 = IDF(inputCol="tf") - self.assertListEqual(idf0.params, [idf0.inputCol, idf0.minDocFreq, idf0.outputCol]) - idf0m = idf0.fit(dataset, {idf0.outputCol: "idf"}) - self.assertEqual(idf0m.uid, idf0.uid, - "Model should inherit the UID from its parent estimator.") - output = idf0m.transform(dataset) - self.assertIsNotNone(output.head().idf) - # Test that parameters transferred to Python Model - ParamTests.check_params(self, idf0m) - - def test_ngram(self): - dataset = self.spark.createDataFrame([ - Row(input=["a", "b", "c", "d", "e"])]) - ngram0 = NGram(n=4, inputCol="input", outputCol="output") - self.assertEqual(ngram0.getN(), 4) - self.assertEqual(ngram0.getInputCol(), "input") - self.assertEqual(ngram0.getOutputCol(), "output") - transformedDF = ngram0.transform(dataset) - self.assertEqual(transformedDF.head().output, ["a b c d", "b c d e"]) - - def test_stopwordsremover(self): - dataset = self.spark.createDataFrame([Row(input=["a", "panda"])]) - stopWordRemover = StopWordsRemover(inputCol="input", outputCol="output") - # Default - self.assertEqual(stopWordRemover.getInputCol(), "input") - transformedDF = stopWordRemover.transform(dataset) - self.assertEqual(transformedDF.head().output, ["panda"]) - self.assertEqual(type(stopWordRemover.getStopWords()), list) - self.assertTrue(isinstance(stopWordRemover.getStopWords()[0], basestring)) - # Custom - stopwords = ["panda"] - stopWordRemover.setStopWords(stopwords) - self.assertEqual(stopWordRemover.getInputCol(), "input") - self.assertEqual(stopWordRemover.getStopWords(), stopwords) - transformedDF = stopWordRemover.transform(dataset) - self.assertEqual(transformedDF.head().output, ["a"]) - # with language selection - stopwords = StopWordsRemover.loadDefaultStopWords("turkish") - dataset = self.spark.createDataFrame([Row(input=["acaba", "ama", "biri"])]) - stopWordRemover.setStopWords(stopwords) - self.assertEqual(stopWordRemover.getStopWords(), stopwords) - transformedDF = stopWordRemover.transform(dataset) - self.assertEqual(transformedDF.head().output, []) - # with locale - stopwords = ["BELKİ"] - dataset = self.spark.createDataFrame([Row(input=["belki"])]) - stopWordRemover.setStopWords(stopwords).setLocale("tr") - self.assertEqual(stopWordRemover.getStopWords(), stopwords) - transformedDF = stopWordRemover.transform(dataset) - self.assertEqual(transformedDF.head().output, []) - - def test_count_vectorizer_with_binary(self): - dataset = self.spark.createDataFrame([ - (0, "a a a b b c".split(' '), SparseVector(3, {0: 1.0, 1: 1.0, 2: 1.0}),), - (1, "a a".split(' '), SparseVector(3, {0: 1.0}),), - (2, "a b".split(' '), SparseVector(3, {0: 1.0, 1: 1.0}),), - (3, "c".split(' '), SparseVector(3, {2: 1.0}),)], ["id", "words", "expected"]) - cv = CountVectorizer(binary=True, inputCol="words", outputCol="features") - model = cv.fit(dataset) - - transformedList = model.transform(dataset).select("features", "expected").collect() - - for r in transformedList: - feature, expected = r - self.assertEqual(feature, expected) - - def test_count_vectorizer_with_maxDF(self): - dataset = self.spark.createDataFrame([ - (0, "a b c d".split(' '), SparseVector(3, {0: 1.0, 1: 1.0, 2: 1.0}),), - (1, "a b c".split(' '), SparseVector(3, {0: 1.0, 1: 1.0}),), - (2, "a b".split(' '), SparseVector(3, {0: 1.0}),), - (3, "a".split(' '), SparseVector(3, {}),)], ["id", "words", "expected"]) - cv = CountVectorizer(inputCol="words", outputCol="features") - model1 = cv.setMaxDF(3).fit(dataset) - self.assertEqual(model1.vocabulary, ['b', 'c', 'd']) - - transformedList1 = model1.transform(dataset).select("features", "expected").collect() - - for r in transformedList1: - feature, expected = r - self.assertEqual(feature, expected) - - model2 = cv.setMaxDF(0.75).fit(dataset) - self.assertEqual(model2.vocabulary, ['b', 'c', 'd']) - - transformedList2 = model2.transform(dataset).select("features", "expected").collect() - - for r in transformedList2: - feature, expected = r - self.assertEqual(feature, expected) - - def test_count_vectorizer_from_vocab(self): - model = CountVectorizerModel.from_vocabulary(["a", "b", "c"], inputCol="words", - outputCol="features", minTF=2) - self.assertEqual(model.vocabulary, ["a", "b", "c"]) - self.assertEqual(model.getMinTF(), 2) - - dataset = self.spark.createDataFrame([ - (0, "a a a b b c".split(' '), SparseVector(3, {0: 3.0, 1: 2.0}),), - (1, "a a".split(' '), SparseVector(3, {0: 2.0}),), - (2, "a b".split(' '), SparseVector(3, {}),)], ["id", "words", "expected"]) - - transformed_list = model.transform(dataset).select("features", "expected").collect() - - for r in transformed_list: - feature, expected = r - self.assertEqual(feature, expected) - - # Test an empty vocabulary - with QuietTest(self.sc): - with self.assertRaisesRegexp(Exception, "vocabSize.*invalid.*0"): - CountVectorizerModel.from_vocabulary([], inputCol="words") - - # Test model with default settings can transform - model_default = CountVectorizerModel.from_vocabulary(["a", "b", "c"], inputCol="words") - transformed_list = model_default.transform(dataset)\ - .select(model_default.getOrDefault(model_default.outputCol)).collect() - self.assertEqual(len(transformed_list), 3) - - def test_rformula_force_index_label(self): - df = self.spark.createDataFrame([ - (1.0, 1.0, "a"), - (0.0, 2.0, "b"), - (1.0, 0.0, "a")], ["y", "x", "s"]) - # Does not index label by default since it's numeric type. - rf = RFormula(formula="y ~ x + s") - model = rf.fit(df) - transformedDF = model.transform(df) - self.assertEqual(transformedDF.head().label, 1.0) - # Force to index label. - rf2 = RFormula(formula="y ~ x + s").setForceIndexLabel(True) - model2 = rf2.fit(df) - transformedDF2 = model2.transform(df) - self.assertEqual(transformedDF2.head().label, 0.0) - - def test_rformula_string_indexer_order_type(self): - df = self.spark.createDataFrame([ - (1.0, 1.0, "a"), - (0.0, 2.0, "b"), - (1.0, 0.0, "a")], ["y", "x", "s"]) - rf = RFormula(formula="y ~ x + s", stringIndexerOrderType="alphabetDesc") - self.assertEqual(rf.getStringIndexerOrderType(), 'alphabetDesc') - transformedDF = rf.fit(df).transform(df) - observed = transformedDF.select("features").collect() - expected = [[1.0, 0.0], [2.0, 1.0], [0.0, 0.0]] - for i in range(0, len(expected)): - self.assertTrue(all(observed[i]["features"].toArray() == expected[i])) - - def test_string_indexer_handle_invalid(self): - df = self.spark.createDataFrame([ - (0, "a"), - (1, "d"), - (2, None)], ["id", "label"]) - - si1 = StringIndexer(inputCol="label", outputCol="indexed", handleInvalid="keep", - stringOrderType="alphabetAsc") - model1 = si1.fit(df) - td1 = model1.transform(df) - actual1 = td1.select("id", "indexed").collect() - expected1 = [Row(id=0, indexed=0.0), Row(id=1, indexed=1.0), Row(id=2, indexed=2.0)] - self.assertEqual(actual1, expected1) - - si2 = si1.setHandleInvalid("skip") - model2 = si2.fit(df) - td2 = model2.transform(df) - actual2 = td2.select("id", "indexed").collect() - expected2 = [Row(id=0, indexed=0.0), Row(id=1, indexed=1.0)] - self.assertEqual(actual2, expected2) - - def test_string_indexer_from_labels(self): - model = StringIndexerModel.from_labels(["a", "b", "c"], inputCol="label", - outputCol="indexed", handleInvalid="keep") - self.assertEqual(model.labels, ["a", "b", "c"]) - - df1 = self.spark.createDataFrame([ - (0, "a"), - (1, "c"), - (2, None), - (3, "b"), - (4, "b")], ["id", "label"]) - - result1 = model.transform(df1) - actual1 = result1.select("id", "indexed").collect() - expected1 = [Row(id=0, indexed=0.0), Row(id=1, indexed=2.0), Row(id=2, indexed=3.0), - Row(id=3, indexed=1.0), Row(id=4, indexed=1.0)] - self.assertEqual(actual1, expected1) - - model_empty_labels = StringIndexerModel.from_labels( - [], inputCol="label", outputCol="indexed", handleInvalid="keep") - actual2 = model_empty_labels.transform(df1).select("id", "indexed").collect() - expected2 = [Row(id=0, indexed=0.0), Row(id=1, indexed=0.0), Row(id=2, indexed=0.0), - Row(id=3, indexed=0.0), Row(id=4, indexed=0.0)] - self.assertEqual(actual2, expected2) - - # Test model with default settings can transform - model_default = StringIndexerModel.from_labels(["a", "b", "c"], inputCol="label") - df2 = self.spark.createDataFrame([ - (0, "a"), - (1, "c"), - (2, "b"), - (3, "b"), - (4, "b")], ["id", "label"]) - transformed_list = model_default.transform(df2)\ - .select(model_default.getOrDefault(model_default.outputCol)).collect() - self.assertEqual(len(transformed_list), 5) - - def test_vector_size_hint(self): - df = self.spark.createDataFrame( - [(0, Vectors.dense([0.0, 10.0, 0.5])), - (1, Vectors.dense([1.0, 11.0, 0.5, 0.6])), - (2, Vectors.dense([2.0, 12.0]))], - ["id", "vector"]) - - sizeHint = VectorSizeHint( - inputCol="vector", - handleInvalid="skip") - sizeHint.setSize(3) - self.assertEqual(sizeHint.getSize(), 3) - - output = sizeHint.transform(df).head().vector - expected = DenseVector([0.0, 10.0, 0.5]) - self.assertEqual(output, expected) - - -class HasInducedError(Params): - - def __init__(self): - super(HasInducedError, self).__init__() - self.inducedError = Param(self, "inducedError", - "Uniformly-distributed error added to feature") - - def getInducedError(self): - return self.getOrDefault(self.inducedError) - - -class InducedErrorModel(Model, HasInducedError): - - def __init__(self): - super(InducedErrorModel, self).__init__() - - def _transform(self, dataset): - return dataset.withColumn("prediction", - dataset.feature + (rand(0) * self.getInducedError())) - - -class InducedErrorEstimator(Estimator, HasInducedError): - - def __init__(self, inducedError=1.0): - super(InducedErrorEstimator, self).__init__() - self._set(inducedError=inducedError) - - def _fit(self, dataset): - model = InducedErrorModel() - self._copyValues(model) - return model - - -class CrossValidatorTests(SparkSessionTestCase): - - def test_copy(self): - dataset = self.spark.createDataFrame([ - (10, 10.0), - (50, 50.0), - (100, 100.0), - (500, 500.0)] * 10, - ["feature", "label"]) - - iee = InducedErrorEstimator() - evaluator = RegressionEvaluator(metricName="rmse") - - grid = (ParamGridBuilder() - .addGrid(iee.inducedError, [100.0, 0.0, 10000.0]) - .build()) - cv = CrossValidator(estimator=iee, estimatorParamMaps=grid, evaluator=evaluator) - cvCopied = cv.copy() - self.assertEqual(cv.getEstimator().uid, cvCopied.getEstimator().uid) - - cvModel = cv.fit(dataset) - cvModelCopied = cvModel.copy() - for index in range(len(cvModel.avgMetrics)): - self.assertTrue(abs(cvModel.avgMetrics[index] - cvModelCopied.avgMetrics[index]) - < 0.0001) - - def test_fit_minimize_metric(self): - dataset = self.spark.createDataFrame([ - (10, 10.0), - (50, 50.0), - (100, 100.0), - (500, 500.0)] * 10, - ["feature", "label"]) - - iee = InducedErrorEstimator() - evaluator = RegressionEvaluator(metricName="rmse") - - grid = (ParamGridBuilder() - .addGrid(iee.inducedError, [100.0, 0.0, 10000.0]) - .build()) - cv = CrossValidator(estimator=iee, estimatorParamMaps=grid, evaluator=evaluator) - cvModel = cv.fit(dataset) - bestModel = cvModel.bestModel - bestModelMetric = evaluator.evaluate(bestModel.transform(dataset)) - - self.assertEqual(0.0, bestModel.getOrDefault('inducedError'), - "Best model should have zero induced error") - self.assertEqual(0.0, bestModelMetric, "Best model has RMSE of 0") - - def test_fit_maximize_metric(self): - dataset = self.spark.createDataFrame([ - (10, 10.0), - (50, 50.0), - (100, 100.0), - (500, 500.0)] * 10, - ["feature", "label"]) - - iee = InducedErrorEstimator() - evaluator = RegressionEvaluator(metricName="r2") - - grid = (ParamGridBuilder() - .addGrid(iee.inducedError, [100.0, 0.0, 10000.0]) - .build()) - cv = CrossValidator(estimator=iee, estimatorParamMaps=grid, evaluator=evaluator) - cvModel = cv.fit(dataset) - bestModel = cvModel.bestModel - bestModelMetric = evaluator.evaluate(bestModel.transform(dataset)) - - self.assertEqual(0.0, bestModel.getOrDefault('inducedError'), - "Best model should have zero induced error") - self.assertEqual(1.0, bestModelMetric, "Best model has R-squared of 1") - - def test_param_grid_type_coercion(self): - lr = LogisticRegression(maxIter=10) - paramGrid = ParamGridBuilder().addGrid(lr.regParam, [0.5, 1]).build() - for param in paramGrid: - for v in param.values(): - assert(type(v) == float) - - def test_save_load_trained_model(self): - # This tests saving and loading the trained model only. - # Save/load for CrossValidator will be added later: SPARK-13786 - temp_path = tempfile.mkdtemp() - dataset = self.spark.createDataFrame( - [(Vectors.dense([0.0]), 0.0), - (Vectors.dense([0.4]), 1.0), - (Vectors.dense([0.5]), 0.0), - (Vectors.dense([0.6]), 1.0), - (Vectors.dense([1.0]), 1.0)] * 10, - ["features", "label"]) - lr = LogisticRegression() - grid = ParamGridBuilder().addGrid(lr.maxIter, [0, 1]).build() - evaluator = BinaryClassificationEvaluator() - cv = CrossValidator(estimator=lr, estimatorParamMaps=grid, evaluator=evaluator) - cvModel = cv.fit(dataset) - lrModel = cvModel.bestModel - - cvModelPath = temp_path + "/cvModel" - lrModel.save(cvModelPath) - loadedLrModel = LogisticRegressionModel.load(cvModelPath) - self.assertEqual(loadedLrModel.uid, lrModel.uid) - self.assertEqual(loadedLrModel.intercept, lrModel.intercept) - - def test_save_load_simple_estimator(self): - temp_path = tempfile.mkdtemp() - dataset = self.spark.createDataFrame( - [(Vectors.dense([0.0]), 0.0), - (Vectors.dense([0.4]), 1.0), - (Vectors.dense([0.5]), 0.0), - (Vectors.dense([0.6]), 1.0), - (Vectors.dense([1.0]), 1.0)] * 10, - ["features", "label"]) - - lr = LogisticRegression() - grid = ParamGridBuilder().addGrid(lr.maxIter, [0, 1]).build() - evaluator = BinaryClassificationEvaluator() - - # test save/load of CrossValidator - cv = CrossValidator(estimator=lr, estimatorParamMaps=grid, evaluator=evaluator) - cvModel = cv.fit(dataset) - cvPath = temp_path + "/cv" - cv.save(cvPath) - loadedCV = CrossValidator.load(cvPath) - self.assertEqual(loadedCV.getEstimator().uid, cv.getEstimator().uid) - self.assertEqual(loadedCV.getEvaluator().uid, cv.getEvaluator().uid) - self.assertEqual(loadedCV.getEstimatorParamMaps(), cv.getEstimatorParamMaps()) - - # test save/load of CrossValidatorModel - cvModelPath = temp_path + "/cvModel" - cvModel.save(cvModelPath) - loadedModel = CrossValidatorModel.load(cvModelPath) - self.assertEqual(loadedModel.bestModel.uid, cvModel.bestModel.uid) - - def test_parallel_evaluation(self): - dataset = self.spark.createDataFrame( - [(Vectors.dense([0.0]), 0.0), - (Vectors.dense([0.4]), 1.0), - (Vectors.dense([0.5]), 0.0), - (Vectors.dense([0.6]), 1.0), - (Vectors.dense([1.0]), 1.0)] * 10, - ["features", "label"]) - - lr = LogisticRegression() - grid = ParamGridBuilder().addGrid(lr.maxIter, [5, 6]).build() - evaluator = BinaryClassificationEvaluator() - - # test save/load of CrossValidator - cv = CrossValidator(estimator=lr, estimatorParamMaps=grid, evaluator=evaluator) - cv.setParallelism(1) - cvSerialModel = cv.fit(dataset) - cv.setParallelism(2) - cvParallelModel = cv.fit(dataset) - self.assertEqual(cvSerialModel.avgMetrics, cvParallelModel.avgMetrics) - - def test_expose_sub_models(self): - temp_path = tempfile.mkdtemp() - dataset = self.spark.createDataFrame( - [(Vectors.dense([0.0]), 0.0), - (Vectors.dense([0.4]), 1.0), - (Vectors.dense([0.5]), 0.0), - (Vectors.dense([0.6]), 1.0), - (Vectors.dense([1.0]), 1.0)] * 10, - ["features", "label"]) - - lr = LogisticRegression() - grid = ParamGridBuilder().addGrid(lr.maxIter, [0, 1]).build() - evaluator = BinaryClassificationEvaluator() - - numFolds = 3 - cv = CrossValidator(estimator=lr, estimatorParamMaps=grid, evaluator=evaluator, - numFolds=numFolds, collectSubModels=True) - - def checkSubModels(subModels): - self.assertEqual(len(subModels), numFolds) - for i in range(numFolds): - self.assertEqual(len(subModels[i]), len(grid)) - - cvModel = cv.fit(dataset) - checkSubModels(cvModel.subModels) - - # Test the default value for option "persistSubModel" to be "true" - testSubPath = temp_path + "/testCrossValidatorSubModels" - savingPathWithSubModels = testSubPath + "cvModel3" - cvModel.save(savingPathWithSubModels) - cvModel3 = CrossValidatorModel.load(savingPathWithSubModels) - checkSubModels(cvModel3.subModels) - cvModel4 = cvModel3.copy() - checkSubModels(cvModel4.subModels) - - savingPathWithoutSubModels = testSubPath + "cvModel2" - cvModel.write().option("persistSubModels", "false").save(savingPathWithoutSubModels) - cvModel2 = CrossValidatorModel.load(savingPathWithoutSubModels) - self.assertEqual(cvModel2.subModels, None) - - for i in range(numFolds): - for j in range(len(grid)): - self.assertEqual(cvModel.subModels[i][j].uid, cvModel3.subModels[i][j].uid) - - def test_save_load_nested_estimator(self): - temp_path = tempfile.mkdtemp() - dataset = self.spark.createDataFrame( - [(Vectors.dense([0.0]), 0.0), - (Vectors.dense([0.4]), 1.0), - (Vectors.dense([0.5]), 0.0), - (Vectors.dense([0.6]), 1.0), - (Vectors.dense([1.0]), 1.0)] * 10, - ["features", "label"]) - - ova = OneVsRest(classifier=LogisticRegression()) - lr1 = LogisticRegression().setMaxIter(100) - lr2 = LogisticRegression().setMaxIter(150) - grid = ParamGridBuilder().addGrid(ova.classifier, [lr1, lr2]).build() - evaluator = MulticlassClassificationEvaluator() - - # test save/load of CrossValidator - cv = CrossValidator(estimator=ova, estimatorParamMaps=grid, evaluator=evaluator) - cvModel = cv.fit(dataset) - cvPath = temp_path + "/cv" - cv.save(cvPath) - loadedCV = CrossValidator.load(cvPath) - self.assertEqual(loadedCV.getEstimator().uid, cv.getEstimator().uid) - self.assertEqual(loadedCV.getEvaluator().uid, cv.getEvaluator().uid) - - originalParamMap = cv.getEstimatorParamMaps() - loadedParamMap = loadedCV.getEstimatorParamMaps() - for i, param in enumerate(loadedParamMap): - for p in param: - if p.name == "classifier": - self.assertEqual(param[p].uid, originalParamMap[i][p].uid) - else: - self.assertEqual(param[p], originalParamMap[i][p]) - - # test save/load of CrossValidatorModel - cvModelPath = temp_path + "/cvModel" - cvModel.save(cvModelPath) - loadedModel = CrossValidatorModel.load(cvModelPath) - self.assertEqual(loadedModel.bestModel.uid, cvModel.bestModel.uid) - - -class TrainValidationSplitTests(SparkSessionTestCase): - - def test_fit_minimize_metric(self): - dataset = self.spark.createDataFrame([ - (10, 10.0), - (50, 50.0), - (100, 100.0), - (500, 500.0)] * 10, - ["feature", "label"]) - - iee = InducedErrorEstimator() - evaluator = RegressionEvaluator(metricName="rmse") - - grid = ParamGridBuilder() \ - .addGrid(iee.inducedError, [100.0, 0.0, 10000.0]) \ - .build() - tvs = TrainValidationSplit(estimator=iee, estimatorParamMaps=grid, evaluator=evaluator) - tvsModel = tvs.fit(dataset) - bestModel = tvsModel.bestModel - bestModelMetric = evaluator.evaluate(bestModel.transform(dataset)) - validationMetrics = tvsModel.validationMetrics - - self.assertEqual(0.0, bestModel.getOrDefault('inducedError'), - "Best model should have zero induced error") - self.assertEqual(0.0, bestModelMetric, "Best model has RMSE of 0") - self.assertEqual(len(grid), len(validationMetrics), - "validationMetrics has the same size of grid parameter") - self.assertEqual(0.0, min(validationMetrics)) - - def test_fit_maximize_metric(self): - dataset = self.spark.createDataFrame([ - (10, 10.0), - (50, 50.0), - (100, 100.0), - (500, 500.0)] * 10, - ["feature", "label"]) - - iee = InducedErrorEstimator() - evaluator = RegressionEvaluator(metricName="r2") - - grid = ParamGridBuilder() \ - .addGrid(iee.inducedError, [100.0, 0.0, 10000.0]) \ - .build() - tvs = TrainValidationSplit(estimator=iee, estimatorParamMaps=grid, evaluator=evaluator) - tvsModel = tvs.fit(dataset) - bestModel = tvsModel.bestModel - bestModelMetric = evaluator.evaluate(bestModel.transform(dataset)) - validationMetrics = tvsModel.validationMetrics - - self.assertEqual(0.0, bestModel.getOrDefault('inducedError'), - "Best model should have zero induced error") - self.assertEqual(1.0, bestModelMetric, "Best model has R-squared of 1") - self.assertEqual(len(grid), len(validationMetrics), - "validationMetrics has the same size of grid parameter") - self.assertEqual(1.0, max(validationMetrics)) - - def test_save_load_trained_model(self): - # This tests saving and loading the trained model only. - # Save/load for TrainValidationSplit will be added later: SPARK-13786 - temp_path = tempfile.mkdtemp() - dataset = self.spark.createDataFrame( - [(Vectors.dense([0.0]), 0.0), - (Vectors.dense([0.4]), 1.0), - (Vectors.dense([0.5]), 0.0), - (Vectors.dense([0.6]), 1.0), - (Vectors.dense([1.0]), 1.0)] * 10, - ["features", "label"]) - lr = LogisticRegression() - grid = ParamGridBuilder().addGrid(lr.maxIter, [0, 1]).build() - evaluator = BinaryClassificationEvaluator() - tvs = TrainValidationSplit(estimator=lr, estimatorParamMaps=grid, evaluator=evaluator) - tvsModel = tvs.fit(dataset) - lrModel = tvsModel.bestModel - - tvsModelPath = temp_path + "/tvsModel" - lrModel.save(tvsModelPath) - loadedLrModel = LogisticRegressionModel.load(tvsModelPath) - self.assertEqual(loadedLrModel.uid, lrModel.uid) - self.assertEqual(loadedLrModel.intercept, lrModel.intercept) - - def test_save_load_simple_estimator(self): - # This tests saving and loading the trained model only. - # Save/load for TrainValidationSplit will be added later: SPARK-13786 - temp_path = tempfile.mkdtemp() - dataset = self.spark.createDataFrame( - [(Vectors.dense([0.0]), 0.0), - (Vectors.dense([0.4]), 1.0), - (Vectors.dense([0.5]), 0.0), - (Vectors.dense([0.6]), 1.0), - (Vectors.dense([1.0]), 1.0)] * 10, - ["features", "label"]) - lr = LogisticRegression() - grid = ParamGridBuilder().addGrid(lr.maxIter, [0, 1]).build() - evaluator = BinaryClassificationEvaluator() - tvs = TrainValidationSplit(estimator=lr, estimatorParamMaps=grid, evaluator=evaluator) - tvsModel = tvs.fit(dataset) - - tvsPath = temp_path + "/tvs" - tvs.save(tvsPath) - loadedTvs = TrainValidationSplit.load(tvsPath) - self.assertEqual(loadedTvs.getEstimator().uid, tvs.getEstimator().uid) - self.assertEqual(loadedTvs.getEvaluator().uid, tvs.getEvaluator().uid) - self.assertEqual(loadedTvs.getEstimatorParamMaps(), tvs.getEstimatorParamMaps()) - - tvsModelPath = temp_path + "/tvsModel" - tvsModel.save(tvsModelPath) - loadedModel = TrainValidationSplitModel.load(tvsModelPath) - self.assertEqual(loadedModel.bestModel.uid, tvsModel.bestModel.uid) - - def test_parallel_evaluation(self): - dataset = self.spark.createDataFrame( - [(Vectors.dense([0.0]), 0.0), - (Vectors.dense([0.4]), 1.0), - (Vectors.dense([0.5]), 0.0), - (Vectors.dense([0.6]), 1.0), - (Vectors.dense([1.0]), 1.0)] * 10, - ["features", "label"]) - lr = LogisticRegression() - grid = ParamGridBuilder().addGrid(lr.maxIter, [5, 6]).build() - evaluator = BinaryClassificationEvaluator() - tvs = TrainValidationSplit(estimator=lr, estimatorParamMaps=grid, evaluator=evaluator) - tvs.setParallelism(1) - tvsSerialModel = tvs.fit(dataset) - tvs.setParallelism(2) - tvsParallelModel = tvs.fit(dataset) - self.assertEqual(tvsSerialModel.validationMetrics, tvsParallelModel.validationMetrics) - - def test_expose_sub_models(self): - temp_path = tempfile.mkdtemp() - dataset = self.spark.createDataFrame( - [(Vectors.dense([0.0]), 0.0), - (Vectors.dense([0.4]), 1.0), - (Vectors.dense([0.5]), 0.0), - (Vectors.dense([0.6]), 1.0), - (Vectors.dense([1.0]), 1.0)] * 10, - ["features", "label"]) - lr = LogisticRegression() - grid = ParamGridBuilder().addGrid(lr.maxIter, [0, 1]).build() - evaluator = BinaryClassificationEvaluator() - tvs = TrainValidationSplit(estimator=lr, estimatorParamMaps=grid, evaluator=evaluator, - collectSubModels=True) - tvsModel = tvs.fit(dataset) - self.assertEqual(len(tvsModel.subModels), len(grid)) - - # Test the default value for option "persistSubModel" to be "true" - testSubPath = temp_path + "/testTrainValidationSplitSubModels" - savingPathWithSubModels = testSubPath + "cvModel3" - tvsModel.save(savingPathWithSubModels) - tvsModel3 = TrainValidationSplitModel.load(savingPathWithSubModels) - self.assertEqual(len(tvsModel3.subModels), len(grid)) - tvsModel4 = tvsModel3.copy() - self.assertEqual(len(tvsModel4.subModels), len(grid)) - - savingPathWithoutSubModels = testSubPath + "cvModel2" - tvsModel.write().option("persistSubModels", "false").save(savingPathWithoutSubModels) - tvsModel2 = TrainValidationSplitModel.load(savingPathWithoutSubModels) - self.assertEqual(tvsModel2.subModels, None) - - for i in range(len(grid)): - self.assertEqual(tvsModel.subModels[i].uid, tvsModel3.subModels[i].uid) - - def test_save_load_nested_estimator(self): - # This tests saving and loading the trained model only. - # Save/load for TrainValidationSplit will be added later: SPARK-13786 - temp_path = tempfile.mkdtemp() - dataset = self.spark.createDataFrame( - [(Vectors.dense([0.0]), 0.0), - (Vectors.dense([0.4]), 1.0), - (Vectors.dense([0.5]), 0.0), - (Vectors.dense([0.6]), 1.0), - (Vectors.dense([1.0]), 1.0)] * 10, - ["features", "label"]) - ova = OneVsRest(classifier=LogisticRegression()) - lr1 = LogisticRegression().setMaxIter(100) - lr2 = LogisticRegression().setMaxIter(150) - grid = ParamGridBuilder().addGrid(ova.classifier, [lr1, lr2]).build() - evaluator = MulticlassClassificationEvaluator() - - tvs = TrainValidationSplit(estimator=ova, estimatorParamMaps=grid, evaluator=evaluator) - tvsModel = tvs.fit(dataset) - tvsPath = temp_path + "/tvs" - tvs.save(tvsPath) - loadedTvs = TrainValidationSplit.load(tvsPath) - self.assertEqual(loadedTvs.getEstimator().uid, tvs.getEstimator().uid) - self.assertEqual(loadedTvs.getEvaluator().uid, tvs.getEvaluator().uid) - - originalParamMap = tvs.getEstimatorParamMaps() - loadedParamMap = loadedTvs.getEstimatorParamMaps() - for i, param in enumerate(loadedParamMap): - for p in param: - if p.name == "classifier": - self.assertEqual(param[p].uid, originalParamMap[i][p].uid) - else: - self.assertEqual(param[p], originalParamMap[i][p]) - - tvsModelPath = temp_path + "/tvsModel" - tvsModel.save(tvsModelPath) - loadedModel = TrainValidationSplitModel.load(tvsModelPath) - self.assertEqual(loadedModel.bestModel.uid, tvsModel.bestModel.uid) - - def test_copy(self): - dataset = self.spark.createDataFrame([ - (10, 10.0), - (50, 50.0), - (100, 100.0), - (500, 500.0)] * 10, - ["feature", "label"]) - - iee = InducedErrorEstimator() - evaluator = RegressionEvaluator(metricName="r2") - - grid = ParamGridBuilder() \ - .addGrid(iee.inducedError, [100.0, 0.0, 10000.0]) \ - .build() - tvs = TrainValidationSplit(estimator=iee, estimatorParamMaps=grid, evaluator=evaluator) - tvsModel = tvs.fit(dataset) - tvsCopied = tvs.copy() - tvsModelCopied = tvsModel.copy() - - self.assertEqual(tvs.getEstimator().uid, tvsCopied.getEstimator().uid, - "Copied TrainValidationSplit has the same uid of Estimator") - - self.assertEqual(tvsModel.bestModel.uid, tvsModelCopied.bestModel.uid) - self.assertEqual(len(tvsModel.validationMetrics), - len(tvsModelCopied.validationMetrics), - "Copied validationMetrics has the same size of the original") - for index in range(len(tvsModel.validationMetrics)): - self.assertEqual(tvsModel.validationMetrics[index], - tvsModelCopied.validationMetrics[index]) - - -class PersistenceTest(SparkSessionTestCase): - - def test_linear_regression(self): - lr = LinearRegression(maxIter=1) - path = tempfile.mkdtemp() - lr_path = path + "/lr" - lr.save(lr_path) - lr2 = LinearRegression.load(lr_path) - self.assertEqual(lr.uid, lr2.uid) - self.assertEqual(type(lr.uid), type(lr2.uid)) - self.assertEqual(lr2.uid, lr2.maxIter.parent, - "Loaded LinearRegression instance uid (%s) did not match Param's uid (%s)" - % (lr2.uid, lr2.maxIter.parent)) - self.assertEqual(lr._defaultParamMap[lr.maxIter], lr2._defaultParamMap[lr2.maxIter], - "Loaded LinearRegression instance default params did not match " + - "original defaults") - try: - rmtree(path) - except OSError: - pass - - def test_linear_regression_pmml_basic(self): - # Most of the validation is done in the Scala side, here we just check - # that we output text rather than parquet (e.g. that the format flag - # was respected). - df = self.spark.createDataFrame([(1.0, 2.0, Vectors.dense(1.0)), - (0.0, 2.0, Vectors.sparse(1, [], []))], - ["label", "weight", "features"]) - lr = LinearRegression(maxIter=1) - model = lr.fit(df) - path = tempfile.mkdtemp() - lr_path = path + "/lr-pmml" - model.write().format("pmml").save(lr_path) - pmml_text_list = self.sc.textFile(lr_path).collect() - pmml_text = "\n".join(pmml_text_list) - self.assertIn("Apache Spark", pmml_text) - self.assertIn("PMML", pmml_text) - - def test_logistic_regression(self): - lr = LogisticRegression(maxIter=1) - path = tempfile.mkdtemp() - lr_path = path + "/logreg" - lr.save(lr_path) - lr2 = LogisticRegression.load(lr_path) - self.assertEqual(lr2.uid, lr2.maxIter.parent, - "Loaded LogisticRegression instance uid (%s) " - "did not match Param's uid (%s)" - % (lr2.uid, lr2.maxIter.parent)) - self.assertEqual(lr._defaultParamMap[lr.maxIter], lr2._defaultParamMap[lr2.maxIter], - "Loaded LogisticRegression instance default params did not match " + - "original defaults") - try: - rmtree(path) - except OSError: - pass - - def _compare_params(self, m1, m2, param): - """ - Compare 2 ML Params instances for the given param, and assert both have the same param value - and parent. The param must be a parameter of m1. - """ - # Prevent key not found error in case of some param in neither paramMap nor defaultParamMap. - if m1.isDefined(param): - paramValue1 = m1.getOrDefault(param) - paramValue2 = m2.getOrDefault(m2.getParam(param.name)) - if isinstance(paramValue1, Params): - self._compare_pipelines(paramValue1, paramValue2) - else: - self.assertEqual(paramValue1, paramValue2) # for general types param - # Assert parents are equal - self.assertEqual(param.parent, m2.getParam(param.name).parent) - else: - # If m1 is not defined param, then m2 should not, too. See SPARK-14931. - self.assertFalse(m2.isDefined(m2.getParam(param.name))) - - def _compare_pipelines(self, m1, m2): - """ - Compare 2 ML types, asserting that they are equivalent. - This currently supports: - - basic types - - Pipeline, PipelineModel - - OneVsRest, OneVsRestModel - This checks: - - uid - - type - - Param values and parents - """ - self.assertEqual(m1.uid, m2.uid) - self.assertEqual(type(m1), type(m2)) - if isinstance(m1, JavaParams) or isinstance(m1, Transformer): - self.assertEqual(len(m1.params), len(m2.params)) - for p in m1.params: - self._compare_params(m1, m2, p) - elif isinstance(m1, Pipeline): - self.assertEqual(len(m1.getStages()), len(m2.getStages())) - for s1, s2 in zip(m1.getStages(), m2.getStages()): - self._compare_pipelines(s1, s2) - elif isinstance(m1, PipelineModel): - self.assertEqual(len(m1.stages), len(m2.stages)) - for s1, s2 in zip(m1.stages, m2.stages): - self._compare_pipelines(s1, s2) - elif isinstance(m1, OneVsRest) or isinstance(m1, OneVsRestModel): - for p in m1.params: - self._compare_params(m1, m2, p) - if isinstance(m1, OneVsRestModel): - self.assertEqual(len(m1.models), len(m2.models)) - for x, y in zip(m1.models, m2.models): - self._compare_pipelines(x, y) - else: - raise RuntimeError("_compare_pipelines does not yet support type: %s" % type(m1)) - - def test_pipeline_persistence(self): - """ - Pipeline[HashingTF, PCA] - """ - temp_path = tempfile.mkdtemp() - - try: - df = self.spark.createDataFrame([(["a", "b", "c"],), (["c", "d", "e"],)], ["words"]) - tf = HashingTF(numFeatures=10, inputCol="words", outputCol="features") - pca = PCA(k=2, inputCol="features", outputCol="pca_features") - pl = Pipeline(stages=[tf, pca]) - model = pl.fit(df) - - pipeline_path = temp_path + "/pipeline" - pl.save(pipeline_path) - loaded_pipeline = Pipeline.load(pipeline_path) - self._compare_pipelines(pl, loaded_pipeline) - - model_path = temp_path + "/pipeline-model" - model.save(model_path) - loaded_model = PipelineModel.load(model_path) - self._compare_pipelines(model, loaded_model) - finally: - try: - rmtree(temp_path) - except OSError: - pass - - def test_nested_pipeline_persistence(self): - """ - Pipeline[HashingTF, Pipeline[PCA]] - """ - temp_path = tempfile.mkdtemp() - - try: - df = self.spark.createDataFrame([(["a", "b", "c"],), (["c", "d", "e"],)], ["words"]) - tf = HashingTF(numFeatures=10, inputCol="words", outputCol="features") - pca = PCA(k=2, inputCol="features", outputCol="pca_features") - p0 = Pipeline(stages=[pca]) - pl = Pipeline(stages=[tf, p0]) - model = pl.fit(df) - - pipeline_path = temp_path + "/pipeline" - pl.save(pipeline_path) - loaded_pipeline = Pipeline.load(pipeline_path) - self._compare_pipelines(pl, loaded_pipeline) - - model_path = temp_path + "/pipeline-model" - model.save(model_path) - loaded_model = PipelineModel.load(model_path) - self._compare_pipelines(model, loaded_model) - finally: - try: - rmtree(temp_path) - except OSError: - pass - - def test_python_transformer_pipeline_persistence(self): - """ - Pipeline[MockUnaryTransformer, Binarizer] - """ - temp_path = tempfile.mkdtemp() - - try: - df = self.spark.range(0, 10).toDF('input') - tf = MockUnaryTransformer(shiftVal=2)\ - .setInputCol("input").setOutputCol("shiftedInput") - tf2 = Binarizer(threshold=6, inputCol="shiftedInput", outputCol="binarized") - pl = Pipeline(stages=[tf, tf2]) - model = pl.fit(df) - - pipeline_path = temp_path + "/pipeline" - pl.save(pipeline_path) - loaded_pipeline = Pipeline.load(pipeline_path) - self._compare_pipelines(pl, loaded_pipeline) - - model_path = temp_path + "/pipeline-model" - model.save(model_path) - loaded_model = PipelineModel.load(model_path) - self._compare_pipelines(model, loaded_model) - finally: - try: - rmtree(temp_path) - except OSError: - pass - - def test_onevsrest(self): - temp_path = tempfile.mkdtemp() - df = self.spark.createDataFrame([(0.0, Vectors.dense(1.0, 0.8)), - (1.0, Vectors.sparse(2, [], [])), - (2.0, Vectors.dense(0.5, 0.5))] * 10, - ["label", "features"]) - lr = LogisticRegression(maxIter=5, regParam=0.01) - ovr = OneVsRest(classifier=lr) - model = ovr.fit(df) - ovrPath = temp_path + "/ovr" - ovr.save(ovrPath) - loadedOvr = OneVsRest.load(ovrPath) - self._compare_pipelines(ovr, loadedOvr) - modelPath = temp_path + "/ovrModel" - model.save(modelPath) - loadedModel = OneVsRestModel.load(modelPath) - self._compare_pipelines(model, loadedModel) - - def test_decisiontree_classifier(self): - dt = DecisionTreeClassifier(maxDepth=1) - path = tempfile.mkdtemp() - dtc_path = path + "/dtc" - dt.save(dtc_path) - dt2 = DecisionTreeClassifier.load(dtc_path) - self.assertEqual(dt2.uid, dt2.maxDepth.parent, - "Loaded DecisionTreeClassifier instance uid (%s) " - "did not match Param's uid (%s)" - % (dt2.uid, dt2.maxDepth.parent)) - self.assertEqual(dt._defaultParamMap[dt.maxDepth], dt2._defaultParamMap[dt2.maxDepth], - "Loaded DecisionTreeClassifier instance default params did not match " + - "original defaults") - try: - rmtree(path) - except OSError: - pass - - def test_decisiontree_regressor(self): - dt = DecisionTreeRegressor(maxDepth=1) - path = tempfile.mkdtemp() - dtr_path = path + "/dtr" - dt.save(dtr_path) - dt2 = DecisionTreeClassifier.load(dtr_path) - self.assertEqual(dt2.uid, dt2.maxDepth.parent, - "Loaded DecisionTreeRegressor instance uid (%s) " - "did not match Param's uid (%s)" - % (dt2.uid, dt2.maxDepth.parent)) - self.assertEqual(dt._defaultParamMap[dt.maxDepth], dt2._defaultParamMap[dt2.maxDepth], - "Loaded DecisionTreeRegressor instance default params did not match " + - "original defaults") - try: - rmtree(path) - except OSError: - pass - - def test_default_read_write(self): - temp_path = tempfile.mkdtemp() - - lr = LogisticRegression() - lr.setMaxIter(50) - lr.setThreshold(.75) - writer = DefaultParamsWriter(lr) - - savePath = temp_path + "/lr" - writer.save(savePath) - - reader = DefaultParamsReadable.read() - lr2 = reader.load(savePath) - - self.assertEqual(lr.uid, lr2.uid) - self.assertEqual(lr.extractParamMap(), lr2.extractParamMap()) - - # test overwrite - lr.setThreshold(.8) - writer.overwrite().save(savePath) - - reader = DefaultParamsReadable.read() - lr3 = reader.load(savePath) - - self.assertEqual(lr.uid, lr3.uid) - self.assertEqual(lr.extractParamMap(), lr3.extractParamMap()) - - def test_default_read_write_default_params(self): - lr = LogisticRegression() - self.assertFalse(lr.isSet(lr.getParam("threshold"))) - - lr.setMaxIter(50) - lr.setThreshold(.75) - - # `threshold` is set by user, default param `predictionCol` is not set by user. - self.assertTrue(lr.isSet(lr.getParam("threshold"))) - self.assertFalse(lr.isSet(lr.getParam("predictionCol"))) - self.assertTrue(lr.hasDefault(lr.getParam("predictionCol"))) - - writer = DefaultParamsWriter(lr) - metadata = json.loads(writer._get_metadata_to_save(lr, self.sc)) - self.assertTrue("defaultParamMap" in metadata) - - reader = DefaultParamsReadable.read() - metadataStr = json.dumps(metadata, separators=[',', ':']) - loadedMetadata = reader._parseMetaData(metadataStr, ) - reader.getAndSetParams(lr, loadedMetadata) - - self.assertTrue(lr.isSet(lr.getParam("threshold"))) - self.assertFalse(lr.isSet(lr.getParam("predictionCol"))) - self.assertTrue(lr.hasDefault(lr.getParam("predictionCol"))) - - # manually create metadata without `defaultParamMap` section. - del metadata['defaultParamMap'] - metadataStr = json.dumps(metadata, separators=[',', ':']) - loadedMetadata = reader._parseMetaData(metadataStr, ) - with self.assertRaisesRegexp(AssertionError, "`defaultParamMap` section not found"): - reader.getAndSetParams(lr, loadedMetadata) - - # Prior to 2.4.0, metadata doesn't have `defaultParamMap`. - metadata['sparkVersion'] = '2.3.0' - metadataStr = json.dumps(metadata, separators=[',', ':']) - loadedMetadata = reader._parseMetaData(metadataStr, ) - reader.getAndSetParams(lr, loadedMetadata) - - -class LDATest(SparkSessionTestCase): - - def _compare(self, m1, m2): - """ - Temp method for comparing instances. - TODO: Replace with generic implementation once SPARK-14706 is merged. - """ - self.assertEqual(m1.uid, m2.uid) - self.assertEqual(type(m1), type(m2)) - self.assertEqual(len(m1.params), len(m2.params)) - for p in m1.params: - if m1.isDefined(p): - self.assertEqual(m1.getOrDefault(p), m2.getOrDefault(p)) - self.assertEqual(p.parent, m2.getParam(p.name).parent) - if isinstance(m1, LDAModel): - self.assertEqual(m1.vocabSize(), m2.vocabSize()) - self.assertEqual(m1.topicsMatrix(), m2.topicsMatrix()) - - def test_persistence(self): - # Test save/load for LDA, LocalLDAModel, DistributedLDAModel. - df = self.spark.createDataFrame([ - [1, Vectors.dense([0.0, 1.0])], - [2, Vectors.sparse(2, {0: 1.0})], - ], ["id", "features"]) - # Fit model - lda = LDA(k=2, seed=1, optimizer="em") - distributedModel = lda.fit(df) - self.assertTrue(distributedModel.isDistributed()) - localModel = distributedModel.toLocal() - self.assertFalse(localModel.isDistributed()) - # Define paths - path = tempfile.mkdtemp() - lda_path = path + "/lda" - dist_model_path = path + "/distLDAModel" - local_model_path = path + "/localLDAModel" - # Test LDA - lda.save(lda_path) - lda2 = LDA.load(lda_path) - self._compare(lda, lda2) - # Test DistributedLDAModel - distributedModel.save(dist_model_path) - distributedModel2 = DistributedLDAModel.load(dist_model_path) - self._compare(distributedModel, distributedModel2) - # Test LocalLDAModel - localModel.save(local_model_path) - localModel2 = LocalLDAModel.load(local_model_path) - self._compare(localModel, localModel2) - # Clean up - try: - rmtree(path) - except OSError: - pass - - -class TrainingSummaryTest(SparkSessionTestCase): - - def test_linear_regression_summary(self): - df = self.spark.createDataFrame([(1.0, 2.0, Vectors.dense(1.0)), - (0.0, 2.0, Vectors.sparse(1, [], []))], - ["label", "weight", "features"]) - lr = LinearRegression(maxIter=5, regParam=0.0, solver="normal", weightCol="weight", - fitIntercept=False) - model = lr.fit(df) - self.assertTrue(model.hasSummary) - s = model.summary - # test that api is callable and returns expected types - self.assertGreater(s.totalIterations, 0) - self.assertTrue(isinstance(s.predictions, DataFrame)) - self.assertEqual(s.predictionCol, "prediction") - self.assertEqual(s.labelCol, "label") - self.assertEqual(s.featuresCol, "features") - objHist = s.objectiveHistory - self.assertTrue(isinstance(objHist, list) and isinstance(objHist[0], float)) - self.assertAlmostEqual(s.explainedVariance, 0.25, 2) - self.assertAlmostEqual(s.meanAbsoluteError, 0.0) - self.assertAlmostEqual(s.meanSquaredError, 0.0) - self.assertAlmostEqual(s.rootMeanSquaredError, 0.0) - self.assertAlmostEqual(s.r2, 1.0, 2) - self.assertAlmostEqual(s.r2adj, 1.0, 2) - self.assertTrue(isinstance(s.residuals, DataFrame)) - self.assertEqual(s.numInstances, 2) - self.assertEqual(s.degreesOfFreedom, 1) - devResiduals = s.devianceResiduals - self.assertTrue(isinstance(devResiduals, list) and isinstance(devResiduals[0], float)) - coefStdErr = s.coefficientStandardErrors - self.assertTrue(isinstance(coefStdErr, list) and isinstance(coefStdErr[0], float)) - tValues = s.tValues - self.assertTrue(isinstance(tValues, list) and isinstance(tValues[0], float)) - pValues = s.pValues - self.assertTrue(isinstance(pValues, list) and isinstance(pValues[0], float)) - # test evaluation (with training dataset) produces a summary with same values - # one check is enough to verify a summary is returned - # The child class LinearRegressionTrainingSummary runs full test - sameSummary = model.evaluate(df) - self.assertAlmostEqual(sameSummary.explainedVariance, s.explainedVariance) - - def test_glr_summary(self): - from pyspark.ml.linalg import Vectors - df = self.spark.createDataFrame([(1.0, 2.0, Vectors.dense(1.0)), - (0.0, 2.0, Vectors.sparse(1, [], []))], - ["label", "weight", "features"]) - glr = GeneralizedLinearRegression(family="gaussian", link="identity", weightCol="weight", - fitIntercept=False) - model = glr.fit(df) - self.assertTrue(model.hasSummary) - s = model.summary - # test that api is callable and returns expected types - self.assertEqual(s.numIterations, 1) # this should default to a single iteration of WLS - self.assertTrue(isinstance(s.predictions, DataFrame)) - self.assertEqual(s.predictionCol, "prediction") - self.assertEqual(s.numInstances, 2) - self.assertTrue(isinstance(s.residuals(), DataFrame)) - self.assertTrue(isinstance(s.residuals("pearson"), DataFrame)) - coefStdErr = s.coefficientStandardErrors - self.assertTrue(isinstance(coefStdErr, list) and isinstance(coefStdErr[0], float)) - tValues = s.tValues - self.assertTrue(isinstance(tValues, list) and isinstance(tValues[0], float)) - pValues = s.pValues - self.assertTrue(isinstance(pValues, list) and isinstance(pValues[0], float)) - self.assertEqual(s.degreesOfFreedom, 1) - self.assertEqual(s.residualDegreeOfFreedom, 1) - self.assertEqual(s.residualDegreeOfFreedomNull, 2) - self.assertEqual(s.rank, 1) - self.assertTrue(isinstance(s.solver, basestring)) - self.assertTrue(isinstance(s.aic, float)) - self.assertTrue(isinstance(s.deviance, float)) - self.assertTrue(isinstance(s.nullDeviance, float)) - self.assertTrue(isinstance(s.dispersion, float)) - # test evaluation (with training dataset) produces a summary with same values - # one check is enough to verify a summary is returned - # The child class GeneralizedLinearRegressionTrainingSummary runs full test - sameSummary = model.evaluate(df) - self.assertAlmostEqual(sameSummary.deviance, s.deviance) - - def test_binary_logistic_regression_summary(self): - df = self.spark.createDataFrame([(1.0, 2.0, Vectors.dense(1.0)), - (0.0, 2.0, Vectors.sparse(1, [], []))], - ["label", "weight", "features"]) - lr = LogisticRegression(maxIter=5, regParam=0.01, weightCol="weight", fitIntercept=False) - model = lr.fit(df) - self.assertTrue(model.hasSummary) - s = model.summary - # test that api is callable and returns expected types - self.assertTrue(isinstance(s.predictions, DataFrame)) - self.assertEqual(s.probabilityCol, "probability") - self.assertEqual(s.labelCol, "label") - self.assertEqual(s.featuresCol, "features") - self.assertEqual(s.predictionCol, "prediction") - objHist = s.objectiveHistory - self.assertTrue(isinstance(objHist, list) and isinstance(objHist[0], float)) - self.assertGreater(s.totalIterations, 0) - self.assertTrue(isinstance(s.labels, list)) - self.assertTrue(isinstance(s.truePositiveRateByLabel, list)) - self.assertTrue(isinstance(s.falsePositiveRateByLabel, list)) - self.assertTrue(isinstance(s.precisionByLabel, list)) - self.assertTrue(isinstance(s.recallByLabel, list)) - self.assertTrue(isinstance(s.fMeasureByLabel(), list)) - self.assertTrue(isinstance(s.fMeasureByLabel(1.0), list)) - self.assertTrue(isinstance(s.roc, DataFrame)) - self.assertAlmostEqual(s.areaUnderROC, 1.0, 2) - self.assertTrue(isinstance(s.pr, DataFrame)) - self.assertTrue(isinstance(s.fMeasureByThreshold, DataFrame)) - self.assertTrue(isinstance(s.precisionByThreshold, DataFrame)) - self.assertTrue(isinstance(s.recallByThreshold, DataFrame)) - self.assertAlmostEqual(s.accuracy, 1.0, 2) - self.assertAlmostEqual(s.weightedTruePositiveRate, 1.0, 2) - self.assertAlmostEqual(s.weightedFalsePositiveRate, 0.0, 2) - self.assertAlmostEqual(s.weightedRecall, 1.0, 2) - self.assertAlmostEqual(s.weightedPrecision, 1.0, 2) - self.assertAlmostEqual(s.weightedFMeasure(), 1.0, 2) - self.assertAlmostEqual(s.weightedFMeasure(1.0), 1.0, 2) - # test evaluation (with training dataset) produces a summary with same values - # one check is enough to verify a summary is returned, Scala version runs full test - sameSummary = model.evaluate(df) - self.assertAlmostEqual(sameSummary.areaUnderROC, s.areaUnderROC) - - def test_multiclass_logistic_regression_summary(self): - df = self.spark.createDataFrame([(1.0, 2.0, Vectors.dense(1.0)), - (0.0, 2.0, Vectors.sparse(1, [], [])), - (2.0, 2.0, Vectors.dense(2.0)), - (2.0, 2.0, Vectors.dense(1.9))], - ["label", "weight", "features"]) - lr = LogisticRegression(maxIter=5, regParam=0.01, weightCol="weight", fitIntercept=False) - model = lr.fit(df) - self.assertTrue(model.hasSummary) - s = model.summary - # test that api is callable and returns expected types - self.assertTrue(isinstance(s.predictions, DataFrame)) - self.assertEqual(s.probabilityCol, "probability") - self.assertEqual(s.labelCol, "label") - self.assertEqual(s.featuresCol, "features") - self.assertEqual(s.predictionCol, "prediction") - objHist = s.objectiveHistory - self.assertTrue(isinstance(objHist, list) and isinstance(objHist[0], float)) - self.assertGreater(s.totalIterations, 0) - self.assertTrue(isinstance(s.labels, list)) - self.assertTrue(isinstance(s.truePositiveRateByLabel, list)) - self.assertTrue(isinstance(s.falsePositiveRateByLabel, list)) - self.assertTrue(isinstance(s.precisionByLabel, list)) - self.assertTrue(isinstance(s.recallByLabel, list)) - self.assertTrue(isinstance(s.fMeasureByLabel(), list)) - self.assertTrue(isinstance(s.fMeasureByLabel(1.0), list)) - self.assertAlmostEqual(s.accuracy, 0.75, 2) - self.assertAlmostEqual(s.weightedTruePositiveRate, 0.75, 2) - self.assertAlmostEqual(s.weightedFalsePositiveRate, 0.25, 2) - self.assertAlmostEqual(s.weightedRecall, 0.75, 2) - self.assertAlmostEqual(s.weightedPrecision, 0.583, 2) - self.assertAlmostEqual(s.weightedFMeasure(), 0.65, 2) - self.assertAlmostEqual(s.weightedFMeasure(1.0), 0.65, 2) - # test evaluation (with training dataset) produces a summary with same values - # one check is enough to verify a summary is returned, Scala version runs full test - sameSummary = model.evaluate(df) - self.assertAlmostEqual(sameSummary.accuracy, s.accuracy) - - def test_gaussian_mixture_summary(self): - data = [(Vectors.dense(1.0),), (Vectors.dense(5.0),), (Vectors.dense(10.0),), - (Vectors.sparse(1, [], []),)] - df = self.spark.createDataFrame(data, ["features"]) - gmm = GaussianMixture(k=2) - model = gmm.fit(df) - self.assertTrue(model.hasSummary) - s = model.summary - self.assertTrue(isinstance(s.predictions, DataFrame)) - self.assertEqual(s.probabilityCol, "probability") - self.assertTrue(isinstance(s.probability, DataFrame)) - self.assertEqual(s.featuresCol, "features") - self.assertEqual(s.predictionCol, "prediction") - self.assertTrue(isinstance(s.cluster, DataFrame)) - self.assertEqual(len(s.clusterSizes), 2) - self.assertEqual(s.k, 2) - self.assertEqual(s.numIter, 3) - - def test_bisecting_kmeans_summary(self): - data = [(Vectors.dense(1.0),), (Vectors.dense(5.0),), (Vectors.dense(10.0),), - (Vectors.sparse(1, [], []),)] - df = self.spark.createDataFrame(data, ["features"]) - bkm = BisectingKMeans(k=2) - model = bkm.fit(df) - self.assertTrue(model.hasSummary) - s = model.summary - self.assertTrue(isinstance(s.predictions, DataFrame)) - self.assertEqual(s.featuresCol, "features") - self.assertEqual(s.predictionCol, "prediction") - self.assertTrue(isinstance(s.cluster, DataFrame)) - self.assertEqual(len(s.clusterSizes), 2) - self.assertEqual(s.k, 2) - self.assertEqual(s.numIter, 20) - - def test_kmeans_summary(self): - data = [(Vectors.dense([0.0, 0.0]),), (Vectors.dense([1.0, 1.0]),), - (Vectors.dense([9.0, 8.0]),), (Vectors.dense([8.0, 9.0]),)] - df = self.spark.createDataFrame(data, ["features"]) - kmeans = KMeans(k=2, seed=1) - model = kmeans.fit(df) - self.assertTrue(model.hasSummary) - s = model.summary - self.assertTrue(isinstance(s.predictions, DataFrame)) - self.assertEqual(s.featuresCol, "features") - self.assertEqual(s.predictionCol, "prediction") - self.assertTrue(isinstance(s.cluster, DataFrame)) - self.assertEqual(len(s.clusterSizes), 2) - self.assertEqual(s.k, 2) - self.assertEqual(s.numIter, 1) - - -class KMeansTests(SparkSessionTestCase): - - def test_kmeans_cosine_distance(self): - data = [(Vectors.dense([1.0, 1.0]),), (Vectors.dense([10.0, 10.0]),), - (Vectors.dense([1.0, 0.5]),), (Vectors.dense([10.0, 4.4]),), - (Vectors.dense([-1.0, 1.0]),), (Vectors.dense([-100.0, 90.0]),)] - df = self.spark.createDataFrame(data, ["features"]) - kmeans = KMeans(k=3, seed=1, distanceMeasure="cosine") - model = kmeans.fit(df) - result = model.transform(df).collect() - self.assertTrue(result[0].prediction == result[1].prediction) - self.assertTrue(result[2].prediction == result[3].prediction) - self.assertTrue(result[4].prediction == result[5].prediction) - - -class OneVsRestTests(SparkSessionTestCase): - - def test_copy(self): - df = self.spark.createDataFrame([(0.0, Vectors.dense(1.0, 0.8)), - (1.0, Vectors.sparse(2, [], [])), - (2.0, Vectors.dense(0.5, 0.5))], - ["label", "features"]) - lr = LogisticRegression(maxIter=5, regParam=0.01) - ovr = OneVsRest(classifier=lr) - ovr1 = ovr.copy({lr.maxIter: 10}) - self.assertEqual(ovr.getClassifier().getMaxIter(), 5) - self.assertEqual(ovr1.getClassifier().getMaxIter(), 10) - model = ovr.fit(df) - model1 = model.copy({model.predictionCol: "indexed"}) - self.assertEqual(model1.getPredictionCol(), "indexed") - - def test_output_columns(self): - df = self.spark.createDataFrame([(0.0, Vectors.dense(1.0, 0.8)), - (1.0, Vectors.sparse(2, [], [])), - (2.0, Vectors.dense(0.5, 0.5))], - ["label", "features"]) - lr = LogisticRegression(maxIter=5, regParam=0.01) - ovr = OneVsRest(classifier=lr, parallelism=1) - model = ovr.fit(df) - output = model.transform(df) - self.assertEqual(output.columns, ["label", "features", "prediction"]) - - def test_parallelism_doesnt_change_output(self): - df = self.spark.createDataFrame([(0.0, Vectors.dense(1.0, 0.8)), - (1.0, Vectors.sparse(2, [], [])), - (2.0, Vectors.dense(0.5, 0.5))], - ["label", "features"]) - ovrPar1 = OneVsRest(classifier=LogisticRegression(maxIter=5, regParam=.01), parallelism=1) - modelPar1 = ovrPar1.fit(df) - ovrPar2 = OneVsRest(classifier=LogisticRegression(maxIter=5, regParam=.01), parallelism=2) - modelPar2 = ovrPar2.fit(df) - for i, model in enumerate(modelPar1.models): - self.assertTrue(np.allclose(model.coefficients.toArray(), - modelPar2.models[i].coefficients.toArray(), atol=1E-4)) - self.assertTrue(np.allclose(model.intercept, modelPar2.models[i].intercept, atol=1E-4)) - - def test_support_for_weightCol(self): - df = self.spark.createDataFrame([(0.0, Vectors.dense(1.0, 0.8), 1.0), - (1.0, Vectors.sparse(2, [], []), 1.0), - (2.0, Vectors.dense(0.5, 0.5), 1.0)], - ["label", "features", "weight"]) - # classifier inherits hasWeightCol - lr = LogisticRegression(maxIter=5, regParam=0.01) - ovr = OneVsRest(classifier=lr, weightCol="weight") - self.assertIsNotNone(ovr.fit(df)) - # classifier doesn't inherit hasWeightCol - dt = DecisionTreeClassifier() - ovr2 = OneVsRest(classifier=dt, weightCol="weight") - self.assertIsNotNone(ovr2.fit(df)) - - -class HashingTFTest(SparkSessionTestCase): - - def test_apply_binary_term_freqs(self): - - df = self.spark.createDataFrame([(0, ["a", "a", "b", "c", "c", "c"])], ["id", "words"]) - n = 10 - hashingTF = HashingTF() - hashingTF.setInputCol("words").setOutputCol("features").setNumFeatures(n).setBinary(True) - output = hashingTF.transform(df) - features = output.select("features").first().features.toArray() - expected = Vectors.dense([1.0, 1.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]).toArray() - for i in range(0, n): - self.assertAlmostEqual(features[i], expected[i], 14, "Error at " + str(i) + - ": expected " + str(expected[i]) + ", got " + str(features[i])) - - -class GeneralizedLinearRegressionTest(SparkSessionTestCase): - - def test_tweedie_distribution(self): - - df = self.spark.createDataFrame( - [(1.0, Vectors.dense(0.0, 0.0)), - (1.0, Vectors.dense(1.0, 2.0)), - (2.0, Vectors.dense(0.0, 0.0)), - (2.0, Vectors.dense(1.0, 1.0)), ], ["label", "features"]) - - glr = GeneralizedLinearRegression(family="tweedie", variancePower=1.6) - model = glr.fit(df) - self.assertTrue(np.allclose(model.coefficients.toArray(), [-0.4645, 0.3402], atol=1E-4)) - self.assertTrue(np.isclose(model.intercept, 0.7841, atol=1E-4)) - - model2 = glr.setLinkPower(-1.0).fit(df) - self.assertTrue(np.allclose(model2.coefficients.toArray(), [-0.6667, 0.5], atol=1E-4)) - self.assertTrue(np.isclose(model2.intercept, 0.6667, atol=1E-4)) - - def test_offset(self): - - df = self.spark.createDataFrame( - [(0.2, 1.0, 2.0, Vectors.dense(0.0, 5.0)), - (0.5, 2.1, 0.5, Vectors.dense(1.0, 2.0)), - (0.9, 0.4, 1.0, Vectors.dense(2.0, 1.0)), - (0.7, 0.7, 0.0, Vectors.dense(3.0, 3.0))], ["label", "weight", "offset", "features"]) - - glr = GeneralizedLinearRegression(family="poisson", weightCol="weight", offsetCol="offset") - model = glr.fit(df) - self.assertTrue(np.allclose(model.coefficients.toArray(), [0.664647, -0.3192581], - atol=1E-4)) - self.assertTrue(np.isclose(model.intercept, -1.561613, atol=1E-4)) - - -class LinearRegressionTest(SparkSessionTestCase): - - def test_linear_regression_with_huber_loss(self): - - data_path = "data/mllib/sample_linear_regression_data.txt" - df = self.spark.read.format("libsvm").load(data_path) - - lir = LinearRegression(loss="huber", epsilon=2.0) - model = lir.fit(df) - - expectedCoefficients = [0.136, 0.7648, -0.7761, 2.4236, 0.537, - 1.2612, -0.333, -0.5694, -0.6311, 0.6053] - expectedIntercept = 0.1607 - expectedScale = 9.758 - - self.assertTrue( - np.allclose(model.coefficients.toArray(), expectedCoefficients, atol=1E-3)) - self.assertTrue(np.isclose(model.intercept, expectedIntercept, atol=1E-3)) - self.assertTrue(np.isclose(model.scale, expectedScale, atol=1E-3)) - - -class LogisticRegressionTest(SparkSessionTestCase): - - def test_binomial_logistic_regression_with_bound(self): - - df = self.spark.createDataFrame( - [(1.0, 1.0, Vectors.dense(0.0, 5.0)), - (0.0, 2.0, Vectors.dense(1.0, 2.0)), - (1.0, 3.0, Vectors.dense(2.0, 1.0)), - (0.0, 4.0, Vectors.dense(3.0, 3.0)), ], ["label", "weight", "features"]) - - lor = LogisticRegression(regParam=0.01, weightCol="weight", - lowerBoundsOnCoefficients=Matrices.dense(1, 2, [-1.0, -1.0]), - upperBoundsOnIntercepts=Vectors.dense(0.0)) - model = lor.fit(df) - self.assertTrue( - np.allclose(model.coefficients.toArray(), [-0.2944, -0.0484], atol=1E-4)) - self.assertTrue(np.isclose(model.intercept, 0.0, atol=1E-4)) - - def test_multinomial_logistic_regression_with_bound(self): - - data_path = "data/mllib/sample_multiclass_classification_data.txt" - df = self.spark.read.format("libsvm").load(data_path) - - lor = LogisticRegression(regParam=0.01, - lowerBoundsOnCoefficients=Matrices.dense(3, 4, range(12)), - upperBoundsOnIntercepts=Vectors.dense(0.0, 0.0, 0.0)) - model = lor.fit(df) - expected = [[4.593, 4.5516, 9.0099, 12.2904], - [1.0, 8.1093, 7.0, 10.0], - [3.041, 5.0, 8.0, 11.0]] - for i in range(0, len(expected)): - self.assertTrue( - np.allclose(model.coefficientMatrix.toArray()[i], expected[i], atol=1E-4)) - self.assertTrue( - np.allclose(model.interceptVector.toArray(), [-0.9057, -1.1392, -0.0033], atol=1E-4)) - - -class MultilayerPerceptronClassifierTest(SparkSessionTestCase): - - def test_raw_and_probability_prediction(self): - - data_path = "data/mllib/sample_multiclass_classification_data.txt" - df = self.spark.read.format("libsvm").load(data_path) - - mlp = MultilayerPerceptronClassifier(maxIter=100, layers=[4, 5, 4, 3], - blockSize=128, seed=123) - model = mlp.fit(df) - test = self.sc.parallelize([Row(features=Vectors.dense(0.1, 0.1, 0.25, 0.25))]).toDF() - result = model.transform(test).head() - expected_prediction = 2.0 - expected_probability = [0.0, 0.0, 1.0] - expected_rawPrediction = [57.3955, -124.5462, 67.9943] - self.assertTrue(result.prediction, expected_prediction) - self.assertTrue(np.allclose(result.probability, expected_probability, atol=1E-4)) - self.assertTrue(np.allclose(result.rawPrediction, expected_rawPrediction, atol=1E-4)) - - -class FPGrowthTests(SparkSessionTestCase): - def setUp(self): - super(FPGrowthTests, self).setUp() - self.data = self.spark.createDataFrame( - [([1, 2], ), ([1, 2], ), ([1, 2, 3], ), ([1, 3], )], - ["items"]) - - def test_association_rules(self): - fp = FPGrowth() - fpm = fp.fit(self.data) - - expected_association_rules = self.spark.createDataFrame( - [([3], [1], 1.0, 1.0), ([2], [1], 1.0, 1.0)], - ["antecedent", "consequent", "confidence", "lift"] - ) - actual_association_rules = fpm.associationRules - - self.assertEqual(actual_association_rules.subtract(expected_association_rules).count(), 0) - self.assertEqual(expected_association_rules.subtract(actual_association_rules).count(), 0) - - def test_freq_itemsets(self): - fp = FPGrowth() - fpm = fp.fit(self.data) - - expected_freq_itemsets = self.spark.createDataFrame( - [([1], 4), ([2], 3), ([2, 1], 3), ([3], 2), ([3, 1], 2)], - ["items", "freq"] - ) - actual_freq_itemsets = fpm.freqItemsets - - self.assertEqual(actual_freq_itemsets.subtract(expected_freq_itemsets).count(), 0) - self.assertEqual(expected_freq_itemsets.subtract(actual_freq_itemsets).count(), 0) - - def tearDown(self): - del self.data - - -class ImageReaderTest(SparkSessionTestCase): - - def test_read_images(self): - data_path = 'data/mllib/images/origin/kittens' - df = ImageSchema.readImages(data_path, recursive=True, dropImageFailures=True) - self.assertEqual(df.count(), 4) - first_row = df.take(1)[0][0] - array = ImageSchema.toNDArray(first_row) - self.assertEqual(len(array), first_row[1]) - self.assertEqual(ImageSchema.toImage(array, origin=first_row[0]), first_row) - self.assertEqual(df.schema, ImageSchema.imageSchema) - self.assertEqual(df.schema["image"].dataType, ImageSchema.columnSchema) - expected = {'CV_8UC3': 16, 'Undefined': -1, 'CV_8U': 0, 'CV_8UC1': 0, 'CV_8UC4': 24} - self.assertEqual(ImageSchema.ocvTypes, expected) - expected = ['origin', 'height', 'width', 'nChannels', 'mode', 'data'] - self.assertEqual(ImageSchema.imageFields, expected) - self.assertEqual(ImageSchema.undefinedImageType, "Undefined") - - with QuietTest(self.sc): - self.assertRaisesRegexp( - TypeError, - "image argument should be pyspark.sql.types.Row; however", - lambda: ImageSchema.toNDArray("a")) - - with QuietTest(self.sc): - self.assertRaisesRegexp( - ValueError, - "image argument should have attributes specified in", - lambda: ImageSchema.toNDArray(Row(a=1))) - - with QuietTest(self.sc): - self.assertRaisesRegexp( - TypeError, - "array argument should be numpy.ndarray; however, it got", - lambda: ImageSchema.toImage("a")) - - -class ImageReaderTest2(PySparkTestCase): - - @classmethod - def setUpClass(cls): - super(ImageReaderTest2, cls).setUpClass() - cls.hive_available = True - # Note that here we enable Hive's support. - cls.spark = None - try: - cls.sc._jvm.org.apache.hadoop.hive.conf.HiveConf() - except py4j.protocol.Py4JError: - cls.tearDownClass() - cls.hive_available = False - except TypeError: - cls.tearDownClass() - cls.hive_available = False - if cls.hive_available: - cls.spark = HiveContext._createForTesting(cls.sc) - - def setUp(self): - if not self.hive_available: - self.skipTest("Hive is not available.") - - @classmethod - def tearDownClass(cls): - super(ImageReaderTest2, cls).tearDownClass() - if cls.spark is not None: - cls.spark.sparkSession.stop() - cls.spark = None - - def test_read_images_multiple_times(self): - # This test case is to check if `ImageSchema.readImages` tries to - # initiate Hive client multiple times. See SPARK-22651. - data_path = 'data/mllib/images/origin/kittens' - ImageSchema.readImages(data_path, recursive=True, dropImageFailures=True) - ImageSchema.readImages(data_path, recursive=True, dropImageFailures=True) - - -class ALSTest(SparkSessionTestCase): - - def test_storage_levels(self): - df = self.spark.createDataFrame( - [(0, 0, 4.0), (0, 1, 2.0), (1, 1, 3.0), (1, 2, 4.0), (2, 1, 1.0), (2, 2, 5.0)], - ["user", "item", "rating"]) - als = ALS().setMaxIter(1).setRank(1) - # test default params - als.fit(df) - self.assertEqual(als.getIntermediateStorageLevel(), "MEMORY_AND_DISK") - self.assertEqual(als._java_obj.getIntermediateStorageLevel(), "MEMORY_AND_DISK") - self.assertEqual(als.getFinalStorageLevel(), "MEMORY_AND_DISK") - self.assertEqual(als._java_obj.getFinalStorageLevel(), "MEMORY_AND_DISK") - # test non-default params - als.setIntermediateStorageLevel("MEMORY_ONLY_2") - als.setFinalStorageLevel("DISK_ONLY") - als.fit(df) - self.assertEqual(als.getIntermediateStorageLevel(), "MEMORY_ONLY_2") - self.assertEqual(als._java_obj.getIntermediateStorageLevel(), "MEMORY_ONLY_2") - self.assertEqual(als.getFinalStorageLevel(), "DISK_ONLY") - self.assertEqual(als._java_obj.getFinalStorageLevel(), "DISK_ONLY") - - -class DefaultValuesTests(PySparkTestCase): - """ - Test :py:class:`JavaParams` classes to see if their default Param values match - those in their Scala counterparts. - """ - - def test_java_params(self): - import pyspark.ml.feature - import pyspark.ml.classification - import pyspark.ml.clustering - import pyspark.ml.evaluation - import pyspark.ml.pipeline - import pyspark.ml.recommendation - import pyspark.ml.regression - - modules = [pyspark.ml.feature, pyspark.ml.classification, pyspark.ml.clustering, - pyspark.ml.evaluation, pyspark.ml.pipeline, pyspark.ml.recommendation, - pyspark.ml.regression] - for module in modules: - for name, cls in inspect.getmembers(module, inspect.isclass): - if not name.endswith('Model') and not name.endswith('Params')\ - and issubclass(cls, JavaParams) and not inspect.isabstract(cls): - # NOTE: disable check_params_exist until there is parity with Scala API - ParamTests.check_params(self, cls(), check_params_exist=False) - - # Additional classes that need explicit construction - from pyspark.ml.feature import CountVectorizerModel, StringIndexerModel - ParamTests.check_params(self, CountVectorizerModel.from_vocabulary(['a'], 'input'), - check_params_exist=False) - ParamTests.check_params(self, StringIndexerModel.from_labels(['a', 'b'], 'input'), - check_params_exist=False) - - -def _squared_distance(a, b): - if isinstance(a, Vector): - return a.squared_distance(b) - else: - return b.squared_distance(a) - - -class VectorTests(MLlibTestCase): - - def _test_serialize(self, v): - self.assertEqual(v, ser.loads(ser.dumps(v))) - jvec = self.sc._jvm.org.apache.spark.ml.python.MLSerDe.loads(bytearray(ser.dumps(v))) - nv = ser.loads(bytes(self.sc._jvm.org.apache.spark.ml.python.MLSerDe.dumps(jvec))) - self.assertEqual(v, nv) - vs = [v] * 100 - jvecs = self.sc._jvm.org.apache.spark.ml.python.MLSerDe.loads(bytearray(ser.dumps(vs))) - nvs = ser.loads(bytes(self.sc._jvm.org.apache.spark.ml.python.MLSerDe.dumps(jvecs))) - self.assertEqual(vs, nvs) - - def test_serialize(self): - self._test_serialize(DenseVector(range(10))) - self._test_serialize(DenseVector(array([1., 2., 3., 4.]))) - self._test_serialize(DenseVector(pyarray.array('d', range(10)))) - self._test_serialize(SparseVector(4, {1: 1, 3: 2})) - self._test_serialize(SparseVector(3, {})) - self._test_serialize(DenseMatrix(2, 3, range(6))) - sm1 = SparseMatrix( - 3, 4, [0, 2, 2, 4, 4], [1, 2, 1, 2], [1.0, 2.0, 4.0, 5.0]) - self._test_serialize(sm1) - - def test_dot(self): - sv = SparseVector(4, {1: 1, 3: 2}) - dv = DenseVector(array([1., 2., 3., 4.])) - lst = DenseVector([1, 2, 3, 4]) - mat = array([[1., 2., 3., 4.], - [1., 2., 3., 4.], - [1., 2., 3., 4.], - [1., 2., 3., 4.]]) - arr = pyarray.array('d', [0, 1, 2, 3]) - self.assertEqual(10.0, sv.dot(dv)) - self.assertTrue(array_equal(array([3., 6., 9., 12.]), sv.dot(mat))) - self.assertEqual(30.0, dv.dot(dv)) - self.assertTrue(array_equal(array([10., 20., 30., 40.]), dv.dot(mat))) - self.assertEqual(30.0, lst.dot(dv)) - self.assertTrue(array_equal(array([10., 20., 30., 40.]), lst.dot(mat))) - self.assertEqual(7.0, sv.dot(arr)) - - def test_squared_distance(self): - sv = SparseVector(4, {1: 1, 3: 2}) - dv = DenseVector(array([1., 2., 3., 4.])) - lst = DenseVector([4, 3, 2, 1]) - lst1 = [4, 3, 2, 1] - arr = pyarray.array('d', [0, 2, 1, 3]) - narr = array([0, 2, 1, 3]) - self.assertEqual(15.0, _squared_distance(sv, dv)) - self.assertEqual(25.0, _squared_distance(sv, lst)) - self.assertEqual(20.0, _squared_distance(dv, lst)) - self.assertEqual(15.0, _squared_distance(dv, sv)) - self.assertEqual(25.0, _squared_distance(lst, sv)) - self.assertEqual(20.0, _squared_distance(lst, dv)) - self.assertEqual(0.0, _squared_distance(sv, sv)) - self.assertEqual(0.0, _squared_distance(dv, dv)) - self.assertEqual(0.0, _squared_distance(lst, lst)) - self.assertEqual(25.0, _squared_distance(sv, lst1)) - self.assertEqual(3.0, _squared_distance(sv, arr)) - self.assertEqual(3.0, _squared_distance(sv, narr)) - - def test_hash(self): - v1 = DenseVector([0.0, 1.0, 0.0, 5.5]) - v2 = SparseVector(4, [(1, 1.0), (3, 5.5)]) - v3 = DenseVector([0.0, 1.0, 0.0, 5.5]) - v4 = SparseVector(4, [(1, 1.0), (3, 2.5)]) - self.assertEqual(hash(v1), hash(v2)) - self.assertEqual(hash(v1), hash(v3)) - self.assertEqual(hash(v2), hash(v3)) - self.assertFalse(hash(v1) == hash(v4)) - self.assertFalse(hash(v2) == hash(v4)) - - def test_eq(self): - v1 = DenseVector([0.0, 1.0, 0.0, 5.5]) - v2 = SparseVector(4, [(1, 1.0), (3, 5.5)]) - v3 = DenseVector([0.0, 1.0, 0.0, 5.5]) - v4 = SparseVector(6, [(1, 1.0), (3, 5.5)]) - v5 = DenseVector([0.0, 1.0, 0.0, 2.5]) - v6 = SparseVector(4, [(1, 1.0), (3, 2.5)]) - self.assertEqual(v1, v2) - self.assertEqual(v1, v3) - self.assertFalse(v2 == v4) - self.assertFalse(v1 == v5) - self.assertFalse(v1 == v6) - - def test_equals(self): - indices = [1, 2, 4] - values = [1., 3., 2.] - self.assertTrue(Vectors._equals(indices, values, list(range(5)), [0., 1., 3., 0., 2.])) - self.assertFalse(Vectors._equals(indices, values, list(range(5)), [0., 3., 1., 0., 2.])) - self.assertFalse(Vectors._equals(indices, values, list(range(5)), [0., 3., 0., 2.])) - self.assertFalse(Vectors._equals(indices, values, list(range(5)), [0., 1., 3., 2., 2.])) - - def test_conversion(self): - # numpy arrays should be automatically upcast to float64 - # tests for fix of [SPARK-5089] - v = array([1, 2, 3, 4], dtype='float64') - dv = DenseVector(v) - self.assertTrue(dv.array.dtype == 'float64') - v = array([1, 2, 3, 4], dtype='float32') - dv = DenseVector(v) - self.assertTrue(dv.array.dtype == 'float64') - - def test_sparse_vector_indexing(self): - sv = SparseVector(5, {1: 1, 3: 2}) - self.assertEqual(sv[0], 0.) - self.assertEqual(sv[3], 2.) - self.assertEqual(sv[1], 1.) - self.assertEqual(sv[2], 0.) - self.assertEqual(sv[4], 0.) - self.assertEqual(sv[-1], 0.) - self.assertEqual(sv[-2], 2.) - self.assertEqual(sv[-3], 0.) - self.assertEqual(sv[-5], 0.) - for ind in [5, -6]: - self.assertRaises(IndexError, sv.__getitem__, ind) - for ind in [7.8, '1']: - self.assertRaises(TypeError, sv.__getitem__, ind) - - zeros = SparseVector(4, {}) - self.assertEqual(zeros[0], 0.0) - self.assertEqual(zeros[3], 0.0) - for ind in [4, -5]: - self.assertRaises(IndexError, zeros.__getitem__, ind) - - empty = SparseVector(0, {}) - for ind in [-1, 0, 1]: - self.assertRaises(IndexError, empty.__getitem__, ind) - - def test_sparse_vector_iteration(self): - self.assertListEqual(list(SparseVector(3, [], [])), [0.0, 0.0, 0.0]) - self.assertListEqual(list(SparseVector(5, [0, 3], [1.0, 2.0])), [1.0, 0.0, 0.0, 2.0, 0.0]) - - def test_matrix_indexing(self): - mat = DenseMatrix(3, 2, [0, 1, 4, 6, 8, 10]) - expected = [[0, 6], [1, 8], [4, 10]] - for i in range(3): - for j in range(2): - self.assertEqual(mat[i, j], expected[i][j]) - - for i, j in [(-1, 0), (4, 1), (3, 4)]: - self.assertRaises(IndexError, mat.__getitem__, (i, j)) - - def test_repr_dense_matrix(self): - mat = DenseMatrix(3, 2, [0, 1, 4, 6, 8, 10]) - self.assertTrue( - repr(mat), - 'DenseMatrix(3, 2, [0.0, 1.0, 4.0, 6.0, 8.0, 10.0], False)') - - mat = DenseMatrix(3, 2, [0, 1, 4, 6, 8, 10], True) - self.assertTrue( - repr(mat), - 'DenseMatrix(3, 2, [0.0, 1.0, 4.0, 6.0, 8.0, 10.0], False)') - - mat = DenseMatrix(6, 3, zeros(18)) - self.assertTrue( - repr(mat), - 'DenseMatrix(6, 3, [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ..., \ - 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], False)') - - def test_repr_sparse_matrix(self): - sm1t = SparseMatrix( - 3, 4, [0, 2, 3, 5], [0, 1, 2, 0, 2], [3.0, 2.0, 4.0, 9.0, 8.0], - isTransposed=True) - self.assertTrue( - repr(sm1t), - 'SparseMatrix(3, 4, [0, 2, 3, 5], [0, 1, 2, 0, 2], [3.0, 2.0, 4.0, 9.0, 8.0], True)') - - indices = tile(arange(6), 3) - values = ones(18) - sm = SparseMatrix(6, 3, [0, 6, 12, 18], indices, values) - self.assertTrue( - repr(sm), "SparseMatrix(6, 3, [0, 6, 12, 18], \ - [0, 1, 2, 3, 4, 5, 0, 1, ..., 4, 5, 0, 1, 2, 3, 4, 5], \ - [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, ..., \ - 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0], False)") - - self.assertTrue( - str(sm), - "6 X 3 CSCMatrix\n\ - (0,0) 1.0\n(1,0) 1.0\n(2,0) 1.0\n(3,0) 1.0\n(4,0) 1.0\n(5,0) 1.0\n\ - (0,1) 1.0\n(1,1) 1.0\n(2,1) 1.0\n(3,1) 1.0\n(4,1) 1.0\n(5,1) 1.0\n\ - (0,2) 1.0\n(1,2) 1.0\n(2,2) 1.0\n(3,2) 1.0\n..\n..") - - sm = SparseMatrix(1, 18, zeros(19), [], []) - self.assertTrue( - repr(sm), - 'SparseMatrix(1, 18, \ - [0, 0, 0, 0, 0, 0, 0, 0, ..., 0, 0, 0, 0, 0, 0, 0, 0], [], [], False)') - - def test_sparse_matrix(self): - # Test sparse matrix creation. - sm1 = SparseMatrix( - 3, 4, [0, 2, 2, 4, 4], [1, 2, 1, 2], [1.0, 2.0, 4.0, 5.0]) - self.assertEqual(sm1.numRows, 3) - self.assertEqual(sm1.numCols, 4) - self.assertEqual(sm1.colPtrs.tolist(), [0, 2, 2, 4, 4]) - self.assertEqual(sm1.rowIndices.tolist(), [1, 2, 1, 2]) - self.assertEqual(sm1.values.tolist(), [1.0, 2.0, 4.0, 5.0]) - self.assertTrue( - repr(sm1), - 'SparseMatrix(3, 4, [0, 2, 2, 4, 4], [1, 2, 1, 2], [1.0, 2.0, 4.0, 5.0], False)') - - # Test indexing - expected = [ - [0, 0, 0, 0], - [1, 0, 4, 0], - [2, 0, 5, 0]] - - for i in range(3): - for j in range(4): - self.assertEqual(expected[i][j], sm1[i, j]) - self.assertTrue(array_equal(sm1.toArray(), expected)) - - for i, j in [(-1, 1), (4, 3), (3, 5)]: - self.assertRaises(IndexError, sm1.__getitem__, (i, j)) - - # Test conversion to dense and sparse. - smnew = sm1.toDense().toSparse() - self.assertEqual(sm1.numRows, smnew.numRows) - self.assertEqual(sm1.numCols, smnew.numCols) - self.assertTrue(array_equal(sm1.colPtrs, smnew.colPtrs)) - self.assertTrue(array_equal(sm1.rowIndices, smnew.rowIndices)) - self.assertTrue(array_equal(sm1.values, smnew.values)) - - sm1t = SparseMatrix( - 3, 4, [0, 2, 3, 5], [0, 1, 2, 0, 2], [3.0, 2.0, 4.0, 9.0, 8.0], - isTransposed=True) - self.assertEqual(sm1t.numRows, 3) - self.assertEqual(sm1t.numCols, 4) - self.assertEqual(sm1t.colPtrs.tolist(), [0, 2, 3, 5]) - self.assertEqual(sm1t.rowIndices.tolist(), [0, 1, 2, 0, 2]) - self.assertEqual(sm1t.values.tolist(), [3.0, 2.0, 4.0, 9.0, 8.0]) - - expected = [ - [3, 2, 0, 0], - [0, 0, 4, 0], - [9, 0, 8, 0]] - - for i in range(3): - for j in range(4): - self.assertEqual(expected[i][j], sm1t[i, j]) - self.assertTrue(array_equal(sm1t.toArray(), expected)) - - def test_dense_matrix_is_transposed(self): - mat1 = DenseMatrix(3, 2, [0, 4, 1, 6, 3, 9], isTransposed=True) - mat = DenseMatrix(3, 2, [0, 1, 3, 4, 6, 9]) - self.assertEqual(mat1, mat) - - expected = [[0, 4], [1, 6], [3, 9]] - for i in range(3): - for j in range(2): - self.assertEqual(mat1[i, j], expected[i][j]) - self.assertTrue(array_equal(mat1.toArray(), expected)) - - sm = mat1.toSparse() - self.assertTrue(array_equal(sm.rowIndices, [1, 2, 0, 1, 2])) - self.assertTrue(array_equal(sm.colPtrs, [0, 2, 5])) - self.assertTrue(array_equal(sm.values, [1, 3, 4, 6, 9])) - - def test_norms(self): - a = DenseVector([0, 2, 3, -1]) - self.assertAlmostEqual(a.norm(2), 3.742, 3) - self.assertTrue(a.norm(1), 6) - self.assertTrue(a.norm(inf), 3) - a = SparseVector(4, [0, 2], [3, -4]) - self.assertAlmostEqual(a.norm(2), 5) - self.assertTrue(a.norm(1), 7) - self.assertTrue(a.norm(inf), 4) - - tmp = SparseVector(4, [0, 2], [3, 0]) - self.assertEqual(tmp.numNonzeros(), 1) - - -class VectorUDTTests(MLlibTestCase): - - dv0 = DenseVector([]) - dv1 = DenseVector([1.0, 2.0]) - sv0 = SparseVector(2, [], []) - sv1 = SparseVector(2, [1], [2.0]) - udt = VectorUDT() - - def test_json_schema(self): - self.assertEqual(VectorUDT.fromJson(self.udt.jsonValue()), self.udt) - - def test_serialization(self): - for v in [self.dv0, self.dv1, self.sv0, self.sv1]: - self.assertEqual(v, self.udt.deserialize(self.udt.serialize(v))) - - def test_infer_schema(self): - rdd = self.sc.parallelize([Row(label=1.0, features=self.dv1), - Row(label=0.0, features=self.sv1)]) - df = rdd.toDF() - schema = df.schema - field = [f for f in schema.fields if f.name == "features"][0] - self.assertEqual(field.dataType, self.udt) - vectors = df.rdd.map(lambda p: p.features).collect() - self.assertEqual(len(vectors), 2) - for v in vectors: - if isinstance(v, SparseVector): - self.assertEqual(v, self.sv1) - elif isinstance(v, DenseVector): - self.assertEqual(v, self.dv1) - else: - raise TypeError("expecting a vector but got %r of type %r" % (v, type(v))) - - -class MatrixUDTTests(MLlibTestCase): - - dm1 = DenseMatrix(3, 2, [0, 1, 4, 5, 9, 10]) - dm2 = DenseMatrix(3, 2, [0, 1, 4, 5, 9, 10], isTransposed=True) - sm1 = SparseMatrix(1, 1, [0, 1], [0], [2.0]) - sm2 = SparseMatrix(2, 1, [0, 0, 1], [0], [5.0], isTransposed=True) - udt = MatrixUDT() - - def test_json_schema(self): - self.assertEqual(MatrixUDT.fromJson(self.udt.jsonValue()), self.udt) - - def test_serialization(self): - for m in [self.dm1, self.dm2, self.sm1, self.sm2]: - self.assertEqual(m, self.udt.deserialize(self.udt.serialize(m))) - - def test_infer_schema(self): - rdd = self.sc.parallelize([("dense", self.dm1), ("sparse", self.sm1)]) - df = rdd.toDF() - schema = df.schema - self.assertTrue(schema.fields[1].dataType, self.udt) - matrices = df.rdd.map(lambda x: x._2).collect() - self.assertEqual(len(matrices), 2) - for m in matrices: - if isinstance(m, DenseMatrix): - self.assertTrue(m, self.dm1) - elif isinstance(m, SparseMatrix): - self.assertTrue(m, self.sm1) - else: - raise ValueError("Expected a matrix but got type %r" % type(m)) - - -class WrapperTests(MLlibTestCase): - - def test_new_java_array(self): - # test array of strings - str_list = ["a", "b", "c"] - java_class = self.sc._gateway.jvm.java.lang.String - java_array = JavaWrapper._new_java_array(str_list, java_class) - self.assertEqual(_java2py(self.sc, java_array), str_list) - # test array of integers - int_list = [1, 2, 3] - java_class = self.sc._gateway.jvm.java.lang.Integer - java_array = JavaWrapper._new_java_array(int_list, java_class) - self.assertEqual(_java2py(self.sc, java_array), int_list) - # test array of floats - float_list = [0.1, 0.2, 0.3] - java_class = self.sc._gateway.jvm.java.lang.Double - java_array = JavaWrapper._new_java_array(float_list, java_class) - self.assertEqual(_java2py(self.sc, java_array), float_list) - # test array of bools - bool_list = [False, True, True] - java_class = self.sc._gateway.jvm.java.lang.Boolean - java_array = JavaWrapper._new_java_array(bool_list, java_class) - self.assertEqual(_java2py(self.sc, java_array), bool_list) - # test array of Java DenseVectors - v1 = DenseVector([0.0, 1.0]) - v2 = DenseVector([1.0, 0.0]) - vec_java_list = [_py2java(self.sc, v1), _py2java(self.sc, v2)] - java_class = self.sc._gateway.jvm.org.apache.spark.ml.linalg.DenseVector - java_array = JavaWrapper._new_java_array(vec_java_list, java_class) - self.assertEqual(_java2py(self.sc, java_array), [v1, v2]) - # test empty array - java_class = self.sc._gateway.jvm.java.lang.Integer - java_array = JavaWrapper._new_java_array([], java_class) - self.assertEqual(_java2py(self.sc, java_array), []) - - -class ChiSquareTestTests(SparkSessionTestCase): - - def test_chisquaretest(self): - data = [[0, Vectors.dense([0, 1, 2])], - [1, Vectors.dense([1, 1, 1])], - [2, Vectors.dense([2, 1, 0])]] - df = self.spark.createDataFrame(data, ['label', 'feat']) - res = ChiSquareTest.test(df, 'feat', 'label') - # This line is hitting the collect bug described in #17218, commented for now. - # pValues = res.select("degreesOfFreedom").collect()) - self.assertIsInstance(res, DataFrame) - fieldNames = set(field.name for field in res.schema.fields) - expectedFields = ["pValues", "degreesOfFreedom", "statistics"] - self.assertTrue(all(field in fieldNames for field in expectedFields)) - - -class UnaryTransformerTests(SparkSessionTestCase): - - def test_unary_transformer_validate_input_type(self): - shiftVal = 3 - transformer = MockUnaryTransformer(shiftVal=shiftVal)\ - .setInputCol("input").setOutputCol("output") - - # should not raise any errors - transformer.validateInputType(DoubleType()) - - with self.assertRaises(TypeError): - # passing the wrong input type should raise an error - transformer.validateInputType(IntegerType()) - - def test_unary_transformer_transform(self): - shiftVal = 3 - transformer = MockUnaryTransformer(shiftVal=shiftVal)\ - .setInputCol("input").setOutputCol("output") - - df = self.spark.range(0, 10).toDF('input') - df = df.withColumn("input", df.input.cast(dataType="double")) - - transformed_df = transformer.transform(df) - results = transformed_df.select("input", "output").collect() - - for res in results: - self.assertEqual(res.input + shiftVal, res.output) - - -class EstimatorTest(unittest.TestCase): - - def testDefaultFitMultiple(self): - N = 4 - data = MockDataset() - estimator = MockEstimator() - params = [{estimator.fake: i} for i in range(N)] - modelIter = estimator.fitMultiple(data, params) - indexList = [] - for index, model in modelIter: - self.assertEqual(model.getFake(), index) - indexList.append(index) - self.assertEqual(sorted(indexList), list(range(N))) - - -if __name__ == "__main__": - from pyspark.ml.tests import * - - runner = unishark.BufferedTestRunner( - reporters=[unishark.XUnitReporter('target/test-reports/pyspark.ml_{}'.format( - os.path.basename(os.environ.get("PYSPARK_PYTHON", ""))))]) - unittest.main(testRunner=runner, verbosity=2) diff --git a/python/pyspark/ml/tests/__init__.py b/python/pyspark/ml/tests/__init__.py new file mode 100644 index 0000000000000..cce3acad34a49 --- /dev/null +++ b/python/pyspark/ml/tests/__init__.py @@ -0,0 +1,16 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# diff --git a/python/pyspark/ml/tests/test_algorithms.py b/python/pyspark/ml/tests/test_algorithms.py new file mode 100644 index 0000000000000..516bb563402e0 --- /dev/null +++ b/python/pyspark/ml/tests/test_algorithms.py @@ -0,0 +1,340 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from shutil import rmtree +import tempfile +import unittest + +import numpy as np + +from pyspark.ml.classification import DecisionTreeClassifier, LogisticRegression, \ + MultilayerPerceptronClassifier, OneVsRest +from pyspark.ml.clustering import DistributedLDAModel, KMeans, LocalLDAModel, LDA, LDAModel +from pyspark.ml.fpm import FPGrowth +from pyspark.ml.linalg import Matrices, Vectors +from pyspark.ml.recommendation import ALS +from pyspark.ml.regression import GeneralizedLinearRegression, LinearRegression +from pyspark.sql import Row +from pyspark.testing.mlutils import SparkSessionTestCase + + +class LogisticRegressionTest(SparkSessionTestCase): + + def test_binomial_logistic_regression_with_bound(self): + + df = self.spark.createDataFrame( + [(1.0, 1.0, Vectors.dense(0.0, 5.0)), + (0.0, 2.0, Vectors.dense(1.0, 2.0)), + (1.0, 3.0, Vectors.dense(2.0, 1.0)), + (0.0, 4.0, Vectors.dense(3.0, 3.0)), ], ["label", "weight", "features"]) + + lor = LogisticRegression(regParam=0.01, weightCol="weight", + lowerBoundsOnCoefficients=Matrices.dense(1, 2, [-1.0, -1.0]), + upperBoundsOnIntercepts=Vectors.dense(0.0)) + model = lor.fit(df) + self.assertTrue( + np.allclose(model.coefficients.toArray(), [-0.2944, -0.0484], atol=1E-4)) + self.assertTrue(np.isclose(model.intercept, 0.0, atol=1E-4)) + + def test_multinomial_logistic_regression_with_bound(self): + + data_path = "data/mllib/sample_multiclass_classification_data.txt" + df = self.spark.read.format("libsvm").load(data_path) + + lor = LogisticRegression(regParam=0.01, + lowerBoundsOnCoefficients=Matrices.dense(3, 4, range(12)), + upperBoundsOnIntercepts=Vectors.dense(0.0, 0.0, 0.0)) + model = lor.fit(df) + expected = [[4.593, 4.5516, 9.0099, 12.2904], + [1.0, 8.1093, 7.0, 10.0], + [3.041, 5.0, 8.0, 11.0]] + for i in range(0, len(expected)): + self.assertTrue( + np.allclose(model.coefficientMatrix.toArray()[i], expected[i], atol=1E-4)) + self.assertTrue( + np.allclose(model.interceptVector.toArray(), [-0.9057, -1.1392, -0.0033], atol=1E-4)) + + +class MultilayerPerceptronClassifierTest(SparkSessionTestCase): + + def test_raw_and_probability_prediction(self): + + data_path = "data/mllib/sample_multiclass_classification_data.txt" + df = self.spark.read.format("libsvm").load(data_path) + + mlp = MultilayerPerceptronClassifier(maxIter=100, layers=[4, 5, 4, 3], + blockSize=128, seed=123) + model = mlp.fit(df) + test = self.sc.parallelize([Row(features=Vectors.dense(0.1, 0.1, 0.25, 0.25))]).toDF() + result = model.transform(test).head() + expected_prediction = 2.0 + expected_probability = [0.0, 0.0, 1.0] + expected_rawPrediction = [57.3955, -124.5462, 67.9943] + self.assertTrue(result.prediction, expected_prediction) + self.assertTrue(np.allclose(result.probability, expected_probability, atol=1E-4)) + self.assertTrue(np.allclose(result.rawPrediction, expected_rawPrediction, atol=1E-4)) + + +class OneVsRestTests(SparkSessionTestCase): + + def test_copy(self): + df = self.spark.createDataFrame([(0.0, Vectors.dense(1.0, 0.8)), + (1.0, Vectors.sparse(2, [], [])), + (2.0, Vectors.dense(0.5, 0.5))], + ["label", "features"]) + lr = LogisticRegression(maxIter=5, regParam=0.01) + ovr = OneVsRest(classifier=lr) + ovr1 = ovr.copy({lr.maxIter: 10}) + self.assertEqual(ovr.getClassifier().getMaxIter(), 5) + self.assertEqual(ovr1.getClassifier().getMaxIter(), 10) + model = ovr.fit(df) + model1 = model.copy({model.predictionCol: "indexed"}) + self.assertEqual(model1.getPredictionCol(), "indexed") + + def test_output_columns(self): + df = self.spark.createDataFrame([(0.0, Vectors.dense(1.0, 0.8)), + (1.0, Vectors.sparse(2, [], [])), + (2.0, Vectors.dense(0.5, 0.5))], + ["label", "features"]) + lr = LogisticRegression(maxIter=5, regParam=0.01) + ovr = OneVsRest(classifier=lr, parallelism=1) + model = ovr.fit(df) + output = model.transform(df) + self.assertEqual(output.columns, ["label", "features", "prediction"]) + + def test_parallelism_doesnt_change_output(self): + df = self.spark.createDataFrame([(0.0, Vectors.dense(1.0, 0.8)), + (1.0, Vectors.sparse(2, [], [])), + (2.0, Vectors.dense(0.5, 0.5))], + ["label", "features"]) + ovrPar1 = OneVsRest(classifier=LogisticRegression(maxIter=5, regParam=.01), parallelism=1) + modelPar1 = ovrPar1.fit(df) + ovrPar2 = OneVsRest(classifier=LogisticRegression(maxIter=5, regParam=.01), parallelism=2) + modelPar2 = ovrPar2.fit(df) + for i, model in enumerate(modelPar1.models): + self.assertTrue(np.allclose(model.coefficients.toArray(), + modelPar2.models[i].coefficients.toArray(), atol=1E-4)) + self.assertTrue(np.allclose(model.intercept, modelPar2.models[i].intercept, atol=1E-4)) + + def test_support_for_weightCol(self): + df = self.spark.createDataFrame([(0.0, Vectors.dense(1.0, 0.8), 1.0), + (1.0, Vectors.sparse(2, [], []), 1.0), + (2.0, Vectors.dense(0.5, 0.5), 1.0)], + ["label", "features", "weight"]) + # classifier inherits hasWeightCol + lr = LogisticRegression(maxIter=5, regParam=0.01) + ovr = OneVsRest(classifier=lr, weightCol="weight") + self.assertIsNotNone(ovr.fit(df)) + # classifier doesn't inherit hasWeightCol + dt = DecisionTreeClassifier() + ovr2 = OneVsRest(classifier=dt, weightCol="weight") + self.assertIsNotNone(ovr2.fit(df)) + + +class KMeansTests(SparkSessionTestCase): + + def test_kmeans_cosine_distance(self): + data = [(Vectors.dense([1.0, 1.0]),), (Vectors.dense([10.0, 10.0]),), + (Vectors.dense([1.0, 0.5]),), (Vectors.dense([10.0, 4.4]),), + (Vectors.dense([-1.0, 1.0]),), (Vectors.dense([-100.0, 90.0]),)] + df = self.spark.createDataFrame(data, ["features"]) + kmeans = KMeans(k=3, seed=1, distanceMeasure="cosine") + model = kmeans.fit(df) + result = model.transform(df).collect() + self.assertTrue(result[0].prediction == result[1].prediction) + self.assertTrue(result[2].prediction == result[3].prediction) + self.assertTrue(result[4].prediction == result[5].prediction) + + +class LDATest(SparkSessionTestCase): + + def _compare(self, m1, m2): + """ + Temp method for comparing instances. + TODO: Replace with generic implementation once SPARK-14706 is merged. + """ + self.assertEqual(m1.uid, m2.uid) + self.assertEqual(type(m1), type(m2)) + self.assertEqual(len(m1.params), len(m2.params)) + for p in m1.params: + if m1.isDefined(p): + self.assertEqual(m1.getOrDefault(p), m2.getOrDefault(p)) + self.assertEqual(p.parent, m2.getParam(p.name).parent) + if isinstance(m1, LDAModel): + self.assertEqual(m1.vocabSize(), m2.vocabSize()) + self.assertEqual(m1.topicsMatrix(), m2.topicsMatrix()) + + def test_persistence(self): + # Test save/load for LDA, LocalLDAModel, DistributedLDAModel. + df = self.spark.createDataFrame([ + [1, Vectors.dense([0.0, 1.0])], + [2, Vectors.sparse(2, {0: 1.0})], + ], ["id", "features"]) + # Fit model + lda = LDA(k=2, seed=1, optimizer="em") + distributedModel = lda.fit(df) + self.assertTrue(distributedModel.isDistributed()) + localModel = distributedModel.toLocal() + self.assertFalse(localModel.isDistributed()) + # Define paths + path = tempfile.mkdtemp() + lda_path = path + "/lda" + dist_model_path = path + "/distLDAModel" + local_model_path = path + "/localLDAModel" + # Test LDA + lda.save(lda_path) + lda2 = LDA.load(lda_path) + self._compare(lda, lda2) + # Test DistributedLDAModel + distributedModel.save(dist_model_path) + distributedModel2 = DistributedLDAModel.load(dist_model_path) + self._compare(distributedModel, distributedModel2) + # Test LocalLDAModel + localModel.save(local_model_path) + localModel2 = LocalLDAModel.load(local_model_path) + self._compare(localModel, localModel2) + # Clean up + try: + rmtree(path) + except OSError: + pass + + +class FPGrowthTests(SparkSessionTestCase): + def setUp(self): + super(FPGrowthTests, self).setUp() + self.data = self.spark.createDataFrame( + [([1, 2], ), ([1, 2], ), ([1, 2, 3], ), ([1, 3], )], + ["items"]) + + def test_association_rules(self): + fp = FPGrowth() + fpm = fp.fit(self.data) + + expected_association_rules = self.spark.createDataFrame( + [([3], [1], 1.0, 1.0), ([2], [1], 1.0, 1.0)], + ["antecedent", "consequent", "confidence", "lift"] + ) + actual_association_rules = fpm.associationRules + + self.assertEqual(actual_association_rules.subtract(expected_association_rules).count(), 0) + self.assertEqual(expected_association_rules.subtract(actual_association_rules).count(), 0) + + def test_freq_itemsets(self): + fp = FPGrowth() + fpm = fp.fit(self.data) + + expected_freq_itemsets = self.spark.createDataFrame( + [([1], 4), ([2], 3), ([2, 1], 3), ([3], 2), ([3, 1], 2)], + ["items", "freq"] + ) + actual_freq_itemsets = fpm.freqItemsets + + self.assertEqual(actual_freq_itemsets.subtract(expected_freq_itemsets).count(), 0) + self.assertEqual(expected_freq_itemsets.subtract(actual_freq_itemsets).count(), 0) + + def tearDown(self): + del self.data + + +class ALSTest(SparkSessionTestCase): + + def test_storage_levels(self): + df = self.spark.createDataFrame( + [(0, 0, 4.0), (0, 1, 2.0), (1, 1, 3.0), (1, 2, 4.0), (2, 1, 1.0), (2, 2, 5.0)], + ["user", "item", "rating"]) + als = ALS().setMaxIter(1).setRank(1) + # test default params + als.fit(df) + self.assertEqual(als.getIntermediateStorageLevel(), "MEMORY_AND_DISK") + self.assertEqual(als._java_obj.getIntermediateStorageLevel(), "MEMORY_AND_DISK") + self.assertEqual(als.getFinalStorageLevel(), "MEMORY_AND_DISK") + self.assertEqual(als._java_obj.getFinalStorageLevel(), "MEMORY_AND_DISK") + # test non-default params + als.setIntermediateStorageLevel("MEMORY_ONLY_2") + als.setFinalStorageLevel("DISK_ONLY") + als.fit(df) + self.assertEqual(als.getIntermediateStorageLevel(), "MEMORY_ONLY_2") + self.assertEqual(als._java_obj.getIntermediateStorageLevel(), "MEMORY_ONLY_2") + self.assertEqual(als.getFinalStorageLevel(), "DISK_ONLY") + self.assertEqual(als._java_obj.getFinalStorageLevel(), "DISK_ONLY") + + +class GeneralizedLinearRegressionTest(SparkSessionTestCase): + + def test_tweedie_distribution(self): + + df = self.spark.createDataFrame( + [(1.0, Vectors.dense(0.0, 0.0)), + (1.0, Vectors.dense(1.0, 2.0)), + (2.0, Vectors.dense(0.0, 0.0)), + (2.0, Vectors.dense(1.0, 1.0)), ], ["label", "features"]) + + glr = GeneralizedLinearRegression(family="tweedie", variancePower=1.6) + model = glr.fit(df) + self.assertTrue(np.allclose(model.coefficients.toArray(), [-0.4645, 0.3402], atol=1E-4)) + self.assertTrue(np.isclose(model.intercept, 0.7841, atol=1E-4)) + + model2 = glr.setLinkPower(-1.0).fit(df) + self.assertTrue(np.allclose(model2.coefficients.toArray(), [-0.6667, 0.5], atol=1E-4)) + self.assertTrue(np.isclose(model2.intercept, 0.6667, atol=1E-4)) + + def test_offset(self): + + df = self.spark.createDataFrame( + [(0.2, 1.0, 2.0, Vectors.dense(0.0, 5.0)), + (0.5, 2.1, 0.5, Vectors.dense(1.0, 2.0)), + (0.9, 0.4, 1.0, Vectors.dense(2.0, 1.0)), + (0.7, 0.7, 0.0, Vectors.dense(3.0, 3.0))], ["label", "weight", "offset", "features"]) + + glr = GeneralizedLinearRegression(family="poisson", weightCol="weight", offsetCol="offset") + model = glr.fit(df) + self.assertTrue(np.allclose(model.coefficients.toArray(), [0.664647, -0.3192581], + atol=1E-4)) + self.assertTrue(np.isclose(model.intercept, -1.561613, atol=1E-4)) + + +class LinearRegressionTest(SparkSessionTestCase): + + def test_linear_regression_with_huber_loss(self): + + data_path = "data/mllib/sample_linear_regression_data.txt" + df = self.spark.read.format("libsvm").load(data_path) + + lir = LinearRegression(loss="huber", epsilon=2.0) + model = lir.fit(df) + + expectedCoefficients = [0.136, 0.7648, -0.7761, 2.4236, 0.537, + 1.2612, -0.333, -0.5694, -0.6311, 0.6053] + expectedIntercept = 0.1607 + expectedScale = 9.758 + + self.assertTrue( + np.allclose(model.coefficients.toArray(), expectedCoefficients, atol=1E-3)) + self.assertTrue(np.isclose(model.intercept, expectedIntercept, atol=1E-3)) + self.assertTrue(np.isclose(model.scale, expectedScale, atol=1E-3)) + + +if __name__ == "__main__": + from pyspark.ml.tests.test_algorithms import * + + try: + import xmlrunner + testRunner = xmlrunner.XMLTestRunner(output='target/test-reports') + except ImportError: + testRunner = None + unittest.main(testRunner=testRunner, verbosity=2) diff --git a/python/pyspark/ml/tests/test_base.py b/python/pyspark/ml/tests/test_base.py new file mode 100644 index 0000000000000..31e3deb53046c --- /dev/null +++ b/python/pyspark/ml/tests/test_base.py @@ -0,0 +1,77 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import unittest + +from pyspark.sql.types import DoubleType, IntegerType +from pyspark.testing.mlutils import MockDataset, MockEstimator, MockUnaryTransformer, \ + SparkSessionTestCase + + +class UnaryTransformerTests(SparkSessionTestCase): + + def test_unary_transformer_validate_input_type(self): + shiftVal = 3 + transformer = MockUnaryTransformer(shiftVal=shiftVal) \ + .setInputCol("input").setOutputCol("output") + + # should not raise any errors + transformer.validateInputType(DoubleType()) + + with self.assertRaises(TypeError): + # passing the wrong input type should raise an error + transformer.validateInputType(IntegerType()) + + def test_unary_transformer_transform(self): + shiftVal = 3 + transformer = MockUnaryTransformer(shiftVal=shiftVal) \ + .setInputCol("input").setOutputCol("output") + + df = self.spark.range(0, 10).toDF('input') + df = df.withColumn("input", df.input.cast(dataType="double")) + + transformed_df = transformer.transform(df) + results = transformed_df.select("input", "output").collect() + + for res in results: + self.assertEqual(res.input + shiftVal, res.output) + + +class EstimatorTest(unittest.TestCase): + + def testDefaultFitMultiple(self): + N = 4 + data = MockDataset() + estimator = MockEstimator() + params = [{estimator.fake: i} for i in range(N)] + modelIter = estimator.fitMultiple(data, params) + indexList = [] + for index, model in modelIter: + self.assertEqual(model.getFake(), index) + indexList.append(index) + self.assertEqual(sorted(indexList), list(range(N))) + + +if __name__ == "__main__": + from pyspark.ml.tests.test_base import * + + try: + import xmlrunner + testRunner = xmlrunner.XMLTestRunner(output='target/test-reports') + except ImportError: + testRunner = None + unittest.main(testRunner=testRunner, verbosity=2) diff --git a/python/pyspark/ml/tests/test_evaluation.py b/python/pyspark/ml/tests/test_evaluation.py new file mode 100644 index 0000000000000..5438455a6f756 --- /dev/null +++ b/python/pyspark/ml/tests/test_evaluation.py @@ -0,0 +1,63 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import unittest + +import numpy as np + +from pyspark.ml.evaluation import ClusteringEvaluator, RegressionEvaluator +from pyspark.ml.linalg import Vectors +from pyspark.sql import Row +from pyspark.testing.mlutils import SparkSessionTestCase + + +class EvaluatorTests(SparkSessionTestCase): + + def test_java_params(self): + """ + This tests a bug fixed by SPARK-18274 which causes multiple copies + of a Params instance in Python to be linked to the same Java instance. + """ + evaluator = RegressionEvaluator(metricName="r2") + df = self.spark.createDataFrame([Row(label=1.0, prediction=1.1)]) + evaluator.evaluate(df) + self.assertEqual(evaluator._java_obj.getMetricName(), "r2") + evaluatorCopy = evaluator.copy({evaluator.metricName: "mae"}) + evaluator.evaluate(df) + evaluatorCopy.evaluate(df) + self.assertEqual(evaluator._java_obj.getMetricName(), "r2") + self.assertEqual(evaluatorCopy._java_obj.getMetricName(), "mae") + + def test_clustering_evaluator_with_cosine_distance(self): + featureAndPredictions = map(lambda x: (Vectors.dense(x[0]), x[1]), + [([1.0, 1.0], 1.0), ([10.0, 10.0], 1.0), ([1.0, 0.5], 2.0), + ([10.0, 4.4], 2.0), ([-1.0, 1.0], 3.0), ([-100.0, 90.0], 3.0)]) + dataset = self.spark.createDataFrame(featureAndPredictions, ["features", "prediction"]) + evaluator = ClusteringEvaluator(predictionCol="prediction", distanceMeasure="cosine") + self.assertEqual(evaluator.getDistanceMeasure(), "cosine") + self.assertTrue(np.isclose(evaluator.evaluate(dataset), 0.992671213, atol=1e-5)) + + +if __name__ == "__main__": + from pyspark.ml.tests.test_evaluation import * + + try: + import xmlrunner + testRunner = xmlrunner.XMLTestRunner(output='target/test-reports') + except ImportError: + testRunner = None + unittest.main(testRunner=testRunner, verbosity=2) diff --git a/python/pyspark/ml/tests/test_feature.py b/python/pyspark/ml/tests/test_feature.py new file mode 100644 index 0000000000000..325feaba66957 --- /dev/null +++ b/python/pyspark/ml/tests/test_feature.py @@ -0,0 +1,311 @@ +# -*- coding: utf-8 -*- +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import sys +import unittest + +if sys.version > '3': + basestring = str + +from pyspark.ml.feature import Binarizer, CountVectorizer, CountVectorizerModel, HashingTF, IDF, \ + NGram, RFormula, StopWordsRemover, StringIndexer, StringIndexerModel, VectorSizeHint +from pyspark.ml.linalg import DenseVector, SparseVector, Vectors +from pyspark.sql import Row +from pyspark.testing.utils import QuietTest +from pyspark.testing.mlutils import check_params, SparkSessionTestCase + + +class FeatureTests(SparkSessionTestCase): + + def test_binarizer(self): + b0 = Binarizer() + self.assertListEqual(b0.params, [b0.inputCol, b0.outputCol, b0.threshold]) + self.assertTrue(all([~b0.isSet(p) for p in b0.params])) + self.assertTrue(b0.hasDefault(b0.threshold)) + self.assertEqual(b0.getThreshold(), 0.0) + b0.setParams(inputCol="input", outputCol="output").setThreshold(1.0) + self.assertTrue(all([b0.isSet(p) for p in b0.params])) + self.assertEqual(b0.getThreshold(), 1.0) + self.assertEqual(b0.getInputCol(), "input") + self.assertEqual(b0.getOutputCol(), "output") + + b0c = b0.copy({b0.threshold: 2.0}) + self.assertEqual(b0c.uid, b0.uid) + self.assertListEqual(b0c.params, b0.params) + self.assertEqual(b0c.getThreshold(), 2.0) + + b1 = Binarizer(threshold=2.0, inputCol="input", outputCol="output") + self.assertNotEqual(b1.uid, b0.uid) + self.assertEqual(b1.getThreshold(), 2.0) + self.assertEqual(b1.getInputCol(), "input") + self.assertEqual(b1.getOutputCol(), "output") + + def test_idf(self): + dataset = self.spark.createDataFrame([ + (DenseVector([1.0, 2.0]),), + (DenseVector([0.0, 1.0]),), + (DenseVector([3.0, 0.2]),)], ["tf"]) + idf0 = IDF(inputCol="tf") + self.assertListEqual(idf0.params, [idf0.inputCol, idf0.minDocFreq, idf0.outputCol]) + idf0m = idf0.fit(dataset, {idf0.outputCol: "idf"}) + self.assertEqual(idf0m.uid, idf0.uid, + "Model should inherit the UID from its parent estimator.") + output = idf0m.transform(dataset) + self.assertIsNotNone(output.head().idf) + # Test that parameters transferred to Python Model + check_params(self, idf0m) + + def test_ngram(self): + dataset = self.spark.createDataFrame([ + Row(input=["a", "b", "c", "d", "e"])]) + ngram0 = NGram(n=4, inputCol="input", outputCol="output") + self.assertEqual(ngram0.getN(), 4) + self.assertEqual(ngram0.getInputCol(), "input") + self.assertEqual(ngram0.getOutputCol(), "output") + transformedDF = ngram0.transform(dataset) + self.assertEqual(transformedDF.head().output, ["a b c d", "b c d e"]) + + def test_stopwordsremover(self): + dataset = self.spark.createDataFrame([Row(input=["a", "panda"])]) + stopWordRemover = StopWordsRemover(inputCol="input", outputCol="output") + # Default + self.assertEqual(stopWordRemover.getInputCol(), "input") + transformedDF = stopWordRemover.transform(dataset) + self.assertEqual(transformedDF.head().output, ["panda"]) + self.assertEqual(type(stopWordRemover.getStopWords()), list) + self.assertTrue(isinstance(stopWordRemover.getStopWords()[0], basestring)) + # Custom + stopwords = ["panda"] + stopWordRemover.setStopWords(stopwords) + self.assertEqual(stopWordRemover.getInputCol(), "input") + self.assertEqual(stopWordRemover.getStopWords(), stopwords) + transformedDF = stopWordRemover.transform(dataset) + self.assertEqual(transformedDF.head().output, ["a"]) + # with language selection + stopwords = StopWordsRemover.loadDefaultStopWords("turkish") + dataset = self.spark.createDataFrame([Row(input=["acaba", "ama", "biri"])]) + stopWordRemover.setStopWords(stopwords) + self.assertEqual(stopWordRemover.getStopWords(), stopwords) + transformedDF = stopWordRemover.transform(dataset) + self.assertEqual(transformedDF.head().output, []) + # with locale + stopwords = ["BELKİ"] + dataset = self.spark.createDataFrame([Row(input=["belki"])]) + stopWordRemover.setStopWords(stopwords).setLocale("tr") + self.assertEqual(stopWordRemover.getStopWords(), stopwords) + transformedDF = stopWordRemover.transform(dataset) + self.assertEqual(transformedDF.head().output, []) + + def test_count_vectorizer_with_binary(self): + dataset = self.spark.createDataFrame([ + (0, "a a a b b c".split(' '), SparseVector(3, {0: 1.0, 1: 1.0, 2: 1.0}),), + (1, "a a".split(' '), SparseVector(3, {0: 1.0}),), + (2, "a b".split(' '), SparseVector(3, {0: 1.0, 1: 1.0}),), + (3, "c".split(' '), SparseVector(3, {2: 1.0}),)], ["id", "words", "expected"]) + cv = CountVectorizer(binary=True, inputCol="words", outputCol="features") + model = cv.fit(dataset) + + transformedList = model.transform(dataset).select("features", "expected").collect() + + for r in transformedList: + feature, expected = r + self.assertEqual(feature, expected) + + def test_count_vectorizer_with_maxDF(self): + dataset = self.spark.createDataFrame([ + (0, "a b c d".split(' '), SparseVector(3, {0: 1.0, 1: 1.0, 2: 1.0}),), + (1, "a b c".split(' '), SparseVector(3, {0: 1.0, 1: 1.0}),), + (2, "a b".split(' '), SparseVector(3, {0: 1.0}),), + (3, "a".split(' '), SparseVector(3, {}),)], ["id", "words", "expected"]) + cv = CountVectorizer(inputCol="words", outputCol="features") + model1 = cv.setMaxDF(3).fit(dataset) + self.assertEqual(model1.vocabulary, ['b', 'c', 'd']) + + transformedList1 = model1.transform(dataset).select("features", "expected").collect() + + for r in transformedList1: + feature, expected = r + self.assertEqual(feature, expected) + + model2 = cv.setMaxDF(0.75).fit(dataset) + self.assertEqual(model2.vocabulary, ['b', 'c', 'd']) + + transformedList2 = model2.transform(dataset).select("features", "expected").collect() + + for r in transformedList2: + feature, expected = r + self.assertEqual(feature, expected) + + def test_count_vectorizer_from_vocab(self): + model = CountVectorizerModel.from_vocabulary(["a", "b", "c"], inputCol="words", + outputCol="features", minTF=2) + self.assertEqual(model.vocabulary, ["a", "b", "c"]) + self.assertEqual(model.getMinTF(), 2) + + dataset = self.spark.createDataFrame([ + (0, "a a a b b c".split(' '), SparseVector(3, {0: 3.0, 1: 2.0}),), + (1, "a a".split(' '), SparseVector(3, {0: 2.0}),), + (2, "a b".split(' '), SparseVector(3, {}),)], ["id", "words", "expected"]) + + transformed_list = model.transform(dataset).select("features", "expected").collect() + + for r in transformed_list: + feature, expected = r + self.assertEqual(feature, expected) + + # Test an empty vocabulary + with QuietTest(self.sc): + with self.assertRaisesRegexp(Exception, "vocabSize.*invalid.*0"): + CountVectorizerModel.from_vocabulary([], inputCol="words") + + # Test model with default settings can transform + model_default = CountVectorizerModel.from_vocabulary(["a", "b", "c"], inputCol="words") + transformed_list = model_default.transform(dataset) \ + .select(model_default.getOrDefault(model_default.outputCol)).collect() + self.assertEqual(len(transformed_list), 3) + + def test_rformula_force_index_label(self): + df = self.spark.createDataFrame([ + (1.0, 1.0, "a"), + (0.0, 2.0, "b"), + (1.0, 0.0, "a")], ["y", "x", "s"]) + # Does not index label by default since it's numeric type. + rf = RFormula(formula="y ~ x + s") + model = rf.fit(df) + transformedDF = model.transform(df) + self.assertEqual(transformedDF.head().label, 1.0) + # Force to index label. + rf2 = RFormula(formula="y ~ x + s").setForceIndexLabel(True) + model2 = rf2.fit(df) + transformedDF2 = model2.transform(df) + self.assertEqual(transformedDF2.head().label, 0.0) + + def test_rformula_string_indexer_order_type(self): + df = self.spark.createDataFrame([ + (1.0, 1.0, "a"), + (0.0, 2.0, "b"), + (1.0, 0.0, "a")], ["y", "x", "s"]) + rf = RFormula(formula="y ~ x + s", stringIndexerOrderType="alphabetDesc") + self.assertEqual(rf.getStringIndexerOrderType(), 'alphabetDesc') + transformedDF = rf.fit(df).transform(df) + observed = transformedDF.select("features").collect() + expected = [[1.0, 0.0], [2.0, 1.0], [0.0, 0.0]] + for i in range(0, len(expected)): + self.assertTrue(all(observed[i]["features"].toArray() == expected[i])) + + def test_string_indexer_handle_invalid(self): + df = self.spark.createDataFrame([ + (0, "a"), + (1, "d"), + (2, None)], ["id", "label"]) + + si1 = StringIndexer(inputCol="label", outputCol="indexed", handleInvalid="keep", + stringOrderType="alphabetAsc") + model1 = si1.fit(df) + td1 = model1.transform(df) + actual1 = td1.select("id", "indexed").collect() + expected1 = [Row(id=0, indexed=0.0), Row(id=1, indexed=1.0), Row(id=2, indexed=2.0)] + self.assertEqual(actual1, expected1) + + si2 = si1.setHandleInvalid("skip") + model2 = si2.fit(df) + td2 = model2.transform(df) + actual2 = td2.select("id", "indexed").collect() + expected2 = [Row(id=0, indexed=0.0), Row(id=1, indexed=1.0)] + self.assertEqual(actual2, expected2) + + def test_string_indexer_from_labels(self): + model = StringIndexerModel.from_labels(["a", "b", "c"], inputCol="label", + outputCol="indexed", handleInvalid="keep") + self.assertEqual(model.labels, ["a", "b", "c"]) + + df1 = self.spark.createDataFrame([ + (0, "a"), + (1, "c"), + (2, None), + (3, "b"), + (4, "b")], ["id", "label"]) + + result1 = model.transform(df1) + actual1 = result1.select("id", "indexed").collect() + expected1 = [Row(id=0, indexed=0.0), Row(id=1, indexed=2.0), Row(id=2, indexed=3.0), + Row(id=3, indexed=1.0), Row(id=4, indexed=1.0)] + self.assertEqual(actual1, expected1) + + model_empty_labels = StringIndexerModel.from_labels( + [], inputCol="label", outputCol="indexed", handleInvalid="keep") + actual2 = model_empty_labels.transform(df1).select("id", "indexed").collect() + expected2 = [Row(id=0, indexed=0.0), Row(id=1, indexed=0.0), Row(id=2, indexed=0.0), + Row(id=3, indexed=0.0), Row(id=4, indexed=0.0)] + self.assertEqual(actual2, expected2) + + # Test model with default settings can transform + model_default = StringIndexerModel.from_labels(["a", "b", "c"], inputCol="label") + df2 = self.spark.createDataFrame([ + (0, "a"), + (1, "c"), + (2, "b"), + (3, "b"), + (4, "b")], ["id", "label"]) + transformed_list = model_default.transform(df2) \ + .select(model_default.getOrDefault(model_default.outputCol)).collect() + self.assertEqual(len(transformed_list), 5) + + def test_vector_size_hint(self): + df = self.spark.createDataFrame( + [(0, Vectors.dense([0.0, 10.0, 0.5])), + (1, Vectors.dense([1.0, 11.0, 0.5, 0.6])), + (2, Vectors.dense([2.0, 12.0]))], + ["id", "vector"]) + + sizeHint = VectorSizeHint( + inputCol="vector", + handleInvalid="skip") + sizeHint.setSize(3) + self.assertEqual(sizeHint.getSize(), 3) + + output = sizeHint.transform(df).head().vector + expected = DenseVector([0.0, 10.0, 0.5]) + self.assertEqual(output, expected) + + +class HashingTFTest(SparkSessionTestCase): + + def test_apply_binary_term_freqs(self): + + df = self.spark.createDataFrame([(0, ["a", "a", "b", "c", "c", "c"])], ["id", "words"]) + n = 10 + hashingTF = HashingTF() + hashingTF.setInputCol("words").setOutputCol("features").setNumFeatures(n).setBinary(True) + output = hashingTF.transform(df) + features = output.select("features").first().features.toArray() + expected = Vectors.dense([1.0, 1.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]).toArray() + for i in range(0, n): + self.assertAlmostEqual(features[i], expected[i], 14, "Error at " + str(i) + + ": expected " + str(expected[i]) + ", got " + str(features[i])) + + +if __name__ == "__main__": + from pyspark.ml.tests.test_feature import * + + try: + import xmlrunner + testRunner = xmlrunner.XMLTestRunner(output='target/test-reports') + except ImportError: + testRunner = None + unittest.main(testRunner=testRunner, verbosity=2) diff --git a/python/pyspark/ml/tests/test_image.py b/python/pyspark/ml/tests/test_image.py new file mode 100644 index 0000000000000..4c280a4a67894 --- /dev/null +++ b/python/pyspark/ml/tests/test_image.py @@ -0,0 +1,110 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +import unittest + +import py4j + +from pyspark.ml.image import ImageSchema +from pyspark.testing.mlutils import PySparkTestCase, SparkSessionTestCase +from pyspark.sql import HiveContext, Row +from pyspark.testing.utils import QuietTest + + +class ImageReaderTest(SparkSessionTestCase): + + def test_read_images(self): + data_path = 'data/mllib/images/origin/kittens' + df = ImageSchema.readImages(data_path, recursive=True, dropImageFailures=True) + self.assertEqual(df.count(), 4) + first_row = df.take(1)[0][0] + array = ImageSchema.toNDArray(first_row) + self.assertEqual(len(array), first_row[1]) + self.assertEqual(ImageSchema.toImage(array, origin=first_row[0]), first_row) + self.assertEqual(df.schema, ImageSchema.imageSchema) + self.assertEqual(df.schema["image"].dataType, ImageSchema.columnSchema) + expected = {'CV_8UC3': 16, 'Undefined': -1, 'CV_8U': 0, 'CV_8UC1': 0, 'CV_8UC4': 24} + self.assertEqual(ImageSchema.ocvTypes, expected) + expected = ['origin', 'height', 'width', 'nChannels', 'mode', 'data'] + self.assertEqual(ImageSchema.imageFields, expected) + self.assertEqual(ImageSchema.undefinedImageType, "Undefined") + + with QuietTest(self.sc): + self.assertRaisesRegexp( + TypeError, + "image argument should be pyspark.sql.types.Row; however", + lambda: ImageSchema.toNDArray("a")) + + with QuietTest(self.sc): + self.assertRaisesRegexp( + ValueError, + "image argument should have attributes specified in", + lambda: ImageSchema.toNDArray(Row(a=1))) + + with QuietTest(self.sc): + self.assertRaisesRegexp( + TypeError, + "array argument should be numpy.ndarray; however, it got", + lambda: ImageSchema.toImage("a")) + + +class ImageReaderTest2(PySparkTestCase): + + @classmethod + def setUpClass(cls): + super(ImageReaderTest2, cls).setUpClass() + cls.hive_available = True + # Note that here we enable Hive's support. + cls.spark = None + try: + cls.sc._jvm.org.apache.hadoop.hive.conf.HiveConf() + except py4j.protocol.Py4JError: + cls.tearDownClass() + cls.hive_available = False + except TypeError: + cls.tearDownClass() + cls.hive_available = False + if cls.hive_available: + cls.spark = HiveContext._createForTesting(cls.sc) + + def setUp(self): + if not self.hive_available: + self.skipTest("Hive is not available.") + + @classmethod + def tearDownClass(cls): + super(ImageReaderTest2, cls).tearDownClass() + if cls.spark is not None: + cls.spark.sparkSession.stop() + cls.spark = None + + def test_read_images_multiple_times(self): + # This test case is to check if `ImageSchema.readImages` tries to + # initiate Hive client multiple times. See SPARK-22651. + data_path = 'data/mllib/images/origin/kittens' + ImageSchema.readImages(data_path, recursive=True, dropImageFailures=True) + ImageSchema.readImages(data_path, recursive=True, dropImageFailures=True) + + +if __name__ == "__main__": + from pyspark.ml.tests.test_image import * + + try: + import xmlrunner + testRunner = xmlrunner.XMLTestRunner(output='target/test-reports') + except ImportError: + testRunner = None + unittest.main(testRunner=testRunner, verbosity=2) diff --git a/python/pyspark/ml/tests/test_linalg.py b/python/pyspark/ml/tests/test_linalg.py new file mode 100644 index 0000000000000..71cad5d7f5ad7 --- /dev/null +++ b/python/pyspark/ml/tests/test_linalg.py @@ -0,0 +1,384 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import unittest +import array as pyarray + +from numpy import arange, array, array_equal, inf, ones, tile, zeros + +from pyspark.ml.linalg import DenseMatrix, DenseVector, MatrixUDT, SparseMatrix, SparseVector, \ + Vector, VectorUDT, Vectors +from pyspark.testing.mllibutils import make_serializer, MLlibTestCase +from pyspark.sql import Row + + +ser = make_serializer() + + +def _squared_distance(a, b): + if isinstance(a, Vector): + return a.squared_distance(b) + else: + return b.squared_distance(a) + + +class VectorTests(MLlibTestCase): + + def _test_serialize(self, v): + self.assertEqual(v, ser.loads(ser.dumps(v))) + jvec = self.sc._jvm.org.apache.spark.ml.python.MLSerDe.loads(bytearray(ser.dumps(v))) + nv = ser.loads(bytes(self.sc._jvm.org.apache.spark.ml.python.MLSerDe.dumps(jvec))) + self.assertEqual(v, nv) + vs = [v] * 100 + jvecs = self.sc._jvm.org.apache.spark.ml.python.MLSerDe.loads(bytearray(ser.dumps(vs))) + nvs = ser.loads(bytes(self.sc._jvm.org.apache.spark.ml.python.MLSerDe.dumps(jvecs))) + self.assertEqual(vs, nvs) + + def test_serialize(self): + self._test_serialize(DenseVector(range(10))) + self._test_serialize(DenseVector(array([1., 2., 3., 4.]))) + self._test_serialize(DenseVector(pyarray.array('d', range(10)))) + self._test_serialize(SparseVector(4, {1: 1, 3: 2})) + self._test_serialize(SparseVector(3, {})) + self._test_serialize(DenseMatrix(2, 3, range(6))) + sm1 = SparseMatrix( + 3, 4, [0, 2, 2, 4, 4], [1, 2, 1, 2], [1.0, 2.0, 4.0, 5.0]) + self._test_serialize(sm1) + + def test_dot(self): + sv = SparseVector(4, {1: 1, 3: 2}) + dv = DenseVector(array([1., 2., 3., 4.])) + lst = DenseVector([1, 2, 3, 4]) + mat = array([[1., 2., 3., 4.], + [1., 2., 3., 4.], + [1., 2., 3., 4.], + [1., 2., 3., 4.]]) + arr = pyarray.array('d', [0, 1, 2, 3]) + self.assertEqual(10.0, sv.dot(dv)) + self.assertTrue(array_equal(array([3., 6., 9., 12.]), sv.dot(mat))) + self.assertEqual(30.0, dv.dot(dv)) + self.assertTrue(array_equal(array([10., 20., 30., 40.]), dv.dot(mat))) + self.assertEqual(30.0, lst.dot(dv)) + self.assertTrue(array_equal(array([10., 20., 30., 40.]), lst.dot(mat))) + self.assertEqual(7.0, sv.dot(arr)) + + def test_squared_distance(self): + sv = SparseVector(4, {1: 1, 3: 2}) + dv = DenseVector(array([1., 2., 3., 4.])) + lst = DenseVector([4, 3, 2, 1]) + lst1 = [4, 3, 2, 1] + arr = pyarray.array('d', [0, 2, 1, 3]) + narr = array([0, 2, 1, 3]) + self.assertEqual(15.0, _squared_distance(sv, dv)) + self.assertEqual(25.0, _squared_distance(sv, lst)) + self.assertEqual(20.0, _squared_distance(dv, lst)) + self.assertEqual(15.0, _squared_distance(dv, sv)) + self.assertEqual(25.0, _squared_distance(lst, sv)) + self.assertEqual(20.0, _squared_distance(lst, dv)) + self.assertEqual(0.0, _squared_distance(sv, sv)) + self.assertEqual(0.0, _squared_distance(dv, dv)) + self.assertEqual(0.0, _squared_distance(lst, lst)) + self.assertEqual(25.0, _squared_distance(sv, lst1)) + self.assertEqual(3.0, _squared_distance(sv, arr)) + self.assertEqual(3.0, _squared_distance(sv, narr)) + + def test_hash(self): + v1 = DenseVector([0.0, 1.0, 0.0, 5.5]) + v2 = SparseVector(4, [(1, 1.0), (3, 5.5)]) + v3 = DenseVector([0.0, 1.0, 0.0, 5.5]) + v4 = SparseVector(4, [(1, 1.0), (3, 2.5)]) + self.assertEqual(hash(v1), hash(v2)) + self.assertEqual(hash(v1), hash(v3)) + self.assertEqual(hash(v2), hash(v3)) + self.assertFalse(hash(v1) == hash(v4)) + self.assertFalse(hash(v2) == hash(v4)) + + def test_eq(self): + v1 = DenseVector([0.0, 1.0, 0.0, 5.5]) + v2 = SparseVector(4, [(1, 1.0), (3, 5.5)]) + v3 = DenseVector([0.0, 1.0, 0.0, 5.5]) + v4 = SparseVector(6, [(1, 1.0), (3, 5.5)]) + v5 = DenseVector([0.0, 1.0, 0.0, 2.5]) + v6 = SparseVector(4, [(1, 1.0), (3, 2.5)]) + self.assertEqual(v1, v2) + self.assertEqual(v1, v3) + self.assertFalse(v2 == v4) + self.assertFalse(v1 == v5) + self.assertFalse(v1 == v6) + + def test_equals(self): + indices = [1, 2, 4] + values = [1., 3., 2.] + self.assertTrue(Vectors._equals(indices, values, list(range(5)), [0., 1., 3., 0., 2.])) + self.assertFalse(Vectors._equals(indices, values, list(range(5)), [0., 3., 1., 0., 2.])) + self.assertFalse(Vectors._equals(indices, values, list(range(5)), [0., 3., 0., 2.])) + self.assertFalse(Vectors._equals(indices, values, list(range(5)), [0., 1., 3., 2., 2.])) + + def test_conversion(self): + # numpy arrays should be automatically upcast to float64 + # tests for fix of [SPARK-5089] + v = array([1, 2, 3, 4], dtype='float64') + dv = DenseVector(v) + self.assertTrue(dv.array.dtype == 'float64') + v = array([1, 2, 3, 4], dtype='float32') + dv = DenseVector(v) + self.assertTrue(dv.array.dtype == 'float64') + + def test_sparse_vector_indexing(self): + sv = SparseVector(5, {1: 1, 3: 2}) + self.assertEqual(sv[0], 0.) + self.assertEqual(sv[3], 2.) + self.assertEqual(sv[1], 1.) + self.assertEqual(sv[2], 0.) + self.assertEqual(sv[4], 0.) + self.assertEqual(sv[-1], 0.) + self.assertEqual(sv[-2], 2.) + self.assertEqual(sv[-3], 0.) + self.assertEqual(sv[-5], 0.) + for ind in [5, -6]: + self.assertRaises(IndexError, sv.__getitem__, ind) + for ind in [7.8, '1']: + self.assertRaises(TypeError, sv.__getitem__, ind) + + zeros = SparseVector(4, {}) + self.assertEqual(zeros[0], 0.0) + self.assertEqual(zeros[3], 0.0) + for ind in [4, -5]: + self.assertRaises(IndexError, zeros.__getitem__, ind) + + empty = SparseVector(0, {}) + for ind in [-1, 0, 1]: + self.assertRaises(IndexError, empty.__getitem__, ind) + + def test_sparse_vector_iteration(self): + self.assertListEqual(list(SparseVector(3, [], [])), [0.0, 0.0, 0.0]) + self.assertListEqual(list(SparseVector(5, [0, 3], [1.0, 2.0])), [1.0, 0.0, 0.0, 2.0, 0.0]) + + def test_matrix_indexing(self): + mat = DenseMatrix(3, 2, [0, 1, 4, 6, 8, 10]) + expected = [[0, 6], [1, 8], [4, 10]] + for i in range(3): + for j in range(2): + self.assertEqual(mat[i, j], expected[i][j]) + + for i, j in [(-1, 0), (4, 1), (3, 4)]: + self.assertRaises(IndexError, mat.__getitem__, (i, j)) + + def test_repr_dense_matrix(self): + mat = DenseMatrix(3, 2, [0, 1, 4, 6, 8, 10]) + self.assertTrue( + repr(mat), + 'DenseMatrix(3, 2, [0.0, 1.0, 4.0, 6.0, 8.0, 10.0], False)') + + mat = DenseMatrix(3, 2, [0, 1, 4, 6, 8, 10], True) + self.assertTrue( + repr(mat), + 'DenseMatrix(3, 2, [0.0, 1.0, 4.0, 6.0, 8.0, 10.0], False)') + + mat = DenseMatrix(6, 3, zeros(18)) + self.assertTrue( + repr(mat), + 'DenseMatrix(6, 3, [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ..., \ + 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], False)') + + def test_repr_sparse_matrix(self): + sm1t = SparseMatrix( + 3, 4, [0, 2, 3, 5], [0, 1, 2, 0, 2], [3.0, 2.0, 4.0, 9.0, 8.0], + isTransposed=True) + self.assertTrue( + repr(sm1t), + 'SparseMatrix(3, 4, [0, 2, 3, 5], [0, 1, 2, 0, 2], [3.0, 2.0, 4.0, 9.0, 8.0], True)') + + indices = tile(arange(6), 3) + values = ones(18) + sm = SparseMatrix(6, 3, [0, 6, 12, 18], indices, values) + self.assertTrue( + repr(sm), "SparseMatrix(6, 3, [0, 6, 12, 18], \ + [0, 1, 2, 3, 4, 5, 0, 1, ..., 4, 5, 0, 1, 2, 3, 4, 5], \ + [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, ..., \ + 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0], False)") + + self.assertTrue( + str(sm), + "6 X 3 CSCMatrix\n\ + (0,0) 1.0\n(1,0) 1.0\n(2,0) 1.0\n(3,0) 1.0\n(4,0) 1.0\n(5,0) 1.0\n\ + (0,1) 1.0\n(1,1) 1.0\n(2,1) 1.0\n(3,1) 1.0\n(4,1) 1.0\n(5,1) 1.0\n\ + (0,2) 1.0\n(1,2) 1.0\n(2,2) 1.0\n(3,2) 1.0\n..\n..") + + sm = SparseMatrix(1, 18, zeros(19), [], []) + self.assertTrue( + repr(sm), + 'SparseMatrix(1, 18, \ + [0, 0, 0, 0, 0, 0, 0, 0, ..., 0, 0, 0, 0, 0, 0, 0, 0], [], [], False)') + + def test_sparse_matrix(self): + # Test sparse matrix creation. + sm1 = SparseMatrix( + 3, 4, [0, 2, 2, 4, 4], [1, 2, 1, 2], [1.0, 2.0, 4.0, 5.0]) + self.assertEqual(sm1.numRows, 3) + self.assertEqual(sm1.numCols, 4) + self.assertEqual(sm1.colPtrs.tolist(), [0, 2, 2, 4, 4]) + self.assertEqual(sm1.rowIndices.tolist(), [1, 2, 1, 2]) + self.assertEqual(sm1.values.tolist(), [1.0, 2.0, 4.0, 5.0]) + self.assertTrue( + repr(sm1), + 'SparseMatrix(3, 4, [0, 2, 2, 4, 4], [1, 2, 1, 2], [1.0, 2.0, 4.0, 5.0], False)') + + # Test indexing + expected = [ + [0, 0, 0, 0], + [1, 0, 4, 0], + [2, 0, 5, 0]] + + for i in range(3): + for j in range(4): + self.assertEqual(expected[i][j], sm1[i, j]) + self.assertTrue(array_equal(sm1.toArray(), expected)) + + for i, j in [(-1, 1), (4, 3), (3, 5)]: + self.assertRaises(IndexError, sm1.__getitem__, (i, j)) + + # Test conversion to dense and sparse. + smnew = sm1.toDense().toSparse() + self.assertEqual(sm1.numRows, smnew.numRows) + self.assertEqual(sm1.numCols, smnew.numCols) + self.assertTrue(array_equal(sm1.colPtrs, smnew.colPtrs)) + self.assertTrue(array_equal(sm1.rowIndices, smnew.rowIndices)) + self.assertTrue(array_equal(sm1.values, smnew.values)) + + sm1t = SparseMatrix( + 3, 4, [0, 2, 3, 5], [0, 1, 2, 0, 2], [3.0, 2.0, 4.0, 9.0, 8.0], + isTransposed=True) + self.assertEqual(sm1t.numRows, 3) + self.assertEqual(sm1t.numCols, 4) + self.assertEqual(sm1t.colPtrs.tolist(), [0, 2, 3, 5]) + self.assertEqual(sm1t.rowIndices.tolist(), [0, 1, 2, 0, 2]) + self.assertEqual(sm1t.values.tolist(), [3.0, 2.0, 4.0, 9.0, 8.0]) + + expected = [ + [3, 2, 0, 0], + [0, 0, 4, 0], + [9, 0, 8, 0]] + + for i in range(3): + for j in range(4): + self.assertEqual(expected[i][j], sm1t[i, j]) + self.assertTrue(array_equal(sm1t.toArray(), expected)) + + def test_dense_matrix_is_transposed(self): + mat1 = DenseMatrix(3, 2, [0, 4, 1, 6, 3, 9], isTransposed=True) + mat = DenseMatrix(3, 2, [0, 1, 3, 4, 6, 9]) + self.assertEqual(mat1, mat) + + expected = [[0, 4], [1, 6], [3, 9]] + for i in range(3): + for j in range(2): + self.assertEqual(mat1[i, j], expected[i][j]) + self.assertTrue(array_equal(mat1.toArray(), expected)) + + sm = mat1.toSparse() + self.assertTrue(array_equal(sm.rowIndices, [1, 2, 0, 1, 2])) + self.assertTrue(array_equal(sm.colPtrs, [0, 2, 5])) + self.assertTrue(array_equal(sm.values, [1, 3, 4, 6, 9])) + + def test_norms(self): + a = DenseVector([0, 2, 3, -1]) + self.assertAlmostEqual(a.norm(2), 3.742, 3) + self.assertTrue(a.norm(1), 6) + self.assertTrue(a.norm(inf), 3) + a = SparseVector(4, [0, 2], [3, -4]) + self.assertAlmostEqual(a.norm(2), 5) + self.assertTrue(a.norm(1), 7) + self.assertTrue(a.norm(inf), 4) + + tmp = SparseVector(4, [0, 2], [3, 0]) + self.assertEqual(tmp.numNonzeros(), 1) + + +class VectorUDTTests(MLlibTestCase): + + dv0 = DenseVector([]) + dv1 = DenseVector([1.0, 2.0]) + sv0 = SparseVector(2, [], []) + sv1 = SparseVector(2, [1], [2.0]) + udt = VectorUDT() + + def test_json_schema(self): + self.assertEqual(VectorUDT.fromJson(self.udt.jsonValue()), self.udt) + + def test_serialization(self): + for v in [self.dv0, self.dv1, self.sv0, self.sv1]: + self.assertEqual(v, self.udt.deserialize(self.udt.serialize(v))) + + def test_infer_schema(self): + rdd = self.sc.parallelize([Row(label=1.0, features=self.dv1), + Row(label=0.0, features=self.sv1)]) + df = rdd.toDF() + schema = df.schema + field = [f for f in schema.fields if f.name == "features"][0] + self.assertEqual(field.dataType, self.udt) + vectors = df.rdd.map(lambda p: p.features).collect() + self.assertEqual(len(vectors), 2) + for v in vectors: + if isinstance(v, SparseVector): + self.assertEqual(v, self.sv1) + elif isinstance(v, DenseVector): + self.assertEqual(v, self.dv1) + else: + raise TypeError("expecting a vector but got %r of type %r" % (v, type(v))) + + +class MatrixUDTTests(MLlibTestCase): + + dm1 = DenseMatrix(3, 2, [0, 1, 4, 5, 9, 10]) + dm2 = DenseMatrix(3, 2, [0, 1, 4, 5, 9, 10], isTransposed=True) + sm1 = SparseMatrix(1, 1, [0, 1], [0], [2.0]) + sm2 = SparseMatrix(2, 1, [0, 0, 1], [0], [5.0], isTransposed=True) + udt = MatrixUDT() + + def test_json_schema(self): + self.assertEqual(MatrixUDT.fromJson(self.udt.jsonValue()), self.udt) + + def test_serialization(self): + for m in [self.dm1, self.dm2, self.sm1, self.sm2]: + self.assertEqual(m, self.udt.deserialize(self.udt.serialize(m))) + + def test_infer_schema(self): + rdd = self.sc.parallelize([("dense", self.dm1), ("sparse", self.sm1)]) + df = rdd.toDF() + schema = df.schema + self.assertTrue(schema.fields[1].dataType, self.udt) + matrices = df.rdd.map(lambda x: x._2).collect() + self.assertEqual(len(matrices), 2) + for m in matrices: + if isinstance(m, DenseMatrix): + self.assertTrue(m, self.dm1) + elif isinstance(m, SparseMatrix): + self.assertTrue(m, self.sm1) + else: + raise ValueError("Expected a matrix but got type %r" % type(m)) + + +if __name__ == "__main__": + from pyspark.ml.tests.test_linalg import * + + try: + import xmlrunner + testRunner = xmlrunner.XMLTestRunner(output='target/test-reports') + except ImportError: + testRunner = None + unittest.main(testRunner=testRunner, verbosity=2) diff --git a/python/pyspark/ml/tests/test_param.py b/python/pyspark/ml/tests/test_param.py new file mode 100644 index 0000000000000..17c1b0bf65dde --- /dev/null +++ b/python/pyspark/ml/tests/test_param.py @@ -0,0 +1,366 @@ +# -*- coding: utf-8 -*- +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import inspect +import sys +import array as pyarray +import unittest + +import numpy as np + +from pyspark import keyword_only +from pyspark.ml.classification import LogisticRegression +from pyspark.ml.clustering import KMeans +from pyspark.ml.feature import Binarizer, Bucketizer, ElementwiseProduct, IndexToString, \ + VectorSlicer, Word2Vec +from pyspark.ml.linalg import DenseVector, SparseVector +from pyspark.ml.param import Param, Params, TypeConverters +from pyspark.ml.param.shared import HasInputCol, HasMaxIter, HasSeed +from pyspark.ml.wrapper import JavaParams +from pyspark.testing.mlutils import check_params, PySparkTestCase, SparkSessionTestCase + + +if sys.version > '3': + xrange = range + + +class ParamTypeConversionTests(PySparkTestCase): + """ + Test that param type conversion happens. + """ + + def test_int(self): + lr = LogisticRegression(maxIter=5.0) + self.assertEqual(lr.getMaxIter(), 5) + self.assertTrue(type(lr.getMaxIter()) == int) + self.assertRaises(TypeError, lambda: LogisticRegression(maxIter="notAnInt")) + self.assertRaises(TypeError, lambda: LogisticRegression(maxIter=5.1)) + + def test_float(self): + lr = LogisticRegression(tol=1) + self.assertEqual(lr.getTol(), 1.0) + self.assertTrue(type(lr.getTol()) == float) + self.assertRaises(TypeError, lambda: LogisticRegression(tol="notAFloat")) + + def test_vector(self): + ewp = ElementwiseProduct(scalingVec=[1, 3]) + self.assertEqual(ewp.getScalingVec(), DenseVector([1.0, 3.0])) + ewp = ElementwiseProduct(scalingVec=np.array([1.2, 3.4])) + self.assertEqual(ewp.getScalingVec(), DenseVector([1.2, 3.4])) + self.assertRaises(TypeError, lambda: ElementwiseProduct(scalingVec=["a", "b"])) + + def test_list(self): + l = [0, 1] + for lst_like in [l, np.array(l), DenseVector(l), SparseVector(len(l), range(len(l)), l), + pyarray.array('l', l), xrange(2), tuple(l)]: + converted = TypeConverters.toList(lst_like) + self.assertEqual(type(converted), list) + self.assertListEqual(converted, l) + + def test_list_int(self): + for indices in [[1.0, 2.0], np.array([1.0, 2.0]), DenseVector([1.0, 2.0]), + SparseVector(2, {0: 1.0, 1: 2.0}), xrange(1, 3), (1.0, 2.0), + pyarray.array('d', [1.0, 2.0])]: + vs = VectorSlicer(indices=indices) + self.assertListEqual(vs.getIndices(), [1, 2]) + self.assertTrue(all([type(v) == int for v in vs.getIndices()])) + self.assertRaises(TypeError, lambda: VectorSlicer(indices=["a", "b"])) + + def test_list_float(self): + b = Bucketizer(splits=[1, 4]) + self.assertEqual(b.getSplits(), [1.0, 4.0]) + self.assertTrue(all([type(v) == float for v in b.getSplits()])) + self.assertRaises(TypeError, lambda: Bucketizer(splits=["a", 1.0])) + + def test_list_string(self): + for labels in [np.array(['a', u'b']), ['a', u'b'], np.array(['a', 'b'])]: + idx_to_string = IndexToString(labels=labels) + self.assertListEqual(idx_to_string.getLabels(), ['a', 'b']) + self.assertRaises(TypeError, lambda: IndexToString(labels=['a', 2])) + + def test_string(self): + lr = LogisticRegression() + for col in ['features', u'features', np.str_('features')]: + lr.setFeaturesCol(col) + self.assertEqual(lr.getFeaturesCol(), 'features') + self.assertRaises(TypeError, lambda: LogisticRegression(featuresCol=2.3)) + + def test_bool(self): + self.assertRaises(TypeError, lambda: LogisticRegression(fitIntercept=1)) + self.assertRaises(TypeError, lambda: LogisticRegression(fitIntercept="false")) + + +class TestParams(HasMaxIter, HasInputCol, HasSeed): + """ + A subclass of Params mixed with HasMaxIter, HasInputCol and HasSeed. + """ + @keyword_only + def __init__(self, seed=None): + super(TestParams, self).__init__() + self._setDefault(maxIter=10) + kwargs = self._input_kwargs + self.setParams(**kwargs) + + @keyword_only + def setParams(self, seed=None): + """ + setParams(self, seed=None) + Sets params for this test. + """ + kwargs = self._input_kwargs + return self._set(**kwargs) + + +class OtherTestParams(HasMaxIter, HasInputCol, HasSeed): + """ + A subclass of Params mixed with HasMaxIter, HasInputCol and HasSeed. + """ + @keyword_only + def __init__(self, seed=None): + super(OtherTestParams, self).__init__() + self._setDefault(maxIter=10) + kwargs = self._input_kwargs + self.setParams(**kwargs) + + @keyword_only + def setParams(self, seed=None): + """ + setParams(self, seed=None) + Sets params for this test. + """ + kwargs = self._input_kwargs + return self._set(**kwargs) + + +class HasThrowableProperty(Params): + + def __init__(self): + super(HasThrowableProperty, self).__init__() + self.p = Param(self, "none", "empty param") + + @property + def test_property(self): + raise RuntimeError("Test property to raise error when invoked") + + +class ParamTests(SparkSessionTestCase): + + def test_copy_new_parent(self): + testParams = TestParams() + # Copying an instantiated param should fail + with self.assertRaises(ValueError): + testParams.maxIter._copy_new_parent(testParams) + # Copying a dummy param should succeed + TestParams.maxIter._copy_new_parent(testParams) + maxIter = testParams.maxIter + self.assertEqual(maxIter.name, "maxIter") + self.assertEqual(maxIter.doc, "max number of iterations (>= 0).") + self.assertTrue(maxIter.parent == testParams.uid) + + def test_param(self): + testParams = TestParams() + maxIter = testParams.maxIter + self.assertEqual(maxIter.name, "maxIter") + self.assertEqual(maxIter.doc, "max number of iterations (>= 0).") + self.assertTrue(maxIter.parent == testParams.uid) + + def test_hasparam(self): + testParams = TestParams() + self.assertTrue(all([testParams.hasParam(p.name) for p in testParams.params])) + self.assertFalse(testParams.hasParam("notAParameter")) + self.assertTrue(testParams.hasParam(u"maxIter")) + + def test_resolveparam(self): + testParams = TestParams() + self.assertEqual(testParams._resolveParam(testParams.maxIter), testParams.maxIter) + self.assertEqual(testParams._resolveParam("maxIter"), testParams.maxIter) + + self.assertEqual(testParams._resolveParam(u"maxIter"), testParams.maxIter) + if sys.version_info[0] >= 3: + # In Python 3, it is allowed to get/set attributes with non-ascii characters. + e_cls = AttributeError + else: + e_cls = UnicodeEncodeError + self.assertRaises(e_cls, lambda: testParams._resolveParam(u"아")) + + def test_params(self): + testParams = TestParams() + maxIter = testParams.maxIter + inputCol = testParams.inputCol + seed = testParams.seed + + params = testParams.params + self.assertEqual(params, [inputCol, maxIter, seed]) + + self.assertTrue(testParams.hasParam(maxIter.name)) + self.assertTrue(testParams.hasDefault(maxIter)) + self.assertFalse(testParams.isSet(maxIter)) + self.assertTrue(testParams.isDefined(maxIter)) + self.assertEqual(testParams.getMaxIter(), 10) + testParams.setMaxIter(100) + self.assertTrue(testParams.isSet(maxIter)) + self.assertEqual(testParams.getMaxIter(), 100) + + self.assertTrue(testParams.hasParam(inputCol.name)) + self.assertFalse(testParams.hasDefault(inputCol)) + self.assertFalse(testParams.isSet(inputCol)) + self.assertFalse(testParams.isDefined(inputCol)) + with self.assertRaises(KeyError): + testParams.getInputCol() + + otherParam = Param(Params._dummy(), "otherParam", "Parameter used to test that " + + "set raises an error for a non-member parameter.", + typeConverter=TypeConverters.toString) + with self.assertRaises(ValueError): + testParams.set(otherParam, "value") + + # Since the default is normally random, set it to a known number for debug str + testParams._setDefault(seed=41) + testParams.setSeed(43) + + self.assertEqual( + testParams.explainParams(), + "\n".join(["inputCol: input column name. (undefined)", + "maxIter: max number of iterations (>= 0). (default: 10, current: 100)", + "seed: random seed. (default: 41, current: 43)"])) + + def test_kmeans_param(self): + algo = KMeans() + self.assertEqual(algo.getInitMode(), "k-means||") + algo.setK(10) + self.assertEqual(algo.getK(), 10) + algo.setInitSteps(10) + self.assertEqual(algo.getInitSteps(), 10) + self.assertEqual(algo.getDistanceMeasure(), "euclidean") + algo.setDistanceMeasure("cosine") + self.assertEqual(algo.getDistanceMeasure(), "cosine") + + def test_hasseed(self): + noSeedSpecd = TestParams() + withSeedSpecd = TestParams(seed=42) + other = OtherTestParams() + # Check that we no longer use 42 as the magic number + self.assertNotEqual(noSeedSpecd.getSeed(), 42) + origSeed = noSeedSpecd.getSeed() + # Check that we only compute the seed once + self.assertEqual(noSeedSpecd.getSeed(), origSeed) + # Check that a specified seed is honored + self.assertEqual(withSeedSpecd.getSeed(), 42) + # Check that a different class has a different seed + self.assertNotEqual(other.getSeed(), noSeedSpecd.getSeed()) + + def test_param_property_error(self): + param_store = HasThrowableProperty() + self.assertRaises(RuntimeError, lambda: param_store.test_property) + params = param_store.params # should not invoke the property 'test_property' + self.assertEqual(len(params), 1) + + def test_word2vec_param(self): + model = Word2Vec().setWindowSize(6) + # Check windowSize is set properly + self.assertEqual(model.getWindowSize(), 6) + + def test_copy_param_extras(self): + tp = TestParams(seed=42) + extra = {tp.getParam(TestParams.inputCol.name): "copy_input"} + tp_copy = tp.copy(extra=extra) + self.assertEqual(tp.uid, tp_copy.uid) + self.assertEqual(tp.params, tp_copy.params) + for k, v in extra.items(): + self.assertTrue(tp_copy.isDefined(k)) + self.assertEqual(tp_copy.getOrDefault(k), v) + copied_no_extra = {} + for k, v in tp_copy._paramMap.items(): + if k not in extra: + copied_no_extra[k] = v + self.assertEqual(tp._paramMap, copied_no_extra) + self.assertEqual(tp._defaultParamMap, tp_copy._defaultParamMap) + + def test_logistic_regression_check_thresholds(self): + self.assertIsInstance( + LogisticRegression(threshold=0.5, thresholds=[0.5, 0.5]), + LogisticRegression + ) + + self.assertRaisesRegexp( + ValueError, + "Logistic Regression getThreshold found inconsistent.*$", + LogisticRegression, threshold=0.42, thresholds=[0.5, 0.5] + ) + + def test_preserve_set_state(self): + dataset = self.spark.createDataFrame([(0.5,)], ["data"]) + binarizer = Binarizer(inputCol="data") + self.assertFalse(binarizer.isSet("threshold")) + binarizer.transform(dataset) + binarizer._transfer_params_from_java() + self.assertFalse(binarizer.isSet("threshold"), + "Params not explicitly set should remain unset after transform") + + def test_default_params_transferred(self): + dataset = self.spark.createDataFrame([(0.5,)], ["data"]) + binarizer = Binarizer(inputCol="data") + # intentionally change the pyspark default, but don't set it + binarizer._defaultParamMap[binarizer.outputCol] = "my_default" + result = binarizer.transform(dataset).select("my_default").collect() + self.assertFalse(binarizer.isSet(binarizer.outputCol)) + self.assertEqual(result[0][0], 1.0) + + +class DefaultValuesTests(PySparkTestCase): + """ + Test :py:class:`JavaParams` classes to see if their default Param values match + those in their Scala counterparts. + """ + + def test_java_params(self): + import pyspark.ml.feature + import pyspark.ml.classification + import pyspark.ml.clustering + import pyspark.ml.evaluation + import pyspark.ml.pipeline + import pyspark.ml.recommendation + import pyspark.ml.regression + + modules = [pyspark.ml.feature, pyspark.ml.classification, pyspark.ml.clustering, + pyspark.ml.evaluation, pyspark.ml.pipeline, pyspark.ml.recommendation, + pyspark.ml.regression] + for module in modules: + for name, cls in inspect.getmembers(module, inspect.isclass): + if not name.endswith('Model') and not name.endswith('Params') \ + and issubclass(cls, JavaParams) and not inspect.isabstract(cls): + # NOTE: disable check_params_exist until there is parity with Scala API + check_params(self, cls(), check_params_exist=False) + + # Additional classes that need explicit construction + from pyspark.ml.feature import CountVectorizerModel, StringIndexerModel + check_params(self, CountVectorizerModel.from_vocabulary(['a'], 'input'), + check_params_exist=False) + check_params(self, StringIndexerModel.from_labels(['a', 'b'], 'input'), + check_params_exist=False) + + +if __name__ == "__main__": + from pyspark.ml.tests.test_param import * + + try: + import xmlrunner + testRunner = xmlrunner.XMLTestRunner(output='target/test-reports') + except ImportError: + testRunner = None + unittest.main(testRunner=testRunner, verbosity=2) diff --git a/python/pyspark/ml/tests/test_persistence.py b/python/pyspark/ml/tests/test_persistence.py new file mode 100644 index 0000000000000..34d687039ab34 --- /dev/null +++ b/python/pyspark/ml/tests/test_persistence.py @@ -0,0 +1,361 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import json +from shutil import rmtree +import tempfile +import unittest + +from pyspark.ml import Transformer +from pyspark.ml.classification import DecisionTreeClassifier, LogisticRegression, OneVsRest, \ + OneVsRestModel +from pyspark.ml.feature import Binarizer, HashingTF, PCA +from pyspark.ml.linalg import Vectors +from pyspark.ml.param import Params +from pyspark.ml.pipeline import Pipeline, PipelineModel +from pyspark.ml.regression import DecisionTreeRegressor, LinearRegression +from pyspark.ml.util import DefaultParamsReadable, DefaultParamsWriter +from pyspark.ml.wrapper import JavaParams +from pyspark.testing.mlutils import MockUnaryTransformer, SparkSessionTestCase + + +class PersistenceTest(SparkSessionTestCase): + + def test_linear_regression(self): + lr = LinearRegression(maxIter=1) + path = tempfile.mkdtemp() + lr_path = path + "/lr" + lr.save(lr_path) + lr2 = LinearRegression.load(lr_path) + self.assertEqual(lr.uid, lr2.uid) + self.assertEqual(type(lr.uid), type(lr2.uid)) + self.assertEqual(lr2.uid, lr2.maxIter.parent, + "Loaded LinearRegression instance uid (%s) did not match Param's uid (%s)" + % (lr2.uid, lr2.maxIter.parent)) + self.assertEqual(lr._defaultParamMap[lr.maxIter], lr2._defaultParamMap[lr2.maxIter], + "Loaded LinearRegression instance default params did not match " + + "original defaults") + try: + rmtree(path) + except OSError: + pass + + def test_linear_regression_pmml_basic(self): + # Most of the validation is done in the Scala side, here we just check + # that we output text rather than parquet (e.g. that the format flag + # was respected). + df = self.spark.createDataFrame([(1.0, 2.0, Vectors.dense(1.0)), + (0.0, 2.0, Vectors.sparse(1, [], []))], + ["label", "weight", "features"]) + lr = LinearRegression(maxIter=1) + model = lr.fit(df) + path = tempfile.mkdtemp() + lr_path = path + "/lr-pmml" + model.write().format("pmml").save(lr_path) + pmml_text_list = self.sc.textFile(lr_path).collect() + pmml_text = "\n".join(pmml_text_list) + self.assertIn("Apache Spark", pmml_text) + self.assertIn("PMML", pmml_text) + + def test_logistic_regression(self): + lr = LogisticRegression(maxIter=1) + path = tempfile.mkdtemp() + lr_path = path + "/logreg" + lr.save(lr_path) + lr2 = LogisticRegression.load(lr_path) + self.assertEqual(lr2.uid, lr2.maxIter.parent, + "Loaded LogisticRegression instance uid (%s) " + "did not match Param's uid (%s)" + % (lr2.uid, lr2.maxIter.parent)) + self.assertEqual(lr._defaultParamMap[lr.maxIter], lr2._defaultParamMap[lr2.maxIter], + "Loaded LogisticRegression instance default params did not match " + + "original defaults") + try: + rmtree(path) + except OSError: + pass + + def _compare_params(self, m1, m2, param): + """ + Compare 2 ML Params instances for the given param, and assert both have the same param value + and parent. The param must be a parameter of m1. + """ + # Prevent key not found error in case of some param in neither paramMap nor defaultParamMap. + if m1.isDefined(param): + paramValue1 = m1.getOrDefault(param) + paramValue2 = m2.getOrDefault(m2.getParam(param.name)) + if isinstance(paramValue1, Params): + self._compare_pipelines(paramValue1, paramValue2) + else: + self.assertEqual(paramValue1, paramValue2) # for general types param + # Assert parents are equal + self.assertEqual(param.parent, m2.getParam(param.name).parent) + else: + # If m1 is not defined param, then m2 should not, too. See SPARK-14931. + self.assertFalse(m2.isDefined(m2.getParam(param.name))) + + def _compare_pipelines(self, m1, m2): + """ + Compare 2 ML types, asserting that they are equivalent. + This currently supports: + - basic types + - Pipeline, PipelineModel + - OneVsRest, OneVsRestModel + This checks: + - uid + - type + - Param values and parents + """ + self.assertEqual(m1.uid, m2.uid) + self.assertEqual(type(m1), type(m2)) + if isinstance(m1, JavaParams) or isinstance(m1, Transformer): + self.assertEqual(len(m1.params), len(m2.params)) + for p in m1.params: + self._compare_params(m1, m2, p) + elif isinstance(m1, Pipeline): + self.assertEqual(len(m1.getStages()), len(m2.getStages())) + for s1, s2 in zip(m1.getStages(), m2.getStages()): + self._compare_pipelines(s1, s2) + elif isinstance(m1, PipelineModel): + self.assertEqual(len(m1.stages), len(m2.stages)) + for s1, s2 in zip(m1.stages, m2.stages): + self._compare_pipelines(s1, s2) + elif isinstance(m1, OneVsRest) or isinstance(m1, OneVsRestModel): + for p in m1.params: + self._compare_params(m1, m2, p) + if isinstance(m1, OneVsRestModel): + self.assertEqual(len(m1.models), len(m2.models)) + for x, y in zip(m1.models, m2.models): + self._compare_pipelines(x, y) + else: + raise RuntimeError("_compare_pipelines does not yet support type: %s" % type(m1)) + + def test_pipeline_persistence(self): + """ + Pipeline[HashingTF, PCA] + """ + temp_path = tempfile.mkdtemp() + + try: + df = self.spark.createDataFrame([(["a", "b", "c"],), (["c", "d", "e"],)], ["words"]) + tf = HashingTF(numFeatures=10, inputCol="words", outputCol="features") + pca = PCA(k=2, inputCol="features", outputCol="pca_features") + pl = Pipeline(stages=[tf, pca]) + model = pl.fit(df) + + pipeline_path = temp_path + "/pipeline" + pl.save(pipeline_path) + loaded_pipeline = Pipeline.load(pipeline_path) + self._compare_pipelines(pl, loaded_pipeline) + + model_path = temp_path + "/pipeline-model" + model.save(model_path) + loaded_model = PipelineModel.load(model_path) + self._compare_pipelines(model, loaded_model) + finally: + try: + rmtree(temp_path) + except OSError: + pass + + def test_nested_pipeline_persistence(self): + """ + Pipeline[HashingTF, Pipeline[PCA]] + """ + temp_path = tempfile.mkdtemp() + + try: + df = self.spark.createDataFrame([(["a", "b", "c"],), (["c", "d", "e"],)], ["words"]) + tf = HashingTF(numFeatures=10, inputCol="words", outputCol="features") + pca = PCA(k=2, inputCol="features", outputCol="pca_features") + p0 = Pipeline(stages=[pca]) + pl = Pipeline(stages=[tf, p0]) + model = pl.fit(df) + + pipeline_path = temp_path + "/pipeline" + pl.save(pipeline_path) + loaded_pipeline = Pipeline.load(pipeline_path) + self._compare_pipelines(pl, loaded_pipeline) + + model_path = temp_path + "/pipeline-model" + model.save(model_path) + loaded_model = PipelineModel.load(model_path) + self._compare_pipelines(model, loaded_model) + finally: + try: + rmtree(temp_path) + except OSError: + pass + + def test_python_transformer_pipeline_persistence(self): + """ + Pipeline[MockUnaryTransformer, Binarizer] + """ + temp_path = tempfile.mkdtemp() + + try: + df = self.spark.range(0, 10).toDF('input') + tf = MockUnaryTransformer(shiftVal=2)\ + .setInputCol("input").setOutputCol("shiftedInput") + tf2 = Binarizer(threshold=6, inputCol="shiftedInput", outputCol="binarized") + pl = Pipeline(stages=[tf, tf2]) + model = pl.fit(df) + + pipeline_path = temp_path + "/pipeline" + pl.save(pipeline_path) + loaded_pipeline = Pipeline.load(pipeline_path) + self._compare_pipelines(pl, loaded_pipeline) + + model_path = temp_path + "/pipeline-model" + model.save(model_path) + loaded_model = PipelineModel.load(model_path) + self._compare_pipelines(model, loaded_model) + finally: + try: + rmtree(temp_path) + except OSError: + pass + + def test_onevsrest(self): + temp_path = tempfile.mkdtemp() + df = self.spark.createDataFrame([(0.0, Vectors.dense(1.0, 0.8)), + (1.0, Vectors.sparse(2, [], [])), + (2.0, Vectors.dense(0.5, 0.5))] * 10, + ["label", "features"]) + lr = LogisticRegression(maxIter=5, regParam=0.01) + ovr = OneVsRest(classifier=lr) + model = ovr.fit(df) + ovrPath = temp_path + "/ovr" + ovr.save(ovrPath) + loadedOvr = OneVsRest.load(ovrPath) + self._compare_pipelines(ovr, loadedOvr) + modelPath = temp_path + "/ovrModel" + model.save(modelPath) + loadedModel = OneVsRestModel.load(modelPath) + self._compare_pipelines(model, loadedModel) + + def test_decisiontree_classifier(self): + dt = DecisionTreeClassifier(maxDepth=1) + path = tempfile.mkdtemp() + dtc_path = path + "/dtc" + dt.save(dtc_path) + dt2 = DecisionTreeClassifier.load(dtc_path) + self.assertEqual(dt2.uid, dt2.maxDepth.parent, + "Loaded DecisionTreeClassifier instance uid (%s) " + "did not match Param's uid (%s)" + % (dt2.uid, dt2.maxDepth.parent)) + self.assertEqual(dt._defaultParamMap[dt.maxDepth], dt2._defaultParamMap[dt2.maxDepth], + "Loaded DecisionTreeClassifier instance default params did not match " + + "original defaults") + try: + rmtree(path) + except OSError: + pass + + def test_decisiontree_regressor(self): + dt = DecisionTreeRegressor(maxDepth=1) + path = tempfile.mkdtemp() + dtr_path = path + "/dtr" + dt.save(dtr_path) + dt2 = DecisionTreeClassifier.load(dtr_path) + self.assertEqual(dt2.uid, dt2.maxDepth.parent, + "Loaded DecisionTreeRegressor instance uid (%s) " + "did not match Param's uid (%s)" + % (dt2.uid, dt2.maxDepth.parent)) + self.assertEqual(dt._defaultParamMap[dt.maxDepth], dt2._defaultParamMap[dt2.maxDepth], + "Loaded DecisionTreeRegressor instance default params did not match " + + "original defaults") + try: + rmtree(path) + except OSError: + pass + + def test_default_read_write(self): + temp_path = tempfile.mkdtemp() + + lr = LogisticRegression() + lr.setMaxIter(50) + lr.setThreshold(.75) + writer = DefaultParamsWriter(lr) + + savePath = temp_path + "/lr" + writer.save(savePath) + + reader = DefaultParamsReadable.read() + lr2 = reader.load(savePath) + + self.assertEqual(lr.uid, lr2.uid) + self.assertEqual(lr.extractParamMap(), lr2.extractParamMap()) + + # test overwrite + lr.setThreshold(.8) + writer.overwrite().save(savePath) + + reader = DefaultParamsReadable.read() + lr3 = reader.load(savePath) + + self.assertEqual(lr.uid, lr3.uid) + self.assertEqual(lr.extractParamMap(), lr3.extractParamMap()) + + def test_default_read_write_default_params(self): + lr = LogisticRegression() + self.assertFalse(lr.isSet(lr.getParam("threshold"))) + + lr.setMaxIter(50) + lr.setThreshold(.75) + + # `threshold` is set by user, default param `predictionCol` is not set by user. + self.assertTrue(lr.isSet(lr.getParam("threshold"))) + self.assertFalse(lr.isSet(lr.getParam("predictionCol"))) + self.assertTrue(lr.hasDefault(lr.getParam("predictionCol"))) + + writer = DefaultParamsWriter(lr) + metadata = json.loads(writer._get_metadata_to_save(lr, self.sc)) + self.assertTrue("defaultParamMap" in metadata) + + reader = DefaultParamsReadable.read() + metadataStr = json.dumps(metadata, separators=[',', ':']) + loadedMetadata = reader._parseMetaData(metadataStr, ) + reader.getAndSetParams(lr, loadedMetadata) + + self.assertTrue(lr.isSet(lr.getParam("threshold"))) + self.assertFalse(lr.isSet(lr.getParam("predictionCol"))) + self.assertTrue(lr.hasDefault(lr.getParam("predictionCol"))) + + # manually create metadata without `defaultParamMap` section. + del metadata['defaultParamMap'] + metadataStr = json.dumps(metadata, separators=[',', ':']) + loadedMetadata = reader._parseMetaData(metadataStr, ) + with self.assertRaisesRegexp(AssertionError, "`defaultParamMap` section not found"): + reader.getAndSetParams(lr, loadedMetadata) + + # Prior to 2.4.0, metadata doesn't have `defaultParamMap`. + metadata['sparkVersion'] = '2.3.0' + metadataStr = json.dumps(metadata, separators=[',', ':']) + loadedMetadata = reader._parseMetaData(metadataStr, ) + reader.getAndSetParams(lr, loadedMetadata) + + +if __name__ == "__main__": + from pyspark.ml.tests.test_persistence import * + + try: + import xmlrunner + testRunner = xmlrunner.XMLTestRunner(output='target/test-reports') + except ImportError: + testRunner = None + unittest.main(testRunner=testRunner, verbosity=2) diff --git a/python/pyspark/ml/tests/test_pipeline.py b/python/pyspark/ml/tests/test_pipeline.py new file mode 100644 index 0000000000000..9e3e6c4a75d7a --- /dev/null +++ b/python/pyspark/ml/tests/test_pipeline.py @@ -0,0 +1,69 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +import unittest + +from pyspark.ml.pipeline import Pipeline +from pyspark.testing.mlutils import MockDataset, MockEstimator, MockTransformer, PySparkTestCase + + +class PipelineTests(PySparkTestCase): + + def test_pipeline(self): + dataset = MockDataset() + estimator0 = MockEstimator() + transformer1 = MockTransformer() + estimator2 = MockEstimator() + transformer3 = MockTransformer() + pipeline = Pipeline(stages=[estimator0, transformer1, estimator2, transformer3]) + pipeline_model = pipeline.fit(dataset, {estimator0.fake: 0, transformer1.fake: 1}) + model0, transformer1, model2, transformer3 = pipeline_model.stages + self.assertEqual(0, model0.dataset_index) + self.assertEqual(0, model0.getFake()) + self.assertEqual(1, transformer1.dataset_index) + self.assertEqual(1, transformer1.getFake()) + self.assertEqual(2, dataset.index) + self.assertIsNone(model2.dataset_index, "The last model shouldn't be called in fit.") + self.assertIsNone(transformer3.dataset_index, + "The last transformer shouldn't be called in fit.") + dataset = pipeline_model.transform(dataset) + self.assertEqual(2, model0.dataset_index) + self.assertEqual(3, transformer1.dataset_index) + self.assertEqual(4, model2.dataset_index) + self.assertEqual(5, transformer3.dataset_index) + self.assertEqual(6, dataset.index) + + def test_identity_pipeline(self): + dataset = MockDataset() + + def doTransform(pipeline): + pipeline_model = pipeline.fit(dataset) + return pipeline_model.transform(dataset) + # check that empty pipeline did not perform any transformation + self.assertEqual(dataset.index, doTransform(Pipeline(stages=[])).index) + # check that failure to set stages param will raise KeyError for missing param + self.assertRaises(KeyError, lambda: doTransform(Pipeline())) + + +if __name__ == "__main__": + from pyspark.ml.tests.test_pipeline import * + + try: + import xmlrunner + testRunner = xmlrunner.XMLTestRunner(output='target/test-reports') + except ImportError: + testRunner = None + unittest.main(testRunner=testRunner, verbosity=2) diff --git a/python/pyspark/ml/tests/test_stat.py b/python/pyspark/ml/tests/test_stat.py new file mode 100644 index 0000000000000..11aaf2e8083e1 --- /dev/null +++ b/python/pyspark/ml/tests/test_stat.py @@ -0,0 +1,50 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import unittest + +from pyspark.ml.linalg import Vectors +from pyspark.ml.stat import ChiSquareTest +from pyspark.sql import DataFrame +from pyspark.testing.mlutils import SparkSessionTestCase + + +class ChiSquareTestTests(SparkSessionTestCase): + + def test_chisquaretest(self): + data = [[0, Vectors.dense([0, 1, 2])], + [1, Vectors.dense([1, 1, 1])], + [2, Vectors.dense([2, 1, 0])]] + df = self.spark.createDataFrame(data, ['label', 'feat']) + res = ChiSquareTest.test(df, 'feat', 'label') + # This line is hitting the collect bug described in #17218, commented for now. + # pValues = res.select("degreesOfFreedom").collect()) + self.assertIsInstance(res, DataFrame) + fieldNames = set(field.name for field in res.schema.fields) + expectedFields = ["pValues", "degreesOfFreedom", "statistics"] + self.assertTrue(all(field in fieldNames for field in expectedFields)) + + +if __name__ == "__main__": + from pyspark.ml.tests.test_stat import * + + try: + import xmlrunner + testRunner = xmlrunner.XMLTestRunner(output='target/test-reports') + except ImportError: + testRunner = None + unittest.main(testRunner=testRunner, verbosity=2) diff --git a/python/pyspark/ml/tests/test_training_summary.py b/python/pyspark/ml/tests/test_training_summary.py new file mode 100644 index 0000000000000..8575111c84025 --- /dev/null +++ b/python/pyspark/ml/tests/test_training_summary.py @@ -0,0 +1,251 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import sys +import unittest + +if sys.version > '3': + basestring = str + +from pyspark.ml.classification import LogisticRegression +from pyspark.ml.clustering import BisectingKMeans, GaussianMixture, KMeans +from pyspark.ml.linalg import Vectors +from pyspark.ml.regression import GeneralizedLinearRegression, LinearRegression +from pyspark.sql import DataFrame +from pyspark.testing.mlutils import SparkSessionTestCase + + +class TrainingSummaryTest(SparkSessionTestCase): + + def test_linear_regression_summary(self): + df = self.spark.createDataFrame([(1.0, 2.0, Vectors.dense(1.0)), + (0.0, 2.0, Vectors.sparse(1, [], []))], + ["label", "weight", "features"]) + lr = LinearRegression(maxIter=5, regParam=0.0, solver="normal", weightCol="weight", + fitIntercept=False) + model = lr.fit(df) + self.assertTrue(model.hasSummary) + s = model.summary + # test that api is callable and returns expected types + self.assertGreater(s.totalIterations, 0) + self.assertTrue(isinstance(s.predictions, DataFrame)) + self.assertEqual(s.predictionCol, "prediction") + self.assertEqual(s.labelCol, "label") + self.assertEqual(s.featuresCol, "features") + objHist = s.objectiveHistory + self.assertTrue(isinstance(objHist, list) and isinstance(objHist[0], float)) + self.assertAlmostEqual(s.explainedVariance, 0.25, 2) + self.assertAlmostEqual(s.meanAbsoluteError, 0.0) + self.assertAlmostEqual(s.meanSquaredError, 0.0) + self.assertAlmostEqual(s.rootMeanSquaredError, 0.0) + self.assertAlmostEqual(s.r2, 1.0, 2) + self.assertAlmostEqual(s.r2adj, 1.0, 2) + self.assertTrue(isinstance(s.residuals, DataFrame)) + self.assertEqual(s.numInstances, 2) + self.assertEqual(s.degreesOfFreedom, 1) + devResiduals = s.devianceResiduals + self.assertTrue(isinstance(devResiduals, list) and isinstance(devResiduals[0], float)) + coefStdErr = s.coefficientStandardErrors + self.assertTrue(isinstance(coefStdErr, list) and isinstance(coefStdErr[0], float)) + tValues = s.tValues + self.assertTrue(isinstance(tValues, list) and isinstance(tValues[0], float)) + pValues = s.pValues + self.assertTrue(isinstance(pValues, list) and isinstance(pValues[0], float)) + # test evaluation (with training dataset) produces a summary with same values + # one check is enough to verify a summary is returned + # The child class LinearRegressionTrainingSummary runs full test + sameSummary = model.evaluate(df) + self.assertAlmostEqual(sameSummary.explainedVariance, s.explainedVariance) + + def test_glr_summary(self): + from pyspark.ml.linalg import Vectors + df = self.spark.createDataFrame([(1.0, 2.0, Vectors.dense(1.0)), + (0.0, 2.0, Vectors.sparse(1, [], []))], + ["label", "weight", "features"]) + glr = GeneralizedLinearRegression(family="gaussian", link="identity", weightCol="weight", + fitIntercept=False) + model = glr.fit(df) + self.assertTrue(model.hasSummary) + s = model.summary + # test that api is callable and returns expected types + self.assertEqual(s.numIterations, 1) # this should default to a single iteration of WLS + self.assertTrue(isinstance(s.predictions, DataFrame)) + self.assertEqual(s.predictionCol, "prediction") + self.assertEqual(s.numInstances, 2) + self.assertTrue(isinstance(s.residuals(), DataFrame)) + self.assertTrue(isinstance(s.residuals("pearson"), DataFrame)) + coefStdErr = s.coefficientStandardErrors + self.assertTrue(isinstance(coefStdErr, list) and isinstance(coefStdErr[0], float)) + tValues = s.tValues + self.assertTrue(isinstance(tValues, list) and isinstance(tValues[0], float)) + pValues = s.pValues + self.assertTrue(isinstance(pValues, list) and isinstance(pValues[0], float)) + self.assertEqual(s.degreesOfFreedom, 1) + self.assertEqual(s.residualDegreeOfFreedom, 1) + self.assertEqual(s.residualDegreeOfFreedomNull, 2) + self.assertEqual(s.rank, 1) + self.assertTrue(isinstance(s.solver, basestring)) + self.assertTrue(isinstance(s.aic, float)) + self.assertTrue(isinstance(s.deviance, float)) + self.assertTrue(isinstance(s.nullDeviance, float)) + self.assertTrue(isinstance(s.dispersion, float)) + # test evaluation (with training dataset) produces a summary with same values + # one check is enough to verify a summary is returned + # The child class GeneralizedLinearRegressionTrainingSummary runs full test + sameSummary = model.evaluate(df) + self.assertAlmostEqual(sameSummary.deviance, s.deviance) + + def test_binary_logistic_regression_summary(self): + df = self.spark.createDataFrame([(1.0, 2.0, Vectors.dense(1.0)), + (0.0, 2.0, Vectors.sparse(1, [], []))], + ["label", "weight", "features"]) + lr = LogisticRegression(maxIter=5, regParam=0.01, weightCol="weight", fitIntercept=False) + model = lr.fit(df) + self.assertTrue(model.hasSummary) + s = model.summary + # test that api is callable and returns expected types + self.assertTrue(isinstance(s.predictions, DataFrame)) + self.assertEqual(s.probabilityCol, "probability") + self.assertEqual(s.labelCol, "label") + self.assertEqual(s.featuresCol, "features") + self.assertEqual(s.predictionCol, "prediction") + objHist = s.objectiveHistory + self.assertTrue(isinstance(objHist, list) and isinstance(objHist[0], float)) + self.assertGreater(s.totalIterations, 0) + self.assertTrue(isinstance(s.labels, list)) + self.assertTrue(isinstance(s.truePositiveRateByLabel, list)) + self.assertTrue(isinstance(s.falsePositiveRateByLabel, list)) + self.assertTrue(isinstance(s.precisionByLabel, list)) + self.assertTrue(isinstance(s.recallByLabel, list)) + self.assertTrue(isinstance(s.fMeasureByLabel(), list)) + self.assertTrue(isinstance(s.fMeasureByLabel(1.0), list)) + self.assertTrue(isinstance(s.roc, DataFrame)) + self.assertAlmostEqual(s.areaUnderROC, 1.0, 2) + self.assertTrue(isinstance(s.pr, DataFrame)) + self.assertTrue(isinstance(s.fMeasureByThreshold, DataFrame)) + self.assertTrue(isinstance(s.precisionByThreshold, DataFrame)) + self.assertTrue(isinstance(s.recallByThreshold, DataFrame)) + self.assertAlmostEqual(s.accuracy, 1.0, 2) + self.assertAlmostEqual(s.weightedTruePositiveRate, 1.0, 2) + self.assertAlmostEqual(s.weightedFalsePositiveRate, 0.0, 2) + self.assertAlmostEqual(s.weightedRecall, 1.0, 2) + self.assertAlmostEqual(s.weightedPrecision, 1.0, 2) + self.assertAlmostEqual(s.weightedFMeasure(), 1.0, 2) + self.assertAlmostEqual(s.weightedFMeasure(1.0), 1.0, 2) + # test evaluation (with training dataset) produces a summary with same values + # one check is enough to verify a summary is returned, Scala version runs full test + sameSummary = model.evaluate(df) + self.assertAlmostEqual(sameSummary.areaUnderROC, s.areaUnderROC) + + def test_multiclass_logistic_regression_summary(self): + df = self.spark.createDataFrame([(1.0, 2.0, Vectors.dense(1.0)), + (0.0, 2.0, Vectors.sparse(1, [], [])), + (2.0, 2.0, Vectors.dense(2.0)), + (2.0, 2.0, Vectors.dense(1.9))], + ["label", "weight", "features"]) + lr = LogisticRegression(maxIter=5, regParam=0.01, weightCol="weight", fitIntercept=False) + model = lr.fit(df) + self.assertTrue(model.hasSummary) + s = model.summary + # test that api is callable and returns expected types + self.assertTrue(isinstance(s.predictions, DataFrame)) + self.assertEqual(s.probabilityCol, "probability") + self.assertEqual(s.labelCol, "label") + self.assertEqual(s.featuresCol, "features") + self.assertEqual(s.predictionCol, "prediction") + objHist = s.objectiveHistory + self.assertTrue(isinstance(objHist, list) and isinstance(objHist[0], float)) + self.assertGreater(s.totalIterations, 0) + self.assertTrue(isinstance(s.labels, list)) + self.assertTrue(isinstance(s.truePositiveRateByLabel, list)) + self.assertTrue(isinstance(s.falsePositiveRateByLabel, list)) + self.assertTrue(isinstance(s.precisionByLabel, list)) + self.assertTrue(isinstance(s.recallByLabel, list)) + self.assertTrue(isinstance(s.fMeasureByLabel(), list)) + self.assertTrue(isinstance(s.fMeasureByLabel(1.0), list)) + self.assertAlmostEqual(s.accuracy, 0.75, 2) + self.assertAlmostEqual(s.weightedTruePositiveRate, 0.75, 2) + self.assertAlmostEqual(s.weightedFalsePositiveRate, 0.25, 2) + self.assertAlmostEqual(s.weightedRecall, 0.75, 2) + self.assertAlmostEqual(s.weightedPrecision, 0.583, 2) + self.assertAlmostEqual(s.weightedFMeasure(), 0.65, 2) + self.assertAlmostEqual(s.weightedFMeasure(1.0), 0.65, 2) + # test evaluation (with training dataset) produces a summary with same values + # one check is enough to verify a summary is returned, Scala version runs full test + sameSummary = model.evaluate(df) + self.assertAlmostEqual(sameSummary.accuracy, s.accuracy) + + def test_gaussian_mixture_summary(self): + data = [(Vectors.dense(1.0),), (Vectors.dense(5.0),), (Vectors.dense(10.0),), + (Vectors.sparse(1, [], []),)] + df = self.spark.createDataFrame(data, ["features"]) + gmm = GaussianMixture(k=2) + model = gmm.fit(df) + self.assertTrue(model.hasSummary) + s = model.summary + self.assertTrue(isinstance(s.predictions, DataFrame)) + self.assertEqual(s.probabilityCol, "probability") + self.assertTrue(isinstance(s.probability, DataFrame)) + self.assertEqual(s.featuresCol, "features") + self.assertEqual(s.predictionCol, "prediction") + self.assertTrue(isinstance(s.cluster, DataFrame)) + self.assertEqual(len(s.clusterSizes), 2) + self.assertEqual(s.k, 2) + self.assertEqual(s.numIter, 3) + + def test_bisecting_kmeans_summary(self): + data = [(Vectors.dense(1.0),), (Vectors.dense(5.0),), (Vectors.dense(10.0),), + (Vectors.sparse(1, [], []),)] + df = self.spark.createDataFrame(data, ["features"]) + bkm = BisectingKMeans(k=2) + model = bkm.fit(df) + self.assertTrue(model.hasSummary) + s = model.summary + self.assertTrue(isinstance(s.predictions, DataFrame)) + self.assertEqual(s.featuresCol, "features") + self.assertEqual(s.predictionCol, "prediction") + self.assertTrue(isinstance(s.cluster, DataFrame)) + self.assertEqual(len(s.clusterSizes), 2) + self.assertEqual(s.k, 2) + self.assertEqual(s.numIter, 20) + + def test_kmeans_summary(self): + data = [(Vectors.dense([0.0, 0.0]),), (Vectors.dense([1.0, 1.0]),), + (Vectors.dense([9.0, 8.0]),), (Vectors.dense([8.0, 9.0]),)] + df = self.spark.createDataFrame(data, ["features"]) + kmeans = KMeans(k=2, seed=1) + model = kmeans.fit(df) + self.assertTrue(model.hasSummary) + s = model.summary + self.assertTrue(isinstance(s.predictions, DataFrame)) + self.assertEqual(s.featuresCol, "features") + self.assertEqual(s.predictionCol, "prediction") + self.assertTrue(isinstance(s.cluster, DataFrame)) + self.assertEqual(len(s.clusterSizes), 2) + self.assertEqual(s.k, 2) + self.assertEqual(s.numIter, 1) + + +if __name__ == "__main__": + from pyspark.ml.tests.test_training_summary import * + + try: + import xmlrunner + testRunner = xmlrunner.XMLTestRunner(output='target/test-reports') + except ImportError: + testRunner = None + unittest.main(testRunner=testRunner, verbosity=2) diff --git a/python/pyspark/ml/tests/test_tuning.py b/python/pyspark/ml/tests/test_tuning.py new file mode 100644 index 0000000000000..39bb921aaf43d --- /dev/null +++ b/python/pyspark/ml/tests/test_tuning.py @@ -0,0 +1,544 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import tempfile +import unittest + +from pyspark.ml import Estimator, Model +from pyspark.ml.classification import LogisticRegression, LogisticRegressionModel, OneVsRest +from pyspark.ml.evaluation import BinaryClassificationEvaluator, \ + MulticlassClassificationEvaluator, RegressionEvaluator +from pyspark.ml.linalg import Vectors +from pyspark.ml.param import Param, Params +from pyspark.ml.tuning import CrossValidator, CrossValidatorModel, ParamGridBuilder, \ + TrainValidationSplit, TrainValidationSplitModel +from pyspark.sql.functions import rand +from pyspark.testing.mlutils import SparkSessionTestCase + + +class HasInducedError(Params): + + def __init__(self): + super(HasInducedError, self).__init__() + self.inducedError = Param(self, "inducedError", + "Uniformly-distributed error added to feature") + + def getInducedError(self): + return self.getOrDefault(self.inducedError) + + +class InducedErrorModel(Model, HasInducedError): + + def __init__(self): + super(InducedErrorModel, self).__init__() + + def _transform(self, dataset): + return dataset.withColumn("prediction", + dataset.feature + (rand(0) * self.getInducedError())) + + +class InducedErrorEstimator(Estimator, HasInducedError): + + def __init__(self, inducedError=1.0): + super(InducedErrorEstimator, self).__init__() + self._set(inducedError=inducedError) + + def _fit(self, dataset): + model = InducedErrorModel() + self._copyValues(model) + return model + + +class CrossValidatorTests(SparkSessionTestCase): + + def test_copy(self): + dataset = self.spark.createDataFrame([ + (10, 10.0), + (50, 50.0), + (100, 100.0), + (500, 500.0)] * 10, + ["feature", "label"]) + + iee = InducedErrorEstimator() + evaluator = RegressionEvaluator(metricName="rmse") + + grid = (ParamGridBuilder() + .addGrid(iee.inducedError, [100.0, 0.0, 10000.0]) + .build()) + cv = CrossValidator(estimator=iee, estimatorParamMaps=grid, evaluator=evaluator) + cvCopied = cv.copy() + self.assertEqual(cv.getEstimator().uid, cvCopied.getEstimator().uid) + + cvModel = cv.fit(dataset) + cvModelCopied = cvModel.copy() + for index in range(len(cvModel.avgMetrics)): + self.assertTrue(abs(cvModel.avgMetrics[index] - cvModelCopied.avgMetrics[index]) + < 0.0001) + + def test_fit_minimize_metric(self): + dataset = self.spark.createDataFrame([ + (10, 10.0), + (50, 50.0), + (100, 100.0), + (500, 500.0)] * 10, + ["feature", "label"]) + + iee = InducedErrorEstimator() + evaluator = RegressionEvaluator(metricName="rmse") + + grid = (ParamGridBuilder() + .addGrid(iee.inducedError, [100.0, 0.0, 10000.0]) + .build()) + cv = CrossValidator(estimator=iee, estimatorParamMaps=grid, evaluator=evaluator) + cvModel = cv.fit(dataset) + bestModel = cvModel.bestModel + bestModelMetric = evaluator.evaluate(bestModel.transform(dataset)) + + self.assertEqual(0.0, bestModel.getOrDefault('inducedError'), + "Best model should have zero induced error") + self.assertEqual(0.0, bestModelMetric, "Best model has RMSE of 0") + + def test_fit_maximize_metric(self): + dataset = self.spark.createDataFrame([ + (10, 10.0), + (50, 50.0), + (100, 100.0), + (500, 500.0)] * 10, + ["feature", "label"]) + + iee = InducedErrorEstimator() + evaluator = RegressionEvaluator(metricName="r2") + + grid = (ParamGridBuilder() + .addGrid(iee.inducedError, [100.0, 0.0, 10000.0]) + .build()) + cv = CrossValidator(estimator=iee, estimatorParamMaps=grid, evaluator=evaluator) + cvModel = cv.fit(dataset) + bestModel = cvModel.bestModel + bestModelMetric = evaluator.evaluate(bestModel.transform(dataset)) + + self.assertEqual(0.0, bestModel.getOrDefault('inducedError'), + "Best model should have zero induced error") + self.assertEqual(1.0, bestModelMetric, "Best model has R-squared of 1") + + def test_param_grid_type_coercion(self): + lr = LogisticRegression(maxIter=10) + paramGrid = ParamGridBuilder().addGrid(lr.regParam, [0.5, 1]).build() + for param in paramGrid: + for v in param.values(): + assert(type(v) == float) + + def test_save_load_trained_model(self): + # This tests saving and loading the trained model only. + # Save/load for CrossValidator will be added later: SPARK-13786 + temp_path = tempfile.mkdtemp() + dataset = self.spark.createDataFrame( + [(Vectors.dense([0.0]), 0.0), + (Vectors.dense([0.4]), 1.0), + (Vectors.dense([0.5]), 0.0), + (Vectors.dense([0.6]), 1.0), + (Vectors.dense([1.0]), 1.0)] * 10, + ["features", "label"]) + lr = LogisticRegression() + grid = ParamGridBuilder().addGrid(lr.maxIter, [0, 1]).build() + evaluator = BinaryClassificationEvaluator() + cv = CrossValidator(estimator=lr, estimatorParamMaps=grid, evaluator=evaluator) + cvModel = cv.fit(dataset) + lrModel = cvModel.bestModel + + cvModelPath = temp_path + "/cvModel" + lrModel.save(cvModelPath) + loadedLrModel = LogisticRegressionModel.load(cvModelPath) + self.assertEqual(loadedLrModel.uid, lrModel.uid) + self.assertEqual(loadedLrModel.intercept, lrModel.intercept) + + def test_save_load_simple_estimator(self): + temp_path = tempfile.mkdtemp() + dataset = self.spark.createDataFrame( + [(Vectors.dense([0.0]), 0.0), + (Vectors.dense([0.4]), 1.0), + (Vectors.dense([0.5]), 0.0), + (Vectors.dense([0.6]), 1.0), + (Vectors.dense([1.0]), 1.0)] * 10, + ["features", "label"]) + + lr = LogisticRegression() + grid = ParamGridBuilder().addGrid(lr.maxIter, [0, 1]).build() + evaluator = BinaryClassificationEvaluator() + + # test save/load of CrossValidator + cv = CrossValidator(estimator=lr, estimatorParamMaps=grid, evaluator=evaluator) + cvModel = cv.fit(dataset) + cvPath = temp_path + "/cv" + cv.save(cvPath) + loadedCV = CrossValidator.load(cvPath) + self.assertEqual(loadedCV.getEstimator().uid, cv.getEstimator().uid) + self.assertEqual(loadedCV.getEvaluator().uid, cv.getEvaluator().uid) + self.assertEqual(loadedCV.getEstimatorParamMaps(), cv.getEstimatorParamMaps()) + + # test save/load of CrossValidatorModel + cvModelPath = temp_path + "/cvModel" + cvModel.save(cvModelPath) + loadedModel = CrossValidatorModel.load(cvModelPath) + self.assertEqual(loadedModel.bestModel.uid, cvModel.bestModel.uid) + + def test_parallel_evaluation(self): + dataset = self.spark.createDataFrame( + [(Vectors.dense([0.0]), 0.0), + (Vectors.dense([0.4]), 1.0), + (Vectors.dense([0.5]), 0.0), + (Vectors.dense([0.6]), 1.0), + (Vectors.dense([1.0]), 1.0)] * 10, + ["features", "label"]) + + lr = LogisticRegression() + grid = ParamGridBuilder().addGrid(lr.maxIter, [5, 6]).build() + evaluator = BinaryClassificationEvaluator() + + # test save/load of CrossValidator + cv = CrossValidator(estimator=lr, estimatorParamMaps=grid, evaluator=evaluator) + cv.setParallelism(1) + cvSerialModel = cv.fit(dataset) + cv.setParallelism(2) + cvParallelModel = cv.fit(dataset) + self.assertEqual(cvSerialModel.avgMetrics, cvParallelModel.avgMetrics) + + def test_expose_sub_models(self): + temp_path = tempfile.mkdtemp() + dataset = self.spark.createDataFrame( + [(Vectors.dense([0.0]), 0.0), + (Vectors.dense([0.4]), 1.0), + (Vectors.dense([0.5]), 0.0), + (Vectors.dense([0.6]), 1.0), + (Vectors.dense([1.0]), 1.0)] * 10, + ["features", "label"]) + + lr = LogisticRegression() + grid = ParamGridBuilder().addGrid(lr.maxIter, [0, 1]).build() + evaluator = BinaryClassificationEvaluator() + + numFolds = 3 + cv = CrossValidator(estimator=lr, estimatorParamMaps=grid, evaluator=evaluator, + numFolds=numFolds, collectSubModels=True) + + def checkSubModels(subModels): + self.assertEqual(len(subModels), numFolds) + for i in range(numFolds): + self.assertEqual(len(subModels[i]), len(grid)) + + cvModel = cv.fit(dataset) + checkSubModels(cvModel.subModels) + + # Test the default value for option "persistSubModel" to be "true" + testSubPath = temp_path + "/testCrossValidatorSubModels" + savingPathWithSubModels = testSubPath + "cvModel3" + cvModel.save(savingPathWithSubModels) + cvModel3 = CrossValidatorModel.load(savingPathWithSubModels) + checkSubModels(cvModel3.subModels) + cvModel4 = cvModel3.copy() + checkSubModels(cvModel4.subModels) + + savingPathWithoutSubModels = testSubPath + "cvModel2" + cvModel.write().option("persistSubModels", "false").save(savingPathWithoutSubModels) + cvModel2 = CrossValidatorModel.load(savingPathWithoutSubModels) + self.assertEqual(cvModel2.subModels, None) + + for i in range(numFolds): + for j in range(len(grid)): + self.assertEqual(cvModel.subModels[i][j].uid, cvModel3.subModels[i][j].uid) + + def test_save_load_nested_estimator(self): + temp_path = tempfile.mkdtemp() + dataset = self.spark.createDataFrame( + [(Vectors.dense([0.0]), 0.0), + (Vectors.dense([0.4]), 1.0), + (Vectors.dense([0.5]), 0.0), + (Vectors.dense([0.6]), 1.0), + (Vectors.dense([1.0]), 1.0)] * 10, + ["features", "label"]) + + ova = OneVsRest(classifier=LogisticRegression()) + lr1 = LogisticRegression().setMaxIter(100) + lr2 = LogisticRegression().setMaxIter(150) + grid = ParamGridBuilder().addGrid(ova.classifier, [lr1, lr2]).build() + evaluator = MulticlassClassificationEvaluator() + + # test save/load of CrossValidator + cv = CrossValidator(estimator=ova, estimatorParamMaps=grid, evaluator=evaluator) + cvModel = cv.fit(dataset) + cvPath = temp_path + "/cv" + cv.save(cvPath) + loadedCV = CrossValidator.load(cvPath) + self.assertEqual(loadedCV.getEstimator().uid, cv.getEstimator().uid) + self.assertEqual(loadedCV.getEvaluator().uid, cv.getEvaluator().uid) + + originalParamMap = cv.getEstimatorParamMaps() + loadedParamMap = loadedCV.getEstimatorParamMaps() + for i, param in enumerate(loadedParamMap): + for p in param: + if p.name == "classifier": + self.assertEqual(param[p].uid, originalParamMap[i][p].uid) + else: + self.assertEqual(param[p], originalParamMap[i][p]) + + # test save/load of CrossValidatorModel + cvModelPath = temp_path + "/cvModel" + cvModel.save(cvModelPath) + loadedModel = CrossValidatorModel.load(cvModelPath) + self.assertEqual(loadedModel.bestModel.uid, cvModel.bestModel.uid) + + +class TrainValidationSplitTests(SparkSessionTestCase): + + def test_fit_minimize_metric(self): + dataset = self.spark.createDataFrame([ + (10, 10.0), + (50, 50.0), + (100, 100.0), + (500, 500.0)] * 10, + ["feature", "label"]) + + iee = InducedErrorEstimator() + evaluator = RegressionEvaluator(metricName="rmse") + + grid = ParamGridBuilder() \ + .addGrid(iee.inducedError, [100.0, 0.0, 10000.0]) \ + .build() + tvs = TrainValidationSplit(estimator=iee, estimatorParamMaps=grid, evaluator=evaluator) + tvsModel = tvs.fit(dataset) + bestModel = tvsModel.bestModel + bestModelMetric = evaluator.evaluate(bestModel.transform(dataset)) + validationMetrics = tvsModel.validationMetrics + + self.assertEqual(0.0, bestModel.getOrDefault('inducedError'), + "Best model should have zero induced error") + self.assertEqual(0.0, bestModelMetric, "Best model has RMSE of 0") + self.assertEqual(len(grid), len(validationMetrics), + "validationMetrics has the same size of grid parameter") + self.assertEqual(0.0, min(validationMetrics)) + + def test_fit_maximize_metric(self): + dataset = self.spark.createDataFrame([ + (10, 10.0), + (50, 50.0), + (100, 100.0), + (500, 500.0)] * 10, + ["feature", "label"]) + + iee = InducedErrorEstimator() + evaluator = RegressionEvaluator(metricName="r2") + + grid = ParamGridBuilder() \ + .addGrid(iee.inducedError, [100.0, 0.0, 10000.0]) \ + .build() + tvs = TrainValidationSplit(estimator=iee, estimatorParamMaps=grid, evaluator=evaluator) + tvsModel = tvs.fit(dataset) + bestModel = tvsModel.bestModel + bestModelMetric = evaluator.evaluate(bestModel.transform(dataset)) + validationMetrics = tvsModel.validationMetrics + + self.assertEqual(0.0, bestModel.getOrDefault('inducedError'), + "Best model should have zero induced error") + self.assertEqual(1.0, bestModelMetric, "Best model has R-squared of 1") + self.assertEqual(len(grid), len(validationMetrics), + "validationMetrics has the same size of grid parameter") + self.assertEqual(1.0, max(validationMetrics)) + + def test_save_load_trained_model(self): + # This tests saving and loading the trained model only. + # Save/load for TrainValidationSplit will be added later: SPARK-13786 + temp_path = tempfile.mkdtemp() + dataset = self.spark.createDataFrame( + [(Vectors.dense([0.0]), 0.0), + (Vectors.dense([0.4]), 1.0), + (Vectors.dense([0.5]), 0.0), + (Vectors.dense([0.6]), 1.0), + (Vectors.dense([1.0]), 1.0)] * 10, + ["features", "label"]) + lr = LogisticRegression() + grid = ParamGridBuilder().addGrid(lr.maxIter, [0, 1]).build() + evaluator = BinaryClassificationEvaluator() + tvs = TrainValidationSplit(estimator=lr, estimatorParamMaps=grid, evaluator=evaluator) + tvsModel = tvs.fit(dataset) + lrModel = tvsModel.bestModel + + tvsModelPath = temp_path + "/tvsModel" + lrModel.save(tvsModelPath) + loadedLrModel = LogisticRegressionModel.load(tvsModelPath) + self.assertEqual(loadedLrModel.uid, lrModel.uid) + self.assertEqual(loadedLrModel.intercept, lrModel.intercept) + + def test_save_load_simple_estimator(self): + # This tests saving and loading the trained model only. + # Save/load for TrainValidationSplit will be added later: SPARK-13786 + temp_path = tempfile.mkdtemp() + dataset = self.spark.createDataFrame( + [(Vectors.dense([0.0]), 0.0), + (Vectors.dense([0.4]), 1.0), + (Vectors.dense([0.5]), 0.0), + (Vectors.dense([0.6]), 1.0), + (Vectors.dense([1.0]), 1.0)] * 10, + ["features", "label"]) + lr = LogisticRegression() + grid = ParamGridBuilder().addGrid(lr.maxIter, [0, 1]).build() + evaluator = BinaryClassificationEvaluator() + tvs = TrainValidationSplit(estimator=lr, estimatorParamMaps=grid, evaluator=evaluator) + tvsModel = tvs.fit(dataset) + + tvsPath = temp_path + "/tvs" + tvs.save(tvsPath) + loadedTvs = TrainValidationSplit.load(tvsPath) + self.assertEqual(loadedTvs.getEstimator().uid, tvs.getEstimator().uid) + self.assertEqual(loadedTvs.getEvaluator().uid, tvs.getEvaluator().uid) + self.assertEqual(loadedTvs.getEstimatorParamMaps(), tvs.getEstimatorParamMaps()) + + tvsModelPath = temp_path + "/tvsModel" + tvsModel.save(tvsModelPath) + loadedModel = TrainValidationSplitModel.load(tvsModelPath) + self.assertEqual(loadedModel.bestModel.uid, tvsModel.bestModel.uid) + + def test_parallel_evaluation(self): + dataset = self.spark.createDataFrame( + [(Vectors.dense([0.0]), 0.0), + (Vectors.dense([0.4]), 1.0), + (Vectors.dense([0.5]), 0.0), + (Vectors.dense([0.6]), 1.0), + (Vectors.dense([1.0]), 1.0)] * 10, + ["features", "label"]) + lr = LogisticRegression() + grid = ParamGridBuilder().addGrid(lr.maxIter, [5, 6]).build() + evaluator = BinaryClassificationEvaluator() + tvs = TrainValidationSplit(estimator=lr, estimatorParamMaps=grid, evaluator=evaluator) + tvs.setParallelism(1) + tvsSerialModel = tvs.fit(dataset) + tvs.setParallelism(2) + tvsParallelModel = tvs.fit(dataset) + self.assertEqual(tvsSerialModel.validationMetrics, tvsParallelModel.validationMetrics) + + def test_expose_sub_models(self): + temp_path = tempfile.mkdtemp() + dataset = self.spark.createDataFrame( + [(Vectors.dense([0.0]), 0.0), + (Vectors.dense([0.4]), 1.0), + (Vectors.dense([0.5]), 0.0), + (Vectors.dense([0.6]), 1.0), + (Vectors.dense([1.0]), 1.0)] * 10, + ["features", "label"]) + lr = LogisticRegression() + grid = ParamGridBuilder().addGrid(lr.maxIter, [0, 1]).build() + evaluator = BinaryClassificationEvaluator() + tvs = TrainValidationSplit(estimator=lr, estimatorParamMaps=grid, evaluator=evaluator, + collectSubModels=True) + tvsModel = tvs.fit(dataset) + self.assertEqual(len(tvsModel.subModels), len(grid)) + + # Test the default value for option "persistSubModel" to be "true" + testSubPath = temp_path + "/testTrainValidationSplitSubModels" + savingPathWithSubModels = testSubPath + "cvModel3" + tvsModel.save(savingPathWithSubModels) + tvsModel3 = TrainValidationSplitModel.load(savingPathWithSubModels) + self.assertEqual(len(tvsModel3.subModels), len(grid)) + tvsModel4 = tvsModel3.copy() + self.assertEqual(len(tvsModel4.subModels), len(grid)) + + savingPathWithoutSubModels = testSubPath + "cvModel2" + tvsModel.write().option("persistSubModels", "false").save(savingPathWithoutSubModels) + tvsModel2 = TrainValidationSplitModel.load(savingPathWithoutSubModels) + self.assertEqual(tvsModel2.subModels, None) + + for i in range(len(grid)): + self.assertEqual(tvsModel.subModels[i].uid, tvsModel3.subModels[i].uid) + + def test_save_load_nested_estimator(self): + # This tests saving and loading the trained model only. + # Save/load for TrainValidationSplit will be added later: SPARK-13786 + temp_path = tempfile.mkdtemp() + dataset = self.spark.createDataFrame( + [(Vectors.dense([0.0]), 0.0), + (Vectors.dense([0.4]), 1.0), + (Vectors.dense([0.5]), 0.0), + (Vectors.dense([0.6]), 1.0), + (Vectors.dense([1.0]), 1.0)] * 10, + ["features", "label"]) + ova = OneVsRest(classifier=LogisticRegression()) + lr1 = LogisticRegression().setMaxIter(100) + lr2 = LogisticRegression().setMaxIter(150) + grid = ParamGridBuilder().addGrid(ova.classifier, [lr1, lr2]).build() + evaluator = MulticlassClassificationEvaluator() + + tvs = TrainValidationSplit(estimator=ova, estimatorParamMaps=grid, evaluator=evaluator) + tvsModel = tvs.fit(dataset) + tvsPath = temp_path + "/tvs" + tvs.save(tvsPath) + loadedTvs = TrainValidationSplit.load(tvsPath) + self.assertEqual(loadedTvs.getEstimator().uid, tvs.getEstimator().uid) + self.assertEqual(loadedTvs.getEvaluator().uid, tvs.getEvaluator().uid) + + originalParamMap = tvs.getEstimatorParamMaps() + loadedParamMap = loadedTvs.getEstimatorParamMaps() + for i, param in enumerate(loadedParamMap): + for p in param: + if p.name == "classifier": + self.assertEqual(param[p].uid, originalParamMap[i][p].uid) + else: + self.assertEqual(param[p], originalParamMap[i][p]) + + tvsModelPath = temp_path + "/tvsModel" + tvsModel.save(tvsModelPath) + loadedModel = TrainValidationSplitModel.load(tvsModelPath) + self.assertEqual(loadedModel.bestModel.uid, tvsModel.bestModel.uid) + + def test_copy(self): + dataset = self.spark.createDataFrame([ + (10, 10.0), + (50, 50.0), + (100, 100.0), + (500, 500.0)] * 10, + ["feature", "label"]) + + iee = InducedErrorEstimator() + evaluator = RegressionEvaluator(metricName="r2") + + grid = ParamGridBuilder() \ + .addGrid(iee.inducedError, [100.0, 0.0, 10000.0]) \ + .build() + tvs = TrainValidationSplit(estimator=iee, estimatorParamMaps=grid, evaluator=evaluator) + tvsModel = tvs.fit(dataset) + tvsCopied = tvs.copy() + tvsModelCopied = tvsModel.copy() + + self.assertEqual(tvs.getEstimator().uid, tvsCopied.getEstimator().uid, + "Copied TrainValidationSplit has the same uid of Estimator") + + self.assertEqual(tvsModel.bestModel.uid, tvsModelCopied.bestModel.uid) + self.assertEqual(len(tvsModel.validationMetrics), + len(tvsModelCopied.validationMetrics), + "Copied validationMetrics has the same size of the original") + for index in range(len(tvsModel.validationMetrics)): + self.assertEqual(tvsModel.validationMetrics[index], + tvsModelCopied.validationMetrics[index]) + + +if __name__ == "__main__": + from pyspark.ml.tests.test_tuning import * + + try: + import xmlrunner + testRunner = xmlrunner.XMLTestRunner(output='target/test-reports') + except ImportError: + testRunner = None + unittest.main(testRunner=testRunner, verbosity=2) diff --git a/python/pyspark/ml/tests/test_wrapper.py b/python/pyspark/ml/tests/test_wrapper.py new file mode 100644 index 0000000000000..ae672a00c1dc1 --- /dev/null +++ b/python/pyspark/ml/tests/test_wrapper.py @@ -0,0 +1,112 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import unittest + +import py4j + +from pyspark.ml.linalg import DenseVector, Vectors +from pyspark.ml.regression import LinearRegression +from pyspark.ml.wrapper import _java2py, _py2java, JavaParams, JavaWrapper +from pyspark.testing.mllibutils import MLlibTestCase +from pyspark.testing.mlutils import SparkSessionTestCase + + +class JavaWrapperMemoryTests(SparkSessionTestCase): + + def test_java_object_gets_detached(self): + df = self.spark.createDataFrame([(1.0, 2.0, Vectors.dense(1.0)), + (0.0, 2.0, Vectors.sparse(1, [], []))], + ["label", "weight", "features"]) + lr = LinearRegression(maxIter=1, regParam=0.0, solver="normal", weightCol="weight", + fitIntercept=False) + + model = lr.fit(df) + summary = model.summary + + self.assertIsInstance(model, JavaWrapper) + self.assertIsInstance(summary, JavaWrapper) + self.assertIsInstance(model, JavaParams) + self.assertNotIsInstance(summary, JavaParams) + + error_no_object = 'Target Object ID does not exist for this gateway' + + self.assertIn("LinearRegression_", model._java_obj.toString()) + self.assertIn("LinearRegressionTrainingSummary", summary._java_obj.toString()) + + model.__del__() + + with self.assertRaisesRegexp(py4j.protocol.Py4JError, error_no_object): + model._java_obj.toString() + self.assertIn("LinearRegressionTrainingSummary", summary._java_obj.toString()) + + try: + summary.__del__() + except: + pass + + with self.assertRaisesRegexp(py4j.protocol.Py4JError, error_no_object): + model._java_obj.toString() + with self.assertRaisesRegexp(py4j.protocol.Py4JError, error_no_object): + summary._java_obj.toString() + + +class WrapperTests(MLlibTestCase): + + def test_new_java_array(self): + # test array of strings + str_list = ["a", "b", "c"] + java_class = self.sc._gateway.jvm.java.lang.String + java_array = JavaWrapper._new_java_array(str_list, java_class) + self.assertEqual(_java2py(self.sc, java_array), str_list) + # test array of integers + int_list = [1, 2, 3] + java_class = self.sc._gateway.jvm.java.lang.Integer + java_array = JavaWrapper._new_java_array(int_list, java_class) + self.assertEqual(_java2py(self.sc, java_array), int_list) + # test array of floats + float_list = [0.1, 0.2, 0.3] + java_class = self.sc._gateway.jvm.java.lang.Double + java_array = JavaWrapper._new_java_array(float_list, java_class) + self.assertEqual(_java2py(self.sc, java_array), float_list) + # test array of bools + bool_list = [False, True, True] + java_class = self.sc._gateway.jvm.java.lang.Boolean + java_array = JavaWrapper._new_java_array(bool_list, java_class) + self.assertEqual(_java2py(self.sc, java_array), bool_list) + # test array of Java DenseVectors + v1 = DenseVector([0.0, 1.0]) + v2 = DenseVector([1.0, 0.0]) + vec_java_list = [_py2java(self.sc, v1), _py2java(self.sc, v2)] + java_class = self.sc._gateway.jvm.org.apache.spark.ml.linalg.DenseVector + java_array = JavaWrapper._new_java_array(vec_java_list, java_class) + self.assertEqual(_java2py(self.sc, java_array), [v1, v2]) + # test empty array + java_class = self.sc._gateway.jvm.java.lang.Integer + java_array = JavaWrapper._new_java_array([], java_class) + self.assertEqual(_java2py(self.sc, java_array), []) + + +if __name__ == "__main__": + from pyspark.ml.tests.test_wrapper import * + + try: + import xmlrunner + testRunner = xmlrunner.XMLTestRunner(output='target/test-reports') + except ImportError: + testRunner = None + unittest.main(testRunner=testRunner, verbosity=2) diff --git a/python/pyspark/mllib/clustering.py b/python/pyspark/mllib/clustering.py index b1a8af6bcc094..4f4355ddb60ee 100644 --- a/python/pyspark/mllib/clustering.py +++ b/python/pyspark/mllib/clustering.py @@ -184,7 +184,7 @@ class KMeansModel(Saveable, Loader): >>> model.k 2 >>> model.computeCost(sc.parallelize(data)) - 2.0000000000000004 + 2.0 >>> model = KMeans.train(sc.parallelize(data), 2) >>> sparse_data = [ ... SparseVector(3, {1: 1.0}), diff --git a/python/pyspark/mllib/fpm.py b/python/pyspark/mllib/fpm.py index de18dad1f675d..6accb9b4926e8 100644 --- a/python/pyspark/mllib/fpm.py +++ b/python/pyspark/mllib/fpm.py @@ -132,7 +132,7 @@ class PrefixSpan(object): A parallel PrefixSpan algorithm to mine frequent sequential patterns. The PrefixSpan algorithm is described in J. Pei, et al., PrefixSpan: Mining Sequential Patterns Efficiently by Prefix-Projected Pattern Growth - ([[http://doi.org/10.1109/ICDE.2001.914830]]). + ([[https://doi.org/10.1109/ICDE.2001.914830]]). .. versionadded:: 1.6.0 """ diff --git a/python/pyspark/mllib/linalg/distributed.py b/python/pyspark/mllib/linalg/distributed.py index 7e8b15056cabe..b7f09782be9dd 100644 --- a/python/pyspark/mllib/linalg/distributed.py +++ b/python/pyspark/mllib/linalg/distributed.py @@ -270,7 +270,7 @@ def tallSkinnyQR(self, computeQ=False): Reference: Paul G. Constantine, David F. Gleich. "Tall and skinny QR factorizations in MapReduce architectures" - ([[http://dx.doi.org/10.1145/1996092.1996103]]) + ([[https://doi.org/10.1145/1996092.1996103]]) :param: computeQ: whether to computeQ :return: QRDecomposition(Q: RowMatrix, R: Matrix), where diff --git a/python/pyspark/mllib/tests.py b/python/pyspark/mllib/tests.py deleted file mode 100644 index 653f5cb9ff4a2..0000000000000 --- a/python/pyspark/mllib/tests.py +++ /dev/null @@ -1,1788 +0,0 @@ -# -# Licensed to the Apache Software Foundation (ASF) under one or more -# contributor license agreements. See the NOTICE file distributed with -# this work for additional information regarding copyright ownership. -# The ASF licenses this file to You under the Apache License, Version 2.0 -# (the "License"); you may not use this file except in compliance with -# the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# - -""" -Fuller unit tests for Python MLlib. -""" - -import os -import sys -import tempfile -import array as pyarray -from math import sqrt -from time import time, sleep -from shutil import rmtree - -import unishark -from numpy import ( - array, array_equal, zeros, inf, random, exp, dot, all, mean, abs, arange, tile, ones) -from numpy import sum as array_sum - -from py4j.protocol import Py4JJavaError - -if sys.version > '3': - basestring = str - -if sys.version_info[:2] <= (2, 6): - try: - import unittest2 as unittest - except ImportError: - sys.stderr.write('Please install unittest2 to test with Python 2.6 or earlier') - sys.exit(1) -else: - import unittest - -from pyspark import SparkContext -import pyspark.ml.linalg as newlinalg -from pyspark.mllib.common import _to_java_object_rdd -from pyspark.mllib.clustering import StreamingKMeans, StreamingKMeansModel -from pyspark.mllib.linalg import Vector, SparseVector, DenseVector, VectorUDT, _convert_to_vector,\ - DenseMatrix, SparseMatrix, Vectors, Matrices, MatrixUDT -from pyspark.mllib.linalg.distributed import RowMatrix -from pyspark.mllib.classification import StreamingLogisticRegressionWithSGD -from pyspark.mllib.fpm import FPGrowth -from pyspark.mllib.recommendation import Rating -from pyspark.mllib.regression import LabeledPoint, StreamingLinearRegressionWithSGD -from pyspark.mllib.random import RandomRDDs -from pyspark.mllib.stat import Statistics -from pyspark.mllib.feature import HashingTF -from pyspark.mllib.feature import Word2Vec -from pyspark.mllib.feature import IDF -from pyspark.mllib.feature import StandardScaler, ElementwiseProduct -from pyspark.mllib.util import LinearDataGenerator -from pyspark.mllib.util import MLUtils -from pyspark.serializers import PickleSerializer -from pyspark.streaming import StreamingContext -from pyspark.sql import SparkSession -from pyspark.sql.utils import IllegalArgumentException -from pyspark.streaming import StreamingContext - -_have_scipy = False -try: - import scipy.sparse - _have_scipy = True -except: - # No SciPy, but that's okay, we'll skip those tests - pass - -ser = PickleSerializer() - - -class MLlibTestCase(unittest.TestCase): - def setUp(self): - self.sc = SparkContext('local[4]', "MLlib tests") - self.spark = SparkSession(self.sc) - - def tearDown(self): - self.spark.stop() - - -class MLLibStreamingTestCase(unittest.TestCase): - def setUp(self): - self.sc = SparkContext('local[4]', "MLlib tests") - self.ssc = StreamingContext(self.sc, 1.0) - - def tearDown(self): - self.ssc.stop(False) - self.sc.stop() - - @staticmethod - def _eventually(condition, timeout=30.0, catch_assertions=False): - """ - Wait a given amount of time for a condition to pass, else fail with an error. - This is a helper utility for streaming ML tests. - :param condition: Function that checks for termination conditions. - condition() can return: - - True: Conditions met. Return without error. - - other value: Conditions not met yet. Continue. Upon timeout, - include last such value in error message. - Note that this method may be called at any time during - streaming execution (e.g., even before any results - have been created). - :param timeout: Number of seconds to wait. Default 30 seconds. - :param catch_assertions: If False (default), do not catch AssertionErrors. - If True, catch AssertionErrors; continue, but save - error to throw upon timeout. - """ - start_time = time() - lastValue = None - while time() - start_time < timeout: - if catch_assertions: - try: - lastValue = condition() - except AssertionError as e: - lastValue = e - else: - lastValue = condition() - if lastValue is True: - return - sleep(0.01) - if isinstance(lastValue, AssertionError): - raise lastValue - else: - raise AssertionError( - "Test failed due to timeout after %g sec, with last condition returning: %s" - % (timeout, lastValue)) - - -def _squared_distance(a, b): - if isinstance(a, Vector): - return a.squared_distance(b) - else: - return b.squared_distance(a) - - -class VectorTests(MLlibTestCase): - - def _test_serialize(self, v): - self.assertEqual(v, ser.loads(ser.dumps(v))) - jvec = self.sc._jvm.org.apache.spark.mllib.api.python.SerDe.loads(bytearray(ser.dumps(v))) - nv = ser.loads(bytes(self.sc._jvm.org.apache.spark.mllib.api.python.SerDe.dumps(jvec))) - self.assertEqual(v, nv) - vs = [v] * 100 - jvecs = self.sc._jvm.org.apache.spark.mllib.api.python.SerDe.loads(bytearray(ser.dumps(vs))) - nvs = ser.loads(bytes(self.sc._jvm.org.apache.spark.mllib.api.python.SerDe.dumps(jvecs))) - self.assertEqual(vs, nvs) - - def test_serialize(self): - self._test_serialize(DenseVector(range(10))) - self._test_serialize(DenseVector(array([1., 2., 3., 4.]))) - self._test_serialize(DenseVector(pyarray.array('d', range(10)))) - self._test_serialize(SparseVector(4, {1: 1, 3: 2})) - self._test_serialize(SparseVector(3, {})) - self._test_serialize(DenseMatrix(2, 3, range(6))) - sm1 = SparseMatrix( - 3, 4, [0, 2, 2, 4, 4], [1, 2, 1, 2], [1.0, 2.0, 4.0, 5.0]) - self._test_serialize(sm1) - - def test_dot(self): - sv = SparseVector(4, {1: 1, 3: 2}) - dv = DenseVector(array([1., 2., 3., 4.])) - lst = DenseVector([1, 2, 3, 4]) - mat = array([[1., 2., 3., 4.], - [1., 2., 3., 4.], - [1., 2., 3., 4.], - [1., 2., 3., 4.]]) - arr = pyarray.array('d', [0, 1, 2, 3]) - self.assertEqual(10.0, sv.dot(dv)) - self.assertTrue(array_equal(array([3., 6., 9., 12.]), sv.dot(mat))) - self.assertEqual(30.0, dv.dot(dv)) - self.assertTrue(array_equal(array([10., 20., 30., 40.]), dv.dot(mat))) - self.assertEqual(30.0, lst.dot(dv)) - self.assertTrue(array_equal(array([10., 20., 30., 40.]), lst.dot(mat))) - self.assertEqual(7.0, sv.dot(arr)) - - def test_squared_distance(self): - sv = SparseVector(4, {1: 1, 3: 2}) - dv = DenseVector(array([1., 2., 3., 4.])) - lst = DenseVector([4, 3, 2, 1]) - lst1 = [4, 3, 2, 1] - arr = pyarray.array('d', [0, 2, 1, 3]) - narr = array([0, 2, 1, 3]) - self.assertEqual(15.0, _squared_distance(sv, dv)) - self.assertEqual(25.0, _squared_distance(sv, lst)) - self.assertEqual(20.0, _squared_distance(dv, lst)) - self.assertEqual(15.0, _squared_distance(dv, sv)) - self.assertEqual(25.0, _squared_distance(lst, sv)) - self.assertEqual(20.0, _squared_distance(lst, dv)) - self.assertEqual(0.0, _squared_distance(sv, sv)) - self.assertEqual(0.0, _squared_distance(dv, dv)) - self.assertEqual(0.0, _squared_distance(lst, lst)) - self.assertEqual(25.0, _squared_distance(sv, lst1)) - self.assertEqual(3.0, _squared_distance(sv, arr)) - self.assertEqual(3.0, _squared_distance(sv, narr)) - - def test_hash(self): - v1 = DenseVector([0.0, 1.0, 0.0, 5.5]) - v2 = SparseVector(4, [(1, 1.0), (3, 5.5)]) - v3 = DenseVector([0.0, 1.0, 0.0, 5.5]) - v4 = SparseVector(4, [(1, 1.0), (3, 2.5)]) - self.assertEqual(hash(v1), hash(v2)) - self.assertEqual(hash(v1), hash(v3)) - self.assertEqual(hash(v2), hash(v3)) - self.assertFalse(hash(v1) == hash(v4)) - self.assertFalse(hash(v2) == hash(v4)) - - def test_eq(self): - v1 = DenseVector([0.0, 1.0, 0.0, 5.5]) - v2 = SparseVector(4, [(1, 1.0), (3, 5.5)]) - v3 = DenseVector([0.0, 1.0, 0.0, 5.5]) - v4 = SparseVector(6, [(1, 1.0), (3, 5.5)]) - v5 = DenseVector([0.0, 1.0, 0.0, 2.5]) - v6 = SparseVector(4, [(1, 1.0), (3, 2.5)]) - self.assertEqual(v1, v2) - self.assertEqual(v1, v3) - self.assertFalse(v2 == v4) - self.assertFalse(v1 == v5) - self.assertFalse(v1 == v6) - - def test_equals(self): - indices = [1, 2, 4] - values = [1., 3., 2.] - self.assertTrue(Vectors._equals(indices, values, list(range(5)), [0., 1., 3., 0., 2.])) - self.assertFalse(Vectors._equals(indices, values, list(range(5)), [0., 3., 1., 0., 2.])) - self.assertFalse(Vectors._equals(indices, values, list(range(5)), [0., 3., 0., 2.])) - self.assertFalse(Vectors._equals(indices, values, list(range(5)), [0., 1., 3., 2., 2.])) - - def test_conversion(self): - # numpy arrays should be automatically upcast to float64 - # tests for fix of [SPARK-5089] - v = array([1, 2, 3, 4], dtype='float64') - dv = DenseVector(v) - self.assertTrue(dv.array.dtype == 'float64') - v = array([1, 2, 3, 4], dtype='float32') - dv = DenseVector(v) - self.assertTrue(dv.array.dtype == 'float64') - - def test_sparse_vector_indexing(self): - sv = SparseVector(5, {1: 1, 3: 2}) - self.assertEqual(sv[0], 0.) - self.assertEqual(sv[3], 2.) - self.assertEqual(sv[1], 1.) - self.assertEqual(sv[2], 0.) - self.assertEqual(sv[4], 0.) - self.assertEqual(sv[-1], 0.) - self.assertEqual(sv[-2], 2.) - self.assertEqual(sv[-3], 0.) - self.assertEqual(sv[-5], 0.) - for ind in [5, -6]: - self.assertRaises(IndexError, sv.__getitem__, ind) - for ind in [7.8, '1']: - self.assertRaises(TypeError, sv.__getitem__, ind) - - zeros = SparseVector(4, {}) - self.assertEqual(zeros[0], 0.0) - self.assertEqual(zeros[3], 0.0) - for ind in [4, -5]: - self.assertRaises(IndexError, zeros.__getitem__, ind) - - empty = SparseVector(0, {}) - for ind in [-1, 0, 1]: - self.assertRaises(IndexError, empty.__getitem__, ind) - - def test_sparse_vector_iteration(self): - self.assertListEqual(list(SparseVector(3, [], [])), [0.0, 0.0, 0.0]) - self.assertListEqual(list(SparseVector(5, [0, 3], [1.0, 2.0])), [1.0, 0.0, 0.0, 2.0, 0.0]) - - def test_matrix_indexing(self): - mat = DenseMatrix(3, 2, [0, 1, 4, 6, 8, 10]) - expected = [[0, 6], [1, 8], [4, 10]] - for i in range(3): - for j in range(2): - self.assertEqual(mat[i, j], expected[i][j]) - - for i, j in [(-1, 0), (4, 1), (3, 4)]: - self.assertRaises(IndexError, mat.__getitem__, (i, j)) - - def test_repr_dense_matrix(self): - mat = DenseMatrix(3, 2, [0, 1, 4, 6, 8, 10]) - self.assertTrue( - repr(mat), - 'DenseMatrix(3, 2, [0.0, 1.0, 4.0, 6.0, 8.0, 10.0], False)') - - mat = DenseMatrix(3, 2, [0, 1, 4, 6, 8, 10], True) - self.assertTrue( - repr(mat), - 'DenseMatrix(3, 2, [0.0, 1.0, 4.0, 6.0, 8.0, 10.0], False)') - - mat = DenseMatrix(6, 3, zeros(18)) - self.assertTrue( - repr(mat), - 'DenseMatrix(6, 3, [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ..., \ - 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], False)') - - def test_repr_sparse_matrix(self): - sm1t = SparseMatrix( - 3, 4, [0, 2, 3, 5], [0, 1, 2, 0, 2], [3.0, 2.0, 4.0, 9.0, 8.0], - isTransposed=True) - self.assertTrue( - repr(sm1t), - 'SparseMatrix(3, 4, [0, 2, 3, 5], [0, 1, 2, 0, 2], [3.0, 2.0, 4.0, 9.0, 8.0], True)') - - indices = tile(arange(6), 3) - values = ones(18) - sm = SparseMatrix(6, 3, [0, 6, 12, 18], indices, values) - self.assertTrue( - repr(sm), "SparseMatrix(6, 3, [0, 6, 12, 18], \ - [0, 1, 2, 3, 4, 5, 0, 1, ..., 4, 5, 0, 1, 2, 3, 4, 5], \ - [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, ..., \ - 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0], False)") - - self.assertTrue( - str(sm), - "6 X 3 CSCMatrix\n\ - (0,0) 1.0\n(1,0) 1.0\n(2,0) 1.0\n(3,0) 1.0\n(4,0) 1.0\n(5,0) 1.0\n\ - (0,1) 1.0\n(1,1) 1.0\n(2,1) 1.0\n(3,1) 1.0\n(4,1) 1.0\n(5,1) 1.0\n\ - (0,2) 1.0\n(1,2) 1.0\n(2,2) 1.0\n(3,2) 1.0\n..\n..") - - sm = SparseMatrix(1, 18, zeros(19), [], []) - self.assertTrue( - repr(sm), - 'SparseMatrix(1, 18, \ - [0, 0, 0, 0, 0, 0, 0, 0, ..., 0, 0, 0, 0, 0, 0, 0, 0], [], [], False)') - - def test_sparse_matrix(self): - # Test sparse matrix creation. - sm1 = SparseMatrix( - 3, 4, [0, 2, 2, 4, 4], [1, 2, 1, 2], [1.0, 2.0, 4.0, 5.0]) - self.assertEqual(sm1.numRows, 3) - self.assertEqual(sm1.numCols, 4) - self.assertEqual(sm1.colPtrs.tolist(), [0, 2, 2, 4, 4]) - self.assertEqual(sm1.rowIndices.tolist(), [1, 2, 1, 2]) - self.assertEqual(sm1.values.tolist(), [1.0, 2.0, 4.0, 5.0]) - self.assertTrue( - repr(sm1), - 'SparseMatrix(3, 4, [0, 2, 2, 4, 4], [1, 2, 1, 2], [1.0, 2.0, 4.0, 5.0], False)') - - # Test indexing - expected = [ - [0, 0, 0, 0], - [1, 0, 4, 0], - [2, 0, 5, 0]] - - for i in range(3): - for j in range(4): - self.assertEqual(expected[i][j], sm1[i, j]) - self.assertTrue(array_equal(sm1.toArray(), expected)) - - for i, j in [(-1, 1), (4, 3), (3, 5)]: - self.assertRaises(IndexError, sm1.__getitem__, (i, j)) - - # Test conversion to dense and sparse. - smnew = sm1.toDense().toSparse() - self.assertEqual(sm1.numRows, smnew.numRows) - self.assertEqual(sm1.numCols, smnew.numCols) - self.assertTrue(array_equal(sm1.colPtrs, smnew.colPtrs)) - self.assertTrue(array_equal(sm1.rowIndices, smnew.rowIndices)) - self.assertTrue(array_equal(sm1.values, smnew.values)) - - sm1t = SparseMatrix( - 3, 4, [0, 2, 3, 5], [0, 1, 2, 0, 2], [3.0, 2.0, 4.0, 9.0, 8.0], - isTransposed=True) - self.assertEqual(sm1t.numRows, 3) - self.assertEqual(sm1t.numCols, 4) - self.assertEqual(sm1t.colPtrs.tolist(), [0, 2, 3, 5]) - self.assertEqual(sm1t.rowIndices.tolist(), [0, 1, 2, 0, 2]) - self.assertEqual(sm1t.values.tolist(), [3.0, 2.0, 4.0, 9.0, 8.0]) - - expected = [ - [3, 2, 0, 0], - [0, 0, 4, 0], - [9, 0, 8, 0]] - - for i in range(3): - for j in range(4): - self.assertEqual(expected[i][j], sm1t[i, j]) - self.assertTrue(array_equal(sm1t.toArray(), expected)) - - def test_dense_matrix_is_transposed(self): - mat1 = DenseMatrix(3, 2, [0, 4, 1, 6, 3, 9], isTransposed=True) - mat = DenseMatrix(3, 2, [0, 1, 3, 4, 6, 9]) - self.assertEqual(mat1, mat) - - expected = [[0, 4], [1, 6], [3, 9]] - for i in range(3): - for j in range(2): - self.assertEqual(mat1[i, j], expected[i][j]) - self.assertTrue(array_equal(mat1.toArray(), expected)) - - sm = mat1.toSparse() - self.assertTrue(array_equal(sm.rowIndices, [1, 2, 0, 1, 2])) - self.assertTrue(array_equal(sm.colPtrs, [0, 2, 5])) - self.assertTrue(array_equal(sm.values, [1, 3, 4, 6, 9])) - - def test_parse_vector(self): - a = DenseVector([]) - self.assertEqual(str(a), '[]') - self.assertEqual(Vectors.parse(str(a)), a) - a = DenseVector([3, 4, 6, 7]) - self.assertEqual(str(a), '[3.0,4.0,6.0,7.0]') - self.assertEqual(Vectors.parse(str(a)), a) - a = SparseVector(4, [], []) - self.assertEqual(str(a), '(4,[],[])') - self.assertEqual(SparseVector.parse(str(a)), a) - a = SparseVector(4, [0, 2], [3, 4]) - self.assertEqual(str(a), '(4,[0,2],[3.0,4.0])') - self.assertEqual(Vectors.parse(str(a)), a) - a = SparseVector(10, [0, 1], [4, 5]) - self.assertEqual(SparseVector.parse(' (10, [0,1 ],[ 4.0,5.0] )'), a) - - def test_norms(self): - a = DenseVector([0, 2, 3, -1]) - self.assertAlmostEqual(a.norm(2), 3.742, 3) - self.assertTrue(a.norm(1), 6) - self.assertTrue(a.norm(inf), 3) - a = SparseVector(4, [0, 2], [3, -4]) - self.assertAlmostEqual(a.norm(2), 5) - self.assertTrue(a.norm(1), 7) - self.assertTrue(a.norm(inf), 4) - - tmp = SparseVector(4, [0, 2], [3, 0]) - self.assertEqual(tmp.numNonzeros(), 1) - - def test_ml_mllib_vector_conversion(self): - # to ml - # dense - mllibDV = Vectors.dense([1, 2, 3]) - mlDV1 = newlinalg.Vectors.dense([1, 2, 3]) - mlDV2 = mllibDV.asML() - self.assertEqual(mlDV2, mlDV1) - # sparse - mllibSV = Vectors.sparse(4, {1: 1.0, 3: 5.5}) - mlSV1 = newlinalg.Vectors.sparse(4, {1: 1.0, 3: 5.5}) - mlSV2 = mllibSV.asML() - self.assertEqual(mlSV2, mlSV1) - # from ml - # dense - mllibDV1 = Vectors.dense([1, 2, 3]) - mlDV = newlinalg.Vectors.dense([1, 2, 3]) - mllibDV2 = Vectors.fromML(mlDV) - self.assertEqual(mllibDV1, mllibDV2) - # sparse - mllibSV1 = Vectors.sparse(4, {1: 1.0, 3: 5.5}) - mlSV = newlinalg.Vectors.sparse(4, {1: 1.0, 3: 5.5}) - mllibSV2 = Vectors.fromML(mlSV) - self.assertEqual(mllibSV1, mllibSV2) - - def test_ml_mllib_matrix_conversion(self): - # to ml - # dense - mllibDM = Matrices.dense(2, 2, [0, 1, 2, 3]) - mlDM1 = newlinalg.Matrices.dense(2, 2, [0, 1, 2, 3]) - mlDM2 = mllibDM.asML() - self.assertEqual(mlDM2, mlDM1) - # transposed - mllibDMt = DenseMatrix(2, 2, [0, 1, 2, 3], True) - mlDMt1 = newlinalg.DenseMatrix(2, 2, [0, 1, 2, 3], True) - mlDMt2 = mllibDMt.asML() - self.assertEqual(mlDMt2, mlDMt1) - # sparse - mllibSM = Matrices.sparse(2, 2, [0, 2, 3], [0, 1, 1], [2, 3, 4]) - mlSM1 = newlinalg.Matrices.sparse(2, 2, [0, 2, 3], [0, 1, 1], [2, 3, 4]) - mlSM2 = mllibSM.asML() - self.assertEqual(mlSM2, mlSM1) - # transposed - mllibSMt = SparseMatrix(2, 2, [0, 2, 3], [0, 1, 1], [2, 3, 4], True) - mlSMt1 = newlinalg.SparseMatrix(2, 2, [0, 2, 3], [0, 1, 1], [2, 3, 4], True) - mlSMt2 = mllibSMt.asML() - self.assertEqual(mlSMt2, mlSMt1) - # from ml - # dense - mllibDM1 = Matrices.dense(2, 2, [1, 2, 3, 4]) - mlDM = newlinalg.Matrices.dense(2, 2, [1, 2, 3, 4]) - mllibDM2 = Matrices.fromML(mlDM) - self.assertEqual(mllibDM1, mllibDM2) - # transposed - mllibDMt1 = DenseMatrix(2, 2, [1, 2, 3, 4], True) - mlDMt = newlinalg.DenseMatrix(2, 2, [1, 2, 3, 4], True) - mllibDMt2 = Matrices.fromML(mlDMt) - self.assertEqual(mllibDMt1, mllibDMt2) - # sparse - mllibSM1 = Matrices.sparse(2, 2, [0, 2, 3], [0, 1, 1], [2, 3, 4]) - mlSM = newlinalg.Matrices.sparse(2, 2, [0, 2, 3], [0, 1, 1], [2, 3, 4]) - mllibSM2 = Matrices.fromML(mlSM) - self.assertEqual(mllibSM1, mllibSM2) - # transposed - mllibSMt1 = SparseMatrix(2, 2, [0, 2, 3], [0, 1, 1], [2, 3, 4], True) - mlSMt = newlinalg.SparseMatrix(2, 2, [0, 2, 3], [0, 1, 1], [2, 3, 4], True) - mllibSMt2 = Matrices.fromML(mlSMt) - self.assertEqual(mllibSMt1, mllibSMt2) - - -class ListTests(MLlibTestCase): - - """ - Test MLlib algorithms on plain lists, to make sure they're passed through - as NumPy arrays. - """ - - def test_bisecting_kmeans(self): - from pyspark.mllib.clustering import BisectingKMeans - data = array([0.0, 0.0, 1.0, 1.0, 9.0, 8.0, 8.0, 9.0]).reshape(4, 2) - bskm = BisectingKMeans() - model = bskm.train(self.sc.parallelize(data, 2), k=4) - p = array([0.0, 0.0]) - rdd_p = self.sc.parallelize([p]) - self.assertEqual(model.predict(p), model.predict(rdd_p).first()) - self.assertEqual(model.computeCost(p), model.computeCost(rdd_p)) - self.assertEqual(model.k, len(model.clusterCenters)) - - def test_kmeans(self): - from pyspark.mllib.clustering import KMeans - data = [ - [0, 1.1], - [0, 1.2], - [1.1, 0], - [1.2, 0], - ] - clusters = KMeans.train(self.sc.parallelize(data), 2, initializationMode="k-means||", - initializationSteps=7, epsilon=1e-4) - self.assertEqual(clusters.predict(data[0]), clusters.predict(data[1])) - self.assertEqual(clusters.predict(data[2]), clusters.predict(data[3])) - - def test_kmeans_deterministic(self): - from pyspark.mllib.clustering import KMeans - X = range(0, 100, 10) - Y = range(0, 100, 10) - data = [[x, y] for x, y in zip(X, Y)] - clusters1 = KMeans.train(self.sc.parallelize(data), - 3, initializationMode="k-means||", - seed=42, initializationSteps=7, epsilon=1e-4) - clusters2 = KMeans.train(self.sc.parallelize(data), - 3, initializationMode="k-means||", - seed=42, initializationSteps=7, epsilon=1e-4) - centers1 = clusters1.centers - centers2 = clusters2.centers - for c1, c2 in zip(centers1, centers2): - # TODO: Allow small numeric difference. - self.assertTrue(array_equal(c1, c2)) - - def test_gmm(self): - from pyspark.mllib.clustering import GaussianMixture - data = self.sc.parallelize([ - [1, 2], - [8, 9], - [-4, -3], - [-6, -7], - ]) - clusters = GaussianMixture.train(data, 2, convergenceTol=0.001, - maxIterations=10, seed=1) - labels = clusters.predict(data).collect() - self.assertEqual(labels[0], labels[1]) - self.assertEqual(labels[2], labels[3]) - - def test_gmm_deterministic(self): - from pyspark.mllib.clustering import GaussianMixture - x = range(0, 100, 10) - y = range(0, 100, 10) - data = self.sc.parallelize([[a, b] for a, b in zip(x, y)]) - clusters1 = GaussianMixture.train(data, 5, convergenceTol=0.001, - maxIterations=10, seed=63) - clusters2 = GaussianMixture.train(data, 5, convergenceTol=0.001, - maxIterations=10, seed=63) - for c1, c2 in zip(clusters1.weights, clusters2.weights): - self.assertEqual(round(c1, 7), round(c2, 7)) - - def test_gmm_with_initial_model(self): - from pyspark.mllib.clustering import GaussianMixture - data = self.sc.parallelize([ - (-10, -5), (-9, -4), (10, 5), (9, 4) - ]) - - gmm1 = GaussianMixture.train(data, 2, convergenceTol=0.001, - maxIterations=10, seed=63) - gmm2 = GaussianMixture.train(data, 2, convergenceTol=0.001, - maxIterations=10, seed=63, initialModel=gmm1) - self.assertAlmostEqual((gmm1.weights - gmm2.weights).sum(), 0.0) - - def test_classification(self): - from pyspark.mllib.classification import LogisticRegressionWithSGD, SVMWithSGD, NaiveBayes - from pyspark.mllib.tree import DecisionTree, DecisionTreeModel, RandomForest,\ - RandomForestModel, GradientBoostedTrees, GradientBoostedTreesModel - data = [ - LabeledPoint(0.0, [1, 0, 0]), - LabeledPoint(1.0, [0, 1, 1]), - LabeledPoint(0.0, [2, 0, 0]), - LabeledPoint(1.0, [0, 2, 1]) - ] - rdd = self.sc.parallelize(data) - features = [p.features.tolist() for p in data] - - temp_dir = tempfile.mkdtemp() - - lr_model = LogisticRegressionWithSGD.train(rdd, iterations=10) - self.assertTrue(lr_model.predict(features[0]) <= 0) - self.assertTrue(lr_model.predict(features[1]) > 0) - self.assertTrue(lr_model.predict(features[2]) <= 0) - self.assertTrue(lr_model.predict(features[3]) > 0) - - svm_model = SVMWithSGD.train(rdd, iterations=10) - self.assertTrue(svm_model.predict(features[0]) <= 0) - self.assertTrue(svm_model.predict(features[1]) > 0) - self.assertTrue(svm_model.predict(features[2]) <= 0) - self.assertTrue(svm_model.predict(features[3]) > 0) - - nb_model = NaiveBayes.train(rdd) - self.assertTrue(nb_model.predict(features[0]) <= 0) - self.assertTrue(nb_model.predict(features[1]) > 0) - self.assertTrue(nb_model.predict(features[2]) <= 0) - self.assertTrue(nb_model.predict(features[3]) > 0) - - categoricalFeaturesInfo = {0: 3} # feature 0 has 3 categories - dt_model = DecisionTree.trainClassifier( - rdd, numClasses=2, categoricalFeaturesInfo=categoricalFeaturesInfo, maxBins=4) - self.assertTrue(dt_model.predict(features[0]) <= 0) - self.assertTrue(dt_model.predict(features[1]) > 0) - self.assertTrue(dt_model.predict(features[2]) <= 0) - self.assertTrue(dt_model.predict(features[3]) > 0) - - dt_model_dir = os.path.join(temp_dir, "dt") - dt_model.save(self.sc, dt_model_dir) - same_dt_model = DecisionTreeModel.load(self.sc, dt_model_dir) - self.assertEqual(same_dt_model.toDebugString(), dt_model.toDebugString()) - - rf_model = RandomForest.trainClassifier( - rdd, numClasses=2, categoricalFeaturesInfo=categoricalFeaturesInfo, numTrees=10, - maxBins=4, seed=1) - self.assertTrue(rf_model.predict(features[0]) <= 0) - self.assertTrue(rf_model.predict(features[1]) > 0) - self.assertTrue(rf_model.predict(features[2]) <= 0) - self.assertTrue(rf_model.predict(features[3]) > 0) - - rf_model_dir = os.path.join(temp_dir, "rf") - rf_model.save(self.sc, rf_model_dir) - same_rf_model = RandomForestModel.load(self.sc, rf_model_dir) - self.assertEqual(same_rf_model.toDebugString(), rf_model.toDebugString()) - - gbt_model = GradientBoostedTrees.trainClassifier( - rdd, categoricalFeaturesInfo=categoricalFeaturesInfo, numIterations=4) - self.assertTrue(gbt_model.predict(features[0]) <= 0) - self.assertTrue(gbt_model.predict(features[1]) > 0) - self.assertTrue(gbt_model.predict(features[2]) <= 0) - self.assertTrue(gbt_model.predict(features[3]) > 0) - - gbt_model_dir = os.path.join(temp_dir, "gbt") - gbt_model.save(self.sc, gbt_model_dir) - same_gbt_model = GradientBoostedTreesModel.load(self.sc, gbt_model_dir) - self.assertEqual(same_gbt_model.toDebugString(), gbt_model.toDebugString()) - - try: - rmtree(temp_dir) - except OSError: - pass - - def test_regression(self): - from pyspark.mllib.regression import LinearRegressionWithSGD, LassoWithSGD, \ - RidgeRegressionWithSGD - from pyspark.mllib.tree import DecisionTree, RandomForest, GradientBoostedTrees - data = [ - LabeledPoint(-1.0, [0, -1]), - LabeledPoint(1.0, [0, 1]), - LabeledPoint(-1.0, [0, -2]), - LabeledPoint(1.0, [0, 2]) - ] - rdd = self.sc.parallelize(data) - features = [p.features.tolist() for p in data] - - lr_model = LinearRegressionWithSGD.train(rdd, iterations=10) - self.assertTrue(lr_model.predict(features[0]) <= 0) - self.assertTrue(lr_model.predict(features[1]) > 0) - self.assertTrue(lr_model.predict(features[2]) <= 0) - self.assertTrue(lr_model.predict(features[3]) > 0) - - lasso_model = LassoWithSGD.train(rdd, iterations=10) - self.assertTrue(lasso_model.predict(features[0]) <= 0) - self.assertTrue(lasso_model.predict(features[1]) > 0) - self.assertTrue(lasso_model.predict(features[2]) <= 0) - self.assertTrue(lasso_model.predict(features[3]) > 0) - - rr_model = RidgeRegressionWithSGD.train(rdd, iterations=10) - self.assertTrue(rr_model.predict(features[0]) <= 0) - self.assertTrue(rr_model.predict(features[1]) > 0) - self.assertTrue(rr_model.predict(features[2]) <= 0) - self.assertTrue(rr_model.predict(features[3]) > 0) - - categoricalFeaturesInfo = {0: 2} # feature 0 has 2 categories - dt_model = DecisionTree.trainRegressor( - rdd, categoricalFeaturesInfo=categoricalFeaturesInfo, maxBins=4) - self.assertTrue(dt_model.predict(features[0]) <= 0) - self.assertTrue(dt_model.predict(features[1]) > 0) - self.assertTrue(dt_model.predict(features[2]) <= 0) - self.assertTrue(dt_model.predict(features[3]) > 0) - - rf_model = RandomForest.trainRegressor( - rdd, categoricalFeaturesInfo=categoricalFeaturesInfo, numTrees=10, maxBins=4, seed=1) - self.assertTrue(rf_model.predict(features[0]) <= 0) - self.assertTrue(rf_model.predict(features[1]) > 0) - self.assertTrue(rf_model.predict(features[2]) <= 0) - self.assertTrue(rf_model.predict(features[3]) > 0) - - gbt_model = GradientBoostedTrees.trainRegressor( - rdd, categoricalFeaturesInfo=categoricalFeaturesInfo, numIterations=4) - self.assertTrue(gbt_model.predict(features[0]) <= 0) - self.assertTrue(gbt_model.predict(features[1]) > 0) - self.assertTrue(gbt_model.predict(features[2]) <= 0) - self.assertTrue(gbt_model.predict(features[3]) > 0) - - try: - LinearRegressionWithSGD.train(rdd, initialWeights=array([1.0, 1.0]), iterations=10) - LassoWithSGD.train(rdd, initialWeights=array([1.0, 1.0]), iterations=10) - RidgeRegressionWithSGD.train(rdd, initialWeights=array([1.0, 1.0]), iterations=10) - except ValueError: - self.fail() - - # Verify that maxBins is being passed through - GradientBoostedTrees.trainRegressor( - rdd, categoricalFeaturesInfo=categoricalFeaturesInfo, numIterations=4, maxBins=32) - with self.assertRaises(Exception) as cm: - GradientBoostedTrees.trainRegressor( - rdd, categoricalFeaturesInfo=categoricalFeaturesInfo, numIterations=4, maxBins=1) - - -class StatTests(MLlibTestCase): - # SPARK-4023 - def test_col_with_different_rdds(self): - # numpy - data = RandomRDDs.normalVectorRDD(self.sc, 1000, 10, 10) - summary = Statistics.colStats(data) - self.assertEqual(1000, summary.count()) - # array - data = self.sc.parallelize([range(10)] * 10) - summary = Statistics.colStats(data) - self.assertEqual(10, summary.count()) - # array - data = self.sc.parallelize([pyarray.array("d", range(10))] * 10) - summary = Statistics.colStats(data) - self.assertEqual(10, summary.count()) - - def test_col_norms(self): - data = RandomRDDs.normalVectorRDD(self.sc, 1000, 10, 10) - summary = Statistics.colStats(data) - self.assertEqual(10, len(summary.normL1())) - self.assertEqual(10, len(summary.normL2())) - - data2 = self.sc.parallelize(range(10)).map(lambda x: Vectors.dense(x)) - summary2 = Statistics.colStats(data2) - self.assertEqual(array([45.0]), summary2.normL1()) - import math - expectedNormL2 = math.sqrt(sum(map(lambda x: x*x, range(10)))) - self.assertTrue(math.fabs(summary2.normL2()[0] - expectedNormL2) < 1e-14) - - -class VectorUDTTests(MLlibTestCase): - - dv0 = DenseVector([]) - dv1 = DenseVector([1.0, 2.0]) - sv0 = SparseVector(2, [], []) - sv1 = SparseVector(2, [1], [2.0]) - udt = VectorUDT() - - def test_json_schema(self): - self.assertEqual(VectorUDT.fromJson(self.udt.jsonValue()), self.udt) - - def test_serialization(self): - for v in [self.dv0, self.dv1, self.sv0, self.sv1]: - self.assertEqual(v, self.udt.deserialize(self.udt.serialize(v))) - - def test_infer_schema(self): - rdd = self.sc.parallelize([LabeledPoint(1.0, self.dv1), LabeledPoint(0.0, self.sv1)]) - df = rdd.toDF() - schema = df.schema - field = [f for f in schema.fields if f.name == "features"][0] - self.assertEqual(field.dataType, self.udt) - vectors = df.rdd.map(lambda p: p.features).collect() - self.assertEqual(len(vectors), 2) - for v in vectors: - if isinstance(v, SparseVector): - self.assertEqual(v, self.sv1) - elif isinstance(v, DenseVector): - self.assertEqual(v, self.dv1) - else: - raise TypeError("expecting a vector but got %r of type %r" % (v, type(v))) - - -class MatrixUDTTests(MLlibTestCase): - - dm1 = DenseMatrix(3, 2, [0, 1, 4, 5, 9, 10]) - dm2 = DenseMatrix(3, 2, [0, 1, 4, 5, 9, 10], isTransposed=True) - sm1 = SparseMatrix(1, 1, [0, 1], [0], [2.0]) - sm2 = SparseMatrix(2, 1, [0, 0, 1], [0], [5.0], isTransposed=True) - udt = MatrixUDT() - - def test_json_schema(self): - self.assertEqual(MatrixUDT.fromJson(self.udt.jsonValue()), self.udt) - - def test_serialization(self): - for m in [self.dm1, self.dm2, self.sm1, self.sm2]: - self.assertEqual(m, self.udt.deserialize(self.udt.serialize(m))) - - def test_infer_schema(self): - rdd = self.sc.parallelize([("dense", self.dm1), ("sparse", self.sm1)]) - df = rdd.toDF() - schema = df.schema - self.assertTrue(schema.fields[1].dataType, self.udt) - matrices = df.rdd.map(lambda x: x._2).collect() - self.assertEqual(len(matrices), 2) - for m in matrices: - if isinstance(m, DenseMatrix): - self.assertTrue(m, self.dm1) - elif isinstance(m, SparseMatrix): - self.assertTrue(m, self.sm1) - else: - raise ValueError("Expected a matrix but got type %r" % type(m)) - - -@unittest.skipIf(not _have_scipy, "SciPy not installed") -class SciPyTests(MLlibTestCase): - - """ - Test both vector operations and MLlib algorithms with SciPy sparse matrices, - if SciPy is available. - """ - - def test_serialize(self): - from scipy.sparse import lil_matrix - lil = lil_matrix((4, 1)) - lil[1, 0] = 1 - lil[3, 0] = 2 - sv = SparseVector(4, {1: 1, 3: 2}) - self.assertEqual(sv, _convert_to_vector(lil)) - self.assertEqual(sv, _convert_to_vector(lil.tocsc())) - self.assertEqual(sv, _convert_to_vector(lil.tocoo())) - self.assertEqual(sv, _convert_to_vector(lil.tocsr())) - self.assertEqual(sv, _convert_to_vector(lil.todok())) - - def serialize(l): - return ser.loads(ser.dumps(_convert_to_vector(l))) - self.assertEqual(sv, serialize(lil)) - self.assertEqual(sv, serialize(lil.tocsc())) - self.assertEqual(sv, serialize(lil.tocsr())) - self.assertEqual(sv, serialize(lil.todok())) - - def test_convert_to_vector(self): - from scipy.sparse import csc_matrix - # Create a CSC matrix with non-sorted indices - indptr = array([0, 2]) - indices = array([3, 1]) - data = array([2.0, 1.0]) - csc = csc_matrix((data, indices, indptr)) - self.assertFalse(csc.has_sorted_indices) - sv = SparseVector(4, {1: 1, 3: 2}) - self.assertEqual(sv, _convert_to_vector(csc)) - - def test_dot(self): - from scipy.sparse import lil_matrix - lil = lil_matrix((4, 1)) - lil[1, 0] = 1 - lil[3, 0] = 2 - dv = DenseVector(array([1., 2., 3., 4.])) - self.assertEqual(10.0, dv.dot(lil)) - - def test_squared_distance(self): - from scipy.sparse import lil_matrix - lil = lil_matrix((4, 1)) - lil[1, 0] = 3 - lil[3, 0] = 2 - dv = DenseVector(array([1., 2., 3., 4.])) - sv = SparseVector(4, {0: 1, 1: 2, 2: 3, 3: 4}) - self.assertEqual(15.0, dv.squared_distance(lil)) - self.assertEqual(15.0, sv.squared_distance(lil)) - - def scipy_matrix(self, size, values): - """Create a column SciPy matrix from a dictionary of values""" - from scipy.sparse import lil_matrix - lil = lil_matrix((size, 1)) - for key, value in values.items(): - lil[key, 0] = value - return lil - - def test_clustering(self): - from pyspark.mllib.clustering import KMeans - data = [ - self.scipy_matrix(3, {1: 1.0}), - self.scipy_matrix(3, {1: 1.1}), - self.scipy_matrix(3, {2: 1.0}), - self.scipy_matrix(3, {2: 1.1}) - ] - clusters = KMeans.train(self.sc.parallelize(data), 2, initializationMode="k-means||") - self.assertEqual(clusters.predict(data[0]), clusters.predict(data[1])) - self.assertEqual(clusters.predict(data[2]), clusters.predict(data[3])) - - def test_classification(self): - from pyspark.mllib.classification import LogisticRegressionWithSGD, SVMWithSGD, NaiveBayes - from pyspark.mllib.tree import DecisionTree - data = [ - LabeledPoint(0.0, self.scipy_matrix(2, {0: 1.0})), - LabeledPoint(1.0, self.scipy_matrix(2, {1: 1.0})), - LabeledPoint(0.0, self.scipy_matrix(2, {0: 2.0})), - LabeledPoint(1.0, self.scipy_matrix(2, {1: 2.0})) - ] - rdd = self.sc.parallelize(data) - features = [p.features for p in data] - - lr_model = LogisticRegressionWithSGD.train(rdd) - self.assertTrue(lr_model.predict(features[0]) <= 0) - self.assertTrue(lr_model.predict(features[1]) > 0) - self.assertTrue(lr_model.predict(features[2]) <= 0) - self.assertTrue(lr_model.predict(features[3]) > 0) - - svm_model = SVMWithSGD.train(rdd) - self.assertTrue(svm_model.predict(features[0]) <= 0) - self.assertTrue(svm_model.predict(features[1]) > 0) - self.assertTrue(svm_model.predict(features[2]) <= 0) - self.assertTrue(svm_model.predict(features[3]) > 0) - - nb_model = NaiveBayes.train(rdd) - self.assertTrue(nb_model.predict(features[0]) <= 0) - self.assertTrue(nb_model.predict(features[1]) > 0) - self.assertTrue(nb_model.predict(features[2]) <= 0) - self.assertTrue(nb_model.predict(features[3]) > 0) - - categoricalFeaturesInfo = {0: 3} # feature 0 has 3 categories - dt_model = DecisionTree.trainClassifier(rdd, numClasses=2, - categoricalFeaturesInfo=categoricalFeaturesInfo) - self.assertTrue(dt_model.predict(features[0]) <= 0) - self.assertTrue(dt_model.predict(features[1]) > 0) - self.assertTrue(dt_model.predict(features[2]) <= 0) - self.assertTrue(dt_model.predict(features[3]) > 0) - - def test_regression(self): - from pyspark.mllib.regression import LinearRegressionWithSGD, LassoWithSGD, \ - RidgeRegressionWithSGD - from pyspark.mllib.tree import DecisionTree - data = [ - LabeledPoint(-1.0, self.scipy_matrix(2, {1: -1.0})), - LabeledPoint(1.0, self.scipy_matrix(2, {1: 1.0})), - LabeledPoint(-1.0, self.scipy_matrix(2, {1: -2.0})), - LabeledPoint(1.0, self.scipy_matrix(2, {1: 2.0})) - ] - rdd = self.sc.parallelize(data) - features = [p.features for p in data] - - lr_model = LinearRegressionWithSGD.train(rdd) - self.assertTrue(lr_model.predict(features[0]) <= 0) - self.assertTrue(lr_model.predict(features[1]) > 0) - self.assertTrue(lr_model.predict(features[2]) <= 0) - self.assertTrue(lr_model.predict(features[3]) > 0) - - lasso_model = LassoWithSGD.train(rdd) - self.assertTrue(lasso_model.predict(features[0]) <= 0) - self.assertTrue(lasso_model.predict(features[1]) > 0) - self.assertTrue(lasso_model.predict(features[2]) <= 0) - self.assertTrue(lasso_model.predict(features[3]) > 0) - - rr_model = RidgeRegressionWithSGD.train(rdd) - self.assertTrue(rr_model.predict(features[0]) <= 0) - self.assertTrue(rr_model.predict(features[1]) > 0) - self.assertTrue(rr_model.predict(features[2]) <= 0) - self.assertTrue(rr_model.predict(features[3]) > 0) - - categoricalFeaturesInfo = {0: 2} # feature 0 has 2 categories - dt_model = DecisionTree.trainRegressor(rdd, categoricalFeaturesInfo=categoricalFeaturesInfo) - self.assertTrue(dt_model.predict(features[0]) <= 0) - self.assertTrue(dt_model.predict(features[1]) > 0) - self.assertTrue(dt_model.predict(features[2]) <= 0) - self.assertTrue(dt_model.predict(features[3]) > 0) - - -class ChiSqTestTests(MLlibTestCase): - def test_goodness_of_fit(self): - from numpy import inf - - observed = Vectors.dense([4, 6, 5]) - pearson = Statistics.chiSqTest(observed) - - # Validated against the R command `chisq.test(c(4, 6, 5), p=c(1/3, 1/3, 1/3))` - self.assertEqual(pearson.statistic, 0.4) - self.assertEqual(pearson.degreesOfFreedom, 2) - self.assertAlmostEqual(pearson.pValue, 0.8187, 4) - - # Different expected and observed sum - observed1 = Vectors.dense([21, 38, 43, 80]) - expected1 = Vectors.dense([3, 5, 7, 20]) - pearson1 = Statistics.chiSqTest(observed1, expected1) - - # Results validated against the R command - # `chisq.test(c(21, 38, 43, 80), p=c(3/35, 1/7, 1/5, 4/7))` - self.assertAlmostEqual(pearson1.statistic, 14.1429, 4) - self.assertEqual(pearson1.degreesOfFreedom, 3) - self.assertAlmostEqual(pearson1.pValue, 0.002717, 4) - - # Vectors with different sizes - observed3 = Vectors.dense([1.0, 2.0, 3.0]) - expected3 = Vectors.dense([1.0, 2.0, 3.0, 4.0]) - self.assertRaises(ValueError, Statistics.chiSqTest, observed3, expected3) - - # Negative counts in observed - neg_obs = Vectors.dense([1.0, 2.0, 3.0, -4.0]) - self.assertRaises(IllegalArgumentException, Statistics.chiSqTest, neg_obs, expected1) - - # Count = 0.0 in expected but not observed - zero_expected = Vectors.dense([1.0, 0.0, 3.0]) - pearson_inf = Statistics.chiSqTest(observed, zero_expected) - self.assertEqual(pearson_inf.statistic, inf) - self.assertEqual(pearson_inf.degreesOfFreedom, 2) - self.assertEqual(pearson_inf.pValue, 0.0) - - # 0.0 in expected and observed simultaneously - zero_observed = Vectors.dense([2.0, 0.0, 1.0]) - self.assertRaises( - IllegalArgumentException, Statistics.chiSqTest, zero_observed, zero_expected) - - def test_matrix_independence(self): - data = [40.0, 24.0, 29.0, 56.0, 32.0, 42.0, 31.0, 10.0, 0.0, 30.0, 15.0, 12.0] - chi = Statistics.chiSqTest(Matrices.dense(3, 4, data)) - - # Results validated against R command - # `chisq.test(rbind(c(40, 56, 31, 30),c(24, 32, 10, 15), c(29, 42, 0, 12)))` - self.assertAlmostEqual(chi.statistic, 21.9958, 4) - self.assertEqual(chi.degreesOfFreedom, 6) - self.assertAlmostEqual(chi.pValue, 0.001213, 4) - - # Negative counts - neg_counts = Matrices.dense(2, 2, [4.0, 5.0, 3.0, -3.0]) - self.assertRaises(IllegalArgumentException, Statistics.chiSqTest, neg_counts) - - # Row sum = 0.0 - row_zero = Matrices.dense(2, 2, [0.0, 1.0, 0.0, 2.0]) - self.assertRaises(IllegalArgumentException, Statistics.chiSqTest, row_zero) - - # Column sum = 0.0 - col_zero = Matrices.dense(2, 2, [0.0, 0.0, 2.0, 2.0]) - self.assertRaises(IllegalArgumentException, Statistics.chiSqTest, col_zero) - - def test_chi_sq_pearson(self): - data = [ - LabeledPoint(0.0, Vectors.dense([0.5, 10.0])), - LabeledPoint(0.0, Vectors.dense([1.5, 20.0])), - LabeledPoint(1.0, Vectors.dense([1.5, 30.0])), - LabeledPoint(0.0, Vectors.dense([3.5, 30.0])), - LabeledPoint(0.0, Vectors.dense([3.5, 40.0])), - LabeledPoint(1.0, Vectors.dense([3.5, 40.0])) - ] - - for numParts in [2, 4, 6, 8]: - chi = Statistics.chiSqTest(self.sc.parallelize(data, numParts)) - feature1 = chi[0] - self.assertEqual(feature1.statistic, 0.75) - self.assertEqual(feature1.degreesOfFreedom, 2) - self.assertAlmostEqual(feature1.pValue, 0.6873, 4) - - feature2 = chi[1] - self.assertEqual(feature2.statistic, 1.5) - self.assertEqual(feature2.degreesOfFreedom, 3) - self.assertAlmostEqual(feature2.pValue, 0.6823, 4) - - def test_right_number_of_results(self): - num_cols = 1001 - sparse_data = [ - LabeledPoint(0.0, Vectors.sparse(num_cols, [(100, 2.0)])), - LabeledPoint(0.1, Vectors.sparse(num_cols, [(200, 1.0)])) - ] - chi = Statistics.chiSqTest(self.sc.parallelize(sparse_data)) - self.assertEqual(len(chi), num_cols) - self.assertIsNotNone(chi[1000]) - - -class KolmogorovSmirnovTest(MLlibTestCase): - - def test_R_implementation_equivalence(self): - data = self.sc.parallelize([ - 1.1626852897838, -0.585924465893051, 1.78546500331661, -1.33259371048501, - -0.446566766553219, 0.569606122374976, -2.88971761441412, -0.869018343326555, - -0.461702683149641, -0.555540910137444, -0.0201353678515895, -0.150382224136063, - -0.628126755843964, 1.32322085193283, -1.52135057001199, -0.437427868856691, - 0.970577579543399, 0.0282226444247749, -0.0857821886527593, 0.389214404984942 - ]) - model = Statistics.kolmogorovSmirnovTest(data, "norm") - self.assertAlmostEqual(model.statistic, 0.189, 3) - self.assertAlmostEqual(model.pValue, 0.422, 3) - - model = Statistics.kolmogorovSmirnovTest(data, "norm", 0, 1) - self.assertAlmostEqual(model.statistic, 0.189, 3) - self.assertAlmostEqual(model.pValue, 0.422, 3) - - -class SerDeTest(MLlibTestCase): - def test_to_java_object_rdd(self): # SPARK-6660 - data = RandomRDDs.uniformRDD(self.sc, 10, 5, seed=0) - self.assertEqual(_to_java_object_rdd(data).count(), 10) - - -class FeatureTest(MLlibTestCase): - def test_idf_model(self): - data = [ - Vectors.dense([1, 2, 6, 0, 2, 3, 1, 1, 0, 0, 3]), - Vectors.dense([1, 3, 0, 1, 3, 0, 0, 2, 0, 0, 1]), - Vectors.dense([1, 4, 1, 0, 0, 4, 9, 0, 1, 2, 0]), - Vectors.dense([2, 1, 0, 3, 0, 0, 5, 0, 2, 3, 9]) - ] - model = IDF().fit(self.sc.parallelize(data, 2)) - idf = model.idf() - self.assertEqual(len(idf), 11) - - -class Word2VecTests(MLlibTestCase): - def test_word2vec_setters(self): - model = Word2Vec() \ - .setVectorSize(2) \ - .setLearningRate(0.01) \ - .setNumPartitions(2) \ - .setNumIterations(10) \ - .setSeed(1024) \ - .setMinCount(3) \ - .setWindowSize(6) - self.assertEqual(model.vectorSize, 2) - self.assertTrue(model.learningRate < 0.02) - self.assertEqual(model.numPartitions, 2) - self.assertEqual(model.numIterations, 10) - self.assertEqual(model.seed, 1024) - self.assertEqual(model.minCount, 3) - self.assertEqual(model.windowSize, 6) - - def test_word2vec_get_vectors(self): - data = [ - ["a", "b", "c", "d", "e", "f", "g"], - ["a", "b", "c", "d", "e", "f"], - ["a", "b", "c", "d", "e"], - ["a", "b", "c", "d"], - ["a", "b", "c"], - ["a", "b"], - ["a"] - ] - model = Word2Vec().fit(self.sc.parallelize(data)) - self.assertEqual(len(model.getVectors()), 3) - - -class StandardScalerTests(MLlibTestCase): - def test_model_setters(self): - data = [ - [1.0, 2.0, 3.0], - [2.0, 3.0, 4.0], - [3.0, 4.0, 5.0] - ] - model = StandardScaler().fit(self.sc.parallelize(data)) - self.assertIsNotNone(model.setWithMean(True)) - self.assertIsNotNone(model.setWithStd(True)) - self.assertEqual(model.transform([1.0, 2.0, 3.0]), DenseVector([-1.0, -1.0, -1.0])) - - def test_model_transform(self): - data = [ - [1.0, 2.0, 3.0], - [2.0, 3.0, 4.0], - [3.0, 4.0, 5.0] - ] - model = StandardScaler().fit(self.sc.parallelize(data)) - self.assertEqual(model.transform([1.0, 2.0, 3.0]), DenseVector([1.0, 2.0, 3.0])) - - -class ElementwiseProductTests(MLlibTestCase): - def test_model_transform(self): - weight = Vectors.dense([3, 2, 1]) - - densevec = Vectors.dense([4, 5, 6]) - sparsevec = Vectors.sparse(3, [0], [1]) - eprod = ElementwiseProduct(weight) - self.assertEqual(eprod.transform(densevec), DenseVector([12, 10, 6])) - self.assertEqual( - eprod.transform(sparsevec), SparseVector(3, [0], [3])) - - -class StreamingKMeansTest(MLLibStreamingTestCase): - def test_model_params(self): - """Test that the model params are set correctly""" - stkm = StreamingKMeans() - stkm.setK(5).setDecayFactor(0.0) - self.assertEqual(stkm._k, 5) - self.assertEqual(stkm._decayFactor, 0.0) - - # Model not set yet. - self.assertIsNone(stkm.latestModel()) - self.assertRaises(ValueError, stkm.trainOn, [0.0, 1.0]) - - stkm.setInitialCenters( - centers=[[0.0, 0.0], [1.0, 1.0]], weights=[1.0, 1.0]) - self.assertEqual( - stkm.latestModel().centers, [[0.0, 0.0], [1.0, 1.0]]) - self.assertEqual(stkm.latestModel().clusterWeights, [1.0, 1.0]) - - def test_accuracy_for_single_center(self): - """Test that parameters obtained are correct for a single center.""" - centers, batches = self.streamingKMeansDataGenerator( - batches=5, numPoints=5, k=1, d=5, r=0.1, seed=0) - stkm = StreamingKMeans(1) - stkm.setInitialCenters([[0., 0., 0., 0., 0.]], [0.]) - input_stream = self.ssc.queueStream( - [self.sc.parallelize(batch, 1) for batch in batches]) - stkm.trainOn(input_stream) - - self.ssc.start() - - def condition(): - self.assertEqual(stkm.latestModel().clusterWeights, [25.0]) - return True - self._eventually(condition, catch_assertions=True) - - realCenters = array_sum(array(centers), axis=0) - for i in range(5): - modelCenters = stkm.latestModel().centers[0][i] - self.assertAlmostEqual(centers[0][i], modelCenters, 1) - self.assertAlmostEqual(realCenters[i], modelCenters, 1) - - def streamingKMeansDataGenerator(self, batches, numPoints, - k, d, r, seed, centers=None): - rng = random.RandomState(seed) - - # Generate centers. - centers = [rng.randn(d) for i in range(k)] - - return centers, [[Vectors.dense(centers[j % k] + r * rng.randn(d)) - for j in range(numPoints)] - for i in range(batches)] - - def test_trainOn_model(self): - """Test the model on toy data with four clusters.""" - stkm = StreamingKMeans() - initCenters = [[1.0, 1.0], [-1.0, 1.0], [-1.0, -1.0], [1.0, -1.0]] - stkm.setInitialCenters( - centers=initCenters, weights=[1.0, 1.0, 1.0, 1.0]) - - # Create a toy dataset by setting a tiny offset for each point. - offsets = [[0, 0.1], [0, -0.1], [0.1, 0], [-0.1, 0]] - batches = [] - for offset in offsets: - batches.append([[offset[0] + center[0], offset[1] + center[1]] - for center in initCenters]) - - batches = [self.sc.parallelize(batch, 1) for batch in batches] - input_stream = self.ssc.queueStream(batches) - stkm.trainOn(input_stream) - self.ssc.start() - - # Give enough time to train the model. - def condition(): - finalModel = stkm.latestModel() - self.assertTrue(all(finalModel.centers == array(initCenters))) - self.assertEqual(finalModel.clusterWeights, [5.0, 5.0, 5.0, 5.0]) - return True - self._eventually(condition, catch_assertions=True) - - def test_predictOn_model(self): - """Test that the model predicts correctly on toy data.""" - stkm = StreamingKMeans() - stkm._model = StreamingKMeansModel( - clusterCenters=[[1.0, 1.0], [-1.0, 1.0], [-1.0, -1.0], [1.0, -1.0]], - clusterWeights=[1.0, 1.0, 1.0, 1.0]) - - predict_data = [[[1.5, 1.5]], [[-1.5, 1.5]], [[-1.5, -1.5]], [[1.5, -1.5]]] - predict_data = [self.sc.parallelize(batch, 1) for batch in predict_data] - predict_stream = self.ssc.queueStream(predict_data) - predict_val = stkm.predictOn(predict_stream) - - result = [] - - def update(rdd): - rdd_collect = rdd.collect() - if rdd_collect: - result.append(rdd_collect) - - predict_val.foreachRDD(update) - self.ssc.start() - - def condition(): - self.assertEqual(result, [[0], [1], [2], [3]]) - return True - - self._eventually(condition, catch_assertions=True) - - @unittest.skip("SPARK-10086: Flaky StreamingKMeans test in PySpark") - def test_trainOn_predictOn(self): - """Test that prediction happens on the updated model.""" - stkm = StreamingKMeans(decayFactor=0.0, k=2) - stkm.setInitialCenters([[0.0], [1.0]], [1.0, 1.0]) - - # Since decay factor is set to zero, once the first batch - # is passed the clusterCenters are updated to [-0.5, 0.7] - # which causes 0.2 & 0.3 to be classified as 1, even though the - # classification based in the initial model would have been 0 - # proving that the model is updated. - batches = [[[-0.5], [0.6], [0.8]], [[0.2], [-0.1], [0.3]]] - batches = [self.sc.parallelize(batch) for batch in batches] - input_stream = self.ssc.queueStream(batches) - predict_results = [] - - def collect(rdd): - rdd_collect = rdd.collect() - if rdd_collect: - predict_results.append(rdd_collect) - - stkm.trainOn(input_stream) - predict_stream = stkm.predictOn(input_stream) - predict_stream.foreachRDD(collect) - - self.ssc.start() - - def condition(): - self.assertEqual(predict_results, [[0, 1, 1], [1, 0, 1]]) - return True - - self._eventually(condition, catch_assertions=True) - - -class LinearDataGeneratorTests(MLlibTestCase): - def test_dim(self): - linear_data = LinearDataGenerator.generateLinearInput( - intercept=0.0, weights=[0.0, 0.0, 0.0], - xMean=[0.0, 0.0, 0.0], xVariance=[0.33, 0.33, 0.33], - nPoints=4, seed=0, eps=0.1) - self.assertEqual(len(linear_data), 4) - for point in linear_data: - self.assertEqual(len(point.features), 3) - - linear_data = LinearDataGenerator.generateLinearRDD( - sc=self.sc, nexamples=6, nfeatures=2, eps=0.1, - nParts=2, intercept=0.0).collect() - self.assertEqual(len(linear_data), 6) - for point in linear_data: - self.assertEqual(len(point.features), 2) - - -class StreamingLogisticRegressionWithSGDTests(MLLibStreamingTestCase): - - @staticmethod - def generateLogisticInput(offset, scale, nPoints, seed): - """ - Generate 1 / (1 + exp(-x * scale + offset)) - - where, - x is randomnly distributed and the threshold - and labels for each sample in x is obtained from a random uniform - distribution. - """ - rng = random.RandomState(seed) - x = rng.randn(nPoints) - sigmoid = 1. / (1 + exp(-(dot(x, scale) + offset))) - y_p = rng.rand(nPoints) - cut_off = y_p <= sigmoid - y_p[cut_off] = 1.0 - y_p[~cut_off] = 0.0 - return [ - LabeledPoint(y_p[i], Vectors.dense([x[i]])) - for i in range(nPoints)] - - @unittest.skip("Super flaky test") - def test_parameter_accuracy(self): - """ - Test that the final value of weights is close to the desired value. - """ - input_batches = [ - self.sc.parallelize(self.generateLogisticInput(0, 1.5, 100, 42 + i)) - for i in range(20)] - input_stream = self.ssc.queueStream(input_batches) - - slr = StreamingLogisticRegressionWithSGD( - stepSize=0.2, numIterations=25) - slr.setInitialWeights([0.0]) - slr.trainOn(input_stream) - - self.ssc.start() - - def condition(): - rel = (1.5 - slr.latestModel().weights.array[0]) / 1.5 - self.assertAlmostEqual(rel, 0.1, 1) - return True - - self._eventually(condition, catch_assertions=True) - - def test_convergence(self): - """ - Test that weights converge to the required value on toy data. - """ - input_batches = [ - self.sc.parallelize(self.generateLogisticInput(0, 1.5, 100, 42 + i)) - for i in range(20)] - input_stream = self.ssc.queueStream(input_batches) - models = [] - - slr = StreamingLogisticRegressionWithSGD( - stepSize=0.2, numIterations=25) - slr.setInitialWeights([0.0]) - slr.trainOn(input_stream) - input_stream.foreachRDD( - lambda x: models.append(slr.latestModel().weights[0])) - - self.ssc.start() - - def condition(): - self.assertEqual(len(models), len(input_batches)) - return True - - # We want all batches to finish for this test. - self._eventually(condition, 60.0, catch_assertions=True) - - t_models = array(models) - diff = t_models[1:] - t_models[:-1] - # Test that weights improve with a small tolerance - self.assertTrue(all(diff >= -0.1)) - self.assertTrue(array_sum(diff > 0) > 1) - - @staticmethod - def calculate_accuracy_error(true, predicted): - return sum(abs(array(true) - array(predicted))) / len(true) - - def test_predictions(self): - """Test predicted values on a toy model.""" - input_batches = [] - for i in range(20): - batch = self.sc.parallelize( - self.generateLogisticInput(0, 1.5, 100, 42 + i)) - input_batches.append(batch.map(lambda x: (x.label, x.features))) - input_stream = self.ssc.queueStream(input_batches) - - slr = StreamingLogisticRegressionWithSGD( - stepSize=0.2, numIterations=25) - slr.setInitialWeights([1.5]) - predict_stream = slr.predictOnValues(input_stream) - true_predicted = [] - predict_stream.foreachRDD(lambda x: true_predicted.append(x.collect())) - self.ssc.start() - - def condition(): - self.assertEqual(len(true_predicted), len(input_batches)) - return True - - self._eventually(condition, catch_assertions=True) - - # Test that the accuracy error is no more than 0.4 on each batch. - for batch in true_predicted: - true, predicted = zip(*batch) - self.assertTrue( - self.calculate_accuracy_error(true, predicted) < 0.4) - - @unittest.skip("Super flaky test") - def test_training_and_prediction(self): - """Test that the model improves on toy data with no. of batches""" - input_batches = [ - self.sc.parallelize(self.generateLogisticInput(0, 1.5, 100, 42 + i)) - for i in range(20)] - predict_batches = [ - b.map(lambda lp: (lp.label, lp.features)) for b in input_batches] - - slr = StreamingLogisticRegressionWithSGD( - stepSize=0.01, numIterations=25) - slr.setInitialWeights([-0.1]) - errors = [] - - def collect_errors(rdd): - true, predicted = zip(*rdd.collect()) - errors.append(self.calculate_accuracy_error(true, predicted)) - - true_predicted = [] - input_stream = self.ssc.queueStream(input_batches) - predict_stream = self.ssc.queueStream(predict_batches) - slr.trainOn(input_stream) - ps = slr.predictOnValues(predict_stream) - ps.foreachRDD(lambda x: collect_errors(x)) - - self.ssc.start() - - def condition(): - # Test that the improvement in error is > 0.3 - if len(errors) == len(predict_batches): - self.assertGreater(errors[1] - errors[-1], 0.3) - if len(errors) >= 3 and errors[1] - errors[-1] > 0.3: - return True - return "Latest errors: " + ", ".join(map(lambda x: str(x), errors)) - - self._eventually(condition) - - -class StreamingLinearRegressionWithTests(MLLibStreamingTestCase): - - def assertArrayAlmostEqual(self, array1, array2, dec): - for i, j in array1, array2: - self.assertAlmostEqual(i, j, dec) - - @unittest.skip("Super flaky test") - def test_parameter_accuracy(self): - """Test that coefs are predicted accurately by fitting on toy data.""" - - # Test that fitting (10*X1 + 10*X2), (X1, X2) gives coefficients - # (10, 10) - slr = StreamingLinearRegressionWithSGD(stepSize=0.2, numIterations=25) - slr.setInitialWeights([0.0, 0.0]) - xMean = [0.0, 0.0] - xVariance = [1.0 / 3.0, 1.0 / 3.0] - - # Create ten batches with 100 sample points in each. - batches = [] - for i in range(10): - batch = LinearDataGenerator.generateLinearInput( - 0.0, [10.0, 10.0], xMean, xVariance, 100, 42 + i, 0.1) - batches.append(self.sc.parallelize(batch)) - - input_stream = self.ssc.queueStream(batches) - slr.trainOn(input_stream) - self.ssc.start() - - def condition(): - self.assertArrayAlmostEqual( - slr.latestModel().weights.array, [10., 10.], 1) - self.assertAlmostEqual(slr.latestModel().intercept, 0.0, 1) - return True - - self._eventually(condition, catch_assertions=True) - - def test_parameter_convergence(self): - """Test that the model parameters improve with streaming data.""" - slr = StreamingLinearRegressionWithSGD(stepSize=0.2, numIterations=25) - slr.setInitialWeights([0.0]) - - # Create ten batches with 100 sample points in each. - batches = [] - for i in range(10): - batch = LinearDataGenerator.generateLinearInput( - 0.0, [10.0], [0.0], [1.0 / 3.0], 100, 42 + i, 0.1) - batches.append(self.sc.parallelize(batch)) - - model_weights = [] - input_stream = self.ssc.queueStream(batches) - input_stream.foreachRDD( - lambda x: model_weights.append(slr.latestModel().weights[0])) - slr.trainOn(input_stream) - self.ssc.start() - - def condition(): - self.assertEqual(len(model_weights), len(batches)) - return True - - # We want all batches to finish for this test. - self._eventually(condition, catch_assertions=True) - - w = array(model_weights) - diff = w[1:] - w[:-1] - self.assertTrue(all(diff >= -0.1)) - - def test_prediction(self): - """Test prediction on a model with weights already set.""" - # Create a model with initial Weights equal to coefs - slr = StreamingLinearRegressionWithSGD(stepSize=0.2, numIterations=25) - slr.setInitialWeights([10.0, 10.0]) - - # Create ten batches with 100 sample points in each. - batches = [] - for i in range(10): - batch = LinearDataGenerator.generateLinearInput( - 0.0, [10.0, 10.0], [0.0, 0.0], [1.0 / 3.0, 1.0 / 3.0], - 100, 42 + i, 0.1) - batches.append( - self.sc.parallelize(batch).map(lambda lp: (lp.label, lp.features))) - - input_stream = self.ssc.queueStream(batches) - output_stream = slr.predictOnValues(input_stream) - samples = [] - output_stream.foreachRDD(lambda x: samples.append(x.collect())) - - self.ssc.start() - - def condition(): - self.assertEqual(len(samples), len(batches)) - return True - - # We want all batches to finish for this test. - self._eventually(condition, catch_assertions=True) - - # Test that mean absolute error on each batch is less than 0.1 - for batch in samples: - true, predicted = zip(*batch) - self.assertTrue(mean(abs(array(true) - array(predicted))) < 0.1) - - @unittest.skip("Super flaky test") - def test_train_prediction(self): - """Test that error on test data improves as model is trained.""" - slr = StreamingLinearRegressionWithSGD(stepSize=0.2, numIterations=25) - slr.setInitialWeights([0.0]) - - # Create ten batches with 100 sample points in each. - batches = [] - for i in range(10): - batch = LinearDataGenerator.generateLinearInput( - 0.0, [10.0], [0.0], [1.0 / 3.0], 100, 42 + i, 0.1) - batches.append(self.sc.parallelize(batch)) - - predict_batches = [ - b.map(lambda lp: (lp.label, lp.features)) for b in batches] - errors = [] - - def func(rdd): - true, predicted = zip(*rdd.collect()) - errors.append(mean(abs(true) - abs(predicted))) - - input_stream = self.ssc.queueStream(batches) - output_stream = self.ssc.queueStream(predict_batches) - slr.trainOn(input_stream) - output_stream = slr.predictOnValues(output_stream) - output_stream.foreachRDD(func) - self.ssc.start() - - def condition(): - if len(errors) == len(predict_batches): - self.assertGreater(errors[1] - errors[-1], 2) - if len(errors) >= 3 and errors[1] - errors[-1] > 2: - return True - return "Latest errors: " + ", ".join(map(lambda x: str(x), errors)) - - self._eventually(condition) - - -class MLUtilsTests(MLlibTestCase): - def test_append_bias(self): - data = [2.0, 2.0, 2.0] - ret = MLUtils.appendBias(data) - self.assertEqual(ret[3], 1.0) - self.assertEqual(type(ret), DenseVector) - - def test_append_bias_with_vector(self): - data = Vectors.dense([2.0, 2.0, 2.0]) - ret = MLUtils.appendBias(data) - self.assertEqual(ret[3], 1.0) - self.assertEqual(type(ret), DenseVector) - - def test_append_bias_with_sp_vector(self): - data = Vectors.sparse(3, {0: 2.0, 2: 2.0}) - expected = Vectors.sparse(4, {0: 2.0, 2: 2.0, 3: 1.0}) - # Returned value must be SparseVector - ret = MLUtils.appendBias(data) - self.assertEqual(ret, expected) - self.assertEqual(type(ret), SparseVector) - - def test_load_vectors(self): - import shutil - data = [ - [1.0, 2.0, 3.0], - [1.0, 2.0, 3.0] - ] - temp_dir = tempfile.mkdtemp() - load_vectors_path = os.path.join(temp_dir, "test_load_vectors") - try: - self.sc.parallelize(data).saveAsTextFile(load_vectors_path) - ret_rdd = MLUtils.loadVectors(self.sc, load_vectors_path) - ret = ret_rdd.collect() - self.assertEqual(len(ret), 2) - self.assertEqual(ret[0], DenseVector([1.0, 2.0, 3.0])) - self.assertEqual(ret[1], DenseVector([1.0, 2.0, 3.0])) - except: - self.fail() - finally: - shutil.rmtree(load_vectors_path) - - -class ALSTests(MLlibTestCase): - - def test_als_ratings_serialize(self): - r = Rating(7, 1123, 3.14) - jr = self.sc._jvm.org.apache.spark.mllib.api.python.SerDe.loads(bytearray(ser.dumps(r))) - nr = ser.loads(bytes(self.sc._jvm.org.apache.spark.mllib.api.python.SerDe.dumps(jr))) - self.assertEqual(r.user, nr.user) - self.assertEqual(r.product, nr.product) - self.assertAlmostEqual(r.rating, nr.rating, 2) - - def test_als_ratings_id_long_error(self): - r = Rating(1205640308657491975, 50233468418, 1.0) - # rating user id exceeds max int value, should fail when pickled - self.assertRaises(Py4JJavaError, self.sc._jvm.org.apache.spark.mllib.api.python.SerDe.loads, - bytearray(ser.dumps(r))) - - -class HashingTFTest(MLlibTestCase): - - def test_binary_term_freqs(self): - hashingTF = HashingTF(100).setBinary(True) - doc = "a a b c c c".split(" ") - n = hashingTF.numFeatures - output = hashingTF.transform(doc).toArray() - expected = Vectors.sparse(n, {hashingTF.indexOf("a"): 1.0, - hashingTF.indexOf("b"): 1.0, - hashingTF.indexOf("c"): 1.0}).toArray() - for i in range(0, n): - self.assertAlmostEqual(output[i], expected[i], 14, "Error at " + str(i) + - ": expected " + str(expected[i]) + ", got " + str(output[i])) - - -class DimensionalityReductionTests(MLlibTestCase): - - denseData = [ - Vectors.dense([0.0, 1.0, 2.0]), - Vectors.dense([3.0, 4.0, 5.0]), - Vectors.dense([6.0, 7.0, 8.0]), - Vectors.dense([9.0, 0.0, 1.0]) - ] - sparseData = [ - Vectors.sparse(3, [(1, 1.0), (2, 2.0)]), - Vectors.sparse(3, [(0, 3.0), (1, 4.0), (2, 5.0)]), - Vectors.sparse(3, [(0, 6.0), (1, 7.0), (2, 8.0)]), - Vectors.sparse(3, [(0, 9.0), (2, 1.0)]) - ] - - def assertEqualUpToSign(self, vecA, vecB): - eq1 = vecA - vecB - eq2 = vecA + vecB - self.assertTrue(sum(abs(eq1)) < 1e-6 or sum(abs(eq2)) < 1e-6) - - def test_svd(self): - denseMat = RowMatrix(self.sc.parallelize(self.denseData)) - sparseMat = RowMatrix(self.sc.parallelize(self.sparseData)) - m = 4 - n = 3 - for mat in [denseMat, sparseMat]: - for k in range(1, 4): - rm = mat.computeSVD(k, computeU=True) - self.assertEqual(rm.s.size, k) - self.assertEqual(rm.U.numRows(), m) - self.assertEqual(rm.U.numCols(), k) - self.assertEqual(rm.V.numRows, n) - self.assertEqual(rm.V.numCols, k) - - # Test that U returned is None if computeU is set to False. - self.assertEqual(mat.computeSVD(1).U, None) - - # Test that low rank matrices cannot have number of singular values - # greater than a limit. - rm = RowMatrix(self.sc.parallelize(tile([1, 2, 3], (3, 1)))) - self.assertEqual(rm.computeSVD(3, False, 1e-6).s.size, 1) - - def test_pca(self): - expected_pcs = array([ - [0.0, 1.0, 0.0], - [sqrt(2.0) / 2.0, 0.0, sqrt(2.0) / 2.0], - [sqrt(2.0) / 2.0, 0.0, -sqrt(2.0) / 2.0] - ]) - n = 3 - denseMat = RowMatrix(self.sc.parallelize(self.denseData)) - sparseMat = RowMatrix(self.sc.parallelize(self.sparseData)) - for mat in [denseMat, sparseMat]: - for k in range(1, 4): - pcs = mat.computePrincipalComponents(k) - self.assertEqual(pcs.numRows, n) - self.assertEqual(pcs.numCols, k) - - # We can just test the updated principal component for equality. - self.assertEqualUpToSign(pcs.toArray()[:, k - 1], expected_pcs[:, k - 1]) - - -class FPGrowthTest(MLlibTestCase): - - def test_fpgrowth(self): - data = [["a", "b", "c"], ["a", "b", "d", "e"], ["a", "c", "e"], ["a", "c", "f"]] - rdd = self.sc.parallelize(data, 2) - model1 = FPGrowth.train(rdd, 0.6, 2) - # use default data partition number when numPartitions is not specified - model2 = FPGrowth.train(rdd, 0.6) - self.assertEqual(sorted(model1.freqItemsets().collect()), - sorted(model2.freqItemsets().collect())) - -if __name__ == "__main__": - from pyspark.mllib.tests import * - if not _have_scipy: - print("NOTE: Skipping SciPy tests as it does not seem to be installed") - runner = unishark.BufferedTestRunner( - reporters=[unishark.XUnitReporter('target/test-reports/pyspark.mllib_{}'.format( - os.path.basename(os.environ.get("PYSPARK_PYTHON", ""))))]) - unittest.main(testRunner=runner, verbosity=2) - if not _have_scipy: - print("NOTE: SciPy tests were skipped as it does not seem to be installed") - sc.stop() diff --git a/python/pyspark/mllib/tests/__init__.py b/python/pyspark/mllib/tests/__init__.py new file mode 100644 index 0000000000000..cce3acad34a49 --- /dev/null +++ b/python/pyspark/mllib/tests/__init__.py @@ -0,0 +1,16 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# diff --git a/python/pyspark/mllib/tests/test_algorithms.py b/python/pyspark/mllib/tests/test_algorithms.py new file mode 100644 index 0000000000000..cc3b64b1cb284 --- /dev/null +++ b/python/pyspark/mllib/tests/test_algorithms.py @@ -0,0 +1,302 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import os +import tempfile +from shutil import rmtree +import unittest + +from numpy import array, array_equal +from py4j.protocol import Py4JJavaError + +from pyspark.mllib.fpm import FPGrowth +from pyspark.mllib.recommendation import Rating +from pyspark.mllib.regression import LabeledPoint +from pyspark.testing.mllibutils import make_serializer, MLlibTestCase + + +ser = make_serializer() + + +class ListTests(MLlibTestCase): + + """ + Test MLlib algorithms on plain lists, to make sure they're passed through + as NumPy arrays. + """ + + def test_bisecting_kmeans(self): + from pyspark.mllib.clustering import BisectingKMeans + data = array([0.0, 0.0, 1.0, 1.0, 9.0, 8.0, 8.0, 9.0]).reshape(4, 2) + bskm = BisectingKMeans() + model = bskm.train(self.sc.parallelize(data, 2), k=4) + p = array([0.0, 0.0]) + rdd_p = self.sc.parallelize([p]) + self.assertEqual(model.predict(p), model.predict(rdd_p).first()) + self.assertEqual(model.computeCost(p), model.computeCost(rdd_p)) + self.assertEqual(model.k, len(model.clusterCenters)) + + def test_kmeans(self): + from pyspark.mllib.clustering import KMeans + data = [ + [0, 1.1], + [0, 1.2], + [1.1, 0], + [1.2, 0], + ] + clusters = KMeans.train(self.sc.parallelize(data), 2, initializationMode="k-means||", + initializationSteps=7, epsilon=1e-4) + self.assertEqual(clusters.predict(data[0]), clusters.predict(data[1])) + self.assertEqual(clusters.predict(data[2]), clusters.predict(data[3])) + + def test_kmeans_deterministic(self): + from pyspark.mllib.clustering import KMeans + X = range(0, 100, 10) + Y = range(0, 100, 10) + data = [[x, y] for x, y in zip(X, Y)] + clusters1 = KMeans.train(self.sc.parallelize(data), + 3, initializationMode="k-means||", + seed=42, initializationSteps=7, epsilon=1e-4) + clusters2 = KMeans.train(self.sc.parallelize(data), + 3, initializationMode="k-means||", + seed=42, initializationSteps=7, epsilon=1e-4) + centers1 = clusters1.centers + centers2 = clusters2.centers + for c1, c2 in zip(centers1, centers2): + # TODO: Allow small numeric difference. + self.assertTrue(array_equal(c1, c2)) + + def test_gmm(self): + from pyspark.mllib.clustering import GaussianMixture + data = self.sc.parallelize([ + [1, 2], + [8, 9], + [-4, -3], + [-6, -7], + ]) + clusters = GaussianMixture.train(data, 2, convergenceTol=0.001, + maxIterations=10, seed=1) + labels = clusters.predict(data).collect() + self.assertEqual(labels[0], labels[1]) + self.assertEqual(labels[2], labels[3]) + + def test_gmm_deterministic(self): + from pyspark.mllib.clustering import GaussianMixture + x = range(0, 100, 10) + y = range(0, 100, 10) + data = self.sc.parallelize([[a, b] for a, b in zip(x, y)]) + clusters1 = GaussianMixture.train(data, 5, convergenceTol=0.001, + maxIterations=10, seed=63) + clusters2 = GaussianMixture.train(data, 5, convergenceTol=0.001, + maxIterations=10, seed=63) + for c1, c2 in zip(clusters1.weights, clusters2.weights): + self.assertEqual(round(c1, 7), round(c2, 7)) + + def test_gmm_with_initial_model(self): + from pyspark.mllib.clustering import GaussianMixture + data = self.sc.parallelize([ + (-10, -5), (-9, -4), (10, 5), (9, 4) + ]) + + gmm1 = GaussianMixture.train(data, 2, convergenceTol=0.001, + maxIterations=10, seed=63) + gmm2 = GaussianMixture.train(data, 2, convergenceTol=0.001, + maxIterations=10, seed=63, initialModel=gmm1) + self.assertAlmostEqual((gmm1.weights - gmm2.weights).sum(), 0.0) + + def test_classification(self): + from pyspark.mllib.classification import LogisticRegressionWithSGD, SVMWithSGD, NaiveBayes + from pyspark.mllib.tree import DecisionTree, DecisionTreeModel, RandomForest, \ + RandomForestModel, GradientBoostedTrees, GradientBoostedTreesModel + data = [ + LabeledPoint(0.0, [1, 0, 0]), + LabeledPoint(1.0, [0, 1, 1]), + LabeledPoint(0.0, [2, 0, 0]), + LabeledPoint(1.0, [0, 2, 1]) + ] + rdd = self.sc.parallelize(data) + features = [p.features.tolist() for p in data] + + temp_dir = tempfile.mkdtemp() + + lr_model = LogisticRegressionWithSGD.train(rdd, iterations=10) + self.assertTrue(lr_model.predict(features[0]) <= 0) + self.assertTrue(lr_model.predict(features[1]) > 0) + self.assertTrue(lr_model.predict(features[2]) <= 0) + self.assertTrue(lr_model.predict(features[3]) > 0) + + svm_model = SVMWithSGD.train(rdd, iterations=10) + self.assertTrue(svm_model.predict(features[0]) <= 0) + self.assertTrue(svm_model.predict(features[1]) > 0) + self.assertTrue(svm_model.predict(features[2]) <= 0) + self.assertTrue(svm_model.predict(features[3]) > 0) + + nb_model = NaiveBayes.train(rdd) + self.assertTrue(nb_model.predict(features[0]) <= 0) + self.assertTrue(nb_model.predict(features[1]) > 0) + self.assertTrue(nb_model.predict(features[2]) <= 0) + self.assertTrue(nb_model.predict(features[3]) > 0) + + categoricalFeaturesInfo = {0: 3} # feature 0 has 3 categories + dt_model = DecisionTree.trainClassifier( + rdd, numClasses=2, categoricalFeaturesInfo=categoricalFeaturesInfo, maxBins=4) + self.assertTrue(dt_model.predict(features[0]) <= 0) + self.assertTrue(dt_model.predict(features[1]) > 0) + self.assertTrue(dt_model.predict(features[2]) <= 0) + self.assertTrue(dt_model.predict(features[3]) > 0) + + dt_model_dir = os.path.join(temp_dir, "dt") + dt_model.save(self.sc, dt_model_dir) + same_dt_model = DecisionTreeModel.load(self.sc, dt_model_dir) + self.assertEqual(same_dt_model.toDebugString(), dt_model.toDebugString()) + + rf_model = RandomForest.trainClassifier( + rdd, numClasses=2, categoricalFeaturesInfo=categoricalFeaturesInfo, numTrees=10, + maxBins=4, seed=1) + self.assertTrue(rf_model.predict(features[0]) <= 0) + self.assertTrue(rf_model.predict(features[1]) > 0) + self.assertTrue(rf_model.predict(features[2]) <= 0) + self.assertTrue(rf_model.predict(features[3]) > 0) + + rf_model_dir = os.path.join(temp_dir, "rf") + rf_model.save(self.sc, rf_model_dir) + same_rf_model = RandomForestModel.load(self.sc, rf_model_dir) + self.assertEqual(same_rf_model.toDebugString(), rf_model.toDebugString()) + + gbt_model = GradientBoostedTrees.trainClassifier( + rdd, categoricalFeaturesInfo=categoricalFeaturesInfo, numIterations=4) + self.assertTrue(gbt_model.predict(features[0]) <= 0) + self.assertTrue(gbt_model.predict(features[1]) > 0) + self.assertTrue(gbt_model.predict(features[2]) <= 0) + self.assertTrue(gbt_model.predict(features[3]) > 0) + + gbt_model_dir = os.path.join(temp_dir, "gbt") + gbt_model.save(self.sc, gbt_model_dir) + same_gbt_model = GradientBoostedTreesModel.load(self.sc, gbt_model_dir) + self.assertEqual(same_gbt_model.toDebugString(), gbt_model.toDebugString()) + + try: + rmtree(temp_dir) + except OSError: + pass + + def test_regression(self): + from pyspark.mllib.regression import LinearRegressionWithSGD, LassoWithSGD, \ + RidgeRegressionWithSGD + from pyspark.mllib.tree import DecisionTree, RandomForest, GradientBoostedTrees + data = [ + LabeledPoint(-1.0, [0, -1]), + LabeledPoint(1.0, [0, 1]), + LabeledPoint(-1.0, [0, -2]), + LabeledPoint(1.0, [0, 2]) + ] + rdd = self.sc.parallelize(data) + features = [p.features.tolist() for p in data] + + lr_model = LinearRegressionWithSGD.train(rdd, iterations=10) + self.assertTrue(lr_model.predict(features[0]) <= 0) + self.assertTrue(lr_model.predict(features[1]) > 0) + self.assertTrue(lr_model.predict(features[2]) <= 0) + self.assertTrue(lr_model.predict(features[3]) > 0) + + lasso_model = LassoWithSGD.train(rdd, iterations=10) + self.assertTrue(lasso_model.predict(features[0]) <= 0) + self.assertTrue(lasso_model.predict(features[1]) > 0) + self.assertTrue(lasso_model.predict(features[2]) <= 0) + self.assertTrue(lasso_model.predict(features[3]) > 0) + + rr_model = RidgeRegressionWithSGD.train(rdd, iterations=10) + self.assertTrue(rr_model.predict(features[0]) <= 0) + self.assertTrue(rr_model.predict(features[1]) > 0) + self.assertTrue(rr_model.predict(features[2]) <= 0) + self.assertTrue(rr_model.predict(features[3]) > 0) + + categoricalFeaturesInfo = {0: 2} # feature 0 has 2 categories + dt_model = DecisionTree.trainRegressor( + rdd, categoricalFeaturesInfo=categoricalFeaturesInfo, maxBins=4) + self.assertTrue(dt_model.predict(features[0]) <= 0) + self.assertTrue(dt_model.predict(features[1]) > 0) + self.assertTrue(dt_model.predict(features[2]) <= 0) + self.assertTrue(dt_model.predict(features[3]) > 0) + + rf_model = RandomForest.trainRegressor( + rdd, categoricalFeaturesInfo=categoricalFeaturesInfo, numTrees=10, maxBins=4, seed=1) + self.assertTrue(rf_model.predict(features[0]) <= 0) + self.assertTrue(rf_model.predict(features[1]) > 0) + self.assertTrue(rf_model.predict(features[2]) <= 0) + self.assertTrue(rf_model.predict(features[3]) > 0) + + gbt_model = GradientBoostedTrees.trainRegressor( + rdd, categoricalFeaturesInfo=categoricalFeaturesInfo, numIterations=4) + self.assertTrue(gbt_model.predict(features[0]) <= 0) + self.assertTrue(gbt_model.predict(features[1]) > 0) + self.assertTrue(gbt_model.predict(features[2]) <= 0) + self.assertTrue(gbt_model.predict(features[3]) > 0) + + try: + LinearRegressionWithSGD.train(rdd, initialWeights=array([1.0, 1.0]), iterations=10) + LassoWithSGD.train(rdd, initialWeights=array([1.0, 1.0]), iterations=10) + RidgeRegressionWithSGD.train(rdd, initialWeights=array([1.0, 1.0]), iterations=10) + except ValueError: + self.fail() + + # Verify that maxBins is being passed through + GradientBoostedTrees.trainRegressor( + rdd, categoricalFeaturesInfo=categoricalFeaturesInfo, numIterations=4, maxBins=32) + with self.assertRaises(Exception) as cm: + GradientBoostedTrees.trainRegressor( + rdd, categoricalFeaturesInfo=categoricalFeaturesInfo, numIterations=4, maxBins=1) + + +class ALSTests(MLlibTestCase): + + def test_als_ratings_serialize(self): + r = Rating(7, 1123, 3.14) + jr = self.sc._jvm.org.apache.spark.mllib.api.python.SerDe.loads(bytearray(ser.dumps(r))) + nr = ser.loads(bytes(self.sc._jvm.org.apache.spark.mllib.api.python.SerDe.dumps(jr))) + self.assertEqual(r.user, nr.user) + self.assertEqual(r.product, nr.product) + self.assertAlmostEqual(r.rating, nr.rating, 2) + + def test_als_ratings_id_long_error(self): + r = Rating(1205640308657491975, 50233468418, 1.0) + # rating user id exceeds max int value, should fail when pickled + self.assertRaises(Py4JJavaError, self.sc._jvm.org.apache.spark.mllib.api.python.SerDe.loads, + bytearray(ser.dumps(r))) + + +class FPGrowthTest(MLlibTestCase): + + def test_fpgrowth(self): + data = [["a", "b", "c"], ["a", "b", "d", "e"], ["a", "c", "e"], ["a", "c", "f"]] + rdd = self.sc.parallelize(data, 2) + model1 = FPGrowth.train(rdd, 0.6, 2) + # use default data partition number when numPartitions is not specified + model2 = FPGrowth.train(rdd, 0.6) + self.assertEqual(sorted(model1.freqItemsets().collect()), + sorted(model2.freqItemsets().collect())) + + +if __name__ == "__main__": + from pyspark.mllib.tests.test_algorithms import * + + try: + import xmlrunner + testRunner = xmlrunner.XMLTestRunner(output='target/test-reports') + except ImportError: + testRunner = None + unittest.main(testRunner=testRunner, verbosity=2) diff --git a/python/pyspark/mllib/tests/test_feature.py b/python/pyspark/mllib/tests/test_feature.py new file mode 100644 index 0000000000000..3da841c408558 --- /dev/null +++ b/python/pyspark/mllib/tests/test_feature.py @@ -0,0 +1,192 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from math import sqrt +import unittest + +from numpy import array, random, exp, abs, tile + +from pyspark.mllib.linalg import Vector, SparseVector, DenseVector, VectorUDT, Vectors +from pyspark.mllib.linalg.distributed import RowMatrix +from pyspark.mllib.feature import HashingTF, IDF, StandardScaler, ElementwiseProduct, Word2Vec +from pyspark.testing.mllibutils import MLlibTestCase + + +class FeatureTest(MLlibTestCase): + def test_idf_model(self): + data = [ + Vectors.dense([1, 2, 6, 0, 2, 3, 1, 1, 0, 0, 3]), + Vectors.dense([1, 3, 0, 1, 3, 0, 0, 2, 0, 0, 1]), + Vectors.dense([1, 4, 1, 0, 0, 4, 9, 0, 1, 2, 0]), + Vectors.dense([2, 1, 0, 3, 0, 0, 5, 0, 2, 3, 9]) + ] + model = IDF().fit(self.sc.parallelize(data, 2)) + idf = model.idf() + self.assertEqual(len(idf), 11) + + +class Word2VecTests(MLlibTestCase): + def test_word2vec_setters(self): + model = Word2Vec() \ + .setVectorSize(2) \ + .setLearningRate(0.01) \ + .setNumPartitions(2) \ + .setNumIterations(10) \ + .setSeed(1024) \ + .setMinCount(3) \ + .setWindowSize(6) + self.assertEqual(model.vectorSize, 2) + self.assertTrue(model.learningRate < 0.02) + self.assertEqual(model.numPartitions, 2) + self.assertEqual(model.numIterations, 10) + self.assertEqual(model.seed, 1024) + self.assertEqual(model.minCount, 3) + self.assertEqual(model.windowSize, 6) + + def test_word2vec_get_vectors(self): + data = [ + ["a", "b", "c", "d", "e", "f", "g"], + ["a", "b", "c", "d", "e", "f"], + ["a", "b", "c", "d", "e"], + ["a", "b", "c", "d"], + ["a", "b", "c"], + ["a", "b"], + ["a"] + ] + model = Word2Vec().fit(self.sc.parallelize(data)) + self.assertEqual(len(model.getVectors()), 3) + + +class StandardScalerTests(MLlibTestCase): + def test_model_setters(self): + data = [ + [1.0, 2.0, 3.0], + [2.0, 3.0, 4.0], + [3.0, 4.0, 5.0] + ] + model = StandardScaler().fit(self.sc.parallelize(data)) + self.assertIsNotNone(model.setWithMean(True)) + self.assertIsNotNone(model.setWithStd(True)) + self.assertEqual(model.transform([1.0, 2.0, 3.0]), DenseVector([-1.0, -1.0, -1.0])) + + def test_model_transform(self): + data = [ + [1.0, 2.0, 3.0], + [2.0, 3.0, 4.0], + [3.0, 4.0, 5.0] + ] + model = StandardScaler().fit(self.sc.parallelize(data)) + self.assertEqual(model.transform([1.0, 2.0, 3.0]), DenseVector([1.0, 2.0, 3.0])) + + +class ElementwiseProductTests(MLlibTestCase): + def test_model_transform(self): + weight = Vectors.dense([3, 2, 1]) + + densevec = Vectors.dense([4, 5, 6]) + sparsevec = Vectors.sparse(3, [0], [1]) + eprod = ElementwiseProduct(weight) + self.assertEqual(eprod.transform(densevec), DenseVector([12, 10, 6])) + self.assertEqual( + eprod.transform(sparsevec), SparseVector(3, [0], [3])) + + +class HashingTFTest(MLlibTestCase): + + def test_binary_term_freqs(self): + hashingTF = HashingTF(100).setBinary(True) + doc = "a a b c c c".split(" ") + n = hashingTF.numFeatures + output = hashingTF.transform(doc).toArray() + expected = Vectors.sparse(n, {hashingTF.indexOf("a"): 1.0, + hashingTF.indexOf("b"): 1.0, + hashingTF.indexOf("c"): 1.0}).toArray() + for i in range(0, n): + self.assertAlmostEqual(output[i], expected[i], 14, "Error at " + str(i) + + ": expected " + str(expected[i]) + ", got " + str(output[i])) + + +class DimensionalityReductionTests(MLlibTestCase): + + denseData = [ + Vectors.dense([0.0, 1.0, 2.0]), + Vectors.dense([3.0, 4.0, 5.0]), + Vectors.dense([6.0, 7.0, 8.0]), + Vectors.dense([9.0, 0.0, 1.0]) + ] + sparseData = [ + Vectors.sparse(3, [(1, 1.0), (2, 2.0)]), + Vectors.sparse(3, [(0, 3.0), (1, 4.0), (2, 5.0)]), + Vectors.sparse(3, [(0, 6.0), (1, 7.0), (2, 8.0)]), + Vectors.sparse(3, [(0, 9.0), (2, 1.0)]) + ] + + def assertEqualUpToSign(self, vecA, vecB): + eq1 = vecA - vecB + eq2 = vecA + vecB + self.assertTrue(sum(abs(eq1)) < 1e-6 or sum(abs(eq2)) < 1e-6) + + def test_svd(self): + denseMat = RowMatrix(self.sc.parallelize(self.denseData)) + sparseMat = RowMatrix(self.sc.parallelize(self.sparseData)) + m = 4 + n = 3 + for mat in [denseMat, sparseMat]: + for k in range(1, 4): + rm = mat.computeSVD(k, computeU=True) + self.assertEqual(rm.s.size, k) + self.assertEqual(rm.U.numRows(), m) + self.assertEqual(rm.U.numCols(), k) + self.assertEqual(rm.V.numRows, n) + self.assertEqual(rm.V.numCols, k) + + # Test that U returned is None if computeU is set to False. + self.assertEqual(mat.computeSVD(1).U, None) + + # Test that low rank matrices cannot have number of singular values + # greater than a limit. + rm = RowMatrix(self.sc.parallelize(tile([1, 2, 3], (3, 1)))) + self.assertEqual(rm.computeSVD(3, False, 1e-6).s.size, 1) + + def test_pca(self): + expected_pcs = array([ + [0.0, 1.0, 0.0], + [sqrt(2.0) / 2.0, 0.0, sqrt(2.0) / 2.0], + [sqrt(2.0) / 2.0, 0.0, -sqrt(2.0) / 2.0] + ]) + n = 3 + denseMat = RowMatrix(self.sc.parallelize(self.denseData)) + sparseMat = RowMatrix(self.sc.parallelize(self.sparseData)) + for mat in [denseMat, sparseMat]: + for k in range(1, 4): + pcs = mat.computePrincipalComponents(k) + self.assertEqual(pcs.numRows, n) + self.assertEqual(pcs.numCols, k) + + # We can just test the updated principal component for equality. + self.assertEqualUpToSign(pcs.toArray()[:, k - 1], expected_pcs[:, k - 1]) + + +if __name__ == "__main__": + from pyspark.mllib.tests.test_feature import * + + try: + import xmlrunner + testRunner = xmlrunner.XMLTestRunner(output='target/test-reports') + except ImportError: + testRunner = None + unittest.main(testRunner=testRunner, verbosity=2) diff --git a/python/pyspark/mllib/tests/test_linalg.py b/python/pyspark/mllib/tests/test_linalg.py new file mode 100644 index 0000000000000..d0ebd9bc3db79 --- /dev/null +++ b/python/pyspark/mllib/tests/test_linalg.py @@ -0,0 +1,633 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import sys +import array as pyarray +import unittest + +from numpy import array, array_equal, zeros, arange, tile, ones, inf + +import pyspark.ml.linalg as newlinalg +from pyspark.mllib.linalg import Vector, SparseVector, DenseVector, VectorUDT, _convert_to_vector, \ + DenseMatrix, SparseMatrix, Vectors, Matrices, MatrixUDT +from pyspark.mllib.regression import LabeledPoint +from pyspark.testing.mllibutils import make_serializer, MLlibTestCase + +_have_scipy = False +try: + import scipy.sparse + _have_scipy = True +except: + # No SciPy, but that's okay, we'll skip those tests + pass + + +ser = make_serializer() + + +def _squared_distance(a, b): + if isinstance(a, Vector): + return a.squared_distance(b) + else: + return b.squared_distance(a) + + +class VectorTests(MLlibTestCase): + + def _test_serialize(self, v): + self.assertEqual(v, ser.loads(ser.dumps(v))) + jvec = self.sc._jvm.org.apache.spark.mllib.api.python.SerDe.loads(bytearray(ser.dumps(v))) + nv = ser.loads(bytes(self.sc._jvm.org.apache.spark.mllib.api.python.SerDe.dumps(jvec))) + self.assertEqual(v, nv) + vs = [v] * 100 + jvecs = self.sc._jvm.org.apache.spark.mllib.api.python.SerDe.loads(bytearray(ser.dumps(vs))) + nvs = ser.loads(bytes(self.sc._jvm.org.apache.spark.mllib.api.python.SerDe.dumps(jvecs))) + self.assertEqual(vs, nvs) + + def test_serialize(self): + self._test_serialize(DenseVector(range(10))) + self._test_serialize(DenseVector(array([1., 2., 3., 4.]))) + self._test_serialize(DenseVector(pyarray.array('d', range(10)))) + self._test_serialize(SparseVector(4, {1: 1, 3: 2})) + self._test_serialize(SparseVector(3, {})) + self._test_serialize(DenseMatrix(2, 3, range(6))) + sm1 = SparseMatrix( + 3, 4, [0, 2, 2, 4, 4], [1, 2, 1, 2], [1.0, 2.0, 4.0, 5.0]) + self._test_serialize(sm1) + + def test_dot(self): + sv = SparseVector(4, {1: 1, 3: 2}) + dv = DenseVector(array([1., 2., 3., 4.])) + lst = DenseVector([1, 2, 3, 4]) + mat = array([[1., 2., 3., 4.], + [1., 2., 3., 4.], + [1., 2., 3., 4.], + [1., 2., 3., 4.]]) + arr = pyarray.array('d', [0, 1, 2, 3]) + self.assertEqual(10.0, sv.dot(dv)) + self.assertTrue(array_equal(array([3., 6., 9., 12.]), sv.dot(mat))) + self.assertEqual(30.0, dv.dot(dv)) + self.assertTrue(array_equal(array([10., 20., 30., 40.]), dv.dot(mat))) + self.assertEqual(30.0, lst.dot(dv)) + self.assertTrue(array_equal(array([10., 20., 30., 40.]), lst.dot(mat))) + self.assertEqual(7.0, sv.dot(arr)) + + def test_squared_distance(self): + sv = SparseVector(4, {1: 1, 3: 2}) + dv = DenseVector(array([1., 2., 3., 4.])) + lst = DenseVector([4, 3, 2, 1]) + lst1 = [4, 3, 2, 1] + arr = pyarray.array('d', [0, 2, 1, 3]) + narr = array([0, 2, 1, 3]) + self.assertEqual(15.0, _squared_distance(sv, dv)) + self.assertEqual(25.0, _squared_distance(sv, lst)) + self.assertEqual(20.0, _squared_distance(dv, lst)) + self.assertEqual(15.0, _squared_distance(dv, sv)) + self.assertEqual(25.0, _squared_distance(lst, sv)) + self.assertEqual(20.0, _squared_distance(lst, dv)) + self.assertEqual(0.0, _squared_distance(sv, sv)) + self.assertEqual(0.0, _squared_distance(dv, dv)) + self.assertEqual(0.0, _squared_distance(lst, lst)) + self.assertEqual(25.0, _squared_distance(sv, lst1)) + self.assertEqual(3.0, _squared_distance(sv, arr)) + self.assertEqual(3.0, _squared_distance(sv, narr)) + + def test_hash(self): + v1 = DenseVector([0.0, 1.0, 0.0, 5.5]) + v2 = SparseVector(4, [(1, 1.0), (3, 5.5)]) + v3 = DenseVector([0.0, 1.0, 0.0, 5.5]) + v4 = SparseVector(4, [(1, 1.0), (3, 2.5)]) + self.assertEqual(hash(v1), hash(v2)) + self.assertEqual(hash(v1), hash(v3)) + self.assertEqual(hash(v2), hash(v3)) + self.assertFalse(hash(v1) == hash(v4)) + self.assertFalse(hash(v2) == hash(v4)) + + def test_eq(self): + v1 = DenseVector([0.0, 1.0, 0.0, 5.5]) + v2 = SparseVector(4, [(1, 1.0), (3, 5.5)]) + v3 = DenseVector([0.0, 1.0, 0.0, 5.5]) + v4 = SparseVector(6, [(1, 1.0), (3, 5.5)]) + v5 = DenseVector([0.0, 1.0, 0.0, 2.5]) + v6 = SparseVector(4, [(1, 1.0), (3, 2.5)]) + self.assertEqual(v1, v2) + self.assertEqual(v1, v3) + self.assertFalse(v2 == v4) + self.assertFalse(v1 == v5) + self.assertFalse(v1 == v6) + + def test_equals(self): + indices = [1, 2, 4] + values = [1., 3., 2.] + self.assertTrue(Vectors._equals(indices, values, list(range(5)), [0., 1., 3., 0., 2.])) + self.assertFalse(Vectors._equals(indices, values, list(range(5)), [0., 3., 1., 0., 2.])) + self.assertFalse(Vectors._equals(indices, values, list(range(5)), [0., 3., 0., 2.])) + self.assertFalse(Vectors._equals(indices, values, list(range(5)), [0., 1., 3., 2., 2.])) + + def test_conversion(self): + # numpy arrays should be automatically upcast to float64 + # tests for fix of [SPARK-5089] + v = array([1, 2, 3, 4], dtype='float64') + dv = DenseVector(v) + self.assertTrue(dv.array.dtype == 'float64') + v = array([1, 2, 3, 4], dtype='float32') + dv = DenseVector(v) + self.assertTrue(dv.array.dtype == 'float64') + + def test_sparse_vector_indexing(self): + sv = SparseVector(5, {1: 1, 3: 2}) + self.assertEqual(sv[0], 0.) + self.assertEqual(sv[3], 2.) + self.assertEqual(sv[1], 1.) + self.assertEqual(sv[2], 0.) + self.assertEqual(sv[4], 0.) + self.assertEqual(sv[-1], 0.) + self.assertEqual(sv[-2], 2.) + self.assertEqual(sv[-3], 0.) + self.assertEqual(sv[-5], 0.) + for ind in [5, -6]: + self.assertRaises(IndexError, sv.__getitem__, ind) + for ind in [7.8, '1']: + self.assertRaises(TypeError, sv.__getitem__, ind) + + zeros = SparseVector(4, {}) + self.assertEqual(zeros[0], 0.0) + self.assertEqual(zeros[3], 0.0) + for ind in [4, -5]: + self.assertRaises(IndexError, zeros.__getitem__, ind) + + empty = SparseVector(0, {}) + for ind in [-1, 0, 1]: + self.assertRaises(IndexError, empty.__getitem__, ind) + + def test_sparse_vector_iteration(self): + self.assertListEqual(list(SparseVector(3, [], [])), [0.0, 0.0, 0.0]) + self.assertListEqual(list(SparseVector(5, [0, 3], [1.0, 2.0])), [1.0, 0.0, 0.0, 2.0, 0.0]) + + def test_matrix_indexing(self): + mat = DenseMatrix(3, 2, [0, 1, 4, 6, 8, 10]) + expected = [[0, 6], [1, 8], [4, 10]] + for i in range(3): + for j in range(2): + self.assertEqual(mat[i, j], expected[i][j]) + + for i, j in [(-1, 0), (4, 1), (3, 4)]: + self.assertRaises(IndexError, mat.__getitem__, (i, j)) + + def test_repr_dense_matrix(self): + mat = DenseMatrix(3, 2, [0, 1, 4, 6, 8, 10]) + self.assertTrue( + repr(mat), + 'DenseMatrix(3, 2, [0.0, 1.0, 4.0, 6.0, 8.0, 10.0], False)') + + mat = DenseMatrix(3, 2, [0, 1, 4, 6, 8, 10], True) + self.assertTrue( + repr(mat), + 'DenseMatrix(3, 2, [0.0, 1.0, 4.0, 6.0, 8.0, 10.0], False)') + + mat = DenseMatrix(6, 3, zeros(18)) + self.assertTrue( + repr(mat), + 'DenseMatrix(6, 3, [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ..., \ + 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], False)') + + def test_repr_sparse_matrix(self): + sm1t = SparseMatrix( + 3, 4, [0, 2, 3, 5], [0, 1, 2, 0, 2], [3.0, 2.0, 4.0, 9.0, 8.0], + isTransposed=True) + self.assertTrue( + repr(sm1t), + 'SparseMatrix(3, 4, [0, 2, 3, 5], [0, 1, 2, 0, 2], [3.0, 2.0, 4.0, 9.0, 8.0], True)') + + indices = tile(arange(6), 3) + values = ones(18) + sm = SparseMatrix(6, 3, [0, 6, 12, 18], indices, values) + self.assertTrue( + repr(sm), "SparseMatrix(6, 3, [0, 6, 12, 18], \ + [0, 1, 2, 3, 4, 5, 0, 1, ..., 4, 5, 0, 1, 2, 3, 4, 5], \ + [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, ..., \ + 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0], False)") + + self.assertTrue( + str(sm), + "6 X 3 CSCMatrix\n\ + (0,0) 1.0\n(1,0) 1.0\n(2,0) 1.0\n(3,0) 1.0\n(4,0) 1.0\n(5,0) 1.0\n\ + (0,1) 1.0\n(1,1) 1.0\n(2,1) 1.0\n(3,1) 1.0\n(4,1) 1.0\n(5,1) 1.0\n\ + (0,2) 1.0\n(1,2) 1.0\n(2,2) 1.0\n(3,2) 1.0\n..\n..") + + sm = SparseMatrix(1, 18, zeros(19), [], []) + self.assertTrue( + repr(sm), + 'SparseMatrix(1, 18, \ + [0, 0, 0, 0, 0, 0, 0, 0, ..., 0, 0, 0, 0, 0, 0, 0, 0], [], [], False)') + + def test_sparse_matrix(self): + # Test sparse matrix creation. + sm1 = SparseMatrix( + 3, 4, [0, 2, 2, 4, 4], [1, 2, 1, 2], [1.0, 2.0, 4.0, 5.0]) + self.assertEqual(sm1.numRows, 3) + self.assertEqual(sm1.numCols, 4) + self.assertEqual(sm1.colPtrs.tolist(), [0, 2, 2, 4, 4]) + self.assertEqual(sm1.rowIndices.tolist(), [1, 2, 1, 2]) + self.assertEqual(sm1.values.tolist(), [1.0, 2.0, 4.0, 5.0]) + self.assertTrue( + repr(sm1), + 'SparseMatrix(3, 4, [0, 2, 2, 4, 4], [1, 2, 1, 2], [1.0, 2.0, 4.0, 5.0], False)') + + # Test indexing + expected = [ + [0, 0, 0, 0], + [1, 0, 4, 0], + [2, 0, 5, 0]] + + for i in range(3): + for j in range(4): + self.assertEqual(expected[i][j], sm1[i, j]) + self.assertTrue(array_equal(sm1.toArray(), expected)) + + for i, j in [(-1, 1), (4, 3), (3, 5)]: + self.assertRaises(IndexError, sm1.__getitem__, (i, j)) + + # Test conversion to dense and sparse. + smnew = sm1.toDense().toSparse() + self.assertEqual(sm1.numRows, smnew.numRows) + self.assertEqual(sm1.numCols, smnew.numCols) + self.assertTrue(array_equal(sm1.colPtrs, smnew.colPtrs)) + self.assertTrue(array_equal(sm1.rowIndices, smnew.rowIndices)) + self.assertTrue(array_equal(sm1.values, smnew.values)) + + sm1t = SparseMatrix( + 3, 4, [0, 2, 3, 5], [0, 1, 2, 0, 2], [3.0, 2.0, 4.0, 9.0, 8.0], + isTransposed=True) + self.assertEqual(sm1t.numRows, 3) + self.assertEqual(sm1t.numCols, 4) + self.assertEqual(sm1t.colPtrs.tolist(), [0, 2, 3, 5]) + self.assertEqual(sm1t.rowIndices.tolist(), [0, 1, 2, 0, 2]) + self.assertEqual(sm1t.values.tolist(), [3.0, 2.0, 4.0, 9.0, 8.0]) + + expected = [ + [3, 2, 0, 0], + [0, 0, 4, 0], + [9, 0, 8, 0]] + + for i in range(3): + for j in range(4): + self.assertEqual(expected[i][j], sm1t[i, j]) + self.assertTrue(array_equal(sm1t.toArray(), expected)) + + def test_dense_matrix_is_transposed(self): + mat1 = DenseMatrix(3, 2, [0, 4, 1, 6, 3, 9], isTransposed=True) + mat = DenseMatrix(3, 2, [0, 1, 3, 4, 6, 9]) + self.assertEqual(mat1, mat) + + expected = [[0, 4], [1, 6], [3, 9]] + for i in range(3): + for j in range(2): + self.assertEqual(mat1[i, j], expected[i][j]) + self.assertTrue(array_equal(mat1.toArray(), expected)) + + sm = mat1.toSparse() + self.assertTrue(array_equal(sm.rowIndices, [1, 2, 0, 1, 2])) + self.assertTrue(array_equal(sm.colPtrs, [0, 2, 5])) + self.assertTrue(array_equal(sm.values, [1, 3, 4, 6, 9])) + + def test_parse_vector(self): + a = DenseVector([]) + self.assertEqual(str(a), '[]') + self.assertEqual(Vectors.parse(str(a)), a) + a = DenseVector([3, 4, 6, 7]) + self.assertEqual(str(a), '[3.0,4.0,6.0,7.0]') + self.assertEqual(Vectors.parse(str(a)), a) + a = SparseVector(4, [], []) + self.assertEqual(str(a), '(4,[],[])') + self.assertEqual(SparseVector.parse(str(a)), a) + a = SparseVector(4, [0, 2], [3, 4]) + self.assertEqual(str(a), '(4,[0,2],[3.0,4.0])') + self.assertEqual(Vectors.parse(str(a)), a) + a = SparseVector(10, [0, 1], [4, 5]) + self.assertEqual(SparseVector.parse(' (10, [0,1 ],[ 4.0,5.0] )'), a) + + def test_norms(self): + a = DenseVector([0, 2, 3, -1]) + self.assertAlmostEqual(a.norm(2), 3.742, 3) + self.assertTrue(a.norm(1), 6) + self.assertTrue(a.norm(inf), 3) + a = SparseVector(4, [0, 2], [3, -4]) + self.assertAlmostEqual(a.norm(2), 5) + self.assertTrue(a.norm(1), 7) + self.assertTrue(a.norm(inf), 4) + + tmp = SparseVector(4, [0, 2], [3, 0]) + self.assertEqual(tmp.numNonzeros(), 1) + + def test_ml_mllib_vector_conversion(self): + # to ml + # dense + mllibDV = Vectors.dense([1, 2, 3]) + mlDV1 = newlinalg.Vectors.dense([1, 2, 3]) + mlDV2 = mllibDV.asML() + self.assertEqual(mlDV2, mlDV1) + # sparse + mllibSV = Vectors.sparse(4, {1: 1.0, 3: 5.5}) + mlSV1 = newlinalg.Vectors.sparse(4, {1: 1.0, 3: 5.5}) + mlSV2 = mllibSV.asML() + self.assertEqual(mlSV2, mlSV1) + # from ml + # dense + mllibDV1 = Vectors.dense([1, 2, 3]) + mlDV = newlinalg.Vectors.dense([1, 2, 3]) + mllibDV2 = Vectors.fromML(mlDV) + self.assertEqual(mllibDV1, mllibDV2) + # sparse + mllibSV1 = Vectors.sparse(4, {1: 1.0, 3: 5.5}) + mlSV = newlinalg.Vectors.sparse(4, {1: 1.0, 3: 5.5}) + mllibSV2 = Vectors.fromML(mlSV) + self.assertEqual(mllibSV1, mllibSV2) + + def test_ml_mllib_matrix_conversion(self): + # to ml + # dense + mllibDM = Matrices.dense(2, 2, [0, 1, 2, 3]) + mlDM1 = newlinalg.Matrices.dense(2, 2, [0, 1, 2, 3]) + mlDM2 = mllibDM.asML() + self.assertEqual(mlDM2, mlDM1) + # transposed + mllibDMt = DenseMatrix(2, 2, [0, 1, 2, 3], True) + mlDMt1 = newlinalg.DenseMatrix(2, 2, [0, 1, 2, 3], True) + mlDMt2 = mllibDMt.asML() + self.assertEqual(mlDMt2, mlDMt1) + # sparse + mllibSM = Matrices.sparse(2, 2, [0, 2, 3], [0, 1, 1], [2, 3, 4]) + mlSM1 = newlinalg.Matrices.sparse(2, 2, [0, 2, 3], [0, 1, 1], [2, 3, 4]) + mlSM2 = mllibSM.asML() + self.assertEqual(mlSM2, mlSM1) + # transposed + mllibSMt = SparseMatrix(2, 2, [0, 2, 3], [0, 1, 1], [2, 3, 4], True) + mlSMt1 = newlinalg.SparseMatrix(2, 2, [0, 2, 3], [0, 1, 1], [2, 3, 4], True) + mlSMt2 = mllibSMt.asML() + self.assertEqual(mlSMt2, mlSMt1) + # from ml + # dense + mllibDM1 = Matrices.dense(2, 2, [1, 2, 3, 4]) + mlDM = newlinalg.Matrices.dense(2, 2, [1, 2, 3, 4]) + mllibDM2 = Matrices.fromML(mlDM) + self.assertEqual(mllibDM1, mllibDM2) + # transposed + mllibDMt1 = DenseMatrix(2, 2, [1, 2, 3, 4], True) + mlDMt = newlinalg.DenseMatrix(2, 2, [1, 2, 3, 4], True) + mllibDMt2 = Matrices.fromML(mlDMt) + self.assertEqual(mllibDMt1, mllibDMt2) + # sparse + mllibSM1 = Matrices.sparse(2, 2, [0, 2, 3], [0, 1, 1], [2, 3, 4]) + mlSM = newlinalg.Matrices.sparse(2, 2, [0, 2, 3], [0, 1, 1], [2, 3, 4]) + mllibSM2 = Matrices.fromML(mlSM) + self.assertEqual(mllibSM1, mllibSM2) + # transposed + mllibSMt1 = SparseMatrix(2, 2, [0, 2, 3], [0, 1, 1], [2, 3, 4], True) + mlSMt = newlinalg.SparseMatrix(2, 2, [0, 2, 3], [0, 1, 1], [2, 3, 4], True) + mllibSMt2 = Matrices.fromML(mlSMt) + self.assertEqual(mllibSMt1, mllibSMt2) + + +class VectorUDTTests(MLlibTestCase): + + dv0 = DenseVector([]) + dv1 = DenseVector([1.0, 2.0]) + sv0 = SparseVector(2, [], []) + sv1 = SparseVector(2, [1], [2.0]) + udt = VectorUDT() + + def test_json_schema(self): + self.assertEqual(VectorUDT.fromJson(self.udt.jsonValue()), self.udt) + + def test_serialization(self): + for v in [self.dv0, self.dv1, self.sv0, self.sv1]: + self.assertEqual(v, self.udt.deserialize(self.udt.serialize(v))) + + def test_infer_schema(self): + rdd = self.sc.parallelize([LabeledPoint(1.0, self.dv1), LabeledPoint(0.0, self.sv1)]) + df = rdd.toDF() + schema = df.schema + field = [f for f in schema.fields if f.name == "features"][0] + self.assertEqual(field.dataType, self.udt) + vectors = df.rdd.map(lambda p: p.features).collect() + self.assertEqual(len(vectors), 2) + for v in vectors: + if isinstance(v, SparseVector): + self.assertEqual(v, self.sv1) + elif isinstance(v, DenseVector): + self.assertEqual(v, self.dv1) + else: + raise TypeError("expecting a vector but got %r of type %r" % (v, type(v))) + + +class MatrixUDTTests(MLlibTestCase): + + dm1 = DenseMatrix(3, 2, [0, 1, 4, 5, 9, 10]) + dm2 = DenseMatrix(3, 2, [0, 1, 4, 5, 9, 10], isTransposed=True) + sm1 = SparseMatrix(1, 1, [0, 1], [0], [2.0]) + sm2 = SparseMatrix(2, 1, [0, 0, 1], [0], [5.0], isTransposed=True) + udt = MatrixUDT() + + def test_json_schema(self): + self.assertEqual(MatrixUDT.fromJson(self.udt.jsonValue()), self.udt) + + def test_serialization(self): + for m in [self.dm1, self.dm2, self.sm1, self.sm2]: + self.assertEqual(m, self.udt.deserialize(self.udt.serialize(m))) + + def test_infer_schema(self): + rdd = self.sc.parallelize([("dense", self.dm1), ("sparse", self.sm1)]) + df = rdd.toDF() + schema = df.schema + self.assertTrue(schema.fields[1].dataType, self.udt) + matrices = df.rdd.map(lambda x: x._2).collect() + self.assertEqual(len(matrices), 2) + for m in matrices: + if isinstance(m, DenseMatrix): + self.assertTrue(m, self.dm1) + elif isinstance(m, SparseMatrix): + self.assertTrue(m, self.sm1) + else: + raise ValueError("Expected a matrix but got type %r" % type(m)) + + +@unittest.skipIf(not _have_scipy, "SciPy not installed") +class SciPyTests(MLlibTestCase): + + """ + Test both vector operations and MLlib algorithms with SciPy sparse matrices, + if SciPy is available. + """ + + def test_serialize(self): + from scipy.sparse import lil_matrix + lil = lil_matrix((4, 1)) + lil[1, 0] = 1 + lil[3, 0] = 2 + sv = SparseVector(4, {1: 1, 3: 2}) + self.assertEqual(sv, _convert_to_vector(lil)) + self.assertEqual(sv, _convert_to_vector(lil.tocsc())) + self.assertEqual(sv, _convert_to_vector(lil.tocoo())) + self.assertEqual(sv, _convert_to_vector(lil.tocsr())) + self.assertEqual(sv, _convert_to_vector(lil.todok())) + + def serialize(l): + return ser.loads(ser.dumps(_convert_to_vector(l))) + self.assertEqual(sv, serialize(lil)) + self.assertEqual(sv, serialize(lil.tocsc())) + self.assertEqual(sv, serialize(lil.tocsr())) + self.assertEqual(sv, serialize(lil.todok())) + + def test_convert_to_vector(self): + from scipy.sparse import csc_matrix + # Create a CSC matrix with non-sorted indices + indptr = array([0, 2]) + indices = array([3, 1]) + data = array([2.0, 1.0]) + csc = csc_matrix((data, indices, indptr)) + self.assertFalse(csc.has_sorted_indices) + sv = SparseVector(4, {1: 1, 3: 2}) + self.assertEqual(sv, _convert_to_vector(csc)) + + def test_dot(self): + from scipy.sparse import lil_matrix + lil = lil_matrix((4, 1)) + lil[1, 0] = 1 + lil[3, 0] = 2 + dv = DenseVector(array([1., 2., 3., 4.])) + self.assertEqual(10.0, dv.dot(lil)) + + def test_squared_distance(self): + from scipy.sparse import lil_matrix + lil = lil_matrix((4, 1)) + lil[1, 0] = 3 + lil[3, 0] = 2 + dv = DenseVector(array([1., 2., 3., 4.])) + sv = SparseVector(4, {0: 1, 1: 2, 2: 3, 3: 4}) + self.assertEqual(15.0, dv.squared_distance(lil)) + self.assertEqual(15.0, sv.squared_distance(lil)) + + def scipy_matrix(self, size, values): + """Create a column SciPy matrix from a dictionary of values""" + from scipy.sparse import lil_matrix + lil = lil_matrix((size, 1)) + for key, value in values.items(): + lil[key, 0] = value + return lil + + def test_clustering(self): + from pyspark.mllib.clustering import KMeans + data = [ + self.scipy_matrix(3, {1: 1.0}), + self.scipy_matrix(3, {1: 1.1}), + self.scipy_matrix(3, {2: 1.0}), + self.scipy_matrix(3, {2: 1.1}) + ] + clusters = KMeans.train(self.sc.parallelize(data), 2, initializationMode="k-means||") + self.assertEqual(clusters.predict(data[0]), clusters.predict(data[1])) + self.assertEqual(clusters.predict(data[2]), clusters.predict(data[3])) + + def test_classification(self): + from pyspark.mllib.classification import LogisticRegressionWithSGD, SVMWithSGD, NaiveBayes + from pyspark.mllib.tree import DecisionTree + data = [ + LabeledPoint(0.0, self.scipy_matrix(2, {0: 1.0})), + LabeledPoint(1.0, self.scipy_matrix(2, {1: 1.0})), + LabeledPoint(0.0, self.scipy_matrix(2, {0: 2.0})), + LabeledPoint(1.0, self.scipy_matrix(2, {1: 2.0})) + ] + rdd = self.sc.parallelize(data) + features = [p.features for p in data] + + lr_model = LogisticRegressionWithSGD.train(rdd) + self.assertTrue(lr_model.predict(features[0]) <= 0) + self.assertTrue(lr_model.predict(features[1]) > 0) + self.assertTrue(lr_model.predict(features[2]) <= 0) + self.assertTrue(lr_model.predict(features[3]) > 0) + + svm_model = SVMWithSGD.train(rdd) + self.assertTrue(svm_model.predict(features[0]) <= 0) + self.assertTrue(svm_model.predict(features[1]) > 0) + self.assertTrue(svm_model.predict(features[2]) <= 0) + self.assertTrue(svm_model.predict(features[3]) > 0) + + nb_model = NaiveBayes.train(rdd) + self.assertTrue(nb_model.predict(features[0]) <= 0) + self.assertTrue(nb_model.predict(features[1]) > 0) + self.assertTrue(nb_model.predict(features[2]) <= 0) + self.assertTrue(nb_model.predict(features[3]) > 0) + + categoricalFeaturesInfo = {0: 3} # feature 0 has 3 categories + dt_model = DecisionTree.trainClassifier(rdd, numClasses=2, + categoricalFeaturesInfo=categoricalFeaturesInfo) + self.assertTrue(dt_model.predict(features[0]) <= 0) + self.assertTrue(dt_model.predict(features[1]) > 0) + self.assertTrue(dt_model.predict(features[2]) <= 0) + self.assertTrue(dt_model.predict(features[3]) > 0) + + def test_regression(self): + from pyspark.mllib.regression import LinearRegressionWithSGD, LassoWithSGD, \ + RidgeRegressionWithSGD + from pyspark.mllib.tree import DecisionTree + data = [ + LabeledPoint(-1.0, self.scipy_matrix(2, {1: -1.0})), + LabeledPoint(1.0, self.scipy_matrix(2, {1: 1.0})), + LabeledPoint(-1.0, self.scipy_matrix(2, {1: -2.0})), + LabeledPoint(1.0, self.scipy_matrix(2, {1: 2.0})) + ] + rdd = self.sc.parallelize(data) + features = [p.features for p in data] + + lr_model = LinearRegressionWithSGD.train(rdd) + self.assertTrue(lr_model.predict(features[0]) <= 0) + self.assertTrue(lr_model.predict(features[1]) > 0) + self.assertTrue(lr_model.predict(features[2]) <= 0) + self.assertTrue(lr_model.predict(features[3]) > 0) + + lasso_model = LassoWithSGD.train(rdd) + self.assertTrue(lasso_model.predict(features[0]) <= 0) + self.assertTrue(lasso_model.predict(features[1]) > 0) + self.assertTrue(lasso_model.predict(features[2]) <= 0) + self.assertTrue(lasso_model.predict(features[3]) > 0) + + rr_model = RidgeRegressionWithSGD.train(rdd) + self.assertTrue(rr_model.predict(features[0]) <= 0) + self.assertTrue(rr_model.predict(features[1]) > 0) + self.assertTrue(rr_model.predict(features[2]) <= 0) + self.assertTrue(rr_model.predict(features[3]) > 0) + + categoricalFeaturesInfo = {0: 2} # feature 0 has 2 categories + dt_model = DecisionTree.trainRegressor(rdd, categoricalFeaturesInfo=categoricalFeaturesInfo) + self.assertTrue(dt_model.predict(features[0]) <= 0) + self.assertTrue(dt_model.predict(features[1]) > 0) + self.assertTrue(dt_model.predict(features[2]) <= 0) + self.assertTrue(dt_model.predict(features[3]) > 0) + + +if __name__ == "__main__": + from pyspark.mllib.tests.test_linalg import * + if not _have_scipy: + print("NOTE: Skipping SciPy tests as it does not seem to be installed") + try: + import xmlrunner + testRunner = xmlrunner.XMLTestRunner(output='target/test-reports') + except ImportError: + testRunner = None + unittest.main(testRunner=testRunner, verbosity=2) + if not _have_scipy: + print("NOTE: SciPy tests were skipped as it does not seem to be installed") diff --git a/python/pyspark/mllib/tests/test_stat.py b/python/pyspark/mllib/tests/test_stat.py new file mode 100644 index 0000000000000..f23ae291d317a --- /dev/null +++ b/python/pyspark/mllib/tests/test_stat.py @@ -0,0 +1,188 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import array as pyarray +import unittest + +from numpy import array + +from pyspark.mllib.linalg import Vector, SparseVector, DenseVector, VectorUDT, _convert_to_vector, \ + DenseMatrix, SparseMatrix, Vectors, Matrices, MatrixUDT +from pyspark.mllib.random import RandomRDDs +from pyspark.mllib.regression import LabeledPoint +from pyspark.mllib.stat import Statistics +from pyspark.sql.utils import IllegalArgumentException +from pyspark.testing.mllibutils import MLlibTestCase + + +class StatTests(MLlibTestCase): + # SPARK-4023 + def test_col_with_different_rdds(self): + # numpy + data = RandomRDDs.normalVectorRDD(self.sc, 1000, 10, 10) + summary = Statistics.colStats(data) + self.assertEqual(1000, summary.count()) + # array + data = self.sc.parallelize([range(10)] * 10) + summary = Statistics.colStats(data) + self.assertEqual(10, summary.count()) + # array + data = self.sc.parallelize([pyarray.array("d", range(10))] * 10) + summary = Statistics.colStats(data) + self.assertEqual(10, summary.count()) + + def test_col_norms(self): + data = RandomRDDs.normalVectorRDD(self.sc, 1000, 10, 10) + summary = Statistics.colStats(data) + self.assertEqual(10, len(summary.normL1())) + self.assertEqual(10, len(summary.normL2())) + + data2 = self.sc.parallelize(range(10)).map(lambda x: Vectors.dense(x)) + summary2 = Statistics.colStats(data2) + self.assertEqual(array([45.0]), summary2.normL1()) + import math + expectedNormL2 = math.sqrt(sum(map(lambda x: x*x, range(10)))) + self.assertTrue(math.fabs(summary2.normL2()[0] - expectedNormL2) < 1e-14) + + +class ChiSqTestTests(MLlibTestCase): + def test_goodness_of_fit(self): + from numpy import inf + + observed = Vectors.dense([4, 6, 5]) + pearson = Statistics.chiSqTest(observed) + + # Validated against the R command `chisq.test(c(4, 6, 5), p=c(1/3, 1/3, 1/3))` + self.assertEqual(pearson.statistic, 0.4) + self.assertEqual(pearson.degreesOfFreedom, 2) + self.assertAlmostEqual(pearson.pValue, 0.8187, 4) + + # Different expected and observed sum + observed1 = Vectors.dense([21, 38, 43, 80]) + expected1 = Vectors.dense([3, 5, 7, 20]) + pearson1 = Statistics.chiSqTest(observed1, expected1) + + # Results validated against the R command + # `chisq.test(c(21, 38, 43, 80), p=c(3/35, 1/7, 1/5, 4/7))` + self.assertAlmostEqual(pearson1.statistic, 14.1429, 4) + self.assertEqual(pearson1.degreesOfFreedom, 3) + self.assertAlmostEqual(pearson1.pValue, 0.002717, 4) + + # Vectors with different sizes + observed3 = Vectors.dense([1.0, 2.0, 3.0]) + expected3 = Vectors.dense([1.0, 2.0, 3.0, 4.0]) + self.assertRaises(ValueError, Statistics.chiSqTest, observed3, expected3) + + # Negative counts in observed + neg_obs = Vectors.dense([1.0, 2.0, 3.0, -4.0]) + self.assertRaises(IllegalArgumentException, Statistics.chiSqTest, neg_obs, expected1) + + # Count = 0.0 in expected but not observed + zero_expected = Vectors.dense([1.0, 0.0, 3.0]) + pearson_inf = Statistics.chiSqTest(observed, zero_expected) + self.assertEqual(pearson_inf.statistic, inf) + self.assertEqual(pearson_inf.degreesOfFreedom, 2) + self.assertEqual(pearson_inf.pValue, 0.0) + + # 0.0 in expected and observed simultaneously + zero_observed = Vectors.dense([2.0, 0.0, 1.0]) + self.assertRaises( + IllegalArgumentException, Statistics.chiSqTest, zero_observed, zero_expected) + + def test_matrix_independence(self): + data = [40.0, 24.0, 29.0, 56.0, 32.0, 42.0, 31.0, 10.0, 0.0, 30.0, 15.0, 12.0] + chi = Statistics.chiSqTest(Matrices.dense(3, 4, data)) + + # Results validated against R command + # `chisq.test(rbind(c(40, 56, 31, 30),c(24, 32, 10, 15), c(29, 42, 0, 12)))` + self.assertAlmostEqual(chi.statistic, 21.9958, 4) + self.assertEqual(chi.degreesOfFreedom, 6) + self.assertAlmostEqual(chi.pValue, 0.001213, 4) + + # Negative counts + neg_counts = Matrices.dense(2, 2, [4.0, 5.0, 3.0, -3.0]) + self.assertRaises(IllegalArgumentException, Statistics.chiSqTest, neg_counts) + + # Row sum = 0.0 + row_zero = Matrices.dense(2, 2, [0.0, 1.0, 0.0, 2.0]) + self.assertRaises(IllegalArgumentException, Statistics.chiSqTest, row_zero) + + # Column sum = 0.0 + col_zero = Matrices.dense(2, 2, [0.0, 0.0, 2.0, 2.0]) + self.assertRaises(IllegalArgumentException, Statistics.chiSqTest, col_zero) + + def test_chi_sq_pearson(self): + data = [ + LabeledPoint(0.0, Vectors.dense([0.5, 10.0])), + LabeledPoint(0.0, Vectors.dense([1.5, 20.0])), + LabeledPoint(1.0, Vectors.dense([1.5, 30.0])), + LabeledPoint(0.0, Vectors.dense([3.5, 30.0])), + LabeledPoint(0.0, Vectors.dense([3.5, 40.0])), + LabeledPoint(1.0, Vectors.dense([3.5, 40.0])) + ] + + for numParts in [2, 4, 6, 8]: + chi = Statistics.chiSqTest(self.sc.parallelize(data, numParts)) + feature1 = chi[0] + self.assertEqual(feature1.statistic, 0.75) + self.assertEqual(feature1.degreesOfFreedom, 2) + self.assertAlmostEqual(feature1.pValue, 0.6873, 4) + + feature2 = chi[1] + self.assertEqual(feature2.statistic, 1.5) + self.assertEqual(feature2.degreesOfFreedom, 3) + self.assertAlmostEqual(feature2.pValue, 0.6823, 4) + + def test_right_number_of_results(self): + num_cols = 1001 + sparse_data = [ + LabeledPoint(0.0, Vectors.sparse(num_cols, [(100, 2.0)])), + LabeledPoint(0.1, Vectors.sparse(num_cols, [(200, 1.0)])) + ] + chi = Statistics.chiSqTest(self.sc.parallelize(sparse_data)) + self.assertEqual(len(chi), num_cols) + self.assertIsNotNone(chi[1000]) + + +class KolmogorovSmirnovTest(MLlibTestCase): + + def test_R_implementation_equivalence(self): + data = self.sc.parallelize([ + 1.1626852897838, -0.585924465893051, 1.78546500331661, -1.33259371048501, + -0.446566766553219, 0.569606122374976, -2.88971761441412, -0.869018343326555, + -0.461702683149641, -0.555540910137444, -0.0201353678515895, -0.150382224136063, + -0.628126755843964, 1.32322085193283, -1.52135057001199, -0.437427868856691, + 0.970577579543399, 0.0282226444247749, -0.0857821886527593, 0.389214404984942 + ]) + model = Statistics.kolmogorovSmirnovTest(data, "norm") + self.assertAlmostEqual(model.statistic, 0.189, 3) + self.assertAlmostEqual(model.pValue, 0.422, 3) + + model = Statistics.kolmogorovSmirnovTest(data, "norm", 0, 1) + self.assertAlmostEqual(model.statistic, 0.189, 3) + self.assertAlmostEqual(model.pValue, 0.422, 3) + + +if __name__ == "__main__": + from pyspark.mllib.tests.test_stat import * + + try: + import xmlrunner + testRunner = xmlrunner.XMLTestRunner(output='target/test-reports') + except ImportError: + testRunner = None + unittest.main(testRunner=testRunner, verbosity=2) diff --git a/python/pyspark/mllib/tests/test_streaming_algorithms.py b/python/pyspark/mllib/tests/test_streaming_algorithms.py new file mode 100644 index 0000000000000..4bc8904acd31c --- /dev/null +++ b/python/pyspark/mllib/tests/test_streaming_algorithms.py @@ -0,0 +1,514 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from time import time, sleep +import unittest + +from numpy import array, random, exp, dot, all, mean, abs +from numpy import sum as array_sum + +from pyspark import SparkContext +from pyspark.mllib.clustering import StreamingKMeans, StreamingKMeansModel +from pyspark.mllib.classification import StreamingLogisticRegressionWithSGD +from pyspark.mllib.linalg import Vectors +from pyspark.mllib.regression import LabeledPoint, StreamingLinearRegressionWithSGD +from pyspark.mllib.util import LinearDataGenerator +from pyspark.streaming import StreamingContext + + +class MLLibStreamingTestCase(unittest.TestCase): + def setUp(self): + self.sc = SparkContext('local[4]', "MLlib tests") + self.ssc = StreamingContext(self.sc, 1.0) + + def tearDown(self): + self.ssc.stop(False) + self.sc.stop() + + @staticmethod + def _eventually(condition, timeout=30.0, catch_assertions=False): + """ + Wait a given amount of time for a condition to pass, else fail with an error. + This is a helper utility for streaming ML tests. + :param condition: Function that checks for termination conditions. + condition() can return: + - True: Conditions met. Return without error. + - other value: Conditions not met yet. Continue. Upon timeout, + include last such value in error message. + Note that this method may be called at any time during + streaming execution (e.g., even before any results + have been created). + :param timeout: Number of seconds to wait. Default 30 seconds. + :param catch_assertions: If False (default), do not catch AssertionErrors. + If True, catch AssertionErrors; continue, but save + error to throw upon timeout. + """ + start_time = time() + lastValue = None + while time() - start_time < timeout: + if catch_assertions: + try: + lastValue = condition() + except AssertionError as e: + lastValue = e + else: + lastValue = condition() + if lastValue is True: + return + sleep(0.01) + if isinstance(lastValue, AssertionError): + raise lastValue + else: + raise AssertionError( + "Test failed due to timeout after %g sec, with last condition returning: %s" + % (timeout, lastValue)) + + +class StreamingKMeansTest(MLLibStreamingTestCase): + def test_model_params(self): + """Test that the model params are set correctly""" + stkm = StreamingKMeans() + stkm.setK(5).setDecayFactor(0.0) + self.assertEqual(stkm._k, 5) + self.assertEqual(stkm._decayFactor, 0.0) + + # Model not set yet. + self.assertIsNone(stkm.latestModel()) + self.assertRaises(ValueError, stkm.trainOn, [0.0, 1.0]) + + stkm.setInitialCenters( + centers=[[0.0, 0.0], [1.0, 1.0]], weights=[1.0, 1.0]) + self.assertEqual( + stkm.latestModel().centers, [[0.0, 0.0], [1.0, 1.0]]) + self.assertEqual(stkm.latestModel().clusterWeights, [1.0, 1.0]) + + def test_accuracy_for_single_center(self): + """Test that parameters obtained are correct for a single center.""" + centers, batches = self.streamingKMeansDataGenerator( + batches=5, numPoints=5, k=1, d=5, r=0.1, seed=0) + stkm = StreamingKMeans(1) + stkm.setInitialCenters([[0., 0., 0., 0., 0.]], [0.]) + input_stream = self.ssc.queueStream( + [self.sc.parallelize(batch, 1) for batch in batches]) + stkm.trainOn(input_stream) + + self.ssc.start() + + def condition(): + self.assertEqual(stkm.latestModel().clusterWeights, [25.0]) + return True + self._eventually(condition, catch_assertions=True) + + realCenters = array_sum(array(centers), axis=0) + for i in range(5): + modelCenters = stkm.latestModel().centers[0][i] + self.assertAlmostEqual(centers[0][i], modelCenters, 1) + self.assertAlmostEqual(realCenters[i], modelCenters, 1) + + def streamingKMeansDataGenerator(self, batches, numPoints, + k, d, r, seed, centers=None): + rng = random.RandomState(seed) + + # Generate centers. + centers = [rng.randn(d) for i in range(k)] + + return centers, [[Vectors.dense(centers[j % k] + r * rng.randn(d)) + for j in range(numPoints)] + for i in range(batches)] + + def test_trainOn_model(self): + """Test the model on toy data with four clusters.""" + stkm = StreamingKMeans() + initCenters = [[1.0, 1.0], [-1.0, 1.0], [-1.0, -1.0], [1.0, -1.0]] + stkm.setInitialCenters( + centers=initCenters, weights=[1.0, 1.0, 1.0, 1.0]) + + # Create a toy dataset by setting a tiny offset for each point. + offsets = [[0, 0.1], [0, -0.1], [0.1, 0], [-0.1, 0]] + batches = [] + for offset in offsets: + batches.append([[offset[0] + center[0], offset[1] + center[1]] + for center in initCenters]) + + batches = [self.sc.parallelize(batch, 1) for batch in batches] + input_stream = self.ssc.queueStream(batches) + stkm.trainOn(input_stream) + self.ssc.start() + + # Give enough time to train the model. + def condition(): + finalModel = stkm.latestModel() + self.assertTrue(all(finalModel.centers == array(initCenters))) + self.assertEqual(finalModel.clusterWeights, [5.0, 5.0, 5.0, 5.0]) + return True + self._eventually(condition, catch_assertions=True) + + def test_predictOn_model(self): + """Test that the model predicts correctly on toy data.""" + stkm = StreamingKMeans() + stkm._model = StreamingKMeansModel( + clusterCenters=[[1.0, 1.0], [-1.0, 1.0], [-1.0, -1.0], [1.0, -1.0]], + clusterWeights=[1.0, 1.0, 1.0, 1.0]) + + predict_data = [[[1.5, 1.5]], [[-1.5, 1.5]], [[-1.5, -1.5]], [[1.5, -1.5]]] + predict_data = [self.sc.parallelize(batch, 1) for batch in predict_data] + predict_stream = self.ssc.queueStream(predict_data) + predict_val = stkm.predictOn(predict_stream) + + result = [] + + def update(rdd): + rdd_collect = rdd.collect() + if rdd_collect: + result.append(rdd_collect) + + predict_val.foreachRDD(update) + self.ssc.start() + + def condition(): + self.assertEqual(result, [[0], [1], [2], [3]]) + return True + + self._eventually(condition, catch_assertions=True) + + @unittest.skip("SPARK-10086: Flaky StreamingKMeans test in PySpark") + def test_trainOn_predictOn(self): + """Test that prediction happens on the updated model.""" + stkm = StreamingKMeans(decayFactor=0.0, k=2) + stkm.setInitialCenters([[0.0], [1.0]], [1.0, 1.0]) + + # Since decay factor is set to zero, once the first batch + # is passed the clusterCenters are updated to [-0.5, 0.7] + # which causes 0.2 & 0.3 to be classified as 1, even though the + # classification based in the initial model would have been 0 + # proving that the model is updated. + batches = [[[-0.5], [0.6], [0.8]], [[0.2], [-0.1], [0.3]]] + batches = [self.sc.parallelize(batch) for batch in batches] + input_stream = self.ssc.queueStream(batches) + predict_results = [] + + def collect(rdd): + rdd_collect = rdd.collect() + if rdd_collect: + predict_results.append(rdd_collect) + + stkm.trainOn(input_stream) + predict_stream = stkm.predictOn(input_stream) + predict_stream.foreachRDD(collect) + + self.ssc.start() + + def condition(): + self.assertEqual(predict_results, [[0, 1, 1], [1, 0, 1]]) + return True + + self._eventually(condition, catch_assertions=True) + + +class StreamingLogisticRegressionWithSGDTests(MLLibStreamingTestCase): + + @staticmethod + def generateLogisticInput(offset, scale, nPoints, seed): + """ + Generate 1 / (1 + exp(-x * scale + offset)) + + where, + x is randomnly distributed and the threshold + and labels for each sample in x is obtained from a random uniform + distribution. + """ + rng = random.RandomState(seed) + x = rng.randn(nPoints) + sigmoid = 1. / (1 + exp(-(dot(x, scale) + offset))) + y_p = rng.rand(nPoints) + cut_off = y_p <= sigmoid + y_p[cut_off] = 1.0 + y_p[~cut_off] = 0.0 + return [ + LabeledPoint(y_p[i], Vectors.dense([x[i]])) + for i in range(nPoints)] + + def test_parameter_accuracy(self): + """ + Test that the final value of weights is close to the desired value. + """ + input_batches = [ + self.sc.parallelize(self.generateLogisticInput(0, 1.5, 100, 42 + i)) + for i in range(20)] + input_stream = self.ssc.queueStream(input_batches) + + slr = StreamingLogisticRegressionWithSGD( + stepSize=0.2, numIterations=25) + slr.setInitialWeights([0.0]) + slr.trainOn(input_stream) + + self.ssc.start() + + def condition(): + rel = (1.5 - slr.latestModel().weights.array[0]) / 1.5 + self.assertAlmostEqual(rel, 0.1, 1) + return True + + self._eventually(condition, catch_assertions=True) + + def test_convergence(self): + """ + Test that weights converge to the required value on toy data. + """ + input_batches = [ + self.sc.parallelize(self.generateLogisticInput(0, 1.5, 100, 42 + i)) + for i in range(20)] + input_stream = self.ssc.queueStream(input_batches) + models = [] + + slr = StreamingLogisticRegressionWithSGD( + stepSize=0.2, numIterations=25) + slr.setInitialWeights([0.0]) + slr.trainOn(input_stream) + input_stream.foreachRDD( + lambda x: models.append(slr.latestModel().weights[0])) + + self.ssc.start() + + def condition(): + self.assertEqual(len(models), len(input_batches)) + return True + + # We want all batches to finish for this test. + self._eventually(condition, 60.0, catch_assertions=True) + + t_models = array(models) + diff = t_models[1:] - t_models[:-1] + # Test that weights improve with a small tolerance + self.assertTrue(all(diff >= -0.1)) + self.assertTrue(array_sum(diff > 0) > 1) + + @staticmethod + def calculate_accuracy_error(true, predicted): + return sum(abs(array(true) - array(predicted))) / len(true) + + def test_predictions(self): + """Test predicted values on a toy model.""" + input_batches = [] + for i in range(20): + batch = self.sc.parallelize( + self.generateLogisticInput(0, 1.5, 100, 42 + i)) + input_batches.append(batch.map(lambda x: (x.label, x.features))) + input_stream = self.ssc.queueStream(input_batches) + + slr = StreamingLogisticRegressionWithSGD( + stepSize=0.2, numIterations=25) + slr.setInitialWeights([1.5]) + predict_stream = slr.predictOnValues(input_stream) + true_predicted = [] + predict_stream.foreachRDD(lambda x: true_predicted.append(x.collect())) + self.ssc.start() + + def condition(): + self.assertEqual(len(true_predicted), len(input_batches)) + return True + + self._eventually(condition, catch_assertions=True) + + # Test that the accuracy error is no more than 0.4 on each batch. + for batch in true_predicted: + true, predicted = zip(*batch) + self.assertTrue( + self.calculate_accuracy_error(true, predicted) < 0.4) + + def test_training_and_prediction(self): + """Test that the model improves on toy data with no. of batches""" + input_batches = [ + self.sc.parallelize(self.generateLogisticInput(0, 1.5, 100, 42 + i)) + for i in range(20)] + predict_batches = [ + b.map(lambda lp: (lp.label, lp.features)) for b in input_batches] + + slr = StreamingLogisticRegressionWithSGD( + stepSize=0.01, numIterations=25) + slr.setInitialWeights([-0.1]) + errors = [] + + def collect_errors(rdd): + true, predicted = zip(*rdd.collect()) + errors.append(self.calculate_accuracy_error(true, predicted)) + + true_predicted = [] + input_stream = self.ssc.queueStream(input_batches) + predict_stream = self.ssc.queueStream(predict_batches) + slr.trainOn(input_stream) + ps = slr.predictOnValues(predict_stream) + ps.foreachRDD(lambda x: collect_errors(x)) + + self.ssc.start() + + def condition(): + # Test that the improvement in error is > 0.3 + if len(errors) == len(predict_batches): + self.assertGreater(errors[1] - errors[-1], 0.3) + if len(errors) >= 3 and errors[1] - errors[-1] > 0.3: + return True + return "Latest errors: " + ", ".join(map(lambda x: str(x), errors)) + + self._eventually(condition) + + +class StreamingLinearRegressionWithTests(MLLibStreamingTestCase): + + def assertArrayAlmostEqual(self, array1, array2, dec): + for i, j in array1, array2: + self.assertAlmostEqual(i, j, dec) + + def test_parameter_accuracy(self): + """Test that coefs are predicted accurately by fitting on toy data.""" + + # Test that fitting (10*X1 + 10*X2), (X1, X2) gives coefficients + # (10, 10) + slr = StreamingLinearRegressionWithSGD(stepSize=0.2, numIterations=25) + slr.setInitialWeights([0.0, 0.0]) + xMean = [0.0, 0.0] + xVariance = [1.0 / 3.0, 1.0 / 3.0] + + # Create ten batches with 100 sample points in each. + batches = [] + for i in range(10): + batch = LinearDataGenerator.generateLinearInput( + 0.0, [10.0, 10.0], xMean, xVariance, 100, 42 + i, 0.1) + batches.append(self.sc.parallelize(batch)) + + input_stream = self.ssc.queueStream(batches) + slr.trainOn(input_stream) + self.ssc.start() + + def condition(): + self.assertArrayAlmostEqual( + slr.latestModel().weights.array, [10., 10.], 1) + self.assertAlmostEqual(slr.latestModel().intercept, 0.0, 1) + return True + + self._eventually(condition, catch_assertions=True) + + def test_parameter_convergence(self): + """Test that the model parameters improve with streaming data.""" + slr = StreamingLinearRegressionWithSGD(stepSize=0.2, numIterations=25) + slr.setInitialWeights([0.0]) + + # Create ten batches with 100 sample points in each. + batches = [] + for i in range(10): + batch = LinearDataGenerator.generateLinearInput( + 0.0, [10.0], [0.0], [1.0 / 3.0], 100, 42 + i, 0.1) + batches.append(self.sc.parallelize(batch)) + + model_weights = [] + input_stream = self.ssc.queueStream(batches) + input_stream.foreachRDD( + lambda x: model_weights.append(slr.latestModel().weights[0])) + slr.trainOn(input_stream) + self.ssc.start() + + def condition(): + self.assertEqual(len(model_weights), len(batches)) + return True + + # We want all batches to finish for this test. + self._eventually(condition, catch_assertions=True) + + w = array(model_weights) + diff = w[1:] - w[:-1] + self.assertTrue(all(diff >= -0.1)) + + def test_prediction(self): + """Test prediction on a model with weights already set.""" + # Create a model with initial Weights equal to coefs + slr = StreamingLinearRegressionWithSGD(stepSize=0.2, numIterations=25) + slr.setInitialWeights([10.0, 10.0]) + + # Create ten batches with 100 sample points in each. + batches = [] + for i in range(10): + batch = LinearDataGenerator.generateLinearInput( + 0.0, [10.0, 10.0], [0.0, 0.0], [1.0 / 3.0, 1.0 / 3.0], + 100, 42 + i, 0.1) + batches.append( + self.sc.parallelize(batch).map(lambda lp: (lp.label, lp.features))) + + input_stream = self.ssc.queueStream(batches) + output_stream = slr.predictOnValues(input_stream) + samples = [] + output_stream.foreachRDD(lambda x: samples.append(x.collect())) + + self.ssc.start() + + def condition(): + self.assertEqual(len(samples), len(batches)) + return True + + # We want all batches to finish for this test. + self._eventually(condition, catch_assertions=True) + + # Test that mean absolute error on each batch is less than 0.1 + for batch in samples: + true, predicted = zip(*batch) + self.assertTrue(mean(abs(array(true) - array(predicted))) < 0.1) + + def test_train_prediction(self): + """Test that error on test data improves as model is trained.""" + slr = StreamingLinearRegressionWithSGD(stepSize=0.2, numIterations=25) + slr.setInitialWeights([0.0]) + + # Create ten batches with 100 sample points in each. + batches = [] + for i in range(10): + batch = LinearDataGenerator.generateLinearInput( + 0.0, [10.0], [0.0], [1.0 / 3.0], 100, 42 + i, 0.1) + batches.append(self.sc.parallelize(batch)) + + predict_batches = [ + b.map(lambda lp: (lp.label, lp.features)) for b in batches] + errors = [] + + def func(rdd): + true, predicted = zip(*rdd.collect()) + errors.append(mean(abs(true) - abs(predicted))) + + input_stream = self.ssc.queueStream(batches) + output_stream = self.ssc.queueStream(predict_batches) + slr.trainOn(input_stream) + output_stream = slr.predictOnValues(output_stream) + output_stream.foreachRDD(func) + self.ssc.start() + + def condition(): + if len(errors) == len(predict_batches): + self.assertGreater(errors[1] - errors[-1], 2) + if len(errors) >= 3 and errors[1] - errors[-1] > 2: + return True + return "Latest errors: " + ", ".join(map(lambda x: str(x), errors)) + + self._eventually(condition) + + +if __name__ == "__main__": + from pyspark.mllib.tests.test_streaming_algorithms import * + + try: + import xmlrunner + testRunner = xmlrunner.XMLTestRunner(output='target/test-reports') + except ImportError: + testRunner = None + unittest.main(testRunner=testRunner, verbosity=2) diff --git a/python/pyspark/mllib/tests/test_util.py b/python/pyspark/mllib/tests/test_util.py new file mode 100644 index 0000000000000..e95716278f122 --- /dev/null +++ b/python/pyspark/mllib/tests/test_util.py @@ -0,0 +1,104 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import os +import tempfile +import unittest + +from pyspark.mllib.common import _to_java_object_rdd +from pyspark.mllib.util import LinearDataGenerator +from pyspark.mllib.util import MLUtils +from pyspark.mllib.linalg import SparseVector, DenseVector, Vectors +from pyspark.mllib.random import RandomRDDs +from pyspark.testing.mllibutils import MLlibTestCase + + +class MLUtilsTests(MLlibTestCase): + def test_append_bias(self): + data = [2.0, 2.0, 2.0] + ret = MLUtils.appendBias(data) + self.assertEqual(ret[3], 1.0) + self.assertEqual(type(ret), DenseVector) + + def test_append_bias_with_vector(self): + data = Vectors.dense([2.0, 2.0, 2.0]) + ret = MLUtils.appendBias(data) + self.assertEqual(ret[3], 1.0) + self.assertEqual(type(ret), DenseVector) + + def test_append_bias_with_sp_vector(self): + data = Vectors.sparse(3, {0: 2.0, 2: 2.0}) + expected = Vectors.sparse(4, {0: 2.0, 2: 2.0, 3: 1.0}) + # Returned value must be SparseVector + ret = MLUtils.appendBias(data) + self.assertEqual(ret, expected) + self.assertEqual(type(ret), SparseVector) + + def test_load_vectors(self): + import shutil + data = [ + [1.0, 2.0, 3.0], + [1.0, 2.0, 3.0] + ] + temp_dir = tempfile.mkdtemp() + load_vectors_path = os.path.join(temp_dir, "test_load_vectors") + try: + self.sc.parallelize(data).saveAsTextFile(load_vectors_path) + ret_rdd = MLUtils.loadVectors(self.sc, load_vectors_path) + ret = ret_rdd.collect() + self.assertEqual(len(ret), 2) + self.assertEqual(ret[0], DenseVector([1.0, 2.0, 3.0])) + self.assertEqual(ret[1], DenseVector([1.0, 2.0, 3.0])) + except: + self.fail() + finally: + shutil.rmtree(load_vectors_path) + + +class LinearDataGeneratorTests(MLlibTestCase): + def test_dim(self): + linear_data = LinearDataGenerator.generateLinearInput( + intercept=0.0, weights=[0.0, 0.0, 0.0], + xMean=[0.0, 0.0, 0.0], xVariance=[0.33, 0.33, 0.33], + nPoints=4, seed=0, eps=0.1) + self.assertEqual(len(linear_data), 4) + for point in linear_data: + self.assertEqual(len(point.features), 3) + + linear_data = LinearDataGenerator.generateLinearRDD( + sc=self.sc, nexamples=6, nfeatures=2, eps=0.1, + nParts=2, intercept=0.0).collect() + self.assertEqual(len(linear_data), 6) + for point in linear_data: + self.assertEqual(len(point.features), 2) + + +class SerDeTest(MLlibTestCase): + def test_to_java_object_rdd(self): # SPARK-6660 + data = RandomRDDs.uniformRDD(self.sc, 10, 5, seed=0) + self.assertEqual(_to_java_object_rdd(data).count(), 10) + + +if __name__ == "__main__": + from pyspark.mllib.tests.test_util import * + + try: + import xmlrunner + testRunner = xmlrunner.XMLTestRunner(output='target/test-reports') + except ImportError: + testRunner = None + unittest.main(testRunner=testRunner, verbosity=2) diff --git a/python/pyspark/rdd.py b/python/pyspark/rdd.py index 8815ade20e92c..e69ae768efd6d 100644 --- a/python/pyspark/rdd.py +++ b/python/pyspark/rdd.py @@ -2354,7 +2354,7 @@ def countApproxDistinct(self, relativeSD=0.05): The algorithm used is based on streamlib's implementation of `"HyperLogLog in Practice: Algorithmic Engineering of a State of The Art Cardinality Estimation Algorithm", available here - `_. + `_. :param relativeSD: Relative accuracy. Smaller values create counters that require more space. diff --git a/python/pyspark/sql/dataframe.py b/python/pyspark/sql/dataframe.py index 6202de36a479e..b8833a39078ba 100644 --- a/python/pyspark/sql/dataframe.py +++ b/python/pyspark/sql/dataframe.py @@ -257,7 +257,7 @@ def explain(self, extended=False): >>> df.explain() == Physical Plan == - Scan ExistingRDD[age#0,name#1] + *(1) Scan ExistingRDD[age#0,name#1] >>> df.explain(True) == Parsed Logical Plan == @@ -732,6 +732,11 @@ def repartitionByRange(self, numPartitions, *cols): At least one partition-by expression must be specified. When no explicit sort order is specified, "ascending nulls first" is assumed. + Note that due to performance reasons this method uses sampling to estimate the ranges. + Hence, the output may not be consistent, since sampling can return different values. + The sample size can be controlled by the config + `spark.sql.execution.rangeExchange.sampleSizePerPartition`. + >>> df.repartitionByRange(2, "age").rdd.getNumPartitions() 2 >>> df.show() @@ -1812,7 +1817,7 @@ def approxQuantile(self, col, probabilities, relativeError): This method implements a variation of the Greenwald-Khanna algorithm (with some speed optimizations). The algorithm was first - present in [[http://dx.doi.org/10.1145/375663.375670 + present in [[https://doi.org/10.1145/375663.375670 Space-efficient Online Computation of Quantile Summaries]] by Greenwald and Khanna. @@ -1934,7 +1939,7 @@ def freqItems(self, cols, support=None): """ Finding frequent items for columns, possibly with false positives. Using the frequent element count algorithm described in - "http://dx.doi.org/10.1145/762471.762473, proposed by Karp, Schenker, and Papadimitriou". + "https://doi.org/10.1145/762471.762473, proposed by Karp, Schenker, and Papadimitriou". :func:`DataFrame.freqItems` and :func:`DataFrameStatFunctions.freqItems` are aliases. .. note:: This function is meant for exploratory data analysis, as we make no diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py index 519fb07352a73..9690c63f988fe 100644 --- a/python/pyspark/sql/functions.py +++ b/python/pyspark/sql/functions.py @@ -2606,7 +2606,7 @@ def map_values(col): return Column(sc._jvm.functions.map_values(_to_java_column(col))) -@since(2.4) +@since(3.0) def map_entries(col): """ Collection function: Returns an unordered array of all entries in the given map. diff --git a/python/pyspark/sql/readwriter.py b/python/pyspark/sql/readwriter.py index 690b13072244b..1d2dd4d808930 100644 --- a/python/pyspark/sql/readwriter.py +++ b/python/pyspark/sql/readwriter.py @@ -177,7 +177,7 @@ def json(self, path, schema=None, primitivesAsString=None, prefersDecimal=None, allowNumericLeadingZero=None, allowBackslashEscapingAnyCharacter=None, mode=None, columnNameOfCorruptRecord=None, dateFormat=None, timestampFormat=None, multiLine=None, allowUnquotedControlChars=None, lineSep=None, samplingRatio=None, - dropFieldIfAllNull=None, encoding=None): + dropFieldIfAllNull=None, encoding=None, locale=None): """ Loads JSON files and returns the results as a :class:`DataFrame`. @@ -249,6 +249,9 @@ def json(self, path, schema=None, primitivesAsString=None, prefersDecimal=None, :param dropFieldIfAllNull: whether to ignore column of all null values or empty array/struct during schema inference. If None is set, it uses the default value, ``false``. + :param locale: sets a locale as language tag in IETF BCP 47 format. If None is set, + it uses the default value, ``en-US``. For instance, ``locale`` is used while + parsing dates and timestamps. >>> df1 = spark.read.json('python/test_support/sql/people.json') >>> df1.dtypes @@ -267,7 +270,8 @@ def json(self, path, schema=None, primitivesAsString=None, prefersDecimal=None, mode=mode, columnNameOfCorruptRecord=columnNameOfCorruptRecord, dateFormat=dateFormat, timestampFormat=timestampFormat, multiLine=multiLine, allowUnquotedControlChars=allowUnquotedControlChars, lineSep=lineSep, - samplingRatio=samplingRatio, dropFieldIfAllNull=dropFieldIfAllNull, encoding=encoding) + samplingRatio=samplingRatio, dropFieldIfAllNull=dropFieldIfAllNull, encoding=encoding, + locale=locale) if isinstance(path, basestring): path = [path] if type(path) == list: @@ -349,7 +353,7 @@ def csv(self, path, schema=None, sep=None, encoding=None, quote=None, escape=Non negativeInf=None, dateFormat=None, timestampFormat=None, maxColumns=None, maxCharsPerColumn=None, maxMalformedLogPerPartition=None, mode=None, columnNameOfCorruptRecord=None, multiLine=None, charToEscapeQuoteEscaping=None, - samplingRatio=None, enforceSchema=None, emptyValue=None): + samplingRatio=None, enforceSchema=None, emptyValue=None, locale=None, lineSep=None): r"""Loads a CSV file and returns the result as a :class:`DataFrame`. This function will go through the input once to determine the input schema if @@ -446,6 +450,12 @@ def csv(self, path, schema=None, sep=None, encoding=None, quote=None, escape=Non If None is set, it uses the default value, ``1.0``. :param emptyValue: sets the string representation of an empty value. If None is set, it uses the default value, empty string. + :param locale: sets a locale as language tag in IETF BCP 47 format. If None is set, + it uses the default value, ``en-US``. For instance, ``locale`` is used while + parsing dates and timestamps. + :param lineSep: defines the line separator that should be used for parsing. If None is + set, it covers all ``\\r``, ``\\r\\n`` and ``\\n``. + Maximum length is 1 character. >>> df = spark.read.csv('python/test_support/sql/ages.csv') >>> df.dtypes @@ -465,7 +475,7 @@ def csv(self, path, schema=None, sep=None, encoding=None, quote=None, escape=Non maxMalformedLogPerPartition=maxMalformedLogPerPartition, mode=mode, columnNameOfCorruptRecord=columnNameOfCorruptRecord, multiLine=multiLine, charToEscapeQuoteEscaping=charToEscapeQuoteEscaping, samplingRatio=samplingRatio, - enforceSchema=enforceSchema, emptyValue=emptyValue) + enforceSchema=enforceSchema, emptyValue=emptyValue, locale=locale, lineSep=lineSep) if isinstance(path, basestring): path = [path] if type(path) == list: @@ -861,7 +871,7 @@ def text(self, path, compression=None, lineSep=None): def csv(self, path, mode=None, compression=None, sep=None, quote=None, escape=None, header=None, nullValue=None, escapeQuotes=None, quoteAll=None, dateFormat=None, timestampFormat=None, ignoreLeadingWhiteSpace=None, ignoreTrailingWhiteSpace=None, - charToEscapeQuoteEscaping=None, encoding=None, emptyValue=None): + charToEscapeQuoteEscaping=None, encoding=None, emptyValue=None, lineSep=None): r"""Saves the content of the :class:`DataFrame` in CSV format at the specified path. :param path: the path in any Hadoop supported file system @@ -915,6 +925,8 @@ def csv(self, path, mode=None, compression=None, sep=None, quote=None, escape=No the default UTF-8 charset will be used. :param emptyValue: sets the string representation of an empty value. If None is set, it uses the default value, ``""``. + :param lineSep: defines the line separator that should be used for writing. If None is + set, it uses the default value, ``\\n``. Maximum length is 1 character. >>> df.write.csv(os.path.join(tempfile.mkdtemp(), 'data')) """ @@ -925,7 +937,7 @@ def csv(self, path, mode=None, compression=None, sep=None, quote=None, escape=No ignoreLeadingWhiteSpace=ignoreLeadingWhiteSpace, ignoreTrailingWhiteSpace=ignoreTrailingWhiteSpace, charToEscapeQuoteEscaping=charToEscapeQuoteEscaping, - encoding=encoding, emptyValue=emptyValue) + encoding=encoding, emptyValue=emptyValue, lineSep=lineSep) self._jwrite.csv(path) @since(1.5) diff --git a/python/pyspark/sql/streaming.py b/python/pyspark/sql/streaming.py index b18453b2a4f96..d92b0d5677e25 100644 --- a/python/pyspark/sql/streaming.py +++ b/python/pyspark/sql/streaming.py @@ -404,7 +404,8 @@ def json(self, path, schema=None, primitivesAsString=None, prefersDecimal=None, allowComments=None, allowUnquotedFieldNames=None, allowSingleQuotes=None, allowNumericLeadingZero=None, allowBackslashEscapingAnyCharacter=None, mode=None, columnNameOfCorruptRecord=None, dateFormat=None, timestampFormat=None, - multiLine=None, allowUnquotedControlChars=None, lineSep=None): + multiLine=None, allowUnquotedControlChars=None, lineSep=None, locale=None, + dropFieldIfAllNull=None, encoding=None): """ Loads a JSON file stream and returns the results as a :class:`DataFrame`. @@ -469,6 +470,16 @@ def json(self, path, schema=None, primitivesAsString=None, prefersDecimal=None, including tab and line feed characters) or not. :param lineSep: defines the line separator that should be used for parsing. If None is set, it covers all ``\\r``, ``\\r\\n`` and ``\\n``. + :param locale: sets a locale as language tag in IETF BCP 47 format. If None is set, + it uses the default value, ``en-US``. For instance, ``locale`` is used while + parsing dates and timestamps. + :param dropFieldIfAllNull: whether to ignore column of all null values or empty + array/struct during schema inference. If None is set, it + uses the default value, ``false``. + :param encoding: allows to forcibly set one of standard basic or extended encoding for + the JSON files. For example UTF-16BE, UTF-32LE. If None is set, + the encoding of input JSON will be detected automatically + when the multiLine option is set to ``true``. >>> json_sdf = spark.readStream.json(tempfile.mkdtemp(), schema = sdf_schema) >>> json_sdf.isStreaming @@ -483,7 +494,8 @@ def json(self, path, schema=None, primitivesAsString=None, prefersDecimal=None, allowBackslashEscapingAnyCharacter=allowBackslashEscapingAnyCharacter, mode=mode, columnNameOfCorruptRecord=columnNameOfCorruptRecord, dateFormat=dateFormat, timestampFormat=timestampFormat, multiLine=multiLine, - allowUnquotedControlChars=allowUnquotedControlChars, lineSep=lineSep) + allowUnquotedControlChars=allowUnquotedControlChars, lineSep=lineSep, locale=locale, + dropFieldIfAllNull=dropFieldIfAllNull, encoding=encoding) if isinstance(path, basestring): return self._df(self._jreader.json(path)) else: @@ -564,7 +576,7 @@ def csv(self, path, schema=None, sep=None, encoding=None, quote=None, escape=Non negativeInf=None, dateFormat=None, timestampFormat=None, maxColumns=None, maxCharsPerColumn=None, maxMalformedLogPerPartition=None, mode=None, columnNameOfCorruptRecord=None, multiLine=None, charToEscapeQuoteEscaping=None, - enforceSchema=None, emptyValue=None): + enforceSchema=None, emptyValue=None, locale=None, lineSep=None): r"""Loads a CSV file stream and returns the result as a :class:`DataFrame`. This function will go through the input once to determine the input schema if @@ -660,6 +672,12 @@ def csv(self, path, schema=None, sep=None, encoding=None, quote=None, escape=Non different, ``\0`` otherwise.. :param emptyValue: sets the string representation of an empty value. If None is set, it uses the default value, empty string. + :param locale: sets a locale as language tag in IETF BCP 47 format. If None is set, + it uses the default value, ``en-US``. For instance, ``locale`` is used while + parsing dates and timestamps. + :param lineSep: defines the line separator that should be used for parsing. If None is + set, it covers all ``\\r``, ``\\r\\n`` and ``\\n``. + Maximum length is 1 character. >>> csv_sdf = spark.readStream.csv(tempfile.mkdtemp(), schema = sdf_schema) >>> csv_sdf.isStreaming @@ -677,7 +695,7 @@ def csv(self, path, schema=None, sep=None, encoding=None, quote=None, escape=Non maxMalformedLogPerPartition=maxMalformedLogPerPartition, mode=mode, columnNameOfCorruptRecord=columnNameOfCorruptRecord, multiLine=multiLine, charToEscapeQuoteEscaping=charToEscapeQuoteEscaping, enforceSchema=enforceSchema, - emptyValue=emptyValue) + emptyValue=emptyValue, locale=locale, lineSep=lineSep) if isinstance(path, basestring): return self._df(self._jreader.csv(path)) else: diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py deleted file mode 100644 index 206d529454f3d..0000000000000 --- a/python/pyspark/sql/tests.py +++ /dev/null @@ -1,7109 +0,0 @@ -# -*- encoding: utf-8 -*- -# -# Licensed to the Apache Software Foundation (ASF) under one or more -# contributor license agreements. See the NOTICE file distributed with -# this work for additional information regarding copyright ownership. -# The ASF licenses this file to You under the Apache License, Version 2.0 -# (the "License"); you may not use this file except in compliance with -# the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# - -""" -Unit tests for pyspark.sql; additional tests are implemented as doctests in -individual modules. -""" -import os -import sys -import subprocess -import pydoc -import shutil -import tempfile -import threading -import pickle -import functools -import time -import datetime -import array -import ctypes -import warnings -import py4j -from contextlib import contextmanager -import unishark - -if sys.version_info[:2] <= (2, 6): - try: - import unittest2 as unittest - except ImportError: - sys.stderr.write('Please install unittest2 to test with Python 2.6 or earlier') - sys.exit(1) -else: - import unittest - -from pyspark.util import _exception_message - -_pandas_requirement_message = None -try: - from pyspark.sql.utils import require_minimum_pandas_version - require_minimum_pandas_version() -except ImportError as e: - # If Pandas version requirement is not satisfied, skip related tests. - _pandas_requirement_message = _exception_message(e) - -_pyarrow_requirement_message = None -try: - from pyspark.sql.utils import require_minimum_pyarrow_version - require_minimum_pyarrow_version() -except ImportError as e: - # If Arrow version requirement is not satisfied, skip related tests. - _pyarrow_requirement_message = _exception_message(e) - -_test_not_compiled_message = None -try: - from pyspark.sql.utils import require_test_compiled - require_test_compiled() -except Exception as e: - _test_not_compiled_message = _exception_message(e) - -_have_pandas = _pandas_requirement_message is None -_have_pyarrow = _pyarrow_requirement_message is None -_test_compiled = _test_not_compiled_message is None - -from pyspark import SparkConf, SparkContext -from pyspark.sql import SparkSession, SQLContext, HiveContext, Column, Row -from pyspark.sql.types import * -from pyspark.sql.types import UserDefinedType, _infer_type, _make_type_verifier -from pyspark.sql.types import _array_signed_int_typecode_ctype_mappings, _array_type_mappings -from pyspark.sql.types import _array_unsigned_int_typecode_ctype_mappings -from pyspark.sql.types import _merge_type -from pyspark.tests import QuietTest, ReusedPySparkTestCase, PySparkTestCase, SparkSubmitTests -from pyspark.sql.functions import UserDefinedFunction, sha2, lit -from pyspark.sql.window import Window -from pyspark.sql.utils import AnalysisException, ParseException, IllegalArgumentException - - -class UTCOffsetTimezone(datetime.tzinfo): - """ - Specifies timezone in UTC offset - """ - - def __init__(self, offset=0): - self.ZERO = datetime.timedelta(hours=offset) - - def utcoffset(self, dt): - return self.ZERO - - def dst(self, dt): - return self.ZERO - - -class ExamplePointUDT(UserDefinedType): - """ - User-defined type (UDT) for ExamplePoint. - """ - - @classmethod - def sqlType(self): - return ArrayType(DoubleType(), False) - - @classmethod - def module(cls): - return 'pyspark.sql.tests' - - @classmethod - def scalaUDT(cls): - return 'org.apache.spark.sql.test.ExamplePointUDT' - - def serialize(self, obj): - return [obj.x, obj.y] - - def deserialize(self, datum): - return ExamplePoint(datum[0], datum[1]) - - -class ExamplePoint: - """ - An example class to demonstrate UDT in Scala, Java, and Python. - """ - - __UDT__ = ExamplePointUDT() - - def __init__(self, x, y): - self.x = x - self.y = y - - def __repr__(self): - return "ExamplePoint(%s,%s)" % (self.x, self.y) - - def __str__(self): - return "(%s,%s)" % (self.x, self.y) - - def __eq__(self, other): - return isinstance(other, self.__class__) and \ - other.x == self.x and other.y == self.y - - -class PythonOnlyUDT(UserDefinedType): - """ - User-defined type (UDT) for ExamplePoint. - """ - - @classmethod - def sqlType(self): - return ArrayType(DoubleType(), False) - - @classmethod - def module(cls): - return '__main__' - - def serialize(self, obj): - return [obj.x, obj.y] - - def deserialize(self, datum): - return PythonOnlyPoint(datum[0], datum[1]) - - @staticmethod - def foo(): - pass - - @property - def props(self): - return {} - - -class PythonOnlyPoint(ExamplePoint): - """ - An example class to demonstrate UDT in only Python - """ - __UDT__ = PythonOnlyUDT() - - -class MyObject(object): - def __init__(self, key, value): - self.key = key - self.value = value - - -class SQLTestUtils(object): - """ - This util assumes the instance of this to have 'spark' attribute, having a spark session. - It is usually used with 'ReusedSQLTestCase' class but can be used if you feel sure the - the implementation of this class has 'spark' attribute. - """ - - @contextmanager - def sql_conf(self, pairs): - """ - A convenient context manager to test some configuration specific logic. This sets - `value` to the configuration `key` and then restores it back when it exits. - """ - assert isinstance(pairs, dict), "pairs should be a dictionary." - assert hasattr(self, "spark"), "it should have 'spark' attribute, having a spark session." - - keys = pairs.keys() - new_values = pairs.values() - old_values = [self.spark.conf.get(key, None) for key in keys] - for key, new_value in zip(keys, new_values): - self.spark.conf.set(key, new_value) - try: - yield - finally: - for key, old_value in zip(keys, old_values): - if old_value is None: - self.spark.conf.unset(key) - else: - self.spark.conf.set(key, old_value) - - @contextmanager - def database(self, *databases): - """ - A convenient context manager to test with some specific databases. This drops the given - databases if exist and sets current database to "default" when it exits. - """ - assert hasattr(self, "spark"), "it should have 'spark' attribute, having a spark session." - - try: - yield - finally: - for db in databases: - self.spark.sql("DROP DATABASE IF EXISTS %s CASCADE" % db) - self.spark.catalog.setCurrentDatabase("default") - - @contextmanager - def table(self, *tables): - """ - A convenient context manager to test with some specific tables. This drops the given tables - if exist when it exits. - """ - assert hasattr(self, "spark"), "it should have 'spark' attribute, having a spark session." - - try: - yield - finally: - for t in tables: - self.spark.sql("DROP TABLE IF EXISTS %s" % t) - - @contextmanager - def tempView(self, *views): - """ - A convenient context manager to test with some specific views. This drops the given views - if exist when it exits. - """ - assert hasattr(self, "spark"), "it should have 'spark' attribute, having a spark session." - - try: - yield - finally: - for v in views: - self.spark.catalog.dropTempView(v) - - @contextmanager - def function(self, *functions): - """ - A convenient context manager to test with some specific functions. This drops the given - functions if exist when it exits. - """ - assert hasattr(self, "spark"), "it should have 'spark' attribute, having a spark session." - - try: - yield - finally: - for f in functions: - self.spark.sql("DROP FUNCTION IF EXISTS %s" % f) - - -class ReusedSQLTestCase(ReusedPySparkTestCase, SQLTestUtils): - @classmethod - def setUpClass(cls): - super(ReusedSQLTestCase, cls).setUpClass() - cls.spark = SparkSession(cls.sc) - - @classmethod - def tearDownClass(cls): - super(ReusedSQLTestCase, cls).tearDownClass() - cls.spark.stop() - - def assertPandasEqual(self, expected, result): - msg = ("DataFrames are not equal: " + - "\n\nExpected:\n%s\n%s" % (expected, expected.dtypes) + - "\n\nResult:\n%s\n%s" % (result, result.dtypes)) - self.assertTrue(expected.equals(result), msg=msg) - - -class DataTypeTests(unittest.TestCase): - # regression test for SPARK-6055 - def test_data_type_eq(self): - lt = LongType() - lt2 = pickle.loads(pickle.dumps(LongType())) - self.assertEqual(lt, lt2) - - # regression test for SPARK-7978 - def test_decimal_type(self): - t1 = DecimalType() - t2 = DecimalType(10, 2) - self.assertTrue(t2 is not t1) - self.assertNotEqual(t1, t2) - t3 = DecimalType(8) - self.assertNotEqual(t2, t3) - - # regression test for SPARK-10392 - def test_datetype_equal_zero(self): - dt = DateType() - self.assertEqual(dt.fromInternal(0), datetime.date(1970, 1, 1)) - - # regression test for SPARK-17035 - def test_timestamp_microsecond(self): - tst = TimestampType() - self.assertEqual(tst.toInternal(datetime.datetime.max) % 1000000, 999999) - - def test_empty_row(self): - row = Row() - self.assertEqual(len(row), 0) - - def test_struct_field_type_name(self): - struct_field = StructField("a", IntegerType()) - self.assertRaises(TypeError, struct_field.typeName) - - def test_invalid_create_row(self): - row_class = Row("c1", "c2") - self.assertRaises(ValueError, lambda: row_class(1, 2, 3)) - - -class SparkSessionBuilderTests(unittest.TestCase): - - def test_create_spark_context_first_then_spark_session(self): - sc = None - session = None - try: - conf = SparkConf().set("key1", "value1") - sc = SparkContext('local[4]', "SessionBuilderTests", conf=conf) - session = SparkSession.builder.config("key2", "value2").getOrCreate() - - self.assertEqual(session.conf.get("key1"), "value1") - self.assertEqual(session.conf.get("key2"), "value2") - self.assertEqual(session.sparkContext, sc) - - self.assertFalse(sc.getConf().contains("key2")) - self.assertEqual(sc.getConf().get("key1"), "value1") - finally: - if session is not None: - session.stop() - if sc is not None: - sc.stop() - - def test_another_spark_session(self): - session1 = None - session2 = None - try: - session1 = SparkSession.builder.config("key1", "value1").getOrCreate() - session2 = SparkSession.builder.config("key2", "value2").getOrCreate() - - self.assertEqual(session1.conf.get("key1"), "value1") - self.assertEqual(session2.conf.get("key1"), "value1") - self.assertEqual(session1.conf.get("key2"), "value2") - self.assertEqual(session2.conf.get("key2"), "value2") - self.assertEqual(session1.sparkContext, session2.sparkContext) - - self.assertEqual(session1.sparkContext.getConf().get("key1"), "value1") - self.assertFalse(session1.sparkContext.getConf().contains("key2")) - finally: - if session1 is not None: - session1.stop() - if session2 is not None: - session2.stop() - - -class SQLTests(ReusedSQLTestCase): - - @classmethod - def setUpClass(cls): - ReusedSQLTestCase.setUpClass() - cls.spark.catalog._reset() - cls.tempdir = tempfile.NamedTemporaryFile(delete=False) - os.unlink(cls.tempdir.name) - cls.testData = [Row(key=i, value=str(i)) for i in range(100)] - cls.df = cls.spark.createDataFrame(cls.testData) - - @classmethod - def tearDownClass(cls): - ReusedSQLTestCase.tearDownClass() - shutil.rmtree(cls.tempdir.name, ignore_errors=True) - - def test_sqlcontext_reuses_sparksession(self): - sqlContext1 = SQLContext(self.sc) - sqlContext2 = SQLContext(self.sc) - self.assertTrue(sqlContext1.sparkSession is sqlContext2.sparkSession) - - def test_row_should_be_read_only(self): - row = Row(a=1, b=2) - self.assertEqual(1, row.a) - - def foo(): - row.a = 3 - self.assertRaises(Exception, foo) - - row2 = self.spark.range(10).first() - self.assertEqual(0, row2.id) - - def foo2(): - row2.id = 2 - self.assertRaises(Exception, foo2) - - def test_range(self): - self.assertEqual(self.spark.range(1, 1).count(), 0) - self.assertEqual(self.spark.range(1, 0, -1).count(), 1) - self.assertEqual(self.spark.range(0, 1 << 40, 1 << 39).count(), 2) - self.assertEqual(self.spark.range(-2).count(), 0) - self.assertEqual(self.spark.range(3).count(), 3) - - def test_duplicated_column_names(self): - df = self.spark.createDataFrame([(1, 2)], ["c", "c"]) - row = df.select('*').first() - self.assertEqual(1, row[0]) - self.assertEqual(2, row[1]) - self.assertEqual("Row(c=1, c=2)", str(row)) - # Cannot access columns - self.assertRaises(AnalysisException, lambda: df.select(df[0]).first()) - self.assertRaises(AnalysisException, lambda: df.select(df.c).first()) - self.assertRaises(AnalysisException, lambda: df.select(df["c"]).first()) - - def test_column_name_encoding(self): - """Ensure that created columns has `str` type consistently.""" - columns = self.spark.createDataFrame([('Alice', 1)], ['name', u'age']).columns - self.assertEqual(columns, ['name', 'age']) - self.assertTrue(isinstance(columns[0], str)) - self.assertTrue(isinstance(columns[1], str)) - - def test_explode(self): - from pyspark.sql.functions import explode, explode_outer, posexplode_outer - d = [ - Row(a=1, intlist=[1, 2, 3], mapfield={"a": "b"}), - Row(a=1, intlist=[], mapfield={}), - Row(a=1, intlist=None, mapfield=None), - ] - rdd = self.sc.parallelize(d) - data = self.spark.createDataFrame(rdd) - - result = data.select(explode(data.intlist).alias("a")).select("a").collect() - self.assertEqual(result[0][0], 1) - self.assertEqual(result[1][0], 2) - self.assertEqual(result[2][0], 3) - - result = data.select(explode(data.mapfield).alias("a", "b")).select("a", "b").collect() - self.assertEqual(result[0][0], "a") - self.assertEqual(result[0][1], "b") - - result = [tuple(x) for x in data.select(posexplode_outer("intlist")).collect()] - self.assertEqual(result, [(0, 1), (1, 2), (2, 3), (None, None), (None, None)]) - - result = [tuple(x) for x in data.select(posexplode_outer("mapfield")).collect()] - self.assertEqual(result, [(0, 'a', 'b'), (None, None, None), (None, None, None)]) - - result = [x[0] for x in data.select(explode_outer("intlist")).collect()] - self.assertEqual(result, [1, 2, 3, None, None]) - - result = [tuple(x) for x in data.select(explode_outer("mapfield")).collect()] - self.assertEqual(result, [('a', 'b'), (None, None), (None, None)]) - - def test_and_in_expression(self): - self.assertEqual(4, self.df.filter((self.df.key <= 10) & (self.df.value <= "2")).count()) - self.assertRaises(ValueError, lambda: (self.df.key <= 10) and (self.df.value <= "2")) - self.assertEqual(14, self.df.filter((self.df.key <= 3) | (self.df.value < "2")).count()) - self.assertRaises(ValueError, lambda: self.df.key <= 3 or self.df.value < "2") - self.assertEqual(99, self.df.filter(~(self.df.key == 1)).count()) - self.assertRaises(ValueError, lambda: not self.df.key == 1) - - def test_udf_with_callable(self): - d = [Row(number=i, squared=i**2) for i in range(10)] - rdd = self.sc.parallelize(d) - data = self.spark.createDataFrame(rdd) - - class PlusFour: - def __call__(self, col): - if col is not None: - return col + 4 - - call = PlusFour() - pudf = UserDefinedFunction(call, LongType()) - res = data.select(pudf(data['number']).alias('plus_four')) - self.assertEqual(res.agg({'plus_four': 'sum'}).collect()[0][0], 85) - - def test_udf_with_partial_function(self): - d = [Row(number=i, squared=i**2) for i in range(10)] - rdd = self.sc.parallelize(d) - data = self.spark.createDataFrame(rdd) - - def some_func(col, param): - if col is not None: - return col + param - - pfunc = functools.partial(some_func, param=4) - pudf = UserDefinedFunction(pfunc, LongType()) - res = data.select(pudf(data['number']).alias('plus_four')) - self.assertEqual(res.agg({'plus_four': 'sum'}).collect()[0][0], 85) - - def test_udf(self): - self.spark.catalog.registerFunction("twoArgs", lambda x, y: len(x) + y, IntegerType()) - [row] = self.spark.sql("SELECT twoArgs('test', 1)").collect() - self.assertEqual(row[0], 5) - - # This is to check if a deprecated 'SQLContext.registerFunction' can call its alias. - sqlContext = self.spark._wrapped - sqlContext.registerFunction("oneArg", lambda x: len(x), IntegerType()) - [row] = sqlContext.sql("SELECT oneArg('test')").collect() - self.assertEqual(row[0], 4) - - def test_udf2(self): - with self.tempView("test"): - self.spark.catalog.registerFunction("strlen", lambda string: len(string), IntegerType()) - self.spark.createDataFrame(self.sc.parallelize([Row(a="test")]))\ - .createOrReplaceTempView("test") - [res] = self.spark.sql("SELECT strlen(a) FROM test WHERE strlen(a) > 1").collect() - self.assertEqual(4, res[0]) - - def test_udf3(self): - two_args = self.spark.catalog.registerFunction( - "twoArgs", UserDefinedFunction(lambda x, y: len(x) + y)) - self.assertEqual(two_args.deterministic, True) - [row] = self.spark.sql("SELECT twoArgs('test', 1)").collect() - self.assertEqual(row[0], u'5') - - def test_udf_registration_return_type_none(self): - two_args = self.spark.catalog.registerFunction( - "twoArgs", UserDefinedFunction(lambda x, y: len(x) + y, "integer"), None) - self.assertEqual(two_args.deterministic, True) - [row] = self.spark.sql("SELECT twoArgs('test', 1)").collect() - self.assertEqual(row[0], 5) - - def test_udf_registration_return_type_not_none(self): - with QuietTest(self.sc): - with self.assertRaisesRegexp(TypeError, "Invalid returnType"): - self.spark.catalog.registerFunction( - "f", UserDefinedFunction(lambda x, y: len(x) + y, StringType()), StringType()) - - def test_nondeterministic_udf(self): - # Test that nondeterministic UDFs are evaluated only once in chained UDF evaluations - from pyspark.sql.functions import udf - import random - udf_random_col = udf(lambda: int(100 * random.random()), IntegerType()).asNondeterministic() - self.assertEqual(udf_random_col.deterministic, False) - df = self.spark.createDataFrame([Row(1)]).select(udf_random_col().alias('RAND')) - udf_add_ten = udf(lambda rand: rand + 10, IntegerType()) - [row] = df.withColumn('RAND_PLUS_TEN', udf_add_ten('RAND')).collect() - self.assertEqual(row[0] + 10, row[1]) - - def test_nondeterministic_udf2(self): - import random - from pyspark.sql.functions import udf - random_udf = udf(lambda: random.randint(6, 6), IntegerType()).asNondeterministic() - self.assertEqual(random_udf.deterministic, False) - random_udf1 = self.spark.catalog.registerFunction("randInt", random_udf) - self.assertEqual(random_udf1.deterministic, False) - [row] = self.spark.sql("SELECT randInt()").collect() - self.assertEqual(row[0], 6) - [row] = self.spark.range(1).select(random_udf1()).collect() - self.assertEqual(row[0], 6) - [row] = self.spark.range(1).select(random_udf()).collect() - self.assertEqual(row[0], 6) - # render_doc() reproduces the help() exception without printing output - pydoc.render_doc(udf(lambda: random.randint(6, 6), IntegerType())) - pydoc.render_doc(random_udf) - pydoc.render_doc(random_udf1) - pydoc.render_doc(udf(lambda x: x).asNondeterministic) - - def test_nondeterministic_udf3(self): - # regression test for SPARK-23233 - from pyspark.sql.functions import udf - f = udf(lambda x: x) - # Here we cache the JVM UDF instance. - self.spark.range(1).select(f("id")) - # This should reset the cache to set the deterministic status correctly. - f = f.asNondeterministic() - # Check the deterministic status of udf. - df = self.spark.range(1).select(f("id")) - deterministic = df._jdf.logicalPlan().projectList().head().deterministic() - self.assertFalse(deterministic) - - def test_nondeterministic_udf_in_aggregate(self): - from pyspark.sql.functions import udf, sum - import random - udf_random_col = udf(lambda: int(100 * random.random()), 'int').asNondeterministic() - df = self.spark.range(10) - - with QuietTest(self.sc): - with self.assertRaisesRegexp(AnalysisException, "nondeterministic"): - df.groupby('id').agg(sum(udf_random_col())).collect() - with self.assertRaisesRegexp(AnalysisException, "nondeterministic"): - df.agg(sum(udf_random_col())).collect() - - def test_chained_udf(self): - self.spark.catalog.registerFunction("double", lambda x: x + x, IntegerType()) - [row] = self.spark.sql("SELECT double(1)").collect() - self.assertEqual(row[0], 2) - [row] = self.spark.sql("SELECT double(double(1))").collect() - self.assertEqual(row[0], 4) - [row] = self.spark.sql("SELECT double(double(1) + 1)").collect() - self.assertEqual(row[0], 6) - - def test_single_udf_with_repeated_argument(self): - # regression test for SPARK-20685 - self.spark.catalog.registerFunction("add", lambda x, y: x + y, IntegerType()) - row = self.spark.sql("SELECT add(1, 1)").first() - self.assertEqual(tuple(row), (2, )) - - def test_multiple_udfs(self): - self.spark.catalog.registerFunction("double", lambda x: x * 2, IntegerType()) - [row] = self.spark.sql("SELECT double(1), double(2)").collect() - self.assertEqual(tuple(row), (2, 4)) - [row] = self.spark.sql("SELECT double(double(1)), double(double(2) + 2)").collect() - self.assertEqual(tuple(row), (4, 12)) - self.spark.catalog.registerFunction("add", lambda x, y: x + y, IntegerType()) - [row] = self.spark.sql("SELECT double(add(1, 2)), add(double(2), 1)").collect() - self.assertEqual(tuple(row), (6, 5)) - - def test_udf_in_filter_on_top_of_outer_join(self): - from pyspark.sql.functions import udf - left = self.spark.createDataFrame([Row(a=1)]) - right = self.spark.createDataFrame([Row(a=1)]) - df = left.join(right, on='a', how='left_outer') - df = df.withColumn('b', udf(lambda x: 'x')(df.a)) - self.assertEqual(df.filter('b = "x"').collect(), [Row(a=1, b='x')]) - - def test_udf_in_filter_on_top_of_join(self): - # regression test for SPARK-18589 - from pyspark.sql.functions import udf - left = self.spark.createDataFrame([Row(a=1)]) - right = self.spark.createDataFrame([Row(b=1)]) - f = udf(lambda a, b: a == b, BooleanType()) - df = left.crossJoin(right).filter(f("a", "b")) - self.assertEqual(df.collect(), [Row(a=1, b=1)]) - - def test_udf_in_join_condition(self): - # regression test for SPARK-25314 - from pyspark.sql.functions import udf - left = self.spark.createDataFrame([Row(a=1)]) - right = self.spark.createDataFrame([Row(b=1)]) - f = udf(lambda a, b: a == b, BooleanType()) - df = left.join(right, f("a", "b")) - with self.assertRaisesRegexp(AnalysisException, 'Detected implicit cartesian product'): - df.collect() - with self.sql_conf({"spark.sql.crossJoin.enabled": True}): - self.assertEqual(df.collect(), [Row(a=1, b=1)]) - - def test_udf_in_left_semi_join_condition(self): - # regression test for SPARK-25314 - from pyspark.sql.functions import udf - left = self.spark.createDataFrame([Row(a=1, a1=1, a2=1), Row(a=2, a1=2, a2=2)]) - right = self.spark.createDataFrame([Row(b=1, b1=1, b2=1)]) - f = udf(lambda a, b: a == b, BooleanType()) - df = left.join(right, f("a", "b"), "leftsemi") - with self.assertRaisesRegexp(AnalysisException, 'Detected implicit cartesian product'): - df.collect() - with self.sql_conf({"spark.sql.crossJoin.enabled": True}): - self.assertEqual(df.collect(), [Row(a=1, a1=1, a2=1)]) - - def test_udf_and_common_filter_in_join_condition(self): - # regression test for SPARK-25314 - # test the complex scenario with both udf and common filter - from pyspark.sql.functions import udf - left = self.spark.createDataFrame([Row(a=1, a1=1, a2=1), Row(a=2, a1=2, a2=2)]) - right = self.spark.createDataFrame([Row(b=1, b1=1, b2=1), Row(b=1, b1=3, b2=1)]) - f = udf(lambda a, b: a == b, BooleanType()) - df = left.join(right, [f("a", "b"), left.a1 == right.b1]) - # do not need spark.sql.crossJoin.enabled=true for udf is not the only join condition. - self.assertEqual(df.collect(), [Row(a=1, a1=1, a2=1, b=1, b1=1, b2=1)]) - - def test_udf_and_common_filter_in_left_semi_join_condition(self): - # regression test for SPARK-25314 - # test the complex scenario with both udf and common filter - from pyspark.sql.functions import udf - left = self.spark.createDataFrame([Row(a=1, a1=1, a2=1), Row(a=2, a1=2, a2=2)]) - right = self.spark.createDataFrame([Row(b=1, b1=1, b2=1), Row(b=1, b1=3, b2=1)]) - f = udf(lambda a, b: a == b, BooleanType()) - df = left.join(right, [f("a", "b"), left.a1 == right.b1], "left_semi") - # do not need spark.sql.crossJoin.enabled=true for udf is not the only join condition. - self.assertEqual(df.collect(), [Row(a=1, a1=1, a2=1)]) - - def test_udf_not_supported_in_join_condition(self): - # regression test for SPARK-25314 - # test python udf is not supported in join type besides left_semi and inner join. - from pyspark.sql.functions import udf - left = self.spark.createDataFrame([Row(a=1, a1=1, a2=1), Row(a=2, a1=2, a2=2)]) - right = self.spark.createDataFrame([Row(b=1, b1=1, b2=1), Row(b=1, b1=3, b2=1)]) - f = udf(lambda a, b: a == b, BooleanType()) - - def runWithJoinType(join_type, type_string): - with self.assertRaisesRegexp( - AnalysisException, - 'Using PythonUDF.*%s is not supported.' % type_string): - left.join(right, [f("a", "b"), left.a1 == right.b1], join_type).collect() - runWithJoinType("full", "FullOuter") - runWithJoinType("left", "LeftOuter") - runWithJoinType("right", "RightOuter") - runWithJoinType("leftanti", "LeftAnti") - - def test_udf_without_arguments(self): - self.spark.catalog.registerFunction("foo", lambda: "bar") - [row] = self.spark.sql("SELECT foo()").collect() - self.assertEqual(row[0], "bar") - - def test_udf_with_array_type(self): - with self.tempView("test"): - d = [Row(l=list(range(3)), d={"key": list(range(5))})] - rdd = self.sc.parallelize(d) - self.spark.createDataFrame(rdd).createOrReplaceTempView("test") - self.spark.catalog.registerFunction( - "copylist", lambda l: list(l), ArrayType(IntegerType())) - self.spark.catalog.registerFunction("maplen", lambda d: len(d), IntegerType()) - [(l1, l2)] = self.spark.sql("select copylist(l), maplen(d) from test").collect() - self.assertEqual(list(range(3)), l1) - self.assertEqual(1, l2) - - def test_broadcast_in_udf(self): - bar = {"a": "aa", "b": "bb", "c": "abc"} - foo = self.sc.broadcast(bar) - self.spark.catalog.registerFunction("MYUDF", lambda x: foo.value[x] if x else '') - [res] = self.spark.sql("SELECT MYUDF('c')").collect() - self.assertEqual("abc", res[0]) - [res] = self.spark.sql("SELECT MYUDF('')").collect() - self.assertEqual("", res[0]) - - def test_udf_with_filter_function(self): - df = self.spark.createDataFrame([(1, "1"), (2, "2"), (1, "2"), (1, "2")], ["key", "value"]) - from pyspark.sql.functions import udf, col - from pyspark.sql.types import BooleanType - - my_filter = udf(lambda a: a < 2, BooleanType()) - sel = df.select(col("key"), col("value")).filter((my_filter(col("key"))) & (df.value < "2")) - self.assertEqual(sel.collect(), [Row(key=1, value='1')]) - - def test_udf_with_aggregate_function(self): - df = self.spark.createDataFrame([(1, "1"), (2, "2"), (1, "2"), (1, "2")], ["key", "value"]) - from pyspark.sql.functions import udf, col, sum - from pyspark.sql.types import BooleanType - - my_filter = udf(lambda a: a == 1, BooleanType()) - sel = df.select(col("key")).distinct().filter(my_filter(col("key"))) - self.assertEqual(sel.collect(), [Row(key=1)]) - - my_copy = udf(lambda x: x, IntegerType()) - my_add = udf(lambda a, b: int(a + b), IntegerType()) - my_strlen = udf(lambda x: len(x), IntegerType()) - sel = df.groupBy(my_copy(col("key")).alias("k"))\ - .agg(sum(my_strlen(col("value"))).alias("s"))\ - .select(my_add(col("k"), col("s")).alias("t")) - self.assertEqual(sel.collect(), [Row(t=4), Row(t=3)]) - - def test_udf_in_generate(self): - from pyspark.sql.functions import udf, explode - df = self.spark.range(5) - f = udf(lambda x: list(range(x)), ArrayType(LongType())) - row = df.select(explode(f(*df))).groupBy().sum().first() - self.assertEqual(row[0], 10) - - df = self.spark.range(3) - res = df.select("id", explode(f(df.id))).collect() - self.assertEqual(res[0][0], 1) - self.assertEqual(res[0][1], 0) - self.assertEqual(res[1][0], 2) - self.assertEqual(res[1][1], 0) - self.assertEqual(res[2][0], 2) - self.assertEqual(res[2][1], 1) - - range_udf = udf(lambda value: list(range(value - 1, value + 1)), ArrayType(IntegerType())) - res = df.select("id", explode(range_udf(df.id))).collect() - self.assertEqual(res[0][0], 0) - self.assertEqual(res[0][1], -1) - self.assertEqual(res[1][0], 0) - self.assertEqual(res[1][1], 0) - self.assertEqual(res[2][0], 1) - self.assertEqual(res[2][1], 0) - self.assertEqual(res[3][0], 1) - self.assertEqual(res[3][1], 1) - - def test_udf_with_order_by_and_limit(self): - from pyspark.sql.functions import udf - my_copy = udf(lambda x: x, IntegerType()) - df = self.spark.range(10).orderBy("id") - res = df.select(df.id, my_copy(df.id).alias("copy")).limit(1) - res.explain(True) - self.assertEqual(res.collect(), [Row(id=0, copy=0)]) - - def test_udf_registration_returns_udf(self): - df = self.spark.range(10) - add_three = self.spark.udf.register("add_three", lambda x: x + 3, IntegerType()) - - self.assertListEqual( - df.selectExpr("add_three(id) AS plus_three").collect(), - df.select(add_three("id").alias("plus_three")).collect() - ) - - # This is to check if a 'SQLContext.udf' can call its alias. - sqlContext = self.spark._wrapped - add_four = sqlContext.udf.register("add_four", lambda x: x + 4, IntegerType()) - - self.assertListEqual( - df.selectExpr("add_four(id) AS plus_four").collect(), - df.select(add_four("id").alias("plus_four")).collect() - ) - - def test_non_existed_udf(self): - spark = self.spark - self.assertRaisesRegexp(AnalysisException, "Can not load class non_existed_udf", - lambda: spark.udf.registerJavaFunction("udf1", "non_existed_udf")) - - # This is to check if a deprecated 'SQLContext.registerJavaFunction' can call its alias. - sqlContext = spark._wrapped - self.assertRaisesRegexp(AnalysisException, "Can not load class non_existed_udf", - lambda: sqlContext.registerJavaFunction("udf1", "non_existed_udf")) - - def test_non_existed_udaf(self): - spark = self.spark - self.assertRaisesRegexp(AnalysisException, "Can not load class non_existed_udaf", - lambda: spark.udf.registerJavaUDAF("udaf1", "non_existed_udaf")) - - def test_linesep_text(self): - df = self.spark.read.text("python/test_support/sql/ages_newlines.csv", lineSep=",") - expected = [Row(value=u'Joe'), Row(value=u'20'), Row(value=u'"Hi'), - Row(value=u'\nI am Jeo"\nTom'), Row(value=u'30'), - Row(value=u'"My name is Tom"\nHyukjin'), Row(value=u'25'), - Row(value=u'"I am Hyukjin\n\nI love Spark!"\n')] - self.assertEqual(df.collect(), expected) - - tpath = tempfile.mkdtemp() - shutil.rmtree(tpath) - try: - df.write.text(tpath, lineSep="!") - expected = [Row(value=u'Joe!20!"Hi!'), Row(value=u'I am Jeo"'), - Row(value=u'Tom!30!"My name is Tom"'), - Row(value=u'Hyukjin!25!"I am Hyukjin'), - Row(value=u''), Row(value=u'I love Spark!"'), - Row(value=u'!')] - readback = self.spark.read.text(tpath) - self.assertEqual(readback.collect(), expected) - finally: - shutil.rmtree(tpath) - - def test_multiline_json(self): - people1 = self.spark.read.json("python/test_support/sql/people.json") - people_array = self.spark.read.json("python/test_support/sql/people_array.json", - multiLine=True) - self.assertEqual(people1.collect(), people_array.collect()) - - def test_encoding_json(self): - people_array = self.spark.read\ - .json("python/test_support/sql/people_array_utf16le.json", - multiLine=True, encoding="UTF-16LE") - expected = [Row(age=30, name=u'Andy'), Row(age=19, name=u'Justin')] - self.assertEqual(people_array.collect(), expected) - - def test_linesep_json(self): - df = self.spark.read.json("python/test_support/sql/people.json", lineSep=",") - expected = [Row(_corrupt_record=None, name=u'Michael'), - Row(_corrupt_record=u' "age":30}\n{"name":"Justin"', name=None), - Row(_corrupt_record=u' "age":19}\n', name=None)] - self.assertEqual(df.collect(), expected) - - tpath = tempfile.mkdtemp() - shutil.rmtree(tpath) - try: - df = self.spark.read.json("python/test_support/sql/people.json") - df.write.json(tpath, lineSep="!!") - readback = self.spark.read.json(tpath, lineSep="!!") - self.assertEqual(readback.collect(), df.collect()) - finally: - shutil.rmtree(tpath) - - def test_multiline_csv(self): - ages_newlines = self.spark.read.csv( - "python/test_support/sql/ages_newlines.csv", multiLine=True) - expected = [Row(_c0=u'Joe', _c1=u'20', _c2=u'Hi,\nI am Jeo'), - Row(_c0=u'Tom', _c1=u'30', _c2=u'My name is Tom'), - Row(_c0=u'Hyukjin', _c1=u'25', _c2=u'I am Hyukjin\n\nI love Spark!')] - self.assertEqual(ages_newlines.collect(), expected) - - def test_ignorewhitespace_csv(self): - tmpPath = tempfile.mkdtemp() - shutil.rmtree(tmpPath) - self.spark.createDataFrame([[" a", "b ", " c "]]).write.csv( - tmpPath, - ignoreLeadingWhiteSpace=False, - ignoreTrailingWhiteSpace=False) - - expected = [Row(value=u' a,b , c ')] - readback = self.spark.read.text(tmpPath) - self.assertEqual(readback.collect(), expected) - shutil.rmtree(tmpPath) - - def test_read_multiple_orc_file(self): - df = self.spark.read.orc(["python/test_support/sql/orc_partitioned/b=0/c=0", - "python/test_support/sql/orc_partitioned/b=1/c=1"]) - self.assertEqual(2, df.count()) - - def test_udf_with_input_file_name(self): - from pyspark.sql.functions import udf, input_file_name - sourceFile = udf(lambda path: path, StringType()) - filePath = "python/test_support/sql/people1.json" - row = self.spark.read.json(filePath).select(sourceFile(input_file_name())).first() - self.assertTrue(row[0].find("people1.json") != -1) - - def test_udf_with_input_file_name_for_hadooprdd(self): - from pyspark.sql.functions import udf, input_file_name - - def filename(path): - return path - - sameText = udf(filename, StringType()) - - rdd = self.sc.textFile('python/test_support/sql/people.json') - df = self.spark.read.json(rdd).select(input_file_name().alias('file')) - row = df.select(sameText(df['file'])).first() - self.assertTrue(row[0].find("people.json") != -1) - - rdd2 = self.sc.newAPIHadoopFile( - 'python/test_support/sql/people.json', - 'org.apache.hadoop.mapreduce.lib.input.TextInputFormat', - 'org.apache.hadoop.io.LongWritable', - 'org.apache.hadoop.io.Text') - - df2 = self.spark.read.json(rdd2).select(input_file_name().alias('file')) - row2 = df2.select(sameText(df2['file'])).first() - self.assertTrue(row2[0].find("people.json") != -1) - - def test_udf_defers_judf_initialization(self): - # This is separate of UDFInitializationTests - # to avoid context initialization - # when udf is called - - from pyspark.sql.functions import UserDefinedFunction - - f = UserDefinedFunction(lambda x: x, StringType()) - - self.assertIsNone( - f._judf_placeholder, - "judf should not be initialized before the first call." - ) - - self.assertIsInstance(f("foo"), Column, "UDF call should return a Column.") - - self.assertIsNotNone( - f._judf_placeholder, - "judf should be initialized after UDF has been called." - ) - - def test_udf_with_string_return_type(self): - from pyspark.sql.functions import UserDefinedFunction - - add_one = UserDefinedFunction(lambda x: x + 1, "integer") - make_pair = UserDefinedFunction(lambda x: (-x, x), "struct") - make_array = UserDefinedFunction( - lambda x: [float(x) for x in range(x, x + 3)], "array") - - expected = (2, Row(x=-1, y=1), [1.0, 2.0, 3.0]) - actual = (self.spark.range(1, 2).toDF("x") - .select(add_one("x"), make_pair("x"), make_array("x")) - .first()) - - self.assertTupleEqual(expected, actual) - - def test_udf_shouldnt_accept_noncallable_object(self): - from pyspark.sql.functions import UserDefinedFunction - - non_callable = None - self.assertRaises(TypeError, UserDefinedFunction, non_callable, StringType()) - - def test_udf_with_decorator(self): - from pyspark.sql.functions import lit, udf - from pyspark.sql.types import IntegerType, DoubleType - - @udf(IntegerType()) - def add_one(x): - if x is not None: - return x + 1 - - @udf(returnType=DoubleType()) - def add_two(x): - if x is not None: - return float(x + 2) - - @udf - def to_upper(x): - if x is not None: - return x.upper() - - @udf() - def to_lower(x): - if x is not None: - return x.lower() - - @udf - def substr(x, start, end): - if x is not None: - return x[start:end] - - @udf("long") - def trunc(x): - return int(x) - - @udf(returnType="double") - def as_double(x): - return float(x) - - df = ( - self.spark - .createDataFrame( - [(1, "Foo", "foobar", 3.0)], ("one", "Foo", "foobar", "float")) - .select( - add_one("one"), add_two("one"), - to_upper("Foo"), to_lower("Foo"), - substr("foobar", lit(0), lit(3)), - trunc("float"), as_double("one"))) - - self.assertListEqual( - [tpe for _, tpe in df.dtypes], - ["int", "double", "string", "string", "string", "bigint", "double"] - ) - - self.assertListEqual( - list(df.first()), - [2, 3.0, "FOO", "foo", "foo", 3, 1.0] - ) - - def test_udf_wrapper(self): - from pyspark.sql.functions import udf - from pyspark.sql.types import IntegerType - - def f(x): - """Identity""" - return x - - return_type = IntegerType() - f_ = udf(f, return_type) - - self.assertTrue(f.__doc__ in f_.__doc__) - self.assertEqual(f, f_.func) - self.assertEqual(return_type, f_.returnType) - - class F(object): - """Identity""" - def __call__(self, x): - return x - - f = F() - return_type = IntegerType() - f_ = udf(f, return_type) - - self.assertTrue(f.__doc__ in f_.__doc__) - self.assertEqual(f, f_.func) - self.assertEqual(return_type, f_.returnType) - - f = functools.partial(f, x=1) - return_type = IntegerType() - f_ = udf(f, return_type) - - self.assertTrue(f.__doc__ in f_.__doc__) - self.assertEqual(f, f_.func) - self.assertEqual(return_type, f_.returnType) - - def test_validate_column_types(self): - from pyspark.sql.functions import udf, to_json - from pyspark.sql.column import _to_java_column - - self.assertTrue("Column" in _to_java_column("a").getClass().toString()) - self.assertTrue("Column" in _to_java_column(u"a").getClass().toString()) - self.assertTrue("Column" in _to_java_column(self.spark.range(1).id).getClass().toString()) - - self.assertRaisesRegexp( - TypeError, - "Invalid argument, not a string or column", - lambda: _to_java_column(1)) - - class A(): - pass - - self.assertRaises(TypeError, lambda: _to_java_column(A())) - self.assertRaises(TypeError, lambda: _to_java_column([])) - - self.assertRaisesRegexp( - TypeError, - "Invalid argument, not a string or column", - lambda: udf(lambda x: x)(None)) - self.assertRaises(TypeError, lambda: to_json(1)) - - def test_basic_functions(self): - rdd = self.sc.parallelize(['{"foo":"bar"}', '{"foo":"baz"}']) - df = self.spark.read.json(rdd) - df.count() - df.collect() - df.schema - - # cache and checkpoint - self.assertFalse(df.is_cached) - df.persist() - df.unpersist(True) - df.cache() - self.assertTrue(df.is_cached) - self.assertEqual(2, df.count()) - - with self.tempView("temp"): - df.createOrReplaceTempView("temp") - df = self.spark.sql("select foo from temp") - df.count() - df.collect() - - def test_apply_schema_to_row(self): - df = self.spark.read.json(self.sc.parallelize(["""{"a":2}"""])) - df2 = self.spark.createDataFrame(df.rdd.map(lambda x: x), df.schema) - self.assertEqual(df.collect(), df2.collect()) - - rdd = self.sc.parallelize(range(10)).map(lambda x: Row(a=x)) - df3 = self.spark.createDataFrame(rdd, df.schema) - self.assertEqual(10, df3.count()) - - def test_infer_schema_to_local(self): - input = [{"a": 1}, {"b": "coffee"}] - rdd = self.sc.parallelize(input) - df = self.spark.createDataFrame(input) - df2 = self.spark.createDataFrame(rdd, samplingRatio=1.0) - self.assertEqual(df.schema, df2.schema) - - rdd = self.sc.parallelize(range(10)).map(lambda x: Row(a=x, b=None)) - df3 = self.spark.createDataFrame(rdd, df.schema) - self.assertEqual(10, df3.count()) - - def test_apply_schema_to_dict_and_rows(self): - schema = StructType().add("b", StringType()).add("a", IntegerType()) - input = [{"a": 1}, {"b": "coffee"}] - rdd = self.sc.parallelize(input) - for verify in [False, True]: - df = self.spark.createDataFrame(input, schema, verifySchema=verify) - df2 = self.spark.createDataFrame(rdd, schema, verifySchema=verify) - self.assertEqual(df.schema, df2.schema) - - rdd = self.sc.parallelize(range(10)).map(lambda x: Row(a=x, b=None)) - df3 = self.spark.createDataFrame(rdd, schema, verifySchema=verify) - self.assertEqual(10, df3.count()) - input = [Row(a=x, b=str(x)) for x in range(10)] - df4 = self.spark.createDataFrame(input, schema, verifySchema=verify) - self.assertEqual(10, df4.count()) - - def test_create_dataframe_schema_mismatch(self): - input = [Row(a=1)] - rdd = self.sc.parallelize(range(3)).map(lambda i: Row(a=i)) - schema = StructType([StructField("a", IntegerType()), StructField("b", StringType())]) - df = self.spark.createDataFrame(rdd, schema) - self.assertRaises(Exception, lambda: df.show()) - - def test_serialize_nested_array_and_map(self): - d = [Row(l=[Row(a=1, b='s')], d={"key": Row(c=1.0, d="2")})] - rdd = self.sc.parallelize(d) - df = self.spark.createDataFrame(rdd) - row = df.head() - self.assertEqual(1, len(row.l)) - self.assertEqual(1, row.l[0].a) - self.assertEqual("2", row.d["key"].d) - - l = df.rdd.map(lambda x: x.l).first() - self.assertEqual(1, len(l)) - self.assertEqual('s', l[0].b) - - d = df.rdd.map(lambda x: x.d).first() - self.assertEqual(1, len(d)) - self.assertEqual(1.0, d["key"].c) - - row = df.rdd.map(lambda x: x.d["key"]).first() - self.assertEqual(1.0, row.c) - self.assertEqual("2", row.d) - - def test_infer_schema(self): - d = [Row(l=[], d={}, s=None), - Row(l=[Row(a=1, b='s')], d={"key": Row(c=1.0, d="2")}, s="")] - rdd = self.sc.parallelize(d) - df = self.spark.createDataFrame(rdd) - self.assertEqual([], df.rdd.map(lambda r: r.l).first()) - self.assertEqual([None, ""], df.rdd.map(lambda r: r.s).collect()) - - with self.tempView("test"): - df.createOrReplaceTempView("test") - result = self.spark.sql("SELECT l[0].a from test where d['key'].d = '2'") - self.assertEqual(1, result.head()[0]) - - df2 = self.spark.createDataFrame(rdd, samplingRatio=1.0) - self.assertEqual(df.schema, df2.schema) - self.assertEqual({}, df2.rdd.map(lambda r: r.d).first()) - self.assertEqual([None, ""], df2.rdd.map(lambda r: r.s).collect()) - - with self.tempView("test2"): - df2.createOrReplaceTempView("test2") - result = self.spark.sql("SELECT l[0].a from test2 where d['key'].d = '2'") - self.assertEqual(1, result.head()[0]) - - def test_infer_schema_specification(self): - from decimal import Decimal - - class A(object): - def __init__(self): - self.a = 1 - - data = [ - True, - 1, - "a", - u"a", - datetime.date(1970, 1, 1), - datetime.datetime(1970, 1, 1, 0, 0), - 1.0, - array.array("d", [1]), - [1], - (1, ), - {"a": 1}, - bytearray(1), - Decimal(1), - Row(a=1), - Row("a")(1), - A(), - ] - - df = self.spark.createDataFrame([data]) - actual = list(map(lambda x: x.dataType.simpleString(), df.schema)) - expected = [ - 'boolean', - 'bigint', - 'string', - 'string', - 'date', - 'timestamp', - 'double', - 'array', - 'array', - 'struct<_1:bigint>', - 'map', - 'binary', - 'decimal(38,18)', - 'struct', - 'struct', - 'struct', - ] - self.assertEqual(actual, expected) - - actual = list(df.first()) - expected = [ - True, - 1, - 'a', - u"a", - datetime.date(1970, 1, 1), - datetime.datetime(1970, 1, 1, 0, 0), - 1.0, - [1.0], - [1], - Row(_1=1), - {"a": 1}, - bytearray(b'\x00'), - Decimal('1.000000000000000000'), - Row(a=1), - Row(a=1), - Row(a=1), - ] - self.assertEqual(actual, expected) - - def test_infer_schema_not_enough_names(self): - df = self.spark.createDataFrame([["a", "b"]], ["col1"]) - self.assertEqual(df.columns, ['col1', '_2']) - - def test_infer_schema_fails(self): - with self.assertRaisesRegexp(TypeError, 'field a'): - self.spark.createDataFrame(self.spark.sparkContext.parallelize([[1, 1], ["x", 1]]), - schema=["a", "b"], samplingRatio=0.99) - - def test_infer_nested_schema(self): - NestedRow = Row("f1", "f2") - nestedRdd1 = self.sc.parallelize([NestedRow([1, 2], {"row1": 1.0}), - NestedRow([2, 3], {"row2": 2.0})]) - df = self.spark.createDataFrame(nestedRdd1) - self.assertEqual(Row(f1=[1, 2], f2={u'row1': 1.0}), df.collect()[0]) - - nestedRdd2 = self.sc.parallelize([NestedRow([[1, 2], [2, 3]], [1, 2]), - NestedRow([[2, 3], [3, 4]], [2, 3])]) - df = self.spark.createDataFrame(nestedRdd2) - self.assertEqual(Row(f1=[[1, 2], [2, 3]], f2=[1, 2]), df.collect()[0]) - - from collections import namedtuple - CustomRow = namedtuple('CustomRow', 'field1 field2') - rdd = self.sc.parallelize([CustomRow(field1=1, field2="row1"), - CustomRow(field1=2, field2="row2"), - CustomRow(field1=3, field2="row3")]) - df = self.spark.createDataFrame(rdd) - self.assertEqual(Row(field1=1, field2=u'row1'), df.first()) - - def test_create_dataframe_from_dict_respects_schema(self): - df = self.spark.createDataFrame([{'a': 1}], ["b"]) - self.assertEqual(df.columns, ['b']) - - def test_create_dataframe_from_objects(self): - data = [MyObject(1, "1"), MyObject(2, "2")] - df = self.spark.createDataFrame(data) - self.assertEqual(df.dtypes, [("key", "bigint"), ("value", "string")]) - self.assertEqual(df.first(), Row(key=1, value="1")) - - def test_select_null_literal(self): - df = self.spark.sql("select null as col") - self.assertEqual(Row(col=None), df.first()) - - def test_apply_schema(self): - from datetime import date, datetime - rdd = self.sc.parallelize([(127, -128, -32768, 32767, 2147483647, 1.0, - date(2010, 1, 1), datetime(2010, 1, 1, 1, 1, 1), - {"a": 1}, (2,), [1, 2, 3], None)]) - schema = StructType([ - StructField("byte1", ByteType(), False), - StructField("byte2", ByteType(), False), - StructField("short1", ShortType(), False), - StructField("short2", ShortType(), False), - StructField("int1", IntegerType(), False), - StructField("float1", FloatType(), False), - StructField("date1", DateType(), False), - StructField("time1", TimestampType(), False), - StructField("map1", MapType(StringType(), IntegerType(), False), False), - StructField("struct1", StructType([StructField("b", ShortType(), False)]), False), - StructField("list1", ArrayType(ByteType(), False), False), - StructField("null1", DoubleType(), True)]) - df = self.spark.createDataFrame(rdd, schema) - results = df.rdd.map(lambda x: (x.byte1, x.byte2, x.short1, x.short2, x.int1, x.float1, - x.date1, x.time1, x.map1["a"], x.struct1.b, x.list1, x.null1)) - r = (127, -128, -32768, 32767, 2147483647, 1.0, date(2010, 1, 1), - datetime(2010, 1, 1, 1, 1, 1), 1, 2, [1, 2, 3], None) - self.assertEqual(r, results.first()) - - with self.tempView("table2"): - df.createOrReplaceTempView("table2") - r = self.spark.sql("SELECT byte1 - 1 AS byte1, byte2 + 1 AS byte2, " + - "short1 + 1 AS short1, short2 - 1 AS short2, int1 - 1 AS int1, " + - "float1 + 1.5 as float1 FROM table2").first() - - self.assertEqual((126, -127, -32767, 32766, 2147483646, 2.5), tuple(r)) - - def test_struct_in_map(self): - d = [Row(m={Row(i=1): Row(s="")})] - df = self.sc.parallelize(d).toDF() - k, v = list(df.head().m.items())[0] - self.assertEqual(1, k.i) - self.assertEqual("", v.s) - - def test_convert_row_to_dict(self): - row = Row(l=[Row(a=1, b='s')], d={"key": Row(c=1.0, d="2")}) - self.assertEqual(1, row.asDict()['l'][0].a) - df = self.sc.parallelize([row]).toDF() - - with self.tempView("test"): - df.createOrReplaceTempView("test") - row = self.spark.sql("select l, d from test").head() - self.assertEqual(1, row.asDict()["l"][0].a) - self.assertEqual(1.0, row.asDict()['d']['key'].c) - - def test_udt(self): - from pyspark.sql.types import _parse_datatype_json_string, _infer_type, _make_type_verifier - from pyspark.sql.tests import ExamplePointUDT, ExamplePoint - - def check_datatype(datatype): - pickled = pickle.loads(pickle.dumps(datatype)) - assert datatype == pickled - scala_datatype = self.spark._jsparkSession.parseDataType(datatype.json()) - python_datatype = _parse_datatype_json_string(scala_datatype.json()) - assert datatype == python_datatype - - check_datatype(ExamplePointUDT()) - structtype_with_udt = StructType([StructField("label", DoubleType(), False), - StructField("point", ExamplePointUDT(), False)]) - check_datatype(structtype_with_udt) - p = ExamplePoint(1.0, 2.0) - self.assertEqual(_infer_type(p), ExamplePointUDT()) - _make_type_verifier(ExamplePointUDT())(ExamplePoint(1.0, 2.0)) - self.assertRaises(ValueError, lambda: _make_type_verifier(ExamplePointUDT())([1.0, 2.0])) - - check_datatype(PythonOnlyUDT()) - structtype_with_udt = StructType([StructField("label", DoubleType(), False), - StructField("point", PythonOnlyUDT(), False)]) - check_datatype(structtype_with_udt) - p = PythonOnlyPoint(1.0, 2.0) - self.assertEqual(_infer_type(p), PythonOnlyUDT()) - _make_type_verifier(PythonOnlyUDT())(PythonOnlyPoint(1.0, 2.0)) - self.assertRaises( - ValueError, - lambda: _make_type_verifier(PythonOnlyUDT())([1.0, 2.0])) - - def test_simple_udt_in_df(self): - schema = StructType().add("key", LongType()).add("val", PythonOnlyUDT()) - df = self.spark.createDataFrame( - [(i % 3, PythonOnlyPoint(float(i), float(i))) for i in range(10)], - schema=schema) - df.collect() - - def test_nested_udt_in_df(self): - schema = StructType().add("key", LongType()).add("val", ArrayType(PythonOnlyUDT())) - df = self.spark.createDataFrame( - [(i % 3, [PythonOnlyPoint(float(i), float(i))]) for i in range(10)], - schema=schema) - df.collect() - - schema = StructType().add("key", LongType()).add("val", - MapType(LongType(), PythonOnlyUDT())) - df = self.spark.createDataFrame( - [(i % 3, {i % 3: PythonOnlyPoint(float(i + 1), float(i + 1))}) for i in range(10)], - schema=schema) - df.collect() - - def test_complex_nested_udt_in_df(self): - from pyspark.sql.functions import udf - - schema = StructType().add("key", LongType()).add("val", PythonOnlyUDT()) - df = self.spark.createDataFrame( - [(i % 3, PythonOnlyPoint(float(i), float(i))) for i in range(10)], - schema=schema) - df.collect() - - gd = df.groupby("key").agg({"val": "collect_list"}) - gd.collect() - udf = udf(lambda k, v: [(k, v[0])], ArrayType(df.schema)) - gd.select(udf(*gd)).collect() - - def test_udt_with_none(self): - df = self.spark.range(0, 10, 1, 1) - - def myudf(x): - if x > 0: - return PythonOnlyPoint(float(x), float(x)) - - self.spark.catalog.registerFunction("udf", myudf, PythonOnlyUDT()) - rows = [r[0] for r in df.selectExpr("udf(id)").take(2)] - self.assertEqual(rows, [None, PythonOnlyPoint(1, 1)]) - - def test_nonparam_udf_with_aggregate(self): - import pyspark.sql.functions as f - - df = self.spark.createDataFrame([(1, 2), (1, 2)]) - f_udf = f.udf(lambda: "const_str") - rows = df.distinct().withColumn("a", f_udf()).collect() - self.assertEqual(rows, [Row(_1=1, _2=2, a=u'const_str')]) - - def test_infer_schema_with_udt(self): - from pyspark.sql.tests import ExamplePoint, ExamplePointUDT - row = Row(label=1.0, point=ExamplePoint(1.0, 2.0)) - df = self.spark.createDataFrame([row]) - schema = df.schema - field = [f for f in schema.fields if f.name == "point"][0] - self.assertEqual(type(field.dataType), ExamplePointUDT) - - with self.tempView("labeled_point"): - df.createOrReplaceTempView("labeled_point") - point = self.spark.sql("SELECT point FROM labeled_point").head().point - self.assertEqual(point, ExamplePoint(1.0, 2.0)) - - row = Row(label=1.0, point=PythonOnlyPoint(1.0, 2.0)) - df = self.spark.createDataFrame([row]) - schema = df.schema - field = [f for f in schema.fields if f.name == "point"][0] - self.assertEqual(type(field.dataType), PythonOnlyUDT) - - with self.tempView("labeled_point"): - df.createOrReplaceTempView("labeled_point") - point = self.spark.sql("SELECT point FROM labeled_point").head().point - self.assertEqual(point, PythonOnlyPoint(1.0, 2.0)) - - def test_apply_schema_with_udt(self): - from pyspark.sql.tests import ExamplePoint, ExamplePointUDT - row = (1.0, ExamplePoint(1.0, 2.0)) - schema = StructType([StructField("label", DoubleType(), False), - StructField("point", ExamplePointUDT(), False)]) - df = self.spark.createDataFrame([row], schema) - point = df.head().point - self.assertEqual(point, ExamplePoint(1.0, 2.0)) - - row = (1.0, PythonOnlyPoint(1.0, 2.0)) - schema = StructType([StructField("label", DoubleType(), False), - StructField("point", PythonOnlyUDT(), False)]) - df = self.spark.createDataFrame([row], schema) - point = df.head().point - self.assertEqual(point, PythonOnlyPoint(1.0, 2.0)) - - def test_udf_with_udt(self): - from pyspark.sql.tests import ExamplePoint, ExamplePointUDT - row = Row(label=1.0, point=ExamplePoint(1.0, 2.0)) - df = self.spark.createDataFrame([row]) - self.assertEqual(1.0, df.rdd.map(lambda r: r.point.x).first()) - udf = UserDefinedFunction(lambda p: p.y, DoubleType()) - self.assertEqual(2.0, df.select(udf(df.point)).first()[0]) - udf2 = UserDefinedFunction(lambda p: ExamplePoint(p.x + 1, p.y + 1), ExamplePointUDT()) - self.assertEqual(ExamplePoint(2.0, 3.0), df.select(udf2(df.point)).first()[0]) - - row = Row(label=1.0, point=PythonOnlyPoint(1.0, 2.0)) - df = self.spark.createDataFrame([row]) - self.assertEqual(1.0, df.rdd.map(lambda r: r.point.x).first()) - udf = UserDefinedFunction(lambda p: p.y, DoubleType()) - self.assertEqual(2.0, df.select(udf(df.point)).first()[0]) - udf2 = UserDefinedFunction(lambda p: PythonOnlyPoint(p.x + 1, p.y + 1), PythonOnlyUDT()) - self.assertEqual(PythonOnlyPoint(2.0, 3.0), df.select(udf2(df.point)).first()[0]) - - def test_parquet_with_udt(self): - from pyspark.sql.tests import ExamplePoint, ExamplePointUDT - row = Row(label=1.0, point=ExamplePoint(1.0, 2.0)) - df0 = self.spark.createDataFrame([row]) - output_dir = os.path.join(self.tempdir.name, "labeled_point") - df0.write.parquet(output_dir) - df1 = self.spark.read.parquet(output_dir) - point = df1.head().point - self.assertEqual(point, ExamplePoint(1.0, 2.0)) - - row = Row(label=1.0, point=PythonOnlyPoint(1.0, 2.0)) - df0 = self.spark.createDataFrame([row]) - df0.write.parquet(output_dir, mode='overwrite') - df1 = self.spark.read.parquet(output_dir) - point = df1.head().point - self.assertEqual(point, PythonOnlyPoint(1.0, 2.0)) - - def test_union_with_udt(self): - from pyspark.sql.tests import ExamplePoint, ExamplePointUDT - row1 = (1.0, ExamplePoint(1.0, 2.0)) - row2 = (2.0, ExamplePoint(3.0, 4.0)) - schema = StructType([StructField("label", DoubleType(), False), - StructField("point", ExamplePointUDT(), False)]) - df1 = self.spark.createDataFrame([row1], schema) - df2 = self.spark.createDataFrame([row2], schema) - - result = df1.union(df2).orderBy("label").collect() - self.assertEqual( - result, - [ - Row(label=1.0, point=ExamplePoint(1.0, 2.0)), - Row(label=2.0, point=ExamplePoint(3.0, 4.0)) - ] - ) - - def test_cast_to_string_with_udt(self): - from pyspark.sql.tests import ExamplePointUDT, ExamplePoint - from pyspark.sql.functions import col - row = (ExamplePoint(1.0, 2.0), PythonOnlyPoint(3.0, 4.0)) - schema = StructType([StructField("point", ExamplePointUDT(), False), - StructField("pypoint", PythonOnlyUDT(), False)]) - df = self.spark.createDataFrame([row], schema) - - result = df.select(col('point').cast('string'), col('pypoint').cast('string')).head() - self.assertEqual(result, Row(point=u'(1.0, 2.0)', pypoint=u'[3.0, 4.0]')) - - def test_column_operators(self): - ci = self.df.key - cs = self.df.value - c = ci == cs - self.assertTrue(isinstance((- ci - 1 - 2) % 3 * 2.5 / 3.5, Column)) - rcc = (1 + ci), (1 - ci), (1 * ci), (1 / ci), (1 % ci), (1 ** ci), (ci ** 1) - self.assertTrue(all(isinstance(c, Column) for c in rcc)) - cb = [ci == 5, ci != 0, ci > 3, ci < 4, ci >= 0, ci <= 7] - self.assertTrue(all(isinstance(c, Column) for c in cb)) - cbool = (ci & ci), (ci | ci), (~ci) - self.assertTrue(all(isinstance(c, Column) for c in cbool)) - css = cs.contains('a'), cs.like('a'), cs.rlike('a'), cs.asc(), cs.desc(),\ - cs.startswith('a'), cs.endswith('a'), ci.eqNullSafe(cs) - self.assertTrue(all(isinstance(c, Column) for c in css)) - self.assertTrue(isinstance(ci.cast(LongType()), Column)) - self.assertRaisesRegexp(ValueError, - "Cannot apply 'in' operator against a column", - lambda: 1 in cs) - - def test_column_getitem(self): - from pyspark.sql.functions import col - - self.assertIsInstance(col("foo")[1:3], Column) - self.assertIsInstance(col("foo")[0], Column) - self.assertIsInstance(col("foo")["bar"], Column) - self.assertRaises(ValueError, lambda: col("foo")[0:10:2]) - - def test_column_select(self): - df = self.df - self.assertEqual(self.testData, df.select("*").collect()) - self.assertEqual(self.testData, df.select(df.key, df.value).collect()) - self.assertEqual([Row(value='1')], df.where(df.key == 1).select(df.value).collect()) - - def test_freqItems(self): - vals = [Row(a=1, b=-2.0) if i % 2 == 0 else Row(a=i, b=i * 1.0) for i in range(100)] - df = self.sc.parallelize(vals).toDF() - items = df.stat.freqItems(("a", "b"), 0.4).collect()[0] - self.assertTrue(1 in items[0]) - self.assertTrue(-2.0 in items[1]) - - def test_aggregator(self): - df = self.df - g = df.groupBy() - self.assertEqual([99, 100], sorted(g.agg({'key': 'max', 'value': 'count'}).collect()[0])) - self.assertEqual([Row(**{"AVG(key#0)": 49.5})], g.mean().collect()) - - from pyspark.sql import functions - self.assertEqual((0, u'99'), - tuple(g.agg(functions.first(df.key), functions.last(df.value)).first())) - self.assertTrue(95 < g.agg(functions.approx_count_distinct(df.key)).first()[0]) - self.assertEqual(100, g.agg(functions.countDistinct(df.value)).first()[0]) - - def test_first_last_ignorenulls(self): - from pyspark.sql import functions - df = self.spark.range(0, 100) - df2 = df.select(functions.when(df.id % 3 == 0, None).otherwise(df.id).alias("id")) - df3 = df2.select(functions.first(df2.id, False).alias('a'), - functions.first(df2.id, True).alias('b'), - functions.last(df2.id, False).alias('c'), - functions.last(df2.id, True).alias('d')) - self.assertEqual([Row(a=None, b=1, c=None, d=98)], df3.collect()) - - def test_approxQuantile(self): - df = self.sc.parallelize([Row(a=i, b=i+10) for i in range(10)]).toDF() - for f in ["a", u"a"]: - aq = df.stat.approxQuantile(f, [0.1, 0.5, 0.9], 0.1) - self.assertTrue(isinstance(aq, list)) - self.assertEqual(len(aq), 3) - self.assertTrue(all(isinstance(q, float) for q in aq)) - aqs = df.stat.approxQuantile(["a", u"b"], [0.1, 0.5, 0.9], 0.1) - self.assertTrue(isinstance(aqs, list)) - self.assertEqual(len(aqs), 2) - self.assertTrue(isinstance(aqs[0], list)) - self.assertEqual(len(aqs[0]), 3) - self.assertTrue(all(isinstance(q, float) for q in aqs[0])) - self.assertTrue(isinstance(aqs[1], list)) - self.assertEqual(len(aqs[1]), 3) - self.assertTrue(all(isinstance(q, float) for q in aqs[1])) - aqt = df.stat.approxQuantile((u"a", "b"), [0.1, 0.5, 0.9], 0.1) - self.assertTrue(isinstance(aqt, list)) - self.assertEqual(len(aqt), 2) - self.assertTrue(isinstance(aqt[0], list)) - self.assertEqual(len(aqt[0]), 3) - self.assertTrue(all(isinstance(q, float) for q in aqt[0])) - self.assertTrue(isinstance(aqt[1], list)) - self.assertEqual(len(aqt[1]), 3) - self.assertTrue(all(isinstance(q, float) for q in aqt[1])) - self.assertRaises(ValueError, lambda: df.stat.approxQuantile(123, [0.1, 0.9], 0.1)) - self.assertRaises(ValueError, lambda: df.stat.approxQuantile(("a", 123), [0.1, 0.9], 0.1)) - self.assertRaises(ValueError, lambda: df.stat.approxQuantile(["a", 123], [0.1, 0.9], 0.1)) - - def test_corr(self): - import math - df = self.sc.parallelize([Row(a=i, b=math.sqrt(i)) for i in range(10)]).toDF() - corr = df.stat.corr(u"a", "b") - self.assertTrue(abs(corr - 0.95734012) < 1e-6) - - def test_sampleby(self): - df = self.sc.parallelize([Row(a=i, b=(i % 3)) for i in range(10)]).toDF() - sampled = df.stat.sampleBy(u"b", fractions={0: 0.5, 1: 0.5}, seed=0) - self.assertTrue(sampled.count() == 3) - - def test_cov(self): - df = self.sc.parallelize([Row(a=i, b=2 * i) for i in range(10)]).toDF() - cov = df.stat.cov(u"a", "b") - self.assertTrue(abs(cov - 55.0 / 3) < 1e-6) - - def test_crosstab(self): - df = self.sc.parallelize([Row(a=i % 3, b=i % 2) for i in range(1, 7)]).toDF() - ct = df.stat.crosstab(u"a", "b").collect() - ct = sorted(ct, key=lambda x: x[0]) - for i, row in enumerate(ct): - self.assertEqual(row[0], str(i)) - self.assertTrue(row[1], 1) - self.assertTrue(row[2], 1) - - def test_math_functions(self): - df = self.sc.parallelize([Row(a=i, b=2 * i) for i in range(10)]).toDF() - from pyspark.sql import functions - import math - - def get_values(l): - return [j[0] for j in l] - - def assert_close(a, b): - c = get_values(b) - diff = [abs(v - c[k]) < 1e-6 for k, v in enumerate(a)] - return sum(diff) == len(a) - assert_close([math.cos(i) for i in range(10)], - df.select(functions.cos(df.a)).collect()) - assert_close([math.cos(i) for i in range(10)], - df.select(functions.cos("a")).collect()) - assert_close([math.sin(i) for i in range(10)], - df.select(functions.sin(df.a)).collect()) - assert_close([math.sin(i) for i in range(10)], - df.select(functions.sin(df['a'])).collect()) - assert_close([math.pow(i, 2 * i) for i in range(10)], - df.select(functions.pow(df.a, df.b)).collect()) - assert_close([math.pow(i, 2) for i in range(10)], - df.select(functions.pow(df.a, 2)).collect()) - assert_close([math.pow(i, 2) for i in range(10)], - df.select(functions.pow(df.a, 2.0)).collect()) - assert_close([math.hypot(i, 2 * i) for i in range(10)], - df.select(functions.hypot(df.a, df.b)).collect()) - - def test_rand_functions(self): - df = self.df - from pyspark.sql import functions - rnd = df.select('key', functions.rand()).collect() - for row in rnd: - assert row[1] >= 0.0 and row[1] <= 1.0, "got: %s" % row[1] - rndn = df.select('key', functions.randn(5)).collect() - for row in rndn: - assert row[1] >= -4.0 and row[1] <= 4.0, "got: %s" % row[1] - - # If the specified seed is 0, we should use it. - # https://issues.apache.org/jira/browse/SPARK-9691 - rnd1 = df.select('key', functions.rand(0)).collect() - rnd2 = df.select('key', functions.rand(0)).collect() - self.assertEqual(sorted(rnd1), sorted(rnd2)) - - rndn1 = df.select('key', functions.randn(0)).collect() - rndn2 = df.select('key', functions.randn(0)).collect() - self.assertEqual(sorted(rndn1), sorted(rndn2)) - - def test_string_functions(self): - from pyspark.sql.functions import col, lit - df = self.spark.createDataFrame([['nick']], schema=['name']) - self.assertRaisesRegexp( - TypeError, - "must be the same type", - lambda: df.select(col('name').substr(0, lit(1)))) - if sys.version_info.major == 2: - self.assertRaises( - TypeError, - lambda: df.select(col('name').substr(long(0), long(1)))) - - def test_array_contains_function(self): - from pyspark.sql.functions import array_contains - - df = self.spark.createDataFrame([(["1", "2", "3"],), ([],)], ['data']) - actual = df.select(array_contains(df.data, "1").alias('b')).collect() - self.assertEqual([Row(b=True), Row(b=False)], actual) - - def test_between_function(self): - df = self.sc.parallelize([ - Row(a=1, b=2, c=3), - Row(a=2, b=1, c=3), - Row(a=4, b=1, c=4)]).toDF() - self.assertEqual([Row(a=2, b=1, c=3), Row(a=4, b=1, c=4)], - df.filter(df.a.between(df.b, df.c)).collect()) - - def test_struct_type(self): - struct1 = StructType().add("f1", StringType(), True).add("f2", StringType(), True, None) - struct2 = StructType([StructField("f1", StringType(), True), - StructField("f2", StringType(), True, None)]) - self.assertEqual(struct1.fieldNames(), struct2.names) - self.assertEqual(struct1, struct2) - - struct1 = StructType().add("f1", StringType(), True).add("f2", StringType(), True, None) - struct2 = StructType([StructField("f1", StringType(), True)]) - self.assertNotEqual(struct1.fieldNames(), struct2.names) - self.assertNotEqual(struct1, struct2) - - struct1 = (StructType().add(StructField("f1", StringType(), True)) - .add(StructField("f2", StringType(), True, None))) - struct2 = StructType([StructField("f1", StringType(), True), - StructField("f2", StringType(), True, None)]) - self.assertEqual(struct1.fieldNames(), struct2.names) - self.assertEqual(struct1, struct2) - - struct1 = (StructType().add(StructField("f1", StringType(), True)) - .add(StructField("f2", StringType(), True, None))) - struct2 = StructType([StructField("f1", StringType(), True)]) - self.assertNotEqual(struct1.fieldNames(), struct2.names) - self.assertNotEqual(struct1, struct2) - - # Catch exception raised during improper construction - self.assertRaises(ValueError, lambda: StructType().add("name")) - - struct1 = StructType().add("f1", StringType(), True).add("f2", StringType(), True, None) - for field in struct1: - self.assertIsInstance(field, StructField) - - struct1 = StructType().add("f1", StringType(), True).add("f2", StringType(), True, None) - self.assertEqual(len(struct1), 2) - - struct1 = StructType().add("f1", StringType(), True).add("f2", StringType(), True, None) - self.assertIs(struct1["f1"], struct1.fields[0]) - self.assertIs(struct1[0], struct1.fields[0]) - self.assertEqual(struct1[0:1], StructType(struct1.fields[0:1])) - self.assertRaises(KeyError, lambda: struct1["f9"]) - self.assertRaises(IndexError, lambda: struct1[9]) - self.assertRaises(TypeError, lambda: struct1[9.9]) - - def test_struct_type_to_internal(self): - # Verify when not needSerializeAnyField - struct = StructType().add("b", StringType()).add("a", StringType()) - string_a = "value_a" - string_b = "value_b" - row = Row(a=string_a, b=string_b) - tupleResult = struct.toInternal(row) - # Reversed because of struct - self.assertEqual(tupleResult, (string_b, string_a)) - - # Verify when needSerializeAnyField - struct1 = StructType().add("b", TimestampType()).add("a", TimestampType()) - timestamp_a = datetime.datetime(2018, 1, 1, 1, 1, 1) - timestamp_b = datetime.datetime(2019, 1, 1, 1, 1, 1) - row = Row(a=timestamp_a, b=timestamp_b) - tupleResult = struct1.toInternal(row) - # Reversed because of struct - d = 1000000 - ts_b = tupleResult[0] - new_timestamp_b = datetime.datetime.fromtimestamp(ts_b // d).replace(microsecond=ts_b % d) - ts_a = tupleResult[1] - new_timestamp_a = datetime.datetime.fromtimestamp(ts_a // d).replace(microsecond=ts_a % d) - self.assertEqual(timestamp_a, new_timestamp_a) - self.assertEqual(timestamp_b, new_timestamp_b) - - def test_parse_datatype_string(self): - from pyspark.sql.types import _all_atomic_types, _parse_datatype_string - for k, t in _all_atomic_types.items(): - if t != NullType: - self.assertEqual(t(), _parse_datatype_string(k)) - self.assertEqual(IntegerType(), _parse_datatype_string("int")) - self.assertEqual(DecimalType(1, 1), _parse_datatype_string("decimal(1 ,1)")) - self.assertEqual(DecimalType(10, 1), _parse_datatype_string("decimal( 10,1 )")) - self.assertEqual(DecimalType(11, 1), _parse_datatype_string("decimal(11,1)")) - self.assertEqual( - ArrayType(IntegerType()), - _parse_datatype_string("array")) - self.assertEqual( - MapType(IntegerType(), DoubleType()), - _parse_datatype_string("map< int, double >")) - self.assertEqual( - StructType([StructField("a", IntegerType()), StructField("c", DoubleType())]), - _parse_datatype_string("struct")) - self.assertEqual( - StructType([StructField("a", IntegerType()), StructField("c", DoubleType())]), - _parse_datatype_string("a:int, c:double")) - self.assertEqual( - StructType([StructField("a", IntegerType()), StructField("c", DoubleType())]), - _parse_datatype_string("a INT, c DOUBLE")) - - def test_metadata_null(self): - schema = StructType([StructField("f1", StringType(), True, None), - StructField("f2", StringType(), True, {'a': None})]) - rdd = self.sc.parallelize([["a", "b"], ["c", "d"]]) - self.spark.createDataFrame(rdd, schema) - - def test_save_and_load(self): - df = self.df - tmpPath = tempfile.mkdtemp() - shutil.rmtree(tmpPath) - df.write.json(tmpPath) - actual = self.spark.read.json(tmpPath) - self.assertEqual(sorted(df.collect()), sorted(actual.collect())) - - schema = StructType([StructField("value", StringType(), True)]) - actual = self.spark.read.json(tmpPath, schema) - self.assertEqual(sorted(df.select("value").collect()), sorted(actual.collect())) - - df.write.json(tmpPath, "overwrite") - actual = self.spark.read.json(tmpPath) - self.assertEqual(sorted(df.collect()), sorted(actual.collect())) - - df.write.save(format="json", mode="overwrite", path=tmpPath, - noUse="this options will not be used in save.") - actual = self.spark.read.load(format="json", path=tmpPath, - noUse="this options will not be used in load.") - self.assertEqual(sorted(df.collect()), sorted(actual.collect())) - - defaultDataSourceName = self.spark.conf.get("spark.sql.sources.default", - "org.apache.spark.sql.parquet") - self.spark.sql("SET spark.sql.sources.default=org.apache.spark.sql.json") - actual = self.spark.read.load(path=tmpPath) - self.assertEqual(sorted(df.collect()), sorted(actual.collect())) - self.spark.sql("SET spark.sql.sources.default=" + defaultDataSourceName) - - csvpath = os.path.join(tempfile.mkdtemp(), 'data') - df.write.option('quote', None).format('csv').save(csvpath) - - shutil.rmtree(tmpPath) - - def test_save_and_load_builder(self): - df = self.df - tmpPath = tempfile.mkdtemp() - shutil.rmtree(tmpPath) - df.write.json(tmpPath) - actual = self.spark.read.json(tmpPath) - self.assertEqual(sorted(df.collect()), sorted(actual.collect())) - - schema = StructType([StructField("value", StringType(), True)]) - actual = self.spark.read.json(tmpPath, schema) - self.assertEqual(sorted(df.select("value").collect()), sorted(actual.collect())) - - df.write.mode("overwrite").json(tmpPath) - actual = self.spark.read.json(tmpPath) - self.assertEqual(sorted(df.collect()), sorted(actual.collect())) - - df.write.mode("overwrite").options(noUse="this options will not be used in save.")\ - .option("noUse", "this option will not be used in save.")\ - .format("json").save(path=tmpPath) - actual =\ - self.spark.read.format("json")\ - .load(path=tmpPath, noUse="this options will not be used in load.") - self.assertEqual(sorted(df.collect()), sorted(actual.collect())) - - defaultDataSourceName = self.spark.conf.get("spark.sql.sources.default", - "org.apache.spark.sql.parquet") - self.spark.sql("SET spark.sql.sources.default=org.apache.spark.sql.json") - actual = self.spark.read.load(path=tmpPath) - self.assertEqual(sorted(df.collect()), sorted(actual.collect())) - self.spark.sql("SET spark.sql.sources.default=" + defaultDataSourceName) - - shutil.rmtree(tmpPath) - - def test_stream_trigger(self): - df = self.spark.readStream.format('text').load('python/test_support/sql/streaming') - - # Should take at least one arg - try: - df.writeStream.trigger() - except ValueError: - pass - - # Should not take multiple args - try: - df.writeStream.trigger(once=True, processingTime='5 seconds') - except ValueError: - pass - - # Should not take multiple args - try: - df.writeStream.trigger(processingTime='5 seconds', continuous='1 second') - except ValueError: - pass - - # Should take only keyword args - try: - df.writeStream.trigger('5 seconds') - self.fail("Should have thrown an exception") - except TypeError: - pass - - def test_stream_read_options(self): - schema = StructType([StructField("data", StringType(), False)]) - df = self.spark.readStream\ - .format('text')\ - .option('path', 'python/test_support/sql/streaming')\ - .schema(schema)\ - .load() - self.assertTrue(df.isStreaming) - self.assertEqual(df.schema.simpleString(), "struct") - - def test_stream_read_options_overwrite(self): - bad_schema = StructType([StructField("test", IntegerType(), False)]) - schema = StructType([StructField("data", StringType(), False)]) - df = self.spark.readStream.format('csv').option('path', 'python/test_support/sql/fake') \ - .schema(bad_schema)\ - .load(path='python/test_support/sql/streaming', schema=schema, format='text') - self.assertTrue(df.isStreaming) - self.assertEqual(df.schema.simpleString(), "struct") - - def test_stream_save_options(self): - df = self.spark.readStream.format('text').load('python/test_support/sql/streaming') \ - .withColumn('id', lit(1)) - for q in self.spark._wrapped.streams.active: - q.stop() - tmpPath = tempfile.mkdtemp() - shutil.rmtree(tmpPath) - self.assertTrue(df.isStreaming) - out = os.path.join(tmpPath, 'out') - chk = os.path.join(tmpPath, 'chk') - q = df.writeStream.option('checkpointLocation', chk).queryName('this_query') \ - .format('parquet').partitionBy('id').outputMode('append').option('path', out).start() - try: - self.assertEqual(q.name, 'this_query') - self.assertTrue(q.isActive) - q.processAllAvailable() - output_files = [] - for _, _, files in os.walk(out): - output_files.extend([f for f in files if not f.startswith('.')]) - self.assertTrue(len(output_files) > 0) - self.assertTrue(len(os.listdir(chk)) > 0) - finally: - q.stop() - shutil.rmtree(tmpPath) - - def test_stream_save_options_overwrite(self): - df = self.spark.readStream.format('text').load('python/test_support/sql/streaming') - for q in self.spark._wrapped.streams.active: - q.stop() - tmpPath = tempfile.mkdtemp() - shutil.rmtree(tmpPath) - self.assertTrue(df.isStreaming) - out = os.path.join(tmpPath, 'out') - chk = os.path.join(tmpPath, 'chk') - fake1 = os.path.join(tmpPath, 'fake1') - fake2 = os.path.join(tmpPath, 'fake2') - q = df.writeStream.option('checkpointLocation', fake1)\ - .format('memory').option('path', fake2) \ - .queryName('fake_query').outputMode('append') \ - .start(path=out, format='parquet', queryName='this_query', checkpointLocation=chk) - - try: - self.assertEqual(q.name, 'this_query') - self.assertTrue(q.isActive) - q.processAllAvailable() - output_files = [] - for _, _, files in os.walk(out): - output_files.extend([f for f in files if not f.startswith('.')]) - self.assertTrue(len(output_files) > 0) - self.assertTrue(len(os.listdir(chk)) > 0) - self.assertFalse(os.path.isdir(fake1)) # should not have been created - self.assertFalse(os.path.isdir(fake2)) # should not have been created - finally: - q.stop() - shutil.rmtree(tmpPath) - - def test_stream_status_and_progress(self): - df = self.spark.readStream.format('text').load('python/test_support/sql/streaming') - for q in self.spark._wrapped.streams.active: - q.stop() - tmpPath = tempfile.mkdtemp() - shutil.rmtree(tmpPath) - self.assertTrue(df.isStreaming) - out = os.path.join(tmpPath, 'out') - chk = os.path.join(tmpPath, 'chk') - - def func(x): - time.sleep(1) - return x - - from pyspark.sql.functions import col, udf - sleep_udf = udf(func) - - # Use "sleep_udf" to delay the progress update so that we can test `lastProgress` when there - # were no updates. - q = df.select(sleep_udf(col("value")).alias('value')).writeStream \ - .start(path=out, format='parquet', queryName='this_query', checkpointLocation=chk) - try: - # "lastProgress" will return None in most cases. However, as it may be flaky when - # Jenkins is very slow, we don't assert it. If there is something wrong, "lastProgress" - # may throw error with a high chance and make this test flaky, so we should still be - # able to detect broken codes. - q.lastProgress - - q.processAllAvailable() - lastProgress = q.lastProgress - recentProgress = q.recentProgress - status = q.status - self.assertEqual(lastProgress['name'], q.name) - self.assertEqual(lastProgress['id'], q.id) - self.assertTrue(any(p == lastProgress for p in recentProgress)) - self.assertTrue( - "message" in status and - "isDataAvailable" in status and - "isTriggerActive" in status) - finally: - q.stop() - shutil.rmtree(tmpPath) - - def test_stream_await_termination(self): - df = self.spark.readStream.format('text').load('python/test_support/sql/streaming') - for q in self.spark._wrapped.streams.active: - q.stop() - tmpPath = tempfile.mkdtemp() - shutil.rmtree(tmpPath) - self.assertTrue(df.isStreaming) - out = os.path.join(tmpPath, 'out') - chk = os.path.join(tmpPath, 'chk') - q = df.writeStream\ - .start(path=out, format='parquet', queryName='this_query', checkpointLocation=chk) - try: - self.assertTrue(q.isActive) - try: - q.awaitTermination("hello") - self.fail("Expected a value exception") - except ValueError: - pass - now = time.time() - # test should take at least 2 seconds - res = q.awaitTermination(2.6) - duration = time.time() - now - self.assertTrue(duration >= 2) - self.assertFalse(res) - finally: - q.stop() - shutil.rmtree(tmpPath) - - def test_stream_exception(self): - sdf = self.spark.readStream.format('text').load('python/test_support/sql/streaming') - sq = sdf.writeStream.format('memory').queryName('query_explain').start() - try: - sq.processAllAvailable() - self.assertEqual(sq.exception(), None) - finally: - sq.stop() - - from pyspark.sql.functions import col, udf - from pyspark.sql.utils import StreamingQueryException - bad_udf = udf(lambda x: 1 / 0) - sq = sdf.select(bad_udf(col("value")))\ - .writeStream\ - .format('memory')\ - .queryName('this_query')\ - .start() - try: - # Process some data to fail the query - sq.processAllAvailable() - self.fail("bad udf should fail the query") - except StreamingQueryException as e: - # This is expected - self.assertTrue("ZeroDivisionError" in e.desc) - finally: - sq.stop() - self.assertTrue(type(sq.exception()) is StreamingQueryException) - self.assertTrue("ZeroDivisionError" in sq.exception().desc) - - def test_query_manager_await_termination(self): - df = self.spark.readStream.format('text').load('python/test_support/sql/streaming') - for q in self.spark._wrapped.streams.active: - q.stop() - tmpPath = tempfile.mkdtemp() - shutil.rmtree(tmpPath) - self.assertTrue(df.isStreaming) - out = os.path.join(tmpPath, 'out') - chk = os.path.join(tmpPath, 'chk') - q = df.writeStream\ - .start(path=out, format='parquet', queryName='this_query', checkpointLocation=chk) - try: - self.assertTrue(q.isActive) - try: - self.spark._wrapped.streams.awaitAnyTermination("hello") - self.fail("Expected a value exception") - except ValueError: - pass - now = time.time() - # test should take at least 2 seconds - res = self.spark._wrapped.streams.awaitAnyTermination(2.6) - duration = time.time() - now - self.assertTrue(duration >= 2) - self.assertFalse(res) - finally: - q.stop() - shutil.rmtree(tmpPath) - - class ForeachWriterTester: - - def __init__(self, spark): - self.spark = spark - - def write_open_event(self, partitionId, epochId): - self._write_event( - self.open_events_dir, - {'partition': partitionId, 'epoch': epochId}) - - def write_process_event(self, row): - self._write_event(self.process_events_dir, {'value': 'text'}) - - def write_close_event(self, error): - self._write_event(self.close_events_dir, {'error': str(error)}) - - def write_input_file(self): - self._write_event(self.input_dir, "text") - - def open_events(self): - return self._read_events(self.open_events_dir, 'partition INT, epoch INT') - - def process_events(self): - return self._read_events(self.process_events_dir, 'value STRING') - - def close_events(self): - return self._read_events(self.close_events_dir, 'error STRING') - - def run_streaming_query_on_writer(self, writer, num_files): - self._reset() - try: - sdf = self.spark.readStream.format('text').load(self.input_dir) - sq = sdf.writeStream.foreach(writer).start() - for i in range(num_files): - self.write_input_file() - sq.processAllAvailable() - finally: - self.stop_all() - - def assert_invalid_writer(self, writer, msg=None): - self._reset() - try: - sdf = self.spark.readStream.format('text').load(self.input_dir) - sq = sdf.writeStream.foreach(writer).start() - self.write_input_file() - sq.processAllAvailable() - self.fail("invalid writer %s did not fail the query" % str(writer)) # not expected - except Exception as e: - if msg: - assert msg in str(e), "%s not in %s" % (msg, str(e)) - - finally: - self.stop_all() - - def stop_all(self): - for q in self.spark._wrapped.streams.active: - q.stop() - - def _reset(self): - self.input_dir = tempfile.mkdtemp() - self.open_events_dir = tempfile.mkdtemp() - self.process_events_dir = tempfile.mkdtemp() - self.close_events_dir = tempfile.mkdtemp() - - def _read_events(self, dir, json): - rows = self.spark.read.schema(json).json(dir).collect() - dicts = [row.asDict() for row in rows] - return dicts - - def _write_event(self, dir, event): - import uuid - with open(os.path.join(dir, str(uuid.uuid4())), 'w') as f: - f.write("%s\n" % str(event)) - - def __getstate__(self): - return (self.open_events_dir, self.process_events_dir, self.close_events_dir) - - def __setstate__(self, state): - self.open_events_dir, self.process_events_dir, self.close_events_dir = state - - # Those foreach tests are failed in Python 3.6 and macOS High Sierra by defined rules - # at http://sealiesoftware.com/blog/archive/2017/6/5/Objective-C_and_fork_in_macOS_1013.html - # To work around this, OBJC_DISABLE_INITIALIZE_FORK_SAFETY=YES. - def test_streaming_foreach_with_simple_function(self): - tester = self.ForeachWriterTester(self.spark) - - def foreach_func(row): - tester.write_process_event(row) - - tester.run_streaming_query_on_writer(foreach_func, 2) - self.assertEqual(len(tester.process_events()), 2) - - def test_streaming_foreach_with_basic_open_process_close(self): - tester = self.ForeachWriterTester(self.spark) - - class ForeachWriter: - def open(self, partitionId, epochId): - tester.write_open_event(partitionId, epochId) - return True - - def process(self, row): - tester.write_process_event(row) - - def close(self, error): - tester.write_close_event(error) - - tester.run_streaming_query_on_writer(ForeachWriter(), 2) - - open_events = tester.open_events() - self.assertEqual(len(open_events), 2) - self.assertSetEqual(set([e['epoch'] for e in open_events]), {0, 1}) - - self.assertEqual(len(tester.process_events()), 2) - - close_events = tester.close_events() - self.assertEqual(len(close_events), 2) - self.assertSetEqual(set([e['error'] for e in close_events]), {'None'}) - - def test_streaming_foreach_with_open_returning_false(self): - tester = self.ForeachWriterTester(self.spark) - - class ForeachWriter: - def open(self, partition_id, epoch_id): - tester.write_open_event(partition_id, epoch_id) - return False - - def process(self, row): - tester.write_process_event(row) - - def close(self, error): - tester.write_close_event(error) - - tester.run_streaming_query_on_writer(ForeachWriter(), 2) - - self.assertEqual(len(tester.open_events()), 2) - - self.assertEqual(len(tester.process_events()), 0) # no row was processed - - close_events = tester.close_events() - self.assertEqual(len(close_events), 2) - self.assertSetEqual(set([e['error'] for e in close_events]), {'None'}) - - def test_streaming_foreach_without_open_method(self): - tester = self.ForeachWriterTester(self.spark) - - class ForeachWriter: - def process(self, row): - tester.write_process_event(row) - - def close(self, error): - tester.write_close_event(error) - - tester.run_streaming_query_on_writer(ForeachWriter(), 2) - self.assertEqual(len(tester.open_events()), 0) # no open events - self.assertEqual(len(tester.process_events()), 2) - self.assertEqual(len(tester.close_events()), 2) - - def test_streaming_foreach_without_close_method(self): - tester = self.ForeachWriterTester(self.spark) - - class ForeachWriter: - def open(self, partition_id, epoch_id): - tester.write_open_event(partition_id, epoch_id) - return True - - def process(self, row): - tester.write_process_event(row) - - tester.run_streaming_query_on_writer(ForeachWriter(), 2) - self.assertEqual(len(tester.open_events()), 2) # no open events - self.assertEqual(len(tester.process_events()), 2) - self.assertEqual(len(tester.close_events()), 0) - - def test_streaming_foreach_without_open_and_close_methods(self): - tester = self.ForeachWriterTester(self.spark) - - class ForeachWriter: - def process(self, row): - tester.write_process_event(row) - - tester.run_streaming_query_on_writer(ForeachWriter(), 2) - self.assertEqual(len(tester.open_events()), 0) # no open events - self.assertEqual(len(tester.process_events()), 2) - self.assertEqual(len(tester.close_events()), 0) - - def test_streaming_foreach_with_process_throwing_error(self): - from pyspark.sql.utils import StreamingQueryException - - tester = self.ForeachWriterTester(self.spark) - - class ForeachWriter: - def process(self, row): - raise Exception("test error") - - def close(self, error): - tester.write_close_event(error) - - try: - tester.run_streaming_query_on_writer(ForeachWriter(), 1) - self.fail("bad writer did not fail the query") # this is not expected - except StreamingQueryException as e: - # TODO: Verify whether original error message is inside the exception - pass - - self.assertEqual(len(tester.process_events()), 0) # no row was processed - close_events = tester.close_events() - self.assertEqual(len(close_events), 1) - # TODO: Verify whether original error message is inside the exception - - def test_streaming_foreach_with_invalid_writers(self): - - tester = self.ForeachWriterTester(self.spark) - - def func_with_iterator_input(iter): - for x in iter: - print(x) - - tester.assert_invalid_writer(func_with_iterator_input) - - class WriterWithoutProcess: - def open(self, partition): - pass - - tester.assert_invalid_writer(WriterWithoutProcess(), "does not have a 'process'") - - class WriterWithNonCallableProcess(): - process = True - - tester.assert_invalid_writer(WriterWithNonCallableProcess(), - "'process' in provided object is not callable") - - class WriterWithNoParamProcess(): - def process(self): - pass - - tester.assert_invalid_writer(WriterWithNoParamProcess()) - - # Abstract class for tests below - class WithProcess(): - def process(self, row): - pass - - class WriterWithNonCallableOpen(WithProcess): - open = True - - tester.assert_invalid_writer(WriterWithNonCallableOpen(), - "'open' in provided object is not callable") - - class WriterWithNoParamOpen(WithProcess): - def open(self): - pass - - tester.assert_invalid_writer(WriterWithNoParamOpen()) - - class WriterWithNonCallableClose(WithProcess): - close = True - - tester.assert_invalid_writer(WriterWithNonCallableClose(), - "'close' in provided object is not callable") - - def test_streaming_foreachBatch(self): - q = None - collected = dict() - - def collectBatch(batch_df, batch_id): - collected[batch_id] = batch_df.collect() - - try: - df = self.spark.readStream.format('text').load('python/test_support/sql/streaming') - q = df.writeStream.foreachBatch(collectBatch).start() - q.processAllAvailable() - self.assertTrue(0 in collected) - self.assertTrue(len(collected[0]), 2) - finally: - if q: - q.stop() - - def test_streaming_foreachBatch_propagates_python_errors(self): - from pyspark.sql.utils import StreamingQueryException - - q = None - - def collectBatch(df, id): - raise Exception("this should fail the query") - - try: - df = self.spark.readStream.format('text').load('python/test_support/sql/streaming') - q = df.writeStream.foreachBatch(collectBatch).start() - q.processAllAvailable() - self.fail("Expected a failure") - except StreamingQueryException as e: - self.assertTrue("this should fail" in str(e)) - finally: - if q: - q.stop() - - def test_help_command(self): - # Regression test for SPARK-5464 - rdd = self.sc.parallelize(['{"foo":"bar"}', '{"foo":"baz"}']) - df = self.spark.read.json(rdd) - # render_doc() reproduces the help() exception without printing output - pydoc.render_doc(df) - pydoc.render_doc(df.foo) - pydoc.render_doc(df.take(1)) - - def test_access_column(self): - df = self.df - self.assertTrue(isinstance(df.key, Column)) - self.assertTrue(isinstance(df['key'], Column)) - self.assertTrue(isinstance(df[0], Column)) - self.assertRaises(IndexError, lambda: df[2]) - self.assertRaises(AnalysisException, lambda: df["bad_key"]) - self.assertRaises(TypeError, lambda: df[{}]) - - def test_column_name_with_non_ascii(self): - if sys.version >= '3': - columnName = "数量" - self.assertTrue(isinstance(columnName, str)) - else: - columnName = unicode("数量", "utf-8") - self.assertTrue(isinstance(columnName, unicode)) - schema = StructType([StructField(columnName, LongType(), True)]) - df = self.spark.createDataFrame([(1,)], schema) - self.assertEqual(schema, df.schema) - self.assertEqual("DataFrame[数量: bigint]", str(df)) - self.assertEqual([("数量", 'bigint')], df.dtypes) - self.assertEqual(1, df.select("数量").first()[0]) - self.assertEqual(1, df.select(df["数量"]).first()[0]) - - def test_access_nested_types(self): - df = self.sc.parallelize([Row(l=[1], r=Row(a=1, b="b"), d={"k": "v"})]).toDF() - self.assertEqual(1, df.select(df.l[0]).first()[0]) - self.assertEqual(1, df.select(df.l.getItem(0)).first()[0]) - self.assertEqual(1, df.select(df.r.a).first()[0]) - self.assertEqual("b", df.select(df.r.getField("b")).first()[0]) - self.assertEqual("v", df.select(df.d["k"]).first()[0]) - self.assertEqual("v", df.select(df.d.getItem("k")).first()[0]) - - def test_field_accessor(self): - df = self.sc.parallelize([Row(l=[1], r=Row(a=1, b="b"), d={"k": "v"})]).toDF() - self.assertEqual(1, df.select(df.l[0]).first()[0]) - self.assertEqual(1, df.select(df.r["a"]).first()[0]) - self.assertEqual(1, df.select(df["r.a"]).first()[0]) - self.assertEqual("b", df.select(df.r["b"]).first()[0]) - self.assertEqual("b", df.select(df["r.b"]).first()[0]) - self.assertEqual("v", df.select(df.d["k"]).first()[0]) - - def test_infer_long_type(self): - longrow = [Row(f1='a', f2=100000000000000)] - df = self.sc.parallelize(longrow).toDF() - self.assertEqual(df.schema.fields[1].dataType, LongType()) - - # this saving as Parquet caused issues as well. - output_dir = os.path.join(self.tempdir.name, "infer_long_type") - df.write.parquet(output_dir) - df1 = self.spark.read.parquet(output_dir) - self.assertEqual('a', df1.first().f1) - self.assertEqual(100000000000000, df1.first().f2) - - self.assertEqual(_infer_type(1), LongType()) - self.assertEqual(_infer_type(2**10), LongType()) - self.assertEqual(_infer_type(2**20), LongType()) - self.assertEqual(_infer_type(2**31 - 1), LongType()) - self.assertEqual(_infer_type(2**31), LongType()) - self.assertEqual(_infer_type(2**61), LongType()) - self.assertEqual(_infer_type(2**71), LongType()) - - def test_merge_type(self): - self.assertEqual(_merge_type(LongType(), NullType()), LongType()) - self.assertEqual(_merge_type(NullType(), LongType()), LongType()) - - self.assertEqual(_merge_type(LongType(), LongType()), LongType()) - - self.assertEqual(_merge_type( - ArrayType(LongType()), - ArrayType(LongType()) - ), ArrayType(LongType())) - with self.assertRaisesRegexp(TypeError, 'element in array'): - _merge_type(ArrayType(LongType()), ArrayType(DoubleType())) - - self.assertEqual(_merge_type( - MapType(StringType(), LongType()), - MapType(StringType(), LongType()) - ), MapType(StringType(), LongType())) - with self.assertRaisesRegexp(TypeError, 'key of map'): - _merge_type( - MapType(StringType(), LongType()), - MapType(DoubleType(), LongType())) - with self.assertRaisesRegexp(TypeError, 'value of map'): - _merge_type( - MapType(StringType(), LongType()), - MapType(StringType(), DoubleType())) - - self.assertEqual(_merge_type( - StructType([StructField("f1", LongType()), StructField("f2", StringType())]), - StructType([StructField("f1", LongType()), StructField("f2", StringType())]) - ), StructType([StructField("f1", LongType()), StructField("f2", StringType())])) - with self.assertRaisesRegexp(TypeError, 'field f1'): - _merge_type( - StructType([StructField("f1", LongType()), StructField("f2", StringType())]), - StructType([StructField("f1", DoubleType()), StructField("f2", StringType())])) - - self.assertEqual(_merge_type( - StructType([StructField("f1", StructType([StructField("f2", LongType())]))]), - StructType([StructField("f1", StructType([StructField("f2", LongType())]))]) - ), StructType([StructField("f1", StructType([StructField("f2", LongType())]))])) - with self.assertRaisesRegexp(TypeError, 'field f2 in field f1'): - _merge_type( - StructType([StructField("f1", StructType([StructField("f2", LongType())]))]), - StructType([StructField("f1", StructType([StructField("f2", StringType())]))])) - - self.assertEqual(_merge_type( - StructType([StructField("f1", ArrayType(LongType())), StructField("f2", StringType())]), - StructType([StructField("f1", ArrayType(LongType())), StructField("f2", StringType())]) - ), StructType([StructField("f1", ArrayType(LongType())), StructField("f2", StringType())])) - with self.assertRaisesRegexp(TypeError, 'element in array field f1'): - _merge_type( - StructType([ - StructField("f1", ArrayType(LongType())), - StructField("f2", StringType())]), - StructType([ - StructField("f1", ArrayType(DoubleType())), - StructField("f2", StringType())])) - - self.assertEqual(_merge_type( - StructType([ - StructField("f1", MapType(StringType(), LongType())), - StructField("f2", StringType())]), - StructType([ - StructField("f1", MapType(StringType(), LongType())), - StructField("f2", StringType())]) - ), StructType([ - StructField("f1", MapType(StringType(), LongType())), - StructField("f2", StringType())])) - with self.assertRaisesRegexp(TypeError, 'value of map field f1'): - _merge_type( - StructType([ - StructField("f1", MapType(StringType(), LongType())), - StructField("f2", StringType())]), - StructType([ - StructField("f1", MapType(StringType(), DoubleType())), - StructField("f2", StringType())])) - - self.assertEqual(_merge_type( - StructType([StructField("f1", ArrayType(MapType(StringType(), LongType())))]), - StructType([StructField("f1", ArrayType(MapType(StringType(), LongType())))]) - ), StructType([StructField("f1", ArrayType(MapType(StringType(), LongType())))])) - with self.assertRaisesRegexp(TypeError, 'key of map element in array field f1'): - _merge_type( - StructType([StructField("f1", ArrayType(MapType(StringType(), LongType())))]), - StructType([StructField("f1", ArrayType(MapType(DoubleType(), LongType())))]) - ) - - def test_filter_with_datetime(self): - time = datetime.datetime(2015, 4, 17, 23, 1, 2, 3000) - date = time.date() - row = Row(date=date, time=time) - df = self.spark.createDataFrame([row]) - self.assertEqual(1, df.filter(df.date == date).count()) - self.assertEqual(1, df.filter(df.time == time).count()) - self.assertEqual(0, df.filter(df.date > date).count()) - self.assertEqual(0, df.filter(df.time > time).count()) - - def test_filter_with_datetime_timezone(self): - dt1 = datetime.datetime(2015, 4, 17, 23, 1, 2, 3000, tzinfo=UTCOffsetTimezone(0)) - dt2 = datetime.datetime(2015, 4, 17, 23, 1, 2, 3000, tzinfo=UTCOffsetTimezone(1)) - row = Row(date=dt1) - df = self.spark.createDataFrame([row]) - self.assertEqual(0, df.filter(df.date == dt2).count()) - self.assertEqual(1, df.filter(df.date > dt2).count()) - self.assertEqual(0, df.filter(df.date < dt2).count()) - - def test_time_with_timezone(self): - day = datetime.date.today() - now = datetime.datetime.now() - ts = time.mktime(now.timetuple()) - # class in __main__ is not serializable - from pyspark.sql.tests import UTCOffsetTimezone - utc = UTCOffsetTimezone() - utcnow = datetime.datetime.utcfromtimestamp(ts) # without microseconds - # add microseconds to utcnow (keeping year,month,day,hour,minute,second) - utcnow = datetime.datetime(*(utcnow.timetuple()[:6] + (now.microsecond, utc))) - df = self.spark.createDataFrame([(day, now, utcnow)]) - day1, now1, utcnow1 = df.first() - self.assertEqual(day1, day) - self.assertEqual(now, now1) - self.assertEqual(now, utcnow1) - - # regression test for SPARK-19561 - def test_datetime_at_epoch(self): - epoch = datetime.datetime.fromtimestamp(0) - df = self.spark.createDataFrame([Row(date=epoch)]) - first = df.select('date', lit(epoch).alias('lit_date')).first() - self.assertEqual(first['date'], epoch) - self.assertEqual(first['lit_date'], epoch) - - def test_dayofweek(self): - from pyspark.sql.functions import dayofweek - dt = datetime.datetime(2017, 11, 6) - df = self.spark.createDataFrame([Row(date=dt)]) - row = df.select(dayofweek(df.date)).first() - self.assertEqual(row[0], 2) - - def test_decimal(self): - from decimal import Decimal - schema = StructType([StructField("decimal", DecimalType(10, 5))]) - df = self.spark.createDataFrame([(Decimal("3.14159"),)], schema) - row = df.select(df.decimal + 1).first() - self.assertEqual(row[0], Decimal("4.14159")) - tmpPath = tempfile.mkdtemp() - shutil.rmtree(tmpPath) - df.write.parquet(tmpPath) - df2 = self.spark.read.parquet(tmpPath) - row = df2.first() - self.assertEqual(row[0], Decimal("3.14159")) - - def test_dropna(self): - schema = StructType([ - StructField("name", StringType(), True), - StructField("age", IntegerType(), True), - StructField("height", DoubleType(), True)]) - - # shouldn't drop a non-null row - self.assertEqual(self.spark.createDataFrame( - [(u'Alice', 50, 80.1)], schema).dropna().count(), - 1) - - # dropping rows with a single null value - self.assertEqual(self.spark.createDataFrame( - [(u'Alice', None, 80.1)], schema).dropna().count(), - 0) - self.assertEqual(self.spark.createDataFrame( - [(u'Alice', None, 80.1)], schema).dropna(how='any').count(), - 0) - - # if how = 'all', only drop rows if all values are null - self.assertEqual(self.spark.createDataFrame( - [(u'Alice', None, 80.1)], schema).dropna(how='all').count(), - 1) - self.assertEqual(self.spark.createDataFrame( - [(None, None, None)], schema).dropna(how='all').count(), - 0) - - # how and subset - self.assertEqual(self.spark.createDataFrame( - [(u'Alice', 50, None)], schema).dropna(how='any', subset=['name', 'age']).count(), - 1) - self.assertEqual(self.spark.createDataFrame( - [(u'Alice', None, None)], schema).dropna(how='any', subset=['name', 'age']).count(), - 0) - - # threshold - self.assertEqual(self.spark.createDataFrame( - [(u'Alice', None, 80.1)], schema).dropna(thresh=2).count(), - 1) - self.assertEqual(self.spark.createDataFrame( - [(u'Alice', None, None)], schema).dropna(thresh=2).count(), - 0) - - # threshold and subset - self.assertEqual(self.spark.createDataFrame( - [(u'Alice', 50, None)], schema).dropna(thresh=2, subset=['name', 'age']).count(), - 1) - self.assertEqual(self.spark.createDataFrame( - [(u'Alice', None, 180.9)], schema).dropna(thresh=2, subset=['name', 'age']).count(), - 0) - - # thresh should take precedence over how - self.assertEqual(self.spark.createDataFrame( - [(u'Alice', 50, None)], schema).dropna( - how='any', thresh=2, subset=['name', 'age']).count(), - 1) - - def test_fillna(self): - schema = StructType([ - StructField("name", StringType(), True), - StructField("age", IntegerType(), True), - StructField("height", DoubleType(), True), - StructField("spy", BooleanType(), True)]) - - # fillna shouldn't change non-null values - row = self.spark.createDataFrame([(u'Alice', 10, 80.1, True)], schema).fillna(50).first() - self.assertEqual(row.age, 10) - - # fillna with int - row = self.spark.createDataFrame([(u'Alice', None, None, None)], schema).fillna(50).first() - self.assertEqual(row.age, 50) - self.assertEqual(row.height, 50.0) - - # fillna with double - row = self.spark.createDataFrame( - [(u'Alice', None, None, None)], schema).fillna(50.1).first() - self.assertEqual(row.age, 50) - self.assertEqual(row.height, 50.1) - - # fillna with bool - row = self.spark.createDataFrame( - [(u'Alice', None, None, None)], schema).fillna(True).first() - self.assertEqual(row.age, None) - self.assertEqual(row.spy, True) - - # fillna with string - row = self.spark.createDataFrame([(None, None, None, None)], schema).fillna("hello").first() - self.assertEqual(row.name, u"hello") - self.assertEqual(row.age, None) - - # fillna with subset specified for numeric cols - row = self.spark.createDataFrame( - [(None, None, None, None)], schema).fillna(50, subset=['name', 'age']).first() - self.assertEqual(row.name, None) - self.assertEqual(row.age, 50) - self.assertEqual(row.height, None) - self.assertEqual(row.spy, None) - - # fillna with subset specified for string cols - row = self.spark.createDataFrame( - [(None, None, None, None)], schema).fillna("haha", subset=['name', 'age']).first() - self.assertEqual(row.name, "haha") - self.assertEqual(row.age, None) - self.assertEqual(row.height, None) - self.assertEqual(row.spy, None) - - # fillna with subset specified for bool cols - row = self.spark.createDataFrame( - [(None, None, None, None)], schema).fillna(True, subset=['name', 'spy']).first() - self.assertEqual(row.name, None) - self.assertEqual(row.age, None) - self.assertEqual(row.height, None) - self.assertEqual(row.spy, True) - - # fillna with dictionary for boolean types - row = self.spark.createDataFrame([Row(a=None), Row(a=True)]).fillna({"a": True}).first() - self.assertEqual(row.a, True) - - def test_bitwise_operations(self): - from pyspark.sql import functions - row = Row(a=170, b=75) - df = self.spark.createDataFrame([row]) - result = df.select(df.a.bitwiseAND(df.b)).collect()[0].asDict() - self.assertEqual(170 & 75, result['(a & b)']) - result = df.select(df.a.bitwiseOR(df.b)).collect()[0].asDict() - self.assertEqual(170 | 75, result['(a | b)']) - result = df.select(df.a.bitwiseXOR(df.b)).collect()[0].asDict() - self.assertEqual(170 ^ 75, result['(a ^ b)']) - result = df.select(functions.bitwiseNOT(df.b)).collect()[0].asDict() - self.assertEqual(~75, result['~b']) - - def test_expr(self): - from pyspark.sql import functions - row = Row(a="length string", b=75) - df = self.spark.createDataFrame([row]) - result = df.select(functions.expr("length(a)")).collect()[0].asDict() - self.assertEqual(13, result["length(a)"]) - - def test_repartitionByRange_dataframe(self): - schema = StructType([ - StructField("name", StringType(), True), - StructField("age", IntegerType(), True), - StructField("height", DoubleType(), True)]) - - df1 = self.spark.createDataFrame( - [(u'Bob', 27, 66.0), (u'Alice', 10, 10.0), (u'Bob', 10, 66.0)], schema) - df2 = self.spark.createDataFrame( - [(u'Alice', 10, 10.0), (u'Bob', 10, 66.0), (u'Bob', 27, 66.0)], schema) - - # test repartitionByRange(numPartitions, *cols) - df3 = df1.repartitionByRange(2, "name", "age") - self.assertEqual(df3.rdd.getNumPartitions(), 2) - self.assertEqual(df3.rdd.first(), df2.rdd.first()) - self.assertEqual(df3.rdd.take(3), df2.rdd.take(3)) - - # test repartitionByRange(numPartitions, *cols) - df4 = df1.repartitionByRange(3, "name", "age") - self.assertEqual(df4.rdd.getNumPartitions(), 3) - self.assertEqual(df4.rdd.first(), df2.rdd.first()) - self.assertEqual(df4.rdd.take(3), df2.rdd.take(3)) - - # test repartitionByRange(*cols) - df5 = df1.repartitionByRange("name", "age") - self.assertEqual(df5.rdd.first(), df2.rdd.first()) - self.assertEqual(df5.rdd.take(3), df2.rdd.take(3)) - - def test_replace(self): - schema = StructType([ - StructField("name", StringType(), True), - StructField("age", IntegerType(), True), - StructField("height", DoubleType(), True)]) - - # replace with int - row = self.spark.createDataFrame([(u'Alice', 10, 10.0)], schema).replace(10, 20).first() - self.assertEqual(row.age, 20) - self.assertEqual(row.height, 20.0) - - # replace with double - row = self.spark.createDataFrame( - [(u'Alice', 80, 80.0)], schema).replace(80.0, 82.1).first() - self.assertEqual(row.age, 82) - self.assertEqual(row.height, 82.1) - - # replace with string - row = self.spark.createDataFrame( - [(u'Alice', 10, 80.1)], schema).replace(u'Alice', u'Ann').first() - self.assertEqual(row.name, u"Ann") - self.assertEqual(row.age, 10) - - # replace with subset specified by a string of a column name w/ actual change - row = self.spark.createDataFrame( - [(u'Alice', 10, 80.1)], schema).replace(10, 20, subset='age').first() - self.assertEqual(row.age, 20) - - # replace with subset specified by a string of a column name w/o actual change - row = self.spark.createDataFrame( - [(u'Alice', 10, 80.1)], schema).replace(10, 20, subset='height').first() - self.assertEqual(row.age, 10) - - # replace with subset specified with one column replaced, another column not in subset - # stays unchanged. - row = self.spark.createDataFrame( - [(u'Alice', 10, 10.0)], schema).replace(10, 20, subset=['name', 'age']).first() - self.assertEqual(row.name, u'Alice') - self.assertEqual(row.age, 20) - self.assertEqual(row.height, 10.0) - - # replace with subset specified but no column will be replaced - row = self.spark.createDataFrame( - [(u'Alice', 10, None)], schema).replace(10, 20, subset=['name', 'height']).first() - self.assertEqual(row.name, u'Alice') - self.assertEqual(row.age, 10) - self.assertEqual(row.height, None) - - # replace with lists - row = self.spark.createDataFrame( - [(u'Alice', 10, 80.1)], schema).replace([u'Alice'], [u'Ann']).first() - self.assertTupleEqual(row, (u'Ann', 10, 80.1)) - - # replace with dict - row = self.spark.createDataFrame( - [(u'Alice', 10, 80.1)], schema).replace({10: 11}).first() - self.assertTupleEqual(row, (u'Alice', 11, 80.1)) - - # test backward compatibility with dummy value - dummy_value = 1 - row = self.spark.createDataFrame( - [(u'Alice', 10, 80.1)], schema).replace({'Alice': 'Bob'}, dummy_value).first() - self.assertTupleEqual(row, (u'Bob', 10, 80.1)) - - # test dict with mixed numerics - row = self.spark.createDataFrame( - [(u'Alice', 10, 80.1)], schema).replace({10: -10, 80.1: 90.5}).first() - self.assertTupleEqual(row, (u'Alice', -10, 90.5)) - - # replace with tuples - row = self.spark.createDataFrame( - [(u'Alice', 10, 80.1)], schema).replace((u'Alice', ), (u'Bob', )).first() - self.assertTupleEqual(row, (u'Bob', 10, 80.1)) - - # replace multiple columns - row = self.spark.createDataFrame( - [(u'Alice', 10, 80.0)], schema).replace((10, 80.0), (20, 90)).first() - self.assertTupleEqual(row, (u'Alice', 20, 90.0)) - - # test for mixed numerics - row = self.spark.createDataFrame( - [(u'Alice', 10, 80.0)], schema).replace((10, 80), (20, 90.5)).first() - self.assertTupleEqual(row, (u'Alice', 20, 90.5)) - - row = self.spark.createDataFrame( - [(u'Alice', 10, 80.0)], schema).replace({10: 20, 80: 90.5}).first() - self.assertTupleEqual(row, (u'Alice', 20, 90.5)) - - # replace with boolean - row = (self - .spark.createDataFrame([(u'Alice', 10, 80.0)], schema) - .selectExpr("name = 'Bob'", 'age <= 15') - .replace(False, True).first()) - self.assertTupleEqual(row, (True, True)) - - # replace string with None and then drop None rows - row = self.spark.createDataFrame( - [(u'Alice', 10, 80.0)], schema).replace(u'Alice', None).dropna() - self.assertEqual(row.count(), 0) - - # replace with number and None - row = self.spark.createDataFrame( - [(u'Alice', 10, 80.0)], schema).replace([10, 80], [20, None]).first() - self.assertTupleEqual(row, (u'Alice', 20, None)) - - # should fail if subset is not list, tuple or None - with self.assertRaises(ValueError): - self.spark.createDataFrame( - [(u'Alice', 10, 80.1)], schema).replace({10: 11}, subset=1).first() - - # should fail if to_replace and value have different length - with self.assertRaises(ValueError): - self.spark.createDataFrame( - [(u'Alice', 10, 80.1)], schema).replace(["Alice", "Bob"], ["Eve"]).first() - - # should fail if when received unexpected type - with self.assertRaises(ValueError): - from datetime import datetime - self.spark.createDataFrame( - [(u'Alice', 10, 80.1)], schema).replace(datetime.now(), datetime.now()).first() - - # should fail if provided mixed type replacements - with self.assertRaises(ValueError): - self.spark.createDataFrame( - [(u'Alice', 10, 80.1)], schema).replace(["Alice", 10], ["Eve", 20]).first() - - with self.assertRaises(ValueError): - self.spark.createDataFrame( - [(u'Alice', 10, 80.1)], schema).replace({u"Alice": u"Bob", 10: 20}).first() - - with self.assertRaisesRegexp( - TypeError, - 'value argument is required when to_replace is not a dictionary.'): - self.spark.createDataFrame( - [(u'Alice', 10, 80.0)], schema).replace(["Alice", "Bob"]).first() - - def test_capture_analysis_exception(self): - self.assertRaises(AnalysisException, lambda: self.spark.sql("select abc")) - self.assertRaises(AnalysisException, lambda: self.df.selectExpr("a + b")) - - def test_capture_parse_exception(self): - self.assertRaises(ParseException, lambda: self.spark.sql("abc")) - - def test_capture_illegalargument_exception(self): - self.assertRaisesRegexp(IllegalArgumentException, "Setting negative mapred.reduce.tasks", - lambda: self.spark.sql("SET mapred.reduce.tasks=-1")) - df = self.spark.createDataFrame([(1, 2)], ["a", "b"]) - self.assertRaisesRegexp(IllegalArgumentException, "1024 is not in the permitted values", - lambda: df.select(sha2(df.a, 1024)).collect()) - try: - df.select(sha2(df.a, 1024)).collect() - except IllegalArgumentException as e: - self.assertRegexpMatches(e.desc, "1024 is not in the permitted values") - self.assertRegexpMatches(e.stackTrace, - "org.apache.spark.sql.functions") - - def test_with_column_with_existing_name(self): - keys = self.df.withColumn("key", self.df.key).select("key").collect() - self.assertEqual([r.key for r in keys], list(range(100))) - - # regression test for SPARK-10417 - def test_column_iterator(self): - - def foo(): - for x in self.df.key: - break - - self.assertRaises(TypeError, foo) - - # add test for SPARK-10577 (test broadcast join hint) - def test_functions_broadcast(self): - from pyspark.sql.functions import broadcast - - df1 = self.spark.createDataFrame([(1, "1"), (2, "2")], ("key", "value")) - df2 = self.spark.createDataFrame([(1, "1"), (2, "2")], ("key", "value")) - - # equijoin - should be converted into broadcast join - plan1 = df1.join(broadcast(df2), "key")._jdf.queryExecution().executedPlan() - self.assertEqual(1, plan1.toString().count("BroadcastHashJoin")) - - # no join key -- should not be a broadcast join - plan2 = df1.crossJoin(broadcast(df2))._jdf.queryExecution().executedPlan() - self.assertEqual(0, plan2.toString().count("BroadcastHashJoin")) - - # planner should not crash without a join - broadcast(df1)._jdf.queryExecution().executedPlan() - - def test_generic_hints(self): - from pyspark.sql import DataFrame - - df1 = self.spark.range(10e10).toDF("id") - df2 = self.spark.range(10e10).toDF("id") - - self.assertIsInstance(df1.hint("broadcast"), DataFrame) - self.assertIsInstance(df1.hint("broadcast", []), DataFrame) - - # Dummy rules - self.assertIsInstance(df1.hint("broadcast", "foo", "bar"), DataFrame) - self.assertIsInstance(df1.hint("broadcast", ["foo", "bar"]), DataFrame) - - plan = df1.join(df2.hint("broadcast"), "id")._jdf.queryExecution().executedPlan() - self.assertEqual(1, plan.toString().count("BroadcastHashJoin")) - - def test_sample(self): - self.assertRaisesRegexp( - TypeError, - "should be a bool, float and number", - lambda: self.spark.range(1).sample()) - - self.assertRaises( - TypeError, - lambda: self.spark.range(1).sample("a")) - - self.assertRaises( - TypeError, - lambda: self.spark.range(1).sample(seed="abc")) - - self.assertRaises( - IllegalArgumentException, - lambda: self.spark.range(1).sample(-1.0)) - - def test_toDF_with_schema_string(self): - data = [Row(key=i, value=str(i)) for i in range(100)] - rdd = self.sc.parallelize(data, 5) - - df = rdd.toDF("key: int, value: string") - self.assertEqual(df.schema.simpleString(), "struct") - self.assertEqual(df.collect(), data) - - # different but compatible field types can be used. - df = rdd.toDF("key: string, value: string") - self.assertEqual(df.schema.simpleString(), "struct") - self.assertEqual(df.collect(), [Row(key=str(i), value=str(i)) for i in range(100)]) - - # field names can differ. - df = rdd.toDF(" a: int, b: string ") - self.assertEqual(df.schema.simpleString(), "struct") - self.assertEqual(df.collect(), data) - - # number of fields must match. - self.assertRaisesRegexp(Exception, "Length of object", - lambda: rdd.toDF("key: int").collect()) - - # field types mismatch will cause exception at runtime. - self.assertRaisesRegexp(Exception, "FloatType can not accept", - lambda: rdd.toDF("key: float, value: string").collect()) - - # flat schema values will be wrapped into row. - df = rdd.map(lambda row: row.key).toDF("int") - self.assertEqual(df.schema.simpleString(), "struct") - self.assertEqual(df.collect(), [Row(key=i) for i in range(100)]) - - # users can use DataType directly instead of data type string. - df = rdd.map(lambda row: row.key).toDF(IntegerType()) - self.assertEqual(df.schema.simpleString(), "struct") - self.assertEqual(df.collect(), [Row(key=i) for i in range(100)]) - - def test_join_without_on(self): - df1 = self.spark.range(1).toDF("a") - df2 = self.spark.range(1).toDF("b") - - with self.sql_conf({"spark.sql.crossJoin.enabled": False}): - self.assertRaises(AnalysisException, lambda: df1.join(df2, how="inner").collect()) - - with self.sql_conf({"spark.sql.crossJoin.enabled": True}): - actual = df1.join(df2, how="inner").collect() - expected = [Row(a=0, b=0)] - self.assertEqual(actual, expected) - - # Regression test for invalid join methods when on is None, Spark-14761 - def test_invalid_join_method(self): - df1 = self.spark.createDataFrame([("Alice", 5), ("Bob", 8)], ["name", "age"]) - df2 = self.spark.createDataFrame([("Alice", 80), ("Bob", 90)], ["name", "height"]) - self.assertRaises(IllegalArgumentException, lambda: df1.join(df2, how="invalid-join-type")) - - # Cartesian products require cross join syntax - def test_require_cross(self): - from pyspark.sql.functions import broadcast - - df1 = self.spark.createDataFrame([(1, "1")], ("key", "value")) - df2 = self.spark.createDataFrame([(1, "1")], ("key", "value")) - - # joins without conditions require cross join syntax - self.assertRaises(AnalysisException, lambda: df1.join(df2).collect()) - - # works with crossJoin - self.assertEqual(1, df1.crossJoin(df2).count()) - - def test_conf(self): - spark = self.spark - spark.conf.set("bogo", "sipeo") - self.assertEqual(spark.conf.get("bogo"), "sipeo") - spark.conf.set("bogo", "ta") - self.assertEqual(spark.conf.get("bogo"), "ta") - self.assertEqual(spark.conf.get("bogo", "not.read"), "ta") - self.assertEqual(spark.conf.get("not.set", "ta"), "ta") - self.assertRaisesRegexp(Exception, "not.set", lambda: spark.conf.get("not.set")) - spark.conf.unset("bogo") - self.assertEqual(spark.conf.get("bogo", "colombia"), "colombia") - - self.assertEqual(spark.conf.get("hyukjin", None), None) - - # This returns 'STATIC' because it's the default value of - # 'spark.sql.sources.partitionOverwriteMode', and `defaultValue` in - # `spark.conf.get` is unset. - self.assertEqual(spark.conf.get("spark.sql.sources.partitionOverwriteMode"), "STATIC") - - # This returns None because 'spark.sql.sources.partitionOverwriteMode' is unset, but - # `defaultValue` in `spark.conf.get` is set to None. - self.assertEqual(spark.conf.get("spark.sql.sources.partitionOverwriteMode", None), None) - - def test_current_database(self): - spark = self.spark - with self.database("some_db"): - self.assertEquals(spark.catalog.currentDatabase(), "default") - spark.sql("CREATE DATABASE some_db") - spark.catalog.setCurrentDatabase("some_db") - self.assertEquals(spark.catalog.currentDatabase(), "some_db") - self.assertRaisesRegexp( - AnalysisException, - "does_not_exist", - lambda: spark.catalog.setCurrentDatabase("does_not_exist")) - - def test_list_databases(self): - spark = self.spark - with self.database("some_db"): - databases = [db.name for db in spark.catalog.listDatabases()] - self.assertEquals(databases, ["default"]) - spark.sql("CREATE DATABASE some_db") - databases = [db.name for db in spark.catalog.listDatabases()] - self.assertEquals(sorted(databases), ["default", "some_db"]) - - def test_list_tables(self): - from pyspark.sql.catalog import Table - spark = self.spark - with self.database("some_db"): - spark.sql("CREATE DATABASE some_db") - with self.table("tab1", "some_db.tab2"): - with self.tempView("temp_tab"): - self.assertEquals(spark.catalog.listTables(), []) - self.assertEquals(spark.catalog.listTables("some_db"), []) - spark.createDataFrame([(1, 1)]).createOrReplaceTempView("temp_tab") - spark.sql("CREATE TABLE tab1 (name STRING, age INT) USING parquet") - spark.sql("CREATE TABLE some_db.tab2 (name STRING, age INT) USING parquet") - tables = sorted(spark.catalog.listTables(), key=lambda t: t.name) - tablesDefault = \ - sorted(spark.catalog.listTables("default"), key=lambda t: t.name) - tablesSomeDb = \ - sorted(spark.catalog.listTables("some_db"), key=lambda t: t.name) - self.assertEquals(tables, tablesDefault) - self.assertEquals(len(tables), 2) - self.assertEquals(len(tablesSomeDb), 2) - self.assertEquals(tables[0], Table( - name="tab1", - database="default", - description=None, - tableType="MANAGED", - isTemporary=False)) - self.assertEquals(tables[1], Table( - name="temp_tab", - database=None, - description=None, - tableType="TEMPORARY", - isTemporary=True)) - self.assertEquals(tablesSomeDb[0], Table( - name="tab2", - database="some_db", - description=None, - tableType="MANAGED", - isTemporary=False)) - self.assertEquals(tablesSomeDb[1], Table( - name="temp_tab", - database=None, - description=None, - tableType="TEMPORARY", - isTemporary=True)) - self.assertRaisesRegexp( - AnalysisException, - "does_not_exist", - lambda: spark.catalog.listTables("does_not_exist")) - - def test_list_functions(self): - from pyspark.sql.catalog import Function - spark = self.spark - with self.database("some_db"): - spark.sql("CREATE DATABASE some_db") - functions = dict((f.name, f) for f in spark.catalog.listFunctions()) - functionsDefault = dict((f.name, f) for f in spark.catalog.listFunctions("default")) - self.assertTrue(len(functions) > 200) - self.assertTrue("+" in functions) - self.assertTrue("like" in functions) - self.assertTrue("month" in functions) - self.assertTrue("to_date" in functions) - self.assertTrue("to_timestamp" in functions) - self.assertTrue("to_unix_timestamp" in functions) - self.assertTrue("current_database" in functions) - self.assertEquals(functions["+"], Function( - name="+", - description=None, - className="org.apache.spark.sql.catalyst.expressions.Add", - isTemporary=True)) - self.assertEquals(functions, functionsDefault) - - with self.function("func1", "some_db.func2"): - spark.catalog.registerFunction("temp_func", lambda x: str(x)) - spark.sql("CREATE FUNCTION func1 AS 'org.apache.spark.data.bricks'") - spark.sql("CREATE FUNCTION some_db.func2 AS 'org.apache.spark.data.bricks'") - newFunctions = dict((f.name, f) for f in spark.catalog.listFunctions()) - newFunctionsSomeDb = \ - dict((f.name, f) for f in spark.catalog.listFunctions("some_db")) - self.assertTrue(set(functions).issubset(set(newFunctions))) - self.assertTrue(set(functions).issubset(set(newFunctionsSomeDb))) - self.assertTrue("temp_func" in newFunctions) - self.assertTrue("func1" in newFunctions) - self.assertTrue("func2" not in newFunctions) - self.assertTrue("temp_func" in newFunctionsSomeDb) - self.assertTrue("func1" not in newFunctionsSomeDb) - self.assertTrue("func2" in newFunctionsSomeDb) - self.assertRaisesRegexp( - AnalysisException, - "does_not_exist", - lambda: spark.catalog.listFunctions("does_not_exist")) - - def test_list_columns(self): - from pyspark.sql.catalog import Column - spark = self.spark - with self.database("some_db"): - spark.sql("CREATE DATABASE some_db") - with self.table("tab1", "some_db.tab2"): - spark.sql("CREATE TABLE tab1 (name STRING, age INT) USING parquet") - spark.sql( - "CREATE TABLE some_db.tab2 (nickname STRING, tolerance FLOAT) USING parquet") - columns = sorted(spark.catalog.listColumns("tab1"), key=lambda c: c.name) - columnsDefault = \ - sorted(spark.catalog.listColumns("tab1", "default"), key=lambda c: c.name) - self.assertEquals(columns, columnsDefault) - self.assertEquals(len(columns), 2) - self.assertEquals(columns[0], Column( - name="age", - description=None, - dataType="int", - nullable=True, - isPartition=False, - isBucket=False)) - self.assertEquals(columns[1], Column( - name="name", - description=None, - dataType="string", - nullable=True, - isPartition=False, - isBucket=False)) - columns2 = \ - sorted(spark.catalog.listColumns("tab2", "some_db"), key=lambda c: c.name) - self.assertEquals(len(columns2), 2) - self.assertEquals(columns2[0], Column( - name="nickname", - description=None, - dataType="string", - nullable=True, - isPartition=False, - isBucket=False)) - self.assertEquals(columns2[1], Column( - name="tolerance", - description=None, - dataType="float", - nullable=True, - isPartition=False, - isBucket=False)) - self.assertRaisesRegexp( - AnalysisException, - "tab2", - lambda: spark.catalog.listColumns("tab2")) - self.assertRaisesRegexp( - AnalysisException, - "does_not_exist", - lambda: spark.catalog.listColumns("does_not_exist")) - - def test_cache(self): - spark = self.spark - with self.tempView("tab1", "tab2"): - spark.createDataFrame([(2, 2), (3, 3)]).createOrReplaceTempView("tab1") - spark.createDataFrame([(2, 2), (3, 3)]).createOrReplaceTempView("tab2") - self.assertFalse(spark.catalog.isCached("tab1")) - self.assertFalse(spark.catalog.isCached("tab2")) - spark.catalog.cacheTable("tab1") - self.assertTrue(spark.catalog.isCached("tab1")) - self.assertFalse(spark.catalog.isCached("tab2")) - spark.catalog.cacheTable("tab2") - spark.catalog.uncacheTable("tab1") - self.assertFalse(spark.catalog.isCached("tab1")) - self.assertTrue(spark.catalog.isCached("tab2")) - spark.catalog.clearCache() - self.assertFalse(spark.catalog.isCached("tab1")) - self.assertFalse(spark.catalog.isCached("tab2")) - self.assertRaisesRegexp( - AnalysisException, - "does_not_exist", - lambda: spark.catalog.isCached("does_not_exist")) - self.assertRaisesRegexp( - AnalysisException, - "does_not_exist", - lambda: spark.catalog.cacheTable("does_not_exist")) - self.assertRaisesRegexp( - AnalysisException, - "does_not_exist", - lambda: spark.catalog.uncacheTable("does_not_exist")) - - def test_read_text_file_list(self): - df = self.spark.read.text(['python/test_support/sql/text-test.txt', - 'python/test_support/sql/text-test.txt']) - count = df.count() - self.assertEquals(count, 4) - - def test_BinaryType_serialization(self): - # Pyrolite version <= 4.9 could not serialize BinaryType with Python3 SPARK-17808 - # The empty bytearray is test for SPARK-21534. - schema = StructType([StructField('mybytes', BinaryType())]) - data = [[bytearray(b'here is my data')], - [bytearray(b'and here is some more')], - [bytearray(b'')]] - df = self.spark.createDataFrame(data, schema=schema) - df.collect() - - # test for SPARK-16542 - def test_array_types(self): - # This test need to make sure that the Scala type selected is at least - # as large as the python's types. This is necessary because python's - # array types depend on C implementation on the machine. Therefore there - # is no machine independent correspondence between python's array types - # and Scala types. - # See: https://docs.python.org/2/library/array.html - - def assertCollectSuccess(typecode, value): - row = Row(myarray=array.array(typecode, [value])) - df = self.spark.createDataFrame([row]) - self.assertEqual(df.first()["myarray"][0], value) - - # supported string types - # - # String types in python's array are "u" for Py_UNICODE and "c" for char. - # "u" will be removed in python 4, and "c" is not supported in python 3. - supported_string_types = [] - if sys.version_info[0] < 4: - supported_string_types += ['u'] - # test unicode - assertCollectSuccess('u', u'a') - if sys.version_info[0] < 3: - supported_string_types += ['c'] - # test string - assertCollectSuccess('c', 'a') - - # supported float and double - # - # Test max, min, and precision for float and double, assuming IEEE 754 - # floating-point format. - supported_fractional_types = ['f', 'd'] - assertCollectSuccess('f', ctypes.c_float(1e+38).value) - assertCollectSuccess('f', ctypes.c_float(1e-38).value) - assertCollectSuccess('f', ctypes.c_float(1.123456).value) - assertCollectSuccess('d', sys.float_info.max) - assertCollectSuccess('d', sys.float_info.min) - assertCollectSuccess('d', sys.float_info.epsilon) - - # supported signed int types - # - # The size of C types changes with implementation, we need to make sure - # that there is no overflow error on the platform running this test. - supported_signed_int_types = list( - set(_array_signed_int_typecode_ctype_mappings.keys()) - .intersection(set(_array_type_mappings.keys()))) - for t in supported_signed_int_types: - ctype = _array_signed_int_typecode_ctype_mappings[t] - max_val = 2 ** (ctypes.sizeof(ctype) * 8 - 1) - assertCollectSuccess(t, max_val - 1) - assertCollectSuccess(t, -max_val) - - # supported unsigned int types - # - # JVM does not have unsigned types. We need to be very careful to make - # sure that there is no overflow error. - supported_unsigned_int_types = list( - set(_array_unsigned_int_typecode_ctype_mappings.keys()) - .intersection(set(_array_type_mappings.keys()))) - for t in supported_unsigned_int_types: - ctype = _array_unsigned_int_typecode_ctype_mappings[t] - assertCollectSuccess(t, 2 ** (ctypes.sizeof(ctype) * 8) - 1) - - # all supported types - # - # Make sure the types tested above: - # 1. are all supported types - # 2. cover all supported types - supported_types = (supported_string_types + - supported_fractional_types + - supported_signed_int_types + - supported_unsigned_int_types) - self.assertEqual(set(supported_types), set(_array_type_mappings.keys())) - - # all unsupported types - # - # Keys in _array_type_mappings is a complete list of all supported types, - # and types not in _array_type_mappings are considered unsupported. - # `array.typecodes` are not supported in python 2. - if sys.version_info[0] < 3: - all_types = set(['c', 'b', 'B', 'u', 'h', 'H', 'i', 'I', 'l', 'L', 'f', 'd']) - else: - all_types = set(array.typecodes) - unsupported_types = all_types - set(supported_types) - # test unsupported types - for t in unsupported_types: - with self.assertRaises(TypeError): - a = array.array(t) - self.spark.createDataFrame([Row(myarray=a)]).collect() - - def test_bucketed_write(self): - data = [ - (1, "foo", 3.0), (2, "foo", 5.0), - (3, "bar", -1.0), (4, "bar", 6.0), - ] - df = self.spark.createDataFrame(data, ["x", "y", "z"]) - - def count_bucketed_cols(names, table="pyspark_bucket"): - """Given a sequence of column names and a table name - query the catalog and return number o columns which are - used for bucketing - """ - cols = self.spark.catalog.listColumns(table) - num = len([c for c in cols if c.name in names and c.isBucket]) - return num - - with self.table("pyspark_bucket"): - # Test write with one bucketing column - df.write.bucketBy(3, "x").mode("overwrite").saveAsTable("pyspark_bucket") - self.assertEqual(count_bucketed_cols(["x"]), 1) - self.assertSetEqual(set(data), set(self.spark.table("pyspark_bucket").collect())) - - # Test write two bucketing columns - df.write.bucketBy(3, "x", "y").mode("overwrite").saveAsTable("pyspark_bucket") - self.assertEqual(count_bucketed_cols(["x", "y"]), 2) - self.assertSetEqual(set(data), set(self.spark.table("pyspark_bucket").collect())) - - # Test write with bucket and sort - df.write.bucketBy(2, "x").sortBy("z").mode("overwrite").saveAsTable("pyspark_bucket") - self.assertEqual(count_bucketed_cols(["x"]), 1) - self.assertSetEqual(set(data), set(self.spark.table("pyspark_bucket").collect())) - - # Test write with a list of columns - df.write.bucketBy(3, ["x", "y"]).mode("overwrite").saveAsTable("pyspark_bucket") - self.assertEqual(count_bucketed_cols(["x", "y"]), 2) - self.assertSetEqual(set(data), set(self.spark.table("pyspark_bucket").collect())) - - # Test write with bucket and sort with a list of columns - (df.write.bucketBy(2, "x") - .sortBy(["y", "z"]) - .mode("overwrite").saveAsTable("pyspark_bucket")) - self.assertSetEqual(set(data), set(self.spark.table("pyspark_bucket").collect())) - - # Test write with bucket and sort with multiple columns - (df.write.bucketBy(2, "x") - .sortBy("y", "z") - .mode("overwrite").saveAsTable("pyspark_bucket")) - self.assertSetEqual(set(data), set(self.spark.table("pyspark_bucket").collect())) - - def _to_pandas(self): - from datetime import datetime, date - schema = StructType().add("a", IntegerType()).add("b", StringType())\ - .add("c", BooleanType()).add("d", FloatType())\ - .add("dt", DateType()).add("ts", TimestampType()) - data = [ - (1, "foo", True, 3.0, date(1969, 1, 1), datetime(1969, 1, 1, 1, 1, 1)), - (2, "foo", True, 5.0, None, None), - (3, "bar", False, -1.0, date(2012, 3, 3), datetime(2012, 3, 3, 3, 3, 3)), - (4, "bar", False, 6.0, date(2100, 4, 4), datetime(2100, 4, 4, 4, 4, 4)), - ] - df = self.spark.createDataFrame(data, schema) - return df.toPandas() - - @unittest.skipIf(not _have_pandas, _pandas_requirement_message) - def test_to_pandas(self): - import numpy as np - pdf = self._to_pandas() - types = pdf.dtypes - self.assertEquals(types[0], np.int32) - self.assertEquals(types[1], np.object) - self.assertEquals(types[2], np.bool) - self.assertEquals(types[3], np.float32) - self.assertEquals(types[4], np.object) # datetime.date - self.assertEquals(types[5], 'datetime64[ns]') - - @unittest.skipIf(_have_pandas, "Required Pandas was found.") - def test_to_pandas_required_pandas_not_found(self): - with QuietTest(self.sc): - with self.assertRaisesRegexp(ImportError, 'Pandas >= .* must be installed'): - self._to_pandas() - - @unittest.skipIf(not _have_pandas, _pandas_requirement_message) - def test_to_pandas_avoid_astype(self): - import numpy as np - schema = StructType().add("a", IntegerType()).add("b", StringType())\ - .add("c", IntegerType()) - data = [(1, "foo", 16777220), (None, "bar", None)] - df = self.spark.createDataFrame(data, schema) - types = df.toPandas().dtypes - self.assertEquals(types[0], np.float64) # doesn't convert to np.int32 due to NaN value. - self.assertEquals(types[1], np.object) - self.assertEquals(types[2], np.float64) - - def test_create_dataframe_from_array_of_long(self): - import array - data = [Row(longarray=array.array('l', [-9223372036854775808, 0, 9223372036854775807]))] - df = self.spark.createDataFrame(data) - self.assertEqual(df.first(), Row(longarray=[-9223372036854775808, 0, 9223372036854775807])) - - @unittest.skipIf(not _have_pandas, _pandas_requirement_message) - def test_create_dataframe_from_pandas_with_timestamp(self): - import pandas as pd - from datetime import datetime - pdf = pd.DataFrame({"ts": [datetime(2017, 10, 31, 1, 1, 1)], - "d": [pd.Timestamp.now().date()]})[["d", "ts"]] - # test types are inferred correctly without specifying schema - df = self.spark.createDataFrame(pdf) - self.assertTrue(isinstance(df.schema['ts'].dataType, TimestampType)) - self.assertTrue(isinstance(df.schema['d'].dataType, DateType)) - # test with schema will accept pdf as input - df = self.spark.createDataFrame(pdf, schema="d date, ts timestamp") - self.assertTrue(isinstance(df.schema['ts'].dataType, TimestampType)) - self.assertTrue(isinstance(df.schema['d'].dataType, DateType)) - - @unittest.skipIf(_have_pandas, "Required Pandas was found.") - def test_create_dataframe_required_pandas_not_found(self): - with QuietTest(self.sc): - with self.assertRaisesRegexp( - ImportError, - "(Pandas >= .* must be installed|No module named '?pandas'?)"): - import pandas as pd - from datetime import datetime - pdf = pd.DataFrame({"ts": [datetime(2017, 10, 31, 1, 1, 1)], - "d": [pd.Timestamp.now().date()]}) - self.spark.createDataFrame(pdf) - - # Regression test for SPARK-23360 - @unittest.skipIf(not _have_pandas, _pandas_requirement_message) - def test_create_dateframe_from_pandas_with_dst(self): - import pandas as pd - from datetime import datetime - - pdf = pd.DataFrame({'time': [datetime(2015, 10, 31, 22, 30)]}) - - df = self.spark.createDataFrame(pdf) - self.assertPandasEqual(pdf, df.toPandas()) - - orig_env_tz = os.environ.get('TZ', None) - try: - tz = 'America/Los_Angeles' - os.environ['TZ'] = tz - time.tzset() - with self.sql_conf({'spark.sql.session.timeZone': tz}): - df = self.spark.createDataFrame(pdf) - self.assertPandasEqual(pdf, df.toPandas()) - finally: - del os.environ['TZ'] - if orig_env_tz is not None: - os.environ['TZ'] = orig_env_tz - time.tzset() - - def test_sort_with_nulls_order(self): - from pyspark.sql import functions - - df = self.spark.createDataFrame( - [('Tom', 80), (None, 60), ('Alice', 50)], ["name", "height"]) - self.assertEquals( - df.select(df.name).orderBy(functions.asc_nulls_first('name')).collect(), - [Row(name=None), Row(name=u'Alice'), Row(name=u'Tom')]) - self.assertEquals( - df.select(df.name).orderBy(functions.asc_nulls_last('name')).collect(), - [Row(name=u'Alice'), Row(name=u'Tom'), Row(name=None)]) - self.assertEquals( - df.select(df.name).orderBy(functions.desc_nulls_first('name')).collect(), - [Row(name=None), Row(name=u'Tom'), Row(name=u'Alice')]) - self.assertEquals( - df.select(df.name).orderBy(functions.desc_nulls_last('name')).collect(), - [Row(name=u'Tom'), Row(name=u'Alice'), Row(name=None)]) - - def test_json_sampling_ratio(self): - rdd = self.spark.sparkContext.range(0, 100, 1, 1) \ - .map(lambda x: '{"a":0.1}' if x == 1 else '{"a":%s}' % str(x)) - schema = self.spark.read.option('inferSchema', True) \ - .option('samplingRatio', 0.5) \ - .json(rdd).schema - self.assertEquals(schema, StructType([StructField("a", LongType(), True)])) - - def test_csv_sampling_ratio(self): - rdd = self.spark.sparkContext.range(0, 100, 1, 1) \ - .map(lambda x: '0.1' if x == 1 else str(x)) - schema = self.spark.read.option('inferSchema', True)\ - .csv(rdd, samplingRatio=0.5).schema - self.assertEquals(schema, StructType([StructField("_c0", IntegerType(), True)])) - - def test_checking_csv_header(self): - path = tempfile.mkdtemp() - shutil.rmtree(path) - try: - self.spark.createDataFrame([[1, 1000], [2000, 2]])\ - .toDF('f1', 'f2').write.option("header", "true").csv(path) - schema = StructType([ - StructField('f2', IntegerType(), nullable=True), - StructField('f1', IntegerType(), nullable=True)]) - df = self.spark.read.option('header', 'true').schema(schema)\ - .csv(path, enforceSchema=False) - self.assertRaisesRegexp( - Exception, - "CSV header does not conform to the schema", - lambda: df.collect()) - finally: - shutil.rmtree(path) - - def test_ignore_column_of_all_nulls(self): - path = tempfile.mkdtemp() - shutil.rmtree(path) - try: - df = self.spark.createDataFrame([["""{"a":null, "b":1, "c":3.0}"""], - ["""{"a":null, "b":null, "c":"string"}"""], - ["""{"a":null, "b":null, "c":null}"""]]) - df.write.text(path) - schema = StructType([ - StructField('b', LongType(), nullable=True), - StructField('c', StringType(), nullable=True)]) - readback = self.spark.read.json(path, dropFieldIfAllNull=True) - self.assertEquals(readback.schema, schema) - finally: - shutil.rmtree(path) - - # SPARK-24721 - @unittest.skipIf(not _test_compiled, _test_not_compiled_message) - def test_datasource_with_udf(self): - from pyspark.sql.functions import udf, lit, col - - path = tempfile.mkdtemp() - shutil.rmtree(path) - - try: - self.spark.range(1).write.mode("overwrite").format('csv').save(path) - filesource_df = self.spark.read.option('inferSchema', True).csv(path).toDF('i') - datasource_df = self.spark.read \ - .format("org.apache.spark.sql.sources.SimpleScanSource") \ - .option('from', 0).option('to', 1).load().toDF('i') - datasource_v2_df = self.spark.read \ - .format("org.apache.spark.sql.sources.v2.SimpleDataSourceV2") \ - .load().toDF('i', 'j') - - c1 = udf(lambda x: x + 1, 'int')(lit(1)) - c2 = udf(lambda x: x + 1, 'int')(col('i')) - - f1 = udf(lambda x: False, 'boolean')(lit(1)) - f2 = udf(lambda x: False, 'boolean')(col('i')) - - for df in [filesource_df, datasource_df, datasource_v2_df]: - result = df.withColumn('c', c1) - expected = df.withColumn('c', lit(2)) - self.assertEquals(expected.collect(), result.collect()) - - for df in [filesource_df, datasource_df, datasource_v2_df]: - result = df.withColumn('c', c2) - expected = df.withColumn('c', col('i') + 1) - self.assertEquals(expected.collect(), result.collect()) - - for df in [filesource_df, datasource_df, datasource_v2_df]: - for f in [f1, f2]: - result = df.filter(f) - self.assertEquals(0, result.count()) - finally: - shutil.rmtree(path) - - def test_repr_behaviors(self): - import re - pattern = re.compile(r'^ *\|', re.MULTILINE) - df = self.spark.createDataFrame([(1, "1"), (22222, "22222")], ("key", "value")) - - # test when eager evaluation is enabled and _repr_html_ will not be called - with self.sql_conf({"spark.sql.repl.eagerEval.enabled": True}): - expected1 = """+-----+-----+ - || key|value| - |+-----+-----+ - || 1| 1| - ||22222|22222| - |+-----+-----+ - |""" - self.assertEquals(re.sub(pattern, '', expected1), df.__repr__()) - with self.sql_conf({"spark.sql.repl.eagerEval.truncate": 3}): - expected2 = """+---+-----+ - ||key|value| - |+---+-----+ - || 1| 1| - ||222| 222| - |+---+-----+ - |""" - self.assertEquals(re.sub(pattern, '', expected2), df.__repr__()) - with self.sql_conf({"spark.sql.repl.eagerEval.maxNumRows": 1}): - expected3 = """+---+-----+ - ||key|value| - |+---+-----+ - || 1| 1| - |+---+-----+ - |only showing top 1 row - |""" - self.assertEquals(re.sub(pattern, '', expected3), df.__repr__()) - - # test when eager evaluation is enabled and _repr_html_ will be called - with self.sql_conf({"spark.sql.repl.eagerEval.enabled": True}): - expected1 = """ - | - | - | - |
    keyvalue
    11
    2222222222
    - |""" - self.assertEquals(re.sub(pattern, '', expected1), df._repr_html_()) - with self.sql_conf({"spark.sql.repl.eagerEval.truncate": 3}): - expected2 = """ - | - | - | - |
    keyvalue
    11
    222222
    - |""" - self.assertEquals(re.sub(pattern, '', expected2), df._repr_html_()) - with self.sql_conf({"spark.sql.repl.eagerEval.maxNumRows": 1}): - expected3 = """ - | - | - |
    keyvalue
    11
    - |only showing top 1 row - |""" - self.assertEquals(re.sub(pattern, '', expected3), df._repr_html_()) - - # test when eager evaluation is disabled and _repr_html_ will be called - with self.sql_conf({"spark.sql.repl.eagerEval.enabled": False}): - expected = "DataFrame[key: bigint, value: string]" - self.assertEquals(None, df._repr_html_()) - self.assertEquals(expected, df.__repr__()) - with self.sql_conf({"spark.sql.repl.eagerEval.truncate": 3}): - self.assertEquals(None, df._repr_html_()) - self.assertEquals(expected, df.__repr__()) - with self.sql_conf({"spark.sql.repl.eagerEval.maxNumRows": 1}): - self.assertEquals(None, df._repr_html_()) - self.assertEquals(expected, df.__repr__()) - - # SPARK-25591 - def test_same_accumulator_in_udfs(self): - from pyspark.sql.functions import udf - - data_schema = StructType([StructField("a", IntegerType(), True), - StructField("b", IntegerType(), True)]) - data = self.spark.createDataFrame([[1, 2]], schema=data_schema) - - test_accum = self.sc.accumulator(0) - - def first_udf(x): - test_accum.add(1) - return x - - def second_udf(x): - test_accum.add(100) - return x - - func_udf = udf(first_udf, IntegerType()) - func_udf2 = udf(second_udf, IntegerType()) - data = data.withColumn("out1", func_udf(data["a"])) - data = data.withColumn("out2", func_udf2(data["b"])) - data.collect() - self.assertEqual(test_accum.value, 101) - - -class HiveSparkSubmitTests(SparkSubmitTests): - - @classmethod - def setUpClass(cls): - # get a SparkContext to check for availability of Hive - sc = SparkContext('local[4]', cls.__name__) - cls.hive_available = True - try: - sc._jvm.org.apache.hadoop.hive.conf.HiveConf() - except py4j.protocol.Py4JError: - cls.hive_available = False - except TypeError: - cls.hive_available = False - finally: - # we don't need this SparkContext for the test - sc.stop() - - def setUp(self): - super(HiveSparkSubmitTests, self).setUp() - if not self.hive_available: - self.skipTest("Hive is not available.") - - def test_hivecontext(self): - # This test checks that HiveContext is using Hive metastore (SPARK-16224). - # It sets a metastore url and checks if there is a derby dir created by - # Hive metastore. If this derby dir exists, HiveContext is using - # Hive metastore. - metastore_path = os.path.join(tempfile.mkdtemp(), "spark16224_metastore_db") - metastore_URL = "jdbc:derby:;databaseName=" + metastore_path + ";create=true" - hive_site_dir = os.path.join(self.programDir, "conf") - hive_site_file = self.createTempFile("hive-site.xml", (""" - | - | - | javax.jdo.option.ConnectionURL - | %s - | - | - """ % metastore_URL).lstrip(), "conf") - script = self.createTempFile("test.py", """ - |import os - | - |from pyspark.conf import SparkConf - |from pyspark.context import SparkContext - |from pyspark.sql import HiveContext - | - |conf = SparkConf() - |sc = SparkContext(conf=conf) - |hive_context = HiveContext(sc) - |print(hive_context.sql("show databases").collect()) - """) - proc = subprocess.Popen( - self.sparkSubmit + ["--master", "local-cluster[1,1,1024]", - "--driver-class-path", hive_site_dir, script], - stdout=subprocess.PIPE) - out, err = proc.communicate() - self.assertEqual(0, proc.returncode) - self.assertIn("default", out.decode('utf-8')) - self.assertTrue(os.path.exists(metastore_path)) - - -class SQLTests2(ReusedSQLTestCase): - - # We can't include this test into SQLTests because we will stop class's SparkContext and cause - # other tests failed. - def test_sparksession_with_stopped_sparkcontext(self): - self.sc.stop() - sc = SparkContext('local[4]', self.sc.appName) - spark = SparkSession.builder.getOrCreate() - try: - df = spark.createDataFrame([(1, 2)], ["c", "c"]) - df.collect() - finally: - spark.stop() - sc.stop() - - -class QueryExecutionListenerTests(unittest.TestCase, SQLTestUtils): - # These tests are separate because it uses 'spark.sql.queryExecutionListeners' which is - # static and immutable. This can't be set or unset, for example, via `spark.conf`. - - @classmethod - def setUpClass(cls): - import glob - from pyspark.find_spark_home import _find_spark_home - - SPARK_HOME = _find_spark_home() - filename_pattern = ( - "sql/core/target/scala-*/test-classes/org/apache/spark/sql/" - "TestQueryExecutionListener.class") - cls.has_listener = bool(glob.glob(os.path.join(SPARK_HOME, filename_pattern))) - - if cls.has_listener: - # Note that 'spark.sql.queryExecutionListeners' is a static immutable configuration. - cls.spark = SparkSession.builder \ - .master("local[4]") \ - .appName(cls.__name__) \ - .config( - "spark.sql.queryExecutionListeners", - "org.apache.spark.sql.TestQueryExecutionListener") \ - .getOrCreate() - - def setUp(self): - if not self.has_listener: - raise self.skipTest( - "'org.apache.spark.sql.TestQueryExecutionListener' is not " - "available. Will skip the related tests.") - - @classmethod - def tearDownClass(cls): - if hasattr(cls, "spark"): - cls.spark.stop() - - def tearDown(self): - self.spark._jvm.OnSuccessCall.clear() - - def test_query_execution_listener_on_collect(self): - self.assertFalse( - self.spark._jvm.OnSuccessCall.isCalled(), - "The callback from the query execution listener should not be called before 'collect'") - self.spark.sql("SELECT * FROM range(1)").collect() - self.assertTrue( - self.spark._jvm.OnSuccessCall.isCalled(), - "The callback from the query execution listener should be called after 'collect'") - - @unittest.skipIf( - not _have_pandas or not _have_pyarrow, - _pandas_requirement_message or _pyarrow_requirement_message) - def test_query_execution_listener_on_collect_with_arrow(self): - with self.sql_conf({"spark.sql.execution.arrow.enabled": True}): - self.assertFalse( - self.spark._jvm.OnSuccessCall.isCalled(), - "The callback from the query execution listener should not be " - "called before 'toPandas'") - self.spark.sql("SELECT * FROM range(1)").toPandas() - self.assertTrue( - self.spark._jvm.OnSuccessCall.isCalled(), - "The callback from the query execution listener should be called after 'toPandas'") - - -class SparkExtensionsTest(unittest.TestCase): - # These tests are separate because it uses 'spark.sql.extensions' which is - # static and immutable. This can't be set or unset, for example, via `spark.conf`. - - @classmethod - def setUpClass(cls): - import glob - from pyspark.find_spark_home import _find_spark_home - - SPARK_HOME = _find_spark_home() - filename_pattern = ( - "sql/core/target/scala-*/test-classes/org/apache/spark/sql/" - "SparkSessionExtensionSuite.class") - if not glob.glob(os.path.join(SPARK_HOME, filename_pattern)): - raise unittest.SkipTest( - "'org.apache.spark.sql.SparkSessionExtensionSuite' is not " - "available. Will skip the related tests.") - - # Note that 'spark.sql.extensions' is a static immutable configuration. - cls.spark = SparkSession.builder \ - .master("local[4]") \ - .appName(cls.__name__) \ - .config( - "spark.sql.extensions", - "org.apache.spark.sql.MyExtensions") \ - .getOrCreate() - - @classmethod - def tearDownClass(cls): - cls.spark.stop() - - def test_use_custom_class_for_extensions(self): - self.assertTrue( - self.spark._jsparkSession.sessionState().planner().strategies().contains( - self.spark._jvm.org.apache.spark.sql.MySparkStrategy(self.spark._jsparkSession)), - "MySparkStrategy not found in active planner strategies") - self.assertTrue( - self.spark._jsparkSession.sessionState().analyzer().extendedResolutionRules().contains( - self.spark._jvm.org.apache.spark.sql.MyRule(self.spark._jsparkSession)), - "MyRule not found in extended resolution rules") - - -class SparkSessionTests(PySparkTestCase): - - # This test is separate because it's closely related with session's start and stop. - # See SPARK-23228. - def test_set_jvm_default_session(self): - spark = SparkSession.builder.getOrCreate() - try: - self.assertTrue(spark._jvm.SparkSession.getDefaultSession().isDefined()) - finally: - spark.stop() - self.assertTrue(spark._jvm.SparkSession.getDefaultSession().isEmpty()) - - def test_jvm_default_session_already_set(self): - # Here, we assume there is the default session already set in JVM. - jsession = self.sc._jvm.SparkSession(self.sc._jsc.sc()) - self.sc._jvm.SparkSession.setDefaultSession(jsession) - - spark = SparkSession.builder.getOrCreate() - try: - self.assertTrue(spark._jvm.SparkSession.getDefaultSession().isDefined()) - # The session should be the same with the exiting one. - self.assertTrue(jsession.equals(spark._jvm.SparkSession.getDefaultSession().get())) - finally: - spark.stop() - - -class SparkSessionTests2(unittest.TestCase): - - def test_active_session(self): - spark = SparkSession.builder \ - .master("local") \ - .getOrCreate() - try: - activeSession = SparkSession.getActiveSession() - df = activeSession.createDataFrame([(1, 'Alice')], ['age', 'name']) - self.assertEqual(df.collect(), [Row(age=1, name=u'Alice')]) - finally: - spark.stop() - - def test_get_active_session_when_no_active_session(self): - active = SparkSession.getActiveSession() - self.assertEqual(active, None) - spark = SparkSession.builder \ - .master("local") \ - .getOrCreate() - active = SparkSession.getActiveSession() - self.assertEqual(active, spark) - spark.stop() - active = SparkSession.getActiveSession() - self.assertEqual(active, None) - - def test_SparkSession(self): - spark = SparkSession.builder \ - .master("local") \ - .config("some-config", "v2") \ - .getOrCreate() - try: - self.assertEqual(spark.conf.get("some-config"), "v2") - self.assertEqual(spark.sparkContext._conf.get("some-config"), "v2") - self.assertEqual(spark.version, spark.sparkContext.version) - spark.sql("CREATE DATABASE test_db") - spark.catalog.setCurrentDatabase("test_db") - self.assertEqual(spark.catalog.currentDatabase(), "test_db") - spark.sql("CREATE TABLE table1 (name STRING, age INT) USING parquet") - self.assertEqual(spark.table("table1").columns, ['name', 'age']) - self.assertEqual(spark.range(3).count(), 3) - finally: - spark.stop() - - def test_global_default_session(self): - spark = SparkSession.builder \ - .master("local") \ - .getOrCreate() - try: - self.assertEqual(SparkSession.builder.getOrCreate(), spark) - finally: - spark.stop() - - def test_default_and_active_session(self): - spark = SparkSession.builder \ - .master("local") \ - .getOrCreate() - activeSession = spark._jvm.SparkSession.getActiveSession() - defaultSession = spark._jvm.SparkSession.getDefaultSession() - try: - self.assertEqual(activeSession, defaultSession) - finally: - spark.stop() - - def test_config_option_propagated_to_existing_session(self): - session1 = SparkSession.builder \ - .master("local") \ - .config("spark-config1", "a") \ - .getOrCreate() - self.assertEqual(session1.conf.get("spark-config1"), "a") - session2 = SparkSession.builder \ - .config("spark-config1", "b") \ - .getOrCreate() - try: - self.assertEqual(session1, session2) - self.assertEqual(session1.conf.get("spark-config1"), "b") - finally: - session1.stop() - - def test_new_session(self): - session = SparkSession.builder \ - .master("local") \ - .getOrCreate() - newSession = session.newSession() - try: - self.assertNotEqual(session, newSession) - finally: - session.stop() - newSession.stop() - - def test_create_new_session_if_old_session_stopped(self): - session = SparkSession.builder \ - .master("local") \ - .getOrCreate() - session.stop() - newSession = SparkSession.builder \ - .master("local") \ - .getOrCreate() - try: - self.assertNotEqual(session, newSession) - finally: - newSession.stop() - - def test_active_session_with_None_and_not_None_context(self): - from pyspark.context import SparkContext - from pyspark.conf import SparkConf - sc = None - session = None - try: - sc = SparkContext._active_spark_context - self.assertEqual(sc, None) - activeSession = SparkSession.getActiveSession() - self.assertEqual(activeSession, None) - sparkConf = SparkConf() - sc = SparkContext.getOrCreate(sparkConf) - activeSession = sc._jvm.SparkSession.getActiveSession() - self.assertFalse(activeSession.isDefined()) - session = SparkSession(sc) - activeSession = sc._jvm.SparkSession.getActiveSession() - self.assertTrue(activeSession.isDefined()) - activeSession2 = SparkSession.getActiveSession() - self.assertNotEqual(activeSession2, None) - finally: - if session is not None: - session.stop() - if sc is not None: - sc.stop() - - -class SparkSessionTests3(ReusedSQLTestCase): - - def test_get_active_session_after_create_dataframe(self): - session2 = None - try: - activeSession1 = SparkSession.getActiveSession() - session1 = self.spark - self.assertEqual(session1, activeSession1) - session2 = self.spark.newSession() - activeSession2 = SparkSession.getActiveSession() - self.assertEqual(session1, activeSession2) - self.assertNotEqual(session2, activeSession2) - session2.createDataFrame([(1, 'Alice')], ['age', 'name']) - activeSession3 = SparkSession.getActiveSession() - self.assertEqual(session2, activeSession3) - session1.createDataFrame([(1, 'Alice')], ['age', 'name']) - activeSession4 = SparkSession.getActiveSession() - self.assertEqual(session1, activeSession4) - finally: - if session2 is not None: - session2.stop() - - -class UDFInitializationTests(unittest.TestCase): - def tearDown(self): - if SparkSession._instantiatedSession is not None: - SparkSession._instantiatedSession.stop() - - if SparkContext._active_spark_context is not None: - SparkContext._active_spark_context.stop() - - def test_udf_init_shouldnt_initialize_context(self): - from pyspark.sql.functions import UserDefinedFunction - - UserDefinedFunction(lambda x: x, StringType()) - - self.assertIsNone( - SparkContext._active_spark_context, - "SparkContext shouldn't be initialized when UserDefinedFunction is created." - ) - self.assertIsNone( - SparkSession._instantiatedSession, - "SparkSession shouldn't be initialized when UserDefinedFunction is created." - ) - - -class HiveContextSQLTests(ReusedPySparkTestCase): - - @classmethod - def setUpClass(cls): - ReusedPySparkTestCase.setUpClass() - cls.tempdir = tempfile.NamedTemporaryFile(delete=False) - cls.hive_available = True - try: - cls.sc._jvm.org.apache.hadoop.hive.conf.HiveConf() - except py4j.protocol.Py4JError: - cls.hive_available = False - except TypeError: - cls.hive_available = False - os.unlink(cls.tempdir.name) - if cls.hive_available: - cls.spark = HiveContext._createForTesting(cls.sc) - cls.testData = [Row(key=i, value=str(i)) for i in range(100)] - cls.df = cls.sc.parallelize(cls.testData).toDF() - - def setUp(self): - if not self.hive_available: - self.skipTest("Hive is not available.") - - @classmethod - def tearDownClass(cls): - ReusedPySparkTestCase.tearDownClass() - shutil.rmtree(cls.tempdir.name, ignore_errors=True) - - def test_save_and_load_table(self): - df = self.df - tmpPath = tempfile.mkdtemp() - shutil.rmtree(tmpPath) - df.write.saveAsTable("savedJsonTable", "json", "append", path=tmpPath) - actual = self.spark.createExternalTable("externalJsonTable", tmpPath, "json") - self.assertEqual(sorted(df.collect()), - sorted(self.spark.sql("SELECT * FROM savedJsonTable").collect())) - self.assertEqual(sorted(df.collect()), - sorted(self.spark.sql("SELECT * FROM externalJsonTable").collect())) - self.assertEqual(sorted(df.collect()), sorted(actual.collect())) - self.spark.sql("DROP TABLE externalJsonTable") - - df.write.saveAsTable("savedJsonTable", "json", "overwrite", path=tmpPath) - schema = StructType([StructField("value", StringType(), True)]) - actual = self.spark.createExternalTable("externalJsonTable", source="json", - schema=schema, path=tmpPath, - noUse="this options will not be used") - self.assertEqual(sorted(df.collect()), - sorted(self.spark.sql("SELECT * FROM savedJsonTable").collect())) - self.assertEqual(sorted(df.select("value").collect()), - sorted(self.spark.sql("SELECT * FROM externalJsonTable").collect())) - self.assertEqual(sorted(df.select("value").collect()), sorted(actual.collect())) - self.spark.sql("DROP TABLE savedJsonTable") - self.spark.sql("DROP TABLE externalJsonTable") - - defaultDataSourceName = self.spark.getConf("spark.sql.sources.default", - "org.apache.spark.sql.parquet") - self.spark.sql("SET spark.sql.sources.default=org.apache.spark.sql.json") - df.write.saveAsTable("savedJsonTable", path=tmpPath, mode="overwrite") - actual = self.spark.createExternalTable("externalJsonTable", path=tmpPath) - self.assertEqual(sorted(df.collect()), - sorted(self.spark.sql("SELECT * FROM savedJsonTable").collect())) - self.assertEqual(sorted(df.collect()), - sorted(self.spark.sql("SELECT * FROM externalJsonTable").collect())) - self.assertEqual(sorted(df.collect()), sorted(actual.collect())) - self.spark.sql("DROP TABLE savedJsonTable") - self.spark.sql("DROP TABLE externalJsonTable") - self.spark.sql("SET spark.sql.sources.default=" + defaultDataSourceName) - - shutil.rmtree(tmpPath) - - def test_window_functions(self): - df = self.spark.createDataFrame([(1, "1"), (2, "2"), (1, "2"), (1, "2")], ["key", "value"]) - w = Window.partitionBy("value").orderBy("key") - from pyspark.sql import functions as F - sel = df.select(df.value, df.key, - F.max("key").over(w.rowsBetween(0, 1)), - F.min("key").over(w.rowsBetween(0, 1)), - F.count("key").over(w.rowsBetween(float('-inf'), float('inf'))), - F.row_number().over(w), - F.rank().over(w), - F.dense_rank().over(w), - F.ntile(2).over(w)) - rs = sorted(sel.collect()) - expected = [ - ("1", 1, 1, 1, 1, 1, 1, 1, 1), - ("2", 1, 1, 1, 3, 1, 1, 1, 1), - ("2", 1, 2, 1, 3, 2, 1, 1, 1), - ("2", 2, 2, 2, 3, 3, 3, 2, 2) - ] - for r, ex in zip(rs, expected): - self.assertEqual(tuple(r), ex[:len(r)]) - - def test_window_functions_without_partitionBy(self): - df = self.spark.createDataFrame([(1, "1"), (2, "2"), (1, "2"), (1, "2")], ["key", "value"]) - w = Window.orderBy("key", df.value) - from pyspark.sql import functions as F - sel = df.select(df.value, df.key, - F.max("key").over(w.rowsBetween(0, 1)), - F.min("key").over(w.rowsBetween(0, 1)), - F.count("key").over(w.rowsBetween(float('-inf'), float('inf'))), - F.row_number().over(w), - F.rank().over(w), - F.dense_rank().over(w), - F.ntile(2).over(w)) - rs = sorted(sel.collect()) - expected = [ - ("1", 1, 1, 1, 4, 1, 1, 1, 1), - ("2", 1, 1, 1, 4, 2, 2, 2, 1), - ("2", 1, 2, 1, 4, 3, 2, 2, 2), - ("2", 2, 2, 2, 4, 4, 4, 3, 2) - ] - for r, ex in zip(rs, expected): - self.assertEqual(tuple(r), ex[:len(r)]) - - def test_window_functions_cumulative_sum(self): - df = self.spark.createDataFrame([("one", 1), ("two", 2)], ["key", "value"]) - from pyspark.sql import functions as F - - # Test cumulative sum - sel = df.select( - df.key, - F.sum(df.value).over(Window.rowsBetween(Window.unboundedPreceding, 0))) - rs = sorted(sel.collect()) - expected = [("one", 1), ("two", 3)] - for r, ex in zip(rs, expected): - self.assertEqual(tuple(r), ex[:len(r)]) - - # Test boundary values less than JVM's Long.MinValue and make sure we don't overflow - sel = df.select( - df.key, - F.sum(df.value).over(Window.rowsBetween(Window.unboundedPreceding - 1, 0))) - rs = sorted(sel.collect()) - expected = [("one", 1), ("two", 3)] - for r, ex in zip(rs, expected): - self.assertEqual(tuple(r), ex[:len(r)]) - - # Test boundary values greater than JVM's Long.MaxValue and make sure we don't overflow - frame_end = Window.unboundedFollowing + 1 - sel = df.select( - df.key, - F.sum(df.value).over(Window.rowsBetween(Window.currentRow, frame_end))) - rs = sorted(sel.collect()) - expected = [("one", 3), ("two", 2)] - for r, ex in zip(rs, expected): - self.assertEqual(tuple(r), ex[:len(r)]) - - def test_collect_functions(self): - df = self.spark.createDataFrame([(1, "1"), (2, "2"), (1, "2"), (1, "2")], ["key", "value"]) - from pyspark.sql import functions - - self.assertEqual( - sorted(df.select(functions.collect_set(df.key).alias('r')).collect()[0].r), - [1, 2]) - self.assertEqual( - sorted(df.select(functions.collect_list(df.key).alias('r')).collect()[0].r), - [1, 1, 1, 2]) - self.assertEqual( - sorted(df.select(functions.collect_set(df.value).alias('r')).collect()[0].r), - ["1", "2"]) - self.assertEqual( - sorted(df.select(functions.collect_list(df.value).alias('r')).collect()[0].r), - ["1", "2", "2", "2"]) - - def test_limit_and_take(self): - df = self.spark.range(1, 1000, numPartitions=10) - - def assert_runs_only_one_job_stage_and_task(job_group_name, f): - tracker = self.sc.statusTracker() - self.sc.setJobGroup(job_group_name, description="") - f() - jobs = tracker.getJobIdsForGroup(job_group_name) - self.assertEqual(1, len(jobs)) - stages = tracker.getJobInfo(jobs[0]).stageIds - self.assertEqual(1, len(stages)) - self.assertEqual(1, tracker.getStageInfo(stages[0]).numTasks) - - # Regression test for SPARK-10731: take should delegate to Scala implementation - assert_runs_only_one_job_stage_and_task("take", lambda: df.take(1)) - # Regression test for SPARK-17514: limit(n).collect() should the perform same as take(n) - assert_runs_only_one_job_stage_and_task("collect_limit", lambda: df.limit(1).collect()) - - def test_datetime_functions(self): - from pyspark.sql import functions - from datetime import date, datetime - df = self.spark.range(1).selectExpr("'2017-01-22' as dateCol") - parse_result = df.select(functions.to_date(functions.col("dateCol"))).first() - self.assertEquals(date(2017, 1, 22), parse_result['to_date(`dateCol`)']) - - @unittest.skipIf(sys.version_info < (3, 3), "Unittest < 3.3 doesn't support mocking") - def test_unbounded_frames(self): - from unittest.mock import patch - from pyspark.sql import functions as F - from pyspark.sql import window - import importlib - - df = self.spark.range(0, 3) - - def rows_frame_match(): - return "ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING" in df.select( - F.count("*").over(window.Window.rowsBetween(-sys.maxsize, sys.maxsize)) - ).columns[0] - - def range_frame_match(): - return "RANGE BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING" in df.select( - F.count("*").over(window.Window.rangeBetween(-sys.maxsize, sys.maxsize)) - ).columns[0] - - with patch("sys.maxsize", 2 ** 31 - 1): - importlib.reload(window) - self.assertTrue(rows_frame_match()) - self.assertTrue(range_frame_match()) - - with patch("sys.maxsize", 2 ** 63 - 1): - importlib.reload(window) - self.assertTrue(rows_frame_match()) - self.assertTrue(range_frame_match()) - - with patch("sys.maxsize", 2 ** 127 - 1): - importlib.reload(window) - self.assertTrue(rows_frame_match()) - self.assertTrue(range_frame_match()) - - importlib.reload(window) - - -class DataTypeVerificationTests(unittest.TestCase): - - def test_verify_type_exception_msg(self): - self.assertRaisesRegexp( - ValueError, - "test_name", - lambda: _make_type_verifier(StringType(), nullable=False, name="test_name")(None)) - - schema = StructType([StructField('a', StructType([StructField('b', IntegerType())]))]) - self.assertRaisesRegexp( - TypeError, - "field b in field a", - lambda: _make_type_verifier(schema)([["data"]])) - - def test_verify_type_ok_nullable(self): - obj = None - types = [IntegerType(), FloatType(), StringType(), StructType([])] - for data_type in types: - try: - _make_type_verifier(data_type, nullable=True)(obj) - except Exception: - self.fail("verify_type(%s, %s, nullable=True)" % (obj, data_type)) - - def test_verify_type_not_nullable(self): - import array - import datetime - import decimal - - schema = StructType([ - StructField('s', StringType(), nullable=False), - StructField('i', IntegerType(), nullable=True)]) - - class MyObj: - def __init__(self, **kwargs): - for k, v in kwargs.items(): - setattr(self, k, v) - - # obj, data_type - success_spec = [ - # String - ("", StringType()), - (u"", StringType()), - (1, StringType()), - (1.0, StringType()), - ([], StringType()), - ({}, StringType()), - - # UDT - (ExamplePoint(1.0, 2.0), ExamplePointUDT()), - - # Boolean - (True, BooleanType()), - - # Byte - (-(2**7), ByteType()), - (2**7 - 1, ByteType()), - - # Short - (-(2**15), ShortType()), - (2**15 - 1, ShortType()), - - # Integer - (-(2**31), IntegerType()), - (2**31 - 1, IntegerType()), - - # Long - (2**64, LongType()), - - # Float & Double - (1.0, FloatType()), - (1.0, DoubleType()), - - # Decimal - (decimal.Decimal("1.0"), DecimalType()), - - # Binary - (bytearray([1, 2]), BinaryType()), - - # Date/Timestamp - (datetime.date(2000, 1, 2), DateType()), - (datetime.datetime(2000, 1, 2, 3, 4), DateType()), - (datetime.datetime(2000, 1, 2, 3, 4), TimestampType()), - - # Array - ([], ArrayType(IntegerType())), - (["1", None], ArrayType(StringType(), containsNull=True)), - ([1, 2], ArrayType(IntegerType())), - ((1, 2), ArrayType(IntegerType())), - (array.array('h', [1, 2]), ArrayType(IntegerType())), - - # Map - ({}, MapType(StringType(), IntegerType())), - ({"a": 1}, MapType(StringType(), IntegerType())), - ({"a": None}, MapType(StringType(), IntegerType(), valueContainsNull=True)), - - # Struct - ({"s": "a", "i": 1}, schema), - ({"s": "a", "i": None}, schema), - ({"s": "a"}, schema), - ({"s": "a", "f": 1.0}, schema), - (Row(s="a", i=1), schema), - (Row(s="a", i=None), schema), - (Row(s="a", i=1, f=1.0), schema), - (["a", 1], schema), - (["a", None], schema), - (("a", 1), schema), - (MyObj(s="a", i=1), schema), - (MyObj(s="a", i=None), schema), - (MyObj(s="a"), schema), - ] - - # obj, data_type, exception class - failure_spec = [ - # String (match anything but None) - (None, StringType(), ValueError), - - # UDT - (ExamplePoint(1.0, 2.0), PythonOnlyUDT(), ValueError), - - # Boolean - (1, BooleanType(), TypeError), - ("True", BooleanType(), TypeError), - ([1], BooleanType(), TypeError), - - # Byte - (-(2**7) - 1, ByteType(), ValueError), - (2**7, ByteType(), ValueError), - ("1", ByteType(), TypeError), - (1.0, ByteType(), TypeError), - - # Short - (-(2**15) - 1, ShortType(), ValueError), - (2**15, ShortType(), ValueError), - - # Integer - (-(2**31) - 1, IntegerType(), ValueError), - (2**31, IntegerType(), ValueError), - - # Float & Double - (1, FloatType(), TypeError), - (1, DoubleType(), TypeError), - - # Decimal - (1.0, DecimalType(), TypeError), - (1, DecimalType(), TypeError), - ("1.0", DecimalType(), TypeError), - - # Binary - (1, BinaryType(), TypeError), - - # Date/Timestamp - ("2000-01-02", DateType(), TypeError), - (946811040, TimestampType(), TypeError), - - # Array - (["1", None], ArrayType(StringType(), containsNull=False), ValueError), - ([1, "2"], ArrayType(IntegerType()), TypeError), - - # Map - ({"a": 1}, MapType(IntegerType(), IntegerType()), TypeError), - ({"a": "1"}, MapType(StringType(), IntegerType()), TypeError), - ({"a": None}, MapType(StringType(), IntegerType(), valueContainsNull=False), - ValueError), - - # Struct - ({"s": "a", "i": "1"}, schema, TypeError), - (Row(s="a"), schema, ValueError), # Row can't have missing field - (Row(s="a", i="1"), schema, TypeError), - (["a"], schema, ValueError), - (["a", "1"], schema, TypeError), - (MyObj(s="a", i="1"), schema, TypeError), - (MyObj(s=None, i="1"), schema, ValueError), - ] - - # Check success cases - for obj, data_type in success_spec: - try: - _make_type_verifier(data_type, nullable=False)(obj) - except Exception: - self.fail("verify_type(%s, %s, nullable=False)" % (obj, data_type)) - - # Check failure cases - for obj, data_type, exp in failure_spec: - msg = "verify_type(%s, %s, nullable=False) == %s" % (obj, data_type, exp) - with self.assertRaises(exp, msg=msg): - _make_type_verifier(data_type, nullable=False)(obj) - - -@unittest.skipIf( - not _have_pandas or not _have_pyarrow, - _pandas_requirement_message or _pyarrow_requirement_message) -class ArrowTests(ReusedSQLTestCase): - - @classmethod - def setUpClass(cls): - from datetime import date, datetime - from decimal import Decimal - from distutils.version import LooseVersion - import pyarrow as pa - super(ArrowTests, cls).setUpClass() - cls.warnings_lock = threading.Lock() - - # Synchronize default timezone between Python and Java - cls.tz_prev = os.environ.get("TZ", None) # save current tz if set - tz = "America/Los_Angeles" - os.environ["TZ"] = tz - time.tzset() - - cls.spark.conf.set("spark.sql.session.timeZone", tz) - cls.spark.conf.set("spark.sql.execution.arrow.enabled", "true") - # Disable fallback by default to easily detect the failures. - cls.spark.conf.set("spark.sql.execution.arrow.fallback.enabled", "false") - cls.schema = StructType([ - StructField("1_str_t", StringType(), True), - StructField("2_int_t", IntegerType(), True), - StructField("3_long_t", LongType(), True), - StructField("4_float_t", FloatType(), True), - StructField("5_double_t", DoubleType(), True), - StructField("6_decimal_t", DecimalType(38, 18), True), - StructField("7_date_t", DateType(), True), - StructField("8_timestamp_t", TimestampType(), True)]) - cls.data = [(u"a", 1, 10, 0.2, 2.0, Decimal("2.0"), - date(1969, 1, 1), datetime(1969, 1, 1, 1, 1, 1)), - (u"b", 2, 20, 0.4, 4.0, Decimal("4.0"), - date(2012, 2, 2), datetime(2012, 2, 2, 2, 2, 2)), - (u"c", 3, 30, 0.8, 6.0, Decimal("6.0"), - date(2100, 3, 3), datetime(2100, 3, 3, 3, 3, 3))] - - # TODO: remove version check once minimum pyarrow version is 0.10.0 - if LooseVersion("0.10.0") <= LooseVersion(pa.__version__): - cls.schema.add(StructField("9_binary_t", BinaryType(), True)) - cls.data[0] = cls.data[0] + (bytearray(b"a"),) - cls.data[1] = cls.data[1] + (bytearray(b"bb"),) - cls.data[2] = cls.data[2] + (bytearray(b"ccc"),) - - @classmethod - def tearDownClass(cls): - del os.environ["TZ"] - if cls.tz_prev is not None: - os.environ["TZ"] = cls.tz_prev - time.tzset() - super(ArrowTests, cls).tearDownClass() - - def create_pandas_data_frame(self): - import pandas as pd - import numpy as np - data_dict = {} - for j, name in enumerate(self.schema.names): - data_dict[name] = [self.data[i][j] for i in range(len(self.data))] - # need to convert these to numpy types first - data_dict["2_int_t"] = np.int32(data_dict["2_int_t"]) - data_dict["4_float_t"] = np.float32(data_dict["4_float_t"]) - return pd.DataFrame(data=data_dict) - - def test_toPandas_fallback_enabled(self): - import pandas as pd - - with self.sql_conf({"spark.sql.execution.arrow.fallback.enabled": True}): - schema = StructType([StructField("map", MapType(StringType(), IntegerType()), True)]) - df = self.spark.createDataFrame([({u'a': 1},)], schema=schema) - with QuietTest(self.sc): - with self.warnings_lock: - with warnings.catch_warnings(record=True) as warns: - # we want the warnings to appear even if this test is run from a subclass - warnings.simplefilter("always") - pdf = df.toPandas() - # Catch and check the last UserWarning. - user_warns = [ - warn.message for warn in warns if isinstance(warn.message, UserWarning)] - self.assertTrue(len(user_warns) > 0) - self.assertTrue( - "Attempting non-optimization" in _exception_message(user_warns[-1])) - self.assertPandasEqual(pdf, pd.DataFrame({u'map': [{u'a': 1}]})) - - def test_toPandas_fallback_disabled(self): - from distutils.version import LooseVersion - import pyarrow as pa - - schema = StructType([StructField("map", MapType(StringType(), IntegerType()), True)]) - df = self.spark.createDataFrame([(None,)], schema=schema) - with QuietTest(self.sc): - with self.warnings_lock: - with self.assertRaisesRegexp(Exception, 'Unsupported type'): - df.toPandas() - - # TODO: remove BinaryType check once minimum pyarrow version is 0.10.0 - if LooseVersion(pa.__version__) < LooseVersion("0.10.0"): - schema = StructType([StructField("binary", BinaryType(), True)]) - df = self.spark.createDataFrame([(None,)], schema=schema) - with QuietTest(self.sc): - with self.assertRaisesRegexp(Exception, 'Unsupported type.*BinaryType'): - df.toPandas() - - def test_null_conversion(self): - df_null = self.spark.createDataFrame([tuple([None for _ in range(len(self.data[0]))])] + - self.data) - pdf = df_null.toPandas() - null_counts = pdf.isnull().sum().tolist() - self.assertTrue(all([c == 1 for c in null_counts])) - - def _toPandas_arrow_toggle(self, df): - with self.sql_conf({"spark.sql.execution.arrow.enabled": False}): - pdf = df.toPandas() - - pdf_arrow = df.toPandas() - - return pdf, pdf_arrow - - @unittest.skip("This test flakes depending on system timezone") - def test_toPandas_arrow_toggle(self): - df = self.spark.createDataFrame(self.data, schema=self.schema) - pdf, pdf_arrow = self._toPandas_arrow_toggle(df) - expected = self.create_pandas_data_frame() - self.assertPandasEqual(expected, pdf) - self.assertPandasEqual(expected, pdf_arrow) - - @unittest.skip("This test flakes depending on system timezone") - def test_toPandas_respect_session_timezone(self): - df = self.spark.createDataFrame(self.data, schema=self.schema) - - timezone = "America/New_York" - with self.sql_conf({ - "spark.sql.execution.pandas.respectSessionTimeZone": False, - "spark.sql.session.timeZone": timezone}): - pdf_la, pdf_arrow_la = self._toPandas_arrow_toggle(df) - self.assertPandasEqual(pdf_arrow_la, pdf_la) - - with self.sql_conf({ - "spark.sql.execution.pandas.respectSessionTimeZone": True, - "spark.sql.session.timeZone": timezone}): - pdf_ny, pdf_arrow_ny = self._toPandas_arrow_toggle(df) - self.assertPandasEqual(pdf_arrow_ny, pdf_ny) - - self.assertFalse(pdf_ny.equals(pdf_la)) - - from pyspark.sql.types import _check_series_convert_timestamps_local_tz - pdf_la_corrected = pdf_la.copy() - for field in self.schema: - if isinstance(field.dataType, TimestampType): - pdf_la_corrected[field.name] = _check_series_convert_timestamps_local_tz( - pdf_la_corrected[field.name], timezone) - self.assertPandasEqual(pdf_ny, pdf_la_corrected) - - @unittest.skip("This test flakes depending on system timezone") - def test_pandas_round_trip(self): - pdf = self.create_pandas_data_frame() - df = self.spark.createDataFrame(self.data, schema=self.schema) - pdf_arrow = df.toPandas() - self.assertPandasEqual(pdf_arrow, pdf) - - def test_filtered_frame(self): - df = self.spark.range(3).toDF("i") - pdf = df.filter("i < 0").toPandas() - self.assertEqual(len(pdf.columns), 1) - self.assertEqual(pdf.columns[0], "i") - self.assertTrue(pdf.empty) - - def _createDataFrame_toggle(self, pdf, schema=None): - with self.sql_conf({"spark.sql.execution.arrow.enabled": False}): - df_no_arrow = self.spark.createDataFrame(pdf, schema=schema) - - df_arrow = self.spark.createDataFrame(pdf, schema=schema) - - return df_no_arrow, df_arrow - - @unittest.skip("This test flakes depending on system timezone") - def test_createDataFrame_toggle(self): - pdf = self.create_pandas_data_frame() - df_no_arrow, df_arrow = self._createDataFrame_toggle(pdf, schema=self.schema) - self.assertEquals(df_no_arrow.collect(), df_arrow.collect()) - - @unittest.skip("This test flakes depending on system timezone") - def test_createDataFrame_respect_session_timezone(self): - from datetime import timedelta - pdf = self.create_pandas_data_frame() - timezone = "America/New_York" - with self.sql_conf({ - "spark.sql.execution.pandas.respectSessionTimeZone": False, - "spark.sql.session.timeZone": timezone}): - df_no_arrow_la, df_arrow_la = self._createDataFrame_toggle(pdf, schema=self.schema) - result_la = df_no_arrow_la.collect() - result_arrow_la = df_arrow_la.collect() - self.assertEqual(result_la, result_arrow_la) - - with self.sql_conf({ - "spark.sql.execution.pandas.respectSessionTimeZone": True, - "spark.sql.session.timeZone": timezone}): - df_no_arrow_ny, df_arrow_ny = self._createDataFrame_toggle(pdf, schema=self.schema) - result_ny = df_no_arrow_ny.collect() - result_arrow_ny = df_arrow_ny.collect() - self.assertEqual(result_ny, result_arrow_ny) - - self.assertNotEqual(result_ny, result_la) - - # Correct result_la by adjusting 3 hours difference between Los Angeles and New York - result_la_corrected = [Row(**{k: v - timedelta(hours=3) if k == '8_timestamp_t' else v - for k, v in row.asDict().items()}) - for row in result_la] - self.assertEqual(result_ny, result_la_corrected) - - def test_createDataFrame_with_schema(self): - pdf = self.create_pandas_data_frame() - df = self.spark.createDataFrame(pdf, schema=self.schema) - self.assertEquals(self.schema, df.schema) - pdf_arrow = df.toPandas() - self.assertPandasEqual(pdf_arrow, pdf) - - def test_createDataFrame_with_incorrect_schema(self): - pdf = self.create_pandas_data_frame() - fields = list(self.schema) - fields[0], fields[7] = fields[7], fields[0] # swap str with timestamp - wrong_schema = StructType(fields) - with QuietTest(self.sc): - with self.assertRaisesRegexp(Exception, ".*No cast.*string.*timestamp.*"): - self.spark.createDataFrame(pdf, schema=wrong_schema) - - def test_createDataFrame_with_names(self): - pdf = self.create_pandas_data_frame() - new_names = list(map(str, range(len(self.schema.fieldNames())))) - # Test that schema as a list of column names gets applied - df = self.spark.createDataFrame(pdf, schema=list(new_names)) - self.assertEquals(df.schema.fieldNames(), new_names) - # Test that schema as tuple of column names gets applied - df = self.spark.createDataFrame(pdf, schema=tuple(new_names)) - self.assertEquals(df.schema.fieldNames(), new_names) - - def test_createDataFrame_column_name_encoding(self): - import pandas as pd - pdf = pd.DataFrame({u'a': [1]}) - columns = self.spark.createDataFrame(pdf).columns - self.assertTrue(isinstance(columns[0], str)) - self.assertEquals(columns[0], 'a') - columns = self.spark.createDataFrame(pdf, [u'b']).columns - self.assertTrue(isinstance(columns[0], str)) - self.assertEquals(columns[0], 'b') - - def test_createDataFrame_with_single_data_type(self): - import pandas as pd - with QuietTest(self.sc): - with self.assertRaisesRegexp(ValueError, ".*IntegerType.*not supported.*"): - self.spark.createDataFrame(pd.DataFrame({"a": [1]}), schema="int") - - def test_createDataFrame_does_not_modify_input(self): - import pandas as pd - # Some series get converted for Spark to consume, this makes sure input is unchanged - pdf = self.create_pandas_data_frame() - # Use a nanosecond value to make sure it is not truncated - pdf.ix[0, '8_timestamp_t'] = pd.Timestamp(1) - # Integers with nulls will get NaNs filled with 0 and will be casted - pdf.ix[1, '2_int_t'] = None - pdf_copy = pdf.copy(deep=True) - self.spark.createDataFrame(pdf, schema=self.schema) - self.assertTrue(pdf.equals(pdf_copy)) - - def test_schema_conversion_roundtrip(self): - from pyspark.sql.types import from_arrow_schema, to_arrow_schema - arrow_schema = to_arrow_schema(self.schema) - schema_rt = from_arrow_schema(arrow_schema) - self.assertEquals(self.schema, schema_rt) - - def test_createDataFrame_with_array_type(self): - import pandas as pd - pdf = pd.DataFrame({"a": [[1, 2], [3, 4]], "b": [[u"x", u"y"], [u"y", u"z"]]}) - df, df_arrow = self._createDataFrame_toggle(pdf) - result = df.collect() - result_arrow = df_arrow.collect() - expected = [tuple(list(e) for e in rec) for rec in pdf.to_records(index=False)] - for r in range(len(expected)): - for e in range(len(expected[r])): - self.assertTrue(expected[r][e] == result_arrow[r][e] and - result[r][e] == result_arrow[r][e]) - - def test_toPandas_with_array_type(self): - expected = [([1, 2], [u"x", u"y"]), ([3, 4], [u"y", u"z"])] - array_schema = StructType([StructField("a", ArrayType(IntegerType())), - StructField("b", ArrayType(StringType()))]) - df = self.spark.createDataFrame(expected, schema=array_schema) - pdf, pdf_arrow = self._toPandas_arrow_toggle(df) - result = [tuple(list(e) for e in rec) for rec in pdf.to_records(index=False)] - result_arrow = [tuple(list(e) for e in rec) for rec in pdf_arrow.to_records(index=False)] - for r in range(len(expected)): - for e in range(len(expected[r])): - self.assertTrue(expected[r][e] == result_arrow[r][e] and - result[r][e] == result_arrow[r][e]) - - def test_createDataFrame_with_int_col_names(self): - import numpy as np - import pandas as pd - pdf = pd.DataFrame(np.random.rand(4, 2)) - df, df_arrow = self._createDataFrame_toggle(pdf) - pdf_col_names = [str(c) for c in pdf.columns] - self.assertEqual(pdf_col_names, df.columns) - self.assertEqual(pdf_col_names, df_arrow.columns) - - @unittest.skip("This test flakes depending on system timezone") - def test_createDataFrame_fallback_enabled(self): - import pandas as pd - - with QuietTest(self.sc): - with self.sql_conf({"spark.sql.execution.arrow.fallback.enabled": True}): - with warnings.catch_warnings(record=True) as warns: - # we want the warnings to appear even if this test is run from a subclass - warnings.simplefilter("always") - df = self.spark.createDataFrame( - pd.DataFrame([[{u'a': 1}]]), "a: map") - # Catch and check the last UserWarning. - user_warns = [ - warn.message for warn in warns if isinstance(warn.message, UserWarning)] - self.assertTrue(len(user_warns) > 0) - self.assertTrue( - "Attempting non-optimization" in _exception_message(user_warns[-1])) - self.assertEqual(df.collect(), [Row(a={u'a': 1})]) - - def test_createDataFrame_fallback_disabled(self): - from distutils.version import LooseVersion - import pandas as pd - import pyarrow as pa - - with QuietTest(self.sc): - with self.assertRaisesRegexp(TypeError, 'Unsupported type'): - self.spark.createDataFrame( - pd.DataFrame([[{u'a': 1}]]), "a: map") - - # TODO: remove BinaryType check once minimum pyarrow version is 0.10.0 - if LooseVersion(pa.__version__) < LooseVersion("0.10.0"): - with QuietTest(self.sc): - with self.assertRaisesRegexp(TypeError, 'Unsupported type.*BinaryType'): - self.spark.createDataFrame( - pd.DataFrame([[{'a': b'aaa'}]]), "a: binary") - - # Regression test for SPARK-23314 - @unittest.skip("This test flakes depending on system timezone") - def test_timestamp_dst(self): - import pandas as pd - # Daylight saving time for Los Angeles for 2015 is Sun, Nov 1 at 2:00 am - dt = [datetime.datetime(2015, 11, 1, 0, 30), - datetime.datetime(2015, 11, 1, 1, 30), - datetime.datetime(2015, 11, 1, 2, 30)] - pdf = pd.DataFrame({'time': dt}) - - df_from_python = self.spark.createDataFrame(dt, 'timestamp').toDF('time') - df_from_pandas = self.spark.createDataFrame(pdf) - - self.assertPandasEqual(pdf, df_from_python.toPandas()) - self.assertPandasEqual(pdf, df_from_pandas.toPandas()) - - -class EncryptionArrowTests(ArrowTests): - - @classmethod - def conf(cls): - return super(EncryptionArrowTests, cls).conf().set("spark.io.encryption.enabled", "true") - - -@unittest.skipIf( - not _have_pandas or not _have_pyarrow, - _pandas_requirement_message or _pyarrow_requirement_message) -class PandasUDFTests(ReusedSQLTestCase): - - def test_pandas_udf_basic(self): - from pyspark.rdd import PythonEvalType - from pyspark.sql.functions import pandas_udf, PandasUDFType - - udf = pandas_udf(lambda x: x, DoubleType()) - self.assertEqual(udf.returnType, DoubleType()) - self.assertEqual(udf.evalType, PythonEvalType.SQL_SCALAR_PANDAS_UDF) - - udf = pandas_udf(lambda x: x, DoubleType(), PandasUDFType.SCALAR) - self.assertEqual(udf.returnType, DoubleType()) - self.assertEqual(udf.evalType, PythonEvalType.SQL_SCALAR_PANDAS_UDF) - - udf = pandas_udf(lambda x: x, 'double', PandasUDFType.SCALAR) - self.assertEqual(udf.returnType, DoubleType()) - self.assertEqual(udf.evalType, PythonEvalType.SQL_SCALAR_PANDAS_UDF) - - udf = pandas_udf(lambda x: x, StructType([StructField("v", DoubleType())]), - PandasUDFType.GROUPED_MAP) - self.assertEqual(udf.returnType, StructType([StructField("v", DoubleType())])) - self.assertEqual(udf.evalType, PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF) - - udf = pandas_udf(lambda x: x, 'v double', PandasUDFType.GROUPED_MAP) - self.assertEqual(udf.returnType, StructType([StructField("v", DoubleType())])) - self.assertEqual(udf.evalType, PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF) - - udf = pandas_udf(lambda x: x, 'v double', - functionType=PandasUDFType.GROUPED_MAP) - self.assertEqual(udf.returnType, StructType([StructField("v", DoubleType())])) - self.assertEqual(udf.evalType, PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF) - - udf = pandas_udf(lambda x: x, returnType='v double', - functionType=PandasUDFType.GROUPED_MAP) - self.assertEqual(udf.returnType, StructType([StructField("v", DoubleType())])) - self.assertEqual(udf.evalType, PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF) - - def test_pandas_udf_decorator(self): - from pyspark.rdd import PythonEvalType - from pyspark.sql.functions import pandas_udf, PandasUDFType - from pyspark.sql.types import StructType, StructField, DoubleType - - @pandas_udf(DoubleType()) - def foo(x): - return x - self.assertEqual(foo.returnType, DoubleType()) - self.assertEqual(foo.evalType, PythonEvalType.SQL_SCALAR_PANDAS_UDF) - - @pandas_udf(returnType=DoubleType()) - def foo(x): - return x - self.assertEqual(foo.returnType, DoubleType()) - self.assertEqual(foo.evalType, PythonEvalType.SQL_SCALAR_PANDAS_UDF) - - schema = StructType([StructField("v", DoubleType())]) - - @pandas_udf(schema, PandasUDFType.GROUPED_MAP) - def foo(x): - return x - self.assertEqual(foo.returnType, schema) - self.assertEqual(foo.evalType, PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF) - - @pandas_udf('v double', PandasUDFType.GROUPED_MAP) - def foo(x): - return x - self.assertEqual(foo.returnType, schema) - self.assertEqual(foo.evalType, PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF) - - @pandas_udf(schema, functionType=PandasUDFType.GROUPED_MAP) - def foo(x): - return x - self.assertEqual(foo.returnType, schema) - self.assertEqual(foo.evalType, PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF) - - @pandas_udf(returnType='double', functionType=PandasUDFType.SCALAR) - def foo(x): - return x - self.assertEqual(foo.returnType, DoubleType()) - self.assertEqual(foo.evalType, PythonEvalType.SQL_SCALAR_PANDAS_UDF) - - @pandas_udf(returnType=schema, functionType=PandasUDFType.GROUPED_MAP) - def foo(x): - return x - self.assertEqual(foo.returnType, schema) - self.assertEqual(foo.evalType, PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF) - - def test_udf_wrong_arg(self): - from pyspark.sql.functions import pandas_udf, PandasUDFType - - with QuietTest(self.sc): - with self.assertRaises(ParseException): - @pandas_udf('blah') - def foo(x): - return x - with self.assertRaisesRegexp(ValueError, 'Invalid returnType.*None'): - @pandas_udf(functionType=PandasUDFType.SCALAR) - def foo(x): - return x - with self.assertRaisesRegexp(ValueError, 'Invalid functionType'): - @pandas_udf('double', 100) - def foo(x): - return x - - with self.assertRaisesRegexp(ValueError, '0-arg pandas_udfs.*not.*supported'): - pandas_udf(lambda: 1, LongType(), PandasUDFType.SCALAR) - with self.assertRaisesRegexp(ValueError, '0-arg pandas_udfs.*not.*supported'): - @pandas_udf(LongType(), PandasUDFType.SCALAR) - def zero_with_type(): - return 1 - - with self.assertRaisesRegexp(TypeError, 'Invalid returnType'): - @pandas_udf(returnType=PandasUDFType.GROUPED_MAP) - def foo(df): - return df - with self.assertRaisesRegexp(TypeError, 'Invalid returnType'): - @pandas_udf(returnType='double', functionType=PandasUDFType.GROUPED_MAP) - def foo(df): - return df - with self.assertRaisesRegexp(ValueError, 'Invalid function'): - @pandas_udf(returnType='k int, v double', functionType=PandasUDFType.GROUPED_MAP) - def foo(k, v, w): - return k - - def test_stopiteration_in_udf(self): - from pyspark.sql.functions import udf, pandas_udf, PandasUDFType - from py4j.protocol import Py4JJavaError - - def foo(x): - raise StopIteration() - - def foofoo(x, y): - raise StopIteration() - - exc_message = "Caught StopIteration thrown from user's code; failing the task" - df = self.spark.range(0, 100) - - # plain udf (test for SPARK-23754) - self.assertRaisesRegexp( - Py4JJavaError, - exc_message, - df.withColumn('v', udf(foo)('id')).collect - ) - - # pandas scalar udf - self.assertRaisesRegexp( - Py4JJavaError, - exc_message, - df.withColumn( - 'v', pandas_udf(foo, 'double', PandasUDFType.SCALAR)('id') - ).collect - ) - - # pandas grouped map - self.assertRaisesRegexp( - Py4JJavaError, - exc_message, - df.groupBy('id').apply( - pandas_udf(foo, df.schema, PandasUDFType.GROUPED_MAP) - ).collect - ) - - self.assertRaisesRegexp( - Py4JJavaError, - exc_message, - df.groupBy('id').apply( - pandas_udf(foofoo, df.schema, PandasUDFType.GROUPED_MAP) - ).collect - ) - - # pandas grouped agg - self.assertRaisesRegexp( - Py4JJavaError, - exc_message, - df.groupBy('id').agg( - pandas_udf(foo, 'double', PandasUDFType.GROUPED_AGG)('id') - ).collect - ) - - -@unittest.skipIf( - not _have_pandas or not _have_pyarrow, - _pandas_requirement_message or _pyarrow_requirement_message) -class ScalarPandasUDFTests(ReusedSQLTestCase): - - @classmethod - def setUpClass(cls): - ReusedSQLTestCase.setUpClass() - - # Synchronize default timezone between Python and Java - cls.tz_prev = os.environ.get("TZ", None) # save current tz if set - tz = "America/Los_Angeles" - os.environ["TZ"] = tz - time.tzset() - - cls.sc.environment["TZ"] = tz - cls.spark.conf.set("spark.sql.session.timeZone", tz) - - @classmethod - def tearDownClass(cls): - del os.environ["TZ"] - if cls.tz_prev is not None: - os.environ["TZ"] = cls.tz_prev - time.tzset() - ReusedSQLTestCase.tearDownClass() - - @property - def nondeterministic_vectorized_udf(self): - from pyspark.sql.functions import pandas_udf - - @pandas_udf('double') - def random_udf(v): - import pandas as pd - import numpy as np - return pd.Series(np.random.random(len(v))) - random_udf = random_udf.asNondeterministic() - return random_udf - - def test_pandas_udf_tokenize(self): - from pyspark.sql.functions import pandas_udf - tokenize = pandas_udf(lambda s: s.apply(lambda str: str.split(' ')), - ArrayType(StringType())) - self.assertEqual(tokenize.returnType, ArrayType(StringType())) - df = self.spark.createDataFrame([("hi boo",), ("bye boo",)], ["vals"]) - result = df.select(tokenize("vals").alias("hi")) - self.assertEqual([Row(hi=[u'hi', u'boo']), Row(hi=[u'bye', u'boo'])], result.collect()) - - def test_pandas_udf_nested_arrays(self): - from pyspark.sql.functions import pandas_udf - tokenize = pandas_udf(lambda s: s.apply(lambda str: [str.split(' ')]), - ArrayType(ArrayType(StringType()))) - self.assertEqual(tokenize.returnType, ArrayType(ArrayType(StringType()))) - df = self.spark.createDataFrame([("hi boo",), ("bye boo",)], ["vals"]) - result = df.select(tokenize("vals").alias("hi")) - self.assertEqual([Row(hi=[[u'hi', u'boo']]), Row(hi=[[u'bye', u'boo']])], result.collect()) - - def test_vectorized_udf_basic(self): - from pyspark.sql.functions import pandas_udf, col, array - df = self.spark.range(10).select( - col('id').cast('string').alias('str'), - col('id').cast('int').alias('int'), - col('id').alias('long'), - col('id').cast('float').alias('float'), - col('id').cast('double').alias('double'), - col('id').cast('decimal').alias('decimal'), - col('id').cast('boolean').alias('bool'), - array(col('id')).alias('array_long')) - f = lambda x: x - str_f = pandas_udf(f, StringType()) - int_f = pandas_udf(f, IntegerType()) - long_f = pandas_udf(f, LongType()) - float_f = pandas_udf(f, FloatType()) - double_f = pandas_udf(f, DoubleType()) - decimal_f = pandas_udf(f, DecimalType()) - bool_f = pandas_udf(f, BooleanType()) - array_long_f = pandas_udf(f, ArrayType(LongType())) - res = df.select(str_f(col('str')), int_f(col('int')), - long_f(col('long')), float_f(col('float')), - double_f(col('double')), decimal_f('decimal'), - bool_f(col('bool')), array_long_f('array_long')) - self.assertEquals(df.collect(), res.collect()) - - def test_register_nondeterministic_vectorized_udf_basic(self): - from pyspark.sql.functions import pandas_udf - from pyspark.rdd import PythonEvalType - import random - random_pandas_udf = pandas_udf( - lambda x: random.randint(6, 6) + x, IntegerType()).asNondeterministic() - self.assertEqual(random_pandas_udf.deterministic, False) - self.assertEqual(random_pandas_udf.evalType, PythonEvalType.SQL_SCALAR_PANDAS_UDF) - nondeterministic_pandas_udf = self.spark.catalog.registerFunction( - "randomPandasUDF", random_pandas_udf) - self.assertEqual(nondeterministic_pandas_udf.deterministic, False) - self.assertEqual(nondeterministic_pandas_udf.evalType, PythonEvalType.SQL_SCALAR_PANDAS_UDF) - [row] = self.spark.sql("SELECT randomPandasUDF(1)").collect() - self.assertEqual(row[0], 7) - - def test_vectorized_udf_null_boolean(self): - from pyspark.sql.functions import pandas_udf, col - data = [(True,), (True,), (None,), (False,)] - schema = StructType().add("bool", BooleanType()) - df = self.spark.createDataFrame(data, schema) - bool_f = pandas_udf(lambda x: x, BooleanType()) - res = df.select(bool_f(col('bool'))) - self.assertEquals(df.collect(), res.collect()) - - def test_vectorized_udf_null_byte(self): - from pyspark.sql.functions import pandas_udf, col - data = [(None,), (2,), (3,), (4,)] - schema = StructType().add("byte", ByteType()) - df = self.spark.createDataFrame(data, schema) - byte_f = pandas_udf(lambda x: x, ByteType()) - res = df.select(byte_f(col('byte'))) - self.assertEquals(df.collect(), res.collect()) - - def test_vectorized_udf_null_short(self): - from pyspark.sql.functions import pandas_udf, col - data = [(None,), (2,), (3,), (4,)] - schema = StructType().add("short", ShortType()) - df = self.spark.createDataFrame(data, schema) - short_f = pandas_udf(lambda x: x, ShortType()) - res = df.select(short_f(col('short'))) - self.assertEquals(df.collect(), res.collect()) - - def test_vectorized_udf_null_int(self): - from pyspark.sql.functions import pandas_udf, col - data = [(None,), (2,), (3,), (4,)] - schema = StructType().add("int", IntegerType()) - df = self.spark.createDataFrame(data, schema) - int_f = pandas_udf(lambda x: x, IntegerType()) - res = df.select(int_f(col('int'))) - self.assertEquals(df.collect(), res.collect()) - - def test_vectorized_udf_null_long(self): - from pyspark.sql.functions import pandas_udf, col - data = [(None,), (2,), (3,), (4,)] - schema = StructType().add("long", LongType()) - df = self.spark.createDataFrame(data, schema) - long_f = pandas_udf(lambda x: x, LongType()) - res = df.select(long_f(col('long'))) - self.assertEquals(df.collect(), res.collect()) - - def test_vectorized_udf_null_float(self): - from pyspark.sql.functions import pandas_udf, col - data = [(3.0,), (5.0,), (-1.0,), (None,)] - schema = StructType().add("float", FloatType()) - df = self.spark.createDataFrame(data, schema) - float_f = pandas_udf(lambda x: x, FloatType()) - res = df.select(float_f(col('float'))) - self.assertEquals(df.collect(), res.collect()) - - def test_vectorized_udf_null_double(self): - from pyspark.sql.functions import pandas_udf, col - data = [(3.0,), (5.0,), (-1.0,), (None,)] - schema = StructType().add("double", DoubleType()) - df = self.spark.createDataFrame(data, schema) - double_f = pandas_udf(lambda x: x, DoubleType()) - res = df.select(double_f(col('double'))) - self.assertEquals(df.collect(), res.collect()) - - def test_vectorized_udf_null_decimal(self): - from decimal import Decimal - from pyspark.sql.functions import pandas_udf, col - data = [(Decimal(3.0),), (Decimal(5.0),), (Decimal(-1.0),), (None,)] - schema = StructType().add("decimal", DecimalType(38, 18)) - df = self.spark.createDataFrame(data, schema) - decimal_f = pandas_udf(lambda x: x, DecimalType(38, 18)) - res = df.select(decimal_f(col('decimal'))) - self.assertEquals(df.collect(), res.collect()) - - def test_vectorized_udf_null_string(self): - from pyspark.sql.functions import pandas_udf, col - data = [("foo",), (None,), ("bar",), ("bar",)] - schema = StructType().add("str", StringType()) - df = self.spark.createDataFrame(data, schema) - str_f = pandas_udf(lambda x: x, StringType()) - res = df.select(str_f(col('str'))) - self.assertEquals(df.collect(), res.collect()) - - def test_vectorized_udf_string_in_udf(self): - from pyspark.sql.functions import pandas_udf, col - import pandas as pd - df = self.spark.range(10) - str_f = pandas_udf(lambda x: pd.Series(map(str, x)), StringType()) - actual = df.select(str_f(col('id'))) - expected = df.select(col('id').cast('string')) - self.assertEquals(expected.collect(), actual.collect()) - - def test_vectorized_udf_datatype_string(self): - from pyspark.sql.functions import pandas_udf, col - df = self.spark.range(10).select( - col('id').cast('string').alias('str'), - col('id').cast('int').alias('int'), - col('id').alias('long'), - col('id').cast('float').alias('float'), - col('id').cast('double').alias('double'), - col('id').cast('decimal').alias('decimal'), - col('id').cast('boolean').alias('bool')) - f = lambda x: x - str_f = pandas_udf(f, 'string') - int_f = pandas_udf(f, 'integer') - long_f = pandas_udf(f, 'long') - float_f = pandas_udf(f, 'float') - double_f = pandas_udf(f, 'double') - decimal_f = pandas_udf(f, 'decimal(38, 18)') - bool_f = pandas_udf(f, 'boolean') - res = df.select(str_f(col('str')), int_f(col('int')), - long_f(col('long')), float_f(col('float')), - double_f(col('double')), decimal_f('decimal'), - bool_f(col('bool'))) - self.assertEquals(df.collect(), res.collect()) - - def test_vectorized_udf_null_binary(self): - from distutils.version import LooseVersion - import pyarrow as pa - from pyspark.sql.functions import pandas_udf, col - if LooseVersion(pa.__version__) < LooseVersion("0.10.0"): - with QuietTest(self.sc): - with self.assertRaisesRegexp( - NotImplementedError, - 'Invalid returnType.*scalar Pandas UDF.*BinaryType'): - pandas_udf(lambda x: x, BinaryType()) - else: - data = [(bytearray(b"a"),), (None,), (bytearray(b"bb"),), (bytearray(b"ccc"),)] - schema = StructType().add("binary", BinaryType()) - df = self.spark.createDataFrame(data, schema) - str_f = pandas_udf(lambda x: x, BinaryType()) - res = df.select(str_f(col('binary'))) - self.assertEquals(df.collect(), res.collect()) - - def test_vectorized_udf_array_type(self): - from pyspark.sql.functions import pandas_udf, col - data = [([1, 2],), ([3, 4],)] - array_schema = StructType([StructField("array", ArrayType(IntegerType()))]) - df = self.spark.createDataFrame(data, schema=array_schema) - array_f = pandas_udf(lambda x: x, ArrayType(IntegerType())) - result = df.select(array_f(col('array'))) - self.assertEquals(df.collect(), result.collect()) - - def test_vectorized_udf_null_array(self): - from pyspark.sql.functions import pandas_udf, col - data = [([1, 2],), (None,), (None,), ([3, 4],), (None,)] - array_schema = StructType([StructField("array", ArrayType(IntegerType()))]) - df = self.spark.createDataFrame(data, schema=array_schema) - array_f = pandas_udf(lambda x: x, ArrayType(IntegerType())) - result = df.select(array_f(col('array'))) - self.assertEquals(df.collect(), result.collect()) - - def test_vectorized_udf_complex(self): - from pyspark.sql.functions import pandas_udf, col, expr - df = self.spark.range(10).select( - col('id').cast('int').alias('a'), - col('id').cast('int').alias('b'), - col('id').cast('double').alias('c')) - add = pandas_udf(lambda x, y: x + y, IntegerType()) - power2 = pandas_udf(lambda x: 2 ** x, IntegerType()) - mul = pandas_udf(lambda x, y: x * y, DoubleType()) - res = df.select(add(col('a'), col('b')), power2(col('a')), mul(col('b'), col('c'))) - expected = df.select(expr('a + b'), expr('power(2, a)'), expr('b * c')) - self.assertEquals(expected.collect(), res.collect()) - - def test_vectorized_udf_exception(self): - from pyspark.sql.functions import pandas_udf, col - df = self.spark.range(10) - raise_exception = pandas_udf(lambda x: x * (1 / 0), LongType()) - with QuietTest(self.sc): - with self.assertRaisesRegexp(Exception, 'division( or modulo)? by zero'): - df.select(raise_exception(col('id'))).collect() - - def test_vectorized_udf_invalid_length(self): - from pyspark.sql.functions import pandas_udf, col - import pandas as pd - df = self.spark.range(10) - raise_exception = pandas_udf(lambda _: pd.Series(1), LongType()) - with QuietTest(self.sc): - with self.assertRaisesRegexp( - Exception, - 'Result vector from pandas_udf was not the required length'): - df.select(raise_exception(col('id'))).collect() - - def test_vectorized_udf_chained(self): - from pyspark.sql.functions import pandas_udf, col - df = self.spark.range(10) - f = pandas_udf(lambda x: x + 1, LongType()) - g = pandas_udf(lambda x: x - 1, LongType()) - res = df.select(g(f(col('id')))) - self.assertEquals(df.collect(), res.collect()) - - def test_vectorized_udf_wrong_return_type(self): - from pyspark.sql.functions import pandas_udf, col - with QuietTest(self.sc): - with self.assertRaisesRegexp( - NotImplementedError, - 'Invalid returnType.*scalar Pandas UDF.*MapType'): - pandas_udf(lambda x: x * 1.0, MapType(LongType(), LongType())) - - def test_vectorized_udf_return_scalar(self): - from pyspark.sql.functions import pandas_udf, col - df = self.spark.range(10) - f = pandas_udf(lambda x: 1.0, DoubleType()) - with QuietTest(self.sc): - with self.assertRaisesRegexp(Exception, 'Return.*type.*Series'): - df.select(f(col('id'))).collect() - - def test_vectorized_udf_decorator(self): - from pyspark.sql.functions import pandas_udf, col - df = self.spark.range(10) - - @pandas_udf(returnType=LongType()) - def identity(x): - return x - res = df.select(identity(col('id'))) - self.assertEquals(df.collect(), res.collect()) - - def test_vectorized_udf_empty_partition(self): - from pyspark.sql.functions import pandas_udf, col - df = self.spark.createDataFrame(self.sc.parallelize([Row(id=1)], 2)) - f = pandas_udf(lambda x: x, LongType()) - res = df.select(f(col('id'))) - self.assertEquals(df.collect(), res.collect()) - - def test_vectorized_udf_varargs(self): - from pyspark.sql.functions import pandas_udf, col - df = self.spark.createDataFrame(self.sc.parallelize([Row(id=1)], 2)) - f = pandas_udf(lambda *v: v[0], LongType()) - res = df.select(f(col('id'))) - self.assertEquals(df.collect(), res.collect()) - - def test_vectorized_udf_unsupported_types(self): - from pyspark.sql.functions import pandas_udf - with QuietTest(self.sc): - with self.assertRaisesRegexp( - NotImplementedError, - 'Invalid returnType.*scalar Pandas UDF.*MapType'): - pandas_udf(lambda x: x, MapType(StringType(), IntegerType())) - - def test_vectorized_udf_dates(self): - from pyspark.sql.functions import pandas_udf, col - from datetime import date - schema = StructType().add("idx", LongType()).add("date", DateType()) - data = [(0, date(1969, 1, 1),), - (1, date(2012, 2, 2),), - (2, None,), - (3, date(2100, 4, 4),)] - df = self.spark.createDataFrame(data, schema=schema) - - date_copy = pandas_udf(lambda t: t, returnType=DateType()) - df = df.withColumn("date_copy", date_copy(col("date"))) - - @pandas_udf(returnType=StringType()) - def check_data(idx, date, date_copy): - import pandas as pd - msgs = [] - is_equal = date.isnull() - for i in range(len(idx)): - if (is_equal[i] and data[idx[i]][1] is None) or \ - date[i] == data[idx[i]][1]: - msgs.append(None) - else: - msgs.append( - "date values are not equal (date='%s': data[%d][1]='%s')" - % (date[i], idx[i], data[idx[i]][1])) - return pd.Series(msgs) - - result = df.withColumn("check_data", - check_data(col("idx"), col("date"), col("date_copy"))).collect() - - self.assertEquals(len(data), len(result)) - for i in range(len(result)): - self.assertEquals(data[i][1], result[i][1]) # "date" col - self.assertEquals(data[i][1], result[i][2]) # "date_copy" col - self.assertIsNone(result[i][3]) # "check_data" col - - @unittest.skip("This test flakes depending on system timezone") - def test_vectorized_udf_timestamps(self): - from pyspark.sql.functions import pandas_udf, col - from datetime import datetime - schema = StructType([ - StructField("idx", LongType(), True), - StructField("timestamp", TimestampType(), True)]) - data = [(0, datetime(1969, 1, 1, 1, 1, 1)), - (1, datetime(2012, 2, 2, 2, 2, 2)), - (2, None), - (3, datetime(2100, 3, 3, 3, 3, 3))] - - df = self.spark.createDataFrame(data, schema=schema) - - # Check that a timestamp passed through a pandas_udf will not be altered by timezone calc - f_timestamp_copy = pandas_udf(lambda t: t, returnType=TimestampType()) - df = df.withColumn("timestamp_copy", f_timestamp_copy(col("timestamp"))) - - @pandas_udf(returnType=StringType()) - def check_data(idx, timestamp, timestamp_copy): - import pandas as pd - msgs = [] - is_equal = timestamp.isnull() # use this array to check values are equal - for i in range(len(idx)): - # Check that timestamps are as expected in the UDF - if (is_equal[i] and data[idx[i]][1] is None) or \ - timestamp[i].to_pydatetime() == data[idx[i]][1]: - msgs.append(None) - else: - msgs.append( - "timestamp values are not equal (timestamp='%s': data[%d][1]='%s')" - % (timestamp[i], idx[i], data[idx[i]][1])) - return pd.Series(msgs) - - result = df.withColumn("check_data", check_data(col("idx"), col("timestamp"), - col("timestamp_copy"))).collect() - # Check that collection values are correct - self.assertEquals(len(data), len(result)) - for i in range(len(result)): - self.assertEquals(data[i][1], result[i][1]) # "timestamp" col - self.assertEquals(data[i][1], result[i][2]) # "timestamp_copy" col - self.assertIsNone(result[i][3]) # "check_data" col - - def test_vectorized_udf_return_timestamp_tz(self): - from pyspark.sql.functions import pandas_udf, col - import pandas as pd - df = self.spark.range(10) - - @pandas_udf(returnType=TimestampType()) - def gen_timestamps(id): - ts = [pd.Timestamp(i, unit='D', tz='America/Los_Angeles') for i in id] - return pd.Series(ts) - - result = df.withColumn("ts", gen_timestamps(col("id"))).collect() - spark_ts_t = TimestampType() - for r in result: - i, ts = r - ts_tz = pd.Timestamp(i, unit='D', tz='America/Los_Angeles').to_pydatetime() - expected = spark_ts_t.fromInternal(spark_ts_t.toInternal(ts_tz)) - self.assertEquals(expected, ts) - - def test_vectorized_udf_check_config(self): - from pyspark.sql.functions import pandas_udf, col - import pandas as pd - with self.sql_conf({"spark.sql.execution.arrow.maxRecordsPerBatch": 3}): - df = self.spark.range(10, numPartitions=1) - - @pandas_udf(returnType=LongType()) - def check_records_per_batch(x): - return pd.Series(x.size).repeat(x.size) - - result = df.select(check_records_per_batch(col("id"))).collect() - for (r,) in result: - self.assertTrue(r <= 3) - - def test_vectorized_udf_timestamps_respect_session_timezone(self): - from pyspark.sql.functions import pandas_udf, col - from datetime import datetime - import pandas as pd - schema = StructType([ - StructField("idx", LongType(), True), - StructField("timestamp", TimestampType(), True)]) - data = [(1, datetime(1969, 1, 1, 1, 1, 1)), - (2, datetime(2012, 2, 2, 2, 2, 2)), - (3, None), - (4, datetime(2100, 3, 3, 3, 3, 3))] - df = self.spark.createDataFrame(data, schema=schema) - - f_timestamp_copy = pandas_udf(lambda ts: ts, TimestampType()) - internal_value = pandas_udf( - lambda ts: ts.apply(lambda ts: ts.value if ts is not pd.NaT else None), LongType()) - - timezone = "America/New_York" - with self.sql_conf({ - "spark.sql.execution.pandas.respectSessionTimeZone": False, - "spark.sql.session.timeZone": timezone}): - df_la = df.withColumn("tscopy", f_timestamp_copy(col("timestamp"))) \ - .withColumn("internal_value", internal_value(col("timestamp"))) - result_la = df_la.select(col("idx"), col("internal_value")).collect() - # Correct result_la by adjusting 3 hours difference between Los Angeles and New York - diff = 3 * 60 * 60 * 1000 * 1000 * 1000 - result_la_corrected = \ - df_la.select(col("idx"), col("tscopy"), col("internal_value") + diff).collect() - - with self.sql_conf({ - "spark.sql.execution.pandas.respectSessionTimeZone": True, - "spark.sql.session.timeZone": timezone}): - df_ny = df.withColumn("tscopy", f_timestamp_copy(col("timestamp"))) \ - .withColumn("internal_value", internal_value(col("timestamp"))) - result_ny = df_ny.select(col("idx"), col("tscopy"), col("internal_value")).collect() - - self.assertNotEqual(result_ny, result_la) - self.assertEqual(result_ny, result_la_corrected) - - def test_nondeterministic_vectorized_udf(self): - # Test that nondeterministic UDFs are evaluated only once in chained UDF evaluations - from pyspark.sql.functions import udf, pandas_udf, col - - @pandas_udf('double') - def plus_ten(v): - return v + 10 - random_udf = self.nondeterministic_vectorized_udf - - df = self.spark.range(10).withColumn('rand', random_udf(col('id'))) - result1 = df.withColumn('plus_ten(rand)', plus_ten(df['rand'])).toPandas() - - self.assertEqual(random_udf.deterministic, False) - self.assertTrue(result1['plus_ten(rand)'].equals(result1['rand'] + 10)) - - def test_nondeterministic_vectorized_udf_in_aggregate(self): - from pyspark.sql.functions import pandas_udf, sum - - df = self.spark.range(10) - random_udf = self.nondeterministic_vectorized_udf - - with QuietTest(self.sc): - with self.assertRaisesRegexp(AnalysisException, 'nondeterministic'): - df.groupby(df.id).agg(sum(random_udf(df.id))).collect() - with self.assertRaisesRegexp(AnalysisException, 'nondeterministic'): - df.agg(sum(random_udf(df.id))).collect() - - def test_register_vectorized_udf_basic(self): - from pyspark.rdd import PythonEvalType - from pyspark.sql.functions import pandas_udf, col, expr - df = self.spark.range(10).select( - col('id').cast('int').alias('a'), - col('id').cast('int').alias('b')) - original_add = pandas_udf(lambda x, y: x + y, IntegerType()) - self.assertEqual(original_add.deterministic, True) - self.assertEqual(original_add.evalType, PythonEvalType.SQL_SCALAR_PANDAS_UDF) - new_add = self.spark.catalog.registerFunction("add1", original_add) - res1 = df.select(new_add(col('a'), col('b'))) - res2 = self.spark.sql( - "SELECT add1(t.a, t.b) FROM (SELECT id as a, id as b FROM range(10)) t") - expected = df.select(expr('a + b')) - self.assertEquals(expected.collect(), res1.collect()) - self.assertEquals(expected.collect(), res2.collect()) - - # Regression test for SPARK-23314 - def test_timestamp_dst(self): - from pyspark.sql.functions import pandas_udf - # Daylight saving time for Los Angeles for 2015 is Sun, Nov 1 at 2:00 am - dt = [datetime.datetime(2015, 11, 1, 0, 30), - datetime.datetime(2015, 11, 1, 1, 30), - datetime.datetime(2015, 11, 1, 2, 30)] - df = self.spark.createDataFrame(dt, 'timestamp').toDF('time') - foo_udf = pandas_udf(lambda x: x, 'timestamp') - result = df.withColumn('time', foo_udf(df.time)) - self.assertEquals(df.collect(), result.collect()) - - @unittest.skipIf(sys.version_info[:2] < (3, 5), "Type hints are supported from Python 3.5.") - def test_type_annotation(self): - from pyspark.sql.functions import pandas_udf - # Regression test to check if type hints can be used. See SPARK-23569. - # Note that it throws an error during compilation in lower Python versions if 'exec' - # is not used. Also, note that we explicitly use another dictionary to avoid modifications - # in the current 'locals()'. - # - # Hyukjin: I think it's an ugly way to test issues about syntax specific in - # higher versions of Python, which we shouldn't encourage. This was the last resort - # I could come up with at that time. - _locals = {} - exec( - "import pandas as pd\ndef noop(col: pd.Series) -> pd.Series: return col", - _locals) - df = self.spark.range(1).select(pandas_udf(f=_locals['noop'], returnType='bigint')('id')) - self.assertEqual(df.first()[0], 0) - - def test_mixed_udf(self): - import pandas as pd - from pyspark.sql.functions import col, udf, pandas_udf - - df = self.spark.range(0, 1).toDF('v') - - # Test mixture of multiple UDFs and Pandas UDFs. - - @udf('int') - def f1(x): - assert type(x) == int - return x + 1 - - @pandas_udf('int') - def f2(x): - assert type(x) == pd.Series - return x + 10 - - @udf('int') - def f3(x): - assert type(x) == int - return x + 100 - - @pandas_udf('int') - def f4(x): - assert type(x) == pd.Series - return x + 1000 - - # Test single expression with chained UDFs - df_chained_1 = df.withColumn('f2_f1', f2(f1(df['v']))) - df_chained_2 = df.withColumn('f3_f2_f1', f3(f2(f1(df['v'])))) - df_chained_3 = df.withColumn('f4_f3_f2_f1', f4(f3(f2(f1(df['v']))))) - df_chained_4 = df.withColumn('f4_f2_f1', f4(f2(f1(df['v'])))) - df_chained_5 = df.withColumn('f4_f3_f1', f4(f3(f1(df['v'])))) - - expected_chained_1 = df.withColumn('f2_f1', df['v'] + 11) - expected_chained_2 = df.withColumn('f3_f2_f1', df['v'] + 111) - expected_chained_3 = df.withColumn('f4_f3_f2_f1', df['v'] + 1111) - expected_chained_4 = df.withColumn('f4_f2_f1', df['v'] + 1011) - expected_chained_5 = df.withColumn('f4_f3_f1', df['v'] + 1101) - - self.assertEquals(expected_chained_1.collect(), df_chained_1.collect()) - self.assertEquals(expected_chained_2.collect(), df_chained_2.collect()) - self.assertEquals(expected_chained_3.collect(), df_chained_3.collect()) - self.assertEquals(expected_chained_4.collect(), df_chained_4.collect()) - self.assertEquals(expected_chained_5.collect(), df_chained_5.collect()) - - # Test multiple mixed UDF expressions in a single projection - df_multi_1 = df \ - .withColumn('f1', f1(col('v'))) \ - .withColumn('f2', f2(col('v'))) \ - .withColumn('f3', f3(col('v'))) \ - .withColumn('f4', f4(col('v'))) \ - .withColumn('f2_f1', f2(col('f1'))) \ - .withColumn('f3_f1', f3(col('f1'))) \ - .withColumn('f4_f1', f4(col('f1'))) \ - .withColumn('f3_f2', f3(col('f2'))) \ - .withColumn('f4_f2', f4(col('f2'))) \ - .withColumn('f4_f3', f4(col('f3'))) \ - .withColumn('f3_f2_f1', f3(col('f2_f1'))) \ - .withColumn('f4_f2_f1', f4(col('f2_f1'))) \ - .withColumn('f4_f3_f1', f4(col('f3_f1'))) \ - .withColumn('f4_f3_f2', f4(col('f3_f2'))) \ - .withColumn('f4_f3_f2_f1', f4(col('f3_f2_f1'))) - - # Test mixed udfs in a single expression - df_multi_2 = df \ - .withColumn('f1', f1(col('v'))) \ - .withColumn('f2', f2(col('v'))) \ - .withColumn('f3', f3(col('v'))) \ - .withColumn('f4', f4(col('v'))) \ - .withColumn('f2_f1', f2(f1(col('v')))) \ - .withColumn('f3_f1', f3(f1(col('v')))) \ - .withColumn('f4_f1', f4(f1(col('v')))) \ - .withColumn('f3_f2', f3(f2(col('v')))) \ - .withColumn('f4_f2', f4(f2(col('v')))) \ - .withColumn('f4_f3', f4(f3(col('v')))) \ - .withColumn('f3_f2_f1', f3(f2(f1(col('v'))))) \ - .withColumn('f4_f2_f1', f4(f2(f1(col('v'))))) \ - .withColumn('f4_f3_f1', f4(f3(f1(col('v'))))) \ - .withColumn('f4_f3_f2', f4(f3(f2(col('v'))))) \ - .withColumn('f4_f3_f2_f1', f4(f3(f2(f1(col('v')))))) - - expected = df \ - .withColumn('f1', df['v'] + 1) \ - .withColumn('f2', df['v'] + 10) \ - .withColumn('f3', df['v'] + 100) \ - .withColumn('f4', df['v'] + 1000) \ - .withColumn('f2_f1', df['v'] + 11) \ - .withColumn('f3_f1', df['v'] + 101) \ - .withColumn('f4_f1', df['v'] + 1001) \ - .withColumn('f3_f2', df['v'] + 110) \ - .withColumn('f4_f2', df['v'] + 1010) \ - .withColumn('f4_f3', df['v'] + 1100) \ - .withColumn('f3_f2_f1', df['v'] + 111) \ - .withColumn('f4_f2_f1', df['v'] + 1011) \ - .withColumn('f4_f3_f1', df['v'] + 1101) \ - .withColumn('f4_f3_f2', df['v'] + 1110) \ - .withColumn('f4_f3_f2_f1', df['v'] + 1111) - - self.assertEquals(expected.collect(), df_multi_1.collect()) - self.assertEquals(expected.collect(), df_multi_2.collect()) - - def test_mixed_udf_and_sql(self): - import pandas as pd - from pyspark.sql import Column - from pyspark.sql.functions import udf, pandas_udf - - df = self.spark.range(0, 1).toDF('v') - - # Test mixture of UDFs, Pandas UDFs and SQL expression. - - @udf('int') - def f1(x): - assert type(x) == int - return x + 1 - - def f2(x): - assert type(x) == Column - return x + 10 - - @pandas_udf('int') - def f3(x): - assert type(x) == pd.Series - return x + 100 - - df1 = df.withColumn('f1', f1(df['v'])) \ - .withColumn('f2', f2(df['v'])) \ - .withColumn('f3', f3(df['v'])) \ - .withColumn('f1_f2', f1(f2(df['v']))) \ - .withColumn('f1_f3', f1(f3(df['v']))) \ - .withColumn('f2_f1', f2(f1(df['v']))) \ - .withColumn('f2_f3', f2(f3(df['v']))) \ - .withColumn('f3_f1', f3(f1(df['v']))) \ - .withColumn('f3_f2', f3(f2(df['v']))) \ - .withColumn('f1_f2_f3', f1(f2(f3(df['v'])))) \ - .withColumn('f1_f3_f2', f1(f3(f2(df['v'])))) \ - .withColumn('f2_f1_f3', f2(f1(f3(df['v'])))) \ - .withColumn('f2_f3_f1', f2(f3(f1(df['v'])))) \ - .withColumn('f3_f1_f2', f3(f1(f2(df['v'])))) \ - .withColumn('f3_f2_f1', f3(f2(f1(df['v'])))) - - expected = df.withColumn('f1', df['v'] + 1) \ - .withColumn('f2', df['v'] + 10) \ - .withColumn('f3', df['v'] + 100) \ - .withColumn('f1_f2', df['v'] + 11) \ - .withColumn('f1_f3', df['v'] + 101) \ - .withColumn('f2_f1', df['v'] + 11) \ - .withColumn('f2_f3', df['v'] + 110) \ - .withColumn('f3_f1', df['v'] + 101) \ - .withColumn('f3_f2', df['v'] + 110) \ - .withColumn('f1_f2_f3', df['v'] + 111) \ - .withColumn('f1_f3_f2', df['v'] + 111) \ - .withColumn('f2_f1_f3', df['v'] + 111) \ - .withColumn('f2_f3_f1', df['v'] + 111) \ - .withColumn('f3_f1_f2', df['v'] + 111) \ - .withColumn('f3_f2_f1', df['v'] + 111) - - self.assertEquals(expected.collect(), df1.collect()) - - # SPARK-24721 - @unittest.skipIf(not _test_compiled, _test_not_compiled_message) - def test_datasource_with_udf(self): - # Same as SQLTests.test_datasource_with_udf, but with Pandas UDF - # This needs to a separate test because Arrow dependency is optional - import pandas as pd - import numpy as np - from pyspark.sql.functions import pandas_udf, lit, col - - path = tempfile.mkdtemp() - shutil.rmtree(path) - - try: - self.spark.range(1).write.mode("overwrite").format('csv').save(path) - filesource_df = self.spark.read.option('inferSchema', True).csv(path).toDF('i') - datasource_df = self.spark.read \ - .format("org.apache.spark.sql.sources.SimpleScanSource") \ - .option('from', 0).option('to', 1).load().toDF('i') - datasource_v2_df = self.spark.read \ - .format("org.apache.spark.sql.sources.v2.SimpleDataSourceV2") \ - .load().toDF('i', 'j') - - c1 = pandas_udf(lambda x: x + 1, 'int')(lit(1)) - c2 = pandas_udf(lambda x: x + 1, 'int')(col('i')) - - f1 = pandas_udf(lambda x: pd.Series(np.repeat(False, len(x))), 'boolean')(lit(1)) - f2 = pandas_udf(lambda x: pd.Series(np.repeat(False, len(x))), 'boolean')(col('i')) - - for df in [filesource_df, datasource_df, datasource_v2_df]: - result = df.withColumn('c', c1) - expected = df.withColumn('c', lit(2)) - self.assertEquals(expected.collect(), result.collect()) - - for df in [filesource_df, datasource_df, datasource_v2_df]: - result = df.withColumn('c', c2) - expected = df.withColumn('c', col('i') + 1) - self.assertEquals(expected.collect(), result.collect()) - - for df in [filesource_df, datasource_df, datasource_v2_df]: - for f in [f1, f2]: - result = df.filter(f) - self.assertEquals(0, result.count()) - finally: - shutil.rmtree(path) - - -@unittest.skipIf( - not _have_pandas or not _have_pyarrow, - _pandas_requirement_message or _pyarrow_requirement_message) -class GroupedMapPandasUDFTests(ReusedSQLTestCase): - - @property - def data(self): - from pyspark.sql.functions import array, explode, col, lit - return self.spark.range(10).toDF('id') \ - .withColumn("vs", array([lit(i) for i in range(20, 30)])) \ - .withColumn("v", explode(col('vs'))).drop('vs') - - def test_supported_types(self): - from decimal import Decimal - from distutils.version import LooseVersion - import pyarrow as pa - from pyspark.sql.functions import pandas_udf, PandasUDFType - - values = [ - 1, 2, 3, - 4, 5, 1.1, - 2.2, Decimal(1.123), - [1, 2, 2], True, 'hello' - ] - output_fields = [ - ('id', IntegerType()), ('byte', ByteType()), ('short', ShortType()), - ('int', IntegerType()), ('long', LongType()), ('float', FloatType()), - ('double', DoubleType()), ('decim', DecimalType(10, 3)), - ('array', ArrayType(IntegerType())), ('bool', BooleanType()), ('str', StringType()) - ] - - # TODO: Add BinaryType to variables above once minimum pyarrow version is 0.10.0 - if LooseVersion(pa.__version__) >= LooseVersion("0.10.0"): - values.append(bytearray([0x01, 0x02])) - output_fields.append(('bin', BinaryType())) - - output_schema = StructType([StructField(*x) for x in output_fields]) - df = self.spark.createDataFrame([values], schema=output_schema) - - # Different forms of group map pandas UDF, results of these are the same - udf1 = pandas_udf( - lambda pdf: pdf.assign( - byte=pdf.byte * 2, - short=pdf.short * 2, - int=pdf.int * 2, - long=pdf.long * 2, - float=pdf.float * 2, - double=pdf.double * 2, - decim=pdf.decim * 2, - bool=False if pdf.bool else True, - str=pdf.str + 'there', - array=pdf.array, - ), - output_schema, - PandasUDFType.GROUPED_MAP - ) - - udf2 = pandas_udf( - lambda _, pdf: pdf.assign( - byte=pdf.byte * 2, - short=pdf.short * 2, - int=pdf.int * 2, - long=pdf.long * 2, - float=pdf.float * 2, - double=pdf.double * 2, - decim=pdf.decim * 2, - bool=False if pdf.bool else True, - str=pdf.str + 'there', - array=pdf.array, - ), - output_schema, - PandasUDFType.GROUPED_MAP - ) - - udf3 = pandas_udf( - lambda key, pdf: pdf.assign( - id=key[0], - byte=pdf.byte * 2, - short=pdf.short * 2, - int=pdf.int * 2, - long=pdf.long * 2, - float=pdf.float * 2, - double=pdf.double * 2, - decim=pdf.decim * 2, - bool=False if pdf.bool else True, - str=pdf.str + 'there', - array=pdf.array, - ), - output_schema, - PandasUDFType.GROUPED_MAP - ) - - result1 = df.groupby('id').apply(udf1).sort('id').toPandas() - expected1 = df.toPandas().groupby('id').apply(udf1.func).reset_index(drop=True) - - result2 = df.groupby('id').apply(udf2).sort('id').toPandas() - expected2 = expected1 - - result3 = df.groupby('id').apply(udf3).sort('id').toPandas() - expected3 = expected1 - - self.assertPandasEqual(expected1, result1) - self.assertPandasEqual(expected2, result2) - self.assertPandasEqual(expected3, result3) - - def test_array_type_correct(self): - from pyspark.sql.functions import pandas_udf, PandasUDFType, array, col - - df = self.data.withColumn("arr", array(col("id"))).repartition(1, "id") - - output_schema = StructType( - [StructField('id', LongType()), - StructField('v', IntegerType()), - StructField('arr', ArrayType(LongType()))]) - - udf = pandas_udf( - lambda pdf: pdf, - output_schema, - PandasUDFType.GROUPED_MAP - ) - - result = df.groupby('id').apply(udf).sort('id').toPandas() - expected = df.toPandas().groupby('id').apply(udf.func).reset_index(drop=True) - self.assertPandasEqual(expected, result) - - def test_register_grouped_map_udf(self): - from pyspark.sql.functions import pandas_udf, PandasUDFType - - foo_udf = pandas_udf(lambda x: x, "id long", PandasUDFType.GROUPED_MAP) - with QuietTest(self.sc): - with self.assertRaisesRegexp( - ValueError, - 'f.*SQL_BATCHED_UDF.*SQL_SCALAR_PANDAS_UDF.*SQL_GROUPED_AGG_PANDAS_UDF.*'): - self.spark.catalog.registerFunction("foo_udf", foo_udf) - - def test_decorator(self): - from pyspark.sql.functions import pandas_udf, PandasUDFType - df = self.data - - @pandas_udf( - 'id long, v int, v1 double, v2 long', - PandasUDFType.GROUPED_MAP - ) - def foo(pdf): - return pdf.assign(v1=pdf.v * pdf.id * 1.0, v2=pdf.v + pdf.id) - - result = df.groupby('id').apply(foo).sort('id').toPandas() - expected = df.toPandas().groupby('id').apply(foo.func).reset_index(drop=True) - self.assertPandasEqual(expected, result) - - def test_coerce(self): - from pyspark.sql.functions import pandas_udf, PandasUDFType - df = self.data - - foo = pandas_udf( - lambda pdf: pdf, - 'id long, v double', - PandasUDFType.GROUPED_MAP - ) - - result = df.groupby('id').apply(foo).sort('id').toPandas() - expected = df.toPandas().groupby('id').apply(foo.func).reset_index(drop=True) - expected = expected.assign(v=expected.v.astype('float64')) - self.assertPandasEqual(expected, result) - - def test_complex_groupby(self): - from pyspark.sql.functions import pandas_udf, col, PandasUDFType - df = self.data - - @pandas_udf( - 'id long, v int, norm double', - PandasUDFType.GROUPED_MAP - ) - def normalize(pdf): - v = pdf.v - return pdf.assign(norm=(v - v.mean()) / v.std()) - - result = df.groupby(col('id') % 2 == 0).apply(normalize).sort('id', 'v').toPandas() - pdf = df.toPandas() - expected = pdf.groupby(pdf['id'] % 2 == 0).apply(normalize.func) - expected = expected.sort_values(['id', 'v']).reset_index(drop=True) - expected = expected.assign(norm=expected.norm.astype('float64')) - self.assertPandasEqual(expected, result) - - def test_empty_groupby(self): - from pyspark.sql.functions import pandas_udf, col, PandasUDFType - df = self.data - - @pandas_udf( - 'id long, v int, norm double', - PandasUDFType.GROUPED_MAP - ) - def normalize(pdf): - v = pdf.v - return pdf.assign(norm=(v - v.mean()) / v.std()) - - result = df.groupby().apply(normalize).sort('id', 'v').toPandas() - pdf = df.toPandas() - expected = normalize.func(pdf) - expected = expected.sort_values(['id', 'v']).reset_index(drop=True) - expected = expected.assign(norm=expected.norm.astype('float64')) - self.assertPandasEqual(expected, result) - - def test_datatype_string(self): - from pyspark.sql.functions import pandas_udf, PandasUDFType - df = self.data - - foo_udf = pandas_udf( - lambda pdf: pdf.assign(v1=pdf.v * pdf.id * 1.0, v2=pdf.v + pdf.id), - 'id long, v int, v1 double, v2 long', - PandasUDFType.GROUPED_MAP - ) - - result = df.groupby('id').apply(foo_udf).sort('id').toPandas() - expected = df.toPandas().groupby('id').apply(foo_udf.func).reset_index(drop=True) - self.assertPandasEqual(expected, result) - - def test_wrong_return_type(self): - from pyspark.sql.functions import pandas_udf, PandasUDFType - - with QuietTest(self.sc): - with self.assertRaisesRegexp( - NotImplementedError, - 'Invalid returnType.*grouped map Pandas UDF.*MapType'): - pandas_udf( - lambda pdf: pdf, - 'id long, v map', - PandasUDFType.GROUPED_MAP) - - def test_wrong_args(self): - from pyspark.sql.functions import udf, pandas_udf, sum, PandasUDFType - df = self.data - - with QuietTest(self.sc): - with self.assertRaisesRegexp(ValueError, 'Invalid udf'): - df.groupby('id').apply(lambda x: x) - with self.assertRaisesRegexp(ValueError, 'Invalid udf'): - df.groupby('id').apply(udf(lambda x: x, DoubleType())) - with self.assertRaisesRegexp(ValueError, 'Invalid udf'): - df.groupby('id').apply(sum(df.v)) - with self.assertRaisesRegexp(ValueError, 'Invalid udf'): - df.groupby('id').apply(df.v + 1) - with self.assertRaisesRegexp(ValueError, 'Invalid function'): - df.groupby('id').apply( - pandas_udf(lambda: 1, StructType([StructField("d", DoubleType())]))) - with self.assertRaisesRegexp(ValueError, 'Invalid udf'): - df.groupby('id').apply(pandas_udf(lambda x, y: x, DoubleType())) - with self.assertRaisesRegexp(ValueError, 'Invalid udf.*GROUPED_MAP'): - df.groupby('id').apply( - pandas_udf(lambda x, y: x, DoubleType(), PandasUDFType.SCALAR)) - - def test_unsupported_types(self): - from distutils.version import LooseVersion - import pyarrow as pa - from pyspark.sql.functions import pandas_udf, PandasUDFType - - common_err_msg = 'Invalid returnType.*grouped map Pandas UDF.*' - unsupported_types = [ - StructField('map', MapType(StringType(), IntegerType())), - StructField('arr_ts', ArrayType(TimestampType())), - StructField('null', NullType()), - ] - - # TODO: Remove this if-statement once minimum pyarrow version is 0.10.0 - if LooseVersion(pa.__version__) < LooseVersion("0.10.0"): - unsupported_types.append(StructField('bin', BinaryType())) - - for unsupported_type in unsupported_types: - schema = StructType([StructField('id', LongType(), True), unsupported_type]) - with QuietTest(self.sc): - with self.assertRaisesRegexp(NotImplementedError, common_err_msg): - pandas_udf(lambda x: x, schema, PandasUDFType.GROUPED_MAP) - - # Regression test for SPARK-23314 - def test_timestamp_dst(self): - from pyspark.sql.functions import pandas_udf, PandasUDFType - # Daylight saving time for Los Angeles for 2015 is Sun, Nov 1 at 2:00 am - dt = [datetime.datetime(2015, 11, 1, 0, 30), - datetime.datetime(2015, 11, 1, 1, 30), - datetime.datetime(2015, 11, 1, 2, 30)] - df = self.spark.createDataFrame(dt, 'timestamp').toDF('time') - foo_udf = pandas_udf(lambda pdf: pdf, 'time timestamp', PandasUDFType.GROUPED_MAP) - result = df.groupby('time').apply(foo_udf).sort('time') - self.assertPandasEqual(df.toPandas(), result.toPandas()) - - def test_udf_with_key(self): - from pyspark.sql.functions import pandas_udf, col, PandasUDFType - df = self.data - pdf = df.toPandas() - - def foo1(key, pdf): - import numpy as np - assert type(key) == tuple - assert type(key[0]) == np.int64 - - return pdf.assign(v1=key[0], - v2=pdf.v * key[0], - v3=pdf.v * pdf.id, - v4=pdf.v * pdf.id.mean()) - - def foo2(key, pdf): - import numpy as np - assert type(key) == tuple - assert type(key[0]) == np.int64 - assert type(key[1]) == np.int32 - - return pdf.assign(v1=key[0], - v2=key[1], - v3=pdf.v * key[0], - v4=pdf.v + key[1]) - - def foo3(key, pdf): - assert type(key) == tuple - assert len(key) == 0 - return pdf.assign(v1=pdf.v * pdf.id) - - # v2 is int because numpy.int64 * pd.Series results in pd.Series - # v3 is long because pd.Series * pd.Series results in pd.Series - udf1 = pandas_udf( - foo1, - 'id long, v int, v1 long, v2 int, v3 long, v4 double', - PandasUDFType.GROUPED_MAP) - - udf2 = pandas_udf( - foo2, - 'id long, v int, v1 long, v2 int, v3 int, v4 int', - PandasUDFType.GROUPED_MAP) - - udf3 = pandas_udf( - foo3, - 'id long, v int, v1 long', - PandasUDFType.GROUPED_MAP) - - # Test groupby column - result1 = df.groupby('id').apply(udf1).sort('id', 'v').toPandas() - expected1 = pdf.groupby('id')\ - .apply(lambda x: udf1.func((x.id.iloc[0],), x))\ - .sort_values(['id', 'v']).reset_index(drop=True) - self.assertPandasEqual(expected1, result1) - - # Test groupby expression - result2 = df.groupby(df.id % 2).apply(udf1).sort('id', 'v').toPandas() - expected2 = pdf.groupby(pdf.id % 2)\ - .apply(lambda x: udf1.func((x.id.iloc[0] % 2,), x))\ - .sort_values(['id', 'v']).reset_index(drop=True) - self.assertPandasEqual(expected2, result2) - - # Test complex groupby - result3 = df.groupby(df.id, df.v % 2).apply(udf2).sort('id', 'v').toPandas() - expected3 = pdf.groupby([pdf.id, pdf.v % 2])\ - .apply(lambda x: udf2.func((x.id.iloc[0], (x.v % 2).iloc[0],), x))\ - .sort_values(['id', 'v']).reset_index(drop=True) - self.assertPandasEqual(expected3, result3) - - # Test empty groupby - result4 = df.groupby().apply(udf3).sort('id', 'v').toPandas() - expected4 = udf3.func((), pdf) - self.assertPandasEqual(expected4, result4) - - def test_column_order(self): - from collections import OrderedDict - import pandas as pd - from pyspark.sql.functions import pandas_udf, PandasUDFType - - # Helper function to set column names from a list - def rename_pdf(pdf, names): - pdf.rename(columns={old: new for old, new in - zip(pd_result.columns, names)}, inplace=True) - - df = self.data - grouped_df = df.groupby('id') - grouped_pdf = df.toPandas().groupby('id') - - # Function returns a pdf with required column names, but order could be arbitrary using dict - def change_col_order(pdf): - # Constructing a DataFrame from a dict should result in the same order, - # but use from_items to ensure the pdf column order is different than schema - return pd.DataFrame.from_items([ - ('id', pdf.id), - ('u', pdf.v * 2), - ('v', pdf.v)]) - - ordered_udf = pandas_udf( - change_col_order, - 'id long, v int, u int', - PandasUDFType.GROUPED_MAP - ) - - # The UDF result should assign columns by name from the pdf - result = grouped_df.apply(ordered_udf).sort('id', 'v')\ - .select('id', 'u', 'v').toPandas() - pd_result = grouped_pdf.apply(change_col_order) - expected = pd_result.sort_values(['id', 'v']).reset_index(drop=True) - self.assertPandasEqual(expected, result) - - # Function returns a pdf with positional columns, indexed by range - def range_col_order(pdf): - # Create a DataFrame with positional columns, fix types to long - return pd.DataFrame(list(zip(pdf.id, pdf.v * 3, pdf.v)), dtype='int64') - - range_udf = pandas_udf( - range_col_order, - 'id long, u long, v long', - PandasUDFType.GROUPED_MAP - ) - - # The UDF result uses positional columns from the pdf - result = grouped_df.apply(range_udf).sort('id', 'v') \ - .select('id', 'u', 'v').toPandas() - pd_result = grouped_pdf.apply(range_col_order) - rename_pdf(pd_result, ['id', 'u', 'v']) - expected = pd_result.sort_values(['id', 'v']).reset_index(drop=True) - self.assertPandasEqual(expected, result) - - # Function returns a pdf with columns indexed with integers - def int_index(pdf): - return pd.DataFrame(OrderedDict([(0, pdf.id), (1, pdf.v * 4), (2, pdf.v)])) - - int_index_udf = pandas_udf( - int_index, - 'id long, u int, v int', - PandasUDFType.GROUPED_MAP - ) - - # The UDF result should assign columns by position of integer index - result = grouped_df.apply(int_index_udf).sort('id', 'v') \ - .select('id', 'u', 'v').toPandas() - pd_result = grouped_pdf.apply(int_index) - rename_pdf(pd_result, ['id', 'u', 'v']) - expected = pd_result.sort_values(['id', 'v']).reset_index(drop=True) - self.assertPandasEqual(expected, result) - - @pandas_udf('id long, v int', PandasUDFType.GROUPED_MAP) - def column_name_typo(pdf): - return pd.DataFrame({'iid': pdf.id, 'v': pdf.v}) - - @pandas_udf('id long, v int', PandasUDFType.GROUPED_MAP) - def invalid_positional_types(pdf): - return pd.DataFrame([(u'a', 1.2)]) - - with QuietTest(self.sc): - with self.assertRaisesRegexp(Exception, "KeyError: 'id'"): - grouped_df.apply(column_name_typo).collect() - with self.assertRaisesRegexp(Exception, "No cast implemented"): - grouped_df.apply(invalid_positional_types).collect() - - def test_positional_assignment_conf(self): - import pandas as pd - from pyspark.sql.functions import pandas_udf, PandasUDFType - - with self.sql_conf({ - "spark.sql.legacy.execution.pandas.groupedMap.assignColumnsByName": False}): - - @pandas_udf("a string, b float", PandasUDFType.GROUPED_MAP) - def foo(_): - return pd.DataFrame([('hi', 1)], columns=['x', 'y']) - - df = self.data - result = df.groupBy('id').apply(foo).select('a', 'b').collect() - for r in result: - self.assertEqual(r.a, 'hi') - self.assertEqual(r.b, 1) - - def test_self_join_with_pandas(self): - import pyspark.sql.functions as F - - @F.pandas_udf('key long, col string', F.PandasUDFType.GROUPED_MAP) - def dummy_pandas_udf(df): - return df[['key', 'col']] - - df = self.spark.createDataFrame([Row(key=1, col='A'), Row(key=1, col='B'), - Row(key=2, col='C')]) - df_with_pandas = df.groupBy('key').apply(dummy_pandas_udf) - - # this was throwing an AnalysisException before SPARK-24208 - res = df_with_pandas.alias('temp0').join(df_with_pandas.alias('temp1'), - F.col('temp0.key') == F.col('temp1.key')) - self.assertEquals(res.count(), 5) - - def test_mixed_scalar_udfs_followed_by_grouby_apply(self): - import pandas as pd - from pyspark.sql.functions import udf, pandas_udf, PandasUDFType - - df = self.spark.range(0, 10).toDF('v1') - df = df.withColumn('v2', udf(lambda x: x + 1, 'int')(df['v1'])) \ - .withColumn('v3', pandas_udf(lambda x: x + 2, 'int')(df['v1'])) - - result = df.groupby() \ - .apply(pandas_udf(lambda x: pd.DataFrame([x.sum().sum()]), - 'sum int', - PandasUDFType.GROUPED_MAP)) - - self.assertEquals(result.collect()[0]['sum'], 165) - - -@unittest.skipIf( - not _have_pandas or not _have_pyarrow, - _pandas_requirement_message or _pyarrow_requirement_message) -class GroupedAggPandasUDFTests(ReusedSQLTestCase): - - @property - def data(self): - from pyspark.sql.functions import array, explode, col, lit - return self.spark.range(10).toDF('id') \ - .withColumn("vs", array([lit(i * 1.0) + col('id') for i in range(20, 30)])) \ - .withColumn("v", explode(col('vs'))) \ - .drop('vs') \ - .withColumn('w', lit(1.0)) - - @property - def python_plus_one(self): - from pyspark.sql.functions import udf - - @udf('double') - def plus_one(v): - assert isinstance(v, (int, float)) - return v + 1 - return plus_one - - @property - def pandas_scalar_plus_two(self): - import pandas as pd - from pyspark.sql.functions import pandas_udf, PandasUDFType - - @pandas_udf('double', PandasUDFType.SCALAR) - def plus_two(v): - assert isinstance(v, pd.Series) - return v + 2 - return plus_two - - @property - def pandas_agg_mean_udf(self): - from pyspark.sql.functions import pandas_udf, PandasUDFType - - @pandas_udf('double', PandasUDFType.GROUPED_AGG) - def avg(v): - return v.mean() - return avg - - @property - def pandas_agg_sum_udf(self): - from pyspark.sql.functions import pandas_udf, PandasUDFType - - @pandas_udf('double', PandasUDFType.GROUPED_AGG) - def sum(v): - return v.sum() - return sum - - @property - def pandas_agg_weighted_mean_udf(self): - import numpy as np - from pyspark.sql.functions import pandas_udf, PandasUDFType - - @pandas_udf('double', PandasUDFType.GROUPED_AGG) - def weighted_mean(v, w): - return np.average(v, weights=w) - return weighted_mean - - def test_manual(self): - from pyspark.sql.functions import pandas_udf, array - - df = self.data - sum_udf = self.pandas_agg_sum_udf - mean_udf = self.pandas_agg_mean_udf - mean_arr_udf = pandas_udf( - self.pandas_agg_mean_udf.func, - ArrayType(self.pandas_agg_mean_udf.returnType), - self.pandas_agg_mean_udf.evalType) - - result1 = df.groupby('id').agg( - sum_udf(df.v), - mean_udf(df.v), - mean_arr_udf(array(df.v))).sort('id') - expected1 = self.spark.createDataFrame( - [[0, 245.0, 24.5, [24.5]], - [1, 255.0, 25.5, [25.5]], - [2, 265.0, 26.5, [26.5]], - [3, 275.0, 27.5, [27.5]], - [4, 285.0, 28.5, [28.5]], - [5, 295.0, 29.5, [29.5]], - [6, 305.0, 30.5, [30.5]], - [7, 315.0, 31.5, [31.5]], - [8, 325.0, 32.5, [32.5]], - [9, 335.0, 33.5, [33.5]]], - ['id', 'sum(v)', 'avg(v)', 'avg(array(v))']) - - self.assertPandasEqual(expected1.toPandas(), result1.toPandas()) - - def test_basic(self): - from pyspark.sql.functions import col, lit, sum, mean - - df = self.data - weighted_mean_udf = self.pandas_agg_weighted_mean_udf - - # Groupby one column and aggregate one UDF with literal - result1 = df.groupby('id').agg(weighted_mean_udf(df.v, lit(1.0))).sort('id') - expected1 = df.groupby('id').agg(mean(df.v).alias('weighted_mean(v, 1.0)')).sort('id') - self.assertPandasEqual(expected1.toPandas(), result1.toPandas()) - - # Groupby one expression and aggregate one UDF with literal - result2 = df.groupby((col('id') + 1)).agg(weighted_mean_udf(df.v, lit(1.0)))\ - .sort(df.id + 1) - expected2 = df.groupby((col('id') + 1))\ - .agg(mean(df.v).alias('weighted_mean(v, 1.0)')).sort(df.id + 1) - self.assertPandasEqual(expected2.toPandas(), result2.toPandas()) - - # Groupby one column and aggregate one UDF without literal - result3 = df.groupby('id').agg(weighted_mean_udf(df.v, df.w)).sort('id') - expected3 = df.groupby('id').agg(mean(df.v).alias('weighted_mean(v, w)')).sort('id') - self.assertPandasEqual(expected3.toPandas(), result3.toPandas()) - - # Groupby one expression and aggregate one UDF without literal - result4 = df.groupby((col('id') + 1).alias('id'))\ - .agg(weighted_mean_udf(df.v, df.w))\ - .sort('id') - expected4 = df.groupby((col('id') + 1).alias('id'))\ - .agg(mean(df.v).alias('weighted_mean(v, w)'))\ - .sort('id') - self.assertPandasEqual(expected4.toPandas(), result4.toPandas()) - - def test_unsupported_types(self): - from pyspark.sql.types import DoubleType, MapType - from pyspark.sql.functions import pandas_udf, PandasUDFType - - with QuietTest(self.sc): - with self.assertRaisesRegexp(NotImplementedError, 'not supported'): - pandas_udf( - lambda x: x, - ArrayType(ArrayType(TimestampType())), - PandasUDFType.GROUPED_AGG) - - with QuietTest(self.sc): - with self.assertRaisesRegexp(NotImplementedError, 'not supported'): - @pandas_udf('mean double, std double', PandasUDFType.GROUPED_AGG) - def mean_and_std_udf(v): - return v.mean(), v.std() - - with QuietTest(self.sc): - with self.assertRaisesRegexp(NotImplementedError, 'not supported'): - @pandas_udf(MapType(DoubleType(), DoubleType()), PandasUDFType.GROUPED_AGG) - def mean_and_std_udf(v): - return {v.mean(): v.std()} - - def test_alias(self): - from pyspark.sql.functions import mean - - df = self.data - mean_udf = self.pandas_agg_mean_udf - - result1 = df.groupby('id').agg(mean_udf(df.v).alias('mean_alias')) - expected1 = df.groupby('id').agg(mean(df.v).alias('mean_alias')) - - self.assertPandasEqual(expected1.toPandas(), result1.toPandas()) - - def test_mixed_sql(self): - """ - Test mixing group aggregate pandas UDF with sql expression. - """ - from pyspark.sql.functions import sum, mean - - df = self.data - sum_udf = self.pandas_agg_sum_udf - - # Mix group aggregate pandas UDF with sql expression - result1 = (df.groupby('id') - .agg(sum_udf(df.v) + 1) - .sort('id')) - expected1 = (df.groupby('id') - .agg(sum(df.v) + 1) - .sort('id')) - - # Mix group aggregate pandas UDF with sql expression (order swapped) - result2 = (df.groupby('id') - .agg(sum_udf(df.v + 1)) - .sort('id')) - - expected2 = (df.groupby('id') - .agg(sum(df.v + 1)) - .sort('id')) - - # Wrap group aggregate pandas UDF with two sql expressions - result3 = (df.groupby('id') - .agg(sum_udf(df.v + 1) + 2) - .sort('id')) - expected3 = (df.groupby('id') - .agg(sum(df.v + 1) + 2) - .sort('id')) - - self.assertPandasEqual(expected1.toPandas(), result1.toPandas()) - self.assertPandasEqual(expected2.toPandas(), result2.toPandas()) - self.assertPandasEqual(expected3.toPandas(), result3.toPandas()) - - def test_mixed_udfs(self): - """ - Test mixing group aggregate pandas UDF with python UDF and scalar pandas UDF. - """ - from pyspark.sql.functions import sum, mean - - df = self.data - plus_one = self.python_plus_one - plus_two = self.pandas_scalar_plus_two - sum_udf = self.pandas_agg_sum_udf - - # Mix group aggregate pandas UDF and python UDF - result1 = (df.groupby('id') - .agg(plus_one(sum_udf(df.v))) - .sort('id')) - expected1 = (df.groupby('id') - .agg(plus_one(sum(df.v))) - .sort('id')) - - # Mix group aggregate pandas UDF and python UDF (order swapped) - result2 = (df.groupby('id') - .agg(sum_udf(plus_one(df.v))) - .sort('id')) - expected2 = (df.groupby('id') - .agg(sum(plus_one(df.v))) - .sort('id')) - - # Mix group aggregate pandas UDF and scalar pandas UDF - result3 = (df.groupby('id') - .agg(sum_udf(plus_two(df.v))) - .sort('id')) - expected3 = (df.groupby('id') - .agg(sum(plus_two(df.v))) - .sort('id')) - - # Mix group aggregate pandas UDF and scalar pandas UDF (order swapped) - result4 = (df.groupby('id') - .agg(plus_two(sum_udf(df.v))) - .sort('id')) - expected4 = (df.groupby('id') - .agg(plus_two(sum(df.v))) - .sort('id')) - - # Wrap group aggregate pandas UDF with two python UDFs and use python UDF in groupby - result5 = (df.groupby(plus_one(df.id)) - .agg(plus_one(sum_udf(plus_one(df.v)))) - .sort('plus_one(id)')) - expected5 = (df.groupby(plus_one(df.id)) - .agg(plus_one(sum(plus_one(df.v)))) - .sort('plus_one(id)')) - - # Wrap group aggregate pandas UDF with two scala pandas UDF and user scala pandas UDF in - # groupby - result6 = (df.groupby(plus_two(df.id)) - .agg(plus_two(sum_udf(plus_two(df.v)))) - .sort('plus_two(id)')) - expected6 = (df.groupby(plus_two(df.id)) - .agg(plus_two(sum(plus_two(df.v)))) - .sort('plus_two(id)')) - - self.assertPandasEqual(expected1.toPandas(), result1.toPandas()) - self.assertPandasEqual(expected2.toPandas(), result2.toPandas()) - self.assertPandasEqual(expected3.toPandas(), result3.toPandas()) - self.assertPandasEqual(expected4.toPandas(), result4.toPandas()) - self.assertPandasEqual(expected5.toPandas(), result5.toPandas()) - self.assertPandasEqual(expected6.toPandas(), result6.toPandas()) - - def test_multiple_udfs(self): - """ - Test multiple group aggregate pandas UDFs in one agg function. - """ - from pyspark.sql.functions import col, lit, sum, mean - - df = self.data - mean_udf = self.pandas_agg_mean_udf - sum_udf = self.pandas_agg_sum_udf - weighted_mean_udf = self.pandas_agg_weighted_mean_udf - - result1 = (df.groupBy('id') - .agg(mean_udf(df.v), - sum_udf(df.v), - weighted_mean_udf(df.v, df.w)) - .sort('id') - .toPandas()) - expected1 = (df.groupBy('id') - .agg(mean(df.v), - sum(df.v), - mean(df.v).alias('weighted_mean(v, w)')) - .sort('id') - .toPandas()) - - self.assertPandasEqual(expected1, result1) - - def test_complex_groupby(self): - from pyspark.sql.functions import lit, sum - - df = self.data - sum_udf = self.pandas_agg_sum_udf - plus_one = self.python_plus_one - plus_two = self.pandas_scalar_plus_two - - # groupby one expression - result1 = df.groupby(df.v % 2).agg(sum_udf(df.v)) - expected1 = df.groupby(df.v % 2).agg(sum(df.v)) - - # empty groupby - result2 = df.groupby().agg(sum_udf(df.v)) - expected2 = df.groupby().agg(sum(df.v)) - - # groupby one column and one sql expression - result3 = df.groupby(df.id, df.v % 2).agg(sum_udf(df.v)).orderBy(df.id, df.v % 2) - expected3 = df.groupby(df.id, df.v % 2).agg(sum(df.v)).orderBy(df.id, df.v % 2) - - # groupby one python UDF - result4 = df.groupby(plus_one(df.id)).agg(sum_udf(df.v)) - expected4 = df.groupby(plus_one(df.id)).agg(sum(df.v)) - - # groupby one scalar pandas UDF - result5 = df.groupby(plus_two(df.id)).agg(sum_udf(df.v)) - expected5 = df.groupby(plus_two(df.id)).agg(sum(df.v)) - - # groupby one expression and one python UDF - result6 = df.groupby(df.v % 2, plus_one(df.id)).agg(sum_udf(df.v)) - expected6 = df.groupby(df.v % 2, plus_one(df.id)).agg(sum(df.v)) - - # groupby one expression and one scalar pandas UDF - result7 = df.groupby(df.v % 2, plus_two(df.id)).agg(sum_udf(df.v)).sort('sum(v)') - expected7 = df.groupby(df.v % 2, plus_two(df.id)).agg(sum(df.v)).sort('sum(v)') - - self.assertPandasEqual(expected1.toPandas(), result1.toPandas()) - self.assertPandasEqual(expected2.toPandas(), result2.toPandas()) - self.assertPandasEqual(expected3.toPandas(), result3.toPandas()) - self.assertPandasEqual(expected4.toPandas(), result4.toPandas()) - self.assertPandasEqual(expected5.toPandas(), result5.toPandas()) - self.assertPandasEqual(expected6.toPandas(), result6.toPandas()) - self.assertPandasEqual(expected7.toPandas(), result7.toPandas()) - - def test_complex_expressions(self): - from pyspark.sql.functions import col, sum - - df = self.data - plus_one = self.python_plus_one - plus_two = self.pandas_scalar_plus_two - sum_udf = self.pandas_agg_sum_udf - - # Test complex expressions with sql expression, python UDF and - # group aggregate pandas UDF - result1 = (df.withColumn('v1', plus_one(df.v)) - .withColumn('v2', df.v + 2) - .groupby(df.id, df.v % 2) - .agg(sum_udf(col('v')), - sum_udf(col('v1') + 3), - sum_udf(col('v2')) + 5, - plus_one(sum_udf(col('v1'))), - sum_udf(plus_one(col('v2')))) - .sort('id') - .toPandas()) - - expected1 = (df.withColumn('v1', df.v + 1) - .withColumn('v2', df.v + 2) - .groupby(df.id, df.v % 2) - .agg(sum(col('v')), - sum(col('v1') + 3), - sum(col('v2')) + 5, - plus_one(sum(col('v1'))), - sum(plus_one(col('v2')))) - .sort('id') - .toPandas()) - - # Test complex expressions with sql expression, scala pandas UDF and - # group aggregate pandas UDF - result2 = (df.withColumn('v1', plus_one(df.v)) - .withColumn('v2', df.v + 2) - .groupby(df.id, df.v % 2) - .agg(sum_udf(col('v')), - sum_udf(col('v1') + 3), - sum_udf(col('v2')) + 5, - plus_two(sum_udf(col('v1'))), - sum_udf(plus_two(col('v2')))) - .sort('id') - .toPandas()) - - expected2 = (df.withColumn('v1', df.v + 1) - .withColumn('v2', df.v + 2) - .groupby(df.id, df.v % 2) - .agg(sum(col('v')), - sum(col('v1') + 3), - sum(col('v2')) + 5, - plus_two(sum(col('v1'))), - sum(plus_two(col('v2')))) - .sort('id') - .toPandas()) - - # Test sequential groupby aggregate - result3 = (df.groupby('id') - .agg(sum_udf(df.v).alias('v')) - .groupby('id') - .agg(sum_udf(col('v'))) - .sort('id') - .toPandas()) - - expected3 = (df.groupby('id') - .agg(sum(df.v).alias('v')) - .groupby('id') - .agg(sum(col('v'))) - .sort('id') - .toPandas()) - - self.assertPandasEqual(expected1, result1) - self.assertPandasEqual(expected2, result2) - self.assertPandasEqual(expected3, result3) - - def test_retain_group_columns(self): - from pyspark.sql.functions import sum, lit, col - with self.sql_conf({"spark.sql.retainGroupColumns": False}): - df = self.data - sum_udf = self.pandas_agg_sum_udf - - result1 = df.groupby(df.id).agg(sum_udf(df.v)) - expected1 = df.groupby(df.id).agg(sum(df.v)) - self.assertPandasEqual(expected1.toPandas(), result1.toPandas()) - - def test_array_type(self): - from pyspark.sql.functions import pandas_udf, PandasUDFType - - df = self.data - - array_udf = pandas_udf(lambda x: [1.0, 2.0], 'array', PandasUDFType.GROUPED_AGG) - result1 = df.groupby('id').agg(array_udf(df['v']).alias('v2')) - self.assertEquals(result1.first()['v2'], [1.0, 2.0]) - - def test_invalid_args(self): - from pyspark.sql.functions import mean - - df = self.data - plus_one = self.python_plus_one - mean_udf = self.pandas_agg_mean_udf - - with QuietTest(self.sc): - with self.assertRaisesRegexp( - AnalysisException, - 'nor.*aggregate function'): - df.groupby(df.id).agg(plus_one(df.v)).collect() - - with QuietTest(self.sc): - with self.assertRaisesRegexp( - AnalysisException, - 'aggregate function.*argument.*aggregate function'): - df.groupby(df.id).agg(mean_udf(mean_udf(df.v))).collect() - - with QuietTest(self.sc): - with self.assertRaisesRegexp( - AnalysisException, - 'mixture.*aggregate function.*group aggregate pandas UDF'): - df.groupby(df.id).agg(mean_udf(df.v), mean(df.v)).collect() - - def test_register_vectorized_udf_basic(self): - from pyspark.sql.functions import pandas_udf - from pyspark.rdd import PythonEvalType - - sum_pandas_udf = pandas_udf( - lambda v: v.sum(), "integer", PythonEvalType.SQL_GROUPED_AGG_PANDAS_UDF) - - self.assertEqual(sum_pandas_udf.evalType, PythonEvalType.SQL_GROUPED_AGG_PANDAS_UDF) - group_agg_pandas_udf = self.spark.udf.register("sum_pandas_udf", sum_pandas_udf) - self.assertEqual(group_agg_pandas_udf.evalType, PythonEvalType.SQL_GROUPED_AGG_PANDAS_UDF) - q = "SELECT sum_pandas_udf(v1) FROM VALUES (3, 0), (2, 0), (1, 1) tbl(v1, v2) GROUP BY v2" - actual = sorted(map(lambda r: r[0], self.spark.sql(q).collect())) - expected = [1, 5] - self.assertEqual(actual, expected) - - -@unittest.skipIf( - not _have_pandas or not _have_pyarrow, - _pandas_requirement_message or _pyarrow_requirement_message) -class WindowPandasUDFTests(ReusedSQLTestCase): - @property - def data(self): - from pyspark.sql.functions import array, explode, col, lit - return self.spark.range(10).toDF('id') \ - .withColumn("vs", array([lit(i * 1.0) + col('id') for i in range(20, 30)])) \ - .withColumn("v", explode(col('vs'))) \ - .drop('vs') \ - .withColumn('w', lit(1.0)) - - @property - def python_plus_one(self): - from pyspark.sql.functions import udf - return udf(lambda v: v + 1, 'double') - - @property - def pandas_scalar_time_two(self): - from pyspark.sql.functions import pandas_udf, PandasUDFType - return pandas_udf(lambda v: v * 2, 'double') - - @property - def pandas_agg_mean_udf(self): - from pyspark.sql.functions import pandas_udf, PandasUDFType - - @pandas_udf('double', PandasUDFType.GROUPED_AGG) - def avg(v): - return v.mean() - return avg - - @property - def pandas_agg_max_udf(self): - from pyspark.sql.functions import pandas_udf, PandasUDFType - - @pandas_udf('double', PandasUDFType.GROUPED_AGG) - def max(v): - return v.max() - return max - - @property - def pandas_agg_min_udf(self): - from pyspark.sql.functions import pandas_udf, PandasUDFType - - @pandas_udf('double', PandasUDFType.GROUPED_AGG) - def min(v): - return v.min() - return min - - @property - def unbounded_window(self): - return Window.partitionBy('id') \ - .rowsBetween(Window.unboundedPreceding, Window.unboundedFollowing) - - @property - def ordered_window(self): - return Window.partitionBy('id').orderBy('v') - - @property - def unpartitioned_window(self): - return Window.partitionBy() - - def test_simple(self): - from pyspark.sql.functions import pandas_udf, PandasUDFType, percent_rank, mean, max - - df = self.data - w = self.unbounded_window - - mean_udf = self.pandas_agg_mean_udf - - result1 = df.withColumn('mean_v', mean_udf(df['v']).over(w)) - expected1 = df.withColumn('mean_v', mean(df['v']).over(w)) - - result2 = df.select(mean_udf(df['v']).over(w)) - expected2 = df.select(mean(df['v']).over(w)) - - self.assertPandasEqual(expected1.toPandas(), result1.toPandas()) - self.assertPandasEqual(expected2.toPandas(), result2.toPandas()) - - def test_multiple_udfs(self): - from pyspark.sql.functions import max, min, mean - - df = self.data - w = self.unbounded_window - - result1 = df.withColumn('mean_v', self.pandas_agg_mean_udf(df['v']).over(w)) \ - .withColumn('max_v', self.pandas_agg_max_udf(df['v']).over(w)) \ - .withColumn('min_w', self.pandas_agg_min_udf(df['w']).over(w)) - - expected1 = df.withColumn('mean_v', mean(df['v']).over(w)) \ - .withColumn('max_v', max(df['v']).over(w)) \ - .withColumn('min_w', min(df['w']).over(w)) - - self.assertPandasEqual(expected1.toPandas(), result1.toPandas()) - - def test_replace_existing(self): - from pyspark.sql.functions import mean - - df = self.data - w = self.unbounded_window - - result1 = df.withColumn('v', self.pandas_agg_mean_udf(df['v']).over(w)) - expected1 = df.withColumn('v', mean(df['v']).over(w)) - - self.assertPandasEqual(expected1.toPandas(), result1.toPandas()) - - def test_mixed_sql(self): - from pyspark.sql.functions import mean - - df = self.data - w = self.unbounded_window - mean_udf = self.pandas_agg_mean_udf - - result1 = df.withColumn('v', mean_udf(df['v'] * 2).over(w) + 1) - expected1 = df.withColumn('v', mean(df['v'] * 2).over(w) + 1) - - self.assertPandasEqual(expected1.toPandas(), result1.toPandas()) - - def test_mixed_udf(self): - from pyspark.sql.functions import mean - - df = self.data - w = self.unbounded_window - - plus_one = self.python_plus_one - time_two = self.pandas_scalar_time_two - mean_udf = self.pandas_agg_mean_udf - - result1 = df.withColumn( - 'v2', - plus_one(mean_udf(plus_one(df['v'])).over(w))) - expected1 = df.withColumn( - 'v2', - plus_one(mean(plus_one(df['v'])).over(w))) - - result2 = df.withColumn( - 'v2', - time_two(mean_udf(time_two(df['v'])).over(w))) - expected2 = df.withColumn( - 'v2', - time_two(mean(time_two(df['v'])).over(w))) - - self.assertPandasEqual(expected1.toPandas(), result1.toPandas()) - self.assertPandasEqual(expected2.toPandas(), result2.toPandas()) - - def test_without_partitionBy(self): - from pyspark.sql.functions import mean - - df = self.data - w = self.unpartitioned_window - mean_udf = self.pandas_agg_mean_udf - - result1 = df.withColumn('v2', mean_udf(df['v']).over(w)) - expected1 = df.withColumn('v2', mean(df['v']).over(w)) - - result2 = df.select(mean_udf(df['v']).over(w)) - expected2 = df.select(mean(df['v']).over(w)) - - self.assertPandasEqual(expected1.toPandas(), result1.toPandas()) - self.assertPandasEqual(expected2.toPandas(), result2.toPandas()) - - def test_mixed_sql_and_udf(self): - from pyspark.sql.functions import max, min, rank, col - - df = self.data - w = self.unbounded_window - ow = self.ordered_window - max_udf = self.pandas_agg_max_udf - min_udf = self.pandas_agg_min_udf - - result1 = df.withColumn('v_diff', max_udf(df['v']).over(w) - min_udf(df['v']).over(w)) - expected1 = df.withColumn('v_diff', max(df['v']).over(w) - min(df['v']).over(w)) - - # Test mixing sql window function and window udf in the same expression - result2 = df.withColumn('v_diff', max_udf(df['v']).over(w) - min(df['v']).over(w)) - expected2 = expected1 - - # Test chaining sql aggregate function and udf - result3 = df.withColumn('max_v', max_udf(df['v']).over(w)) \ - .withColumn('min_v', min(df['v']).over(w)) \ - .withColumn('v_diff', col('max_v') - col('min_v')) \ - .drop('max_v', 'min_v') - expected3 = expected1 - - # Test mixing sql window function and udf - result4 = df.withColumn('max_v', max_udf(df['v']).over(w)) \ - .withColumn('rank', rank().over(ow)) - expected4 = df.withColumn('max_v', max(df['v']).over(w)) \ - .withColumn('rank', rank().over(ow)) - - self.assertPandasEqual(expected1.toPandas(), result1.toPandas()) - self.assertPandasEqual(expected2.toPandas(), result2.toPandas()) - self.assertPandasEqual(expected3.toPandas(), result3.toPandas()) - self.assertPandasEqual(expected4.toPandas(), result4.toPandas()) - - def test_array_type(self): - from pyspark.sql.functions import pandas_udf, PandasUDFType - - df = self.data - w = self.unbounded_window - - array_udf = pandas_udf(lambda x: [1.0, 2.0], 'array', PandasUDFType.GROUPED_AGG) - result1 = df.withColumn('v2', array_udf(df['v']).over(w)) - self.assertEquals(result1.first()['v2'], [1.0, 2.0]) - - def test_invalid_args(self): - from pyspark.sql.functions import mean, pandas_udf, PandasUDFType - - df = self.data - w = self.unbounded_window - ow = self.ordered_window - mean_udf = self.pandas_agg_mean_udf - - with QuietTest(self.sc): - with self.assertRaisesRegexp( - AnalysisException, - '.*not supported within a window function'): - foo_udf = pandas_udf(lambda x: x, 'v double', PandasUDFType.GROUPED_MAP) - df.withColumn('v2', foo_udf(df['v']).over(w)) - - with QuietTest(self.sc): - with self.assertRaisesRegexp( - AnalysisException, - '.*Only unbounded window frame is supported.*'): - df.withColumn('mean_v', mean_udf(df['v']).over(ow)) - - -if __name__ == "__main__": - from pyspark.sql.tests import * - - runner = unishark.BufferedTestRunner( - reporters=[unishark.XUnitReporter('target/test-reports/pyspark.sql_{}'.format( - os.path.basename(os.environ.get("PYSPARK_PYTHON", ""))))]) - unittest.main(testRunner=runner, verbosity=2) diff --git a/python/pyspark/sql/tests/__init__.py b/python/pyspark/sql/tests/__init__.py new file mode 100644 index 0000000000000..cce3acad34a49 --- /dev/null +++ b/python/pyspark/sql/tests/__init__.py @@ -0,0 +1,16 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# diff --git a/python/pyspark/sql/tests/test_appsubmit.py b/python/pyspark/sql/tests/test_appsubmit.py new file mode 100644 index 0000000000000..43abcde7785d8 --- /dev/null +++ b/python/pyspark/sql/tests/test_appsubmit.py @@ -0,0 +1,97 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import os +import subprocess +import tempfile + +import py4j + +from pyspark import SparkContext +from pyspark.tests.test_appsubmit import SparkSubmitTests + + +class HiveSparkSubmitTests(SparkSubmitTests): + + @classmethod + def setUpClass(cls): + # get a SparkContext to check for availability of Hive + sc = SparkContext('local[4]', cls.__name__) + cls.hive_available = True + try: + sc._jvm.org.apache.hadoop.hive.conf.HiveConf() + except py4j.protocol.Py4JError: + cls.hive_available = False + except TypeError: + cls.hive_available = False + finally: + # we don't need this SparkContext for the test + sc.stop() + + def setUp(self): + super(HiveSparkSubmitTests, self).setUp() + if not self.hive_available: + self.skipTest("Hive is not available.") + + def test_hivecontext(self): + # This test checks that HiveContext is using Hive metastore (SPARK-16224). + # It sets a metastore url and checks if there is a derby dir created by + # Hive metastore. If this derby dir exists, HiveContext is using + # Hive metastore. + metastore_path = os.path.join(tempfile.mkdtemp(), "spark16224_metastore_db") + metastore_URL = "jdbc:derby:;databaseName=" + metastore_path + ";create=true" + hive_site_dir = os.path.join(self.programDir, "conf") + hive_site_file = self.createTempFile("hive-site.xml", (""" + | + | + | javax.jdo.option.ConnectionURL + | %s + | + | + """ % metastore_URL).lstrip(), "conf") + script = self.createTempFile("test.py", """ + |import os + | + |from pyspark.conf import SparkConf + |from pyspark.context import SparkContext + |from pyspark.sql import HiveContext + | + |conf = SparkConf() + |sc = SparkContext(conf=conf) + |hive_context = HiveContext(sc) + |print(hive_context.sql("show databases").collect()) + """) + proc = subprocess.Popen( + self.sparkSubmit + ["--master", "local-cluster[1,1,1024]", + "--driver-class-path", hive_site_dir, script], + stdout=subprocess.PIPE) + out, err = proc.communicate() + self.assertEqual(0, proc.returncode) + self.assertIn("default", out.decode('utf-8')) + self.assertTrue(os.path.exists(metastore_path)) + + +if __name__ == "__main__": + import unittest + from pyspark.sql.tests.test_appsubmit import * + + try: + import xmlrunner + testRunner = xmlrunner.XMLTestRunner(output='target/test-reports') + except ImportError: + testRunner = None + unittest.main(testRunner=testRunner, verbosity=2) diff --git a/python/pyspark/sql/tests/test_arrow.py b/python/pyspark/sql/tests/test_arrow.py new file mode 100644 index 0000000000000..751c2eddf5cb6 --- /dev/null +++ b/python/pyspark/sql/tests/test_arrow.py @@ -0,0 +1,400 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import datetime +import os +import threading +import time +import unittest +import warnings + +from pyspark.sql import Row +from pyspark.sql.types import * +from pyspark.testing.sqlutils import ReusedSQLTestCase, have_pandas, have_pyarrow, \ + pandas_requirement_message, pyarrow_requirement_message +from pyspark.testing.utils import QuietTest +from pyspark.util import _exception_message + + +@unittest.skipIf( + not have_pandas or not have_pyarrow, + pandas_requirement_message or pyarrow_requirement_message) +class ArrowTests(ReusedSQLTestCase): + + @classmethod + def setUpClass(cls): + from datetime import date, datetime + from decimal import Decimal + from distutils.version import LooseVersion + import pyarrow as pa + super(ArrowTests, cls).setUpClass() + cls.warnings_lock = threading.Lock() + + # Synchronize default timezone between Python and Java + cls.tz_prev = os.environ.get("TZ", None) # save current tz if set + tz = "UTC" + os.environ["TZ"] = tz + time.tzset() + + cls.spark.conf.set("spark.sql.session.timeZone", tz) + cls.spark.conf.set("spark.sql.execution.arrow.enabled", "true") + # Disable fallback by default to easily detect the failures. + cls.spark.conf.set("spark.sql.execution.arrow.fallback.enabled", "false") + cls.schema = StructType([ + StructField("1_str_t", StringType(), True), + StructField("2_int_t", IntegerType(), True), + StructField("3_long_t", LongType(), True), + StructField("4_float_t", FloatType(), True), + StructField("5_double_t", DoubleType(), True), + StructField("6_decimal_t", DecimalType(38, 18), True), + StructField("7_date_t", DateType(), True), + StructField("8_timestamp_t", TimestampType(), True)]) + cls.data = [(u"a", 1, 10, 0.2, 2.0, Decimal("2.0"), + date(1969, 1, 1), datetime(1969, 1, 1, 1, 1, 1)), + (u"b", 2, 20, 0.4, 4.0, Decimal("4.0"), + date(2012, 2, 2), datetime(2012, 2, 2, 2, 2, 2)), + (u"c", 3, 30, 0.8, 6.0, Decimal("6.0"), + date(2100, 3, 3), datetime(2100, 3, 3, 3, 3, 3))] + + # TODO: remove version check once minimum pyarrow version is 0.10.0 + if LooseVersion("0.10.0") <= LooseVersion(pa.__version__): + cls.schema.add(StructField("9_binary_t", BinaryType(), True)) + cls.data[0] = cls.data[0] + (bytearray(b"a"),) + cls.data[1] = cls.data[1] + (bytearray(b"bb"),) + cls.data[2] = cls.data[2] + (bytearray(b"ccc"),) + + @classmethod + def tearDownClass(cls): + del os.environ["TZ"] + if cls.tz_prev is not None: + os.environ["TZ"] = cls.tz_prev + time.tzset() + super(ArrowTests, cls).tearDownClass() + + def create_pandas_data_frame(self): + import pandas as pd + import numpy as np + data_dict = {} + for j, name in enumerate(self.schema.names): + data_dict[name] = [self.data[i][j] for i in range(len(self.data))] + # need to convert these to numpy types first + data_dict["2_int_t"] = np.int32(data_dict["2_int_t"]) + data_dict["4_float_t"] = np.float32(data_dict["4_float_t"]) + return pd.DataFrame(data=data_dict) + + def test_toPandas_fallback_enabled(self): + import pandas as pd + + with self.sql_conf({"spark.sql.execution.arrow.fallback.enabled": True}): + schema = StructType([StructField("map", MapType(StringType(), IntegerType()), True)]) + df = self.spark.createDataFrame([({u'a': 1},)], schema=schema) + with QuietTest(self.sc): + with self.warnings_lock: + with warnings.catch_warnings(record=True) as warns: + # we want the warnings to appear even if this test is run from a subclass + warnings.simplefilter("always") + pdf = df.toPandas() + # Catch and check the last UserWarning. + user_warns = [ + warn.message for warn in warns if isinstance(warn.message, UserWarning)] + self.assertTrue(len(user_warns) > 0) + self.assertTrue( + "Attempting non-optimization" in _exception_message(user_warns[-1])) + self.assertPandasEqual(pdf, pd.DataFrame({u'map': [{u'a': 1}]})) + + def test_toPandas_fallback_disabled(self): + from distutils.version import LooseVersion + import pyarrow as pa + + schema = StructType([StructField("map", MapType(StringType(), IntegerType()), True)]) + df = self.spark.createDataFrame([(None,)], schema=schema) + with QuietTest(self.sc): + with self.warnings_lock: + with self.assertRaisesRegexp(Exception, 'Unsupported type'): + df.toPandas() + + # TODO: remove BinaryType check once minimum pyarrow version is 0.10.0 + if LooseVersion(pa.__version__) < LooseVersion("0.10.0"): + schema = StructType([StructField("binary", BinaryType(), True)]) + df = self.spark.createDataFrame([(None,)], schema=schema) + with QuietTest(self.sc): + with self.assertRaisesRegexp(Exception, 'Unsupported type.*BinaryType'): + df.toPandas() + + def test_null_conversion(self): + df_null = self.spark.createDataFrame([tuple([None for _ in range(len(self.data[0]))])] + + self.data) + pdf = df_null.toPandas() + null_counts = pdf.isnull().sum().tolist() + self.assertTrue(all([c == 1 for c in null_counts])) + + def _toPandas_arrow_toggle(self, df): + with self.sql_conf({"spark.sql.execution.arrow.enabled": False}): + pdf = df.toPandas() + + pdf_arrow = df.toPandas() + + return pdf, pdf_arrow + + def test_toPandas_arrow_toggle(self): + df = self.spark.createDataFrame(self.data, schema=self.schema) + pdf, pdf_arrow = self._toPandas_arrow_toggle(df) + expected = self.create_pandas_data_frame() + self.assertPandasEqual(expected, pdf) + self.assertPandasEqual(expected, pdf_arrow) + + def test_toPandas_respect_session_timezone(self): + df = self.spark.createDataFrame(self.data, schema=self.schema) + + timezone = "America/New_York" + with self.sql_conf({ + "spark.sql.execution.pandas.respectSessionTimeZone": False, + "spark.sql.session.timeZone": timezone}): + pdf_la, pdf_arrow_la = self._toPandas_arrow_toggle(df) + self.assertPandasEqual(pdf_arrow_la, pdf_la) + + with self.sql_conf({ + "spark.sql.execution.pandas.respectSessionTimeZone": True, + "spark.sql.session.timeZone": timezone}): + pdf_ny, pdf_arrow_ny = self._toPandas_arrow_toggle(df) + self.assertPandasEqual(pdf_arrow_ny, pdf_ny) + + self.assertFalse(pdf_ny.equals(pdf_la)) + + from pyspark.sql.types import _check_series_convert_timestamps_local_tz + pdf_la_corrected = pdf_la.copy() + for field in self.schema: + if isinstance(field.dataType, TimestampType): + pdf_la_corrected[field.name] = _check_series_convert_timestamps_local_tz( + pdf_la_corrected[field.name], timezone) + self.assertPandasEqual(pdf_ny, pdf_la_corrected) + + def test_pandas_round_trip(self): + pdf = self.create_pandas_data_frame() + df = self.spark.createDataFrame(self.data, schema=self.schema) + pdf_arrow = df.toPandas() + self.assertPandasEqual(pdf_arrow, pdf) + + def test_filtered_frame(self): + df = self.spark.range(3).toDF("i") + pdf = df.filter("i < 0").toPandas() + self.assertEqual(len(pdf.columns), 1) + self.assertEqual(pdf.columns[0], "i") + self.assertTrue(pdf.empty) + + def _createDataFrame_toggle(self, pdf, schema=None): + with self.sql_conf({"spark.sql.execution.arrow.enabled": False}): + df_no_arrow = self.spark.createDataFrame(pdf, schema=schema) + + df_arrow = self.spark.createDataFrame(pdf, schema=schema) + + return df_no_arrow, df_arrow + + def test_createDataFrame_toggle(self): + pdf = self.create_pandas_data_frame() + df_no_arrow, df_arrow = self._createDataFrame_toggle(pdf, schema=self.schema) + self.assertEquals(df_no_arrow.collect(), df_arrow.collect()) + + def test_createDataFrame_respect_session_timezone(self): + from datetime import timedelta + pdf = self.create_pandas_data_frame() + timezone = "America/New_York" + with self.sql_conf({ + "spark.sql.execution.pandas.respectSessionTimeZone": False, + "spark.sql.session.timeZone": timezone}): + df_no_arrow_la, df_arrow_la = self._createDataFrame_toggle(pdf, schema=self.schema) + result_la = df_no_arrow_la.collect() + result_arrow_la = df_arrow_la.collect() + self.assertEqual(result_la, result_arrow_la) + + with self.sql_conf({ + "spark.sql.execution.pandas.respectSessionTimeZone": True, + "spark.sql.session.timeZone": timezone}): + df_no_arrow_ny, df_arrow_ny = self._createDataFrame_toggle(pdf, schema=self.schema) + result_ny = df_no_arrow_ny.collect() + result_arrow_ny = df_arrow_ny.collect() + self.assertEqual(result_ny, result_arrow_ny) + + self.assertNotEqual(result_ny, result_la) + + # Correct result_la by adjusting 3 hours difference between Los Angeles and New York + result_la_corrected = [Row(**{k: v + timedelta(hours=5) if k == '8_timestamp_t' else v + for k, v in row.asDict().items()}) + for row in result_la] + self.assertEqual(result_ny, result_la_corrected) + + def test_createDataFrame_with_schema(self): + pdf = self.create_pandas_data_frame() + df = self.spark.createDataFrame(pdf, schema=self.schema) + self.assertEquals(self.schema, df.schema) + pdf_arrow = df.toPandas() + self.assertPandasEqual(pdf_arrow, pdf) + + def test_createDataFrame_with_incorrect_schema(self): + pdf = self.create_pandas_data_frame() + fields = list(self.schema) + fields[0], fields[7] = fields[7], fields[0] # swap str with timestamp + wrong_schema = StructType(fields) + with QuietTest(self.sc): + with self.assertRaisesRegexp(Exception, ".*No cast.*string.*timestamp.*"): + self.spark.createDataFrame(pdf, schema=wrong_schema) + + def test_createDataFrame_with_names(self): + pdf = self.create_pandas_data_frame() + new_names = list(map(str, range(len(self.schema.fieldNames())))) + # Test that schema as a list of column names gets applied + df = self.spark.createDataFrame(pdf, schema=list(new_names)) + self.assertEquals(df.schema.fieldNames(), new_names) + # Test that schema as tuple of column names gets applied + df = self.spark.createDataFrame(pdf, schema=tuple(new_names)) + self.assertEquals(df.schema.fieldNames(), new_names) + + def test_createDataFrame_column_name_encoding(self): + import pandas as pd + pdf = pd.DataFrame({u'a': [1]}) + columns = self.spark.createDataFrame(pdf).columns + self.assertTrue(isinstance(columns[0], str)) + self.assertEquals(columns[0], 'a') + columns = self.spark.createDataFrame(pdf, [u'b']).columns + self.assertTrue(isinstance(columns[0], str)) + self.assertEquals(columns[0], 'b') + + def test_createDataFrame_with_single_data_type(self): + import pandas as pd + with QuietTest(self.sc): + with self.assertRaisesRegexp(ValueError, ".*IntegerType.*not supported.*"): + self.spark.createDataFrame(pd.DataFrame({"a": [1]}), schema="int") + + def test_createDataFrame_does_not_modify_input(self): + import pandas as pd + # Some series get converted for Spark to consume, this makes sure input is unchanged + pdf = self.create_pandas_data_frame() + # Use a nanosecond value to make sure it is not truncated + pdf.ix[0, '8_timestamp_t'] = pd.Timestamp(1) + # Integers with nulls will get NaNs filled with 0 and will be casted + pdf.ix[1, '2_int_t'] = None + pdf_copy = pdf.copy(deep=True) + self.spark.createDataFrame(pdf, schema=self.schema) + self.assertTrue(pdf.equals(pdf_copy)) + + def test_schema_conversion_roundtrip(self): + from pyspark.sql.types import from_arrow_schema, to_arrow_schema + arrow_schema = to_arrow_schema(self.schema) + schema_rt = from_arrow_schema(arrow_schema) + self.assertEquals(self.schema, schema_rt) + + def test_createDataFrame_with_array_type(self): + import pandas as pd + pdf = pd.DataFrame({"a": [[1, 2], [3, 4]], "b": [[u"x", u"y"], [u"y", u"z"]]}) + df, df_arrow = self._createDataFrame_toggle(pdf) + result = df.collect() + result_arrow = df_arrow.collect() + expected = [tuple(list(e) for e in rec) for rec in pdf.to_records(index=False)] + for r in range(len(expected)): + for e in range(len(expected[r])): + self.assertTrue(expected[r][e] == result_arrow[r][e] and + result[r][e] == result_arrow[r][e]) + + def test_toPandas_with_array_type(self): + expected = [([1, 2], [u"x", u"y"]), ([3, 4], [u"y", u"z"])] + array_schema = StructType([StructField("a", ArrayType(IntegerType())), + StructField("b", ArrayType(StringType()))]) + df = self.spark.createDataFrame(expected, schema=array_schema) + pdf, pdf_arrow = self._toPandas_arrow_toggle(df) + result = [tuple(list(e) for e in rec) for rec in pdf.to_records(index=False)] + result_arrow = [tuple(list(e) for e in rec) for rec in pdf_arrow.to_records(index=False)] + for r in range(len(expected)): + for e in range(len(expected[r])): + self.assertTrue(expected[r][e] == result_arrow[r][e] and + result[r][e] == result_arrow[r][e]) + + def test_createDataFrame_with_int_col_names(self): + import numpy as np + import pandas as pd + pdf = pd.DataFrame(np.random.rand(4, 2)) + df, df_arrow = self._createDataFrame_toggle(pdf) + pdf_col_names = [str(c) for c in pdf.columns] + self.assertEqual(pdf_col_names, df.columns) + self.assertEqual(pdf_col_names, df_arrow.columns) + + def test_createDataFrame_fallback_enabled(self): + import pandas as pd + + with QuietTest(self.sc): + with self.sql_conf({"spark.sql.execution.arrow.fallback.enabled": True}): + with warnings.catch_warnings(record=True) as warns: + # we want the warnings to appear even if this test is run from a subclass + warnings.simplefilter("always") + df = self.spark.createDataFrame( + pd.DataFrame([[{u'a': 1}]]), "a: map") + # Catch and check the last UserWarning. + user_warns = [ + warn.message for warn in warns if isinstance(warn.message, UserWarning)] + self.assertTrue(len(user_warns) > 0) + self.assertTrue( + "Attempting non-optimization" in _exception_message(user_warns[-1])) + self.assertEqual(df.collect(), [Row(a={u'a': 1})]) + + def test_createDataFrame_fallback_disabled(self): + from distutils.version import LooseVersion + import pandas as pd + import pyarrow as pa + + with QuietTest(self.sc): + with self.assertRaisesRegexp(TypeError, 'Unsupported type'): + self.spark.createDataFrame( + pd.DataFrame([[{u'a': 1}]]), "a: map") + + # TODO: remove BinaryType check once minimum pyarrow version is 0.10.0 + if LooseVersion(pa.__version__) < LooseVersion("0.10.0"): + with QuietTest(self.sc): + with self.assertRaisesRegexp(TypeError, 'Unsupported type.*BinaryType'): + self.spark.createDataFrame( + pd.DataFrame([[{'a': b'aaa'}]]), "a: binary") + + # Regression test for SPARK-23314 + def test_timestamp_dst(self): + import pandas as pd + # Daylight saving time for Los Angeles for 2015 is Sun, Nov 1 at 2:00 am + dt = [datetime.datetime(2015, 11, 1, 0, 30), + datetime.datetime(2015, 11, 1, 1, 30), + datetime.datetime(2015, 11, 1, 2, 30)] + pdf = pd.DataFrame({'time': dt}) + + df_from_python = self.spark.createDataFrame(dt, 'timestamp').toDF('time') + df_from_pandas = self.spark.createDataFrame(pdf) + + self.assertPandasEqual(pdf, df_from_python.toPandas()) + self.assertPandasEqual(pdf, df_from_pandas.toPandas()) + + +class EncryptionArrowTests(ArrowTests): + + @classmethod + def conf(cls): + return super(EncryptionArrowTests, cls).conf().set("spark.io.encryption.enabled", "true") + + +if __name__ == "__main__": + from pyspark.sql.tests.test_arrow import * + + try: + import xmlrunner + testRunner = xmlrunner.XMLTestRunner(output='target/test-reports') + except ImportError: + testRunner = None + unittest.main(testRunner=testRunner, verbosity=2) diff --git a/python/pyspark/sql/tests/test_catalog.py b/python/pyspark/sql/tests/test_catalog.py new file mode 100644 index 0000000000000..873405a2c6aa3 --- /dev/null +++ b/python/pyspark/sql/tests/test_catalog.py @@ -0,0 +1,200 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from pyspark.sql.utils import AnalysisException +from pyspark.testing.sqlutils import ReusedSQLTestCase + + +class CatalogTests(ReusedSQLTestCase): + + def test_current_database(self): + spark = self.spark + with self.database("some_db"): + self.assertEquals(spark.catalog.currentDatabase(), "default") + spark.sql("CREATE DATABASE some_db") + spark.catalog.setCurrentDatabase("some_db") + self.assertEquals(spark.catalog.currentDatabase(), "some_db") + self.assertRaisesRegexp( + AnalysisException, + "does_not_exist", + lambda: spark.catalog.setCurrentDatabase("does_not_exist")) + + def test_list_databases(self): + spark = self.spark + with self.database("some_db"): + databases = [db.name for db in spark.catalog.listDatabases()] + self.assertEquals(databases, ["default"]) + spark.sql("CREATE DATABASE some_db") + databases = [db.name for db in spark.catalog.listDatabases()] + self.assertEquals(sorted(databases), ["default", "some_db"]) + + def test_list_tables(self): + from pyspark.sql.catalog import Table + spark = self.spark + with self.database("some_db"): + spark.sql("CREATE DATABASE some_db") + with self.table("tab1", "some_db.tab2"): + with self.tempView("temp_tab"): + self.assertEquals(spark.catalog.listTables(), []) + self.assertEquals(spark.catalog.listTables("some_db"), []) + spark.createDataFrame([(1, 1)]).createOrReplaceTempView("temp_tab") + spark.sql("CREATE TABLE tab1 (name STRING, age INT) USING parquet") + spark.sql("CREATE TABLE some_db.tab2 (name STRING, age INT) USING parquet") + tables = sorted(spark.catalog.listTables(), key=lambda t: t.name) + tablesDefault = \ + sorted(spark.catalog.listTables("default"), key=lambda t: t.name) + tablesSomeDb = \ + sorted(spark.catalog.listTables("some_db"), key=lambda t: t.name) + self.assertEquals(tables, tablesDefault) + self.assertEquals(len(tables), 2) + self.assertEquals(len(tablesSomeDb), 2) + self.assertEquals(tables[0], Table( + name="tab1", + database="default", + description=None, + tableType="MANAGED", + isTemporary=False)) + self.assertEquals(tables[1], Table( + name="temp_tab", + database=None, + description=None, + tableType="TEMPORARY", + isTemporary=True)) + self.assertEquals(tablesSomeDb[0], Table( + name="tab2", + database="some_db", + description=None, + tableType="MANAGED", + isTemporary=False)) + self.assertEquals(tablesSomeDb[1], Table( + name="temp_tab", + database=None, + description=None, + tableType="TEMPORARY", + isTemporary=True)) + self.assertRaisesRegexp( + AnalysisException, + "does_not_exist", + lambda: spark.catalog.listTables("does_not_exist")) + + def test_list_functions(self): + from pyspark.sql.catalog import Function + spark = self.spark + with self.database("some_db"): + spark.sql("CREATE DATABASE some_db") + functions = dict((f.name, f) for f in spark.catalog.listFunctions()) + functionsDefault = dict((f.name, f) for f in spark.catalog.listFunctions("default")) + self.assertTrue(len(functions) > 200) + self.assertTrue("+" in functions) + self.assertTrue("like" in functions) + self.assertTrue("month" in functions) + self.assertTrue("to_date" in functions) + self.assertTrue("to_timestamp" in functions) + self.assertTrue("to_unix_timestamp" in functions) + self.assertTrue("current_database" in functions) + self.assertEquals(functions["+"], Function( + name="+", + description=None, + className="org.apache.spark.sql.catalyst.expressions.Add", + isTemporary=True)) + self.assertEquals(functions, functionsDefault) + + with self.function("func1", "some_db.func2"): + spark.catalog.registerFunction("temp_func", lambda x: str(x)) + spark.sql("CREATE FUNCTION func1 AS 'org.apache.spark.data.bricks'") + spark.sql("CREATE FUNCTION some_db.func2 AS 'org.apache.spark.data.bricks'") + newFunctions = dict((f.name, f) for f in spark.catalog.listFunctions()) + newFunctionsSomeDb = \ + dict((f.name, f) for f in spark.catalog.listFunctions("some_db")) + self.assertTrue(set(functions).issubset(set(newFunctions))) + self.assertTrue(set(functions).issubset(set(newFunctionsSomeDb))) + self.assertTrue("temp_func" in newFunctions) + self.assertTrue("func1" in newFunctions) + self.assertTrue("func2" not in newFunctions) + self.assertTrue("temp_func" in newFunctionsSomeDb) + self.assertTrue("func1" not in newFunctionsSomeDb) + self.assertTrue("func2" in newFunctionsSomeDb) + self.assertRaisesRegexp( + AnalysisException, + "does_not_exist", + lambda: spark.catalog.listFunctions("does_not_exist")) + + def test_list_columns(self): + from pyspark.sql.catalog import Column + spark = self.spark + with self.database("some_db"): + spark.sql("CREATE DATABASE some_db") + with self.table("tab1", "some_db.tab2"): + spark.sql("CREATE TABLE tab1 (name STRING, age INT) USING parquet") + spark.sql( + "CREATE TABLE some_db.tab2 (nickname STRING, tolerance FLOAT) USING parquet") + columns = sorted(spark.catalog.listColumns("tab1"), key=lambda c: c.name) + columnsDefault = \ + sorted(spark.catalog.listColumns("tab1", "default"), key=lambda c: c.name) + self.assertEquals(columns, columnsDefault) + self.assertEquals(len(columns), 2) + self.assertEquals(columns[0], Column( + name="age", + description=None, + dataType="int", + nullable=True, + isPartition=False, + isBucket=False)) + self.assertEquals(columns[1], Column( + name="name", + description=None, + dataType="string", + nullable=True, + isPartition=False, + isBucket=False)) + columns2 = \ + sorted(spark.catalog.listColumns("tab2", "some_db"), key=lambda c: c.name) + self.assertEquals(len(columns2), 2) + self.assertEquals(columns2[0], Column( + name="nickname", + description=None, + dataType="string", + nullable=True, + isPartition=False, + isBucket=False)) + self.assertEquals(columns2[1], Column( + name="tolerance", + description=None, + dataType="float", + nullable=True, + isPartition=False, + isBucket=False)) + self.assertRaisesRegexp( + AnalysisException, + "tab2", + lambda: spark.catalog.listColumns("tab2")) + self.assertRaisesRegexp( + AnalysisException, + "does_not_exist", + lambda: spark.catalog.listColumns("does_not_exist")) + + +if __name__ == "__main__": + import unittest + from pyspark.sql.tests.test_catalog import * + + try: + import xmlrunner + testRunner = xmlrunner.XMLTestRunner(output='target/test-reports') + except ImportError: + testRunner = None + unittest.main(testRunner=testRunner, verbosity=2) diff --git a/python/pyspark/sql/tests/test_column.py b/python/pyspark/sql/tests/test_column.py new file mode 100644 index 0000000000000..01d4f7e223a41 --- /dev/null +++ b/python/pyspark/sql/tests/test_column.py @@ -0,0 +1,158 @@ +# -*- encoding: utf-8 -*- +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import sys + +from pyspark.sql import Column, Row +from pyspark.sql.types import * +from pyspark.sql.utils import AnalysisException +from pyspark.testing.sqlutils import ReusedSQLTestCase + + +class ColumnTests(ReusedSQLTestCase): + + def test_column_name_encoding(self): + """Ensure that created columns has `str` type consistently.""" + columns = self.spark.createDataFrame([('Alice', 1)], ['name', u'age']).columns + self.assertEqual(columns, ['name', 'age']) + self.assertTrue(isinstance(columns[0], str)) + self.assertTrue(isinstance(columns[1], str)) + + def test_and_in_expression(self): + self.assertEqual(4, self.df.filter((self.df.key <= 10) & (self.df.value <= "2")).count()) + self.assertRaises(ValueError, lambda: (self.df.key <= 10) and (self.df.value <= "2")) + self.assertEqual(14, self.df.filter((self.df.key <= 3) | (self.df.value < "2")).count()) + self.assertRaises(ValueError, lambda: self.df.key <= 3 or self.df.value < "2") + self.assertEqual(99, self.df.filter(~(self.df.key == 1)).count()) + self.assertRaises(ValueError, lambda: not self.df.key == 1) + + def test_validate_column_types(self): + from pyspark.sql.functions import udf, to_json + from pyspark.sql.column import _to_java_column + + self.assertTrue("Column" in _to_java_column("a").getClass().toString()) + self.assertTrue("Column" in _to_java_column(u"a").getClass().toString()) + self.assertTrue("Column" in _to_java_column(self.spark.range(1).id).getClass().toString()) + + self.assertRaisesRegexp( + TypeError, + "Invalid argument, not a string or column", + lambda: _to_java_column(1)) + + class A(): + pass + + self.assertRaises(TypeError, lambda: _to_java_column(A())) + self.assertRaises(TypeError, lambda: _to_java_column([])) + + self.assertRaisesRegexp( + TypeError, + "Invalid argument, not a string or column", + lambda: udf(lambda x: x)(None)) + self.assertRaises(TypeError, lambda: to_json(1)) + + def test_column_operators(self): + ci = self.df.key + cs = self.df.value + c = ci == cs + self.assertTrue(isinstance((- ci - 1 - 2) % 3 * 2.5 / 3.5, Column)) + rcc = (1 + ci), (1 - ci), (1 * ci), (1 / ci), (1 % ci), (1 ** ci), (ci ** 1) + self.assertTrue(all(isinstance(c, Column) for c in rcc)) + cb = [ci == 5, ci != 0, ci > 3, ci < 4, ci >= 0, ci <= 7] + self.assertTrue(all(isinstance(c, Column) for c in cb)) + cbool = (ci & ci), (ci | ci), (~ci) + self.assertTrue(all(isinstance(c, Column) for c in cbool)) + css = cs.contains('a'), cs.like('a'), cs.rlike('a'), cs.asc(), cs.desc(),\ + cs.startswith('a'), cs.endswith('a'), ci.eqNullSafe(cs) + self.assertTrue(all(isinstance(c, Column) for c in css)) + self.assertTrue(isinstance(ci.cast(LongType()), Column)) + self.assertRaisesRegexp(ValueError, + "Cannot apply 'in' operator against a column", + lambda: 1 in cs) + + def test_column_getitem(self): + from pyspark.sql.functions import col + + self.assertIsInstance(col("foo")[1:3], Column) + self.assertIsInstance(col("foo")[0], Column) + self.assertIsInstance(col("foo")["bar"], Column) + self.assertRaises(ValueError, lambda: col("foo")[0:10:2]) + + def test_column_select(self): + df = self.df + self.assertEqual(self.testData, df.select("*").collect()) + self.assertEqual(self.testData, df.select(df.key, df.value).collect()) + self.assertEqual([Row(value='1')], df.where(df.key == 1).select(df.value).collect()) + + def test_access_column(self): + df = self.df + self.assertTrue(isinstance(df.key, Column)) + self.assertTrue(isinstance(df['key'], Column)) + self.assertTrue(isinstance(df[0], Column)) + self.assertRaises(IndexError, lambda: df[2]) + self.assertRaises(AnalysisException, lambda: df["bad_key"]) + self.assertRaises(TypeError, lambda: df[{}]) + + def test_column_name_with_non_ascii(self): + if sys.version >= '3': + columnName = "数量" + self.assertTrue(isinstance(columnName, str)) + else: + columnName = unicode("数量", "utf-8") + self.assertTrue(isinstance(columnName, unicode)) + schema = StructType([StructField(columnName, LongType(), True)]) + df = self.spark.createDataFrame([(1,)], schema) + self.assertEqual(schema, df.schema) + self.assertEqual("DataFrame[数量: bigint]", str(df)) + self.assertEqual([("数量", 'bigint')], df.dtypes) + self.assertEqual(1, df.select("数量").first()[0]) + self.assertEqual(1, df.select(df["数量"]).first()[0]) + + def test_field_accessor(self): + df = self.sc.parallelize([Row(l=[1], r=Row(a=1, b="b"), d={"k": "v"})]).toDF() + self.assertEqual(1, df.select(df.l[0]).first()[0]) + self.assertEqual(1, df.select(df.r["a"]).first()[0]) + self.assertEqual(1, df.select(df["r.a"]).first()[0]) + self.assertEqual("b", df.select(df.r["b"]).first()[0]) + self.assertEqual("b", df.select(df["r.b"]).first()[0]) + self.assertEqual("v", df.select(df.d["k"]).first()[0]) + + def test_bitwise_operations(self): + from pyspark.sql import functions + row = Row(a=170, b=75) + df = self.spark.createDataFrame([row]) + result = df.select(df.a.bitwiseAND(df.b)).collect()[0].asDict() + self.assertEqual(170 & 75, result['(a & b)']) + result = df.select(df.a.bitwiseOR(df.b)).collect()[0].asDict() + self.assertEqual(170 | 75, result['(a | b)']) + result = df.select(df.a.bitwiseXOR(df.b)).collect()[0].asDict() + self.assertEqual(170 ^ 75, result['(a ^ b)']) + result = df.select(functions.bitwiseNOT(df.b)).collect()[0].asDict() + self.assertEqual(~75, result['~b']) + + +if __name__ == "__main__": + import unittest + from pyspark.sql.tests.test_column import * + + try: + import xmlrunner + testRunner = xmlrunner.XMLTestRunner(output='target/test-reports') + except ImportError: + testRunner = None + unittest.main(testRunner=testRunner, verbosity=2) diff --git a/python/pyspark/sql/tests/test_conf.py b/python/pyspark/sql/tests/test_conf.py new file mode 100644 index 0000000000000..53ac4a66f4645 --- /dev/null +++ b/python/pyspark/sql/tests/test_conf.py @@ -0,0 +1,56 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from pyspark.testing.sqlutils import ReusedSQLTestCase + + +class ConfTests(ReusedSQLTestCase): + + def test_conf(self): + spark = self.spark + spark.conf.set("bogo", "sipeo") + self.assertEqual(spark.conf.get("bogo"), "sipeo") + spark.conf.set("bogo", "ta") + self.assertEqual(spark.conf.get("bogo"), "ta") + self.assertEqual(spark.conf.get("bogo", "not.read"), "ta") + self.assertEqual(spark.conf.get("not.set", "ta"), "ta") + self.assertRaisesRegexp(Exception, "not.set", lambda: spark.conf.get("not.set")) + spark.conf.unset("bogo") + self.assertEqual(spark.conf.get("bogo", "colombia"), "colombia") + + self.assertEqual(spark.conf.get("hyukjin", None), None) + + # This returns 'STATIC' because it's the default value of + # 'spark.sql.sources.partitionOverwriteMode', and `defaultValue` in + # `spark.conf.get` is unset. + self.assertEqual(spark.conf.get("spark.sql.sources.partitionOverwriteMode"), "STATIC") + + # This returns None because 'spark.sql.sources.partitionOverwriteMode' is unset, but + # `defaultValue` in `spark.conf.get` is set to None. + self.assertEqual(spark.conf.get("spark.sql.sources.partitionOverwriteMode", None), None) + + +if __name__ == "__main__": + import unittest + from pyspark.sql.tests.test_conf import * + + try: + import xmlrunner + testRunner = xmlrunner.XMLTestRunner(output='target/test-reports') + except ImportError: + testRunner = None + unittest.main(testRunner=testRunner, verbosity=2) diff --git a/python/pyspark/sql/tests/test_context.py b/python/pyspark/sql/tests/test_context.py new file mode 100644 index 0000000000000..918f4ad2d62f4 --- /dev/null +++ b/python/pyspark/sql/tests/test_context.py @@ -0,0 +1,264 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +import os +import shutil +import sys +import tempfile +import unittest + +import py4j + +from pyspark import HiveContext, Row +from pyspark.sql.types import * +from pyspark.sql.window import Window +from pyspark.testing.utils import ReusedPySparkTestCase + + +class HiveContextSQLTests(ReusedPySparkTestCase): + + @classmethod + def setUpClass(cls): + ReusedPySparkTestCase.setUpClass() + cls.tempdir = tempfile.NamedTemporaryFile(delete=False) + cls.hive_available = True + try: + cls.sc._jvm.org.apache.hadoop.hive.conf.HiveConf() + except py4j.protocol.Py4JError: + cls.hive_available = False + except TypeError: + cls.hive_available = False + os.unlink(cls.tempdir.name) + if cls.hive_available: + cls.spark = HiveContext._createForTesting(cls.sc) + cls.testData = [Row(key=i, value=str(i)) for i in range(100)] + cls.df = cls.sc.parallelize(cls.testData).toDF() + + def setUp(self): + if not self.hive_available: + self.skipTest("Hive is not available.") + + @classmethod + def tearDownClass(cls): + ReusedPySparkTestCase.tearDownClass() + shutil.rmtree(cls.tempdir.name, ignore_errors=True) + + def test_save_and_load_table(self): + df = self.df + tmpPath = tempfile.mkdtemp() + shutil.rmtree(tmpPath) + df.write.saveAsTable("savedJsonTable", "json", "append", path=tmpPath) + actual = self.spark.createExternalTable("externalJsonTable", tmpPath, "json") + self.assertEqual(sorted(df.collect()), + sorted(self.spark.sql("SELECT * FROM savedJsonTable").collect())) + self.assertEqual(sorted(df.collect()), + sorted(self.spark.sql("SELECT * FROM externalJsonTable").collect())) + self.assertEqual(sorted(df.collect()), sorted(actual.collect())) + self.spark.sql("DROP TABLE externalJsonTable") + + df.write.saveAsTable("savedJsonTable", "json", "overwrite", path=tmpPath) + schema = StructType([StructField("value", StringType(), True)]) + actual = self.spark.createExternalTable("externalJsonTable", source="json", + schema=schema, path=tmpPath, + noUse="this options will not be used") + self.assertEqual(sorted(df.collect()), + sorted(self.spark.sql("SELECT * FROM savedJsonTable").collect())) + self.assertEqual(sorted(df.select("value").collect()), + sorted(self.spark.sql("SELECT * FROM externalJsonTable").collect())) + self.assertEqual(sorted(df.select("value").collect()), sorted(actual.collect())) + self.spark.sql("DROP TABLE savedJsonTable") + self.spark.sql("DROP TABLE externalJsonTable") + + defaultDataSourceName = self.spark.getConf("spark.sql.sources.default", + "org.apache.spark.sql.parquet") + self.spark.sql("SET spark.sql.sources.default=org.apache.spark.sql.json") + df.write.saveAsTable("savedJsonTable", path=tmpPath, mode="overwrite") + actual = self.spark.createExternalTable("externalJsonTable", path=tmpPath) + self.assertEqual(sorted(df.collect()), + sorted(self.spark.sql("SELECT * FROM savedJsonTable").collect())) + self.assertEqual(sorted(df.collect()), + sorted(self.spark.sql("SELECT * FROM externalJsonTable").collect())) + self.assertEqual(sorted(df.collect()), sorted(actual.collect())) + self.spark.sql("DROP TABLE savedJsonTable") + self.spark.sql("DROP TABLE externalJsonTable") + self.spark.sql("SET spark.sql.sources.default=" + defaultDataSourceName) + + shutil.rmtree(tmpPath) + + def test_window_functions(self): + df = self.spark.createDataFrame([(1, "1"), (2, "2"), (1, "2"), (1, "2")], ["key", "value"]) + w = Window.partitionBy("value").orderBy("key") + from pyspark.sql import functions as F + sel = df.select(df.value, df.key, + F.max("key").over(w.rowsBetween(0, 1)), + F.min("key").over(w.rowsBetween(0, 1)), + F.count("key").over(w.rowsBetween(float('-inf'), float('inf'))), + F.row_number().over(w), + F.rank().over(w), + F.dense_rank().over(w), + F.ntile(2).over(w)) + rs = sorted(sel.collect()) + expected = [ + ("1", 1, 1, 1, 1, 1, 1, 1, 1), + ("2", 1, 1, 1, 3, 1, 1, 1, 1), + ("2", 1, 2, 1, 3, 2, 1, 1, 1), + ("2", 2, 2, 2, 3, 3, 3, 2, 2) + ] + for r, ex in zip(rs, expected): + self.assertEqual(tuple(r), ex[:len(r)]) + + def test_window_functions_without_partitionBy(self): + df = self.spark.createDataFrame([(1, "1"), (2, "2"), (1, "2"), (1, "2")], ["key", "value"]) + w = Window.orderBy("key", df.value) + from pyspark.sql import functions as F + sel = df.select(df.value, df.key, + F.max("key").over(w.rowsBetween(0, 1)), + F.min("key").over(w.rowsBetween(0, 1)), + F.count("key").over(w.rowsBetween(float('-inf'), float('inf'))), + F.row_number().over(w), + F.rank().over(w), + F.dense_rank().over(w), + F.ntile(2).over(w)) + rs = sorted(sel.collect()) + expected = [ + ("1", 1, 1, 1, 4, 1, 1, 1, 1), + ("2", 1, 1, 1, 4, 2, 2, 2, 1), + ("2", 1, 2, 1, 4, 3, 2, 2, 2), + ("2", 2, 2, 2, 4, 4, 4, 3, 2) + ] + for r, ex in zip(rs, expected): + self.assertEqual(tuple(r), ex[:len(r)]) + + def test_window_functions_cumulative_sum(self): + df = self.spark.createDataFrame([("one", 1), ("two", 2)], ["key", "value"]) + from pyspark.sql import functions as F + + # Test cumulative sum + sel = df.select( + df.key, + F.sum(df.value).over(Window.rowsBetween(Window.unboundedPreceding, 0))) + rs = sorted(sel.collect()) + expected = [("one", 1), ("two", 3)] + for r, ex in zip(rs, expected): + self.assertEqual(tuple(r), ex[:len(r)]) + + # Test boundary values less than JVM's Long.MinValue and make sure we don't overflow + sel = df.select( + df.key, + F.sum(df.value).over(Window.rowsBetween(Window.unboundedPreceding - 1, 0))) + rs = sorted(sel.collect()) + expected = [("one", 1), ("two", 3)] + for r, ex in zip(rs, expected): + self.assertEqual(tuple(r), ex[:len(r)]) + + # Test boundary values greater than JVM's Long.MaxValue and make sure we don't overflow + frame_end = Window.unboundedFollowing + 1 + sel = df.select( + df.key, + F.sum(df.value).over(Window.rowsBetween(Window.currentRow, frame_end))) + rs = sorted(sel.collect()) + expected = [("one", 3), ("two", 2)] + for r, ex in zip(rs, expected): + self.assertEqual(tuple(r), ex[:len(r)]) + + def test_collect_functions(self): + df = self.spark.createDataFrame([(1, "1"), (2, "2"), (1, "2"), (1, "2")], ["key", "value"]) + from pyspark.sql import functions + + self.assertEqual( + sorted(df.select(functions.collect_set(df.key).alias('r')).collect()[0].r), + [1, 2]) + self.assertEqual( + sorted(df.select(functions.collect_list(df.key).alias('r')).collect()[0].r), + [1, 1, 1, 2]) + self.assertEqual( + sorted(df.select(functions.collect_set(df.value).alias('r')).collect()[0].r), + ["1", "2"]) + self.assertEqual( + sorted(df.select(functions.collect_list(df.value).alias('r')).collect()[0].r), + ["1", "2", "2", "2"]) + + def test_limit_and_take(self): + df = self.spark.range(1, 1000, numPartitions=10) + + def assert_runs_only_one_job_stage_and_task(job_group_name, f): + tracker = self.sc.statusTracker() + self.sc.setJobGroup(job_group_name, description="") + f() + jobs = tracker.getJobIdsForGroup(job_group_name) + self.assertEqual(1, len(jobs)) + stages = tracker.getJobInfo(jobs[0]).stageIds + self.assertEqual(1, len(stages)) + self.assertEqual(1, tracker.getStageInfo(stages[0]).numTasks) + + # Regression test for SPARK-10731: take should delegate to Scala implementation + assert_runs_only_one_job_stage_and_task("take", lambda: df.take(1)) + # Regression test for SPARK-17514: limit(n).collect() should the perform same as take(n) + assert_runs_only_one_job_stage_and_task("collect_limit", lambda: df.limit(1).collect()) + + def test_datetime_functions(self): + from pyspark.sql import functions + from datetime import date + df = self.spark.range(1).selectExpr("'2017-01-22' as dateCol") + parse_result = df.select(functions.to_date(functions.col("dateCol"))).first() + self.assertEquals(date(2017, 1, 22), parse_result['to_date(`dateCol`)']) + + @unittest.skipIf(sys.version_info < (3, 3), "Unittest < 3.3 doesn't support mocking") + def test_unbounded_frames(self): + from unittest.mock import patch + from pyspark.sql import functions as F + from pyspark.sql import window + import importlib + + df = self.spark.range(0, 3) + + def rows_frame_match(): + return "ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING" in df.select( + F.count("*").over(window.Window.rowsBetween(-sys.maxsize, sys.maxsize)) + ).columns[0] + + def range_frame_match(): + return "RANGE BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING" in df.select( + F.count("*").over(window.Window.rangeBetween(-sys.maxsize, sys.maxsize)) + ).columns[0] + + with patch("sys.maxsize", 2 ** 31 - 1): + importlib.reload(window) + self.assertTrue(rows_frame_match()) + self.assertTrue(range_frame_match()) + + with patch("sys.maxsize", 2 ** 63 - 1): + importlib.reload(window) + self.assertTrue(rows_frame_match()) + self.assertTrue(range_frame_match()) + + with patch("sys.maxsize", 2 ** 127 - 1): + importlib.reload(window) + self.assertTrue(rows_frame_match()) + self.assertTrue(range_frame_match()) + + importlib.reload(window) + + +if __name__ == "__main__": + from pyspark.sql.tests.test_context import * + + try: + import xmlrunner + testRunner = xmlrunner.XMLTestRunner(output='target/test-reports') + except ImportError: + testRunner = None + unittest.main(testRunner=testRunner, verbosity=2) diff --git a/python/pyspark/sql/tests/test_dataframe.py b/python/pyspark/sql/tests/test_dataframe.py new file mode 100644 index 0000000000000..908d400e00092 --- /dev/null +++ b/python/pyspark/sql/tests/test_dataframe.py @@ -0,0 +1,738 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import os +import pydoc +import time +import unittest + +from pyspark.sql import SparkSession, Row +from pyspark.sql.types import * +from pyspark.sql.utils import AnalysisException, IllegalArgumentException +from pyspark.testing.sqlutils import ReusedSQLTestCase, SQLTestUtils, have_pyarrow, have_pandas, \ + pandas_requirement_message, pyarrow_requirement_message +from pyspark.testing.utils import QuietTest + + +class DataFrameTests(ReusedSQLTestCase): + + def test_range(self): + self.assertEqual(self.spark.range(1, 1).count(), 0) + self.assertEqual(self.spark.range(1, 0, -1).count(), 1) + self.assertEqual(self.spark.range(0, 1 << 40, 1 << 39).count(), 2) + self.assertEqual(self.spark.range(-2).count(), 0) + self.assertEqual(self.spark.range(3).count(), 3) + + def test_duplicated_column_names(self): + df = self.spark.createDataFrame([(1, 2)], ["c", "c"]) + row = df.select('*').first() + self.assertEqual(1, row[0]) + self.assertEqual(2, row[1]) + self.assertEqual("Row(c=1, c=2)", str(row)) + # Cannot access columns + self.assertRaises(AnalysisException, lambda: df.select(df[0]).first()) + self.assertRaises(AnalysisException, lambda: df.select(df.c).first()) + self.assertRaises(AnalysisException, lambda: df.select(df["c"]).first()) + + def test_freqItems(self): + vals = [Row(a=1, b=-2.0) if i % 2 == 0 else Row(a=i, b=i * 1.0) for i in range(100)] + df = self.sc.parallelize(vals).toDF() + items = df.stat.freqItems(("a", "b"), 0.4).collect()[0] + self.assertTrue(1 in items[0]) + self.assertTrue(-2.0 in items[1]) + + def test_help_command(self): + # Regression test for SPARK-5464 + rdd = self.sc.parallelize(['{"foo":"bar"}', '{"foo":"baz"}']) + df = self.spark.read.json(rdd) + # render_doc() reproduces the help() exception without printing output + pydoc.render_doc(df) + pydoc.render_doc(df.foo) + pydoc.render_doc(df.take(1)) + + def test_dropna(self): + schema = StructType([ + StructField("name", StringType(), True), + StructField("age", IntegerType(), True), + StructField("height", DoubleType(), True)]) + + # shouldn't drop a non-null row + self.assertEqual(self.spark.createDataFrame( + [(u'Alice', 50, 80.1)], schema).dropna().count(), + 1) + + # dropping rows with a single null value + self.assertEqual(self.spark.createDataFrame( + [(u'Alice', None, 80.1)], schema).dropna().count(), + 0) + self.assertEqual(self.spark.createDataFrame( + [(u'Alice', None, 80.1)], schema).dropna(how='any').count(), + 0) + + # if how = 'all', only drop rows if all values are null + self.assertEqual(self.spark.createDataFrame( + [(u'Alice', None, 80.1)], schema).dropna(how='all').count(), + 1) + self.assertEqual(self.spark.createDataFrame( + [(None, None, None)], schema).dropna(how='all').count(), + 0) + + # how and subset + self.assertEqual(self.spark.createDataFrame( + [(u'Alice', 50, None)], schema).dropna(how='any', subset=['name', 'age']).count(), + 1) + self.assertEqual(self.spark.createDataFrame( + [(u'Alice', None, None)], schema).dropna(how='any', subset=['name', 'age']).count(), + 0) + + # threshold + self.assertEqual(self.spark.createDataFrame( + [(u'Alice', None, 80.1)], schema).dropna(thresh=2).count(), + 1) + self.assertEqual(self.spark.createDataFrame( + [(u'Alice', None, None)], schema).dropna(thresh=2).count(), + 0) + + # threshold and subset + self.assertEqual(self.spark.createDataFrame( + [(u'Alice', 50, None)], schema).dropna(thresh=2, subset=['name', 'age']).count(), + 1) + self.assertEqual(self.spark.createDataFrame( + [(u'Alice', None, 180.9)], schema).dropna(thresh=2, subset=['name', 'age']).count(), + 0) + + # thresh should take precedence over how + self.assertEqual(self.spark.createDataFrame( + [(u'Alice', 50, None)], schema).dropna( + how='any', thresh=2, subset=['name', 'age']).count(), + 1) + + def test_fillna(self): + schema = StructType([ + StructField("name", StringType(), True), + StructField("age", IntegerType(), True), + StructField("height", DoubleType(), True), + StructField("spy", BooleanType(), True)]) + + # fillna shouldn't change non-null values + row = self.spark.createDataFrame([(u'Alice', 10, 80.1, True)], schema).fillna(50).first() + self.assertEqual(row.age, 10) + + # fillna with int + row = self.spark.createDataFrame([(u'Alice', None, None, None)], schema).fillna(50).first() + self.assertEqual(row.age, 50) + self.assertEqual(row.height, 50.0) + + # fillna with double + row = self.spark.createDataFrame( + [(u'Alice', None, None, None)], schema).fillna(50.1).first() + self.assertEqual(row.age, 50) + self.assertEqual(row.height, 50.1) + + # fillna with bool + row = self.spark.createDataFrame( + [(u'Alice', None, None, None)], schema).fillna(True).first() + self.assertEqual(row.age, None) + self.assertEqual(row.spy, True) + + # fillna with string + row = self.spark.createDataFrame([(None, None, None, None)], schema).fillna("hello").first() + self.assertEqual(row.name, u"hello") + self.assertEqual(row.age, None) + + # fillna with subset specified for numeric cols + row = self.spark.createDataFrame( + [(None, None, None, None)], schema).fillna(50, subset=['name', 'age']).first() + self.assertEqual(row.name, None) + self.assertEqual(row.age, 50) + self.assertEqual(row.height, None) + self.assertEqual(row.spy, None) + + # fillna with subset specified for string cols + row = self.spark.createDataFrame( + [(None, None, None, None)], schema).fillna("haha", subset=['name', 'age']).first() + self.assertEqual(row.name, "haha") + self.assertEqual(row.age, None) + self.assertEqual(row.height, None) + self.assertEqual(row.spy, None) + + # fillna with subset specified for bool cols + row = self.spark.createDataFrame( + [(None, None, None, None)], schema).fillna(True, subset=['name', 'spy']).first() + self.assertEqual(row.name, None) + self.assertEqual(row.age, None) + self.assertEqual(row.height, None) + self.assertEqual(row.spy, True) + + # fillna with dictionary for boolean types + row = self.spark.createDataFrame([Row(a=None), Row(a=True)]).fillna({"a": True}).first() + self.assertEqual(row.a, True) + + def test_repartitionByRange_dataframe(self): + schema = StructType([ + StructField("name", StringType(), True), + StructField("age", IntegerType(), True), + StructField("height", DoubleType(), True)]) + + df1 = self.spark.createDataFrame( + [(u'Bob', 27, 66.0), (u'Alice', 10, 10.0), (u'Bob', 10, 66.0)], schema) + df2 = self.spark.createDataFrame( + [(u'Alice', 10, 10.0), (u'Bob', 10, 66.0), (u'Bob', 27, 66.0)], schema) + + # test repartitionByRange(numPartitions, *cols) + df3 = df1.repartitionByRange(2, "name", "age") + self.assertEqual(df3.rdd.getNumPartitions(), 2) + self.assertEqual(df3.rdd.first(), df2.rdd.first()) + self.assertEqual(df3.rdd.take(3), df2.rdd.take(3)) + + # test repartitionByRange(numPartitions, *cols) + df4 = df1.repartitionByRange(3, "name", "age") + self.assertEqual(df4.rdd.getNumPartitions(), 3) + self.assertEqual(df4.rdd.first(), df2.rdd.first()) + self.assertEqual(df4.rdd.take(3), df2.rdd.take(3)) + + # test repartitionByRange(*cols) + df5 = df1.repartitionByRange("name", "age") + self.assertEqual(df5.rdd.first(), df2.rdd.first()) + self.assertEqual(df5.rdd.take(3), df2.rdd.take(3)) + + def test_replace(self): + schema = StructType([ + StructField("name", StringType(), True), + StructField("age", IntegerType(), True), + StructField("height", DoubleType(), True)]) + + # replace with int + row = self.spark.createDataFrame([(u'Alice', 10, 10.0)], schema).replace(10, 20).first() + self.assertEqual(row.age, 20) + self.assertEqual(row.height, 20.0) + + # replace with double + row = self.spark.createDataFrame( + [(u'Alice', 80, 80.0)], schema).replace(80.0, 82.1).first() + self.assertEqual(row.age, 82) + self.assertEqual(row.height, 82.1) + + # replace with string + row = self.spark.createDataFrame( + [(u'Alice', 10, 80.1)], schema).replace(u'Alice', u'Ann').first() + self.assertEqual(row.name, u"Ann") + self.assertEqual(row.age, 10) + + # replace with subset specified by a string of a column name w/ actual change + row = self.spark.createDataFrame( + [(u'Alice', 10, 80.1)], schema).replace(10, 20, subset='age').first() + self.assertEqual(row.age, 20) + + # replace with subset specified by a string of a column name w/o actual change + row = self.spark.createDataFrame( + [(u'Alice', 10, 80.1)], schema).replace(10, 20, subset='height').first() + self.assertEqual(row.age, 10) + + # replace with subset specified with one column replaced, another column not in subset + # stays unchanged. + row = self.spark.createDataFrame( + [(u'Alice', 10, 10.0)], schema).replace(10, 20, subset=['name', 'age']).first() + self.assertEqual(row.name, u'Alice') + self.assertEqual(row.age, 20) + self.assertEqual(row.height, 10.0) + + # replace with subset specified but no column will be replaced + row = self.spark.createDataFrame( + [(u'Alice', 10, None)], schema).replace(10, 20, subset=['name', 'height']).first() + self.assertEqual(row.name, u'Alice') + self.assertEqual(row.age, 10) + self.assertEqual(row.height, None) + + # replace with lists + row = self.spark.createDataFrame( + [(u'Alice', 10, 80.1)], schema).replace([u'Alice'], [u'Ann']).first() + self.assertTupleEqual(row, (u'Ann', 10, 80.1)) + + # replace with dict + row = self.spark.createDataFrame( + [(u'Alice', 10, 80.1)], schema).replace({10: 11}).first() + self.assertTupleEqual(row, (u'Alice', 11, 80.1)) + + # test backward compatibility with dummy value + dummy_value = 1 + row = self.spark.createDataFrame( + [(u'Alice', 10, 80.1)], schema).replace({'Alice': 'Bob'}, dummy_value).first() + self.assertTupleEqual(row, (u'Bob', 10, 80.1)) + + # test dict with mixed numerics + row = self.spark.createDataFrame( + [(u'Alice', 10, 80.1)], schema).replace({10: -10, 80.1: 90.5}).first() + self.assertTupleEqual(row, (u'Alice', -10, 90.5)) + + # replace with tuples + row = self.spark.createDataFrame( + [(u'Alice', 10, 80.1)], schema).replace((u'Alice', ), (u'Bob', )).first() + self.assertTupleEqual(row, (u'Bob', 10, 80.1)) + + # replace multiple columns + row = self.spark.createDataFrame( + [(u'Alice', 10, 80.0)], schema).replace((10, 80.0), (20, 90)).first() + self.assertTupleEqual(row, (u'Alice', 20, 90.0)) + + # test for mixed numerics + row = self.spark.createDataFrame( + [(u'Alice', 10, 80.0)], schema).replace((10, 80), (20, 90.5)).first() + self.assertTupleEqual(row, (u'Alice', 20, 90.5)) + + row = self.spark.createDataFrame( + [(u'Alice', 10, 80.0)], schema).replace({10: 20, 80: 90.5}).first() + self.assertTupleEqual(row, (u'Alice', 20, 90.5)) + + # replace with boolean + row = (self + .spark.createDataFrame([(u'Alice', 10, 80.0)], schema) + .selectExpr("name = 'Bob'", 'age <= 15') + .replace(False, True).first()) + self.assertTupleEqual(row, (True, True)) + + # replace string with None and then drop None rows + row = self.spark.createDataFrame( + [(u'Alice', 10, 80.0)], schema).replace(u'Alice', None).dropna() + self.assertEqual(row.count(), 0) + + # replace with number and None + row = self.spark.createDataFrame( + [(u'Alice', 10, 80.0)], schema).replace([10, 80], [20, None]).first() + self.assertTupleEqual(row, (u'Alice', 20, None)) + + # should fail if subset is not list, tuple or None + with self.assertRaises(ValueError): + self.spark.createDataFrame( + [(u'Alice', 10, 80.1)], schema).replace({10: 11}, subset=1).first() + + # should fail if to_replace and value have different length + with self.assertRaises(ValueError): + self.spark.createDataFrame( + [(u'Alice', 10, 80.1)], schema).replace(["Alice", "Bob"], ["Eve"]).first() + + # should fail if when received unexpected type + with self.assertRaises(ValueError): + from datetime import datetime + self.spark.createDataFrame( + [(u'Alice', 10, 80.1)], schema).replace(datetime.now(), datetime.now()).first() + + # should fail if provided mixed type replacements + with self.assertRaises(ValueError): + self.spark.createDataFrame( + [(u'Alice', 10, 80.1)], schema).replace(["Alice", 10], ["Eve", 20]).first() + + with self.assertRaises(ValueError): + self.spark.createDataFrame( + [(u'Alice', 10, 80.1)], schema).replace({u"Alice": u"Bob", 10: 20}).first() + + with self.assertRaisesRegexp( + TypeError, + 'value argument is required when to_replace is not a dictionary.'): + self.spark.createDataFrame( + [(u'Alice', 10, 80.0)], schema).replace(["Alice", "Bob"]).first() + + def test_with_column_with_existing_name(self): + keys = self.df.withColumn("key", self.df.key).select("key").collect() + self.assertEqual([r.key for r in keys], list(range(100))) + + # regression test for SPARK-10417 + def test_column_iterator(self): + + def foo(): + for x in self.df.key: + break + + self.assertRaises(TypeError, foo) + + def test_generic_hints(self): + from pyspark.sql import DataFrame + + df1 = self.spark.range(10e10).toDF("id") + df2 = self.spark.range(10e10).toDF("id") + + self.assertIsInstance(df1.hint("broadcast"), DataFrame) + self.assertIsInstance(df1.hint("broadcast", []), DataFrame) + + # Dummy rules + self.assertIsInstance(df1.hint("broadcast", "foo", "bar"), DataFrame) + self.assertIsInstance(df1.hint("broadcast", ["foo", "bar"]), DataFrame) + + plan = df1.join(df2.hint("broadcast"), "id")._jdf.queryExecution().executedPlan() + self.assertEqual(1, plan.toString().count("BroadcastHashJoin")) + + def test_sample(self): + self.assertRaisesRegexp( + TypeError, + "should be a bool, float and number", + lambda: self.spark.range(1).sample()) + + self.assertRaises( + TypeError, + lambda: self.spark.range(1).sample("a")) + + self.assertRaises( + TypeError, + lambda: self.spark.range(1).sample(seed="abc")) + + self.assertRaises( + IllegalArgumentException, + lambda: self.spark.range(1).sample(-1.0)) + + def test_toDF_with_schema_string(self): + data = [Row(key=i, value=str(i)) for i in range(100)] + rdd = self.sc.parallelize(data, 5) + + df = rdd.toDF("key: int, value: string") + self.assertEqual(df.schema.simpleString(), "struct") + self.assertEqual(df.collect(), data) + + # different but compatible field types can be used. + df = rdd.toDF("key: string, value: string") + self.assertEqual(df.schema.simpleString(), "struct") + self.assertEqual(df.collect(), [Row(key=str(i), value=str(i)) for i in range(100)]) + + # field names can differ. + df = rdd.toDF(" a: int, b: string ") + self.assertEqual(df.schema.simpleString(), "struct") + self.assertEqual(df.collect(), data) + + # number of fields must match. + self.assertRaisesRegexp(Exception, "Length of object", + lambda: rdd.toDF("key: int").collect()) + + # field types mismatch will cause exception at runtime. + self.assertRaisesRegexp(Exception, "FloatType can not accept", + lambda: rdd.toDF("key: float, value: string").collect()) + + # flat schema values will be wrapped into row. + df = rdd.map(lambda row: row.key).toDF("int") + self.assertEqual(df.schema.simpleString(), "struct") + self.assertEqual(df.collect(), [Row(key=i) for i in range(100)]) + + # users can use DataType directly instead of data type string. + df = rdd.map(lambda row: row.key).toDF(IntegerType()) + self.assertEqual(df.schema.simpleString(), "struct") + self.assertEqual(df.collect(), [Row(key=i) for i in range(100)]) + + def test_join_without_on(self): + df1 = self.spark.range(1).toDF("a") + df2 = self.spark.range(1).toDF("b") + + with self.sql_conf({"spark.sql.crossJoin.enabled": False}): + self.assertRaises(AnalysisException, lambda: df1.join(df2, how="inner").collect()) + + with self.sql_conf({"spark.sql.crossJoin.enabled": True}): + actual = df1.join(df2, how="inner").collect() + expected = [Row(a=0, b=0)] + self.assertEqual(actual, expected) + + # Regression test for invalid join methods when on is None, Spark-14761 + def test_invalid_join_method(self): + df1 = self.spark.createDataFrame([("Alice", 5), ("Bob", 8)], ["name", "age"]) + df2 = self.spark.createDataFrame([("Alice", 80), ("Bob", 90)], ["name", "height"]) + self.assertRaises(IllegalArgumentException, lambda: df1.join(df2, how="invalid-join-type")) + + # Cartesian products require cross join syntax + def test_require_cross(self): + + df1 = self.spark.createDataFrame([(1, "1")], ("key", "value")) + df2 = self.spark.createDataFrame([(1, "1")], ("key", "value")) + + # joins without conditions require cross join syntax + self.assertRaises(AnalysisException, lambda: df1.join(df2).collect()) + + # works with crossJoin + self.assertEqual(1, df1.crossJoin(df2).count()) + + def test_cache(self): + spark = self.spark + with self.tempView("tab1", "tab2"): + spark.createDataFrame([(2, 2), (3, 3)]).createOrReplaceTempView("tab1") + spark.createDataFrame([(2, 2), (3, 3)]).createOrReplaceTempView("tab2") + self.assertFalse(spark.catalog.isCached("tab1")) + self.assertFalse(spark.catalog.isCached("tab2")) + spark.catalog.cacheTable("tab1") + self.assertTrue(spark.catalog.isCached("tab1")) + self.assertFalse(spark.catalog.isCached("tab2")) + spark.catalog.cacheTable("tab2") + spark.catalog.uncacheTable("tab1") + self.assertFalse(spark.catalog.isCached("tab1")) + self.assertTrue(spark.catalog.isCached("tab2")) + spark.catalog.clearCache() + self.assertFalse(spark.catalog.isCached("tab1")) + self.assertFalse(spark.catalog.isCached("tab2")) + self.assertRaisesRegexp( + AnalysisException, + "does_not_exist", + lambda: spark.catalog.isCached("does_not_exist")) + self.assertRaisesRegexp( + AnalysisException, + "does_not_exist", + lambda: spark.catalog.cacheTable("does_not_exist")) + self.assertRaisesRegexp( + AnalysisException, + "does_not_exist", + lambda: spark.catalog.uncacheTable("does_not_exist")) + + def _to_pandas(self): + from datetime import datetime, date + schema = StructType().add("a", IntegerType()).add("b", StringType())\ + .add("c", BooleanType()).add("d", FloatType())\ + .add("dt", DateType()).add("ts", TimestampType()) + data = [ + (1, "foo", True, 3.0, date(1969, 1, 1), datetime(1969, 1, 1, 1, 1, 1)), + (2, "foo", True, 5.0, None, None), + (3, "bar", False, -1.0, date(2012, 3, 3), datetime(2012, 3, 3, 3, 3, 3)), + (4, "bar", False, 6.0, date(2100, 4, 4), datetime(2100, 4, 4, 4, 4, 4)), + ] + df = self.spark.createDataFrame(data, schema) + return df.toPandas() + + @unittest.skipIf(not have_pandas, pandas_requirement_message) + def test_to_pandas(self): + import numpy as np + pdf = self._to_pandas() + types = pdf.dtypes + self.assertEquals(types[0], np.int32) + self.assertEquals(types[1], np.object) + self.assertEquals(types[2], np.bool) + self.assertEquals(types[3], np.float32) + self.assertEquals(types[4], np.object) # datetime.date + self.assertEquals(types[5], 'datetime64[ns]') + + @unittest.skipIf(have_pandas, "Required Pandas was found.") + def test_to_pandas_required_pandas_not_found(self): + with QuietTest(self.sc): + with self.assertRaisesRegexp(ImportError, 'Pandas >= .* must be installed'): + self._to_pandas() + + @unittest.skipIf(not have_pandas, pandas_requirement_message) + def test_to_pandas_avoid_astype(self): + import numpy as np + schema = StructType().add("a", IntegerType()).add("b", StringType())\ + .add("c", IntegerType()) + data = [(1, "foo", 16777220), (None, "bar", None)] + df = self.spark.createDataFrame(data, schema) + types = df.toPandas().dtypes + self.assertEquals(types[0], np.float64) # doesn't convert to np.int32 due to NaN value. + self.assertEquals(types[1], np.object) + self.assertEquals(types[2], np.float64) + + def test_create_dataframe_from_array_of_long(self): + import array + data = [Row(longarray=array.array('l', [-9223372036854775808, 0, 9223372036854775807]))] + df = self.spark.createDataFrame(data) + self.assertEqual(df.first(), Row(longarray=[-9223372036854775808, 0, 9223372036854775807])) + + @unittest.skipIf(not have_pandas, pandas_requirement_message) + def test_create_dataframe_from_pandas_with_timestamp(self): + import pandas as pd + from datetime import datetime + pdf = pd.DataFrame({"ts": [datetime(2017, 10, 31, 1, 1, 1)], + "d": [pd.Timestamp.now().date()]}, columns=["d", "ts"]) + # test types are inferred correctly without specifying schema + df = self.spark.createDataFrame(pdf) + self.assertTrue(isinstance(df.schema['ts'].dataType, TimestampType)) + self.assertTrue(isinstance(df.schema['d'].dataType, DateType)) + # test with schema will accept pdf as input + df = self.spark.createDataFrame(pdf, schema="d date, ts timestamp") + self.assertTrue(isinstance(df.schema['ts'].dataType, TimestampType)) + self.assertTrue(isinstance(df.schema['d'].dataType, DateType)) + + @unittest.skipIf(have_pandas, "Required Pandas was found.") + def test_create_dataframe_required_pandas_not_found(self): + with QuietTest(self.sc): + with self.assertRaisesRegexp( + ImportError, + "(Pandas >= .* must be installed|No module named '?pandas'?)"): + import pandas as pd + from datetime import datetime + pdf = pd.DataFrame({"ts": [datetime(2017, 10, 31, 1, 1, 1)], + "d": [pd.Timestamp.now().date()]}) + self.spark.createDataFrame(pdf) + + # Regression test for SPARK-23360 + @unittest.skipIf(not have_pandas, pandas_requirement_message) + def test_create_dateframe_from_pandas_with_dst(self): + import pandas as pd + from datetime import datetime + + pdf = pd.DataFrame({'time': [datetime(2015, 10, 31, 22, 30)]}) + + df = self.spark.createDataFrame(pdf) + self.assertPandasEqual(pdf, df.toPandas()) + + orig_env_tz = os.environ.get('TZ', None) + try: + tz = 'America/Los_Angeles' + os.environ['TZ'] = tz + time.tzset() + with self.sql_conf({'spark.sql.session.timeZone': tz}): + df = self.spark.createDataFrame(pdf) + self.assertPandasEqual(pdf, df.toPandas()) + finally: + del os.environ['TZ'] + if orig_env_tz is not None: + os.environ['TZ'] = orig_env_tz + time.tzset() + + def test_repr_behaviors(self): + import re + pattern = re.compile(r'^ *\|', re.MULTILINE) + df = self.spark.createDataFrame([(1, "1"), (22222, "22222")], ("key", "value")) + + # test when eager evaluation is enabled and _repr_html_ will not be called + with self.sql_conf({"spark.sql.repl.eagerEval.enabled": True}): + expected1 = """+-----+-----+ + || key|value| + |+-----+-----+ + || 1| 1| + ||22222|22222| + |+-----+-----+ + |""" + self.assertEquals(re.sub(pattern, '', expected1), df.__repr__()) + with self.sql_conf({"spark.sql.repl.eagerEval.truncate": 3}): + expected2 = """+---+-----+ + ||key|value| + |+---+-----+ + || 1| 1| + ||222| 222| + |+---+-----+ + |""" + self.assertEquals(re.sub(pattern, '', expected2), df.__repr__()) + with self.sql_conf({"spark.sql.repl.eagerEval.maxNumRows": 1}): + expected3 = """+---+-----+ + ||key|value| + |+---+-----+ + || 1| 1| + |+---+-----+ + |only showing top 1 row + |""" + self.assertEquals(re.sub(pattern, '', expected3), df.__repr__()) + + # test when eager evaluation is enabled and _repr_html_ will be called + with self.sql_conf({"spark.sql.repl.eagerEval.enabled": True}): + expected1 = """ + | + | + | + |
    keyvalue
    11
    2222222222
    + |""" + self.assertEquals(re.sub(pattern, '', expected1), df._repr_html_()) + with self.sql_conf({"spark.sql.repl.eagerEval.truncate": 3}): + expected2 = """ + | + | + | + |
    keyvalue
    11
    222222
    + |""" + self.assertEquals(re.sub(pattern, '', expected2), df._repr_html_()) + with self.sql_conf({"spark.sql.repl.eagerEval.maxNumRows": 1}): + expected3 = """ + | + | + |
    keyvalue
    11
    + |only showing top 1 row + |""" + self.assertEquals(re.sub(pattern, '', expected3), df._repr_html_()) + + # test when eager evaluation is disabled and _repr_html_ will be called + with self.sql_conf({"spark.sql.repl.eagerEval.enabled": False}): + expected = "DataFrame[key: bigint, value: string]" + self.assertEquals(None, df._repr_html_()) + self.assertEquals(expected, df.__repr__()) + with self.sql_conf({"spark.sql.repl.eagerEval.truncate": 3}): + self.assertEquals(None, df._repr_html_()) + self.assertEquals(expected, df.__repr__()) + with self.sql_conf({"spark.sql.repl.eagerEval.maxNumRows": 1}): + self.assertEquals(None, df._repr_html_()) + self.assertEquals(expected, df.__repr__()) + + +class QueryExecutionListenerTests(unittest.TestCase, SQLTestUtils): + # These tests are separate because it uses 'spark.sql.queryExecutionListeners' which is + # static and immutable. This can't be set or unset, for example, via `spark.conf`. + + @classmethod + def setUpClass(cls): + import glob + from pyspark.find_spark_home import _find_spark_home + + SPARK_HOME = _find_spark_home() + filename_pattern = ( + "sql/core/target/scala-*/test-classes/org/apache/spark/sql/" + "TestQueryExecutionListener.class") + cls.has_listener = bool(glob.glob(os.path.join(SPARK_HOME, filename_pattern))) + + if cls.has_listener: + # Note that 'spark.sql.queryExecutionListeners' is a static immutable configuration. + cls.spark = SparkSession.builder \ + .master("local[4]") \ + .appName(cls.__name__) \ + .config( + "spark.sql.queryExecutionListeners", + "org.apache.spark.sql.TestQueryExecutionListener") \ + .getOrCreate() + + def setUp(self): + if not self.has_listener: + raise self.skipTest( + "'org.apache.spark.sql.TestQueryExecutionListener' is not " + "available. Will skip the related tests.") + + @classmethod + def tearDownClass(cls): + if hasattr(cls, "spark"): + cls.spark.stop() + + def tearDown(self): + self.spark._jvm.OnSuccessCall.clear() + + def test_query_execution_listener_on_collect(self): + self.assertFalse( + self.spark._jvm.OnSuccessCall.isCalled(), + "The callback from the query execution listener should not be called before 'collect'") + self.spark.sql("SELECT * FROM range(1)").collect() + self.assertTrue( + self.spark._jvm.OnSuccessCall.isCalled(), + "The callback from the query execution listener should be called after 'collect'") + + @unittest.skipIf( + not have_pandas or not have_pyarrow, + pandas_requirement_message or pyarrow_requirement_message) + def test_query_execution_listener_on_collect_with_arrow(self): + with self.sql_conf({"spark.sql.execution.arrow.enabled": True}): + self.assertFalse( + self.spark._jvm.OnSuccessCall.isCalled(), + "The callback from the query execution listener should not be " + "called before 'toPandas'") + self.spark.sql("SELECT * FROM range(1)").toPandas() + self.assertTrue( + self.spark._jvm.OnSuccessCall.isCalled(), + "The callback from the query execution listener should be called after 'toPandas'") + + +if __name__ == "__main__": + from pyspark.sql.tests.test_dataframe import * + + try: + import xmlrunner + testRunner = xmlrunner.XMLTestRunner(output='target/test-reports') + except ImportError: + testRunner = None + unittest.main(testRunner=testRunner, verbosity=2) diff --git a/python/pyspark/sql/tests/test_datasources.py b/python/pyspark/sql/tests/test_datasources.py new file mode 100644 index 0000000000000..5579620bc2be1 --- /dev/null +++ b/python/pyspark/sql/tests/test_datasources.py @@ -0,0 +1,171 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import shutil +import tempfile + +from pyspark.sql import Row +from pyspark.sql.types import * +from pyspark.testing.sqlutils import ReusedSQLTestCase + + +class DataSourcesTests(ReusedSQLTestCase): + + def test_linesep_text(self): + df = self.spark.read.text("python/test_support/sql/ages_newlines.csv", lineSep=",") + expected = [Row(value=u'Joe'), Row(value=u'20'), Row(value=u'"Hi'), + Row(value=u'\nI am Jeo"\nTom'), Row(value=u'30'), + Row(value=u'"My name is Tom"\nHyukjin'), Row(value=u'25'), + Row(value=u'"I am Hyukjin\n\nI love Spark!"\n')] + self.assertEqual(df.collect(), expected) + + tpath = tempfile.mkdtemp() + shutil.rmtree(tpath) + try: + df.write.text(tpath, lineSep="!") + expected = [Row(value=u'Joe!20!"Hi!'), Row(value=u'I am Jeo"'), + Row(value=u'Tom!30!"My name is Tom"'), + Row(value=u'Hyukjin!25!"I am Hyukjin'), + Row(value=u''), Row(value=u'I love Spark!"'), + Row(value=u'!')] + readback = self.spark.read.text(tpath) + self.assertEqual(readback.collect(), expected) + finally: + shutil.rmtree(tpath) + + def test_multiline_json(self): + people1 = self.spark.read.json("python/test_support/sql/people.json") + people_array = self.spark.read.json("python/test_support/sql/people_array.json", + multiLine=True) + self.assertEqual(people1.collect(), people_array.collect()) + + def test_encoding_json(self): + people_array = self.spark.read\ + .json("python/test_support/sql/people_array_utf16le.json", + multiLine=True, encoding="UTF-16LE") + expected = [Row(age=30, name=u'Andy'), Row(age=19, name=u'Justin')] + self.assertEqual(people_array.collect(), expected) + + def test_linesep_json(self): + df = self.spark.read.json("python/test_support/sql/people.json", lineSep=",") + expected = [Row(_corrupt_record=None, name=u'Michael'), + Row(_corrupt_record=u' "age":30}\n{"name":"Justin"', name=None), + Row(_corrupt_record=u' "age":19}\n', name=None)] + self.assertEqual(df.collect(), expected) + + tpath = tempfile.mkdtemp() + shutil.rmtree(tpath) + try: + df = self.spark.read.json("python/test_support/sql/people.json") + df.write.json(tpath, lineSep="!!") + readback = self.spark.read.json(tpath, lineSep="!!") + self.assertEqual(readback.collect(), df.collect()) + finally: + shutil.rmtree(tpath) + + def test_multiline_csv(self): + ages_newlines = self.spark.read.csv( + "python/test_support/sql/ages_newlines.csv", multiLine=True) + expected = [Row(_c0=u'Joe', _c1=u'20', _c2=u'Hi,\nI am Jeo'), + Row(_c0=u'Tom', _c1=u'30', _c2=u'My name is Tom'), + Row(_c0=u'Hyukjin', _c1=u'25', _c2=u'I am Hyukjin\n\nI love Spark!')] + self.assertEqual(ages_newlines.collect(), expected) + + def test_ignorewhitespace_csv(self): + tmpPath = tempfile.mkdtemp() + shutil.rmtree(tmpPath) + self.spark.createDataFrame([[" a", "b ", " c "]]).write.csv( + tmpPath, + ignoreLeadingWhiteSpace=False, + ignoreTrailingWhiteSpace=False) + + expected = [Row(value=u' a,b , c ')] + readback = self.spark.read.text(tmpPath) + self.assertEqual(readback.collect(), expected) + shutil.rmtree(tmpPath) + + def test_read_multiple_orc_file(self): + df = self.spark.read.orc(["python/test_support/sql/orc_partitioned/b=0/c=0", + "python/test_support/sql/orc_partitioned/b=1/c=1"]) + self.assertEqual(2, df.count()) + + def test_read_text_file_list(self): + df = self.spark.read.text(['python/test_support/sql/text-test.txt', + 'python/test_support/sql/text-test.txt']) + count = df.count() + self.assertEquals(count, 4) + + def test_json_sampling_ratio(self): + rdd = self.spark.sparkContext.range(0, 100, 1, 1) \ + .map(lambda x: '{"a":0.1}' if x == 1 else '{"a":%s}' % str(x)) + schema = self.spark.read.option('inferSchema', True) \ + .option('samplingRatio', 0.5) \ + .json(rdd).schema + self.assertEquals(schema, StructType([StructField("a", LongType(), True)])) + + def test_csv_sampling_ratio(self): + rdd = self.spark.sparkContext.range(0, 100, 1, 1) \ + .map(lambda x: '0.1' if x == 1 else str(x)) + schema = self.spark.read.option('inferSchema', True)\ + .csv(rdd, samplingRatio=0.5).schema + self.assertEquals(schema, StructType([StructField("_c0", IntegerType(), True)])) + + def test_checking_csv_header(self): + path = tempfile.mkdtemp() + shutil.rmtree(path) + try: + self.spark.createDataFrame([[1, 1000], [2000, 2]])\ + .toDF('f1', 'f2').write.option("header", "true").csv(path) + schema = StructType([ + StructField('f2', IntegerType(), nullable=True), + StructField('f1', IntegerType(), nullable=True)]) + df = self.spark.read.option('header', 'true').schema(schema)\ + .csv(path, enforceSchema=False) + self.assertRaisesRegexp( + Exception, + "CSV header does not conform to the schema", + lambda: df.collect()) + finally: + shutil.rmtree(path) + + def test_ignore_column_of_all_nulls(self): + path = tempfile.mkdtemp() + shutil.rmtree(path) + try: + df = self.spark.createDataFrame([["""{"a":null, "b":1, "c":3.0}"""], + ["""{"a":null, "b":null, "c":"string"}"""], + ["""{"a":null, "b":null, "c":null}"""]]) + df.write.text(path) + schema = StructType([ + StructField('b', LongType(), nullable=True), + StructField('c', StringType(), nullable=True)]) + readback = self.spark.read.json(path, dropFieldIfAllNull=True) + self.assertEquals(readback.schema, schema) + finally: + shutil.rmtree(path) + + +if __name__ == "__main__": + import unittest + from pyspark.sql.tests.test_datasources import * + + try: + import xmlrunner + testRunner = xmlrunner.XMLTestRunner(output='target/test-reports') + except ImportError: + testRunner = None + unittest.main(testRunner=testRunner, verbosity=2) diff --git a/python/pyspark/sql/tests/test_functions.py b/python/pyspark/sql/tests/test_functions.py new file mode 100644 index 0000000000000..fe6660272e323 --- /dev/null +++ b/python/pyspark/sql/tests/test_functions.py @@ -0,0 +1,279 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import datetime +import sys + +from pyspark.sql import Row +from pyspark.testing.sqlutils import ReusedSQLTestCase + + +class FunctionsTests(ReusedSQLTestCase): + + def test_explode(self): + from pyspark.sql.functions import explode, explode_outer, posexplode_outer + d = [ + Row(a=1, intlist=[1, 2, 3], mapfield={"a": "b"}), + Row(a=1, intlist=[], mapfield={}), + Row(a=1, intlist=None, mapfield=None), + ] + rdd = self.sc.parallelize(d) + data = self.spark.createDataFrame(rdd) + + result = data.select(explode(data.intlist).alias("a")).select("a").collect() + self.assertEqual(result[0][0], 1) + self.assertEqual(result[1][0], 2) + self.assertEqual(result[2][0], 3) + + result = data.select(explode(data.mapfield).alias("a", "b")).select("a", "b").collect() + self.assertEqual(result[0][0], "a") + self.assertEqual(result[0][1], "b") + + result = [tuple(x) for x in data.select(posexplode_outer("intlist")).collect()] + self.assertEqual(result, [(0, 1), (1, 2), (2, 3), (None, None), (None, None)]) + + result = [tuple(x) for x in data.select(posexplode_outer("mapfield")).collect()] + self.assertEqual(result, [(0, 'a', 'b'), (None, None, None), (None, None, None)]) + + result = [x[0] for x in data.select(explode_outer("intlist")).collect()] + self.assertEqual(result, [1, 2, 3, None, None]) + + result = [tuple(x) for x in data.select(explode_outer("mapfield")).collect()] + self.assertEqual(result, [('a', 'b'), (None, None), (None, None)]) + + def test_basic_functions(self): + rdd = self.sc.parallelize(['{"foo":"bar"}', '{"foo":"baz"}']) + df = self.spark.read.json(rdd) + df.count() + df.collect() + df.schema + + # cache and checkpoint + self.assertFalse(df.is_cached) + df.persist() + df.unpersist(True) + df.cache() + self.assertTrue(df.is_cached) + self.assertEqual(2, df.count()) + + with self.tempView("temp"): + df.createOrReplaceTempView("temp") + df = self.spark.sql("select foo from temp") + df.count() + df.collect() + + def test_corr(self): + import math + df = self.sc.parallelize([Row(a=i, b=math.sqrt(i)) for i in range(10)]).toDF() + corr = df.stat.corr(u"a", "b") + self.assertTrue(abs(corr - 0.95734012) < 1e-6) + + def test_sampleby(self): + df = self.sc.parallelize([Row(a=i, b=(i % 3)) for i in range(10)]).toDF() + sampled = df.stat.sampleBy(u"b", fractions={0: 0.5, 1: 0.5}, seed=0) + self.assertTrue(sampled.count() == 3) + + def test_cov(self): + df = self.sc.parallelize([Row(a=i, b=2 * i) for i in range(10)]).toDF() + cov = df.stat.cov(u"a", "b") + self.assertTrue(abs(cov - 55.0 / 3) < 1e-6) + + def test_crosstab(self): + df = self.sc.parallelize([Row(a=i % 3, b=i % 2) for i in range(1, 7)]).toDF() + ct = df.stat.crosstab(u"a", "b").collect() + ct = sorted(ct, key=lambda x: x[0]) + for i, row in enumerate(ct): + self.assertEqual(row[0], str(i)) + self.assertTrue(row[1], 1) + self.assertTrue(row[2], 1) + + def test_math_functions(self): + df = self.sc.parallelize([Row(a=i, b=2 * i) for i in range(10)]).toDF() + from pyspark.sql import functions + import math + + def get_values(l): + return [j[0] for j in l] + + def assert_close(a, b): + c = get_values(b) + diff = [abs(v - c[k]) < 1e-6 for k, v in enumerate(a)] + return sum(diff) == len(a) + assert_close([math.cos(i) for i in range(10)], + df.select(functions.cos(df.a)).collect()) + assert_close([math.cos(i) for i in range(10)], + df.select(functions.cos("a")).collect()) + assert_close([math.sin(i) for i in range(10)], + df.select(functions.sin(df.a)).collect()) + assert_close([math.sin(i) for i in range(10)], + df.select(functions.sin(df['a'])).collect()) + assert_close([math.pow(i, 2 * i) for i in range(10)], + df.select(functions.pow(df.a, df.b)).collect()) + assert_close([math.pow(i, 2) for i in range(10)], + df.select(functions.pow(df.a, 2)).collect()) + assert_close([math.pow(i, 2) for i in range(10)], + df.select(functions.pow(df.a, 2.0)).collect()) + assert_close([math.hypot(i, 2 * i) for i in range(10)], + df.select(functions.hypot(df.a, df.b)).collect()) + + def test_rand_functions(self): + df = self.df + from pyspark.sql import functions + rnd = df.select('key', functions.rand()).collect() + for row in rnd: + assert row[1] >= 0.0 and row[1] <= 1.0, "got: %s" % row[1] + rndn = df.select('key', functions.randn(5)).collect() + for row in rndn: + assert row[1] >= -4.0 and row[1] <= 4.0, "got: %s" % row[1] + + # If the specified seed is 0, we should use it. + # https://issues.apache.org/jira/browse/SPARK-9691 + rnd1 = df.select('key', functions.rand(0)).collect() + rnd2 = df.select('key', functions.rand(0)).collect() + self.assertEqual(sorted(rnd1), sorted(rnd2)) + + rndn1 = df.select('key', functions.randn(0)).collect() + rndn2 = df.select('key', functions.randn(0)).collect() + self.assertEqual(sorted(rndn1), sorted(rndn2)) + + def test_string_functions(self): + from pyspark.sql.functions import col, lit + df = self.spark.createDataFrame([['nick']], schema=['name']) + self.assertRaisesRegexp( + TypeError, + "must be the same type", + lambda: df.select(col('name').substr(0, lit(1)))) + if sys.version_info.major == 2: + self.assertRaises( + TypeError, + lambda: df.select(col('name').substr(long(0), long(1)))) + + def test_array_contains_function(self): + from pyspark.sql.functions import array_contains + + df = self.spark.createDataFrame([(["1", "2", "3"],), ([],)], ['data']) + actual = df.select(array_contains(df.data, "1").alias('b')).collect() + self.assertEqual([Row(b=True), Row(b=False)], actual) + + def test_between_function(self): + df = self.sc.parallelize([ + Row(a=1, b=2, c=3), + Row(a=2, b=1, c=3), + Row(a=4, b=1, c=4)]).toDF() + self.assertEqual([Row(a=2, b=1, c=3), Row(a=4, b=1, c=4)], + df.filter(df.a.between(df.b, df.c)).collect()) + + def test_dayofweek(self): + from pyspark.sql.functions import dayofweek + dt = datetime.datetime(2017, 11, 6) + df = self.spark.createDataFrame([Row(date=dt)]) + row = df.select(dayofweek(df.date)).first() + self.assertEqual(row[0], 2) + + def test_expr(self): + from pyspark.sql import functions + row = Row(a="length string", b=75) + df = self.spark.createDataFrame([row]) + result = df.select(functions.expr("length(a)")).collect()[0].asDict() + self.assertEqual(13, result["length(a)"]) + + # add test for SPARK-10577 (test broadcast join hint) + def test_functions_broadcast(self): + from pyspark.sql.functions import broadcast + + df1 = self.spark.createDataFrame([(1, "1"), (2, "2")], ("key", "value")) + df2 = self.spark.createDataFrame([(1, "1"), (2, "2")], ("key", "value")) + + # equijoin - should be converted into broadcast join + plan1 = df1.join(broadcast(df2), "key")._jdf.queryExecution().executedPlan() + self.assertEqual(1, plan1.toString().count("BroadcastHashJoin")) + + # no join key -- should not be a broadcast join + plan2 = df1.crossJoin(broadcast(df2))._jdf.queryExecution().executedPlan() + self.assertEqual(0, plan2.toString().count("BroadcastHashJoin")) + + # planner should not crash without a join + broadcast(df1)._jdf.queryExecution().executedPlan() + + def test_first_last_ignorenulls(self): + from pyspark.sql import functions + df = self.spark.range(0, 100) + df2 = df.select(functions.when(df.id % 3 == 0, None).otherwise(df.id).alias("id")) + df3 = df2.select(functions.first(df2.id, False).alias('a'), + functions.first(df2.id, True).alias('b'), + functions.last(df2.id, False).alias('c'), + functions.last(df2.id, True).alias('d')) + self.assertEqual([Row(a=None, b=1, c=None, d=98)], df3.collect()) + + def test_approxQuantile(self): + df = self.sc.parallelize([Row(a=i, b=i+10) for i in range(10)]).toDF() + for f in ["a", u"a"]: + aq = df.stat.approxQuantile(f, [0.1, 0.5, 0.9], 0.1) + self.assertTrue(isinstance(aq, list)) + self.assertEqual(len(aq), 3) + self.assertTrue(all(isinstance(q, float) for q in aq)) + aqs = df.stat.approxQuantile(["a", u"b"], [0.1, 0.5, 0.9], 0.1) + self.assertTrue(isinstance(aqs, list)) + self.assertEqual(len(aqs), 2) + self.assertTrue(isinstance(aqs[0], list)) + self.assertEqual(len(aqs[0]), 3) + self.assertTrue(all(isinstance(q, float) for q in aqs[0])) + self.assertTrue(isinstance(aqs[1], list)) + self.assertEqual(len(aqs[1]), 3) + self.assertTrue(all(isinstance(q, float) for q in aqs[1])) + aqt = df.stat.approxQuantile((u"a", "b"), [0.1, 0.5, 0.9], 0.1) + self.assertTrue(isinstance(aqt, list)) + self.assertEqual(len(aqt), 2) + self.assertTrue(isinstance(aqt[0], list)) + self.assertEqual(len(aqt[0]), 3) + self.assertTrue(all(isinstance(q, float) for q in aqt[0])) + self.assertTrue(isinstance(aqt[1], list)) + self.assertEqual(len(aqt[1]), 3) + self.assertTrue(all(isinstance(q, float) for q in aqt[1])) + self.assertRaises(ValueError, lambda: df.stat.approxQuantile(123, [0.1, 0.9], 0.1)) + self.assertRaises(ValueError, lambda: df.stat.approxQuantile(("a", 123), [0.1, 0.9], 0.1)) + self.assertRaises(ValueError, lambda: df.stat.approxQuantile(["a", 123], [0.1, 0.9], 0.1)) + + def test_sort_with_nulls_order(self): + from pyspark.sql import functions + + df = self.spark.createDataFrame( + [('Tom', 80), (None, 60), ('Alice', 50)], ["name", "height"]) + self.assertEquals( + df.select(df.name).orderBy(functions.asc_nulls_first('name')).collect(), + [Row(name=None), Row(name=u'Alice'), Row(name=u'Tom')]) + self.assertEquals( + df.select(df.name).orderBy(functions.asc_nulls_last('name')).collect(), + [Row(name=u'Alice'), Row(name=u'Tom'), Row(name=None)]) + self.assertEquals( + df.select(df.name).orderBy(functions.desc_nulls_first('name')).collect(), + [Row(name=None), Row(name=u'Tom'), Row(name=u'Alice')]) + self.assertEquals( + df.select(df.name).orderBy(functions.desc_nulls_last('name')).collect(), + [Row(name=u'Tom'), Row(name=u'Alice'), Row(name=None)]) + + +if __name__ == "__main__": + import unittest + from pyspark.sql.tests.test_functions import * + + try: + import xmlrunner + testRunner = xmlrunner.XMLTestRunner(output='target/test-reports') + except ImportError: + testRunner = None + unittest.main(testRunner=testRunner, verbosity=2) diff --git a/python/pyspark/sql/tests/test_group.py b/python/pyspark/sql/tests/test_group.py new file mode 100644 index 0000000000000..6de1b8ea0b3ce --- /dev/null +++ b/python/pyspark/sql/tests/test_group.py @@ -0,0 +1,46 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from pyspark.sql import Row +from pyspark.testing.sqlutils import ReusedSQLTestCase + + +class GroupTests(ReusedSQLTestCase): + + def test_aggregator(self): + df = self.df + g = df.groupBy() + self.assertEqual([99, 100], sorted(g.agg({'key': 'max', 'value': 'count'}).collect()[0])) + self.assertEqual([Row(**{"AVG(key#0)": 49.5})], g.mean().collect()) + + from pyspark.sql import functions + self.assertEqual((0, u'99'), + tuple(g.agg(functions.first(df.key), functions.last(df.value)).first())) + self.assertTrue(95 < g.agg(functions.approx_count_distinct(df.key)).first()[0]) + self.assertEqual(100, g.agg(functions.countDistinct(df.value)).first()[0]) + + +if __name__ == "__main__": + import unittest + from pyspark.sql.tests.test_group import * + + try: + import xmlrunner + testRunner = xmlrunner.XMLTestRunner(output='target/test-reports') + except ImportError: + testRunner = None + unittest.main(testRunner=testRunner, verbosity=2) diff --git a/python/pyspark/sql/tests/test_pandas_udf.py b/python/pyspark/sql/tests/test_pandas_udf.py new file mode 100644 index 0000000000000..c4b5478a7e893 --- /dev/null +++ b/python/pyspark/sql/tests/test_pandas_udf.py @@ -0,0 +1,217 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import unittest + +from pyspark.sql.types import * +from pyspark.sql.utils import ParseException +from pyspark.testing.sqlutils import ReusedSQLTestCase, have_pandas, have_pyarrow, \ + pandas_requirement_message, pyarrow_requirement_message +from pyspark.testing.utils import QuietTest + + +@unittest.skipIf( + not have_pandas or not have_pyarrow, + pandas_requirement_message or pyarrow_requirement_message) +class PandasUDFTests(ReusedSQLTestCase): + + def test_pandas_udf_basic(self): + from pyspark.rdd import PythonEvalType + from pyspark.sql.functions import pandas_udf, PandasUDFType + + udf = pandas_udf(lambda x: x, DoubleType()) + self.assertEqual(udf.returnType, DoubleType()) + self.assertEqual(udf.evalType, PythonEvalType.SQL_SCALAR_PANDAS_UDF) + + udf = pandas_udf(lambda x: x, DoubleType(), PandasUDFType.SCALAR) + self.assertEqual(udf.returnType, DoubleType()) + self.assertEqual(udf.evalType, PythonEvalType.SQL_SCALAR_PANDAS_UDF) + + udf = pandas_udf(lambda x: x, 'double', PandasUDFType.SCALAR) + self.assertEqual(udf.returnType, DoubleType()) + self.assertEqual(udf.evalType, PythonEvalType.SQL_SCALAR_PANDAS_UDF) + + udf = pandas_udf(lambda x: x, StructType([StructField("v", DoubleType())]), + PandasUDFType.GROUPED_MAP) + self.assertEqual(udf.returnType, StructType([StructField("v", DoubleType())])) + self.assertEqual(udf.evalType, PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF) + + udf = pandas_udf(lambda x: x, 'v double', PandasUDFType.GROUPED_MAP) + self.assertEqual(udf.returnType, StructType([StructField("v", DoubleType())])) + self.assertEqual(udf.evalType, PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF) + + udf = pandas_udf(lambda x: x, 'v double', + functionType=PandasUDFType.GROUPED_MAP) + self.assertEqual(udf.returnType, StructType([StructField("v", DoubleType())])) + self.assertEqual(udf.evalType, PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF) + + udf = pandas_udf(lambda x: x, returnType='v double', + functionType=PandasUDFType.GROUPED_MAP) + self.assertEqual(udf.returnType, StructType([StructField("v", DoubleType())])) + self.assertEqual(udf.evalType, PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF) + + def test_pandas_udf_decorator(self): + from pyspark.rdd import PythonEvalType + from pyspark.sql.functions import pandas_udf, PandasUDFType + from pyspark.sql.types import StructType, StructField, DoubleType + + @pandas_udf(DoubleType()) + def foo(x): + return x + self.assertEqual(foo.returnType, DoubleType()) + self.assertEqual(foo.evalType, PythonEvalType.SQL_SCALAR_PANDAS_UDF) + + @pandas_udf(returnType=DoubleType()) + def foo(x): + return x + self.assertEqual(foo.returnType, DoubleType()) + self.assertEqual(foo.evalType, PythonEvalType.SQL_SCALAR_PANDAS_UDF) + + schema = StructType([StructField("v", DoubleType())]) + + @pandas_udf(schema, PandasUDFType.GROUPED_MAP) + def foo(x): + return x + self.assertEqual(foo.returnType, schema) + self.assertEqual(foo.evalType, PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF) + + @pandas_udf('v double', PandasUDFType.GROUPED_MAP) + def foo(x): + return x + self.assertEqual(foo.returnType, schema) + self.assertEqual(foo.evalType, PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF) + + @pandas_udf(schema, functionType=PandasUDFType.GROUPED_MAP) + def foo(x): + return x + self.assertEqual(foo.returnType, schema) + self.assertEqual(foo.evalType, PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF) + + @pandas_udf(returnType='double', functionType=PandasUDFType.SCALAR) + def foo(x): + return x + self.assertEqual(foo.returnType, DoubleType()) + self.assertEqual(foo.evalType, PythonEvalType.SQL_SCALAR_PANDAS_UDF) + + @pandas_udf(returnType=schema, functionType=PandasUDFType.GROUPED_MAP) + def foo(x): + return x + self.assertEqual(foo.returnType, schema) + self.assertEqual(foo.evalType, PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF) + + def test_udf_wrong_arg(self): + from pyspark.sql.functions import pandas_udf, PandasUDFType + + with QuietTest(self.sc): + with self.assertRaises(ParseException): + @pandas_udf('blah') + def foo(x): + return x + with self.assertRaisesRegexp(ValueError, 'Invalid returnType.*None'): + @pandas_udf(functionType=PandasUDFType.SCALAR) + def foo(x): + return x + with self.assertRaisesRegexp(ValueError, 'Invalid functionType'): + @pandas_udf('double', 100) + def foo(x): + return x + + with self.assertRaisesRegexp(ValueError, '0-arg pandas_udfs.*not.*supported'): + pandas_udf(lambda: 1, LongType(), PandasUDFType.SCALAR) + with self.assertRaisesRegexp(ValueError, '0-arg pandas_udfs.*not.*supported'): + @pandas_udf(LongType(), PandasUDFType.SCALAR) + def zero_with_type(): + return 1 + + with self.assertRaisesRegexp(TypeError, 'Invalid returnType'): + @pandas_udf(returnType=PandasUDFType.GROUPED_MAP) + def foo(df): + return df + with self.assertRaisesRegexp(TypeError, 'Invalid returnType'): + @pandas_udf(returnType='double', functionType=PandasUDFType.GROUPED_MAP) + def foo(df): + return df + with self.assertRaisesRegexp(ValueError, 'Invalid function'): + @pandas_udf(returnType='k int, v double', functionType=PandasUDFType.GROUPED_MAP) + def foo(k, v, w): + return k + + def test_stopiteration_in_udf(self): + from pyspark.sql.functions import udf, pandas_udf, PandasUDFType + from py4j.protocol import Py4JJavaError + + def foo(x): + raise StopIteration() + + def foofoo(x, y): + raise StopIteration() + + exc_message = "Caught StopIteration thrown from user's code; failing the task" + df = self.spark.range(0, 100) + + # plain udf (test for SPARK-23754) + self.assertRaisesRegexp( + Py4JJavaError, + exc_message, + df.withColumn('v', udf(foo)('id')).collect + ) + + # pandas scalar udf + self.assertRaisesRegexp( + Py4JJavaError, + exc_message, + df.withColumn( + 'v', pandas_udf(foo, 'double', PandasUDFType.SCALAR)('id') + ).collect + ) + + # pandas grouped map + self.assertRaisesRegexp( + Py4JJavaError, + exc_message, + df.groupBy('id').apply( + pandas_udf(foo, df.schema, PandasUDFType.GROUPED_MAP) + ).collect + ) + + self.assertRaisesRegexp( + Py4JJavaError, + exc_message, + df.groupBy('id').apply( + pandas_udf(foofoo, df.schema, PandasUDFType.GROUPED_MAP) + ).collect + ) + + # pandas grouped agg + self.assertRaisesRegexp( + Py4JJavaError, + exc_message, + df.groupBy('id').agg( + pandas_udf(foo, 'double', PandasUDFType.GROUPED_AGG)('id') + ).collect + ) + + +if __name__ == "__main__": + from pyspark.sql.tests.test_pandas_udf import * + + try: + import xmlrunner + testRunner = xmlrunner.XMLTestRunner(output='target/test-reports') + except ImportError: + testRunner = None + unittest.main(testRunner=testRunner, verbosity=2) diff --git a/python/pyspark/sql/tests/test_pandas_udf_grouped_agg.py b/python/pyspark/sql/tests/test_pandas_udf_grouped_agg.py new file mode 100644 index 0000000000000..5383704434c85 --- /dev/null +++ b/python/pyspark/sql/tests/test_pandas_udf_grouped_agg.py @@ -0,0 +1,504 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import unittest + +from pyspark.sql.types import * +from pyspark.sql.utils import AnalysisException +from pyspark.testing.sqlutils import ReusedSQLTestCase, have_pandas, have_pyarrow, \ + pandas_requirement_message, pyarrow_requirement_message +from pyspark.testing.utils import QuietTest + + +@unittest.skipIf( + not have_pandas or not have_pyarrow, + pandas_requirement_message or pyarrow_requirement_message) +class GroupedAggPandasUDFTests(ReusedSQLTestCase): + + @property + def data(self): + from pyspark.sql.functions import array, explode, col, lit + return self.spark.range(10).toDF('id') \ + .withColumn("vs", array([lit(i * 1.0) + col('id') for i in range(20, 30)])) \ + .withColumn("v", explode(col('vs'))) \ + .drop('vs') \ + .withColumn('w', lit(1.0)) + + @property + def python_plus_one(self): + from pyspark.sql.functions import udf + + @udf('double') + def plus_one(v): + assert isinstance(v, (int, float)) + return v + 1 + return plus_one + + @property + def pandas_scalar_plus_two(self): + import pandas as pd + from pyspark.sql.functions import pandas_udf, PandasUDFType + + @pandas_udf('double', PandasUDFType.SCALAR) + def plus_two(v): + assert isinstance(v, pd.Series) + return v + 2 + return plus_two + + @property + def pandas_agg_mean_udf(self): + from pyspark.sql.functions import pandas_udf, PandasUDFType + + @pandas_udf('double', PandasUDFType.GROUPED_AGG) + def avg(v): + return v.mean() + return avg + + @property + def pandas_agg_sum_udf(self): + from pyspark.sql.functions import pandas_udf, PandasUDFType + + @pandas_udf('double', PandasUDFType.GROUPED_AGG) + def sum(v): + return v.sum() + return sum + + @property + def pandas_agg_weighted_mean_udf(self): + import numpy as np + from pyspark.sql.functions import pandas_udf, PandasUDFType + + @pandas_udf('double', PandasUDFType.GROUPED_AGG) + def weighted_mean(v, w): + return np.average(v, weights=w) + return weighted_mean + + def test_manual(self): + from pyspark.sql.functions import pandas_udf, array + + df = self.data + sum_udf = self.pandas_agg_sum_udf + mean_udf = self.pandas_agg_mean_udf + mean_arr_udf = pandas_udf( + self.pandas_agg_mean_udf.func, + ArrayType(self.pandas_agg_mean_udf.returnType), + self.pandas_agg_mean_udf.evalType) + + result1 = df.groupby('id').agg( + sum_udf(df.v), + mean_udf(df.v), + mean_arr_udf(array(df.v))).sort('id') + expected1 = self.spark.createDataFrame( + [[0, 245.0, 24.5, [24.5]], + [1, 255.0, 25.5, [25.5]], + [2, 265.0, 26.5, [26.5]], + [3, 275.0, 27.5, [27.5]], + [4, 285.0, 28.5, [28.5]], + [5, 295.0, 29.5, [29.5]], + [6, 305.0, 30.5, [30.5]], + [7, 315.0, 31.5, [31.5]], + [8, 325.0, 32.5, [32.5]], + [9, 335.0, 33.5, [33.5]]], + ['id', 'sum(v)', 'avg(v)', 'avg(array(v))']) + + self.assertPandasEqual(expected1.toPandas(), result1.toPandas()) + + def test_basic(self): + from pyspark.sql.functions import col, lit, mean + + df = self.data + weighted_mean_udf = self.pandas_agg_weighted_mean_udf + + # Groupby one column and aggregate one UDF with literal + result1 = df.groupby('id').agg(weighted_mean_udf(df.v, lit(1.0))).sort('id') + expected1 = df.groupby('id').agg(mean(df.v).alias('weighted_mean(v, 1.0)')).sort('id') + self.assertPandasEqual(expected1.toPandas(), result1.toPandas()) + + # Groupby one expression and aggregate one UDF with literal + result2 = df.groupby((col('id') + 1)).agg(weighted_mean_udf(df.v, lit(1.0)))\ + .sort(df.id + 1) + expected2 = df.groupby((col('id') + 1))\ + .agg(mean(df.v).alias('weighted_mean(v, 1.0)')).sort(df.id + 1) + self.assertPandasEqual(expected2.toPandas(), result2.toPandas()) + + # Groupby one column and aggregate one UDF without literal + result3 = df.groupby('id').agg(weighted_mean_udf(df.v, df.w)).sort('id') + expected3 = df.groupby('id').agg(mean(df.v).alias('weighted_mean(v, w)')).sort('id') + self.assertPandasEqual(expected3.toPandas(), result3.toPandas()) + + # Groupby one expression and aggregate one UDF without literal + result4 = df.groupby((col('id') + 1).alias('id'))\ + .agg(weighted_mean_udf(df.v, df.w))\ + .sort('id') + expected4 = df.groupby((col('id') + 1).alias('id'))\ + .agg(mean(df.v).alias('weighted_mean(v, w)'))\ + .sort('id') + self.assertPandasEqual(expected4.toPandas(), result4.toPandas()) + + def test_unsupported_types(self): + from pyspark.sql.types import DoubleType, MapType + from pyspark.sql.functions import pandas_udf, PandasUDFType + + with QuietTest(self.sc): + with self.assertRaisesRegexp(NotImplementedError, 'not supported'): + pandas_udf( + lambda x: x, + ArrayType(ArrayType(TimestampType())), + PandasUDFType.GROUPED_AGG) + + with QuietTest(self.sc): + with self.assertRaisesRegexp(NotImplementedError, 'not supported'): + @pandas_udf('mean double, std double', PandasUDFType.GROUPED_AGG) + def mean_and_std_udf(v): + return v.mean(), v.std() + + with QuietTest(self.sc): + with self.assertRaisesRegexp(NotImplementedError, 'not supported'): + @pandas_udf(MapType(DoubleType(), DoubleType()), PandasUDFType.GROUPED_AGG) + def mean_and_std_udf(v): + return {v.mean(): v.std()} + + def test_alias(self): + from pyspark.sql.functions import mean + + df = self.data + mean_udf = self.pandas_agg_mean_udf + + result1 = df.groupby('id').agg(mean_udf(df.v).alias('mean_alias')) + expected1 = df.groupby('id').agg(mean(df.v).alias('mean_alias')) + + self.assertPandasEqual(expected1.toPandas(), result1.toPandas()) + + def test_mixed_sql(self): + """ + Test mixing group aggregate pandas UDF with sql expression. + """ + from pyspark.sql.functions import sum + + df = self.data + sum_udf = self.pandas_agg_sum_udf + + # Mix group aggregate pandas UDF with sql expression + result1 = (df.groupby('id') + .agg(sum_udf(df.v) + 1) + .sort('id')) + expected1 = (df.groupby('id') + .agg(sum(df.v) + 1) + .sort('id')) + + # Mix group aggregate pandas UDF with sql expression (order swapped) + result2 = (df.groupby('id') + .agg(sum_udf(df.v + 1)) + .sort('id')) + + expected2 = (df.groupby('id') + .agg(sum(df.v + 1)) + .sort('id')) + + # Wrap group aggregate pandas UDF with two sql expressions + result3 = (df.groupby('id') + .agg(sum_udf(df.v + 1) + 2) + .sort('id')) + expected3 = (df.groupby('id') + .agg(sum(df.v + 1) + 2) + .sort('id')) + + self.assertPandasEqual(expected1.toPandas(), result1.toPandas()) + self.assertPandasEqual(expected2.toPandas(), result2.toPandas()) + self.assertPandasEqual(expected3.toPandas(), result3.toPandas()) + + def test_mixed_udfs(self): + """ + Test mixing group aggregate pandas UDF with python UDF and scalar pandas UDF. + """ + from pyspark.sql.functions import sum + + df = self.data + plus_one = self.python_plus_one + plus_two = self.pandas_scalar_plus_two + sum_udf = self.pandas_agg_sum_udf + + # Mix group aggregate pandas UDF and python UDF + result1 = (df.groupby('id') + .agg(plus_one(sum_udf(df.v))) + .sort('id')) + expected1 = (df.groupby('id') + .agg(plus_one(sum(df.v))) + .sort('id')) + + # Mix group aggregate pandas UDF and python UDF (order swapped) + result2 = (df.groupby('id') + .agg(sum_udf(plus_one(df.v))) + .sort('id')) + expected2 = (df.groupby('id') + .agg(sum(plus_one(df.v))) + .sort('id')) + + # Mix group aggregate pandas UDF and scalar pandas UDF + result3 = (df.groupby('id') + .agg(sum_udf(plus_two(df.v))) + .sort('id')) + expected3 = (df.groupby('id') + .agg(sum(plus_two(df.v))) + .sort('id')) + + # Mix group aggregate pandas UDF and scalar pandas UDF (order swapped) + result4 = (df.groupby('id') + .agg(plus_two(sum_udf(df.v))) + .sort('id')) + expected4 = (df.groupby('id') + .agg(plus_two(sum(df.v))) + .sort('id')) + + # Wrap group aggregate pandas UDF with two python UDFs and use python UDF in groupby + result5 = (df.groupby(plus_one(df.id)) + .agg(plus_one(sum_udf(plus_one(df.v)))) + .sort('plus_one(id)')) + expected5 = (df.groupby(plus_one(df.id)) + .agg(plus_one(sum(plus_one(df.v)))) + .sort('plus_one(id)')) + + # Wrap group aggregate pandas UDF with two scala pandas UDF and user scala pandas UDF in + # groupby + result6 = (df.groupby(plus_two(df.id)) + .agg(plus_two(sum_udf(plus_two(df.v)))) + .sort('plus_two(id)')) + expected6 = (df.groupby(plus_two(df.id)) + .agg(plus_two(sum(plus_two(df.v)))) + .sort('plus_two(id)')) + + self.assertPandasEqual(expected1.toPandas(), result1.toPandas()) + self.assertPandasEqual(expected2.toPandas(), result2.toPandas()) + self.assertPandasEqual(expected3.toPandas(), result3.toPandas()) + self.assertPandasEqual(expected4.toPandas(), result4.toPandas()) + self.assertPandasEqual(expected5.toPandas(), result5.toPandas()) + self.assertPandasEqual(expected6.toPandas(), result6.toPandas()) + + def test_multiple_udfs(self): + """ + Test multiple group aggregate pandas UDFs in one agg function. + """ + from pyspark.sql.functions import sum, mean + + df = self.data + mean_udf = self.pandas_agg_mean_udf + sum_udf = self.pandas_agg_sum_udf + weighted_mean_udf = self.pandas_agg_weighted_mean_udf + + result1 = (df.groupBy('id') + .agg(mean_udf(df.v), + sum_udf(df.v), + weighted_mean_udf(df.v, df.w)) + .sort('id') + .toPandas()) + expected1 = (df.groupBy('id') + .agg(mean(df.v), + sum(df.v), + mean(df.v).alias('weighted_mean(v, w)')) + .sort('id') + .toPandas()) + + self.assertPandasEqual(expected1, result1) + + def test_complex_groupby(self): + from pyspark.sql.functions import sum + + df = self.data + sum_udf = self.pandas_agg_sum_udf + plus_one = self.python_plus_one + plus_two = self.pandas_scalar_plus_two + + # groupby one expression + result1 = df.groupby(df.v % 2).agg(sum_udf(df.v)) + expected1 = df.groupby(df.v % 2).agg(sum(df.v)) + + # empty groupby + result2 = df.groupby().agg(sum_udf(df.v)) + expected2 = df.groupby().agg(sum(df.v)) + + # groupby one column and one sql expression + result3 = df.groupby(df.id, df.v % 2).agg(sum_udf(df.v)).orderBy(df.id, df.v % 2) + expected3 = df.groupby(df.id, df.v % 2).agg(sum(df.v)).orderBy(df.id, df.v % 2) + + # groupby one python UDF + result4 = df.groupby(plus_one(df.id)).agg(sum_udf(df.v)) + expected4 = df.groupby(plus_one(df.id)).agg(sum(df.v)) + + # groupby one scalar pandas UDF + result5 = df.groupby(plus_two(df.id)).agg(sum_udf(df.v)) + expected5 = df.groupby(plus_two(df.id)).agg(sum(df.v)) + + # groupby one expression and one python UDF + result6 = df.groupby(df.v % 2, plus_one(df.id)).agg(sum_udf(df.v)) + expected6 = df.groupby(df.v % 2, plus_one(df.id)).agg(sum(df.v)) + + # groupby one expression and one scalar pandas UDF + result7 = df.groupby(df.v % 2, plus_two(df.id)).agg(sum_udf(df.v)).sort('sum(v)') + expected7 = df.groupby(df.v % 2, plus_two(df.id)).agg(sum(df.v)).sort('sum(v)') + + self.assertPandasEqual(expected1.toPandas(), result1.toPandas()) + self.assertPandasEqual(expected2.toPandas(), result2.toPandas()) + self.assertPandasEqual(expected3.toPandas(), result3.toPandas()) + self.assertPandasEqual(expected4.toPandas(), result4.toPandas()) + self.assertPandasEqual(expected5.toPandas(), result5.toPandas()) + self.assertPandasEqual(expected6.toPandas(), result6.toPandas()) + self.assertPandasEqual(expected7.toPandas(), result7.toPandas()) + + def test_complex_expressions(self): + from pyspark.sql.functions import col, sum + + df = self.data + plus_one = self.python_plus_one + plus_two = self.pandas_scalar_plus_two + sum_udf = self.pandas_agg_sum_udf + + # Test complex expressions with sql expression, python UDF and + # group aggregate pandas UDF + result1 = (df.withColumn('v1', plus_one(df.v)) + .withColumn('v2', df.v + 2) + .groupby(df.id, df.v % 2) + .agg(sum_udf(col('v')), + sum_udf(col('v1') + 3), + sum_udf(col('v2')) + 5, + plus_one(sum_udf(col('v1'))), + sum_udf(plus_one(col('v2')))) + .sort('id') + .toPandas()) + + expected1 = (df.withColumn('v1', df.v + 1) + .withColumn('v2', df.v + 2) + .groupby(df.id, df.v % 2) + .agg(sum(col('v')), + sum(col('v1') + 3), + sum(col('v2')) + 5, + plus_one(sum(col('v1'))), + sum(plus_one(col('v2')))) + .sort('id') + .toPandas()) + + # Test complex expressions with sql expression, scala pandas UDF and + # group aggregate pandas UDF + result2 = (df.withColumn('v1', plus_one(df.v)) + .withColumn('v2', df.v + 2) + .groupby(df.id, df.v % 2) + .agg(sum_udf(col('v')), + sum_udf(col('v1') + 3), + sum_udf(col('v2')) + 5, + plus_two(sum_udf(col('v1'))), + sum_udf(plus_two(col('v2')))) + .sort('id') + .toPandas()) + + expected2 = (df.withColumn('v1', df.v + 1) + .withColumn('v2', df.v + 2) + .groupby(df.id, df.v % 2) + .agg(sum(col('v')), + sum(col('v1') + 3), + sum(col('v2')) + 5, + plus_two(sum(col('v1'))), + sum(plus_two(col('v2')))) + .sort('id') + .toPandas()) + + # Test sequential groupby aggregate + result3 = (df.groupby('id') + .agg(sum_udf(df.v).alias('v')) + .groupby('id') + .agg(sum_udf(col('v'))) + .sort('id') + .toPandas()) + + expected3 = (df.groupby('id') + .agg(sum(df.v).alias('v')) + .groupby('id') + .agg(sum(col('v'))) + .sort('id') + .toPandas()) + + self.assertPandasEqual(expected1, result1) + self.assertPandasEqual(expected2, result2) + self.assertPandasEqual(expected3, result3) + + def test_retain_group_columns(self): + from pyspark.sql.functions import sum + with self.sql_conf({"spark.sql.retainGroupColumns": False}): + df = self.data + sum_udf = self.pandas_agg_sum_udf + + result1 = df.groupby(df.id).agg(sum_udf(df.v)) + expected1 = df.groupby(df.id).agg(sum(df.v)) + self.assertPandasEqual(expected1.toPandas(), result1.toPandas()) + + def test_array_type(self): + from pyspark.sql.functions import pandas_udf, PandasUDFType + + df = self.data + + array_udf = pandas_udf(lambda x: [1.0, 2.0], 'array', PandasUDFType.GROUPED_AGG) + result1 = df.groupby('id').agg(array_udf(df['v']).alias('v2')) + self.assertEquals(result1.first()['v2'], [1.0, 2.0]) + + def test_invalid_args(self): + from pyspark.sql.functions import mean + + df = self.data + plus_one = self.python_plus_one + mean_udf = self.pandas_agg_mean_udf + + with QuietTest(self.sc): + with self.assertRaisesRegexp( + AnalysisException, + 'nor.*aggregate function'): + df.groupby(df.id).agg(plus_one(df.v)).collect() + + with QuietTest(self.sc): + with self.assertRaisesRegexp( + AnalysisException, + 'aggregate function.*argument.*aggregate function'): + df.groupby(df.id).agg(mean_udf(mean_udf(df.v))).collect() + + with QuietTest(self.sc): + with self.assertRaisesRegexp( + AnalysisException, + 'mixture.*aggregate function.*group aggregate pandas UDF'): + df.groupby(df.id).agg(mean_udf(df.v), mean(df.v)).collect() + + def test_register_vectorized_udf_basic(self): + from pyspark.sql.functions import pandas_udf + from pyspark.rdd import PythonEvalType + + sum_pandas_udf = pandas_udf( + lambda v: v.sum(), "integer", PythonEvalType.SQL_GROUPED_AGG_PANDAS_UDF) + + self.assertEqual(sum_pandas_udf.evalType, PythonEvalType.SQL_GROUPED_AGG_PANDAS_UDF) + group_agg_pandas_udf = self.spark.udf.register("sum_pandas_udf", sum_pandas_udf) + self.assertEqual(group_agg_pandas_udf.evalType, PythonEvalType.SQL_GROUPED_AGG_PANDAS_UDF) + q = "SELECT sum_pandas_udf(v1) FROM VALUES (3, 0), (2, 0), (1, 1) tbl(v1, v2) GROUP BY v2" + actual = sorted(map(lambda r: r[0], self.spark.sql(q).collect())) + expected = [1, 5] + self.assertEqual(actual, expected) + + +if __name__ == "__main__": + from pyspark.sql.tests.test_pandas_udf_grouped_agg import * + + try: + import xmlrunner + testRunner = xmlrunner.XMLTestRunner(output='target/test-reports') + except ImportError: + testRunner = None + unittest.main(testRunner=testRunner, verbosity=2) diff --git a/python/pyspark/sql/tests/test_pandas_udf_grouped_map.py b/python/pyspark/sql/tests/test_pandas_udf_grouped_map.py new file mode 100644 index 0000000000000..bfecc071386e9 --- /dev/null +++ b/python/pyspark/sql/tests/test_pandas_udf_grouped_map.py @@ -0,0 +1,531 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import datetime +import unittest + +from pyspark.sql import Row +from pyspark.sql.types import * +from pyspark.testing.sqlutils import ReusedSQLTestCase, have_pandas, have_pyarrow, \ + pandas_requirement_message, pyarrow_requirement_message +from pyspark.testing.utils import QuietTest + + +@unittest.skipIf( + not have_pandas or not have_pyarrow, + pandas_requirement_message or pyarrow_requirement_message) +class GroupedMapPandasUDFTests(ReusedSQLTestCase): + + @property + def data(self): + from pyspark.sql.functions import array, explode, col, lit + return self.spark.range(10).toDF('id') \ + .withColumn("vs", array([lit(i) for i in range(20, 30)])) \ + .withColumn("v", explode(col('vs'))).drop('vs') + + def test_supported_types(self): + from decimal import Decimal + from distutils.version import LooseVersion + import pyarrow as pa + from pyspark.sql.functions import pandas_udf, PandasUDFType + + values = [ + 1, 2, 3, + 4, 5, 1.1, + 2.2, Decimal(1.123), + [1, 2, 2], True, 'hello' + ] + output_fields = [ + ('id', IntegerType()), ('byte', ByteType()), ('short', ShortType()), + ('int', IntegerType()), ('long', LongType()), ('float', FloatType()), + ('double', DoubleType()), ('decim', DecimalType(10, 3)), + ('array', ArrayType(IntegerType())), ('bool', BooleanType()), ('str', StringType()) + ] + + # TODO: Add BinaryType to variables above once minimum pyarrow version is 0.10.0 + if LooseVersion(pa.__version__) >= LooseVersion("0.10.0"): + values.append(bytearray([0x01, 0x02])) + output_fields.append(('bin', BinaryType())) + + output_schema = StructType([StructField(*x) for x in output_fields]) + df = self.spark.createDataFrame([values], schema=output_schema) + + # Different forms of group map pandas UDF, results of these are the same + udf1 = pandas_udf( + lambda pdf: pdf.assign( + byte=pdf.byte * 2, + short=pdf.short * 2, + int=pdf.int * 2, + long=pdf.long * 2, + float=pdf.float * 2, + double=pdf.double * 2, + decim=pdf.decim * 2, + bool=False if pdf.bool else True, + str=pdf.str + 'there', + array=pdf.array, + ), + output_schema, + PandasUDFType.GROUPED_MAP + ) + + udf2 = pandas_udf( + lambda _, pdf: pdf.assign( + byte=pdf.byte * 2, + short=pdf.short * 2, + int=pdf.int * 2, + long=pdf.long * 2, + float=pdf.float * 2, + double=pdf.double * 2, + decim=pdf.decim * 2, + bool=False if pdf.bool else True, + str=pdf.str + 'there', + array=pdf.array, + ), + output_schema, + PandasUDFType.GROUPED_MAP + ) + + udf3 = pandas_udf( + lambda key, pdf: pdf.assign( + id=key[0], + byte=pdf.byte * 2, + short=pdf.short * 2, + int=pdf.int * 2, + long=pdf.long * 2, + float=pdf.float * 2, + double=pdf.double * 2, + decim=pdf.decim * 2, + bool=False if pdf.bool else True, + str=pdf.str + 'there', + array=pdf.array, + ), + output_schema, + PandasUDFType.GROUPED_MAP + ) + + result1 = df.groupby('id').apply(udf1).sort('id').toPandas() + expected1 = df.toPandas().groupby('id').apply(udf1.func).reset_index(drop=True) + + result2 = df.groupby('id').apply(udf2).sort('id').toPandas() + expected2 = expected1 + + result3 = df.groupby('id').apply(udf3).sort('id').toPandas() + expected3 = expected1 + + self.assertPandasEqual(expected1, result1) + self.assertPandasEqual(expected2, result2) + self.assertPandasEqual(expected3, result3) + + def test_array_type_correct(self): + from pyspark.sql.functions import pandas_udf, PandasUDFType, array, col + + df = self.data.withColumn("arr", array(col("id"))).repartition(1, "id") + + output_schema = StructType( + [StructField('id', LongType()), + StructField('v', IntegerType()), + StructField('arr', ArrayType(LongType()))]) + + udf = pandas_udf( + lambda pdf: pdf, + output_schema, + PandasUDFType.GROUPED_MAP + ) + + result = df.groupby('id').apply(udf).sort('id').toPandas() + expected = df.toPandas().groupby('id').apply(udf.func).reset_index(drop=True) + self.assertPandasEqual(expected, result) + + def test_register_grouped_map_udf(self): + from pyspark.sql.functions import pandas_udf, PandasUDFType + + foo_udf = pandas_udf(lambda x: x, "id long", PandasUDFType.GROUPED_MAP) + with QuietTest(self.sc): + with self.assertRaisesRegexp( + ValueError, + 'f.*SQL_BATCHED_UDF.*SQL_SCALAR_PANDAS_UDF.*SQL_GROUPED_AGG_PANDAS_UDF.*'): + self.spark.catalog.registerFunction("foo_udf", foo_udf) + + def test_decorator(self): + from pyspark.sql.functions import pandas_udf, PandasUDFType + df = self.data + + @pandas_udf( + 'id long, v int, v1 double, v2 long', + PandasUDFType.GROUPED_MAP + ) + def foo(pdf): + return pdf.assign(v1=pdf.v * pdf.id * 1.0, v2=pdf.v + pdf.id) + + result = df.groupby('id').apply(foo).sort('id').toPandas() + expected = df.toPandas().groupby('id').apply(foo.func).reset_index(drop=True) + self.assertPandasEqual(expected, result) + + def test_coerce(self): + from pyspark.sql.functions import pandas_udf, PandasUDFType + df = self.data + + foo = pandas_udf( + lambda pdf: pdf, + 'id long, v double', + PandasUDFType.GROUPED_MAP + ) + + result = df.groupby('id').apply(foo).sort('id').toPandas() + expected = df.toPandas().groupby('id').apply(foo.func).reset_index(drop=True) + expected = expected.assign(v=expected.v.astype('float64')) + self.assertPandasEqual(expected, result) + + def test_complex_groupby(self): + from pyspark.sql.functions import pandas_udf, col, PandasUDFType + df = self.data + + @pandas_udf( + 'id long, v int, norm double', + PandasUDFType.GROUPED_MAP + ) + def normalize(pdf): + v = pdf.v + return pdf.assign(norm=(v - v.mean()) / v.std()) + + result = df.groupby(col('id') % 2 == 0).apply(normalize).sort('id', 'v').toPandas() + pdf = df.toPandas() + expected = pdf.groupby(pdf['id'] % 2 == 0).apply(normalize.func) + expected = expected.sort_values(['id', 'v']).reset_index(drop=True) + expected = expected.assign(norm=expected.norm.astype('float64')) + self.assertPandasEqual(expected, result) + + def test_empty_groupby(self): + from pyspark.sql.functions import pandas_udf, PandasUDFType + df = self.data + + @pandas_udf( + 'id long, v int, norm double', + PandasUDFType.GROUPED_MAP + ) + def normalize(pdf): + v = pdf.v + return pdf.assign(norm=(v - v.mean()) / v.std()) + + result = df.groupby().apply(normalize).sort('id', 'v').toPandas() + pdf = df.toPandas() + expected = normalize.func(pdf) + expected = expected.sort_values(['id', 'v']).reset_index(drop=True) + expected = expected.assign(norm=expected.norm.astype('float64')) + self.assertPandasEqual(expected, result) + + def test_datatype_string(self): + from pyspark.sql.functions import pandas_udf, PandasUDFType + df = self.data + + foo_udf = pandas_udf( + lambda pdf: pdf.assign(v1=pdf.v * pdf.id * 1.0, v2=pdf.v + pdf.id), + 'id long, v int, v1 double, v2 long', + PandasUDFType.GROUPED_MAP + ) + + result = df.groupby('id').apply(foo_udf).sort('id').toPandas() + expected = df.toPandas().groupby('id').apply(foo_udf.func).reset_index(drop=True) + self.assertPandasEqual(expected, result) + + def test_wrong_return_type(self): + from pyspark.sql.functions import pandas_udf, PandasUDFType + + with QuietTest(self.sc): + with self.assertRaisesRegexp( + NotImplementedError, + 'Invalid returnType.*grouped map Pandas UDF.*MapType'): + pandas_udf( + lambda pdf: pdf, + 'id long, v map', + PandasUDFType.GROUPED_MAP) + + def test_wrong_args(self): + from pyspark.sql.functions import udf, pandas_udf, sum, PandasUDFType + df = self.data + + with QuietTest(self.sc): + with self.assertRaisesRegexp(ValueError, 'Invalid udf'): + df.groupby('id').apply(lambda x: x) + with self.assertRaisesRegexp(ValueError, 'Invalid udf'): + df.groupby('id').apply(udf(lambda x: x, DoubleType())) + with self.assertRaisesRegexp(ValueError, 'Invalid udf'): + df.groupby('id').apply(sum(df.v)) + with self.assertRaisesRegexp(ValueError, 'Invalid udf'): + df.groupby('id').apply(df.v + 1) + with self.assertRaisesRegexp(ValueError, 'Invalid function'): + df.groupby('id').apply( + pandas_udf(lambda: 1, StructType([StructField("d", DoubleType())]))) + with self.assertRaisesRegexp(ValueError, 'Invalid udf'): + df.groupby('id').apply(pandas_udf(lambda x, y: x, DoubleType())) + with self.assertRaisesRegexp(ValueError, 'Invalid udf.*GROUPED_MAP'): + df.groupby('id').apply( + pandas_udf(lambda x, y: x, DoubleType(), PandasUDFType.SCALAR)) + + def test_unsupported_types(self): + from distutils.version import LooseVersion + import pyarrow as pa + from pyspark.sql.functions import pandas_udf, PandasUDFType + + common_err_msg = 'Invalid returnType.*grouped map Pandas UDF.*' + unsupported_types = [ + StructField('map', MapType(StringType(), IntegerType())), + StructField('arr_ts', ArrayType(TimestampType())), + StructField('null', NullType()), + ] + + # TODO: Remove this if-statement once minimum pyarrow version is 0.10.0 + if LooseVersion(pa.__version__) < LooseVersion("0.10.0"): + unsupported_types.append(StructField('bin', BinaryType())) + + for unsupported_type in unsupported_types: + schema = StructType([StructField('id', LongType(), True), unsupported_type]) + with QuietTest(self.sc): + with self.assertRaisesRegexp(NotImplementedError, common_err_msg): + pandas_udf(lambda x: x, schema, PandasUDFType.GROUPED_MAP) + + # Regression test for SPARK-23314 + def test_timestamp_dst(self): + from pyspark.sql.functions import pandas_udf, PandasUDFType + # Daylight saving time for Los Angeles for 2015 is Sun, Nov 1 at 2:00 am + dt = [datetime.datetime(2015, 11, 1, 0, 30), + datetime.datetime(2015, 11, 1, 1, 30), + datetime.datetime(2015, 11, 1, 2, 30)] + df = self.spark.createDataFrame(dt, 'timestamp').toDF('time') + foo_udf = pandas_udf(lambda pdf: pdf, 'time timestamp', PandasUDFType.GROUPED_MAP) + result = df.groupby('time').apply(foo_udf).sort('time') + self.assertPandasEqual(df.toPandas(), result.toPandas()) + + def test_udf_with_key(self): + from pyspark.sql.functions import pandas_udf, PandasUDFType + df = self.data + pdf = df.toPandas() + + def foo1(key, pdf): + import numpy as np + assert type(key) == tuple + assert type(key[0]) == np.int64 + + return pdf.assign(v1=key[0], + v2=pdf.v * key[0], + v3=pdf.v * pdf.id, + v4=pdf.v * pdf.id.mean()) + + def foo2(key, pdf): + import numpy as np + assert type(key) == tuple + assert type(key[0]) == np.int64 + assert type(key[1]) == np.int32 + + return pdf.assign(v1=key[0], + v2=key[1], + v3=pdf.v * key[0], + v4=pdf.v + key[1]) + + def foo3(key, pdf): + assert type(key) == tuple + assert len(key) == 0 + return pdf.assign(v1=pdf.v * pdf.id) + + # v2 is int because numpy.int64 * pd.Series results in pd.Series + # v3 is long because pd.Series * pd.Series results in pd.Series + udf1 = pandas_udf( + foo1, + 'id long, v int, v1 long, v2 int, v3 long, v4 double', + PandasUDFType.GROUPED_MAP) + + udf2 = pandas_udf( + foo2, + 'id long, v int, v1 long, v2 int, v3 int, v4 int', + PandasUDFType.GROUPED_MAP) + + udf3 = pandas_udf( + foo3, + 'id long, v int, v1 long', + PandasUDFType.GROUPED_MAP) + + # Test groupby column + result1 = df.groupby('id').apply(udf1).sort('id', 'v').toPandas() + expected1 = pdf.groupby('id')\ + .apply(lambda x: udf1.func((x.id.iloc[0],), x))\ + .sort_values(['id', 'v']).reset_index(drop=True) + self.assertPandasEqual(expected1, result1) + + # Test groupby expression + result2 = df.groupby(df.id % 2).apply(udf1).sort('id', 'v').toPandas() + expected2 = pdf.groupby(pdf.id % 2)\ + .apply(lambda x: udf1.func((x.id.iloc[0] % 2,), x))\ + .sort_values(['id', 'v']).reset_index(drop=True) + self.assertPandasEqual(expected2, result2) + + # Test complex groupby + result3 = df.groupby(df.id, df.v % 2).apply(udf2).sort('id', 'v').toPandas() + expected3 = pdf.groupby([pdf.id, pdf.v % 2])\ + .apply(lambda x: udf2.func((x.id.iloc[0], (x.v % 2).iloc[0],), x))\ + .sort_values(['id', 'v']).reset_index(drop=True) + self.assertPandasEqual(expected3, result3) + + # Test empty groupby + result4 = df.groupby().apply(udf3).sort('id', 'v').toPandas() + expected4 = udf3.func((), pdf) + self.assertPandasEqual(expected4, result4) + + def test_column_order(self): + from collections import OrderedDict + import pandas as pd + from pyspark.sql.functions import pandas_udf, PandasUDFType + + # Helper function to set column names from a list + def rename_pdf(pdf, names): + pdf.rename(columns={old: new for old, new in + zip(pd_result.columns, names)}, inplace=True) + + df = self.data + grouped_df = df.groupby('id') + grouped_pdf = df.toPandas().groupby('id') + + # Function returns a pdf with required column names, but order could be arbitrary using dict + def change_col_order(pdf): + # Constructing a DataFrame from a dict should result in the same order, + # but use from_items to ensure the pdf column order is different than schema + return pd.DataFrame.from_items([ + ('id', pdf.id), + ('u', pdf.v * 2), + ('v', pdf.v)]) + + ordered_udf = pandas_udf( + change_col_order, + 'id long, v int, u int', + PandasUDFType.GROUPED_MAP + ) + + # The UDF result should assign columns by name from the pdf + result = grouped_df.apply(ordered_udf).sort('id', 'v')\ + .select('id', 'u', 'v').toPandas() + pd_result = grouped_pdf.apply(change_col_order) + expected = pd_result.sort_values(['id', 'v']).reset_index(drop=True) + self.assertPandasEqual(expected, result) + + # Function returns a pdf with positional columns, indexed by range + def range_col_order(pdf): + # Create a DataFrame with positional columns, fix types to long + return pd.DataFrame(list(zip(pdf.id, pdf.v * 3, pdf.v)), dtype='int64') + + range_udf = pandas_udf( + range_col_order, + 'id long, u long, v long', + PandasUDFType.GROUPED_MAP + ) + + # The UDF result uses positional columns from the pdf + result = grouped_df.apply(range_udf).sort('id', 'v') \ + .select('id', 'u', 'v').toPandas() + pd_result = grouped_pdf.apply(range_col_order) + rename_pdf(pd_result, ['id', 'u', 'v']) + expected = pd_result.sort_values(['id', 'v']).reset_index(drop=True) + self.assertPandasEqual(expected, result) + + # Function returns a pdf with columns indexed with integers + def int_index(pdf): + return pd.DataFrame(OrderedDict([(0, pdf.id), (1, pdf.v * 4), (2, pdf.v)])) + + int_index_udf = pandas_udf( + int_index, + 'id long, u int, v int', + PandasUDFType.GROUPED_MAP + ) + + # The UDF result should assign columns by position of integer index + result = grouped_df.apply(int_index_udf).sort('id', 'v') \ + .select('id', 'u', 'v').toPandas() + pd_result = grouped_pdf.apply(int_index) + rename_pdf(pd_result, ['id', 'u', 'v']) + expected = pd_result.sort_values(['id', 'v']).reset_index(drop=True) + self.assertPandasEqual(expected, result) + + @pandas_udf('id long, v int', PandasUDFType.GROUPED_MAP) + def column_name_typo(pdf): + return pd.DataFrame({'iid': pdf.id, 'v': pdf.v}) + + @pandas_udf('id long, v int', PandasUDFType.GROUPED_MAP) + def invalid_positional_types(pdf): + return pd.DataFrame([(u'a', 1.2)]) + + with QuietTest(self.sc): + with self.assertRaisesRegexp(Exception, "KeyError: 'id'"): + grouped_df.apply(column_name_typo).collect() + with self.assertRaisesRegexp(Exception, "No cast implemented"): + grouped_df.apply(invalid_positional_types).collect() + + def test_positional_assignment_conf(self): + import pandas as pd + from pyspark.sql.functions import pandas_udf, PandasUDFType + + with self.sql_conf({ + "spark.sql.legacy.execution.pandas.groupedMap.assignColumnsByName": False}): + + @pandas_udf("a string, b float", PandasUDFType.GROUPED_MAP) + def foo(_): + return pd.DataFrame([('hi', 1)], columns=['x', 'y']) + + df = self.data + result = df.groupBy('id').apply(foo).select('a', 'b').collect() + for r in result: + self.assertEqual(r.a, 'hi') + self.assertEqual(r.b, 1) + + def test_self_join_with_pandas(self): + import pyspark.sql.functions as F + + @F.pandas_udf('key long, col string', F.PandasUDFType.GROUPED_MAP) + def dummy_pandas_udf(df): + return df[['key', 'col']] + + df = self.spark.createDataFrame([Row(key=1, col='A'), Row(key=1, col='B'), + Row(key=2, col='C')]) + df_with_pandas = df.groupBy('key').apply(dummy_pandas_udf) + + # this was throwing an AnalysisException before SPARK-24208 + res = df_with_pandas.alias('temp0').join(df_with_pandas.alias('temp1'), + F.col('temp0.key') == F.col('temp1.key')) + self.assertEquals(res.count(), 5) + + def test_mixed_scalar_udfs_followed_by_grouby_apply(self): + import pandas as pd + from pyspark.sql.functions import udf, pandas_udf, PandasUDFType + + df = self.spark.range(0, 10).toDF('v1') + df = df.withColumn('v2', udf(lambda x: x + 1, 'int')(df['v1'])) \ + .withColumn('v3', pandas_udf(lambda x: x + 2, 'int')(df['v1'])) + + result = df.groupby() \ + .apply(pandas_udf(lambda x: pd.DataFrame([x.sum().sum()]), + 'sum int', + PandasUDFType.GROUPED_MAP)) + + self.assertEquals(result.collect()[0]['sum'], 165) + + +if __name__ == "__main__": + from pyspark.sql.tests.test_pandas_udf_grouped_map import * + + try: + import xmlrunner + testRunner = xmlrunner.XMLTestRunner(output='target/test-reports') + except ImportError: + testRunner = None + unittest.main(testRunner=testRunner, verbosity=2) diff --git a/python/pyspark/sql/tests/test_pandas_udf_scalar.py b/python/pyspark/sql/tests/test_pandas_udf_scalar.py new file mode 100644 index 0000000000000..b303398850394 --- /dev/null +++ b/python/pyspark/sql/tests/test_pandas_udf_scalar.py @@ -0,0 +1,808 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +import datetime +import os +import shutil +import sys +import tempfile +import time +import unittest + +from pyspark.sql.types import Row +from pyspark.sql.types import * +from pyspark.sql.utils import AnalysisException +from pyspark.testing.sqlutils import ReusedSQLTestCase, test_compiled,\ + test_not_compiled_message, have_pandas, have_pyarrow, pandas_requirement_message, \ + pyarrow_requirement_message +from pyspark.testing.utils import QuietTest + + +@unittest.skipIf( + not have_pandas or not have_pyarrow, + pandas_requirement_message or pyarrow_requirement_message) +class ScalarPandasUDFTests(ReusedSQLTestCase): + + @classmethod + def setUpClass(cls): + ReusedSQLTestCase.setUpClass() + + # Synchronize default timezone between Python and Java + cls.tz_prev = os.environ.get("TZ", None) # save current tz if set + tz = "UTC" + os.environ["TZ"] = tz + time.tzset() + + cls.sc.environment["TZ"] = tz + cls.spark.conf.set("spark.sql.session.timeZone", tz) + + @classmethod + def tearDownClass(cls): + del os.environ["TZ"] + if cls.tz_prev is not None: + os.environ["TZ"] = cls.tz_prev + time.tzset() + ReusedSQLTestCase.tearDownClass() + + @property + def nondeterministic_vectorized_udf(self): + from pyspark.sql.functions import pandas_udf + + @pandas_udf('double') + def random_udf(v): + import pandas as pd + import numpy as np + return pd.Series(np.random.random(len(v))) + random_udf = random_udf.asNondeterministic() + return random_udf + + def test_pandas_udf_tokenize(self): + from pyspark.sql.functions import pandas_udf + tokenize = pandas_udf(lambda s: s.apply(lambda str: str.split(' ')), + ArrayType(StringType())) + self.assertEqual(tokenize.returnType, ArrayType(StringType())) + df = self.spark.createDataFrame([("hi boo",), ("bye boo",)], ["vals"]) + result = df.select(tokenize("vals").alias("hi")) + self.assertEqual([Row(hi=[u'hi', u'boo']), Row(hi=[u'bye', u'boo'])], result.collect()) + + def test_pandas_udf_nested_arrays(self): + from pyspark.sql.functions import pandas_udf + tokenize = pandas_udf(lambda s: s.apply(lambda str: [str.split(' ')]), + ArrayType(ArrayType(StringType()))) + self.assertEqual(tokenize.returnType, ArrayType(ArrayType(StringType()))) + df = self.spark.createDataFrame([("hi boo",), ("bye boo",)], ["vals"]) + result = df.select(tokenize("vals").alias("hi")) + self.assertEqual([Row(hi=[[u'hi', u'boo']]), Row(hi=[[u'bye', u'boo']])], result.collect()) + + def test_vectorized_udf_basic(self): + from pyspark.sql.functions import pandas_udf, col, array + df = self.spark.range(10).select( + col('id').cast('string').alias('str'), + col('id').cast('int').alias('int'), + col('id').alias('long'), + col('id').cast('float').alias('float'), + col('id').cast('double').alias('double'), + col('id').cast('decimal').alias('decimal'), + col('id').cast('boolean').alias('bool'), + array(col('id')).alias('array_long')) + f = lambda x: x + str_f = pandas_udf(f, StringType()) + int_f = pandas_udf(f, IntegerType()) + long_f = pandas_udf(f, LongType()) + float_f = pandas_udf(f, FloatType()) + double_f = pandas_udf(f, DoubleType()) + decimal_f = pandas_udf(f, DecimalType()) + bool_f = pandas_udf(f, BooleanType()) + array_long_f = pandas_udf(f, ArrayType(LongType())) + res = df.select(str_f(col('str')), int_f(col('int')), + long_f(col('long')), float_f(col('float')), + double_f(col('double')), decimal_f('decimal'), + bool_f(col('bool')), array_long_f('array_long')) + self.assertEquals(df.collect(), res.collect()) + + def test_register_nondeterministic_vectorized_udf_basic(self): + from pyspark.sql.functions import pandas_udf + from pyspark.rdd import PythonEvalType + import random + random_pandas_udf = pandas_udf( + lambda x: random.randint(6, 6) + x, IntegerType()).asNondeterministic() + self.assertEqual(random_pandas_udf.deterministic, False) + self.assertEqual(random_pandas_udf.evalType, PythonEvalType.SQL_SCALAR_PANDAS_UDF) + nondeterministic_pandas_udf = self.spark.catalog.registerFunction( + "randomPandasUDF", random_pandas_udf) + self.assertEqual(nondeterministic_pandas_udf.deterministic, False) + self.assertEqual(nondeterministic_pandas_udf.evalType, PythonEvalType.SQL_SCALAR_PANDAS_UDF) + [row] = self.spark.sql("SELECT randomPandasUDF(1)").collect() + self.assertEqual(row[0], 7) + + def test_vectorized_udf_null_boolean(self): + from pyspark.sql.functions import pandas_udf, col + data = [(True,), (True,), (None,), (False,)] + schema = StructType().add("bool", BooleanType()) + df = self.spark.createDataFrame(data, schema) + bool_f = pandas_udf(lambda x: x, BooleanType()) + res = df.select(bool_f(col('bool'))) + self.assertEquals(df.collect(), res.collect()) + + def test_vectorized_udf_null_byte(self): + from pyspark.sql.functions import pandas_udf, col + data = [(None,), (2,), (3,), (4,)] + schema = StructType().add("byte", ByteType()) + df = self.spark.createDataFrame(data, schema) + byte_f = pandas_udf(lambda x: x, ByteType()) + res = df.select(byte_f(col('byte'))) + self.assertEquals(df.collect(), res.collect()) + + def test_vectorized_udf_null_short(self): + from pyspark.sql.functions import pandas_udf, col + data = [(None,), (2,), (3,), (4,)] + schema = StructType().add("short", ShortType()) + df = self.spark.createDataFrame(data, schema) + short_f = pandas_udf(lambda x: x, ShortType()) + res = df.select(short_f(col('short'))) + self.assertEquals(df.collect(), res.collect()) + + def test_vectorized_udf_null_int(self): + from pyspark.sql.functions import pandas_udf, col + data = [(None,), (2,), (3,), (4,)] + schema = StructType().add("int", IntegerType()) + df = self.spark.createDataFrame(data, schema) + int_f = pandas_udf(lambda x: x, IntegerType()) + res = df.select(int_f(col('int'))) + self.assertEquals(df.collect(), res.collect()) + + def test_vectorized_udf_null_long(self): + from pyspark.sql.functions import pandas_udf, col + data = [(None,), (2,), (3,), (4,)] + schema = StructType().add("long", LongType()) + df = self.spark.createDataFrame(data, schema) + long_f = pandas_udf(lambda x: x, LongType()) + res = df.select(long_f(col('long'))) + self.assertEquals(df.collect(), res.collect()) + + def test_vectorized_udf_null_float(self): + from pyspark.sql.functions import pandas_udf, col + data = [(3.0,), (5.0,), (-1.0,), (None,)] + schema = StructType().add("float", FloatType()) + df = self.spark.createDataFrame(data, schema) + float_f = pandas_udf(lambda x: x, FloatType()) + res = df.select(float_f(col('float'))) + self.assertEquals(df.collect(), res.collect()) + + def test_vectorized_udf_null_double(self): + from pyspark.sql.functions import pandas_udf, col + data = [(3.0,), (5.0,), (-1.0,), (None,)] + schema = StructType().add("double", DoubleType()) + df = self.spark.createDataFrame(data, schema) + double_f = pandas_udf(lambda x: x, DoubleType()) + res = df.select(double_f(col('double'))) + self.assertEquals(df.collect(), res.collect()) + + def test_vectorized_udf_null_decimal(self): + from decimal import Decimal + from pyspark.sql.functions import pandas_udf, col + data = [(Decimal(3.0),), (Decimal(5.0),), (Decimal(-1.0),), (None,)] + schema = StructType().add("decimal", DecimalType(38, 18)) + df = self.spark.createDataFrame(data, schema) + decimal_f = pandas_udf(lambda x: x, DecimalType(38, 18)) + res = df.select(decimal_f(col('decimal'))) + self.assertEquals(df.collect(), res.collect()) + + def test_vectorized_udf_null_string(self): + from pyspark.sql.functions import pandas_udf, col + data = [("foo",), (None,), ("bar",), ("bar",)] + schema = StructType().add("str", StringType()) + df = self.spark.createDataFrame(data, schema) + str_f = pandas_udf(lambda x: x, StringType()) + res = df.select(str_f(col('str'))) + self.assertEquals(df.collect(), res.collect()) + + def test_vectorized_udf_string_in_udf(self): + from pyspark.sql.functions import pandas_udf, col + import pandas as pd + df = self.spark.range(10) + str_f = pandas_udf(lambda x: pd.Series(map(str, x)), StringType()) + actual = df.select(str_f(col('id'))) + expected = df.select(col('id').cast('string')) + self.assertEquals(expected.collect(), actual.collect()) + + def test_vectorized_udf_datatype_string(self): + from pyspark.sql.functions import pandas_udf, col + df = self.spark.range(10).select( + col('id').cast('string').alias('str'), + col('id').cast('int').alias('int'), + col('id').alias('long'), + col('id').cast('float').alias('float'), + col('id').cast('double').alias('double'), + col('id').cast('decimal').alias('decimal'), + col('id').cast('boolean').alias('bool')) + f = lambda x: x + str_f = pandas_udf(f, 'string') + int_f = pandas_udf(f, 'integer') + long_f = pandas_udf(f, 'long') + float_f = pandas_udf(f, 'float') + double_f = pandas_udf(f, 'double') + decimal_f = pandas_udf(f, 'decimal(38, 18)') + bool_f = pandas_udf(f, 'boolean') + res = df.select(str_f(col('str')), int_f(col('int')), + long_f(col('long')), float_f(col('float')), + double_f(col('double')), decimal_f('decimal'), + bool_f(col('bool'))) + self.assertEquals(df.collect(), res.collect()) + + def test_vectorized_udf_null_binary(self): + from distutils.version import LooseVersion + import pyarrow as pa + from pyspark.sql.functions import pandas_udf, col + if LooseVersion(pa.__version__) < LooseVersion("0.10.0"): + with QuietTest(self.sc): + with self.assertRaisesRegexp( + NotImplementedError, + 'Invalid returnType.*scalar Pandas UDF.*BinaryType'): + pandas_udf(lambda x: x, BinaryType()) + else: + data = [(bytearray(b"a"),), (None,), (bytearray(b"bb"),), (bytearray(b"ccc"),)] + schema = StructType().add("binary", BinaryType()) + df = self.spark.createDataFrame(data, schema) + str_f = pandas_udf(lambda x: x, BinaryType()) + res = df.select(str_f(col('binary'))) + self.assertEquals(df.collect(), res.collect()) + + def test_vectorized_udf_array_type(self): + from pyspark.sql.functions import pandas_udf, col + data = [([1, 2],), ([3, 4],)] + array_schema = StructType([StructField("array", ArrayType(IntegerType()))]) + df = self.spark.createDataFrame(data, schema=array_schema) + array_f = pandas_udf(lambda x: x, ArrayType(IntegerType())) + result = df.select(array_f(col('array'))) + self.assertEquals(df.collect(), result.collect()) + + def test_vectorized_udf_null_array(self): + from pyspark.sql.functions import pandas_udf, col + data = [([1, 2],), (None,), (None,), ([3, 4],), (None,)] + array_schema = StructType([StructField("array", ArrayType(IntegerType()))]) + df = self.spark.createDataFrame(data, schema=array_schema) + array_f = pandas_udf(lambda x: x, ArrayType(IntegerType())) + result = df.select(array_f(col('array'))) + self.assertEquals(df.collect(), result.collect()) + + def test_vectorized_udf_complex(self): + from pyspark.sql.functions import pandas_udf, col, expr + df = self.spark.range(10).select( + col('id').cast('int').alias('a'), + col('id').cast('int').alias('b'), + col('id').cast('double').alias('c')) + add = pandas_udf(lambda x, y: x + y, IntegerType()) + power2 = pandas_udf(lambda x: 2 ** x, IntegerType()) + mul = pandas_udf(lambda x, y: x * y, DoubleType()) + res = df.select(add(col('a'), col('b')), power2(col('a')), mul(col('b'), col('c'))) + expected = df.select(expr('a + b'), expr('power(2, a)'), expr('b * c')) + self.assertEquals(expected.collect(), res.collect()) + + def test_vectorized_udf_exception(self): + from pyspark.sql.functions import pandas_udf, col + df = self.spark.range(10) + raise_exception = pandas_udf(lambda x: x * (1 / 0), LongType()) + with QuietTest(self.sc): + with self.assertRaisesRegexp(Exception, 'division( or modulo)? by zero'): + df.select(raise_exception(col('id'))).collect() + + def test_vectorized_udf_invalid_length(self): + from pyspark.sql.functions import pandas_udf, col + import pandas as pd + df = self.spark.range(10) + raise_exception = pandas_udf(lambda _: pd.Series(1), LongType()) + with QuietTest(self.sc): + with self.assertRaisesRegexp( + Exception, + 'Result vector from pandas_udf was not the required length'): + df.select(raise_exception(col('id'))).collect() + + def test_vectorized_udf_chained(self): + from pyspark.sql.functions import pandas_udf, col + df = self.spark.range(10) + f = pandas_udf(lambda x: x + 1, LongType()) + g = pandas_udf(lambda x: x - 1, LongType()) + res = df.select(g(f(col('id')))) + self.assertEquals(df.collect(), res.collect()) + + def test_vectorized_udf_wrong_return_type(self): + from pyspark.sql.functions import pandas_udf + with QuietTest(self.sc): + with self.assertRaisesRegexp( + NotImplementedError, + 'Invalid returnType.*scalar Pandas UDF.*MapType'): + pandas_udf(lambda x: x * 1.0, MapType(LongType(), LongType())) + + def test_vectorized_udf_return_scalar(self): + from pyspark.sql.functions import pandas_udf, col + df = self.spark.range(10) + f = pandas_udf(lambda x: 1.0, DoubleType()) + with QuietTest(self.sc): + with self.assertRaisesRegexp(Exception, 'Return.*type.*Series'): + df.select(f(col('id'))).collect() + + def test_vectorized_udf_decorator(self): + from pyspark.sql.functions import pandas_udf, col + df = self.spark.range(10) + + @pandas_udf(returnType=LongType()) + def identity(x): + return x + res = df.select(identity(col('id'))) + self.assertEquals(df.collect(), res.collect()) + + def test_vectorized_udf_empty_partition(self): + from pyspark.sql.functions import pandas_udf, col + df = self.spark.createDataFrame(self.sc.parallelize([Row(id=1)], 2)) + f = pandas_udf(lambda x: x, LongType()) + res = df.select(f(col('id'))) + self.assertEquals(df.collect(), res.collect()) + + def test_vectorized_udf_varargs(self): + from pyspark.sql.functions import pandas_udf, col + df = self.spark.createDataFrame(self.sc.parallelize([Row(id=1)], 2)) + f = pandas_udf(lambda *v: v[0], LongType()) + res = df.select(f(col('id'))) + self.assertEquals(df.collect(), res.collect()) + + def test_vectorized_udf_unsupported_types(self): + from pyspark.sql.functions import pandas_udf + with QuietTest(self.sc): + with self.assertRaisesRegexp( + NotImplementedError, + 'Invalid returnType.*scalar Pandas UDF.*MapType'): + pandas_udf(lambda x: x, MapType(StringType(), IntegerType())) + + def test_vectorized_udf_dates(self): + from pyspark.sql.functions import pandas_udf, col + from datetime import date + schema = StructType().add("idx", LongType()).add("date", DateType()) + data = [(0, date(1969, 1, 1),), + (1, date(2012, 2, 2),), + (2, None,), + (3, date(2100, 4, 4),)] + df = self.spark.createDataFrame(data, schema=schema) + + date_copy = pandas_udf(lambda t: t, returnType=DateType()) + df = df.withColumn("date_copy", date_copy(col("date"))) + + @pandas_udf(returnType=StringType()) + def check_data(idx, date, date_copy): + import pandas as pd + msgs = [] + is_equal = date.isnull() + for i in range(len(idx)): + if (is_equal[i] and data[idx[i]][1] is None) or \ + date[i] == data[idx[i]][1]: + msgs.append(None) + else: + msgs.append( + "date values are not equal (date='%s': data[%d][1]='%s')" + % (date[i], idx[i], data[idx[i]][1])) + return pd.Series(msgs) + + result = df.withColumn("check_data", + check_data(col("idx"), col("date"), col("date_copy"))).collect() + + self.assertEquals(len(data), len(result)) + for i in range(len(result)): + self.assertEquals(data[i][1], result[i][1]) # "date" col + self.assertEquals(data[i][1], result[i][2]) # "date_copy" col + self.assertIsNone(result[i][3]) # "check_data" col + + def test_vectorized_udf_timestamps(self): + from pyspark.sql.functions import pandas_udf, col + from datetime import datetime + schema = StructType([ + StructField("idx", LongType(), True), + StructField("timestamp", TimestampType(), True)]) + data = [(0, datetime(1969, 1, 1, 1, 1, 1)), + (1, datetime(2012, 2, 2, 2, 2, 2)), + (2, None), + (3, datetime(2100, 3, 3, 3, 3, 3))] + + df = self.spark.createDataFrame(data, schema=schema) + + # Check that a timestamp passed through a pandas_udf will not be altered by timezone calc + f_timestamp_copy = pandas_udf(lambda t: t, returnType=TimestampType()) + df = df.withColumn("timestamp_copy", f_timestamp_copy(col("timestamp"))) + + @pandas_udf(returnType=StringType()) + def check_data(idx, timestamp, timestamp_copy): + import pandas as pd + msgs = [] + is_equal = timestamp.isnull() # use this array to check values are equal + for i in range(len(idx)): + # Check that timestamps are as expected in the UDF + if (is_equal[i] and data[idx[i]][1] is None) or \ + timestamp[i].to_pydatetime() == data[idx[i]][1]: + msgs.append(None) + else: + msgs.append( + "timestamp values are not equal (timestamp='%s': data[%d][1]='%s')" + % (timestamp[i], idx[i], data[idx[i]][1])) + return pd.Series(msgs) + + result = df.withColumn("check_data", check_data(col("idx"), col("timestamp"), + col("timestamp_copy"))).collect() + # Check that collection values are correct + self.assertEquals(len(data), len(result)) + for i in range(len(result)): + self.assertEquals(data[i][1], result[i][1]) # "timestamp" col + self.assertEquals(data[i][1], result[i][2]) # "timestamp_copy" col + self.assertIsNone(result[i][3]) # "check_data" col + + def test_vectorized_udf_return_timestamp_tz(self): + from pyspark.sql.functions import pandas_udf, col + import pandas as pd + df = self.spark.range(10) + + @pandas_udf(returnType=TimestampType()) + def gen_timestamps(id): + ts = [pd.Timestamp(i, unit='D', tz='America/Los_Angeles') for i in id] + return pd.Series(ts) + + result = df.withColumn("ts", gen_timestamps(col("id"))).collect() + spark_ts_t = TimestampType() + for r in result: + i, ts = r + ts_tz = pd.Timestamp(i, unit='D', tz='America/Los_Angeles').to_pydatetime() + expected = spark_ts_t.fromInternal(spark_ts_t.toInternal(ts_tz)) + self.assertEquals(expected, ts) + + def test_vectorized_udf_check_config(self): + from pyspark.sql.functions import pandas_udf, col + import pandas as pd + with self.sql_conf({"spark.sql.execution.arrow.maxRecordsPerBatch": 3}): + df = self.spark.range(10, numPartitions=1) + + @pandas_udf(returnType=LongType()) + def check_records_per_batch(x): + return pd.Series(x.size).repeat(x.size) + + result = df.select(check_records_per_batch(col("id"))).collect() + for (r,) in result: + self.assertTrue(r <= 3) + + def test_vectorized_udf_timestamps_respect_session_timezone(self): + from pyspark.sql.functions import pandas_udf, col + from datetime import datetime + import pandas as pd + schema = StructType([ + StructField("idx", LongType(), True), + StructField("timestamp", TimestampType(), True)]) + data = [(1, datetime(1969, 1, 1, 1, 1, 1)), + (2, datetime(2012, 2, 2, 2, 2, 2)), + (3, None), + (4, datetime(2100, 3, 3, 3, 3, 3))] + df = self.spark.createDataFrame(data, schema=schema) + + f_timestamp_copy = pandas_udf(lambda ts: ts, TimestampType()) + internal_value = pandas_udf( + lambda ts: ts.apply(lambda ts: ts.value if ts is not pd.NaT else None), LongType()) + + timezone = "America/New_York" + with self.sql_conf({ + "spark.sql.execution.pandas.respectSessionTimeZone": False, + "spark.sql.session.timeZone": timezone}): + df_la = df.withColumn("tscopy", f_timestamp_copy(col("timestamp"))) \ + .withColumn("internal_value", internal_value(col("timestamp"))) + result_la = df_la.select(col("idx"), col("internal_value")).collect() + # Correct result_la by adjusting 3 hours difference between Los Angeles and New York + diff = 5 * 60 * 60 * 1000 * 1000 * 1000 + result_la_corrected = \ + df_la.select(col("idx"), col("tscopy"), col("internal_value") - diff).collect() + + with self.sql_conf({ + "spark.sql.execution.pandas.respectSessionTimeZone": True, + "spark.sql.session.timeZone": timezone}): + df_ny = df.withColumn("tscopy", f_timestamp_copy(col("timestamp"))) \ + .withColumn("internal_value", internal_value(col("timestamp"))) + result_ny = df_ny.select(col("idx"), col("tscopy"), col("internal_value")).collect() + + self.assertNotEqual(result_ny, result_la) + self.assertEqual(result_ny, result_la_corrected) + + def test_nondeterministic_vectorized_udf(self): + # Test that nondeterministic UDFs are evaluated only once in chained UDF evaluations + from pyspark.sql.functions import pandas_udf, col + + @pandas_udf('double') + def plus_ten(v): + return v + 10 + random_udf = self.nondeterministic_vectorized_udf + + df = self.spark.range(10).withColumn('rand', random_udf(col('id'))) + result1 = df.withColumn('plus_ten(rand)', plus_ten(df['rand'])).toPandas() + + self.assertEqual(random_udf.deterministic, False) + self.assertTrue(result1['plus_ten(rand)'].equals(result1['rand'] + 10)) + + def test_nondeterministic_vectorized_udf_in_aggregate(self): + from pyspark.sql.functions import sum + + df = self.spark.range(10) + random_udf = self.nondeterministic_vectorized_udf + + with QuietTest(self.sc): + with self.assertRaisesRegexp(AnalysisException, 'nondeterministic'): + df.groupby(df.id).agg(sum(random_udf(df.id))).collect() + with self.assertRaisesRegexp(AnalysisException, 'nondeterministic'): + df.agg(sum(random_udf(df.id))).collect() + + def test_register_vectorized_udf_basic(self): + from pyspark.rdd import PythonEvalType + from pyspark.sql.functions import pandas_udf, col, expr + df = self.spark.range(10).select( + col('id').cast('int').alias('a'), + col('id').cast('int').alias('b')) + original_add = pandas_udf(lambda x, y: x + y, IntegerType()) + self.assertEqual(original_add.deterministic, True) + self.assertEqual(original_add.evalType, PythonEvalType.SQL_SCALAR_PANDAS_UDF) + new_add = self.spark.catalog.registerFunction("add1", original_add) + res1 = df.select(new_add(col('a'), col('b'))) + res2 = self.spark.sql( + "SELECT add1(t.a, t.b) FROM (SELECT id as a, id as b FROM range(10)) t") + expected = df.select(expr('a + b')) + self.assertEquals(expected.collect(), res1.collect()) + self.assertEquals(expected.collect(), res2.collect()) + + # Regression test for SPARK-23314 + def test_timestamp_dst(self): + from pyspark.sql.functions import pandas_udf + # Daylight saving time for Los Angeles for 2015 is Sun, Nov 1 at 2:00 am + dt = [datetime.datetime(2015, 11, 1, 0, 30), + datetime.datetime(2015, 11, 1, 1, 30), + datetime.datetime(2015, 11, 1, 2, 30)] + df = self.spark.createDataFrame(dt, 'timestamp').toDF('time') + foo_udf = pandas_udf(lambda x: x, 'timestamp') + result = df.withColumn('time', foo_udf(df.time)) + self.assertEquals(df.collect(), result.collect()) + + @unittest.skipIf(sys.version_info[:2] < (3, 5), "Type hints are supported from Python 3.5.") + def test_type_annotation(self): + from pyspark.sql.functions import pandas_udf + # Regression test to check if type hints can be used. See SPARK-23569. + # Note that it throws an error during compilation in lower Python versions if 'exec' + # is not used. Also, note that we explicitly use another dictionary to avoid modifications + # in the current 'locals()'. + # + # Hyukjin: I think it's an ugly way to test issues about syntax specific in + # higher versions of Python, which we shouldn't encourage. This was the last resort + # I could come up with at that time. + _locals = {} + exec( + "import pandas as pd\ndef noop(col: pd.Series) -> pd.Series: return col", + _locals) + df = self.spark.range(1).select(pandas_udf(f=_locals['noop'], returnType='bigint')('id')) + self.assertEqual(df.first()[0], 0) + + def test_mixed_udf(self): + import pandas as pd + from pyspark.sql.functions import col, udf, pandas_udf + + df = self.spark.range(0, 1).toDF('v') + + # Test mixture of multiple UDFs and Pandas UDFs. + + @udf('int') + def f1(x): + assert type(x) == int + return x + 1 + + @pandas_udf('int') + def f2(x): + assert type(x) == pd.Series + return x + 10 + + @udf('int') + def f3(x): + assert type(x) == int + return x + 100 + + @pandas_udf('int') + def f4(x): + assert type(x) == pd.Series + return x + 1000 + + # Test single expression with chained UDFs + df_chained_1 = df.withColumn('f2_f1', f2(f1(df['v']))) + df_chained_2 = df.withColumn('f3_f2_f1', f3(f2(f1(df['v'])))) + df_chained_3 = df.withColumn('f4_f3_f2_f1', f4(f3(f2(f1(df['v']))))) + df_chained_4 = df.withColumn('f4_f2_f1', f4(f2(f1(df['v'])))) + df_chained_5 = df.withColumn('f4_f3_f1', f4(f3(f1(df['v'])))) + + expected_chained_1 = df.withColumn('f2_f1', df['v'] + 11) + expected_chained_2 = df.withColumn('f3_f2_f1', df['v'] + 111) + expected_chained_3 = df.withColumn('f4_f3_f2_f1', df['v'] + 1111) + expected_chained_4 = df.withColumn('f4_f2_f1', df['v'] + 1011) + expected_chained_5 = df.withColumn('f4_f3_f1', df['v'] + 1101) + + self.assertEquals(expected_chained_1.collect(), df_chained_1.collect()) + self.assertEquals(expected_chained_2.collect(), df_chained_2.collect()) + self.assertEquals(expected_chained_3.collect(), df_chained_3.collect()) + self.assertEquals(expected_chained_4.collect(), df_chained_4.collect()) + self.assertEquals(expected_chained_5.collect(), df_chained_5.collect()) + + # Test multiple mixed UDF expressions in a single projection + df_multi_1 = df \ + .withColumn('f1', f1(col('v'))) \ + .withColumn('f2', f2(col('v'))) \ + .withColumn('f3', f3(col('v'))) \ + .withColumn('f4', f4(col('v'))) \ + .withColumn('f2_f1', f2(col('f1'))) \ + .withColumn('f3_f1', f3(col('f1'))) \ + .withColumn('f4_f1', f4(col('f1'))) \ + .withColumn('f3_f2', f3(col('f2'))) \ + .withColumn('f4_f2', f4(col('f2'))) \ + .withColumn('f4_f3', f4(col('f3'))) \ + .withColumn('f3_f2_f1', f3(col('f2_f1'))) \ + .withColumn('f4_f2_f1', f4(col('f2_f1'))) \ + .withColumn('f4_f3_f1', f4(col('f3_f1'))) \ + .withColumn('f4_f3_f2', f4(col('f3_f2'))) \ + .withColumn('f4_f3_f2_f1', f4(col('f3_f2_f1'))) + + # Test mixed udfs in a single expression + df_multi_2 = df \ + .withColumn('f1', f1(col('v'))) \ + .withColumn('f2', f2(col('v'))) \ + .withColumn('f3', f3(col('v'))) \ + .withColumn('f4', f4(col('v'))) \ + .withColumn('f2_f1', f2(f1(col('v')))) \ + .withColumn('f3_f1', f3(f1(col('v')))) \ + .withColumn('f4_f1', f4(f1(col('v')))) \ + .withColumn('f3_f2', f3(f2(col('v')))) \ + .withColumn('f4_f2', f4(f2(col('v')))) \ + .withColumn('f4_f3', f4(f3(col('v')))) \ + .withColumn('f3_f2_f1', f3(f2(f1(col('v'))))) \ + .withColumn('f4_f2_f1', f4(f2(f1(col('v'))))) \ + .withColumn('f4_f3_f1', f4(f3(f1(col('v'))))) \ + .withColumn('f4_f3_f2', f4(f3(f2(col('v'))))) \ + .withColumn('f4_f3_f2_f1', f4(f3(f2(f1(col('v')))))) + + expected = df \ + .withColumn('f1', df['v'] + 1) \ + .withColumn('f2', df['v'] + 10) \ + .withColumn('f3', df['v'] + 100) \ + .withColumn('f4', df['v'] + 1000) \ + .withColumn('f2_f1', df['v'] + 11) \ + .withColumn('f3_f1', df['v'] + 101) \ + .withColumn('f4_f1', df['v'] + 1001) \ + .withColumn('f3_f2', df['v'] + 110) \ + .withColumn('f4_f2', df['v'] + 1010) \ + .withColumn('f4_f3', df['v'] + 1100) \ + .withColumn('f3_f2_f1', df['v'] + 111) \ + .withColumn('f4_f2_f1', df['v'] + 1011) \ + .withColumn('f4_f3_f1', df['v'] + 1101) \ + .withColumn('f4_f3_f2', df['v'] + 1110) \ + .withColumn('f4_f3_f2_f1', df['v'] + 1111) + + self.assertEquals(expected.collect(), df_multi_1.collect()) + self.assertEquals(expected.collect(), df_multi_2.collect()) + + def test_mixed_udf_and_sql(self): + import pandas as pd + from pyspark.sql import Column + from pyspark.sql.functions import udf, pandas_udf + + df = self.spark.range(0, 1).toDF('v') + + # Test mixture of UDFs, Pandas UDFs and SQL expression. + + @udf('int') + def f1(x): + assert type(x) == int + return x + 1 + + def f2(x): + assert type(x) == Column + return x + 10 + + @pandas_udf('int') + def f3(x): + assert type(x) == pd.Series + return x + 100 + + df1 = df.withColumn('f1', f1(df['v'])) \ + .withColumn('f2', f2(df['v'])) \ + .withColumn('f3', f3(df['v'])) \ + .withColumn('f1_f2', f1(f2(df['v']))) \ + .withColumn('f1_f3', f1(f3(df['v']))) \ + .withColumn('f2_f1', f2(f1(df['v']))) \ + .withColumn('f2_f3', f2(f3(df['v']))) \ + .withColumn('f3_f1', f3(f1(df['v']))) \ + .withColumn('f3_f2', f3(f2(df['v']))) \ + .withColumn('f1_f2_f3', f1(f2(f3(df['v'])))) \ + .withColumn('f1_f3_f2', f1(f3(f2(df['v'])))) \ + .withColumn('f2_f1_f3', f2(f1(f3(df['v'])))) \ + .withColumn('f2_f3_f1', f2(f3(f1(df['v'])))) \ + .withColumn('f3_f1_f2', f3(f1(f2(df['v'])))) \ + .withColumn('f3_f2_f1', f3(f2(f1(df['v'])))) + + expected = df.withColumn('f1', df['v'] + 1) \ + .withColumn('f2', df['v'] + 10) \ + .withColumn('f3', df['v'] + 100) \ + .withColumn('f1_f2', df['v'] + 11) \ + .withColumn('f1_f3', df['v'] + 101) \ + .withColumn('f2_f1', df['v'] + 11) \ + .withColumn('f2_f3', df['v'] + 110) \ + .withColumn('f3_f1', df['v'] + 101) \ + .withColumn('f3_f2', df['v'] + 110) \ + .withColumn('f1_f2_f3', df['v'] + 111) \ + .withColumn('f1_f3_f2', df['v'] + 111) \ + .withColumn('f2_f1_f3', df['v'] + 111) \ + .withColumn('f2_f3_f1', df['v'] + 111) \ + .withColumn('f3_f1_f2', df['v'] + 111) \ + .withColumn('f3_f2_f1', df['v'] + 111) + + self.assertEquals(expected.collect(), df1.collect()) + + # SPARK-24721 + @unittest.skipIf(not test_compiled, test_not_compiled_message) + def test_datasource_with_udf(self): + # Same as SQLTests.test_datasource_with_udf, but with Pandas UDF + # This needs to a separate test because Arrow dependency is optional + import pandas as pd + import numpy as np + from pyspark.sql.functions import pandas_udf, lit, col + + path = tempfile.mkdtemp() + shutil.rmtree(path) + + try: + self.spark.range(1).write.mode("overwrite").format('csv').save(path) + filesource_df = self.spark.read.option('inferSchema', True).csv(path).toDF('i') + datasource_df = self.spark.read \ + .format("org.apache.spark.sql.sources.SimpleScanSource") \ + .option('from', 0).option('to', 1).load().toDF('i') + datasource_v2_df = self.spark.read \ + .format("org.apache.spark.sql.sources.v2.SimpleDataSourceV2") \ + .load().toDF('i', 'j') + + c1 = pandas_udf(lambda x: x + 1, 'int')(lit(1)) + c2 = pandas_udf(lambda x: x + 1, 'int')(col('i')) + + f1 = pandas_udf(lambda x: pd.Series(np.repeat(False, len(x))), 'boolean')(lit(1)) + f2 = pandas_udf(lambda x: pd.Series(np.repeat(False, len(x))), 'boolean')(col('i')) + + for df in [filesource_df, datasource_df, datasource_v2_df]: + result = df.withColumn('c', c1) + expected = df.withColumn('c', lit(2)) + self.assertEquals(expected.collect(), result.collect()) + + for df in [filesource_df, datasource_df, datasource_v2_df]: + result = df.withColumn('c', c2) + expected = df.withColumn('c', col('i') + 1) + self.assertEquals(expected.collect(), result.collect()) + + for df in [filesource_df, datasource_df, datasource_v2_df]: + for f in [f1, f2]: + result = df.filter(f) + self.assertEquals(0, result.count()) + finally: + shutil.rmtree(path) + + +if __name__ == "__main__": + from pyspark.sql.tests.test_pandas_udf_scalar import * + + try: + import xmlrunner + testRunner = xmlrunner.XMLTestRunner(output='target/test-reports') + except ImportError: + testRunner = None + unittest.main(testRunner=testRunner, verbosity=2) diff --git a/python/pyspark/sql/tests/test_pandas_udf_window.py b/python/pyspark/sql/tests/test_pandas_udf_window.py new file mode 100644 index 0000000000000..f0e6d2696df62 --- /dev/null +++ b/python/pyspark/sql/tests/test_pandas_udf_window.py @@ -0,0 +1,263 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import unittest + +from pyspark.sql.utils import AnalysisException +from pyspark.sql.window import Window +from pyspark.testing.sqlutils import ReusedSQLTestCase, have_pandas, have_pyarrow, \ + pandas_requirement_message, pyarrow_requirement_message +from pyspark.testing.utils import QuietTest + + +@unittest.skipIf( + not have_pandas or not have_pyarrow, + pandas_requirement_message or pyarrow_requirement_message) +class WindowPandasUDFTests(ReusedSQLTestCase): + @property + def data(self): + from pyspark.sql.functions import array, explode, col, lit + return self.spark.range(10).toDF('id') \ + .withColumn("vs", array([lit(i * 1.0) + col('id') for i in range(20, 30)])) \ + .withColumn("v", explode(col('vs'))) \ + .drop('vs') \ + .withColumn('w', lit(1.0)) + + @property + def python_plus_one(self): + from pyspark.sql.functions import udf + return udf(lambda v: v + 1, 'double') + + @property + def pandas_scalar_time_two(self): + from pyspark.sql.functions import pandas_udf + return pandas_udf(lambda v: v * 2, 'double') + + @property + def pandas_agg_mean_udf(self): + from pyspark.sql.functions import pandas_udf, PandasUDFType + + @pandas_udf('double', PandasUDFType.GROUPED_AGG) + def avg(v): + return v.mean() + return avg + + @property + def pandas_agg_max_udf(self): + from pyspark.sql.functions import pandas_udf, PandasUDFType + + @pandas_udf('double', PandasUDFType.GROUPED_AGG) + def max(v): + return v.max() + return max + + @property + def pandas_agg_min_udf(self): + from pyspark.sql.functions import pandas_udf, PandasUDFType + + @pandas_udf('double', PandasUDFType.GROUPED_AGG) + def min(v): + return v.min() + return min + + @property + def unbounded_window(self): + return Window.partitionBy('id') \ + .rowsBetween(Window.unboundedPreceding, Window.unboundedFollowing) + + @property + def ordered_window(self): + return Window.partitionBy('id').orderBy('v') + + @property + def unpartitioned_window(self): + return Window.partitionBy() + + def test_simple(self): + from pyspark.sql.functions import mean + + df = self.data + w = self.unbounded_window + + mean_udf = self.pandas_agg_mean_udf + + result1 = df.withColumn('mean_v', mean_udf(df['v']).over(w)) + expected1 = df.withColumn('mean_v', mean(df['v']).over(w)) + + result2 = df.select(mean_udf(df['v']).over(w)) + expected2 = df.select(mean(df['v']).over(w)) + + self.assertPandasEqual(expected1.toPandas(), result1.toPandas()) + self.assertPandasEqual(expected2.toPandas(), result2.toPandas()) + + def test_multiple_udfs(self): + from pyspark.sql.functions import max, min, mean + + df = self.data + w = self.unbounded_window + + result1 = df.withColumn('mean_v', self.pandas_agg_mean_udf(df['v']).over(w)) \ + .withColumn('max_v', self.pandas_agg_max_udf(df['v']).over(w)) \ + .withColumn('min_w', self.pandas_agg_min_udf(df['w']).over(w)) + + expected1 = df.withColumn('mean_v', mean(df['v']).over(w)) \ + .withColumn('max_v', max(df['v']).over(w)) \ + .withColumn('min_w', min(df['w']).over(w)) + + self.assertPandasEqual(expected1.toPandas(), result1.toPandas()) + + def test_replace_existing(self): + from pyspark.sql.functions import mean + + df = self.data + w = self.unbounded_window + + result1 = df.withColumn('v', self.pandas_agg_mean_udf(df['v']).over(w)) + expected1 = df.withColumn('v', mean(df['v']).over(w)) + + self.assertPandasEqual(expected1.toPandas(), result1.toPandas()) + + def test_mixed_sql(self): + from pyspark.sql.functions import mean + + df = self.data + w = self.unbounded_window + mean_udf = self.pandas_agg_mean_udf + + result1 = df.withColumn('v', mean_udf(df['v'] * 2).over(w) + 1) + expected1 = df.withColumn('v', mean(df['v'] * 2).over(w) + 1) + + self.assertPandasEqual(expected1.toPandas(), result1.toPandas()) + + def test_mixed_udf(self): + from pyspark.sql.functions import mean + + df = self.data + w = self.unbounded_window + + plus_one = self.python_plus_one + time_two = self.pandas_scalar_time_two + mean_udf = self.pandas_agg_mean_udf + + result1 = df.withColumn( + 'v2', + plus_one(mean_udf(plus_one(df['v'])).over(w))) + expected1 = df.withColumn( + 'v2', + plus_one(mean(plus_one(df['v'])).over(w))) + + result2 = df.withColumn( + 'v2', + time_two(mean_udf(time_two(df['v'])).over(w))) + expected2 = df.withColumn( + 'v2', + time_two(mean(time_two(df['v'])).over(w))) + + self.assertPandasEqual(expected1.toPandas(), result1.toPandas()) + self.assertPandasEqual(expected2.toPandas(), result2.toPandas()) + + def test_without_partitionBy(self): + from pyspark.sql.functions import mean + + df = self.data + w = self.unpartitioned_window + mean_udf = self.pandas_agg_mean_udf + + result1 = df.withColumn('v2', mean_udf(df['v']).over(w)) + expected1 = df.withColumn('v2', mean(df['v']).over(w)) + + result2 = df.select(mean_udf(df['v']).over(w)) + expected2 = df.select(mean(df['v']).over(w)) + + self.assertPandasEqual(expected1.toPandas(), result1.toPandas()) + self.assertPandasEqual(expected2.toPandas(), result2.toPandas()) + + def test_mixed_sql_and_udf(self): + from pyspark.sql.functions import max, min, rank, col + + df = self.data + w = self.unbounded_window + ow = self.ordered_window + max_udf = self.pandas_agg_max_udf + min_udf = self.pandas_agg_min_udf + + result1 = df.withColumn('v_diff', max_udf(df['v']).over(w) - min_udf(df['v']).over(w)) + expected1 = df.withColumn('v_diff', max(df['v']).over(w) - min(df['v']).over(w)) + + # Test mixing sql window function and window udf in the same expression + result2 = df.withColumn('v_diff', max_udf(df['v']).over(w) - min(df['v']).over(w)) + expected2 = expected1 + + # Test chaining sql aggregate function and udf + result3 = df.withColumn('max_v', max_udf(df['v']).over(w)) \ + .withColumn('min_v', min(df['v']).over(w)) \ + .withColumn('v_diff', col('max_v') - col('min_v')) \ + .drop('max_v', 'min_v') + expected3 = expected1 + + # Test mixing sql window function and udf + result4 = df.withColumn('max_v', max_udf(df['v']).over(w)) \ + .withColumn('rank', rank().over(ow)) + expected4 = df.withColumn('max_v', max(df['v']).over(w)) \ + .withColumn('rank', rank().over(ow)) + + self.assertPandasEqual(expected1.toPandas(), result1.toPandas()) + self.assertPandasEqual(expected2.toPandas(), result2.toPandas()) + self.assertPandasEqual(expected3.toPandas(), result3.toPandas()) + self.assertPandasEqual(expected4.toPandas(), result4.toPandas()) + + def test_array_type(self): + from pyspark.sql.functions import pandas_udf, PandasUDFType + + df = self.data + w = self.unbounded_window + + array_udf = pandas_udf(lambda x: [1.0, 2.0], 'array', PandasUDFType.GROUPED_AGG) + result1 = df.withColumn('v2', array_udf(df['v']).over(w)) + self.assertEquals(result1.first()['v2'], [1.0, 2.0]) + + def test_invalid_args(self): + from pyspark.sql.functions import pandas_udf, PandasUDFType + + df = self.data + w = self.unbounded_window + ow = self.ordered_window + mean_udf = self.pandas_agg_mean_udf + + with QuietTest(self.sc): + with self.assertRaisesRegexp( + AnalysisException, + '.*not supported within a window function'): + foo_udf = pandas_udf(lambda x: x, 'v double', PandasUDFType.GROUPED_MAP) + df.withColumn('v2', foo_udf(df['v']).over(w)) + + with QuietTest(self.sc): + with self.assertRaisesRegexp( + AnalysisException, + '.*Only unbounded window frame is supported.*'): + df.withColumn('mean_v', mean_udf(df['v']).over(ow)) + + +if __name__ == "__main__": + from pyspark.sql.tests.test_pandas_udf_window import * + + try: + import xmlrunner + testRunner = xmlrunner.XMLTestRunner(output='target/test-reports') + except ImportError: + testRunner = None + unittest.main(testRunner=testRunner, verbosity=2) diff --git a/python/pyspark/sql/tests/test_readwriter.py b/python/pyspark/sql/tests/test_readwriter.py new file mode 100644 index 0000000000000..2f8712d7631f5 --- /dev/null +++ b/python/pyspark/sql/tests/test_readwriter.py @@ -0,0 +1,154 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import os +import shutil +import tempfile + +from pyspark.sql.types import * +from pyspark.testing.sqlutils import ReusedSQLTestCase + + +class ReadwriterTests(ReusedSQLTestCase): + + def test_save_and_load(self): + df = self.df + tmpPath = tempfile.mkdtemp() + shutil.rmtree(tmpPath) + df.write.json(tmpPath) + actual = self.spark.read.json(tmpPath) + self.assertEqual(sorted(df.collect()), sorted(actual.collect())) + + schema = StructType([StructField("value", StringType(), True)]) + actual = self.spark.read.json(tmpPath, schema) + self.assertEqual(sorted(df.select("value").collect()), sorted(actual.collect())) + + df.write.json(tmpPath, "overwrite") + actual = self.spark.read.json(tmpPath) + self.assertEqual(sorted(df.collect()), sorted(actual.collect())) + + df.write.save(format="json", mode="overwrite", path=tmpPath, + noUse="this options will not be used in save.") + actual = self.spark.read.load(format="json", path=tmpPath, + noUse="this options will not be used in load.") + self.assertEqual(sorted(df.collect()), sorted(actual.collect())) + + defaultDataSourceName = self.spark.conf.get("spark.sql.sources.default", + "org.apache.spark.sql.parquet") + self.spark.sql("SET spark.sql.sources.default=org.apache.spark.sql.json") + actual = self.spark.read.load(path=tmpPath) + self.assertEqual(sorted(df.collect()), sorted(actual.collect())) + self.spark.sql("SET spark.sql.sources.default=" + defaultDataSourceName) + + csvpath = os.path.join(tempfile.mkdtemp(), 'data') + df.write.option('quote', None).format('csv').save(csvpath) + + shutil.rmtree(tmpPath) + + def test_save_and_load_builder(self): + df = self.df + tmpPath = tempfile.mkdtemp() + shutil.rmtree(tmpPath) + df.write.json(tmpPath) + actual = self.spark.read.json(tmpPath) + self.assertEqual(sorted(df.collect()), sorted(actual.collect())) + + schema = StructType([StructField("value", StringType(), True)]) + actual = self.spark.read.json(tmpPath, schema) + self.assertEqual(sorted(df.select("value").collect()), sorted(actual.collect())) + + df.write.mode("overwrite").json(tmpPath) + actual = self.spark.read.json(tmpPath) + self.assertEqual(sorted(df.collect()), sorted(actual.collect())) + + df.write.mode("overwrite").options(noUse="this options will not be used in save.")\ + .option("noUse", "this option will not be used in save.")\ + .format("json").save(path=tmpPath) + actual =\ + self.spark.read.format("json")\ + .load(path=tmpPath, noUse="this options will not be used in load.") + self.assertEqual(sorted(df.collect()), sorted(actual.collect())) + + defaultDataSourceName = self.spark.conf.get("spark.sql.sources.default", + "org.apache.spark.sql.parquet") + self.spark.sql("SET spark.sql.sources.default=org.apache.spark.sql.json") + actual = self.spark.read.load(path=tmpPath) + self.assertEqual(sorted(df.collect()), sorted(actual.collect())) + self.spark.sql("SET spark.sql.sources.default=" + defaultDataSourceName) + + shutil.rmtree(tmpPath) + + def test_bucketed_write(self): + data = [ + (1, "foo", 3.0), (2, "foo", 5.0), + (3, "bar", -1.0), (4, "bar", 6.0), + ] + df = self.spark.createDataFrame(data, ["x", "y", "z"]) + + def count_bucketed_cols(names, table="pyspark_bucket"): + """Given a sequence of column names and a table name + query the catalog and return number o columns which are + used for bucketing + """ + cols = self.spark.catalog.listColumns(table) + num = len([c for c in cols if c.name in names and c.isBucket]) + return num + + with self.table("pyspark_bucket"): + # Test write with one bucketing column + df.write.bucketBy(3, "x").mode("overwrite").saveAsTable("pyspark_bucket") + self.assertEqual(count_bucketed_cols(["x"]), 1) + self.assertSetEqual(set(data), set(self.spark.table("pyspark_bucket").collect())) + + # Test write two bucketing columns + df.write.bucketBy(3, "x", "y").mode("overwrite").saveAsTable("pyspark_bucket") + self.assertEqual(count_bucketed_cols(["x", "y"]), 2) + self.assertSetEqual(set(data), set(self.spark.table("pyspark_bucket").collect())) + + # Test write with bucket and sort + df.write.bucketBy(2, "x").sortBy("z").mode("overwrite").saveAsTable("pyspark_bucket") + self.assertEqual(count_bucketed_cols(["x"]), 1) + self.assertSetEqual(set(data), set(self.spark.table("pyspark_bucket").collect())) + + # Test write with a list of columns + df.write.bucketBy(3, ["x", "y"]).mode("overwrite").saveAsTable("pyspark_bucket") + self.assertEqual(count_bucketed_cols(["x", "y"]), 2) + self.assertSetEqual(set(data), set(self.spark.table("pyspark_bucket").collect())) + + # Test write with bucket and sort with a list of columns + (df.write.bucketBy(2, "x") + .sortBy(["y", "z"]) + .mode("overwrite").saveAsTable("pyspark_bucket")) + self.assertSetEqual(set(data), set(self.spark.table("pyspark_bucket").collect())) + + # Test write with bucket and sort with multiple columns + (df.write.bucketBy(2, "x") + .sortBy("y", "z") + .mode("overwrite").saveAsTable("pyspark_bucket")) + self.assertSetEqual(set(data), set(self.spark.table("pyspark_bucket").collect())) + + +if __name__ == "__main__": + import unittest + from pyspark.sql.tests.test_readwriter import * + + try: + import xmlrunner + testRunner = xmlrunner.XMLTestRunner(output='target/test-reports') + except ImportError: + testRunner = None + unittest.main(testRunner=testRunner, verbosity=2) diff --git a/python/pyspark/sql/tests/test_serde.py b/python/pyspark/sql/tests/test_serde.py new file mode 100644 index 0000000000000..8707f46b6a25a --- /dev/null +++ b/python/pyspark/sql/tests/test_serde.py @@ -0,0 +1,139 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import datetime +import shutil +import tempfile +import time + +from pyspark.sql import Row +from pyspark.sql.functions import lit +from pyspark.sql.types import * +from pyspark.testing.sqlutils import ReusedSQLTestCase, UTCOffsetTimezone + + +class SerdeTests(ReusedSQLTestCase): + + def test_serialize_nested_array_and_map(self): + d = [Row(l=[Row(a=1, b='s')], d={"key": Row(c=1.0, d="2")})] + rdd = self.sc.parallelize(d) + df = self.spark.createDataFrame(rdd) + row = df.head() + self.assertEqual(1, len(row.l)) + self.assertEqual(1, row.l[0].a) + self.assertEqual("2", row.d["key"].d) + + l = df.rdd.map(lambda x: x.l).first() + self.assertEqual(1, len(l)) + self.assertEqual('s', l[0].b) + + d = df.rdd.map(lambda x: x.d).first() + self.assertEqual(1, len(d)) + self.assertEqual(1.0, d["key"].c) + + row = df.rdd.map(lambda x: x.d["key"]).first() + self.assertEqual(1.0, row.c) + self.assertEqual("2", row.d) + + def test_select_null_literal(self): + df = self.spark.sql("select null as col") + self.assertEqual(Row(col=None), df.first()) + + def test_struct_in_map(self): + d = [Row(m={Row(i=1): Row(s="")})] + df = self.sc.parallelize(d).toDF() + k, v = list(df.head().m.items())[0] + self.assertEqual(1, k.i) + self.assertEqual("", v.s) + + def test_filter_with_datetime(self): + time = datetime.datetime(2015, 4, 17, 23, 1, 2, 3000) + date = time.date() + row = Row(date=date, time=time) + df = self.spark.createDataFrame([row]) + self.assertEqual(1, df.filter(df.date == date).count()) + self.assertEqual(1, df.filter(df.time == time).count()) + self.assertEqual(0, df.filter(df.date > date).count()) + self.assertEqual(0, df.filter(df.time > time).count()) + + def test_filter_with_datetime_timezone(self): + dt1 = datetime.datetime(2015, 4, 17, 23, 1, 2, 3000, tzinfo=UTCOffsetTimezone(0)) + dt2 = datetime.datetime(2015, 4, 17, 23, 1, 2, 3000, tzinfo=UTCOffsetTimezone(1)) + row = Row(date=dt1) + df = self.spark.createDataFrame([row]) + self.assertEqual(0, df.filter(df.date == dt2).count()) + self.assertEqual(1, df.filter(df.date > dt2).count()) + self.assertEqual(0, df.filter(df.date < dt2).count()) + + def test_time_with_timezone(self): + day = datetime.date.today() + now = datetime.datetime.now() + ts = time.mktime(now.timetuple()) + # class in __main__ is not serializable + from pyspark.testing.sqlutils import UTCOffsetTimezone + utc = UTCOffsetTimezone() + utcnow = datetime.datetime.utcfromtimestamp(ts) # without microseconds + # add microseconds to utcnow (keeping year,month,day,hour,minute,second) + utcnow = datetime.datetime(*(utcnow.timetuple()[:6] + (now.microsecond, utc))) + df = self.spark.createDataFrame([(day, now, utcnow)]) + day1, now1, utcnow1 = df.first() + self.assertEqual(day1, day) + self.assertEqual(now, now1) + self.assertEqual(now, utcnow1) + + # regression test for SPARK-19561 + def test_datetime_at_epoch(self): + epoch = datetime.datetime.fromtimestamp(0) + df = self.spark.createDataFrame([Row(date=epoch)]) + first = df.select('date', lit(epoch).alias('lit_date')).first() + self.assertEqual(first['date'], epoch) + self.assertEqual(first['lit_date'], epoch) + + def test_decimal(self): + from decimal import Decimal + schema = StructType([StructField("decimal", DecimalType(10, 5))]) + df = self.spark.createDataFrame([(Decimal("3.14159"),)], schema) + row = df.select(df.decimal + 1).first() + self.assertEqual(row[0], Decimal("4.14159")) + tmpPath = tempfile.mkdtemp() + shutil.rmtree(tmpPath) + df.write.parquet(tmpPath) + df2 = self.spark.read.parquet(tmpPath) + row = df2.first() + self.assertEqual(row[0], Decimal("3.14159")) + + def test_BinaryType_serialization(self): + # Pyrolite version <= 4.9 could not serialize BinaryType with Python3 SPARK-17808 + # The empty bytearray is test for SPARK-21534. + schema = StructType([StructField('mybytes', BinaryType())]) + data = [[bytearray(b'here is my data')], + [bytearray(b'and here is some more')], + [bytearray(b'')]] + df = self.spark.createDataFrame(data, schema=schema) + df.collect() + + +if __name__ == "__main__": + import unittest + from pyspark.sql.tests.test_serde import * + + try: + import xmlrunner + testRunner = xmlrunner.XMLTestRunner(output='target/test-reports') + except ImportError: + testRunner = None + unittest.main(testRunner=testRunner, verbosity=2) diff --git a/python/pyspark/sql/tests/test_session.py b/python/pyspark/sql/tests/test_session.py new file mode 100644 index 0000000000000..c6b9e0b2ca554 --- /dev/null +++ b/python/pyspark/sql/tests/test_session.py @@ -0,0 +1,321 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import os +import unittest + +from pyspark import SparkConf, SparkContext +from pyspark.sql import SparkSession, SQLContext, Row +from pyspark.testing.sqlutils import ReusedSQLTestCase +from pyspark.testing.utils import PySparkTestCase + + +class SparkSessionTests(ReusedSQLTestCase): + def test_sqlcontext_reuses_sparksession(self): + sqlContext1 = SQLContext(self.sc) + sqlContext2 = SQLContext(self.sc) + self.assertTrue(sqlContext1.sparkSession is sqlContext2.sparkSession) + + +class SparkSessionTests1(ReusedSQLTestCase): + + # We can't include this test into SQLTests because we will stop class's SparkContext and cause + # other tests failed. + def test_sparksession_with_stopped_sparkcontext(self): + self.sc.stop() + sc = SparkContext('local[4]', self.sc.appName) + spark = SparkSession.builder.getOrCreate() + try: + df = spark.createDataFrame([(1, 2)], ["c", "c"]) + df.collect() + finally: + spark.stop() + sc.stop() + + +class SparkSessionTests2(PySparkTestCase): + + # This test is separate because it's closely related with session's start and stop. + # See SPARK-23228. + def test_set_jvm_default_session(self): + spark = SparkSession.builder.getOrCreate() + try: + self.assertTrue(spark._jvm.SparkSession.getDefaultSession().isDefined()) + finally: + spark.stop() + self.assertTrue(spark._jvm.SparkSession.getDefaultSession().isEmpty()) + + def test_jvm_default_session_already_set(self): + # Here, we assume there is the default session already set in JVM. + jsession = self.sc._jvm.SparkSession(self.sc._jsc.sc()) + self.sc._jvm.SparkSession.setDefaultSession(jsession) + + spark = SparkSession.builder.getOrCreate() + try: + self.assertTrue(spark._jvm.SparkSession.getDefaultSession().isDefined()) + # The session should be the same with the exiting one. + self.assertTrue(jsession.equals(spark._jvm.SparkSession.getDefaultSession().get())) + finally: + spark.stop() + + +class SparkSessionTests3(unittest.TestCase): + + def test_active_session(self): + spark = SparkSession.builder \ + .master("local") \ + .getOrCreate() + try: + activeSession = SparkSession.getActiveSession() + df = activeSession.createDataFrame([(1, 'Alice')], ['age', 'name']) + self.assertEqual(df.collect(), [Row(age=1, name=u'Alice')]) + finally: + spark.stop() + + def test_get_active_session_when_no_active_session(self): + active = SparkSession.getActiveSession() + self.assertEqual(active, None) + spark = SparkSession.builder \ + .master("local") \ + .getOrCreate() + active = SparkSession.getActiveSession() + self.assertEqual(active, spark) + spark.stop() + active = SparkSession.getActiveSession() + self.assertEqual(active, None) + + def test_SparkSession(self): + spark = SparkSession.builder \ + .master("local") \ + .config("some-config", "v2") \ + .getOrCreate() + try: + self.assertEqual(spark.conf.get("some-config"), "v2") + self.assertEqual(spark.sparkContext._conf.get("some-config"), "v2") + self.assertEqual(spark.version, spark.sparkContext.version) + spark.sql("CREATE DATABASE test_db") + spark.catalog.setCurrentDatabase("test_db") + self.assertEqual(spark.catalog.currentDatabase(), "test_db") + spark.sql("CREATE TABLE table1 (name STRING, age INT) USING parquet") + self.assertEqual(spark.table("table1").columns, ['name', 'age']) + self.assertEqual(spark.range(3).count(), 3) + finally: + spark.stop() + + def test_global_default_session(self): + spark = SparkSession.builder \ + .master("local") \ + .getOrCreate() + try: + self.assertEqual(SparkSession.builder.getOrCreate(), spark) + finally: + spark.stop() + + def test_default_and_active_session(self): + spark = SparkSession.builder \ + .master("local") \ + .getOrCreate() + activeSession = spark._jvm.SparkSession.getActiveSession() + defaultSession = spark._jvm.SparkSession.getDefaultSession() + try: + self.assertEqual(activeSession, defaultSession) + finally: + spark.stop() + + def test_config_option_propagated_to_existing_session(self): + session1 = SparkSession.builder \ + .master("local") \ + .config("spark-config1", "a") \ + .getOrCreate() + self.assertEqual(session1.conf.get("spark-config1"), "a") + session2 = SparkSession.builder \ + .config("spark-config1", "b") \ + .getOrCreate() + try: + self.assertEqual(session1, session2) + self.assertEqual(session1.conf.get("spark-config1"), "b") + finally: + session1.stop() + + def test_new_session(self): + session = SparkSession.builder \ + .master("local") \ + .getOrCreate() + newSession = session.newSession() + try: + self.assertNotEqual(session, newSession) + finally: + session.stop() + newSession.stop() + + def test_create_new_session_if_old_session_stopped(self): + session = SparkSession.builder \ + .master("local") \ + .getOrCreate() + session.stop() + newSession = SparkSession.builder \ + .master("local") \ + .getOrCreate() + try: + self.assertNotEqual(session, newSession) + finally: + newSession.stop() + + def test_active_session_with_None_and_not_None_context(self): + from pyspark.context import SparkContext + from pyspark.conf import SparkConf + sc = None + session = None + try: + sc = SparkContext._active_spark_context + self.assertEqual(sc, None) + activeSession = SparkSession.getActiveSession() + self.assertEqual(activeSession, None) + sparkConf = SparkConf() + sc = SparkContext.getOrCreate(sparkConf) + activeSession = sc._jvm.SparkSession.getActiveSession() + self.assertFalse(activeSession.isDefined()) + session = SparkSession(sc) + activeSession = sc._jvm.SparkSession.getActiveSession() + self.assertTrue(activeSession.isDefined()) + activeSession2 = SparkSession.getActiveSession() + self.assertNotEqual(activeSession2, None) + finally: + if session is not None: + session.stop() + if sc is not None: + sc.stop() + + +class SparkSessionTests4(ReusedSQLTestCase): + + def test_get_active_session_after_create_dataframe(self): + session2 = None + try: + activeSession1 = SparkSession.getActiveSession() + session1 = self.spark + self.assertEqual(session1, activeSession1) + session2 = self.spark.newSession() + activeSession2 = SparkSession.getActiveSession() + self.assertEqual(session1, activeSession2) + self.assertNotEqual(session2, activeSession2) + session2.createDataFrame([(1, 'Alice')], ['age', 'name']) + activeSession3 = SparkSession.getActiveSession() + self.assertEqual(session2, activeSession3) + session1.createDataFrame([(1, 'Alice')], ['age', 'name']) + activeSession4 = SparkSession.getActiveSession() + self.assertEqual(session1, activeSession4) + finally: + if session2 is not None: + session2.stop() + + +class SparkSessionBuilderTests(unittest.TestCase): + + def test_create_spark_context_first_then_spark_session(self): + sc = None + session = None + try: + conf = SparkConf().set("key1", "value1") + sc = SparkContext('local[4]', "SessionBuilderTests", conf=conf) + session = SparkSession.builder.config("key2", "value2").getOrCreate() + + self.assertEqual(session.conf.get("key1"), "value1") + self.assertEqual(session.conf.get("key2"), "value2") + self.assertEqual(session.sparkContext, sc) + + self.assertFalse(sc.getConf().contains("key2")) + self.assertEqual(sc.getConf().get("key1"), "value1") + finally: + if session is not None: + session.stop() + if sc is not None: + sc.stop() + + def test_another_spark_session(self): + session1 = None + session2 = None + try: + session1 = SparkSession.builder.config("key1", "value1").getOrCreate() + session2 = SparkSession.builder.config("key2", "value2").getOrCreate() + + self.assertEqual(session1.conf.get("key1"), "value1") + self.assertEqual(session2.conf.get("key1"), "value1") + self.assertEqual(session1.conf.get("key2"), "value2") + self.assertEqual(session2.conf.get("key2"), "value2") + self.assertEqual(session1.sparkContext, session2.sparkContext) + + self.assertEqual(session1.sparkContext.getConf().get("key1"), "value1") + self.assertFalse(session1.sparkContext.getConf().contains("key2")) + finally: + if session1 is not None: + session1.stop() + if session2 is not None: + session2.stop() + + +class SparkExtensionsTest(unittest.TestCase): + # These tests are separate because it uses 'spark.sql.extensions' which is + # static and immutable. This can't be set or unset, for example, via `spark.conf`. + + @classmethod + def setUpClass(cls): + import glob + from pyspark.find_spark_home import _find_spark_home + + SPARK_HOME = _find_spark_home() + filename_pattern = ( + "sql/core/target/scala-*/test-classes/org/apache/spark/sql/" + "SparkSessionExtensionSuite.class") + if not glob.glob(os.path.join(SPARK_HOME, filename_pattern)): + raise unittest.SkipTest( + "'org.apache.spark.sql.SparkSessionExtensionSuite' is not " + "available. Will skip the related tests.") + + # Note that 'spark.sql.extensions' is a static immutable configuration. + cls.spark = SparkSession.builder \ + .master("local[4]") \ + .appName(cls.__name__) \ + .config( + "spark.sql.extensions", + "org.apache.spark.sql.MyExtensions") \ + .getOrCreate() + + @classmethod + def tearDownClass(cls): + cls.spark.stop() + + def test_use_custom_class_for_extensions(self): + self.assertTrue( + self.spark._jsparkSession.sessionState().planner().strategies().contains( + self.spark._jvm.org.apache.spark.sql.MySparkStrategy(self.spark._jsparkSession)), + "MySparkStrategy not found in active planner strategies") + self.assertTrue( + self.spark._jsparkSession.sessionState().analyzer().extendedResolutionRules().contains( + self.spark._jvm.org.apache.spark.sql.MyRule(self.spark._jsparkSession)), + "MyRule not found in extended resolution rules") + + +if __name__ == "__main__": + from pyspark.sql.tests.test_session import * + + try: + import xmlrunner + testRunner = xmlrunner.XMLTestRunner(output='target/test-reports') + except ImportError: + testRunner = None + unittest.main(testRunner=testRunner, verbosity=2) diff --git a/python/pyspark/sql/tests/test_streaming.py b/python/pyspark/sql/tests/test_streaming.py new file mode 100644 index 0000000000000..4b71759f74a55 --- /dev/null +++ b/python/pyspark/sql/tests/test_streaming.py @@ -0,0 +1,567 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import os +import shutil +import tempfile +import time + +from pyspark.sql.functions import lit +from pyspark.sql.types import * +from pyspark.testing.sqlutils import ReusedSQLTestCase + + +class StreamingTests(ReusedSQLTestCase): + + def test_stream_trigger(self): + df = self.spark.readStream.format('text').load('python/test_support/sql/streaming') + + # Should take at least one arg + try: + df.writeStream.trigger() + except ValueError: + pass + + # Should not take multiple args + try: + df.writeStream.trigger(once=True, processingTime='5 seconds') + except ValueError: + pass + + # Should not take multiple args + try: + df.writeStream.trigger(processingTime='5 seconds', continuous='1 second') + except ValueError: + pass + + # Should take only keyword args + try: + df.writeStream.trigger('5 seconds') + self.fail("Should have thrown an exception") + except TypeError: + pass + + def test_stream_read_options(self): + schema = StructType([StructField("data", StringType(), False)]) + df = self.spark.readStream\ + .format('text')\ + .option('path', 'python/test_support/sql/streaming')\ + .schema(schema)\ + .load() + self.assertTrue(df.isStreaming) + self.assertEqual(df.schema.simpleString(), "struct") + + def test_stream_read_options_overwrite(self): + bad_schema = StructType([StructField("test", IntegerType(), False)]) + schema = StructType([StructField("data", StringType(), False)]) + df = self.spark.readStream.format('csv').option('path', 'python/test_support/sql/fake') \ + .schema(bad_schema)\ + .load(path='python/test_support/sql/streaming', schema=schema, format='text') + self.assertTrue(df.isStreaming) + self.assertEqual(df.schema.simpleString(), "struct") + + def test_stream_save_options(self): + df = self.spark.readStream.format('text').load('python/test_support/sql/streaming') \ + .withColumn('id', lit(1)) + for q in self.spark._wrapped.streams.active: + q.stop() + tmpPath = tempfile.mkdtemp() + shutil.rmtree(tmpPath) + self.assertTrue(df.isStreaming) + out = os.path.join(tmpPath, 'out') + chk = os.path.join(tmpPath, 'chk') + q = df.writeStream.option('checkpointLocation', chk).queryName('this_query') \ + .format('parquet').partitionBy('id').outputMode('append').option('path', out).start() + try: + self.assertEqual(q.name, 'this_query') + self.assertTrue(q.isActive) + q.processAllAvailable() + output_files = [] + for _, _, files in os.walk(out): + output_files.extend([f for f in files if not f.startswith('.')]) + self.assertTrue(len(output_files) > 0) + self.assertTrue(len(os.listdir(chk)) > 0) + finally: + q.stop() + shutil.rmtree(tmpPath) + + def test_stream_save_options_overwrite(self): + df = self.spark.readStream.format('text').load('python/test_support/sql/streaming') + for q in self.spark._wrapped.streams.active: + q.stop() + tmpPath = tempfile.mkdtemp() + shutil.rmtree(tmpPath) + self.assertTrue(df.isStreaming) + out = os.path.join(tmpPath, 'out') + chk = os.path.join(tmpPath, 'chk') + fake1 = os.path.join(tmpPath, 'fake1') + fake2 = os.path.join(tmpPath, 'fake2') + q = df.writeStream.option('checkpointLocation', fake1)\ + .format('memory').option('path', fake2) \ + .queryName('fake_query').outputMode('append') \ + .start(path=out, format='parquet', queryName='this_query', checkpointLocation=chk) + + try: + self.assertEqual(q.name, 'this_query') + self.assertTrue(q.isActive) + q.processAllAvailable() + output_files = [] + for _, _, files in os.walk(out): + output_files.extend([f for f in files if not f.startswith('.')]) + self.assertTrue(len(output_files) > 0) + self.assertTrue(len(os.listdir(chk)) > 0) + self.assertFalse(os.path.isdir(fake1)) # should not have been created + self.assertFalse(os.path.isdir(fake2)) # should not have been created + finally: + q.stop() + shutil.rmtree(tmpPath) + + def test_stream_status_and_progress(self): + df = self.spark.readStream.format('text').load('python/test_support/sql/streaming') + for q in self.spark._wrapped.streams.active: + q.stop() + tmpPath = tempfile.mkdtemp() + shutil.rmtree(tmpPath) + self.assertTrue(df.isStreaming) + out = os.path.join(tmpPath, 'out') + chk = os.path.join(tmpPath, 'chk') + + def func(x): + time.sleep(1) + return x + + from pyspark.sql.functions import col, udf + sleep_udf = udf(func) + + # Use "sleep_udf" to delay the progress update so that we can test `lastProgress` when there + # were no updates. + q = df.select(sleep_udf(col("value")).alias('value')).writeStream \ + .start(path=out, format='parquet', queryName='this_query', checkpointLocation=chk) + try: + # "lastProgress" will return None in most cases. However, as it may be flaky when + # Jenkins is very slow, we don't assert it. If there is something wrong, "lastProgress" + # may throw error with a high chance and make this test flaky, so we should still be + # able to detect broken codes. + q.lastProgress + + q.processAllAvailable() + lastProgress = q.lastProgress + recentProgress = q.recentProgress + status = q.status + self.assertEqual(lastProgress['name'], q.name) + self.assertEqual(lastProgress['id'], q.id) + self.assertTrue(any(p == lastProgress for p in recentProgress)) + self.assertTrue( + "message" in status and + "isDataAvailable" in status and + "isTriggerActive" in status) + finally: + q.stop() + shutil.rmtree(tmpPath) + + def test_stream_await_termination(self): + df = self.spark.readStream.format('text').load('python/test_support/sql/streaming') + for q in self.spark._wrapped.streams.active: + q.stop() + tmpPath = tempfile.mkdtemp() + shutil.rmtree(tmpPath) + self.assertTrue(df.isStreaming) + out = os.path.join(tmpPath, 'out') + chk = os.path.join(tmpPath, 'chk') + q = df.writeStream\ + .start(path=out, format='parquet', queryName='this_query', checkpointLocation=chk) + try: + self.assertTrue(q.isActive) + try: + q.awaitTermination("hello") + self.fail("Expected a value exception") + except ValueError: + pass + now = time.time() + # test should take at least 2 seconds + res = q.awaitTermination(2.6) + duration = time.time() - now + self.assertTrue(duration >= 2) + self.assertFalse(res) + finally: + q.stop() + shutil.rmtree(tmpPath) + + def test_stream_exception(self): + sdf = self.spark.readStream.format('text').load('python/test_support/sql/streaming') + sq = sdf.writeStream.format('memory').queryName('query_explain').start() + try: + sq.processAllAvailable() + self.assertEqual(sq.exception(), None) + finally: + sq.stop() + + from pyspark.sql.functions import col, udf + from pyspark.sql.utils import StreamingQueryException + bad_udf = udf(lambda x: 1 / 0) + sq = sdf.select(bad_udf(col("value")))\ + .writeStream\ + .format('memory')\ + .queryName('this_query')\ + .start() + try: + # Process some data to fail the query + sq.processAllAvailable() + self.fail("bad udf should fail the query") + except StreamingQueryException as e: + # This is expected + self.assertTrue("ZeroDivisionError" in e.desc) + finally: + sq.stop() + self.assertTrue(type(sq.exception()) is StreamingQueryException) + self.assertTrue("ZeroDivisionError" in sq.exception().desc) + + def test_query_manager_await_termination(self): + df = self.spark.readStream.format('text').load('python/test_support/sql/streaming') + for q in self.spark._wrapped.streams.active: + q.stop() + tmpPath = tempfile.mkdtemp() + shutil.rmtree(tmpPath) + self.assertTrue(df.isStreaming) + out = os.path.join(tmpPath, 'out') + chk = os.path.join(tmpPath, 'chk') + q = df.writeStream\ + .start(path=out, format='parquet', queryName='this_query', checkpointLocation=chk) + try: + self.assertTrue(q.isActive) + try: + self.spark._wrapped.streams.awaitAnyTermination("hello") + self.fail("Expected a value exception") + except ValueError: + pass + now = time.time() + # test should take at least 2 seconds + res = self.spark._wrapped.streams.awaitAnyTermination(2.6) + duration = time.time() - now + self.assertTrue(duration >= 2) + self.assertFalse(res) + finally: + q.stop() + shutil.rmtree(tmpPath) + + class ForeachWriterTester: + + def __init__(self, spark): + self.spark = spark + + def write_open_event(self, partitionId, epochId): + self._write_event( + self.open_events_dir, + {'partition': partitionId, 'epoch': epochId}) + + def write_process_event(self, row): + self._write_event(self.process_events_dir, {'value': 'text'}) + + def write_close_event(self, error): + self._write_event(self.close_events_dir, {'error': str(error)}) + + def write_input_file(self): + self._write_event(self.input_dir, "text") + + def open_events(self): + return self._read_events(self.open_events_dir, 'partition INT, epoch INT') + + def process_events(self): + return self._read_events(self.process_events_dir, 'value STRING') + + def close_events(self): + return self._read_events(self.close_events_dir, 'error STRING') + + def run_streaming_query_on_writer(self, writer, num_files): + self._reset() + try: + sdf = self.spark.readStream.format('text').load(self.input_dir) + sq = sdf.writeStream.foreach(writer).start() + for i in range(num_files): + self.write_input_file() + sq.processAllAvailable() + finally: + self.stop_all() + + def assert_invalid_writer(self, writer, msg=None): + self._reset() + try: + sdf = self.spark.readStream.format('text').load(self.input_dir) + sq = sdf.writeStream.foreach(writer).start() + self.write_input_file() + sq.processAllAvailable() + self.fail("invalid writer %s did not fail the query" % str(writer)) # not expected + except Exception as e: + if msg: + assert msg in str(e), "%s not in %s" % (msg, str(e)) + + finally: + self.stop_all() + + def stop_all(self): + for q in self.spark._wrapped.streams.active: + q.stop() + + def _reset(self): + self.input_dir = tempfile.mkdtemp() + self.open_events_dir = tempfile.mkdtemp() + self.process_events_dir = tempfile.mkdtemp() + self.close_events_dir = tempfile.mkdtemp() + + def _read_events(self, dir, json): + rows = self.spark.read.schema(json).json(dir).collect() + dicts = [row.asDict() for row in rows] + return dicts + + def _write_event(self, dir, event): + import uuid + with open(os.path.join(dir, str(uuid.uuid4())), 'w') as f: + f.write("%s\n" % str(event)) + + def __getstate__(self): + return (self.open_events_dir, self.process_events_dir, self.close_events_dir) + + def __setstate__(self, state): + self.open_events_dir, self.process_events_dir, self.close_events_dir = state + + # Those foreach tests are failed in Python 3.6 and macOS High Sierra by defined rules + # at http://sealiesoftware.com/blog/archive/2017/6/5/Objective-C_and_fork_in_macOS_1013.html + # To work around this, OBJC_DISABLE_INITIALIZE_FORK_SAFETY=YES. + def test_streaming_foreach_with_simple_function(self): + tester = self.ForeachWriterTester(self.spark) + + def foreach_func(row): + tester.write_process_event(row) + + tester.run_streaming_query_on_writer(foreach_func, 2) + self.assertEqual(len(tester.process_events()), 2) + + def test_streaming_foreach_with_basic_open_process_close(self): + tester = self.ForeachWriterTester(self.spark) + + class ForeachWriter: + def open(self, partitionId, epochId): + tester.write_open_event(partitionId, epochId) + return True + + def process(self, row): + tester.write_process_event(row) + + def close(self, error): + tester.write_close_event(error) + + tester.run_streaming_query_on_writer(ForeachWriter(), 2) + + open_events = tester.open_events() + self.assertEqual(len(open_events), 2) + self.assertSetEqual(set([e['epoch'] for e in open_events]), {0, 1}) + + self.assertEqual(len(tester.process_events()), 2) + + close_events = tester.close_events() + self.assertEqual(len(close_events), 2) + self.assertSetEqual(set([e['error'] for e in close_events]), {'None'}) + + def test_streaming_foreach_with_open_returning_false(self): + tester = self.ForeachWriterTester(self.spark) + + class ForeachWriter: + def open(self, partition_id, epoch_id): + tester.write_open_event(partition_id, epoch_id) + return False + + def process(self, row): + tester.write_process_event(row) + + def close(self, error): + tester.write_close_event(error) + + tester.run_streaming_query_on_writer(ForeachWriter(), 2) + + self.assertEqual(len(tester.open_events()), 2) + + self.assertEqual(len(tester.process_events()), 0) # no row was processed + + close_events = tester.close_events() + self.assertEqual(len(close_events), 2) + self.assertSetEqual(set([e['error'] for e in close_events]), {'None'}) + + def test_streaming_foreach_without_open_method(self): + tester = self.ForeachWriterTester(self.spark) + + class ForeachWriter: + def process(self, row): + tester.write_process_event(row) + + def close(self, error): + tester.write_close_event(error) + + tester.run_streaming_query_on_writer(ForeachWriter(), 2) + self.assertEqual(len(tester.open_events()), 0) # no open events + self.assertEqual(len(tester.process_events()), 2) + self.assertEqual(len(tester.close_events()), 2) + + def test_streaming_foreach_without_close_method(self): + tester = self.ForeachWriterTester(self.spark) + + class ForeachWriter: + def open(self, partition_id, epoch_id): + tester.write_open_event(partition_id, epoch_id) + return True + + def process(self, row): + tester.write_process_event(row) + + tester.run_streaming_query_on_writer(ForeachWriter(), 2) + self.assertEqual(len(tester.open_events()), 2) # no open events + self.assertEqual(len(tester.process_events()), 2) + self.assertEqual(len(tester.close_events()), 0) + + def test_streaming_foreach_without_open_and_close_methods(self): + tester = self.ForeachWriterTester(self.spark) + + class ForeachWriter: + def process(self, row): + tester.write_process_event(row) + + tester.run_streaming_query_on_writer(ForeachWriter(), 2) + self.assertEqual(len(tester.open_events()), 0) # no open events + self.assertEqual(len(tester.process_events()), 2) + self.assertEqual(len(tester.close_events()), 0) + + def test_streaming_foreach_with_process_throwing_error(self): + from pyspark.sql.utils import StreamingQueryException + + tester = self.ForeachWriterTester(self.spark) + + class ForeachWriter: + def process(self, row): + raise Exception("test error") + + def close(self, error): + tester.write_close_event(error) + + try: + tester.run_streaming_query_on_writer(ForeachWriter(), 1) + self.fail("bad writer did not fail the query") # this is not expected + except StreamingQueryException as e: + # TODO: Verify whether original error message is inside the exception + pass + + self.assertEqual(len(tester.process_events()), 0) # no row was processed + close_events = tester.close_events() + self.assertEqual(len(close_events), 1) + # TODO: Verify whether original error message is inside the exception + + def test_streaming_foreach_with_invalid_writers(self): + + tester = self.ForeachWriterTester(self.spark) + + def func_with_iterator_input(iter): + for x in iter: + print(x) + + tester.assert_invalid_writer(func_with_iterator_input) + + class WriterWithoutProcess: + def open(self, partition): + pass + + tester.assert_invalid_writer(WriterWithoutProcess(), "does not have a 'process'") + + class WriterWithNonCallableProcess(): + process = True + + tester.assert_invalid_writer(WriterWithNonCallableProcess(), + "'process' in provided object is not callable") + + class WriterWithNoParamProcess(): + def process(self): + pass + + tester.assert_invalid_writer(WriterWithNoParamProcess()) + + # Abstract class for tests below + class WithProcess(): + def process(self, row): + pass + + class WriterWithNonCallableOpen(WithProcess): + open = True + + tester.assert_invalid_writer(WriterWithNonCallableOpen(), + "'open' in provided object is not callable") + + class WriterWithNoParamOpen(WithProcess): + def open(self): + pass + + tester.assert_invalid_writer(WriterWithNoParamOpen()) + + class WriterWithNonCallableClose(WithProcess): + close = True + + tester.assert_invalid_writer(WriterWithNonCallableClose(), + "'close' in provided object is not callable") + + def test_streaming_foreachBatch(self): + q = None + collected = dict() + + def collectBatch(batch_df, batch_id): + collected[batch_id] = batch_df.collect() + + try: + df = self.spark.readStream.format('text').load('python/test_support/sql/streaming') + q = df.writeStream.foreachBatch(collectBatch).start() + q.processAllAvailable() + self.assertTrue(0 in collected) + self.assertTrue(len(collected[0]), 2) + finally: + if q: + q.stop() + + def test_streaming_foreachBatch_propagates_python_errors(self): + from pyspark.sql.utils import StreamingQueryException + + q = None + + def collectBatch(df, id): + raise Exception("this should fail the query") + + try: + df = self.spark.readStream.format('text').load('python/test_support/sql/streaming') + q = df.writeStream.foreachBatch(collectBatch).start() + q.processAllAvailable() + self.fail("Expected a failure") + except StreamingQueryException as e: + self.assertTrue("this should fail" in str(e)) + finally: + if q: + q.stop() + + +if __name__ == "__main__": + import unittest + from pyspark.sql.tests.test_streaming import * + + try: + import xmlrunner + testRunner = xmlrunner.XMLTestRunner(output='target/test-reports') + except ImportError: + testRunner = None + unittest.main(testRunner=testRunner, verbosity=2) diff --git a/python/pyspark/sql/tests/test_types.py b/python/pyspark/sql/tests/test_types.py new file mode 100644 index 0000000000000..fb673f2a385ef --- /dev/null +++ b/python/pyspark/sql/tests/test_types.py @@ -0,0 +1,945 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import array +import ctypes +import datetime +import os +import pickle +import sys +import unittest + +from pyspark.sql import Row +from pyspark.sql.functions import UserDefinedFunction +from pyspark.sql.types import * +from pyspark.sql.types import _array_signed_int_typecode_ctype_mappings, _array_type_mappings, \ + _array_unsigned_int_typecode_ctype_mappings, _infer_type, _make_type_verifier, _merge_type +from pyspark.testing.sqlutils import ReusedSQLTestCase, ExamplePointUDT, PythonOnlyUDT, \ + ExamplePoint, PythonOnlyPoint, MyObject + + +class TypesTests(ReusedSQLTestCase): + + def test_apply_schema_to_row(self): + df = self.spark.read.json(self.sc.parallelize(["""{"a":2}"""])) + df2 = self.spark.createDataFrame(df.rdd.map(lambda x: x), df.schema) + self.assertEqual(df.collect(), df2.collect()) + + rdd = self.sc.parallelize(range(10)).map(lambda x: Row(a=x)) + df3 = self.spark.createDataFrame(rdd, df.schema) + self.assertEqual(10, df3.count()) + + def test_infer_schema_to_local(self): + input = [{"a": 1}, {"b": "coffee"}] + rdd = self.sc.parallelize(input) + df = self.spark.createDataFrame(input) + df2 = self.spark.createDataFrame(rdd, samplingRatio=1.0) + self.assertEqual(df.schema, df2.schema) + + rdd = self.sc.parallelize(range(10)).map(lambda x: Row(a=x, b=None)) + df3 = self.spark.createDataFrame(rdd, df.schema) + self.assertEqual(10, df3.count()) + + def test_apply_schema_to_dict_and_rows(self): + schema = StructType().add("b", StringType()).add("a", IntegerType()) + input = [{"a": 1}, {"b": "coffee"}] + rdd = self.sc.parallelize(input) + for verify in [False, True]: + df = self.spark.createDataFrame(input, schema, verifySchema=verify) + df2 = self.spark.createDataFrame(rdd, schema, verifySchema=verify) + self.assertEqual(df.schema, df2.schema) + + rdd = self.sc.parallelize(range(10)).map(lambda x: Row(a=x, b=None)) + df3 = self.spark.createDataFrame(rdd, schema, verifySchema=verify) + self.assertEqual(10, df3.count()) + input = [Row(a=x, b=str(x)) for x in range(10)] + df4 = self.spark.createDataFrame(input, schema, verifySchema=verify) + self.assertEqual(10, df4.count()) + + def test_create_dataframe_schema_mismatch(self): + input = [Row(a=1)] + rdd = self.sc.parallelize(range(3)).map(lambda i: Row(a=i)) + schema = StructType([StructField("a", IntegerType()), StructField("b", StringType())]) + df = self.spark.createDataFrame(rdd, schema) + self.assertRaises(Exception, lambda: df.show()) + + def test_infer_schema(self): + d = [Row(l=[], d={}, s=None), + Row(l=[Row(a=1, b='s')], d={"key": Row(c=1.0, d="2")}, s="")] + rdd = self.sc.parallelize(d) + df = self.spark.createDataFrame(rdd) + self.assertEqual([], df.rdd.map(lambda r: r.l).first()) + self.assertEqual([None, ""], df.rdd.map(lambda r: r.s).collect()) + + with self.tempView("test"): + df.createOrReplaceTempView("test") + result = self.spark.sql("SELECT l[0].a from test where d['key'].d = '2'") + self.assertEqual(1, result.head()[0]) + + df2 = self.spark.createDataFrame(rdd, samplingRatio=1.0) + self.assertEqual(df.schema, df2.schema) + self.assertEqual({}, df2.rdd.map(lambda r: r.d).first()) + self.assertEqual([None, ""], df2.rdd.map(lambda r: r.s).collect()) + + with self.tempView("test2"): + df2.createOrReplaceTempView("test2") + result = self.spark.sql("SELECT l[0].a from test2 where d['key'].d = '2'") + self.assertEqual(1, result.head()[0]) + + def test_infer_schema_specification(self): + from decimal import Decimal + + class A(object): + def __init__(self): + self.a = 1 + + data = [ + True, + 1, + "a", + u"a", + datetime.date(1970, 1, 1), + datetime.datetime(1970, 1, 1, 0, 0), + 1.0, + array.array("d", [1]), + [1], + (1, ), + {"a": 1}, + bytearray(1), + Decimal(1), + Row(a=1), + Row("a")(1), + A(), + ] + + df = self.spark.createDataFrame([data]) + actual = list(map(lambda x: x.dataType.simpleString(), df.schema)) + expected = [ + 'boolean', + 'bigint', + 'string', + 'string', + 'date', + 'timestamp', + 'double', + 'array', + 'array', + 'struct<_1:bigint>', + 'map', + 'binary', + 'decimal(38,18)', + 'struct', + 'struct', + 'struct', + ] + self.assertEqual(actual, expected) + + actual = list(df.first()) + expected = [ + True, + 1, + 'a', + u"a", + datetime.date(1970, 1, 1), + datetime.datetime(1970, 1, 1, 0, 0), + 1.0, + [1.0], + [1], + Row(_1=1), + {"a": 1}, + bytearray(b'\x00'), + Decimal('1.000000000000000000'), + Row(a=1), + Row(a=1), + Row(a=1), + ] + self.assertEqual(actual, expected) + + def test_infer_schema_not_enough_names(self): + df = self.spark.createDataFrame([["a", "b"]], ["col1"]) + self.assertEqual(df.columns, ['col1', '_2']) + + def test_infer_schema_fails(self): + with self.assertRaisesRegexp(TypeError, 'field a'): + self.spark.createDataFrame(self.spark.sparkContext.parallelize([[1, 1], ["x", 1]]), + schema=["a", "b"], samplingRatio=0.99) + + def test_infer_nested_schema(self): + NestedRow = Row("f1", "f2") + nestedRdd1 = self.sc.parallelize([NestedRow([1, 2], {"row1": 1.0}), + NestedRow([2, 3], {"row2": 2.0})]) + df = self.spark.createDataFrame(nestedRdd1) + self.assertEqual(Row(f1=[1, 2], f2={u'row1': 1.0}), df.collect()[0]) + + nestedRdd2 = self.sc.parallelize([NestedRow([[1, 2], [2, 3]], [1, 2]), + NestedRow([[2, 3], [3, 4]], [2, 3])]) + df = self.spark.createDataFrame(nestedRdd2) + self.assertEqual(Row(f1=[[1, 2], [2, 3]], f2=[1, 2]), df.collect()[0]) + + from collections import namedtuple + CustomRow = namedtuple('CustomRow', 'field1 field2') + rdd = self.sc.parallelize([CustomRow(field1=1, field2="row1"), + CustomRow(field1=2, field2="row2"), + CustomRow(field1=3, field2="row3")]) + df = self.spark.createDataFrame(rdd) + self.assertEqual(Row(field1=1, field2=u'row1'), df.first()) + + def test_create_dataframe_from_dict_respects_schema(self): + df = self.spark.createDataFrame([{'a': 1}], ["b"]) + self.assertEqual(df.columns, ['b']) + + def test_create_dataframe_from_objects(self): + data = [MyObject(1, "1"), MyObject(2, "2")] + df = self.spark.createDataFrame(data) + self.assertEqual(df.dtypes, [("key", "bigint"), ("value", "string")]) + self.assertEqual(df.first(), Row(key=1, value="1")) + + def test_apply_schema(self): + from datetime import date, datetime + rdd = self.sc.parallelize([(127, -128, -32768, 32767, 2147483647, 1.0, + date(2010, 1, 1), datetime(2010, 1, 1, 1, 1, 1), + {"a": 1}, (2,), [1, 2, 3], None)]) + schema = StructType([ + StructField("byte1", ByteType(), False), + StructField("byte2", ByteType(), False), + StructField("short1", ShortType(), False), + StructField("short2", ShortType(), False), + StructField("int1", IntegerType(), False), + StructField("float1", FloatType(), False), + StructField("date1", DateType(), False), + StructField("time1", TimestampType(), False), + StructField("map1", MapType(StringType(), IntegerType(), False), False), + StructField("struct1", StructType([StructField("b", ShortType(), False)]), False), + StructField("list1", ArrayType(ByteType(), False), False), + StructField("null1", DoubleType(), True)]) + df = self.spark.createDataFrame(rdd, schema) + results = df.rdd.map(lambda x: (x.byte1, x.byte2, x.short1, x.short2, x.int1, x.float1, + x.date1, x.time1, x.map1["a"], x.struct1.b, x.list1, x.null1)) + r = (127, -128, -32768, 32767, 2147483647, 1.0, date(2010, 1, 1), + datetime(2010, 1, 1, 1, 1, 1), 1, 2, [1, 2, 3], None) + self.assertEqual(r, results.first()) + + with self.tempView("table2"): + df.createOrReplaceTempView("table2") + r = self.spark.sql("SELECT byte1 - 1 AS byte1, byte2 + 1 AS byte2, " + + "short1 + 1 AS short1, short2 - 1 AS short2, int1 - 1 AS int1, " + + "float1 + 1.5 as float1 FROM table2").first() + + self.assertEqual((126, -127, -32767, 32766, 2147483646, 2.5), tuple(r)) + + def test_convert_row_to_dict(self): + row = Row(l=[Row(a=1, b='s')], d={"key": Row(c=1.0, d="2")}) + self.assertEqual(1, row.asDict()['l'][0].a) + df = self.sc.parallelize([row]).toDF() + + with self.tempView("test"): + df.createOrReplaceTempView("test") + row = self.spark.sql("select l, d from test").head() + self.assertEqual(1, row.asDict()["l"][0].a) + self.assertEqual(1.0, row.asDict()['d']['key'].c) + + def test_udt(self): + from pyspark.sql.types import _parse_datatype_json_string, _infer_type, _make_type_verifier + + def check_datatype(datatype): + pickled = pickle.loads(pickle.dumps(datatype)) + assert datatype == pickled + scala_datatype = self.spark._jsparkSession.parseDataType(datatype.json()) + python_datatype = _parse_datatype_json_string(scala_datatype.json()) + assert datatype == python_datatype + + check_datatype(ExamplePointUDT()) + structtype_with_udt = StructType([StructField("label", DoubleType(), False), + StructField("point", ExamplePointUDT(), False)]) + check_datatype(structtype_with_udt) + p = ExamplePoint(1.0, 2.0) + self.assertEqual(_infer_type(p), ExamplePointUDT()) + _make_type_verifier(ExamplePointUDT())(ExamplePoint(1.0, 2.0)) + self.assertRaises(ValueError, lambda: _make_type_verifier(ExamplePointUDT())([1.0, 2.0])) + + check_datatype(PythonOnlyUDT()) + structtype_with_udt = StructType([StructField("label", DoubleType(), False), + StructField("point", PythonOnlyUDT(), False)]) + check_datatype(structtype_with_udt) + p = PythonOnlyPoint(1.0, 2.0) + self.assertEqual(_infer_type(p), PythonOnlyUDT()) + _make_type_verifier(PythonOnlyUDT())(PythonOnlyPoint(1.0, 2.0)) + self.assertRaises( + ValueError, + lambda: _make_type_verifier(PythonOnlyUDT())([1.0, 2.0])) + + def test_simple_udt_in_df(self): + schema = StructType().add("key", LongType()).add("val", PythonOnlyUDT()) + df = self.spark.createDataFrame( + [(i % 3, PythonOnlyPoint(float(i), float(i))) for i in range(10)], + schema=schema) + df.collect() + + def test_nested_udt_in_df(self): + schema = StructType().add("key", LongType()).add("val", ArrayType(PythonOnlyUDT())) + df = self.spark.createDataFrame( + [(i % 3, [PythonOnlyPoint(float(i), float(i))]) for i in range(10)], + schema=schema) + df.collect() + + schema = StructType().add("key", LongType()).add("val", + MapType(LongType(), PythonOnlyUDT())) + df = self.spark.createDataFrame( + [(i % 3, {i % 3: PythonOnlyPoint(float(i + 1), float(i + 1))}) for i in range(10)], + schema=schema) + df.collect() + + def test_complex_nested_udt_in_df(self): + from pyspark.sql.functions import udf + + schema = StructType().add("key", LongType()).add("val", PythonOnlyUDT()) + df = self.spark.createDataFrame( + [(i % 3, PythonOnlyPoint(float(i), float(i))) for i in range(10)], + schema=schema) + df.collect() + + gd = df.groupby("key").agg({"val": "collect_list"}) + gd.collect() + udf = udf(lambda k, v: [(k, v[0])], ArrayType(df.schema)) + gd.select(udf(*gd)).collect() + + def test_udt_with_none(self): + df = self.spark.range(0, 10, 1, 1) + + def myudf(x): + if x > 0: + return PythonOnlyPoint(float(x), float(x)) + + self.spark.catalog.registerFunction("udf", myudf, PythonOnlyUDT()) + rows = [r[0] for r in df.selectExpr("udf(id)").take(2)] + self.assertEqual(rows, [None, PythonOnlyPoint(1, 1)]) + + def test_infer_schema_with_udt(self): + row = Row(label=1.0, point=ExamplePoint(1.0, 2.0)) + df = self.spark.createDataFrame([row]) + schema = df.schema + field = [f for f in schema.fields if f.name == "point"][0] + self.assertEqual(type(field.dataType), ExamplePointUDT) + + with self.tempView("labeled_point"): + df.createOrReplaceTempView("labeled_point") + point = self.spark.sql("SELECT point FROM labeled_point").head().point + self.assertEqual(point, ExamplePoint(1.0, 2.0)) + + row = Row(label=1.0, point=PythonOnlyPoint(1.0, 2.0)) + df = self.spark.createDataFrame([row]) + schema = df.schema + field = [f for f in schema.fields if f.name == "point"][0] + self.assertEqual(type(field.dataType), PythonOnlyUDT) + + with self.tempView("labeled_point"): + df.createOrReplaceTempView("labeled_point") + point = self.spark.sql("SELECT point FROM labeled_point").head().point + self.assertEqual(point, PythonOnlyPoint(1.0, 2.0)) + + def test_apply_schema_with_udt(self): + row = (1.0, ExamplePoint(1.0, 2.0)) + schema = StructType([StructField("label", DoubleType(), False), + StructField("point", ExamplePointUDT(), False)]) + df = self.spark.createDataFrame([row], schema) + point = df.head().point + self.assertEqual(point, ExamplePoint(1.0, 2.0)) + + row = (1.0, PythonOnlyPoint(1.0, 2.0)) + schema = StructType([StructField("label", DoubleType(), False), + StructField("point", PythonOnlyUDT(), False)]) + df = self.spark.createDataFrame([row], schema) + point = df.head().point + self.assertEqual(point, PythonOnlyPoint(1.0, 2.0)) + + def test_udf_with_udt(self): + row = Row(label=1.0, point=ExamplePoint(1.0, 2.0)) + df = self.spark.createDataFrame([row]) + self.assertEqual(1.0, df.rdd.map(lambda r: r.point.x).first()) + udf = UserDefinedFunction(lambda p: p.y, DoubleType()) + self.assertEqual(2.0, df.select(udf(df.point)).first()[0]) + udf2 = UserDefinedFunction(lambda p: ExamplePoint(p.x + 1, p.y + 1), ExamplePointUDT()) + self.assertEqual(ExamplePoint(2.0, 3.0), df.select(udf2(df.point)).first()[0]) + + row = Row(label=1.0, point=PythonOnlyPoint(1.0, 2.0)) + df = self.spark.createDataFrame([row]) + self.assertEqual(1.0, df.rdd.map(lambda r: r.point.x).first()) + udf = UserDefinedFunction(lambda p: p.y, DoubleType()) + self.assertEqual(2.0, df.select(udf(df.point)).first()[0]) + udf2 = UserDefinedFunction(lambda p: PythonOnlyPoint(p.x + 1, p.y + 1), PythonOnlyUDT()) + self.assertEqual(PythonOnlyPoint(2.0, 3.0), df.select(udf2(df.point)).first()[0]) + + def test_parquet_with_udt(self): + row = Row(label=1.0, point=ExamplePoint(1.0, 2.0)) + df0 = self.spark.createDataFrame([row]) + output_dir = os.path.join(self.tempdir.name, "labeled_point") + df0.write.parquet(output_dir) + df1 = self.spark.read.parquet(output_dir) + point = df1.head().point + self.assertEqual(point, ExamplePoint(1.0, 2.0)) + + row = Row(label=1.0, point=PythonOnlyPoint(1.0, 2.0)) + df0 = self.spark.createDataFrame([row]) + df0.write.parquet(output_dir, mode='overwrite') + df1 = self.spark.read.parquet(output_dir) + point = df1.head().point + self.assertEqual(point, PythonOnlyPoint(1.0, 2.0)) + + def test_union_with_udt(self): + row1 = (1.0, ExamplePoint(1.0, 2.0)) + row2 = (2.0, ExamplePoint(3.0, 4.0)) + schema = StructType([StructField("label", DoubleType(), False), + StructField("point", ExamplePointUDT(), False)]) + df1 = self.spark.createDataFrame([row1], schema) + df2 = self.spark.createDataFrame([row2], schema) + + result = df1.union(df2).orderBy("label").collect() + self.assertEqual( + result, + [ + Row(label=1.0, point=ExamplePoint(1.0, 2.0)), + Row(label=2.0, point=ExamplePoint(3.0, 4.0)) + ] + ) + + def test_cast_to_string_with_udt(self): + from pyspark.sql.functions import col + row = (ExamplePoint(1.0, 2.0), PythonOnlyPoint(3.0, 4.0)) + schema = StructType([StructField("point", ExamplePointUDT(), False), + StructField("pypoint", PythonOnlyUDT(), False)]) + df = self.spark.createDataFrame([row], schema) + + result = df.select(col('point').cast('string'), col('pypoint').cast('string')).head() + self.assertEqual(result, Row(point=u'(1.0, 2.0)', pypoint=u'[3.0, 4.0]')) + + def test_struct_type(self): + struct1 = StructType().add("f1", StringType(), True).add("f2", StringType(), True, None) + struct2 = StructType([StructField("f1", StringType(), True), + StructField("f2", StringType(), True, None)]) + self.assertEqual(struct1.fieldNames(), struct2.names) + self.assertEqual(struct1, struct2) + + struct1 = StructType().add("f1", StringType(), True).add("f2", StringType(), True, None) + struct2 = StructType([StructField("f1", StringType(), True)]) + self.assertNotEqual(struct1.fieldNames(), struct2.names) + self.assertNotEqual(struct1, struct2) + + struct1 = (StructType().add(StructField("f1", StringType(), True)) + .add(StructField("f2", StringType(), True, None))) + struct2 = StructType([StructField("f1", StringType(), True), + StructField("f2", StringType(), True, None)]) + self.assertEqual(struct1.fieldNames(), struct2.names) + self.assertEqual(struct1, struct2) + + struct1 = (StructType().add(StructField("f1", StringType(), True)) + .add(StructField("f2", StringType(), True, None))) + struct2 = StructType([StructField("f1", StringType(), True)]) + self.assertNotEqual(struct1.fieldNames(), struct2.names) + self.assertNotEqual(struct1, struct2) + + # Catch exception raised during improper construction + self.assertRaises(ValueError, lambda: StructType().add("name")) + + struct1 = StructType().add("f1", StringType(), True).add("f2", StringType(), True, None) + for field in struct1: + self.assertIsInstance(field, StructField) + + struct1 = StructType().add("f1", StringType(), True).add("f2", StringType(), True, None) + self.assertEqual(len(struct1), 2) + + struct1 = StructType().add("f1", StringType(), True).add("f2", StringType(), True, None) + self.assertIs(struct1["f1"], struct1.fields[0]) + self.assertIs(struct1[0], struct1.fields[0]) + self.assertEqual(struct1[0:1], StructType(struct1.fields[0:1])) + self.assertRaises(KeyError, lambda: struct1["f9"]) + self.assertRaises(IndexError, lambda: struct1[9]) + self.assertRaises(TypeError, lambda: struct1[9.9]) + + def test_parse_datatype_string(self): + from pyspark.sql.types import _all_atomic_types, _parse_datatype_string + for k, t in _all_atomic_types.items(): + if t != NullType: + self.assertEqual(t(), _parse_datatype_string(k)) + self.assertEqual(IntegerType(), _parse_datatype_string("int")) + self.assertEqual(DecimalType(1, 1), _parse_datatype_string("decimal(1 ,1)")) + self.assertEqual(DecimalType(10, 1), _parse_datatype_string("decimal( 10,1 )")) + self.assertEqual(DecimalType(11, 1), _parse_datatype_string("decimal(11,1)")) + self.assertEqual( + ArrayType(IntegerType()), + _parse_datatype_string("array")) + self.assertEqual( + MapType(IntegerType(), DoubleType()), + _parse_datatype_string("map< int, double >")) + self.assertEqual( + StructType([StructField("a", IntegerType()), StructField("c", DoubleType())]), + _parse_datatype_string("struct")) + self.assertEqual( + StructType([StructField("a", IntegerType()), StructField("c", DoubleType())]), + _parse_datatype_string("a:int, c:double")) + self.assertEqual( + StructType([StructField("a", IntegerType()), StructField("c", DoubleType())]), + _parse_datatype_string("a INT, c DOUBLE")) + + def test_metadata_null(self): + schema = StructType([StructField("f1", StringType(), True, None), + StructField("f2", StringType(), True, {'a': None})]) + rdd = self.sc.parallelize([["a", "b"], ["c", "d"]]) + self.spark.createDataFrame(rdd, schema) + + def test_access_nested_types(self): + df = self.sc.parallelize([Row(l=[1], r=Row(a=1, b="b"), d={"k": "v"})]).toDF() + self.assertEqual(1, df.select(df.l[0]).first()[0]) + self.assertEqual(1, df.select(df.l.getItem(0)).first()[0]) + self.assertEqual(1, df.select(df.r.a).first()[0]) + self.assertEqual("b", df.select(df.r.getField("b")).first()[0]) + self.assertEqual("v", df.select(df.d["k"]).first()[0]) + self.assertEqual("v", df.select(df.d.getItem("k")).first()[0]) + + def test_infer_long_type(self): + longrow = [Row(f1='a', f2=100000000000000)] + df = self.sc.parallelize(longrow).toDF() + self.assertEqual(df.schema.fields[1].dataType, LongType()) + + # this saving as Parquet caused issues as well. + output_dir = os.path.join(self.tempdir.name, "infer_long_type") + df.write.parquet(output_dir) + df1 = self.spark.read.parquet(output_dir) + self.assertEqual('a', df1.first().f1) + self.assertEqual(100000000000000, df1.first().f2) + + self.assertEqual(_infer_type(1), LongType()) + self.assertEqual(_infer_type(2**10), LongType()) + self.assertEqual(_infer_type(2**20), LongType()) + self.assertEqual(_infer_type(2**31 - 1), LongType()) + self.assertEqual(_infer_type(2**31), LongType()) + self.assertEqual(_infer_type(2**61), LongType()) + self.assertEqual(_infer_type(2**71), LongType()) + + def test_merge_type(self): + self.assertEqual(_merge_type(LongType(), NullType()), LongType()) + self.assertEqual(_merge_type(NullType(), LongType()), LongType()) + + self.assertEqual(_merge_type(LongType(), LongType()), LongType()) + + self.assertEqual(_merge_type( + ArrayType(LongType()), + ArrayType(LongType()) + ), ArrayType(LongType())) + with self.assertRaisesRegexp(TypeError, 'element in array'): + _merge_type(ArrayType(LongType()), ArrayType(DoubleType())) + + self.assertEqual(_merge_type( + MapType(StringType(), LongType()), + MapType(StringType(), LongType()) + ), MapType(StringType(), LongType())) + with self.assertRaisesRegexp(TypeError, 'key of map'): + _merge_type( + MapType(StringType(), LongType()), + MapType(DoubleType(), LongType())) + with self.assertRaisesRegexp(TypeError, 'value of map'): + _merge_type( + MapType(StringType(), LongType()), + MapType(StringType(), DoubleType())) + + self.assertEqual(_merge_type( + StructType([StructField("f1", LongType()), StructField("f2", StringType())]), + StructType([StructField("f1", LongType()), StructField("f2", StringType())]) + ), StructType([StructField("f1", LongType()), StructField("f2", StringType())])) + with self.assertRaisesRegexp(TypeError, 'field f1'): + _merge_type( + StructType([StructField("f1", LongType()), StructField("f2", StringType())]), + StructType([StructField("f1", DoubleType()), StructField("f2", StringType())])) + + self.assertEqual(_merge_type( + StructType([StructField("f1", StructType([StructField("f2", LongType())]))]), + StructType([StructField("f1", StructType([StructField("f2", LongType())]))]) + ), StructType([StructField("f1", StructType([StructField("f2", LongType())]))])) + with self.assertRaisesRegexp(TypeError, 'field f2 in field f1'): + _merge_type( + StructType([StructField("f1", StructType([StructField("f2", LongType())]))]), + StructType([StructField("f1", StructType([StructField("f2", StringType())]))])) + + self.assertEqual(_merge_type( + StructType([StructField("f1", ArrayType(LongType())), StructField("f2", StringType())]), + StructType([StructField("f1", ArrayType(LongType())), StructField("f2", StringType())]) + ), StructType([StructField("f1", ArrayType(LongType())), StructField("f2", StringType())])) + with self.assertRaisesRegexp(TypeError, 'element in array field f1'): + _merge_type( + StructType([ + StructField("f1", ArrayType(LongType())), + StructField("f2", StringType())]), + StructType([ + StructField("f1", ArrayType(DoubleType())), + StructField("f2", StringType())])) + + self.assertEqual(_merge_type( + StructType([ + StructField("f1", MapType(StringType(), LongType())), + StructField("f2", StringType())]), + StructType([ + StructField("f1", MapType(StringType(), LongType())), + StructField("f2", StringType())]) + ), StructType([ + StructField("f1", MapType(StringType(), LongType())), + StructField("f2", StringType())])) + with self.assertRaisesRegexp(TypeError, 'value of map field f1'): + _merge_type( + StructType([ + StructField("f1", MapType(StringType(), LongType())), + StructField("f2", StringType())]), + StructType([ + StructField("f1", MapType(StringType(), DoubleType())), + StructField("f2", StringType())])) + + self.assertEqual(_merge_type( + StructType([StructField("f1", ArrayType(MapType(StringType(), LongType())))]), + StructType([StructField("f1", ArrayType(MapType(StringType(), LongType())))]) + ), StructType([StructField("f1", ArrayType(MapType(StringType(), LongType())))])) + with self.assertRaisesRegexp(TypeError, 'key of map element in array field f1'): + _merge_type( + StructType([StructField("f1", ArrayType(MapType(StringType(), LongType())))]), + StructType([StructField("f1", ArrayType(MapType(DoubleType(), LongType())))]) + ) + + # test for SPARK-16542 + def test_array_types(self): + # This test need to make sure that the Scala type selected is at least + # as large as the python's types. This is necessary because python's + # array types depend on C implementation on the machine. Therefore there + # is no machine independent correspondence between python's array types + # and Scala types. + # See: https://docs.python.org/2/library/array.html + + def assertCollectSuccess(typecode, value): + row = Row(myarray=array.array(typecode, [value])) + df = self.spark.createDataFrame([row]) + self.assertEqual(df.first()["myarray"][0], value) + + # supported string types + # + # String types in python's array are "u" for Py_UNICODE and "c" for char. + # "u" will be removed in python 4, and "c" is not supported in python 3. + supported_string_types = [] + if sys.version_info[0] < 4: + supported_string_types += ['u'] + # test unicode + assertCollectSuccess('u', u'a') + if sys.version_info[0] < 3: + supported_string_types += ['c'] + # test string + assertCollectSuccess('c', 'a') + + # supported float and double + # + # Test max, min, and precision for float and double, assuming IEEE 754 + # floating-point format. + supported_fractional_types = ['f', 'd'] + assertCollectSuccess('f', ctypes.c_float(1e+38).value) + assertCollectSuccess('f', ctypes.c_float(1e-38).value) + assertCollectSuccess('f', ctypes.c_float(1.123456).value) + assertCollectSuccess('d', sys.float_info.max) + assertCollectSuccess('d', sys.float_info.min) + assertCollectSuccess('d', sys.float_info.epsilon) + + # supported signed int types + # + # The size of C types changes with implementation, we need to make sure + # that there is no overflow error on the platform running this test. + supported_signed_int_types = list( + set(_array_signed_int_typecode_ctype_mappings.keys()) + .intersection(set(_array_type_mappings.keys()))) + for t in supported_signed_int_types: + ctype = _array_signed_int_typecode_ctype_mappings[t] + max_val = 2 ** (ctypes.sizeof(ctype) * 8 - 1) + assertCollectSuccess(t, max_val - 1) + assertCollectSuccess(t, -max_val) + + # supported unsigned int types + # + # JVM does not have unsigned types. We need to be very careful to make + # sure that there is no overflow error. + supported_unsigned_int_types = list( + set(_array_unsigned_int_typecode_ctype_mappings.keys()) + .intersection(set(_array_type_mappings.keys()))) + for t in supported_unsigned_int_types: + ctype = _array_unsigned_int_typecode_ctype_mappings[t] + assertCollectSuccess(t, 2 ** (ctypes.sizeof(ctype) * 8) - 1) + + # all supported types + # + # Make sure the types tested above: + # 1. are all supported types + # 2. cover all supported types + supported_types = (supported_string_types + + supported_fractional_types + + supported_signed_int_types + + supported_unsigned_int_types) + self.assertEqual(set(supported_types), set(_array_type_mappings.keys())) + + # all unsupported types + # + # Keys in _array_type_mappings is a complete list of all supported types, + # and types not in _array_type_mappings are considered unsupported. + # `array.typecodes` are not supported in python 2. + if sys.version_info[0] < 3: + all_types = set(['c', 'b', 'B', 'u', 'h', 'H', 'i', 'I', 'l', 'L', 'f', 'd']) + else: + all_types = set(array.typecodes) + unsupported_types = all_types - set(supported_types) + # test unsupported types + for t in unsupported_types: + with self.assertRaises(TypeError): + a = array.array(t) + self.spark.createDataFrame([Row(myarray=a)]).collect() + + +class DataTypeTests(unittest.TestCase): + # regression test for SPARK-6055 + def test_data_type_eq(self): + lt = LongType() + lt2 = pickle.loads(pickle.dumps(LongType())) + self.assertEqual(lt, lt2) + + # regression test for SPARK-7978 + def test_decimal_type(self): + t1 = DecimalType() + t2 = DecimalType(10, 2) + self.assertTrue(t2 is not t1) + self.assertNotEqual(t1, t2) + t3 = DecimalType(8) + self.assertNotEqual(t2, t3) + + # regression test for SPARK-10392 + def test_datetype_equal_zero(self): + dt = DateType() + self.assertEqual(dt.fromInternal(0), datetime.date(1970, 1, 1)) + + # regression test for SPARK-17035 + def test_timestamp_microsecond(self): + tst = TimestampType() + self.assertEqual(tst.toInternal(datetime.datetime.max) % 1000000, 999999) + + def test_empty_row(self): + row = Row() + self.assertEqual(len(row), 0) + + def test_struct_field_type_name(self): + struct_field = StructField("a", IntegerType()) + self.assertRaises(TypeError, struct_field.typeName) + + def test_invalid_create_row(self): + row_class = Row("c1", "c2") + self.assertRaises(ValueError, lambda: row_class(1, 2, 3)) + + +class DataTypeVerificationTests(unittest.TestCase): + + def test_verify_type_exception_msg(self): + self.assertRaisesRegexp( + ValueError, + "test_name", + lambda: _make_type_verifier(StringType(), nullable=False, name="test_name")(None)) + + schema = StructType([StructField('a', StructType([StructField('b', IntegerType())]))]) + self.assertRaisesRegexp( + TypeError, + "field b in field a", + lambda: _make_type_verifier(schema)([["data"]])) + + def test_verify_type_ok_nullable(self): + obj = None + types = [IntegerType(), FloatType(), StringType(), StructType([])] + for data_type in types: + try: + _make_type_verifier(data_type, nullable=True)(obj) + except Exception: + self.fail("verify_type(%s, %s, nullable=True)" % (obj, data_type)) + + def test_verify_type_not_nullable(self): + import array + import datetime + import decimal + + schema = StructType([ + StructField('s', StringType(), nullable=False), + StructField('i', IntegerType(), nullable=True)]) + + class MyObj: + def __init__(self, **kwargs): + for k, v in kwargs.items(): + setattr(self, k, v) + + # obj, data_type + success_spec = [ + # String + ("", StringType()), + (u"", StringType()), + (1, StringType()), + (1.0, StringType()), + ([], StringType()), + ({}, StringType()), + + # UDT + (ExamplePoint(1.0, 2.0), ExamplePointUDT()), + + # Boolean + (True, BooleanType()), + + # Byte + (-(2**7), ByteType()), + (2**7 - 1, ByteType()), + + # Short + (-(2**15), ShortType()), + (2**15 - 1, ShortType()), + + # Integer + (-(2**31), IntegerType()), + (2**31 - 1, IntegerType()), + + # Long + (2**64, LongType()), + + # Float & Double + (1.0, FloatType()), + (1.0, DoubleType()), + + # Decimal + (decimal.Decimal("1.0"), DecimalType()), + + # Binary + (bytearray([1, 2]), BinaryType()), + + # Date/Timestamp + (datetime.date(2000, 1, 2), DateType()), + (datetime.datetime(2000, 1, 2, 3, 4), DateType()), + (datetime.datetime(2000, 1, 2, 3, 4), TimestampType()), + + # Array + ([], ArrayType(IntegerType())), + (["1", None], ArrayType(StringType(), containsNull=True)), + ([1, 2], ArrayType(IntegerType())), + ((1, 2), ArrayType(IntegerType())), + (array.array('h', [1, 2]), ArrayType(IntegerType())), + + # Map + ({}, MapType(StringType(), IntegerType())), + ({"a": 1}, MapType(StringType(), IntegerType())), + ({"a": None}, MapType(StringType(), IntegerType(), valueContainsNull=True)), + + # Struct + ({"s": "a", "i": 1}, schema), + ({"s": "a", "i": None}, schema), + ({"s": "a"}, schema), + ({"s": "a", "f": 1.0}, schema), + (Row(s="a", i=1), schema), + (Row(s="a", i=None), schema), + (Row(s="a", i=1, f=1.0), schema), + (["a", 1], schema), + (["a", None], schema), + (("a", 1), schema), + (MyObj(s="a", i=1), schema), + (MyObj(s="a", i=None), schema), + (MyObj(s="a"), schema), + ] + + # obj, data_type, exception class + failure_spec = [ + # String (match anything but None) + (None, StringType(), ValueError), + + # UDT + (ExamplePoint(1.0, 2.0), PythonOnlyUDT(), ValueError), + + # Boolean + (1, BooleanType(), TypeError), + ("True", BooleanType(), TypeError), + ([1], BooleanType(), TypeError), + + # Byte + (-(2**7) - 1, ByteType(), ValueError), + (2**7, ByteType(), ValueError), + ("1", ByteType(), TypeError), + (1.0, ByteType(), TypeError), + + # Short + (-(2**15) - 1, ShortType(), ValueError), + (2**15, ShortType(), ValueError), + + # Integer + (-(2**31) - 1, IntegerType(), ValueError), + (2**31, IntegerType(), ValueError), + + # Float & Double + (1, FloatType(), TypeError), + (1, DoubleType(), TypeError), + + # Decimal + (1.0, DecimalType(), TypeError), + (1, DecimalType(), TypeError), + ("1.0", DecimalType(), TypeError), + + # Binary + (1, BinaryType(), TypeError), + + # Date/Timestamp + ("2000-01-02", DateType(), TypeError), + (946811040, TimestampType(), TypeError), + + # Array + (["1", None], ArrayType(StringType(), containsNull=False), ValueError), + ([1, "2"], ArrayType(IntegerType()), TypeError), + + # Map + ({"a": 1}, MapType(IntegerType(), IntegerType()), TypeError), + ({"a": "1"}, MapType(StringType(), IntegerType()), TypeError), + ({"a": None}, MapType(StringType(), IntegerType(), valueContainsNull=False), + ValueError), + + # Struct + ({"s": "a", "i": "1"}, schema, TypeError), + (Row(s="a"), schema, ValueError), # Row can't have missing field + (Row(s="a", i="1"), schema, TypeError), + (["a"], schema, ValueError), + (["a", "1"], schema, TypeError), + (MyObj(s="a", i="1"), schema, TypeError), + (MyObj(s=None, i="1"), schema, ValueError), + ] + + # Check success cases + for obj, data_type in success_spec: + try: + _make_type_verifier(data_type, nullable=False)(obj) + except Exception: + self.fail("verify_type(%s, %s, nullable=False)" % (obj, data_type)) + + # Check failure cases + for obj, data_type, exp in failure_spec: + msg = "verify_type(%s, %s, nullable=False) == %s" % (obj, data_type, exp) + with self.assertRaises(exp, msg=msg): + _make_type_verifier(data_type, nullable=False)(obj) + + +if __name__ == "__main__": + from pyspark.sql.tests.test_types import * + + try: + import xmlrunner + testRunner = xmlrunner.XMLTestRunner(output='target/test-reports') + except ImportError: + testRunner = None + unittest.main(testRunner=testRunner, verbosity=2) diff --git a/python/pyspark/sql/tests/test_udf.py b/python/pyspark/sql/tests/test_udf.py new file mode 100644 index 0000000000000..ed298f724d551 --- /dev/null +++ b/python/pyspark/sql/tests/test_udf.py @@ -0,0 +1,667 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import functools +import pydoc +import shutil +import tempfile +import unittest + +from pyspark import SparkContext +from pyspark.sql import SparkSession, Column, Row +from pyspark.sql.functions import UserDefinedFunction +from pyspark.sql.types import * +from pyspark.sql.utils import AnalysisException +from pyspark.testing.sqlutils import ReusedSQLTestCase, test_compiled, test_not_compiled_message +from pyspark.testing.utils import QuietTest + + +class UDFTests(ReusedSQLTestCase): + + def test_udf_with_callable(self): + d = [Row(number=i, squared=i**2) for i in range(10)] + rdd = self.sc.parallelize(d) + data = self.spark.createDataFrame(rdd) + + class PlusFour: + def __call__(self, col): + if col is not None: + return col + 4 + + call = PlusFour() + pudf = UserDefinedFunction(call, LongType()) + res = data.select(pudf(data['number']).alias('plus_four')) + self.assertEqual(res.agg({'plus_four': 'sum'}).collect()[0][0], 85) + + def test_udf_with_partial_function(self): + d = [Row(number=i, squared=i**2) for i in range(10)] + rdd = self.sc.parallelize(d) + data = self.spark.createDataFrame(rdd) + + def some_func(col, param): + if col is not None: + return col + param + + pfunc = functools.partial(some_func, param=4) + pudf = UserDefinedFunction(pfunc, LongType()) + res = data.select(pudf(data['number']).alias('plus_four')) + self.assertEqual(res.agg({'plus_four': 'sum'}).collect()[0][0], 85) + + def test_udf(self): + self.spark.catalog.registerFunction("twoArgs", lambda x, y: len(x) + y, IntegerType()) + [row] = self.spark.sql("SELECT twoArgs('test', 1)").collect() + self.assertEqual(row[0], 5) + + # This is to check if a deprecated 'SQLContext.registerFunction' can call its alias. + sqlContext = self.spark._wrapped + sqlContext.registerFunction("oneArg", lambda x: len(x), IntegerType()) + [row] = sqlContext.sql("SELECT oneArg('test')").collect() + self.assertEqual(row[0], 4) + + def test_udf2(self): + with self.tempView("test"): + self.spark.catalog.registerFunction("strlen", lambda string: len(string), IntegerType()) + self.spark.createDataFrame(self.sc.parallelize([Row(a="test")]))\ + .createOrReplaceTempView("test") + [res] = self.spark.sql("SELECT strlen(a) FROM test WHERE strlen(a) > 1").collect() + self.assertEqual(4, res[0]) + + def test_udf3(self): + two_args = self.spark.catalog.registerFunction( + "twoArgs", UserDefinedFunction(lambda x, y: len(x) + y)) + self.assertEqual(two_args.deterministic, True) + [row] = self.spark.sql("SELECT twoArgs('test', 1)").collect() + self.assertEqual(row[0], u'5') + + def test_udf_registration_return_type_none(self): + two_args = self.spark.catalog.registerFunction( + "twoArgs", UserDefinedFunction(lambda x, y: len(x) + y, "integer"), None) + self.assertEqual(two_args.deterministic, True) + [row] = self.spark.sql("SELECT twoArgs('test', 1)").collect() + self.assertEqual(row[0], 5) + + def test_udf_registration_return_type_not_none(self): + with QuietTest(self.sc): + with self.assertRaisesRegexp(TypeError, "Invalid returnType"): + self.spark.catalog.registerFunction( + "f", UserDefinedFunction(lambda x, y: len(x) + y, StringType()), StringType()) + + def test_nondeterministic_udf(self): + # Test that nondeterministic UDFs are evaluated only once in chained UDF evaluations + from pyspark.sql.functions import udf + import random + udf_random_col = udf(lambda: int(100 * random.random()), IntegerType()).asNondeterministic() + self.assertEqual(udf_random_col.deterministic, False) + df = self.spark.createDataFrame([Row(1)]).select(udf_random_col().alias('RAND')) + udf_add_ten = udf(lambda rand: rand + 10, IntegerType()) + [row] = df.withColumn('RAND_PLUS_TEN', udf_add_ten('RAND')).collect() + self.assertEqual(row[0] + 10, row[1]) + + def test_nondeterministic_udf2(self): + import random + from pyspark.sql.functions import udf + random_udf = udf(lambda: random.randint(6, 6), IntegerType()).asNondeterministic() + self.assertEqual(random_udf.deterministic, False) + random_udf1 = self.spark.catalog.registerFunction("randInt", random_udf) + self.assertEqual(random_udf1.deterministic, False) + [row] = self.spark.sql("SELECT randInt()").collect() + self.assertEqual(row[0], 6) + [row] = self.spark.range(1).select(random_udf1()).collect() + self.assertEqual(row[0], 6) + [row] = self.spark.range(1).select(random_udf()).collect() + self.assertEqual(row[0], 6) + # render_doc() reproduces the help() exception without printing output + pydoc.render_doc(udf(lambda: random.randint(6, 6), IntegerType())) + pydoc.render_doc(random_udf) + pydoc.render_doc(random_udf1) + pydoc.render_doc(udf(lambda x: x).asNondeterministic) + + def test_nondeterministic_udf3(self): + # regression test for SPARK-23233 + from pyspark.sql.functions import udf + f = udf(lambda x: x) + # Here we cache the JVM UDF instance. + self.spark.range(1).select(f("id")) + # This should reset the cache to set the deterministic status correctly. + f = f.asNondeterministic() + # Check the deterministic status of udf. + df = self.spark.range(1).select(f("id")) + deterministic = df._jdf.logicalPlan().projectList().head().deterministic() + self.assertFalse(deterministic) + + def test_nondeterministic_udf_in_aggregate(self): + from pyspark.sql.functions import udf, sum + import random + udf_random_col = udf(lambda: int(100 * random.random()), 'int').asNondeterministic() + df = self.spark.range(10) + + with QuietTest(self.sc): + with self.assertRaisesRegexp(AnalysisException, "nondeterministic"): + df.groupby('id').agg(sum(udf_random_col())).collect() + with self.assertRaisesRegexp(AnalysisException, "nondeterministic"): + df.agg(sum(udf_random_col())).collect() + + def test_chained_udf(self): + self.spark.catalog.registerFunction("double", lambda x: x + x, IntegerType()) + [row] = self.spark.sql("SELECT double(1)").collect() + self.assertEqual(row[0], 2) + [row] = self.spark.sql("SELECT double(double(1))").collect() + self.assertEqual(row[0], 4) + [row] = self.spark.sql("SELECT double(double(1) + 1)").collect() + self.assertEqual(row[0], 6) + + def test_single_udf_with_repeated_argument(self): + # regression test for SPARK-20685 + self.spark.catalog.registerFunction("add", lambda x, y: x + y, IntegerType()) + row = self.spark.sql("SELECT add(1, 1)").first() + self.assertEqual(tuple(row), (2, )) + + def test_multiple_udfs(self): + self.spark.catalog.registerFunction("double", lambda x: x * 2, IntegerType()) + [row] = self.spark.sql("SELECT double(1), double(2)").collect() + self.assertEqual(tuple(row), (2, 4)) + [row] = self.spark.sql("SELECT double(double(1)), double(double(2) + 2)").collect() + self.assertEqual(tuple(row), (4, 12)) + self.spark.catalog.registerFunction("add", lambda x, y: x + y, IntegerType()) + [row] = self.spark.sql("SELECT double(add(1, 2)), add(double(2), 1)").collect() + self.assertEqual(tuple(row), (6, 5)) + + def test_udf_in_filter_on_top_of_outer_join(self): + from pyspark.sql.functions import udf + left = self.spark.createDataFrame([Row(a=1)]) + right = self.spark.createDataFrame([Row(a=1)]) + df = left.join(right, on='a', how='left_outer') + df = df.withColumn('b', udf(lambda x: 'x')(df.a)) + self.assertEqual(df.filter('b = "x"').collect(), [Row(a=1, b='x')]) + + def test_udf_in_filter_on_top_of_join(self): + # regression test for SPARK-18589 + from pyspark.sql.functions import udf + left = self.spark.createDataFrame([Row(a=1)]) + right = self.spark.createDataFrame([Row(b=1)]) + f = udf(lambda a, b: a == b, BooleanType()) + df = left.crossJoin(right).filter(f("a", "b")) + self.assertEqual(df.collect(), [Row(a=1, b=1)]) + + def test_udf_in_join_condition(self): + # regression test for SPARK-25314 + from pyspark.sql.functions import udf + left = self.spark.createDataFrame([Row(a=1)]) + right = self.spark.createDataFrame([Row(b=1)]) + f = udf(lambda a, b: a == b, BooleanType()) + df = left.join(right, f("a", "b")) + with self.assertRaisesRegexp(AnalysisException, 'Detected implicit cartesian product'): + df.collect() + with self.sql_conf({"spark.sql.crossJoin.enabled": True}): + self.assertEqual(df.collect(), [Row(a=1, b=1)]) + + def test_udf_in_left_outer_join_condition(self): + # regression test for SPARK-26147 + from pyspark.sql.functions import udf, col + left = self.spark.createDataFrame([Row(a=1)]) + right = self.spark.createDataFrame([Row(b=1)]) + f = udf(lambda a: str(a), StringType()) + # The join condition can't be pushed down, as it refers to attributes from both sides. + # The Python UDF only refer to attributes from one side, so it's evaluable. + df = left.join(right, f("a") == col("b").cast("string"), how="left_outer") + with self.sql_conf({"spark.sql.crossJoin.enabled": True}): + self.assertEqual(df.collect(), [Row(a=1, b=1)]) + + def test_udf_in_left_semi_join_condition(self): + # regression test for SPARK-25314 + from pyspark.sql.functions import udf + left = self.spark.createDataFrame([Row(a=1, a1=1, a2=1), Row(a=2, a1=2, a2=2)]) + right = self.spark.createDataFrame([Row(b=1, b1=1, b2=1)]) + f = udf(lambda a, b: a == b, BooleanType()) + df = left.join(right, f("a", "b"), "leftsemi") + with self.assertRaisesRegexp(AnalysisException, 'Detected implicit cartesian product'): + df.collect() + with self.sql_conf({"spark.sql.crossJoin.enabled": True}): + self.assertEqual(df.collect(), [Row(a=1, a1=1, a2=1)]) + + def test_udf_and_common_filter_in_join_condition(self): + # regression test for SPARK-25314 + # test the complex scenario with both udf and common filter + from pyspark.sql.functions import udf + left = self.spark.createDataFrame([Row(a=1, a1=1, a2=1), Row(a=2, a1=2, a2=2)]) + right = self.spark.createDataFrame([Row(b=1, b1=1, b2=1), Row(b=1, b1=3, b2=1)]) + f = udf(lambda a, b: a == b, BooleanType()) + df = left.join(right, [f("a", "b"), left.a1 == right.b1]) + # do not need spark.sql.crossJoin.enabled=true for udf is not the only join condition. + self.assertEqual(df.collect(), [Row(a=1, a1=1, a2=1, b=1, b1=1, b2=1)]) + + def test_udf_and_common_filter_in_left_semi_join_condition(self): + # regression test for SPARK-25314 + # test the complex scenario with both udf and common filter + from pyspark.sql.functions import udf + left = self.spark.createDataFrame([Row(a=1, a1=1, a2=1), Row(a=2, a1=2, a2=2)]) + right = self.spark.createDataFrame([Row(b=1, b1=1, b2=1), Row(b=1, b1=3, b2=1)]) + f = udf(lambda a, b: a == b, BooleanType()) + df = left.join(right, [f("a", "b"), left.a1 == right.b1], "left_semi") + # do not need spark.sql.crossJoin.enabled=true for udf is not the only join condition. + self.assertEqual(df.collect(), [Row(a=1, a1=1, a2=1)]) + + def test_udf_not_supported_in_join_condition(self): + # regression test for SPARK-25314 + # test python udf is not supported in join type besides left_semi and inner join. + from pyspark.sql.functions import udf + left = self.spark.createDataFrame([Row(a=1, a1=1, a2=1), Row(a=2, a1=2, a2=2)]) + right = self.spark.createDataFrame([Row(b=1, b1=1, b2=1), Row(b=1, b1=3, b2=1)]) + f = udf(lambda a, b: a == b, BooleanType()) + + def runWithJoinType(join_type, type_string): + with self.assertRaisesRegexp( + AnalysisException, + 'Using PythonUDF.*%s is not supported.' % type_string): + left.join(right, [f("a", "b"), left.a1 == right.b1], join_type).collect() + runWithJoinType("full", "FullOuter") + runWithJoinType("left", "LeftOuter") + runWithJoinType("right", "RightOuter") + runWithJoinType("leftanti", "LeftAnti") + + def test_udf_without_arguments(self): + self.spark.catalog.registerFunction("foo", lambda: "bar") + [row] = self.spark.sql("SELECT foo()").collect() + self.assertEqual(row[0], "bar") + + def test_udf_with_array_type(self): + with self.tempView("test"): + d = [Row(l=list(range(3)), d={"key": list(range(5))})] + rdd = self.sc.parallelize(d) + self.spark.createDataFrame(rdd).createOrReplaceTempView("test") + self.spark.catalog.registerFunction( + "copylist", lambda l: list(l), ArrayType(IntegerType())) + self.spark.catalog.registerFunction("maplen", lambda d: len(d), IntegerType()) + [(l1, l2)] = self.spark.sql("select copylist(l), maplen(d) from test").collect() + self.assertEqual(list(range(3)), l1) + self.assertEqual(1, l2) + + def test_broadcast_in_udf(self): + bar = {"a": "aa", "b": "bb", "c": "abc"} + foo = self.sc.broadcast(bar) + self.spark.catalog.registerFunction("MYUDF", lambda x: foo.value[x] if x else '') + [res] = self.spark.sql("SELECT MYUDF('c')").collect() + self.assertEqual("abc", res[0]) + [res] = self.spark.sql("SELECT MYUDF('')").collect() + self.assertEqual("", res[0]) + + def test_udf_with_filter_function(self): + df = self.spark.createDataFrame([(1, "1"), (2, "2"), (1, "2"), (1, "2")], ["key", "value"]) + from pyspark.sql.functions import udf, col + from pyspark.sql.types import BooleanType + + my_filter = udf(lambda a: a < 2, BooleanType()) + sel = df.select(col("key"), col("value")).filter((my_filter(col("key"))) & (df.value < "2")) + self.assertEqual(sel.collect(), [Row(key=1, value='1')]) + + def test_udf_with_aggregate_function(self): + df = self.spark.createDataFrame([(1, "1"), (2, "2"), (1, "2"), (1, "2")], ["key", "value"]) + from pyspark.sql.functions import udf, col, sum + from pyspark.sql.types import BooleanType + + my_filter = udf(lambda a: a == 1, BooleanType()) + sel = df.select(col("key")).distinct().filter(my_filter(col("key"))) + self.assertEqual(sel.collect(), [Row(key=1)]) + + my_copy = udf(lambda x: x, IntegerType()) + my_add = udf(lambda a, b: int(a + b), IntegerType()) + my_strlen = udf(lambda x: len(x), IntegerType()) + sel = df.groupBy(my_copy(col("key")).alias("k"))\ + .agg(sum(my_strlen(col("value"))).alias("s"))\ + .select(my_add(col("k"), col("s")).alias("t")) + self.assertEqual(sel.collect(), [Row(t=4), Row(t=3)]) + + def test_udf_in_generate(self): + from pyspark.sql.functions import udf, explode + df = self.spark.range(5) + f = udf(lambda x: list(range(x)), ArrayType(LongType())) + row = df.select(explode(f(*df))).groupBy().sum().first() + self.assertEqual(row[0], 10) + + df = self.spark.range(3) + res = df.select("id", explode(f(df.id))).collect() + self.assertEqual(res[0][0], 1) + self.assertEqual(res[0][1], 0) + self.assertEqual(res[1][0], 2) + self.assertEqual(res[1][1], 0) + self.assertEqual(res[2][0], 2) + self.assertEqual(res[2][1], 1) + + range_udf = udf(lambda value: list(range(value - 1, value + 1)), ArrayType(IntegerType())) + res = df.select("id", explode(range_udf(df.id))).collect() + self.assertEqual(res[0][0], 0) + self.assertEqual(res[0][1], -1) + self.assertEqual(res[1][0], 0) + self.assertEqual(res[1][1], 0) + self.assertEqual(res[2][0], 1) + self.assertEqual(res[2][1], 0) + self.assertEqual(res[3][0], 1) + self.assertEqual(res[3][1], 1) + + def test_udf_with_order_by_and_limit(self): + from pyspark.sql.functions import udf + my_copy = udf(lambda x: x, IntegerType()) + df = self.spark.range(10).orderBy("id") + res = df.select(df.id, my_copy(df.id).alias("copy")).limit(1) + res.explain(True) + self.assertEqual(res.collect(), [Row(id=0, copy=0)]) + + def test_udf_registration_returns_udf(self): + df = self.spark.range(10) + add_three = self.spark.udf.register("add_three", lambda x: x + 3, IntegerType()) + + self.assertListEqual( + df.selectExpr("add_three(id) AS plus_three").collect(), + df.select(add_three("id").alias("plus_three")).collect() + ) + + # This is to check if a 'SQLContext.udf' can call its alias. + sqlContext = self.spark._wrapped + add_four = sqlContext.udf.register("add_four", lambda x: x + 4, IntegerType()) + + self.assertListEqual( + df.selectExpr("add_four(id) AS plus_four").collect(), + df.select(add_four("id").alias("plus_four")).collect() + ) + + def test_non_existed_udf(self): + spark = self.spark + self.assertRaisesRegexp(AnalysisException, "Can not load class non_existed_udf", + lambda: spark.udf.registerJavaFunction("udf1", "non_existed_udf")) + + # This is to check if a deprecated 'SQLContext.registerJavaFunction' can call its alias. + sqlContext = spark._wrapped + self.assertRaisesRegexp(AnalysisException, "Can not load class non_existed_udf", + lambda: sqlContext.registerJavaFunction("udf1", "non_existed_udf")) + + def test_non_existed_udaf(self): + spark = self.spark + self.assertRaisesRegexp(AnalysisException, "Can not load class non_existed_udaf", + lambda: spark.udf.registerJavaUDAF("udaf1", "non_existed_udaf")) + + def test_udf_with_input_file_name(self): + from pyspark.sql.functions import udf, input_file_name + sourceFile = udf(lambda path: path, StringType()) + filePath = "python/test_support/sql/people1.json" + row = self.spark.read.json(filePath).select(sourceFile(input_file_name())).first() + self.assertTrue(row[0].find("people1.json") != -1) + + def test_udf_with_input_file_name_for_hadooprdd(self): + from pyspark.sql.functions import udf, input_file_name + + def filename(path): + return path + + sameText = udf(filename, StringType()) + + rdd = self.sc.textFile('python/test_support/sql/people.json') + df = self.spark.read.json(rdd).select(input_file_name().alias('file')) + row = df.select(sameText(df['file'])).first() + self.assertTrue(row[0].find("people.json") != -1) + + rdd2 = self.sc.newAPIHadoopFile( + 'python/test_support/sql/people.json', + 'org.apache.hadoop.mapreduce.lib.input.TextInputFormat', + 'org.apache.hadoop.io.LongWritable', + 'org.apache.hadoop.io.Text') + + df2 = self.spark.read.json(rdd2).select(input_file_name().alias('file')) + row2 = df2.select(sameText(df2['file'])).first() + self.assertTrue(row2[0].find("people.json") != -1) + + def test_udf_defers_judf_initialization(self): + # This is separate of UDFInitializationTests + # to avoid context initialization + # when udf is called + + from pyspark.sql.functions import UserDefinedFunction + + f = UserDefinedFunction(lambda x: x, StringType()) + + self.assertIsNone( + f._judf_placeholder, + "judf should not be initialized before the first call." + ) + + self.assertIsInstance(f("foo"), Column, "UDF call should return a Column.") + + self.assertIsNotNone( + f._judf_placeholder, + "judf should be initialized after UDF has been called." + ) + + def test_udf_with_string_return_type(self): + from pyspark.sql.functions import UserDefinedFunction + + add_one = UserDefinedFunction(lambda x: x + 1, "integer") + make_pair = UserDefinedFunction(lambda x: (-x, x), "struct") + make_array = UserDefinedFunction( + lambda x: [float(x) for x in range(x, x + 3)], "array") + + expected = (2, Row(x=-1, y=1), [1.0, 2.0, 3.0]) + actual = (self.spark.range(1, 2).toDF("x") + .select(add_one("x"), make_pair("x"), make_array("x")) + .first()) + + self.assertTupleEqual(expected, actual) + + def test_udf_shouldnt_accept_noncallable_object(self): + from pyspark.sql.functions import UserDefinedFunction + + non_callable = None + self.assertRaises(TypeError, UserDefinedFunction, non_callable, StringType()) + + def test_udf_with_decorator(self): + from pyspark.sql.functions import lit, udf + from pyspark.sql.types import IntegerType, DoubleType + + @udf(IntegerType()) + def add_one(x): + if x is not None: + return x + 1 + + @udf(returnType=DoubleType()) + def add_two(x): + if x is not None: + return float(x + 2) + + @udf + def to_upper(x): + if x is not None: + return x.upper() + + @udf() + def to_lower(x): + if x is not None: + return x.lower() + + @udf + def substr(x, start, end): + if x is not None: + return x[start:end] + + @udf("long") + def trunc(x): + return int(x) + + @udf(returnType="double") + def as_double(x): + return float(x) + + df = ( + self.spark + .createDataFrame( + [(1, "Foo", "foobar", 3.0)], ("one", "Foo", "foobar", "float")) + .select( + add_one("one"), add_two("one"), + to_upper("Foo"), to_lower("Foo"), + substr("foobar", lit(0), lit(3)), + trunc("float"), as_double("one"))) + + self.assertListEqual( + [tpe for _, tpe in df.dtypes], + ["int", "double", "string", "string", "string", "bigint", "double"] + ) + + self.assertListEqual( + list(df.first()), + [2, 3.0, "FOO", "foo", "foo", 3, 1.0] + ) + + def test_udf_wrapper(self): + from pyspark.sql.functions import udf + from pyspark.sql.types import IntegerType + + def f(x): + """Identity""" + return x + + return_type = IntegerType() + f_ = udf(f, return_type) + + self.assertTrue(f.__doc__ in f_.__doc__) + self.assertEqual(f, f_.func) + self.assertEqual(return_type, f_.returnType) + + class F(object): + """Identity""" + def __call__(self, x): + return x + + f = F() + return_type = IntegerType() + f_ = udf(f, return_type) + + self.assertTrue(f.__doc__ in f_.__doc__) + self.assertEqual(f, f_.func) + self.assertEqual(return_type, f_.returnType) + + f = functools.partial(f, x=1) + return_type = IntegerType() + f_ = udf(f, return_type) + + self.assertTrue(f.__doc__ in f_.__doc__) + self.assertEqual(f, f_.func) + self.assertEqual(return_type, f_.returnType) + + def test_nonparam_udf_with_aggregate(self): + import pyspark.sql.functions as f + + df = self.spark.createDataFrame([(1, 2), (1, 2)]) + f_udf = f.udf(lambda: "const_str") + rows = df.distinct().withColumn("a", f_udf()).collect() + self.assertEqual(rows, [Row(_1=1, _2=2, a=u'const_str')]) + + # SPARK-24721 + @unittest.skipIf(not test_compiled, test_not_compiled_message) + def test_datasource_with_udf(self): + from pyspark.sql.functions import udf, lit, col + + path = tempfile.mkdtemp() + shutil.rmtree(path) + + try: + self.spark.range(1).write.mode("overwrite").format('csv').save(path) + filesource_df = self.spark.read.option('inferSchema', True).csv(path).toDF('i') + datasource_df = self.spark.read \ + .format("org.apache.spark.sql.sources.SimpleScanSource") \ + .option('from', 0).option('to', 1).load().toDF('i') + datasource_v2_df = self.spark.read \ + .format("org.apache.spark.sql.sources.v2.SimpleDataSourceV2") \ + .load().toDF('i', 'j') + + c1 = udf(lambda x: x + 1, 'int')(lit(1)) + c2 = udf(lambda x: x + 1, 'int')(col('i')) + + f1 = udf(lambda x: False, 'boolean')(lit(1)) + f2 = udf(lambda x: False, 'boolean')(col('i')) + + for df in [filesource_df, datasource_df, datasource_v2_df]: + result = df.withColumn('c', c1) + expected = df.withColumn('c', lit(2)) + self.assertEquals(expected.collect(), result.collect()) + + for df in [filesource_df, datasource_df, datasource_v2_df]: + result = df.withColumn('c', c2) + expected = df.withColumn('c', col('i') + 1) + self.assertEquals(expected.collect(), result.collect()) + + for df in [filesource_df, datasource_df, datasource_v2_df]: + for f in [f1, f2]: + result = df.filter(f) + self.assertEquals(0, result.count()) + finally: + shutil.rmtree(path) + + # SPARK-25591 + def test_same_accumulator_in_udfs(self): + from pyspark.sql.functions import udf + + data_schema = StructType([StructField("a", IntegerType(), True), + StructField("b", IntegerType(), True)]) + data = self.spark.createDataFrame([[1, 2]], schema=data_schema) + + test_accum = self.sc.accumulator(0) + + def first_udf(x): + test_accum.add(1) + return x + + def second_udf(x): + test_accum.add(100) + return x + + func_udf = udf(first_udf, IntegerType()) + func_udf2 = udf(second_udf, IntegerType()) + data = data.withColumn("out1", func_udf(data["a"])) + data = data.withColumn("out2", func_udf2(data["b"])) + data.collect() + self.assertEqual(test_accum.value, 101) + + +class UDFInitializationTests(unittest.TestCase): + def tearDown(self): + if SparkSession._instantiatedSession is not None: + SparkSession._instantiatedSession.stop() + + if SparkContext._active_spark_context is not None: + SparkContext._active_spark_context.stop() + + def test_udf_init_shouldnt_initialize_context(self): + from pyspark.sql.functions import UserDefinedFunction + + UserDefinedFunction(lambda x: x, StringType()) + + self.assertIsNone( + SparkContext._active_spark_context, + "SparkContext shouldn't be initialized when UserDefinedFunction is created." + ) + self.assertIsNone( + SparkSession._instantiatedSession, + "SparkSession shouldn't be initialized when UserDefinedFunction is created." + ) + + +if __name__ == "__main__": + from pyspark.sql.tests.test_udf import * + + try: + import xmlrunner + testRunner = xmlrunner.XMLTestRunner(output='target/test-reports') + except ImportError: + testRunner = None + unittest.main(testRunner=testRunner, verbosity=2) diff --git a/python/pyspark/sql/tests/test_utils.py b/python/pyspark/sql/tests/test_utils.py new file mode 100644 index 0000000000000..5bb921da5c2f3 --- /dev/null +++ b/python/pyspark/sql/tests/test_utils.py @@ -0,0 +1,55 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from pyspark.sql.functions import sha2 +from pyspark.sql.utils import AnalysisException, ParseException, IllegalArgumentException +from pyspark.testing.sqlutils import ReusedSQLTestCase + + +class UtilsTests(ReusedSQLTestCase): + + def test_capture_analysis_exception(self): + self.assertRaises(AnalysisException, lambda: self.spark.sql("select abc")) + self.assertRaises(AnalysisException, lambda: self.df.selectExpr("a + b")) + + def test_capture_parse_exception(self): + self.assertRaises(ParseException, lambda: self.spark.sql("abc")) + + def test_capture_illegalargument_exception(self): + self.assertRaisesRegexp(IllegalArgumentException, "Setting negative mapred.reduce.tasks", + lambda: self.spark.sql("SET mapred.reduce.tasks=-1")) + df = self.spark.createDataFrame([(1, 2)], ["a", "b"]) + self.assertRaisesRegexp(IllegalArgumentException, "1024 is not in the permitted values", + lambda: df.select(sha2(df.a, 1024)).collect()) + try: + df.select(sha2(df.a, 1024)).collect() + except IllegalArgumentException as e: + self.assertRegexpMatches(e.desc, "1024 is not in the permitted values") + self.assertRegexpMatches(e.stackTrace, + "org.apache.spark.sql.functions") + + +if __name__ == "__main__": + import unittest + from pyspark.sql.tests.test_utils import * + + try: + import xmlrunner + testRunner = xmlrunner.XMLTestRunner(output='target/test-reports') + except ImportError: + testRunner = None + unittest.main(testRunner=testRunner, verbosity=2) diff --git a/python/pyspark/streaming/tests/__init__.py b/python/pyspark/streaming/tests/__init__.py new file mode 100644 index 0000000000000..cce3acad34a49 --- /dev/null +++ b/python/pyspark/streaming/tests/__init__.py @@ -0,0 +1,16 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# diff --git a/python/pyspark/streaming/tests/test_context.py b/python/pyspark/streaming/tests/test_context.py new file mode 100644 index 0000000000000..b44121462a920 --- /dev/null +++ b/python/pyspark/streaming/tests/test_context.py @@ -0,0 +1,184 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +import os +import struct +import tempfile +import time + +from pyspark.streaming import StreamingContext +from pyspark.testing.streamingutils import PySparkStreamingTestCase + + +class StreamingContextTests(PySparkStreamingTestCase): + + duration = 0.1 + setupCalled = False + + def _add_input_stream(self): + inputs = [range(1, x) for x in range(101)] + stream = self.ssc.queueStream(inputs) + self._collect(stream, 1, block=False) + + def test_stop_only_streaming_context(self): + self._add_input_stream() + self.ssc.start() + self.ssc.stop(False) + self.assertEqual(len(self.sc.parallelize(range(5), 5).glom().collect()), 5) + + def test_stop_multiple_times(self): + self._add_input_stream() + self.ssc.start() + self.ssc.stop(False) + self.ssc.stop(False) + + def test_queue_stream(self): + input = [list(range(i + 1)) for i in range(3)] + dstream = self.ssc.queueStream(input) + result = self._collect(dstream, 3) + self.assertEqual(input, result) + + def test_text_file_stream(self): + d = tempfile.mkdtemp() + self.ssc = StreamingContext(self.sc, self.duration) + dstream2 = self.ssc.textFileStream(d).map(int) + result = self._collect(dstream2, 2, block=False) + self.ssc.start() + for name in ('a', 'b'): + time.sleep(1) + with open(os.path.join(d, name), "w") as f: + f.writelines(["%d\n" % i for i in range(10)]) + self.wait_for(result, 2) + self.assertEqual([list(range(10)), list(range(10))], result) + + def test_binary_records_stream(self): + d = tempfile.mkdtemp() + self.ssc = StreamingContext(self.sc, self.duration) + dstream = self.ssc.binaryRecordsStream(d, 10).map( + lambda v: struct.unpack("10b", bytes(v))) + result = self._collect(dstream, 2, block=False) + self.ssc.start() + for name in ('a', 'b'): + time.sleep(1) + with open(os.path.join(d, name), "wb") as f: + f.write(bytearray(range(10))) + self.wait_for(result, 2) + self.assertEqual([list(range(10)), list(range(10))], [list(v[0]) for v in result]) + + def test_union(self): + input = [list(range(i + 1)) for i in range(3)] + dstream = self.ssc.queueStream(input) + dstream2 = self.ssc.queueStream(input) + dstream3 = self.ssc.union(dstream, dstream2) + result = self._collect(dstream3, 3) + expected = [i * 2 for i in input] + self.assertEqual(expected, result) + + def test_transform(self): + dstream1 = self.ssc.queueStream([[1]]) + dstream2 = self.ssc.queueStream([[2]]) + dstream3 = self.ssc.queueStream([[3]]) + + def func(rdds): + rdd1, rdd2, rdd3 = rdds + return rdd2.union(rdd3).union(rdd1) + + dstream = self.ssc.transform([dstream1, dstream2, dstream3], func) + + self.assertEqual([2, 3, 1], self._take(dstream, 3)) + + def test_transform_pairrdd(self): + # This regression test case is for SPARK-17756. + dstream = self.ssc.queueStream( + [[1], [2], [3]]).transform(lambda rdd: rdd.cartesian(rdd)) + self.assertEqual([(1, 1), (2, 2), (3, 3)], self._take(dstream, 3)) + + def test_get_active(self): + self.assertEqual(StreamingContext.getActive(), None) + + # Verify that getActive() returns the active context + self.ssc.queueStream([[1]]).foreachRDD(lambda rdd: rdd.count()) + self.ssc.start() + self.assertEqual(StreamingContext.getActive(), self.ssc) + + # Verify that getActive() returns None + self.ssc.stop(False) + self.assertEqual(StreamingContext.getActive(), None) + + # Verify that if the Java context is stopped, then getActive() returns None + self.ssc = StreamingContext(self.sc, self.duration) + self.ssc.queueStream([[1]]).foreachRDD(lambda rdd: rdd.count()) + self.ssc.start() + self.assertEqual(StreamingContext.getActive(), self.ssc) + self.ssc._jssc.stop(False) + self.assertEqual(StreamingContext.getActive(), None) + + def test_get_active_or_create(self): + # Test StreamingContext.getActiveOrCreate() without checkpoint data + # See CheckpointTests for tests with checkpoint data + self.ssc = None + self.assertEqual(StreamingContext.getActive(), None) + + def setupFunc(): + ssc = StreamingContext(self.sc, self.duration) + ssc.queueStream([[1]]).foreachRDD(lambda rdd: rdd.count()) + self.setupCalled = True + return ssc + + # Verify that getActiveOrCreate() (w/o checkpoint) calls setupFunc when no context is active + self.setupCalled = False + self.ssc = StreamingContext.getActiveOrCreate(None, setupFunc) + self.assertTrue(self.setupCalled) + + # Verify that getActiveOrCreate() returns active context and does not call the setupFunc + self.ssc.start() + self.setupCalled = False + self.assertEqual(StreamingContext.getActiveOrCreate(None, setupFunc), self.ssc) + self.assertFalse(self.setupCalled) + + # Verify that getActiveOrCreate() calls setupFunc after active context is stopped + self.ssc.stop(False) + self.setupCalled = False + self.ssc = StreamingContext.getActiveOrCreate(None, setupFunc) + self.assertTrue(self.setupCalled) + + # Verify that if the Java context is stopped, then getActive() returns None + self.ssc = StreamingContext(self.sc, self.duration) + self.ssc.queueStream([[1]]).foreachRDD(lambda rdd: rdd.count()) + self.ssc.start() + self.assertEqual(StreamingContext.getActive(), self.ssc) + self.ssc._jssc.stop(False) + self.setupCalled = False + self.ssc = StreamingContext.getActiveOrCreate(None, setupFunc) + self.assertTrue(self.setupCalled) + + def test_await_termination_or_timeout(self): + self._add_input_stream() + self.ssc.start() + self.assertFalse(self.ssc.awaitTerminationOrTimeout(0.001)) + self.ssc.stop(False) + self.assertTrue(self.ssc.awaitTerminationOrTimeout(0.001)) + + +if __name__ == "__main__": + import unittest + from pyspark.streaming.tests.test_context import * + + try: + import xmlrunner + unittest.main(testRunner=xmlrunner.XMLTestRunner(output='target/test-reports'), verbosity=2) + except ImportError: + unittest.main(verbosity=2) diff --git a/python/pyspark/streaming/tests.py b/python/pyspark/streaming/tests/test_dstream.py similarity index 50% rename from python/pyspark/streaming/tests.py rename to python/pyspark/streaming/tests/test_dstream.py index 34e3291651eec..d14e346b7a688 100644 --- a/python/pyspark/streaming/tests.py +++ b/python/pyspark/streaming/tests/test_dstream.py @@ -14,151 +14,18 @@ # See the License for the specific language governing permissions and # limitations under the License. # - -import glob -import os -import sys -from itertools import chain -import time import operator -import tempfile -import random -import struct +import os import shutil -import unishark +import tempfile +import time +import unittest from functools import reduce +from itertools import chain -if sys.version_info[:2] <= (2, 6): - try: - import unittest2 as unittest - except ImportError: - sys.stderr.write('Please install unittest2 to test with Python 2.6 or earlier') - sys.exit(1) -else: - import unittest - -if sys.version >= "3": - long = int - -from pyspark.context import SparkConf, SparkContext, RDD -from pyspark.storagelevel import StorageLevel -from pyspark.streaming.context import StreamingContext -from pyspark.streaming.kinesis import KinesisUtils, InitialPositionInStream -from pyspark.streaming.listener import StreamingListener - - -class PySparkStreamingTestCase(unittest.TestCase): - - timeout = 30 # seconds - duration = .5 - - @classmethod - def setUpClass(cls): - class_name = cls.__name__ - conf = SparkConf().set("spark.default.parallelism", 1) - cls.sc = SparkContext(appName=class_name, conf=conf) - cls.sc.setCheckpointDir(tempfile.mkdtemp()) - - @classmethod - def tearDownClass(cls): - cls.sc.stop() - # Clean up in the JVM just in case there has been some issues in Python API - try: - jSparkContextOption = SparkContext._jvm.SparkContext.get() - if jSparkContextOption.nonEmpty(): - jSparkContextOption.get().stop() - except: - pass - - def setUp(self): - self.ssc = StreamingContext(self.sc, self.duration) - - def tearDown(self): - if self.ssc is not None: - self.ssc.stop(False) - # Clean up in the JVM just in case there has been some issues in Python API - try: - jStreamingContextOption = StreamingContext._jvm.SparkContext.getActive() - if jStreamingContextOption.nonEmpty(): - jStreamingContextOption.get().stop(False) - except: - pass - - def wait_for(self, result, n): - start_time = time.time() - while len(result) < n and time.time() - start_time < self.timeout: - time.sleep(0.01) - if len(result) < n: - print("timeout after", self.timeout) - - def _take(self, dstream, n): - """ - Return the first `n` elements in the stream (will start and stop). - """ - results = [] - - def take(_, rdd): - if rdd and len(results) < n: - results.extend(rdd.take(n - len(results))) - - dstream.foreachRDD(take) - - self.ssc.start() - self.wait_for(results, n) - return results - - def _collect(self, dstream, n, block=True): - """ - Collect each RDDs into the returned list. - - :return: list, which will have the collected items. - """ - result = [] - - def get_output(_, rdd): - if rdd and len(result) < n: - r = rdd.collect() - if r: - result.append(r) - - dstream.foreachRDD(get_output) - - if not block: - return result - - self.ssc.start() - self.wait_for(result, n) - return result - - def _test_func(self, input, func, expected, sort=False, input2=None): - """ - @param input: dataset for the test. This should be list of lists. - @param func: wrapped function. This function should return PythonDStream object. - @param expected: expected output for this testcase. - """ - if not isinstance(input[0], RDD): - input = [self.sc.parallelize(d, 1) for d in input] - input_stream = self.ssc.queueStream(input) - if input2 and not isinstance(input2[0], RDD): - input2 = [self.sc.parallelize(d, 1) for d in input2] - input_stream2 = self.ssc.queueStream(input2) if input2 is not None else None - - # Apply test function to stream. - if input2: - stream = func(input_stream, input_stream2) - else: - stream = func(input_stream) - - result = self._collect(stream, len(expected)) - if sort: - self._sort_result_based_on_key(result) - self._sort_result_based_on_key(expected) - self.assertEqual(expected, result) - - def _sort_result_based_on_key(self, outputs): - """Sort the list based on first value.""" - for output in outputs: - output.sort(key=lambda x: x[0]) +from pyspark import SparkConf, SparkContext, RDD +from pyspark.streaming import StreamingContext +from pyspark.testing.streamingutils import PySparkStreamingTestCase class BasicOperationTests(PySparkStreamingTestCase): @@ -522,135 +389,6 @@ def failed_func(i): self.fail("a failed func should throw an error") -class StreamingListenerTests(PySparkStreamingTestCase): - - duration = .5 - - class BatchInfoCollector(StreamingListener): - - def __init__(self): - super(StreamingListener, self).__init__() - self.batchInfosCompleted = [] - self.batchInfosStarted = [] - self.batchInfosSubmitted = [] - self.streamingStartedTime = [] - - def onStreamingStarted(self, streamingStarted): - self.streamingStartedTime.append(streamingStarted.time) - - def onBatchSubmitted(self, batchSubmitted): - self.batchInfosSubmitted.append(batchSubmitted.batchInfo()) - - def onBatchStarted(self, batchStarted): - self.batchInfosStarted.append(batchStarted.batchInfo()) - - def onBatchCompleted(self, batchCompleted): - self.batchInfosCompleted.append(batchCompleted.batchInfo()) - - def test_batch_info_reports(self): - batch_collector = self.BatchInfoCollector() - self.ssc.addStreamingListener(batch_collector) - input = [[1], [2], [3], [4]] - - def func(dstream): - return dstream.map(int) - expected = [[1], [2], [3], [4]] - self._test_func(input, func, expected) - - batchInfosSubmitted = batch_collector.batchInfosSubmitted - batchInfosStarted = batch_collector.batchInfosStarted - batchInfosCompleted = batch_collector.batchInfosCompleted - streamingStartedTime = batch_collector.streamingStartedTime - - self.wait_for(batchInfosCompleted, 4) - - self.assertEqual(len(streamingStartedTime), 1) - - self.assertGreaterEqual(len(batchInfosSubmitted), 4) - for info in batchInfosSubmitted: - self.assertGreaterEqual(info.batchTime().milliseconds(), 0) - self.assertGreaterEqual(info.submissionTime(), 0) - - for streamId in info.streamIdToInputInfo(): - streamInputInfo = info.streamIdToInputInfo()[streamId] - self.assertGreaterEqual(streamInputInfo.inputStreamId(), 0) - self.assertGreaterEqual(streamInputInfo.numRecords, 0) - for key in streamInputInfo.metadata(): - self.assertIsNotNone(streamInputInfo.metadata()[key]) - self.assertIsNotNone(streamInputInfo.metadataDescription()) - - for outputOpId in info.outputOperationInfos(): - outputInfo = info.outputOperationInfos()[outputOpId] - self.assertGreaterEqual(outputInfo.batchTime().milliseconds(), 0) - self.assertGreaterEqual(outputInfo.id(), 0) - self.assertIsNotNone(outputInfo.name()) - self.assertIsNotNone(outputInfo.description()) - self.assertGreaterEqual(outputInfo.startTime(), -1) - self.assertGreaterEqual(outputInfo.endTime(), -1) - self.assertIsNone(outputInfo.failureReason()) - - self.assertEqual(info.schedulingDelay(), -1) - self.assertEqual(info.processingDelay(), -1) - self.assertEqual(info.totalDelay(), -1) - self.assertEqual(info.numRecords(), 0) - - self.assertGreaterEqual(len(batchInfosStarted), 4) - for info in batchInfosStarted: - self.assertGreaterEqual(info.batchTime().milliseconds(), 0) - self.assertGreaterEqual(info.submissionTime(), 0) - - for streamId in info.streamIdToInputInfo(): - streamInputInfo = info.streamIdToInputInfo()[streamId] - self.assertGreaterEqual(streamInputInfo.inputStreamId(), 0) - self.assertGreaterEqual(streamInputInfo.numRecords, 0) - for key in streamInputInfo.metadata(): - self.assertIsNotNone(streamInputInfo.metadata()[key]) - self.assertIsNotNone(streamInputInfo.metadataDescription()) - - for outputOpId in info.outputOperationInfos(): - outputInfo = info.outputOperationInfos()[outputOpId] - self.assertGreaterEqual(outputInfo.batchTime().milliseconds(), 0) - self.assertGreaterEqual(outputInfo.id(), 0) - self.assertIsNotNone(outputInfo.name()) - self.assertIsNotNone(outputInfo.description()) - self.assertGreaterEqual(outputInfo.startTime(), -1) - self.assertGreaterEqual(outputInfo.endTime(), -1) - self.assertIsNone(outputInfo.failureReason()) - - self.assertGreaterEqual(info.schedulingDelay(), 0) - self.assertEqual(info.processingDelay(), -1) - self.assertEqual(info.totalDelay(), -1) - self.assertEqual(info.numRecords(), 0) - - self.assertGreaterEqual(len(batchInfosCompleted), 4) - for info in batchInfosCompleted: - self.assertGreaterEqual(info.batchTime().milliseconds(), 0) - self.assertGreaterEqual(info.submissionTime(), 0) - - for streamId in info.streamIdToInputInfo(): - streamInputInfo = info.streamIdToInputInfo()[streamId] - self.assertGreaterEqual(streamInputInfo.inputStreamId(), 0) - self.assertGreaterEqual(streamInputInfo.numRecords, 0) - for key in streamInputInfo.metadata(): - self.assertIsNotNone(streamInputInfo.metadata()[key]) - self.assertIsNotNone(streamInputInfo.metadataDescription()) - - for outputOpId in info.outputOperationInfos(): - outputInfo = info.outputOperationInfos()[outputOpId] - self.assertGreaterEqual(outputInfo.batchTime().milliseconds(), 0) - self.assertGreaterEqual(outputInfo.id(), 0) - self.assertIsNotNone(outputInfo.name()) - self.assertIsNotNone(outputInfo.description()) - self.assertGreaterEqual(outputInfo.startTime(), 0) - self.assertGreaterEqual(outputInfo.endTime(), 0) - self.assertIsNone(outputInfo.failureReason()) - - self.assertGreaterEqual(info.schedulingDelay(), 0) - self.assertGreaterEqual(info.processingDelay(), 0) - self.assertGreaterEqual(info.totalDelay(), 0) - self.assertEqual(info.numRecords(), 0) - - class WindowFunctionTests(PySparkStreamingTestCase): timeout = 15 @@ -728,156 +466,6 @@ def func(dstream): self._test_func(input, func, expected) -class StreamingContextTests(PySparkStreamingTestCase): - - duration = 0.1 - setupCalled = False - - def _add_input_stream(self): - inputs = [range(1, x) for x in range(101)] - stream = self.ssc.queueStream(inputs) - self._collect(stream, 1, block=False) - - def test_stop_only_streaming_context(self): - self._add_input_stream() - self.ssc.start() - self.ssc.stop(False) - self.assertEqual(len(self.sc.parallelize(range(5), 5).glom().collect()), 5) - - def test_stop_multiple_times(self): - self._add_input_stream() - self.ssc.start() - self.ssc.stop(False) - self.ssc.stop(False) - - def test_queue_stream(self): - input = [list(range(i + 1)) for i in range(3)] - dstream = self.ssc.queueStream(input) - result = self._collect(dstream, 3) - self.assertEqual(input, result) - - def test_text_file_stream(self): - d = tempfile.mkdtemp() - self.ssc = StreamingContext(self.sc, self.duration) - dstream2 = self.ssc.textFileStream(d).map(int) - result = self._collect(dstream2, 2, block=False) - self.ssc.start() - for name in ('a', 'b'): - time.sleep(1) - with open(os.path.join(d, name), "w") as f: - f.writelines(["%d\n" % i for i in range(10)]) - self.wait_for(result, 2) - self.assertEqual([list(range(10)), list(range(10))], result) - - def test_binary_records_stream(self): - d = tempfile.mkdtemp() - self.ssc = StreamingContext(self.sc, self.duration) - dstream = self.ssc.binaryRecordsStream(d, 10).map( - lambda v: struct.unpack("10b", bytes(v))) - result = self._collect(dstream, 2, block=False) - self.ssc.start() - for name in ('a', 'b'): - time.sleep(1) - with open(os.path.join(d, name), "wb") as f: - f.write(bytearray(range(10))) - self.wait_for(result, 2) - self.assertEqual([list(range(10)), list(range(10))], [list(v[0]) for v in result]) - - def test_union(self): - input = [list(range(i + 1)) for i in range(3)] - dstream = self.ssc.queueStream(input) - dstream2 = self.ssc.queueStream(input) - dstream3 = self.ssc.union(dstream, dstream2) - result = self._collect(dstream3, 3) - expected = [i * 2 for i in input] - self.assertEqual(expected, result) - - def test_transform(self): - dstream1 = self.ssc.queueStream([[1]]) - dstream2 = self.ssc.queueStream([[2]]) - dstream3 = self.ssc.queueStream([[3]]) - - def func(rdds): - rdd1, rdd2, rdd3 = rdds - return rdd2.union(rdd3).union(rdd1) - - dstream = self.ssc.transform([dstream1, dstream2, dstream3], func) - - self.assertEqual([2, 3, 1], self._take(dstream, 3)) - - def test_transform_pairrdd(self): - # This regression test case is for SPARK-17756. - dstream = self.ssc.queueStream( - [[1], [2], [3]]).transform(lambda rdd: rdd.cartesian(rdd)) - self.assertEqual([(1, 1), (2, 2), (3, 3)], self._take(dstream, 3)) - - def test_get_active(self): - self.assertEqual(StreamingContext.getActive(), None) - - # Verify that getActive() returns the active context - self.ssc.queueStream([[1]]).foreachRDD(lambda rdd: rdd.count()) - self.ssc.start() - self.assertEqual(StreamingContext.getActive(), self.ssc) - - # Verify that getActive() returns None - self.ssc.stop(False) - self.assertEqual(StreamingContext.getActive(), None) - - # Verify that if the Java context is stopped, then getActive() returns None - self.ssc = StreamingContext(self.sc, self.duration) - self.ssc.queueStream([[1]]).foreachRDD(lambda rdd: rdd.count()) - self.ssc.start() - self.assertEqual(StreamingContext.getActive(), self.ssc) - self.ssc._jssc.stop(False) - self.assertEqual(StreamingContext.getActive(), None) - - def test_get_active_or_create(self): - # Test StreamingContext.getActiveOrCreate() without checkpoint data - # See CheckpointTests for tests with checkpoint data - self.ssc = None - self.assertEqual(StreamingContext.getActive(), None) - - def setupFunc(): - ssc = StreamingContext(self.sc, self.duration) - ssc.queueStream([[1]]).foreachRDD(lambda rdd: rdd.count()) - self.setupCalled = True - return ssc - - # Verify that getActiveOrCreate() (w/o checkpoint) calls setupFunc when no context is active - self.setupCalled = False - self.ssc = StreamingContext.getActiveOrCreate(None, setupFunc) - self.assertTrue(self.setupCalled) - - # Verify that getActiveOrCreate() returns active context and does not call the setupFunc - self.ssc.start() - self.setupCalled = False - self.assertEqual(StreamingContext.getActiveOrCreate(None, setupFunc), self.ssc) - self.assertFalse(self.setupCalled) - - # Verify that getActiveOrCreate() calls setupFunc after active context is stopped - self.ssc.stop(False) - self.setupCalled = False - self.ssc = StreamingContext.getActiveOrCreate(None, setupFunc) - self.assertTrue(self.setupCalled) - - # Verify that if the Java context is stopped, then getActive() returns None - self.ssc = StreamingContext(self.sc, self.duration) - self.ssc.queueStream([[1]]).foreachRDD(lambda rdd: rdd.count()) - self.ssc.start() - self.assertEqual(StreamingContext.getActive(), self.ssc) - self.ssc._jssc.stop(False) - self.setupCalled = False - self.ssc = StreamingContext.getActiveOrCreate(None, setupFunc) - self.assertTrue(self.setupCalled) - - def test_await_termination_or_timeout(self): - self._add_input_stream() - self.ssc.start() - self.assertFalse(self.ssc.awaitTerminationOrTimeout(0.001)) - self.ssc.stop(False) - self.assertTrue(self.ssc.awaitTerminationOrTimeout(0.001)) - - class CheckpointTests(unittest.TestCase): setupCalled = False @@ -1042,140 +630,11 @@ def check_output(n): self.ssc.stop(True, True) -class KinesisStreamTests(PySparkStreamingTestCase): - - def test_kinesis_stream_api(self): - # Don't start the StreamingContext because we cannot test it in Jenkins - kinesisStream1 = KinesisUtils.createStream( - self.ssc, "myAppNam", "mySparkStream", - "https://kinesis.us-west-2.amazonaws.com", "us-west-2", - InitialPositionInStream.LATEST, 2, StorageLevel.MEMORY_AND_DISK_2) - kinesisStream2 = KinesisUtils.createStream( - self.ssc, "myAppNam", "mySparkStream", - "https://kinesis.us-west-2.amazonaws.com", "us-west-2", - InitialPositionInStream.LATEST, 2, StorageLevel.MEMORY_AND_DISK_2, - "awsAccessKey", "awsSecretKey") - - def test_kinesis_stream(self): - if not are_kinesis_tests_enabled: - sys.stderr.write( - "Skipped test_kinesis_stream (enable by setting environment variable %s=1" - % kinesis_test_environ_var) - return - - import random - kinesisAppName = ("KinesisStreamTests-%d" % abs(random.randint(0, 10000000))) - kinesisTestUtils = self.ssc._jvm.org.apache.spark.streaming.kinesis.KinesisTestUtils(2) - try: - kinesisTestUtils.createStream() - aWSCredentials = kinesisTestUtils.getAWSCredentials() - stream = KinesisUtils.createStream( - self.ssc, kinesisAppName, kinesisTestUtils.streamName(), - kinesisTestUtils.endpointUrl(), kinesisTestUtils.regionName(), - InitialPositionInStream.LATEST, 10, StorageLevel.MEMORY_ONLY, - aWSCredentials.getAWSAccessKeyId(), aWSCredentials.getAWSSecretKey()) - - outputBuffer = [] - - def get_output(_, rdd): - for e in rdd.collect(): - outputBuffer.append(e) - - stream.foreachRDD(get_output) - self.ssc.start() - - testData = [i for i in range(1, 11)] - expectedOutput = set([str(i) for i in testData]) - start_time = time.time() - while time.time() - start_time < 120: - kinesisTestUtils.pushData(testData) - if expectedOutput == set(outputBuffer): - break - time.sleep(10) - self.assertEqual(expectedOutput, set(outputBuffer)) - except: - import traceback - traceback.print_exc() - raise - finally: - self.ssc.stop(False) - kinesisTestUtils.deleteStream() - kinesisTestUtils.deleteDynamoDBTable(kinesisAppName) - - -# Search jar in the project dir using the jar name_prefix for both sbt build and maven build because -# the artifact jars are in different directories. -def search_jar(dir, name_prefix): - # We should ignore the following jars - ignored_jar_suffixes = ("javadoc.jar", "sources.jar", "test-sources.jar", "tests.jar") - jars = (glob.glob(os.path.join(dir, "target/scala-*/" + name_prefix + "-*.jar")) + # sbt build - glob.glob(os.path.join(dir, "target/" + name_prefix + "_*.jar"))) # maven build - return [jar for jar in jars if not jar.endswith(ignored_jar_suffixes)] - - -def _kinesis_asl_assembly_dir(): - SPARK_HOME = os.environ["SPARK_HOME"] - return os.path.join(SPARK_HOME, "external/kinesis-asl-assembly") - - -def search_kinesis_asl_assembly_jar(): - jars = search_jar(_kinesis_asl_assembly_dir(), "spark-streaming-kinesis-asl-assembly") - if not jars: - return None - elif len(jars) > 1: - raise Exception(("Found multiple Spark Streaming Kinesis ASL assembly JARs: %s; please " - "remove all but one") % (", ".join(jars))) - else: - return jars[0] - - -# Must be same as the variable and condition defined in KinesisTestUtils.scala and modules.py -kinesis_test_environ_var = "ENABLE_KINESIS_TESTS" -are_kinesis_tests_enabled = os.environ.get(kinesis_test_environ_var) == '1' - if __name__ == "__main__": - from pyspark.streaming.tests import * - kinesis_asl_assembly_jar = search_kinesis_asl_assembly_jar() - - if kinesis_asl_assembly_jar is None: - kinesis_jar_present = False - jars_args = "" - else: - kinesis_jar_present = True - jars_args = "--jars %s" % kinesis_asl_assembly_jar - - existing_args = os.environ.get("PYSPARK_SUBMIT_ARGS", "pyspark-shell") - os.environ["PYSPARK_SUBMIT_ARGS"] = " ".join([jars_args, existing_args]) - testcases = [BasicOperationTests, WindowFunctionTests, StreamingContextTests, CheckpointTests, - StreamingListenerTests] - - if kinesis_jar_present is True: - testcases.append(KinesisStreamTests) - elif are_kinesis_tests_enabled is False: - sys.stderr.write("Skipping all Kinesis Python tests as the optional Kinesis project was " - "not compiled into a JAR. To run these tests, " - "you need to build Spark with 'build/sbt -Pkinesis-asl assembly/package " - "streaming-kinesis-asl-assembly/assembly' or " - "'build/mvn -Pkinesis-asl package' before running this test.") - else: - raise Exception( - ("Failed to find Spark Streaming Kinesis assembly jar in %s. " - % _kinesis_asl_assembly_dir()) + - "You need to build Spark with 'build/sbt -Pkinesis-asl " - "assembly/package streaming-kinesis-asl-assembly/assembly'" - "or 'build/mvn -Pkinesis-asl package' before running this test.") - - sys.stderr.write("Running tests: %s \n" % (str(testcases))) - failed = False - for testcase in testcases: - sys.stderr.write("[Running %s]\n" % (testcase)) - tests = unittest.TestLoader().loadTestsFromTestCase(testcase) - runner = unishark.BufferedTestRunner( - verbosity=2, - reporters=[unishark.XUnitReporter('target/test-reports/pyspark.streaming_{}'.format( - os.path.basename(os.environ.get("PYSPARK_PYTHON", ""))))]) - - result = runner.run(tests) - if not result.wasSuccessful(): - failed = True - sys.exit(failed) + from pyspark.streaming.tests.test_dstream import * + + try: + import xmlrunner + unittest.main(testRunner=xmlrunner.XMLTestRunner(output='target/test-reports'), verbosity=2) + except ImportError: + unittest.main(verbosity=2) diff --git a/python/pyspark/streaming/tests/test_kinesis.py b/python/pyspark/streaming/tests/test_kinesis.py new file mode 100644 index 0000000000000..d8a0b47f04097 --- /dev/null +++ b/python/pyspark/streaming/tests/test_kinesis.py @@ -0,0 +1,89 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +import time +import unittest + +from pyspark import StorageLevel +from pyspark.streaming.kinesis import KinesisUtils, InitialPositionInStream +from pyspark.testing.streamingutils import should_test_kinesis, kinesis_requirement_message, \ + PySparkStreamingTestCase + + +@unittest.skipIf(not should_test_kinesis, kinesis_requirement_message) +class KinesisStreamTests(PySparkStreamingTestCase): + + def test_kinesis_stream_api(self): + # Don't start the StreamingContext because we cannot test it in Jenkins + KinesisUtils.createStream( + self.ssc, "myAppNam", "mySparkStream", + "https://kinesis.us-west-2.amazonaws.com", "us-west-2", + InitialPositionInStream.LATEST, 2, StorageLevel.MEMORY_AND_DISK_2) + KinesisUtils.createStream( + self.ssc, "myAppNam", "mySparkStream", + "https://kinesis.us-west-2.amazonaws.com", "us-west-2", + InitialPositionInStream.LATEST, 2, StorageLevel.MEMORY_AND_DISK_2, + "awsAccessKey", "awsSecretKey") + + def test_kinesis_stream(self): + import random + kinesisAppName = ("KinesisStreamTests-%d" % abs(random.randint(0, 10000000))) + kinesisTestUtils = self.ssc._jvm.org.apache.spark.streaming.kinesis.KinesisTestUtils(2) + try: + kinesisTestUtils.createStream() + aWSCredentials = kinesisTestUtils.getAWSCredentials() + stream = KinesisUtils.createStream( + self.ssc, kinesisAppName, kinesisTestUtils.streamName(), + kinesisTestUtils.endpointUrl(), kinesisTestUtils.regionName(), + InitialPositionInStream.LATEST, 10, StorageLevel.MEMORY_ONLY, + aWSCredentials.getAWSAccessKeyId(), aWSCredentials.getAWSSecretKey()) + + outputBuffer = [] + + def get_output(_, rdd): + for e in rdd.collect(): + outputBuffer.append(e) + + stream.foreachRDD(get_output) + self.ssc.start() + + testData = [i for i in range(1, 11)] + expectedOutput = set([str(i) for i in testData]) + start_time = time.time() + while time.time() - start_time < 120: + kinesisTestUtils.pushData(testData) + if expectedOutput == set(outputBuffer): + break + time.sleep(10) + self.assertEqual(expectedOutput, set(outputBuffer)) + except: + import traceback + traceback.print_exc() + raise + finally: + self.ssc.stop(False) + kinesisTestUtils.deleteStream() + kinesisTestUtils.deleteDynamoDBTable(kinesisAppName) + + +if __name__ == "__main__": + from pyspark.streaming.tests.test_kinesis import * + + try: + import xmlrunner + unittest.main(testRunner=xmlrunner.XMLTestRunner(output='target/test-reports'), verbosity=2) + except ImportError: + unittest.main(verbosity=2) diff --git a/python/pyspark/streaming/tests/test_listener.py b/python/pyspark/streaming/tests/test_listener.py new file mode 100644 index 0000000000000..7c874b6b32500 --- /dev/null +++ b/python/pyspark/streaming/tests/test_listener.py @@ -0,0 +1,158 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +from pyspark.streaming import StreamingListener +from pyspark.testing.streamingutils import PySparkStreamingTestCase + + +class StreamingListenerTests(PySparkStreamingTestCase): + + duration = .5 + + class BatchInfoCollector(StreamingListener): + + def __init__(self): + super(StreamingListener, self).__init__() + self.batchInfosCompleted = [] + self.batchInfosStarted = [] + self.batchInfosSubmitted = [] + self.streamingStartedTime = [] + + def onStreamingStarted(self, streamingStarted): + self.streamingStartedTime.append(streamingStarted.time) + + def onBatchSubmitted(self, batchSubmitted): + self.batchInfosSubmitted.append(batchSubmitted.batchInfo()) + + def onBatchStarted(self, batchStarted): + self.batchInfosStarted.append(batchStarted.batchInfo()) + + def onBatchCompleted(self, batchCompleted): + self.batchInfosCompleted.append(batchCompleted.batchInfo()) + + def test_batch_info_reports(self): + batch_collector = self.BatchInfoCollector() + self.ssc.addStreamingListener(batch_collector) + input = [[1], [2], [3], [4]] + + def func(dstream): + return dstream.map(int) + expected = [[1], [2], [3], [4]] + self._test_func(input, func, expected) + + batchInfosSubmitted = batch_collector.batchInfosSubmitted + batchInfosStarted = batch_collector.batchInfosStarted + batchInfosCompleted = batch_collector.batchInfosCompleted + streamingStartedTime = batch_collector.streamingStartedTime + + self.wait_for(batchInfosCompleted, 4) + + self.assertEqual(len(streamingStartedTime), 1) + + self.assertGreaterEqual(len(batchInfosSubmitted), 4) + for info in batchInfosSubmitted: + self.assertGreaterEqual(info.batchTime().milliseconds(), 0) + self.assertGreaterEqual(info.submissionTime(), 0) + + for streamId in info.streamIdToInputInfo(): + streamInputInfo = info.streamIdToInputInfo()[streamId] + self.assertGreaterEqual(streamInputInfo.inputStreamId(), 0) + self.assertGreaterEqual(streamInputInfo.numRecords, 0) + for key in streamInputInfo.metadata(): + self.assertIsNotNone(streamInputInfo.metadata()[key]) + self.assertIsNotNone(streamInputInfo.metadataDescription()) + + for outputOpId in info.outputOperationInfos(): + outputInfo = info.outputOperationInfos()[outputOpId] + self.assertGreaterEqual(outputInfo.batchTime().milliseconds(), 0) + self.assertGreaterEqual(outputInfo.id(), 0) + self.assertIsNotNone(outputInfo.name()) + self.assertIsNotNone(outputInfo.description()) + self.assertGreaterEqual(outputInfo.startTime(), -1) + self.assertGreaterEqual(outputInfo.endTime(), -1) + self.assertIsNone(outputInfo.failureReason()) + + self.assertEqual(info.schedulingDelay(), -1) + self.assertEqual(info.processingDelay(), -1) + self.assertEqual(info.totalDelay(), -1) + self.assertEqual(info.numRecords(), 0) + + self.assertGreaterEqual(len(batchInfosStarted), 4) + for info in batchInfosStarted: + self.assertGreaterEqual(info.batchTime().milliseconds(), 0) + self.assertGreaterEqual(info.submissionTime(), 0) + + for streamId in info.streamIdToInputInfo(): + streamInputInfo = info.streamIdToInputInfo()[streamId] + self.assertGreaterEqual(streamInputInfo.inputStreamId(), 0) + self.assertGreaterEqual(streamInputInfo.numRecords, 0) + for key in streamInputInfo.metadata(): + self.assertIsNotNone(streamInputInfo.metadata()[key]) + self.assertIsNotNone(streamInputInfo.metadataDescription()) + + for outputOpId in info.outputOperationInfos(): + outputInfo = info.outputOperationInfos()[outputOpId] + self.assertGreaterEqual(outputInfo.batchTime().milliseconds(), 0) + self.assertGreaterEqual(outputInfo.id(), 0) + self.assertIsNotNone(outputInfo.name()) + self.assertIsNotNone(outputInfo.description()) + self.assertGreaterEqual(outputInfo.startTime(), -1) + self.assertGreaterEqual(outputInfo.endTime(), -1) + self.assertIsNone(outputInfo.failureReason()) + + self.assertGreaterEqual(info.schedulingDelay(), 0) + self.assertEqual(info.processingDelay(), -1) + self.assertEqual(info.totalDelay(), -1) + self.assertEqual(info.numRecords(), 0) + + self.assertGreaterEqual(len(batchInfosCompleted), 4) + for info in batchInfosCompleted: + self.assertGreaterEqual(info.batchTime().milliseconds(), 0) + self.assertGreaterEqual(info.submissionTime(), 0) + + for streamId in info.streamIdToInputInfo(): + streamInputInfo = info.streamIdToInputInfo()[streamId] + self.assertGreaterEqual(streamInputInfo.inputStreamId(), 0) + self.assertGreaterEqual(streamInputInfo.numRecords, 0) + for key in streamInputInfo.metadata(): + self.assertIsNotNone(streamInputInfo.metadata()[key]) + self.assertIsNotNone(streamInputInfo.metadataDescription()) + + for outputOpId in info.outputOperationInfos(): + outputInfo = info.outputOperationInfos()[outputOpId] + self.assertGreaterEqual(outputInfo.batchTime().milliseconds(), 0) + self.assertGreaterEqual(outputInfo.id(), 0) + self.assertIsNotNone(outputInfo.name()) + self.assertIsNotNone(outputInfo.description()) + self.assertGreaterEqual(outputInfo.startTime(), 0) + self.assertGreaterEqual(outputInfo.endTime(), 0) + self.assertIsNone(outputInfo.failureReason()) + + self.assertGreaterEqual(info.schedulingDelay(), 0) + self.assertGreaterEqual(info.processingDelay(), 0) + self.assertGreaterEqual(info.totalDelay(), 0) + self.assertEqual(info.numRecords(), 0) + + +if __name__ == "__main__": + import unittest + from pyspark.streaming.tests.test_listener import * + + try: + import xmlrunner + unittest.main(testRunner=xmlrunner.XMLTestRunner(output='target/test-reports'), verbosity=2) + except ImportError: + unittest.main(verbosity=2) diff --git a/python/pyspark/taskcontext.py b/python/pyspark/taskcontext.py index b61643eb0a16e..98b505c9046be 100644 --- a/python/pyspark/taskcontext.py +++ b/python/pyspark/taskcontext.py @@ -147,8 +147,8 @@ def __init__(self): @classmethod def _getOrCreate(cls): """Internal function to get or create global BarrierTaskContext.""" - if cls._taskContext is None: - cls._taskContext = BarrierTaskContext() + if not isinstance(cls._taskContext, BarrierTaskContext): + cls._taskContext = object.__new__(cls) return cls._taskContext @classmethod diff --git a/python/pyspark/test_serializers.py b/python/pyspark/test_serializers.py deleted file mode 100644 index 5b43729f9ebb1..0000000000000 --- a/python/pyspark/test_serializers.py +++ /dev/null @@ -1,90 +0,0 @@ -# -# Licensed to the Apache Software Foundation (ASF) under one or more -# contributor license agreements. See the NOTICE file distributed with -# this work for additional information regarding copyright ownership. -# The ASF licenses this file to You under the Apache License, Version 2.0 -# (the "License"); you may not use this file except in compliance with -# the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# - -import io -import math -import struct -import sys -import unittest - -try: - import xmlrunner -except ImportError: - xmlrunner = None - -from pyspark import serializers - - -def read_int(b): - return struct.unpack("!i", b)[0] - - -def write_int(i): - return struct.pack("!i", i) - - -class SerializersTest(unittest.TestCase): - - def test_chunked_stream(self): - original_bytes = bytearray(range(100)) - for data_length in [1, 10, 100]: - for buffer_length in [1, 2, 3, 5, 20, 99, 100, 101, 500]: - dest = ByteArrayOutput() - stream_out = serializers.ChunkedStream(dest, buffer_length) - stream_out.write(original_bytes[:data_length]) - stream_out.close() - num_chunks = int(math.ceil(float(data_length) / buffer_length)) - # length for each chunk, and a final -1 at the very end - exp_size = (num_chunks + 1) * 4 + data_length - self.assertEqual(len(dest.buffer), exp_size) - dest_pos = 0 - data_pos = 0 - for chunk_idx in range(num_chunks): - chunk_length = read_int(dest.buffer[dest_pos:(dest_pos + 4)]) - if chunk_idx == num_chunks - 1: - exp_length = data_length % buffer_length - if exp_length == 0: - exp_length = buffer_length - else: - exp_length = buffer_length - self.assertEqual(chunk_length, exp_length) - dest_pos += 4 - dest_chunk = dest.buffer[dest_pos:dest_pos + chunk_length] - orig_chunk = original_bytes[data_pos:data_pos + chunk_length] - self.assertEqual(dest_chunk, orig_chunk) - dest_pos += chunk_length - data_pos += chunk_length - # ends with a -1 - self.assertEqual(dest.buffer[-4:], write_int(-1)) - - -class ByteArrayOutput(object): - def __init__(self): - self.buffer = bytearray() - - def write(self, b): - self.buffer += b - - def close(self): - pass - -if __name__ == '__main__': - from pyspark.test_serializers import * - if xmlrunner: - unittest.main(testRunner=xmlrunner.XMLTestRunner(output='target/test-reports'), verbosity=2) - else: - unittest.main(verbosity=2) diff --git a/python/pyspark/testing/__init__.py b/python/pyspark/testing/__init__.py new file mode 100644 index 0000000000000..12bdf0d0175b6 --- /dev/null +++ b/python/pyspark/testing/__init__.py @@ -0,0 +1,16 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# diff --git a/python/pyspark/testing/mllibutils.py b/python/pyspark/testing/mllibutils.py new file mode 100644 index 0000000000000..25f1bba8d37ac --- /dev/null +++ b/python/pyspark/testing/mllibutils.py @@ -0,0 +1,35 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import unittest + +from pyspark import SparkContext +from pyspark.serializers import PickleSerializer +from pyspark.sql import SparkSession + + +def make_serializer(): + return PickleSerializer() + + +class MLlibTestCase(unittest.TestCase): + def setUp(self): + self.sc = SparkContext('local[4]', "MLlib tests") + self.spark = SparkSession(self.sc) + + def tearDown(self): + self.spark.stop() diff --git a/python/pyspark/testing/mlutils.py b/python/pyspark/testing/mlutils.py new file mode 100644 index 0000000000000..12bf650a28ee1 --- /dev/null +++ b/python/pyspark/testing/mlutils.py @@ -0,0 +1,161 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import numpy as np + +from pyspark.ml import Estimator, Model, Transformer, UnaryTransformer +from pyspark.ml.param import Param, Params, TypeConverters +from pyspark.ml.util import DefaultParamsReadable, DefaultParamsWritable +from pyspark.ml.wrapper import _java2py +from pyspark.sql import DataFrame, SparkSession +from pyspark.sql.types import DoubleType +from pyspark.testing.utils import ReusedPySparkTestCase as PySparkTestCase + + +def check_params(test_self, py_stage, check_params_exist=True): + """ + Checks common requirements for Params.params: + - set of params exist in Java and Python and are ordered by names + - param parent has the same UID as the object's UID + - default param value from Java matches value in Python + - optionally check if all params from Java also exist in Python + """ + py_stage_str = "%s %s" % (type(py_stage), py_stage) + if not hasattr(py_stage, "_to_java"): + return + java_stage = py_stage._to_java() + if java_stage is None: + return + test_self.assertEqual(py_stage.uid, java_stage.uid(), msg=py_stage_str) + if check_params_exist: + param_names = [p.name for p in py_stage.params] + java_params = list(java_stage.params()) + java_param_names = [jp.name() for jp in java_params] + test_self.assertEqual( + param_names, sorted(java_param_names), + "Param list in Python does not match Java for %s:\nJava = %s\nPython = %s" + % (py_stage_str, java_param_names, param_names)) + for p in py_stage.params: + test_self.assertEqual(p.parent, py_stage.uid) + java_param = java_stage.getParam(p.name) + py_has_default = py_stage.hasDefault(p) + java_has_default = java_stage.hasDefault(java_param) + test_self.assertEqual(py_has_default, java_has_default, + "Default value mismatch of param %s for Params %s" + % (p.name, str(py_stage))) + if py_has_default: + if p.name == "seed": + continue # Random seeds between Spark and PySpark are different + java_default = _java2py(test_self.sc, + java_stage.clear(java_param).getOrDefault(java_param)) + py_stage._clear(p) + py_default = py_stage.getOrDefault(p) + # equality test for NaN is always False + if isinstance(java_default, float) and np.isnan(java_default): + java_default = "NaN" + py_default = "NaN" if np.isnan(py_default) else "not NaN" + test_self.assertEqual( + java_default, py_default, + "Java default %s != python default %s of param %s for Params %s" + % (str(java_default), str(py_default), p.name, str(py_stage))) + + +class SparkSessionTestCase(PySparkTestCase): + @classmethod + def setUpClass(cls): + PySparkTestCase.setUpClass() + cls.spark = SparkSession(cls.sc) + + @classmethod + def tearDownClass(cls): + PySparkTestCase.tearDownClass() + cls.spark.stop() + + +class MockDataset(DataFrame): + + def __init__(self): + self.index = 0 + + +class HasFake(Params): + + def __init__(self): + super(HasFake, self).__init__() + self.fake = Param(self, "fake", "fake param") + + def getFake(self): + return self.getOrDefault(self.fake) + + +class MockTransformer(Transformer, HasFake): + + def __init__(self): + super(MockTransformer, self).__init__() + self.dataset_index = None + + def _transform(self, dataset): + self.dataset_index = dataset.index + dataset.index += 1 + return dataset + + +class MockUnaryTransformer(UnaryTransformer, DefaultParamsReadable, DefaultParamsWritable): + + shift = Param(Params._dummy(), "shift", "The amount by which to shift " + + "data in a DataFrame", + typeConverter=TypeConverters.toFloat) + + def __init__(self, shiftVal=1): + super(MockUnaryTransformer, self).__init__() + self._setDefault(shift=1) + self._set(shift=shiftVal) + + def getShift(self): + return self.getOrDefault(self.shift) + + def setShift(self, shift): + self._set(shift=shift) + + def createTransformFunc(self): + shiftVal = self.getShift() + return lambda x: x + shiftVal + + def outputDataType(self): + return DoubleType() + + def validateInputType(self, inputType): + if inputType != DoubleType(): + raise TypeError("Bad input type: {}. ".format(inputType) + + "Requires Double.") + + +class MockEstimator(Estimator, HasFake): + + def __init__(self): + super(MockEstimator, self).__init__() + self.dataset_index = None + + def _fit(self, dataset): + self.dataset_index = dataset.index + model = MockModel() + self._copyValues(model) + return model + + +class MockModel(MockTransformer, Model, HasFake): + pass diff --git a/python/pyspark/testing/sqlutils.py b/python/pyspark/testing/sqlutils.py new file mode 100644 index 0000000000000..afc40ccf4139d --- /dev/null +++ b/python/pyspark/testing/sqlutils.py @@ -0,0 +1,268 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import datetime +import os +import shutil +import tempfile +from contextlib import contextmanager + +from pyspark.sql import SparkSession +from pyspark.sql.types import ArrayType, DoubleType, UserDefinedType, Row +from pyspark.testing.utils import ReusedPySparkTestCase +from pyspark.util import _exception_message + + +pandas_requirement_message = None +try: + from pyspark.sql.utils import require_minimum_pandas_version + require_minimum_pandas_version() +except ImportError as e: + # If Pandas version requirement is not satisfied, skip related tests. + pandas_requirement_message = _exception_message(e) + +pyarrow_requirement_message = None +try: + from pyspark.sql.utils import require_minimum_pyarrow_version + require_minimum_pyarrow_version() +except ImportError as e: + # If Arrow version requirement is not satisfied, skip related tests. + pyarrow_requirement_message = _exception_message(e) + +test_not_compiled_message = None +try: + from pyspark.sql.utils import require_test_compiled + require_test_compiled() +except Exception as e: + test_not_compiled_message = _exception_message(e) + +have_pandas = pandas_requirement_message is None +have_pyarrow = pyarrow_requirement_message is None +test_compiled = test_not_compiled_message is None + + +class UTCOffsetTimezone(datetime.tzinfo): + """ + Specifies timezone in UTC offset + """ + + def __init__(self, offset=0): + self.ZERO = datetime.timedelta(hours=offset) + + def utcoffset(self, dt): + return self.ZERO + + def dst(self, dt): + return self.ZERO + + +class ExamplePointUDT(UserDefinedType): + """ + User-defined type (UDT) for ExamplePoint. + """ + + @classmethod + def sqlType(self): + return ArrayType(DoubleType(), False) + + @classmethod + def module(cls): + return 'pyspark.sql.tests' + + @classmethod + def scalaUDT(cls): + return 'org.apache.spark.sql.test.ExamplePointUDT' + + def serialize(self, obj): + return [obj.x, obj.y] + + def deserialize(self, datum): + return ExamplePoint(datum[0], datum[1]) + + +class ExamplePoint: + """ + An example class to demonstrate UDT in Scala, Java, and Python. + """ + + __UDT__ = ExamplePointUDT() + + def __init__(self, x, y): + self.x = x + self.y = y + + def __repr__(self): + return "ExamplePoint(%s,%s)" % (self.x, self.y) + + def __str__(self): + return "(%s,%s)" % (self.x, self.y) + + def __eq__(self, other): + return isinstance(other, self.__class__) and \ + other.x == self.x and other.y == self.y + + +class PythonOnlyUDT(UserDefinedType): + """ + User-defined type (UDT) for ExamplePoint. + """ + + @classmethod + def sqlType(self): + return ArrayType(DoubleType(), False) + + @classmethod + def module(cls): + return '__main__' + + def serialize(self, obj): + return [obj.x, obj.y] + + def deserialize(self, datum): + return PythonOnlyPoint(datum[0], datum[1]) + + @staticmethod + def foo(): + pass + + @property + def props(self): + return {} + + +class PythonOnlyPoint(ExamplePoint): + """ + An example class to demonstrate UDT in only Python + """ + __UDT__ = PythonOnlyUDT() + + +class MyObject(object): + def __init__(self, key, value): + self.key = key + self.value = value + + +class SQLTestUtils(object): + """ + This util assumes the instance of this to have 'spark' attribute, having a spark session. + It is usually used with 'ReusedSQLTestCase' class but can be used if you feel sure the + the implementation of this class has 'spark' attribute. + """ + + @contextmanager + def sql_conf(self, pairs): + """ + A convenient context manager to test some configuration specific logic. This sets + `value` to the configuration `key` and then restores it back when it exits. + """ + assert isinstance(pairs, dict), "pairs should be a dictionary." + assert hasattr(self, "spark"), "it should have 'spark' attribute, having a spark session." + + keys = pairs.keys() + new_values = pairs.values() + old_values = [self.spark.conf.get(key, None) for key in keys] + for key, new_value in zip(keys, new_values): + self.spark.conf.set(key, new_value) + try: + yield + finally: + for key, old_value in zip(keys, old_values): + if old_value is None: + self.spark.conf.unset(key) + else: + self.spark.conf.set(key, old_value) + + @contextmanager + def database(self, *databases): + """ + A convenient context manager to test with some specific databases. This drops the given + databases if it exists and sets current database to "default" when it exits. + """ + assert hasattr(self, "spark"), "it should have 'spark' attribute, having a spark session." + + try: + yield + finally: + for db in databases: + self.spark.sql("DROP DATABASE IF EXISTS %s CASCADE" % db) + self.spark.catalog.setCurrentDatabase("default") + + @contextmanager + def table(self, *tables): + """ + A convenient context manager to test with some specific tables. This drops the given tables + if it exists. + """ + assert hasattr(self, "spark"), "it should have 'spark' attribute, having a spark session." + + try: + yield + finally: + for t in tables: + self.spark.sql("DROP TABLE IF EXISTS %s" % t) + + @contextmanager + def tempView(self, *views): + """ + A convenient context manager to test with some specific views. This drops the given views + if it exists. + """ + assert hasattr(self, "spark"), "it should have 'spark' attribute, having a spark session." + + try: + yield + finally: + for v in views: + self.spark.catalog.dropTempView(v) + + @contextmanager + def function(self, *functions): + """ + A convenient context manager to test with some specific functions. This drops the given + functions if it exists. + """ + assert hasattr(self, "spark"), "it should have 'spark' attribute, having a spark session." + + try: + yield + finally: + for f in functions: + self.spark.sql("DROP FUNCTION IF EXISTS %s" % f) + + +class ReusedSQLTestCase(ReusedPySparkTestCase, SQLTestUtils): + @classmethod + def setUpClass(cls): + super(ReusedSQLTestCase, cls).setUpClass() + cls.spark = SparkSession(cls.sc) + cls.tempdir = tempfile.NamedTemporaryFile(delete=False) + os.unlink(cls.tempdir.name) + cls.testData = [Row(key=i, value=str(i)) for i in range(100)] + cls.df = cls.spark.createDataFrame(cls.testData) + + @classmethod + def tearDownClass(cls): + super(ReusedSQLTestCase, cls).tearDownClass() + cls.spark.stop() + shutil.rmtree(cls.tempdir.name, ignore_errors=True) + + def assertPandasEqual(self, expected, result): + msg = ("DataFrames are not equal: " + + "\n\nExpected:\n%s\n%s" % (expected, expected.dtypes) + + "\n\nResult:\n%s\n%s" % (result, result.dtypes)) + self.assertTrue(expected.equals(result), msg=msg) diff --git a/python/pyspark/testing/streamingutils.py b/python/pyspark/testing/streamingutils.py new file mode 100644 index 0000000000000..85a2fa14b936c --- /dev/null +++ b/python/pyspark/testing/streamingutils.py @@ -0,0 +1,190 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +import glob +import os +import tempfile +import time +import unittest + +from pyspark import SparkConf, SparkContext, RDD +from pyspark.streaming import StreamingContext + + +def search_kinesis_asl_assembly_jar(): + kinesis_asl_assembly_dir = os.path.join( + os.environ["SPARK_HOME"], "external/kinesis-asl-assembly") + + # We should ignore the following jars + ignored_jar_suffixes = ("javadoc.jar", "sources.jar", "test-sources.jar", "tests.jar") + + # Search jar in the project dir using the jar name_prefix for both sbt build and maven + # build because the artifact jars are in different directories. + name_prefix = "spark-streaming-kinesis-asl-assembly" + sbt_build = glob.glob(os.path.join( + kinesis_asl_assembly_dir, "target/scala-*/%s-*.jar" % name_prefix)) + maven_build = glob.glob(os.path.join( + kinesis_asl_assembly_dir, "target/%s_*.jar" % name_prefix)) + jar_paths = sbt_build + maven_build + jars = [jar for jar in jar_paths if not jar.endswith(ignored_jar_suffixes)] + + if not jars: + return None + elif len(jars) > 1: + raise Exception(("Found multiple Spark Streaming Kinesis ASL assembly JARs: %s; please " + "remove all but one") % (", ".join(jars))) + else: + return jars[0] + + +# Must be same as the variable and condition defined in KinesisTestUtils.scala and modules.py +kinesis_test_environ_var = "ENABLE_KINESIS_TESTS" +should_skip_kinesis_tests = not os.environ.get(kinesis_test_environ_var) == '1' + +if should_skip_kinesis_tests: + kinesis_requirement_message = ( + "Skipping all Kinesis Python tests as environmental variable 'ENABLE_KINESIS_TESTS' " + "was not set.") +else: + kinesis_asl_assembly_jar = search_kinesis_asl_assembly_jar() + if kinesis_asl_assembly_jar is None: + kinesis_requirement_message = ( + "Skipping all Kinesis Python tests as the optional Kinesis project was " + "not compiled into a JAR. To run these tests, " + "you need to build Spark with 'build/sbt -Pkinesis-asl assembly/package " + "streaming-kinesis-asl-assembly/assembly' or " + "'build/mvn -Pkinesis-asl package' before running this test.") + else: + existing_args = os.environ.get("PYSPARK_SUBMIT_ARGS", "pyspark-shell") + jars_args = "--jars %s" % kinesis_asl_assembly_jar + os.environ["PYSPARK_SUBMIT_ARGS"] = " ".join([jars_args, existing_args]) + kinesis_requirement_message = None + +should_test_kinesis = kinesis_requirement_message is None + + +class PySparkStreamingTestCase(unittest.TestCase): + + timeout = 30 # seconds + duration = .5 + + @classmethod + def setUpClass(cls): + class_name = cls.__name__ + conf = SparkConf().set("spark.default.parallelism", 1) + cls.sc = SparkContext(appName=class_name, conf=conf) + cls.sc.setCheckpointDir(tempfile.mkdtemp()) + + @classmethod + def tearDownClass(cls): + cls.sc.stop() + # Clean up in the JVM just in case there has been some issues in Python API + try: + jSparkContextOption = SparkContext._jvm.SparkContext.get() + if jSparkContextOption.nonEmpty(): + jSparkContextOption.get().stop() + except: + pass + + def setUp(self): + self.ssc = StreamingContext(self.sc, self.duration) + + def tearDown(self): + if self.ssc is not None: + self.ssc.stop(False) + # Clean up in the JVM just in case there has been some issues in Python API + try: + jStreamingContextOption = StreamingContext._jvm.SparkContext.getActive() + if jStreamingContextOption.nonEmpty(): + jStreamingContextOption.get().stop(False) + except: + pass + + def wait_for(self, result, n): + start_time = time.time() + while len(result) < n and time.time() - start_time < self.timeout: + time.sleep(0.01) + if len(result) < n: + print("timeout after", self.timeout) + + def _take(self, dstream, n): + """ + Return the first `n` elements in the stream (will start and stop). + """ + results = [] + + def take(_, rdd): + if rdd and len(results) < n: + results.extend(rdd.take(n - len(results))) + + dstream.foreachRDD(take) + + self.ssc.start() + self.wait_for(results, n) + return results + + def _collect(self, dstream, n, block=True): + """ + Collect each RDDs into the returned list. + + :return: list, which will have the collected items. + """ + result = [] + + def get_output(_, rdd): + if rdd and len(result) < n: + r = rdd.collect() + if r: + result.append(r) + + dstream.foreachRDD(get_output) + + if not block: + return result + + self.ssc.start() + self.wait_for(result, n) + return result + + def _test_func(self, input, func, expected, sort=False, input2=None): + """ + @param input: dataset for the test. This should be list of lists. + @param func: wrapped function. This function should return PythonDStream object. + @param expected: expected output for this testcase. + """ + if not isinstance(input[0], RDD): + input = [self.sc.parallelize(d, 1) for d in input] + input_stream = self.ssc.queueStream(input) + if input2 and not isinstance(input2[0], RDD): + input2 = [self.sc.parallelize(d, 1) for d in input2] + input_stream2 = self.ssc.queueStream(input2) if input2 is not None else None + + # Apply test function to stream. + if input2: + stream = func(input_stream, input_stream2) + else: + stream = func(input_stream) + + result = self._collect(stream, len(expected)) + if sort: + self._sort_result_based_on_key(result) + self._sort_result_based_on_key(expected) + self.assertEqual(expected, result) + + def _sort_result_based_on_key(self, outputs): + """Sort the list based on first value.""" + for output in outputs: + output.sort(key=lambda x: x[0]) diff --git a/python/pyspark/testing/utils.py b/python/pyspark/testing/utils.py new file mode 100644 index 0000000000000..7df0acae026f3 --- /dev/null +++ b/python/pyspark/testing/utils.py @@ -0,0 +1,102 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +import os +import struct +import sys +import unittest + +from pyspark import SparkContext, SparkConf + + +have_scipy = False +have_numpy = False +try: + import scipy.sparse + have_scipy = True +except: + # No SciPy, but that's okay, we'll skip those tests + pass +try: + import numpy as np + have_numpy = True +except: + # No NumPy, but that's okay, we'll skip those tests + pass + + +SPARK_HOME = os.environ["SPARK_HOME"] + + +def read_int(b): + return struct.unpack("!i", b)[0] + + +def write_int(i): + return struct.pack("!i", i) + + +class QuietTest(object): + def __init__(self, sc): + self.log4j = sc._jvm.org.apache.log4j + + def __enter__(self): + self.old_level = self.log4j.LogManager.getRootLogger().getLevel() + self.log4j.LogManager.getRootLogger().setLevel(self.log4j.Level.FATAL) + + def __exit__(self, exc_type, exc_val, exc_tb): + self.log4j.LogManager.getRootLogger().setLevel(self.old_level) + + +class PySparkTestCase(unittest.TestCase): + + def setUp(self): + self._old_sys_path = list(sys.path) + class_name = self.__class__.__name__ + self.sc = SparkContext('local[4]', class_name) + + def tearDown(self): + self.sc.stop() + sys.path = self._old_sys_path + + +class ReusedPySparkTestCase(unittest.TestCase): + + @classmethod + def conf(cls): + """ + Override this in subclasses to supply a more specific conf + """ + return SparkConf() + + @classmethod + def setUpClass(cls): + cls.sc = SparkContext('local[4]', cls.__name__, conf=cls.conf()) + + @classmethod + def tearDownClass(cls): + cls.sc.stop() + + +class ByteArrayOutput(object): + def __init__(self): + self.buffer = bytearray() + + def write(self, b): + self.buffer += b + + def close(self): + pass diff --git a/python/pyspark/tests.py b/python/pyspark/tests.py deleted file mode 100644 index c15d443ebbba9..0000000000000 --- a/python/pyspark/tests.py +++ /dev/null @@ -1,2522 +0,0 @@ -# -# Licensed to the Apache Software Foundation (ASF) under one or more -# contributor license agreements. See the NOTICE file distributed with -# this work for additional information regarding copyright ownership. -# The ASF licenses this file to You under the Apache License, Version 2.0 -# (the "License"); you may not use this file except in compliance with -# the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# - -""" -Unit tests for PySpark; additional tests are implemented as doctests in -individual modules. -""" - -from array import array -from glob import glob -import os -import re -import shutil -import subprocess -import sys -import tempfile -import time -import zipfile -import random -import threading -import hashlib - -from py4j.protocol import Py4JJavaError -xmlrunner = None - -if sys.version_info[:2] <= (2, 6): - try: - import unittest2 as unittest - except ImportError: - sys.stderr.write('Please install unittest2 to test with Python 2.6 or earlier') - sys.exit(1) -else: - import unittest - if sys.version_info[0] >= 3: - xrange = range - basestring = str - -import unishark - -if sys.version >= "3": - from io import StringIO -else: - from StringIO import StringIO - - -from pyspark import keyword_only -from pyspark.conf import SparkConf -from pyspark.context import SparkContext -from pyspark.rdd import RDD -from pyspark.files import SparkFiles -from pyspark.serializers import read_int, BatchedSerializer, MarshalSerializer, PickleSerializer, \ - CloudPickleSerializer, CompressedSerializer, UTF8Deserializer, NoOpSerializer, \ - PairDeserializer, CartesianDeserializer, AutoBatchedSerializer, AutoSerializer, \ - FlattenedValuesSerializer -from pyspark.shuffle import Aggregator, ExternalMerger, ExternalSorter -from pyspark import shuffle -from pyspark.profiler import BasicProfiler -from pyspark.taskcontext import BarrierTaskContext, TaskContext - -_have_scipy = False -_have_numpy = False -try: - import scipy.sparse - _have_scipy = True -except: - # No SciPy, but that's okay, we'll skip those tests - pass -try: - import numpy as np - _have_numpy = True -except: - # No NumPy, but that's okay, we'll skip those tests - pass - - -SPARK_HOME = os.environ["SPARK_HOME"] - - -class MergerTests(unittest.TestCase): - - def setUp(self): - self.N = 1 << 12 - self.l = [i for i in xrange(self.N)] - self.data = list(zip(self.l, self.l)) - self.agg = Aggregator(lambda x: [x], - lambda x, y: x.append(y) or x, - lambda x, y: x.extend(y) or x) - - def test_small_dataset(self): - m = ExternalMerger(self.agg, 1000) - m.mergeValues(self.data) - self.assertEqual(m.spills, 0) - self.assertEqual(sum(sum(v) for k, v in m.items()), - sum(xrange(self.N))) - - m = ExternalMerger(self.agg, 1000) - m.mergeCombiners(map(lambda x_y1: (x_y1[0], [x_y1[1]]), self.data)) - self.assertEqual(m.spills, 0) - self.assertEqual(sum(sum(v) for k, v in m.items()), - sum(xrange(self.N))) - - def test_medium_dataset(self): - m = ExternalMerger(self.agg, 20) - m.mergeValues(self.data) - self.assertTrue(m.spills >= 1) - self.assertEqual(sum(sum(v) for k, v in m.items()), - sum(xrange(self.N))) - - m = ExternalMerger(self.agg, 10) - m.mergeCombiners(map(lambda x_y2: (x_y2[0], [x_y2[1]]), self.data * 3)) - self.assertTrue(m.spills >= 1) - self.assertEqual(sum(sum(v) for k, v in m.items()), - sum(xrange(self.N)) * 3) - - def test_huge_dataset(self): - m = ExternalMerger(self.agg, 5, partitions=3) - m.mergeCombiners(map(lambda k_v: (k_v[0], [str(k_v[1])]), self.data * 10)) - self.assertTrue(m.spills >= 1) - self.assertEqual(sum(len(v) for k, v in m.items()), - self.N * 10) - m._cleanup() - - def test_group_by_key(self): - - def gen_data(N, step): - for i in range(1, N + 1, step): - for j in range(i): - yield (i, [j]) - - def gen_gs(N, step=1): - return shuffle.GroupByKey(gen_data(N, step)) - - self.assertEqual(1, len(list(gen_gs(1)))) - self.assertEqual(2, len(list(gen_gs(2)))) - self.assertEqual(100, len(list(gen_gs(100)))) - self.assertEqual(list(range(1, 101)), [k for k, _ in gen_gs(100)]) - self.assertTrue(all(list(range(k)) == list(vs) for k, vs in gen_gs(100))) - - for k, vs in gen_gs(50002, 10000): - self.assertEqual(k, len(vs)) - self.assertEqual(list(range(k)), list(vs)) - - ser = PickleSerializer() - l = ser.loads(ser.dumps(list(gen_gs(50002, 30000)))) - for k, vs in l: - self.assertEqual(k, len(vs)) - self.assertEqual(list(range(k)), list(vs)) - - def test_stopiteration_is_raised(self): - - def stopit(*args, **kwargs): - raise StopIteration() - - def legit_create_combiner(x): - return [x] - - def legit_merge_value(x, y): - return x.append(y) or x - - def legit_merge_combiners(x, y): - return x.extend(y) or x - - data = [(x % 2, x) for x in range(100)] - - # wrong create combiner - m = ExternalMerger(Aggregator(stopit, legit_merge_value, legit_merge_combiners), 20) - with self.assertRaises((Py4JJavaError, RuntimeError)) as cm: - m.mergeValues(data) - - # wrong merge value - m = ExternalMerger(Aggregator(legit_create_combiner, stopit, legit_merge_combiners), 20) - with self.assertRaises((Py4JJavaError, RuntimeError)) as cm: - m.mergeValues(data) - - # wrong merge combiners - m = ExternalMerger(Aggregator(legit_create_combiner, legit_merge_value, stopit), 20) - with self.assertRaises((Py4JJavaError, RuntimeError)) as cm: - m.mergeCombiners(map(lambda x_y1: (x_y1[0], [x_y1[1]]), data)) - - -class SorterTests(unittest.TestCase): - def test_in_memory_sort(self): - l = list(range(1024)) - random.shuffle(l) - sorter = ExternalSorter(1024) - self.assertEqual(sorted(l), list(sorter.sorted(l))) - self.assertEqual(sorted(l, reverse=True), list(sorter.sorted(l, reverse=True))) - self.assertEqual(sorted(l, key=lambda x: -x), list(sorter.sorted(l, key=lambda x: -x))) - self.assertEqual(sorted(l, key=lambda x: -x, reverse=True), - list(sorter.sorted(l, key=lambda x: -x, reverse=True))) - - def test_external_sort(self): - class CustomizedSorter(ExternalSorter): - def _next_limit(self): - return self.memory_limit - l = list(range(1024)) - random.shuffle(l) - sorter = CustomizedSorter(1) - self.assertEqual(sorted(l), list(sorter.sorted(l))) - self.assertGreater(shuffle.DiskBytesSpilled, 0) - last = shuffle.DiskBytesSpilled - self.assertEqual(sorted(l, reverse=True), list(sorter.sorted(l, reverse=True))) - self.assertGreater(shuffle.DiskBytesSpilled, last) - last = shuffle.DiskBytesSpilled - self.assertEqual(sorted(l, key=lambda x: -x), list(sorter.sorted(l, key=lambda x: -x))) - self.assertGreater(shuffle.DiskBytesSpilled, last) - last = shuffle.DiskBytesSpilled - self.assertEqual(sorted(l, key=lambda x: -x, reverse=True), - list(sorter.sorted(l, key=lambda x: -x, reverse=True))) - self.assertGreater(shuffle.DiskBytesSpilled, last) - - def test_external_sort_in_rdd(self): - conf = SparkConf().set("spark.python.worker.memory", "1m") - sc = SparkContext(conf=conf) - l = list(range(10240)) - random.shuffle(l) - rdd = sc.parallelize(l, 4) - self.assertEqual(sorted(l), rdd.sortBy(lambda x: x).collect()) - sc.stop() - - -class SerializationTestCase(unittest.TestCase): - - def test_namedtuple(self): - from collections import namedtuple - from pickle import dumps, loads - P = namedtuple("P", "x y") - p1 = P(1, 3) - p2 = loads(dumps(p1, 2)) - self.assertEqual(p1, p2) - - from pyspark.cloudpickle import dumps - P2 = loads(dumps(P)) - p3 = P2(1, 3) - self.assertEqual(p1, p3) - - def test_itemgetter(self): - from operator import itemgetter - ser = CloudPickleSerializer() - d = range(10) - getter = itemgetter(1) - getter2 = ser.loads(ser.dumps(getter)) - self.assertEqual(getter(d), getter2(d)) - - getter = itemgetter(0, 3) - getter2 = ser.loads(ser.dumps(getter)) - self.assertEqual(getter(d), getter2(d)) - - def test_function_module_name(self): - ser = CloudPickleSerializer() - func = lambda x: x - func2 = ser.loads(ser.dumps(func)) - self.assertEqual(func.__module__, func2.__module__) - - def test_attrgetter(self): - from operator import attrgetter - ser = CloudPickleSerializer() - - class C(object): - def __getattr__(self, item): - return item - d = C() - getter = attrgetter("a") - getter2 = ser.loads(ser.dumps(getter)) - self.assertEqual(getter(d), getter2(d)) - getter = attrgetter("a", "b") - getter2 = ser.loads(ser.dumps(getter)) - self.assertEqual(getter(d), getter2(d)) - - d.e = C() - getter = attrgetter("e.a") - getter2 = ser.loads(ser.dumps(getter)) - self.assertEqual(getter(d), getter2(d)) - getter = attrgetter("e.a", "e.b") - getter2 = ser.loads(ser.dumps(getter)) - self.assertEqual(getter(d), getter2(d)) - - # Regression test for SPARK-3415 - def test_pickling_file_handles(self): - # to be corrected with SPARK-11160 - if not xmlrunner: - ser = CloudPickleSerializer() - out1 = sys.stderr - out2 = ser.loads(ser.dumps(out1)) - self.assertEqual(out1, out2) - - def test_func_globals(self): - - class Unpicklable(object): - def __reduce__(self): - raise Exception("not picklable") - - global exit - exit = Unpicklable() - - ser = CloudPickleSerializer() - self.assertRaises(Exception, lambda: ser.dumps(exit)) - - def foo(): - sys.exit(0) - - self.assertTrue("exit" in foo.__code__.co_names) - ser.dumps(foo) - - def test_compressed_serializer(self): - ser = CompressedSerializer(PickleSerializer()) - try: - from StringIO import StringIO - except ImportError: - from io import BytesIO as StringIO - io = StringIO() - ser.dump_stream(["abc", u"123", range(5)], io) - io.seek(0) - self.assertEqual(["abc", u"123", range(5)], list(ser.load_stream(io))) - ser.dump_stream(range(1000), io) - io.seek(0) - self.assertEqual(["abc", u"123", range(5)] + list(range(1000)), list(ser.load_stream(io))) - io.close() - - def test_hash_serializer(self): - hash(NoOpSerializer()) - hash(UTF8Deserializer()) - hash(PickleSerializer()) - hash(MarshalSerializer()) - hash(AutoSerializer()) - hash(BatchedSerializer(PickleSerializer())) - hash(AutoBatchedSerializer(MarshalSerializer())) - hash(PairDeserializer(NoOpSerializer(), UTF8Deserializer())) - hash(CartesianDeserializer(NoOpSerializer(), UTF8Deserializer())) - hash(CompressedSerializer(PickleSerializer())) - hash(FlattenedValuesSerializer(PickleSerializer())) - - -class QuietTest(object): - def __init__(self, sc): - self.log4j = sc._jvm.org.apache.log4j - - def __enter__(self): - self.old_level = self.log4j.LogManager.getRootLogger().getLevel() - self.log4j.LogManager.getRootLogger().setLevel(self.log4j.Level.FATAL) - - def __exit__(self, exc_type, exc_val, exc_tb): - self.log4j.LogManager.getRootLogger().setLevel(self.old_level) - - -class PySparkTestCase(unittest.TestCase): - - def setUp(self): - self._old_sys_path = list(sys.path) - class_name = self.__class__.__name__ - self.sc = SparkContext('local[4]', class_name) - - def tearDown(self): - self.sc.stop() - sys.path = self._old_sys_path - - -class ReusedPySparkTestCase(unittest.TestCase): - - @classmethod - def conf(cls): - """ - Override this in subclasses to supply a more specific conf - """ - return SparkConf() - - @classmethod - def setUpClass(cls): - cls.sc = SparkContext('local[4]', cls.__name__, conf=cls.conf()) - - @classmethod - def tearDownClass(cls): - cls.sc.stop() - - -class CheckpointTests(ReusedPySparkTestCase): - - def setUp(self): - self.checkpointDir = tempfile.NamedTemporaryFile(delete=False) - os.unlink(self.checkpointDir.name) - self.sc.setCheckpointDir(self.checkpointDir.name) - - def tearDown(self): - shutil.rmtree(self.checkpointDir.name) - - def test_basic_checkpointing(self): - parCollection = self.sc.parallelize([1, 2, 3, 4]) - flatMappedRDD = parCollection.flatMap(lambda x: range(1, x + 1)) - - self.assertFalse(flatMappedRDD.isCheckpointed()) - self.assertTrue(flatMappedRDD.getCheckpointFile() is None) - - flatMappedRDD.checkpoint() - result = flatMappedRDD.collect() - time.sleep(1) # 1 second - self.assertTrue(flatMappedRDD.isCheckpointed()) - self.assertEqual(flatMappedRDD.collect(), result) - self.assertEqual("file:" + self.checkpointDir.name, - os.path.dirname(os.path.dirname(flatMappedRDD.getCheckpointFile()))) - - def test_checkpoint_and_restore(self): - parCollection = self.sc.parallelize([1, 2, 3, 4]) - flatMappedRDD = parCollection.flatMap(lambda x: [x]) - - self.assertFalse(flatMappedRDD.isCheckpointed()) - self.assertTrue(flatMappedRDD.getCheckpointFile() is None) - - flatMappedRDD.checkpoint() - flatMappedRDD.count() # forces a checkpoint to be computed - time.sleep(1) # 1 second - - self.assertTrue(flatMappedRDD.getCheckpointFile() is not None) - recovered = self.sc._checkpointFile(flatMappedRDD.getCheckpointFile(), - flatMappedRDD._jrdd_deserializer) - self.assertEqual([1, 2, 3, 4], recovered.collect()) - - -class LocalCheckpointTests(ReusedPySparkTestCase): - - def test_basic_localcheckpointing(self): - parCollection = self.sc.parallelize([1, 2, 3, 4]) - flatMappedRDD = parCollection.flatMap(lambda x: range(1, x + 1)) - - self.assertFalse(flatMappedRDD.isCheckpointed()) - self.assertFalse(flatMappedRDD.isLocallyCheckpointed()) - - flatMappedRDD.localCheckpoint() - result = flatMappedRDD.collect() - time.sleep(1) # 1 second - self.assertTrue(flatMappedRDD.isCheckpointed()) - self.assertTrue(flatMappedRDD.isLocallyCheckpointed()) - self.assertEqual(flatMappedRDD.collect(), result) - - -class AddFileTests(PySparkTestCase): - - def test_add_py_file(self): - # To ensure that we're actually testing addPyFile's effects, check that - # this job fails due to `userlibrary` not being on the Python path: - # disable logging in log4j temporarily - def func(x): - from userlibrary import UserClass - return UserClass().hello() - with QuietTest(self.sc): - self.assertRaises(Exception, self.sc.parallelize(range(2)).map(func).first) - - # Add the file, so the job should now succeed: - path = os.path.join(SPARK_HOME, "python/test_support/userlibrary.py") - self.sc.addPyFile(path) - res = self.sc.parallelize(range(2)).map(func).first() - self.assertEqual("Hello World!", res) - - def test_add_file_locally(self): - path = os.path.join(SPARK_HOME, "python/test_support/hello/hello.txt") - self.sc.addFile(path) - download_path = SparkFiles.get("hello.txt") - self.assertNotEqual(path, download_path) - with open(download_path) as test_file: - self.assertEqual("Hello World!\n", test_file.readline()) - - def test_add_file_recursively_locally(self): - path = os.path.join(SPARK_HOME, "python/test_support/hello") - self.sc.addFile(path, True) - download_path = SparkFiles.get("hello") - self.assertNotEqual(path, download_path) - with open(download_path + "/hello.txt") as test_file: - self.assertEqual("Hello World!\n", test_file.readline()) - with open(download_path + "/sub_hello/sub_hello.txt") as test_file: - self.assertEqual("Sub Hello World!\n", test_file.readline()) - - def test_add_py_file_locally(self): - # To ensure that we're actually testing addPyFile's effects, check that - # this fails due to `userlibrary` not being on the Python path: - def func(): - from userlibrary import UserClass - self.assertRaises(ImportError, func) - path = os.path.join(SPARK_HOME, "python/test_support/userlibrary.py") - self.sc.addPyFile(path) - from userlibrary import UserClass - self.assertEqual("Hello World!", UserClass().hello()) - - def test_add_egg_file_locally(self): - # To ensure that we're actually testing addPyFile's effects, check that - # this fails due to `userlibrary` not being on the Python path: - def func(): - from userlib import UserClass - self.assertRaises(ImportError, func) - path = os.path.join(SPARK_HOME, "python/test_support/userlib-0.1.zip") - self.sc.addPyFile(path) - from userlib import UserClass - self.assertEqual("Hello World from inside a package!", UserClass().hello()) - - def test_overwrite_system_module(self): - self.sc.addPyFile(os.path.join(SPARK_HOME, "python/test_support/SimpleHTTPServer.py")) - - import SimpleHTTPServer - self.assertEqual("My Server", SimpleHTTPServer.__name__) - - def func(x): - import SimpleHTTPServer - return SimpleHTTPServer.__name__ - - self.assertEqual(["My Server"], self.sc.parallelize(range(1)).map(func).collect()) - - -class TaskContextTests(PySparkTestCase): - - def setUp(self): - self._old_sys_path = list(sys.path) - class_name = self.__class__.__name__ - # Allow retries even though they are normally disabled in local mode - self.sc = SparkContext('local[4, 2]', class_name) - - def test_stage_id(self): - """Test the stage ids are available and incrementing as expected.""" - rdd = self.sc.parallelize(range(10)) - stage1 = rdd.map(lambda x: TaskContext.get().stageId()).take(1)[0] - stage2 = rdd.map(lambda x: TaskContext.get().stageId()).take(1)[0] - # Test using the constructor directly rather than the get() - stage3 = rdd.map(lambda x: TaskContext().stageId()).take(1)[0] - self.assertEqual(stage1 + 1, stage2) - self.assertEqual(stage1 + 2, stage3) - self.assertEqual(stage2 + 1, stage3) - - def test_partition_id(self): - """Test the partition id.""" - rdd1 = self.sc.parallelize(range(10), 1) - rdd2 = self.sc.parallelize(range(10), 2) - pids1 = rdd1.map(lambda x: TaskContext.get().partitionId()).collect() - pids2 = rdd2.map(lambda x: TaskContext.get().partitionId()).collect() - self.assertEqual(0, pids1[0]) - self.assertEqual(0, pids1[9]) - self.assertEqual(0, pids2[0]) - self.assertEqual(1, pids2[9]) - - def test_attempt_number(self): - """Verify the attempt numbers are correctly reported.""" - rdd = self.sc.parallelize(range(10)) - # Verify a simple job with no failures - attempt_numbers = rdd.map(lambda x: TaskContext.get().attemptNumber()).collect() - map(lambda attempt: self.assertEqual(0, attempt), attempt_numbers) - - def fail_on_first(x): - """Fail on the first attempt so we get a positive attempt number""" - tc = TaskContext.get() - attempt_number = tc.attemptNumber() - partition_id = tc.partitionId() - attempt_id = tc.taskAttemptId() - if attempt_number == 0 and partition_id == 0: - raise Exception("Failing on first attempt") - else: - return [x, partition_id, attempt_number, attempt_id] - result = rdd.map(fail_on_first).collect() - # We should re-submit the first partition to it but other partitions should be attempt 0 - self.assertEqual([0, 0, 1], result[0][0:3]) - self.assertEqual([9, 3, 0], result[9][0:3]) - first_partition = filter(lambda x: x[1] == 0, result) - map(lambda x: self.assertEqual(1, x[2]), first_partition) - other_partitions = filter(lambda x: x[1] != 0, result) - map(lambda x: self.assertEqual(0, x[2]), other_partitions) - # The task attempt id should be different - self.assertTrue(result[0][3] != result[9][3]) - - def test_tc_on_driver(self): - """Verify that getting the TaskContext on the driver returns None.""" - tc = TaskContext.get() - self.assertTrue(tc is None) - - def test_get_local_property(self): - """Verify that local properties set on the driver are available in TaskContext.""" - key = "testkey" - value = "testvalue" - self.sc.setLocalProperty(key, value) - try: - rdd = self.sc.parallelize(range(1), 1) - prop1 = rdd.map(lambda _: TaskContext.get().getLocalProperty(key)).collect()[0] - self.assertEqual(prop1, value) - prop2 = rdd.map(lambda _: TaskContext.get().getLocalProperty("otherkey")).collect()[0] - self.assertTrue(prop2 is None) - finally: - self.sc.setLocalProperty(key, None) - - def test_barrier(self): - """ - Verify that BarrierTaskContext.barrier() performs global sync among all barrier tasks - within a stage. - """ - rdd = self.sc.parallelize(range(10), 4) - - def f(iterator): - yield sum(iterator) - - def context_barrier(x): - tc = BarrierTaskContext.get() - time.sleep(random.randint(1, 10)) - tc.barrier() - return time.time() - - times = rdd.barrier().mapPartitions(f).map(context_barrier).collect() - self.assertTrue(max(times) - min(times) < 1) - - def test_barrier_infos(self): - """ - Verify that BarrierTaskContext.getTaskInfos() returns a list of all task infos in the - barrier stage. - """ - rdd = self.sc.parallelize(range(10), 4) - - def f(iterator): - yield sum(iterator) - - taskInfos = rdd.barrier().mapPartitions(f).map(lambda x: BarrierTaskContext.get() - .getTaskInfos()).collect() - self.assertTrue(len(taskInfos) == 4) - self.assertTrue(len(taskInfos[0]) == 4) - - -class RDDTests(ReusedPySparkTestCase): - - def test_range(self): - self.assertEqual(self.sc.range(1, 1).count(), 0) - self.assertEqual(self.sc.range(1, 0, -1).count(), 1) - self.assertEqual(self.sc.range(0, 1 << 40, 1 << 39).count(), 2) - - def test_id(self): - rdd = self.sc.parallelize(range(10)) - id = rdd.id() - self.assertEqual(id, rdd.id()) - rdd2 = rdd.map(str).filter(bool) - id2 = rdd2.id() - self.assertEqual(id + 1, id2) - self.assertEqual(id2, rdd2.id()) - - def test_empty_rdd(self): - rdd = self.sc.emptyRDD() - self.assertTrue(rdd.isEmpty()) - - def test_sum(self): - self.assertEqual(0, self.sc.emptyRDD().sum()) - self.assertEqual(6, self.sc.parallelize([1, 2, 3]).sum()) - - def test_to_localiterator(self): - from time import sleep - rdd = self.sc.parallelize([1, 2, 3]) - it = rdd.toLocalIterator() - sleep(5) - self.assertEqual([1, 2, 3], sorted(it)) - - rdd2 = rdd.repartition(1000) - it2 = rdd2.toLocalIterator() - sleep(5) - self.assertEqual([1, 2, 3], sorted(it2)) - - def test_save_as_textfile_with_unicode(self): - # Regression test for SPARK-970 - x = u"\u00A1Hola, mundo!" - data = self.sc.parallelize([x]) - tempFile = tempfile.NamedTemporaryFile(delete=True) - tempFile.close() - data.saveAsTextFile(tempFile.name) - raw_contents = b''.join(open(p, 'rb').read() - for p in glob(tempFile.name + "/part-0000*")) - self.assertEqual(x, raw_contents.strip().decode("utf-8")) - - def test_save_as_textfile_with_utf8(self): - x = u"\u00A1Hola, mundo!" - data = self.sc.parallelize([x.encode("utf-8")]) - tempFile = tempfile.NamedTemporaryFile(delete=True) - tempFile.close() - data.saveAsTextFile(tempFile.name) - raw_contents = b''.join(open(p, 'rb').read() - for p in glob(tempFile.name + "/part-0000*")) - self.assertEqual(x, raw_contents.strip().decode('utf8')) - - def test_transforming_cartesian_result(self): - # Regression test for SPARK-1034 - rdd1 = self.sc.parallelize([1, 2]) - rdd2 = self.sc.parallelize([3, 4]) - cart = rdd1.cartesian(rdd2) - result = cart.map(lambda x_y3: x_y3[0] + x_y3[1]).collect() - - def test_transforming_pickle_file(self): - # Regression test for SPARK-2601 - data = self.sc.parallelize([u"Hello", u"World!"]) - tempFile = tempfile.NamedTemporaryFile(delete=True) - tempFile.close() - data.saveAsPickleFile(tempFile.name) - pickled_file = self.sc.pickleFile(tempFile.name) - pickled_file.map(lambda x: x).collect() - - def test_cartesian_on_textfile(self): - # Regression test for - path = os.path.join(SPARK_HOME, "python/test_support/hello/hello.txt") - a = self.sc.textFile(path) - result = a.cartesian(a).collect() - (x, y) = result[0] - self.assertEqual(u"Hello World!", x.strip()) - self.assertEqual(u"Hello World!", y.strip()) - - def test_cartesian_chaining(self): - # Tests for SPARK-16589 - rdd = self.sc.parallelize(range(10), 2) - self.assertSetEqual( - set(rdd.cartesian(rdd).cartesian(rdd).collect()), - set([((x, y), z) for x in range(10) for y in range(10) for z in range(10)]) - ) - - self.assertSetEqual( - set(rdd.cartesian(rdd.cartesian(rdd)).collect()), - set([(x, (y, z)) for x in range(10) for y in range(10) for z in range(10)]) - ) - - self.assertSetEqual( - set(rdd.cartesian(rdd.zip(rdd)).collect()), - set([(x, (y, y)) for x in range(10) for y in range(10)]) - ) - - def test_zip_chaining(self): - # Tests for SPARK-21985 - rdd = self.sc.parallelize('abc', 2) - self.assertSetEqual( - set(rdd.zip(rdd).zip(rdd).collect()), - set([((x, x), x) for x in 'abc']) - ) - self.assertSetEqual( - set(rdd.zip(rdd.zip(rdd)).collect()), - set([(x, (x, x)) for x in 'abc']) - ) - - def test_deleting_input_files(self): - # Regression test for SPARK-1025 - tempFile = tempfile.NamedTemporaryFile(delete=False) - tempFile.write(b"Hello World!") - tempFile.close() - data = self.sc.textFile(tempFile.name) - filtered_data = data.filter(lambda x: True) - self.assertEqual(1, filtered_data.count()) - os.unlink(tempFile.name) - with QuietTest(self.sc): - self.assertRaises(Exception, lambda: filtered_data.count()) - - def test_sampling_default_seed(self): - # Test for SPARK-3995 (default seed setting) - data = self.sc.parallelize(xrange(1000), 1) - subset = data.takeSample(False, 10) - self.assertEqual(len(subset), 10) - - def test_aggregate_mutable_zero_value(self): - # Test for SPARK-9021; uses aggregate and treeAggregate to build dict - # representing a counter of ints - # NOTE: dict is used instead of collections.Counter for Python 2.6 - # compatibility - from collections import defaultdict - - # Show that single or multiple partitions work - data1 = self.sc.range(10, numSlices=1) - data2 = self.sc.range(10, numSlices=2) - - def seqOp(x, y): - x[y] += 1 - return x - - def comboOp(x, y): - for key, val in y.items(): - x[key] += val - return x - - counts1 = data1.aggregate(defaultdict(int), seqOp, comboOp) - counts2 = data2.aggregate(defaultdict(int), seqOp, comboOp) - counts3 = data1.treeAggregate(defaultdict(int), seqOp, comboOp, 2) - counts4 = data2.treeAggregate(defaultdict(int), seqOp, comboOp, 2) - - ground_truth = defaultdict(int, dict((i, 1) for i in range(10))) - self.assertEqual(counts1, ground_truth) - self.assertEqual(counts2, ground_truth) - self.assertEqual(counts3, ground_truth) - self.assertEqual(counts4, ground_truth) - - def test_aggregate_by_key_mutable_zero_value(self): - # Test for SPARK-9021; uses aggregateByKey to make a pair RDD that - # contains lists of all values for each key in the original RDD - - # list(range(...)) for Python 3.x compatibility (can't use * operator - # on a range object) - # list(zip(...)) for Python 3.x compatibility (want to parallelize a - # collection, not a zip object) - tuples = list(zip(list(range(10))*2, [1]*20)) - # Show that single or multiple partitions work - data1 = self.sc.parallelize(tuples, 1) - data2 = self.sc.parallelize(tuples, 2) - - def seqOp(x, y): - x.append(y) - return x - - def comboOp(x, y): - x.extend(y) - return x - - values1 = data1.aggregateByKey([], seqOp, comboOp).collect() - values2 = data2.aggregateByKey([], seqOp, comboOp).collect() - # Sort lists to ensure clean comparison with ground_truth - values1.sort() - values2.sort() - - ground_truth = [(i, [1]*2) for i in range(10)] - self.assertEqual(values1, ground_truth) - self.assertEqual(values2, ground_truth) - - def test_fold_mutable_zero_value(self): - # Test for SPARK-9021; uses fold to merge an RDD of dict counters into - # a single dict - # NOTE: dict is used instead of collections.Counter for Python 2.6 - # compatibility - from collections import defaultdict - - counts1 = defaultdict(int, dict((i, 1) for i in range(10))) - counts2 = defaultdict(int, dict((i, 1) for i in range(3, 8))) - counts3 = defaultdict(int, dict((i, 1) for i in range(4, 7))) - counts4 = defaultdict(int, dict((i, 1) for i in range(5, 6))) - all_counts = [counts1, counts2, counts3, counts4] - # Show that single or multiple partitions work - data1 = self.sc.parallelize(all_counts, 1) - data2 = self.sc.parallelize(all_counts, 2) - - def comboOp(x, y): - for key, val in y.items(): - x[key] += val - return x - - fold1 = data1.fold(defaultdict(int), comboOp) - fold2 = data2.fold(defaultdict(int), comboOp) - - ground_truth = defaultdict(int) - for counts in all_counts: - for key, val in counts.items(): - ground_truth[key] += val - self.assertEqual(fold1, ground_truth) - self.assertEqual(fold2, ground_truth) - - def test_fold_by_key_mutable_zero_value(self): - # Test for SPARK-9021; uses foldByKey to make a pair RDD that contains - # lists of all values for each key in the original RDD - - tuples = [(i, range(i)) for i in range(10)]*2 - # Show that single or multiple partitions work - data1 = self.sc.parallelize(tuples, 1) - data2 = self.sc.parallelize(tuples, 2) - - def comboOp(x, y): - x.extend(y) - return x - - values1 = data1.foldByKey([], comboOp).collect() - values2 = data2.foldByKey([], comboOp).collect() - # Sort lists to ensure clean comparison with ground_truth - values1.sort() - values2.sort() - - # list(range(...)) for Python 3.x compatibility - ground_truth = [(i, list(range(i))*2) for i in range(10)] - self.assertEqual(values1, ground_truth) - self.assertEqual(values2, ground_truth) - - def test_aggregate_by_key(self): - data = self.sc.parallelize([(1, 1), (1, 1), (3, 2), (5, 1), (5, 3)], 2) - - def seqOp(x, y): - x.add(y) - return x - - def combOp(x, y): - x |= y - return x - - sets = dict(data.aggregateByKey(set(), seqOp, combOp).collect()) - self.assertEqual(3, len(sets)) - self.assertEqual(set([1]), sets[1]) - self.assertEqual(set([2]), sets[3]) - self.assertEqual(set([1, 3]), sets[5]) - - def test_itemgetter(self): - rdd = self.sc.parallelize([range(10)]) - from operator import itemgetter - self.assertEqual([1], rdd.map(itemgetter(1)).collect()) - self.assertEqual([(2, 3)], rdd.map(itemgetter(2, 3)).collect()) - - def test_namedtuple_in_rdd(self): - from collections import namedtuple - Person = namedtuple("Person", "id firstName lastName") - jon = Person(1, "Jon", "Doe") - jane = Person(2, "Jane", "Doe") - theDoes = self.sc.parallelize([jon, jane]) - self.assertEqual([jon, jane], theDoes.collect()) - - def test_large_broadcast(self): - N = 10000 - data = [[float(i) for i in range(300)] for i in range(N)] - bdata = self.sc.broadcast(data) # 27MB - m = self.sc.parallelize(range(1), 1).map(lambda x: len(bdata.value)).sum() - self.assertEqual(N, m) - - def test_unpersist(self): - N = 1000 - data = [[float(i) for i in range(300)] for i in range(N)] - bdata = self.sc.broadcast(data) # 3MB - bdata.unpersist() - m = self.sc.parallelize(range(1), 1).map(lambda x: len(bdata.value)).sum() - self.assertEqual(N, m) - bdata.destroy() - try: - self.sc.parallelize(range(1), 1).map(lambda x: len(bdata.value)).sum() - except Exception as e: - pass - else: - raise Exception("job should fail after destroy the broadcast") - - def test_multiple_broadcasts(self): - N = 1 << 21 - b1 = self.sc.broadcast(set(range(N))) # multiple blocks in JVM - r = list(range(1 << 15)) - random.shuffle(r) - s = str(r).encode() - checksum = hashlib.md5(s).hexdigest() - b2 = self.sc.broadcast(s) - r = list(set(self.sc.parallelize(range(10), 10).map( - lambda x: (len(b1.value), hashlib.md5(b2.value).hexdigest())).collect())) - self.assertEqual(1, len(r)) - size, csum = r[0] - self.assertEqual(N, size) - self.assertEqual(checksum, csum) - - random.shuffle(r) - s = str(r).encode() - checksum = hashlib.md5(s).hexdigest() - b2 = self.sc.broadcast(s) - r = list(set(self.sc.parallelize(range(10), 10).map( - lambda x: (len(b1.value), hashlib.md5(b2.value).hexdigest())).collect())) - self.assertEqual(1, len(r)) - size, csum = r[0] - self.assertEqual(N, size) - self.assertEqual(checksum, csum) - - def test_multithread_broadcast_pickle(self): - import threading - - b1 = self.sc.broadcast(list(range(3))) - b2 = self.sc.broadcast(list(range(3))) - - def f1(): - return b1.value - - def f2(): - return b2.value - - funcs_num_pickled = {f1: None, f2: None} - - def do_pickle(f, sc): - command = (f, None, sc.serializer, sc.serializer) - ser = CloudPickleSerializer() - ser.dumps(command) - - def process_vars(sc): - broadcast_vars = list(sc._pickled_broadcast_vars) - num_pickled = len(broadcast_vars) - sc._pickled_broadcast_vars.clear() - return num_pickled - - def run(f, sc): - do_pickle(f, sc) - funcs_num_pickled[f] = process_vars(sc) - - # pickle f1, adds b1 to sc._pickled_broadcast_vars in main thread local storage - do_pickle(f1, self.sc) - - # run all for f2, should only add/count/clear b2 from worker thread local storage - t = threading.Thread(target=run, args=(f2, self.sc)) - t.start() - t.join() - - # count number of vars pickled in main thread, only b1 should be counted and cleared - funcs_num_pickled[f1] = process_vars(self.sc) - - self.assertEqual(funcs_num_pickled[f1], 1) - self.assertEqual(funcs_num_pickled[f2], 1) - self.assertEqual(len(list(self.sc._pickled_broadcast_vars)), 0) - - def test_large_closure(self): - N = 200000 - data = [float(i) for i in xrange(N)] - rdd = self.sc.parallelize(range(1), 1).map(lambda x: len(data)) - self.assertEqual(N, rdd.first()) - # regression test for SPARK-6886 - self.assertEqual(1, rdd.map(lambda x: (x, 1)).groupByKey().count()) - - def test_zip_with_different_serializers(self): - a = self.sc.parallelize(range(5)) - b = self.sc.parallelize(range(100, 105)) - self.assertEqual(a.zip(b).collect(), [(0, 100), (1, 101), (2, 102), (3, 103), (4, 104)]) - a = a._reserialize(BatchedSerializer(PickleSerializer(), 2)) - b = b._reserialize(MarshalSerializer()) - self.assertEqual(a.zip(b).collect(), [(0, 100), (1, 101), (2, 102), (3, 103), (4, 104)]) - # regression test for SPARK-4841 - path = os.path.join(SPARK_HOME, "python/test_support/hello/hello.txt") - t = self.sc.textFile(path) - cnt = t.count() - self.assertEqual(cnt, t.zip(t).count()) - rdd = t.map(str) - self.assertEqual(cnt, t.zip(rdd).count()) - # regression test for bug in _reserializer() - self.assertEqual(cnt, t.zip(rdd).count()) - - def test_zip_with_different_object_sizes(self): - # regress test for SPARK-5973 - a = self.sc.parallelize(xrange(10000)).map(lambda i: '*' * i) - b = self.sc.parallelize(xrange(10000, 20000)).map(lambda i: '*' * i) - self.assertEqual(10000, a.zip(b).count()) - - def test_zip_with_different_number_of_items(self): - a = self.sc.parallelize(range(5), 2) - # different number of partitions - b = self.sc.parallelize(range(100, 106), 3) - self.assertRaises(ValueError, lambda: a.zip(b)) - with QuietTest(self.sc): - # different number of batched items in JVM - b = self.sc.parallelize(range(100, 104), 2) - self.assertRaises(Exception, lambda: a.zip(b).count()) - # different number of items in one pair - b = self.sc.parallelize(range(100, 106), 2) - self.assertRaises(Exception, lambda: a.zip(b).count()) - # same total number of items, but different distributions - a = self.sc.parallelize([2, 3], 2).flatMap(range) - b = self.sc.parallelize([3, 2], 2).flatMap(range) - self.assertEqual(a.count(), b.count()) - self.assertRaises(Exception, lambda: a.zip(b).count()) - - def test_count_approx_distinct(self): - rdd = self.sc.parallelize(xrange(1000)) - self.assertTrue(950 < rdd.countApproxDistinct(0.03) < 1050) - self.assertTrue(950 < rdd.map(float).countApproxDistinct(0.03) < 1050) - self.assertTrue(950 < rdd.map(str).countApproxDistinct(0.03) < 1050) - self.assertTrue(950 < rdd.map(lambda x: (x, -x)).countApproxDistinct(0.03) < 1050) - - rdd = self.sc.parallelize([i % 20 for i in range(1000)], 7) - self.assertTrue(18 < rdd.countApproxDistinct() < 22) - self.assertTrue(18 < rdd.map(float).countApproxDistinct() < 22) - self.assertTrue(18 < rdd.map(str).countApproxDistinct() < 22) - self.assertTrue(18 < rdd.map(lambda x: (x, -x)).countApproxDistinct() < 22) - - self.assertRaises(ValueError, lambda: rdd.countApproxDistinct(0.00000001)) - - def test_histogram(self): - # empty - rdd = self.sc.parallelize([]) - self.assertEqual([0], rdd.histogram([0, 10])[1]) - self.assertEqual([0, 0], rdd.histogram([0, 4, 10])[1]) - self.assertRaises(ValueError, lambda: rdd.histogram(1)) - - # out of range - rdd = self.sc.parallelize([10.01, -0.01]) - self.assertEqual([0], rdd.histogram([0, 10])[1]) - self.assertEqual([0, 0], rdd.histogram((0, 4, 10))[1]) - - # in range with one bucket - rdd = self.sc.parallelize(range(1, 5)) - self.assertEqual([4], rdd.histogram([0, 10])[1]) - self.assertEqual([3, 1], rdd.histogram([0, 4, 10])[1]) - - # in range with one bucket exact match - self.assertEqual([4], rdd.histogram([1, 4])[1]) - - # out of range with two buckets - rdd = self.sc.parallelize([10.01, -0.01]) - self.assertEqual([0, 0], rdd.histogram([0, 5, 10])[1]) - - # out of range with two uneven buckets - rdd = self.sc.parallelize([10.01, -0.01]) - self.assertEqual([0, 0], rdd.histogram([0, 4, 10])[1]) - - # in range with two buckets - rdd = self.sc.parallelize([1, 2, 3, 5, 6]) - self.assertEqual([3, 2], rdd.histogram([0, 5, 10])[1]) - - # in range with two bucket and None - rdd = self.sc.parallelize([1, 2, 3, 5, 6, None, float('nan')]) - self.assertEqual([3, 2], rdd.histogram([0, 5, 10])[1]) - - # in range with two uneven buckets - rdd = self.sc.parallelize([1, 2, 3, 5, 6]) - self.assertEqual([3, 2], rdd.histogram([0, 5, 11])[1]) - - # mixed range with two uneven buckets - rdd = self.sc.parallelize([-0.01, 0.0, 1, 2, 3, 5, 6, 11.0, 11.01]) - self.assertEqual([4, 3], rdd.histogram([0, 5, 11])[1]) - - # mixed range with four uneven buckets - rdd = self.sc.parallelize([-0.01, 0.0, 1, 2, 3, 5, 6, 11.01, 12.0, 199.0, 200.0, 200.1]) - self.assertEqual([4, 2, 1, 3], rdd.histogram([0.0, 5.0, 11.0, 12.0, 200.0])[1]) - - # mixed range with uneven buckets and NaN - rdd = self.sc.parallelize([-0.01, 0.0, 1, 2, 3, 5, 6, 11.01, 12.0, - 199.0, 200.0, 200.1, None, float('nan')]) - self.assertEqual([4, 2, 1, 3], rdd.histogram([0.0, 5.0, 11.0, 12.0, 200.0])[1]) - - # out of range with infinite buckets - rdd = self.sc.parallelize([10.01, -0.01, float('nan'), float("inf")]) - self.assertEqual([1, 2], rdd.histogram([float('-inf'), 0, float('inf')])[1]) - - # invalid buckets - self.assertRaises(ValueError, lambda: rdd.histogram([])) - self.assertRaises(ValueError, lambda: rdd.histogram([1])) - self.assertRaises(ValueError, lambda: rdd.histogram(0)) - self.assertRaises(TypeError, lambda: rdd.histogram({})) - - # without buckets - rdd = self.sc.parallelize(range(1, 5)) - self.assertEqual(([1, 4], [4]), rdd.histogram(1)) - - # without buckets single element - rdd = self.sc.parallelize([1]) - self.assertEqual(([1, 1], [1]), rdd.histogram(1)) - - # without bucket no range - rdd = self.sc.parallelize([1] * 4) - self.assertEqual(([1, 1], [4]), rdd.histogram(1)) - - # without buckets basic two - rdd = self.sc.parallelize(range(1, 5)) - self.assertEqual(([1, 2.5, 4], [2, 2]), rdd.histogram(2)) - - # without buckets with more requested than elements - rdd = self.sc.parallelize([1, 2]) - buckets = [1 + 0.2 * i for i in range(6)] - hist = [1, 0, 0, 0, 1] - self.assertEqual((buckets, hist), rdd.histogram(5)) - - # invalid RDDs - rdd = self.sc.parallelize([1, float('inf')]) - self.assertRaises(ValueError, lambda: rdd.histogram(2)) - rdd = self.sc.parallelize([float('nan')]) - self.assertRaises(ValueError, lambda: rdd.histogram(2)) - - # string - rdd = self.sc.parallelize(["ab", "ac", "b", "bd", "ef"], 2) - self.assertEqual([2, 2], rdd.histogram(["a", "b", "c"])[1]) - self.assertEqual((["ab", "ef"], [5]), rdd.histogram(1)) - self.assertRaises(TypeError, lambda: rdd.histogram(2)) - - def test_repartitionAndSortWithinPartitions_asc(self): - rdd = self.sc.parallelize([(0, 5), (3, 8), (2, 6), (0, 8), (3, 8), (1, 3)], 2) - - repartitioned = rdd.repartitionAndSortWithinPartitions(2, lambda key: key % 2, True) - partitions = repartitioned.glom().collect() - self.assertEqual(partitions[0], [(0, 5), (0, 8), (2, 6)]) - self.assertEqual(partitions[1], [(1, 3), (3, 8), (3, 8)]) - - def test_repartitionAndSortWithinPartitions_desc(self): - rdd = self.sc.parallelize([(0, 5), (3, 8), (2, 6), (0, 8), (3, 8), (1, 3)], 2) - - repartitioned = rdd.repartitionAndSortWithinPartitions(2, lambda key: key % 2, False) - partitions = repartitioned.glom().collect() - self.assertEqual(partitions[0], [(2, 6), (0, 5), (0, 8)]) - self.assertEqual(partitions[1], [(3, 8), (3, 8), (1, 3)]) - - def test_repartition_no_skewed(self): - num_partitions = 20 - a = self.sc.parallelize(range(int(1000)), 2) - l = a.repartition(num_partitions).glom().map(len).collect() - zeros = len([x for x in l if x == 0]) - self.assertTrue(zeros == 0) - l = a.coalesce(num_partitions, True).glom().map(len).collect() - zeros = len([x for x in l if x == 0]) - self.assertTrue(zeros == 0) - - def test_repartition_on_textfile(self): - path = os.path.join(SPARK_HOME, "python/test_support/hello/hello.txt") - rdd = self.sc.textFile(path) - result = rdd.repartition(1).collect() - self.assertEqual(u"Hello World!", result[0]) - - def test_distinct(self): - rdd = self.sc.parallelize((1, 2, 3)*10, 10) - self.assertEqual(rdd.getNumPartitions(), 10) - self.assertEqual(rdd.distinct().count(), 3) - result = rdd.distinct(5) - self.assertEqual(result.getNumPartitions(), 5) - self.assertEqual(result.count(), 3) - - def test_external_group_by_key(self): - self.sc._conf.set("spark.python.worker.memory", "1m") - N = 200001 - kv = self.sc.parallelize(xrange(N)).map(lambda x: (x % 3, x)) - gkv = kv.groupByKey().cache() - self.assertEqual(3, gkv.count()) - filtered = gkv.filter(lambda kv: kv[0] == 1) - self.assertEqual(1, filtered.count()) - self.assertEqual([(1, N // 3)], filtered.mapValues(len).collect()) - self.assertEqual([(N // 3, N // 3)], - filtered.values().map(lambda x: (len(x), len(list(x)))).collect()) - result = filtered.collect()[0][1] - self.assertEqual(N // 3, len(result)) - self.assertTrue(isinstance(result.data, shuffle.ExternalListOfList)) - - def test_sort_on_empty_rdd(self): - self.assertEqual([], self.sc.parallelize(zip([], [])).sortByKey().collect()) - - def test_sample(self): - rdd = self.sc.parallelize(range(0, 100), 4) - wo = rdd.sample(False, 0.1, 2).collect() - wo_dup = rdd.sample(False, 0.1, 2).collect() - self.assertSetEqual(set(wo), set(wo_dup)) - wr = rdd.sample(True, 0.2, 5).collect() - wr_dup = rdd.sample(True, 0.2, 5).collect() - self.assertSetEqual(set(wr), set(wr_dup)) - wo_s10 = rdd.sample(False, 0.3, 10).collect() - wo_s20 = rdd.sample(False, 0.3, 20).collect() - self.assertNotEqual(set(wo_s10), set(wo_s20)) - wr_s11 = rdd.sample(True, 0.4, 11).collect() - wr_s21 = rdd.sample(True, 0.4, 21).collect() - self.assertNotEqual(set(wr_s11), set(wr_s21)) - - def test_null_in_rdd(self): - jrdd = self.sc._jvm.PythonUtils.generateRDDWithNull(self.sc._jsc) - rdd = RDD(jrdd, self.sc, UTF8Deserializer()) - self.assertEqual([u"a", None, u"b"], rdd.collect()) - rdd = RDD(jrdd, self.sc, NoOpSerializer()) - self.assertEqual([b"a", None, b"b"], rdd.collect()) - - def test_multiple_python_java_RDD_conversions(self): - # Regression test for SPARK-5361 - data = [ - (u'1', {u'director': u'David Lean'}), - (u'2', {u'director': u'Andrew Dominik'}) - ] - data_rdd = self.sc.parallelize(data) - data_java_rdd = data_rdd._to_java_object_rdd() - data_python_rdd = self.sc._jvm.SerDeUtil.javaToPython(data_java_rdd) - converted_rdd = RDD(data_python_rdd, self.sc) - self.assertEqual(2, converted_rdd.count()) - - # conversion between python and java RDD threw exceptions - data_java_rdd = converted_rdd._to_java_object_rdd() - data_python_rdd = self.sc._jvm.SerDeUtil.javaToPython(data_java_rdd) - converted_rdd = RDD(data_python_rdd, self.sc) - self.assertEqual(2, converted_rdd.count()) - - def test_narrow_dependency_in_join(self): - rdd = self.sc.parallelize(range(10)).map(lambda x: (x, x)) - parted = rdd.partitionBy(2) - self.assertEqual(2, parted.union(parted).getNumPartitions()) - self.assertEqual(rdd.getNumPartitions() + 2, parted.union(rdd).getNumPartitions()) - self.assertEqual(rdd.getNumPartitions() + 2, rdd.union(parted).getNumPartitions()) - - tracker = self.sc.statusTracker() - - self.sc.setJobGroup("test1", "test", True) - d = sorted(parted.join(parted).collect()) - self.assertEqual(10, len(d)) - self.assertEqual((0, (0, 0)), d[0]) - jobId = tracker.getJobIdsForGroup("test1")[0] - self.assertEqual(2, len(tracker.getJobInfo(jobId).stageIds)) - - self.sc.setJobGroup("test2", "test", True) - d = sorted(parted.join(rdd).collect()) - self.assertEqual(10, len(d)) - self.assertEqual((0, (0, 0)), d[0]) - jobId = tracker.getJobIdsForGroup("test2")[0] - self.assertEqual(3, len(tracker.getJobInfo(jobId).stageIds)) - - self.sc.setJobGroup("test3", "test", True) - d = sorted(parted.cogroup(parted).collect()) - self.assertEqual(10, len(d)) - self.assertEqual([[0], [0]], list(map(list, d[0][1]))) - jobId = tracker.getJobIdsForGroup("test3")[0] - self.assertEqual(2, len(tracker.getJobInfo(jobId).stageIds)) - - self.sc.setJobGroup("test4", "test", True) - d = sorted(parted.cogroup(rdd).collect()) - self.assertEqual(10, len(d)) - self.assertEqual([[0], [0]], list(map(list, d[0][1]))) - jobId = tracker.getJobIdsForGroup("test4")[0] - self.assertEqual(3, len(tracker.getJobInfo(jobId).stageIds)) - - # Regression test for SPARK-6294 - def test_take_on_jrdd(self): - rdd = self.sc.parallelize(xrange(1 << 20)).map(lambda x: str(x)) - rdd._jrdd.first() - - def test_sortByKey_uses_all_partitions_not_only_first_and_last(self): - # Regression test for SPARK-5969 - seq = [(i * 59 % 101, i) for i in range(101)] # unsorted sequence - rdd = self.sc.parallelize(seq) - for ascending in [True, False]: - sort = rdd.sortByKey(ascending=ascending, numPartitions=5) - self.assertEqual(sort.collect(), sorted(seq, reverse=not ascending)) - sizes = sort.glom().map(len).collect() - for size in sizes: - self.assertGreater(size, 0) - - def test_pipe_functions(self): - data = ['1', '2', '3'] - rdd = self.sc.parallelize(data) - with QuietTest(self.sc): - self.assertEqual([], rdd.pipe('cc').collect()) - self.assertRaises(Py4JJavaError, rdd.pipe('cc', checkCode=True).collect) - result = rdd.pipe('cat').collect() - result.sort() - for x, y in zip(data, result): - self.assertEqual(x, y) - self.assertRaises(Py4JJavaError, rdd.pipe('grep 4', checkCode=True).collect) - self.assertEqual([], rdd.pipe('grep 4').collect()) - - def test_pipe_unicode(self): - # Regression test for SPARK-20947 - data = [u'\u6d4b\u8bd5', '1'] - rdd = self.sc.parallelize(data) - result = rdd.pipe('cat').collect() - self.assertEqual(data, result) - - def test_stopiteration_in_user_code(self): - - def stopit(*x): - raise StopIteration() - - seq_rdd = self.sc.parallelize(range(10)) - keyed_rdd = self.sc.parallelize((x % 2, x) for x in range(10)) - msg = "Caught StopIteration thrown from user's code; failing the task" - - self.assertRaisesRegexp(Py4JJavaError, msg, seq_rdd.map(stopit).collect) - self.assertRaisesRegexp(Py4JJavaError, msg, seq_rdd.filter(stopit).collect) - self.assertRaisesRegexp(Py4JJavaError, msg, seq_rdd.foreach, stopit) - self.assertRaisesRegexp(Py4JJavaError, msg, seq_rdd.reduce, stopit) - self.assertRaisesRegexp(Py4JJavaError, msg, seq_rdd.fold, 0, stopit) - self.assertRaisesRegexp(Py4JJavaError, msg, seq_rdd.foreach, stopit) - self.assertRaisesRegexp(Py4JJavaError, msg, - seq_rdd.cartesian(seq_rdd).flatMap(stopit).collect) - - # these methods call the user function both in the driver and in the executor - # the exception raised is different according to where the StopIteration happens - # RuntimeError is raised if in the driver - # Py4JJavaError is raised if in the executor (wraps the RuntimeError raised in the worker) - self.assertRaisesRegexp((Py4JJavaError, RuntimeError), msg, - keyed_rdd.reduceByKeyLocally, stopit) - self.assertRaisesRegexp((Py4JJavaError, RuntimeError), msg, - seq_rdd.aggregate, 0, stopit, lambda *x: 1) - self.assertRaisesRegexp((Py4JJavaError, RuntimeError), msg, - seq_rdd.aggregate, 0, lambda *x: 1, stopit) - - -class ProfilerTests(PySparkTestCase): - - def setUp(self): - self._old_sys_path = list(sys.path) - class_name = self.__class__.__name__ - conf = SparkConf().set("spark.python.profile", "true") - self.sc = SparkContext('local[4]', class_name, conf=conf) - - def test_profiler(self): - self.do_computation() - - profilers = self.sc.profiler_collector.profilers - self.assertEqual(1, len(profilers)) - id, profiler, _ = profilers[0] - stats = profiler.stats() - self.assertTrue(stats is not None) - width, stat_list = stats.get_print_list([]) - func_names = [func_name for fname, n, func_name in stat_list] - self.assertTrue("heavy_foo" in func_names) - - old_stdout = sys.stdout - sys.stdout = io = StringIO() - self.sc.show_profiles() - self.assertTrue("heavy_foo" in io.getvalue()) - sys.stdout = old_stdout - - d = tempfile.gettempdir() - self.sc.dump_profiles(d) - self.assertTrue("rdd_%d.pstats" % id in os.listdir(d)) - - def test_custom_profiler(self): - class TestCustomProfiler(BasicProfiler): - def show(self, id): - self.result = "Custom formatting" - - self.sc.profiler_collector.profiler_cls = TestCustomProfiler - - self.do_computation() - - profilers = self.sc.profiler_collector.profilers - self.assertEqual(1, len(profilers)) - _, profiler, _ = profilers[0] - self.assertTrue(isinstance(profiler, TestCustomProfiler)) - - self.sc.show_profiles() - self.assertEqual("Custom formatting", profiler.result) - - def do_computation(self): - def heavy_foo(x): - for i in range(1 << 18): - x = 1 - - rdd = self.sc.parallelize(range(100)) - rdd.foreach(heavy_foo) - - -class ProfilerTests2(unittest.TestCase): - def test_profiler_disabled(self): - sc = SparkContext(conf=SparkConf().set("spark.python.profile", "false")) - try: - self.assertRaisesRegexp( - RuntimeError, - "'spark.python.profile' configuration must be set", - lambda: sc.show_profiles()) - self.assertRaisesRegexp( - RuntimeError, - "'spark.python.profile' configuration must be set", - lambda: sc.dump_profiles("/tmp/abc")) - finally: - sc.stop() - - -class InputFormatTests(ReusedPySparkTestCase): - - @classmethod - def setUpClass(cls): - ReusedPySparkTestCase.setUpClass() - cls.tempdir = tempfile.NamedTemporaryFile(delete=False) - os.unlink(cls.tempdir.name) - cls.sc._jvm.WriteInputFormatTestDataGenerator.generateData(cls.tempdir.name, cls.sc._jsc) - - @classmethod - def tearDownClass(cls): - ReusedPySparkTestCase.tearDownClass() - shutil.rmtree(cls.tempdir.name) - - @unittest.skipIf(sys.version >= "3", "serialize array of byte") - def test_sequencefiles(self): - basepath = self.tempdir.name - ints = sorted(self.sc.sequenceFile(basepath + "/sftestdata/sfint/", - "org.apache.hadoop.io.IntWritable", - "org.apache.hadoop.io.Text").collect()) - ei = [(1, u'aa'), (1, u'aa'), (2, u'aa'), (2, u'bb'), (2, u'bb'), (3, u'cc')] - self.assertEqual(ints, ei) - - doubles = sorted(self.sc.sequenceFile(basepath + "/sftestdata/sfdouble/", - "org.apache.hadoop.io.DoubleWritable", - "org.apache.hadoop.io.Text").collect()) - ed = [(1.0, u'aa'), (1.0, u'aa'), (2.0, u'aa'), (2.0, u'bb'), (2.0, u'bb'), (3.0, u'cc')] - self.assertEqual(doubles, ed) - - bytes = sorted(self.sc.sequenceFile(basepath + "/sftestdata/sfbytes/", - "org.apache.hadoop.io.IntWritable", - "org.apache.hadoop.io.BytesWritable").collect()) - ebs = [(1, bytearray('aa', 'utf-8')), - (1, bytearray('aa', 'utf-8')), - (2, bytearray('aa', 'utf-8')), - (2, bytearray('bb', 'utf-8')), - (2, bytearray('bb', 'utf-8')), - (3, bytearray('cc', 'utf-8'))] - self.assertEqual(bytes, ebs) - - text = sorted(self.sc.sequenceFile(basepath + "/sftestdata/sftext/", - "org.apache.hadoop.io.Text", - "org.apache.hadoop.io.Text").collect()) - et = [(u'1', u'aa'), - (u'1', u'aa'), - (u'2', u'aa'), - (u'2', u'bb'), - (u'2', u'bb'), - (u'3', u'cc')] - self.assertEqual(text, et) - - bools = sorted(self.sc.sequenceFile(basepath + "/sftestdata/sfbool/", - "org.apache.hadoop.io.IntWritable", - "org.apache.hadoop.io.BooleanWritable").collect()) - eb = [(1, False), (1, True), (2, False), (2, False), (2, True), (3, True)] - self.assertEqual(bools, eb) - - nulls = sorted(self.sc.sequenceFile(basepath + "/sftestdata/sfnull/", - "org.apache.hadoop.io.IntWritable", - "org.apache.hadoop.io.BooleanWritable").collect()) - en = [(1, None), (1, None), (2, None), (2, None), (2, None), (3, None)] - self.assertEqual(nulls, en) - - maps = self.sc.sequenceFile(basepath + "/sftestdata/sfmap/", - "org.apache.hadoop.io.IntWritable", - "org.apache.hadoop.io.MapWritable").collect() - em = [(1, {}), - (1, {3.0: u'bb'}), - (2, {1.0: u'aa'}), - (2, {1.0: u'cc'}), - (3, {2.0: u'dd'})] - for v in maps: - self.assertTrue(v in em) - - # arrays get pickled to tuples by default - tuples = sorted(self.sc.sequenceFile( - basepath + "/sftestdata/sfarray/", - "org.apache.hadoop.io.IntWritable", - "org.apache.spark.api.python.DoubleArrayWritable").collect()) - et = [(1, ()), - (2, (3.0, 4.0, 5.0)), - (3, (4.0, 5.0, 6.0))] - self.assertEqual(tuples, et) - - # with custom converters, primitive arrays can stay as arrays - arrays = sorted(self.sc.sequenceFile( - basepath + "/sftestdata/sfarray/", - "org.apache.hadoop.io.IntWritable", - "org.apache.spark.api.python.DoubleArrayWritable", - valueConverter="org.apache.spark.api.python.WritableToDoubleArrayConverter").collect()) - ea = [(1, array('d')), - (2, array('d', [3.0, 4.0, 5.0])), - (3, array('d', [4.0, 5.0, 6.0]))] - self.assertEqual(arrays, ea) - - clazz = sorted(self.sc.sequenceFile(basepath + "/sftestdata/sfclass/", - "org.apache.hadoop.io.Text", - "org.apache.spark.api.python.TestWritable").collect()) - cname = u'org.apache.spark.api.python.TestWritable' - ec = [(u'1', {u'__class__': cname, u'double': 1.0, u'int': 1, u'str': u'test1'}), - (u'2', {u'__class__': cname, u'double': 2.3, u'int': 2, u'str': u'test2'}), - (u'3', {u'__class__': cname, u'double': 3.1, u'int': 3, u'str': u'test3'}), - (u'4', {u'__class__': cname, u'double': 4.2, u'int': 4, u'str': u'test4'}), - (u'5', {u'__class__': cname, u'double': 5.5, u'int': 5, u'str': u'test56'})] - self.assertEqual(clazz, ec) - - unbatched_clazz = sorted(self.sc.sequenceFile(basepath + "/sftestdata/sfclass/", - "org.apache.hadoop.io.Text", - "org.apache.spark.api.python.TestWritable", - ).collect()) - self.assertEqual(unbatched_clazz, ec) - - def test_oldhadoop(self): - basepath = self.tempdir.name - ints = sorted(self.sc.hadoopFile(basepath + "/sftestdata/sfint/", - "org.apache.hadoop.mapred.SequenceFileInputFormat", - "org.apache.hadoop.io.IntWritable", - "org.apache.hadoop.io.Text").collect()) - ei = [(1, u'aa'), (1, u'aa'), (2, u'aa'), (2, u'bb'), (2, u'bb'), (3, u'cc')] - self.assertEqual(ints, ei) - - hellopath = os.path.join(SPARK_HOME, "python/test_support/hello/hello.txt") - oldconf = {"mapreduce.input.fileinputformat.inputdir": hellopath} - hello = self.sc.hadoopRDD("org.apache.hadoop.mapred.TextInputFormat", - "org.apache.hadoop.io.LongWritable", - "org.apache.hadoop.io.Text", - conf=oldconf).collect() - result = [(0, u'Hello World!')] - self.assertEqual(hello, result) - - def test_newhadoop(self): - basepath = self.tempdir.name - ints = sorted(self.sc.newAPIHadoopFile( - basepath + "/sftestdata/sfint/", - "org.apache.hadoop.mapreduce.lib.input.SequenceFileInputFormat", - "org.apache.hadoop.io.IntWritable", - "org.apache.hadoop.io.Text").collect()) - ei = [(1, u'aa'), (1, u'aa'), (2, u'aa'), (2, u'bb'), (2, u'bb'), (3, u'cc')] - self.assertEqual(ints, ei) - - hellopath = os.path.join(SPARK_HOME, "python/test_support/hello/hello.txt") - newconf = {"mapreduce.input.fileinputformat.inputdir": hellopath} - hello = self.sc.newAPIHadoopRDD("org.apache.hadoop.mapreduce.lib.input.TextInputFormat", - "org.apache.hadoop.io.LongWritable", - "org.apache.hadoop.io.Text", - conf=newconf).collect() - result = [(0, u'Hello World!')] - self.assertEqual(hello, result) - - def test_newolderror(self): - basepath = self.tempdir.name - self.assertRaises(Exception, lambda: self.sc.hadoopFile( - basepath + "/sftestdata/sfint/", - "org.apache.hadoop.mapreduce.lib.input.SequenceFileInputFormat", - "org.apache.hadoop.io.IntWritable", - "org.apache.hadoop.io.Text")) - - self.assertRaises(Exception, lambda: self.sc.newAPIHadoopFile( - basepath + "/sftestdata/sfint/", - "org.apache.hadoop.mapred.SequenceFileInputFormat", - "org.apache.hadoop.io.IntWritable", - "org.apache.hadoop.io.Text")) - - def test_bad_inputs(self): - basepath = self.tempdir.name - self.assertRaises(Exception, lambda: self.sc.sequenceFile( - basepath + "/sftestdata/sfint/", - "org.apache.hadoop.io.NotValidWritable", - "org.apache.hadoop.io.Text")) - self.assertRaises(Exception, lambda: self.sc.hadoopFile( - basepath + "/sftestdata/sfint/", - "org.apache.hadoop.mapred.NotValidInputFormat", - "org.apache.hadoop.io.IntWritable", - "org.apache.hadoop.io.Text")) - self.assertRaises(Exception, lambda: self.sc.newAPIHadoopFile( - basepath + "/sftestdata/sfint/", - "org.apache.hadoop.mapreduce.lib.input.NotValidInputFormat", - "org.apache.hadoop.io.IntWritable", - "org.apache.hadoop.io.Text")) - - def test_converters(self): - # use of custom converters - basepath = self.tempdir.name - maps = sorted(self.sc.sequenceFile( - basepath + "/sftestdata/sfmap/", - "org.apache.hadoop.io.IntWritable", - "org.apache.hadoop.io.MapWritable", - keyConverter="org.apache.spark.api.python.TestInputKeyConverter", - valueConverter="org.apache.spark.api.python.TestInputValueConverter").collect()) - em = [(u'\x01', []), - (u'\x01', [3.0]), - (u'\x02', [1.0]), - (u'\x02', [1.0]), - (u'\x03', [2.0])] - self.assertEqual(maps, em) - - def test_binary_files(self): - path = os.path.join(self.tempdir.name, "binaryfiles") - os.mkdir(path) - data = b"short binary data" - with open(os.path.join(path, "part-0000"), 'wb') as f: - f.write(data) - [(p, d)] = self.sc.binaryFiles(path).collect() - self.assertTrue(p.endswith("part-0000")) - self.assertEqual(d, data) - - def test_binary_records(self): - path = os.path.join(self.tempdir.name, "binaryrecords") - os.mkdir(path) - with open(os.path.join(path, "part-0000"), 'w') as f: - for i in range(100): - f.write('%04d' % i) - result = self.sc.binaryRecords(path, 4).map(int).collect() - self.assertEqual(list(range(100)), result) - - -class OutputFormatTests(ReusedPySparkTestCase): - - def setUp(self): - self.tempdir = tempfile.NamedTemporaryFile(delete=False) - os.unlink(self.tempdir.name) - - def tearDown(self): - shutil.rmtree(self.tempdir.name, ignore_errors=True) - - @unittest.skipIf(sys.version >= "3", "serialize array of byte") - def test_sequencefiles(self): - basepath = self.tempdir.name - ei = [(1, u'aa'), (1, u'aa'), (2, u'aa'), (2, u'bb'), (2, u'bb'), (3, u'cc')] - self.sc.parallelize(ei).saveAsSequenceFile(basepath + "/sfint/") - ints = sorted(self.sc.sequenceFile(basepath + "/sfint/").collect()) - self.assertEqual(ints, ei) - - ed = [(1.0, u'aa'), (1.0, u'aa'), (2.0, u'aa'), (2.0, u'bb'), (2.0, u'bb'), (3.0, u'cc')] - self.sc.parallelize(ed).saveAsSequenceFile(basepath + "/sfdouble/") - doubles = sorted(self.sc.sequenceFile(basepath + "/sfdouble/").collect()) - self.assertEqual(doubles, ed) - - ebs = [(1, bytearray(b'\x00\x07spam\x08')), (2, bytearray(b'\x00\x07spam\x08'))] - self.sc.parallelize(ebs).saveAsSequenceFile(basepath + "/sfbytes/") - bytes = sorted(self.sc.sequenceFile(basepath + "/sfbytes/").collect()) - self.assertEqual(bytes, ebs) - - et = [(u'1', u'aa'), - (u'2', u'bb'), - (u'3', u'cc')] - self.sc.parallelize(et).saveAsSequenceFile(basepath + "/sftext/") - text = sorted(self.sc.sequenceFile(basepath + "/sftext/").collect()) - self.assertEqual(text, et) - - eb = [(1, False), (1, True), (2, False), (2, False), (2, True), (3, True)] - self.sc.parallelize(eb).saveAsSequenceFile(basepath + "/sfbool/") - bools = sorted(self.sc.sequenceFile(basepath + "/sfbool/").collect()) - self.assertEqual(bools, eb) - - en = [(1, None), (1, None), (2, None), (2, None), (2, None), (3, None)] - self.sc.parallelize(en).saveAsSequenceFile(basepath + "/sfnull/") - nulls = sorted(self.sc.sequenceFile(basepath + "/sfnull/").collect()) - self.assertEqual(nulls, en) - - em = [(1, {}), - (1, {3.0: u'bb'}), - (2, {1.0: u'aa'}), - (2, {1.0: u'cc'}), - (3, {2.0: u'dd'})] - self.sc.parallelize(em).saveAsSequenceFile(basepath + "/sfmap/") - maps = self.sc.sequenceFile(basepath + "/sfmap/").collect() - for v in maps: - self.assertTrue(v, em) - - def test_oldhadoop(self): - basepath = self.tempdir.name - dict_data = [(1, {}), - (1, {"row1": 1.0}), - (2, {"row2": 2.0})] - self.sc.parallelize(dict_data).saveAsHadoopFile( - basepath + "/oldhadoop/", - "org.apache.hadoop.mapred.SequenceFileOutputFormat", - "org.apache.hadoop.io.IntWritable", - "org.apache.hadoop.io.MapWritable") - result = self.sc.hadoopFile( - basepath + "/oldhadoop/", - "org.apache.hadoop.mapred.SequenceFileInputFormat", - "org.apache.hadoop.io.IntWritable", - "org.apache.hadoop.io.MapWritable").collect() - for v in result: - self.assertTrue(v, dict_data) - - conf = { - "mapred.output.format.class": "org.apache.hadoop.mapred.SequenceFileOutputFormat", - "mapreduce.job.output.key.class": "org.apache.hadoop.io.IntWritable", - "mapreduce.job.output.value.class": "org.apache.hadoop.io.MapWritable", - "mapreduce.output.fileoutputformat.outputdir": basepath + "/olddataset/" - } - self.sc.parallelize(dict_data).saveAsHadoopDataset(conf) - input_conf = {"mapreduce.input.fileinputformat.inputdir": basepath + "/olddataset/"} - result = self.sc.hadoopRDD( - "org.apache.hadoop.mapred.SequenceFileInputFormat", - "org.apache.hadoop.io.IntWritable", - "org.apache.hadoop.io.MapWritable", - conf=input_conf).collect() - for v in result: - self.assertTrue(v, dict_data) - - def test_newhadoop(self): - basepath = self.tempdir.name - data = [(1, ""), - (1, "a"), - (2, "bcdf")] - self.sc.parallelize(data).saveAsNewAPIHadoopFile( - basepath + "/newhadoop/", - "org.apache.hadoop.mapreduce.lib.output.SequenceFileOutputFormat", - "org.apache.hadoop.io.IntWritable", - "org.apache.hadoop.io.Text") - result = sorted(self.sc.newAPIHadoopFile( - basepath + "/newhadoop/", - "org.apache.hadoop.mapreduce.lib.input.SequenceFileInputFormat", - "org.apache.hadoop.io.IntWritable", - "org.apache.hadoop.io.Text").collect()) - self.assertEqual(result, data) - - conf = { - "mapreduce.job.outputformat.class": - "org.apache.hadoop.mapreduce.lib.output.SequenceFileOutputFormat", - "mapreduce.job.output.key.class": "org.apache.hadoop.io.IntWritable", - "mapreduce.job.output.value.class": "org.apache.hadoop.io.Text", - "mapreduce.output.fileoutputformat.outputdir": basepath + "/newdataset/" - } - self.sc.parallelize(data).saveAsNewAPIHadoopDataset(conf) - input_conf = {"mapreduce.input.fileinputformat.inputdir": basepath + "/newdataset/"} - new_dataset = sorted(self.sc.newAPIHadoopRDD( - "org.apache.hadoop.mapreduce.lib.input.SequenceFileInputFormat", - "org.apache.hadoop.io.IntWritable", - "org.apache.hadoop.io.Text", - conf=input_conf).collect()) - self.assertEqual(new_dataset, data) - - @unittest.skipIf(sys.version >= "3", "serialize of array") - def test_newhadoop_with_array(self): - basepath = self.tempdir.name - # use custom ArrayWritable types and converters to handle arrays - array_data = [(1, array('d')), - (1, array('d', [1.0, 2.0, 3.0])), - (2, array('d', [3.0, 4.0, 5.0]))] - self.sc.parallelize(array_data).saveAsNewAPIHadoopFile( - basepath + "/newhadoop/", - "org.apache.hadoop.mapreduce.lib.output.SequenceFileOutputFormat", - "org.apache.hadoop.io.IntWritable", - "org.apache.spark.api.python.DoubleArrayWritable", - valueConverter="org.apache.spark.api.python.DoubleArrayToWritableConverter") - result = sorted(self.sc.newAPIHadoopFile( - basepath + "/newhadoop/", - "org.apache.hadoop.mapreduce.lib.input.SequenceFileInputFormat", - "org.apache.hadoop.io.IntWritable", - "org.apache.spark.api.python.DoubleArrayWritable", - valueConverter="org.apache.spark.api.python.WritableToDoubleArrayConverter").collect()) - self.assertEqual(result, array_data) - - conf = { - "mapreduce.job.outputformat.class": - "org.apache.hadoop.mapreduce.lib.output.SequenceFileOutputFormat", - "mapreduce.job.output.key.class": "org.apache.hadoop.io.IntWritable", - "mapreduce.job.output.value.class": "org.apache.spark.api.python.DoubleArrayWritable", - "mapreduce.output.fileoutputformat.outputdir": basepath + "/newdataset/" - } - self.sc.parallelize(array_data).saveAsNewAPIHadoopDataset( - conf, - valueConverter="org.apache.spark.api.python.DoubleArrayToWritableConverter") - input_conf = {"mapreduce.input.fileinputformat.inputdir": basepath + "/newdataset/"} - new_dataset = sorted(self.sc.newAPIHadoopRDD( - "org.apache.hadoop.mapreduce.lib.input.SequenceFileInputFormat", - "org.apache.hadoop.io.IntWritable", - "org.apache.spark.api.python.DoubleArrayWritable", - valueConverter="org.apache.spark.api.python.WritableToDoubleArrayConverter", - conf=input_conf).collect()) - self.assertEqual(new_dataset, array_data) - - def test_newolderror(self): - basepath = self.tempdir.name - rdd = self.sc.parallelize(range(1, 4)).map(lambda x: (x, "a" * x)) - self.assertRaises(Exception, lambda: rdd.saveAsHadoopFile( - basepath + "/newolderror/saveAsHadoopFile/", - "org.apache.hadoop.mapreduce.lib.output.SequenceFileOutputFormat")) - self.assertRaises(Exception, lambda: rdd.saveAsNewAPIHadoopFile( - basepath + "/newolderror/saveAsNewAPIHadoopFile/", - "org.apache.hadoop.mapred.SequenceFileOutputFormat")) - - def test_bad_inputs(self): - basepath = self.tempdir.name - rdd = self.sc.parallelize(range(1, 4)).map(lambda x: (x, "a" * x)) - self.assertRaises(Exception, lambda: rdd.saveAsHadoopFile( - basepath + "/badinputs/saveAsHadoopFile/", - "org.apache.hadoop.mapred.NotValidOutputFormat")) - self.assertRaises(Exception, lambda: rdd.saveAsNewAPIHadoopFile( - basepath + "/badinputs/saveAsNewAPIHadoopFile/", - "org.apache.hadoop.mapreduce.lib.output.NotValidOutputFormat")) - - def test_converters(self): - # use of custom converters - basepath = self.tempdir.name - data = [(1, {3.0: u'bb'}), - (2, {1.0: u'aa'}), - (3, {2.0: u'dd'})] - self.sc.parallelize(data).saveAsNewAPIHadoopFile( - basepath + "/converters/", - "org.apache.hadoop.mapreduce.lib.output.SequenceFileOutputFormat", - keyConverter="org.apache.spark.api.python.TestOutputKeyConverter", - valueConverter="org.apache.spark.api.python.TestOutputValueConverter") - converted = sorted(self.sc.sequenceFile(basepath + "/converters/").collect()) - expected = [(u'1', 3.0), - (u'2', 1.0), - (u'3', 2.0)] - self.assertEqual(converted, expected) - - def test_reserialization(self): - basepath = self.tempdir.name - x = range(1, 5) - y = range(1001, 1005) - data = list(zip(x, y)) - rdd = self.sc.parallelize(x).zip(self.sc.parallelize(y)) - rdd.saveAsSequenceFile(basepath + "/reserialize/sequence") - result1 = sorted(self.sc.sequenceFile(basepath + "/reserialize/sequence").collect()) - self.assertEqual(result1, data) - - rdd.saveAsHadoopFile( - basepath + "/reserialize/hadoop", - "org.apache.hadoop.mapred.SequenceFileOutputFormat") - result2 = sorted(self.sc.sequenceFile(basepath + "/reserialize/hadoop").collect()) - self.assertEqual(result2, data) - - rdd.saveAsNewAPIHadoopFile( - basepath + "/reserialize/newhadoop", - "org.apache.hadoop.mapreduce.lib.output.SequenceFileOutputFormat") - result3 = sorted(self.sc.sequenceFile(basepath + "/reserialize/newhadoop").collect()) - self.assertEqual(result3, data) - - conf4 = { - "mapred.output.format.class": "org.apache.hadoop.mapred.SequenceFileOutputFormat", - "mapreduce.job.output.key.class": "org.apache.hadoop.io.IntWritable", - "mapreduce.job.output.value.class": "org.apache.hadoop.io.IntWritable", - "mapreduce.output.fileoutputformat.outputdir": basepath + "/reserialize/dataset"} - rdd.saveAsHadoopDataset(conf4) - result4 = sorted(self.sc.sequenceFile(basepath + "/reserialize/dataset").collect()) - self.assertEqual(result4, data) - - conf5 = {"mapreduce.job.outputformat.class": - "org.apache.hadoop.mapreduce.lib.output.SequenceFileOutputFormat", - "mapreduce.job.output.key.class": "org.apache.hadoop.io.IntWritable", - "mapreduce.job.output.value.class": "org.apache.hadoop.io.IntWritable", - "mapreduce.output.fileoutputformat.outputdir": basepath + "/reserialize/newdataset" - } - rdd.saveAsNewAPIHadoopDataset(conf5) - result5 = sorted(self.sc.sequenceFile(basepath + "/reserialize/newdataset").collect()) - self.assertEqual(result5, data) - - def test_malformed_RDD(self): - basepath = self.tempdir.name - # non-batch-serialized RDD[[(K, V)]] should be rejected - data = [[(1, "a")], [(2, "aa")], [(3, "aaa")]] - rdd = self.sc.parallelize(data, len(data)) - self.assertRaises(Exception, lambda: rdd.saveAsSequenceFile( - basepath + "/malformed/sequence")) - - -class DaemonTests(unittest.TestCase): - def connect(self, port): - from socket import socket, AF_INET, SOCK_STREAM - sock = socket(AF_INET, SOCK_STREAM) - sock.connect(('127.0.0.1', port)) - # send a split index of -1 to shutdown the worker - sock.send(b"\xFF\xFF\xFF\xFF") - sock.close() - return True - - def do_termination_test(self, terminator): - from subprocess import Popen, PIPE - from errno import ECONNREFUSED - - # start daemon - daemon_path = os.path.join(os.path.dirname(__file__), "daemon.py") - python_exec = sys.executable or os.environ.get("PYSPARK_PYTHON") - daemon = Popen([python_exec, daemon_path], stdin=PIPE, stdout=PIPE) - - # read the port number - port = read_int(daemon.stdout) - - # daemon should accept connections - self.assertTrue(self.connect(port)) - - # request shutdown - terminator(daemon) - time.sleep(1) - - # daemon should no longer accept connections - try: - self.connect(port) - except EnvironmentError as exception: - self.assertEqual(exception.errno, ECONNREFUSED) - else: - self.fail("Expected EnvironmentError to be raised") - - def test_termination_stdin(self): - """Ensure that daemon and workers terminate when stdin is closed.""" - self.do_termination_test(lambda daemon: daemon.stdin.close()) - - def test_termination_sigterm(self): - """Ensure that daemon and workers terminate on SIGTERM.""" - from signal import SIGTERM - self.do_termination_test(lambda daemon: os.kill(daemon.pid, SIGTERM)) - - -class WorkerTests(ReusedPySparkTestCase): - def test_cancel_task(self): - temp = tempfile.NamedTemporaryFile(delete=True) - temp.close() - path = temp.name - - def sleep(x): - import os - import time - with open(path, 'w') as f: - f.write("%d %d" % (os.getppid(), os.getpid())) - time.sleep(100) - - # start job in background thread - def run(): - try: - self.sc.parallelize(range(1), 1).foreach(sleep) - except Exception: - pass - import threading - t = threading.Thread(target=run) - t.daemon = True - t.start() - - daemon_pid, worker_pid = 0, 0 - while True: - if os.path.exists(path): - with open(path) as f: - data = f.read().split(' ') - daemon_pid, worker_pid = map(int, data) - break - time.sleep(0.1) - - # cancel jobs - self.sc.cancelAllJobs() - t.join() - - for i in range(50): - try: - os.kill(worker_pid, 0) - time.sleep(0.1) - except OSError: - break # worker was killed - else: - self.fail("worker has not been killed after 5 seconds") - - try: - os.kill(daemon_pid, 0) - except OSError: - self.fail("daemon had been killed") - - # run a normal job - rdd = self.sc.parallelize(xrange(100), 1) - self.assertEqual(100, rdd.map(str).count()) - - def test_after_exception(self): - def raise_exception(_): - raise Exception() - rdd = self.sc.parallelize(xrange(100), 1) - with QuietTest(self.sc): - self.assertRaises(Exception, lambda: rdd.foreach(raise_exception)) - self.assertEqual(100, rdd.map(str).count()) - - def test_after_jvm_exception(self): - tempFile = tempfile.NamedTemporaryFile(delete=False) - tempFile.write(b"Hello World!") - tempFile.close() - data = self.sc.textFile(tempFile.name, 1) - filtered_data = data.filter(lambda x: True) - self.assertEqual(1, filtered_data.count()) - os.unlink(tempFile.name) - with QuietTest(self.sc): - self.assertRaises(Exception, lambda: filtered_data.count()) - - rdd = self.sc.parallelize(xrange(100), 1) - self.assertEqual(100, rdd.map(str).count()) - - def test_accumulator_when_reuse_worker(self): - from pyspark.accumulators import INT_ACCUMULATOR_PARAM - acc1 = self.sc.accumulator(0, INT_ACCUMULATOR_PARAM) - self.sc.parallelize(xrange(100), 20).foreach(lambda x: acc1.add(x)) - self.assertEqual(sum(range(100)), acc1.value) - - acc2 = self.sc.accumulator(0, INT_ACCUMULATOR_PARAM) - self.sc.parallelize(xrange(100), 20).foreach(lambda x: acc2.add(x)) - self.assertEqual(sum(range(100)), acc2.value) - self.assertEqual(sum(range(100)), acc1.value) - - def test_reuse_worker_after_take(self): - rdd = self.sc.parallelize(xrange(100000), 1) - self.assertEqual(0, rdd.first()) - - def count(): - try: - rdd.count() - except Exception: - pass - - t = threading.Thread(target=count) - t.daemon = True - t.start() - t.join(5) - self.assertTrue(not t.isAlive()) - self.assertEqual(100000, rdd.count()) - - def test_with_different_versions_of_python(self): - rdd = self.sc.parallelize(range(10)) - rdd.count() - version = self.sc.pythonVer - self.sc.pythonVer = "2.0" - try: - with QuietTest(self.sc): - self.assertRaises(Py4JJavaError, lambda: rdd.count()) - finally: - self.sc.pythonVer = version - - -class SparkSubmitTests(unittest.TestCase): - - def setUp(self): - self.programDir = tempfile.mkdtemp() - tmp_dir = tempfile.gettempdir() - self.sparkSubmit = [ - os.path.join(os.environ.get("SPARK_HOME"), "bin", "spark-submit"), - "--conf", "spark.driver.extraJavaOptions=-Djava.io.tmpdir={0}".format(tmp_dir), - "--conf", "spark.executor.extraJavaOptions=-Djava.io.tmpdir={0}".format(tmp_dir), - ] - - def tearDown(self): - shutil.rmtree(self.programDir) - - def createTempFile(self, name, content, dir=None): - """ - Create a temp file with the given name and content and return its path. - Strips leading spaces from content up to the first '|' in each line. - """ - pattern = re.compile(r'^ *\|', re.MULTILINE) - content = re.sub(pattern, '', content.strip()) - if dir is None: - path = os.path.join(self.programDir, name) - else: - os.makedirs(os.path.join(self.programDir, dir)) - path = os.path.join(self.programDir, dir, name) - with open(path, "w") as f: - f.write(content) - return path - - def createFileInZip(self, name, content, ext=".zip", dir=None, zip_name=None): - """ - Create a zip archive containing a file with the given content and return its path. - Strips leading spaces from content up to the first '|' in each line. - """ - pattern = re.compile(r'^ *\|', re.MULTILINE) - content = re.sub(pattern, '', content.strip()) - if dir is None: - path = os.path.join(self.programDir, name + ext) - else: - path = os.path.join(self.programDir, dir, zip_name + ext) - zip = zipfile.ZipFile(path, 'w') - zip.writestr(name, content) - zip.close() - return path - - def create_spark_package(self, artifact_name): - group_id, artifact_id, version = artifact_name.split(":") - self.createTempFile("%s-%s.pom" % (artifact_id, version), (""" - | - | - | 4.0.0 - | %s - | %s - | %s - | - """ % (group_id, artifact_id, version)).lstrip(), - os.path.join(group_id, artifact_id, version)) - self.createFileInZip("%s.py" % artifact_id, """ - |def myfunc(x): - | return x + 1 - """, ".jar", os.path.join(group_id, artifact_id, version), - "%s-%s" % (artifact_id, version)) - - def test_single_script(self): - """Submit and test a single script file""" - script = self.createTempFile("test.py", """ - |from pyspark import SparkContext - | - |sc = SparkContext() - |print(sc.parallelize([1, 2, 3]).map(lambda x: x * 2).collect()) - """) - proc = subprocess.Popen(self.sparkSubmit + [script], stdout=subprocess.PIPE) - out, err = proc.communicate() - self.assertEqual(0, proc.returncode) - self.assertIn("[2, 4, 6]", out.decode('utf-8')) - - def test_script_with_local_functions(self): - """Submit and test a single script file calling a global function""" - script = self.createTempFile("test.py", """ - |from pyspark import SparkContext - | - |def foo(x): - | return x * 3 - | - |sc = SparkContext() - |print(sc.parallelize([1, 2, 3]).map(foo).collect()) - """) - proc = subprocess.Popen(self.sparkSubmit + [script], stdout=subprocess.PIPE) - out, err = proc.communicate() - self.assertEqual(0, proc.returncode) - self.assertIn("[3, 6, 9]", out.decode('utf-8')) - - def test_module_dependency(self): - """Submit and test a script with a dependency on another module""" - script = self.createTempFile("test.py", """ - |from pyspark import SparkContext - |from mylib import myfunc - | - |sc = SparkContext() - |print(sc.parallelize([1, 2, 3]).map(myfunc).collect()) - """) - zip = self.createFileInZip("mylib.py", """ - |def myfunc(x): - | return x + 1 - """) - proc = subprocess.Popen(self.sparkSubmit + ["--py-files", zip, script], - stdout=subprocess.PIPE) - out, err = proc.communicate() - self.assertEqual(0, proc.returncode) - self.assertIn("[2, 3, 4]", out.decode('utf-8')) - - def test_module_dependency_on_cluster(self): - """Submit and test a script with a dependency on another module on a cluster""" - script = self.createTempFile("test.py", """ - |from pyspark import SparkContext - |from mylib import myfunc - | - |sc = SparkContext() - |print(sc.parallelize([1, 2, 3]).map(myfunc).collect()) - """) - zip = self.createFileInZip("mylib.py", """ - |def myfunc(x): - | return x + 1 - """) - proc = subprocess.Popen(self.sparkSubmit + ["--py-files", zip, "--master", - "local-cluster[1,1,1024]", script], - stdout=subprocess.PIPE) - out, err = proc.communicate() - self.assertEqual(0, proc.returncode) - self.assertIn("[2, 3, 4]", out.decode('utf-8')) - - def test_package_dependency(self): - """Submit and test a script with a dependency on a Spark Package""" - script = self.createTempFile("test.py", """ - |from pyspark import SparkContext - |from mylib import myfunc - | - |sc = SparkContext() - |print(sc.parallelize([1, 2, 3]).map(myfunc).collect()) - """) - self.create_spark_package("a:mylib:0.1") - proc = subprocess.Popen( - self.sparkSubmit + ["--packages", "a:mylib:0.1", "--repositories", - "file:" + self.programDir, script], - stdout=subprocess.PIPE) - out, err = proc.communicate() - self.assertEqual(0, proc.returncode) - self.assertIn("[2, 3, 4]", out.decode('utf-8')) - - def test_package_dependency_on_cluster(self): - """Submit and test a script with a dependency on a Spark Package on a cluster""" - script = self.createTempFile("test.py", """ - |from pyspark import SparkContext - |from mylib import myfunc - | - |sc = SparkContext() - |print(sc.parallelize([1, 2, 3]).map(myfunc).collect()) - """) - self.create_spark_package("a:mylib:0.1") - proc = subprocess.Popen( - self.sparkSubmit + ["--packages", "a:mylib:0.1", "--repositories", - "file:" + self.programDir, "--master", "local-cluster[1,1,1024]", - script], - stdout=subprocess.PIPE) - out, err = proc.communicate() - self.assertEqual(0, proc.returncode) - self.assertIn("[2, 3, 4]", out.decode('utf-8')) - - def test_single_script_on_cluster(self): - """Submit and test a single script on a cluster""" - script = self.createTempFile("test.py", """ - |from pyspark import SparkContext - | - |def foo(x): - | return x * 2 - | - |sc = SparkContext() - |print(sc.parallelize([1, 2, 3]).map(foo).collect()) - """) - # this will fail if you have different spark.executor.memory - # in conf/spark-defaults.conf - proc = subprocess.Popen( - self.sparkSubmit + ["--master", "local-cluster[1,1,1024]", script], - stdout=subprocess.PIPE) - out, err = proc.communicate() - self.assertEqual(0, proc.returncode) - self.assertIn("[2, 4, 6]", out.decode('utf-8')) - - def test_user_configuration(self): - """Make sure user configuration is respected (SPARK-19307)""" - script = self.createTempFile("test.py", """ - |from pyspark import SparkConf, SparkContext - | - |conf = SparkConf().set("spark.test_config", "1") - |sc = SparkContext(conf = conf) - |try: - | if sc._conf.get("spark.test_config") != "1": - | raise Exception("Cannot find spark.test_config in SparkContext's conf.") - |finally: - | sc.stop() - """) - proc = subprocess.Popen( - self.sparkSubmit + ["--master", "local", script], - stdout=subprocess.PIPE, - stderr=subprocess.STDOUT) - out, err = proc.communicate() - self.assertEqual(0, proc.returncode, msg="Process failed with error:\n {0}".format(out)) - - def test_conda(self): - """Submit and test a single script file via conda""" - script = self.createTempFile("test.py", """ - |from pyspark import SparkContext - | - |sc = SparkContext() - |sc.addCondaPackages('numpy=1.14.0') - | - |# Ensure numpy is accessible on the driver - |import numpy - |arr = [1, 2, 3] - |def mul2(x): - | # Also ensure numpy accessible from executor - | assert numpy.version.version == "1.14.0" - | return x * 2 - |print(sc.parallelize(arr).map(mul2).collect()) - """) - props = self.createTempFile("properties", """ - |spark.conda.binaryPath {} - |spark.conda.channelUrls https://repo.continuum.io/pkgs/main - |spark.conda.bootstrapPackages python=3.5 - """.format(os.environ["CONDA_BIN"])) - env = dict(os.environ) - del env['PYSPARK_PYTHON'] - del env['PYSPARK_DRIVER_PYTHON'] - proc = subprocess.Popen(self.sparkSubmit + [ - "--properties-file", props, script], - stdout=subprocess.PIPE, - stderr=subprocess.PIPE, - env=env) - out, err = proc.communicate() - if 0 != proc.returncode: - self.fail(("spark-submit was unsuccessful with error code {}\n\n" + - "stdout:\n{}\n\nstderr:\n{}").format(proc.returncode, out, err)) - self.assertIn("[2, 4, 6]", out.decode('utf-8')) - - -class ContextTests(unittest.TestCase): - - def test_failed_sparkcontext_creation(self): - # Regression test for SPARK-1550 - self.assertRaises(Exception, lambda: SparkContext("an-invalid-master-name")) - - def test_get_or_create(self): - with SparkContext.getOrCreate() as sc: - self.assertTrue(SparkContext.getOrCreate() is sc) - - def test_parallelize_eager_cleanup(self): - with SparkContext() as sc: - temp_files = os.listdir(sc._temp_dir) - rdd = sc.parallelize([0, 1, 2]) - post_parallalize_temp_files = os.listdir(sc._temp_dir) - self.assertEqual(temp_files, post_parallalize_temp_files) - - def test_set_conf(self): - # This is for an internal use case. When there is an existing SparkContext, - # SparkSession's builder needs to set configs into SparkContext's conf. - sc = SparkContext() - sc._conf.set("spark.test.SPARK16224", "SPARK16224") - self.assertEqual(sc._jsc.sc().conf().get("spark.test.SPARK16224"), "SPARK16224") - sc.stop() - - def test_stop(self): - sc = SparkContext() - self.assertNotEqual(SparkContext._active_spark_context, None) - sc.stop() - self.assertEqual(SparkContext._active_spark_context, None) - - def test_with(self): - with SparkContext() as sc: - self.assertNotEqual(SparkContext._active_spark_context, None) - self.assertEqual(SparkContext._active_spark_context, None) - - def test_with_exception(self): - try: - with SparkContext() as sc: - self.assertNotEqual(SparkContext._active_spark_context, None) - raise Exception() - except: - pass - self.assertEqual(SparkContext._active_spark_context, None) - - def test_with_stop(self): - with SparkContext() as sc: - self.assertNotEqual(SparkContext._active_spark_context, None) - sc.stop() - self.assertEqual(SparkContext._active_spark_context, None) - - def test_progress_api(self): - with SparkContext() as sc: - sc.setJobGroup('test_progress_api', '', True) - rdd = sc.parallelize(range(10)).map(lambda x: time.sleep(100)) - - def run(): - try: - rdd.count() - except Exception: - pass - t = threading.Thread(target=run) - t.daemon = True - t.start() - # wait for scheduler to start - time.sleep(1) - - tracker = sc.statusTracker() - jobIds = tracker.getJobIdsForGroup('test_progress_api') - self.assertEqual(1, len(jobIds)) - job = tracker.getJobInfo(jobIds[0]) - self.assertEqual(1, len(job.stageIds)) - stage = tracker.getStageInfo(job.stageIds[0]) - self.assertEqual(rdd.getNumPartitions(), stage.numTasks) - - sc.cancelAllJobs() - t.join() - # wait for event listener to update the status - time.sleep(1) - - job = tracker.getJobInfo(jobIds[0]) - self.assertEqual('FAILED', job.status) - self.assertEqual([], tracker.getActiveJobsIds()) - self.assertEqual([], tracker.getActiveStageIds()) - - sc.stop() - - def test_startTime(self): - with SparkContext() as sc: - self.assertGreater(sc.startTime, 0) - - -class ConfTests(unittest.TestCase): - def test_memory_conf(self): - memoryList = ["1T", "1G", "1M", "1024K"] - for memory in memoryList: - sc = SparkContext(conf=SparkConf().set("spark.python.worker.memory", memory)) - l = list(range(1024)) - random.shuffle(l) - rdd = sc.parallelize(l, 4) - self.assertEqual(sorted(l), rdd.sortBy(lambda x: x).collect()) - sc.stop() - - -class KeywordOnlyTests(unittest.TestCase): - class Wrapped(object): - @keyword_only - def set(self, x=None, y=None): - if "x" in self._input_kwargs: - self._x = self._input_kwargs["x"] - if "y" in self._input_kwargs: - self._y = self._input_kwargs["y"] - return x, y - - def test_keywords(self): - w = self.Wrapped() - x, y = w.set(y=1) - self.assertEqual(y, 1) - self.assertEqual(y, w._y) - self.assertIsNone(x) - self.assertFalse(hasattr(w, "_x")) - - def test_non_keywords(self): - w = self.Wrapped() - self.assertRaises(TypeError, lambda: w.set(0, y=1)) - - def test_kwarg_ownership(self): - # test _input_kwargs is owned by each class instance and not a shared static variable - class Setter(object): - @keyword_only - def set(self, x=None, other=None, other_x=None): - if "other" in self._input_kwargs: - self._input_kwargs["other"].set(x=self._input_kwargs["other_x"]) - self._x = self._input_kwargs["x"] - - a = Setter() - b = Setter() - a.set(x=1, other=b, other_x=2) - self.assertEqual(a._x, 1) - self.assertEqual(b._x, 2) - - -class UtilTests(PySparkTestCase): - def test_py4j_exception_message(self): - from pyspark.util import _exception_message - - with self.assertRaises(Py4JJavaError) as context: - # This attempts java.lang.String(null) which throws an NPE. - self.sc._jvm.java.lang.String(None) - - self.assertTrue('NullPointerException' in _exception_message(context.exception)) - - def test_parsing_version_string(self): - from pyspark.util import VersionUtils - self.assertRaises(ValueError, lambda: VersionUtils.majorMinorVersion("abced")) - - -@unittest.skipIf(not _have_scipy, "SciPy not installed") -class SciPyTests(PySparkTestCase): - - """General PySpark tests that depend on scipy """ - - def test_serialize(self): - from scipy.special import gammaln - x = range(1, 5) - expected = list(map(gammaln, x)) - observed = self.sc.parallelize(x).map(gammaln).collect() - self.assertEqual(expected, observed) - - -@unittest.skipIf(not _have_numpy, "NumPy not installed") -class NumPyTests(PySparkTestCase): - - """General PySpark tests that depend on numpy """ - - def test_statcounter_array(self): - x = self.sc.parallelize([np.array([1.0, 1.0]), np.array([2.0, 2.0]), np.array([3.0, 3.0])]) - s = x.stats() - self.assertSequenceEqual([2.0, 2.0], s.mean().tolist()) - self.assertSequenceEqual([1.0, 1.0], s.min().tolist()) - self.assertSequenceEqual([3.0, 3.0], s.max().tolist()) - self.assertSequenceEqual([1.0, 1.0], s.sampleStdev().tolist()) - - stats_dict = s.asDict() - self.assertEqual(3, stats_dict['count']) - self.assertSequenceEqual([2.0, 2.0], stats_dict['mean'].tolist()) - self.assertSequenceEqual([1.0, 1.0], stats_dict['min'].tolist()) - self.assertSequenceEqual([3.0, 3.0], stats_dict['max'].tolist()) - self.assertSequenceEqual([6.0, 6.0], stats_dict['sum'].tolist()) - self.assertSequenceEqual([1.0, 1.0], stats_dict['stdev'].tolist()) - self.assertSequenceEqual([1.0, 1.0], stats_dict['variance'].tolist()) - - stats_sample_dict = s.asDict(sample=True) - self.assertEqual(3, stats_dict['count']) - self.assertSequenceEqual([2.0, 2.0], stats_sample_dict['mean'].tolist()) - self.assertSequenceEqual([1.0, 1.0], stats_sample_dict['min'].tolist()) - self.assertSequenceEqual([3.0, 3.0], stats_sample_dict['max'].tolist()) - self.assertSequenceEqual([6.0, 6.0], stats_sample_dict['sum'].tolist()) - self.assertSequenceEqual( - [0.816496580927726, 0.816496580927726], stats_sample_dict['stdev'].tolist()) - self.assertSequenceEqual( - [0.6666666666666666, 0.6666666666666666], stats_sample_dict['variance'].tolist()) - - -if __name__ == "__main__": - from pyspark.tests import * - runner = unishark.BufferedTestRunner( - reporters=[unishark.XUnitReporter('target/test-reports/pyspark_{}'.format( - os.path.basename(os.environ.get("PYSPARK_PYTHON", ""))))]) - unittest.main(testRunner=runner, verbosity=2) diff --git a/python/pyspark/tests/__init__.py b/python/pyspark/tests/__init__.py new file mode 100644 index 0000000000000..12bdf0d0175b6 --- /dev/null +++ b/python/pyspark/tests/__init__.py @@ -0,0 +1,16 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# diff --git a/python/pyspark/tests/test_appsubmit.py b/python/pyspark/tests/test_appsubmit.py new file mode 100644 index 0000000000000..92bcb11561307 --- /dev/null +++ b/python/pyspark/tests/test_appsubmit.py @@ -0,0 +1,248 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import os +import re +import shutil +import subprocess +import tempfile +import unittest +import zipfile + + +class SparkSubmitTests(unittest.TestCase): + + def setUp(self): + self.programDir = tempfile.mkdtemp() + tmp_dir = tempfile.gettempdir() + self.sparkSubmit = [ + os.path.join(os.environ.get("SPARK_HOME"), "bin", "spark-submit"), + "--conf", "spark.driver.extraJavaOptions=-Djava.io.tmpdir={0}".format(tmp_dir), + "--conf", "spark.executor.extraJavaOptions=-Djava.io.tmpdir={0}".format(tmp_dir), + ] + + def tearDown(self): + shutil.rmtree(self.programDir) + + def createTempFile(self, name, content, dir=None): + """ + Create a temp file with the given name and content and return its path. + Strips leading spaces from content up to the first '|' in each line. + """ + pattern = re.compile(r'^ *\|', re.MULTILINE) + content = re.sub(pattern, '', content.strip()) + if dir is None: + path = os.path.join(self.programDir, name) + else: + os.makedirs(os.path.join(self.programDir, dir)) + path = os.path.join(self.programDir, dir, name) + with open(path, "w") as f: + f.write(content) + return path + + def createFileInZip(self, name, content, ext=".zip", dir=None, zip_name=None): + """ + Create a zip archive containing a file with the given content and return its path. + Strips leading spaces from content up to the first '|' in each line. + """ + pattern = re.compile(r'^ *\|', re.MULTILINE) + content = re.sub(pattern, '', content.strip()) + if dir is None: + path = os.path.join(self.programDir, name + ext) + else: + path = os.path.join(self.programDir, dir, zip_name + ext) + zip = zipfile.ZipFile(path, 'w') + zip.writestr(name, content) + zip.close() + return path + + def create_spark_package(self, artifact_name): + group_id, artifact_id, version = artifact_name.split(":") + self.createTempFile("%s-%s.pom" % (artifact_id, version), (""" + | + | + | 4.0.0 + | %s + | %s + | %s + | + """ % (group_id, artifact_id, version)).lstrip(), + os.path.join(group_id, artifact_id, version)) + self.createFileInZip("%s.py" % artifact_id, """ + |def myfunc(x): + | return x + 1 + """, ".jar", os.path.join(group_id, artifact_id, version), + "%s-%s" % (artifact_id, version)) + + def test_single_script(self): + """Submit and test a single script file""" + script = self.createTempFile("test.py", """ + |from pyspark import SparkContext + | + |sc = SparkContext() + |print(sc.parallelize([1, 2, 3]).map(lambda x: x * 2).collect()) + """) + proc = subprocess.Popen(self.sparkSubmit + [script], stdout=subprocess.PIPE) + out, err = proc.communicate() + self.assertEqual(0, proc.returncode) + self.assertIn("[2, 4, 6]", out.decode('utf-8')) + + def test_script_with_local_functions(self): + """Submit and test a single script file calling a global function""" + script = self.createTempFile("test.py", """ + |from pyspark import SparkContext + | + |def foo(x): + | return x * 3 + | + |sc = SparkContext() + |print(sc.parallelize([1, 2, 3]).map(foo).collect()) + """) + proc = subprocess.Popen(self.sparkSubmit + [script], stdout=subprocess.PIPE) + out, err = proc.communicate() + self.assertEqual(0, proc.returncode) + self.assertIn("[3, 6, 9]", out.decode('utf-8')) + + def test_module_dependency(self): + """Submit and test a script with a dependency on another module""" + script = self.createTempFile("test.py", """ + |from pyspark import SparkContext + |from mylib import myfunc + | + |sc = SparkContext() + |print(sc.parallelize([1, 2, 3]).map(myfunc).collect()) + """) + zip = self.createFileInZip("mylib.py", """ + |def myfunc(x): + | return x + 1 + """) + proc = subprocess.Popen(self.sparkSubmit + ["--py-files", zip, script], + stdout=subprocess.PIPE) + out, err = proc.communicate() + self.assertEqual(0, proc.returncode) + self.assertIn("[2, 3, 4]", out.decode('utf-8')) + + def test_module_dependency_on_cluster(self): + """Submit and test a script with a dependency on another module on a cluster""" + script = self.createTempFile("test.py", """ + |from pyspark import SparkContext + |from mylib import myfunc + | + |sc = SparkContext() + |print(sc.parallelize([1, 2, 3]).map(myfunc).collect()) + """) + zip = self.createFileInZip("mylib.py", """ + |def myfunc(x): + | return x + 1 + """) + proc = subprocess.Popen(self.sparkSubmit + ["--py-files", zip, "--master", + "local-cluster[1,1,1024]", script], + stdout=subprocess.PIPE) + out, err = proc.communicate() + self.assertEqual(0, proc.returncode) + self.assertIn("[2, 3, 4]", out.decode('utf-8')) + + def test_package_dependency(self): + """Submit and test a script with a dependency on a Spark Package""" + script = self.createTempFile("test.py", """ + |from pyspark import SparkContext + |from mylib import myfunc + | + |sc = SparkContext() + |print(sc.parallelize([1, 2, 3]).map(myfunc).collect()) + """) + self.create_spark_package("a:mylib:0.1") + proc = subprocess.Popen( + self.sparkSubmit + ["--packages", "a:mylib:0.1", "--repositories", + "file:" + self.programDir, script], + stdout=subprocess.PIPE) + out, err = proc.communicate() + self.assertEqual(0, proc.returncode) + self.assertIn("[2, 3, 4]", out.decode('utf-8')) + + def test_package_dependency_on_cluster(self): + """Submit and test a script with a dependency on a Spark Package on a cluster""" + script = self.createTempFile("test.py", """ + |from pyspark import SparkContext + |from mylib import myfunc + | + |sc = SparkContext() + |print(sc.parallelize([1, 2, 3]).map(myfunc).collect()) + """) + self.create_spark_package("a:mylib:0.1") + proc = subprocess.Popen( + self.sparkSubmit + ["--packages", "a:mylib:0.1", "--repositories", + "file:" + self.programDir, "--master", "local-cluster[1,1,1024]", + script], + stdout=subprocess.PIPE) + out, err = proc.communicate() + self.assertEqual(0, proc.returncode) + self.assertIn("[2, 3, 4]", out.decode('utf-8')) + + def test_single_script_on_cluster(self): + """Submit and test a single script on a cluster""" + script = self.createTempFile("test.py", """ + |from pyspark import SparkContext + | + |def foo(x): + | return x * 2 + | + |sc = SparkContext() + |print(sc.parallelize([1, 2, 3]).map(foo).collect()) + """) + # this will fail if you have different spark.executor.memory + # in conf/spark-defaults.conf + proc = subprocess.Popen( + self.sparkSubmit + ["--master", "local-cluster[1,1,1024]", script], + stdout=subprocess.PIPE) + out, err = proc.communicate() + self.assertEqual(0, proc.returncode) + self.assertIn("[2, 4, 6]", out.decode('utf-8')) + + def test_user_configuration(self): + """Make sure user configuration is respected (SPARK-19307)""" + script = self.createTempFile("test.py", """ + |from pyspark import SparkConf, SparkContext + | + |conf = SparkConf().set("spark.test_config", "1") + |sc = SparkContext(conf = conf) + |try: + | if sc._conf.get("spark.test_config") != "1": + | raise Exception("Cannot find spark.test_config in SparkContext's conf.") + |finally: + | sc.stop() + """) + proc = subprocess.Popen( + self.sparkSubmit + ["--master", "local", script], + stdout=subprocess.PIPE, + stderr=subprocess.STDOUT) + out, err = proc.communicate() + self.assertEqual(0, proc.returncode, msg="Process failed with error:\n {0}".format(out)) + + +if __name__ == "__main__": + from pyspark.tests.test_appsubmit import * + + try: + import xmlrunner + testRunner = xmlrunner.XMLTestRunner(output='target/test-reports') + except ImportError: + testRunner = None + unittest.main(testRunner=testRunner, verbosity=2) diff --git a/python/pyspark/test_broadcast.py b/python/pyspark/tests/test_broadcast.py similarity index 91% rename from python/pyspark/test_broadcast.py rename to python/pyspark/tests/test_broadcast.py index a00329c18ad8f..a98626e8f4bc9 100644 --- a/python/pyspark/test_broadcast.py +++ b/python/pyspark/tests/test_broadcast.py @@ -14,20 +14,12 @@ # See the License for the specific language governing permissions and # limitations under the License. # - import os import random import tempfile import unittest -try: - import xmlrunner -except ImportError: - xmlrunner = None - -from pyspark.broadcast import Broadcast -from pyspark.conf import SparkConf -from pyspark.context import SparkContext +from pyspark import SparkConf, SparkContext from pyspark.java_gateway import launch_gateway from pyspark.serializers import ChunkedStream @@ -118,9 +110,13 @@ def random_bytes(n): for buffer_length in [1, 2, 5, 8192]: self._test_chunked_stream(random_bytes(data_length), buffer_length) + if __name__ == '__main__': - from pyspark.test_broadcast import * - if xmlrunner: - unittest.main(testRunner=xmlrunner.XMLTestRunner(output='target/test-reports'), verbosity=2) - else: - unittest.main(verbosity=2) + from pyspark.tests.test_broadcast import * + + try: + import xmlrunner + testRunner = xmlrunner.XMLTestRunner(output='target/test-reports') + except ImportError: + testRunner = None + unittest.main(testRunner=testRunner, verbosity=2) diff --git a/python/pyspark/tests/test_conf.py b/python/pyspark/tests/test_conf.py new file mode 100644 index 0000000000000..f5a9accc3fe6e --- /dev/null +++ b/python/pyspark/tests/test_conf.py @@ -0,0 +1,43 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +import random +import unittest + +from pyspark import SparkContext, SparkConf + + +class ConfTests(unittest.TestCase): + def test_memory_conf(self): + memoryList = ["1T", "1G", "1M", "1024K"] + for memory in memoryList: + sc = SparkContext(conf=SparkConf().set("spark.python.worker.memory", memory)) + l = list(range(1024)) + random.shuffle(l) + rdd = sc.parallelize(l, 4) + self.assertEqual(sorted(l), rdd.sortBy(lambda x: x).collect()) + sc.stop() + + +if __name__ == "__main__": + from pyspark.tests.test_conf import * + + try: + import xmlrunner + testRunner = xmlrunner.XMLTestRunner(output='target/test-reports') + except ImportError: + testRunner = None + unittest.main(testRunner=testRunner, verbosity=2) diff --git a/python/pyspark/tests/test_context.py b/python/pyspark/tests/test_context.py new file mode 100644 index 0000000000000..201baf420354d --- /dev/null +++ b/python/pyspark/tests/test_context.py @@ -0,0 +1,258 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +import os +import shutil +import tempfile +import threading +import time +import unittest + +from pyspark import SparkFiles, SparkContext +from pyspark.testing.utils import ReusedPySparkTestCase, PySparkTestCase, QuietTest, SPARK_HOME + + +class CheckpointTests(ReusedPySparkTestCase): + + def setUp(self): + self.checkpointDir = tempfile.NamedTemporaryFile(delete=False) + os.unlink(self.checkpointDir.name) + self.sc.setCheckpointDir(self.checkpointDir.name) + + def tearDown(self): + shutil.rmtree(self.checkpointDir.name) + + def test_basic_checkpointing(self): + parCollection = self.sc.parallelize([1, 2, 3, 4]) + flatMappedRDD = parCollection.flatMap(lambda x: range(1, x + 1)) + + self.assertFalse(flatMappedRDD.isCheckpointed()) + self.assertTrue(flatMappedRDD.getCheckpointFile() is None) + + flatMappedRDD.checkpoint() + result = flatMappedRDD.collect() + time.sleep(1) # 1 second + self.assertTrue(flatMappedRDD.isCheckpointed()) + self.assertEqual(flatMappedRDD.collect(), result) + self.assertEqual("file:" + self.checkpointDir.name, + os.path.dirname(os.path.dirname(flatMappedRDD.getCheckpointFile()))) + + def test_checkpoint_and_restore(self): + parCollection = self.sc.parallelize([1, 2, 3, 4]) + flatMappedRDD = parCollection.flatMap(lambda x: [x]) + + self.assertFalse(flatMappedRDD.isCheckpointed()) + self.assertTrue(flatMappedRDD.getCheckpointFile() is None) + + flatMappedRDD.checkpoint() + flatMappedRDD.count() # forces a checkpoint to be computed + time.sleep(1) # 1 second + + self.assertTrue(flatMappedRDD.getCheckpointFile() is not None) + recovered = self.sc._checkpointFile(flatMappedRDD.getCheckpointFile(), + flatMappedRDD._jrdd_deserializer) + self.assertEqual([1, 2, 3, 4], recovered.collect()) + + +class LocalCheckpointTests(ReusedPySparkTestCase): + + def test_basic_localcheckpointing(self): + parCollection = self.sc.parallelize([1, 2, 3, 4]) + flatMappedRDD = parCollection.flatMap(lambda x: range(1, x + 1)) + + self.assertFalse(flatMappedRDD.isCheckpointed()) + self.assertFalse(flatMappedRDD.isLocallyCheckpointed()) + + flatMappedRDD.localCheckpoint() + result = flatMappedRDD.collect() + time.sleep(1) # 1 second + self.assertTrue(flatMappedRDD.isCheckpointed()) + self.assertTrue(flatMappedRDD.isLocallyCheckpointed()) + self.assertEqual(flatMappedRDD.collect(), result) + + +class AddFileTests(PySparkTestCase): + + def test_add_py_file(self): + # To ensure that we're actually testing addPyFile's effects, check that + # this job fails due to `userlibrary` not being on the Python path: + # disable logging in log4j temporarily + def func(x): + from userlibrary import UserClass + return UserClass().hello() + with QuietTest(self.sc): + self.assertRaises(Exception, self.sc.parallelize(range(2)).map(func).first) + + # Add the file, so the job should now succeed: + path = os.path.join(SPARK_HOME, "python/test_support/userlibrary.py") + self.sc.addPyFile(path) + res = self.sc.parallelize(range(2)).map(func).first() + self.assertEqual("Hello World!", res) + + def test_add_file_locally(self): + path = os.path.join(SPARK_HOME, "python/test_support/hello/hello.txt") + self.sc.addFile(path) + download_path = SparkFiles.get("hello.txt") + self.assertNotEqual(path, download_path) + with open(download_path) as test_file: + self.assertEqual("Hello World!\n", test_file.readline()) + + def test_add_file_recursively_locally(self): + path = os.path.join(SPARK_HOME, "python/test_support/hello") + self.sc.addFile(path, True) + download_path = SparkFiles.get("hello") + self.assertNotEqual(path, download_path) + with open(download_path + "/hello.txt") as test_file: + self.assertEqual("Hello World!\n", test_file.readline()) + with open(download_path + "/sub_hello/sub_hello.txt") as test_file: + self.assertEqual("Sub Hello World!\n", test_file.readline()) + + def test_add_py_file_locally(self): + # To ensure that we're actually testing addPyFile's effects, check that + # this fails due to `userlibrary` not being on the Python path: + def func(): + from userlibrary import UserClass + self.assertRaises(ImportError, func) + path = os.path.join(SPARK_HOME, "python/test_support/userlibrary.py") + self.sc.addPyFile(path) + from userlibrary import UserClass + self.assertEqual("Hello World!", UserClass().hello()) + + def test_add_egg_file_locally(self): + # To ensure that we're actually testing addPyFile's effects, check that + # this fails due to `userlibrary` not being on the Python path: + def func(): + from userlib import UserClass + self.assertRaises(ImportError, func) + path = os.path.join(SPARK_HOME, "python/test_support/userlib-0.1.zip") + self.sc.addPyFile(path) + from userlib import UserClass + self.assertEqual("Hello World from inside a package!", UserClass().hello()) + + def test_overwrite_system_module(self): + self.sc.addPyFile(os.path.join(SPARK_HOME, "python/test_support/SimpleHTTPServer.py")) + + import SimpleHTTPServer + self.assertEqual("My Server", SimpleHTTPServer.__name__) + + def func(x): + import SimpleHTTPServer + return SimpleHTTPServer.__name__ + + self.assertEqual(["My Server"], self.sc.parallelize(range(1)).map(func).collect()) + + +class ContextTests(unittest.TestCase): + + def test_failed_sparkcontext_creation(self): + # Regression test for SPARK-1550 + self.assertRaises(Exception, lambda: SparkContext("an-invalid-master-name")) + + def test_get_or_create(self): + with SparkContext.getOrCreate() as sc: + self.assertTrue(SparkContext.getOrCreate() is sc) + + def test_parallelize_eager_cleanup(self): + with SparkContext() as sc: + temp_files = os.listdir(sc._temp_dir) + rdd = sc.parallelize([0, 1, 2]) + post_parallalize_temp_files = os.listdir(sc._temp_dir) + self.assertEqual(temp_files, post_parallalize_temp_files) + + def test_set_conf(self): + # This is for an internal use case. When there is an existing SparkContext, + # SparkSession's builder needs to set configs into SparkContext's conf. + sc = SparkContext() + sc._conf.set("spark.test.SPARK16224", "SPARK16224") + self.assertEqual(sc._jsc.sc().conf().get("spark.test.SPARK16224"), "SPARK16224") + sc.stop() + + def test_stop(self): + sc = SparkContext() + self.assertNotEqual(SparkContext._active_spark_context, None) + sc.stop() + self.assertEqual(SparkContext._active_spark_context, None) + + def test_with(self): + with SparkContext() as sc: + self.assertNotEqual(SparkContext._active_spark_context, None) + self.assertEqual(SparkContext._active_spark_context, None) + + def test_with_exception(self): + try: + with SparkContext() as sc: + self.assertNotEqual(SparkContext._active_spark_context, None) + raise Exception() + except: + pass + self.assertEqual(SparkContext._active_spark_context, None) + + def test_with_stop(self): + with SparkContext() as sc: + self.assertNotEqual(SparkContext._active_spark_context, None) + sc.stop() + self.assertEqual(SparkContext._active_spark_context, None) + + def test_progress_api(self): + with SparkContext() as sc: + sc.setJobGroup('test_progress_api', '', True) + rdd = sc.parallelize(range(10)).map(lambda x: time.sleep(100)) + + def run(): + try: + rdd.count() + except Exception: + pass + t = threading.Thread(target=run) + t.daemon = True + t.start() + # wait for scheduler to start + time.sleep(1) + + tracker = sc.statusTracker() + jobIds = tracker.getJobIdsForGroup('test_progress_api') + self.assertEqual(1, len(jobIds)) + job = tracker.getJobInfo(jobIds[0]) + self.assertEqual(1, len(job.stageIds)) + stage = tracker.getStageInfo(job.stageIds[0]) + self.assertEqual(rdd.getNumPartitions(), stage.numTasks) + + sc.cancelAllJobs() + t.join() + # wait for event listener to update the status + time.sleep(1) + + job = tracker.getJobInfo(jobIds[0]) + self.assertEqual('FAILED', job.status) + self.assertEqual([], tracker.getActiveJobsIds()) + self.assertEqual([], tracker.getActiveStageIds()) + + sc.stop() + + def test_startTime(self): + with SparkContext() as sc: + self.assertGreater(sc.startTime, 0) + + +if __name__ == "__main__": + from pyspark.tests.test_context import * + + try: + import xmlrunner + testRunner = xmlrunner.XMLTestRunner(output='target/test-reports') + except ImportError: + testRunner = None + unittest.main(testRunner=testRunner, verbosity=2) diff --git a/python/pyspark/tests/test_daemon.py b/python/pyspark/tests/test_daemon.py new file mode 100644 index 0000000000000..fccd74fff1516 --- /dev/null +++ b/python/pyspark/tests/test_daemon.py @@ -0,0 +1,80 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +import os +import sys +import time +import unittest + +from pyspark.serializers import read_int + + +class DaemonTests(unittest.TestCase): + def connect(self, port): + from socket import socket, AF_INET, SOCK_STREAM + sock = socket(AF_INET, SOCK_STREAM) + sock.connect(('127.0.0.1', port)) + # send a split index of -1 to shutdown the worker + sock.send(b"\xFF\xFF\xFF\xFF") + sock.close() + return True + + def do_termination_test(self, terminator): + from subprocess import Popen, PIPE + from errno import ECONNREFUSED + + # start daemon + daemon_path = os.path.join(os.path.dirname(__file__), "..", "daemon.py") + python_exec = sys.executable or os.environ.get("PYSPARK_PYTHON") + daemon = Popen([python_exec, daemon_path], stdin=PIPE, stdout=PIPE) + + # read the port number + port = read_int(daemon.stdout) + + # daemon should accept connections + self.assertTrue(self.connect(port)) + + # request shutdown + terminator(daemon) + time.sleep(1) + + # daemon should no longer accept connections + try: + self.connect(port) + except EnvironmentError as exception: + self.assertEqual(exception.errno, ECONNREFUSED) + else: + self.fail("Expected EnvironmentError to be raised") + + def test_termination_stdin(self): + """Ensure that daemon and workers terminate when stdin is closed.""" + self.do_termination_test(lambda daemon: daemon.stdin.close()) + + def test_termination_sigterm(self): + """Ensure that daemon and workers terminate on SIGTERM.""" + from signal import SIGTERM + self.do_termination_test(lambda daemon: os.kill(daemon.pid, SIGTERM)) + + +if __name__ == "__main__": + from pyspark.tests.test_daemon import * + + try: + import xmlrunner + testRunner = xmlrunner.XMLTestRunner(output='target/test-reports') + except ImportError: + testRunner = None + unittest.main(testRunner=testRunner, verbosity=2) diff --git a/python/pyspark/tests/test_join.py b/python/pyspark/tests/test_join.py new file mode 100644 index 0000000000000..e97e695f8b20d --- /dev/null +++ b/python/pyspark/tests/test_join.py @@ -0,0 +1,69 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +from pyspark.testing.utils import ReusedPySparkTestCase + + +class JoinTests(ReusedPySparkTestCase): + + def test_narrow_dependency_in_join(self): + rdd = self.sc.parallelize(range(10)).map(lambda x: (x, x)) + parted = rdd.partitionBy(2) + self.assertEqual(2, parted.union(parted).getNumPartitions()) + self.assertEqual(rdd.getNumPartitions() + 2, parted.union(rdd).getNumPartitions()) + self.assertEqual(rdd.getNumPartitions() + 2, rdd.union(parted).getNumPartitions()) + + tracker = self.sc.statusTracker() + + self.sc.setJobGroup("test1", "test", True) + d = sorted(parted.join(parted).collect()) + self.assertEqual(10, len(d)) + self.assertEqual((0, (0, 0)), d[0]) + jobId = tracker.getJobIdsForGroup("test1")[0] + self.assertEqual(2, len(tracker.getJobInfo(jobId).stageIds)) + + self.sc.setJobGroup("test2", "test", True) + d = sorted(parted.join(rdd).collect()) + self.assertEqual(10, len(d)) + self.assertEqual((0, (0, 0)), d[0]) + jobId = tracker.getJobIdsForGroup("test2")[0] + self.assertEqual(3, len(tracker.getJobInfo(jobId).stageIds)) + + self.sc.setJobGroup("test3", "test", True) + d = sorted(parted.cogroup(parted).collect()) + self.assertEqual(10, len(d)) + self.assertEqual([[0], [0]], list(map(list, d[0][1]))) + jobId = tracker.getJobIdsForGroup("test3")[0] + self.assertEqual(2, len(tracker.getJobInfo(jobId).stageIds)) + + self.sc.setJobGroup("test4", "test", True) + d = sorted(parted.cogroup(rdd).collect()) + self.assertEqual(10, len(d)) + self.assertEqual([[0], [0]], list(map(list, d[0][1]))) + jobId = tracker.getJobIdsForGroup("test4")[0] + self.assertEqual(3, len(tracker.getJobInfo(jobId).stageIds)) + + +if __name__ == "__main__": + import unittest + from pyspark.tests.test_join import * + + try: + import xmlrunner + testRunner = xmlrunner.XMLTestRunner(output='target/test-reports') + except ImportError: + testRunner = None + unittest.main(testRunner=testRunner, verbosity=2) diff --git a/python/pyspark/tests/test_profiler.py b/python/pyspark/tests/test_profiler.py new file mode 100644 index 0000000000000..56cbcff01657c --- /dev/null +++ b/python/pyspark/tests/test_profiler.py @@ -0,0 +1,112 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import os +import sys +import tempfile +import unittest + +from pyspark import SparkConf, SparkContext, BasicProfiler +from pyspark.testing.utils import PySparkTestCase + +if sys.version >= "3": + from io import StringIO +else: + from StringIO import StringIO + + +class ProfilerTests(PySparkTestCase): + + def setUp(self): + self._old_sys_path = list(sys.path) + class_name = self.__class__.__name__ + conf = SparkConf().set("spark.python.profile", "true") + self.sc = SparkContext('local[4]', class_name, conf=conf) + + def test_profiler(self): + self.do_computation() + + profilers = self.sc.profiler_collector.profilers + self.assertEqual(1, len(profilers)) + id, profiler, _ = profilers[0] + stats = profiler.stats() + self.assertTrue(stats is not None) + width, stat_list = stats.get_print_list([]) + func_names = [func_name for fname, n, func_name in stat_list] + self.assertTrue("heavy_foo" in func_names) + + old_stdout = sys.stdout + sys.stdout = io = StringIO() + self.sc.show_profiles() + self.assertTrue("heavy_foo" in io.getvalue()) + sys.stdout = old_stdout + + d = tempfile.gettempdir() + self.sc.dump_profiles(d) + self.assertTrue("rdd_%d.pstats" % id in os.listdir(d)) + + def test_custom_profiler(self): + class TestCustomProfiler(BasicProfiler): + def show(self, id): + self.result = "Custom formatting" + + self.sc.profiler_collector.profiler_cls = TestCustomProfiler + + self.do_computation() + + profilers = self.sc.profiler_collector.profilers + self.assertEqual(1, len(profilers)) + _, profiler, _ = profilers[0] + self.assertTrue(isinstance(profiler, TestCustomProfiler)) + + self.sc.show_profiles() + self.assertEqual("Custom formatting", profiler.result) + + def do_computation(self): + def heavy_foo(x): + for i in range(1 << 18): + x = 1 + + rdd = self.sc.parallelize(range(100)) + rdd.foreach(heavy_foo) + + +class ProfilerTests2(unittest.TestCase): + def test_profiler_disabled(self): + sc = SparkContext(conf=SparkConf().set("spark.python.profile", "false")) + try: + self.assertRaisesRegexp( + RuntimeError, + "'spark.python.profile' configuration must be set", + lambda: sc.show_profiles()) + self.assertRaisesRegexp( + RuntimeError, + "'spark.python.profile' configuration must be set", + lambda: sc.dump_profiles("/tmp/abc")) + finally: + sc.stop() + + +if __name__ == "__main__": + from pyspark.tests.test_profiler import * + + try: + import xmlrunner + testRunner = xmlrunner.XMLTestRunner(output='target/test-reports') + except ImportError: + testRunner = None + unittest.main(testRunner=testRunner, verbosity=2) diff --git a/python/pyspark/tests/test_rdd.py b/python/pyspark/tests/test_rdd.py new file mode 100644 index 0000000000000..b2a544b8de78a --- /dev/null +++ b/python/pyspark/tests/test_rdd.py @@ -0,0 +1,739 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +import hashlib +import os +import random +import sys +import tempfile +from glob import glob + +from py4j.protocol import Py4JJavaError + +from pyspark import shuffle, RDD +from pyspark.serializers import CloudPickleSerializer, BatchedSerializer, PickleSerializer,\ + MarshalSerializer, UTF8Deserializer, NoOpSerializer +from pyspark.testing.utils import ReusedPySparkTestCase, SPARK_HOME, QuietTest + +if sys.version_info[0] >= 3: + xrange = range + + +class RDDTests(ReusedPySparkTestCase): + + def test_range(self): + self.assertEqual(self.sc.range(1, 1).count(), 0) + self.assertEqual(self.sc.range(1, 0, -1).count(), 1) + self.assertEqual(self.sc.range(0, 1 << 40, 1 << 39).count(), 2) + + def test_id(self): + rdd = self.sc.parallelize(range(10)) + id = rdd.id() + self.assertEqual(id, rdd.id()) + rdd2 = rdd.map(str).filter(bool) + id2 = rdd2.id() + self.assertEqual(id + 1, id2) + self.assertEqual(id2, rdd2.id()) + + def test_empty_rdd(self): + rdd = self.sc.emptyRDD() + self.assertTrue(rdd.isEmpty()) + + def test_sum(self): + self.assertEqual(0, self.sc.emptyRDD().sum()) + self.assertEqual(6, self.sc.parallelize([1, 2, 3]).sum()) + + def test_to_localiterator(self): + from time import sleep + rdd = self.sc.parallelize([1, 2, 3]) + it = rdd.toLocalIterator() + sleep(5) + self.assertEqual([1, 2, 3], sorted(it)) + + rdd2 = rdd.repartition(1000) + it2 = rdd2.toLocalIterator() + sleep(5) + self.assertEqual([1, 2, 3], sorted(it2)) + + def test_save_as_textfile_with_unicode(self): + # Regression test for SPARK-970 + x = u"\u00A1Hola, mundo!" + data = self.sc.parallelize([x]) + tempFile = tempfile.NamedTemporaryFile(delete=True) + tempFile.close() + data.saveAsTextFile(tempFile.name) + raw_contents = b''.join(open(p, 'rb').read() + for p in glob(tempFile.name + "/part-0000*")) + self.assertEqual(x, raw_contents.strip().decode("utf-8")) + + def test_save_as_textfile_with_utf8(self): + x = u"\u00A1Hola, mundo!" + data = self.sc.parallelize([x.encode("utf-8")]) + tempFile = tempfile.NamedTemporaryFile(delete=True) + tempFile.close() + data.saveAsTextFile(tempFile.name) + raw_contents = b''.join(open(p, 'rb').read() + for p in glob(tempFile.name + "/part-0000*")) + self.assertEqual(x, raw_contents.strip().decode('utf8')) + + def test_transforming_cartesian_result(self): + # Regression test for SPARK-1034 + rdd1 = self.sc.parallelize([1, 2]) + rdd2 = self.sc.parallelize([3, 4]) + cart = rdd1.cartesian(rdd2) + result = cart.map(lambda x_y3: x_y3[0] + x_y3[1]).collect() + + def test_transforming_pickle_file(self): + # Regression test for SPARK-2601 + data = self.sc.parallelize([u"Hello", u"World!"]) + tempFile = tempfile.NamedTemporaryFile(delete=True) + tempFile.close() + data.saveAsPickleFile(tempFile.name) + pickled_file = self.sc.pickleFile(tempFile.name) + pickled_file.map(lambda x: x).collect() + + def test_cartesian_on_textfile(self): + # Regression test for + path = os.path.join(SPARK_HOME, "python/test_support/hello/hello.txt") + a = self.sc.textFile(path) + result = a.cartesian(a).collect() + (x, y) = result[0] + self.assertEqual(u"Hello World!", x.strip()) + self.assertEqual(u"Hello World!", y.strip()) + + def test_cartesian_chaining(self): + # Tests for SPARK-16589 + rdd = self.sc.parallelize(range(10), 2) + self.assertSetEqual( + set(rdd.cartesian(rdd).cartesian(rdd).collect()), + set([((x, y), z) for x in range(10) for y in range(10) for z in range(10)]) + ) + + self.assertSetEqual( + set(rdd.cartesian(rdd.cartesian(rdd)).collect()), + set([(x, (y, z)) for x in range(10) for y in range(10) for z in range(10)]) + ) + + self.assertSetEqual( + set(rdd.cartesian(rdd.zip(rdd)).collect()), + set([(x, (y, y)) for x in range(10) for y in range(10)]) + ) + + def test_zip_chaining(self): + # Tests for SPARK-21985 + rdd = self.sc.parallelize('abc', 2) + self.assertSetEqual( + set(rdd.zip(rdd).zip(rdd).collect()), + set([((x, x), x) for x in 'abc']) + ) + self.assertSetEqual( + set(rdd.zip(rdd.zip(rdd)).collect()), + set([(x, (x, x)) for x in 'abc']) + ) + + def test_deleting_input_files(self): + # Regression test for SPARK-1025 + tempFile = tempfile.NamedTemporaryFile(delete=False) + tempFile.write(b"Hello World!") + tempFile.close() + data = self.sc.textFile(tempFile.name) + filtered_data = data.filter(lambda x: True) + self.assertEqual(1, filtered_data.count()) + os.unlink(tempFile.name) + with QuietTest(self.sc): + self.assertRaises(Exception, lambda: filtered_data.count()) + + def test_sampling_default_seed(self): + # Test for SPARK-3995 (default seed setting) + data = self.sc.parallelize(xrange(1000), 1) + subset = data.takeSample(False, 10) + self.assertEqual(len(subset), 10) + + def test_aggregate_mutable_zero_value(self): + # Test for SPARK-9021; uses aggregate and treeAggregate to build dict + # representing a counter of ints + # NOTE: dict is used instead of collections.Counter for Python 2.6 + # compatibility + from collections import defaultdict + + # Show that single or multiple partitions work + data1 = self.sc.range(10, numSlices=1) + data2 = self.sc.range(10, numSlices=2) + + def seqOp(x, y): + x[y] += 1 + return x + + def comboOp(x, y): + for key, val in y.items(): + x[key] += val + return x + + counts1 = data1.aggregate(defaultdict(int), seqOp, comboOp) + counts2 = data2.aggregate(defaultdict(int), seqOp, comboOp) + counts3 = data1.treeAggregate(defaultdict(int), seqOp, comboOp, 2) + counts4 = data2.treeAggregate(defaultdict(int), seqOp, comboOp, 2) + + ground_truth = defaultdict(int, dict((i, 1) for i in range(10))) + self.assertEqual(counts1, ground_truth) + self.assertEqual(counts2, ground_truth) + self.assertEqual(counts3, ground_truth) + self.assertEqual(counts4, ground_truth) + + def test_aggregate_by_key_mutable_zero_value(self): + # Test for SPARK-9021; uses aggregateByKey to make a pair RDD that + # contains lists of all values for each key in the original RDD + + # list(range(...)) for Python 3.x compatibility (can't use * operator + # on a range object) + # list(zip(...)) for Python 3.x compatibility (want to parallelize a + # collection, not a zip object) + tuples = list(zip(list(range(10))*2, [1]*20)) + # Show that single or multiple partitions work + data1 = self.sc.parallelize(tuples, 1) + data2 = self.sc.parallelize(tuples, 2) + + def seqOp(x, y): + x.append(y) + return x + + def comboOp(x, y): + x.extend(y) + return x + + values1 = data1.aggregateByKey([], seqOp, comboOp).collect() + values2 = data2.aggregateByKey([], seqOp, comboOp).collect() + # Sort lists to ensure clean comparison with ground_truth + values1.sort() + values2.sort() + + ground_truth = [(i, [1]*2) for i in range(10)] + self.assertEqual(values1, ground_truth) + self.assertEqual(values2, ground_truth) + + def test_fold_mutable_zero_value(self): + # Test for SPARK-9021; uses fold to merge an RDD of dict counters into + # a single dict + # NOTE: dict is used instead of collections.Counter for Python 2.6 + # compatibility + from collections import defaultdict + + counts1 = defaultdict(int, dict((i, 1) for i in range(10))) + counts2 = defaultdict(int, dict((i, 1) for i in range(3, 8))) + counts3 = defaultdict(int, dict((i, 1) for i in range(4, 7))) + counts4 = defaultdict(int, dict((i, 1) for i in range(5, 6))) + all_counts = [counts1, counts2, counts3, counts4] + # Show that single or multiple partitions work + data1 = self.sc.parallelize(all_counts, 1) + data2 = self.sc.parallelize(all_counts, 2) + + def comboOp(x, y): + for key, val in y.items(): + x[key] += val + return x + + fold1 = data1.fold(defaultdict(int), comboOp) + fold2 = data2.fold(defaultdict(int), comboOp) + + ground_truth = defaultdict(int) + for counts in all_counts: + for key, val in counts.items(): + ground_truth[key] += val + self.assertEqual(fold1, ground_truth) + self.assertEqual(fold2, ground_truth) + + def test_fold_by_key_mutable_zero_value(self): + # Test for SPARK-9021; uses foldByKey to make a pair RDD that contains + # lists of all values for each key in the original RDD + + tuples = [(i, range(i)) for i in range(10)]*2 + # Show that single or multiple partitions work + data1 = self.sc.parallelize(tuples, 1) + data2 = self.sc.parallelize(tuples, 2) + + def comboOp(x, y): + x.extend(y) + return x + + values1 = data1.foldByKey([], comboOp).collect() + values2 = data2.foldByKey([], comboOp).collect() + # Sort lists to ensure clean comparison with ground_truth + values1.sort() + values2.sort() + + # list(range(...)) for Python 3.x compatibility + ground_truth = [(i, list(range(i))*2) for i in range(10)] + self.assertEqual(values1, ground_truth) + self.assertEqual(values2, ground_truth) + + def test_aggregate_by_key(self): + data = self.sc.parallelize([(1, 1), (1, 1), (3, 2), (5, 1), (5, 3)], 2) + + def seqOp(x, y): + x.add(y) + return x + + def combOp(x, y): + x |= y + return x + + sets = dict(data.aggregateByKey(set(), seqOp, combOp).collect()) + self.assertEqual(3, len(sets)) + self.assertEqual(set([1]), sets[1]) + self.assertEqual(set([2]), sets[3]) + self.assertEqual(set([1, 3]), sets[5]) + + def test_itemgetter(self): + rdd = self.sc.parallelize([range(10)]) + from operator import itemgetter + self.assertEqual([1], rdd.map(itemgetter(1)).collect()) + self.assertEqual([(2, 3)], rdd.map(itemgetter(2, 3)).collect()) + + def test_namedtuple_in_rdd(self): + from collections import namedtuple + Person = namedtuple("Person", "id firstName lastName") + jon = Person(1, "Jon", "Doe") + jane = Person(2, "Jane", "Doe") + theDoes = self.sc.parallelize([jon, jane]) + self.assertEqual([jon, jane], theDoes.collect()) + + def test_large_broadcast(self): + N = 10000 + data = [[float(i) for i in range(300)] for i in range(N)] + bdata = self.sc.broadcast(data) # 27MB + m = self.sc.parallelize(range(1), 1).map(lambda x: len(bdata.value)).sum() + self.assertEqual(N, m) + + def test_unpersist(self): + N = 1000 + data = [[float(i) for i in range(300)] for i in range(N)] + bdata = self.sc.broadcast(data) # 3MB + bdata.unpersist() + m = self.sc.parallelize(range(1), 1).map(lambda x: len(bdata.value)).sum() + self.assertEqual(N, m) + bdata.destroy() + try: + self.sc.parallelize(range(1), 1).map(lambda x: len(bdata.value)).sum() + except Exception as e: + pass + else: + raise Exception("job should fail after destroy the broadcast") + + def test_multiple_broadcasts(self): + N = 1 << 21 + b1 = self.sc.broadcast(set(range(N))) # multiple blocks in JVM + r = list(range(1 << 15)) + random.shuffle(r) + s = str(r).encode() + checksum = hashlib.md5(s).hexdigest() + b2 = self.sc.broadcast(s) + r = list(set(self.sc.parallelize(range(10), 10).map( + lambda x: (len(b1.value), hashlib.md5(b2.value).hexdigest())).collect())) + self.assertEqual(1, len(r)) + size, csum = r[0] + self.assertEqual(N, size) + self.assertEqual(checksum, csum) + + random.shuffle(r) + s = str(r).encode() + checksum = hashlib.md5(s).hexdigest() + b2 = self.sc.broadcast(s) + r = list(set(self.sc.parallelize(range(10), 10).map( + lambda x: (len(b1.value), hashlib.md5(b2.value).hexdigest())).collect())) + self.assertEqual(1, len(r)) + size, csum = r[0] + self.assertEqual(N, size) + self.assertEqual(checksum, csum) + + def test_multithread_broadcast_pickle(self): + import threading + + b1 = self.sc.broadcast(list(range(3))) + b2 = self.sc.broadcast(list(range(3))) + + def f1(): + return b1.value + + def f2(): + return b2.value + + funcs_num_pickled = {f1: None, f2: None} + + def do_pickle(f, sc): + command = (f, None, sc.serializer, sc.serializer) + ser = CloudPickleSerializer() + ser.dumps(command) + + def process_vars(sc): + broadcast_vars = list(sc._pickled_broadcast_vars) + num_pickled = len(broadcast_vars) + sc._pickled_broadcast_vars.clear() + return num_pickled + + def run(f, sc): + do_pickle(f, sc) + funcs_num_pickled[f] = process_vars(sc) + + # pickle f1, adds b1 to sc._pickled_broadcast_vars in main thread local storage + do_pickle(f1, self.sc) + + # run all for f2, should only add/count/clear b2 from worker thread local storage + t = threading.Thread(target=run, args=(f2, self.sc)) + t.start() + t.join() + + # count number of vars pickled in main thread, only b1 should be counted and cleared + funcs_num_pickled[f1] = process_vars(self.sc) + + self.assertEqual(funcs_num_pickled[f1], 1) + self.assertEqual(funcs_num_pickled[f2], 1) + self.assertEqual(len(list(self.sc._pickled_broadcast_vars)), 0) + + def test_large_closure(self): + N = 200000 + data = [float(i) for i in xrange(N)] + rdd = self.sc.parallelize(range(1), 1).map(lambda x: len(data)) + self.assertEqual(N, rdd.first()) + # regression test for SPARK-6886 + self.assertEqual(1, rdd.map(lambda x: (x, 1)).groupByKey().count()) + + def test_zip_with_different_serializers(self): + a = self.sc.parallelize(range(5)) + b = self.sc.parallelize(range(100, 105)) + self.assertEqual(a.zip(b).collect(), [(0, 100), (1, 101), (2, 102), (3, 103), (4, 104)]) + a = a._reserialize(BatchedSerializer(PickleSerializer(), 2)) + b = b._reserialize(MarshalSerializer()) + self.assertEqual(a.zip(b).collect(), [(0, 100), (1, 101), (2, 102), (3, 103), (4, 104)]) + # regression test for SPARK-4841 + path = os.path.join(SPARK_HOME, "python/test_support/hello/hello.txt") + t = self.sc.textFile(path) + cnt = t.count() + self.assertEqual(cnt, t.zip(t).count()) + rdd = t.map(str) + self.assertEqual(cnt, t.zip(rdd).count()) + # regression test for bug in _reserializer() + self.assertEqual(cnt, t.zip(rdd).count()) + + def test_zip_with_different_object_sizes(self): + # regress test for SPARK-5973 + a = self.sc.parallelize(xrange(10000)).map(lambda i: '*' * i) + b = self.sc.parallelize(xrange(10000, 20000)).map(lambda i: '*' * i) + self.assertEqual(10000, a.zip(b).count()) + + def test_zip_with_different_number_of_items(self): + a = self.sc.parallelize(range(5), 2) + # different number of partitions + b = self.sc.parallelize(range(100, 106), 3) + self.assertRaises(ValueError, lambda: a.zip(b)) + with QuietTest(self.sc): + # different number of batched items in JVM + b = self.sc.parallelize(range(100, 104), 2) + self.assertRaises(Exception, lambda: a.zip(b).count()) + # different number of items in one pair + b = self.sc.parallelize(range(100, 106), 2) + self.assertRaises(Exception, lambda: a.zip(b).count()) + # same total number of items, but different distributions + a = self.sc.parallelize([2, 3], 2).flatMap(range) + b = self.sc.parallelize([3, 2], 2).flatMap(range) + self.assertEqual(a.count(), b.count()) + self.assertRaises(Exception, lambda: a.zip(b).count()) + + def test_count_approx_distinct(self): + rdd = self.sc.parallelize(xrange(1000)) + self.assertTrue(950 < rdd.countApproxDistinct(0.03) < 1050) + self.assertTrue(950 < rdd.map(float).countApproxDistinct(0.03) < 1050) + self.assertTrue(950 < rdd.map(str).countApproxDistinct(0.03) < 1050) + self.assertTrue(950 < rdd.map(lambda x: (x, -x)).countApproxDistinct(0.03) < 1050) + + rdd = self.sc.parallelize([i % 20 for i in range(1000)], 7) + self.assertTrue(18 < rdd.countApproxDistinct() < 22) + self.assertTrue(18 < rdd.map(float).countApproxDistinct() < 22) + self.assertTrue(18 < rdd.map(str).countApproxDistinct() < 22) + self.assertTrue(18 < rdd.map(lambda x: (x, -x)).countApproxDistinct() < 22) + + self.assertRaises(ValueError, lambda: rdd.countApproxDistinct(0.00000001)) + + def test_histogram(self): + # empty + rdd = self.sc.parallelize([]) + self.assertEqual([0], rdd.histogram([0, 10])[1]) + self.assertEqual([0, 0], rdd.histogram([0, 4, 10])[1]) + self.assertRaises(ValueError, lambda: rdd.histogram(1)) + + # out of range + rdd = self.sc.parallelize([10.01, -0.01]) + self.assertEqual([0], rdd.histogram([0, 10])[1]) + self.assertEqual([0, 0], rdd.histogram((0, 4, 10))[1]) + + # in range with one bucket + rdd = self.sc.parallelize(range(1, 5)) + self.assertEqual([4], rdd.histogram([0, 10])[1]) + self.assertEqual([3, 1], rdd.histogram([0, 4, 10])[1]) + + # in range with one bucket exact match + self.assertEqual([4], rdd.histogram([1, 4])[1]) + + # out of range with two buckets + rdd = self.sc.parallelize([10.01, -0.01]) + self.assertEqual([0, 0], rdd.histogram([0, 5, 10])[1]) + + # out of range with two uneven buckets + rdd = self.sc.parallelize([10.01, -0.01]) + self.assertEqual([0, 0], rdd.histogram([0, 4, 10])[1]) + + # in range with two buckets + rdd = self.sc.parallelize([1, 2, 3, 5, 6]) + self.assertEqual([3, 2], rdd.histogram([0, 5, 10])[1]) + + # in range with two bucket and None + rdd = self.sc.parallelize([1, 2, 3, 5, 6, None, float('nan')]) + self.assertEqual([3, 2], rdd.histogram([0, 5, 10])[1]) + + # in range with two uneven buckets + rdd = self.sc.parallelize([1, 2, 3, 5, 6]) + self.assertEqual([3, 2], rdd.histogram([0, 5, 11])[1]) + + # mixed range with two uneven buckets + rdd = self.sc.parallelize([-0.01, 0.0, 1, 2, 3, 5, 6, 11.0, 11.01]) + self.assertEqual([4, 3], rdd.histogram([0, 5, 11])[1]) + + # mixed range with four uneven buckets + rdd = self.sc.parallelize([-0.01, 0.0, 1, 2, 3, 5, 6, 11.01, 12.0, 199.0, 200.0, 200.1]) + self.assertEqual([4, 2, 1, 3], rdd.histogram([0.0, 5.0, 11.0, 12.0, 200.0])[1]) + + # mixed range with uneven buckets and NaN + rdd = self.sc.parallelize([-0.01, 0.0, 1, 2, 3, 5, 6, 11.01, 12.0, + 199.0, 200.0, 200.1, None, float('nan')]) + self.assertEqual([4, 2, 1, 3], rdd.histogram([0.0, 5.0, 11.0, 12.0, 200.0])[1]) + + # out of range with infinite buckets + rdd = self.sc.parallelize([10.01, -0.01, float('nan'), float("inf")]) + self.assertEqual([1, 2], rdd.histogram([float('-inf'), 0, float('inf')])[1]) + + # invalid buckets + self.assertRaises(ValueError, lambda: rdd.histogram([])) + self.assertRaises(ValueError, lambda: rdd.histogram([1])) + self.assertRaises(ValueError, lambda: rdd.histogram(0)) + self.assertRaises(TypeError, lambda: rdd.histogram({})) + + # without buckets + rdd = self.sc.parallelize(range(1, 5)) + self.assertEqual(([1, 4], [4]), rdd.histogram(1)) + + # without buckets single element + rdd = self.sc.parallelize([1]) + self.assertEqual(([1, 1], [1]), rdd.histogram(1)) + + # without bucket no range + rdd = self.sc.parallelize([1] * 4) + self.assertEqual(([1, 1], [4]), rdd.histogram(1)) + + # without buckets basic two + rdd = self.sc.parallelize(range(1, 5)) + self.assertEqual(([1, 2.5, 4], [2, 2]), rdd.histogram(2)) + + # without buckets with more requested than elements + rdd = self.sc.parallelize([1, 2]) + buckets = [1 + 0.2 * i for i in range(6)] + hist = [1, 0, 0, 0, 1] + self.assertEqual((buckets, hist), rdd.histogram(5)) + + # invalid RDDs + rdd = self.sc.parallelize([1, float('inf')]) + self.assertRaises(ValueError, lambda: rdd.histogram(2)) + rdd = self.sc.parallelize([float('nan')]) + self.assertRaises(ValueError, lambda: rdd.histogram(2)) + + # string + rdd = self.sc.parallelize(["ab", "ac", "b", "bd", "ef"], 2) + self.assertEqual([2, 2], rdd.histogram(["a", "b", "c"])[1]) + self.assertEqual((["ab", "ef"], [5]), rdd.histogram(1)) + self.assertRaises(TypeError, lambda: rdd.histogram(2)) + + def test_repartitionAndSortWithinPartitions_asc(self): + rdd = self.sc.parallelize([(0, 5), (3, 8), (2, 6), (0, 8), (3, 8), (1, 3)], 2) + + repartitioned = rdd.repartitionAndSortWithinPartitions(2, lambda key: key % 2, True) + partitions = repartitioned.glom().collect() + self.assertEqual(partitions[0], [(0, 5), (0, 8), (2, 6)]) + self.assertEqual(partitions[1], [(1, 3), (3, 8), (3, 8)]) + + def test_repartitionAndSortWithinPartitions_desc(self): + rdd = self.sc.parallelize([(0, 5), (3, 8), (2, 6), (0, 8), (3, 8), (1, 3)], 2) + + repartitioned = rdd.repartitionAndSortWithinPartitions(2, lambda key: key % 2, False) + partitions = repartitioned.glom().collect() + self.assertEqual(partitions[0], [(2, 6), (0, 5), (0, 8)]) + self.assertEqual(partitions[1], [(3, 8), (3, 8), (1, 3)]) + + def test_repartition_no_skewed(self): + num_partitions = 20 + a = self.sc.parallelize(range(int(1000)), 2) + l = a.repartition(num_partitions).glom().map(len).collect() + zeros = len([x for x in l if x == 0]) + self.assertTrue(zeros == 0) + l = a.coalesce(num_partitions, True).glom().map(len).collect() + zeros = len([x for x in l if x == 0]) + self.assertTrue(zeros == 0) + + def test_repartition_on_textfile(self): + path = os.path.join(SPARK_HOME, "python/test_support/hello/hello.txt") + rdd = self.sc.textFile(path) + result = rdd.repartition(1).collect() + self.assertEqual(u"Hello World!", result[0]) + + def test_distinct(self): + rdd = self.sc.parallelize((1, 2, 3)*10, 10) + self.assertEqual(rdd.getNumPartitions(), 10) + self.assertEqual(rdd.distinct().count(), 3) + result = rdd.distinct(5) + self.assertEqual(result.getNumPartitions(), 5) + self.assertEqual(result.count(), 3) + + def test_external_group_by_key(self): + self.sc._conf.set("spark.python.worker.memory", "1m") + N = 200001 + kv = self.sc.parallelize(xrange(N)).map(lambda x: (x % 3, x)) + gkv = kv.groupByKey().cache() + self.assertEqual(3, gkv.count()) + filtered = gkv.filter(lambda kv: kv[0] == 1) + self.assertEqual(1, filtered.count()) + self.assertEqual([(1, N // 3)], filtered.mapValues(len).collect()) + self.assertEqual([(N // 3, N // 3)], + filtered.values().map(lambda x: (len(x), len(list(x)))).collect()) + result = filtered.collect()[0][1] + self.assertEqual(N // 3, len(result)) + self.assertTrue(isinstance(result.data, shuffle.ExternalListOfList)) + + def test_sort_on_empty_rdd(self): + self.assertEqual([], self.sc.parallelize(zip([], [])).sortByKey().collect()) + + def test_sample(self): + rdd = self.sc.parallelize(range(0, 100), 4) + wo = rdd.sample(False, 0.1, 2).collect() + wo_dup = rdd.sample(False, 0.1, 2).collect() + self.assertSetEqual(set(wo), set(wo_dup)) + wr = rdd.sample(True, 0.2, 5).collect() + wr_dup = rdd.sample(True, 0.2, 5).collect() + self.assertSetEqual(set(wr), set(wr_dup)) + wo_s10 = rdd.sample(False, 0.3, 10).collect() + wo_s20 = rdd.sample(False, 0.3, 20).collect() + self.assertNotEqual(set(wo_s10), set(wo_s20)) + wr_s11 = rdd.sample(True, 0.4, 11).collect() + wr_s21 = rdd.sample(True, 0.4, 21).collect() + self.assertNotEqual(set(wr_s11), set(wr_s21)) + + def test_null_in_rdd(self): + jrdd = self.sc._jvm.PythonUtils.generateRDDWithNull(self.sc._jsc) + rdd = RDD(jrdd, self.sc, UTF8Deserializer()) + self.assertEqual([u"a", None, u"b"], rdd.collect()) + rdd = RDD(jrdd, self.sc, NoOpSerializer()) + self.assertEqual([b"a", None, b"b"], rdd.collect()) + + def test_multiple_python_java_RDD_conversions(self): + # Regression test for SPARK-5361 + data = [ + (u'1', {u'director': u'David Lean'}), + (u'2', {u'director': u'Andrew Dominik'}) + ] + data_rdd = self.sc.parallelize(data) + data_java_rdd = data_rdd._to_java_object_rdd() + data_python_rdd = self.sc._jvm.SerDeUtil.javaToPython(data_java_rdd) + converted_rdd = RDD(data_python_rdd, self.sc) + self.assertEqual(2, converted_rdd.count()) + + # conversion between python and java RDD threw exceptions + data_java_rdd = converted_rdd._to_java_object_rdd() + data_python_rdd = self.sc._jvm.SerDeUtil.javaToPython(data_java_rdd) + converted_rdd = RDD(data_python_rdd, self.sc) + self.assertEqual(2, converted_rdd.count()) + + # Regression test for SPARK-6294 + def test_take_on_jrdd(self): + rdd = self.sc.parallelize(xrange(1 << 20)).map(lambda x: str(x)) + rdd._jrdd.first() + + def test_sortByKey_uses_all_partitions_not_only_first_and_last(self): + # Regression test for SPARK-5969 + seq = [(i * 59 % 101, i) for i in range(101)] # unsorted sequence + rdd = self.sc.parallelize(seq) + for ascending in [True, False]: + sort = rdd.sortByKey(ascending=ascending, numPartitions=5) + self.assertEqual(sort.collect(), sorted(seq, reverse=not ascending)) + sizes = sort.glom().map(len).collect() + for size in sizes: + self.assertGreater(size, 0) + + def test_pipe_functions(self): + data = ['1', '2', '3'] + rdd = self.sc.parallelize(data) + with QuietTest(self.sc): + self.assertEqual([], rdd.pipe('cc').collect()) + self.assertRaises(Py4JJavaError, rdd.pipe('cc', checkCode=True).collect) + result = rdd.pipe('cat').collect() + result.sort() + for x, y in zip(data, result): + self.assertEqual(x, y) + self.assertRaises(Py4JJavaError, rdd.pipe('grep 4', checkCode=True).collect) + self.assertEqual([], rdd.pipe('grep 4').collect()) + + def test_pipe_unicode(self): + # Regression test for SPARK-20947 + data = [u'\u6d4b\u8bd5', '1'] + rdd = self.sc.parallelize(data) + result = rdd.pipe('cat').collect() + self.assertEqual(data, result) + + def test_stopiteration_in_user_code(self): + + def stopit(*x): + raise StopIteration() + + seq_rdd = self.sc.parallelize(range(10)) + keyed_rdd = self.sc.parallelize((x % 2, x) for x in range(10)) + msg = "Caught StopIteration thrown from user's code; failing the task" + + self.assertRaisesRegexp(Py4JJavaError, msg, seq_rdd.map(stopit).collect) + self.assertRaisesRegexp(Py4JJavaError, msg, seq_rdd.filter(stopit).collect) + self.assertRaisesRegexp(Py4JJavaError, msg, seq_rdd.foreach, stopit) + self.assertRaisesRegexp(Py4JJavaError, msg, seq_rdd.reduce, stopit) + self.assertRaisesRegexp(Py4JJavaError, msg, seq_rdd.fold, 0, stopit) + self.assertRaisesRegexp(Py4JJavaError, msg, seq_rdd.foreach, stopit) + self.assertRaisesRegexp(Py4JJavaError, msg, + seq_rdd.cartesian(seq_rdd).flatMap(stopit).collect) + + # these methods call the user function both in the driver and in the executor + # the exception raised is different according to where the StopIteration happens + # RuntimeError is raised if in the driver + # Py4JJavaError is raised if in the executor (wraps the RuntimeError raised in the worker) + self.assertRaisesRegexp((Py4JJavaError, RuntimeError), msg, + keyed_rdd.reduceByKeyLocally, stopit) + self.assertRaisesRegexp((Py4JJavaError, RuntimeError), msg, + seq_rdd.aggregate, 0, stopit, lambda *x: 1) + self.assertRaisesRegexp((Py4JJavaError, RuntimeError), msg, + seq_rdd.aggregate, 0, lambda *x: 1, stopit) + + +if __name__ == "__main__": + import unittest + from pyspark.tests.test_rdd import * + + try: + import xmlrunner + testRunner = xmlrunner.XMLTestRunner(output='target/test-reports') + except ImportError: + testRunner = None + unittest.main(testRunner=testRunner, verbosity=2) diff --git a/python/pyspark/tests/test_readwrite.py b/python/pyspark/tests/test_readwrite.py new file mode 100644 index 0000000000000..e45f5b371f461 --- /dev/null +++ b/python/pyspark/tests/test_readwrite.py @@ -0,0 +1,499 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +import os +import shutil +import sys +import tempfile +import unittest +from array import array + +from pyspark.testing.utils import ReusedPySparkTestCase, SPARK_HOME + + +class InputFormatTests(ReusedPySparkTestCase): + + @classmethod + def setUpClass(cls): + ReusedPySparkTestCase.setUpClass() + cls.tempdir = tempfile.NamedTemporaryFile(delete=False) + os.unlink(cls.tempdir.name) + cls.sc._jvm.WriteInputFormatTestDataGenerator.generateData(cls.tempdir.name, cls.sc._jsc) + + @classmethod + def tearDownClass(cls): + ReusedPySparkTestCase.tearDownClass() + shutil.rmtree(cls.tempdir.name) + + @unittest.skipIf(sys.version >= "3", "serialize array of byte") + def test_sequencefiles(self): + basepath = self.tempdir.name + ints = sorted(self.sc.sequenceFile(basepath + "/sftestdata/sfint/", + "org.apache.hadoop.io.IntWritable", + "org.apache.hadoop.io.Text").collect()) + ei = [(1, u'aa'), (1, u'aa'), (2, u'aa'), (2, u'bb'), (2, u'bb'), (3, u'cc')] + self.assertEqual(ints, ei) + + doubles = sorted(self.sc.sequenceFile(basepath + "/sftestdata/sfdouble/", + "org.apache.hadoop.io.DoubleWritable", + "org.apache.hadoop.io.Text").collect()) + ed = [(1.0, u'aa'), (1.0, u'aa'), (2.0, u'aa'), (2.0, u'bb'), (2.0, u'bb'), (3.0, u'cc')] + self.assertEqual(doubles, ed) + + bytes = sorted(self.sc.sequenceFile(basepath + "/sftestdata/sfbytes/", + "org.apache.hadoop.io.IntWritable", + "org.apache.hadoop.io.BytesWritable").collect()) + ebs = [(1, bytearray('aa', 'utf-8')), + (1, bytearray('aa', 'utf-8')), + (2, bytearray('aa', 'utf-8')), + (2, bytearray('bb', 'utf-8')), + (2, bytearray('bb', 'utf-8')), + (3, bytearray('cc', 'utf-8'))] + self.assertEqual(bytes, ebs) + + text = sorted(self.sc.sequenceFile(basepath + "/sftestdata/sftext/", + "org.apache.hadoop.io.Text", + "org.apache.hadoop.io.Text").collect()) + et = [(u'1', u'aa'), + (u'1', u'aa'), + (u'2', u'aa'), + (u'2', u'bb'), + (u'2', u'bb'), + (u'3', u'cc')] + self.assertEqual(text, et) + + bools = sorted(self.sc.sequenceFile(basepath + "/sftestdata/sfbool/", + "org.apache.hadoop.io.IntWritable", + "org.apache.hadoop.io.BooleanWritable").collect()) + eb = [(1, False), (1, True), (2, False), (2, False), (2, True), (3, True)] + self.assertEqual(bools, eb) + + nulls = sorted(self.sc.sequenceFile(basepath + "/sftestdata/sfnull/", + "org.apache.hadoop.io.IntWritable", + "org.apache.hadoop.io.BooleanWritable").collect()) + en = [(1, None), (1, None), (2, None), (2, None), (2, None), (3, None)] + self.assertEqual(nulls, en) + + maps = self.sc.sequenceFile(basepath + "/sftestdata/sfmap/", + "org.apache.hadoop.io.IntWritable", + "org.apache.hadoop.io.MapWritable").collect() + em = [(1, {}), + (1, {3.0: u'bb'}), + (2, {1.0: u'aa'}), + (2, {1.0: u'cc'}), + (3, {2.0: u'dd'})] + for v in maps: + self.assertTrue(v in em) + + # arrays get pickled to tuples by default + tuples = sorted(self.sc.sequenceFile( + basepath + "/sftestdata/sfarray/", + "org.apache.hadoop.io.IntWritable", + "org.apache.spark.api.python.DoubleArrayWritable").collect()) + et = [(1, ()), + (2, (3.0, 4.0, 5.0)), + (3, (4.0, 5.0, 6.0))] + self.assertEqual(tuples, et) + + # with custom converters, primitive arrays can stay as arrays + arrays = sorted(self.sc.sequenceFile( + basepath + "/sftestdata/sfarray/", + "org.apache.hadoop.io.IntWritable", + "org.apache.spark.api.python.DoubleArrayWritable", + valueConverter="org.apache.spark.api.python.WritableToDoubleArrayConverter").collect()) + ea = [(1, array('d')), + (2, array('d', [3.0, 4.0, 5.0])), + (3, array('d', [4.0, 5.0, 6.0]))] + self.assertEqual(arrays, ea) + + clazz = sorted(self.sc.sequenceFile(basepath + "/sftestdata/sfclass/", + "org.apache.hadoop.io.Text", + "org.apache.spark.api.python.TestWritable").collect()) + cname = u'org.apache.spark.api.python.TestWritable' + ec = [(u'1', {u'__class__': cname, u'double': 1.0, u'int': 1, u'str': u'test1'}), + (u'2', {u'__class__': cname, u'double': 2.3, u'int': 2, u'str': u'test2'}), + (u'3', {u'__class__': cname, u'double': 3.1, u'int': 3, u'str': u'test3'}), + (u'4', {u'__class__': cname, u'double': 4.2, u'int': 4, u'str': u'test4'}), + (u'5', {u'__class__': cname, u'double': 5.5, u'int': 5, u'str': u'test56'})] + self.assertEqual(clazz, ec) + + unbatched_clazz = sorted(self.sc.sequenceFile(basepath + "/sftestdata/sfclass/", + "org.apache.hadoop.io.Text", + "org.apache.spark.api.python.TestWritable", + ).collect()) + self.assertEqual(unbatched_clazz, ec) + + def test_oldhadoop(self): + basepath = self.tempdir.name + ints = sorted(self.sc.hadoopFile(basepath + "/sftestdata/sfint/", + "org.apache.hadoop.mapred.SequenceFileInputFormat", + "org.apache.hadoop.io.IntWritable", + "org.apache.hadoop.io.Text").collect()) + ei = [(1, u'aa'), (1, u'aa'), (2, u'aa'), (2, u'bb'), (2, u'bb'), (3, u'cc')] + self.assertEqual(ints, ei) + + hellopath = os.path.join(SPARK_HOME, "python/test_support/hello/hello.txt") + oldconf = {"mapreduce.input.fileinputformat.inputdir": hellopath} + hello = self.sc.hadoopRDD("org.apache.hadoop.mapred.TextInputFormat", + "org.apache.hadoop.io.LongWritable", + "org.apache.hadoop.io.Text", + conf=oldconf).collect() + result = [(0, u'Hello World!')] + self.assertEqual(hello, result) + + def test_newhadoop(self): + basepath = self.tempdir.name + ints = sorted(self.sc.newAPIHadoopFile( + basepath + "/sftestdata/sfint/", + "org.apache.hadoop.mapreduce.lib.input.SequenceFileInputFormat", + "org.apache.hadoop.io.IntWritable", + "org.apache.hadoop.io.Text").collect()) + ei = [(1, u'aa'), (1, u'aa'), (2, u'aa'), (2, u'bb'), (2, u'bb'), (3, u'cc')] + self.assertEqual(ints, ei) + + hellopath = os.path.join(SPARK_HOME, "python/test_support/hello/hello.txt") + newconf = {"mapreduce.input.fileinputformat.inputdir": hellopath} + hello = self.sc.newAPIHadoopRDD("org.apache.hadoop.mapreduce.lib.input.TextInputFormat", + "org.apache.hadoop.io.LongWritable", + "org.apache.hadoop.io.Text", + conf=newconf).collect() + result = [(0, u'Hello World!')] + self.assertEqual(hello, result) + + def test_newolderror(self): + basepath = self.tempdir.name + self.assertRaises(Exception, lambda: self.sc.hadoopFile( + basepath + "/sftestdata/sfint/", + "org.apache.hadoop.mapreduce.lib.input.SequenceFileInputFormat", + "org.apache.hadoop.io.IntWritable", + "org.apache.hadoop.io.Text")) + + self.assertRaises(Exception, lambda: self.sc.newAPIHadoopFile( + basepath + "/sftestdata/sfint/", + "org.apache.hadoop.mapred.SequenceFileInputFormat", + "org.apache.hadoop.io.IntWritable", + "org.apache.hadoop.io.Text")) + + def test_bad_inputs(self): + basepath = self.tempdir.name + self.assertRaises(Exception, lambda: self.sc.sequenceFile( + basepath + "/sftestdata/sfint/", + "org.apache.hadoop.io.NotValidWritable", + "org.apache.hadoop.io.Text")) + self.assertRaises(Exception, lambda: self.sc.hadoopFile( + basepath + "/sftestdata/sfint/", + "org.apache.hadoop.mapred.NotValidInputFormat", + "org.apache.hadoop.io.IntWritable", + "org.apache.hadoop.io.Text")) + self.assertRaises(Exception, lambda: self.sc.newAPIHadoopFile( + basepath + "/sftestdata/sfint/", + "org.apache.hadoop.mapreduce.lib.input.NotValidInputFormat", + "org.apache.hadoop.io.IntWritable", + "org.apache.hadoop.io.Text")) + + def test_converters(self): + # use of custom converters + basepath = self.tempdir.name + maps = sorted(self.sc.sequenceFile( + basepath + "/sftestdata/sfmap/", + "org.apache.hadoop.io.IntWritable", + "org.apache.hadoop.io.MapWritable", + keyConverter="org.apache.spark.api.python.TestInputKeyConverter", + valueConverter="org.apache.spark.api.python.TestInputValueConverter").collect()) + em = [(u'\x01', []), + (u'\x01', [3.0]), + (u'\x02', [1.0]), + (u'\x02', [1.0]), + (u'\x03', [2.0])] + self.assertEqual(maps, em) + + def test_binary_files(self): + path = os.path.join(self.tempdir.name, "binaryfiles") + os.mkdir(path) + data = b"short binary data" + with open(os.path.join(path, "part-0000"), 'wb') as f: + f.write(data) + [(p, d)] = self.sc.binaryFiles(path).collect() + self.assertTrue(p.endswith("part-0000")) + self.assertEqual(d, data) + + def test_binary_records(self): + path = os.path.join(self.tempdir.name, "binaryrecords") + os.mkdir(path) + with open(os.path.join(path, "part-0000"), 'w') as f: + for i in range(100): + f.write('%04d' % i) + result = self.sc.binaryRecords(path, 4).map(int).collect() + self.assertEqual(list(range(100)), result) + + +class OutputFormatTests(ReusedPySparkTestCase): + + def setUp(self): + self.tempdir = tempfile.NamedTemporaryFile(delete=False) + os.unlink(self.tempdir.name) + + def tearDown(self): + shutil.rmtree(self.tempdir.name, ignore_errors=True) + + @unittest.skipIf(sys.version >= "3", "serialize array of byte") + def test_sequencefiles(self): + basepath = self.tempdir.name + ei = [(1, u'aa'), (1, u'aa'), (2, u'aa'), (2, u'bb'), (2, u'bb'), (3, u'cc')] + self.sc.parallelize(ei).saveAsSequenceFile(basepath + "/sfint/") + ints = sorted(self.sc.sequenceFile(basepath + "/sfint/").collect()) + self.assertEqual(ints, ei) + + ed = [(1.0, u'aa'), (1.0, u'aa'), (2.0, u'aa'), (2.0, u'bb'), (2.0, u'bb'), (3.0, u'cc')] + self.sc.parallelize(ed).saveAsSequenceFile(basepath + "/sfdouble/") + doubles = sorted(self.sc.sequenceFile(basepath + "/sfdouble/").collect()) + self.assertEqual(doubles, ed) + + ebs = [(1, bytearray(b'\x00\x07spam\x08')), (2, bytearray(b'\x00\x07spam\x08'))] + self.sc.parallelize(ebs).saveAsSequenceFile(basepath + "/sfbytes/") + bytes = sorted(self.sc.sequenceFile(basepath + "/sfbytes/").collect()) + self.assertEqual(bytes, ebs) + + et = [(u'1', u'aa'), + (u'2', u'bb'), + (u'3', u'cc')] + self.sc.parallelize(et).saveAsSequenceFile(basepath + "/sftext/") + text = sorted(self.sc.sequenceFile(basepath + "/sftext/").collect()) + self.assertEqual(text, et) + + eb = [(1, False), (1, True), (2, False), (2, False), (2, True), (3, True)] + self.sc.parallelize(eb).saveAsSequenceFile(basepath + "/sfbool/") + bools = sorted(self.sc.sequenceFile(basepath + "/sfbool/").collect()) + self.assertEqual(bools, eb) + + en = [(1, None), (1, None), (2, None), (2, None), (2, None), (3, None)] + self.sc.parallelize(en).saveAsSequenceFile(basepath + "/sfnull/") + nulls = sorted(self.sc.sequenceFile(basepath + "/sfnull/").collect()) + self.assertEqual(nulls, en) + + em = [(1, {}), + (1, {3.0: u'bb'}), + (2, {1.0: u'aa'}), + (2, {1.0: u'cc'}), + (3, {2.0: u'dd'})] + self.sc.parallelize(em).saveAsSequenceFile(basepath + "/sfmap/") + maps = self.sc.sequenceFile(basepath + "/sfmap/").collect() + for v in maps: + self.assertTrue(v, em) + + def test_oldhadoop(self): + basepath = self.tempdir.name + dict_data = [(1, {}), + (1, {"row1": 1.0}), + (2, {"row2": 2.0})] + self.sc.parallelize(dict_data).saveAsHadoopFile( + basepath + "/oldhadoop/", + "org.apache.hadoop.mapred.SequenceFileOutputFormat", + "org.apache.hadoop.io.IntWritable", + "org.apache.hadoop.io.MapWritable") + result = self.sc.hadoopFile( + basepath + "/oldhadoop/", + "org.apache.hadoop.mapred.SequenceFileInputFormat", + "org.apache.hadoop.io.IntWritable", + "org.apache.hadoop.io.MapWritable").collect() + for v in result: + self.assertTrue(v, dict_data) + + conf = { + "mapred.output.format.class": "org.apache.hadoop.mapred.SequenceFileOutputFormat", + "mapreduce.job.output.key.class": "org.apache.hadoop.io.IntWritable", + "mapreduce.job.output.value.class": "org.apache.hadoop.io.MapWritable", + "mapreduce.output.fileoutputformat.outputdir": basepath + "/olddataset/" + } + self.sc.parallelize(dict_data).saveAsHadoopDataset(conf) + input_conf = {"mapreduce.input.fileinputformat.inputdir": basepath + "/olddataset/"} + result = self.sc.hadoopRDD( + "org.apache.hadoop.mapred.SequenceFileInputFormat", + "org.apache.hadoop.io.IntWritable", + "org.apache.hadoop.io.MapWritable", + conf=input_conf).collect() + for v in result: + self.assertTrue(v, dict_data) + + def test_newhadoop(self): + basepath = self.tempdir.name + data = [(1, ""), + (1, "a"), + (2, "bcdf")] + self.sc.parallelize(data).saveAsNewAPIHadoopFile( + basepath + "/newhadoop/", + "org.apache.hadoop.mapreduce.lib.output.SequenceFileOutputFormat", + "org.apache.hadoop.io.IntWritable", + "org.apache.hadoop.io.Text") + result = sorted(self.sc.newAPIHadoopFile( + basepath + "/newhadoop/", + "org.apache.hadoop.mapreduce.lib.input.SequenceFileInputFormat", + "org.apache.hadoop.io.IntWritable", + "org.apache.hadoop.io.Text").collect()) + self.assertEqual(result, data) + + conf = { + "mapreduce.job.outputformat.class": + "org.apache.hadoop.mapreduce.lib.output.SequenceFileOutputFormat", + "mapreduce.job.output.key.class": "org.apache.hadoop.io.IntWritable", + "mapreduce.job.output.value.class": "org.apache.hadoop.io.Text", + "mapreduce.output.fileoutputformat.outputdir": basepath + "/newdataset/" + } + self.sc.parallelize(data).saveAsNewAPIHadoopDataset(conf) + input_conf = {"mapreduce.input.fileinputformat.inputdir": basepath + "/newdataset/"} + new_dataset = sorted(self.sc.newAPIHadoopRDD( + "org.apache.hadoop.mapreduce.lib.input.SequenceFileInputFormat", + "org.apache.hadoop.io.IntWritable", + "org.apache.hadoop.io.Text", + conf=input_conf).collect()) + self.assertEqual(new_dataset, data) + + @unittest.skipIf(sys.version >= "3", "serialize of array") + def test_newhadoop_with_array(self): + basepath = self.tempdir.name + # use custom ArrayWritable types and converters to handle arrays + array_data = [(1, array('d')), + (1, array('d', [1.0, 2.0, 3.0])), + (2, array('d', [3.0, 4.0, 5.0]))] + self.sc.parallelize(array_data).saveAsNewAPIHadoopFile( + basepath + "/newhadoop/", + "org.apache.hadoop.mapreduce.lib.output.SequenceFileOutputFormat", + "org.apache.hadoop.io.IntWritable", + "org.apache.spark.api.python.DoubleArrayWritable", + valueConverter="org.apache.spark.api.python.DoubleArrayToWritableConverter") + result = sorted(self.sc.newAPIHadoopFile( + basepath + "/newhadoop/", + "org.apache.hadoop.mapreduce.lib.input.SequenceFileInputFormat", + "org.apache.hadoop.io.IntWritable", + "org.apache.spark.api.python.DoubleArrayWritable", + valueConverter="org.apache.spark.api.python.WritableToDoubleArrayConverter").collect()) + self.assertEqual(result, array_data) + + conf = { + "mapreduce.job.outputformat.class": + "org.apache.hadoop.mapreduce.lib.output.SequenceFileOutputFormat", + "mapreduce.job.output.key.class": "org.apache.hadoop.io.IntWritable", + "mapreduce.job.output.value.class": "org.apache.spark.api.python.DoubleArrayWritable", + "mapreduce.output.fileoutputformat.outputdir": basepath + "/newdataset/" + } + self.sc.parallelize(array_data).saveAsNewAPIHadoopDataset( + conf, + valueConverter="org.apache.spark.api.python.DoubleArrayToWritableConverter") + input_conf = {"mapreduce.input.fileinputformat.inputdir": basepath + "/newdataset/"} + new_dataset = sorted(self.sc.newAPIHadoopRDD( + "org.apache.hadoop.mapreduce.lib.input.SequenceFileInputFormat", + "org.apache.hadoop.io.IntWritable", + "org.apache.spark.api.python.DoubleArrayWritable", + valueConverter="org.apache.spark.api.python.WritableToDoubleArrayConverter", + conf=input_conf).collect()) + self.assertEqual(new_dataset, array_data) + + def test_newolderror(self): + basepath = self.tempdir.name + rdd = self.sc.parallelize(range(1, 4)).map(lambda x: (x, "a" * x)) + self.assertRaises(Exception, lambda: rdd.saveAsHadoopFile( + basepath + "/newolderror/saveAsHadoopFile/", + "org.apache.hadoop.mapreduce.lib.output.SequenceFileOutputFormat")) + self.assertRaises(Exception, lambda: rdd.saveAsNewAPIHadoopFile( + basepath + "/newolderror/saveAsNewAPIHadoopFile/", + "org.apache.hadoop.mapred.SequenceFileOutputFormat")) + + def test_bad_inputs(self): + basepath = self.tempdir.name + rdd = self.sc.parallelize(range(1, 4)).map(lambda x: (x, "a" * x)) + self.assertRaises(Exception, lambda: rdd.saveAsHadoopFile( + basepath + "/badinputs/saveAsHadoopFile/", + "org.apache.hadoop.mapred.NotValidOutputFormat")) + self.assertRaises(Exception, lambda: rdd.saveAsNewAPIHadoopFile( + basepath + "/badinputs/saveAsNewAPIHadoopFile/", + "org.apache.hadoop.mapreduce.lib.output.NotValidOutputFormat")) + + def test_converters(self): + # use of custom converters + basepath = self.tempdir.name + data = [(1, {3.0: u'bb'}), + (2, {1.0: u'aa'}), + (3, {2.0: u'dd'})] + self.sc.parallelize(data).saveAsNewAPIHadoopFile( + basepath + "/converters/", + "org.apache.hadoop.mapreduce.lib.output.SequenceFileOutputFormat", + keyConverter="org.apache.spark.api.python.TestOutputKeyConverter", + valueConverter="org.apache.spark.api.python.TestOutputValueConverter") + converted = sorted(self.sc.sequenceFile(basepath + "/converters/").collect()) + expected = [(u'1', 3.0), + (u'2', 1.0), + (u'3', 2.0)] + self.assertEqual(converted, expected) + + def test_reserialization(self): + basepath = self.tempdir.name + x = range(1, 5) + y = range(1001, 1005) + data = list(zip(x, y)) + rdd = self.sc.parallelize(x).zip(self.sc.parallelize(y)) + rdd.saveAsSequenceFile(basepath + "/reserialize/sequence") + result1 = sorted(self.sc.sequenceFile(basepath + "/reserialize/sequence").collect()) + self.assertEqual(result1, data) + + rdd.saveAsHadoopFile( + basepath + "/reserialize/hadoop", + "org.apache.hadoop.mapred.SequenceFileOutputFormat") + result2 = sorted(self.sc.sequenceFile(basepath + "/reserialize/hadoop").collect()) + self.assertEqual(result2, data) + + rdd.saveAsNewAPIHadoopFile( + basepath + "/reserialize/newhadoop", + "org.apache.hadoop.mapreduce.lib.output.SequenceFileOutputFormat") + result3 = sorted(self.sc.sequenceFile(basepath + "/reserialize/newhadoop").collect()) + self.assertEqual(result3, data) + + conf4 = { + "mapred.output.format.class": "org.apache.hadoop.mapred.SequenceFileOutputFormat", + "mapreduce.job.output.key.class": "org.apache.hadoop.io.IntWritable", + "mapreduce.job.output.value.class": "org.apache.hadoop.io.IntWritable", + "mapreduce.output.fileoutputformat.outputdir": basepath + "/reserialize/dataset"} + rdd.saveAsHadoopDataset(conf4) + result4 = sorted(self.sc.sequenceFile(basepath + "/reserialize/dataset").collect()) + self.assertEqual(result4, data) + + conf5 = {"mapreduce.job.outputformat.class": + "org.apache.hadoop.mapreduce.lib.output.SequenceFileOutputFormat", + "mapreduce.job.output.key.class": "org.apache.hadoop.io.IntWritable", + "mapreduce.job.output.value.class": "org.apache.hadoop.io.IntWritable", + "mapreduce.output.fileoutputformat.outputdir": basepath + "/reserialize/newdataset" + } + rdd.saveAsNewAPIHadoopDataset(conf5) + result5 = sorted(self.sc.sequenceFile(basepath + "/reserialize/newdataset").collect()) + self.assertEqual(result5, data) + + def test_malformed_RDD(self): + basepath = self.tempdir.name + # non-batch-serialized RDD[[(K, V)]] should be rejected + data = [[(1, "a")], [(2, "aa")], [(3, "aaa")]] + rdd = self.sc.parallelize(data, len(data)) + self.assertRaises(Exception, lambda: rdd.saveAsSequenceFile( + basepath + "/malformed/sequence")) + + +if __name__ == "__main__": + from pyspark.tests.test_readwrite import * + + try: + import xmlrunner + testRunner = xmlrunner.XMLTestRunner(output='target/test-reports') + except ImportError: + testRunner = None + unittest.main(testRunner=testRunner, verbosity=2) diff --git a/python/pyspark/tests/test_serializers.py b/python/pyspark/tests/test_serializers.py new file mode 100644 index 0000000000000..bce94062c8af7 --- /dev/null +++ b/python/pyspark/tests/test_serializers.py @@ -0,0 +1,237 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +import math +import sys +import unittest + +from pyspark import serializers +from pyspark.serializers import * +from pyspark.serializers import CloudPickleSerializer, CompressedSerializer, \ + AutoBatchedSerializer, BatchedSerializer, AutoSerializer, NoOpSerializer, PairDeserializer, \ + FlattenedValuesSerializer, CartesianDeserializer +from pyspark.testing.utils import PySparkTestCase, read_int, write_int, ByteArrayOutput, \ + have_numpy, have_scipy + + +class SerializationTestCase(unittest.TestCase): + + def test_namedtuple(self): + from collections import namedtuple + from pickle import dumps, loads + P = namedtuple("P", "x y") + p1 = P(1, 3) + p2 = loads(dumps(p1, 2)) + self.assertEqual(p1, p2) + + from pyspark.cloudpickle import dumps + P2 = loads(dumps(P)) + p3 = P2(1, 3) + self.assertEqual(p1, p3) + + def test_itemgetter(self): + from operator import itemgetter + ser = CloudPickleSerializer() + d = range(10) + getter = itemgetter(1) + getter2 = ser.loads(ser.dumps(getter)) + self.assertEqual(getter(d), getter2(d)) + + getter = itemgetter(0, 3) + getter2 = ser.loads(ser.dumps(getter)) + self.assertEqual(getter(d), getter2(d)) + + def test_function_module_name(self): + ser = CloudPickleSerializer() + func = lambda x: x + func2 = ser.loads(ser.dumps(func)) + self.assertEqual(func.__module__, func2.__module__) + + def test_attrgetter(self): + from operator import attrgetter + ser = CloudPickleSerializer() + + class C(object): + def __getattr__(self, item): + return item + d = C() + getter = attrgetter("a") + getter2 = ser.loads(ser.dumps(getter)) + self.assertEqual(getter(d), getter2(d)) + getter = attrgetter("a", "b") + getter2 = ser.loads(ser.dumps(getter)) + self.assertEqual(getter(d), getter2(d)) + + d.e = C() + getter = attrgetter("e.a") + getter2 = ser.loads(ser.dumps(getter)) + self.assertEqual(getter(d), getter2(d)) + getter = attrgetter("e.a", "e.b") + getter2 = ser.loads(ser.dumps(getter)) + self.assertEqual(getter(d), getter2(d)) + + # Regression test for SPARK-3415 + def test_pickling_file_handles(self): + # to be corrected with SPARK-11160 + try: + import xmlrunner + except ImportError: + ser = CloudPickleSerializer() + out1 = sys.stderr + out2 = ser.loads(ser.dumps(out1)) + self.assertEqual(out1, out2) + + def test_func_globals(self): + + class Unpicklable(object): + def __reduce__(self): + raise Exception("not picklable") + + global exit + exit = Unpicklable() + + ser = CloudPickleSerializer() + self.assertRaises(Exception, lambda: ser.dumps(exit)) + + def foo(): + sys.exit(0) + + self.assertTrue("exit" in foo.__code__.co_names) + ser.dumps(foo) + + def test_compressed_serializer(self): + ser = CompressedSerializer(PickleSerializer()) + try: + from StringIO import StringIO + except ImportError: + from io import BytesIO as StringIO + io = StringIO() + ser.dump_stream(["abc", u"123", range(5)], io) + io.seek(0) + self.assertEqual(["abc", u"123", range(5)], list(ser.load_stream(io))) + ser.dump_stream(range(1000), io) + io.seek(0) + self.assertEqual(["abc", u"123", range(5)] + list(range(1000)), list(ser.load_stream(io))) + io.close() + + def test_hash_serializer(self): + hash(NoOpSerializer()) + hash(UTF8Deserializer()) + hash(PickleSerializer()) + hash(MarshalSerializer()) + hash(AutoSerializer()) + hash(BatchedSerializer(PickleSerializer())) + hash(AutoBatchedSerializer(MarshalSerializer())) + hash(PairDeserializer(NoOpSerializer(), UTF8Deserializer())) + hash(CartesianDeserializer(NoOpSerializer(), UTF8Deserializer())) + hash(CompressedSerializer(PickleSerializer())) + hash(FlattenedValuesSerializer(PickleSerializer())) + + +@unittest.skipIf(not have_scipy, "SciPy not installed") +class SciPyTests(PySparkTestCase): + + """General PySpark tests that depend on scipy """ + + def test_serialize(self): + from scipy.special import gammaln + + x = range(1, 5) + expected = list(map(gammaln, x)) + observed = self.sc.parallelize(x).map(gammaln).collect() + self.assertEqual(expected, observed) + + +@unittest.skipIf(not have_numpy, "NumPy not installed") +class NumPyTests(PySparkTestCase): + + """General PySpark tests that depend on numpy """ + + def test_statcounter_array(self): + import numpy as np + + x = self.sc.parallelize([np.array([1.0, 1.0]), np.array([2.0, 2.0]), np.array([3.0, 3.0])]) + s = x.stats() + self.assertSequenceEqual([2.0, 2.0], s.mean().tolist()) + self.assertSequenceEqual([1.0, 1.0], s.min().tolist()) + self.assertSequenceEqual([3.0, 3.0], s.max().tolist()) + self.assertSequenceEqual([1.0, 1.0], s.sampleStdev().tolist()) + + stats_dict = s.asDict() + self.assertEqual(3, stats_dict['count']) + self.assertSequenceEqual([2.0, 2.0], stats_dict['mean'].tolist()) + self.assertSequenceEqual([1.0, 1.0], stats_dict['min'].tolist()) + self.assertSequenceEqual([3.0, 3.0], stats_dict['max'].tolist()) + self.assertSequenceEqual([6.0, 6.0], stats_dict['sum'].tolist()) + self.assertSequenceEqual([1.0, 1.0], stats_dict['stdev'].tolist()) + self.assertSequenceEqual([1.0, 1.0], stats_dict['variance'].tolist()) + + stats_sample_dict = s.asDict(sample=True) + self.assertEqual(3, stats_dict['count']) + self.assertSequenceEqual([2.0, 2.0], stats_sample_dict['mean'].tolist()) + self.assertSequenceEqual([1.0, 1.0], stats_sample_dict['min'].tolist()) + self.assertSequenceEqual([3.0, 3.0], stats_sample_dict['max'].tolist()) + self.assertSequenceEqual([6.0, 6.0], stats_sample_dict['sum'].tolist()) + self.assertSequenceEqual( + [0.816496580927726, 0.816496580927726], stats_sample_dict['stdev'].tolist()) + self.assertSequenceEqual( + [0.6666666666666666, 0.6666666666666666], stats_sample_dict['variance'].tolist()) + + +class SerializersTest(unittest.TestCase): + + def test_chunked_stream(self): + original_bytes = bytearray(range(100)) + for data_length in [1, 10, 100]: + for buffer_length in [1, 2, 3, 5, 20, 99, 100, 101, 500]: + dest = ByteArrayOutput() + stream_out = serializers.ChunkedStream(dest, buffer_length) + stream_out.write(original_bytes[:data_length]) + stream_out.close() + num_chunks = int(math.ceil(float(data_length) / buffer_length)) + # length for each chunk, and a final -1 at the very end + exp_size = (num_chunks + 1) * 4 + data_length + self.assertEqual(len(dest.buffer), exp_size) + dest_pos = 0 + data_pos = 0 + for chunk_idx in range(num_chunks): + chunk_length = read_int(dest.buffer[dest_pos:(dest_pos + 4)]) + if chunk_idx == num_chunks - 1: + exp_length = data_length % buffer_length + if exp_length == 0: + exp_length = buffer_length + else: + exp_length = buffer_length + self.assertEqual(chunk_length, exp_length) + dest_pos += 4 + dest_chunk = dest.buffer[dest_pos:dest_pos + chunk_length] + orig_chunk = original_bytes[data_pos:data_pos + chunk_length] + self.assertEqual(dest_chunk, orig_chunk) + dest_pos += chunk_length + data_pos += chunk_length + # ends with a -1 + self.assertEqual(dest.buffer[-4:], write_int(-1)) + + +if __name__ == "__main__": + from pyspark.tests.test_serializers import * + + try: + import xmlrunner + testRunner = xmlrunner.XMLTestRunner(output='target/test-reports') + except ImportError: + testRunner = None + unittest.main(testRunner=testRunner, verbosity=2) diff --git a/python/pyspark/tests/test_shuffle.py b/python/pyspark/tests/test_shuffle.py new file mode 100644 index 0000000000000..0489426061b75 --- /dev/null +++ b/python/pyspark/tests/test_shuffle.py @@ -0,0 +1,181 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +import random +import sys +import unittest + +from py4j.protocol import Py4JJavaError + +from pyspark import shuffle, PickleSerializer, SparkConf, SparkContext +from pyspark.shuffle import Aggregator, ExternalMerger, ExternalSorter + +if sys.version_info[0] >= 3: + xrange = range + + +class MergerTests(unittest.TestCase): + + def setUp(self): + self.N = 1 << 12 + self.l = [i for i in xrange(self.N)] + self.data = list(zip(self.l, self.l)) + self.agg = Aggregator(lambda x: [x], + lambda x, y: x.append(y) or x, + lambda x, y: x.extend(y) or x) + + def test_small_dataset(self): + m = ExternalMerger(self.agg, 1000) + m.mergeValues(self.data) + self.assertEqual(m.spills, 0) + self.assertEqual(sum(sum(v) for k, v in m.items()), + sum(xrange(self.N))) + + m = ExternalMerger(self.agg, 1000) + m.mergeCombiners(map(lambda x_y1: (x_y1[0], [x_y1[1]]), self.data)) + self.assertEqual(m.spills, 0) + self.assertEqual(sum(sum(v) for k, v in m.items()), + sum(xrange(self.N))) + + def test_medium_dataset(self): + m = ExternalMerger(self.agg, 20) + m.mergeValues(self.data) + self.assertTrue(m.spills >= 1) + self.assertEqual(sum(sum(v) for k, v in m.items()), + sum(xrange(self.N))) + + m = ExternalMerger(self.agg, 10) + m.mergeCombiners(map(lambda x_y2: (x_y2[0], [x_y2[1]]), self.data * 3)) + self.assertTrue(m.spills >= 1) + self.assertEqual(sum(sum(v) for k, v in m.items()), + sum(xrange(self.N)) * 3) + + def test_huge_dataset(self): + m = ExternalMerger(self.agg, 5, partitions=3) + m.mergeCombiners(map(lambda k_v: (k_v[0], [str(k_v[1])]), self.data * 10)) + self.assertTrue(m.spills >= 1) + self.assertEqual(sum(len(v) for k, v in m.items()), + self.N * 10) + m._cleanup() + + def test_group_by_key(self): + + def gen_data(N, step): + for i in range(1, N + 1, step): + for j in range(i): + yield (i, [j]) + + def gen_gs(N, step=1): + return shuffle.GroupByKey(gen_data(N, step)) + + self.assertEqual(1, len(list(gen_gs(1)))) + self.assertEqual(2, len(list(gen_gs(2)))) + self.assertEqual(100, len(list(gen_gs(100)))) + self.assertEqual(list(range(1, 101)), [k for k, _ in gen_gs(100)]) + self.assertTrue(all(list(range(k)) == list(vs) for k, vs in gen_gs(100))) + + for k, vs in gen_gs(50002, 10000): + self.assertEqual(k, len(vs)) + self.assertEqual(list(range(k)), list(vs)) + + ser = PickleSerializer() + l = ser.loads(ser.dumps(list(gen_gs(50002, 30000)))) + for k, vs in l: + self.assertEqual(k, len(vs)) + self.assertEqual(list(range(k)), list(vs)) + + def test_stopiteration_is_raised(self): + + def stopit(*args, **kwargs): + raise StopIteration() + + def legit_create_combiner(x): + return [x] + + def legit_merge_value(x, y): + return x.append(y) or x + + def legit_merge_combiners(x, y): + return x.extend(y) or x + + data = [(x % 2, x) for x in range(100)] + + # wrong create combiner + m = ExternalMerger(Aggregator(stopit, legit_merge_value, legit_merge_combiners), 20) + with self.assertRaises((Py4JJavaError, RuntimeError)) as cm: + m.mergeValues(data) + + # wrong merge value + m = ExternalMerger(Aggregator(legit_create_combiner, stopit, legit_merge_combiners), 20) + with self.assertRaises((Py4JJavaError, RuntimeError)) as cm: + m.mergeValues(data) + + # wrong merge combiners + m = ExternalMerger(Aggregator(legit_create_combiner, legit_merge_value, stopit), 20) + with self.assertRaises((Py4JJavaError, RuntimeError)) as cm: + m.mergeCombiners(map(lambda x_y1: (x_y1[0], [x_y1[1]]), data)) + + +class SorterTests(unittest.TestCase): + def test_in_memory_sort(self): + l = list(range(1024)) + random.shuffle(l) + sorter = ExternalSorter(1024) + self.assertEqual(sorted(l), list(sorter.sorted(l))) + self.assertEqual(sorted(l, reverse=True), list(sorter.sorted(l, reverse=True))) + self.assertEqual(sorted(l, key=lambda x: -x), list(sorter.sorted(l, key=lambda x: -x))) + self.assertEqual(sorted(l, key=lambda x: -x, reverse=True), + list(sorter.sorted(l, key=lambda x: -x, reverse=True))) + + def test_external_sort(self): + class CustomizedSorter(ExternalSorter): + def _next_limit(self): + return self.memory_limit + l = list(range(1024)) + random.shuffle(l) + sorter = CustomizedSorter(1) + self.assertEqual(sorted(l), list(sorter.sorted(l))) + self.assertGreater(shuffle.DiskBytesSpilled, 0) + last = shuffle.DiskBytesSpilled + self.assertEqual(sorted(l, reverse=True), list(sorter.sorted(l, reverse=True))) + self.assertGreater(shuffle.DiskBytesSpilled, last) + last = shuffle.DiskBytesSpilled + self.assertEqual(sorted(l, key=lambda x: -x), list(sorter.sorted(l, key=lambda x: -x))) + self.assertGreater(shuffle.DiskBytesSpilled, last) + last = shuffle.DiskBytesSpilled + self.assertEqual(sorted(l, key=lambda x: -x, reverse=True), + list(sorter.sorted(l, key=lambda x: -x, reverse=True))) + self.assertGreater(shuffle.DiskBytesSpilled, last) + + def test_external_sort_in_rdd(self): + conf = SparkConf().set("spark.python.worker.memory", "1m") + sc = SparkContext(conf=conf) + l = list(range(10240)) + random.shuffle(l) + rdd = sc.parallelize(l, 4) + self.assertEqual(sorted(l), rdd.sortBy(lambda x: x).collect()) + sc.stop() + + +if __name__ == "__main__": + from pyspark.tests.test_shuffle import * + + try: + import xmlrunner + testRunner = xmlrunner.XMLTestRunner(output='target/test-reports') + except ImportError: + testRunner = None + unittest.main(testRunner=testRunner, verbosity=2) diff --git a/python/pyspark/tests/test_taskcontext.py b/python/pyspark/tests/test_taskcontext.py new file mode 100644 index 0000000000000..b3a967440a9b2 --- /dev/null +++ b/python/pyspark/tests/test_taskcontext.py @@ -0,0 +1,161 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +import random +import sys +import time + +from pyspark import SparkContext, TaskContext, BarrierTaskContext +from pyspark.testing.utils import PySparkTestCase + + +class TaskContextTests(PySparkTestCase): + + def setUp(self): + self._old_sys_path = list(sys.path) + class_name = self.__class__.__name__ + # Allow retries even though they are normally disabled in local mode + self.sc = SparkContext('local[4, 2]', class_name) + + def test_stage_id(self): + """Test the stage ids are available and incrementing as expected.""" + rdd = self.sc.parallelize(range(10)) + stage1 = rdd.map(lambda x: TaskContext.get().stageId()).take(1)[0] + stage2 = rdd.map(lambda x: TaskContext.get().stageId()).take(1)[0] + # Test using the constructor directly rather than the get() + stage3 = rdd.map(lambda x: TaskContext().stageId()).take(1)[0] + self.assertEqual(stage1 + 1, stage2) + self.assertEqual(stage1 + 2, stage3) + self.assertEqual(stage2 + 1, stage3) + + def test_partition_id(self): + """Test the partition id.""" + rdd1 = self.sc.parallelize(range(10), 1) + rdd2 = self.sc.parallelize(range(10), 2) + pids1 = rdd1.map(lambda x: TaskContext.get().partitionId()).collect() + pids2 = rdd2.map(lambda x: TaskContext.get().partitionId()).collect() + self.assertEqual(0, pids1[0]) + self.assertEqual(0, pids1[9]) + self.assertEqual(0, pids2[0]) + self.assertEqual(1, pids2[9]) + + def test_attempt_number(self): + """Verify the attempt numbers are correctly reported.""" + rdd = self.sc.parallelize(range(10)) + # Verify a simple job with no failures + attempt_numbers = rdd.map(lambda x: TaskContext.get().attemptNumber()).collect() + map(lambda attempt: self.assertEqual(0, attempt), attempt_numbers) + + def fail_on_first(x): + """Fail on the first attempt so we get a positive attempt number""" + tc = TaskContext.get() + attempt_number = tc.attemptNumber() + partition_id = tc.partitionId() + attempt_id = tc.taskAttemptId() + if attempt_number == 0 and partition_id == 0: + raise Exception("Failing on first attempt") + else: + return [x, partition_id, attempt_number, attempt_id] + result = rdd.map(fail_on_first).collect() + # We should re-submit the first partition to it but other partitions should be attempt 0 + self.assertEqual([0, 0, 1], result[0][0:3]) + self.assertEqual([9, 3, 0], result[9][0:3]) + first_partition = filter(lambda x: x[1] == 0, result) + map(lambda x: self.assertEqual(1, x[2]), first_partition) + other_partitions = filter(lambda x: x[1] != 0, result) + map(lambda x: self.assertEqual(0, x[2]), other_partitions) + # The task attempt id should be different + self.assertTrue(result[0][3] != result[9][3]) + + def test_tc_on_driver(self): + """Verify that getting the TaskContext on the driver returns None.""" + tc = TaskContext.get() + self.assertTrue(tc is None) + + def test_get_local_property(self): + """Verify that local properties set on the driver are available in TaskContext.""" + key = "testkey" + value = "testvalue" + self.sc.setLocalProperty(key, value) + try: + rdd = self.sc.parallelize(range(1), 1) + prop1 = rdd.map(lambda _: TaskContext.get().getLocalProperty(key)).collect()[0] + self.assertEqual(prop1, value) + prop2 = rdd.map(lambda _: TaskContext.get().getLocalProperty("otherkey")).collect()[0] + self.assertTrue(prop2 is None) + finally: + self.sc.setLocalProperty(key, None) + + def test_barrier(self): + """ + Verify that BarrierTaskContext.barrier() performs global sync among all barrier tasks + within a stage. + """ + rdd = self.sc.parallelize(range(10), 4) + + def f(iterator): + yield sum(iterator) + + def context_barrier(x): + tc = BarrierTaskContext.get() + time.sleep(random.randint(1, 10)) + tc.barrier() + return time.time() + + times = rdd.barrier().mapPartitions(f).map(context_barrier).collect() + self.assertTrue(max(times) - min(times) < 1) + + def test_barrier_with_python_worker_reuse(self): + """ + Verify that BarrierTaskContext.barrier() with reused python worker. + """ + self.sc._conf.set("spark.python.work.reuse", "true") + rdd = self.sc.parallelize(range(4), 4) + # start a normal job first to start all worker + result = rdd.map(lambda x: x ** 2).collect() + self.assertEqual([0, 1, 4, 9], result) + # make sure `spark.python.work.reuse=true` + self.assertEqual(self.sc._conf.get("spark.python.work.reuse"), "true") + + # worker will be reused in this barrier job + self.test_barrier() + + def test_barrier_infos(self): + """ + Verify that BarrierTaskContext.getTaskInfos() returns a list of all task infos in the + barrier stage. + """ + rdd = self.sc.parallelize(range(10), 4) + + def f(iterator): + yield sum(iterator) + + taskInfos = rdd.barrier().mapPartitions(f).map(lambda x: BarrierTaskContext.get() + .getTaskInfos()).collect() + self.assertTrue(len(taskInfos) == 4) + self.assertTrue(len(taskInfos[0]) == 4) + + +if __name__ == "__main__": + import unittest + from pyspark.tests.test_taskcontext import * + + try: + import xmlrunner + testRunner = xmlrunner.XMLTestRunner(output='target/test-reports') + except ImportError: + testRunner = None + unittest.main(testRunner=testRunner, verbosity=2) diff --git a/python/pyspark/tests/test_util.py b/python/pyspark/tests/test_util.py new file mode 100644 index 0000000000000..11cda8fd2f5cd --- /dev/null +++ b/python/pyspark/tests/test_util.py @@ -0,0 +1,86 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +import unittest + +from py4j.protocol import Py4JJavaError + +from pyspark import keyword_only +from pyspark.testing.utils import PySparkTestCase + + +class KeywordOnlyTests(unittest.TestCase): + class Wrapped(object): + @keyword_only + def set(self, x=None, y=None): + if "x" in self._input_kwargs: + self._x = self._input_kwargs["x"] + if "y" in self._input_kwargs: + self._y = self._input_kwargs["y"] + return x, y + + def test_keywords(self): + w = self.Wrapped() + x, y = w.set(y=1) + self.assertEqual(y, 1) + self.assertEqual(y, w._y) + self.assertIsNone(x) + self.assertFalse(hasattr(w, "_x")) + + def test_non_keywords(self): + w = self.Wrapped() + self.assertRaises(TypeError, lambda: w.set(0, y=1)) + + def test_kwarg_ownership(self): + # test _input_kwargs is owned by each class instance and not a shared static variable + class Setter(object): + @keyword_only + def set(self, x=None, other=None, other_x=None): + if "other" in self._input_kwargs: + self._input_kwargs["other"].set(x=self._input_kwargs["other_x"]) + self._x = self._input_kwargs["x"] + + a = Setter() + b = Setter() + a.set(x=1, other=b, other_x=2) + self.assertEqual(a._x, 1) + self.assertEqual(b._x, 2) + + +class UtilTests(PySparkTestCase): + def test_py4j_exception_message(self): + from pyspark.util import _exception_message + + with self.assertRaises(Py4JJavaError) as context: + # This attempts java.lang.String(null) which throws an NPE. + self.sc._jvm.java.lang.String(None) + + self.assertTrue('NullPointerException' in _exception_message(context.exception)) + + def test_parsing_version_string(self): + from pyspark.util import VersionUtils + self.assertRaises(ValueError, lambda: VersionUtils.majorMinorVersion("abced")) + + +if __name__ == "__main__": + from pyspark.tests.test_util import * + + try: + import xmlrunner + testRunner = xmlrunner.XMLTestRunner(output='target/test-reports') + except ImportError: + testRunner = None + unittest.main(testRunner=testRunner, verbosity=2) diff --git a/python/pyspark/tests/test_worker.py b/python/pyspark/tests/test_worker.py new file mode 100644 index 0000000000000..a33b77d983419 --- /dev/null +++ b/python/pyspark/tests/test_worker.py @@ -0,0 +1,157 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +import os +import sys +import tempfile +import threading +import time + +from py4j.protocol import Py4JJavaError + +from pyspark.testing.utils import ReusedPySparkTestCase, QuietTest + +if sys.version_info[0] >= 3: + xrange = range + + +class WorkerTests(ReusedPySparkTestCase): + def test_cancel_task(self): + temp = tempfile.NamedTemporaryFile(delete=True) + temp.close() + path = temp.name + + def sleep(x): + import os + import time + with open(path, 'w') as f: + f.write("%d %d" % (os.getppid(), os.getpid())) + time.sleep(100) + + # start job in background thread + def run(): + try: + self.sc.parallelize(range(1), 1).foreach(sleep) + except Exception: + pass + import threading + t = threading.Thread(target=run) + t.daemon = True + t.start() + + daemon_pid, worker_pid = 0, 0 + while True: + if os.path.exists(path): + with open(path) as f: + data = f.read().split(' ') + daemon_pid, worker_pid = map(int, data) + break + time.sleep(0.1) + + # cancel jobs + self.sc.cancelAllJobs() + t.join() + + for i in range(50): + try: + os.kill(worker_pid, 0) + time.sleep(0.1) + except OSError: + break # worker was killed + else: + self.fail("worker has not been killed after 5 seconds") + + try: + os.kill(daemon_pid, 0) + except OSError: + self.fail("daemon had been killed") + + # run a normal job + rdd = self.sc.parallelize(xrange(100), 1) + self.assertEqual(100, rdd.map(str).count()) + + def test_after_exception(self): + def raise_exception(_): + raise Exception() + rdd = self.sc.parallelize(xrange(100), 1) + with QuietTest(self.sc): + self.assertRaises(Exception, lambda: rdd.foreach(raise_exception)) + self.assertEqual(100, rdd.map(str).count()) + + def test_after_jvm_exception(self): + tempFile = tempfile.NamedTemporaryFile(delete=False) + tempFile.write(b"Hello World!") + tempFile.close() + data = self.sc.textFile(tempFile.name, 1) + filtered_data = data.filter(lambda x: True) + self.assertEqual(1, filtered_data.count()) + os.unlink(tempFile.name) + with QuietTest(self.sc): + self.assertRaises(Exception, lambda: filtered_data.count()) + + rdd = self.sc.parallelize(xrange(100), 1) + self.assertEqual(100, rdd.map(str).count()) + + def test_accumulator_when_reuse_worker(self): + from pyspark.accumulators import INT_ACCUMULATOR_PARAM + acc1 = self.sc.accumulator(0, INT_ACCUMULATOR_PARAM) + self.sc.parallelize(xrange(100), 20).foreach(lambda x: acc1.add(x)) + self.assertEqual(sum(range(100)), acc1.value) + + acc2 = self.sc.accumulator(0, INT_ACCUMULATOR_PARAM) + self.sc.parallelize(xrange(100), 20).foreach(lambda x: acc2.add(x)) + self.assertEqual(sum(range(100)), acc2.value) + self.assertEqual(sum(range(100)), acc1.value) + + def test_reuse_worker_after_take(self): + rdd = self.sc.parallelize(xrange(100000), 1) + self.assertEqual(0, rdd.first()) + + def count(): + try: + rdd.count() + except Exception: + pass + + t = threading.Thread(target=count) + t.daemon = True + t.start() + t.join(5) + self.assertTrue(not t.isAlive()) + self.assertEqual(100000, rdd.count()) + + def test_with_different_versions_of_python(self): + rdd = self.sc.parallelize(range(10)) + rdd.count() + version = self.sc.pythonVer + self.sc.pythonVer = "2.0" + try: + with QuietTest(self.sc): + self.assertRaises(Py4JJavaError, lambda: rdd.count()) + finally: + self.sc.pythonVer = version + + +if __name__ == "__main__": + import unittest + from pyspark.tests.test_worker import * + + try: + import xmlrunner + testRunner = xmlrunner.XMLTestRunner(output='target/test-reports') + except ImportError: + testRunner = None + unittest.main(testRunner=testRunner, verbosity=2) diff --git a/python/run-tests.py b/python/run-tests.py index 9e6b7d780409a..be5a60a28def6 100755 --- a/python/run-tests.py +++ b/python/run-tests.py @@ -60,9 +60,7 @@ def print_red(text): LOGGER = logging.getLogger() # Find out where the assembly jars are located. -# Later, add back 2.12 to this list: -# for scala in ["2.11", "2.12"]: -for scala in ["2.11"]: +for scala in ["2.11", "2.12"]: build_dir = os.path.join(SPARK_HOME, "assembly", "target", "scala-" + scala) if os.path.isdir(build_dir): SPARK_DIST_CLASSPATH = os.path.join(build_dir, "jars", "*") @@ -263,8 +261,9 @@ def main(): for module in modules_to_test: if python_implementation not in module.blacklisted_python_implementations: for test_goal in module.python_test_goals: - if test_goal in ('pyspark.streaming.tests', 'pyspark.mllib.tests', - 'pyspark.tests', 'pyspark.sql.tests'): + heavy_tests = ['pyspark.streaming.tests', 'pyspark.mllib.tests', + 'pyspark.tests', 'pyspark.sql.tests', 'pyspark.ml.tests'] + if any(map(lambda prefix: test_goal.startswith(prefix), heavy_tests)): priority = 0 else: priority = 100 diff --git a/repl/pom.xml b/repl/pom.xml index fa015b69d45d4..c7de67e41ca94 100644 --- a/repl/pom.xml +++ b/repl/pom.xml @@ -20,12 +20,12 @@ 4.0.0 org.apache.spark - spark-parent_2.11 + spark-parent_2.12 3.0.0-SNAPSHOT ../pom.xml - spark-repl_2.11 + spark-repl_2.12 jar Spark Project REPL http://spark.apache.org/ diff --git a/repl/src/test/scala/org/apache/spark/repl/ExecutorClassLoaderSuite.scala b/repl/src/test/scala/org/apache/spark/repl/ExecutorClassLoaderSuite.scala index e5e2094368fb0..ac528ecb829b0 100644 --- a/repl/src/test/scala/org/apache/spark/repl/ExecutorClassLoaderSuite.scala +++ b/repl/src/test/scala/org/apache/spark/repl/ExecutorClassLoaderSuite.scala @@ -126,7 +126,7 @@ class ExecutorClassLoaderSuite test("child first") { val parentLoader = new URLClassLoader(urls2, null) val classLoader = new ExecutorClassLoader(new SparkConf(), null, url1, parentLoader, true) - val fakeClass = classLoader.loadClass("ReplFakeClass2").newInstance() + val fakeClass = classLoader.loadClass("ReplFakeClass2").getConstructor().newInstance() val fakeClassVersion = fakeClass.toString assert(fakeClassVersion === "1") } @@ -134,7 +134,7 @@ class ExecutorClassLoaderSuite test("parent first") { val parentLoader = new URLClassLoader(urls2, null) val classLoader = new ExecutorClassLoader(new SparkConf(), null, url1, parentLoader, false) - val fakeClass = classLoader.loadClass("ReplFakeClass1").newInstance() + val fakeClass = classLoader.loadClass("ReplFakeClass1").getConstructor().newInstance() val fakeClassVersion = fakeClass.toString assert(fakeClassVersion === "2") } @@ -142,7 +142,7 @@ class ExecutorClassLoaderSuite test("child first can fall back") { val parentLoader = new URLClassLoader(urls2, null) val classLoader = new ExecutorClassLoader(new SparkConf(), null, url1, parentLoader, true) - val fakeClass = classLoader.loadClass("ReplFakeClass3").newInstance() + val fakeClass = classLoader.loadClass("ReplFakeClass3").getConstructor().newInstance() val fakeClassVersion = fakeClass.toString assert(fakeClassVersion === "2") } @@ -151,7 +151,7 @@ class ExecutorClassLoaderSuite val parentLoader = new URLClassLoader(urls2, null) val classLoader = new ExecutorClassLoader(new SparkConf(), null, url1, parentLoader, true) intercept[java.lang.ClassNotFoundException] { - classLoader.loadClass("ReplFakeClassDoesNotExist").newInstance() + classLoader.loadClass("ReplFakeClassDoesNotExist").getConstructor().newInstance() } } @@ -202,11 +202,11 @@ class ExecutorClassLoaderSuite val classLoader = new ExecutorClassLoader(new SparkConf(), env, "spark://localhost:1234", getClass().getClassLoader(), false) - val fakeClass = classLoader.loadClass("ReplFakeClass2").newInstance() + val fakeClass = classLoader.loadClass("ReplFakeClass2").getConstructor().newInstance() val fakeClassVersion = fakeClass.toString assert(fakeClassVersion === "1") intercept[java.lang.ClassNotFoundException] { - classLoader.loadClass("ReplFakeClassDoesNotExist").newInstance() + classLoader.loadClass("ReplFakeClassDoesNotExist").getConstructor().newInstance() } } diff --git a/resource-managers/kubernetes/core/pom.xml b/resource-managers/kubernetes/core/pom.xml index b89ea383bf872..8d594ee8f1478 100644 --- a/resource-managers/kubernetes/core/pom.xml +++ b/resource-managers/kubernetes/core/pom.xml @@ -19,12 +19,12 @@ 4.0.0 org.apache.spark - spark-parent_2.11 + spark-parent_2.12 3.0.0-SNAPSHOT ../../../pom.xml - spark-kubernetes_2.11 + spark-kubernetes_2.12 jar Spark Project Kubernetes diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/Config.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/Config.scala index 3be93fd6059d0..43219cb9406f6 100644 --- a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/Config.scala +++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/Config.scala @@ -321,6 +321,7 @@ private[spark] object Config extends Logging { val KUBERNETES_VOLUMES_PVC_TYPE = "persistentVolumeClaim" val KUBERNETES_VOLUMES_EMPTYDIR_TYPE = "emptyDir" val KUBERNETES_VOLUMES_MOUNT_PATH_KEY = "mount.path" + val KUBERNETES_VOLUMES_MOUNT_SUBPATH_KEY = "mount.subPath" val KUBERNETES_VOLUMES_MOUNT_READONLY_KEY = "mount.readOnly" val KUBERNETES_VOLUMES_OPTIONS_PATH_KEY = "options.path" val KUBERNETES_VOLUMES_OPTIONS_CLAIM_NAME_KEY = "options.claimName" diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/KubernetesVolumeSpec.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/KubernetesVolumeSpec.scala index b1762d1efe2ea..1a214fad96618 100644 --- a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/KubernetesVolumeSpec.scala +++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/KubernetesVolumeSpec.scala @@ -34,5 +34,6 @@ private[spark] case class KubernetesEmptyDirVolumeConf( private[spark] case class KubernetesVolumeSpec[T <: KubernetesVolumeSpecificConf]( volumeName: String, mountPath: String, + mountSubPath: String, mountReadOnly: Boolean, volumeConf: T) diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/KubernetesVolumeUtils.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/KubernetesVolumeUtils.scala index 713df5fffc3a2..155326469235b 100644 --- a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/KubernetesVolumeUtils.scala +++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/KubernetesVolumeUtils.scala @@ -39,6 +39,7 @@ private[spark] object KubernetesVolumeUtils { getVolumeTypesAndNames(properties).map { case (volumeType, volumeName) => val pathKey = s"$volumeType.$volumeName.$KUBERNETES_VOLUMES_MOUNT_PATH_KEY" val readOnlyKey = s"$volumeType.$volumeName.$KUBERNETES_VOLUMES_MOUNT_READONLY_KEY" + val subPathKey = s"$volumeType.$volumeName.$KUBERNETES_VOLUMES_MOUNT_SUBPATH_KEY" for { path <- properties.getTry(pathKey) @@ -46,6 +47,7 @@ private[spark] object KubernetesVolumeUtils { } yield KubernetesVolumeSpec( volumeName = volumeName, mountPath = path, + mountSubPath = properties.get(subPathKey).getOrElse(""), mountReadOnly = properties.get(readOnlyKey).exists(_.toBoolean), volumeConf = volumeConf ) diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/features/MountVolumesFeatureStep.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/features/MountVolumesFeatureStep.scala index e60259c4a9b5a..1473a7d3ee7f6 100644 --- a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/features/MountVolumesFeatureStep.scala +++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/features/MountVolumesFeatureStep.scala @@ -51,6 +51,7 @@ private[spark] class MountVolumesFeatureStep( val volumeMount = new VolumeMountBuilder() .withMountPath(spec.mountPath) .withReadOnly(spec.mountReadOnly) + .withSubPath(spec.mountSubPath) .withName(spec.volumeName) .build() diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/KubernetesDriverBuilder.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/KubernetesDriverBuilder.scala index ad44dc63a90aa..3b5908f8f01d9 100644 --- a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/KubernetesDriverBuilder.scala +++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/KubernetesDriverBuilder.scala @@ -59,7 +59,7 @@ private[spark] class KubernetesDriverBuilder( providePodTemplateConfigMapStep: (KubernetesConf[_ <: KubernetesRoleSpecificConf] => PodTemplateConfigMapStep) = new PodTemplateConfigMapStep(_), - provideInitialPod: () => SparkPod = SparkPod.initialPod) { + provideInitialPod: () => SparkPod = () => SparkPod.initialPod()) { import KubernetesDriverBuilder._ diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/KubernetesExecutorBuilder.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/KubernetesExecutorBuilder.scala index e6b46308465d2..368d14bbc3cb5 100644 --- a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/KubernetesExecutorBuilder.scala +++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/KubernetesExecutorBuilder.scala @@ -57,7 +57,7 @@ private[spark] class KubernetesExecutorBuilder( KubernetesConf[KubernetesExecutorSpecificConf] => HadoopSparkUserExecutorFeatureStep) = new HadoopSparkUserExecutorFeatureStep(_), - provideInitialPod: () => SparkPod = SparkPod.initialPod) { + provideInitialPod: () => SparkPod = () => SparkPod.initialPod()) { def buildFromFeatures( kubernetesConf: KubernetesConf[KubernetesExecutorSpecificConf]): SparkPod = { diff --git a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/KubernetesVolumeUtilsSuite.scala b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/KubernetesVolumeUtilsSuite.scala index d795d159773a8..de79a58a3a756 100644 --- a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/KubernetesVolumeUtilsSuite.scala +++ b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/KubernetesVolumeUtilsSuite.scala @@ -33,6 +33,18 @@ class KubernetesVolumeUtilsSuite extends SparkFunSuite { KubernetesHostPathVolumeConf("/hostPath")) } + test("Parses subPath correctly") { + val sparkConf = new SparkConf(false) + sparkConf.set("test.emptyDir.volumeName.mount.path", "/path") + sparkConf.set("test.emptyDir.volumeName.mount.readOnly", "true") + sparkConf.set("test.emptyDir.volumeName.mount.subPath", "subPath") + + val volumeSpec = KubernetesVolumeUtils.parseVolumesWithPrefix(sparkConf, "test.").head.get + assert(volumeSpec.volumeName === "volumeName") + assert(volumeSpec.mountPath === "/path") + assert(volumeSpec.mountSubPath === "subPath") + } + test("Parses persistentVolumeClaim volumes correctly") { val sparkConf = new SparkConf(false) sparkConf.set("test.persistentVolumeClaim.volumeName.mount.path", "/path") diff --git a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/MountVolumesFeatureStepSuite.scala b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/MountVolumesFeatureStepSuite.scala index 82b828c04df69..9b5e42ebc849d 100644 --- a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/MountVolumesFeatureStepSuite.scala +++ b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/MountVolumesFeatureStepSuite.scala @@ -44,6 +44,7 @@ class MountVolumesFeatureStepSuite extends SparkFunSuite { val volumeConf = KubernetesVolumeSpec( "testVolume", "/tmp", + "", false, KubernetesHostPathVolumeConf("/hostPath/tmp") ) @@ -63,6 +64,7 @@ class MountVolumesFeatureStepSuite extends SparkFunSuite { val volumeConf = KubernetesVolumeSpec( "testVolume", "/tmp", + "", true, KubernetesPVCVolumeConf("pvcClaim") ) @@ -84,6 +86,7 @@ class MountVolumesFeatureStepSuite extends SparkFunSuite { val volumeConf = KubernetesVolumeSpec( "testVolume", "/tmp", + "", false, KubernetesEmptyDirVolumeConf(Some("Memory"), Some("6G")) ) @@ -105,6 +108,7 @@ class MountVolumesFeatureStepSuite extends SparkFunSuite { val volumeConf = KubernetesVolumeSpec( "testVolume", "/tmp", + "", false, KubernetesEmptyDirVolumeConf(None, None) ) @@ -126,12 +130,14 @@ class MountVolumesFeatureStepSuite extends SparkFunSuite { val hpVolumeConf = KubernetesVolumeSpec( "hpVolume", "/tmp", + "", false, KubernetesHostPathVolumeConf("/hostPath/tmp") ) val pvcVolumeConf = KubernetesVolumeSpec( "checkpointVolume", "/checkpoints", + "", true, KubernetesPVCVolumeConf("pvcClaim") ) @@ -143,4 +149,77 @@ class MountVolumesFeatureStepSuite extends SparkFunSuite { assert(configuredPod.pod.getSpec.getVolumes.size() === 2) assert(configuredPod.container.getVolumeMounts.size() === 2) } + + test("Mounts subpath on emptyDir") { + val volumeConf = KubernetesVolumeSpec( + "testVolume", + "/tmp", + "foo", + false, + KubernetesEmptyDirVolumeConf(None, None) + ) + val kubernetesConf = emptyKubernetesConf.copy(roleVolumes = volumeConf :: Nil) + val step = new MountVolumesFeatureStep(kubernetesConf) + val configuredPod = step.configurePod(SparkPod.initialPod()) + + assert(configuredPod.pod.getSpec.getVolumes.size() === 1) + val emptyDirMount = configuredPod.container.getVolumeMounts.get(0) + assert(emptyDirMount.getMountPath === "/tmp") + assert(emptyDirMount.getName === "testVolume") + assert(emptyDirMount.getSubPath === "foo") + } + + test("Mounts subpath on persistentVolumeClaims") { + val volumeConf = KubernetesVolumeSpec( + "testVolume", + "/tmp", + "bar", + true, + KubernetesPVCVolumeConf("pvcClaim") + ) + val kubernetesConf = emptyKubernetesConf.copy(roleVolumes = volumeConf :: Nil) + val step = new MountVolumesFeatureStep(kubernetesConf) + val configuredPod = step.configurePod(SparkPod.initialPod()) + + assert(configuredPod.pod.getSpec.getVolumes.size() === 1) + val pvcClaim = configuredPod.pod.getSpec.getVolumes.get(0).getPersistentVolumeClaim + assert(pvcClaim.getClaimName === "pvcClaim") + assert(configuredPod.container.getVolumeMounts.size() === 1) + val pvcMount = configuredPod.container.getVolumeMounts.get(0) + assert(pvcMount.getMountPath === "/tmp") + assert(pvcMount.getName === "testVolume") + assert(pvcMount.getSubPath === "bar") + } + + test("Mounts multiple subpaths") { + val volumeConf = KubernetesEmptyDirVolumeConf(None, None) + val emptyDirSpec = KubernetesVolumeSpec( + "testEmptyDir", + "/tmp/foo", + "foo", + true, + KubernetesEmptyDirVolumeConf(None, None) + ) + val pvcSpec = KubernetesVolumeSpec( + "testPVC", + "/tmp/bar", + "bar", + true, + KubernetesEmptyDirVolumeConf(None, None) + ) + val kubernetesConf = emptyKubernetesConf.copy( + roleVolumes = emptyDirSpec :: pvcSpec :: Nil) + val step = new MountVolumesFeatureStep(kubernetesConf) + val configuredPod = step.configurePod(SparkPod.initialPod()) + + assert(configuredPod.pod.getSpec.getVolumes.size() === 2) + val mounts = configuredPod.container.getVolumeMounts + assert(mounts.size() === 2) + assert(mounts.get(0).getName === "testEmptyDir") + assert(mounts.get(0).getMountPath === "/tmp/foo") + assert(mounts.get(0).getSubPath === "foo") + assert(mounts.get(1).getName === "testPVC") + assert(mounts.get(1).getMountPath === "/tmp/bar") + assert(mounts.get(1).getSubPath === "bar") + } } diff --git a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/KubernetesDriverBuilderSuite.scala b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/KubernetesDriverBuilderSuite.scala index 0cfb2127a5d26..24aebef8633ad 100644 --- a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/KubernetesDriverBuilderSuite.scala +++ b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/KubernetesDriverBuilderSuite.scala @@ -177,6 +177,42 @@ class KubernetesDriverBuilderSuite extends SparkFunSuite { val volumeSpec = KubernetesVolumeSpec( "volume", "/tmp", + "", + false, + KubernetesHostPathVolumeConf("/path")) + val conf = KubernetesConf( + new SparkConf(false), + KubernetesDriverSpecificConf( + JavaMainAppResource(None), + "test-app", + "main", + Seq.empty), + "prefix", + "appId", + None, + Map.empty, + Map.empty, + Map.empty, + Map.empty, + Map.empty, + volumeSpec :: Nil, + hadoopConfSpec = None) + validateStepTypesApplied( + builderUnderTest.buildFromFeatures(conf), + BASIC_STEP_TYPE, + CREDENTIALS_STEP_TYPE, + SERVICE_STEP_TYPE, + LOCAL_DIRS_STEP_TYPE, + LOCAL_FILES_STEP_TYPE, + MOUNT_VOLUMES_STEP_TYPE, + DRIVER_CMD_STEP_TYPE) + } + + test("Apply volumes step if a mount subpath is present.") { + val volumeSpec = KubernetesVolumeSpec( + "volume", + "/tmp", + "foo", false, KubernetesHostPathVolumeConf("/path")) val conf = KubernetesConf( diff --git a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/scheduler/cluster/k8s/KubernetesExecutorBuilderSuite.scala b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/scheduler/cluster/k8s/KubernetesExecutorBuilderSuite.scala index 29007431d4b52..b3e75b6661926 100644 --- a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/scheduler/cluster/k8s/KubernetesExecutorBuilderSuite.scala +++ b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/scheduler/cluster/k8s/KubernetesExecutorBuilderSuite.scala @@ -135,6 +135,7 @@ class KubernetesExecutorBuilderSuite extends SparkFunSuite { val volumeSpec = KubernetesVolumeSpec( "volume", "/tmp", + "", false, KubernetesHostPathVolumeConf("/checkpoint")) val conf = KubernetesConf( diff --git a/resource-managers/kubernetes/docker/src/main/dockerfiles/spark/Dockerfile b/resource-managers/kubernetes/docker/src/main/dockerfiles/spark/Dockerfile index a5850dad21d06..787b8cc52021e 100644 --- a/resource-managers/kubernetes/docker/src/main/dockerfiles/spark/Dockerfile +++ b/resource-managers/kubernetes/docker/src/main/dockerfiles/spark/Dockerfile @@ -17,11 +17,6 @@ FROM openjdk:8-alpine -ARG spark_jars=jars -ARG example_jars=examples/jars -ARG img_path=kubernetes/dockerfiles -ARG k8s_tests=kubernetes/tests - # Before building the docker image, first build and make a Spark distribution following # the instructions in http://spark.apache.org/docs/latest/building-spark.html. # If this docker file is being used in the context of building your images from a Spark @@ -39,10 +34,13 @@ RUN set -ex && \ ln -sv /bin/bash /bin/sh && \ chgrp root /etc/passwd && chmod ug+rw /etc/passwd -COPY ${spark_jars} /opt/spark/jars +COPY jars /opt/spark/jars COPY bin /opt/spark/bin COPY sbin /opt/spark/sbin -COPY ${img_path}/spark/entrypoint.sh /opt/ +COPY kubernetes/dockerfiles/spark/entrypoint.sh /opt/ +COPY examples /opt/spark/examples +COPY kubernetes/tests /opt/spark/tests +COPY data /opt/spark/data ENV SPARK_HOME /opt/spark diff --git a/resource-managers/kubernetes/integration-tests/README.md b/resource-managers/kubernetes/integration-tests/README.md index 64f8e77597eba..73fc0581d64f5 100644 --- a/resource-managers/kubernetes/integration-tests/README.md +++ b/resource-managers/kubernetes/integration-tests/README.md @@ -107,7 +107,7 @@ properties to Maven. For example: mvn integration-test -am -pl :spark-kubernetes-integration-tests_2.11 \ -Pkubernetes -Pkubernetes-integration-tests \ - -Phadoop-2.7 -Dhadoop.version=2.7.3 \ + -Phadoop-2.7 -Dhadoop.version=2.7.4 \ -Dspark.kubernetes.test.sparkTgz=spark-3.0.0-SNAPSHOT-bin-example.tgz \ -Dspark.kubernetes.test.imageTag=sometag \ -Dspark.kubernetes.test.imageRepo=docker.io/somerepo \ diff --git a/resource-managers/kubernetes/integration-tests/pom.xml b/resource-managers/kubernetes/integration-tests/pom.xml index 301b6fe8eee56..17af0e03f2bbb 100644 --- a/resource-managers/kubernetes/integration-tests/pom.xml +++ b/resource-managers/kubernetes/integration-tests/pom.xml @@ -19,12 +19,12 @@ 4.0.0 org.apache.spark - spark-parent_2.11 + spark-parent_2.12 3.0.0-SNAPSHOT ../../../pom.xml - spark-kubernetes-integration-tests_2.11 + spark-kubernetes-integration-tests_2.12 1.3.0 1.4.0 diff --git a/resource-managers/kubernetes/integration-tests/scripts/setup-integration-test-env.sh b/resource-managers/kubernetes/integration-tests/scripts/setup-integration-test-env.sh index a4a9f5b7da131..36e30d7b2cffb 100755 --- a/resource-managers/kubernetes/integration-tests/scripts/setup-integration-test-env.sh +++ b/resource-managers/kubernetes/integration-tests/scripts/setup-integration-test-env.sh @@ -72,10 +72,16 @@ then IMAGE_TAG=$(uuidgen); cd $UNPACKED_SPARK_TGZ + # Build PySpark image + LANGUAGE_BINDING_BUILD_ARGS="-p $UNPACKED_SPARK_TGZ/kubernetes/dockerfiles/spark/bindings/python/Dockerfile" + + # Build SparkR image + LANGUAGE_BINDING_BUILD_ARGS="$LANGUAGE_BINDING_BUILD_ARGS -R $UNPACKED_SPARK_TGZ/kubernetes/dockerfiles/spark/bindings/R/Dockerfile" + case $DEPLOY_MODE in cloud) # Build images - $UNPACKED_SPARK_TGZ/bin/docker-image-tool.sh -r $IMAGE_REPO -t $IMAGE_TAG build + $UNPACKED_SPARK_TGZ/bin/docker-image-tool.sh -r $IMAGE_REPO -t $IMAGE_TAG $LANGUAGE_BINDING_BUILD_ARGS build # Push images appropriately if [[ $IMAGE_REPO == gcr.io* ]] ; @@ -89,13 +95,13 @@ then docker-for-desktop) # Only need to build as this will place it in our local Docker repo which is all # we need for Docker for Desktop to work so no need to also push - $UNPACKED_SPARK_TGZ/bin/docker-image-tool.sh -r $IMAGE_REPO -t $IMAGE_TAG build + $UNPACKED_SPARK_TGZ/bin/docker-image-tool.sh -r $IMAGE_REPO -t $IMAGE_TAG $LANGUAGE_BINDING_BUILD_ARGS build ;; minikube) # Only need to build and if we do this with the -m option for minikube we will # build the images directly using the minikube Docker daemon so no need to push - $UNPACKED_SPARK_TGZ/bin/docker-image-tool.sh -m -r $IMAGE_REPO -t $IMAGE_TAG build + $UNPACKED_SPARK_TGZ/bin/docker-image-tool.sh -m -r $IMAGE_REPO -t $IMAGE_TAG $LANGUAGE_BINDING_BUILD_ARGS build ;; *) echo "Unrecognized deploy mode $DEPLOY_MODE" && exit 1 diff --git a/resource-managers/mesos/pom.xml b/resource-managers/mesos/pom.xml index 9585bdfafdcf4..7b3aad4d6ce35 100644 --- a/resource-managers/mesos/pom.xml +++ b/resource-managers/mesos/pom.xml @@ -19,12 +19,12 @@ 4.0.0 org.apache.spark - spark-parent_2.11 + spark-parent_2.12 3.0.0-SNAPSHOT ../../pom.xml - spark-mesos_2.11 + spark-mesos_2.12 jar Spark Project Mesos diff --git a/resource-managers/yarn/pom.xml b/resource-managers/yarn/pom.xml index e55b814be8465..d18df9955bb1f 100644 --- a/resource-managers/yarn/pom.xml +++ b/resource-managers/yarn/pom.xml @@ -19,12 +19,12 @@ 4.0.0 org.apache.spark - spark-parent_2.11 + spark-parent_2.12 3.0.0-SNAPSHOT ../../pom.xml - spark-yarn_2.11 + spark-yarn_2.12 jar Spark Project YARN diff --git a/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnAllocator.scala b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnAllocator.scala index ebdcf45603cea..9497530805c1a 100644 --- a/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnAllocator.scala +++ b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnAllocator.scala @@ -20,7 +20,6 @@ package org.apache.spark.deploy.yarn import java.util.Collections import java.util.concurrent._ import java.util.concurrent.atomic.AtomicInteger -import java.util.regex.Pattern import scala.collection.JavaConverters._ import scala.collection.mutable @@ -598,13 +597,21 @@ private[yarn] class YarnAllocator( (false, s"Container ${containerId}${onHostStr} was preempted.") // Should probably still count memory exceeded exit codes towards task failures case VMEM_EXCEEDED_EXIT_CODE => - (true, memLimitExceededLogMessage( - completedContainer.getDiagnostics, - VMEM_EXCEEDED_PATTERN)) + val vmemExceededPattern = raw"$MEM_REGEX of $MEM_REGEX virtual memory used".r + val diag = vmemExceededPattern.findFirstIn(completedContainer.getDiagnostics) + .map(_.concat(".")).getOrElse("") + val message = "Container killed by YARN for exceeding virtual memory limits. " + + s"$diag Consider boosting ${EXECUTOR_MEMORY_OVERHEAD.key} or boosting " + + s"${YarnConfiguration.NM_VMEM_PMEM_RATIO} or disabling " + + s"${YarnConfiguration.NM_VMEM_CHECK_ENABLED} because of YARN-4714." + (true, message) case PMEM_EXCEEDED_EXIT_CODE => - (true, memLimitExceededLogMessage( - completedContainer.getDiagnostics, - PMEM_EXCEEDED_PATTERN)) + val pmemExceededPattern = raw"$MEM_REGEX of $MEM_REGEX physical memory used".r + val diag = pmemExceededPattern.findFirstIn(completedContainer.getDiagnostics) + .map(_.concat(".")).getOrElse("") + val message = "Container killed by YARN for exceeding physical memory limits. " + + s"$diag Consider boosting ${EXECUTOR_MEMORY_OVERHEAD.key}." + (true, message) case _ => // all the failures which not covered above, like: // disk failure, kill by app master or resource manager, ... @@ -735,18 +742,6 @@ private[yarn] class YarnAllocator( private object YarnAllocator { val MEM_REGEX = "[0-9.]+ [KMG]B" - val PMEM_EXCEEDED_PATTERN = - Pattern.compile(s"$MEM_REGEX of $MEM_REGEX physical memory used") - val VMEM_EXCEEDED_PATTERN = - Pattern.compile(s"$MEM_REGEX of $MEM_REGEX virtual memory used") val VMEM_EXCEEDED_EXIT_CODE = -103 val PMEM_EXCEEDED_EXIT_CODE = -104 - - def memLimitExceededLogMessage(diagnostics: String, pattern: Pattern): String = { - val matcher = pattern.matcher(diagnostics) - val diag = if (matcher.find()) " " + matcher.group() + "." else "" - s"Container killed by YARN for exceeding memory limits. $diag " + - "Consider boosting spark.yarn.executor.memoryOverhead or " + - "disabling yarn.nodemanager.vmem-check-enabled because of YARN-4714." - } } diff --git a/resource-managers/yarn/src/main/scala/org/apache/spark/scheduler/cluster/SchedulerExtensionService.scala b/resource-managers/yarn/src/main/scala/org/apache/spark/scheduler/cluster/SchedulerExtensionService.scala index 4ed285230ff81..7d15f0e2fbac8 100644 --- a/resource-managers/yarn/src/main/scala/org/apache/spark/scheduler/cluster/SchedulerExtensionService.scala +++ b/resource-managers/yarn/src/main/scala/org/apache/spark/scheduler/cluster/SchedulerExtensionService.scala @@ -107,7 +107,7 @@ private[spark] class SchedulerExtensionServices extends SchedulerExtensionServic services = sparkContext.conf.get(SCHEDULER_SERVICES).map { sClass => val instance = Utils.classForName(sClass) - .newInstance() + .getConstructor().newInstance() .asInstanceOf[SchedulerExtensionService] // bind this service instance.start(binding) diff --git a/resource-managers/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnAllocatorSuite.scala b/resource-managers/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnAllocatorSuite.scala index 35299166d9814..b61e7df4420ef 100644 --- a/resource-managers/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnAllocatorSuite.scala +++ b/resource-managers/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnAllocatorSuite.scala @@ -29,7 +29,6 @@ import org.mockito.Mockito._ import org.scalatest.{BeforeAndAfterEach, Matchers} import org.apache.spark.{SecurityManager, SparkConf, SparkFunSuite} -import org.apache.spark.deploy.yarn.YarnAllocator._ import org.apache.spark.deploy.yarn.YarnSparkHadoopUtil._ import org.apache.spark.deploy.yarn.config._ import org.apache.spark.rpc.RpcEndpointRef @@ -116,8 +115,7 @@ class YarnAllocatorSuite extends SparkFunSuite with Matchers with BeforeAndAfter } def createContainer(host: String, resource: Resource = containerResource): Container = { - // When YARN 2.6+ is required, avoid deprecation by using version with long second arg - val containerId = ContainerId.newInstance(appAttemptId, containerNum) + val containerId = ContainerId.newContainerId(appAttemptId, containerNum) containerNum += 1 val nodeId = NodeId.newInstance(host, 1000) Container.newInstance(containerId, nodeId, "", resource, RM_REQUEST_PRIORITY, null) @@ -377,17 +375,6 @@ class YarnAllocatorSuite extends SparkFunSuite with Matchers with BeforeAndAfter verify(mockAmClient).updateBlacklist(Seq[String]().asJava, Seq("hostA", "hostB").asJava) } - test("memory exceeded diagnostic regexes") { - val diagnostics = - "Container [pid=12465,containerID=container_1412887393566_0003_01_000002] is running " + - "beyond physical memory limits. Current usage: 2.1 MB of 2 GB physical memory used; " + - "5.8 GB of 4.2 GB virtual memory used. Killing container." - val vmemMsg = memLimitExceededLogMessage(diagnostics, VMEM_EXCEEDED_PATTERN) - val pmemMsg = memLimitExceededLogMessage(diagnostics, PMEM_EXCEEDED_PATTERN) - assert(vmemMsg.contains("5.8 GB of 4.2 GB virtual memory used.")) - assert(pmemMsg.contains("2.1 MB of 2 GB physical memory used.")) - } - test("window based failure executor counting") { sparkConf.set("spark.yarn.executor.failuresValidityInterval", "100s") val handler = createAllocator(4) diff --git a/scalastyle-config.xml b/scalastyle-config.xml index 36a73e3362218..4892819ae9973 100644 --- a/scalastyle-config.xml +++ b/scalastyle-config.xml @@ -240,6 +240,17 @@ This file is divided into 3 sections: ]]> + + throw new \w+Error\( + + + JavaConversions diff --git a/spark-docker-image-generator/src/test/resources/ExpectedDockerfile b/spark-docker-image-generator/src/test/resources/ExpectedDockerfile index 2e0613cd2a826..31ec83d3db601 100644 --- a/spark-docker-image-generator/src/test/resources/ExpectedDockerfile +++ b/spark-docker-image-generator/src/test/resources/ExpectedDockerfile @@ -17,11 +17,6 @@ FROM fabric8/java-centos-openjdk8-jdk:latest -ARG spark_jars=jars -ARG example_jars=examples/jars -ARG img_path=kubernetes/dockerfiles -ARG k8s_tests=kubernetes/tests - # Before building the docker image, first build and make a Spark distribution following # the instructions in http://spark.apache.org/docs/latest/building-spark.html. # If this docker file is being used in the context of building your images from a Spark @@ -39,10 +34,13 @@ RUN set -ex && \ ln -sv /bin/bash /bin/sh && \ chgrp root /etc/passwd && chmod ug+rw /etc/passwd -COPY ${spark_jars} /opt/spark/jars +COPY jars /opt/spark/jars COPY bin /opt/spark/bin COPY sbin /opt/spark/sbin -COPY ${img_path}/spark/entrypoint.sh /opt/ +COPY kubernetes/dockerfiles/spark/entrypoint.sh /opt/ +COPY examples /opt/spark/examples +COPY kubernetes/tests /opt/spark/tests +COPY data /opt/spark/data ENV SPARK_HOME /opt/spark diff --git a/sql/catalyst/pom.xml b/sql/catalyst/pom.xml index 16ecebf159c1f..20cc5d03fbe52 100644 --- a/sql/catalyst/pom.xml +++ b/sql/catalyst/pom.xml @@ -21,12 +21,12 @@ 4.0.0 org.apache.spark - spark-parent_2.11 + spark-parent_2.12 3.0.0-SNAPSHOT ../../pom.xml - spark-catalyst_2.11 + spark-catalyst_2.12 jar Spark Project Catalyst http://spark.apache.org/ diff --git a/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 b/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 index e2d34d1650ddc..5e732edb17baa 100644 --- a/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 +++ b/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 @@ -691,6 +691,7 @@ namedWindow windowSpec : name=identifier #windowRef + | '('name=identifier')' #windowRef | '(' ( CLUSTER BY partition+=expression (',' partition+=expression)* | ((PARTITION | DISTRIBUTE) BY partition+=expression (',' partition+=expression)*)? diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/RowFactory.java b/sql/catalyst/src/main/java/org/apache/spark/sql/RowFactory.java index 2ce1fdcbf56ae..0258e66ffb6e5 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/RowFactory.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/RowFactory.java @@ -17,7 +17,7 @@ package org.apache.spark.sql; -import org.apache.spark.annotation.InterfaceStability; +import org.apache.spark.annotation.Stable; import org.apache.spark.sql.catalyst.expressions.GenericRow; /** @@ -25,7 +25,7 @@ * * @since 1.3.0 */ -@InterfaceStability.Stable +@Stable public class RowFactory { /** diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/RowBasedKeyValueBatch.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/RowBasedKeyValueBatch.java index 460513816dfd9..6344cf18c11b8 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/RowBasedKeyValueBatch.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/RowBasedKeyValueBatch.java @@ -20,6 +20,7 @@ import java.io.IOException; import org.apache.spark.memory.MemoryConsumer; +import org.apache.spark.memory.SparkOutOfMemoryError; import org.apache.spark.memory.TaskMemoryManager; import org.apache.spark.sql.types.*; import org.apache.spark.unsafe.memory.MemoryBlock; @@ -126,7 +127,7 @@ public final void close() { private boolean acquirePage(long requiredSize) { try { page = allocatePage(requiredSize); - } catch (OutOfMemoryError e) { + } catch (SparkOutOfMemoryError e) { logger.warn("Failed to allocate page ({} bytes).", requiredSize); return false; } diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeArrayData.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeArrayData.java index 9002abdcfd474..d5f679fe23d48 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeArrayData.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeArrayData.java @@ -334,17 +334,11 @@ public void setLong(int ordinal, long value) { } public void setFloat(int ordinal, float value) { - if (Float.isNaN(value)) { - value = Float.NaN; - } assertIndexIsValid(ordinal); Platform.putFloat(baseObject, getElementOffset(ordinal, 4), value); } public void setDouble(int ordinal, double value) { - if (Double.isNaN(value)) { - value = Double.NaN; - } assertIndexIsValid(ordinal); Platform.putDouble(baseObject, getElementOffset(ordinal, 8), value); } diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java index a76e6ef8c91c1..9bf9452855f5f 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java @@ -224,9 +224,6 @@ public void setLong(int ordinal, long value) { public void setDouble(int ordinal, double value) { assertIndexIsValid(ordinal); setNotNullAt(ordinal); - if (Double.isNaN(value)) { - value = Double.NaN; - } Platform.putDouble(baseObject, getFieldOffset(ordinal), value); } @@ -255,9 +252,6 @@ public void setByte(int ordinal, byte value) { public void setFloat(int ordinal, float value) { assertIndexIsValid(ordinal); setNotNullAt(ordinal); - if (Float.isNaN(value)) { - value = Float.NaN; - } Platform.putFloat(baseObject, getFieldOffset(ordinal), value); } diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/UnsafeWriter.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/UnsafeWriter.java index 2781655002000..95263a0da95a8 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/UnsafeWriter.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/UnsafeWriter.java @@ -199,16 +199,10 @@ protected final void writeLong(long offset, long value) { } protected final void writeFloat(long offset, float value) { - if (Float.isNaN(value)) { - value = Float.NaN; - } Platform.putFloat(getBuffer(), offset, value); } protected final void writeDouble(long offset, double value) { - if (Double.isNaN(value)) { - value = Double.NaN; - } Platform.putDouble(getBuffer(), offset, value); } } diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/execution/UnsafeExternalRowSorter.java b/sql/catalyst/src/main/java/org/apache/spark/sql/execution/UnsafeExternalRowSorter.java index 1b2f5eee5ccdd..5395e4035e680 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/execution/UnsafeExternalRowSorter.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/execution/UnsafeExternalRowSorter.java @@ -50,7 +50,7 @@ public final class UnsafeExternalRowSorter { private long numRowsInserted = 0; private final StructType schema; - private final PrefixComputer prefixComputer; + private final UnsafeExternalRowSorter.PrefixComputer prefixComputer; private final UnsafeExternalSorter sorter; public abstract static class PrefixComputer { @@ -74,7 +74,7 @@ public static UnsafeExternalRowSorter createWithRecordComparator( StructType schema, Supplier recordComparatorSupplier, PrefixComparator prefixComparator, - PrefixComputer prefixComputer, + UnsafeExternalRowSorter.PrefixComputer prefixComputer, long pageSizeBytes, boolean canUseRadixSort) throws IOException { return new UnsafeExternalRowSorter(schema, recordComparatorSupplier, prefixComparator, @@ -85,7 +85,7 @@ public static UnsafeExternalRowSorter create( StructType schema, Ordering ordering, PrefixComparator prefixComparator, - PrefixComputer prefixComputer, + UnsafeExternalRowSorter.PrefixComputer prefixComputer, long pageSizeBytes, boolean canUseRadixSort) throws IOException { Supplier recordComparatorSupplier = @@ -98,9 +98,9 @@ private UnsafeExternalRowSorter( StructType schema, Supplier recordComparatorSupplier, PrefixComparator prefixComparator, - PrefixComputer prefixComputer, + UnsafeExternalRowSorter.PrefixComputer prefixComputer, long pageSizeBytes, - boolean canUseRadixSort) throws IOException { + boolean canUseRadixSort) { this.schema = schema; this.prefixComputer = prefixComputer; final SparkEnv sparkEnv = SparkEnv.get(); diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/streaming/GroupStateTimeout.java b/sql/catalyst/src/main/java/org/apache/spark/sql/streaming/GroupStateTimeout.java index 5f1032d1229da..5f6a46f2b8e89 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/streaming/GroupStateTimeout.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/streaming/GroupStateTimeout.java @@ -17,8 +17,8 @@ package org.apache.spark.sql.streaming; +import org.apache.spark.annotation.Evolving; import org.apache.spark.annotation.Experimental; -import org.apache.spark.annotation.InterfaceStability; import org.apache.spark.sql.catalyst.plans.logical.*; /** @@ -29,7 +29,7 @@ * @since 2.2.0 */ @Experimental -@InterfaceStability.Evolving +@Evolving public class GroupStateTimeout { /** diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/streaming/OutputMode.java b/sql/catalyst/src/main/java/org/apache/spark/sql/streaming/OutputMode.java index 470c128ee6c3d..a3d72a1f5d49f 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/streaming/OutputMode.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/streaming/OutputMode.java @@ -17,7 +17,7 @@ package org.apache.spark.sql.streaming; -import org.apache.spark.annotation.InterfaceStability; +import org.apache.spark.annotation.Evolving; import org.apache.spark.sql.catalyst.streaming.InternalOutputModes; /** @@ -26,7 +26,7 @@ * * @since 2.0.0 */ -@InterfaceStability.Evolving +@Evolving public class OutputMode { /** diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/types/DataTypes.java b/sql/catalyst/src/main/java/org/apache/spark/sql/types/DataTypes.java index 0f8570fe470bd..d786374f69e20 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/types/DataTypes.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/types/DataTypes.java @@ -19,7 +19,7 @@ import java.util.*; -import org.apache.spark.annotation.InterfaceStability; +import org.apache.spark.annotation.Stable; /** * To get/create specific data type, users should use singleton objects and factory methods @@ -27,7 +27,7 @@ * * @since 1.3.0 */ -@InterfaceStability.Stable +@Stable public class DataTypes { /** * Gets the StringType object. diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/types/SQLUserDefinedType.java b/sql/catalyst/src/main/java/org/apache/spark/sql/types/SQLUserDefinedType.java index 1290614a3207d..a54398324fc66 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/types/SQLUserDefinedType.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/types/SQLUserDefinedType.java @@ -20,7 +20,7 @@ import java.lang.annotation.*; import org.apache.spark.annotation.DeveloperApi; -import org.apache.spark.annotation.InterfaceStability; +import org.apache.spark.annotation.Evolving; /** * ::DeveloperApi:: @@ -31,7 +31,7 @@ @DeveloperApi @Retention(RetentionPolicy.RUNTIME) @Target(ElementType.TYPE) -@InterfaceStability.Evolving +@Evolving public @interface SQLUserDefinedType { /** diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/AnalysisException.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/AnalysisException.scala index 50ee6cd4085ea..f5c87677ab9eb 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/AnalysisException.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/AnalysisException.scala @@ -17,16 +17,15 @@ package org.apache.spark.sql -import org.apache.spark.annotation.InterfaceStability +import org.apache.spark.annotation.Stable import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan - /** * Thrown when a query fails to analyze, usually because the query itself is invalid. * * @since 1.3.0 */ -@InterfaceStability.Stable +@Stable class AnalysisException protected[sql] ( val message: String, val line: Option[Int] = None, diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/Encoder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/Encoder.scala index 7b02317b8538f..9853a4fcc2f9d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/Encoder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/Encoder.scala @@ -20,10 +20,9 @@ package org.apache.spark.sql import scala.annotation.implicitNotFound import scala.reflect.ClassTag -import org.apache.spark.annotation.{Experimental, InterfaceStability} +import org.apache.spark.annotation.{Evolving, Experimental} import org.apache.spark.sql.types._ - /** * :: Experimental :: * Used to convert a JVM object of type `T` to and from the internal Spark SQL representation. @@ -67,7 +66,7 @@ import org.apache.spark.sql.types._ * @since 1.6.0 */ @Experimental -@InterfaceStability.Evolving +@Evolving @implicitNotFound("Unable to find encoder for type ${T}. An implicit Encoder[${T}] is needed to " + "store ${T} instances in a Dataset. Primitive types (Int, String, etc) and Product types (case " + "classes) are supported by importing spark.implicits._ Support for serializing other types " + diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/Encoders.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/Encoders.scala index 8a30c81912fe9..42b865c027205 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/Encoders.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/Encoders.scala @@ -22,7 +22,7 @@ import java.lang.reflect.Modifier import scala.reflect.{classTag, ClassTag} import scala.reflect.runtime.universe.TypeTag -import org.apache.spark.annotation.{Experimental, InterfaceStability} +import org.apache.spark.annotation.{Evolving, Experimental} import org.apache.spark.sql.catalyst.analysis.GetColumnByOrdinal import org.apache.spark.sql.catalyst.encoders.{encoderFor, ExpressionEncoder} import org.apache.spark.sql.catalyst.expressions.{BoundReference, Cast} @@ -36,7 +36,7 @@ import org.apache.spark.sql.types._ * @since 1.6.0 */ @Experimental -@InterfaceStability.Evolving +@Evolving object Encoders { /** diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/Row.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/Row.scala index 180c2d130074e..e12bf9616e2de 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/Row.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/Row.scala @@ -20,14 +20,14 @@ package org.apache.spark.sql import scala.collection.JavaConverters._ import scala.util.hashing.MurmurHash3 -import org.apache.spark.annotation.InterfaceStability +import org.apache.spark.annotation.Stable import org.apache.spark.sql.catalyst.expressions.GenericRow import org.apache.spark.sql.types.StructType /** * @since 1.3.0 */ -@InterfaceStability.Stable +@Stable object Row { /** * This method can be used to extract fields from a [[Row]] object in a pattern match. Example: @@ -124,7 +124,7 @@ object Row { * * @since 1.3.0 */ -@InterfaceStability.Stable +@Stable trait Row extends Serializable { /** Number of elements in the Row. */ def size: Int = length diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/JavaTypeInference.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/JavaTypeInference.scala index 8ef8b2be6939c..311060e5961cb 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/JavaTypeInference.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/JavaTypeInference.scala @@ -73,10 +73,10 @@ object JavaTypeInference { : (DataType, Boolean) = { typeToken.getRawType match { case c: Class[_] if c.isAnnotationPresent(classOf[SQLUserDefinedType]) => - (c.getAnnotation(classOf[SQLUserDefinedType]).udt().newInstance(), true) + (c.getAnnotation(classOf[SQLUserDefinedType]).udt().getConstructor().newInstance(), true) case c: Class[_] if UDTRegistration.exists(c.getName) => - val udt = UDTRegistration.getUDTFor(c.getName).get.newInstance() + val udt = UDTRegistration.getUDTFor(c.getName).get.getConstructor().newInstance() .asInstanceOf[UserDefinedType[_ >: Null]] (udt, true) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/QueryPlanningTracker.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/QueryPlanningTracker.scala new file mode 100644 index 0000000000000..244081cd160b6 --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/QueryPlanningTracker.scala @@ -0,0 +1,134 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst + +import scala.collection.JavaConverters._ + +import org.apache.spark.util.BoundedPriorityQueue + + +/** + * A simple utility for tracking runtime and associated stats in query planning. + * + * There are two separate concepts we track: + * + * 1. Phases: These are broad scope phases in query planning, as listed below, i.e. analysis, + * optimizationm and physical planning (just planning). + * + * 2. Rules: These are the individual Catalyst rules that we track. In addition to time, we also + * track the number of invocations and effective invocations. + */ +object QueryPlanningTracker { + + // Define a list of common phases here. + val PARSING = "parsing" + val ANALYSIS = "analysis" + val OPTIMIZATION = "optimization" + val PLANNING = "planning" + + class RuleSummary( + var totalTimeNs: Long, var numInvocations: Long, var numEffectiveInvocations: Long) { + + def this() = this(totalTimeNs = 0, numInvocations = 0, numEffectiveInvocations = 0) + + override def toString: String = { + s"RuleSummary($totalTimeNs, $numInvocations, $numEffectiveInvocations)" + } + } + + /** + * A thread local variable to implicitly pass the tracker around. This assumes the query planner + * is single-threaded, and avoids passing the same tracker context in every function call. + */ + private val localTracker = new ThreadLocal[QueryPlanningTracker]() { + override def initialValue: QueryPlanningTracker = null + } + + /** Returns the current tracker in scope, based on the thread local variable. */ + def get: Option[QueryPlanningTracker] = Option(localTracker.get()) + + /** Sets the current tracker for the execution of function f. We assume f is single-threaded. */ + def withTracker[T](tracker: QueryPlanningTracker)(f: => T): T = { + val originalTracker = localTracker.get() + localTracker.set(tracker) + try f finally { localTracker.set(originalTracker) } + } +} + + +class QueryPlanningTracker { + + import QueryPlanningTracker._ + + // Mapping from the name of a rule to a rule's summary. + // Use a Java HashMap for less overhead. + private val rulesMap = new java.util.HashMap[String, RuleSummary] + + // From a phase to time in ns. + private val phaseToTimeNs = new java.util.HashMap[String, Long] + + /** Measure the runtime of function f, and add it to the time for the specified phase. */ + def measureTime[T](phase: String)(f: => T): T = { + val startTime = System.nanoTime() + val ret = f + val timeTaken = System.nanoTime() - startTime + phaseToTimeNs.put(phase, phaseToTimeNs.getOrDefault(phase, 0) + timeTaken) + ret + } + + /** + * Record a specific invocation of a rule. + * + * @param rule name of the rule + * @param timeNs time taken to run this invocation + * @param effective whether the invocation has resulted in a plan change + */ + def recordRuleInvocation(rule: String, timeNs: Long, effective: Boolean): Unit = { + var s = rulesMap.get(rule) + if (s eq null) { + s = new RuleSummary + rulesMap.put(rule, s) + } + + s.totalTimeNs += timeNs + s.numInvocations += 1 + s.numEffectiveInvocations += (if (effective) 1 else 0) + } + + // ------------ reporting functions below ------------ + + def rules: Map[String, RuleSummary] = rulesMap.asScala.toMap + + def phases: Map[String, Long] = phaseToTimeNs.asScala.toMap + + /** + * Returns the top k most expensive rules (as measured by time). If k is larger than the rules + * seen so far, return all the rules. If there is no rule seen so far or k <= 0, return empty seq. + */ + def topRulesByTime(k: Int): Seq[(String, RuleSummary)] = { + if (k <= 0) { + Seq.empty + } else { + val orderingByTime: Ordering[(String, RuleSummary)] = Ordering.by(e => e._2.totalTimeNs) + val q = new BoundedPriorityQueue(k)(orderingByTime) + rulesMap.asScala.foreach(q.+=) + q.toSeq.sortBy(r => -r._2.totalTimeNs) + } + } + +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala index 912744eab6a3a..c8542d0f2f7de 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala @@ -357,7 +357,8 @@ object ScalaReflection extends ScalaReflection { ) case t if t.typeSymbol.annotations.exists(_.tree.tpe =:= typeOf[SQLUserDefinedType]) => - val udt = getClassFromType(t).getAnnotation(classOf[SQLUserDefinedType]).udt().newInstance() + val udt = getClassFromType(t).getAnnotation(classOf[SQLUserDefinedType]).udt(). + getConstructor().newInstance() val obj = NewInstance( udt.userClass.getAnnotation(classOf[SQLUserDefinedType]).udt(), Nil, @@ -365,8 +366,8 @@ object ScalaReflection extends ScalaReflection { Invoke(obj, "deserialize", ObjectType(udt.userClass), path :: Nil) case t if UDTRegistration.exists(getClassNameFromType(t)) => - val udt = UDTRegistration.getUDTFor(getClassNameFromType(t)).get.newInstance() - .asInstanceOf[UserDefinedType[_]] + val udt = UDTRegistration.getUDTFor(getClassNameFromType(t)).get.getConstructor(). + newInstance().asInstanceOf[UserDefinedType[_]] val obj = NewInstance( udt.getClass, Nil, @@ -601,7 +602,7 @@ object ScalaReflection extends ScalaReflection { case t if t.typeSymbol.annotations.exists(_.tree.tpe =:= typeOf[SQLUserDefinedType]) => val udt = getClassFromType(t) - .getAnnotation(classOf[SQLUserDefinedType]).udt().newInstance() + .getAnnotation(classOf[SQLUserDefinedType]).udt().getConstructor().newInstance() val obj = NewInstance( udt.userClass.getAnnotation(classOf[SQLUserDefinedType]).udt(), Nil, @@ -609,8 +610,8 @@ object ScalaReflection extends ScalaReflection { Invoke(obj, "serialize", udt, inputObject :: Nil) case t if UDTRegistration.exists(getClassNameFromType(t)) => - val udt = UDTRegistration.getUDTFor(getClassNameFromType(t)).get.newInstance() - .asInstanceOf[UserDefinedType[_]] + val udt = UDTRegistration.getUDTFor(getClassNameFromType(t)).get.getConstructor(). + newInstance().asInstanceOf[UserDefinedType[_]] val obj = NewInstance( udt.getClass, Nil, @@ -721,11 +722,12 @@ object ScalaReflection extends ScalaReflection { // Null type would wrongly match the first of them, which is Option as of now case t if t <:< definitions.NullTpe => Schema(NullType, nullable = true) case t if t.typeSymbol.annotations.exists(_.tree.tpe =:= typeOf[SQLUserDefinedType]) => - val udt = getClassFromType(t).getAnnotation(classOf[SQLUserDefinedType]).udt().newInstance() + val udt = getClassFromType(t).getAnnotation(classOf[SQLUserDefinedType]).udt(). + getConstructor().newInstance() Schema(udt, nullable = true) case t if UDTRegistration.exists(getClassNameFromType(t)) => - val udt = UDTRegistration.getUDTFor(getClassNameFromType(t)).get.newInstance() - .asInstanceOf[UserDefinedType[_]] + val udt = UDTRegistration.getUDTFor(getClassNameFromType(t)).get.getConstructor(). + newInstance().asInstanceOf[UserDefinedType[_]] Schema(udt, nullable = true) case t if t <:< localTypeOf[Option[_]] => val TypeRef(_, _, Seq(optType)) = t @@ -786,12 +788,37 @@ object ScalaReflection extends ScalaReflection { } /** - * Finds an accessible constructor with compatible parameters. This is a more flexible search - * than the exact matching algorithm in `Class.getConstructor`. The first assignment-compatible - * matching constructor is returned. Otherwise, it returns `None`. + * Finds an accessible constructor with compatible parameters. This is a more flexible search than + * the exact matching algorithm in `Class.getConstructor`. The first assignment-compatible + * matching constructor is returned if it exists. Otherwise, we check for additional compatible + * constructors defined in the companion object as `apply` methods. Otherwise, it returns `None`. */ - def findConstructor(cls: Class[_], paramTypes: Seq[Class[_]]): Option[Constructor[_]] = { - Option(ConstructorUtils.getMatchingAccessibleConstructor(cls, paramTypes: _*)) + def findConstructor[T](cls: Class[T], paramTypes: Seq[Class[_]]): Option[Seq[AnyRef] => T] = { + Option(ConstructorUtils.getMatchingAccessibleConstructor(cls, paramTypes: _*)) match { + case Some(c) => Some(x => c.newInstance(x: _*).asInstanceOf[T]) + case None => + val companion = mirror.staticClass(cls.getName).companion + val moduleMirror = mirror.reflectModule(companion.asModule) + val applyMethods = companion.asTerm.typeSignature + .member(universe.TermName("apply")).asTerm.alternatives + applyMethods.find { method => + val params = method.typeSignature.paramLists.head + // Check that the needed params are the same length and of matching types + params.size == paramTypes.tail.size && + params.zip(paramTypes.tail).forall { case(ps, pc) => + ps.typeSignature.typeSymbol == mirror.classSymbol(pc) + } + }.map { applyMethodSymbol => + val expectedArgsCount = applyMethodSymbol.typeSignature.paramLists.head.size + val instanceMirror = mirror.reflect(moduleMirror.instance) + val method = instanceMirror.reflectMethod(applyMethodSymbol.asMethod) + (_args: Seq[AnyRef]) => { + // Drop the "outer" argument if it is provided + val args = if (_args.size == expectedArgsCount) _args else _args.tail + method.apply(args: _*).asInstanceOf[T] + } + } + } } /** @@ -971,8 +998,19 @@ trait ScalaReflection extends Logging { } } + /** + * If our type is a Scala trait it may have a companion object that + * only defines a constructor via `apply` method. + */ + private def getCompanionConstructor(tpe: Type): Symbol = { + tpe.typeSymbol.asClass.companion.asTerm.typeSignature.member(universe.TermName("apply")) + } + protected def constructParams(tpe: Type): Seq[Symbol] = { - val constructorSymbol = tpe.dealias.member(termNames.CONSTRUCTOR) + val constructorSymbol = tpe.member(termNames.CONSTRUCTOR) match { + case NoSymbol => getCompanionConstructor(tpe) + case sym => sym + } val params = if (constructorSymbol.isMethod) { constructorSymbol.asMethod.paramLists } else { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index c2d22c5e7ce60..b977fa07db5c4 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -102,16 +102,18 @@ class Analyzer( this(catalog, conf, conf.optimizerMaxIterations) } - def executeAndCheck(plan: LogicalPlan): LogicalPlan = AnalysisHelper.markInAnalyzer { - val analyzed = execute(plan) - try { - checkAnalysis(analyzed) - analyzed - } catch { - case e: AnalysisException => - val ae = new AnalysisException(e.message, e.line, e.startPosition, Option(analyzed)) - ae.setStackTrace(e.getStackTrace) - throw ae + def executeAndCheck(plan: LogicalPlan, tracker: QueryPlanningTracker): LogicalPlan = { + AnalysisHelper.markInAnalyzer { + val analyzed = executeAndTrack(plan, tracker) + try { + checkAnalysis(analyzed) + analyzed + } catch { + case e: AnalysisException => + val ae = new AnalysisException(e.message, e.line, e.startPosition, Option(analyzed)) + ae.setStackTrace(e.getStackTrace) + throw ae + } } } @@ -824,7 +826,8 @@ class Analyzer( } private def dedupAttr(attr: Attribute, attrMap: AttributeMap[Attribute]): Attribute = { - attrMap.get(attr).getOrElse(attr).withQualifier(attr.qualifier) + val exprId = attrMap.getOrElse(attr, attr).exprId + attr.withExprId(exprId) } /** @@ -870,7 +873,7 @@ class Analyzer( private def dedupOuterReferencesInSubquery( plan: LogicalPlan, attrMap: AttributeMap[Attribute]): LogicalPlan = { - plan resolveOperatorsDown { case currentFragment => + plan transformDown { case currentFragment => currentFragment transformExpressions { case OuterReference(a: Attribute) => OuterReference(dedupAttr(a, attrMap)) @@ -952,6 +955,12 @@ class Analyzer( // rule: ResolveDeserializer. case plan if containsDeserializer(plan.expressions) => plan + // SPARK-25942: Resolves aggregate expressions with `AppendColumns`'s children, instead of + // `AppendColumns`, because `AppendColumns`'s serializer might produce conflict attribute + // names leading to ambiguous references exception. + case a @ Aggregate(groupingExprs, aggExprs, appendColumns: AppendColumns) => + a.mapExpressions(resolve(_, appendColumns)) + case q: LogicalPlan => logTrace(s"Attempting to resolve ${q.simpleString}") q.mapExpressions(resolve(_, q)) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala index 72ac80e0a0a18..133fa119b7aa6 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala @@ -181,8 +181,9 @@ object TypeCoercion { } /** - * The method finds a common type for data types that differ only in nullable, containsNull - * and valueContainsNull flags. If the input types are too different, None is returned. + * The method finds a common type for data types that differ only in nullable flags, including + * `nullable`, `containsNull` of [[ArrayType]] and `valueContainsNull` of [[MapType]]. + * If the input types are different besides nullable flags, None is returned. */ def findCommonTypeDifferentOnlyInNullFlags(t1: DataType, t2: DataType): Option[DataType] = { if (t1 == t2) { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala index 857cf382b8f2c..36cad3cf74785 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala @@ -112,6 +112,7 @@ case class UnresolvedAttribute(nameParts: Seq[String]) extends Attribute with Un override def withQualifier(newQualifier: Seq[String]): UnresolvedAttribute = this override def withName(newName: String): UnresolvedAttribute = UnresolvedAttribute.quoted(newName) override def withMetadata(newMetadata: Metadata): Attribute = this + override def withExprId(newExprId: ExprId): UnresolvedAttribute = this override def toString: String = s"'$name" diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalog.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalog.scala index c11b444212946..b6771ec4dffe9 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalog.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalog.scala @@ -1134,7 +1134,8 @@ class SessionCatalog( if (clsForUDAF.isAssignableFrom(clazz)) { val cls = Utils.classForName("org.apache.spark.sql.execution.aggregate.ScalaUDAF") val e = cls.getConstructor(classOf[Seq[Expression]], clsForUDAF, classOf[Int], classOf[Int]) - .newInstance(input, clazz.newInstance().asInstanceOf[Object], Int.box(1), Int.box(1)) + .newInstance(input, + clazz.getConstructor().newInstance().asInstanceOf[Object], Int.box(1), Int.box(1)) .asInstanceOf[ImplicitCastInputTypes] // Check input argument size diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/csv/CSVOptions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/csv/CSVOptions.scala index cdaaa172e8367..94bdb72d675d4 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/csv/CSVOptions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/csv/CSVOptions.scala @@ -25,6 +25,7 @@ import org.apache.commons.lang3.time.FastDateFormat import org.apache.spark.internal.Logging import org.apache.spark.sql.catalyst.util._ +import org.apache.spark.sql.internal.SQLConf class CSVOptions( @transient val parameters: CaseInsensitiveMap[String], @@ -33,11 +34,22 @@ class CSVOptions( defaultColumnNameOfCorruptRecord: String) extends Logging with Serializable { + def this( + parameters: Map[String, String], + columnPruning: Boolean, + defaultTimeZoneId: String) = { + this( + CaseInsensitiveMap(parameters), + columnPruning, + defaultTimeZoneId, + SQLConf.get.columnNameOfCorruptRecord) + } + def this( parameters: Map[String, String], columnPruning: Boolean, defaultTimeZoneId: String, - defaultColumnNameOfCorruptRecord: String = "") = { + defaultColumnNameOfCorruptRecord: String) = { this( CaseInsensitiveMap(parameters), columnPruning, @@ -131,13 +143,16 @@ class CSVOptions( val timeZone: TimeZone = DateTimeUtils.getTimeZone( parameters.getOrElse(DateTimeUtils.TIMEZONE_OPTION, defaultTimeZoneId)) + // A language tag in IETF BCP 47 format + val locale: Locale = parameters.get("locale").map(Locale.forLanguageTag).getOrElse(Locale.US) + // Uses `FastDateFormat` which can be direct replacement for `SimpleDateFormat` and thread-safe. val dateFormat: FastDateFormat = - FastDateFormat.getInstance(parameters.getOrElse("dateFormat", "yyyy-MM-dd"), Locale.US) + FastDateFormat.getInstance(parameters.getOrElse("dateFormat", "yyyy-MM-dd"), locale) val timestampFormat: FastDateFormat = FastDateFormat.getInstance( - parameters.getOrElse("timestampFormat", "yyyy-MM-dd'T'HH:mm:ss.SSSXXX"), timeZone, Locale.US) + parameters.getOrElse("timestampFormat", "yyyy-MM-dd'T'HH:mm:ss.SSSXXX"), timeZone, locale) val multiLine = parameters.get("multiLine").map(_.toBoolean).getOrElse(false) @@ -177,6 +192,20 @@ class CSVOptions( */ val emptyValueInWrite = emptyValue.getOrElse("\"\"") + /** + * A string between two consecutive JSON records. + */ + val lineSeparator: Option[String] = parameters.get("lineSep").map { sep => + require(sep.nonEmpty, "'lineSep' cannot be an empty string.") + require(sep.length == 1, "'lineSep' can contain only 1 character.") + sep + } + + val lineSeparatorInRead: Option[Array[Byte]] = lineSeparator.map { lineSep => + lineSep.getBytes(charset) + } + val lineSeparatorInWrite: Option[String] = lineSeparator + def asWriterSettings: CsvWriterSettings = { val writerSettings = new CsvWriterSettings() val format = writerSettings.getFormat @@ -185,6 +214,8 @@ class CSVOptions( format.setQuoteEscape(escape) charToEscapeQuoteEscaping.foreach(format.setCharToEscapeQuoteEscaping) format.setComment(comment) + lineSeparatorInWrite.foreach(format.setLineSeparator) + writerSettings.setIgnoreLeadingWhitespaces(ignoreLeadingWhiteSpaceFlagInWrite) writerSettings.setIgnoreTrailingWhitespaces(ignoreTrailingWhiteSpaceFlagInWrite) writerSettings.setNullValue(nullValue) @@ -201,8 +232,10 @@ class CSVOptions( format.setDelimiter(delimiter) format.setQuote(quote) format.setQuoteEscape(escape) + lineSeparator.foreach(format.setLineSeparator) charToEscapeQuoteEscaping.foreach(format.setCharToEscapeQuoteEscaping) format.setComment(comment) + settings.setIgnoreLeadingWhitespaces(ignoreLeadingWhiteSpaceInRead) settings.setIgnoreTrailingWhitespaces(ignoreTrailingWhiteSpaceInRead) settings.setReadInputOnSeparateThread(false) @@ -212,7 +245,10 @@ class CSVOptions( settings.setEmptyValue(emptyValueInRead) settings.setMaxCharsPerColumn(maxCharsPerColumn) settings.setUnescapedQuoteHandling(UnescapedQuoteHandling.STOP_AT_DELIMITER) - settings.setLineSeparatorDetectionEnabled(multiLine == true) + settings.setLineSeparatorDetectionEnabled(lineSeparatorInRead.isEmpty && multiLine) + lineSeparatorInRead.foreach { _ => + settings.setNormalizeLineEndingsWithinQuotes(!multiLine) + } settings } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/csv/UnivocityParser.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/csv/UnivocityParser.scala index 46ed58ed92830..ed196935e357f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/csv/UnivocityParser.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/csv/UnivocityParser.scala @@ -271,11 +271,12 @@ private[sql] object UnivocityParser { def tokenizeStream( inputStream: InputStream, shouldDropHeader: Boolean, - tokenizer: CsvParser): Iterator[Array[String]] = { + tokenizer: CsvParser, + encoding: String): Iterator[Array[String]] = { val handleHeader: () => Unit = () => if (shouldDropHeader) tokenizer.parseNext - convertStream(inputStream, tokenizer, handleHeader)(tokens => tokens) + convertStream(inputStream, tokenizer, handleHeader, encoding)(tokens => tokens) } /** @@ -297,7 +298,7 @@ private[sql] object UnivocityParser { val handleHeader: () => Unit = () => headerChecker.checkHeaderColumnNames(tokenizer) - convertStream(inputStream, tokenizer, handleHeader) { tokens => + convertStream(inputStream, tokenizer, handleHeader, parser.options.charset) { tokens => safeParser.parse(tokens) }.flatten } @@ -305,9 +306,10 @@ private[sql] object UnivocityParser { private def convertStream[T]( inputStream: InputStream, tokenizer: CsvParser, - handleHeader: () => Unit)( + handleHeader: () => Unit, + encoding: String)( convert: Array[String] => T) = new Iterator[T] { - tokenizer.beginParsing(inputStream) + tokenizer.beginParsing(inputStream, encoding) // We can handle header here since here the stream is open. handleHeader() diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala index 592520c59a761..589e215c55e44 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala @@ -49,15 +49,6 @@ object ExpressionEncoder { val mirror = ScalaReflection.mirror val tpe = typeTag[T].in(mirror).tpe - if (ScalaReflection.optionOfProductType(tpe)) { - throw new UnsupportedOperationException( - "Cannot create encoder for Option of Product type, because Product type is represented " + - "as a row, and the entire row can not be null in Spark SQL like normal databases. " + - "You can wrap your type with Tuple1 if you do want top level null Product objects, " + - "e.g. instead of creating `Dataset[Option[MyClass]]`, you can do something like " + - "`val ds: Dataset[Tuple1[MyClass]] = Seq(Tuple1(MyClass(...)), Tuple1(null)).toDS`") - } - val cls = mirror.runtimeClass(tpe) val serializer = ScalaReflection.serializerForType(tpe) val deserializer = ScalaReflection.deserializerForType(tpe) @@ -198,7 +189,7 @@ case class ExpressionEncoder[T]( val serializer: Seq[NamedExpression] = { val clsName = Utils.getSimpleName(clsTag.runtimeClass) - if (isSerializedAsStruct) { + if (isSerializedAsStructForTopLevel) { val nullSafeSerializer = objSerializer.transformUp { case r: BoundReference => // For input object of Product type, we can't encode it to row if it's null, as Spark SQL @@ -213,6 +204,9 @@ case class ExpressionEncoder[T]( } else { // For other input objects like primitive, array, map, etc., we construct a struct to wrap // the serializer which is a column of an row. + // + // Note: Because Spark SQL doesn't allow top-level row to be null, to encode + // top-level Option[Product] type, we make it as a top-level struct column. CreateNamedStruct(Literal("value") :: objSerializer :: Nil) } }.flatten @@ -226,7 +220,7 @@ case class ExpressionEncoder[T]( * `GetColumnByOrdinal` with corresponding ordinal. */ val deserializer: Expression = { - if (isSerializedAsStruct) { + if (isSerializedAsStructForTopLevel) { // We serialized this kind of objects to root-level row. The input of general deserializer // is a `GetColumnByOrdinal(0)` expression to extract first column of a row. We need to // transform attributes accessors. @@ -253,10 +247,21 @@ case class ExpressionEncoder[T]( }) /** - * Returns true if the type `T` is serialized as a struct. + * Returns true if the type `T` is serialized as a struct by `objSerializer`. */ def isSerializedAsStruct: Boolean = objSerializer.dataType.isInstanceOf[StructType] + /** + * If the type `T` is serialized as a struct, when it is encoded to a Spark SQL row, fields in + * the struct are naturally mapped to top-level columns in a row. In other words, the serialized + * struct is flattened to row. But in case of the `T` is also an `Option` type, it can't be + * flattened to top-level row, because in Spark SQL top-level row can't be null. This method + * returns true if `T` is serialized as struct and is not `Option` type. + */ + def isSerializedAsStructForTopLevel: Boolean = { + isSerializedAsStruct && !classOf[Option[_]].isAssignableFrom(clsTag.runtimeClass) + } + // serializer expressions are used to encode an object to a row, while the object is usually an // intermediate value produced inside an operator, not from the output of the child operator. This // is quite different from normal expressions, and `AttributeReference` doesn't work here diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ExprUtils.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ExprUtils.scala index 040b56cc1caea..89e9071324eff 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ExprUtils.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ExprUtils.scala @@ -67,4 +67,20 @@ object ExprUtils { case _ => throw new AnalysisException("Must use a map() function for options") } + + /** + * A convenient function for schema validation in datasources supporting + * `columnNameOfCorruptRecord` as an option. + */ + def verifyColumnNameOfCorruptRecord( + schema: StructType, + columnNameOfCorruptRecord: String): Unit = { + schema.getFieldIndex(columnNameOfCorruptRecord).foreach { corruptFieldIndex => + val f = schema(corruptFieldIndex) + if (f.dataType != StringType || !f.nullable) { + throw new AnalysisException( + "The field for corrupt records must be string type and nullable") + } + } + } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala index 141fcffcb6fab..2ecec61adb0ac 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala @@ -24,10 +24,11 @@ import org.apache.spark.sql.catalyst.analysis.{TypeCheckResult, TypeCoercion} import org.apache.spark.sql.catalyst.expressions.aggregate.DeclarativeAggregate import org.apache.spark.sql.catalyst.expressions.codegen._ import org.apache.spark.sql.catalyst.expressions.codegen.Block._ +import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.catalyst.trees.TreeNode +import org.apache.spark.sql.catalyst.util.truncatedString import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types._ -import org.apache.spark.util.Utils //////////////////////////////////////////////////////////////////////////////////////////////////// // This file defines the basic expression abstract classes in Catalyst. @@ -40,12 +41,28 @@ import org.apache.spark.util.Utils * "name(arguments...)", the concrete implementation must be a case class whose constructor * arguments are all Expressions types. See [[Substring]] for an example. * - * There are a few important traits: + * There are a few important traits or abstract classes: * * - [[Nondeterministic]]: an expression that is not deterministic. + * - [[Stateful]]: an expression that contains mutable state. For example, MonotonicallyIncreasingID + * and Rand. A stateful expression is always non-deterministic. * - [[Unevaluable]]: an expression that is not supposed to be evaluated. * - [[CodegenFallback]]: an expression that does not have code gen implemented and falls back to * interpreted mode. + * - [[NullIntolerant]]: an expression that is null intolerant (i.e. any null input will result in + * null output). + * - [[NonSQLExpression]]: a common base trait for the expressions that do not have SQL + * expressions like representation. For example, `ScalaUDF`, `ScalaUDAF`, + * and object `MapObjects` and `Invoke`. + * - [[UserDefinedExpression]]: a common base trait for user-defined functions, including + * UDF/UDAF/UDTF. + * - [[HigherOrderFunction]]: a common base trait for higher order functions that take one or more + * (lambda) functions and applies these to some objects. The function + * produces a number of variables which can be consumed by some lambda + * functions. + * - [[NamedExpression]]: An [[Expression]] that is named. + * - [[TimeZoneAwareExpression]]: A common base trait for time zone aware expressions. + * - [[SubqueryExpression]]: A base interface for expressions that contain a [[LogicalPlan]]. * * - [[LeafExpression]]: an expression that has no child. * - [[UnaryExpression]]: an expression that has one child. @@ -54,12 +71,20 @@ import org.apache.spark.util.Utils * - [[BinaryOperator]]: a special case of [[BinaryExpression]] that requires two children to have * the same output data type. * + * A few important traits used for type coercion rules: + * - [[ExpectsInputTypes]]: an expression that has the expected input types. This trait is typically + * used by operator expressions (e.g. [[Add]], [[Subtract]]) to define + * expected input types without any implicit casting. + * - [[ImplicitCastInputTypes]]: an expression that has the expected input types, which can be + * implicitly castable using [[TypeCoercion.ImplicitTypeCasts]]. + * - [[ComplexTypeMergingExpression]]: to resolve output types of the complex expressions + * (e.g., [[CaseWhen]]). */ abstract class Expression extends TreeNode[Expression] { /** * Returns true when an expression is a candidate for static evaluation before the query is - * executed. + * executed. A typical use case: [[org.apache.spark.sql.catalyst.optimizer.ConstantFolding]] * * The following conditions are used to determine suitability for constant folding: * - A [[Coalesce]] is foldable if all of its children are foldable @@ -72,7 +97,8 @@ abstract class Expression extends TreeNode[Expression] { /** * Returns true when the current expression always return the same result for fixed inputs from - * children. + * children. The non-deterministic expressions should not change in number and order. They should + * not be evaluated during the query planning. * * Note that this means that an expression should be considered as non-deterministic if: * - it relies on some mutable internal state, or @@ -237,7 +263,7 @@ abstract class Expression extends TreeNode[Expression] { override def simpleString: String = toString - override def toString: String = prettyName + Utils.truncatedString( + override def toString: String = prettyName + truncatedString( flatArguments.toSeq, "(", ", ", ")") /** @@ -252,8 +278,9 @@ abstract class Expression extends TreeNode[Expression] { /** - * An expression that cannot be evaluated. Some expressions don't live past analysis or optimization - * time (e.g. Star). This trait is used by those expressions. + * An expression that cannot be evaluated. These expressions don't live past analysis or + * optimization time (e.g. Star) and should not be evaluated during query planning and + * execution. */ trait Unevaluable extends Expression { @@ -724,9 +751,10 @@ abstract class TernaryExpression extends Expression { } /** - * A trait resolving nullable, containsNull, valueContainsNull flags of the output date type. - * This logic is usually utilized by expressions combining data from multiple child expressions - * of non-primitive types (e.g. [[CaseWhen]]). + * A trait used for resolving nullable flags, including `nullable`, `containsNull` of [[ArrayType]] + * and `valueContainsNull` of [[MapType]], containsNull, valueContainsNull flags of the output date + * type. This is usually utilized by the expressions (e.g. [[CaseWhen]]) that combine data from + * multiple child expressions of non-primitive types. */ trait ComplexTypeMergingExpression extends Expression { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/interfaces.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/interfaces.scala index e1d16a2cd38b0..56c2ee6b53fe5 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/interfaces.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/interfaces.scala @@ -128,12 +128,10 @@ case class AggregateExpression( override def nullable: Boolean = aggregateFunction.nullable override def references: AttributeSet = { - val childReferences = mode match { - case Partial | Complete => aggregateFunction.references.toSeq - case PartialMerge | Final => aggregateFunction.aggBufferAttributes + mode match { + case Partial | Complete => aggregateFunction.references + case PartialMerge | Final => AttributeSet(aggregateFunction.aggBufferAttributes) } - - AttributeSet(childReferences) } override def toString: String = { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala index b868a0f4fa284..7c8f7cd4315b8 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala @@ -1305,7 +1305,7 @@ object CodeGenerator extends Logging { throw new CompileException(msg, e.getLocation) } - (evaluator.getClazz().newInstance().asInstanceOf[GeneratedClass], maxCodeSize) + (evaluator.getClazz().getConstructor().newInstance().asInstanceOf[GeneratedClass], maxCodeSize) } /** diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateOrdering.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateOrdering.scala index 9a51be6ed5aeb..283fd2a6e9383 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateOrdering.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateOrdering.scala @@ -68,62 +68,55 @@ object GenerateOrdering extends CodeGenerator[Seq[SortOrder], Ordering[InternalR genComparisons(ctx, ordering) } + /** + * Creates the variables for ordering based on the given order. + */ + private def createOrderKeys( + ctx: CodegenContext, + row: String, + ordering: Seq[SortOrder]): Seq[ExprCode] = { + ctx.INPUT_ROW = row + // to use INPUT_ROW we must make sure currentVars is null + ctx.currentVars = null + ordering.map(_.child.genCode(ctx)) + } + /** * Generates the code for ordering based on the given order. */ def genComparisons(ctx: CodegenContext, ordering: Seq[SortOrder]): String = { val oldInputRow = ctx.INPUT_ROW val oldCurrentVars = ctx.currentVars - val inputRow = "i" - ctx.INPUT_ROW = inputRow - // to use INPUT_ROW we must make sure currentVars is null - ctx.currentVars = null - - val comparisons = ordering.map { order => - val eval = order.child.genCode(ctx) - val asc = order.isAscending - val isNullA = ctx.freshName("isNullA") - val primitiveA = ctx.freshName("primitiveA") - val isNullB = ctx.freshName("isNullB") - val primitiveB = ctx.freshName("primitiveB") + val rowAKeys = createOrderKeys(ctx, "a", ordering) + val rowBKeys = createOrderKeys(ctx, "b", ordering) + val comparisons = rowAKeys.zip(rowBKeys).zipWithIndex.map { case ((l, r), i) => + val dt = ordering(i).child.dataType + val asc = ordering(i).isAscending + val nullOrdering = ordering(i).nullOrdering + val lRetValue = nullOrdering match { + case NullsFirst => "-1" + case NullsLast => "1" + } + val rRetValue = nullOrdering match { + case NullsFirst => "1" + case NullsLast => "-1" + } s""" - ${ctx.INPUT_ROW} = a; - boolean $isNullA; - ${CodeGenerator.javaType(order.child.dataType)} $primitiveA; - { - ${eval.code} - $isNullA = ${eval.isNull}; - $primitiveA = ${eval.value}; - } - ${ctx.INPUT_ROW} = b; - boolean $isNullB; - ${CodeGenerator.javaType(order.child.dataType)} $primitiveB; - { - ${eval.code} - $isNullB = ${eval.isNull}; - $primitiveB = ${eval.value}; - } - if ($isNullA && $isNullB) { - // Nothing - } else if ($isNullA) { - return ${ - order.nullOrdering match { - case NullsFirst => "-1" - case NullsLast => "1" - }}; - } else if ($isNullB) { - return ${ - order.nullOrdering match { - case NullsFirst => "1" - case NullsLast => "-1" - }}; - } else { - int comp = ${ctx.genComp(order.child.dataType, primitiveA, primitiveB)}; - if (comp != 0) { - return ${if (asc) "comp" else "-comp"}; - } - } - """ + |${l.code} + |${r.code} + |if (${l.isNull} && ${r.isNull}) { + | // Nothing + |} else if (${l.isNull}) { + | return $lRetValue; + |} else if (${r.isNull}) { + | return $rRetValue; + |} else { + | int comp = ${ctx.genComp(dt, l.value, r.value)}; + | if (comp != 0) { + | return ${if (asc) "comp" else "-comp"}; + | } + |} + """.stripMargin } val code = ctx.splitExpressions( @@ -133,30 +126,24 @@ object GenerateOrdering extends CodeGenerator[Seq[SortOrder], Ordering[InternalR returnType = "int", makeSplitFunction = { body => s""" - InternalRow ${ctx.INPUT_ROW} = null; // Holds current row being evaluated. - $body - return 0; - """ + |$body + |return 0; + """.stripMargin }, foldFunctions = { funCalls => funCalls.zipWithIndex.map { case (funCall, i) => val comp = ctx.freshName("comp") s""" - int $comp = $funCall; - if ($comp != 0) { - return $comp; - } - """ + |int $comp = $funCall; + |if ($comp != 0) { + | return $comp; + |} + """.stripMargin }.mkString }) ctx.currentVars = oldCurrentVars ctx.INPUT_ROW = oldInputRow - // make sure INPUT_ROW is declared even if splitExpressions - // returns an inlined block - s""" - |InternalRow $inputRow = null; - |$code - """.stripMargin + code } protected def create(ordering: Seq[SortOrder]): BaseOrdering = { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala index b24d7486f3454..43116743e9952 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala @@ -350,7 +350,7 @@ case class MapValues(child: Expression) > SELECT _FUNC_(map(1, 'a', 2, 'b')); [{"key":1,"value":"a"},{"key":2,"value":"b"}] """, - since = "2.4.0") + since = "3.0.0") case class MapEntries(child: Expression) extends UnaryExpression with ExpectsInputTypes { override def inputTypes: Seq[AbstractDataType] = Seq(MapType) @@ -521,13 +521,18 @@ case class MapEntries(child: Expression) extends UnaryExpression with ExpectsInp case class MapConcat(children: Seq[Expression]) extends ComplexTypeMergingExpression { override def checkInputDataTypes(): TypeCheckResult = { - var funcName = s"function $prettyName" + val funcName = s"function $prettyName" if (children.exists(!_.dataType.isInstanceOf[MapType])) { TypeCheckResult.TypeCheckFailure( s"input to $funcName should all be of type map, but it's " + children.map(_.dataType.catalogString).mkString("[", ", ", "]")) } else { - TypeUtils.checkForSameTypeInputExpr(children.map(_.dataType), funcName) + val sameTypeCheck = TypeUtils.checkForSameTypeInputExpr(children.map(_.dataType), funcName) + if (sameTypeCheck.isFailure) { + sameTypeCheck + } else { + TypeUtils.checkForMapKeyType(dataType.keyType) + } } } @@ -740,7 +745,8 @@ case class MapFromEntries(child: Expression) extends UnaryExpression { @transient override lazy val dataType: MapType = dataTypeDetails.get._1 override def checkInputDataTypes(): TypeCheckResult = dataTypeDetails match { - case Some(_) => TypeCheckResult.TypeCheckSuccess + case Some((mapType, _, _)) => + TypeUtils.checkForMapKeyType(mapType.keyType) case None => TypeCheckResult.TypeCheckFailure(s"'${child.sql}' is of " + s"${child.dataType.catalogString} type. $prettyName accepts only arrays of pair structs.") } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala index 0361372b6b732..6b77996789f1a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala @@ -161,11 +161,11 @@ case class CreateMap(children: Seq[Expression]) extends Expression { "The given values of function map should all be the same type, but they are " + values.map(_.dataType.catalogString).mkString("[", ", ", "]")) } else { - TypeCheckResult.TypeCheckSuccess + TypeUtils.checkForMapKeyType(dataType.keyType) } } - override def dataType: DataType = { + override def dataType: MapType = { MapType( keyType = TypeCoercion.findCommonTypeDifferentOnlyInNullFlags(keys.map(_.dataType)) .getOrElse(StringType), @@ -224,6 +224,16 @@ case class MapFromArrays(left: Expression, right: Expression) override def inputTypes: Seq[AbstractDataType] = Seq(ArrayType, ArrayType) + override def checkInputDataTypes(): TypeCheckResult = { + val defaultCheck = super.checkInputDataTypes() + if (defaultCheck.isFailure) { + defaultCheck + } else { + val keyType = left.dataType.asInstanceOf[ArrayType].elementType + TypeUtils.checkForMapKeyType(keyType) + } + } + override def dataType: DataType = { MapType( keyType = left.dataType.asInstanceOf[ArrayType].elementType, diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/csvExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/csvExpressions.scala index aff372b899f86..1e4e1c663c90e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/csvExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/csvExpressions.scala @@ -106,6 +106,10 @@ case class CsvToStructs( throw new AnalysisException(s"from_csv() doesn't support the ${mode.name} mode. " + s"Acceptable modes are ${PermissiveMode.name} and ${FailFastMode.name}.") } + ExprUtils.verifyColumnNameOfCorruptRecord( + nullableSchema, + parsedOptions.columnNameOfCorruptRecord) + val actualSchema = StructType(nullableSchema.filterNot(_.name == parsedOptions.columnNameOfCorruptRecord)) val rawParser = new UnivocityParser(actualSchema, actualSchema, parsedOptions) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/higherOrderFunctions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/higherOrderFunctions.scala index b07d9466ba0d1..8b31021866220 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/higherOrderFunctions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/higherOrderFunctions.scala @@ -264,13 +264,13 @@ case class ArrayTransform( * Filters entries in a map using the provided function. */ @ExpressionDescription( -usage = "_FUNC_(expr, func) - Filters entries in a map using the function.", -examples = """ + usage = "_FUNC_(expr, func) - Filters entries in a map using the function.", + examples = """ Examples: > SELECT _FUNC_(map(1, 0, 2, 2, 3, -1), (k, v) -> k > v); {1:0,3:-1} """, -since = "2.4.0") + since = "3.0.0") case class MapFilter( argument: Expression, function: Expression) @@ -504,7 +504,7 @@ case class ArrayAggregate( > SELECT _FUNC_(map_from_arrays(array(1, 2, 3), array(1, 2, 3)), (k, v) -> k + v); {2:1,4:2,6:3} """, - since = "2.4.0") + since = "3.0.0") case class TransformKeys( argument: Expression, function: Expression) @@ -514,6 +514,10 @@ case class TransformKeys( override def dataType: DataType = MapType(function.dataType, valueType, valueContainsNull) + override def checkInputDataTypes(): TypeCheckResult = { + TypeUtils.checkForMapKeyType(function.dataType) + } + override def bind(f: (Expression, Seq[(DataType, Boolean)]) => LambdaFunction): TransformKeys = { copy(function = f(function, (keyType, false) :: (valueType, valueContainsNull) :: Nil)) } @@ -554,7 +558,7 @@ case class TransformKeys( > SELECT _FUNC_(map_from_arrays(array(1, 2, 3), array(1, 2, 3)), (k, v) -> k + v); {1:2,2:4,3:6} """, - since = "2.4.0") + since = "3.0.0") case class TransformValues( argument: Expression, function: Expression) @@ -605,7 +609,7 @@ case class TransformValues( > SELECT _FUNC_(map(1, 'a', 2, 'b'), map(1, 'x', 2, 'y'), (k, v1, v2) -> concat(v1, v2)); {1:"ax",2:"by"} """, - since = "2.4.0") + since = "3.0.0") case class MapZipWith(left: Expression, right: Expression, function: Expression) extends HigherOrderFunction with CodegenFallback { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/jsonExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/jsonExpressions.scala index eafcb6161036e..47304d835fdf8 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/jsonExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/jsonExpressions.scala @@ -550,15 +550,23 @@ case class JsonToStructs( s"Input schema ${nullableSchema.catalogString} must be a struct, an array or a map.") } - // This converts parsed rows to the desired output by the given schema. @transient - lazy val converter = nullableSchema match { - case _: StructType => - (rows: Iterator[InternalRow]) => if (rows.hasNext) rows.next() else null - case _: ArrayType => - (rows: Iterator[InternalRow]) => if (rows.hasNext) rows.next().getArray(0) else null - case _: MapType => - (rows: Iterator[InternalRow]) => if (rows.hasNext) rows.next().getMap(0) else null + private lazy val castRow = nullableSchema match { + case _: StructType => (row: InternalRow) => row + case _: ArrayType => (row: InternalRow) => row.getArray(0) + case _: MapType => (row: InternalRow) => row.getMap(0) + } + + // This converts parsed rows to the desired output by the given schema. + private def convertRow(rows: Iterator[InternalRow]) = { + if (rows.hasNext) { + val result = rows.next() + // JSON's parser produces one record only. + assert(!rows.hasNext) + castRow(result) + } else { + throw new IllegalArgumentException("Expected one row from JSON parser.") + } } val nameOfCorruptRecord = SQLConf.get.getConf(SQLConf.COLUMN_NAME_OF_CORRUPT_RECORD) @@ -569,14 +577,17 @@ case class JsonToStructs( throw new IllegalArgumentException(s"from_json() doesn't support the ${mode.name} mode. " + s"Acceptable modes are ${PermissiveMode.name} and ${FailFastMode.name}.") } - val rawParser = new JacksonParser(nullableSchema, parsedOptions, allowArrayAsStructs = false) - val createParser = CreateJacksonParser.utf8String _ - - val parserSchema = nullableSchema match { - case s: StructType => s - case other => StructType(StructField("value", other) :: Nil) + val (parserSchema, actualSchema) = nullableSchema match { + case s: StructType => + ExprUtils.verifyColumnNameOfCorruptRecord(s, parsedOptions.columnNameOfCorruptRecord) + (s, StructType(s.filterNot(_.name == parsedOptions.columnNameOfCorruptRecord))) + case other => + (StructType(StructField("value", other) :: Nil), other) } + val rawParser = new JacksonParser(actualSchema, parsedOptions, allowArrayAsStructs = false) + val createParser = CreateJacksonParser.utf8String _ + new FailureSafeParser[UTF8String]( input => rawParser.parse(input, createParser, identity[UTF8String]), mode, @@ -591,7 +602,7 @@ case class JsonToStructs( copy(timeZoneId = Option(timeZoneId)) override def nullSafeEval(json: Any): Any = { - converter(parser.parse(json.asInstanceOf[UTF8String])) + convertRow(parser.parse(json.asInstanceOf[UTF8String])) } override def inputTypes: Seq[AbstractDataType] = StringType :: Nil diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala index 584a2946bd564..02b48f9e30f2d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala @@ -115,6 +115,7 @@ abstract class Attribute extends LeafExpression with NamedExpression with NullIn def withQualifier(newQualifier: Seq[String]): Attribute def withName(newName: String): Attribute def withMetadata(newMetadata: Metadata): Attribute + def withExprId(newExprId: ExprId): Attribute override def toAttribute: Attribute = this def newInstance(): Attribute @@ -129,6 +130,9 @@ abstract class Attribute extends LeafExpression with NamedExpression with NullIn * Note that exprId and qualifiers are in a separate parameter list because * we only pattern match on child and name. * + * Note that when creating a new Alias, all the [[AttributeReference]] that refer to + * the original alias should be updated to the new one. + * * @param child The computation being performed * @param name The name to be associated with the result of computing [[child]]. * @param exprId A globally unique id used to check if an [[AttributeReference]] refers to this @@ -299,7 +303,7 @@ case class AttributeReference( } } - def withExprId(newExprId: ExprId): AttributeReference = { + override def withExprId(newExprId: ExprId): AttributeReference = { if (exprId == newExprId) { this } else { @@ -362,6 +366,8 @@ case class PrettyAttribute( throw new UnsupportedOperationException override def qualifier: Seq[String] = throw new UnsupportedOperationException override def exprId: ExprId = throw new UnsupportedOperationException + override def withExprId(newExprId: ExprId): Attribute = + throw new UnsupportedOperationException override def nullable: Boolean = true } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala index 4fd36a47cef52..59c897b6a53ce 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala @@ -462,12 +462,12 @@ case class NewInstance( val d = outerObj.getClass +: paramTypes val c = getConstructor(outerObj.getClass +: paramTypes) (args: Seq[AnyRef]) => { - c.newInstance(outerObj +: args: _*) + c(outerObj +: args) } }.getOrElse { val c = getConstructor(paramTypes) (args: Seq[AnyRef]) => { - c.newInstance(args: _*) + c(args) } } } @@ -486,10 +486,16 @@ case class NewInstance( ev.isNull = resultIsNull - val constructorCall = outer.map { gen => - s"${gen.value}.new ${cls.getSimpleName}($argString)" - }.getOrElse { - s"new $className($argString)" + val constructorCall = cls.getConstructors.size match { + // If there are no constructors, the `new` method will fail. In + // this case we can try to call the apply method constructor + // that might be defined on the companion object. + case 0 => s"$className$$.MODULE$$.apply($argString)" + case _ => outer.map { gen => + s"${gen.value}.new ${cls.getSimpleName}($argString)" + }.getOrElse { + s"new $className($argString)" + } } val code = code""" diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JSONOptions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JSONOptions.scala index 64152e04928d2..e10b8a327c01a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JSONOptions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JSONOptions.scala @@ -76,16 +76,19 @@ private[sql] class JSONOptions( // Whether to ignore column of all null values or empty array/struct during schema inference val dropFieldIfAllNull = parameters.get("dropFieldIfAllNull").map(_.toBoolean).getOrElse(false) + // A language tag in IETF BCP 47 format + val locale: Locale = parameters.get("locale").map(Locale.forLanguageTag).getOrElse(Locale.US) + val timeZone: TimeZone = DateTimeUtils.getTimeZone( parameters.getOrElse(DateTimeUtils.TIMEZONE_OPTION, defaultTimeZoneId)) // Uses `FastDateFormat` which can be direct replacement for `SimpleDateFormat` and thread-safe. val dateFormat: FastDateFormat = - FastDateFormat.getInstance(parameters.getOrElse("dateFormat", "yyyy-MM-dd"), Locale.US) + FastDateFormat.getInstance(parameters.getOrElse("dateFormat", "yyyy-MM-dd"), locale) val timestampFormat: FastDateFormat = FastDateFormat.getInstance( - parameters.getOrElse("timestampFormat", "yyyy-MM-dd'T'HH:mm:ss.SSSXXX"), timeZone, Locale.US) + parameters.getOrElse("timestampFormat", "yyyy-MM-dd'T'HH:mm:ss.SSSXXX"), timeZone, locale) val multiLine = parameters.get("multiLine").map(_.toBoolean).getOrElse(false) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JacksonParser.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JacksonParser.scala index 57c7f2faf3107..773ff5a7a4013 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JacksonParser.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JacksonParser.scala @@ -399,7 +399,7 @@ class JacksonParser( // a null first token is equivalent to testing for input.trim.isEmpty // but it works on any token stream and not just strings parser.nextToken() match { - case null => Nil + case null => throw new RuntimeException("Not found any JSON token") case _ => rootConverter.apply(parser) match { case null => throw new RuntimeException("Root converter returned null") case rows => rows diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index a50a94ea2765f..30d6f6b5db470 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -84,7 +84,7 @@ abstract class Optimizer(sessionCatalog: SessionCatalog) SimplifyConditionals, RemoveDispensableExpressions, SimplifyBinaryComparison, - ReplaceNullWithFalse, + ReplaceNullWithFalseInPredicate, PruneFilters, EliminateSorts, SimplifyCasts, diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/ReplaceNullWithFalseInPredicate.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/ReplaceNullWithFalseInPredicate.scala new file mode 100644 index 0000000000000..72a60f692ac78 --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/ReplaceNullWithFalseInPredicate.scala @@ -0,0 +1,110 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.optimizer + +import org.apache.spark.sql.catalyst.expressions.{And, ArrayExists, ArrayFilter, CaseWhen, Expression, If} +import org.apache.spark.sql.catalyst.expressions.{LambdaFunction, Literal, MapFilter, Or} +import org.apache.spark.sql.catalyst.expressions.Literal.FalseLiteral +import org.apache.spark.sql.catalyst.plans.logical.{Filter, Join, LogicalPlan} +import org.apache.spark.sql.catalyst.rules.Rule +import org.apache.spark.sql.types.BooleanType +import org.apache.spark.util.Utils + + +/** + * A rule that replaces `Literal(null, BooleanType)` with `FalseLiteral`, if possible, in the search + * condition of the WHERE/HAVING/ON(JOIN) clauses, which contain an implicit Boolean operator + * "(search condition) = TRUE". The replacement is only valid when `Literal(null, BooleanType)` is + * semantically equivalent to `FalseLiteral` when evaluating the whole search condition. + * + * Please note that FALSE and NULL are not exchangeable in most cases, when the search condition + * contains NOT and NULL-tolerant expressions. Thus, the rule is very conservative and applicable + * in very limited cases. + * + * For example, `Filter(Literal(null, BooleanType))` is equal to `Filter(FalseLiteral)`. + * + * Another example containing branches is `Filter(If(cond, FalseLiteral, Literal(null, _)))`; + * this can be optimized to `Filter(If(cond, FalseLiteral, FalseLiteral))`, and eventually + * `Filter(FalseLiteral)`. + * + * Moreover, this rule also transforms predicates in all [[If]] expressions as well as branch + * conditions in all [[CaseWhen]] expressions, even if they are not part of the search conditions. + * + * For example, `Project(If(And(cond, Literal(null)), Literal(1), Literal(2)))` can be simplified + * into `Project(Literal(2))`. + */ +object ReplaceNullWithFalseInPredicate extends Rule[LogicalPlan] { + + def apply(plan: LogicalPlan): LogicalPlan = plan transform { + case f @ Filter(cond, _) => f.copy(condition = replaceNullWithFalse(cond)) + case j @ Join(_, _, _, Some(cond)) => j.copy(condition = Some(replaceNullWithFalse(cond))) + case p: LogicalPlan => p transformExpressions { + case i @ If(pred, _, _) => i.copy(predicate = replaceNullWithFalse(pred)) + case cw @ CaseWhen(branches, _) => + val newBranches = branches.map { case (cond, value) => + replaceNullWithFalse(cond) -> value + } + cw.copy(branches = newBranches) + case af @ ArrayFilter(_, lf @ LambdaFunction(func, _, _)) => + val newLambda = lf.copy(function = replaceNullWithFalse(func)) + af.copy(function = newLambda) + case ae @ ArrayExists(_, lf @ LambdaFunction(func, _, _)) => + val newLambda = lf.copy(function = replaceNullWithFalse(func)) + ae.copy(function = newLambda) + case mf @ MapFilter(_, lf @ LambdaFunction(func, _, _)) => + val newLambda = lf.copy(function = replaceNullWithFalse(func)) + mf.copy(function = newLambda) + } + } + + /** + * Recursively traverse the Boolean-type expression to replace + * `Literal(null, BooleanType)` with `FalseLiteral`, if possible. + * + * Note that `transformExpressionsDown` can not be used here as we must stop as soon as we hit + * an expression that is not [[CaseWhen]], [[If]], [[And]], [[Or]] or + * `Literal(null, BooleanType)`. + */ + private def replaceNullWithFalse(e: Expression): Expression = e match { + case Literal(null, BooleanType) => + FalseLiteral + case And(left, right) => + And(replaceNullWithFalse(left), replaceNullWithFalse(right)) + case Or(left, right) => + Or(replaceNullWithFalse(left), replaceNullWithFalse(right)) + case cw: CaseWhen if cw.dataType == BooleanType => + val newBranches = cw.branches.map { case (cond, value) => + replaceNullWithFalse(cond) -> replaceNullWithFalse(value) + } + val newElseValue = cw.elseValue.map(replaceNullWithFalse) + CaseWhen(newBranches, newElseValue) + case i @ If(pred, trueVal, falseVal) if i.dataType == BooleanType => + If(replaceNullWithFalse(pred), replaceNullWithFalse(trueVal), replaceNullWithFalse(falseVal)) + case e if e.dataType == BooleanType => + e + case e => + val message = "Expected a Boolean type expression in replaceNullWithFalse, " + + s"but got the type `${e.dataType.catalogString}` in `${e.sql}`." + if (Utils.isTesting) { + throw new IllegalArgumentException(message) + } else { + logWarning(message) + e + } + } +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala index 2b29b49d00ab9..468a950fb1087 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala @@ -736,60 +736,3 @@ object CombineConcats extends Rule[LogicalPlan] { flattenConcats(concat) } } - -/** - * A rule that replaces `Literal(null, _)` with `FalseLiteral` for further optimizations. - * - * This rule applies to conditions in [[Filter]] and [[Join]]. Moreover, it transforms predicates - * in all [[If]] expressions as well as branch conditions in all [[CaseWhen]] expressions. - * - * For example, `Filter(Literal(null, _))` is equal to `Filter(FalseLiteral)`. - * - * Another example containing branches is `Filter(If(cond, FalseLiteral, Literal(null, _)))`; - * this can be optimized to `Filter(If(cond, FalseLiteral, FalseLiteral))`, and eventually - * `Filter(FalseLiteral)`. - * - * As this rule is not limited to conditions in [[Filter]] and [[Join]], arbitrary plans can - * benefit from it. For example, `Project(If(And(cond, Literal(null)), Literal(1), Literal(2)))` - * can be simplified into `Project(Literal(2))`. - * - * As a result, many unnecessary computations can be removed in the query optimization phase. - */ -object ReplaceNullWithFalse extends Rule[LogicalPlan] { - - def apply(plan: LogicalPlan): LogicalPlan = plan transform { - case f @ Filter(cond, _) => f.copy(condition = replaceNullWithFalse(cond)) - case j @ Join(_, _, _, Some(cond)) => j.copy(condition = Some(replaceNullWithFalse(cond))) - case p: LogicalPlan => p transformExpressions { - case i @ If(pred, _, _) => i.copy(predicate = replaceNullWithFalse(pred)) - case cw @ CaseWhen(branches, _) => - val newBranches = branches.map { case (cond, value) => - replaceNullWithFalse(cond) -> value - } - cw.copy(branches = newBranches) - } - } - - /** - * Recursively replaces `Literal(null, _)` with `FalseLiteral`. - * - * Note that `transformExpressionsDown` can not be used here as we must stop as soon as we hit - * an expression that is not [[CaseWhen]], [[If]], [[And]], [[Or]] or `Literal(null, _)`. - */ - private def replaceNullWithFalse(e: Expression): Expression = e match { - case cw: CaseWhen if cw.dataType == BooleanType => - val newBranches = cw.branches.map { case (cond, value) => - replaceNullWithFalse(cond) -> replaceNullWithFalse(value) - } - val newElseValue = cw.elseValue.map(replaceNullWithFalse) - CaseWhen(newBranches, newElseValue) - case i @ If(pred, trueVal, falseVal) if i.dataType == BooleanType => - If(replaceNullWithFalse(pred), replaceNullWithFalse(trueVal), replaceNullWithFalse(falseVal)) - case And(left, right) => - And(replaceNullWithFalse(left), replaceNullWithFalse(right)) - case Or(left, right) => - Or(replaceNullWithFalse(left), replaceNullWithFalse(right)) - case Literal(null, _) => FalseLiteral - case _ => e - } -} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/joins.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/joins.scala index 7149edee0173e..6ebb194d71c2e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/joins.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/joins.scala @@ -155,19 +155,20 @@ object EliminateOuterJoin extends Rule[LogicalPlan] with PredicateHelper { } /** - * PythonUDF in join condition can not be evaluated, this rule will detect the PythonUDF - * and pull them out from join condition. For python udf accessing attributes from only one side, - * they are pushed down by operation push down rules. If not (e.g. user disables filter push - * down rules), we need to pull them out in this rule too. + * PythonUDF in join condition can't be evaluated if it refers to attributes from both join sides. + * See `ExtractPythonUDFs` for details. This rule will detect un-evaluable PythonUDF and pull them + * out from join condition. */ object PullOutPythonUDFInJoinCondition extends Rule[LogicalPlan] with PredicateHelper { - def hasPythonUDF(expression: Expression): Boolean = { - expression.collectFirst { case udf: PythonUDF => udf }.isDefined + + private def hasUnevaluablePythonUDF(expr: Expression, j: Join): Boolean = { + expr.find { e => + PythonUDF.isScalarPythonUDF(e) && !canEvaluate(e, j.left) && !canEvaluate(e, j.right) + }.isDefined } override def apply(plan: LogicalPlan): LogicalPlan = plan transformUp { - case j @ Join(_, _, joinType, condition) - if condition.isDefined && hasPythonUDF(condition.get) => + case j @ Join(_, _, joinType, Some(cond)) if hasUnevaluablePythonUDF(cond, j) => if (!joinType.isInstanceOf[InnerLike] && joinType != LeftSemi) { // The current strategy only support InnerLike and LeftSemi join because for other type, // it breaks SQL semantic if we run the join condition as a filter after join. If we pass @@ -179,10 +180,9 @@ object PullOutPythonUDFInJoinCondition extends Rule[LogicalPlan] with PredicateH } // If condition expression contains python udf, it will be moved out from // the new join conditions. - val (udf, rest) = - splitConjunctivePredicates(condition.get).partition(hasPythonUDF) + val (udf, rest) = splitConjunctivePredicates(cond).partition(hasUnevaluablePythonUDF(_, j)) val newCondition = if (rest.isEmpty) { - logWarning(s"The join condition:$condition of the join plan contains PythonUDF only," + + logWarning(s"The join condition:$cond of the join plan contains PythonUDF only," + s" it will be moved out and the join plan will be turned to cross join.") None } else { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala index f09c5ceefed13..a26ec4eed8648 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala @@ -17,15 +17,15 @@ package org.apache.spark.sql.catalyst.plans.logical -import org.apache.spark.sql.catalyst.{AliasIdentifier} +import org.apache.spark.sql.catalyst.AliasIdentifier import org.apache.spark.sql.catalyst.analysis.{MultiInstanceRelation, NamedRelation} import org.apache.spark.sql.catalyst.catalog.{CatalogStorageFormat, CatalogTable} import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression +import org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateExpression, AggregateFunction} import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.plans.physical.{HashPartitioning, Partitioning, RangePartitioning, RoundRobinPartitioning} +import org.apache.spark.sql.catalyst.util.truncatedString import org.apache.spark.sql.types._ -import org.apache.spark.util.Utils import org.apache.spark.util.random.RandomSampler /** @@ -485,7 +485,7 @@ case class With(child: LogicalPlan, cteRelations: Seq[(String, SubqueryAlias)]) override def output: Seq[Attribute] = child.output override def simpleString: String = { - val cteAliases = Utils.truncatedString(cteRelations.map(_._1), "[", ", ", "]") + val cteAliases = truncatedString(cteRelations.map(_._1), "[", ", ", "]") s"CTE $cteAliases" } @@ -575,6 +575,18 @@ case class Range( } } +/** + * This is a Group by operator with the aggregate functions and projections. + * + * @param groupingExpressions expressions for grouping keys + * @param aggregateExpressions expressions for a project list, which could contain + * [[AggregateFunction]]s. + * + * Note: Currently, aggregateExpressions is the project list of this Group by operator. Before + * separating projection from grouping and aggregate, we should avoid expression-level optimization + * on aggregateExpressions, which could reference an expression in groupingExpressions. + * For example, see the rule [[org.apache.spark.sql.catalyst.optimizer.SimplifyExtractValueOps]] + */ case class Aggregate( groupingExpressions: Seq[Expression], aggregateExpressions: Seq[NamedExpression], diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/rules/RuleExecutor.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/rules/RuleExecutor.scala index e991a2dc7462f..cf6ff4f986399 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/rules/RuleExecutor.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/rules/RuleExecutor.scala @@ -18,6 +18,7 @@ package org.apache.spark.sql.catalyst.rules import org.apache.spark.internal.Logging +import org.apache.spark.sql.catalyst.QueryPlanningTracker import org.apache.spark.sql.catalyst.errors.TreeNodeException import org.apache.spark.sql.catalyst.trees.TreeNode import org.apache.spark.sql.catalyst.util.sideBySide @@ -66,6 +67,17 @@ abstract class RuleExecutor[TreeType <: TreeNode[_]] extends Logging { */ protected def isPlanIntegral(plan: TreeType): Boolean = true + /** + * Executes the batches of rules defined by the subclass, and also tracks timing info for each + * rule using the provided tracker. + * @see [[execute]] + */ + def executeAndTrack(plan: TreeType, tracker: QueryPlanningTracker): TreeType = { + QueryPlanningTracker.withTracker(tracker) { + execute(plan) + } + } + /** * Executes the batches of rules defined by the subclass. The batches are executed serially * using the defined execution strategy. Within each batch, rules are also executed serially. @@ -74,6 +86,7 @@ abstract class RuleExecutor[TreeType <: TreeNode[_]] extends Logging { var curPlan = plan val queryExecutionMetrics = RuleExecutor.queryExecutionMeter val planChangeLogger = new PlanChangeLogger() + val tracker: Option[QueryPlanningTracker] = QueryPlanningTracker.get batches.foreach { batch => val batchStartPlan = curPlan @@ -88,8 +101,9 @@ abstract class RuleExecutor[TreeType <: TreeNode[_]] extends Logging { val startTime = System.nanoTime() val result = rule(plan) val runTime = System.nanoTime() - startTime + val effective = !result.fastEquals(plan) - if (!result.fastEquals(plan)) { + if (effective) { queryExecutionMetrics.incNumEffectiveExecution(rule.ruleName) queryExecutionMetrics.incTimeEffectiveExecutionBy(rule.ruleName, runTime) planChangeLogger.log(rule.ruleName, plan, result) @@ -97,6 +111,9 @@ abstract class RuleExecutor[TreeType <: TreeNode[_]] extends Logging { queryExecutionMetrics.incExecutionTimeBy(rule.ruleName, runTime) queryExecutionMetrics.incNumExecution(rule.ruleName) + // Record timing information using QueryPlanningTracker + tracker.foreach(_.recordRuleInvocation(rule.ruleName, runTime, effective)) + // Run the structural integrity checker against the plan after each rule. if (!isPlanIntegral(result)) { val message = s"After applying rule ${rule.ruleName} in batch ${batch.name}, " + diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala index 566b14d7d0e19..624680eb8ec0a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala @@ -17,11 +17,13 @@ package org.apache.spark.sql.catalyst.trees +import java.io.Writer import java.util.UUID import scala.collection.Map import scala.reflect.ClassTag +import org.apache.commons.io.output.StringBuilderWriter import org.apache.commons.lang3.ClassUtils import org.json4s.JsonAST._ import org.json4s.JsonDSL._ @@ -35,9 +37,9 @@ import org.apache.spark.sql.catalyst.errors._ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.JoinType import org.apache.spark.sql.catalyst.plans.physical.{BroadcastMode, Partitioning} +import org.apache.spark.sql.catalyst.util.truncatedString import org.apache.spark.sql.types._ import org.apache.spark.storage.StorageLevel -import org.apache.spark.util.Utils /** Used by [[TreeNode.getNodeNumbered]] when traversing the tree for a given number */ private class MutableInt(var i: Int) @@ -440,10 +442,10 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]] extends Product { case tn: TreeNode[_] => tn.simpleString :: Nil case seq: Seq[Any] if seq.toSet.subsetOf(allChildren.asInstanceOf[Set[Any]]) => Nil case iter: Iterable[_] if iter.isEmpty => Nil - case seq: Seq[_] => Utils.truncatedString(seq, "[", ", ", "]") :: Nil - case set: Set[_] => Utils.truncatedString(set.toSeq, "{", ", ", "}") :: Nil + case seq: Seq[_] => truncatedString(seq, "[", ", ", "]") :: Nil + case set: Set[_] => truncatedString(set.toSeq, "{", ", ", "}") :: Nil case array: Array[_] if array.isEmpty => Nil - case array: Array[_] => Utils.truncatedString(array, "[", ", ", "]") :: Nil + case array: Array[_] => truncatedString(array, "[", ", ", "]") :: Nil case null => Nil case None => Nil case Some(null) => Nil @@ -471,7 +473,20 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]] extends Product { def treeString: String = treeString(verbose = true) def treeString(verbose: Boolean, addSuffix: Boolean = false): String = { - generateTreeString(0, Nil, new StringBuilder, verbose = verbose, addSuffix = addSuffix).toString + val writer = new StringBuilderWriter() + try { + treeString(writer, verbose, addSuffix) + writer.toString + } finally { + writer.close() + } + } + + def treeString( + writer: Writer, + verbose: Boolean, + addSuffix: Boolean): Unit = { + generateTreeString(0, Nil, writer, verbose, "", addSuffix) } /** @@ -523,7 +538,7 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]] extends Product { protected def innerChildren: Seq[TreeNode[_]] = Seq.empty /** - * Appends the string representation of this node and its children to the given StringBuilder. + * Appends the string representation of this node and its children to the given Writer. * * The `i`-th element in `lastChildren` indicates whether the ancestor of the current node at * depth `i + 1` is the last child of its own parent node. The depth of the root node is 0, and @@ -534,16 +549,16 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]] extends Product { def generateTreeString( depth: Int, lastChildren: Seq[Boolean], - builder: StringBuilder, + writer: Writer, verbose: Boolean, prefix: String = "", - addSuffix: Boolean = false): StringBuilder = { + addSuffix: Boolean = false): Unit = { if (depth > 0) { lastChildren.init.foreach { isLast => - builder.append(if (isLast) " " else ": ") + writer.write(if (isLast) " " else ": ") } - builder.append(if (lastChildren.last) "+- " else ":- ") + writer.write(if (lastChildren.last) "+- " else ":- ") } val str = if (verbose) { @@ -551,27 +566,25 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]] extends Product { } else { simpleString } - builder.append(prefix) - builder.append(str) - builder.append("\n") + writer.write(prefix) + writer.write(str) + writer.write("\n") if (innerChildren.nonEmpty) { innerChildren.init.foreach(_.generateTreeString( - depth + 2, lastChildren :+ children.isEmpty :+ false, builder, verbose, + depth + 2, lastChildren :+ children.isEmpty :+ false, writer, verbose, addSuffix = addSuffix)) innerChildren.last.generateTreeString( - depth + 2, lastChildren :+ children.isEmpty :+ true, builder, verbose, + depth + 2, lastChildren :+ children.isEmpty :+ true, writer, verbose, addSuffix = addSuffix) } if (children.nonEmpty) { children.init.foreach(_.generateTreeString( - depth + 1, lastChildren :+ false, builder, verbose, prefix, addSuffix)) + depth + 1, lastChildren :+ false, writer, verbose, prefix, addSuffix)) children.last.generateTreeString( - depth + 1, lastChildren :+ true, builder, verbose, prefix, addSuffix) + depth + 1, lastChildren :+ true, writer, verbose, prefix, addSuffix) } - - builder } /** @@ -653,7 +666,7 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]] extends Product { t.forall(_.isInstanceOf[Partitioning]) || t.forall(_.isInstanceOf[DataType]) => JArray(t.map(parseToJson).toList) case t: Seq[_] if t.length > 0 && t.head.isInstanceOf[String] => - JString(Utils.truncatedString(t, "[", ", ", "]")) + JString(truncatedString(t, "[", ", ", "]")) case t: Seq[_] => JNull case m: Map[_, _] => JNull // if it's a scala object, we can simply keep the full class path. diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/HyperLogLogPlusPlusHelper.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/HyperLogLogPlusPlusHelper.scala index 9bacd3b925be3..ea619c6a7666c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/HyperLogLogPlusPlusHelper.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/HyperLogLogPlusPlusHelper.scala @@ -199,7 +199,7 @@ class HyperLogLogPlusPlusHelper(relativeSD: Double) extends Serializable { var shift = 0 while (idx < m && i < REGISTERS_PER_WORD) { val Midx = (word >>> shift) & REGISTER_WORD_MASK - zInverse += 1.0 / (1 << Midx) + zInverse += 1.0 / (1L << Midx) if (Midx == 0) { V += 1.0d } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/QuantileSummaries.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/QuantileSummaries.scala index 3190e511e2cb5..2a03f85ab594b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/QuantileSummaries.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/QuantileSummaries.scala @@ -25,7 +25,7 @@ import org.apache.spark.sql.catalyst.util.QuantileSummaries.Stats * Helper class to compute approximate quantile summary. * This implementation is based on the algorithm proposed in the paper: * "Space-efficient Online Computation of Quantile Summaries" by Greenwald, Michael - * and Khanna, Sanjeev. (http://dx.doi.org/10.1145/375663.375670) + * and Khanna, Sanjeev. (https://doi.org/10.1145/375663.375670) * * In order to optimize for speed, it maintains an internal buffer of the last seen samples, * and only inserts them after crossing a certain size threshold. This guarantees a near-constant diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/TypeUtils.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/TypeUtils.scala index 76218b459ef0d..2a71fdb7592bb 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/TypeUtils.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/TypeUtils.scala @@ -46,12 +46,20 @@ object TypeUtils { if (TypeCoercion.haveSameType(types)) { TypeCheckResult.TypeCheckSuccess } else { - return TypeCheckResult.TypeCheckFailure( + TypeCheckResult.TypeCheckFailure( s"input to $caller should all be the same type, but it's " + types.map(_.catalogString).mkString("[", ", ", "]")) } } + def checkForMapKeyType(keyType: DataType): TypeCheckResult = { + if (keyType.existsRecursively(_.isInstanceOf[MapType])) { + TypeCheckResult.TypeCheckFailure("The key of map cannot be/contain map.") + } else { + TypeCheckResult.TypeCheckSuccess + } + } + def getNumeric(t: DataType): Numeric[Any] = t.asInstanceOf[NumericType].numeric.asInstanceOf[Numeric[Any]] diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/package.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/package.scala index 0978e92dd4f72..277584b20dcd2 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/package.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/package.scala @@ -19,13 +19,16 @@ package org.apache.spark.sql.catalyst import java.io._ import java.nio.charset.StandardCharsets +import java.util.concurrent.atomic.AtomicBoolean +import org.apache.spark.internal.Logging import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types.{NumericType, StringType} import org.apache.spark.unsafe.types.UTF8String import org.apache.spark.util.Utils -package object util { +package object util extends Logging { /** Silences output to stderr or stdout for the duration of f */ def quietly[A](f: => A): A = { @@ -167,6 +170,38 @@ package object util { builder.toString() } + /** Whether we have warned about plan string truncation yet. */ + private val truncationWarningPrinted = new AtomicBoolean(false) + + /** + * Format a sequence with semantics similar to calling .mkString(). Any elements beyond + * maxNumToStringFields will be dropped and replaced by a "... N more fields" placeholder. + * + * @return the trimmed and formatted string. + */ + def truncatedString[T]( + seq: Seq[T], + start: String, + sep: String, + end: String, + maxNumFields: Int = SQLConf.get.maxToStringFields): String = { + if (seq.length > maxNumFields) { + if (truncationWarningPrinted.compareAndSet(false, true)) { + logWarning( + "Truncated the string representation of a plan since it was too large. This " + + s"behavior can be adjusted by setting '${SQLConf.MAX_TO_STRING_FIELDS.key}'.") + } + val numFields = math.max(0, maxNumFields - 1) + seq.take(numFields).mkString( + start, sep, sep + "... " + (seq.length - numFields) + " more fields" + end) + } else { + seq.mkString(start, sep, end) + } + } + + /** Shorthand for calling truncatedString() without start or end strings. */ + def truncatedString[T](seq: Seq[T], sep: String): String = truncatedString(seq, "", sep, "") + /* FIX ME implicit class debugLogging(a: Any) { def debugLogging() { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index cf67ac1ff813c..be4496c87293e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -1611,6 +1611,22 @@ object SQLConf { .booleanConf .createWithDefault(false) + val NAME_NON_STRUCT_GROUPING_KEY_AS_VALUE = + buildConf("spark.sql.legacy.dataset.nameNonStructGroupingKeyAsValue") + .internal() + .doc("When set to true, the key attribute resulted from running `Dataset.groupByKey` " + + "for non-struct key type, will be named as `value`, following the behavior of Spark " + + "version 2.4 and earlier.") + .booleanConf + .createWithDefault(false) + + val MAX_TO_STRING_FIELDS = buildConf("spark.sql.debug.maxToStringFields") + .doc("Maximum number of fields of sequence-like entries can be converted to strings " + + "in debug output. Any elements beyond the limit will be dropped and replaced by a" + + """ "... N more fields" placeholder.""") + .intConf + .createWithDefault(25) + val MAX_REPEATED_ALIAS_SIZE = buildConf("spark.sql.maxRepeatedAliasSize") .internal() @@ -2047,6 +2063,11 @@ class SQLConf extends Serializable with Logging { def integralDivideReturnLong: Boolean = getConf(SQLConf.LEGACY_INTEGRALDIVIDE_RETURN_LONG) + def nameNonStructGroupingKeyAsValue: Boolean = + getConf(SQLConf.NAME_NON_STRUCT_GROUPING_KEY_AS_VALUE) + + def maxToStringFields: Int = getConf(SQLConf.MAX_TO_STRING_FIELDS) + def maxRepeatedAliasSize: Int = getConf(SQLConf.MAX_REPEATED_ALIAS_SIZE) /** ********************** SQLConf functionality methods ************ */ diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/AbstractDataType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/AbstractDataType.scala index c43cc748655e8..5367ce2af8e9f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/AbstractDataType.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/AbstractDataType.scala @@ -19,7 +19,7 @@ package org.apache.spark.sql.types import scala.reflect.runtime.universe.TypeTag -import org.apache.spark.annotation.InterfaceStability +import org.apache.spark.annotation.Stable import org.apache.spark.sql.catalyst.expressions.Expression /** @@ -134,7 +134,7 @@ object AtomicType { * * @since 1.3.0 */ -@InterfaceStability.Stable +@Stable abstract class NumericType extends AtomicType { // Unfortunately we can't get this implicitly as that breaks Spark Serialization. In order for // implicitly[Numeric[JvmType]] to be valid, we have to change JvmType from a type variable to a diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/ArrayType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/ArrayType.scala index 58c75b5dc7a35..7465569868f07 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/ArrayType.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/ArrayType.scala @@ -21,7 +21,7 @@ import scala.math.Ordering import org.json4s.JsonDSL._ -import org.apache.spark.annotation.InterfaceStability +import org.apache.spark.annotation.Stable import org.apache.spark.sql.catalyst.util.ArrayData /** @@ -29,7 +29,7 @@ import org.apache.spark.sql.catalyst.util.ArrayData * * @since 1.3.0 */ -@InterfaceStability.Stable +@Stable object ArrayType extends AbstractDataType { /** * Construct a [[ArrayType]] object with the given element type. The `containsNull` is true. @@ -60,7 +60,7 @@ object ArrayType extends AbstractDataType { * * @since 1.3.0 */ -@InterfaceStability.Stable +@Stable case class ArrayType(elementType: DataType, containsNull: Boolean) extends DataType { /** No-arg constructor for kryo. */ diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/BinaryType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/BinaryType.scala index 032d6b54aeb79..cc8b3e6e399a2 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/BinaryType.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/BinaryType.scala @@ -20,15 +20,14 @@ package org.apache.spark.sql.types import scala.math.Ordering import scala.reflect.runtime.universe.typeTag -import org.apache.spark.annotation.InterfaceStability +import org.apache.spark.annotation.Stable import org.apache.spark.sql.catalyst.util.TypeUtils - /** * The data type representing `Array[Byte]` values. * Please use the singleton `DataTypes.BinaryType`. */ -@InterfaceStability.Stable +@Stable class BinaryType private() extends AtomicType { // The companion object and this class is separated so the companion object also subclasses // this type. Otherwise, the companion object would be of type "BinaryType$" in byte code. @@ -55,5 +54,5 @@ class BinaryType private() extends AtomicType { /** * @since 1.3.0 */ -@InterfaceStability.Stable +@Stable case object BinaryType extends BinaryType diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/BooleanType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/BooleanType.scala index 63f354d2243cf..5e3de71caa37e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/BooleanType.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/BooleanType.scala @@ -20,15 +20,14 @@ package org.apache.spark.sql.types import scala.math.Ordering import scala.reflect.runtime.universe.typeTag -import org.apache.spark.annotation.InterfaceStability - +import org.apache.spark.annotation.Stable /** * The data type representing `Boolean` values. Please use the singleton `DataTypes.BooleanType`. * * @since 1.3.0 */ -@InterfaceStability.Stable +@Stable class BooleanType private() extends AtomicType { // The companion object and this class is separated so the companion object also subclasses // this type. Otherwise, the companion object would be of type "BooleanType$" in byte code. @@ -48,5 +47,5 @@ class BooleanType private() extends AtomicType { /** * @since 1.3.0 */ -@InterfaceStability.Stable +@Stable case object BooleanType extends BooleanType diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/ByteType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/ByteType.scala index 5854c3f5ba116..9d400eefc0f8d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/ByteType.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/ByteType.scala @@ -20,14 +20,14 @@ package org.apache.spark.sql.types import scala.math.{Integral, Numeric, Ordering} import scala.reflect.runtime.universe.typeTag -import org.apache.spark.annotation.InterfaceStability +import org.apache.spark.annotation.Stable /** * The data type representing `Byte` values. Please use the singleton `DataTypes.ByteType`. * * @since 1.3.0 */ -@InterfaceStability.Stable +@Stable class ByteType private() extends IntegralType { // The companion object and this class is separated so the companion object also subclasses // this type. Otherwise, the companion object would be of type "ByteType$" in byte code. @@ -52,5 +52,5 @@ class ByteType private() extends IntegralType { /** * @since 1.3.0 */ -@InterfaceStability.Stable +@Stable case object ByteType extends ByteType diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/CalendarIntervalType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/CalendarIntervalType.scala index 2342036a57460..8e297874a0d62 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/CalendarIntervalType.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/CalendarIntervalType.scala @@ -17,7 +17,7 @@ package org.apache.spark.sql.types -import org.apache.spark.annotation.InterfaceStability +import org.apache.spark.annotation.Stable /** * The data type representing calendar time intervals. The calendar time interval is stored @@ -29,7 +29,7 @@ import org.apache.spark.annotation.InterfaceStability * * @since 1.5.0 */ -@InterfaceStability.Stable +@Stable class CalendarIntervalType private() extends DataType { override def defaultSize: Int = 16 @@ -40,5 +40,5 @@ class CalendarIntervalType private() extends DataType { /** * @since 1.5.0 */ -@InterfaceStability.Stable +@Stable case object CalendarIntervalType extends CalendarIntervalType diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DataType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DataType.scala index e53628d11ccf3..c58f7a2397374 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DataType.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DataType.scala @@ -26,7 +26,7 @@ import org.json4s.JsonAST.JValue import org.json4s.JsonDSL._ import org.json4s.jackson.JsonMethods._ -import org.apache.spark.annotation.InterfaceStability +import org.apache.spark.annotation.Stable import org.apache.spark.sql.catalyst.analysis.Resolver import org.apache.spark.sql.catalyst.expressions.{Cast, Expression} import org.apache.spark.sql.catalyst.parser.CatalystSqlParser @@ -38,7 +38,7 @@ import org.apache.spark.util.Utils * * @since 1.3.0 */ -@InterfaceStability.Stable +@Stable abstract class DataType extends AbstractDataType { /** * Enables matching against DataType for expressions: @@ -111,7 +111,7 @@ abstract class DataType extends AbstractDataType { /** * @since 1.3.0 */ -@InterfaceStability.Stable +@Stable object DataType { private val FIXED_DECIMAL = """decimal\(\s*(\d+)\s*,\s*(\-?\d+)\s*\)""".r @@ -180,7 +180,7 @@ object DataType { ("pyClass", _), ("sqlType", _), ("type", JString("udt"))) => - Utils.classForName(udtClass).newInstance().asInstanceOf[UserDefinedType[_]] + Utils.classForName(udtClass).getConstructor().newInstance().asInstanceOf[UserDefinedType[_]] // Python UDT case JSortedObject( diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DateType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DateType.scala index 9e70dd486a125..7491014b22dab 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DateType.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DateType.scala @@ -20,7 +20,7 @@ package org.apache.spark.sql.types import scala.math.Ordering import scala.reflect.runtime.universe.typeTag -import org.apache.spark.annotation.InterfaceStability +import org.apache.spark.annotation.Stable /** * A date type, supporting "0001-01-01" through "9999-12-31". @@ -31,7 +31,7 @@ import org.apache.spark.annotation.InterfaceStability * * @since 1.3.0 */ -@InterfaceStability.Stable +@Stable class DateType private() extends AtomicType { // The companion object and this class is separated so the companion object also subclasses // this type. Otherwise, the companion object would be of type "DateType$" in byte code. @@ -53,5 +53,5 @@ class DateType private() extends AtomicType { /** * @since 1.3.0 */ -@InterfaceStability.Stable +@Stable case object DateType extends DateType diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/Decimal.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/Decimal.scala index 9eed2eb202045..0192059a3a39f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/Decimal.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/Decimal.scala @@ -20,7 +20,7 @@ package org.apache.spark.sql.types import java.lang.{Long => JLong} import java.math.{BigInteger, MathContext, RoundingMode} -import org.apache.spark.annotation.InterfaceStability +import org.apache.spark.annotation.Unstable import org.apache.spark.sql.AnalysisException /** @@ -31,7 +31,7 @@ import org.apache.spark.sql.AnalysisException * - If decimalVal is set, it represents the whole decimal value * - Otherwise, the decimal value is longVal / (10 ** _scale) */ -@InterfaceStability.Unstable +@Unstable final class Decimal extends Ordered[Decimal] with Serializable { import org.apache.spark.sql.types.Decimal._ @@ -185,9 +185,21 @@ final class Decimal extends Ordered[Decimal] with Serializable { } } - def toScalaBigInt: BigInt = BigInt(toLong) + def toScalaBigInt: BigInt = { + if (decimalVal.ne(null)) { + decimalVal.toBigInt() + } else { + BigInt(toLong) + } + } - def toJavaBigInteger: java.math.BigInteger = java.math.BigInteger.valueOf(toLong) + def toJavaBigInteger: java.math.BigInteger = { + if (decimalVal.ne(null)) { + decimalVal.underlying().toBigInteger() + } else { + java.math.BigInteger.valueOf(toLong) + } + } def toUnscaledLong: Long = { if (decimalVal.ne(null)) { @@ -407,7 +419,7 @@ final class Decimal extends Ordered[Decimal] with Serializable { } } -@InterfaceStability.Unstable +@Unstable object Decimal { val ROUND_HALF_UP = BigDecimal.RoundingMode.HALF_UP val ROUND_HALF_EVEN = BigDecimal.RoundingMode.HALF_EVEN diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DecimalType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DecimalType.scala index 15004e4b9667d..25eddaf06a780 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DecimalType.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DecimalType.scala @@ -21,11 +21,10 @@ import java.util.Locale import scala.reflect.runtime.universe.typeTag -import org.apache.spark.annotation.InterfaceStability +import org.apache.spark.annotation.Stable import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.expressions.{Expression, Literal} - /** * The data type representing `java.math.BigDecimal` values. * A Decimal that must have fixed precision (the maximum number of digits) and scale (the number @@ -39,7 +38,7 @@ import org.apache.spark.sql.catalyst.expressions.{Expression, Literal} * * @since 1.3.0 */ -@InterfaceStability.Stable +@Stable case class DecimalType(precision: Int, scale: Int) extends FractionalType { if (scale > precision) { @@ -110,7 +109,7 @@ case class DecimalType(precision: Int, scale: Int) extends FractionalType { * * @since 1.3.0 */ -@InterfaceStability.Stable +@Stable object DecimalType extends AbstractDataType { import scala.math.min diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DoubleType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DoubleType.scala index a5c79ff01ca06..afd3353397019 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DoubleType.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DoubleType.scala @@ -21,7 +21,7 @@ import scala.math.{Fractional, Numeric, Ordering} import scala.math.Numeric.DoubleAsIfIntegral import scala.reflect.runtime.universe.typeTag -import org.apache.spark.annotation.InterfaceStability +import org.apache.spark.annotation.Stable import org.apache.spark.util.Utils /** @@ -29,7 +29,7 @@ import org.apache.spark.util.Utils * * @since 1.3.0 */ -@InterfaceStability.Stable +@Stable class DoubleType private() extends FractionalType { // The companion object and this class is separated so the companion object also subclasses // this type. Otherwise, the companion object would be of type "DoubleType$" in byte code. @@ -54,5 +54,5 @@ class DoubleType private() extends FractionalType { /** * @since 1.3.0 */ -@InterfaceStability.Stable +@Stable case object DoubleType extends DoubleType diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/FloatType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/FloatType.scala index 352147ec936c9..6d98987304081 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/FloatType.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/FloatType.scala @@ -21,7 +21,7 @@ import scala.math.{Fractional, Numeric, Ordering} import scala.math.Numeric.FloatAsIfIntegral import scala.reflect.runtime.universe.typeTag -import org.apache.spark.annotation.InterfaceStability +import org.apache.spark.annotation.Stable import org.apache.spark.util.Utils /** @@ -29,7 +29,7 @@ import org.apache.spark.util.Utils * * @since 1.3.0 */ -@InterfaceStability.Stable +@Stable class FloatType private() extends FractionalType { // The companion object and this class is separated so the companion object also subclasses // this type. Otherwise, the companion object would be of type "FloatType$" in byte code. @@ -55,5 +55,5 @@ class FloatType private() extends FractionalType { /** * @since 1.3.0 */ -@InterfaceStability.Stable +@Stable case object FloatType extends FloatType diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/IntegerType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/IntegerType.scala index a85e3729188d9..0755202d20df1 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/IntegerType.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/IntegerType.scala @@ -20,14 +20,14 @@ package org.apache.spark.sql.types import scala.math.{Integral, Numeric, Ordering} import scala.reflect.runtime.universe.typeTag -import org.apache.spark.annotation.InterfaceStability +import org.apache.spark.annotation.Stable /** * The data type representing `Int` values. Please use the singleton `DataTypes.IntegerType`. * * @since 1.3.0 */ -@InterfaceStability.Stable +@Stable class IntegerType private() extends IntegralType { // The companion object and this class is separated so the companion object also subclasses // this type. Otherwise, the companion object would be of type "IntegerType$" in byte code. @@ -51,5 +51,5 @@ class IntegerType private() extends IntegralType { /** * @since 1.3.0 */ -@InterfaceStability.Stable +@Stable case object IntegerType extends IntegerType diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/LongType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/LongType.scala index 0997028fc1057..3c49c721fdc88 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/LongType.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/LongType.scala @@ -20,14 +20,14 @@ package org.apache.spark.sql.types import scala.math.{Integral, Numeric, Ordering} import scala.reflect.runtime.universe.typeTag -import org.apache.spark.annotation.InterfaceStability +import org.apache.spark.annotation.Stable /** * The data type representing `Long` values. Please use the singleton `DataTypes.LongType`. * * @since 1.3.0 */ -@InterfaceStability.Stable +@Stable class LongType private() extends IntegralType { // The companion object and this class is separated so the companion object also subclasses // this type. Otherwise, the companion object would be of type "LongType$" in byte code. @@ -51,5 +51,5 @@ class LongType private() extends IntegralType { /** * @since 1.3.0 */ -@InterfaceStability.Stable +@Stable case object LongType extends LongType diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/MapType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/MapType.scala index 594e155268bf6..29b9ffc0c3549 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/MapType.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/MapType.scala @@ -20,7 +20,7 @@ package org.apache.spark.sql.types import org.json4s.JsonAST.JValue import org.json4s.JsonDSL._ -import org.apache.spark.annotation.InterfaceStability +import org.apache.spark.annotation.Stable /** * The data type for Maps. Keys in a map are not allowed to have `null` values. @@ -31,7 +31,7 @@ import org.apache.spark.annotation.InterfaceStability * @param valueType The data type of map values. * @param valueContainsNull Indicates if map values have `null` values. */ -@InterfaceStability.Stable +@Stable case class MapType( keyType: DataType, valueType: DataType, @@ -78,7 +78,7 @@ case class MapType( /** * @since 1.3.0 */ -@InterfaceStability.Stable +@Stable object MapType extends AbstractDataType { override private[sql] def defaultConcreteType: DataType = apply(NullType, NullType) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/Metadata.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/Metadata.scala index 7c15dc0de4b6b..4979aced145c9 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/Metadata.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/Metadata.scala @@ -22,7 +22,7 @@ import scala.collection.mutable import org.json4s._ import org.json4s.jackson.JsonMethods._ -import org.apache.spark.annotation.InterfaceStability +import org.apache.spark.annotation.Stable /** @@ -37,7 +37,7 @@ import org.apache.spark.annotation.InterfaceStability * * @since 1.3.0 */ -@InterfaceStability.Stable +@Stable sealed class Metadata private[types] (private[types] val map: Map[String, Any]) extends Serializable { @@ -117,7 +117,7 @@ sealed class Metadata private[types] (private[types] val map: Map[String, Any]) /** * @since 1.3.0 */ -@InterfaceStability.Stable +@Stable object Metadata { private[this] val _empty = new Metadata(Map.empty) @@ -228,7 +228,7 @@ object Metadata { * * @since 1.3.0 */ -@InterfaceStability.Stable +@Stable class MetadataBuilder { private val map: mutable.Map[String, Any] = mutable.Map.empty diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/NullType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/NullType.scala index 494225b47a270..14097a5280d50 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/NullType.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/NullType.scala @@ -17,15 +17,14 @@ package org.apache.spark.sql.types -import org.apache.spark.annotation.InterfaceStability - +import org.apache.spark.annotation.Stable /** * The data type representing `NULL` values. Please use the singleton `DataTypes.NullType`. * * @since 1.3.0 */ -@InterfaceStability.Stable +@Stable class NullType private() extends DataType { // The companion object and this class is separated so the companion object also subclasses // this type. Otherwise, the companion object would be of type "NullType$" in byte code. @@ -38,5 +37,5 @@ class NullType private() extends DataType { /** * @since 1.3.0 */ -@InterfaceStability.Stable +@Stable case object NullType extends NullType diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/ObjectType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/ObjectType.scala index 203e85e1c99bd..6756b209f432e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/ObjectType.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/ObjectType.scala @@ -19,9 +19,9 @@ package org.apache.spark.sql.types import scala.language.existentials -import org.apache.spark.annotation.InterfaceStability +import org.apache.spark.annotation.Evolving -@InterfaceStability.Evolving +@Evolving object ObjectType extends AbstractDataType { override private[sql] def defaultConcreteType: DataType = throw new UnsupportedOperationException( @@ -38,7 +38,7 @@ object ObjectType extends AbstractDataType { /** * Represents a JVM object that is passing through Spark SQL expression evaluation. */ -@InterfaceStability.Evolving +@Evolving case class ObjectType(cls: Class[_]) extends DataType { override def defaultSize: Int = 4096 diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/ShortType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/ShortType.scala index ee655c338b59f..9b5ddfef1ccf5 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/ShortType.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/ShortType.scala @@ -20,14 +20,14 @@ package org.apache.spark.sql.types import scala.math.{Integral, Numeric, Ordering} import scala.reflect.runtime.universe.typeTag -import org.apache.spark.annotation.InterfaceStability +import org.apache.spark.annotation.Stable /** * The data type representing `Short` values. Please use the singleton `DataTypes.ShortType`. * * @since 1.3.0 */ -@InterfaceStability.Stable +@Stable class ShortType private() extends IntegralType { // The companion object and this class is separated so the companion object also subclasses // this type. Otherwise, the companion object would be of type "ShortType$" in byte code. @@ -51,5 +51,5 @@ class ShortType private() extends IntegralType { /** * @since 1.3.0 */ -@InterfaceStability.Stable +@Stable case object ShortType extends ShortType diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/StringType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/StringType.scala index 59b124cda7d14..8ce1cd078e312 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/StringType.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/StringType.scala @@ -20,7 +20,7 @@ package org.apache.spark.sql.types import scala.math.Ordering import scala.reflect.runtime.universe.typeTag -import org.apache.spark.annotation.InterfaceStability +import org.apache.spark.annotation.Stable import org.apache.spark.unsafe.types.UTF8String /** @@ -28,7 +28,7 @@ import org.apache.spark.unsafe.types.UTF8String * * @since 1.3.0 */ -@InterfaceStability.Stable +@Stable class StringType private() extends AtomicType { // The companion object and this class is separated so the companion object also subclasses // this type. Otherwise, the companion object would be of type "StringType$" in byte code. @@ -48,6 +48,6 @@ class StringType private() extends AtomicType { /** * @since 1.3.0 */ -@InterfaceStability.Stable +@Stable case object StringType extends StringType diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/StructField.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/StructField.scala index 35f9970a0aaec..6f6b561d67d49 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/StructField.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/StructField.scala @@ -20,7 +20,7 @@ package org.apache.spark.sql.types import org.json4s.JsonAST.JValue import org.json4s.JsonDSL._ -import org.apache.spark.annotation.InterfaceStability +import org.apache.spark.annotation.Stable import org.apache.spark.sql.catalyst.util.{escapeSingleQuotedString, quoteIdentifier} /** @@ -33,7 +33,7 @@ import org.apache.spark.sql.catalyst.util.{escapeSingleQuotedString, quoteIdenti * * @since 1.3.0 */ -@InterfaceStability.Stable +@Stable case class StructField( name: String, dataType: DataType, diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/StructType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/StructType.scala index 06289b1483203..6e8bbde7787a6 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/StructType.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/StructType.scala @@ -24,10 +24,10 @@ import scala.util.control.NonFatal import org.json4s.JsonDSL._ import org.apache.spark.SparkException -import org.apache.spark.annotation.InterfaceStability +import org.apache.spark.annotation.Stable import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference, InterpretedOrdering} import org.apache.spark.sql.catalyst.parser.{CatalystSqlParser, LegacyTypeStringParser} -import org.apache.spark.sql.catalyst.util.{escapeSingleQuotedString, quoteIdentifier} +import org.apache.spark.sql.catalyst.util.{quoteIdentifier, truncatedString} import org.apache.spark.util.Utils /** @@ -95,7 +95,7 @@ import org.apache.spark.util.Utils * * @since 1.3.0 */ -@InterfaceStability.Stable +@Stable case class StructType(fields: Array[StructField]) extends DataType with Seq[StructField] { /** No-arg constructor for kryo. */ @@ -346,7 +346,7 @@ case class StructType(fields: Array[StructField]) extends DataType with Seq[Stru override def simpleString: String = { val fieldTypes = fields.view.map(field => s"${field.name}:${field.dataType.simpleString}") - Utils.truncatedString(fieldTypes, "struct<", ",", ">") + truncatedString(fieldTypes, "struct<", ",", ">") } override def catalogString: String = { @@ -422,7 +422,7 @@ case class StructType(fields: Array[StructField]) extends DataType with Seq[Stru /** * @since 1.3.0 */ -@InterfaceStability.Stable +@Stable object StructType extends AbstractDataType { override private[sql] def defaultConcreteType: DataType = new StructType diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/TimestampType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/TimestampType.scala index fdb91e0499920..a20f155418f8a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/TimestampType.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/TimestampType.scala @@ -20,7 +20,7 @@ package org.apache.spark.sql.types import scala.math.Ordering import scala.reflect.runtime.universe.typeTag -import org.apache.spark.annotation.InterfaceStability +import org.apache.spark.annotation.Stable /** * The data type representing `java.sql.Timestamp` values. @@ -28,7 +28,7 @@ import org.apache.spark.annotation.InterfaceStability * * @since 1.3.0 */ -@InterfaceStability.Stable +@Stable class TimestampType private() extends AtomicType { // The companion object and this class is separated so the companion object also subclasses // this type. Otherwise, the companion object would be of type "TimestampType$" in byte code. @@ -50,5 +50,5 @@ class TimestampType private() extends AtomicType { /** * @since 1.3.0 */ -@InterfaceStability.Stable +@Stable case object TimestampType extends TimestampType diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/QueryPlanningTrackerSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/QueryPlanningTrackerSuite.scala new file mode 100644 index 0000000000000..120b284a77854 --- /dev/null +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/QueryPlanningTrackerSuite.scala @@ -0,0 +1,85 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst + +import org.apache.spark.SparkFunSuite + +class QueryPlanningTrackerSuite extends SparkFunSuite { + + test("phases") { + val t = new QueryPlanningTracker + t.measureTime("p1") { + Thread.sleep(1) + } + + assert(t.phases("p1") > 0) + assert(!t.phases.contains("p2")) + + val old = t.phases("p1") + + t.measureTime("p1") { + Thread.sleep(1) + } + assert(t.phases("p1") > old) + } + + test("rules") { + val t = new QueryPlanningTracker + t.recordRuleInvocation("r1", 1, effective = false) + t.recordRuleInvocation("r2", 2, effective = true) + t.recordRuleInvocation("r3", 1, effective = false) + t.recordRuleInvocation("r3", 2, effective = true) + + val rules = t.rules + + assert(rules("r1").totalTimeNs == 1) + assert(rules("r1").numInvocations == 1) + assert(rules("r1").numEffectiveInvocations == 0) + + assert(rules("r2").totalTimeNs == 2) + assert(rules("r2").numInvocations == 1) + assert(rules("r2").numEffectiveInvocations == 1) + + assert(rules("r3").totalTimeNs == 3) + assert(rules("r3").numInvocations == 2) + assert(rules("r3").numEffectiveInvocations == 1) + } + + test("topRulesByTime") { + val t = new QueryPlanningTracker + + // Return empty seq when k = 0 + assert(t.topRulesByTime(0) == Seq.empty) + assert(t.topRulesByTime(1) == Seq.empty) + + t.recordRuleInvocation("r2", 2, effective = true) + t.recordRuleInvocation("r4", 4, effective = true) + t.recordRuleInvocation("r1", 1, effective = false) + t.recordRuleInvocation("r3", 3, effective = false) + + // k <= total size + assert(t.topRulesByTime(0) == Seq.empty) + val top = t.topRulesByTime(2) + assert(top.size == 2) + assert(top(0)._1 == "r4") + assert(top(1)._1 == "r3") + + // k > total size + assert(t.topRulesByTime(10).size == 4) + } +} diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ScalaReflectionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ScalaReflectionSuite.scala index d98589db323cc..80824cc2a7f21 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ScalaReflectionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ScalaReflectionSuite.scala @@ -109,6 +109,30 @@ object TestingUDT { } } +/** An example derived from Twitter/Scrooge codegen for thrift */ +object ScroogeLikeExample { + def apply(x: Int): ScroogeLikeExample = new Immutable(x) + + def unapply(_item: ScroogeLikeExample): Option[Int] = Some(_item.x) + + class Immutable(val x: Int) extends ScroogeLikeExample +} + +trait ScroogeLikeExample extends Product1[Int] with Serializable { + import ScroogeLikeExample._ + + def x: Int + + def _1: Int = x + + override def canEqual(other: Any): Boolean = other.isInstanceOf[ScroogeLikeExample] + + override def equals(other: Any): Boolean = + canEqual(other) && + this.x == other.asInstanceOf[ScroogeLikeExample].x + + override def hashCode: Int = x +} class ScalaReflectionSuite extends SparkFunSuite { import org.apache.spark.sql.catalyst.ScalaReflection._ @@ -362,4 +386,11 @@ class ScalaReflectionSuite extends SparkFunSuite { assert(numberOfCheckedArguments(deserializerFor[(java.lang.Double, Int)]) == 1) assert(numberOfCheckedArguments(deserializerFor[(java.lang.Integer, java.lang.Integer)]) == 0) } + + test("SPARK-8288: schemaFor works for a class with only a companion object constructor") { + val schema = schemaFor[ScroogeLikeExample] + assert(schema === Schema( + StructType(Seq( + StructField("x", IntegerType, nullable = false))), nullable = true)) + } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala index 94778840d706b..117e96175e92a 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala @@ -17,8 +17,6 @@ package org.apache.spark.sql.catalyst.analysis -import scala.beans.{BeanInfo, BeanProperty} - import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.dsl.plans._ @@ -30,8 +28,9 @@ import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, GenericArrayData, MapData} import org.apache.spark.sql.types._ -@BeanInfo -private[sql] case class GroupableData(@BeanProperty data: Int) +private[sql] case class GroupableData(data: Int) { + def getData: Int = data +} private[sql] class GroupableUDT extends UserDefinedType[GroupableData] { @@ -50,8 +49,9 @@ private[sql] class GroupableUDT extends UserDefinedType[GroupableData] { private[spark] override def asNullable: GroupableUDT = this } -@BeanInfo -private[sql] case class UngroupableData(@BeanProperty data: Map[Int, Int]) +private[sql] case class UngroupableData(data: Map[Int, Int]) { + def getData: Map[Int, Int] = data +} private[sql] class UngroupableUDT extends UserDefinedType[UngroupableData] { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisTest.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisTest.scala index 3d7c91870133b..fab1b776a3c72 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisTest.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisTest.scala @@ -21,6 +21,7 @@ import java.net.URI import java.util.Locale import org.apache.spark.sql.AnalysisException +import org.apache.spark.sql.catalyst.QueryPlanningTracker import org.apache.spark.sql.catalyst.catalog.{CatalogDatabase, InMemoryCatalog, SessionCatalog} import org.apache.spark.sql.catalyst.plans.PlanTest import org.apache.spark.sql.catalyst.plans.logical._ @@ -54,7 +55,7 @@ trait AnalysisTest extends PlanTest { expectedPlan: LogicalPlan, caseSensitive: Boolean = true): Unit = { val analyzer = getAnalyzer(caseSensitive) - val actualPlan = analyzer.executeAndCheck(inputPlan) + val actualPlan = analyzer.executeAndCheck(inputPlan, new QueryPlanningTracker) comparePlans(actualPlan, expectedPlan) } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ResolveGroupingAnalyticsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ResolveGroupingAnalyticsSuite.scala index 8da4d7e3aa372..aa5eda8e5ba87 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ResolveGroupingAnalyticsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ResolveGroupingAnalyticsSuite.scala @@ -19,6 +19,7 @@ package org.apache.spark.sql.catalyst.analysis import java.util.TimeZone +import org.apache.spark.sql.catalyst.QueryPlanningTracker import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.logical._ @@ -109,7 +110,7 @@ class ResolveGroupingAnalyticsSuite extends AnalysisTest { Seq(UnresolvedAlias(Multiply(unresolved_a, Literal(2))), unresolved_b, UnresolvedAlias(count(unresolved_c)))) - val resultPlan = getAnalyzer(true).executeAndCheck(originalPlan2) + val resultPlan = getAnalyzer(true).executeAndCheck(originalPlan2, new QueryPlanningTracker) val gExpressions = resultPlan.asInstanceOf[Aggregate].groupingExpressions assert(gExpressions.size == 3) val firstGroupingExprAttrName = diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ResolvedUuidExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ResolvedUuidExpressionsSuite.scala index fe57c199b8744..64bd07534b19b 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ResolvedUuidExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ResolvedUuidExpressionsSuite.scala @@ -17,6 +17,7 @@ package org.apache.spark.sql.catalyst.analysis +import org.apache.spark.sql.catalyst.QueryPlanningTracker import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.dsl.plans._ import org.apache.spark.sql.catalyst.expressions._ @@ -34,6 +35,7 @@ class ResolvedUuidExpressionsSuite extends AnalysisTest { private lazy val uuid3 = Uuid().as('_uuid3) private lazy val uuid1Ref = uuid1.toAttribute + private val tracker = new QueryPlanningTracker private val analyzer = getAnalyzer(caseSensitive = true) private def getUuidExpressions(plan: LogicalPlan): Seq[Uuid] = { @@ -47,7 +49,7 @@ class ResolvedUuidExpressionsSuite extends AnalysisTest { test("analyzed plan sets random seed for Uuid expression") { val plan = r.select(a, uuid1) - val resolvedPlan = analyzer.executeAndCheck(plan) + val resolvedPlan = analyzer.executeAndCheck(plan, tracker) getUuidExpressions(resolvedPlan).foreach { u => assert(u.resolved) assert(u.randomSeed.isDefined) @@ -56,14 +58,14 @@ class ResolvedUuidExpressionsSuite extends AnalysisTest { test("Uuid expressions should have different random seeds") { val plan = r.select(a, uuid1).groupBy(uuid1Ref)(uuid2, uuid3) - val resolvedPlan = analyzer.executeAndCheck(plan) + val resolvedPlan = analyzer.executeAndCheck(plan, tracker) assert(getUuidExpressions(resolvedPlan).map(_.randomSeed.get).distinct.length == 3) } test("Different analyzed plans should have different random seeds in Uuids") { val plan = r.select(a, uuid1).groupBy(uuid1Ref)(uuid2, uuid3) - val resolvedPlan1 = analyzer.executeAndCheck(plan) - val resolvedPlan2 = analyzer.executeAndCheck(plan) + val resolvedPlan1 = analyzer.executeAndCheck(plan, tracker) + val resolvedPlan2 = analyzer.executeAndCheck(plan, tracker) val uuids1 = getUuidExpressions(resolvedPlan1) val uuids2 = getUuidExpressions(resolvedPlan2) assert(uuids1.distinct.length == 3) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoderSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoderSuite.scala index e9b100b3b30db..be8fd90c4c52a 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoderSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoderSuite.scala @@ -128,13 +128,13 @@ class ExpressionEncoderSuite extends CodegenInterpretedPlanTest with AnalysisTes encodeDecodeTest(-3.7f, "primitive float") encodeDecodeTest(-3.7, "primitive double") - encodeDecodeTest(new java.lang.Boolean(false), "boxed boolean") - encodeDecodeTest(new java.lang.Byte(-3.toByte), "boxed byte") - encodeDecodeTest(new java.lang.Short(-3.toShort), "boxed short") - encodeDecodeTest(new java.lang.Integer(-3), "boxed int") - encodeDecodeTest(new java.lang.Long(-3L), "boxed long") - encodeDecodeTest(new java.lang.Float(-3.7f), "boxed float") - encodeDecodeTest(new java.lang.Double(-3.7), "boxed double") + encodeDecodeTest(java.lang.Boolean.FALSE, "boxed boolean") + encodeDecodeTest(java.lang.Byte.valueOf(-3: Byte), "boxed byte") + encodeDecodeTest(java.lang.Short.valueOf(-3: Short), "boxed short") + encodeDecodeTest(java.lang.Integer.valueOf(-3), "boxed int") + encodeDecodeTest(java.lang.Long.valueOf(-3L), "boxed long") + encodeDecodeTest(java.lang.Float.valueOf(-3.7f), "boxed float") + encodeDecodeTest(java.lang.Double.valueOf(-3.7), "boxed double") encodeDecodeTest(BigDecimal("32131413.211321313"), "scala decimal") encodeDecodeTest(new java.math.BigDecimal("231341.23123"), "java decimal") @@ -224,7 +224,7 @@ class ExpressionEncoderSuite extends CodegenInterpretedPlanTest with AnalysisTes productTest( RepeatedData( Seq(1, 2), - Seq(new Integer(1), null, new Integer(2)), + Seq(Integer.valueOf(1), null, Integer.valueOf(2)), Map(1 -> 2L), Map(1 -> null), PrimitiveData(1, 1, 1, 1, 1, 1, true))) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala index 2e0adbb465008..d2edb2f24688d 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala @@ -25,6 +25,7 @@ import scala.util.Random import org.apache.spark.SparkFunSuite import org.apache.spark.sql.Row import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.analysis.TypeCheckResult import org.apache.spark.sql.catalyst.util.DateTimeTestUtils import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types._ @@ -108,32 +109,28 @@ class CollectionExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper } test("Map Concat") { - val m0 = Literal.create(Map("a" -> "1", "b" -> "2"), MapType(StringType, StringType, + val m0 = Literal.create(create_map("a" -> "1", "b" -> "2"), MapType(StringType, StringType, valueContainsNull = false)) - val m1 = Literal.create(Map("c" -> "3", "a" -> "4"), MapType(StringType, StringType, + val m1 = Literal.create(create_map("c" -> "3", "a" -> "4"), MapType(StringType, StringType, valueContainsNull = false)) - val m2 = Literal.create(Map("d" -> "4", "e" -> "5"), MapType(StringType, StringType)) - val m3 = Literal.create(Map("a" -> "1", "b" -> "2"), MapType(StringType, StringType)) - val m4 = Literal.create(Map("a" -> null, "c" -> "3"), MapType(StringType, StringType)) - val m5 = Literal.create(Map("a" -> 1, "b" -> 2), MapType(StringType, IntegerType)) - val m6 = Literal.create(Map("a" -> null, "c" -> 3), MapType(StringType, IntegerType)) - val m7 = Literal.create(Map(List(1, 2) -> 1, List(3, 4) -> 2), + val m2 = Literal.create(create_map("d" -> "4", "e" -> "5"), MapType(StringType, StringType)) + val m3 = Literal.create(create_map("a" -> "1", "b" -> "2"), MapType(StringType, StringType)) + val m4 = Literal.create(create_map("a" -> null, "c" -> "3"), MapType(StringType, StringType)) + val m5 = Literal.create(create_map("a" -> 1, "b" -> 2), MapType(StringType, IntegerType)) + val m6 = Literal.create(create_map("a" -> null, "c" -> 3), MapType(StringType, IntegerType)) + val m7 = Literal.create(create_map(List(1, 2) -> 1, List(3, 4) -> 2), MapType(ArrayType(IntegerType), IntegerType)) - val m8 = Literal.create(Map(List(5, 6) -> 3, List(1, 2) -> 4), + val m8 = Literal.create(create_map(List(5, 6) -> 3, List(1, 2) -> 4), MapType(ArrayType(IntegerType), IntegerType)) - val m9 = Literal.create(Map(Map(1 -> 2, 3 -> 4) -> 1, Map(5 -> 6, 7 -> 8) -> 2), - MapType(MapType(IntegerType, IntegerType), IntegerType)) - val m10 = Literal.create(Map(Map(9 -> 10, 11 -> 12) -> 3, Map(1 -> 2, 3 -> 4) -> 4), - MapType(MapType(IntegerType, IntegerType), IntegerType)) - val m11 = Literal.create(Map(1 -> "1", 2 -> "2"), MapType(IntegerType, StringType, + val m9 = Literal.create(create_map(1 -> "1", 2 -> "2"), MapType(IntegerType, StringType, valueContainsNull = false)) - val m12 = Literal.create(Map(3 -> "3", 4 -> "4"), MapType(IntegerType, StringType, + val m10 = Literal.create(create_map(3 -> "3", 4 -> "4"), MapType(IntegerType, StringType, valueContainsNull = false)) - val m13 = Literal.create(Map(1 -> 2, 3 -> 4), + val m11 = Literal.create(create_map(1 -> 2, 3 -> 4), MapType(IntegerType, IntegerType, valueContainsNull = false)) - val m14 = Literal.create(Map(5 -> 6), + val m12 = Literal.create(create_map(5 -> 6), MapType(IntegerType, IntegerType, valueContainsNull = false)) - val m15 = Literal.create(Map(7 -> null), + val m13 = Literal.create(create_map(7 -> null), MapType(IntegerType, IntegerType, valueContainsNull = true)) val mNull = Literal.create(null, MapType(StringType, StringType)) @@ -147,7 +144,7 @@ class CollectionExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper // maps with no overlap checkEvaluation(MapConcat(Seq(m0, m2)), - Map("a" -> "1", "b" -> "2", "d" -> "4", "e" -> "5")) + create_map("a" -> "1", "b" -> "2", "d" -> "4", "e" -> "5")) // 3 maps checkEvaluation(MapConcat(Seq(m0, m1, m2)), @@ -174,7 +171,7 @@ class CollectionExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper ) // keys that are primitive - checkEvaluation(MapConcat(Seq(m11, m12)), + checkEvaluation(MapConcat(Seq(m9, m10)), ( Array(1, 2, 3, 4), // keys Array("1", "2", "3", "4") // values @@ -189,20 +186,11 @@ class CollectionExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper ) ) - // keys that are maps, with overlap - checkEvaluation(MapConcat(Seq(m9, m10)), - ( - Array(Map(1 -> 2, 3 -> 4), Map(5 -> 6, 7 -> 8), Map(9 -> 10, 11 -> 12), - Map(1 -> 2, 3 -> 4)), // keys - Array(1, 2, 3, 4) // values - ) - ) - // both keys and value are primitive and valueContainsNull = false - checkEvaluation(MapConcat(Seq(m13, m14)), Map(1 -> 2, 3 -> 4, 5 -> 6)) + checkEvaluation(MapConcat(Seq(m11, m12)), create_map(1 -> 2, 3 -> 4, 5 -> 6)) // both keys and value are primitive and valueContainsNull = true - checkEvaluation(MapConcat(Seq(m13, m15)), Map(1 -> 2, 3 -> 4, 7 -> null)) + checkEvaluation(MapConcat(Seq(m11, m13)), create_map(1 -> 2, 3 -> 4, 7 -> null)) // null map checkEvaluation(MapConcat(Seq(m0, mNull)), null) @@ -211,7 +199,7 @@ class CollectionExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper checkEvaluation(MapConcat(Seq(mNull)), null) // single map - checkEvaluation(MapConcat(Seq(m0)), Map("a" -> "1", "b" -> "2")) + checkEvaluation(MapConcat(Seq(m0)), create_map("a" -> "1", "b" -> "2")) // no map checkEvaluation(MapConcat(Seq.empty), Map.empty) @@ -245,12 +233,12 @@ class CollectionExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper assert(MapConcat(Seq(m1, mNull)).nullable) val mapConcat = MapConcat(Seq( - Literal.create(Map(Seq(1, 2) -> Seq("a", "b")), + Literal.create(create_map(Seq(1, 2) -> Seq("a", "b")), MapType( ArrayType(IntegerType, containsNull = false), ArrayType(StringType, containsNull = false), valueContainsNull = false)), - Literal.create(Map(Seq(3, 4, null) -> Seq("c", "d", null), Seq(6) -> null), + Literal.create(create_map(Seq(3, 4, null) -> Seq("c", "d", null), Seq(6) -> null), MapType( ArrayType(IntegerType, containsNull = true), ArrayType(StringType, containsNull = true), @@ -264,6 +252,18 @@ class CollectionExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper Seq(1, 2) -> Seq("a", "b"), Seq(3, 4, null) -> Seq("c", "d", null), Seq(6) -> null)) + + // map key can't be map + val mapOfMap = Literal.create(Map(Map(1 -> 2, 3 -> 4) -> 1, Map(5 -> 6, 7 -> 8) -> 2), + MapType(MapType(IntegerType, IntegerType), IntegerType)) + val mapOfMap2 = Literal.create(Map(Map(9 -> 10, 11 -> 12) -> 3, Map(1 -> 2, 3 -> 4) -> 4), + MapType(MapType(IntegerType, IntegerType), IntegerType)) + val map = MapConcat(Seq(mapOfMap, mapOfMap2)) + map.checkInputDataTypes() match { + case TypeCheckResult.TypeCheckSuccess => fail("should not allow map as map key") + case TypeCheckResult.TypeCheckFailure(msg) => + assert(msg.contains("The key of map cannot be/contain map")) + } } test("MapFromEntries") { @@ -274,20 +274,20 @@ class CollectionExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper StructField("b", valueType))), true) } - def r(values: Any*): InternalRow = create_row(values: _*) + def row(values: Any*): InternalRow = create_row(values: _*) // Primitive-type keys and values val aiType = arrayType(IntegerType, IntegerType) - val ai0 = Literal.create(Seq(r(1, 10), r(2, 20), r(3, 20)), aiType) - val ai1 = Literal.create(Seq(r(1, null), r(2, 20), r(3, null)), aiType) + val ai0 = Literal.create(Seq(row(1, 10), row(2, 20), row(3, 20)), aiType) + val ai1 = Literal.create(Seq(row(1, null), row(2, 20), row(3, null)), aiType) val ai2 = Literal.create(Seq.empty, aiType) val ai3 = Literal.create(null, aiType) - val ai4 = Literal.create(Seq(r(1, 10), r(1, 20)), aiType) - val ai5 = Literal.create(Seq(r(1, 10), r(null, 20)), aiType) - val ai6 = Literal.create(Seq(null, r(2, 20), null), aiType) + val ai4 = Literal.create(Seq(row(1, 10), row(1, 20)), aiType) + val ai5 = Literal.create(Seq(row(1, 10), row(null, 20)), aiType) + val ai6 = Literal.create(Seq(null, row(2, 20), null), aiType) - checkEvaluation(MapFromEntries(ai0), Map(1 -> 10, 2 -> 20, 3 -> 20)) - checkEvaluation(MapFromEntries(ai1), Map(1 -> null, 2 -> 20, 3 -> null)) + checkEvaluation(MapFromEntries(ai0), create_map(1 -> 10, 2 -> 20, 3 -> 20)) + checkEvaluation(MapFromEntries(ai1), create_map(1 -> null, 2 -> 20, 3 -> null)) checkEvaluation(MapFromEntries(ai2), Map.empty) checkEvaluation(MapFromEntries(ai3), null) checkEvaluation(MapKeys(MapFromEntries(ai4)), Seq(1, 1)) @@ -298,23 +298,36 @@ class CollectionExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper // Non-primitive-type keys and values val asType = arrayType(StringType, StringType) - val as0 = Literal.create(Seq(r("a", "aa"), r("b", "bb"), r("c", "bb")), asType) - val as1 = Literal.create(Seq(r("a", null), r("b", "bb"), r("c", null)), asType) + val as0 = Literal.create(Seq(row("a", "aa"), row("b", "bb"), row("c", "bb")), asType) + val as1 = Literal.create(Seq(row("a", null), row("b", "bb"), row("c", null)), asType) val as2 = Literal.create(Seq.empty, asType) val as3 = Literal.create(null, asType) - val as4 = Literal.create(Seq(r("a", "aa"), r("a", "bb")), asType) - val as5 = Literal.create(Seq(r("a", "aa"), r(null, "bb")), asType) - val as6 = Literal.create(Seq(null, r("b", "bb"), null), asType) + val as4 = Literal.create(Seq(row("a", "aa"), row("a", "bb")), asType) + val as5 = Literal.create(Seq(row("a", "aa"), row(null, "bb")), asType) + val as6 = Literal.create(Seq(null, row("b", "bb"), null), asType) - checkEvaluation(MapFromEntries(as0), Map("a" -> "aa", "b" -> "bb", "c" -> "bb")) - checkEvaluation(MapFromEntries(as1), Map("a" -> null, "b" -> "bb", "c" -> null)) + checkEvaluation(MapFromEntries(as0), create_map("a" -> "aa", "b" -> "bb", "c" -> "bb")) + checkEvaluation(MapFromEntries(as1), create_map("a" -> null, "b" -> "bb", "c" -> null)) checkEvaluation(MapFromEntries(as2), Map.empty) checkEvaluation(MapFromEntries(as3), null) checkEvaluation(MapKeys(MapFromEntries(as4)), Seq("a", "a")) + checkEvaluation(MapFromEntries(as6), null) + + // Map key can't be null checkExceptionInExpression[RuntimeException]( MapFromEntries(as5), "The first field from a struct (key) can't be null.") - checkEvaluation(MapFromEntries(as6), null) + + // map key can't be map + val structOfMap = row(create_map(1 -> 1), 1) + val map = MapFromEntries(Literal.create( + Seq(structOfMap), + arrayType(keyType = MapType(IntegerType, IntegerType), valueType = IntegerType))) + map.checkInputDataTypes() match { + case TypeCheckResult.TypeCheckSuccess => fail("should not allow map as map key") + case TypeCheckResult.TypeCheckFailure(msg) => + assert(msg.contains("The key of map cannot be/contain map")) + } } test("Sort Array") { @@ -1645,6 +1658,19 @@ class CollectionExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper assert(ArrayExcept(a24, a22).dataType.asInstanceOf[ArrayType].containsNull === true) } + test("Array Except - null handling") { + val empty = Literal.create(Seq.empty[Int], ArrayType(IntegerType, containsNull = false)) + val oneNull = Literal.create(Seq(null), ArrayType(IntegerType)) + val twoNulls = Literal.create(Seq(null, null), ArrayType(IntegerType)) + + checkEvaluation(ArrayExcept(oneNull, oneNull), Seq.empty) + checkEvaluation(ArrayExcept(twoNulls, twoNulls), Seq.empty) + checkEvaluation(ArrayExcept(twoNulls, oneNull), Seq.empty) + checkEvaluation(ArrayExcept(empty, oneNull), Seq.empty) + checkEvaluation(ArrayExcept(oneNull, empty), Seq(null)) + checkEvaluation(ArrayExcept(twoNulls, empty), Seq(null)) + } + test("Array Intersect") { val a00 = Literal.create(Seq(1, 2, 4), ArrayType(IntegerType, false)) val a01 = Literal.create(Seq(4, 2), ArrayType(IntegerType, false)) @@ -1756,4 +1782,17 @@ class CollectionExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper assert(ArrayIntersect(a20, a21).dataType.asInstanceOf[ArrayType].containsNull === false) assert(ArrayIntersect(a23, a24).dataType.asInstanceOf[ArrayType].containsNull === true) } + + test("Array Intersect - null handling") { + val empty = Literal.create(Seq.empty[Int], ArrayType(IntegerType, containsNull = false)) + val oneNull = Literal.create(Seq(null), ArrayType(IntegerType)) + val twoNulls = Literal.create(Seq(null, null), ArrayType(IntegerType)) + + checkEvaluation(ArrayIntersect(oneNull, oneNull), Seq(null)) + checkEvaluation(ArrayIntersect(twoNulls, twoNulls), Seq(null)) + checkEvaluation(ArrayIntersect(twoNulls, oneNull), Seq(null)) + checkEvaluation(ArrayIntersect(oneNull, twoNulls), Seq(null)) + checkEvaluation(ArrayIntersect(empty, oneNull), Seq.empty) + checkEvaluation(ArrayIntersect(oneNull, empty), Seq.empty) + } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ComplexTypeSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ComplexTypeSuite.scala index 77aaf55480ec2..d95f42e04e37c 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ComplexTypeSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ComplexTypeSuite.scala @@ -18,7 +18,7 @@ package org.apache.spark.sql.catalyst.expressions import org.apache.spark.SparkFunSuite -import org.apache.spark.sql.catalyst.analysis.UnresolvedExtractValue +import org.apache.spark.sql.catalyst.analysis.{TypeCheckResult, UnresolvedExtractValue} import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.expressions.codegen.CodegenContext import org.apache.spark.sql.types._ @@ -158,40 +158,32 @@ class ComplexTypeSuite extends SparkFunSuite with ExpressionEvalHelper { keys.zip(values).flatMap { case (k, v) => Seq(k, v) } } - def createMap(keys: Seq[Any], values: Seq[Any]): Map[Any, Any] = { - // catalyst map is order-sensitive, so we create ListMap here to preserve the elements order. - scala.collection.immutable.ListMap(keys.zip(values): _*) - } - val intSeq = Seq(5, 10, 15, 20, 25) val longSeq = intSeq.map(_.toLong) val strSeq = intSeq.map(_.toString) + checkEvaluation(CreateMap(Nil), Map.empty) checkEvaluation( CreateMap(interlace(intSeq.map(Literal(_)), longSeq.map(Literal(_)))), - createMap(intSeq, longSeq)) + create_map(intSeq, longSeq)) checkEvaluation( CreateMap(interlace(strSeq.map(Literal(_)), longSeq.map(Literal(_)))), - createMap(strSeq, longSeq)) + create_map(strSeq, longSeq)) checkEvaluation( CreateMap(interlace(longSeq.map(Literal(_)), strSeq.map(Literal(_)))), - createMap(longSeq, strSeq)) + create_map(longSeq, strSeq)) val strWithNull = strSeq.drop(1).map(Literal(_)) :+ Literal.create(null, StringType) checkEvaluation( CreateMap(interlace(intSeq.map(Literal(_)), strWithNull)), - createMap(intSeq, strWithNull.map(_.value))) - intercept[RuntimeException] { - checkEvaluationWithoutCodegen( - CreateMap(interlace(strWithNull, intSeq.map(Literal(_)))), - null, null) - } - intercept[RuntimeException] { - checkEvaluationWithUnsafeProjection( - CreateMap(interlace(strWithNull, intSeq.map(Literal(_)))), - null, null) - } + create_map(intSeq, strWithNull.map(_.value))) + // Map key can't be null + checkExceptionInExpression[RuntimeException]( + CreateMap(interlace(strWithNull, intSeq.map(Literal(_)))), + "Cannot use null as map key") + + // ArrayType map key and value val map = CreateMap(Seq( Literal.create(intSeq, ArrayType(IntegerType, containsNull = false)), Literal.create(strSeq, ArrayType(StringType, containsNull = false)), @@ -202,15 +194,21 @@ class ComplexTypeSuite extends SparkFunSuite with ExpressionEvalHelper { ArrayType(IntegerType, containsNull = true), ArrayType(StringType, containsNull = true), valueContainsNull = false)) - checkEvaluation(map, createMap(Seq(intSeq, intSeq :+ null), Seq(strSeq, strSeq :+ null))) + checkEvaluation(map, create_map(intSeq -> strSeq, (intSeq :+ null) -> (strSeq :+ null))) + + // map key can't be map + val map2 = CreateMap(Seq( + Literal.create(create_map(1 -> 1), MapType(IntegerType, IntegerType)), + Literal(1) + )) + map2.checkInputDataTypes() match { + case TypeCheckResult.TypeCheckSuccess => fail("should not allow map as map key") + case TypeCheckResult.TypeCheckFailure(msg) => + assert(msg.contains("The key of map cannot be/contain map")) + } } test("MapFromArrays") { - def createMap(keys: Seq[Any], values: Seq[Any]): Map[Any, Any] = { - // catalyst map is order-sensitive, so we create ListMap here to preserve the elements order. - scala.collection.immutable.ListMap(keys.zip(values): _*) - } - val intSeq = Seq(5, 10, 15, 20, 25) val longSeq = intSeq.map(_.toLong) val strSeq = intSeq.map(_.toString) @@ -228,24 +226,33 @@ class ComplexTypeSuite extends SparkFunSuite with ExpressionEvalHelper { val nullArray = Literal.create(null, ArrayType(StringType, false)) - checkEvaluation(MapFromArrays(intArray, longArray), createMap(intSeq, longSeq)) - checkEvaluation(MapFromArrays(intArray, strArray), createMap(intSeq, strSeq)) - checkEvaluation(MapFromArrays(integerArray, strArray), createMap(integerSeq, strSeq)) + checkEvaluation(MapFromArrays(intArray, longArray), create_map(intSeq, longSeq)) + checkEvaluation(MapFromArrays(intArray, strArray), create_map(intSeq, strSeq)) + checkEvaluation(MapFromArrays(integerArray, strArray), create_map(integerSeq, strSeq)) checkEvaluation( - MapFromArrays(strArray, intWithNullArray), createMap(strSeq, intWithNullSeq)) + MapFromArrays(strArray, intWithNullArray), create_map(strSeq, intWithNullSeq)) checkEvaluation( - MapFromArrays(strArray, longWithNullArray), createMap(strSeq, longWithNullSeq)) + MapFromArrays(strArray, longWithNullArray), create_map(strSeq, longWithNullSeq)) checkEvaluation( - MapFromArrays(strArray, longWithNullArray), createMap(strSeq, longWithNullSeq)) + MapFromArrays(strArray, longWithNullArray), create_map(strSeq, longWithNullSeq)) checkEvaluation(MapFromArrays(nullArray, nullArray), null) - intercept[RuntimeException] { - checkEvaluation(MapFromArrays(intWithNullArray, strArray), null) - } - intercept[RuntimeException] { - checkEvaluation( - MapFromArrays(intArray, Literal.create(Seq(1), ArrayType(IntegerType))), null) + // Map key can't be null + checkExceptionInExpression[RuntimeException]( + MapFromArrays(intWithNullArray, strArray), + "Cannot use null as map key") + + // map key can't be map + val arrayOfMap = Seq(create_map(1 -> "a", 2 -> "b")) + val map = MapFromArrays( + Literal.create(arrayOfMap, ArrayType(MapType(IntegerType, StringType))), + Literal.create(Seq(1), ArrayType(IntegerType)) + ) + map.checkInputDataTypes() match { + case TypeCheckResult.TypeCheckSuccess => fail("should not allow map as map key") + case TypeCheckResult.TypeCheckFailure(msg) => + assert(msg.contains("The key of map cannot be/contain map")) } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CsvExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CsvExpressionsSuite.scala index d006197bd5678..98c93a4946f4f 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CsvExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CsvExpressionsSuite.scala @@ -17,11 +17,13 @@ package org.apache.spark.sql.catalyst.expressions -import java.util.Calendar +import java.text.SimpleDateFormat +import java.util.{Calendar, Locale} import org.scalatest.exceptions.TestFailedException import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.plans.PlanTestBase import org.apache.spark.sql.catalyst.util._ @@ -209,4 +211,30 @@ class CsvExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper with P "2015-12-31T16:00:00" ) } + + test("parse date with locale") { + Seq("en-US", "ru-RU").foreach { langTag => + val locale = Locale.forLanguageTag(langTag) + val date = new SimpleDateFormat("yyyy-MM-dd").parse("2018-11-05") + val schema = new StructType().add("d", DateType) + val dateFormat = "MMM yyyy" + val sdf = new SimpleDateFormat(dateFormat, locale) + val dateStr = sdf.format(date) + val options = Map("dateFormat" -> dateFormat, "locale" -> langTag) + + checkEvaluation( + CsvToStructs(schema, options, Literal.create(dateStr), gmtId), + InternalRow(17836)) // number of days from 1970-01-01 + } + } + + test("verify corrupt column") { + checkExceptionInExpression[AnalysisException]( + CsvToStructs( + schema = StructType.fromDDL("i int, _unparsed boolean"), + options = Map("columnNameOfCorruptRecord" -> "_unparsed"), + child = Literal.create("a"), + timeZoneId = gmtId), + expectedErrMsg = "The field for corrupt records must be string type and nullable") + } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala index da18475276a13..eb33325d0b31a 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala @@ -33,7 +33,7 @@ import org.apache.spark.sql.catalyst.expressions.codegen._ import org.apache.spark.sql.catalyst.optimizer.SimpleTestOptimizer import org.apache.spark.sql.catalyst.plans.PlanTestBase import org.apache.spark.sql.catalyst.plans.logical.{OneRowRelation, Project} -import org.apache.spark.sql.catalyst.util.{ArrayData, MapData} +import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, ArrayData, MapData} import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types._ import org.apache.spark.util.Utils @@ -48,6 +48,25 @@ trait ExpressionEvalHelper extends GeneratorDrivenPropertyChecks with PlanTestBa InternalRow.fromSeq(values.map(CatalystTypeConverters.convertToCatalyst)) } + // Currently MapData just stores the key and value arrays. Its equality is not well implemented, + // as the order of the map entries should not matter for equality. This method creates MapData + // with the entries ordering preserved, so that we can deterministically test expressions with + // map input/output. + protected def create_map(entries: (_, _)*): ArrayBasedMapData = { + create_map(entries.map(_._1), entries.map(_._2)) + } + + protected def create_map(keys: Seq[_], values: Seq[_]): ArrayBasedMapData = { + assert(keys.length == values.length) + val keyArray = CatalystTypeConverters + .convertToCatalyst(keys) + .asInstanceOf[ArrayData] + val valueArray = CatalystTypeConverters + .convertToCatalyst(values) + .asInstanceOf[ArrayData] + new ArrayBasedMapData(keyArray, valueArray) + } + private def prepareEvaluation(expression: Expression): Expression = { val serializer = new JavaSerializer(new SparkConf()).newInstance val resolver = ResolveTimeZone(new SQLConf) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/HigherOrderFunctionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/HigherOrderFunctionsSuite.scala index e13f4d98295be..66bf18af95799 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/HigherOrderFunctionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/HigherOrderFunctionsSuite.scala @@ -18,6 +18,7 @@ package org.apache.spark.sql.catalyst.expressions import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.catalyst.analysis.TypeCheckResult import org.apache.spark.sql.catalyst.util.ArrayBasedMapData import org.apache.spark.sql.types._ @@ -310,13 +311,13 @@ class HigherOrderFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper test("TransformKeys") { val ai0 = Literal.create( - Map(1 -> 1, 2 -> 2, 3 -> 3, 4 -> 4), + create_map(1 -> 1, 2 -> 2, 3 -> 3, 4 -> 4), MapType(IntegerType, IntegerType, valueContainsNull = false)) val ai1 = Literal.create( Map.empty[Int, Int], MapType(IntegerType, IntegerType, valueContainsNull = true)) val ai2 = Literal.create( - Map(1 -> 1, 2 -> null, 3 -> 3), + create_map(1 -> 1, 2 -> null, 3 -> 3), MapType(IntegerType, IntegerType, valueContainsNull = true)) val ai3 = Literal.create(null, MapType(IntegerType, IntegerType, valueContainsNull = false)) @@ -324,26 +325,27 @@ class HigherOrderFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper val plusValue: (Expression, Expression) => Expression = (k, v) => k + v val modKey: (Expression, Expression) => Expression = (k, v) => k % 3 - checkEvaluation(transformKeys(ai0, plusOne), Map(2 -> 1, 3 -> 2, 4 -> 3, 5 -> 4)) - checkEvaluation(transformKeys(ai0, plusValue), Map(2 -> 1, 4 -> 2, 6 -> 3, 8 -> 4)) + checkEvaluation(transformKeys(ai0, plusOne), create_map(2 -> 1, 3 -> 2, 4 -> 3, 5 -> 4)) + checkEvaluation(transformKeys(ai0, plusValue), create_map(2 -> 1, 4 -> 2, 6 -> 3, 8 -> 4)) checkEvaluation( - transformKeys(transformKeys(ai0, plusOne), plusValue), Map(3 -> 1, 5 -> 2, 7 -> 3, 9 -> 4)) + transformKeys(transformKeys(ai0, plusOne), plusValue), + create_map(3 -> 1, 5 -> 2, 7 -> 3, 9 -> 4)) checkEvaluation(transformKeys(ai0, modKey), ArrayBasedMapData(Array(1, 2, 0, 1), Array(1, 2, 3, 4))) checkEvaluation(transformKeys(ai1, plusOne), Map.empty[Int, Int]) checkEvaluation(transformKeys(ai1, plusOne), Map.empty[Int, Int]) checkEvaluation( transformKeys(transformKeys(ai1, plusOne), plusValue), Map.empty[Int, Int]) - checkEvaluation(transformKeys(ai2, plusOne), Map(2 -> 1, 3 -> null, 4 -> 3)) + checkEvaluation(transformKeys(ai2, plusOne), create_map(2 -> 1, 3 -> null, 4 -> 3)) checkEvaluation( - transformKeys(transformKeys(ai2, plusOne), plusOne), Map(3 -> 1, 4 -> null, 5 -> 3)) + transformKeys(transformKeys(ai2, plusOne), plusOne), create_map(3 -> 1, 4 -> null, 5 -> 3)) checkEvaluation(transformKeys(ai3, plusOne), null) val as0 = Literal.create( - Map("a" -> "xy", "bb" -> "yz", "ccc" -> "zx"), + create_map("a" -> "xy", "bb" -> "yz", "ccc" -> "zx"), MapType(StringType, StringType, valueContainsNull = false)) val as1 = Literal.create( - Map("a" -> "xy", "bb" -> "yz", "ccc" -> null), + create_map("a" -> "xy", "bb" -> "yz", "ccc" -> null), MapType(StringType, StringType, valueContainsNull = true)) val as2 = Literal.create(null, MapType(StringType, StringType, valueContainsNull = false)) @@ -355,26 +357,35 @@ class HigherOrderFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper (k, v) => Length(k) + 1 checkEvaluation( - transformKeys(as0, concatValue), Map("axy" -> "xy", "bbyz" -> "yz", "ccczx" -> "zx")) + transformKeys(as0, concatValue), create_map("axy" -> "xy", "bbyz" -> "yz", "ccczx" -> "zx")) checkEvaluation( transformKeys(transformKeys(as0, concatValue), concatValue), - Map("axyxy" -> "xy", "bbyzyz" -> "yz", "ccczxzx" -> "zx")) + create_map("axyxy" -> "xy", "bbyzyz" -> "yz", "ccczxzx" -> "zx")) checkEvaluation(transformKeys(as3, concatValue), Map.empty[String, String]) checkEvaluation( transformKeys(transformKeys(as3, concatValue), convertKeyToKeyLength), Map.empty[Int, String]) checkEvaluation(transformKeys(as0, convertKeyToKeyLength), - Map(2 -> "xy", 3 -> "yz", 4 -> "zx")) + create_map(2 -> "xy", 3 -> "yz", 4 -> "zx")) checkEvaluation(transformKeys(as1, convertKeyToKeyLength), - Map(2 -> "xy", 3 -> "yz", 4 -> null)) + create_map(2 -> "xy", 3 -> "yz", 4 -> null)) checkEvaluation(transformKeys(as2, convertKeyToKeyLength), null) checkEvaluation(transformKeys(as3, convertKeyToKeyLength), Map.empty[Int, String]) val ax0 = Literal.create( - Map(1 -> "x", 2 -> "y", 3 -> "z"), + create_map(1 -> "x", 2 -> "y", 3 -> "z"), MapType(IntegerType, StringType, valueContainsNull = false)) - checkEvaluation(transformKeys(ax0, plusOne), Map(2 -> "x", 3 -> "y", 4 -> "z")) + checkEvaluation(transformKeys(ax0, plusOne), create_map(2 -> "x", 3 -> "y", 4 -> "z")) + + // map key can't be map + val makeMap: (Expression, Expression) => Expression = (k, v) => CreateMap(Seq(k, v)) + val map = transformKeys(ai0, makeMap) + map.checkInputDataTypes() match { + case TypeCheckResult.TypeCheckSuccess => fail("should not allow map as map key") + case TypeCheckResult.TypeCheckFailure(msg) => + assert(msg.contains("The key of map cannot be/contain map")) + } } test("TransformValues") { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/JsonExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/JsonExpressionsSuite.scala index 304642161146b..9b89a27c23770 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/JsonExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/JsonExpressionsSuite.scala @@ -17,11 +17,13 @@ package org.apache.spark.sql.catalyst.expressions -import java.util.Calendar +import java.text.SimpleDateFormat +import java.util.{Calendar, Locale} import org.scalatest.exceptions.TestFailedException import org.apache.spark.{SparkException, SparkFunSuite} +import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.errors.TreeNodeException import org.apache.spark.sql.catalyst.plans.PlanTestBase @@ -546,7 +548,7 @@ class JsonExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper with val schema = StructType(StructField("a", IntegerType) :: Nil) checkEvaluation( JsonToStructs(schema, Map.empty, Literal.create(" ", StringType), gmtId), - null + InternalRow(null) ) } @@ -737,4 +739,30 @@ class JsonExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper with CreateMap(Seq(Literal.create("allowNumericLeadingZeros"), Literal.create("true")))), "struct") } + + test("parse date with locale") { + Seq("en-US", "ru-RU").foreach { langTag => + val locale = Locale.forLanguageTag(langTag) + val date = new SimpleDateFormat("yyyy-MM-dd").parse("2018-11-05") + val schema = new StructType().add("d", DateType) + val dateFormat = "MMM yyyy" + val sdf = new SimpleDateFormat(dateFormat, locale) + val dateStr = s"""{"d":"${sdf.format(date)}"}""" + val options = Map("dateFormat" -> dateFormat, "locale" -> langTag) + + checkEvaluation( + JsonToStructs(schema, options, Literal.create(dateStr), gmtId), + InternalRow(17836)) // number of days from 1970-01-01 + } + } + + test("verify corrupt column") { + checkExceptionInExpression[AnalysisException]( + JsonToStructs( + schema = StructType.fromDDL("i int, _unparsed boolean"), + options = Map("columnNameOfCorruptRecord" -> "_unparsed"), + child = Literal.create("""{"i":"a"}"""), + timeZoneId = gmtId), + expectedErrMsg = "The field for corrupt records must be string type and nullable") + } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ObjectExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ObjectExpressionsSuite.scala index d145fd0aaba47..436675bf50353 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ObjectExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ObjectExpressionsSuite.scala @@ -29,6 +29,7 @@ import org.apache.spark.{SparkConf, SparkFunSuite} import org.apache.spark.serializer.{JavaSerializer, KryoSerializer} import org.apache.spark.sql.{RandomDataGenerator, Row} import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow, JavaTypeInference, ScalaReflection} +import org.apache.spark.sql.catalyst.ScroogeLikeExample import org.apache.spark.sql.catalyst.analysis.{ResolveTimeZone, SimpleAnalyzer, UnresolvedDeserializer} import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.encoders._ @@ -307,7 +308,7 @@ class ObjectExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { val conf = new SparkConf() Seq(true, false).foreach { useKryo => val serializer = if (useKryo) new KryoSerializer(conf) else new JavaSerializer(conf) - val expected = serializer.newInstance().serialize(new Integer(1)).array() + val expected = serializer.newInstance().serialize(Integer.valueOf(1)).array() val encodeUsingSerializer = EncodeUsingSerializer(inputObject, useKryo) checkEvaluation(encodeUsingSerializer, expected, InternalRow.fromSeq(Seq(1))) checkEvaluation(encodeUsingSerializer, null, InternalRow.fromSeq(Seq(null))) @@ -384,9 +385,9 @@ class ObjectExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { val conf = new SparkConf() Seq(true, false).foreach { useKryo => val serializer = if (useKryo) new KryoSerializer(conf) else new JavaSerializer(conf) - val input = serializer.newInstance().serialize(new Integer(1)).array() + val input = serializer.newInstance().serialize(Integer.valueOf(1)).array() val decodeUsingSerializer = DecodeUsingSerializer(inputObject, ClassTag(cls), useKryo) - checkEvaluation(decodeUsingSerializer, new Integer(1), InternalRow.fromSeq(Seq(input))) + checkEvaluation(decodeUsingSerializer, Integer.valueOf(1), InternalRow.fromSeq(Seq(input))) checkEvaluation(decodeUsingSerializer, null, InternalRow.fromSeq(Seq(null))) } } @@ -410,6 +411,15 @@ class ObjectExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { dataType = ObjectType(classOf[outerObj.Inner]), outerPointer = Some(() => outerObj)) checkObjectExprEvaluation(newInst2, new outerObj.Inner(1)) + + // SPARK-8288: A class with only a companion object constructor + val newInst3 = NewInstance( + cls = classOf[ScroogeLikeExample], + arguments = Literal(1) :: Nil, + propagateNull = false, + dataType = ObjectType(classOf[ScroogeLikeExample]), + outerPointer = Some(() => outerObj)) + checkObjectExprEvaluation(newInst3, ScroogeLikeExample(1)) } test("LambdaVariable should support interpreted execution") { @@ -575,7 +585,7 @@ class ObjectExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { // NULL key test val scalaMapHasNullKey = scala.collection.Map[java.lang.Integer, String]( - null.asInstanceOf[java.lang.Integer] -> "v0", new java.lang.Integer(1) -> "v1") + null.asInstanceOf[java.lang.Integer] -> "v0", java.lang.Integer.valueOf(1) -> "v1") val javaMapHasNullKey = new java.util.HashMap[java.lang.Integer, java.lang.String]() { { put(null, "v0") diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/aggregate/AggregateExpressionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/aggregate/AggregateExpressionSuite.scala new file mode 100644 index 0000000000000..8e9c9972071ad --- /dev/null +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/aggregate/AggregateExpressionSuite.scala @@ -0,0 +1,34 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.expressions.aggregate + +import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.catalyst.analysis.UnresolvedAttribute +import org.apache.spark.sql.catalyst.expressions.{Add, AttributeSet} + +class AggregateExpressionSuite extends SparkFunSuite { + + test("test references from unresolved aggregate functions") { + val x = UnresolvedAttribute("x") + val y = UnresolvedAttribute("y") + val actual = AggregateExpression(Sum(Add(x, y)), mode = Complete, isDistinct = false).references + val expected = AttributeSet(x :: y :: Nil) + assert(expected == actual, s"Expected: $expected. Actual: $actual") + } + +} diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/aggregate/PercentileSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/aggregate/PercentileSuite.scala index 294fce8e9a10f..63c7b42978025 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/aggregate/PercentileSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/aggregate/PercentileSuite.scala @@ -41,9 +41,9 @@ class PercentileSuite extends SparkFunSuite { val buffer = new OpenHashMap[AnyRef, Long]() assert(compareEquals(agg.deserialize(agg.serialize(buffer)), buffer)) - // Check non-empty buffer serializa and deserialize. + // Check non-empty buffer serialize and deserialize. data.foreach { key => - buffer.changeValue(new Integer(key), 1L, _ + 1L) + buffer.changeValue(Integer.valueOf(key), 1L, _ + 1L) } assert(compareEquals(agg.deserialize(agg.serialize(buffer)), buffer)) } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/PullOutPythonUDFInJoinConditionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/PullOutPythonUDFInJoinConditionSuite.scala new file mode 100644 index 0000000000000..3f1c91df7f2e9 --- /dev/null +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/PullOutPythonUDFInJoinConditionSuite.scala @@ -0,0 +1,217 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.optimizer + +import org.apache.spark.api.python.PythonEvalType +import org.apache.spark.sql.AnalysisException +import org.apache.spark.sql.catalyst.dsl.expressions._ +import org.apache.spark.sql.catalyst.dsl.plans._ +import org.apache.spark.sql.catalyst.expressions.PythonUDF +import org.apache.spark.sql.catalyst.plans._ +import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan} +import org.apache.spark.sql.catalyst.rules.RuleExecutor +import org.apache.spark.sql.internal.SQLConf._ +import org.apache.spark.sql.types.{BooleanType, IntegerType} + +class PullOutPythonUDFInJoinConditionSuite extends PlanTest { + + object Optimize extends RuleExecutor[LogicalPlan] { + val batches = + Batch("Extract PythonUDF From JoinCondition", Once, + PullOutPythonUDFInJoinCondition) :: + Batch("Check Cartesian Products", Once, + CheckCartesianProducts) :: Nil + } + + val attrA = 'a.int + val attrB = 'b.int + val attrC = 'c.int + val attrD = 'd.int + + val testRelationLeft = LocalRelation(attrA, attrB) + val testRelationRight = LocalRelation(attrC, attrD) + + // This join condition refers to attributes from 2 tables, but the PythonUDF inside it only + // refer to attributes from one side. + val evaluableJoinCond = { + val pythonUDF = PythonUDF("evaluable", null, + IntegerType, + Seq(attrA), + PythonEvalType.SQL_BATCHED_UDF, + udfDeterministic = true) + pythonUDF === attrC + } + + // This join condition is a PythonUDF which refers to attributes from 2 tables. + val unevaluableJoinCond = PythonUDF("unevaluable", null, + BooleanType, + Seq(attrA, attrC), + PythonEvalType.SQL_BATCHED_UDF, + udfDeterministic = true) + + val unsupportedJoinTypes = Seq(LeftOuter, RightOuter, FullOuter, LeftAnti) + + private def comparePlanWithCrossJoinEnable(query: LogicalPlan, expected: LogicalPlan): Unit = { + // AnalysisException thrown by CheckCartesianProducts while spark.sql.crossJoin.enabled=false + val exception = intercept[AnalysisException] { + Optimize.execute(query.analyze) + } + assert(exception.message.startsWith("Detected implicit cartesian product")) + + // pull out the python udf while set spark.sql.crossJoin.enabled=true + withSQLConf(CROSS_JOINS_ENABLED.key -> "true") { + val optimized = Optimize.execute(query.analyze) + comparePlans(optimized, expected) + } + } + + test("inner join condition with python udf") { + val query1 = testRelationLeft.join( + testRelationRight, + joinType = Inner, + condition = Some(unevaluableJoinCond)) + val expected1 = testRelationLeft.join( + testRelationRight, + joinType = Inner, + condition = None).where(unevaluableJoinCond).analyze + comparePlanWithCrossJoinEnable(query1, expected1) + + // evaluable PythonUDF will not be touched + val query2 = testRelationLeft.join( + testRelationRight, + joinType = Inner, + condition = Some(evaluableJoinCond)) + comparePlans(Optimize.execute(query2), query2) + } + + test("left semi join condition with python udf") { + val query1 = testRelationLeft.join( + testRelationRight, + joinType = LeftSemi, + condition = Some(unevaluableJoinCond)) + val expected1 = testRelationLeft.join( + testRelationRight, + joinType = Inner, + condition = None).where(unevaluableJoinCond).select('a, 'b).analyze + comparePlanWithCrossJoinEnable(query1, expected1) + + // evaluable PythonUDF will not be touched + val query2 = testRelationLeft.join( + testRelationRight, + joinType = LeftSemi, + condition = Some(evaluableJoinCond)) + comparePlans(Optimize.execute(query2), query2) + } + + test("unevaluable python udf and common condition") { + val query = testRelationLeft.join( + testRelationRight, + joinType = Inner, + condition = Some(unevaluableJoinCond && 'a.attr === 'c.attr)) + val expected = testRelationLeft.join( + testRelationRight, + joinType = Inner, + condition = Some('a.attr === 'c.attr)).where(unevaluableJoinCond).analyze + val optimized = Optimize.execute(query.analyze) + comparePlans(optimized, expected) + } + + test("unevaluable python udf or common condition") { + val query = testRelationLeft.join( + testRelationRight, + joinType = Inner, + condition = Some(unevaluableJoinCond || 'a.attr === 'c.attr)) + val expected = testRelationLeft.join( + testRelationRight, + joinType = Inner, + condition = None).where(unevaluableJoinCond || 'a.attr === 'c.attr).analyze + comparePlanWithCrossJoinEnable(query, expected) + } + + test("pull out whole complex condition with multiple unevaluable python udf") { + val pythonUDF1 = PythonUDF("pythonUDF1", null, + BooleanType, + Seq(attrA, attrC), + PythonEvalType.SQL_BATCHED_UDF, + udfDeterministic = true) + val condition = (unevaluableJoinCond || 'a.attr === 'c.attr) && pythonUDF1 + + val query = testRelationLeft.join( + testRelationRight, + joinType = Inner, + condition = Some(condition)) + val expected = testRelationLeft.join( + testRelationRight, + joinType = Inner, + condition = None).where(condition).analyze + comparePlanWithCrossJoinEnable(query, expected) + } + + test("partial pull out complex condition with multiple unevaluable python udf") { + val pythonUDF1 = PythonUDF("pythonUDF1", null, + BooleanType, + Seq(attrA, attrC), + PythonEvalType.SQL_BATCHED_UDF, + udfDeterministic = true) + val condition = (unevaluableJoinCond || pythonUDF1) && 'a.attr === 'c.attr + + val query = testRelationLeft.join( + testRelationRight, + joinType = Inner, + condition = Some(condition)) + val expected = testRelationLeft.join( + testRelationRight, + joinType = Inner, + condition = Some('a.attr === 'c.attr)).where(unevaluableJoinCond || pythonUDF1).analyze + val optimized = Optimize.execute(query.analyze) + comparePlans(optimized, expected) + } + + test("pull out unevaluable python udf when it's mixed with evaluable one") { + val query = testRelationLeft.join( + testRelationRight, + joinType = Inner, + condition = Some(evaluableJoinCond && unevaluableJoinCond)) + val expected = testRelationLeft.join( + testRelationRight, + joinType = Inner, + condition = Some(evaluableJoinCond)).where(unevaluableJoinCond).analyze + val optimized = Optimize.execute(query.analyze) + comparePlans(optimized, expected) + } + + test("throw an exception for not support join type") { + for (joinType <- unsupportedJoinTypes) { + val e = intercept[AnalysisException] { + val query = testRelationLeft.join( + testRelationRight, + joinType, + condition = Some(unevaluableJoinCond)) + Optimize.execute(query.analyze) + } + assert(e.message.contentEquals( + s"Using PythonUDF in join condition of join type $joinType is not supported.")) + + val query2 = testRelationLeft.join( + testRelationRight, + joinType, + condition = Some(evaluableJoinCond)) + comparePlans(Optimize.execute(query2), query2) + } + } +} diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ReplaceNullWithFalseSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ReplaceNullWithFalseInPredicateSuite.scala similarity index 84% rename from sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ReplaceNullWithFalseSuite.scala rename to sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ReplaceNullWithFalseInPredicateSuite.scala index c6b5d0ec96776..ee0d04da3e46c 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ReplaceNullWithFalseSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ReplaceNullWithFalseInPredicateSuite.scala @@ -20,14 +20,14 @@ package org.apache.spark.sql.catalyst.optimizer import org.apache.spark.sql.catalyst.analysis.UnresolvedAttribute import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.dsl.plans._ -import org.apache.spark.sql.catalyst.expressions.{And, CaseWhen, Expression, GreaterThan, If, Literal, Or} +import org.apache.spark.sql.catalyst.expressions.{And, ArrayExists, ArrayFilter, ArrayTransform, CaseWhen, Expression, GreaterThan, If, LambdaFunction, Literal, MapFilter, NamedExpression, Or} import org.apache.spark.sql.catalyst.expressions.Literal.{FalseLiteral, TrueLiteral} import org.apache.spark.sql.catalyst.plans.{Inner, PlanTest} import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan} import org.apache.spark.sql.catalyst.rules.RuleExecutor import org.apache.spark.sql.types.{BooleanType, IntegerType} -class ReplaceNullWithFalseSuite extends PlanTest { +class ReplaceNullWithFalseInPredicateSuite extends PlanTest { object Optimize extends RuleExecutor[LogicalPlan] { val batches = @@ -36,15 +36,23 @@ class ReplaceNullWithFalseSuite extends PlanTest { ConstantFolding, BooleanSimplification, SimplifyConditionals, - ReplaceNullWithFalse) :: Nil + ReplaceNullWithFalseInPredicate) :: Nil } - private val testRelation = LocalRelation('i.int, 'b.boolean) + private val testRelation = + LocalRelation('i.int, 'b.boolean, 'a.array(IntegerType), 'm.map(IntegerType, IntegerType)) private val anotherTestRelation = LocalRelation('d.int) test("replace null inside filter and join conditions") { - testFilter(originalCond = Literal(null), expectedCond = FalseLiteral) - testJoin(originalCond = Literal(null), expectedCond = FalseLiteral) + testFilter(originalCond = Literal(null, BooleanType), expectedCond = FalseLiteral) + testJoin(originalCond = Literal(null, BooleanType), expectedCond = FalseLiteral) + } + + test("Not expected type - replaceNullWithFalse") { + val e = intercept[IllegalArgumentException] { + testFilter(originalCond = Literal(null, IntegerType), expectedCond = FalseLiteral) + }.getMessage + assert(e.contains("but got the type `int` in `CAST(NULL AS INT)")) } test("replace null in branches of If") { @@ -298,6 +306,26 @@ class ReplaceNullWithFalseSuite extends PlanTest { testProjection(originalExpr = column, expectedExpr = column) } + test("replace nulls in lambda function of ArrayFilter") { + testHigherOrderFunc('a, ArrayFilter, Seq('e)) + } + + test("replace nulls in lambda function of ArrayExists") { + testHigherOrderFunc('a, ArrayExists, Seq('e)) + } + + test("replace nulls in lambda function of MapFilter") { + testHigherOrderFunc('m, MapFilter, Seq('k, 'v)) + } + + test("inability to replace nulls in arbitrary higher-order function") { + val lambdaFunc = LambdaFunction( + function = If('e > 0, Literal(null, BooleanType), TrueLiteral), + arguments = Seq[NamedExpression]('e)) + val column = ArrayTransform('a, lambdaFunc) + testProjection(originalExpr = column, expectedExpr = column) + } + private def testFilter(originalCond: Expression, expectedCond: Expression): Unit = { test((rel, exp) => rel.where(exp), originalCond, expectedCond) } @@ -310,6 +338,25 @@ class ReplaceNullWithFalseSuite extends PlanTest { test((rel, exp) => rel.select(exp), originalExpr, expectedExpr) } + private def testHigherOrderFunc( + argument: Expression, + createExpr: (Expression, Expression) => Expression, + lambdaArgs: Seq[NamedExpression]): Unit = { + val condArg = lambdaArgs.last + // the lambda body is: if(arg > 0, null, true) + val cond = GreaterThan(condArg, Literal(0)) + val lambda1 = LambdaFunction( + function = If(cond, Literal(null, BooleanType), TrueLiteral), + arguments = lambdaArgs) + // the optimized lambda body is: if(arg > 0, false, true) + val lambda2 = LambdaFunction( + function = If(cond, FalseLiteral, TrueLiteral), + arguments = lambdaArgs) + testProjection( + originalExpr = createExpr(argument, lambda1) as 'x, + expectedExpr = createExpr(argument, lambda2) as 'x) + } + private def test( func: (LogicalPlan, Expression) => LogicalPlan, originalExpr: Expression, diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/SetOperationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/SetOperationSuite.scala index da3923f8d6477..17e00c9a3ead2 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/SetOperationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/SetOperationSuite.scala @@ -20,7 +20,7 @@ package org.apache.spark.sql.catalyst.optimizer import org.apache.spark.sql.catalyst.analysis.EliminateSubqueryAliases import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.dsl.plans._ -import org.apache.spark.sql.catalyst.expressions.{Alias, And, GreaterThan, GreaterThanOrEqual, If, Literal, ReplicateRows} +import org.apache.spark.sql.catalyst.expressions.{And, GreaterThan, GreaterThanOrEqual, If, Literal, Rand, ReplicateRows} import org.apache.spark.sql.catalyst.plans.PlanTest import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.rules._ @@ -196,4 +196,31 @@ class SetOperationSuite extends PlanTest { )) comparePlans(expectedPlan, rewrittenPlan) } + + test("SPARK-23356 union: expressions with literal in project list are pushed down") { + val unionQuery = testUnion.select(('a + 1).as("aa")) + val unionOptimized = Optimize.execute(unionQuery.analyze) + val unionCorrectAnswer = + Union(testRelation.select(('a + 1).as("aa")) :: + testRelation2.select(('d + 1).as("aa")) :: + testRelation3.select(('g + 1).as("aa")) :: Nil).analyze + comparePlans(unionOptimized, unionCorrectAnswer) + } + + test("SPARK-23356 union: expressions in project list are pushed down") { + val unionQuery = testUnion.select(('a + 'b).as("ab")) + val unionOptimized = Optimize.execute(unionQuery.analyze) + val unionCorrectAnswer = + Union(testRelation.select(('a + 'b).as("ab")) :: + testRelation2.select(('d + 'e).as("ab")) :: + testRelation3.select(('g + 'h).as("ab")) :: Nil).analyze + comparePlans(unionOptimized, unionCorrectAnswer) + } + + test("SPARK-23356 union: no pushdown for non-deterministic expression") { + val unionQuery = testUnion.select('a, Rand(10).as("rnd")) + val unionOptimized = Optimize.execute(unionQuery.analyze) + val unionCorrectAnswer = unionQuery.analyze + comparePlans(unionOptimized, unionCorrectAnswer) + } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/types/DecimalSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/types/DecimalSuite.scala index 10de90c6a44ca..8abd7625c21aa 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/types/DecimalSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/types/DecimalSuite.scala @@ -228,4 +228,15 @@ class DecimalSuite extends SparkFunSuite with PrivateMethodTester { val decimal = Decimal.apply(bigInt) assert(decimal.toJavaBigDecimal.unscaledValue.toString === "9223372036854775808") } + + test("SPARK-26038: toScalaBigInt/toJavaBigInteger") { + // not fitting long + val decimal = Decimal("1234568790123456789012348790.1234879012345678901234568790") + assert(decimal.toScalaBigInt == scala.math.BigInt("1234568790123456789012348790")) + assert(decimal.toJavaBigInteger == new java.math.BigInteger("1234568790123456789012348790")) + // fitting long + val decimalLong = Decimal(123456789123456789L, 18, 9) + assert(decimalLong.toScalaBigInt == scala.math.BigInt("123456789")) + assert(decimalLong.toJavaBigInteger == new java.math.BigInteger("123456789")) + } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/util/UtilSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/util/UtilSuite.scala new file mode 100644 index 0000000000000..9c162026942f6 --- /dev/null +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/util/UtilSuite.scala @@ -0,0 +1,31 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.util + +import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.catalyst.util.truncatedString + +class UtilSuite extends SparkFunSuite { + test("truncatedString") { + assert(truncatedString(Nil, "[", ", ", "]", 2) == "[]") + assert(truncatedString(Seq(1, 2), "[", ", ", "]", 2) == "[1, 2]") + assert(truncatedString(Seq(1, 2, 3), "[", ", ", "]", 2) == "[1, ... 2 more fields]") + assert(truncatedString(Seq(1, 2, 3), "[", ", ", "]", -5) == "[, ... 3 more fields]") + assert(truncatedString(Seq(1, 2, 3), ", ") == "1, 2, 3") + } +} diff --git a/sql/core/benchmarks/DataSourceReadBenchmark-results.txt b/sql/core/benchmarks/DataSourceReadBenchmark-results.txt index 2d3bae442cc50..b07e8b1197ff0 100644 --- a/sql/core/benchmarks/DataSourceReadBenchmark-results.txt +++ b/sql/core/benchmarks/DataSourceReadBenchmark-results.txt @@ -2,268 +2,268 @@ SQL Single Numeric Column Scan ================================================================================================ -OpenJDK 64-Bit Server VM 1.8.0_181-b13 on Linux 3.10.0-862.3.2.el7.x86_64 +OpenJDK 64-Bit Server VM 1.8.0_191-b12 on Linux 3.10.0-862.3.2.el7.x86_64 Intel(R) Xeon(R) CPU E5-2670 v2 @ 2.50GHz SQL Single TINYINT Column Scan: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------ -SQL CSV 21508 / 22112 0.7 1367.5 1.0X -SQL Json 8705 / 8825 1.8 553.4 2.5X -SQL Parquet Vectorized 157 / 186 100.0 10.0 136.7X -SQL Parquet MR 1789 / 1794 8.8 113.8 12.0X -SQL ORC Vectorized 156 / 166 100.9 9.9 138.0X -SQL ORC Vectorized with copy 218 / 225 72.1 13.9 98.6X -SQL ORC MR 1448 / 1492 10.9 92.0 14.9X - -OpenJDK 64-Bit Server VM 1.8.0_181-b13 on Linux 3.10.0-862.3.2.el7.x86_64 +SQL CSV 26366 / 26562 0.6 1676.3 1.0X +SQL Json 8709 / 8724 1.8 553.7 3.0X +SQL Parquet Vectorized 166 / 187 94.8 10.5 159.0X +SQL Parquet MR 1706 / 1720 9.2 108.4 15.5X +SQL ORC Vectorized 167 / 174 94.2 10.6 157.9X +SQL ORC Vectorized with copy 226 / 231 69.6 14.4 116.7X +SQL ORC MR 1433 / 1465 11.0 91.1 18.4X + +OpenJDK 64-Bit Server VM 1.8.0_191-b12 on Linux 3.10.0-862.3.2.el7.x86_64 Intel(R) Xeon(R) CPU E5-2670 v2 @ 2.50GHz Parquet Reader Single TINYINT Column Scan: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------ -ParquetReader Vectorized 202 / 211 77.7 12.9 1.0X -ParquetReader Vectorized -> Row 118 / 120 133.5 7.5 1.7X +ParquetReader Vectorized 200 / 207 78.7 12.7 1.0X +ParquetReader Vectorized -> Row 117 / 119 134.7 7.4 1.7X -OpenJDK 64-Bit Server VM 1.8.0_181-b13 on Linux 3.10.0-862.3.2.el7.x86_64 +OpenJDK 64-Bit Server VM 1.8.0_191-b12 on Linux 3.10.0-862.3.2.el7.x86_64 Intel(R) Xeon(R) CPU E5-2670 v2 @ 2.50GHz SQL Single SMALLINT Column Scan: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------ -SQL CSV 23282 / 23312 0.7 1480.2 1.0X -SQL Json 9187 / 9189 1.7 584.1 2.5X -SQL Parquet Vectorized 204 / 218 77.0 13.0 114.0X -SQL Parquet MR 1941 / 1953 8.1 123.4 12.0X -SQL ORC Vectorized 217 / 225 72.6 13.8 107.5X -SQL ORC Vectorized with copy 279 / 289 56.3 17.8 83.4X -SQL ORC MR 1541 / 1549 10.2 98.0 15.1X - -OpenJDK 64-Bit Server VM 1.8.0_181-b13 on Linux 3.10.0-862.3.2.el7.x86_64 +SQL CSV 26489 / 26547 0.6 1684.1 1.0X +SQL Json 8990 / 8998 1.7 571.5 2.9X +SQL Parquet Vectorized 209 / 221 75.1 13.3 126.5X +SQL Parquet MR 1949 / 1949 8.1 123.9 13.6X +SQL ORC Vectorized 221 / 228 71.3 14.0 120.1X +SQL ORC Vectorized with copy 315 / 319 49.9 20.1 84.0X +SQL ORC MR 1527 / 1549 10.3 97.1 17.3X + +OpenJDK 64-Bit Server VM 1.8.0_191-b12 on Linux 3.10.0-862.3.2.el7.x86_64 Intel(R) Xeon(R) CPU E5-2670 v2 @ 2.50GHz Parquet Reader Single SMALLINT Column Scan: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------ -ParquetReader Vectorized 288 / 297 54.6 18.3 1.0X -ParquetReader Vectorized -> Row 255 / 257 61.7 16.2 1.1X +ParquetReader Vectorized 286 / 296 54.9 18.2 1.0X +ParquetReader Vectorized -> Row 249 / 253 63.1 15.8 1.1X -OpenJDK 64-Bit Server VM 1.8.0_181-b13 on Linux 3.10.0-862.3.2.el7.x86_64 +OpenJDK 64-Bit Server VM 1.8.0_191-b12 on Linux 3.10.0-862.3.2.el7.x86_64 Intel(R) Xeon(R) CPU E5-2670 v2 @ 2.50GHz SQL Single INT Column Scan: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------ -SQL CSV 24990 / 25012 0.6 1588.8 1.0X -SQL Json 9837 / 9865 1.6 625.4 2.5X -SQL Parquet Vectorized 170 / 180 92.3 10.8 146.6X -SQL Parquet MR 2319 / 2328 6.8 147.4 10.8X -SQL ORC Vectorized 293 / 301 53.7 18.6 85.3X -SQL ORC Vectorized with copy 297 / 309 52.9 18.9 84.0X -SQL ORC MR 1667 / 1674 9.4 106.0 15.0X - -OpenJDK 64-Bit Server VM 1.8.0_181-b13 on Linux 3.10.0-862.3.2.el7.x86_64 +SQL CSV 27701 / 27744 0.6 1761.2 1.0X +SQL Json 9703 / 9733 1.6 616.9 2.9X +SQL Parquet Vectorized 176 / 182 89.2 11.2 157.0X +SQL Parquet MR 2164 / 2173 7.3 137.6 12.8X +SQL ORC Vectorized 307 / 314 51.2 19.5 90.2X +SQL ORC Vectorized with copy 312 / 319 50.4 19.8 88.7X +SQL ORC MR 1690 / 1700 9.3 107.4 16.4X + +OpenJDK 64-Bit Server VM 1.8.0_191-b12 on Linux 3.10.0-862.3.2.el7.x86_64 Intel(R) Xeon(R) CPU E5-2670 v2 @ 2.50GHz Parquet Reader Single INT Column Scan: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------ -ParquetReader Vectorized 257 / 274 61.3 16.3 1.0X -ParquetReader Vectorized -> Row 259 / 264 60.8 16.4 1.0X +ParquetReader Vectorized 259 / 277 60.7 16.5 1.0X +ParquetReader Vectorized -> Row 261 / 265 60.3 16.6 1.0X -OpenJDK 64-Bit Server VM 1.8.0_181-b13 on Linux 3.10.0-862.3.2.el7.x86_64 +OpenJDK 64-Bit Server VM 1.8.0_191-b12 on Linux 3.10.0-862.3.2.el7.x86_64 Intel(R) Xeon(R) CPU E5-2670 v2 @ 2.50GHz SQL Single BIGINT Column Scan: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------ -SQL CSV 32537 / 32554 0.5 2068.7 1.0X -SQL Json 12610 / 12668 1.2 801.7 2.6X -SQL Parquet Vectorized 258 / 276 61.0 16.4 126.2X -SQL Parquet MR 2422 / 2435 6.5 154.0 13.4X -SQL ORC Vectorized 378 / 385 41.6 24.0 86.2X -SQL ORC Vectorized with copy 381 / 389 41.3 24.2 85.4X -SQL ORC MR 1797 / 1819 8.8 114.3 18.1X - -OpenJDK 64-Bit Server VM 1.8.0_181-b13 on Linux 3.10.0-862.3.2.el7.x86_64 +SQL CSV 34813 / 34900 0.5 2213.3 1.0X +SQL Json 12570 / 12617 1.3 799.2 2.8X +SQL Parquet Vectorized 270 / 308 58.2 17.2 128.9X +SQL Parquet MR 2427 / 2431 6.5 154.3 14.3X +SQL ORC Vectorized 388 / 398 40.6 24.6 89.8X +SQL ORC Vectorized with copy 395 / 402 39.9 25.1 88.2X +SQL ORC MR 1819 / 1851 8.6 115.7 19.1X + +OpenJDK 64-Bit Server VM 1.8.0_191-b12 on Linux 3.10.0-862.3.2.el7.x86_64 Intel(R) Xeon(R) CPU E5-2670 v2 @ 2.50GHz Parquet Reader Single BIGINT Column Scan: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------ -ParquetReader Vectorized 352 / 368 44.7 22.4 1.0X -ParquetReader Vectorized -> Row 351 / 359 44.8 22.3 1.0X +ParquetReader Vectorized 372 / 379 42.3 23.7 1.0X +ParquetReader Vectorized -> Row 357 / 368 44.1 22.7 1.0X -OpenJDK 64-Bit Server VM 1.8.0_181-b13 on Linux 3.10.0-862.3.2.el7.x86_64 +OpenJDK 64-Bit Server VM 1.8.0_191-b12 on Linux 3.10.0-862.3.2.el7.x86_64 Intel(R) Xeon(R) CPU E5-2670 v2 @ 2.50GHz SQL Single FLOAT Column Scan: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------ -SQL CSV 27179 / 27184 0.6 1728.0 1.0X -SQL Json 12578 / 12585 1.3 799.7 2.2X -SQL Parquet Vectorized 161 / 171 97.5 10.3 168.5X -SQL Parquet MR 2361 / 2395 6.7 150.1 11.5X -SQL ORC Vectorized 473 / 480 33.3 30.0 57.5X -SQL ORC Vectorized with copy 478 / 483 32.9 30.4 56.8X -SQL ORC MR 1858 / 1859 8.5 118.2 14.6X - -OpenJDK 64-Bit Server VM 1.8.0_181-b13 on Linux 3.10.0-862.3.2.el7.x86_64 +SQL CSV 28753 / 28781 0.5 1828.0 1.0X +SQL Json 12039 / 12215 1.3 765.4 2.4X +SQL Parquet Vectorized 170 / 177 92.4 10.8 169.0X +SQL Parquet MR 2184 / 2196 7.2 138.9 13.2X +SQL ORC Vectorized 432 / 440 36.4 27.5 66.5X +SQL ORC Vectorized with copy 439 / 442 35.9 27.9 65.6X +SQL ORC MR 1812 / 1833 8.7 115.2 15.9X + +OpenJDK 64-Bit Server VM 1.8.0_191-b12 on Linux 3.10.0-862.3.2.el7.x86_64 Intel(R) Xeon(R) CPU E5-2670 v2 @ 2.50GHz Parquet Reader Single FLOAT Column Scan: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------ -ParquetReader Vectorized 251 / 255 62.7 15.9 1.0X -ParquetReader Vectorized -> Row 255 / 259 61.8 16.2 1.0X +ParquetReader Vectorized 253 / 260 62.2 16.1 1.0X +ParquetReader Vectorized -> Row 256 / 257 61.6 16.2 1.0X -OpenJDK 64-Bit Server VM 1.8.0_181-b13 on Linux 3.10.0-862.3.2.el7.x86_64 +OpenJDK 64-Bit Server VM 1.8.0_191-b12 on Linux 3.10.0-862.3.2.el7.x86_64 Intel(R) Xeon(R) CPU E5-2670 v2 @ 2.50GHz SQL Single DOUBLE Column Scan: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------ -SQL CSV 34797 / 34830 0.5 2212.3 1.0X -SQL Json 17806 / 17828 0.9 1132.1 2.0X -SQL Parquet Vectorized 260 / 269 60.6 16.5 134.0X -SQL Parquet MR 2512 / 2534 6.3 159.7 13.9X -SQL ORC Vectorized 582 / 593 27.0 37.0 59.8X -SQL ORC Vectorized with copy 576 / 584 27.3 36.6 60.4X -SQL ORC MR 2309 / 2313 6.8 146.8 15.1X - -OpenJDK 64-Bit Server VM 1.8.0_181-b13 on Linux 3.10.0-862.3.2.el7.x86_64 +SQL CSV 36177 / 36188 0.4 2300.1 1.0X +SQL Json 18895 / 18898 0.8 1201.3 1.9X +SQL Parquet Vectorized 267 / 276 58.9 17.0 135.6X +SQL Parquet MR 2355 / 2363 6.7 149.7 15.4X +SQL ORC Vectorized 543 / 546 29.0 34.5 66.6X +SQL ORC Vectorized with copy 548 / 557 28.7 34.8 66.0X +SQL ORC MR 2246 / 2258 7.0 142.8 16.1X + +OpenJDK 64-Bit Server VM 1.8.0_191-b12 on Linux 3.10.0-862.3.2.el7.x86_64 Intel(R) Xeon(R) CPU E5-2670 v2 @ 2.50GHz Parquet Reader Single DOUBLE Column Scan: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------ -ParquetReader Vectorized 350 / 363 44.9 22.3 1.0X -ParquetReader Vectorized -> Row 350 / 366 44.9 22.3 1.0X +ParquetReader Vectorized 353 / 367 44.6 22.4 1.0X +ParquetReader Vectorized -> Row 351 / 357 44.7 22.3 1.0X ================================================================================================ Int and String Scan ================================================================================================ -OpenJDK 64-Bit Server VM 1.8.0_181-b13 on Linux 3.10.0-862.3.2.el7.x86_64 +OpenJDK 64-Bit Server VM 1.8.0_191-b12 on Linux 3.10.0-862.3.2.el7.x86_64 Intel(R) Xeon(R) CPU E5-2670 v2 @ 2.50GHz Int and String Scan: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------ -SQL CSV 22486 / 22590 0.5 2144.5 1.0X -SQL Json 14124 / 14195 0.7 1347.0 1.6X -SQL Parquet Vectorized 2342 / 2347 4.5 223.4 9.6X -SQL Parquet MR 4660 / 4664 2.2 444.4 4.8X -SQL ORC Vectorized 2378 / 2379 4.4 226.8 9.5X -SQL ORC Vectorized with copy 2548 / 2571 4.1 243.0 8.8X -SQL ORC MR 4206 / 4211 2.5 401.1 5.3X +SQL CSV 21130 / 21246 0.5 2015.1 1.0X +SQL Json 12145 / 12174 0.9 1158.2 1.7X +SQL Parquet Vectorized 2363 / 2377 4.4 225.3 8.9X +SQL Parquet MR 4555 / 4557 2.3 434.4 4.6X +SQL ORC Vectorized 2361 / 2388 4.4 225.1 9.0X +SQL ORC Vectorized with copy 2540 / 2557 4.1 242.2 8.3X +SQL ORC MR 4186 / 4209 2.5 399.2 5.0X ================================================================================================ Repeated String Scan ================================================================================================ -OpenJDK 64-Bit Server VM 1.8.0_181-b13 on Linux 3.10.0-862.3.2.el7.x86_64 +OpenJDK 64-Bit Server VM 1.8.0_191-b12 on Linux 3.10.0-862.3.2.el7.x86_64 Intel(R) Xeon(R) CPU E5-2670 v2 @ 2.50GHz Repeated String: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------ -SQL CSV 12150 / 12178 0.9 1158.7 1.0X -SQL Json 7012 / 7014 1.5 668.7 1.7X -SQL Parquet Vectorized 792 / 796 13.2 75.5 15.3X -SQL Parquet MR 1961 / 1975 5.3 187.0 6.2X -SQL ORC Vectorized 482 / 485 21.8 46.0 25.2X -SQL ORC Vectorized with copy 710 / 715 14.8 67.7 17.1X -SQL ORC MR 2081 / 2083 5.0 198.5 5.8X +SQL CSV 11693 / 11729 0.9 1115.1 1.0X +SQL Json 7025 / 7025 1.5 669.9 1.7X +SQL Parquet Vectorized 803 / 821 13.1 76.6 14.6X +SQL Parquet MR 1776 / 1790 5.9 169.4 6.6X +SQL ORC Vectorized 491 / 494 21.4 46.8 23.8X +SQL ORC Vectorized with copy 723 / 725 14.5 68.9 16.2X +SQL ORC MR 2050 / 2063 5.1 195.5 5.7X ================================================================================================ Partitioned Table Scan ================================================================================================ -OpenJDK 64-Bit Server VM 1.8.0_181-b13 on Linux 3.10.0-862.3.2.el7.x86_64 +OpenJDK 64-Bit Server VM 1.8.0_191-b12 on Linux 3.10.0-862.3.2.el7.x86_64 Intel(R) Xeon(R) CPU E5-2670 v2 @ 2.50GHz Partitioned Table: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------ -Data column - CSV 31789 / 31791 0.5 2021.1 1.0X -Data column - Json 12873 / 12918 1.2 818.4 2.5X -Data column - Parquet Vectorized 267 / 280 58.9 17.0 119.1X -Data column - Parquet MR 3387 / 3402 4.6 215.3 9.4X -Data column - ORC Vectorized 391 / 453 40.2 24.9 81.2X -Data column - ORC Vectorized with copy 392 / 398 40.2 24.9 81.2X -Data column - ORC MR 2508 / 2512 6.3 159.4 12.7X -Partition column - CSV 6965 / 6977 2.3 442.8 4.6X -Partition column - Json 5563 / 5576 2.8 353.7 5.7X -Partition column - Parquet Vectorized 65 / 78 241.1 4.1 487.2X -Partition column - Parquet MR 1811 / 1811 8.7 115.1 17.6X -Partition column - ORC Vectorized 66 / 73 239.0 4.2 483.0X -Partition column - ORC Vectorized with copy 65 / 70 241.1 4.1 487.3X -Partition column - ORC MR 1775 / 1778 8.9 112.8 17.9X -Both columns - CSV 30032 / 30113 0.5 1909.4 1.1X -Both columns - Json 13941 / 13959 1.1 886.3 2.3X -Both columns - Parquet Vectorized 312 / 330 50.3 19.9 101.7X -Both columns - Parquet MR 3858 / 3862 4.1 245.3 8.2X -Both columns - ORC Vectorized 431 / 437 36.5 27.4 73.8X -Both column - ORC Vectorized with copy 523 / 529 30.1 33.3 60.7X -Both columns - ORC MR 2712 / 2805 5.8 172.4 11.7X +Data column - CSV 30965 / 31041 0.5 1968.7 1.0X +Data column - Json 12876 / 12882 1.2 818.6 2.4X +Data column - Parquet Vectorized 277 / 282 56.7 17.6 111.6X +Data column - Parquet MR 3398 / 3402 4.6 216.0 9.1X +Data column - ORC Vectorized 399 / 407 39.4 25.4 77.5X +Data column - ORC Vectorized with copy 407 / 447 38.6 25.9 76.0X +Data column - ORC MR 2583 / 2589 6.1 164.2 12.0X +Partition column - CSV 7403 / 7427 2.1 470.7 4.2X +Partition column - Json 5587 / 5625 2.8 355.2 5.5X +Partition column - Parquet Vectorized 71 / 78 222.6 4.5 438.3X +Partition column - Parquet MR 1798 / 1808 8.7 114.3 17.2X +Partition column - ORC Vectorized 72 / 75 219.0 4.6 431.2X +Partition column - ORC Vectorized with copy 71 / 77 221.1 4.5 435.4X +Partition column - ORC MR 1772 / 1778 8.9 112.6 17.5X +Both columns - CSV 30211 / 30212 0.5 1920.7 1.0X +Both columns - Json 13382 / 13391 1.2 850.8 2.3X +Both columns - Parquet Vectorized 321 / 333 49.0 20.4 96.4X +Both columns - Parquet MR 3656 / 3661 4.3 232.4 8.5X +Both columns - ORC Vectorized 443 / 448 35.5 28.2 69.9X +Both column - ORC Vectorized with copy 527 / 533 29.9 33.5 58.8X +Both columns - ORC MR 2626 / 2633 6.0 167.0 11.8X ================================================================================================ String with Nulls Scan ================================================================================================ -OpenJDK 64-Bit Server VM 1.8.0_181-b13 on Linux 3.10.0-862.3.2.el7.x86_64 +OpenJDK 64-Bit Server VM 1.8.0_191-b12 on Linux 3.10.0-862.3.2.el7.x86_64 Intel(R) Xeon(R) CPU E5-2670 v2 @ 2.50GHz -String with Nulls Scan: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative +String with Nulls Scan (0.0%): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------ -SQL CSV 13525 / 13823 0.8 1289.9 1.0X -SQL Json 9913 / 9921 1.1 945.3 1.4X -SQL Parquet Vectorized 1517 / 1517 6.9 144.7 8.9X -SQL Parquet MR 3996 / 4008 2.6 381.1 3.4X -ParquetReader Vectorized 1120 / 1128 9.4 106.8 12.1X -SQL ORC Vectorized 1203 / 1224 8.7 114.7 11.2X -SQL ORC Vectorized with copy 1639 / 1646 6.4 156.3 8.3X -SQL ORC MR 3720 / 3780 2.8 354.7 3.6X - -OpenJDK 64-Bit Server VM 1.8.0_181-b13 on Linux 3.10.0-862.3.2.el7.x86_64 +SQL CSV 13918 / 13979 0.8 1327.3 1.0X +SQL Json 10068 / 10068 1.0 960.1 1.4X +SQL Parquet Vectorized 1563 / 1564 6.7 149.0 8.9X +SQL Parquet MR 3835 / 3836 2.7 365.8 3.6X +ParquetReader Vectorized 1115 / 1118 9.4 106.4 12.5X +SQL ORC Vectorized 1172 / 1208 8.9 111.8 11.9X +SQL ORC Vectorized with copy 1630 / 1644 6.4 155.5 8.5X +SQL ORC MR 3708 / 3711 2.8 353.6 3.8X + +OpenJDK 64-Bit Server VM 1.8.0_191-b12 on Linux 3.10.0-862.3.2.el7.x86_64 Intel(R) Xeon(R) CPU E5-2670 v2 @ 2.50GHz -String with Nulls Scan: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative +String with Nulls Scan (50.0%): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------ -SQL CSV 15860 / 15877 0.7 1512.5 1.0X -SQL Json 7676 / 7688 1.4 732.0 2.1X -SQL Parquet Vectorized 1072 / 1084 9.8 102.2 14.8X -SQL Parquet MR 2890 / 2897 3.6 275.6 5.5X -ParquetReader Vectorized 1052 / 1053 10.0 100.4 15.1X -SQL ORC Vectorized 1248 / 1248 8.4 119.0 12.7X -SQL ORC Vectorized with copy 1627 / 1637 6.4 155.2 9.7X -SQL ORC MR 3365 / 3369 3.1 320.9 4.7X - -OpenJDK 64-Bit Server VM 1.8.0_181-b13 on Linux 3.10.0-862.3.2.el7.x86_64 +SQL CSV 13972 / 14043 0.8 1332.5 1.0X +SQL Json 7436 / 7469 1.4 709.1 1.9X +SQL Parquet Vectorized 1103 / 1112 9.5 105.2 12.7X +SQL Parquet MR 2841 / 2847 3.7 271.0 4.9X +ParquetReader Vectorized 992 / 1012 10.6 94.6 14.1X +SQL ORC Vectorized 1275 / 1349 8.2 121.6 11.0X +SQL ORC Vectorized with copy 1631 / 1644 6.4 155.5 8.6X +SQL ORC MR 3244 / 3259 3.2 309.3 4.3X + +OpenJDK 64-Bit Server VM 1.8.0_191-b12 on Linux 3.10.0-862.3.2.el7.x86_64 Intel(R) Xeon(R) CPU E5-2670 v2 @ 2.50GHz -String with Nulls Scan: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative +String with Nulls Scan (95.0%): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------ -SQL CSV 13401 / 13561 0.8 1278.1 1.0X -SQL Json 5253 / 5303 2.0 500.9 2.6X -SQL Parquet Vectorized 233 / 242 45.0 22.2 57.6X -SQL Parquet MR 1791 / 1796 5.9 170.8 7.5X -ParquetReader Vectorized 236 / 238 44.4 22.5 56.7X -SQL ORC Vectorized 453 / 473 23.2 43.2 29.6X -SQL ORC Vectorized with copy 573 / 577 18.3 54.7 23.4X -SQL ORC MR 1846 / 1850 5.7 176.0 7.3X +SQL CSV 11228 / 11244 0.9 1070.8 1.0X +SQL Json 5200 / 5247 2.0 495.9 2.2X +SQL Parquet Vectorized 238 / 242 44.1 22.7 47.2X +SQL Parquet MR 1730 / 1734 6.1 165.0 6.5X +ParquetReader Vectorized 237 / 238 44.3 22.6 47.4X +SQL ORC Vectorized 459 / 462 22.8 43.8 24.4X +SQL ORC Vectorized with copy 581 / 583 18.1 55.4 19.3X +SQL ORC MR 1767 / 1783 5.9 168.5 6.4X ================================================================================================ Single Column Scan From Wide Columns ================================================================================================ -OpenJDK 64-Bit Server VM 1.8.0_181-b13 on Linux 3.10.0-862.3.2.el7.x86_64 +OpenJDK 64-Bit Server VM 1.8.0_191-b12 on Linux 3.10.0-862.3.2.el7.x86_64 Intel(R) Xeon(R) CPU E5-2670 v2 @ 2.50GHz Single Column Scan from 10 columns: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------ -SQL CSV 3147 / 3148 0.3 3001.1 1.0X -SQL Json 2666 / 2693 0.4 2542.9 1.2X -SQL Parquet Vectorized 54 / 58 19.5 51.3 58.5X -SQL Parquet MR 220 / 353 4.8 209.9 14.3X -SQL ORC Vectorized 63 / 77 16.8 59.7 50.3X -SQL ORC Vectorized with copy 63 / 66 16.7 59.8 50.2X -SQL ORC MR 317 / 321 3.3 302.2 9.9X - -OpenJDK 64-Bit Server VM 1.8.0_181-b13 on Linux 3.10.0-862.3.2.el7.x86_64 +SQL CSV 3322 / 3356 0.3 3167.9 1.0X +SQL Json 2808 / 2843 0.4 2678.2 1.2X +SQL Parquet Vectorized 56 / 63 18.9 52.9 59.8X +SQL Parquet MR 215 / 219 4.9 205.4 15.4X +SQL ORC Vectorized 64 / 76 16.4 60.9 52.0X +SQL ORC Vectorized with copy 64 / 67 16.3 61.3 51.7X +SQL ORC MR 314 / 316 3.3 299.6 10.6X + +OpenJDK 64-Bit Server VM 1.8.0_191-b12 on Linux 3.10.0-862.3.2.el7.x86_64 Intel(R) Xeon(R) CPU E5-2670 v2 @ 2.50GHz Single Column Scan from 50 columns: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------ -SQL CSV 7902 / 7921 0.1 7536.2 1.0X -SQL Json 9467 / 9491 0.1 9028.6 0.8X -SQL Parquet Vectorized 73 / 79 14.3 69.8 108.0X -SQL Parquet MR 239 / 247 4.4 228.0 33.1X -SQL ORC Vectorized 78 / 84 13.4 74.6 101.0X -SQL ORC Vectorized with copy 78 / 88 13.4 74.4 101.3X -SQL ORC MR 910 / 918 1.2 867.6 8.7X - -OpenJDK 64-Bit Server VM 1.8.0_181-b13 on Linux 3.10.0-862.3.2.el7.x86_64 +SQL CSV 7978 / 7989 0.1 7608.5 1.0X +SQL Json 10294 / 10325 0.1 9816.9 0.8X +SQL Parquet Vectorized 72 / 85 14.5 69.0 110.3X +SQL Parquet MR 237 / 241 4.4 226.4 33.6X +SQL ORC Vectorized 82 / 92 12.7 78.5 97.0X +SQL ORC Vectorized with copy 82 / 88 12.7 78.5 97.0X +SQL ORC MR 900 / 909 1.2 858.5 8.9X + +OpenJDK 64-Bit Server VM 1.8.0_191-b12 on Linux 3.10.0-862.3.2.el7.x86_64 Intel(R) Xeon(R) CPU E5-2670 v2 @ 2.50GHz Single Column Scan from 100 columns: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------ -SQL CSV 13539 / 13543 0.1 12912.0 1.0X -SQL Json 17420 / 17446 0.1 16613.1 0.8X -SQL Parquet Vectorized 103 / 120 10.2 98.1 131.6X -SQL Parquet MR 250 / 258 4.2 238.9 54.1X -SQL ORC Vectorized 99 / 104 10.6 94.6 136.5X -SQL ORC Vectorized with copy 100 / 106 10.5 95.6 135.1X -SQL ORC MR 1653 / 1659 0.6 1576.3 8.2X +SQL CSV 13489 / 13508 0.1 12864.3 1.0X +SQL Json 18813 / 18827 0.1 17941.4 0.7X +SQL Parquet Vectorized 107 / 111 9.8 101.8 126.3X +SQL Parquet MR 275 / 286 3.8 262.3 49.0X +SQL ORC Vectorized 107 / 115 9.8 101.7 126.4X +SQL ORC Vectorized with copy 107 / 115 9.8 102.3 125.8X +SQL ORC MR 1659 / 1664 0.6 1582.3 8.1X diff --git a/sql/core/benchmarks/WideTableBenchmark-results.txt b/sql/core/benchmarks/WideTableBenchmark-results.txt index 3b41a3e036c4d..8c09f9ca11307 100644 --- a/sql/core/benchmarks/WideTableBenchmark-results.txt +++ b/sql/core/benchmarks/WideTableBenchmark-results.txt @@ -6,12 +6,12 @@ OpenJDK 64-Bit Server VM 1.8.0_191-b12 on Linux 3.10.0-862.3.2.el7.x86_64 Intel(R) Xeon(R) CPU E5-2670 v2 @ 2.50GHz projection on wide table: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------ -split threshold 10 38932 / 39307 0.0 37128.1 1.0X -split threshold 100 31991 / 32556 0.0 30508.8 1.2X -split threshold 1024 10993 / 11041 0.1 10483.5 3.5X -split threshold 2048 8959 / 8998 0.1 8543.8 4.3X -split threshold 4096 8116 / 8134 0.1 7739.8 4.8X -split threshold 8196 8069 / 8098 0.1 7695.5 4.8X -split threshold 65536 57068 / 57339 0.0 54424.3 0.7X +split threshold 10 40571 / 40937 0.0 38691.7 1.0X +split threshold 100 31116 / 31669 0.0 29674.6 1.3X +split threshold 1024 10077 / 10199 0.1 9609.7 4.0X +split threshold 2048 8654 / 8692 0.1 8253.2 4.7X +split threshold 4096 8006 / 8038 0.1 7634.7 5.1X +split threshold 8192 8069 / 8107 0.1 7695.3 5.0X +split threshold 65536 56973 / 57204 0.0 54333.7 0.7X diff --git a/sql/core/pom.xml b/sql/core/pom.xml index 95e98c5444721..ac5f1fc923e7d 100644 --- a/sql/core/pom.xml +++ b/sql/core/pom.xml @@ -21,12 +21,12 @@ 4.0.0 org.apache.spark - spark-parent_2.11 + spark-parent_2.12 3.0.0-SNAPSHOT ../../pom.xml - spark-sql_2.11 + spark-sql_2.12 jar Spark Project SQL http://spark.apache.org/ diff --git a/sql/core/src/main/java/org/apache/spark/api/java/function/FlatMapGroupsWithStateFunction.java b/sql/core/src/main/java/org/apache/spark/api/java/function/FlatMapGroupsWithStateFunction.java index 802949c0ddb60..d4e1d89491f43 100644 --- a/sql/core/src/main/java/org/apache/spark/api/java/function/FlatMapGroupsWithStateFunction.java +++ b/sql/core/src/main/java/org/apache/spark/api/java/function/FlatMapGroupsWithStateFunction.java @@ -20,8 +20,8 @@ import java.io.Serializable; import java.util.Iterator; +import org.apache.spark.annotation.Evolving; import org.apache.spark.annotation.Experimental; -import org.apache.spark.annotation.InterfaceStability; import org.apache.spark.sql.streaming.GroupState; /** @@ -33,7 +33,7 @@ * @since 2.1.1 */ @Experimental -@InterfaceStability.Evolving +@Evolving public interface FlatMapGroupsWithStateFunction extends Serializable { Iterator call(K key, Iterator values, GroupState state) throws Exception; } diff --git a/sql/core/src/main/java/org/apache/spark/api/java/function/MapGroupsWithStateFunction.java b/sql/core/src/main/java/org/apache/spark/api/java/function/MapGroupsWithStateFunction.java index 353e9886a8a57..f0abfde843cc5 100644 --- a/sql/core/src/main/java/org/apache/spark/api/java/function/MapGroupsWithStateFunction.java +++ b/sql/core/src/main/java/org/apache/spark/api/java/function/MapGroupsWithStateFunction.java @@ -20,8 +20,8 @@ import java.io.Serializable; import java.util.Iterator; +import org.apache.spark.annotation.Evolving; import org.apache.spark.annotation.Experimental; -import org.apache.spark.annotation.InterfaceStability; import org.apache.spark.sql.streaming.GroupState; /** @@ -32,7 +32,7 @@ * @since 2.1.1 */ @Experimental -@InterfaceStability.Evolving +@Evolving public interface MapGroupsWithStateFunction extends Serializable { R call(K key, Iterator values, GroupState state) throws Exception; } diff --git a/sql/core/src/main/java/org/apache/spark/sql/SaveMode.java b/sql/core/src/main/java/org/apache/spark/sql/SaveMode.java index 1c3c9794fb6bb..9cc073f53a3eb 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/SaveMode.java +++ b/sql/core/src/main/java/org/apache/spark/sql/SaveMode.java @@ -16,14 +16,14 @@ */ package org.apache.spark.sql; -import org.apache.spark.annotation.InterfaceStability; +import org.apache.spark.annotation.Stable; /** * SaveMode is used to specify the expected behavior of saving a DataFrame to a data source. * * @since 1.3.0 */ -@InterfaceStability.Stable +@Stable public enum SaveMode { /** * Append mode means that when saving a DataFrame to a data source, if data/table already exists, diff --git a/sql/core/src/main/java/org/apache/spark/sql/api/java/UDF0.java b/sql/core/src/main/java/org/apache/spark/sql/api/java/UDF0.java index 4eeb7be3f5abb..631d6eb1cfb03 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/api/java/UDF0.java +++ b/sql/core/src/main/java/org/apache/spark/sql/api/java/UDF0.java @@ -19,12 +19,12 @@ import java.io.Serializable; -import org.apache.spark.annotation.InterfaceStability; +import org.apache.spark.annotation.Stable; /** * A Spark SQL UDF that has 0 arguments. */ -@InterfaceStability.Stable +@Stable public interface UDF0 extends Serializable { R call() throws Exception; } diff --git a/sql/core/src/main/java/org/apache/spark/sql/api/java/UDF1.java b/sql/core/src/main/java/org/apache/spark/sql/api/java/UDF1.java index 1460daf27dc20..a5d01406edd8c 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/api/java/UDF1.java +++ b/sql/core/src/main/java/org/apache/spark/sql/api/java/UDF1.java @@ -19,12 +19,12 @@ import java.io.Serializable; -import org.apache.spark.annotation.InterfaceStability; +import org.apache.spark.annotation.Stable; /** * A Spark SQL UDF that has 1 arguments. */ -@InterfaceStability.Stable +@Stable public interface UDF1 extends Serializable { R call(T1 t1) throws Exception; } diff --git a/sql/core/src/main/java/org/apache/spark/sql/api/java/UDF10.java b/sql/core/src/main/java/org/apache/spark/sql/api/java/UDF10.java index 7c4f1e4897084..effe99e30b2a5 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/api/java/UDF10.java +++ b/sql/core/src/main/java/org/apache/spark/sql/api/java/UDF10.java @@ -19,12 +19,12 @@ import java.io.Serializable; -import org.apache.spark.annotation.InterfaceStability; +import org.apache.spark.annotation.Stable; /** * A Spark SQL UDF that has 10 arguments. */ -@InterfaceStability.Stable +@Stable public interface UDF10 extends Serializable { R call(T1 t1, T2 t2, T3 t3, T4 t4, T5 t5, T6 t6, T7 t7, T8 t8, T9 t9, T10 t10) throws Exception; } diff --git a/sql/core/src/main/java/org/apache/spark/sql/api/java/UDF11.java b/sql/core/src/main/java/org/apache/spark/sql/api/java/UDF11.java index 26a05106aebd6..e70b18b84b08f 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/api/java/UDF11.java +++ b/sql/core/src/main/java/org/apache/spark/sql/api/java/UDF11.java @@ -19,12 +19,12 @@ import java.io.Serializable; -import org.apache.spark.annotation.InterfaceStability; +import org.apache.spark.annotation.Stable; /** * A Spark SQL UDF that has 11 arguments. */ -@InterfaceStability.Stable +@Stable public interface UDF11 extends Serializable { R call(T1 t1, T2 t2, T3 t3, T4 t4, T5 t5, T6 t6, T7 t7, T8 t8, T9 t9, T10 t10, T11 t11) throws Exception; } diff --git a/sql/core/src/main/java/org/apache/spark/sql/api/java/UDF12.java b/sql/core/src/main/java/org/apache/spark/sql/api/java/UDF12.java index 8ef7a99042025..339feb34135e1 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/api/java/UDF12.java +++ b/sql/core/src/main/java/org/apache/spark/sql/api/java/UDF12.java @@ -19,12 +19,12 @@ import java.io.Serializable; -import org.apache.spark.annotation.InterfaceStability; +import org.apache.spark.annotation.Stable; /** * A Spark SQL UDF that has 12 arguments. */ -@InterfaceStability.Stable +@Stable public interface UDF12 extends Serializable { R call(T1 t1, T2 t2, T3 t3, T4 t4, T5 t5, T6 t6, T7 t7, T8 t8, T9 t9, T10 t10, T11 t11, T12 t12) throws Exception; } diff --git a/sql/core/src/main/java/org/apache/spark/sql/api/java/UDF13.java b/sql/core/src/main/java/org/apache/spark/sql/api/java/UDF13.java index 5c3b2ec1222e2..d346e5c908c6f 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/api/java/UDF13.java +++ b/sql/core/src/main/java/org/apache/spark/sql/api/java/UDF13.java @@ -19,12 +19,12 @@ import java.io.Serializable; -import org.apache.spark.annotation.InterfaceStability; +import org.apache.spark.annotation.Stable; /** * A Spark SQL UDF that has 13 arguments. */ -@InterfaceStability.Stable +@Stable public interface UDF13 extends Serializable { R call(T1 t1, T2 t2, T3 t3, T4 t4, T5 t5, T6 t6, T7 t7, T8 t8, T9 t9, T10 t10, T11 t11, T12 t12, T13 t13) throws Exception; } diff --git a/sql/core/src/main/java/org/apache/spark/sql/api/java/UDF14.java b/sql/core/src/main/java/org/apache/spark/sql/api/java/UDF14.java index 97e744d843466..d27f9f5270f4b 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/api/java/UDF14.java +++ b/sql/core/src/main/java/org/apache/spark/sql/api/java/UDF14.java @@ -19,12 +19,12 @@ import java.io.Serializable; -import org.apache.spark.annotation.InterfaceStability; +import org.apache.spark.annotation.Stable; /** * A Spark SQL UDF that has 14 arguments. */ -@InterfaceStability.Stable +@Stable public interface UDF14 extends Serializable { R call(T1 t1, T2 t2, T3 t3, T4 t4, T5 t5, T6 t6, T7 t7, T8 t8, T9 t9, T10 t10, T11 t11, T12 t12, T13 t13, T14 t14) throws Exception; } diff --git a/sql/core/src/main/java/org/apache/spark/sql/api/java/UDF15.java b/sql/core/src/main/java/org/apache/spark/sql/api/java/UDF15.java index 7ddbf914fc11a..b99b57a91d465 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/api/java/UDF15.java +++ b/sql/core/src/main/java/org/apache/spark/sql/api/java/UDF15.java @@ -19,12 +19,12 @@ import java.io.Serializable; -import org.apache.spark.annotation.InterfaceStability; +import org.apache.spark.annotation.Stable; /** * A Spark SQL UDF that has 15 arguments. */ -@InterfaceStability.Stable +@Stable public interface UDF15 extends Serializable { R call(T1 t1, T2 t2, T3 t3, T4 t4, T5 t5, T6 t6, T7 t7, T8 t8, T9 t9, T10 t10, T11 t11, T12 t12, T13 t13, T14 t14, T15 t15) throws Exception; } diff --git a/sql/core/src/main/java/org/apache/spark/sql/api/java/UDF16.java b/sql/core/src/main/java/org/apache/spark/sql/api/java/UDF16.java index 0ae5dc7195ad6..7899fc4b7ad65 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/api/java/UDF16.java +++ b/sql/core/src/main/java/org/apache/spark/sql/api/java/UDF16.java @@ -19,12 +19,12 @@ import java.io.Serializable; -import org.apache.spark.annotation.InterfaceStability; +import org.apache.spark.annotation.Stable; /** * A Spark SQL UDF that has 16 arguments. */ -@InterfaceStability.Stable +@Stable public interface UDF16 extends Serializable { R call(T1 t1, T2 t2, T3 t3, T4 t4, T5 t5, T6 t6, T7 t7, T8 t8, T9 t9, T10 t10, T11 t11, T12 t12, T13 t13, T14 t14, T15 t15, T16 t16) throws Exception; } diff --git a/sql/core/src/main/java/org/apache/spark/sql/api/java/UDF17.java b/sql/core/src/main/java/org/apache/spark/sql/api/java/UDF17.java index 03543a556c614..40a7e95724fc2 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/api/java/UDF17.java +++ b/sql/core/src/main/java/org/apache/spark/sql/api/java/UDF17.java @@ -19,12 +19,12 @@ import java.io.Serializable; -import org.apache.spark.annotation.InterfaceStability; +import org.apache.spark.annotation.Stable; /** * A Spark SQL UDF that has 17 arguments. */ -@InterfaceStability.Stable +@Stable public interface UDF17 extends Serializable { R call(T1 t1, T2 t2, T3 t3, T4 t4, T5 t5, T6 t6, T7 t7, T8 t8, T9 t9, T10 t10, T11 t11, T12 t12, T13 t13, T14 t14, T15 t15, T16 t16, T17 t17) throws Exception; } diff --git a/sql/core/src/main/java/org/apache/spark/sql/api/java/UDF18.java b/sql/core/src/main/java/org/apache/spark/sql/api/java/UDF18.java index 46740d3443916..47935a935891c 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/api/java/UDF18.java +++ b/sql/core/src/main/java/org/apache/spark/sql/api/java/UDF18.java @@ -19,12 +19,12 @@ import java.io.Serializable; -import org.apache.spark.annotation.InterfaceStability; +import org.apache.spark.annotation.Stable; /** * A Spark SQL UDF that has 18 arguments. */ -@InterfaceStability.Stable +@Stable public interface UDF18 extends Serializable { R call(T1 t1, T2 t2, T3 t3, T4 t4, T5 t5, T6 t6, T7 t7, T8 t8, T9 t9, T10 t10, T11 t11, T12 t12, T13 t13, T14 t14, T15 t15, T16 t16, T17 t17, T18 t18) throws Exception; } diff --git a/sql/core/src/main/java/org/apache/spark/sql/api/java/UDF19.java b/sql/core/src/main/java/org/apache/spark/sql/api/java/UDF19.java index 33fefd8ecaf1d..578b796ff03a3 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/api/java/UDF19.java +++ b/sql/core/src/main/java/org/apache/spark/sql/api/java/UDF19.java @@ -19,12 +19,12 @@ import java.io.Serializable; -import org.apache.spark.annotation.InterfaceStability; +import org.apache.spark.annotation.Stable; /** * A Spark SQL UDF that has 19 arguments. */ -@InterfaceStability.Stable +@Stable public interface UDF19 extends Serializable { R call(T1 t1, T2 t2, T3 t3, T4 t4, T5 t5, T6 t6, T7 t7, T8 t8, T9 t9, T10 t10, T11 t11, T12 t12, T13 t13, T14 t14, T15 t15, T16 t16, T17 t17, T18 t18, T19 t19) throws Exception; } diff --git a/sql/core/src/main/java/org/apache/spark/sql/api/java/UDF2.java b/sql/core/src/main/java/org/apache/spark/sql/api/java/UDF2.java index 9822f19217d76..2f856aa3cf630 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/api/java/UDF2.java +++ b/sql/core/src/main/java/org/apache/spark/sql/api/java/UDF2.java @@ -19,12 +19,12 @@ import java.io.Serializable; -import org.apache.spark.annotation.InterfaceStability; +import org.apache.spark.annotation.Stable; /** * A Spark SQL UDF that has 2 arguments. */ -@InterfaceStability.Stable +@Stable public interface UDF2 extends Serializable { R call(T1 t1, T2 t2) throws Exception; } diff --git a/sql/core/src/main/java/org/apache/spark/sql/api/java/UDF20.java b/sql/core/src/main/java/org/apache/spark/sql/api/java/UDF20.java index 8c5e90182da1c..aa8a9fa897040 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/api/java/UDF20.java +++ b/sql/core/src/main/java/org/apache/spark/sql/api/java/UDF20.java @@ -19,12 +19,12 @@ import java.io.Serializable; -import org.apache.spark.annotation.InterfaceStability; +import org.apache.spark.annotation.Stable; /** * A Spark SQL UDF that has 20 arguments. */ -@InterfaceStability.Stable +@Stable public interface UDF20 extends Serializable { R call(T1 t1, T2 t2, T3 t3, T4 t4, T5 t5, T6 t6, T7 t7, T8 t8, T9 t9, T10 t10, T11 t11, T12 t12, T13 t13, T14 t14, T15 t15, T16 t16, T17 t17, T18 t18, T19 t19, T20 t20) throws Exception; } diff --git a/sql/core/src/main/java/org/apache/spark/sql/api/java/UDF21.java b/sql/core/src/main/java/org/apache/spark/sql/api/java/UDF21.java index e3b09f5167cff..0fe52bce2eca2 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/api/java/UDF21.java +++ b/sql/core/src/main/java/org/apache/spark/sql/api/java/UDF21.java @@ -19,12 +19,12 @@ import java.io.Serializable; -import org.apache.spark.annotation.InterfaceStability; +import org.apache.spark.annotation.Stable; /** * A Spark SQL UDF that has 21 arguments. */ -@InterfaceStability.Stable +@Stable public interface UDF21 extends Serializable { R call(T1 t1, T2 t2, T3 t3, T4 t4, T5 t5, T6 t6, T7 t7, T8 t8, T9 t9, T10 t10, T11 t11, T12 t12, T13 t13, T14 t14, T15 t15, T16 t16, T17 t17, T18 t18, T19 t19, T20 t20, T21 t21) throws Exception; } diff --git a/sql/core/src/main/java/org/apache/spark/sql/api/java/UDF22.java b/sql/core/src/main/java/org/apache/spark/sql/api/java/UDF22.java index dc6cfa9097bab..69fd8ca422833 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/api/java/UDF22.java +++ b/sql/core/src/main/java/org/apache/spark/sql/api/java/UDF22.java @@ -19,12 +19,12 @@ import java.io.Serializable; -import org.apache.spark.annotation.InterfaceStability; +import org.apache.spark.annotation.Stable; /** * A Spark SQL UDF that has 22 arguments. */ -@InterfaceStability.Stable +@Stable public interface UDF22 extends Serializable { R call(T1 t1, T2 t2, T3 t3, T4 t4, T5 t5, T6 t6, T7 t7, T8 t8, T9 t9, T10 t10, T11 t11, T12 t12, T13 t13, T14 t14, T15 t15, T16 t16, T17 t17, T18 t18, T19 t19, T20 t20, T21 t21, T22 t22) throws Exception; } diff --git a/sql/core/src/main/java/org/apache/spark/sql/api/java/UDF3.java b/sql/core/src/main/java/org/apache/spark/sql/api/java/UDF3.java index 7c264b69ba195..84ffd655672a2 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/api/java/UDF3.java +++ b/sql/core/src/main/java/org/apache/spark/sql/api/java/UDF3.java @@ -19,12 +19,12 @@ import java.io.Serializable; -import org.apache.spark.annotation.InterfaceStability; +import org.apache.spark.annotation.Stable; /** * A Spark SQL UDF that has 3 arguments. */ -@InterfaceStability.Stable +@Stable public interface UDF3 extends Serializable { R call(T1 t1, T2 t2, T3 t3) throws Exception; } diff --git a/sql/core/src/main/java/org/apache/spark/sql/api/java/UDF4.java b/sql/core/src/main/java/org/apache/spark/sql/api/java/UDF4.java index 58df38fc3c911..dd2dc285c226d 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/api/java/UDF4.java +++ b/sql/core/src/main/java/org/apache/spark/sql/api/java/UDF4.java @@ -19,12 +19,12 @@ import java.io.Serializable; -import org.apache.spark.annotation.InterfaceStability; +import org.apache.spark.annotation.Stable; /** * A Spark SQL UDF that has 4 arguments. */ -@InterfaceStability.Stable +@Stable public interface UDF4 extends Serializable { R call(T1 t1, T2 t2, T3 t3, T4 t4) throws Exception; } diff --git a/sql/core/src/main/java/org/apache/spark/sql/api/java/UDF5.java b/sql/core/src/main/java/org/apache/spark/sql/api/java/UDF5.java index 4146f96e2eed5..795cc21c3f76e 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/api/java/UDF5.java +++ b/sql/core/src/main/java/org/apache/spark/sql/api/java/UDF5.java @@ -19,12 +19,12 @@ import java.io.Serializable; -import org.apache.spark.annotation.InterfaceStability; +import org.apache.spark.annotation.Stable; /** * A Spark SQL UDF that has 5 arguments. */ -@InterfaceStability.Stable +@Stable public interface UDF5 extends Serializable { R call(T1 t1, T2 t2, T3 t3, T4 t4, T5 t5) throws Exception; } diff --git a/sql/core/src/main/java/org/apache/spark/sql/api/java/UDF6.java b/sql/core/src/main/java/org/apache/spark/sql/api/java/UDF6.java index 25d39654c1095..a954684c3c9a9 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/api/java/UDF6.java +++ b/sql/core/src/main/java/org/apache/spark/sql/api/java/UDF6.java @@ -19,12 +19,12 @@ import java.io.Serializable; -import org.apache.spark.annotation.InterfaceStability; +import org.apache.spark.annotation.Stable; /** * A Spark SQL UDF that has 6 arguments. */ -@InterfaceStability.Stable +@Stable public interface UDF6 extends Serializable { R call(T1 t1, T2 t2, T3 t3, T4 t4, T5 t5, T6 t6) throws Exception; } diff --git a/sql/core/src/main/java/org/apache/spark/sql/api/java/UDF7.java b/sql/core/src/main/java/org/apache/spark/sql/api/java/UDF7.java index ce63b6a91adbb..03761f2c9ebbf 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/api/java/UDF7.java +++ b/sql/core/src/main/java/org/apache/spark/sql/api/java/UDF7.java @@ -19,12 +19,12 @@ import java.io.Serializable; -import org.apache.spark.annotation.InterfaceStability; +import org.apache.spark.annotation.Stable; /** * A Spark SQL UDF that has 7 arguments. */ -@InterfaceStability.Stable +@Stable public interface UDF7 extends Serializable { R call(T1 t1, T2 t2, T3 t3, T4 t4, T5 t5, T6 t6, T7 t7) throws Exception; } diff --git a/sql/core/src/main/java/org/apache/spark/sql/api/java/UDF8.java b/sql/core/src/main/java/org/apache/spark/sql/api/java/UDF8.java index 0e00209ef6b9f..8cd3583b2cbf0 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/api/java/UDF8.java +++ b/sql/core/src/main/java/org/apache/spark/sql/api/java/UDF8.java @@ -19,12 +19,12 @@ import java.io.Serializable; -import org.apache.spark.annotation.InterfaceStability; +import org.apache.spark.annotation.Stable; /** * A Spark SQL UDF that has 8 arguments. */ -@InterfaceStability.Stable +@Stable public interface UDF8 extends Serializable { R call(T1 t1, T2 t2, T3 t3, T4 t4, T5 t5, T6 t6, T7 t7, T8 t8) throws Exception; } diff --git a/sql/core/src/main/java/org/apache/spark/sql/api/java/UDF9.java b/sql/core/src/main/java/org/apache/spark/sql/api/java/UDF9.java index 077981bb3e3ee..78a7097791963 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/api/java/UDF9.java +++ b/sql/core/src/main/java/org/apache/spark/sql/api/java/UDF9.java @@ -19,12 +19,12 @@ import java.io.Serializable; -import org.apache.spark.annotation.InterfaceStability; +import org.apache.spark.annotation.Stable; /** * A Spark SQL UDF that has 9 arguments. */ -@InterfaceStability.Stable +@Stable public interface UDF9 extends Serializable { R call(T1 t1, T2 t2, T3 t3, T4 t4, T5 t5, T6 t6, T7 t7, T8 t8, T9 t9) throws Exception; } diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/SchemaColumnConvertNotSupportedException.java b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/SchemaColumnConvertNotSupportedException.java index 82a1169cbe7ae..7d1fbe64fc960 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/SchemaColumnConvertNotSupportedException.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/SchemaColumnConvertNotSupportedException.java @@ -17,12 +17,12 @@ package org.apache.spark.sql.execution.datasources; -import org.apache.spark.annotation.InterfaceStability; +import org.apache.spark.annotation.Unstable; /** * Exception thrown when the parquet reader find column type mismatches. */ -@InterfaceStability.Unstable +@Unstable public class SchemaColumnConvertNotSupportedException extends RuntimeException { /** diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/WritableColumnVector.java b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/WritableColumnVector.java index b0e119d658cb4..4f5e72c1326ac 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/WritableColumnVector.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/WritableColumnVector.java @@ -101,10 +101,11 @@ private void throwUnsupportedException(int requiredCapacity, Throwable cause) { String message = "Cannot reserve additional contiguous bytes in the vectorized reader (" + (requiredCapacity >= 0 ? "requested " + requiredCapacity + " bytes" : "integer overflow") + "). As a workaround, you can reduce the vectorized reader batch size, or disable the " + - "vectorized reader. For parquet file format, refer to " + + "vectorized reader, or disable " + SQLConf.BUCKETING_ENABLED().key() + " if you read " + + "from bucket table. For Parquet file format, refer to " + SQLConf.PARQUET_VECTORIZED_READER_BATCH_SIZE().key() + " (default " + SQLConf.PARQUET_VECTORIZED_READER_BATCH_SIZE().defaultValueString() + - ") and " + SQLConf.PARQUET_VECTORIZED_READER_ENABLED().key() + "; for orc file format, " + + ") and " + SQLConf.PARQUET_VECTORIZED_READER_ENABLED().key() + "; for ORC file format, " + "refer to " + SQLConf.ORC_VECTORIZED_READER_BATCH_SIZE().key() + " (default " + SQLConf.ORC_VECTORIZED_READER_BATCH_SIZE().defaultValueString() + ") and " + SQLConf.ORC_VECTORIZED_READER_ENABLED().key() + "."; diff --git a/sql/core/src/main/java/org/apache/spark/sql/expressions/javalang/typed.java b/sql/core/src/main/java/org/apache/spark/sql/expressions/javalang/typed.java index ec9c107b1c119..5a72f0c6a2555 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/expressions/javalang/typed.java +++ b/sql/core/src/main/java/org/apache/spark/sql/expressions/javalang/typed.java @@ -17,8 +17,8 @@ package org.apache.spark.sql.expressions.javalang; +import org.apache.spark.annotation.Evolving; import org.apache.spark.annotation.Experimental; -import org.apache.spark.annotation.InterfaceStability; import org.apache.spark.api.java.function.MapFunction; import org.apache.spark.sql.TypedColumn; import org.apache.spark.sql.execution.aggregate.TypedAverage; @@ -35,7 +35,7 @@ * @since 2.0.0 */ @Experimental -@InterfaceStability.Evolving +@Evolving public class typed { // Note: make sure to keep in sync with typed.scala diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/BatchReadSupportProvider.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/BatchReadSupportProvider.java index f403dc619e86c..2a4933d75e8d0 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/BatchReadSupportProvider.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/BatchReadSupportProvider.java @@ -17,7 +17,7 @@ package org.apache.spark.sql.sources.v2; -import org.apache.spark.annotation.InterfaceStability; +import org.apache.spark.annotation.Evolving; import org.apache.spark.sql.execution.datasources.v2.DataSourceV2Utils; import org.apache.spark.sql.sources.v2.reader.BatchReadSupport; import org.apache.spark.sql.types.StructType; @@ -29,7 +29,7 @@ * This interface is used to create {@link BatchReadSupport} instances when end users run * {@code SparkSession.read.format(...).option(...).load()}. */ -@InterfaceStability.Evolving +@Evolving public interface BatchReadSupportProvider extends DataSourceV2 { /** diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/BatchWriteSupportProvider.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/BatchWriteSupportProvider.java index bd10c3353bf12..df439e2c02fe3 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/BatchWriteSupportProvider.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/BatchWriteSupportProvider.java @@ -19,7 +19,7 @@ import java.util.Optional; -import org.apache.spark.annotation.InterfaceStability; +import org.apache.spark.annotation.Evolving; import org.apache.spark.sql.SaveMode; import org.apache.spark.sql.sources.v2.writer.BatchWriteSupport; import org.apache.spark.sql.types.StructType; @@ -31,7 +31,7 @@ * This interface is used to create {@link BatchWriteSupport} instances when end users run * {@code Dataset.write.format(...).option(...).save()}. */ -@InterfaceStability.Evolving +@Evolving public interface BatchWriteSupportProvider extends DataSourceV2 { /** diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/ContinuousReadSupportProvider.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/ContinuousReadSupportProvider.java index 824c290518acf..b4f2eb34a1560 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/ContinuousReadSupportProvider.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/ContinuousReadSupportProvider.java @@ -17,7 +17,7 @@ package org.apache.spark.sql.sources.v2; -import org.apache.spark.annotation.InterfaceStability; +import org.apache.spark.annotation.Evolving; import org.apache.spark.sql.execution.datasources.v2.DataSourceV2Utils; import org.apache.spark.sql.sources.v2.reader.streaming.ContinuousReadSupport; import org.apache.spark.sql.types.StructType; @@ -29,7 +29,7 @@ * This interface is used to create {@link ContinuousReadSupport} instances when end users run * {@code SparkSession.readStream.format(...).option(...).load()} with a continuous trigger. */ -@InterfaceStability.Evolving +@Evolving public interface ContinuousReadSupportProvider extends DataSourceV2 { /** diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/DataSourceOptions.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/DataSourceOptions.java index 83df3be747085..1c5e3a0cd31e7 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/DataSourceOptions.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/DataSourceOptions.java @@ -26,7 +26,7 @@ import com.fasterxml.jackson.databind.ObjectMapper; -import org.apache.spark.annotation.InterfaceStability; +import org.apache.spark.annotation.Evolving; /** * An immutable string-to-string map in which keys are case-insensitive. This is used to represent @@ -73,7 +73,7 @@ * * */ -@InterfaceStability.Evolving +@Evolving public class DataSourceOptions { private final Map keyLowerCasedMap; diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/DataSourceV2.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/DataSourceV2.java index 6e31e84bf6c72..eae7a45d1d446 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/DataSourceV2.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/DataSourceV2.java @@ -17,7 +17,7 @@ package org.apache.spark.sql.sources.v2; -import org.apache.spark.annotation.InterfaceStability; +import org.apache.spark.annotation.Evolving; /** * The base interface for data source v2. Implementations must have a public, 0-arg constructor. @@ -30,5 +30,5 @@ * If Spark fails to execute any methods in the implementations of this interface (by throwing an * exception), the read action will fail and no Spark job will be submitted. */ -@InterfaceStability.Evolving +@Evolving public interface DataSourceV2 {} diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/MicroBatchReadSupportProvider.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/MicroBatchReadSupportProvider.java index 61c08e7fa89df..c4d9ef88f607e 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/MicroBatchReadSupportProvider.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/MicroBatchReadSupportProvider.java @@ -17,7 +17,7 @@ package org.apache.spark.sql.sources.v2; -import org.apache.spark.annotation.InterfaceStability; +import org.apache.spark.annotation.Evolving; import org.apache.spark.sql.execution.datasources.v2.DataSourceV2Utils; import org.apache.spark.sql.sources.v2.reader.streaming.MicroBatchReadSupport; import org.apache.spark.sql.types.StructType; @@ -29,7 +29,7 @@ * This interface is used to create {@link MicroBatchReadSupport} instances when end users run * {@code SparkSession.readStream.format(...).option(...).load()} with a micro-batch trigger. */ -@InterfaceStability.Evolving +@Evolving public interface MicroBatchReadSupportProvider extends DataSourceV2 { /** diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/SessionConfigSupport.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/SessionConfigSupport.java index bbe430e299261..c00abd9b685b5 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/SessionConfigSupport.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/SessionConfigSupport.java @@ -17,14 +17,14 @@ package org.apache.spark.sql.sources.v2; -import org.apache.spark.annotation.InterfaceStability; +import org.apache.spark.annotation.Evolving; /** * A mix-in interface for {@link DataSourceV2}. Data sources can implement this interface to * propagate session configs with the specified key-prefix to all data source operations in this * session. */ -@InterfaceStability.Evolving +@Evolving public interface SessionConfigSupport extends DataSourceV2 { /** diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/StreamingWriteSupportProvider.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/StreamingWriteSupportProvider.java index f9ca85d8089b4..8ac9c51750865 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/StreamingWriteSupportProvider.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/StreamingWriteSupportProvider.java @@ -17,7 +17,7 @@ package org.apache.spark.sql.sources.v2; -import org.apache.spark.annotation.InterfaceStability; +import org.apache.spark.annotation.Evolving; import org.apache.spark.sql.execution.streaming.BaseStreamingSink; import org.apache.spark.sql.sources.v2.writer.streaming.StreamingWriteSupport; import org.apache.spark.sql.streaming.OutputMode; @@ -30,7 +30,7 @@ * This interface is used to create {@link StreamingWriteSupport} instances when end users run * {@code Dataset.writeStream.format(...).option(...).start()}. */ -@InterfaceStability.Evolving +@Evolving public interface StreamingWriteSupportProvider extends DataSourceV2, BaseStreamingSink { /** diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/BatchReadSupport.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/BatchReadSupport.java index 452ee86675b42..518a8b03a2c6e 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/BatchReadSupport.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/BatchReadSupport.java @@ -17,7 +17,7 @@ package org.apache.spark.sql.sources.v2.reader; -import org.apache.spark.annotation.InterfaceStability; +import org.apache.spark.annotation.Evolving; /** * An interface that defines how to load the data from data source for batch processing. @@ -29,7 +29,7 @@ * {@link ScanConfig}. The {@link ScanConfig} will be used to create input partitions and reader * factory to scan data from the data source with a Spark job. */ -@InterfaceStability.Evolving +@Evolving public interface BatchReadSupport extends ReadSupport { /** diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/InputPartition.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/InputPartition.java index 95c30de907e44..5f5248084bad6 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/InputPartition.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/InputPartition.java @@ -19,7 +19,7 @@ import java.io.Serializable; -import org.apache.spark.annotation.InterfaceStability; +import org.apache.spark.annotation.Evolving; /** * A serializable representation of an input partition returned by @@ -32,7 +32,7 @@ * the actual reading. So {@link InputPartition} must be serializable while {@link PartitionReader} * doesn't need to be. */ -@InterfaceStability.Evolving +@Evolving public interface InputPartition extends Serializable { /** diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/PartitionReader.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/PartitionReader.java index 04ff8d0a19fc3..2945925959538 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/PartitionReader.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/PartitionReader.java @@ -20,7 +20,7 @@ import java.io.Closeable; import java.io.IOException; -import org.apache.spark.annotation.InterfaceStability; +import org.apache.spark.annotation.Evolving; /** * A partition reader returned by {@link PartitionReaderFactory#createReader(InputPartition)} or @@ -32,7 +32,7 @@ * data sources(whose {@link PartitionReaderFactory#supportColumnarReads(InputPartition)} * returns true). */ -@InterfaceStability.Evolving +@Evolving public interface PartitionReader extends Closeable { /** diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/PartitionReaderFactory.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/PartitionReaderFactory.java index f35de9310eee3..97f4a473953fc 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/PartitionReaderFactory.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/PartitionReaderFactory.java @@ -19,7 +19,7 @@ import java.io.Serializable; -import org.apache.spark.annotation.InterfaceStability; +import org.apache.spark.annotation.Evolving; import org.apache.spark.sql.catalyst.InternalRow; import org.apache.spark.sql.vectorized.ColumnarBatch; @@ -30,7 +30,7 @@ * {@link PartitionReader} (by throwing an exception), corresponding Spark task would fail and * get retried until hitting the maximum retry times. */ -@InterfaceStability.Evolving +@Evolving public interface PartitionReaderFactory extends Serializable { /** diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/ReadSupport.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/ReadSupport.java index a58ddb288f1ed..b1f610a82e8a2 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/ReadSupport.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/ReadSupport.java @@ -17,7 +17,7 @@ package org.apache.spark.sql.sources.v2.reader; -import org.apache.spark.annotation.InterfaceStability; +import org.apache.spark.annotation.Evolving; import org.apache.spark.sql.types.StructType; /** @@ -27,7 +27,7 @@ * If Spark fails to execute any methods in the implementations of this interface (by throwing an * exception), the read action will fail and no Spark job will be submitted. */ -@InterfaceStability.Evolving +@Evolving public interface ReadSupport { /** diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/ScanConfig.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/ScanConfig.java index 7462ce2820585..a69872a527746 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/ScanConfig.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/ScanConfig.java @@ -17,7 +17,7 @@ package org.apache.spark.sql.sources.v2.reader; -import org.apache.spark.annotation.InterfaceStability; +import org.apache.spark.annotation.Evolving; import org.apache.spark.sql.types.StructType; /** @@ -31,7 +31,7 @@ * {@link SupportsReportStatistics#estimateStatistics(ScanConfig)}, implementations mostly need to * cast the input {@link ScanConfig} to the concrete {@link ScanConfig} class of the data source. */ -@InterfaceStability.Evolving +@Evolving public interface ScanConfig { /** diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/ScanConfigBuilder.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/ScanConfigBuilder.java index 4c0eedfddfe22..4922962f70655 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/ScanConfigBuilder.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/ScanConfigBuilder.java @@ -17,14 +17,14 @@ package org.apache.spark.sql.sources.v2.reader; -import org.apache.spark.annotation.InterfaceStability; +import org.apache.spark.annotation.Evolving; /** * An interface for building the {@link ScanConfig}. Implementations can mixin those * SupportsPushDownXYZ interfaces to do operator pushdown, and keep the operator pushdown result in * the returned {@link ScanConfig}. */ -@InterfaceStability.Evolving +@Evolving public interface ScanConfigBuilder { ScanConfig build(); } diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/Statistics.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/Statistics.java index 44799c7d49137..14776f37fed46 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/Statistics.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/Statistics.java @@ -19,13 +19,13 @@ import java.util.OptionalLong; -import org.apache.spark.annotation.InterfaceStability; +import org.apache.spark.annotation.Evolving; /** * An interface to represent statistics for a data source, which is returned by * {@link SupportsReportStatistics#estimateStatistics(ScanConfig)}. */ -@InterfaceStability.Evolving +@Evolving public interface Statistics { OptionalLong sizeInBytes(); OptionalLong numRows(); diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/SupportsPushDownFilters.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/SupportsPushDownFilters.java index 5e7985f645a06..3a89baa1b44c2 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/SupportsPushDownFilters.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/SupportsPushDownFilters.java @@ -17,14 +17,14 @@ package org.apache.spark.sql.sources.v2.reader; -import org.apache.spark.annotation.InterfaceStability; +import org.apache.spark.annotation.Evolving; import org.apache.spark.sql.sources.Filter; /** * A mix-in interface for {@link ScanConfigBuilder}. Data sources can implement this interface to * push down filters to the data source and reduce the size of the data to be read. */ -@InterfaceStability.Evolving +@Evolving public interface SupportsPushDownFilters extends ScanConfigBuilder { /** diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/SupportsPushDownRequiredColumns.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/SupportsPushDownRequiredColumns.java index edb164937d6ef..1934763224881 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/SupportsPushDownRequiredColumns.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/SupportsPushDownRequiredColumns.java @@ -17,7 +17,7 @@ package org.apache.spark.sql.sources.v2.reader; -import org.apache.spark.annotation.InterfaceStability; +import org.apache.spark.annotation.Evolving; import org.apache.spark.sql.types.StructType; /** @@ -25,7 +25,7 @@ * interface to push down required columns to the data source and only read these columns during * scan to reduce the size of the data to be read. */ -@InterfaceStability.Evolving +@Evolving public interface SupportsPushDownRequiredColumns extends ScanConfigBuilder { /** diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/SupportsReportPartitioning.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/SupportsReportPartitioning.java index db62cd4515362..0335c7775c2af 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/SupportsReportPartitioning.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/SupportsReportPartitioning.java @@ -17,7 +17,7 @@ package org.apache.spark.sql.sources.v2.reader; -import org.apache.spark.annotation.InterfaceStability; +import org.apache.spark.annotation.Evolving; import org.apache.spark.sql.sources.v2.reader.partitioning.Partitioning; /** @@ -27,7 +27,7 @@ * Note that, when a {@link ReadSupport} implementation creates exactly one {@link InputPartition}, * Spark may avoid adding a shuffle even if the reader does not implement this interface. */ -@InterfaceStability.Evolving +@Evolving public interface SupportsReportPartitioning extends ReadSupport { /** diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/SupportsReportStatistics.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/SupportsReportStatistics.java index 1831488ba096f..917372cdd25b3 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/SupportsReportStatistics.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/SupportsReportStatistics.java @@ -17,7 +17,7 @@ package org.apache.spark.sql.sources.v2.reader; -import org.apache.spark.annotation.InterfaceStability; +import org.apache.spark.annotation.Evolving; /** * A mix in interface for {@link BatchReadSupport}. Data sources can implement this interface to @@ -27,7 +27,7 @@ * data source. Implementations that return more accurate statistics based on pushed operators will * not improve query performance until the planner can push operators before getting stats. */ -@InterfaceStability.Evolving +@Evolving public interface SupportsReportStatistics extends ReadSupport { /** diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/partitioning/ClusteredDistribution.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/partitioning/ClusteredDistribution.java index 6764d4b7665c7..1cdc02f5736b1 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/partitioning/ClusteredDistribution.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/partitioning/ClusteredDistribution.java @@ -17,7 +17,7 @@ package org.apache.spark.sql.sources.v2.reader.partitioning; -import org.apache.spark.annotation.InterfaceStability; +import org.apache.spark.annotation.Evolving; import org.apache.spark.sql.sources.v2.reader.PartitionReader; /** @@ -25,7 +25,7 @@ * share the same values for the {@link #clusteredColumns} will be produced by the same * {@link PartitionReader}. */ -@InterfaceStability.Evolving +@Evolving public class ClusteredDistribution implements Distribution { /** diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/partitioning/Distribution.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/partitioning/Distribution.java index 364a3f553923c..02b0e68974919 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/partitioning/Distribution.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/partitioning/Distribution.java @@ -17,7 +17,7 @@ package org.apache.spark.sql.sources.v2.reader.partitioning; -import org.apache.spark.annotation.InterfaceStability; +import org.apache.spark.annotation.Evolving; import org.apache.spark.sql.sources.v2.reader.PartitionReader; /** @@ -37,5 +37,5 @@ *
  • {@link ClusteredDistribution}
  • * */ -@InterfaceStability.Evolving +@Evolving public interface Distribution {} diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/partitioning/Partitioning.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/partitioning/Partitioning.java index fb0b6f1df43bb..c9a00262c1287 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/partitioning/Partitioning.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/partitioning/Partitioning.java @@ -17,7 +17,7 @@ package org.apache.spark.sql.sources.v2.reader.partitioning; -import org.apache.spark.annotation.InterfaceStability; +import org.apache.spark.annotation.Evolving; import org.apache.spark.sql.sources.v2.reader.InputPartition; import org.apache.spark.sql.sources.v2.reader.ScanConfig; import org.apache.spark.sql.sources.v2.reader.SupportsReportPartitioning; @@ -28,7 +28,7 @@ * like a snapshot. Once created, it should be deterministic and always report the same number of * partitions and the same "satisfy" result for a certain distribution. */ -@InterfaceStability.Evolving +@Evolving public interface Partitioning { /** diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/streaming/ContinuousPartitionReader.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/streaming/ContinuousPartitionReader.java index 9101c8a44d34e..c7f6fce6e81af 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/streaming/ContinuousPartitionReader.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/streaming/ContinuousPartitionReader.java @@ -17,13 +17,13 @@ package org.apache.spark.sql.sources.v2.reader.streaming; -import org.apache.spark.annotation.InterfaceStability; +import org.apache.spark.annotation.Evolving; import org.apache.spark.sql.sources.v2.reader.PartitionReader; /** * A variation on {@link PartitionReader} for use with continuous streaming processing. */ -@InterfaceStability.Evolving +@Evolving public interface ContinuousPartitionReader extends PartitionReader { /** diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/streaming/ContinuousPartitionReaderFactory.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/streaming/ContinuousPartitionReaderFactory.java index 2d9f1ca1686a1..41195befe5e57 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/streaming/ContinuousPartitionReaderFactory.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/streaming/ContinuousPartitionReaderFactory.java @@ -17,7 +17,7 @@ package org.apache.spark.sql.sources.v2.reader.streaming; -import org.apache.spark.annotation.InterfaceStability; +import org.apache.spark.annotation.Evolving; import org.apache.spark.sql.catalyst.InternalRow; import org.apache.spark.sql.sources.v2.reader.InputPartition; import org.apache.spark.sql.sources.v2.reader.PartitionReaderFactory; @@ -28,7 +28,7 @@ * instead of {@link org.apache.spark.sql.sources.v2.reader.PartitionReader}. It's used for * continuous streaming processing. */ -@InterfaceStability.Evolving +@Evolving public interface ContinuousPartitionReaderFactory extends PartitionReaderFactory { @Override ContinuousPartitionReader createReader(InputPartition partition); diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/streaming/ContinuousReadSupport.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/streaming/ContinuousReadSupport.java index 9a3ad2eb8a801..2b784ac0e9f35 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/streaming/ContinuousReadSupport.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/streaming/ContinuousReadSupport.java @@ -17,7 +17,7 @@ package org.apache.spark.sql.sources.v2.reader.streaming; -import org.apache.spark.annotation.InterfaceStability; +import org.apache.spark.annotation.Evolving; import org.apache.spark.sql.execution.streaming.BaseStreamingSource; import org.apache.spark.sql.sources.v2.reader.InputPartition; import org.apache.spark.sql.sources.v2.reader.ScanConfig; @@ -36,7 +36,7 @@ * {@link #stop()} will be called when the streaming execution is completed. Note that a single * query may have multiple executions due to restart or failure recovery. */ -@InterfaceStability.Evolving +@Evolving public interface ContinuousReadSupport extends StreamingReadSupport, BaseStreamingSource { /** diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/streaming/MicroBatchReadSupport.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/streaming/MicroBatchReadSupport.java index edb0db11bff2c..f56066c639388 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/streaming/MicroBatchReadSupport.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/streaming/MicroBatchReadSupport.java @@ -17,7 +17,7 @@ package org.apache.spark.sql.sources.v2.reader.streaming; -import org.apache.spark.annotation.InterfaceStability; +import org.apache.spark.annotation.Evolving; import org.apache.spark.sql.execution.streaming.BaseStreamingSource; import org.apache.spark.sql.sources.v2.reader.*; @@ -33,7 +33,7 @@ * will be called when the streaming execution is completed. Note that a single query may have * multiple executions due to restart or failure recovery. */ -@InterfaceStability.Evolving +@Evolving public interface MicroBatchReadSupport extends StreamingReadSupport, BaseStreamingSource { /** diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/streaming/Offset.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/streaming/Offset.java index 6cf27734867cb..6104175d2c9e3 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/streaming/Offset.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/streaming/Offset.java @@ -17,7 +17,7 @@ package org.apache.spark.sql.sources.v2.reader.streaming; -import org.apache.spark.annotation.InterfaceStability; +import org.apache.spark.annotation.Evolving; /** * An abstract representation of progress through a {@link MicroBatchReadSupport} or @@ -30,7 +30,7 @@ * maintain compatibility with DataSource V1 APIs. This extension will be removed once we * get rid of V1 completely. */ -@InterfaceStability.Evolving +@Evolving public abstract class Offset extends org.apache.spark.sql.execution.streaming.Offset { /** * A JSON-serialized representation of an Offset that is diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/streaming/PartitionOffset.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/streaming/PartitionOffset.java index 383e73db6762b..2c97d924a0629 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/streaming/PartitionOffset.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/streaming/PartitionOffset.java @@ -19,7 +19,7 @@ import java.io.Serializable; -import org.apache.spark.annotation.InterfaceStability; +import org.apache.spark.annotation.Evolving; /** * Used for per-partition offsets in continuous processing. ContinuousReader implementations will @@ -27,6 +27,6 @@ * * These offsets must be serializable. */ -@InterfaceStability.Evolving +@Evolving public interface PartitionOffset extends Serializable { } diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/BatchWriteSupport.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/BatchWriteSupport.java index 0ec9e05d6a02b..efe1ac4f78db1 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/BatchWriteSupport.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/BatchWriteSupport.java @@ -17,7 +17,7 @@ package org.apache.spark.sql.sources.v2.writer; -import org.apache.spark.annotation.InterfaceStability; +import org.apache.spark.annotation.Evolving; /** * An interface that defines how to write the data to data source for batch processing. @@ -37,7 +37,7 @@ * * Please refer to the documentation of commit/abort methods for detailed specifications. */ -@InterfaceStability.Evolving +@Evolving public interface BatchWriteSupport { /** diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/DataWriter.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/DataWriter.java index 5fb067966ee67..d142ee523ef9f 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/DataWriter.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/DataWriter.java @@ -19,7 +19,7 @@ import java.io.IOException; -import org.apache.spark.annotation.InterfaceStability; +import org.apache.spark.annotation.Evolving; /** * A data writer returned by {@link DataWriterFactory#createWriter(int, long)} and is @@ -55,7 +55,7 @@ * * Note that, Currently the type `T` can only be {@link org.apache.spark.sql.catalyst.InternalRow}. */ -@InterfaceStability.Evolving +@Evolving public interface DataWriter { /** diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/DataWriterFactory.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/DataWriterFactory.java index 19a36dd232456..65105f46b82d5 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/DataWriterFactory.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/DataWriterFactory.java @@ -20,7 +20,7 @@ import java.io.Serializable; import org.apache.spark.TaskContext; -import org.apache.spark.annotation.InterfaceStability; +import org.apache.spark.annotation.Evolving; import org.apache.spark.sql.catalyst.InternalRow; /** @@ -31,7 +31,7 @@ * will be created on executors and do the actual writing. So this interface must be * serializable and {@link DataWriter} doesn't need to be. */ -@InterfaceStability.Evolving +@Evolving public interface DataWriterFactory extends Serializable { /** diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/WriterCommitMessage.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/WriterCommitMessage.java index 123335c414e9f..9216e34399092 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/WriterCommitMessage.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/WriterCommitMessage.java @@ -19,8 +19,8 @@ import java.io.Serializable; +import org.apache.spark.annotation.Evolving; import org.apache.spark.sql.sources.v2.writer.streaming.StreamingWriteSupport; -import org.apache.spark.annotation.InterfaceStability; /** * A commit message returned by {@link DataWriter#commit()} and will be sent back to the driver side @@ -30,5 +30,5 @@ * This is an empty interface, data sources should define their own message class and use it when * generating messages at executor side and handling the messages at driver side. */ -@InterfaceStability.Evolving +@Evolving public interface WriterCommitMessage extends Serializable {} diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/streaming/StreamingDataWriterFactory.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/streaming/StreamingDataWriterFactory.java index a4da24fc5ae68..7d3d21cb2b637 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/streaming/StreamingDataWriterFactory.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/streaming/StreamingDataWriterFactory.java @@ -20,7 +20,7 @@ import java.io.Serializable; import org.apache.spark.TaskContext; -import org.apache.spark.annotation.InterfaceStability; +import org.apache.spark.annotation.Evolving; import org.apache.spark.sql.catalyst.InternalRow; import org.apache.spark.sql.sources.v2.writer.DataWriter; @@ -33,7 +33,7 @@ * will be created on executors and do the actual writing. So this interface must be * serializable and {@link DataWriter} doesn't need to be. */ -@InterfaceStability.Evolving +@Evolving public interface StreamingDataWriterFactory extends Serializable { /** diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/streaming/StreamingWriteSupport.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/streaming/StreamingWriteSupport.java index 3fdfac5e1c84a..84cfbf2dda483 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/streaming/StreamingWriteSupport.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/streaming/StreamingWriteSupport.java @@ -17,7 +17,7 @@ package org.apache.spark.sql.sources.v2.writer.streaming; -import org.apache.spark.annotation.InterfaceStability; +import org.apache.spark.annotation.Evolving; import org.apache.spark.sql.sources.v2.writer.DataWriter; import org.apache.spark.sql.sources.v2.writer.WriterCommitMessage; @@ -27,7 +27,7 @@ * Streaming queries are divided into intervals of data called epochs, with a monotonically * increasing numeric ID. This writer handles commits and aborts for each successive epoch. */ -@InterfaceStability.Evolving +@Evolving public interface StreamingWriteSupport { /** diff --git a/sql/core/src/main/java/org/apache/spark/sql/streaming/Trigger.java b/sql/core/src/main/java/org/apache/spark/sql/streaming/Trigger.java index 5371a23230c98..fd6f7be2abc5a 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/streaming/Trigger.java +++ b/sql/core/src/main/java/org/apache/spark/sql/streaming/Trigger.java @@ -19,9 +19,9 @@ import java.util.concurrent.TimeUnit; +import org.apache.spark.annotation.Evolving; import scala.concurrent.duration.Duration; -import org.apache.spark.annotation.InterfaceStability; import org.apache.spark.sql.execution.streaming.continuous.ContinuousTrigger; import org.apache.spark.sql.execution.streaming.OneTimeTrigger$; @@ -30,7 +30,7 @@ * * @since 2.0.0 */ -@InterfaceStability.Evolving +@Evolving public class Trigger { /** diff --git a/sql/core/src/main/java/org/apache/spark/sql/vectorized/ArrowColumnVector.java b/sql/core/src/main/java/org/apache/spark/sql/vectorized/ArrowColumnVector.java index 5f58b031f6aef..906e9bc26ef53 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/vectorized/ArrowColumnVector.java +++ b/sql/core/src/main/java/org/apache/spark/sql/vectorized/ArrowColumnVector.java @@ -22,7 +22,7 @@ import org.apache.arrow.vector.complex.*; import org.apache.arrow.vector.holders.NullableVarCharHolder; -import org.apache.spark.annotation.InterfaceStability; +import org.apache.spark.annotation.Evolving; import org.apache.spark.sql.execution.arrow.ArrowUtils; import org.apache.spark.sql.types.*; import org.apache.spark.unsafe.types.UTF8String; @@ -31,7 +31,7 @@ * A column vector backed by Apache Arrow. Currently calendar interval type and map type are not * supported. */ -@InterfaceStability.Evolving +@Evolving public final class ArrowColumnVector extends ColumnVector { private final ArrowVectorAccessor accessor; diff --git a/sql/core/src/main/java/org/apache/spark/sql/vectorized/ColumnVector.java b/sql/core/src/main/java/org/apache/spark/sql/vectorized/ColumnVector.java index ad99b450a4809..14caaeaedbe2b 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/vectorized/ColumnVector.java +++ b/sql/core/src/main/java/org/apache/spark/sql/vectorized/ColumnVector.java @@ -16,7 +16,7 @@ */ package org.apache.spark.sql.vectorized; -import org.apache.spark.annotation.InterfaceStability; +import org.apache.spark.annotation.Evolving; import org.apache.spark.sql.types.DataType; import org.apache.spark.sql.types.Decimal; import org.apache.spark.unsafe.types.CalendarInterval; @@ -47,7 +47,7 @@ * format. Since it is expected to reuse the ColumnVector instance while loading data, the storage * footprint is negligible. */ -@InterfaceStability.Evolving +@Evolving public abstract class ColumnVector implements AutoCloseable { /** diff --git a/sql/core/src/main/java/org/apache/spark/sql/vectorized/ColumnarArray.java b/sql/core/src/main/java/org/apache/spark/sql/vectorized/ColumnarArray.java index 72a192d089b9f..dd2bd789c26d0 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/vectorized/ColumnarArray.java +++ b/sql/core/src/main/java/org/apache/spark/sql/vectorized/ColumnarArray.java @@ -16,7 +16,7 @@ */ package org.apache.spark.sql.vectorized; -import org.apache.spark.annotation.InterfaceStability; +import org.apache.spark.annotation.Evolving; import org.apache.spark.sql.catalyst.util.ArrayData; import org.apache.spark.sql.types.*; import org.apache.spark.unsafe.types.CalendarInterval; @@ -25,7 +25,7 @@ /** * Array abstraction in {@link ColumnVector}. */ -@InterfaceStability.Evolving +@Evolving public final class ColumnarArray extends ArrayData { // The data for this array. This array contains elements from // data[offset] to data[offset + length). diff --git a/sql/core/src/main/java/org/apache/spark/sql/vectorized/ColumnarBatch.java b/sql/core/src/main/java/org/apache/spark/sql/vectorized/ColumnarBatch.java index d206c1df42abb..07546a54013ec 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/vectorized/ColumnarBatch.java +++ b/sql/core/src/main/java/org/apache/spark/sql/vectorized/ColumnarBatch.java @@ -18,7 +18,7 @@ import java.util.*; -import org.apache.spark.annotation.InterfaceStability; +import org.apache.spark.annotation.Evolving; import org.apache.spark.sql.catalyst.InternalRow; import org.apache.spark.sql.execution.vectorized.MutableColumnarRow; @@ -27,7 +27,7 @@ * batch so that Spark can access the data row by row. Instance of it is meant to be reused during * the entire data loading process. */ -@InterfaceStability.Evolving +@Evolving public final class ColumnarBatch { private int numRows; private final ColumnVector[] columns; diff --git a/sql/core/src/main/java/org/apache/spark/sql/vectorized/ColumnarRow.java b/sql/core/src/main/java/org/apache/spark/sql/vectorized/ColumnarRow.java index f2f2279590023..4b9d3c5f59915 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/vectorized/ColumnarRow.java +++ b/sql/core/src/main/java/org/apache/spark/sql/vectorized/ColumnarRow.java @@ -16,7 +16,7 @@ */ package org.apache.spark.sql.vectorized; -import org.apache.spark.annotation.InterfaceStability; +import org.apache.spark.annotation.Evolving; import org.apache.spark.sql.catalyst.InternalRow; import org.apache.spark.sql.catalyst.expressions.GenericInternalRow; import org.apache.spark.sql.types.*; @@ -26,7 +26,7 @@ /** * Row abstraction in {@link ColumnVector}. */ -@InterfaceStability.Evolving +@Evolving public final class ColumnarRow extends InternalRow { // The data for this row. // E.g. the value of 3rd int field is `data.getChild(3).getInt(rowId)`. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Column.scala b/sql/core/src/main/scala/org/apache/spark/sql/Column.scala index a9a19aa8a1001..5a408b29f9337 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Column.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Column.scala @@ -20,7 +20,7 @@ package org.apache.spark.sql import scala.collection.JavaConverters._ import scala.language.implicitConversions -import org.apache.spark.annotation.InterfaceStability +import org.apache.spark.annotation.Stable import org.apache.spark.internal.Logging import org.apache.spark.sql.catalyst.analysis._ import org.apache.spark.sql.catalyst.encoders.{encoderFor, ExpressionEncoder} @@ -60,7 +60,7 @@ private[sql] object Column { * * @since 1.6.0 */ -@InterfaceStability.Stable +@Stable class TypedColumn[-T, U]( expr: Expression, private[sql] val encoder: ExpressionEncoder[U]) @@ -74,6 +74,9 @@ class TypedColumn[-T, U]( inputEncoder: ExpressionEncoder[_], inputAttributes: Seq[Attribute]): TypedColumn[T, U] = { val unresolvedDeserializer = UnresolvedDeserializer(inputEncoder.deserializer, inputAttributes) + + // This only inserts inputs into typed aggregate expressions. For untyped aggregate expressions, + // the resolving is handled in the analyzer directly. val newExpr = expr transform { case ta: TypedAggregateExpression if ta.inputDeserializer.isEmpty => ta.withInputInfo( @@ -127,7 +130,7 @@ class TypedColumn[-T, U]( * * @since 1.3.0 */ -@InterfaceStability.Stable +@Stable class Column(val expr: Expression) extends Logging { def this(name: String) = this(name match { @@ -1224,7 +1227,7 @@ class Column(val expr: Expression) extends Logging { * * @since 1.3.0 */ -@InterfaceStability.Stable +@Stable class ColumnName(name: String) extends Column(name) { /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameNaFunctions.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameNaFunctions.scala index 5288907b7d7ff..53e9f810d7c85 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameNaFunctions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameNaFunctions.scala @@ -22,18 +22,17 @@ import java.util.Locale import scala.collection.JavaConverters._ -import org.apache.spark.annotation.InterfaceStability +import org.apache.spark.annotation.Stable import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.functions._ import org.apache.spark.sql.types._ - /** * Functionality for working with missing data in `DataFrame`s. * * @since 1.3.1 */ -@InterfaceStability.Stable +@Stable final class DataFrameNaFunctions private[sql](df: DataFrame) { /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala index 95c97e5c9433c..da88598eed061 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala @@ -24,11 +24,12 @@ import scala.collection.JavaConverters._ import com.fasterxml.jackson.databind.ObjectMapper import org.apache.spark.Partition -import org.apache.spark.annotation.InterfaceStability +import org.apache.spark.annotation.Stable import org.apache.spark.api.java.JavaRDD import org.apache.spark.internal.Logging import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.csv.{CSVHeaderChecker, CSVOptions, UnivocityParser} +import org.apache.spark.sql.catalyst.expressions.ExprUtils import org.apache.spark.sql.catalyst.json.{CreateJacksonParser, JacksonParser, JSONOptions} import org.apache.spark.sql.catalyst.util.FailureSafeParser import org.apache.spark.sql.execution.command.DDLUtils @@ -48,7 +49,7 @@ import org.apache.spark.unsafe.types.UTF8String * * @since 1.4.0 */ -@InterfaceStability.Stable +@Stable class DataFrameReader private[sql](sparkSession: SparkSession) extends Logging { /** @@ -194,7 +195,7 @@ class DataFrameReader private[sql](sparkSession: SparkSession) extends Logging { val cls = DataSource.lookupDataSource(source, sparkSession.sessionState.conf) if (classOf[DataSourceV2].isAssignableFrom(cls)) { - val ds = cls.newInstance().asInstanceOf[DataSourceV2] + val ds = cls.getConstructor().newInstance().asInstanceOf[DataSourceV2] if (ds.isInstanceOf[BatchReadSupportProvider]) { val sessionOptions = DataSourceV2Utils.extractSessionConfigs( ds = ds, conf = sparkSession.sessionState.conf) @@ -384,6 +385,8 @@ class DataFrameReader private[sql](sparkSession: SparkSession) extends Logging { * for schema inferring. *
  • `dropFieldIfAllNull` (default `false`): whether to ignore column of all null values or * empty array/struct during schema inference.
  • + *
  • `locale` (default is `en-US`): sets a locale as language tag in IETF BCP 47 format. + * For instance, this is used while parsing dates and timestamps.
  • * * * @since 2.0.0 @@ -440,7 +443,7 @@ class DataFrameReader private[sql](sparkSession: SparkSession) extends Logging { TextInputJsonDataSource.inferFromDataset(jsonDataset, parsedOptions) } - verifyColumnNameOfCorruptRecord(schema, parsedOptions.columnNameOfCorruptRecord) + ExprUtils.verifyColumnNameOfCorruptRecord(schema, parsedOptions.columnNameOfCorruptRecord) val actualSchema = StructType(schema.filterNot(_.name == parsedOptions.columnNameOfCorruptRecord)) @@ -502,7 +505,7 @@ class DataFrameReader private[sql](sparkSession: SparkSession) extends Logging { parsedOptions) } - verifyColumnNameOfCorruptRecord(schema, parsedOptions.columnNameOfCorruptRecord) + ExprUtils.verifyColumnNameOfCorruptRecord(schema, parsedOptions.columnNameOfCorruptRecord) val actualSchema = StructType(schema.filterNot(_.name == parsedOptions.columnNameOfCorruptRecord)) @@ -604,6 +607,10 @@ class DataFrameReader private[sql](sparkSession: SparkSession) extends Logging { * `spark.sql.columnNameOfCorruptRecord`): allows renaming the new field having malformed string * created by `PERMISSIVE` mode. This overrides `spark.sql.columnNameOfCorruptRecord`. *
  • `multiLine` (default `false`): parse one record, which may span multiple lines.
  • + *
  • `locale` (default is `en-US`): sets a locale as language tag in IETF BCP 47 format. + * For instance, this is used while parsing dates and timestamps.
  • + *
  • `lineSep` (default covers all `\r`, `\r\n` and `\n`): defines the line separator + * that should be used for parsing. Maximum length is 1 character.
  • * * * @since 2.0.0 @@ -761,22 +768,6 @@ class DataFrameReader private[sql](sparkSession: SparkSession) extends Logging { } } - /** - * A convenient function for schema validation in datasources supporting - * `columnNameOfCorruptRecord` as an option. - */ - private def verifyColumnNameOfCorruptRecord( - schema: StructType, - columnNameOfCorruptRecord: String): Unit = { - schema.getFieldIndex(columnNameOfCorruptRecord).foreach { corruptFieldIndex => - val f = schema(corruptFieldIndex) - if (f.dataType != StringType || !f.nullable) { - throw new AnalysisException( - "The field for corrupt records must be string type and nullable") - } - } - } - /////////////////////////////////////////////////////////////////////////////////////// // Builder pattern config options /////////////////////////////////////////////////////////////////////////////////////// diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameStatFunctions.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameStatFunctions.scala index 7c12432d33c33..0b22b898557f7 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameStatFunctions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameStatFunctions.scala @@ -21,7 +21,7 @@ import java.{lang => jl, util => ju} import scala.collection.JavaConverters._ -import org.apache.spark.annotation.InterfaceStability +import org.apache.spark.annotation.Stable import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.execution.stat._ import org.apache.spark.sql.functions.col @@ -33,7 +33,7 @@ import org.apache.spark.util.sketch.{BloomFilter, CountMinSketch} * * @since 1.4.0 */ -@InterfaceStability.Stable +@Stable final class DataFrameStatFunctions private[sql](df: DataFrame) { /** @@ -51,7 +51,7 @@ final class DataFrameStatFunctions private[sql](df: DataFrame) { * * This method implements a variation of the Greenwald-Khanna algorithm (with some speed * optimizations). - * The algorithm was first present in + * The algorithm was first present in * Space-efficient Online Computation of Quantile Summaries by Greenwald and Khanna. * * @param col the name of the numerical column @@ -218,7 +218,7 @@ final class DataFrameStatFunctions private[sql](df: DataFrame) { /** * Finding frequent items for columns, possibly with false positives. Using the * frequent element count algorithm described in - * here, proposed by Karp, + * here, proposed by Karp, * Schenker, and Papadimitriou. * The `support` should be greater than 1e-4. * @@ -265,7 +265,7 @@ final class DataFrameStatFunctions private[sql](df: DataFrame) { /** * Finding frequent items for columns, possibly with false positives. Using the * frequent element count algorithm described in - * here, proposed by Karp, + * here, proposed by Karp, * Schenker, and Papadimitriou. * Uses a `default` support of 1%. * @@ -284,7 +284,7 @@ final class DataFrameStatFunctions private[sql](df: DataFrame) { /** * (Scala-specific) Finding frequent items for columns, possibly with false positives. Using the * frequent element count algorithm described in - * here, proposed by Karp, Schenker, + * here, proposed by Karp, Schenker, * and Papadimitriou. * * This function is meant for exploratory data analysis, as we make no guarantee about the @@ -328,7 +328,7 @@ final class DataFrameStatFunctions private[sql](df: DataFrame) { /** * (Scala-specific) Finding frequent items for columns, possibly with false positives. Using the * frequent element count algorithm described in - * here, proposed by Karp, Schenker, + * here, proposed by Karp, Schenker, * and Papadimitriou. * Uses a `default` support of 1%. * diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala index 5a28870f5d3c2..5a807d3d4b93e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala @@ -21,7 +21,7 @@ import java.util.{Locale, Properties, UUID} import scala.collection.JavaConverters._ -import org.apache.spark.annotation.InterfaceStability +import org.apache.spark.annotation.Stable import org.apache.spark.sql.catalyst.TableIdentifier import org.apache.spark.sql.catalyst.analysis.{EliminateSubqueryAliases, UnresolvedRelation} import org.apache.spark.sql.catalyst.catalog._ @@ -40,7 +40,7 @@ import org.apache.spark.sql.types.StructType * * @since 1.4.0 */ -@InterfaceStability.Stable +@Stable final class DataFrameWriter[T] private[sql](ds: Dataset[T]) { private val df = ds.toDF() @@ -243,7 +243,7 @@ final class DataFrameWriter[T] private[sql](ds: Dataset[T]) { val cls = DataSource.lookupDataSource(source, df.sparkSession.sessionState.conf) if (classOf[DataSourceV2].isAssignableFrom(cls)) { - val source = cls.newInstance().asInstanceOf[DataSourceV2] + val source = cls.getConstructor().newInstance().asInstanceOf[DataSourceV2] source match { case provider: BatchWriteSupportProvider => val sessionOptions = DataSourceV2Utils.extractSessionConfigs( @@ -658,6 +658,8 @@ final class DataFrameWriter[T] private[sql](ds: Dataset[T]) { * whitespaces from values being written should be skipped. *
  • `ignoreTrailingWhiteSpace` (default `true`): a flag indicating defines whether or not * trailing whitespaces from values being written should be skipped.
  • + *
  • `lineSep` (default `\n`): defines the line separator that should be used for writing. + * Maximum length is 1 character.
  • * * * @since 2.0.0 diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala index f835343d6e067..b10d66dfb1aef 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala @@ -21,19 +21,18 @@ import java.io.CharArrayWriter import scala.collection.JavaConverters._ import scala.language.implicitConversions -import scala.reflect.runtime.universe.TypeTag import scala.util.control.NonFatal import org.apache.commons.lang3.StringUtils import org.apache.spark.TaskContext -import org.apache.spark.annotation.{DeveloperApi, Experimental, InterfaceStability} +import org.apache.spark.annotation.{DeveloperApi, Evolving, Experimental, Stable, Unstable} import org.apache.spark.api.java.JavaRDD import org.apache.spark.api.java.function._ import org.apache.spark.api.python.{PythonRDD, SerDeUtil} import org.apache.spark.broadcast.Broadcast import org.apache.spark.rdd.RDD -import org.apache.spark.sql.catalyst._ +import org.apache.spark.sql.catalyst.QueryPlanningTracker import org.apache.spark.sql.catalyst.analysis._ import org.apache.spark.sql.catalyst.catalog.HiveTableRelation import org.apache.spark.sql.catalyst.encoders._ @@ -78,6 +77,14 @@ private[sql] object Dataset { qe.assertAnalyzed() new Dataset[Row](sparkSession, qe, RowEncoder(qe.analyzed.schema)) } + + /** A variant of ofRows that allows passing in a tracker so we can track query parsing time. */ + def ofRows(sparkSession: SparkSession, logicalPlan: LogicalPlan, tracker: QueryPlanningTracker) + : DataFrame = { + val qe = new QueryExecution(sparkSession, logicalPlan, tracker) + qe.assertAnalyzed() + new Dataset[Row](sparkSession, qe, RowEncoder(qe.analyzed.schema)) + } } /** @@ -166,10 +173,10 @@ private[sql] object Dataset { * * @since 1.6.0 */ -@InterfaceStability.Stable +@Stable class Dataset[T] private[sql]( @transient val sparkSession: SparkSession, - @DeveloperApi @InterfaceStability.Unstable @transient val queryExecution: QueryExecution, + @DeveloperApi @Unstable @transient val queryExecution: QueryExecution, encoder: Encoder[T]) extends Serializable { @@ -426,7 +433,7 @@ class Dataset[T] private[sql]( * @since 1.6.0 */ @Experimental - @InterfaceStability.Evolving + @Evolving def as[U : Encoder]: Dataset[U] = Dataset[U](sparkSession, logicalPlan) /** @@ -544,7 +551,7 @@ class Dataset[T] private[sql]( * @group streaming * @since 2.0.0 */ - @InterfaceStability.Evolving + @Evolving def isStreaming: Boolean = logicalPlan.isStreaming /** @@ -557,7 +564,7 @@ class Dataset[T] private[sql]( * @since 2.1.0 */ @Experimental - @InterfaceStability.Evolving + @Evolving def checkpoint(): Dataset[T] = checkpoint(eager = true, reliableCheckpoint = true) /** @@ -570,7 +577,7 @@ class Dataset[T] private[sql]( * @since 2.1.0 */ @Experimental - @InterfaceStability.Evolving + @Evolving def checkpoint(eager: Boolean): Dataset[T] = checkpoint(eager = eager, reliableCheckpoint = true) /** @@ -583,7 +590,7 @@ class Dataset[T] private[sql]( * @since 2.3.0 */ @Experimental - @InterfaceStability.Evolving + @Evolving def localCheckpoint(): Dataset[T] = checkpoint(eager = true, reliableCheckpoint = false) /** @@ -596,7 +603,7 @@ class Dataset[T] private[sql]( * @since 2.3.0 */ @Experimental - @InterfaceStability.Evolving + @Evolving def localCheckpoint(eager: Boolean): Dataset[T] = checkpoint( eager = eager, reliableCheckpoint = false @@ -671,7 +678,7 @@ class Dataset[T] private[sql]( * @group streaming * @since 2.1.0 */ - @InterfaceStability.Evolving + @Evolving // We only accept an existing column name, not a derived column here as a watermark that is // defined on a derived column cannot referenced elsewhere in the plan. def withWatermark(eventTime: String, delayThreshold: String): Dataset[T] = withTypedPlan { @@ -1066,7 +1073,7 @@ class Dataset[T] private[sql]( * @since 1.6.0 */ @Experimental - @InterfaceStability.Evolving + @Evolving def joinWith[U](other: Dataset[U], condition: Column, joinType: String): Dataset[(T, U)] = { // Creates a Join node and resolve it first, to get join condition resolved, self-join resolved, // etc. @@ -1086,7 +1093,7 @@ class Dataset[T] private[sql]( // Note that we do this before joining them, to enable the join operator to return null for one // side, in cases like outer-join. val left = { - val combined = if (!this.exprEnc.isSerializedAsStruct) { + val combined = if (!this.exprEnc.isSerializedAsStructForTopLevel) { assert(joined.left.output.length == 1) Alias(joined.left.output.head, "_1")() } else { @@ -1096,7 +1103,7 @@ class Dataset[T] private[sql]( } val right = { - val combined = if (!other.exprEnc.isSerializedAsStruct) { + val combined = if (!other.exprEnc.isSerializedAsStructForTopLevel) { assert(joined.right.output.length == 1) Alias(joined.right.output.head, "_2")() } else { @@ -1109,14 +1116,14 @@ class Dataset[T] private[sql]( // combine the outputs of each join side. val conditionExpr = joined.condition.get transformUp { case a: Attribute if joined.left.outputSet.contains(a) => - if (!this.exprEnc.isSerializedAsStruct) { + if (!this.exprEnc.isSerializedAsStructForTopLevel) { left.output.head } else { val index = joined.left.output.indexWhere(_.exprId == a.exprId) GetStructField(left.output.head, index) } case a: Attribute if joined.right.outputSet.contains(a) => - if (!other.exprEnc.isSerializedAsStruct) { + if (!other.exprEnc.isSerializedAsStructForTopLevel) { right.output.head } else { val index = joined.right.output.indexWhere(_.exprId == a.exprId) @@ -1142,7 +1149,7 @@ class Dataset[T] private[sql]( * @since 1.6.0 */ @Experimental - @InterfaceStability.Evolving + @Evolving def joinWith[U](other: Dataset[U], condition: Column): Dataset[(T, U)] = { joinWith(other, condition, "inner") } @@ -1384,12 +1391,12 @@ class Dataset[T] private[sql]( * @since 1.6.0 */ @Experimental - @InterfaceStability.Evolving + @Evolving def select[U1](c1: TypedColumn[T, U1]): Dataset[U1] = { implicit val encoder = c1.encoder val project = Project(c1.withInputType(exprEnc, logicalPlan.output).named :: Nil, logicalPlan) - if (!encoder.isSerializedAsStruct) { + if (!encoder.isSerializedAsStructForTopLevel) { new Dataset[U1](sparkSession, project, encoder) } else { // Flattens inner fields of U1 @@ -1418,7 +1425,7 @@ class Dataset[T] private[sql]( * @since 1.6.0 */ @Experimental - @InterfaceStability.Evolving + @Evolving def select[U1, U2](c1: TypedColumn[T, U1], c2: TypedColumn[T, U2]): Dataset[(U1, U2)] = selectUntyped(c1, c2).asInstanceOf[Dataset[(U1, U2)]] @@ -1430,7 +1437,7 @@ class Dataset[T] private[sql]( * @since 1.6.0 */ @Experimental - @InterfaceStability.Evolving + @Evolving def select[U1, U2, U3]( c1: TypedColumn[T, U1], c2: TypedColumn[T, U2], @@ -1445,7 +1452,7 @@ class Dataset[T] private[sql]( * @since 1.6.0 */ @Experimental - @InterfaceStability.Evolving + @Evolving def select[U1, U2, U3, U4]( c1: TypedColumn[T, U1], c2: TypedColumn[T, U2], @@ -1461,7 +1468,7 @@ class Dataset[T] private[sql]( * @since 1.6.0 */ @Experimental - @InterfaceStability.Evolving + @Evolving def select[U1, U2, U3, U4, U5]( c1: TypedColumn[T, U1], c2: TypedColumn[T, U2], @@ -1632,7 +1639,7 @@ class Dataset[T] private[sql]( * @since 1.6.0 */ @Experimental - @InterfaceStability.Evolving + @Evolving def reduce(func: (T, T) => T): T = withNewRDDExecutionId { rdd.reduce(func) } @@ -1647,7 +1654,7 @@ class Dataset[T] private[sql]( * @since 1.6.0 */ @Experimental - @InterfaceStability.Evolving + @Evolving def reduce(func: ReduceFunction[T]): T = reduce(func.call(_, _)) /** @@ -1659,7 +1666,7 @@ class Dataset[T] private[sql]( * @since 2.0.0 */ @Experimental - @InterfaceStability.Evolving + @Evolving def groupByKey[K: Encoder](func: T => K): KeyValueGroupedDataset[K, T] = { val withGroupingKey = AppendColumns(func, logicalPlan) val executed = sparkSession.sessionState.executePlan(withGroupingKey) @@ -1681,7 +1688,7 @@ class Dataset[T] private[sql]( * @since 2.0.0 */ @Experimental - @InterfaceStability.Evolving + @Evolving def groupByKey[K](func: MapFunction[T, K], encoder: Encoder[K]): KeyValueGroupedDataset[K, T] = groupByKey(func.call(_))(encoder) @@ -2497,7 +2504,7 @@ class Dataset[T] private[sql]( * @since 1.6.0 */ @Experimental - @InterfaceStability.Evolving + @Evolving def filter(func: T => Boolean): Dataset[T] = { withTypedPlan(TypedFilter(func, logicalPlan)) } @@ -2511,7 +2518,7 @@ class Dataset[T] private[sql]( * @since 1.6.0 */ @Experimental - @InterfaceStability.Evolving + @Evolving def filter(func: FilterFunction[T]): Dataset[T] = { withTypedPlan(TypedFilter(func, logicalPlan)) } @@ -2525,7 +2532,7 @@ class Dataset[T] private[sql]( * @since 1.6.0 */ @Experimental - @InterfaceStability.Evolving + @Evolving def map[U : Encoder](func: T => U): Dataset[U] = withTypedPlan { MapElements[T, U](func, logicalPlan) } @@ -2539,7 +2546,7 @@ class Dataset[T] private[sql]( * @since 1.6.0 */ @Experimental - @InterfaceStability.Evolving + @Evolving def map[U](func: MapFunction[T, U], encoder: Encoder[U]): Dataset[U] = { implicit val uEnc = encoder withTypedPlan(MapElements[T, U](func, logicalPlan)) @@ -2554,7 +2561,7 @@ class Dataset[T] private[sql]( * @since 1.6.0 */ @Experimental - @InterfaceStability.Evolving + @Evolving def mapPartitions[U : Encoder](func: Iterator[T] => Iterator[U]): Dataset[U] = { new Dataset[U]( sparkSession, @@ -2571,7 +2578,7 @@ class Dataset[T] private[sql]( * @since 1.6.0 */ @Experimental - @InterfaceStability.Evolving + @Evolving def mapPartitions[U](f: MapPartitionsFunction[T, U], encoder: Encoder[U]): Dataset[U] = { val func: (Iterator[T]) => Iterator[U] = x => f.call(x.asJava).asScala mapPartitions(func)(encoder) @@ -2602,7 +2609,7 @@ class Dataset[T] private[sql]( * @since 1.6.0 */ @Experimental - @InterfaceStability.Evolving + @Evolving def flatMap[U : Encoder](func: T => TraversableOnce[U]): Dataset[U] = mapPartitions(_.flatMap(func)) @@ -2616,7 +2623,7 @@ class Dataset[T] private[sql]( * @since 1.6.0 */ @Experimental - @InterfaceStability.Evolving + @Evolving def flatMap[U](f: FlatMapFunction[T, U], encoder: Encoder[U]): Dataset[U] = { val func: (T) => Iterator[U] = x => f.call(x).asScala flatMap(func)(encoder) @@ -2803,6 +2810,12 @@ class Dataset[T] private[sql]( * When no explicit sort order is specified, "ascending nulls first" is assumed. * Note, the rows are not sorted in each partition of the resulting Dataset. * + * + * Note that due to performance reasons this method uses sampling to estimate the ranges. + * Hence, the output may not be consistent, since sampling can return different values. + * The sample size can be controlled by the config + * `spark.sql.execution.rangeExchange.sampleSizePerPartition`. + * * @group typedrel * @since 2.3.0 */ @@ -2827,6 +2840,11 @@ class Dataset[T] private[sql]( * When no explicit sort order is specified, "ascending nulls first" is assumed. * Note, the rows are not sorted in each partition of the resulting Dataset. * + * Note that due to performance reasons this method uses sampling to estimate the ranges. + * Hence, the output may not be consistent, since sampling can return different values. + * The sample size can be controlled by the config + * `spark.sql.execution.rangeExchange.sampleSizePerPartition`. + * * @group typedrel * @since 2.3.0 */ @@ -3078,7 +3096,7 @@ class Dataset[T] private[sql]( * @group basic * @since 2.0.0 */ - @InterfaceStability.Evolving + @Evolving def writeStream: DataStreamWriter[T] = { if (!isStreaming) { logicalPlan.failAnalysis( diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DatasetHolder.scala b/sql/core/src/main/scala/org/apache/spark/sql/DatasetHolder.scala index 08aa1bbe78fae..1c4ffefb897ea 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DatasetHolder.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DatasetHolder.scala @@ -17,7 +17,7 @@ package org.apache.spark.sql -import org.apache.spark.annotation.InterfaceStability +import org.apache.spark.annotation.Stable /** * A container for a [[Dataset]], used for implicit conversions in Scala. @@ -30,7 +30,7 @@ import org.apache.spark.annotation.InterfaceStability * * @since 1.6.0 */ -@InterfaceStability.Stable +@Stable case class DatasetHolder[T] private[sql](private val ds: Dataset[T]) { // This is declared with parentheses to prevent the Scala compiler from treating diff --git a/sql/core/src/main/scala/org/apache/spark/sql/ExperimentalMethods.scala b/sql/core/src/main/scala/org/apache/spark/sql/ExperimentalMethods.scala index bd8dd6ea3fe0f..302d38cde1430 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/ExperimentalMethods.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/ExperimentalMethods.scala @@ -17,7 +17,7 @@ package org.apache.spark.sql -import org.apache.spark.annotation.{Experimental, InterfaceStability} +import org.apache.spark.annotation.{Experimental, Unstable} import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.catalyst.rules.Rule @@ -33,7 +33,7 @@ import org.apache.spark.sql.catalyst.rules.Rule * @since 1.3.0 */ @Experimental -@InterfaceStability.Unstable +@Unstable class ExperimentalMethods private[sql]() { /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/ForeachWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/ForeachWriter.scala index 52b8c839643e7..5c0fe798b1044 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/ForeachWriter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/ForeachWriter.scala @@ -17,7 +17,7 @@ package org.apache.spark.sql -import org.apache.spark.annotation.InterfaceStability +import org.apache.spark.annotation.Evolving /** * The abstract class for writing custom logic to process data generated by a query. @@ -104,7 +104,7 @@ import org.apache.spark.annotation.InterfaceStability * * @since 2.0.0 */ -@InterfaceStability.Evolving +@Evolving abstract class ForeachWriter[T] extends Serializable { // TODO: Move this to org.apache.spark.sql.util or consolidate this with batch API. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/KeyValueGroupedDataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/KeyValueGroupedDataset.scala index 555bcdffb6ee4..a3cbea9021f22 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/KeyValueGroupedDataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/KeyValueGroupedDataset.scala @@ -19,13 +19,14 @@ package org.apache.spark.sql import scala.collection.JavaConverters._ -import org.apache.spark.annotation.{Experimental, InterfaceStability} +import org.apache.spark.annotation.{Evolving, Experimental} import org.apache.spark.api.java.function._ import org.apache.spark.sql.catalyst.encoders.{encoderFor, ExpressionEncoder} import org.apache.spark.sql.catalyst.expressions.{Alias, Attribute, CreateStruct} import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.execution.QueryExecution import org.apache.spark.sql.expressions.ReduceAggregator +import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.streaming.{GroupState, GroupStateTimeout, OutputMode} /** @@ -37,7 +38,7 @@ import org.apache.spark.sql.streaming.{GroupState, GroupStateTimeout, OutputMode * @since 2.0.0 */ @Experimental -@InterfaceStability.Evolving +@Evolving class KeyValueGroupedDataset[K, V] private[sql]( kEncoder: Encoder[K], vEncoder: Encoder[V], @@ -237,7 +238,7 @@ class KeyValueGroupedDataset[K, V] private[sql]( * @since 2.2.0 */ @Experimental - @InterfaceStability.Evolving + @Evolving def mapGroupsWithState[S: Encoder, U: Encoder]( func: (K, Iterator[V], GroupState[S]) => U): Dataset[U] = { val flatMapFunc = (key: K, it: Iterator[V], s: GroupState[S]) => Iterator(func(key, it, s)) @@ -272,7 +273,7 @@ class KeyValueGroupedDataset[K, V] private[sql]( * @since 2.2.0 */ @Experimental - @InterfaceStability.Evolving + @Evolving def mapGroupsWithState[S: Encoder, U: Encoder]( timeoutConf: GroupStateTimeout)( func: (K, Iterator[V], GroupState[S]) => U): Dataset[U] = { @@ -309,7 +310,7 @@ class KeyValueGroupedDataset[K, V] private[sql]( * @since 2.2.0 */ @Experimental - @InterfaceStability.Evolving + @Evolving def mapGroupsWithState[S, U]( func: MapGroupsWithStateFunction[K, V, S, U], stateEncoder: Encoder[S], @@ -340,7 +341,7 @@ class KeyValueGroupedDataset[K, V] private[sql]( * @since 2.2.0 */ @Experimental - @InterfaceStability.Evolving + @Evolving def mapGroupsWithState[S, U]( func: MapGroupsWithStateFunction[K, V, S, U], stateEncoder: Encoder[S], @@ -371,7 +372,7 @@ class KeyValueGroupedDataset[K, V] private[sql]( * @since 2.2.0 */ @Experimental - @InterfaceStability.Evolving + @Evolving def flatMapGroupsWithState[S: Encoder, U: Encoder]( outputMode: OutputMode, timeoutConf: GroupStateTimeout)( @@ -413,7 +414,7 @@ class KeyValueGroupedDataset[K, V] private[sql]( * @since 2.2.0 */ @Experimental - @InterfaceStability.Evolving + @Evolving def flatMapGroupsWithState[S, U]( func: FlatMapGroupsWithStateFunction[K, V, S, U], outputMode: OutputMode, @@ -457,9 +458,13 @@ class KeyValueGroupedDataset[K, V] private[sql]( val encoders = columns.map(_.encoder) val namedColumns = columns.map(_.withInputType(vExprEnc, dataAttributes).named) - val keyColumn = if (!kExprEnc.isSerializedAsStruct) { + val keyColumn = if (!kExprEnc.isSerializedAsStructForTopLevel) { assert(groupingAttributes.length == 1) - groupingAttributes.head + if (SQLConf.get.nameNonStructGroupingKeyAsValue) { + groupingAttributes.head + } else { + Alias(groupingAttributes.head, "key")() + } } else { Alias(CreateStruct(groupingAttributes), "key")() } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala index d4e75b5ebd405..e85636d82a62c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala @@ -22,7 +22,7 @@ import java.util.Locale import scala.collection.JavaConverters._ import scala.language.implicitConversions -import org.apache.spark.annotation.InterfaceStability +import org.apache.spark.annotation.Stable import org.apache.spark.api.python.PythonEvalType import org.apache.spark.broadcast.Broadcast import org.apache.spark.sql.catalyst.analysis.{Star, UnresolvedAlias, UnresolvedAttribute, UnresolvedFunction} @@ -45,7 +45,7 @@ import org.apache.spark.sql.types.{NumericType, StructType} * * @since 2.0.0 */ -@InterfaceStability.Stable +@Stable class RelationalGroupedDataset protected[sql]( df: DataFrame, groupingExprs: Seq[Expression], diff --git a/sql/core/src/main/scala/org/apache/spark/sql/RuntimeConfig.scala b/sql/core/src/main/scala/org/apache/spark/sql/RuntimeConfig.scala index 3c39579149fff..5a554eff02e3d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/RuntimeConfig.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/RuntimeConfig.scala @@ -17,11 +17,10 @@ package org.apache.spark.sql -import org.apache.spark.annotation.InterfaceStability +import org.apache.spark.annotation.Stable import org.apache.spark.internal.config.{ConfigEntry, OptionalConfigEntry} import org.apache.spark.sql.internal.SQLConf - /** * Runtime configuration interface for Spark. To access this, use `SparkSession.conf`. * @@ -29,7 +28,7 @@ import org.apache.spark.sql.internal.SQLConf * * @since 2.0.0 */ -@InterfaceStability.Stable +@Stable class RuntimeConfig private[sql](sqlConf: SQLConf = new SQLConf) { /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala index 9982b60fefe60..43f34e6ff4b85 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala @@ -23,7 +23,7 @@ import scala.collection.immutable import scala.reflect.runtime.universe.TypeTag import org.apache.spark.{SparkConf, SparkContext} -import org.apache.spark.annotation.{DeveloperApi, Experimental, InterfaceStability} +import org.apache.spark.annotation._ import org.apache.spark.api.java.JavaRDD import org.apache.spark.internal.Logging import org.apache.spark.internal.config.ConfigEntry @@ -54,7 +54,7 @@ import org.apache.spark.sql.util.ExecutionListenerManager * @groupname Ungrouped Support functions for language integrated queries * @since 1.0.0 */ -@InterfaceStability.Stable +@Stable class SQLContext private[sql](val sparkSession: SparkSession) extends Logging with Serializable { @@ -86,7 +86,7 @@ class SQLContext private[sql](val sparkSession: SparkSession) * that listen for execution metrics. */ @Experimental - @InterfaceStability.Evolving + @Evolving def listenerManager: ExecutionListenerManager = sparkSession.listenerManager /** @@ -158,7 +158,7 @@ class SQLContext private[sql](val sparkSession: SparkSession) */ @Experimental @transient - @InterfaceStability.Unstable + @Unstable def experimental: ExperimentalMethods = sparkSession.experimental /** @@ -244,7 +244,7 @@ class SQLContext private[sql](val sparkSession: SparkSession) * @since 1.3.0 */ @Experimental - @InterfaceStability.Evolving + @Evolving object implicits extends SQLImplicits with Serializable { protected override def _sqlContext: SQLContext = self } @@ -258,7 +258,7 @@ class SQLContext private[sql](val sparkSession: SparkSession) * @since 1.3.0 */ @Experimental - @InterfaceStability.Evolving + @Evolving def createDataFrame[A <: Product : TypeTag](rdd: RDD[A]): DataFrame = { sparkSession.createDataFrame(rdd) } @@ -271,7 +271,7 @@ class SQLContext private[sql](val sparkSession: SparkSession) * @since 1.3.0 */ @Experimental - @InterfaceStability.Evolving + @Evolving def createDataFrame[A <: Product : TypeTag](data: Seq[A]): DataFrame = { sparkSession.createDataFrame(data) } @@ -319,7 +319,7 @@ class SQLContext private[sql](val sparkSession: SparkSession) * @since 1.3.0 */ @DeveloperApi - @InterfaceStability.Evolving + @Evolving def createDataFrame(rowRDD: RDD[Row], schema: StructType): DataFrame = { sparkSession.createDataFrame(rowRDD, schema) } @@ -363,7 +363,7 @@ class SQLContext private[sql](val sparkSession: SparkSession) * @group dataset */ @Experimental - @InterfaceStability.Evolving + @Evolving def createDataset[T : Encoder](data: Seq[T]): Dataset[T] = { sparkSession.createDataset(data) } @@ -401,7 +401,7 @@ class SQLContext private[sql](val sparkSession: SparkSession) * @group dataset */ @Experimental - @InterfaceStability.Evolving + @Evolving def createDataset[T : Encoder](data: java.util.List[T]): Dataset[T] = { sparkSession.createDataset(data) } @@ -428,7 +428,7 @@ class SQLContext private[sql](val sparkSession: SparkSession) * @since 1.3.0 */ @DeveloperApi - @InterfaceStability.Evolving + @Evolving def createDataFrame(rowRDD: JavaRDD[Row], schema: StructType): DataFrame = { sparkSession.createDataFrame(rowRDD, schema) } @@ -443,7 +443,7 @@ class SQLContext private[sql](val sparkSession: SparkSession) * @since 1.6.0 */ @DeveloperApi - @InterfaceStability.Evolving + @Evolving def createDataFrame(rows: java.util.List[Row], schema: StructType): DataFrame = { sparkSession.createDataFrame(rows, schema) } @@ -507,7 +507,7 @@ class SQLContext private[sql](val sparkSession: SparkSession) * * @since 2.0.0 */ - @InterfaceStability.Evolving + @Evolving def readStream: DataStreamReader = sparkSession.readStream @@ -631,7 +631,7 @@ class SQLContext private[sql](val sparkSession: SparkSession) * @group dataframe */ @Experimental - @InterfaceStability.Evolving + @Evolving def range(end: Long): DataFrame = sparkSession.range(end).toDF() /** @@ -643,7 +643,7 @@ class SQLContext private[sql](val sparkSession: SparkSession) * @group dataframe */ @Experimental - @InterfaceStability.Evolving + @Evolving def range(start: Long, end: Long): DataFrame = sparkSession.range(start, end).toDF() /** @@ -655,7 +655,7 @@ class SQLContext private[sql](val sparkSession: SparkSession) * @group dataframe */ @Experimental - @InterfaceStability.Evolving + @Evolving def range(start: Long, end: Long, step: Long): DataFrame = { sparkSession.range(start, end, step).toDF() } @@ -670,7 +670,7 @@ class SQLContext private[sql](val sparkSession: SparkSession) * @group dataframe */ @Experimental - @InterfaceStability.Evolving + @Evolving def range(start: Long, end: Long, step: Long, numPartitions: Int): DataFrame = { sparkSession.range(start, end, step, numPartitions).toDF() } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLImplicits.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLImplicits.scala index 05db292bd41b1..d329af0145c2f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SQLImplicits.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLImplicits.scala @@ -21,7 +21,7 @@ import scala.collection.Map import scala.language.implicitConversions import scala.reflect.runtime.universe.TypeTag -import org.apache.spark.annotation.InterfaceStability +import org.apache.spark.annotation.Evolving import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder @@ -30,7 +30,7 @@ import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder * * @since 1.6.0 */ -@InterfaceStability.Evolving +@Evolving abstract class SQLImplicits extends LowPrioritySQLImplicits { protected def _sqlContext: SQLContext diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala b/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala index 3f0b8208612d7..703272acda22c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala @@ -25,7 +25,7 @@ import scala.reflect.runtime.universe.TypeTag import scala.util.control.NonFatal import org.apache.spark.{SPARK_VERSION, SparkConf, SparkContext, TaskContext} -import org.apache.spark.annotation.{DeveloperApi, Experimental, InterfaceStability} +import org.apache.spark.annotation.{DeveloperApi, Evolving, Experimental, Stable, Unstable} import org.apache.spark.api.java.JavaRDD import org.apache.spark.internal.Logging import org.apache.spark.rdd.RDD @@ -73,7 +73,7 @@ import org.apache.spark.util.{CallSite, Utils} * @param parentSessionState If supplied, inherit all session state (i.e. temporary * views, SQL config, UDFs etc) from parent. */ -@InterfaceStability.Stable +@Stable class SparkSession private( @transient val sparkContext: SparkContext, @transient private val existingSharedState: Option[SharedState], @@ -124,7 +124,7 @@ class SparkSession private( * * @since 2.2.0 */ - @InterfaceStability.Unstable + @Unstable @transient lazy val sharedState: SharedState = { existingSharedState.getOrElse(new SharedState(sparkContext)) @@ -145,7 +145,7 @@ class SparkSession private( * * @since 2.2.0 */ - @InterfaceStability.Unstable + @Unstable @transient lazy val sessionState: SessionState = { parentSessionState @@ -186,7 +186,7 @@ class SparkSession private( * @since 2.0.0 */ @Experimental - @InterfaceStability.Evolving + @Evolving def listenerManager: ExecutionListenerManager = sessionState.listenerManager /** @@ -197,7 +197,7 @@ class SparkSession private( * @since 2.0.0 */ @Experimental - @InterfaceStability.Unstable + @Unstable def experimental: ExperimentalMethods = sessionState.experimentalMethods /** @@ -231,7 +231,7 @@ class SparkSession private( * @since 2.0.0 */ @Experimental - @InterfaceStability.Unstable + @Unstable def streams: StreamingQueryManager = sessionState.streamingQueryManager /** @@ -289,7 +289,7 @@ class SparkSession private( * @return 2.0.0 */ @Experimental - @InterfaceStability.Evolving + @Evolving def emptyDataset[T: Encoder]: Dataset[T] = { val encoder = implicitly[Encoder[T]] new Dataset(self, LocalRelation(encoder.schema.toAttributes), encoder) @@ -302,7 +302,7 @@ class SparkSession private( * @since 2.0.0 */ @Experimental - @InterfaceStability.Evolving + @Evolving def createDataFrame[A <: Product : TypeTag](rdd: RDD[A]): DataFrame = { SparkSession.setActiveSession(this) val encoder = Encoders.product[A] @@ -316,7 +316,7 @@ class SparkSession private( * @since 2.0.0 */ @Experimental - @InterfaceStability.Evolving + @Evolving def createDataFrame[A <: Product : TypeTag](data: Seq[A]): DataFrame = { SparkSession.setActiveSession(this) val schema = ScalaReflection.schemaFor[A].dataType.asInstanceOf[StructType] @@ -356,7 +356,7 @@ class SparkSession private( * @since 2.0.0 */ @DeveloperApi - @InterfaceStability.Evolving + @Evolving def createDataFrame(rowRDD: RDD[Row], schema: StructType): DataFrame = { createDataFrame(rowRDD, schema, needsConversion = true) } @@ -370,7 +370,7 @@ class SparkSession private( * @since 2.0.0 */ @DeveloperApi - @InterfaceStability.Evolving + @Evolving def createDataFrame(rowRDD: JavaRDD[Row], schema: StructType): DataFrame = { createDataFrame(rowRDD.rdd, schema) } @@ -384,7 +384,7 @@ class SparkSession private( * @since 2.0.0 */ @DeveloperApi - @InterfaceStability.Evolving + @Evolving def createDataFrame(rows: java.util.List[Row], schema: StructType): DataFrame = { Dataset.ofRows(self, LocalRelation.fromExternalRows(schema.toAttributes, rows.asScala)) } @@ -474,7 +474,7 @@ class SparkSession private( * @since 2.0.0 */ @Experimental - @InterfaceStability.Evolving + @Evolving def createDataset[T : Encoder](data: Seq[T]): Dataset[T] = { val enc = encoderFor[T] val attributes = enc.schema.toAttributes @@ -493,7 +493,7 @@ class SparkSession private( * @since 2.0.0 */ @Experimental - @InterfaceStability.Evolving + @Evolving def createDataset[T : Encoder](data: RDD[T]): Dataset[T] = { Dataset[T](self, ExternalRDD(data, self)) } @@ -515,7 +515,7 @@ class SparkSession private( * @since 2.0.0 */ @Experimental - @InterfaceStability.Evolving + @Evolving def createDataset[T : Encoder](data: java.util.List[T]): Dataset[T] = { createDataset(data.asScala) } @@ -528,7 +528,7 @@ class SparkSession private( * @since 2.0.0 */ @Experimental - @InterfaceStability.Evolving + @Evolving def range(end: Long): Dataset[java.lang.Long] = range(0, end) /** @@ -539,7 +539,7 @@ class SparkSession private( * @since 2.0.0 */ @Experimental - @InterfaceStability.Evolving + @Evolving def range(start: Long, end: Long): Dataset[java.lang.Long] = { range(start, end, step = 1, numPartitions = sparkContext.defaultParallelism) } @@ -552,7 +552,7 @@ class SparkSession private( * @since 2.0.0 */ @Experimental - @InterfaceStability.Evolving + @Evolving def range(start: Long, end: Long, step: Long): Dataset[java.lang.Long] = { range(start, end, step, numPartitions = sparkContext.defaultParallelism) } @@ -566,7 +566,7 @@ class SparkSession private( * @since 2.0.0 */ @Experimental - @InterfaceStability.Evolving + @Evolving def range(start: Long, end: Long, step: Long, numPartitions: Int): Dataset[java.lang.Long] = { new Dataset(self, Range(start, end, step, numPartitions), Encoders.LONG) } @@ -648,7 +648,11 @@ class SparkSession private( * @since 2.0.0 */ def sql(sqlText: String): DataFrame = { - Dataset.ofRows(self, sessionState.sqlParser.parsePlan(sqlText)) + val tracker = new QueryPlanningTracker + val plan = tracker.measureTime(QueryPlanningTracker.PARSING) { + sessionState.sqlParser.parsePlan(sqlText) + } + Dataset.ofRows(self, plan, tracker) } /** @@ -672,7 +676,7 @@ class SparkSession private( * * @since 2.0.0 */ - @InterfaceStability.Evolving + @Evolving def readStream: DataStreamReader = new DataStreamReader(self) /** @@ -706,7 +710,7 @@ class SparkSession private( * @since 2.0.0 */ @Experimental - @InterfaceStability.Evolving + @Evolving object implicits extends SQLImplicits with Serializable { protected override def _sqlContext: SQLContext = SparkSession.this.sqlContext } @@ -775,13 +779,13 @@ class SparkSession private( } -@InterfaceStability.Stable +@Stable object SparkSession extends Logging { /** * Builder for [[SparkSession]]. */ - @InterfaceStability.Stable + @Stable class Builder extends Logging { private[this] val options = new scala.collection.mutable.HashMap[String, String] @@ -1146,7 +1150,7 @@ object SparkSession extends Logging { val extensionConfClassName = extensionOption.get try { val extensionConfClass = Utils.classForName(extensionConfClassName) - val extensionConf = extensionConfClass.newInstance() + val extensionConf = extensionConfClass.getConstructor().newInstance() .asInstanceOf[SparkSessionExtensions => Unit] extensionConf(extensions) } catch { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SparkSessionExtensions.scala b/sql/core/src/main/scala/org/apache/spark/sql/SparkSessionExtensions.scala index a4864344b2d25..5ed76789786bc 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SparkSessionExtensions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SparkSessionExtensions.scala @@ -19,7 +19,7 @@ package org.apache.spark.sql import scala.collection.mutable -import org.apache.spark.annotation.{DeveloperApi, Experimental, InterfaceStability} +import org.apache.spark.annotation.{DeveloperApi, Experimental, Unstable} import org.apache.spark.sql.catalyst.FunctionIdentifier import org.apache.spark.sql.catalyst.analysis.FunctionRegistry import org.apache.spark.sql.catalyst.analysis.FunctionRegistry.FunctionBuilder @@ -66,7 +66,7 @@ import org.apache.spark.sql.catalyst.rules.Rule */ @DeveloperApi @Experimental -@InterfaceStability.Unstable +@Unstable class SparkSessionExtensions { type RuleBuilder = SparkSession => Rule[LogicalPlan] type CheckRuleBuilder = SparkSession => LogicalPlan => Unit diff --git a/sql/core/src/main/scala/org/apache/spark/sql/UDFRegistration.scala b/sql/core/src/main/scala/org/apache/spark/sql/UDFRegistration.scala index aa3a6c3bf122f..5a3f556c9c074 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/UDFRegistration.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/UDFRegistration.scala @@ -22,7 +22,7 @@ import java.lang.reflect.ParameterizedType import scala.reflect.runtime.universe.TypeTag import scala.util.Try -import org.apache.spark.annotation.InterfaceStability +import org.apache.spark.annotation.Stable import org.apache.spark.api.python.PythonEvalType import org.apache.spark.internal.Logging import org.apache.spark.sql.api.java._ @@ -44,7 +44,7 @@ import org.apache.spark.util.Utils * * @since 1.3.0 */ -@InterfaceStability.Stable +@Stable class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends Logging { protected[sql] def registerPython(name: String, udf: UserDefinedPythonFunction): Unit = { @@ -670,7 +670,7 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends throw new AnalysisException(s"It is invalid to implement multiple UDF interfaces, UDF class $className") } else { try { - val udf = clazz.newInstance() + val udf = clazz.getConstructor().newInstance() val udfReturnType = udfInterfaces(0).getActualTypeArguments.last var returnType = returnDataType if (returnType == null) { @@ -727,7 +727,7 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends if (!classOf[UserDefinedAggregateFunction].isAssignableFrom(clazz)) { throw new AnalysisException(s"class $className doesn't implement interface UserDefinedAggregateFunction") } - val udaf = clazz.newInstance().asInstanceOf[UserDefinedAggregateFunction] + val udaf = clazz.getConstructor().newInstance().asInstanceOf[UserDefinedAggregateFunction] register(name, udaf) } catch { case e: ClassNotFoundException => throw new AnalysisException(s"Can not load class ${className}, please make sure it is on the classpath") diff --git a/sql/core/src/main/scala/org/apache/spark/sql/api/r/SQLUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/api/r/SQLUtils.scala index af20764f9a968..becb05cf72aba 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/api/r/SQLUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/api/r/SQLUtils.scala @@ -111,7 +111,7 @@ private[sql] object SQLUtils extends Logging { private[this] def doConversion(data: Object, dataType: DataType): Object = { data match { case d: java.lang.Double if dataType == FloatType => - new java.lang.Float(d) + java.lang.Float.valueOf(d.toFloat) // Scala Map is the only allowed external type of map type in Row. case m: java.util.Map[_, _] => m.asScala case _ => data diff --git a/sql/core/src/main/scala/org/apache/spark/sql/catalog/Catalog.scala b/sql/core/src/main/scala/org/apache/spark/sql/catalog/Catalog.scala index ab81725def3f4..44668610d8052 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/catalog/Catalog.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/catalog/Catalog.scala @@ -19,7 +19,7 @@ package org.apache.spark.sql.catalog import scala.collection.JavaConverters._ -import org.apache.spark.annotation.{Experimental, InterfaceStability} +import org.apache.spark.annotation.{Evolving, Experimental, Stable} import org.apache.spark.sql.{AnalysisException, DataFrame, Dataset} import org.apache.spark.sql.types.StructType import org.apache.spark.storage.StorageLevel @@ -29,7 +29,7 @@ import org.apache.spark.storage.StorageLevel * * @since 2.0.0 */ -@InterfaceStability.Stable +@Stable abstract class Catalog { /** @@ -233,7 +233,7 @@ abstract class Catalog { * @since 2.2.0 */ @Experimental - @InterfaceStability.Evolving + @Evolving def createTable(tableName: String, path: String): DataFrame /** @@ -261,7 +261,7 @@ abstract class Catalog { * @since 2.2.0 */ @Experimental - @InterfaceStability.Evolving + @Evolving def createTable(tableName: String, path: String, source: String): DataFrame /** @@ -292,7 +292,7 @@ abstract class Catalog { * @since 2.2.0 */ @Experimental - @InterfaceStability.Evolving + @Evolving def createTable( tableName: String, source: String, @@ -330,7 +330,7 @@ abstract class Catalog { * @since 2.2.0 */ @Experimental - @InterfaceStability.Evolving + @Evolving def createTable( tableName: String, source: String, @@ -366,7 +366,7 @@ abstract class Catalog { * @since 2.2.0 */ @Experimental - @InterfaceStability.Evolving + @Evolving def createTable( tableName: String, source: String, @@ -406,7 +406,7 @@ abstract class Catalog { * @since 2.2.0 */ @Experimental - @InterfaceStability.Evolving + @Evolving def createTable( tableName: String, source: String, diff --git a/sql/core/src/main/scala/org/apache/spark/sql/catalog/interface.scala b/sql/core/src/main/scala/org/apache/spark/sql/catalog/interface.scala index c0c5ebc2ba2d6..cb270875228ab 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/catalog/interface.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/catalog/interface.scala @@ -19,7 +19,7 @@ package org.apache.spark.sql.catalog import javax.annotation.Nullable -import org.apache.spark.annotation.InterfaceStability +import org.apache.spark.annotation.Stable import org.apache.spark.sql.catalyst.DefinedByConstructorParams @@ -34,7 +34,7 @@ import org.apache.spark.sql.catalyst.DefinedByConstructorParams * @param locationUri path (in the form of a uri) to data files. * @since 2.0.0 */ -@InterfaceStability.Stable +@Stable class Database( val name: String, @Nullable val description: String, @@ -61,7 +61,7 @@ class Database( * @param isTemporary whether the table is a temporary table. * @since 2.0.0 */ -@InterfaceStability.Stable +@Stable class Table( val name: String, @Nullable val database: String, @@ -93,7 +93,7 @@ class Table( * @param isBucket whether the column is a bucket column. * @since 2.0.0 */ -@InterfaceStability.Stable +@Stable class Column( val name: String, @Nullable val description: String, @@ -126,7 +126,7 @@ class Column( * @param isTemporary whether the function is a temporary function or not. * @since 2.0.0 */ -@InterfaceStability.Stable +@Stable class Function( val name: String, @Nullable val database: String, diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/DataSourceScanExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/DataSourceScanExec.scala index 9b500c1a040ae..9c04f8679b46d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/DataSourceScanExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/DataSourceScanExec.scala @@ -30,6 +30,7 @@ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.codegen.CodegenContext import org.apache.spark.sql.catalyst.plans.QueryPlan import org.apache.spark.sql.catalyst.plans.physical.{HashPartitioning, Partitioning, UnknownPartitioning} +import org.apache.spark.sql.catalyst.util.truncatedString import org.apache.spark.sql.execution.datasources._ import org.apache.spark.sql.execution.datasources.parquet.{ParquetFileFormat => ParquetSource} import org.apache.spark.sql.execution.metric.SQLMetrics @@ -56,8 +57,8 @@ trait DataSourceScanExec extends LeafExecNode with CodegenSupport { case (key, value) => key + ": " + StringUtils.abbreviate(redact(value), 100) } - val metadataStr = Utils.truncatedString(metadataEntries, " ", ", ", "") - s"$nodeNamePrefix$nodeName${Utils.truncatedString(output, "[", ",", "]")}$metadataStr" + val metadataStr = truncatedString(metadataEntries, " ", ", ", "") + s"$nodeNamePrefix$nodeName${truncatedString(output, "[", ",", "]")}$metadataStr" } override def verboseString: String = redact(super.verboseString) @@ -83,7 +84,7 @@ case class RowDataSourceScanExec( rdd: RDD[InternalRow], @transient relation: BaseRelation, override val tableIdentifier: Option[TableIdentifier]) - extends DataSourceScanExec { + extends DataSourceScanExec with InputRDDCodegen { def output: Seq[Attribute] = requiredColumnsIndex.map(fullOutput) @@ -103,30 +104,10 @@ case class RowDataSourceScanExec( } } - override def inputRDDs(): Seq[RDD[InternalRow]] = { - rdd :: Nil - } + // Input can be InternalRow, has to be turned into UnsafeRows. + override protected val createUnsafeProjection: Boolean = true - override protected def doProduce(ctx: CodegenContext): String = { - val numOutputRows = metricTerm(ctx, "numOutputRows") - // PhysicalRDD always just has one input - val input = ctx.addMutableState("scala.collection.Iterator", "input", v => s"$v = inputs[0];") - val exprRows = output.zipWithIndex.map{ case (a, i) => - BoundReference(i, a.dataType, a.nullable) - } - val row = ctx.freshName("row") - ctx.INPUT_ROW = row - ctx.currentVars = null - val columnsRowInput = exprRows.map(_.genCode(ctx)) - s""" - |while ($input.hasNext()) { - | InternalRow $row = (InternalRow) $input.next(); - | $numOutputRows.add(1); - | ${consume(ctx, columnsRowInput).trim} - | if (shouldStop()) return; - |} - """.stripMargin - } + override def inputRDD: RDD[InternalRow] = rdd override val metadata: Map[String, String] = { val markedFilters = for (filter <- filters) yield { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/ExistingRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/ExistingRDD.scala index 2962becb64e88..e214bfd050410 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/ExistingRDD.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/ExistingRDD.scala @@ -24,9 +24,9 @@ import org.apache.spark.sql.catalyst.analysis.MultiInstanceRelation import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.plans.physical.{Partitioning, UnknownPartitioning} +import org.apache.spark.sql.catalyst.util.truncatedString import org.apache.spark.sql.execution.metric.SQLMetrics import org.apache.spark.sql.types.DataType -import org.apache.spark.util.Utils object RDDConversions { def productToRowRdd[A <: Product](data: RDD[A], outputTypes: Seq[DataType]): RDD[InternalRow] = { @@ -175,7 +175,7 @@ case class RDDScanExec( rdd: RDD[InternalRow], name: String, override val outputPartitioning: Partitioning = UnknownPartitioning(0), - override val outputOrdering: Seq[SortOrder] = Nil) extends LeafExecNode { + override val outputOrdering: Seq[SortOrder] = Nil) extends LeafExecNode with InputRDDCodegen { private def rddName: String = Option(rdd.name).map(n => s" $n").getOrElse("") @@ -197,6 +197,11 @@ case class RDDScanExec( } override def simpleString: String = { - s"$nodeName${Utils.truncatedString(output, "[", ",", "]")}" + s"$nodeName${truncatedString(output, "[", ",", "]")}" } + + // Input can be InternalRow, has to be turned into UnsafeRows. + override protected val createUnsafeProjection: Boolean = true + + override def inputRDD: RDD[InternalRow] = rdd } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/LocalTableScanExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/LocalTableScanExec.scala index 448eb703eacde..31640db3722ec 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/LocalTableScanExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/LocalTableScanExec.scala @@ -31,7 +31,7 @@ import org.apache.spark.sql.execution.metric.SQLMetrics */ case class LocalTableScanExec( output: Seq[Attribute], - @transient rows: Seq[InternalRow]) extends LeafExecNode { + @transient rows: Seq[InternalRow]) extends LeafExecNode with InputRDDCodegen { override lazy val metrics = Map( "numOutputRows" -> SQLMetrics.createMetric(sparkContext, "number of output rows")) @@ -76,4 +76,12 @@ case class LocalTableScanExec( longMetric("numOutputRows").add(taken.size) taken } + + // Input is already UnsafeRows. + override protected val createUnsafeProjection: Boolean = false + + // Do not codegen when there is no parent - to support the fast driver-local collect/take paths. + override def supportCodegen: Boolean = (parent != null) + + override def inputRDD: RDD[InternalRow] = rdd } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala index c957285b2a315..714581ee10c32 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala @@ -17,16 +17,21 @@ package org.apache.spark.sql.execution +import java.io.{BufferedWriter, OutputStreamWriter, Writer} import java.nio.charset.StandardCharsets import java.sql.{Date, Timestamp} +import org.apache.commons.io.output.StringBuilderWriter +import org.apache.hadoop.fs.Path + import org.apache.spark.rdd.RDD import org.apache.spark.sql.{AnalysisException, Row, SparkSession} -import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.{InternalRow, QueryPlanningTracker} import org.apache.spark.sql.catalyst.analysis.UnsupportedOperationChecker import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, ReturnAnswer} import org.apache.spark.sql.catalyst.rules.Rule import org.apache.spark.sql.catalyst.util.DateTimeUtils +import org.apache.spark.sql.catalyst.util.truncatedString import org.apache.spark.sql.execution.adaptive.PlanQueryStage import org.apache.spark.sql.execution.command.{DescribeTableCommand, ExecutedCommandExec, ShowTablesCommand} import org.apache.spark.sql.execution.exchange.{EnsureRequirements, ReuseExchange} @@ -40,7 +45,10 @@ import org.apache.spark.util.Utils * While this is not a public class, we should avoid changing the function names for the sake of * changing them, because a lot of developers use the feature for debugging. */ -class QueryExecution(val sparkSession: SparkSession, val logical: LogicalPlan) { +class QueryExecution( + val sparkSession: SparkSession, + val logical: LogicalPlan, + val tracker: QueryPlanningTracker = new QueryPlanningTracker) { // TODO: Move the planner an optimizer into here from SessionState. protected def planner = sparkSession.sessionState.planner @@ -53,9 +61,9 @@ class QueryExecution(val sparkSession: SparkSession, val logical: LogicalPlan) { } } - lazy val analyzed: LogicalPlan = { + lazy val analyzed: LogicalPlan = tracker.measureTime(QueryPlanningTracker.ANALYSIS) { SparkSession.setActiveSession(sparkSession) - sparkSession.sessionState.analyzer.executeAndCheck(logical) + sparkSession.sessionState.analyzer.executeAndCheck(logical, tracker) } lazy val withCachedData: LogicalPlan = { @@ -64,9 +72,11 @@ class QueryExecution(val sparkSession: SparkSession, val logical: LogicalPlan) { sparkSession.sharedState.cacheManager.useCachedData(analyzed) } - lazy val optimizedPlan: LogicalPlan = sparkSession.sessionState.optimizer.execute(withCachedData) + lazy val optimizedPlan: LogicalPlan = tracker.measureTime(QueryPlanningTracker.OPTIMIZATION) { + sparkSession.sessionState.optimizer.executeAndTrack(withCachedData, tracker) + } - lazy val sparkPlan: SparkPlan = { + lazy val sparkPlan: SparkPlan = tracker.measureTime(QueryPlanningTracker.PLANNING) { SparkSession.setActiveSession(sparkSession) // TODO: We use next(), i.e. take the first plan returned by the planner, here for now, // but we will implement to choose the best plan. @@ -75,7 +85,9 @@ class QueryExecution(val sparkSession: SparkSession, val logical: LogicalPlan) { // executedPlan should not be used to initialize any SparkPlan. It should be // only used for execution. - lazy val executedPlan: SparkPlan = prepareForExecution(sparkPlan) + lazy val executedPlan: SparkPlan = tracker.measureTime(QueryPlanningTracker.PLANNING) { + prepareForExecution(sparkPlan) + } /** Internal version of the RDD. Avoids copies and has no schema */ lazy val toRdd: RDD[InternalRow] = executedPlan.execute() @@ -203,23 +215,38 @@ class QueryExecution(val sparkSession: SparkSession, val logical: LogicalPlan) { """.stripMargin.trim } + private def writeOrError(writer: Writer)(f: Writer => Unit): Unit = { + try f(writer) + catch { + case e: AnalysisException => writer.write(e.toString) + } + } + + private def writePlans(writer: Writer): Unit = { + val (verbose, addSuffix) = (true, false) + + writer.write("== Parsed Logical Plan ==\n") + writeOrError(writer)(logical.treeString(_, verbose, addSuffix)) + writer.write("\n== Analyzed Logical Plan ==\n") + val analyzedOutput = stringOrError(truncatedString( + analyzed.output.map(o => s"${o.name}: ${o.dataType.simpleString}"), ", ")) + writer.write(analyzedOutput) + writer.write("\n") + writeOrError(writer)(analyzed.treeString(_, verbose, addSuffix)) + writer.write("\n== Optimized Logical Plan ==\n") + writeOrError(writer)(optimizedPlan.treeString(_, verbose, addSuffix)) + writer.write("\n== Physical Plan ==\n") + writeOrError(writer)(executedPlan.treeString(_, verbose, addSuffix)) + } + override def toString: String = withRedaction { - def output = Utils.truncatedString( - analyzed.output.map(o => s"${o.name}: ${o.dataType.simpleString}"), ", ") - val analyzedPlan = Seq( - stringOrError(output), - stringOrError(analyzed.treeString(verbose = true)) - ).filter(_.nonEmpty).mkString("\n") - - s"""== Parsed Logical Plan == - |${stringOrError(logical.treeString(verbose = true))} - |== Analyzed Logical Plan == - |$analyzedPlan - |== Optimized Logical Plan == - |${stringOrError(optimizedPlan.treeString(verbose = true))} - |== Physical Plan == - |${stringOrError(executedPlan.treeString(verbose = true))} - """.stripMargin.trim + val writer = new StringBuilderWriter() + try { + writePlans(writer) + writer.toString + } finally { + writer.close() + } } def stringWithStats: String = withRedaction { @@ -264,5 +291,22 @@ class QueryExecution(val sparkSession: SparkSession, val logical: LogicalPlan) { def codegenToSeq(): Seq[(String, String)] = { org.apache.spark.sql.execution.debug.codegenStringSeq(executedPlan) } + + /** + * Dumps debug information about query execution into the specified file. + */ + def toFile(path: String): Unit = { + val filePath = new Path(path) + val fs = filePath.getFileSystem(sparkSession.sessionState.newHadoopConf()) + val writer = new BufferedWriter(new OutputStreamWriter(fs.create(filePath))) + + try { + writePlans(writer) + writer.write("\n== Whole Stage Codegen ==\n") + org.apache.spark.sql.execution.debug.writeCodegen(writer, executedPlan) + } finally { + writer.close() + } + } } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/ShuffledRowRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/ShuffledRowRDD.scala index 862ee05392f37..9b05faaed0459 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/ShuffledRowRDD.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/ShuffledRowRDD.scala @@ -22,6 +22,7 @@ import java.util.Arrays import org.apache.spark._ import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.execution.metric.{SQLMetric, SQLShuffleMetricsReporter} /** * The [[Partition]] used by [[ShuffledRowRDD]]. A post-shuffle partition @@ -112,6 +113,7 @@ class CoalescedPartitioner(val parent: Partitioner, val partitionStartIndices: A */ class ShuffledRowRDD( var dependency: ShuffleDependency[Int, InternalRow, InternalRow], + metrics: Map[String, SQLMetric], specifiedPartitionStartIndices: Option[Array[Int]] = None) extends RDD[InternalRow](dependency.rdd.context, Nil) { @@ -154,6 +156,10 @@ class ShuffledRowRDD( override def compute(split: Partition, context: TaskContext): Iterator[InternalRow] = { val shuffledRowPartition = split.asInstanceOf[ShuffledRowRDDPartition] + val tempMetrics = context.taskMetrics().createTempShuffleReadMetrics() + // `SQLShuffleMetricsReporter` will update its own metrics for SQL exchange operator, + // as well as the `tempMetrics` for basic shuffle metrics. + val sqlMetricsReporter = new SQLShuffleMetricsReporter(tempMetrics, metrics) // The range of pre-shuffle partitions that we are fetching at here is // [startPreShufflePartitionIndex, endPreShufflePartitionIndex - 1]. val reader = @@ -161,7 +167,8 @@ class ShuffledRowRDD( dependency.shuffleHandle, shuffledRowPartition.startPreShufflePartitionIndex, shuffledRowPartition.endPreShufflePartitionIndex, - context) + context, + sqlMetricsReporter) reader.read().asInstanceOf[Iterator[Product2[Int, InternalRow]]].map(_._2) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegenExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegenExec.scala index 5f81b6fe743c9..fbda0d87a175f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegenExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegenExec.scala @@ -17,6 +17,7 @@ package org.apache.spark.sql.execution +import java.io.Writer import java.util.Locale import java.util.function.Supplier @@ -349,6 +350,15 @@ trait CodegenSupport extends SparkPlan { */ def needStopCheck: Boolean = parent.needStopCheck + /** + * Helper default should stop check code. + */ + def shouldStopCheckCode: String = if (needStopCheck) { + "if (shouldStop()) return;" + } else { + "// shouldStop check is eliminated" + } + /** * A sequence of checks which evaluate to true if the downstream Limit operators have not received * enough records and reached the limit. If current node is a data producing node, it can leverage @@ -405,6 +415,53 @@ trait BlockingOperatorWithCodegen extends CodegenSupport { override def limitNotReachedChecks: Seq[String] = Nil } +/** + * Leaf codegen node reading from a single RDD. + */ +trait InputRDDCodegen extends CodegenSupport { + + def inputRDD: RDD[InternalRow] + + // If the input can be InternalRows, an UnsafeProjection needs to be created. + protected val createUnsafeProjection: Boolean + + override def inputRDDs(): Seq[RDD[InternalRow]] = { + inputRDD :: Nil + } + + override def doProduce(ctx: CodegenContext): String = { + // Inline mutable state since an InputRDDCodegen is used once in a task for WholeStageCodegen + val input = ctx.addMutableState("scala.collection.Iterator", "input", v => s"$v = inputs[0];", + forceInline = true) + val row = ctx.freshName("row") + + val outputVars = if (createUnsafeProjection) { + // creating the vars will make the parent consume add an unsafe projection. + ctx.INPUT_ROW = row + ctx.currentVars = null + output.zipWithIndex.map { case (a, i) => + BoundReference(i, a.dataType, a.nullable).genCode(ctx) + } + } else { + null + } + + val updateNumOutputRowsMetrics = if (metrics.contains("numOutputRows")) { + val numOutputRows = metricTerm(ctx, "numOutputRows") + s"$numOutputRows.add(1);" + } else { + "" + } + s""" + | while ($limitNotReachedCond $input.hasNext()) { + | InternalRow $row = (InternalRow) $input.next(); + | ${updateNumOutputRowsMetrics} + | ${consume(ctx, outputVars, if (createUnsafeProjection) null else row).trim} + | ${shouldStopCheckCode} + | } + """.stripMargin + } +} /** * InputAdapter is used to hide a SparkPlan from a subtree that supports codegen. @@ -412,7 +469,7 @@ trait BlockingOperatorWithCodegen extends CodegenSupport { * This is the leaf node of a tree with WholeStageCodegen that is used to generate code * that consumes an RDD iterator of InternalRow. */ -case class InputAdapter(child: SparkPlan) extends UnaryExecNode with CodegenSupport { +case class InputAdapter(child: SparkPlan) extends UnaryExecNode with InputRDDCodegen { override def output: Seq[Attribute] = child.output @@ -428,33 +485,19 @@ case class InputAdapter(child: SparkPlan) extends UnaryExecNode with CodegenSupp child.doExecuteBroadcast() } - override def inputRDDs(): Seq[RDD[InternalRow]] = { - child.execute() :: Nil - } + override def inputRDD: RDD[InternalRow] = child.execute() - override def doProduce(ctx: CodegenContext): String = { - // Right now, InputAdapter is only used when there is one input RDD. - // Inline mutable state since an InputAdapter is used once in a task for WholeStageCodegen - val input = ctx.addMutableState("scala.collection.Iterator", "input", v => s"$v = inputs[0];", - forceInline = true) - val row = ctx.freshName("row") - s""" - | while ($limitNotReachedCond $input.hasNext()) { - | InternalRow $row = (InternalRow) $input.next(); - | ${consume(ctx, null, row).trim} - | if (shouldStop()) return; - | } - """.stripMargin - } + // InputAdapter does not need UnsafeProjection. + protected val createUnsafeProjection: Boolean = false override def generateTreeString( depth: Int, lastChildren: Seq[Boolean], - builder: StringBuilder, + writer: Writer, verbose: Boolean, prefix: String = "", - addSuffix: Boolean = false): StringBuilder = { - child.generateTreeString(depth, lastChildren, builder, verbose, "") + addSuffix: Boolean = false): Unit = { + child.generateTreeString(depth, lastChildren, writer, verbose, prefix = "", addSuffix = false) } override def needCopyResult: Boolean = false @@ -726,11 +769,11 @@ case class WholeStageCodegenExec(child: SparkPlan)(val codegenStageId: Int) override def generateTreeString( depth: Int, lastChildren: Seq[Boolean], - builder: StringBuilder, + writer: Writer, verbose: Boolean, prefix: String = "", - addSuffix: Boolean = false): StringBuilder = { - child.generateTreeString(depth, lastChildren, builder, verbose, s"*($codegenStageId) ") + addSuffix: Boolean = false): Unit = { + child.generateTreeString(depth, lastChildren, writer, verbose, s"*($codegenStageId) ", false) } override def needStopCheck: Boolean = true diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/QueryStage.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/QueryStage.scala index 20f13c280c12c..33d001629f986 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/QueryStage.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/QueryStage.scala @@ -17,6 +17,7 @@ package org.apache.spark.sql.execution.adaptive +import java.io.Writer import java.util.Properties import scala.concurrent.{ExecutionContext, Future} @@ -172,11 +173,11 @@ abstract class QueryStage extends UnaryExecNode { override def generateTreeString( depth: Int, lastChildren: Seq[Boolean], - builder: StringBuilder, + writer: Writer, verbose: Boolean, prefix: String = "", - addSuffix: Boolean = false): StringBuilder = { - child.generateTreeString(depth, lastChildren, builder, verbose, "*") + addSuffix: Boolean = false): Unit = { + child.generateTreeString(depth, lastChildren, writer, verbose, "*") } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/QueryStageInput.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/QueryStageInput.scala index 887f815b6117a..3a43b1cfe94bc 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/QueryStageInput.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/QueryStageInput.scala @@ -17,12 +17,15 @@ package org.apache.spark.sql.execution.adaptive +import java.io.Writer + import org.apache.spark.broadcast import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeMap, Expression, SortOrder} import org.apache.spark.sql.catalyst.plans.physical.{HashPartitioning, Partitioning, UnknownPartitioning} import org.apache.spark.sql.execution._ +import org.apache.spark.sql.execution.metric.SQLMetrics /** * QueryStageInput is the leaf node of a QueryStage and is used to hide its child stage. It gets @@ -58,11 +61,11 @@ abstract class QueryStageInput extends LeafExecNode { override def generateTreeString( depth: Int, lastChildren: Seq[Boolean], - builder: StringBuilder, + writer: Writer, verbose: Boolean, prefix: String = "", - addSuffix: Boolean = false): StringBuilder = { - childStage.generateTreeString(depth, lastChildren, builder, verbose, "*") + addSuffix: Boolean = false): Unit = { + childStage.generateTreeString(depth, lastChildren, writer, verbose, "*") } } @@ -78,13 +81,15 @@ case class ShuffleQueryStageInput( partitionStartIndices: Option[Array[Int]] = None) extends QueryStageInput { + override lazy val metrics = SQLMetrics.getShuffleReadMetrics(sparkContext) + override def outputPartitioning: Partitioning = partitionStartIndices.map { indices => UnknownPartitioning(indices.length) }.getOrElse(super.outputPartitioning) override def doExecute(): RDD[InternalRow] = { val childRDD = childStage.execute().asInstanceOf[ShuffledRowRDD] - new ShuffledRowRDD(childRDD.dependency, partitionStartIndices) + new ShuffledRowRDD(childRDD.dependency, metrics, partitionStartIndices) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala index 25d8e7dff3d99..4827f838fc514 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala @@ -18,7 +18,7 @@ package org.apache.spark.sql.execution.aggregate import org.apache.spark.TaskContext -import org.apache.spark.memory.TaskMemoryManager +import org.apache.spark.memory.{SparkOutOfMemoryError, TaskMemoryManager} import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.errors._ @@ -27,6 +27,7 @@ import org.apache.spark.sql.catalyst.expressions.aggregate._ import org.apache.spark.sql.catalyst.expressions.codegen._ import org.apache.spark.sql.catalyst.expressions.codegen.Block._ import org.apache.spark.sql.catalyst.plans.physical._ +import org.apache.spark.sql.catalyst.util.truncatedString import org.apache.spark.sql.execution._ import org.apache.spark.sql.execution.metric.{SQLMetric, SQLMetrics} import org.apache.spark.sql.execution.vectorized.MutableColumnarRow @@ -762,6 +763,8 @@ case class HashAggregateExec( ("true", "true", "", "") } + val oomeClassName = classOf[SparkOutOfMemoryError].getName + val findOrInsertRegularHashMap: String = s""" |// generate grouping key @@ -787,7 +790,7 @@ case class HashAggregateExec( | $unsafeRowKeys, ${hashEval.value}); | if ($unsafeRowBuffer == null) { | // failed to allocate the first page - | throw new OutOfMemoryError("No enough memory for aggregation"); + | throw new $oomeClassName("No enough memory for aggregation"); | } |} """.stripMargin @@ -928,9 +931,9 @@ case class HashAggregateExec( testFallbackStartsAt match { case None => - val keyString = Utils.truncatedString(groupingExpressions, "[", ", ", "]") - val functionString = Utils.truncatedString(allAggregateExpressions, "[", ", ", "]") - val outputString = Utils.truncatedString(output, "[", ", ", "]") + val keyString = truncatedString(groupingExpressions, "[", ", ", "]") + val functionString = truncatedString(allAggregateExpressions, "[", ", ", "]") + val outputString = truncatedString(output, "[", ", ", "]") if (verbose) { s"HashAggregate(keys=$keyString, functions=$functionString, output=$outputString)" } else { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/ObjectHashAggregateExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/ObjectHashAggregateExec.scala index 66955b8ef723c..7145bb03028d9 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/ObjectHashAggregateExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/ObjectHashAggregateExec.scala @@ -23,9 +23,9 @@ import org.apache.spark.sql.catalyst.errors._ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.aggregate._ import org.apache.spark.sql.catalyst.plans.physical._ +import org.apache.spark.sql.catalyst.util.truncatedString import org.apache.spark.sql.execution._ import org.apache.spark.sql.execution.metric.SQLMetrics -import org.apache.spark.util.Utils /** * A hash-based aggregate operator that supports [[TypedImperativeAggregate]] functions that may @@ -143,9 +143,9 @@ case class ObjectHashAggregateExec( private def toString(verbose: Boolean): String = { val allAggregateExpressions = aggregateExpressions - val keyString = Utils.truncatedString(groupingExpressions, "[", ", ", "]") - val functionString = Utils.truncatedString(allAggregateExpressions, "[", ", ", "]") - val outputString = Utils.truncatedString(output, "[", ", ", "]") + val keyString = truncatedString(groupingExpressions, "[", ", ", "]") + val functionString = truncatedString(allAggregateExpressions, "[", ", ", "]") + val outputString = truncatedString(output, "[", ", ", "]") if (verbose) { s"ObjectHashAggregate(keys=$keyString, functions=$functionString, output=$outputString)" } else { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/SortAggregateExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/SortAggregateExec.scala index fc87de2c52e41..d732b905dcdd5 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/SortAggregateExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/SortAggregateExec.scala @@ -23,9 +23,9 @@ import org.apache.spark.sql.catalyst.errors._ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.aggregate._ import org.apache.spark.sql.catalyst.plans.physical._ +import org.apache.spark.sql.catalyst.util.truncatedString import org.apache.spark.sql.execution.{SparkPlan, UnaryExecNode} import org.apache.spark.sql.execution.metric.SQLMetrics -import org.apache.spark.util.Utils /** * Sort-based aggregate operator. @@ -114,9 +114,9 @@ case class SortAggregateExec( private def toString(verbose: Boolean): String = { val allAggregateExpressions = aggregateExpressions - val keyString = Utils.truncatedString(groupingExpressions, "[", ", ", "]") - val functionString = Utils.truncatedString(allAggregateExpressions, "[", ", ", "]") - val outputString = Utils.truncatedString(output, "[", ", ", "]") + val keyString = truncatedString(groupingExpressions, "[", ", ", "]") + val functionString = truncatedString(allAggregateExpressions, "[", ", ", "]") + val outputString = truncatedString(output, "[", ", ", "]") if (verbose) { s"SortAggregate(key=$keyString, functions=$functionString, output=$outputString)" } else { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregationIterator.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregationIterator.scala index 72505f7fac0c6..6d849869b577a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregationIterator.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregationIterator.scala @@ -206,7 +206,9 @@ class TungstenAggregationIterator( buffer = hashMap.getAggregationBufferFromUnsafeRow(groupingKey) if (buffer == null) { // failed to allocate the first page + // scalastyle:off throwerror throw new SparkOutOfMemoryError("No enough memory for aggregation") + // scalastyle:on throwerror } } processRow(buffer, newInput) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TypedAggregateExpression.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TypedAggregateExpression.scala index 39200ec00e152..b75752945a492 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TypedAggregateExpression.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TypedAggregateExpression.scala @@ -40,9 +40,9 @@ object TypedAggregateExpression { val outputEncoder = encoderFor[OUT] val outputType = outputEncoder.objSerializer.dataType - // Checks if the buffer object is simple, i.e. the buffer encoder is flat and the serializer - // expression is an alias of `BoundReference`, which means the buffer object doesn't need - // serialization. + // Checks if the buffer object is simple, i.e. the `BUF` type is not serialized as struct + // and the serializer expression is an alias of `BoundReference`, which means the buffer + // object doesn't need serialization. val isSimpleBuffer = { bufferSerializer.head match { case Alias(_: BoundReference, _) if !bufferEncoder.isSerializedAsStruct => true @@ -76,7 +76,7 @@ object TypedAggregateExpression { None, bufferSerializer, bufferEncoder.resolveAndBind().deserializer, - outputEncoder.serializer, + outputEncoder.objSerializer, outputType, outputEncoder.objSerializer.nullable) } @@ -213,7 +213,7 @@ case class ComplexTypedAggregateExpression( inputSchema: Option[StructType], bufferSerializer: Seq[NamedExpression], bufferDeserializer: Expression, - outputSerializer: Seq[Expression], + outputSerializer: Expression, dataType: DataType, nullable: Boolean, mutableAggBufferOffset: Int = 0, @@ -245,13 +245,7 @@ case class ComplexTypedAggregateExpression( aggregator.merge(buffer, input) } - private lazy val resultObjToRow = dataType match { - case _: StructType => - UnsafeProjection.create(CreateStruct(outputSerializer)) - case _ => - assert(outputSerializer.length == 1) - UnsafeProjection.create(outputSerializer.head) - } + private lazy val resultObjToRow = UnsafeProjection.create(outputSerializer) override def eval(buffer: Any): Any = { val resultObj = aggregator.finish(buffer) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryRelation.scala index 3b6588587c35a..73eb65f84489c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryRelation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryRelation.scala @@ -27,9 +27,10 @@ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.QueryPlan import org.apache.spark.sql.catalyst.plans.logical import org.apache.spark.sql.catalyst.plans.logical.{HintInfo, LogicalPlan, Statistics} +import org.apache.spark.sql.catalyst.util.truncatedString import org.apache.spark.sql.execution.SparkPlan import org.apache.spark.storage.StorageLevel -import org.apache.spark.util.{LongAccumulator, Utils} +import org.apache.spark.util.LongAccumulator /** @@ -209,5 +210,5 @@ case class InMemoryRelation( override protected def otherCopyArgs: Seq[AnyRef] = Seq(statsOfPlanToCache) override def simpleString: String = - s"InMemoryRelation [${Utils.truncatedString(output, ", ")}], ${cacheBuilder.storageLevel}" + s"InMemoryRelation [${truncatedString(output, ", ")}], ${cacheBuilder.storageLevel}" } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/tables.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/tables.scala index 823dc0d5ed387..e2cd40906f401 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/tables.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/tables.scala @@ -231,7 +231,8 @@ case class AlterTableAddColumnsCommand( } if (DDLUtils.isDatasourceTable(catalogTable)) { - DataSource.lookupDataSource(catalogTable.provider.get, conf).newInstance() match { + DataSource.lookupDataSource(catalogTable.provider.get, conf). + getConstructor().newInstance() match { // For datasource table, this command can only support the following File format. // TextFileFormat only default to one column "value" // Hive type is already considered as hive serde table, so the logic will not diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala index 55e627135c877..346d1f36ad9b3 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala @@ -206,7 +206,7 @@ case class DataSource( /** Returns the name and schema of the source that can be used to continually read data. */ private def sourceSchema(): SourceInfo = { - providingClass.newInstance() match { + providingClass.getConstructor().newInstance() match { case s: StreamSourceProvider => val (name, schema) = s.sourceSchema( sparkSession.sqlContext, userSpecifiedSchema, className, caseInsensitiveOptions) @@ -252,7 +252,7 @@ case class DataSource( /** Returns a source that can be used to continually read data. */ def createSource(metadataPath: String): Source = { - providingClass.newInstance() match { + providingClass.getConstructor().newInstance() match { case s: StreamSourceProvider => s.createSource( sparkSession.sqlContext, @@ -281,7 +281,7 @@ case class DataSource( /** Returns a sink that can be used to continually write data. */ def createSink(outputMode: OutputMode): Sink = { - providingClass.newInstance() match { + providingClass.getConstructor().newInstance() match { case s: StreamSinkProvider => s.createSink(sparkSession.sqlContext, caseInsensitiveOptions, partitionColumns, outputMode) @@ -312,7 +312,7 @@ case class DataSource( * that files already exist, we don't need to check them again. */ def resolveRelation(checkFilesExist: Boolean = true): BaseRelation = { - val relation = (providingClass.newInstance(), userSpecifiedSchema) match { + val relation = (providingClass.getConstructor().newInstance(), userSpecifiedSchema) match { // TODO: Throw when too much is given. case (dataSource: SchemaRelationProvider, Some(schema)) => dataSource.createRelation(sparkSession.sqlContext, caseInsensitiveOptions, schema) @@ -481,7 +481,7 @@ case class DataSource( throw new AnalysisException("Cannot save interval data type into external storage.") } - providingClass.newInstance() match { + providingClass.getConstructor().newInstance() match { case dataSource: CreatableRelationProvider => dataSource.createRelation( sparkSession.sqlContext, mode, caseInsensitiveOptions, Dataset.ofRows(sparkSession, data)) @@ -518,7 +518,7 @@ case class DataSource( throw new AnalysisException("Cannot save interval data type into external storage.") } - providingClass.newInstance() match { + providingClass.getConstructor().newInstance() match { case dataSource: CreatableRelationProvider => SaveIntoDataSourceCommand(data, dataSource, caseInsensitiveOptions, mode) case format: FileFormat => diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileSourceStrategy.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileSourceStrategy.scala index fe27b78bf3360..62ab5c80d47cf 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileSourceStrategy.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileSourceStrategy.scala @@ -150,7 +150,7 @@ object FileSourceStrategy extends Strategy with Logging { // The attribute name of predicate could be different than the one in schema in case of // case insensitive, we should change them to match the one in schema, so we do not need to // worry about case sensitivity anymore. - val normalizedFilters = filters.map { e => + val normalizedFilters = filters.filterNot(SubqueryExpression.hasSubquery).map { e => e transform { case a: AttributeReference => a.withName(l.output.find(_.semanticEquals(a)).get.name) @@ -163,7 +163,6 @@ object FileSourceStrategy extends Strategy with Logging { val partitionSet = AttributeSet(partitionColumns) val partitionKeyFilters = ExpressionSet(normalizedFilters - .filterNot(SubqueryExpression.hasSubquery(_)) .filter(_.references.subsetOf(partitionSet))) logInfo(s"Pruning directories with: ${partitionKeyFilters.mkString(",")}") diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/LogicalRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/LogicalRelation.scala index 8d715f6342988..1023572d19e2e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/LogicalRelation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/LogicalRelation.scala @@ -21,8 +21,8 @@ import org.apache.spark.sql.catalyst.catalog.CatalogTable import org.apache.spark.sql.catalyst.expressions.{AttributeMap, AttributeReference} import org.apache.spark.sql.catalyst.plans.QueryPlan import org.apache.spark.sql.catalyst.plans.logical.{LeafNode, LogicalPlan, Statistics} +import org.apache.spark.sql.catalyst.util.truncatedString import org.apache.spark.sql.sources.BaseRelation -import org.apache.spark.util.Utils /** * Used to link a [[BaseRelation]] in to a logical query plan. @@ -63,7 +63,7 @@ case class LogicalRelation( case _ => // Do nothing. } - override def simpleString: String = s"Relation[${Utils.truncatedString(output, ",")}] $relation" + override def simpleString: String = s"Relation[${truncatedString(output, ",")}] $relation" } object LogicalRelation { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PruneFileSourcePartitions.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PruneFileSourcePartitions.scala index 16b2367bfdd5c..329b9539f52e3 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PruneFileSourcePartitions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PruneFileSourcePartitions.scala @@ -42,7 +42,7 @@ private[sql] object PruneFileSourcePartitions extends Rule[LogicalPlan] { // The attribute name of predicate could be different than the one in schema in case of // case insensitive, we should change them to match the one in schema, so we donot need to // worry about case sensitivity anymore. - val normalizedFilters = filters.map { e => + val normalizedFilters = filters.filterNot(SubqueryExpression.hasSubquery).map { e => e transform { case a: AttributeReference => a.withName(logicalRelation.output.find(_.semanticEquals(a)).get.name) @@ -56,7 +56,6 @@ private[sql] object PruneFileSourcePartitions extends Rule[LogicalPlan] { val partitionSet = AttributeSet(partitionColumns) val partitionKeyFilters = ExpressionSet(normalizedFilters - .filterNot(SubqueryExpression.hasSubquery(_)) .filter(_.references.subsetOf(partitionSet))) if (partitionKeyFilters.nonEmpty) { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVDataSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVDataSource.scala index 4808e8ef042d1..b35b8851918b1 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVDataSource.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVDataSource.scala @@ -95,7 +95,7 @@ object TextInputCSVDataSource extends CSVDataSource { headerChecker: CSVHeaderChecker, requiredSchema: StructType): Iterator[InternalRow] = { val lines = { - val linesReader = new HadoopFileLinesReader(file, conf) + val linesReader = new HadoopFileLinesReader(file, parser.options.lineSeparatorInRead, conf) Option(TaskContext.get()).foreach(_.addTaskCompletionListener[Unit](_ => linesReader.close())) linesReader.map { line => new String(line.getBytes, 0, line.getLength, parser.options.charset) @@ -192,7 +192,8 @@ object MultiLineCSVDataSource extends CSVDataSource { UnivocityParser.tokenizeStream( CodecStreams.createInputStreamWithCloseResource(lines.getConfiguration, path), shouldDropHeader = false, - new CsvParser(parsedOptions.asParserSettings)) + new CsvParser(parsedOptions.asParserSettings), + encoding = parsedOptions.charset) }.take(1).headOption match { case Some(firstRow) => val caseSensitive = sparkSession.sessionState.conf.caseSensitiveAnalysis @@ -203,7 +204,8 @@ object MultiLineCSVDataSource extends CSVDataSource { lines.getConfiguration, new Path(lines.getPath())), parsedOptions.headerFlag, - new CsvParser(parsedOptions.asParserSettings)) + new CsvParser(parsedOptions.asParserSettings), + encoding = parsedOptions.charset) } val sampled = CSVUtils.sample(tokenRDD, parsedOptions) CSVInferSchema.infer(sampled, header, parsedOptions) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVFileFormat.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVFileFormat.scala index 964b56e706a0b..ff1911d69a6b5 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVFileFormat.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVFileFormat.scala @@ -27,6 +27,7 @@ import org.apache.spark.internal.Logging import org.apache.spark.sql.{AnalysisException, SparkSession} import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.csv.{CSVHeaderChecker, CSVOptions, UnivocityGenerator, UnivocityParser} +import org.apache.spark.sql.catalyst.expressions.ExprUtils import org.apache.spark.sql.catalyst.util.CompressionCodecs import org.apache.spark.sql.execution.datasources._ import org.apache.spark.sql.sources._ @@ -110,13 +111,7 @@ class CSVFileFormat extends TextBasedFileFormat with DataSourceRegister { sparkSession.sessionState.conf.columnNameOfCorruptRecord) // Check a field requirement for corrupt records here to throw an exception in a driver side - dataSchema.getFieldIndex(parsedOptions.columnNameOfCorruptRecord).foreach { corruptFieldIndex => - val f = dataSchema(corruptFieldIndex) - if (f.dataType != StringType || !f.nullable) { - throw new AnalysisException( - "The field for corrupt records must be string type and nullable") - } - } + ExprUtils.verifyColumnNameOfCorruptRecord(dataSchema, parsedOptions.columnNameOfCorruptRecord) if (requiredSchema.length == 1 && requiredSchema.head.name == parsedOptions.columnNameOfCorruptRecord) { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/DriverRegistry.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/DriverRegistry.scala index 1723596de1db2..530d836d9fde3 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/DriverRegistry.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/DriverRegistry.scala @@ -50,7 +50,7 @@ object DriverRegistry extends Logging { } else { synchronized { if (wrapperMap.get(className).isEmpty) { - val wrapper = new DriverWrapper(cls.newInstance().asInstanceOf[Driver]) + val wrapper = new DriverWrapper(cls.getConstructor().newInstance().asInstanceOf[Driver]) DriverManager.registerDriver(wrapper) wrapperMap(className) = wrapper logTrace(s"Wrapper for $className registered") diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRelation.scala index f15014442e3fb..51c385e25bee3 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRelation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRelation.scala @@ -27,10 +27,10 @@ import org.apache.spark.rdd.RDD import org.apache.spark.sql.{AnalysisException, DataFrame, Row, SaveMode, SparkSession, SQLContext} import org.apache.spark.sql.catalyst.analysis._ import org.apache.spark.sql.catalyst.util.DateTimeUtils +import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.jdbc.JdbcDialects import org.apache.spark.sql.sources._ import org.apache.spark.sql.types.{DataType, DateType, NumericType, StructType, TimestampType} -import org.apache.spark.util.Utils /** * Instructions on how to partition the table among workers. @@ -159,8 +159,9 @@ private[sql] object JDBCRelation extends Logging { val column = schema.find { f => resolver(f.name, columnName) || resolver(dialect.quoteIdentifier(f.name), columnName) }.getOrElse { + val maxNumToStringFields = SQLConf.get.maxToStringFields throw new AnalysisException(s"User-defined partition column $columnName not " + - s"found in the JDBC relation: ${schema.simpleString(Utils.maxNumToStringFields)}") + s"found in the JDBC relation: ${schema.simpleString(maxNumToStringFields)}") } column.dataType match { case _: NumericType | DateType | TimestampType => diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonFileFormat.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonFileFormat.scala index 1f7c9d73f19fe..610f0d1619fc9 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonFileFormat.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonFileFormat.scala @@ -26,7 +26,8 @@ import org.apache.hadoop.mapreduce.{Job, TaskAttemptContext} import org.apache.spark.internal.Logging import org.apache.spark.sql.{AnalysisException, SparkSession} import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.json.{JacksonGenerator, JacksonParser, JSONOptions, JSONOptionsInRead} +import org.apache.spark.sql.catalyst.expressions.ExprUtils +import org.apache.spark.sql.catalyst.json._ import org.apache.spark.sql.catalyst.util.CompressionCodecs import org.apache.spark.sql.execution.datasources._ import org.apache.spark.sql.sources._ @@ -107,13 +108,7 @@ class JsonFileFormat extends TextBasedFileFormat with DataSourceRegister { val actualSchema = StructType(requiredSchema.filterNot(_.name == parsedOptions.columnNameOfCorruptRecord)) // Check a field requirement for corrupt records here to throw an exception in a driver side - dataSchema.getFieldIndex(parsedOptions.columnNameOfCorruptRecord).foreach { corruptFieldIndex => - val f = dataSchema(corruptFieldIndex) - if (f.dataType != StringType || !f.nullable) { - throw new AnalysisException( - "The field for corrupt records must be string type and nullable") - } - } + ExprUtils.verifyColumnNameOfCorruptRecord(dataSchema, parsedOptions.columnNameOfCorruptRecord) if (requiredSchema.length == 1 && requiredSchema.head.name == parsedOptions.columnNameOfCorruptRecord) { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcOutputWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcOutputWriter.scala index 84755bfa301f0..7e38fc651a31f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcOutputWriter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcOutputWriter.scala @@ -20,8 +20,9 @@ package org.apache.spark.sql.execution.datasources.orc import org.apache.hadoop.fs.Path import org.apache.hadoop.io.NullWritable import org.apache.hadoop.mapreduce.TaskAttemptContext -import org.apache.orc.mapred.OrcStruct -import org.apache.orc.mapreduce.OrcOutputFormat +import org.apache.orc.OrcFile +import org.apache.orc.mapred.{OrcOutputFormat => OrcMapRedOutputFormat, OrcStruct} +import org.apache.orc.mapreduce.{OrcMapreduceRecordWriter, OrcOutputFormat} import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.execution.datasources.OutputWriter @@ -36,11 +37,17 @@ private[orc] class OrcOutputWriter( private[this] val serializer = new OrcSerializer(dataSchema) private val recordWriter = { - new OrcOutputFormat[OrcStruct]() { + val orcOutputFormat = new OrcOutputFormat[OrcStruct]() { override def getDefaultWorkFile(context: TaskAttemptContext, extension: String): Path = { new Path(path) } - }.getRecordWriter(context) + } + val filename = orcOutputFormat.getDefaultWorkFile(context, ".orc") + val options = OrcMapRedOutputFormat.buildOptions(context.getConfiguration) + val writer = OrcFile.createWriter(filename, options) + val recordWriter = new OrcMapreduceRecordWriter[OrcStruct](writer) + OrcUtils.addSparkVersionMetadata(writer) + recordWriter } override def write(row: InternalRow): Unit = { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcUtils.scala index 95fb25bf5addb..57d2c56e87b4d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcUtils.scala @@ -17,18 +17,19 @@ package org.apache.spark.sql.execution.datasources.orc +import java.nio.charset.StandardCharsets.UTF_8 import java.util.Locale import scala.collection.JavaConverters._ import org.apache.hadoop.conf.Configuration import org.apache.hadoop.fs.{FileStatus, Path} -import org.apache.orc.{OrcFile, Reader, TypeDescription} +import org.apache.orc.{OrcFile, Reader, TypeDescription, Writer} -import org.apache.spark.SparkException +import org.apache.spark.{SPARK_VERSION_SHORT, SparkException} import org.apache.spark.deploy.SparkHadoopUtil import org.apache.spark.internal.Logging -import org.apache.spark.sql.SparkSession +import org.apache.spark.sql.{SPARK_VERSION_METADATA_KEY, SparkSession} import org.apache.spark.sql.catalyst.analysis.caseSensitiveResolution import org.apache.spark.sql.catalyst.parser.CatalystSqlParser import org.apache.spark.sql.types._ @@ -144,4 +145,11 @@ object OrcUtils extends Logging { } } } + + /** + * Add a metadata specifying Spark version. + */ + def addSparkVersionMetadata(writer: Writer): Unit = { + writer.addUserMetadata(SPARK_VERSION_METADATA_KEY, UTF_8.encode(SPARK_VERSION_SHORT)) + } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetWriteSupport.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetWriteSupport.scala index 4096e88ddd123..bca48c33cd523 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetWriteSupport.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetWriteSupport.scala @@ -29,7 +29,9 @@ import org.apache.parquet.hadoop.api.WriteSupport import org.apache.parquet.hadoop.api.WriteSupport.WriteContext import org.apache.parquet.io.api.{Binary, RecordConsumer} +import org.apache.spark.SPARK_VERSION_SHORT import org.apache.spark.internal.Logging +import org.apache.spark.sql.SPARK_VERSION_METADATA_KEY import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.SpecializedGetters import org.apache.spark.sql.catalyst.util.DateTimeUtils @@ -92,7 +94,10 @@ private[parquet] class ParquetWriteSupport extends WriteSupport[InternalRow] wit this.rootFieldWriters = schema.map(_.dataType).map(makeWriter).toArray[ValueWriter] val messageType = new SparkToParquetSchemaConverter(configuration).convert(schema) - val metadata = Map(ParquetReadSupport.SPARK_METADATA_KEY -> schemaString).asJava + val metadata = Map( + SPARK_VERSION_METADATA_KEY -> SPARK_VERSION_SHORT, + ParquetReadSupport.SPARK_METADATA_KEY -> schemaString + ).asJava logInfo( s"""Initialized Parquet WriteSupport with Catalyst schema: diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2StringFormat.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2StringFormat.scala index 97e6c6d702acb..e829f621b4ea3 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2StringFormat.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2StringFormat.scala @@ -20,6 +20,7 @@ package org.apache.spark.sql.execution.datasources.v2 import org.apache.commons.lang3.StringUtils import org.apache.spark.sql.catalyst.expressions.{Attribute, Expression} +import org.apache.spark.sql.catalyst.util.truncatedString import org.apache.spark.sql.sources.DataSourceRegister import org.apache.spark.sql.sources.v2.DataSourceV2 import org.apache.spark.util.Utils @@ -72,10 +73,10 @@ trait DataSourceV2StringFormat { }.mkString("[", ",", "]") } - val outputStr = Utils.truncatedString(output, "[", ", ", "]") + val outputStr = truncatedString(output, "[", ", ", "]") val entriesStr = if (entries.nonEmpty) { - Utils.truncatedString(entries.map { + truncatedString(entries.map { case (key, value) => key + ": " + StringUtils.abbreviate(value, 100) }, " (", ", ", ")") } else { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/debug/package.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/debug/package.scala index 366e1fe6a4aaa..3511cefa7c292 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/debug/package.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/debug/package.scala @@ -17,10 +17,13 @@ package org.apache.spark.sql.execution +import java.io.Writer import java.util.Collections import scala.collection.JavaConverters._ +import org.apache.commons.io.output.StringBuilderWriter + import org.apache.spark.internal.Logging import org.apache.spark.rdd.RDD import org.apache.spark.sql._ @@ -30,7 +33,6 @@ import org.apache.spark.sql.catalyst.expressions.codegen.{CodeFormatter, Codegen import org.apache.spark.sql.catalyst.plans.physical.Partitioning import org.apache.spark.sql.catalyst.trees.TreeNodeRef import org.apache.spark.sql.execution.streaming.{StreamExecution, StreamingQueryWrapper} -import org.apache.spark.sql.execution.streaming.continuous.WriteToContinuousDataSourceExec import org.apache.spark.sql.streaming.StreamingQuery import org.apache.spark.util.{AccumulatorV2, LongAccumulator} @@ -70,15 +72,25 @@ package object debug { * @return single String containing all WholeStageCodegen subtrees and corresponding codegen */ def codegenString(plan: SparkPlan): String = { + val writer = new StringBuilderWriter() + + try { + writeCodegen(writer, plan) + writer.toString + } finally { + writer.close() + } + } + + def writeCodegen(writer: Writer, plan: SparkPlan): Unit = { val codegenSeq = codegenStringSeq(plan) - var output = s"Found ${codegenSeq.size} WholeStageCodegen subtrees.\n" + writer.write(s"Found ${codegenSeq.size} WholeStageCodegen subtrees.\n") for (((subtree, code), i) <- codegenSeq.zipWithIndex) { - output += s"== Subtree ${i + 1} / ${codegenSeq.size} ==\n" - output += subtree - output += "\nGenerated code:\n" - output += s"${code}\n" + writer.write(s"== Subtree ${i + 1} / ${codegenSeq.size} ==\n") + writer.write(subtree) + writer.write("\nGenerated code:\n") + writer.write(s"${code}\n") } - output } /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/ShuffleExchangeExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/ShuffleExchangeExec.scala index bbd1a3f005d74..6387ea27c342a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/ShuffleExchangeExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/ShuffleExchangeExec.scala @@ -26,7 +26,7 @@ import org.apache.spark.serializer.Serializer import org.apache.spark.shuffle.sort.SortShuffleManager import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.errors._ -import org.apache.spark.sql.catalyst.expressions.{Attribute, UnsafeProjection, UnsafeRow} +import org.apache.spark.sql.catalyst.expressions.{Attribute, BoundReference, UnsafeProjection, UnsafeRow} import org.apache.spark.sql.catalyst.expressions.codegen.LazilyGeneratedOrdering import org.apache.spark.sql.catalyst.plans.physical._ import org.apache.spark.sql.execution._ @@ -47,7 +47,8 @@ case class ShuffleExchangeExec( // e.g. it can be null on the Executor side override lazy val metrics = Map( - "dataSize" -> SQLMetrics.createSizeMetric(sparkContext, "data size")) + "dataSize" -> SQLMetrics.createSizeMetric(sparkContext, "data size") + ) ++ SQLMetrics.getShuffleReadMetrics(sparkContext) override def nodeName: String = { "Exchange" @@ -85,7 +86,7 @@ case class ShuffleExchangeExec( assert(newPartitioning.isInstanceOf[HashPartitioning]) newPartitioning = UnknownPartitioning(indices.length) } - new ShuffledRowRDD(shuffleDependency, specifiedPartitionStartIndices) + new ShuffledRowRDD(shuffleDependency, metrics, specifiedPartitionStartIndices) } /** @@ -198,13 +199,21 @@ object ShuffleExchangeExec { override def getPartition(key: Any): Int = key.asInstanceOf[Int] } case RangePartitioning(sortingExpressions, numPartitions) => - // Internally, RangePartitioner runs a job on the RDD that samples keys to compute - // partition bounds. To get accurate samples, we need to copy the mutable keys. + // Extract only fields used for sorting to avoid collecting large fields that does not + // affect sorting result when deciding partition bounds in RangePartitioner val rddForSampling = rdd.mapPartitionsInternal { iter => + val projection = + UnsafeProjection.create(sortingExpressions.map(_.child), outputAttributes) val mutablePair = new MutablePair[InternalRow, Null]() - iter.map(row => mutablePair.update(row.copy(), null)) + // Internally, RangePartitioner runs a job on the RDD that samples keys to compute + // partition bounds. To get accurate samples, we need to copy the mutable keys. + iter.map(row => mutablePair.update(projection(row).copy(), null)) } - implicit val ordering = new LazilyGeneratedOrdering(sortingExpressions, outputAttributes) + // Construct ordering on extracted sort key. + val orderingAttributes = sortingExpressions.zipWithIndex.map { case (ord, i) => + ord.copy(child = BoundReference(i, ord.dataType, ord.nullable)) + } + implicit val ordering = new LazilyGeneratedOrdering(orderingAttributes) new RangePartitioner( numPartitions, rddForSampling, @@ -230,7 +239,10 @@ object ShuffleExchangeExec { case h: HashPartitioning => val projection = UnsafeProjection.create(h.partitionIdExpression :: Nil, outputAttributes) row => projection(row).getInt(0) - case RangePartitioning(_, _) | SinglePartition => identity + case RangePartitioning(sortingExpressions, _) => + val projection = UnsafeProjection.create(sortingExpressions.map(_.child), outputAttributes) + row => projection(row) + case SinglePartition => identity case _ => sys.error(s"Exchange not implemented for $newPartitioning") } @@ -253,7 +265,7 @@ object ShuffleExchangeExec { } // The comparator for comparing row hashcode, which should always be Integer. val prefixComparator = PrefixComparators.LONG - val canUseRadixSort = SparkEnv.get.conf.get(SQLConf.RADIX_SORT_ENABLED) + val canUseRadixSort = SQLConf.get.enableRadixSort // The prefix computer generates row hashcode as the prefix, so we may decrease the // probability that the prefixes are equal when input rows choose column values from a // limited range. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/limit.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/limit.scala index 256f12b605c92..3edfec5e5bb65 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/limit.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/limit.scala @@ -23,8 +23,9 @@ import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, CodeGenerator, ExprCode, LazilyGeneratedOrdering} import org.apache.spark.sql.catalyst.plans.physical._ +import org.apache.spark.sql.catalyst.util.truncatedString import org.apache.spark.sql.execution.exchange.ShuffleExchangeExec -import org.apache.spark.util.Utils +import org.apache.spark.sql.execution.metric.SQLMetrics /** * Take the first `limit` elements and collect them to a single partition. @@ -40,11 +41,13 @@ case class CollectLimitExec(limit: Int, child: SparkPlan) extends UnaryExecNode LocalLimitExec(limit, child).executeToIterator().take(limit) } private val serializer: Serializer = new UnsafeRowSerializer(child.output.size) + override lazy val metrics = SQLMetrics.getShuffleReadMetrics(sparkContext) protected override def doExecute(): RDD[InternalRow] = { val locallyLimited = child.execute().mapPartitionsInternal(_.take(limit)) val shuffled = new ShuffledRowRDD( ShuffleExchangeExec.prepareShuffleDependency( - locallyLimited, child.output, SinglePartition, serializer)) + locallyLimited, child.output, SinglePartition, serializer), + metrics) shuffled.mapPartitionsInternal(_.take(limit)) } } @@ -154,6 +157,8 @@ case class TakeOrderedAndProjectExec( private val serializer: Serializer = new UnsafeRowSerializer(child.output.size) + override lazy val metrics = SQLMetrics.getShuffleReadMetrics(sparkContext) + protected override def doExecute(): RDD[InternalRow] = { val ord = new LazilyGeneratedOrdering(sortOrder, child.output) val localTopK: RDD[InternalRow] = { @@ -163,7 +168,8 @@ case class TakeOrderedAndProjectExec( } val shuffled = new ShuffledRowRDD( ShuffleExchangeExec.prepareShuffleDependency( - localTopK, child.output, SinglePartition, serializer)) + localTopK, child.output, SinglePartition, serializer), + metrics) shuffled.mapPartitions { iter => val topK = org.apache.spark.util.collection.Utils.takeOrdered(iter.map(_.copy()), limit)(ord) if (projectList != child.output) { @@ -180,8 +186,8 @@ case class TakeOrderedAndProjectExec( override def outputPartitioning: Partitioning = SinglePartition override def simpleString: String = { - val orderByString = Utils.truncatedString(sortOrder, "[", ",", "]") - val outputString = Utils.truncatedString(output, "[", ",", "]") + val orderByString = truncatedString(sortOrder, "[", ",", "]") + val outputString = truncatedString(output, "[", ",", "]") s"TakeOrderedAndProject(limit=$limit, orderBy=$orderByString, output=$outputString)" } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/metric/SQLMetrics.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/metric/SQLMetrics.scala index cbf707f4a9cfd..0b5ee3a5e0577 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/metric/SQLMetrics.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/metric/SQLMetrics.scala @@ -82,6 +82,14 @@ object SQLMetrics { private val baseForAvgMetric: Int = 10 + val REMOTE_BLOCKS_FETCHED = "remoteBlocksFetched" + val LOCAL_BLOCKS_FETCHED = "localBlocksFetched" + val REMOTE_BYTES_READ = "remoteBytesRead" + val REMOTE_BYTES_READ_TO_DISK = "remoteBytesReadToDisk" + val LOCAL_BYTES_READ = "localBytesRead" + val FETCH_WAIT_TIME = "fetchWaitTime" + val RECORDS_READ = "recordsRead" + /** * Converts a double value to long value by multiplying a base integer, so we can store it in * `SQLMetrics`. It only works for average metrics. When showing the metrics on UI, we restore @@ -194,4 +202,16 @@ object SQLMetrics { SparkListenerDriverAccumUpdates(executionId.toLong, metrics.map(m => m.id -> m.value))) } } + + /** + * Create all shuffle read relative metrics and return the Map. + */ + def getShuffleReadMetrics(sc: SparkContext): Map[String, SQLMetric] = Map( + REMOTE_BLOCKS_FETCHED -> createMetric(sc, "remote blocks fetched"), + LOCAL_BLOCKS_FETCHED -> createMetric(sc, "local blocks fetched"), + REMOTE_BYTES_READ -> createSizeMetric(sc, "remote bytes read"), + REMOTE_BYTES_READ_TO_DISK -> createSizeMetric(sc, "remote bytes read to disk"), + LOCAL_BYTES_READ -> createSizeMetric(sc, "local bytes read"), + FETCH_WAIT_TIME -> createTimingMetric(sc, "fetch wait time"), + RECORDS_READ -> createMetric(sc, "records read")) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/metric/SQLShuffleMetricsReporter.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/metric/SQLShuffleMetricsReporter.scala new file mode 100644 index 0000000000000..542141ea4b4e6 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/metric/SQLShuffleMetricsReporter.scala @@ -0,0 +1,67 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.metric + +import org.apache.spark.executor.TempShuffleReadMetrics + +/** + * A shuffle metrics reporter for SQL exchange operators. + * @param tempMetrics [[TempShuffleReadMetrics]] created in TaskContext. + * @param metrics All metrics in current SparkPlan. This param should not empty and + * contains all shuffle metrics defined in [[SQLMetrics.getShuffleReadMetrics]]. + */ +private[spark] class SQLShuffleMetricsReporter( + tempMetrics: TempShuffleReadMetrics, + metrics: Map[String, SQLMetric]) extends TempShuffleReadMetrics { + private[this] val _remoteBlocksFetched = metrics(SQLMetrics.REMOTE_BLOCKS_FETCHED) + private[this] val _localBlocksFetched = metrics(SQLMetrics.LOCAL_BLOCKS_FETCHED) + private[this] val _remoteBytesRead = metrics(SQLMetrics.REMOTE_BYTES_READ) + private[this] val _remoteBytesReadToDisk = metrics(SQLMetrics.REMOTE_BYTES_READ_TO_DISK) + private[this] val _localBytesRead = metrics(SQLMetrics.LOCAL_BYTES_READ) + private[this] val _fetchWaitTime = metrics(SQLMetrics.FETCH_WAIT_TIME) + private[this] val _recordsRead = metrics(SQLMetrics.RECORDS_READ) + + override def incRemoteBlocksFetched(v: Long): Unit = { + _remoteBlocksFetched.add(v) + tempMetrics.incRemoteBlocksFetched(v) + } + override def incLocalBlocksFetched(v: Long): Unit = { + _localBlocksFetched.add(v) + tempMetrics.incLocalBlocksFetched(v) + } + override def incRemoteBytesRead(v: Long): Unit = { + _remoteBytesRead.add(v) + tempMetrics.incRemoteBytesRead(v) + } + override def incRemoteBytesReadToDisk(v: Long): Unit = { + _remoteBytesReadToDisk.add(v) + tempMetrics.incRemoteBytesReadToDisk(v) + } + override def incLocalBytesRead(v: Long): Unit = { + _localBytesRead.add(v) + tempMetrics.incLocalBytesRead(v) + } + override def incFetchWaitTime(v: Long): Unit = { + _fetchWaitTime.add(v) + tempMetrics.incFetchWaitTime(v) + } + override def incRecordsRead(v: Long): Unit = { + _recordsRead.add(v) + tempMetrics.incRecordsRead(v) + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/RowQueue.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/RowQueue.scala index d2820ff335ecf..eb12641f548ad 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/RowQueue.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/RowQueue.scala @@ -23,7 +23,7 @@ import com.google.common.io.Closeables import org.apache.spark.{SparkEnv, SparkException} import org.apache.spark.io.NioBufferedFileInputStream -import org.apache.spark.memory.{MemoryConsumer, TaskMemoryManager} +import org.apache.spark.memory.{MemoryConsumer, SparkOutOfMemoryError, TaskMemoryManager} import org.apache.spark.serializer.SerializerManager import org.apache.spark.sql.catalyst.expressions.UnsafeRow import org.apache.spark.unsafe.Platform @@ -226,7 +226,7 @@ private[python] case class HybridRowQueue( val page = try { allocatePage(required) } catch { - case _: OutOfMemoryError => + case _: SparkOutOfMemoryError => null } val buffer = if (page != null) { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/stat/FrequentItems.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/stat/FrequentItems.scala index 86f6307254332..420faa6f24734 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/stat/FrequentItems.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/stat/FrequentItems.scala @@ -69,7 +69,7 @@ object FrequentItems extends Logging { /** * Finding frequent items for columns, possibly with false positives. Using the * frequent element count algorithm described in - * here, proposed by Karp, Schenker, + * here, proposed by Karp, Schenker, * and Papadimitriou. * The `support` should be greater than 1e-4. * For Internal use only. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/stat/StatFunctions.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/stat/StatFunctions.scala index bea652cc33076..ac25a8fd90bc7 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/stat/StatFunctions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/stat/StatFunctions.scala @@ -45,7 +45,7 @@ object StatFunctions extends Logging { * * This method implements a variation of the Greenwald-Khanna algorithm (with some speed * optimizations). - * The algorithm was first present in + * The algorithm was first present in * Space-efficient Online Computation of Quantile Summaries by Greenwald and Khanna. * * @param df the dataframe diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/CheckpointFileManager.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/CheckpointFileManager.scala index 606ba250ad9d2..b3e4240c315bc 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/CheckpointFileManager.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/CheckpointFileManager.scala @@ -56,7 +56,7 @@ trait CheckpointFileManager { * @param overwriteIfPossible If true, then the implementations must do a best-effort attempt to * overwrite the file if it already exists. It should not throw * any exception if the file exists. However, if false, then the - * implementation must not overwrite if the file alraedy exists and + * implementation must not overwrite if the file already exists and * must throw `FileAlreadyExistsException` in that case. */ def createAtomic(path: Path, overwriteIfPossible: Boolean): CancellableFSDataOutputStream diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MicroBatchExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MicroBatchExecution.scala index 2cac86599ef19..5defca391a355 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MicroBatchExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MicroBatchExecution.scala @@ -24,13 +24,14 @@ import org.apache.spark.sql.{Dataset, SparkSession} import org.apache.spark.sql.catalyst.encoders.RowEncoder import org.apache.spark.sql.catalyst.expressions.{Alias, CurrentBatchTimestamp, CurrentDate, CurrentTimestamp} import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan, Project} +import org.apache.spark.sql.catalyst.util.truncatedString import org.apache.spark.sql.execution.SQLExecution import org.apache.spark.sql.execution.datasources.v2.{StreamingDataSourceV2Relation, WriteToDataSourceV2} import org.apache.spark.sql.execution.streaming.sources.{MicroBatchWritSupport, RateControlMicroBatchReadSupport} import org.apache.spark.sql.sources.v2._ import org.apache.spark.sql.sources.v2.reader.streaming.{MicroBatchReadSupport, Offset => OffsetV2} import org.apache.spark.sql.streaming.{OutputMode, ProcessingTime, Trigger} -import org.apache.spark.util.{Clock, Utils} +import org.apache.spark.util.Clock class MicroBatchExecution( sparkSession: SparkSession, @@ -475,8 +476,8 @@ class MicroBatchExecution( case StreamingExecutionRelation(source, output) => newData.get(source).map { dataPlan => assert(output.size == dataPlan.output.size, - s"Invalid batch: ${Utils.truncatedString(output, ",")} != " + - s"${Utils.truncatedString(dataPlan.output, ",")}") + s"Invalid batch: ${truncatedString(output, ",")} != " + + s"${truncatedString(dataPlan.output, ",")}") val aliases = output.zip(dataPlan.output).map { case (to, from) => Alias(from, to.name)(exprId = to.exprId, explicitMetadata = Some(from.metadata)) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala index 631a6eb649ffb..89b4f40c9c0b9 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala @@ -88,6 +88,7 @@ abstract class StreamExecution( val resolvedCheckpointRoot = { val checkpointPath = new Path(checkpointRoot) val fs = checkpointPath.getFileSystem(sparkSession.sessionState.newHadoopConf()) + fs.mkdirs(checkpointPath) checkpointPath.makeQualified(fs.getUri, fs.getWorkingDirectory).toUri.toString } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamMetadata.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamMetadata.scala index 0bc54eac4ee8e..516afbea5d9de 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamMetadata.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamMetadata.scala @@ -19,16 +19,18 @@ package org.apache.spark.sql.execution.streaming import java.io.{InputStreamReader, OutputStreamWriter} import java.nio.charset.StandardCharsets +import java.util.ConcurrentModificationException import scala.util.control.NonFatal import org.apache.commons.io.IOUtils import org.apache.hadoop.conf.Configuration -import org.apache.hadoop.fs.{FileSystem, FSDataInputStream, FSDataOutputStream, Path} +import org.apache.hadoop.fs.{FileAlreadyExistsException, FSDataInputStream, Path} import org.json4s.NoTypeHints import org.json4s.jackson.Serialization import org.apache.spark.internal.Logging +import org.apache.spark.sql.execution.streaming.CheckpointFileManager.CancellableFSDataOutputStream import org.apache.spark.sql.streaming.StreamingQuery /** @@ -70,19 +72,26 @@ object StreamMetadata extends Logging { metadata: StreamMetadata, metadataFile: Path, hadoopConf: Configuration): Unit = { - var output: FSDataOutputStream = null + var output: CancellableFSDataOutputStream = null try { - val fs = metadataFile.getFileSystem(hadoopConf) - output = fs.create(metadataFile) + val fileManager = CheckpointFileManager.create(metadataFile.getParent, hadoopConf) + output = fileManager.createAtomic(metadataFile, overwriteIfPossible = false) val writer = new OutputStreamWriter(output) Serialization.write(metadata, writer) writer.close() } catch { - case NonFatal(e) => + case e: FileAlreadyExistsException => + if (output != null) { + output.cancel() + } + throw new ConcurrentModificationException( + s"Multiple streaming queries are concurrently using $metadataFile", e) + case e: Throwable => + if (output != null) { + output.cancel() + } logError(s"Error writing stream metadata $metadata to $metadataFile", e) throw e - } finally { - IOUtils.closeQuietly(output) } } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/Triggers.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/Triggers.scala index 19e3e55cb2829..4c0db3cb42a82 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/Triggers.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/Triggers.scala @@ -17,7 +17,7 @@ package org.apache.spark.sql.execution.streaming -import org.apache.spark.annotation.{Experimental, InterfaceStability} +import org.apache.spark.annotation.{Evolving, Experimental} import org.apache.spark.sql.streaming.Trigger /** @@ -25,5 +25,5 @@ import org.apache.spark.sql.streaming.Trigger * the query. */ @Experimental -@InterfaceStability.Evolving +@Evolving case object OneTimeTrigger extends Trigger diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousExecution.scala index f009c52449adc..1eab55122e84b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousExecution.scala @@ -28,6 +28,7 @@ import org.apache.spark.SparkEnv import org.apache.spark.sql.SparkSession import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeMap, CurrentDate, CurrentTimestamp} import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan +import org.apache.spark.sql.catalyst.util.truncatedString import org.apache.spark.sql.execution.SQLExecution import org.apache.spark.sql.execution.datasources.v2.{DataSourceV2ScanExec, StreamingDataSourceV2Relation} import org.apache.spark.sql.execution.streaming.{ContinuousExecutionRelation, StreamingRelationV2, _} @@ -35,7 +36,7 @@ import org.apache.spark.sql.sources.v2 import org.apache.spark.sql.sources.v2.{ContinuousReadSupportProvider, DataSourceOptions, StreamingWriteSupportProvider} import org.apache.spark.sql.sources.v2.reader.streaming.{ContinuousReadSupport, PartitionOffset} import org.apache.spark.sql.streaming.{OutputMode, ProcessingTime, Trigger} -import org.apache.spark.util.{Clock, Utils} +import org.apache.spark.util.Clock class ContinuousExecution( sparkSession: SparkSession, @@ -164,8 +165,8 @@ class ContinuousExecution( val newOutput = readSupport.fullSchema().toAttributes assert(output.size == newOutput.size, - s"Invalid reader: ${Utils.truncatedString(output, ",")} != " + - s"${Utils.truncatedString(newOutput, ",")}") + s"Invalid reader: ${truncatedString(output, ",")} != " + + s"${truncatedString(newOutput, ",")}") replacements ++= output.zip(newOutput) val loggedOffset = offsets.offsets(0) @@ -262,7 +263,12 @@ class ContinuousExecution( reportTimeTaken("runContinuous") { SQLExecution.withNewExecutionId( - sparkSessionForQuery, lastExecution)(lastExecution.toRdd) + sparkSessionForQuery, lastExecution) { + // Materialize `executedPlan` so that accessing it when `toRdd` is running doesn't need to + // wait for a lock + lastExecution.executedPlan + lastExecution.toRdd + } } } catch { case t: Throwable if StreamExecution.isInterruptionException(t, sparkSession.sparkContext) && diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousTrigger.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousTrigger.scala index 90e1766c4d9f1..caffcc3c4c1a7 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousTrigger.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousTrigger.scala @@ -23,15 +23,15 @@ import scala.concurrent.duration.Duration import org.apache.commons.lang3.StringUtils -import org.apache.spark.annotation.{Experimental, InterfaceStability} -import org.apache.spark.sql.streaming.{ProcessingTime, Trigger} +import org.apache.spark.annotation.Evolving +import org.apache.spark.sql.streaming.Trigger import org.apache.spark.unsafe.types.CalendarInterval /** * A [[Trigger]] that continuously processes streaming data, asynchronously checkpointing at * the specified interval. */ -@InterfaceStability.Evolving +@Evolving case class ContinuousTrigger(intervalMs: Long) extends Trigger { require(intervalMs >= 0, "the interval of trigger should not be negative") } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/memory.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/memory.scala index adf52aba21a04..daee089f3871d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/memory.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/memory.scala @@ -31,11 +31,11 @@ import org.apache.spark.sql.catalyst.expressions.{Attribute, UnsafeRow} import org.apache.spark.sql.catalyst.plans.logical.{LeafNode, LogicalPlan, Statistics} import org.apache.spark.sql.catalyst.plans.logical.statsEstimation.EstimationUtils import org.apache.spark.sql.catalyst.streaming.InternalOutputModes._ +import org.apache.spark.sql.catalyst.util.truncatedString import org.apache.spark.sql.sources.v2.reader._ import org.apache.spark.sql.sources.v2.reader.streaming.{MicroBatchReadSupport, Offset => OffsetV2} import org.apache.spark.sql.streaming.OutputMode import org.apache.spark.sql.types.StructType -import org.apache.spark.util.Utils object MemoryStream { protected val currentBlockId = new AtomicInteger(0) @@ -117,7 +117,7 @@ case class MemoryStream[A : Encoder](id: Int, sqlContext: SQLContext) } } - override def toString: String = s"MemoryStream[${Utils.truncatedString(output, ",")}]" + override def toString: String = s"MemoryStream[${truncatedString(output, ",")}]" override def deserializeOffset(json: String): OffsetV2 = LongOffset(json.toLong) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala index d3313b8a315c9..7d785aa09cd9e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala @@ -213,7 +213,7 @@ object StateStoreProvider { */ def create(providerClassName: String): StateStoreProvider = { val providerClass = Utils.classForName(providerClassName) - providerClass.newInstance().asInstanceOf[StateStoreProvider] + providerClass.getConstructor().newInstance().asInstanceOf[StateStoreProvider] } /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SQLAppStatusListener.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SQLAppStatusListener.scala index 2959eacf95dae..a656a2f53e0a7 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SQLAppStatusListener.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SQLAppStatusListener.scala @@ -159,7 +159,6 @@ class SQLAppStatusListener( } private def aggregateMetrics(exec: LiveExecutionData): Map[Long, String] = { - val metricIds = exec.metrics.map(_.accumulatorId).sorted val metricTypes = exec.metrics.map { m => (m.accumulatorId, m.metricType) }.toMap val metrics = exec.stages.toSeq .flatMap { stageId => Option(stageMetrics.get(stageId)) } @@ -167,10 +166,10 @@ class SQLAppStatusListener( .flatMap { metrics => metrics.ids.zip(metrics.values) } val aggregatedMetrics = (metrics ++ exec.driverAccumUpdates.toSeq) - .filter { case (id, _) => metricIds.contains(id) } + .filter { case (id, _) => metricTypes.contains(id) } .groupBy(_._1) .map { case (id, values) => - id -> SQLMetrics.stringValue(metricTypes(id), values.map(_._2).toSeq) + id -> SQLMetrics.stringValue(metricTypes(id), values.map(_._2)) } // Check the execution again for whether the aggregated metrics data has been calculated. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/expressions/Aggregator.scala b/sql/core/src/main/scala/org/apache/spark/sql/expressions/Aggregator.scala index 1e076207bc607..6b4def35e1955 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/expressions/Aggregator.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/expressions/Aggregator.scala @@ -17,8 +17,8 @@ package org.apache.spark.sql.expressions -import org.apache.spark.annotation.{Experimental, InterfaceStability} -import org.apache.spark.sql.{Dataset, Encoder, TypedColumn} +import org.apache.spark.annotation.{Evolving, Experimental} +import org.apache.spark.sql.{Encoder, TypedColumn} import org.apache.spark.sql.catalyst.encoders.encoderFor import org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateExpression, Complete} import org.apache.spark.sql.execution.aggregate.TypedAggregateExpression @@ -51,7 +51,7 @@ import org.apache.spark.sql.execution.aggregate.TypedAggregateExpression * @since 1.6.0 */ @Experimental -@InterfaceStability.Evolving +@Evolving abstract class Aggregator[-IN, BUF, OUT] extends Serializable { /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/expressions/UserDefinedFunction.scala b/sql/core/src/main/scala/org/apache/spark/sql/expressions/UserDefinedFunction.scala index eb956c4b3e888..58a942afe28c3 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/expressions/UserDefinedFunction.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/expressions/UserDefinedFunction.scala @@ -17,7 +17,7 @@ package org.apache.spark.sql.expressions -import org.apache.spark.annotation.InterfaceStability +import org.apache.spark.annotation.Stable import org.apache.spark.sql.Column import org.apache.spark.sql.catalyst.ScalaReflection import org.apache.spark.sql.catalyst.expressions.ScalaUDF @@ -37,7 +37,7 @@ import org.apache.spark.sql.types.DataType * * @since 1.3.0 */ -@InterfaceStability.Stable +@Stable case class UserDefinedFunction protected[sql] ( f: AnyRef, dataType: DataType, diff --git a/sql/core/src/main/scala/org/apache/spark/sql/expressions/Window.scala b/sql/core/src/main/scala/org/apache/spark/sql/expressions/Window.scala index 14dec8f0810f2..64375085a64ad 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/expressions/Window.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/expressions/Window.scala @@ -17,7 +17,7 @@ package org.apache.spark.sql.expressions -import org.apache.spark.annotation.InterfaceStability +import org.apache.spark.annotation.Stable import org.apache.spark.sql.Column import org.apache.spark.sql.catalyst.expressions._ @@ -39,7 +39,7 @@ import org.apache.spark.sql.catalyst.expressions._ * * @since 1.4.0 */ -@InterfaceStability.Stable +@Stable object Window { /** @@ -243,5 +243,5 @@ object Window { * * @since 1.4.0 */ -@InterfaceStability.Stable +@Stable class Window private() // So we can see Window in JavaDoc. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/expressions/WindowSpec.scala b/sql/core/src/main/scala/org/apache/spark/sql/expressions/WindowSpec.scala index 0cc43a58237df..0ace9d25f9dc7 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/expressions/WindowSpec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/expressions/WindowSpec.scala @@ -17,7 +17,7 @@ package org.apache.spark.sql.expressions -import org.apache.spark.annotation.InterfaceStability +import org.apache.spark.annotation.Stable import org.apache.spark.sql.{AnalysisException, Column} import org.apache.spark.sql.catalyst.expressions._ @@ -28,7 +28,7 @@ import org.apache.spark.sql.catalyst.expressions._ * * @since 1.4.0 */ -@InterfaceStability.Stable +@Stable class WindowSpec private[sql]( partitionSpec: Seq[Expression], orderSpec: Seq[SortOrder], diff --git a/sql/core/src/main/scala/org/apache/spark/sql/expressions/scalalang/typed.scala b/sql/core/src/main/scala/org/apache/spark/sql/expressions/scalalang/typed.scala index 3e637d594caf3..1cb579c4faa76 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/expressions/scalalang/typed.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/expressions/scalalang/typed.scala @@ -17,7 +17,7 @@ package org.apache.spark.sql.expressions.scalalang -import org.apache.spark.annotation.{Experimental, InterfaceStability} +import org.apache.spark.annotation.{Evolving, Experimental} import org.apache.spark.sql._ import org.apache.spark.sql.execution.aggregate._ @@ -30,7 +30,7 @@ import org.apache.spark.sql.execution.aggregate._ * @since 2.0.0 */ @Experimental -@InterfaceStability.Evolving +@Evolving // scalastyle:off object typed { // scalastyle:on diff --git a/sql/core/src/main/scala/org/apache/spark/sql/expressions/udaf.scala b/sql/core/src/main/scala/org/apache/spark/sql/expressions/udaf.scala index 4976b875fa298..4e8cb3a6ddd66 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/expressions/udaf.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/expressions/udaf.scala @@ -17,7 +17,7 @@ package org.apache.spark.sql.expressions -import org.apache.spark.annotation.InterfaceStability +import org.apache.spark.annotation.Stable import org.apache.spark.sql.{Column, Row} import org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateExpression, Complete} import org.apache.spark.sql.execution.aggregate.ScalaUDAF @@ -28,7 +28,7 @@ import org.apache.spark.sql.types._ * * @since 1.5.0 */ -@InterfaceStability.Stable +@Stable abstract class UserDefinedAggregateFunction extends Serializable { /** @@ -159,7 +159,7 @@ abstract class UserDefinedAggregateFunction extends Serializable { * * @since 1.5.0 */ -@InterfaceStability.Stable +@Stable abstract class MutableAggregationBuffer extends Row { /** Update the ith value of this buffer. */ diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala index 8edc57834b2f5..274be8ae7c835 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala @@ -23,7 +23,7 @@ import scala.reflect.runtime.universe.{typeTag, TypeTag} import scala.util.Try import scala.util.control.NonFatal -import org.apache.spark.annotation.InterfaceStability +import org.apache.spark.annotation.Stable import org.apache.spark.sql.api.java._ import org.apache.spark.sql.catalyst.ScalaReflection import org.apache.spark.sql.catalyst.analysis.{Star, UnresolvedFunction} @@ -68,7 +68,7 @@ import org.apache.spark.util.Utils * @groupname Ungrouped Support functions for DataFrames * @since 1.3.0 */ -@InterfaceStability.Stable +@Stable // scalastyle:off object functions { // scalastyle:on @@ -3836,7 +3836,7 @@ object functions { /** * Returns an unordered array of all entries in the given map. * @group collection_funcs - * @since 2.4.0 + * @since 3.0.0 */ def map_entries(e: Column): Column = withExpr { MapEntries(e.expr) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/internal/BaseSessionStateBuilder.scala b/sql/core/src/main/scala/org/apache/spark/sql/internal/BaseSessionStateBuilder.scala index f67cc32c15dd2..ac07e1f6bb4f8 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/internal/BaseSessionStateBuilder.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/internal/BaseSessionStateBuilder.scala @@ -17,7 +17,7 @@ package org.apache.spark.sql.internal import org.apache.spark.SparkConf -import org.apache.spark.annotation.{Experimental, InterfaceStability} +import org.apache.spark.annotation.{Experimental, Unstable} import org.apache.spark.sql.{ExperimentalMethods, SparkSession, UDFRegistration, _} import org.apache.spark.sql.catalyst.analysis.{Analyzer, FunctionRegistry} import org.apache.spark.sql.catalyst.catalog.SessionCatalog @@ -50,7 +50,7 @@ import org.apache.spark.sql.util.ExecutionListenerManager * and `catalog` fields. Note that the state is cloned when `build` is called, and not before. */ @Experimental -@InterfaceStability.Unstable +@Unstable abstract class BaseSessionStateBuilder( val session: SparkSession, val parentState: Option[SessionState] = None) { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/internal/SessionState.scala b/sql/core/src/main/scala/org/apache/spark/sql/internal/SessionState.scala index accbea41b9603..b34db581ca2c1 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/internal/SessionState.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/internal/SessionState.scala @@ -23,7 +23,7 @@ import org.apache.hadoop.conf.Configuration import org.apache.hadoop.fs.Path import org.apache.spark.SparkContext -import org.apache.spark.annotation.{Experimental, InterfaceStability} +import org.apache.spark.annotation.{Experimental, Unstable} import org.apache.spark.sql._ import org.apache.spark.sql.catalyst.analysis.{Analyzer, FunctionRegistry} import org.apache.spark.sql.catalyst.catalog._ @@ -124,7 +124,7 @@ private[sql] object SessionState { * Concrete implementation of a [[BaseSessionStateBuilder]]. */ @Experimental -@InterfaceStability.Unstable +@Unstable class SessionStateBuilder( session: SparkSession, parentState: Option[SessionState] = None) @@ -135,7 +135,7 @@ class SessionStateBuilder( /** * Session shared [[FunctionResourceLoader]]. */ -@InterfaceStability.Unstable +@Unstable class SessionResourceLoader(session: SparkSession) extends FunctionResourceLoader { override def loadResource(resource: FunctionResource): Unit = { resource.resourceType match { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JdbcDialects.scala b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JdbcDialects.scala index f76c1fae562c6..230b43022b02b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JdbcDialects.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JdbcDialects.scala @@ -21,8 +21,7 @@ import java.sql.{Connection, Date, Timestamp} import org.apache.commons.lang3.StringUtils -import org.apache.spark.annotation.{DeveloperApi, InterfaceStability, Since} -import org.apache.spark.sql.execution.datasources.jdbc.JDBCOptions +import org.apache.spark.annotation.{DeveloperApi, Evolving, Since} import org.apache.spark.sql.types._ /** @@ -34,7 +33,7 @@ import org.apache.spark.sql.types._ * send a null value to the database. */ @DeveloperApi -@InterfaceStability.Evolving +@Evolving case class JdbcType(databaseTypeDefinition : String, jdbcNullType : Int) /** @@ -57,7 +56,7 @@ case class JdbcType(databaseTypeDefinition : String, jdbcNullType : Int) * for the given Catalyst type. */ @DeveloperApi -@InterfaceStability.Evolving +@Evolving abstract class JdbcDialect extends Serializable { /** * Check if this dialect instance can handle a certain jdbc url. @@ -197,7 +196,7 @@ abstract class JdbcDialect extends Serializable { * sure to register your dialects first. */ @DeveloperApi -@InterfaceStability.Evolving +@Evolving object JdbcDialects { /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/package.scala b/sql/core/src/main/scala/org/apache/spark/sql/package.scala index 161e0102f0b43..61875931d226e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/package.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/package.scala @@ -17,7 +17,7 @@ package org.apache.spark -import org.apache.spark.annotation.{DeveloperApi, InterfaceStability} +import org.apache.spark.annotation.{DeveloperApi, Unstable} import org.apache.spark.sql.execution.SparkStrategy /** @@ -40,8 +40,17 @@ package object sql { * [[org.apache.spark.sql.sources]] */ @DeveloperApi - @InterfaceStability.Unstable + @Unstable type Strategy = SparkStrategy type DataFrame = Dataset[Row] + + /** + * Metadata key which is used to write Spark version in the followings: + * - Parquet file metadata + * - ORC file metadata + * + * Note that Hive table property `spark.sql.create.version` also has Spark version. + */ + private[sql] val SPARK_VERSION_METADATA_KEY = "org.apache.spark.version" } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/sources/filters.scala b/sql/core/src/main/scala/org/apache/spark/sql/sources/filters.scala index bdd8c4da6bd30..3f941cc6e1072 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/sources/filters.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/sources/filters.scala @@ -17,7 +17,7 @@ package org.apache.spark.sql.sources -import org.apache.spark.annotation.InterfaceStability +import org.apache.spark.annotation.Stable //////////////////////////////////////////////////////////////////////////////////////////////////// // This file defines all the filters that we can push down to the data sources. @@ -28,7 +28,7 @@ import org.apache.spark.annotation.InterfaceStability * * @since 1.3.0 */ -@InterfaceStability.Stable +@Stable abstract class Filter { /** * List of columns that are referenced by this filter. @@ -48,7 +48,7 @@ abstract class Filter { * * @since 1.3.0 */ -@InterfaceStability.Stable +@Stable case class EqualTo(attribute: String, value: Any) extends Filter { override def references: Array[String] = Array(attribute) ++ findReferences(value) } @@ -60,7 +60,7 @@ case class EqualTo(attribute: String, value: Any) extends Filter { * * @since 1.5.0 */ -@InterfaceStability.Stable +@Stable case class EqualNullSafe(attribute: String, value: Any) extends Filter { override def references: Array[String] = Array(attribute) ++ findReferences(value) } @@ -71,7 +71,7 @@ case class EqualNullSafe(attribute: String, value: Any) extends Filter { * * @since 1.3.0 */ -@InterfaceStability.Stable +@Stable case class GreaterThan(attribute: String, value: Any) extends Filter { override def references: Array[String] = Array(attribute) ++ findReferences(value) } @@ -82,7 +82,7 @@ case class GreaterThan(attribute: String, value: Any) extends Filter { * * @since 1.3.0 */ -@InterfaceStability.Stable +@Stable case class GreaterThanOrEqual(attribute: String, value: Any) extends Filter { override def references: Array[String] = Array(attribute) ++ findReferences(value) } @@ -93,7 +93,7 @@ case class GreaterThanOrEqual(attribute: String, value: Any) extends Filter { * * @since 1.3.0 */ -@InterfaceStability.Stable +@Stable case class LessThan(attribute: String, value: Any) extends Filter { override def references: Array[String] = Array(attribute) ++ findReferences(value) } @@ -104,7 +104,7 @@ case class LessThan(attribute: String, value: Any) extends Filter { * * @since 1.3.0 */ -@InterfaceStability.Stable +@Stable case class LessThanOrEqual(attribute: String, value: Any) extends Filter { override def references: Array[String] = Array(attribute) ++ findReferences(value) } @@ -114,7 +114,7 @@ case class LessThanOrEqual(attribute: String, value: Any) extends Filter { * * @since 1.3.0 */ -@InterfaceStability.Stable +@Stable case class In(attribute: String, values: Array[Any]) extends Filter { override def hashCode(): Int = { var h = attribute.hashCode @@ -141,7 +141,7 @@ case class In(attribute: String, values: Array[Any]) extends Filter { * * @since 1.3.0 */ -@InterfaceStability.Stable +@Stable case class IsNull(attribute: String) extends Filter { override def references: Array[String] = Array(attribute) } @@ -151,7 +151,7 @@ case class IsNull(attribute: String) extends Filter { * * @since 1.3.0 */ -@InterfaceStability.Stable +@Stable case class IsNotNull(attribute: String) extends Filter { override def references: Array[String] = Array(attribute) } @@ -161,7 +161,7 @@ case class IsNotNull(attribute: String) extends Filter { * * @since 1.3.0 */ -@InterfaceStability.Stable +@Stable case class And(left: Filter, right: Filter) extends Filter { override def references: Array[String] = left.references ++ right.references } @@ -171,7 +171,7 @@ case class And(left: Filter, right: Filter) extends Filter { * * @since 1.3.0 */ -@InterfaceStability.Stable +@Stable case class Or(left: Filter, right: Filter) extends Filter { override def references: Array[String] = left.references ++ right.references } @@ -181,7 +181,7 @@ case class Or(left: Filter, right: Filter) extends Filter { * * @since 1.3.0 */ -@InterfaceStability.Stable +@Stable case class Not(child: Filter) extends Filter { override def references: Array[String] = child.references } @@ -192,7 +192,7 @@ case class Not(child: Filter) extends Filter { * * @since 1.3.1 */ -@InterfaceStability.Stable +@Stable case class StringStartsWith(attribute: String, value: String) extends Filter { override def references: Array[String] = Array(attribute) } @@ -203,7 +203,7 @@ case class StringStartsWith(attribute: String, value: String) extends Filter { * * @since 1.3.1 */ -@InterfaceStability.Stable +@Stable case class StringEndsWith(attribute: String, value: String) extends Filter { override def references: Array[String] = Array(attribute) } @@ -214,7 +214,7 @@ case class StringEndsWith(attribute: String, value: String) extends Filter { * * @since 1.3.1 */ -@InterfaceStability.Stable +@Stable case class StringContains(attribute: String, value: String) extends Filter { override def references: Array[String] = Array(attribute) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala b/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala index 6057a795c8bf5..6ad054c9f6403 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala @@ -17,7 +17,7 @@ package org.apache.spark.sql.sources -import org.apache.spark.annotation.{DeveloperApi, Experimental, InterfaceStability} +import org.apache.spark.annotation._ import org.apache.spark.rdd.RDD import org.apache.spark.sql._ import org.apache.spark.sql.catalyst.InternalRow @@ -35,7 +35,7 @@ import org.apache.spark.sql.types.StructType * * @since 1.5.0 */ -@InterfaceStability.Stable +@Stable trait DataSourceRegister { /** @@ -65,7 +65,7 @@ trait DataSourceRegister { * * @since 1.3.0 */ -@InterfaceStability.Stable +@Stable trait RelationProvider { /** * Returns a new base relation with the given parameters. @@ -96,7 +96,7 @@ trait RelationProvider { * * @since 1.3.0 */ -@InterfaceStability.Stable +@Stable trait SchemaRelationProvider { /** * Returns a new base relation with the given parameters and user defined schema. @@ -117,7 +117,7 @@ trait SchemaRelationProvider { * @since 2.0.0 */ @Experimental -@InterfaceStability.Unstable +@Unstable trait StreamSourceProvider { /** @@ -148,7 +148,7 @@ trait StreamSourceProvider { * @since 2.0.0 */ @Experimental -@InterfaceStability.Unstable +@Unstable trait StreamSinkProvider { def createSink( sqlContext: SQLContext, @@ -160,7 +160,7 @@ trait StreamSinkProvider { /** * @since 1.3.0 */ -@InterfaceStability.Stable +@Stable trait CreatableRelationProvider { /** * Saves a DataFrame to a destination (using data source-specific parameters) @@ -192,7 +192,7 @@ trait CreatableRelationProvider { * * @since 1.3.0 */ -@InterfaceStability.Stable +@Stable abstract class BaseRelation { def sqlContext: SQLContext def schema: StructType @@ -242,7 +242,7 @@ abstract class BaseRelation { * * @since 1.3.0 */ -@InterfaceStability.Stable +@Stable trait TableScan { def buildScan(): RDD[Row] } @@ -253,7 +253,7 @@ trait TableScan { * * @since 1.3.0 */ -@InterfaceStability.Stable +@Stable trait PrunedScan { def buildScan(requiredColumns: Array[String]): RDD[Row] } @@ -271,7 +271,7 @@ trait PrunedScan { * * @since 1.3.0 */ -@InterfaceStability.Stable +@Stable trait PrunedFilteredScan { def buildScan(requiredColumns: Array[String], filters: Array[Filter]): RDD[Row] } @@ -293,7 +293,7 @@ trait PrunedFilteredScan { * * @since 1.3.0 */ -@InterfaceStability.Stable +@Stable trait InsertableRelation { def insert(data: DataFrame, overwrite: Boolean): Unit } @@ -309,7 +309,7 @@ trait InsertableRelation { * @since 1.3.0 */ @Experimental -@InterfaceStability.Unstable +@Unstable trait CatalystScan { def buildScan(requiredColumns: Seq[Attribute], filters: Seq[Expression]): RDD[Row] } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamReader.scala b/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamReader.scala index 4c7dcedafeeae..c8e3e1c191044 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamReader.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamReader.scala @@ -21,7 +21,7 @@ import java.util.Locale import scala.collection.JavaConverters._ -import org.apache.spark.annotation.InterfaceStability +import org.apache.spark.annotation.Evolving import org.apache.spark.internal.Logging import org.apache.spark.sql.{AnalysisException, DataFrame, Dataset, SparkSession} import org.apache.spark.sql.execution.command.DDLUtils @@ -40,7 +40,7 @@ import org.apache.spark.util.Utils * * @since 2.0.0 */ -@InterfaceStability.Evolving +@Evolving final class DataStreamReader private[sql](sparkSession: SparkSession) extends Logging { /** * Specifies the input data source format. @@ -158,7 +158,8 @@ final class DataStreamReader private[sql](sparkSession: SparkSession) extends Lo "read files of Hive data source directly.") } - val ds = DataSource.lookupDataSource(source, sparkSession.sqlContext.conf).newInstance() + val ds = DataSource.lookupDataSource(source, sparkSession.sqlContext.conf). + getConstructor().newInstance() // We need to generate the V1 data source so we can pass it to the V2 relation as a shim. // We can't be sure at this point whether we'll actually want to use V2, since we don't know the // writer or whether the query is continuous. @@ -296,6 +297,8 @@ final class DataStreamReader private[sql](sparkSession: SparkSession) extends Lo * that should be used for parsing. *
  • `dropFieldIfAllNull` (default `false`): whether to ignore column of all null values or * empty array/struct during schema inference.
  • + *
  • `locale` (default is `en-US`): sets a locale as language tag in IETF BCP 47 format. + * For instance, this is used while parsing dates and timestamps.
  • * * * @since 2.0.0 @@ -372,6 +375,10 @@ final class DataStreamReader private[sql](sparkSession: SparkSession) extends Lo * `spark.sql.columnNameOfCorruptRecord`): allows renaming the new field having malformed string * created by `PERMISSIVE` mode. This overrides `spark.sql.columnNameOfCorruptRecord`. *
  • `multiLine` (default `false`): parse one record, which may span multiple lines.
  • + *
  • `locale` (default is `en-US`): sets a locale as language tag in IETF BCP 47 format. + * For instance, this is used while parsing dates and timestamps.
  • + *
  • `lineSep` (default covers all `\r`, `\r\n` and `\n`): defines the line separator + * that should be used for parsing. Maximum length is 1 character.
  • * * * @since 2.0.0 diff --git a/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamWriter.scala index 4a8c7fdb58ff1..5733258a6b310 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamWriter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamWriter.scala @@ -21,7 +21,7 @@ import java.util.Locale import scala.collection.JavaConverters._ -import org.apache.spark.annotation.InterfaceStability +import org.apache.spark.annotation.Evolving import org.apache.spark.api.java.function.VoidFunction2 import org.apache.spark.sql._ import org.apache.spark.sql.catalyst.streaming.InternalOutputModes @@ -39,7 +39,7 @@ import org.apache.spark.sql.sources.v2.StreamingWriteSupportProvider * * @since 2.0.0 */ -@InterfaceStability.Evolving +@Evolving final class DataStreamWriter[T] private[sql](ds: Dataset[T]) { private val df = ds.toDF() @@ -307,7 +307,7 @@ final class DataStreamWriter[T] private[sql](ds: Dataset[T]) { val ds = DataSource.lookupDataSource(source, df.sparkSession.sessionState.conf) val disabledSources = df.sparkSession.sqlContext.conf.disabledV2StreamingWriters.split(",") var options = extraOptions.toMap - val sink = ds.newInstance() match { + val sink = ds.getConstructor().newInstance() match { case w: StreamingWriteSupportProvider if !disabledSources.contains(w.getClass.getCanonicalName) => val sessionOptions = DataSourceV2Utils.extractSessionConfigs( @@ -365,7 +365,7 @@ final class DataStreamWriter[T] private[sql](ds: Dataset[T]) { * * @since 2.4.0 */ - @InterfaceStability.Evolving + @Evolving def foreachBatch(function: (Dataset[T], Long) => Unit): DataStreamWriter[T] = { this.source = "foreachBatch" if (function == null) throw new IllegalArgumentException("foreachBatch function cannot be null") @@ -386,7 +386,7 @@ final class DataStreamWriter[T] private[sql](ds: Dataset[T]) { * * @since 2.4.0 */ - @InterfaceStability.Evolving + @Evolving def foreachBatch(function: VoidFunction2[Dataset[T], java.lang.Long]): DataStreamWriter[T] = { foreachBatch((batchDs: Dataset[T], batchId: Long) => function.call(batchDs, batchId)) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/streaming/GroupState.scala b/sql/core/src/main/scala/org/apache/spark/sql/streaming/GroupState.scala index e9510c903acae..ab68eba81b843 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/streaming/GroupState.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/streaming/GroupState.scala @@ -17,8 +17,7 @@ package org.apache.spark.sql.streaming -import org.apache.spark.annotation.{Experimental, InterfaceStability} -import org.apache.spark.sql.KeyValueGroupedDataset +import org.apache.spark.annotation.{Evolving, Experimental} import org.apache.spark.sql.catalyst.plans.logical.LogicalGroupState /** @@ -192,7 +191,7 @@ import org.apache.spark.sql.catalyst.plans.logical.LogicalGroupState * @since 2.2.0 */ @Experimental -@InterfaceStability.Evolving +@Evolving trait GroupState[S] extends LogicalGroupState[S] { /** Whether state exists or not. */ diff --git a/sql/core/src/main/scala/org/apache/spark/sql/streaming/ProcessingTime.scala b/sql/core/src/main/scala/org/apache/spark/sql/streaming/ProcessingTime.scala index a033575d3d38f..236bd55ee6212 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/streaming/ProcessingTime.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/streaming/ProcessingTime.scala @@ -23,7 +23,7 @@ import scala.concurrent.duration.Duration import org.apache.commons.lang3.StringUtils -import org.apache.spark.annotation.InterfaceStability +import org.apache.spark.annotation.Evolving import org.apache.spark.unsafe.types.CalendarInterval /** @@ -48,7 +48,7 @@ import org.apache.spark.unsafe.types.CalendarInterval * * @since 2.0.0 */ -@InterfaceStability.Evolving +@Evolving @deprecated("use Trigger.ProcessingTime(intervalMs)", "2.2.0") case class ProcessingTime(intervalMs: Long) extends Trigger { require(intervalMs >= 0, "the interval of trigger should not be negative") @@ -59,7 +59,7 @@ case class ProcessingTime(intervalMs: Long) extends Trigger { * * @since 2.0.0 */ -@InterfaceStability.Evolving +@Evolving @deprecated("use Trigger.ProcessingTime(intervalMs)", "2.2.0") object ProcessingTime { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/streaming/StreamingQuery.scala b/sql/core/src/main/scala/org/apache/spark/sql/streaming/StreamingQuery.scala index f2dfbe42260d7..47ddc88e964e8 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/streaming/StreamingQuery.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/streaming/StreamingQuery.scala @@ -19,7 +19,7 @@ package org.apache.spark.sql.streaming import java.util.UUID -import org.apache.spark.annotation.InterfaceStability +import org.apache.spark.annotation.Evolving import org.apache.spark.sql.SparkSession /** @@ -27,7 +27,7 @@ import org.apache.spark.sql.SparkSession * All these methods are thread-safe. * @since 2.0.0 */ -@InterfaceStability.Evolving +@Evolving trait StreamingQuery { /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/streaming/StreamingQueryException.scala b/sql/core/src/main/scala/org/apache/spark/sql/streaming/StreamingQueryException.scala index 03aeb14de502a..646d6888b2a16 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/streaming/StreamingQueryException.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/streaming/StreamingQueryException.scala @@ -17,7 +17,7 @@ package org.apache.spark.sql.streaming -import org.apache.spark.annotation.InterfaceStability +import org.apache.spark.annotation.Evolving /** * Exception that stopped a [[StreamingQuery]]. Use `cause` get the actual exception @@ -28,7 +28,7 @@ import org.apache.spark.annotation.InterfaceStability * @param endOffset Ending offset in json of the range of data in exception occurred * @since 2.0.0 */ -@InterfaceStability.Evolving +@Evolving class StreamingQueryException private[sql]( private val queryDebugString: String, val message: String, diff --git a/sql/core/src/main/scala/org/apache/spark/sql/streaming/StreamingQueryListener.scala b/sql/core/src/main/scala/org/apache/spark/sql/streaming/StreamingQueryListener.scala index 6aa82b89ede81..916d6a0365965 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/streaming/StreamingQueryListener.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/streaming/StreamingQueryListener.scala @@ -19,7 +19,7 @@ package org.apache.spark.sql.streaming import java.util.UUID -import org.apache.spark.annotation.InterfaceStability +import org.apache.spark.annotation.Evolving import org.apache.spark.scheduler.SparkListenerEvent /** @@ -28,7 +28,7 @@ import org.apache.spark.scheduler.SparkListenerEvent * * @since 2.0.0 */ -@InterfaceStability.Evolving +@Evolving abstract class StreamingQueryListener { import StreamingQueryListener._ @@ -67,14 +67,14 @@ abstract class StreamingQueryListener { * Companion object of [[StreamingQueryListener]] that defines the listener events. * @since 2.0.0 */ -@InterfaceStability.Evolving +@Evolving object StreamingQueryListener { /** * Base type of [[StreamingQueryListener]] events * @since 2.0.0 */ - @InterfaceStability.Evolving + @Evolving trait Event extends SparkListenerEvent /** @@ -84,7 +84,7 @@ object StreamingQueryListener { * @param name User-specified name of the query, null if not specified. * @since 2.1.0 */ - @InterfaceStability.Evolving + @Evolving class QueryStartedEvent private[sql]( val id: UUID, val runId: UUID, @@ -95,7 +95,7 @@ object StreamingQueryListener { * @param progress The query progress updates. * @since 2.1.0 */ - @InterfaceStability.Evolving + @Evolving class QueryProgressEvent private[sql](val progress: StreamingQueryProgress) extends Event /** @@ -107,7 +107,7 @@ object StreamingQueryListener { * with an exception. Otherwise, it will be `None`. * @since 2.1.0 */ - @InterfaceStability.Evolving + @Evolving class QueryTerminatedEvent private[sql]( val id: UUID, val runId: UUID, diff --git a/sql/core/src/main/scala/org/apache/spark/sql/streaming/StreamingQueryManager.scala b/sql/core/src/main/scala/org/apache/spark/sql/streaming/StreamingQueryManager.scala index cd52d991d55c9..d9fe1a992a093 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/streaming/StreamingQueryManager.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/streaming/StreamingQueryManager.scala @@ -25,7 +25,7 @@ import scala.collection.mutable import org.apache.hadoop.fs.Path import org.apache.spark.SparkException -import org.apache.spark.annotation.InterfaceStability +import org.apache.spark.annotation.Evolving import org.apache.spark.internal.Logging import org.apache.spark.sql.{AnalysisException, DataFrame, SparkSession} import org.apache.spark.sql.catalyst.analysis.UnsupportedOperationChecker @@ -42,7 +42,7 @@ import org.apache.spark.util.{Clock, SystemClock, Utils} * * @since 2.0.0 */ -@InterfaceStability.Evolving +@Evolving class StreamingQueryManager private[sql] (sparkSession: SparkSession) extends Logging { private[sql] val stateStoreCoordinator = @@ -311,7 +311,7 @@ class StreamingQueryManager private[sql] (sparkSession: SparkSession) extends Lo outputMode: OutputMode, useTempCheckpointLocation: Boolean = false, recoverFromCheckpointLocation: Boolean = true, - trigger: Trigger = ProcessingTime(0), + trigger: Trigger = Trigger.ProcessingTime(0), triggerClock: Clock = new SystemClock()): StreamingQuery = { val query = createQuery( userSpecifiedName, diff --git a/sql/core/src/main/scala/org/apache/spark/sql/streaming/StreamingQueryStatus.scala b/sql/core/src/main/scala/org/apache/spark/sql/streaming/StreamingQueryStatus.scala index a0c9bcc8929eb..9dc62b7aac891 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/streaming/StreamingQueryStatus.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/streaming/StreamingQueryStatus.scala @@ -22,7 +22,7 @@ import org.json4s.JsonAST.JValue import org.json4s.JsonDSL._ import org.json4s.jackson.JsonMethods._ -import org.apache.spark.annotation.InterfaceStability +import org.apache.spark.annotation.Evolving /** * Reports information about the instantaneous status of a streaming query. @@ -34,7 +34,7 @@ import org.apache.spark.annotation.InterfaceStability * * @since 2.1.0 */ -@InterfaceStability.Evolving +@Evolving class StreamingQueryStatus protected[sql]( val message: String, val isDataAvailable: Boolean, diff --git a/sql/core/src/main/scala/org/apache/spark/sql/streaming/progress.scala b/sql/core/src/main/scala/org/apache/spark/sql/streaming/progress.scala index f2173aa1e59c2..3cd6700efef5f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/streaming/progress.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/streaming/progress.scala @@ -29,12 +29,12 @@ import org.json4s.JsonAST.JValue import org.json4s.JsonDSL._ import org.json4s.jackson.JsonMethods._ -import org.apache.spark.annotation.InterfaceStability +import org.apache.spark.annotation.Evolving /** * Information about updates made to stateful operators in a [[StreamingQuery]] during a trigger. */ -@InterfaceStability.Evolving +@Evolving class StateOperatorProgress private[sql]( val numRowsTotal: Long, val numRowsUpdated: Long, @@ -94,7 +94,7 @@ class StateOperatorProgress private[sql]( * @param sources detailed statistics on data being read from each of the streaming sources. * @since 2.1.0 */ -@InterfaceStability.Evolving +@Evolving class StreamingQueryProgress private[sql]( val id: UUID, val runId: UUID, @@ -165,7 +165,7 @@ class StreamingQueryProgress private[sql]( * Spark. * @since 2.1.0 */ -@InterfaceStability.Evolving +@Evolving class SourceProgress protected[sql]( val description: String, val startOffset: String, @@ -209,7 +209,7 @@ class SourceProgress protected[sql]( * @param description Description of the source corresponding to this status. * @since 2.1.0 */ -@InterfaceStability.Evolving +@Evolving class SinkProgress protected[sql]( val description: String) extends Serializable { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/test/ExamplePointUDT.scala b/sql/core/src/main/scala/org/apache/spark/sql/test/ExamplePointUDT.scala index 8bab7e1c58762..7beac16599de5 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/test/ExamplePointUDT.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/test/ExamplePointUDT.scala @@ -45,7 +45,7 @@ private[sql] class ExamplePointUDT extends UserDefinedType[ExamplePoint] { override def sqlType: DataType = ArrayType(DoubleType, false) - override def pyUDT: String = "pyspark.sql.tests.ExamplePointUDT" + override def pyUDT: String = "pyspark.testing.sqlutils.ExamplePointUDT" override def serialize(p: ExamplePoint): GenericArrayData = { val output = new Array[Any](2) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/util/QueryExecutionListener.scala b/sql/core/src/main/scala/org/apache/spark/sql/util/QueryExecutionListener.scala index 1310fdfa1356b..77ae047705de0 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/util/QueryExecutionListener.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/util/QueryExecutionListener.scala @@ -19,7 +19,7 @@ package org.apache.spark.sql.util import scala.collection.JavaConverters._ -import org.apache.spark.annotation.{DeveloperApi, Experimental, InterfaceStability} +import org.apache.spark.annotation.{DeveloperApi, Evolving, Experimental} import org.apache.spark.internal.Logging import org.apache.spark.scheduler.{SparkListener, SparkListenerEvent} import org.apache.spark.sql.SparkSession @@ -36,7 +36,7 @@ import org.apache.spark.util.{ListenerBus, Utils} * multiple different threads. */ @Experimental -@InterfaceStability.Evolving +@Evolving trait QueryExecutionListener { /** @@ -73,7 +73,7 @@ trait QueryExecutionListener { * Manager for [[QueryExecutionListener]]. See `org.apache.spark.sql.SQLContext.listenerManager`. */ @Experimental -@InterfaceStability.Evolving +@Evolving // The `session` is used to indicate which session carries this listener manager, and we only // catch SQL executions which are launched by the same session. // The `loadExtensions` flag is used to indicate whether we should load the pre-defined, diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/JavaBeanDeserializationSuite.java b/sql/core/src/test/java/test/org/apache/spark/sql/JavaBeanDeserializationSuite.java index 7f975a647c241..8f35abeb579b5 100644 --- a/sql/core/src/test/java/test/org/apache/spark/sql/JavaBeanDeserializationSuite.java +++ b/sql/core/src/test/java/test/org/apache/spark/sql/JavaBeanDeserializationSuite.java @@ -143,11 +143,16 @@ public void setIntervals(List intervals) { this.intervals = intervals; } + @Override + public int hashCode() { + return id ^ Objects.hashCode(intervals); + } + @Override public boolean equals(Object obj) { if (!(obj instanceof ArrayRecord)) return false; ArrayRecord other = (ArrayRecord) obj; - return (other.id == this.id) && other.intervals.equals(this.intervals); + return (other.id == this.id) && Objects.equals(other.intervals, this.intervals); } @Override @@ -184,6 +189,11 @@ public void setIntervals(Map intervals) { this.intervals = intervals; } + @Override + public int hashCode() { + return id ^ Objects.hashCode(intervals); + } + @Override public boolean equals(Object obj) { if (!(obj instanceof MapRecord)) return false; @@ -225,6 +235,11 @@ public void setEndTime(long endTime) { this.endTime = endTime; } + @Override + public int hashCode() { + return Long.hashCode(startTime) ^ Long.hashCode(endTime); + } + @Override public boolean equals(Object obj) { if (!(obj instanceof Interval)) return false; diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/JavaRowSuite.java b/sql/core/src/test/java/test/org/apache/spark/sql/JavaRowSuite.java index 3ab4db2a035d3..ca78d6489ef5c 100644 --- a/sql/core/src/test/java/test/org/apache/spark/sql/JavaRowSuite.java +++ b/sql/core/src/test/java/test/org/apache/spark/sql/JavaRowSuite.java @@ -67,20 +67,20 @@ public void setUp() { public void constructSimpleRow() { Row simpleRow = RowFactory.create( byteValue, // ByteType - new Byte(byteValue), + Byte.valueOf(byteValue), shortValue, // ShortType - new Short(shortValue), + Short.valueOf(shortValue), intValue, // IntegerType - new Integer(intValue), + Integer.valueOf(intValue), longValue, // LongType - new Long(longValue), + Long.valueOf(longValue), floatValue, // FloatType - new Float(floatValue), + Float.valueOf(floatValue), doubleValue, // DoubleType - new Double(doubleValue), + Double.valueOf(doubleValue), decimalValue, // DecimalType booleanValue, // BooleanType - new Boolean(booleanValue), + Boolean.valueOf(booleanValue), stringValue, // StringType binaryValue, // BinaryType dateValue, // DateType diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/JavaStringLength.java b/sql/core/src/test/java/test/org/apache/spark/sql/JavaStringLength.java index b90224f2ae397..5955eabe496df 100644 --- a/sql/core/src/test/java/test/org/apache/spark/sql/JavaStringLength.java +++ b/sql/core/src/test/java/test/org/apache/spark/sql/JavaStringLength.java @@ -25,6 +25,6 @@ public class JavaStringLength implements UDF1 { @Override public Integer call(String str) throws Exception { - return new Integer(str.length()); + return Integer.valueOf(str.length()); } } diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaRangeInputPartition.java b/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaRangeInputPartition.java new file mode 100644 index 0000000000000..438f489a3eea7 --- /dev/null +++ b/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaRangeInputPartition.java @@ -0,0 +1,30 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package test.org.apache.spark.sql.sources.v2; + +import org.apache.spark.sql.sources.v2.reader.InputPartition; + +class JavaRangeInputPartition implements InputPartition { + int start; + int end; + + JavaRangeInputPartition(int start, int end) { + this.start = start; + this.end = end; + } +} diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaSimpleReadSupport.java b/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaSimpleReadSupport.java index 685f9b9747e85..ced51dde6997b 100644 --- a/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaSimpleReadSupport.java +++ b/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaSimpleReadSupport.java @@ -88,12 +88,3 @@ public void close() throws IOException { } } -class JavaRangeInputPartition implements InputPartition { - int start; - int end; - - JavaRangeInputPartition(int start, int end) { - this.start = start; - this.end = end; - } -} diff --git a/sql/core/src/test/resources/sql-tests/inputs/typeCoercion/native/mapconcat.sql b/sql/core/src/test/resources/sql-tests/inputs/typeCoercion/native/mapconcat.sql index 69da67fc66fc0..60895020fcc83 100644 --- a/sql/core/src/test/resources/sql-tests/inputs/typeCoercion/native/mapconcat.sql +++ b/sql/core/src/test/resources/sql-tests/inputs/typeCoercion/native/mapconcat.sql @@ -13,7 +13,6 @@ CREATE TEMPORARY VIEW various_maps AS SELECT * FROM VALUES ( map('a', 'b'), map('c', 'd'), map(array('a', 'b'), array('c', 'd')), map(array('e'), array('f')), map(struct('a', 1), struct('b', 2)), map(struct('c', 3), struct('d', 4)), - map(map('a', 1), map('b', 2)), map(map('c', 3), map('d', 4)), map('a', 1), map('c', 2), map(1, 'a'), map(2, 'c') ) AS various_maps ( @@ -31,7 +30,6 @@ CREATE TEMPORARY VIEW various_maps AS SELECT * FROM VALUES ( string_map1, string_map2, array_map1, array_map2, struct_map1, struct_map2, - map_map1, map_map2, string_int_map1, string_int_map2, int_string_map1, int_string_map2 ); @@ -51,7 +49,6 @@ SELECT map_concat(string_map1, string_map2) string_map, map_concat(array_map1, array_map2) array_map, map_concat(struct_map1, struct_map2) struct_map, - map_concat(map_map1, map_map2) map_map, map_concat(string_int_map1, string_int_map2) string_int_map, map_concat(int_string_map1, int_string_map2) int_string_map FROM various_maps; @@ -71,7 +68,7 @@ FROM various_maps; -- Concatenate map of incompatible types 1 SELECT - map_concat(tinyint_map1, map_map2) tm_map + map_concat(tinyint_map1, array_map1) tm_map FROM various_maps; -- Concatenate map of incompatible types 2 @@ -86,10 +83,10 @@ FROM various_maps; -- Concatenate map of incompatible types 4 SELECT - map_concat(map_map1, array_map2) ma_map + map_concat(struct_map1, array_map2) ma_map FROM various_maps; -- Concatenate map of incompatible types 5 SELECT - map_concat(map_map1, struct_map2) ms_map + map_concat(int_map1, array_map2) ms_map FROM various_maps; diff --git a/sql/core/src/test/resources/sql-tests/inputs/window.sql b/sql/core/src/test/resources/sql-tests/inputs/window.sql index cda4db4b449fe..faab4c61c8640 100644 --- a/sql/core/src/test/resources/sql-tests/inputs/window.sql +++ b/sql/core/src/test/resources/sql-tests/inputs/window.sql @@ -109,3 +109,9 @@ last_value(false, false) OVER w AS last_value_contain_null FROM testData WINDOW w AS () ORDER BY cate, val; + +-- parentheses around window reference +SELECT cate, sum(val) OVER (w) +FROM testData +WHERE val is not null +WINDOW w AS (PARTITION BY cate ORDER BY val); diff --git a/sql/core/src/test/resources/sql-tests/results/describe-part-after-analyze.sql.out b/sql/core/src/test/resources/sql-tests/results/describe-part-after-analyze.sql.out index 270eee8680225..17dd317f63b70 100644 --- a/sql/core/src/test/resources/sql-tests/results/describe-part-after-analyze.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/describe-part-after-analyze.sql.out @@ -93,7 +93,7 @@ Partition Values [ds=2017-08-01, hr=10] Location [not included in comparison]sql/core/spark-warehouse/t/ds=2017-08-01/hr=10 Created Time [not included in comparison] Last Access [not included in comparison] -Partition Statistics 1264 bytes, 3 rows +Partition Statistics [not included in comparison] bytes, 3 rows # Storage Information Location [not included in comparison]sql/core/spark-warehouse/t @@ -128,7 +128,7 @@ Partition Values [ds=2017-08-01, hr=10] Location [not included in comparison]sql/core/spark-warehouse/t/ds=2017-08-01/hr=10 Created Time [not included in comparison] Last Access [not included in comparison] -Partition Statistics 1264 bytes, 3 rows +Partition Statistics [not included in comparison] bytes, 3 rows # Storage Information Location [not included in comparison]sql/core/spark-warehouse/t @@ -155,7 +155,7 @@ Partition Values [ds=2017-08-01, hr=11] Location [not included in comparison]sql/core/spark-warehouse/t/ds=2017-08-01/hr=11 Created Time [not included in comparison] Last Access [not included in comparison] -Partition Statistics 1278 bytes, 4 rows +Partition Statistics [not included in comparison] bytes, 4 rows # Storage Information Location [not included in comparison]sql/core/spark-warehouse/t @@ -190,7 +190,7 @@ Partition Values [ds=2017-08-01, hr=10] Location [not included in comparison]sql/core/spark-warehouse/t/ds=2017-08-01/hr=10 Created Time [not included in comparison] Last Access [not included in comparison] -Partition Statistics 1264 bytes, 3 rows +Partition Statistics [not included in comparison] bytes, 3 rows # Storage Information Location [not included in comparison]sql/core/spark-warehouse/t @@ -217,7 +217,7 @@ Partition Values [ds=2017-08-01, hr=11] Location [not included in comparison]sql/core/spark-warehouse/t/ds=2017-08-01/hr=11 Created Time [not included in comparison] Last Access [not included in comparison] -Partition Statistics 1278 bytes, 4 rows +Partition Statistics [not included in comparison] bytes, 4 rows # Storage Information Location [not included in comparison]sql/core/spark-warehouse/t @@ -244,7 +244,7 @@ Partition Values [ds=2017-09-01, hr=5] Location [not included in comparison]sql/core/spark-warehouse/t/ds=2017-09-01/hr=5 Created Time [not included in comparison] Last Access [not included in comparison] -Partition Statistics 1250 bytes, 2 rows +Partition Statistics [not included in comparison] bytes, 2 rows # Storage Information Location [not included in comparison]sql/core/spark-warehouse/t diff --git a/sql/core/src/test/resources/sql-tests/results/operators.sql.out b/sql/core/src/test/resources/sql-tests/results/operators.sql.out index fd1d0db9e3f78..570b281353f3d 100644 --- a/sql/core/src/test/resources/sql-tests/results/operators.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/operators.sql.out @@ -201,7 +201,7 @@ struct -- !query 24 output == Physical Plan == *Project [null AS (CAST(concat(a, CAST(1 AS STRING)) AS DOUBLE) + CAST(2 AS DOUBLE))#x] -+- Scan OneRowRelation[] ++- *Scan OneRowRelation[] -- !query 25 @@ -211,7 +211,7 @@ struct -- !query 25 output == Physical Plan == *Project [-1b AS concat(CAST((1 - 2) AS STRING), b)#x] -+- Scan OneRowRelation[] ++- *Scan OneRowRelation[] -- !query 26 @@ -221,7 +221,7 @@ struct -- !query 26 output == Physical Plan == *Project [11b AS concat(CAST(((2 * 4) + 3) AS STRING), b)#x] -+- Scan OneRowRelation[] ++- *Scan OneRowRelation[] -- !query 27 @@ -231,7 +231,7 @@ struct -- !query 27 output == Physical Plan == *Project [4a2.0 AS concat(concat(CAST((3 + 1) AS STRING), a), CAST((CAST(4 AS DOUBLE) / CAST(2 AS DOUBLE)) AS STRING))#x] -+- Scan OneRowRelation[] ++- *Scan OneRowRelation[] -- !query 28 @@ -241,7 +241,7 @@ struct -- !query 28 output == Physical Plan == *Project [true AS ((1 = 1) OR (concat(a, b) = ab))#x] -+- Scan OneRowRelation[] ++- *Scan OneRowRelation[] -- !query 29 @@ -251,7 +251,7 @@ struct -- !query 29 output == Physical Plan == *Project [false AS ((concat(a, c) = ac) AND (2 = 3))#x] -+- Scan OneRowRelation[] ++- *Scan OneRowRelation[] -- !query 30 diff --git a/sql/core/src/test/resources/sql-tests/results/typeCoercion/native/mapconcat.sql.out b/sql/core/src/test/resources/sql-tests/results/typeCoercion/native/mapconcat.sql.out index efc88e47209a6..79e00860e4c05 100644 --- a/sql/core/src/test/resources/sql-tests/results/typeCoercion/native/mapconcat.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/typeCoercion/native/mapconcat.sql.out @@ -18,7 +18,6 @@ CREATE TEMPORARY VIEW various_maps AS SELECT * FROM VALUES ( map('a', 'b'), map('c', 'd'), map(array('a', 'b'), array('c', 'd')), map(array('e'), array('f')), map(struct('a', 1), struct('b', 2)), map(struct('c', 3), struct('d', 4)), - map(map('a', 1), map('b', 2)), map(map('c', 3), map('d', 4)), map('a', 1), map('c', 2), map(1, 'a'), map(2, 'c') ) AS various_maps ( @@ -36,7 +35,6 @@ CREATE TEMPORARY VIEW various_maps AS SELECT * FROM VALUES ( string_map1, string_map2, array_map1, array_map2, struct_map1, struct_map2, - map_map1, map_map2, string_int_map1, string_int_map2, int_string_map1, int_string_map2 ) @@ -61,14 +59,13 @@ SELECT map_concat(string_map1, string_map2) string_map, map_concat(array_map1, array_map2) array_map, map_concat(struct_map1, struct_map2) struct_map, - map_concat(map_map1, map_map2) map_map, map_concat(string_int_map1, string_int_map2) string_int_map, map_concat(int_string_map1, int_string_map2) int_string_map FROM various_maps -- !query 1 schema -struct,tinyint_map:map,smallint_map:map,int_map:map,bigint_map:map,decimal_map:map,float_map:map,double_map:map,date_map:map,timestamp_map:map,string_map:map,array_map:map,array>,struct_map:map,struct>,map_map:map,map>,string_int_map:map,int_string_map:map> +struct,tinyint_map:map,smallint_map:map,int_map:map,bigint_map:map,decimal_map:map,float_map:map,double_map:map,date_map:map,timestamp_map:map,string_map:map,array_map:map,array>,struct_map:map,struct>,string_int_map:map,int_string_map:map> -- !query 1 output -{false:true,true:false} {1:2,3:4} {1:2,3:4} {4:6,7:8} {6:7,8:9} {9223372036854775808:9223372036854775809,9223372036854775809:9223372036854775808} {1.0:2.0,3.0:4.0} {1.0:2.0,3.0:4.0} {2016-03-12:2016-03-11,2016-03-14:2016-03-13} {2016-11-11 20:54:00.0:2016-11-09 20:54:00.0,2016-11-15 20:54:00.0:2016-11-12 20:54:00.0} {"a":"b","c":"d"} {["a","b"]:["c","d"],["e"]:["f"]} {{"col1":"a","col2":1}:{"col1":"b","col2":2},{"col1":"c","col2":3}:{"col1":"d","col2":4}} {{"a":1}:{"b":2},{"c":3}:{"d":4}} {"a":1,"c":2} {1:"a",2:"c"} +{false:true,true:false} {1:2,3:4} {1:2,3:4} {4:6,7:8} {6:7,8:9} {9223372036854775808:9223372036854775809,9223372036854775809:9223372036854775808} {1.0:2.0,3.0:4.0} {1.0:2.0,3.0:4.0} {2016-03-12:2016-03-11,2016-03-14:2016-03-13} {2016-11-11 20:54:00.0:2016-11-09 20:54:00.0,2016-11-15 20:54:00.0:2016-11-12 20:54:00.0} {"a":"b","c":"d"} {["a","b"]:["c","d"],["e"]:["f"]} {{"col1":"a","col2":1}:{"col1":"b","col2":2},{"col1":"c","col2":3}:{"col1":"d","col2":4}} {"a":1,"c":2} {1:"a",2:"c"} -- !query 2 @@ -91,13 +88,13 @@ struct,si_map:map,ib_map:map -- !query 3 output org.apache.spark.sql.AnalysisException -cannot resolve 'map_concat(various_maps.`tinyint_map1`, various_maps.`map_map2`)' due to data type mismatch: input to function map_concat should all be the same type, but it's [map, map,map>]; line 2 pos 4 +cannot resolve 'map_concat(various_maps.`tinyint_map1`, various_maps.`array_map1`)' due to data type mismatch: input to function map_concat should all be the same type, but it's [map, map,array>]; line 2 pos 4 -- !query 4 @@ -124,21 +121,21 @@ cannot resolve 'map_concat(various_maps.`int_map1`, various_maps.`struct_map2`)' -- !query 6 SELECT - map_concat(map_map1, array_map2) ma_map + map_concat(struct_map1, array_map2) ma_map FROM various_maps -- !query 6 schema struct<> -- !query 6 output org.apache.spark.sql.AnalysisException -cannot resolve 'map_concat(various_maps.`map_map1`, various_maps.`array_map2`)' due to data type mismatch: input to function map_concat should all be the same type, but it's [map,map>, map,array>]; line 2 pos 4 +cannot resolve 'map_concat(various_maps.`struct_map1`, various_maps.`array_map2`)' due to data type mismatch: input to function map_concat should all be the same type, but it's [map,struct>, map,array>]; line 2 pos 4 -- !query 7 SELECT - map_concat(map_map1, struct_map2) ms_map + map_concat(int_map1, array_map2) ms_map FROM various_maps -- !query 7 schema struct<> -- !query 7 output org.apache.spark.sql.AnalysisException -cannot resolve 'map_concat(various_maps.`map_map1`, various_maps.`struct_map2`)' due to data type mismatch: input to function map_concat should all be the same type, but it's [map,map>, map,struct>]; line 2 pos 4 +cannot resolve 'map_concat(various_maps.`int_map1`, various_maps.`array_map2`)' due to data type mismatch: input to function map_concat should all be the same type, but it's [map, map,array>]; line 2 pos 4 diff --git a/sql/core/src/test/resources/sql-tests/results/window.sql.out b/sql/core/src/test/resources/sql-tests/results/window.sql.out index 4afbcd62853dc..8190e21129b5c 100644 --- a/sql/core/src/test/resources/sql-tests/results/window.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/window.sql.out @@ -1,5 +1,5 @@ -- Automatically generated by SQLQueryTestSuite --- Number of queries: 22 +-- Number of queries: 23 -- !query 0 @@ -363,3 +363,20 @@ NULL a false true false false true false 1 b false true false false true false 2 b false true false false true false 3 b false true false false true false + + +-- !query 22 +SELECT cate, sum(val) OVER (w) +FROM testData +WHERE val is not null +WINDOW w AS (PARTITION BY cate ORDER BY val) +-- !query 22 schema +struct +-- !query 22 output +NULL 3 +a 2 +a 2 +a 4 +b 1 +b 3 +b 6 diff --git a/sql/core/src/test/scala/org/apache/spark/sql/ApproximatePercentileQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/ApproximatePercentileQuerySuite.scala index d635912cf7205..52708f5fe4108 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/ApproximatePercentileQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/ApproximatePercentileQuerySuite.scala @@ -208,7 +208,7 @@ class ApproximatePercentileQuerySuite extends QueryTest with SharedSQLContext { test("percentile_approx(col, ...), input rows contains null, with out group by") { withTempView(table) { - (1 to 1000).map(new Integer(_)).flatMap(Seq(null: Integer, _)).toDF("col") + (1 to 1000).map(Integer.valueOf(_)).flatMap(Seq(null: Integer, _)).toDF("col") .createOrReplaceTempView(table) checkAnswer( spark.sql( @@ -226,8 +226,8 @@ class ApproximatePercentileQuerySuite extends QueryTest with SharedSQLContext { withTempView(table) { val rand = new java.util.Random() (1 to 1000) - .map(new Integer(_)) - .map(v => (new Integer(v % 2), v)) + .map(Integer.valueOf(_)) + .map(v => (Integer.valueOf(v % 2), v)) // Add some nulls .flatMap(Seq(_, (null: Integer, null: Integer))) .toDF("key", "value").createOrReplaceTempView(table) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/CsvFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/CsvFunctionsSuite.scala index 1dd8ec31ee111..1c359ce1d2014 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/CsvFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/CsvFunctionsSuite.scala @@ -17,6 +17,9 @@ package org.apache.spark.sql +import java.text.SimpleDateFormat +import java.util.Locale + import scala.collection.JavaConverters._ import org.apache.spark.SparkException @@ -117,4 +120,65 @@ class CsvFunctionsSuite extends QueryTest with SharedSQLContext { "Acceptable modes are PERMISSIVE and FAILFAST.")) } } + + test("from_csv uses DDL strings for defining a schema - java") { + val df = Seq("""1,"haa"""").toDS() + checkAnswer( + df.select( + from_csv($"value", lit("a INT, b STRING"), new java.util.HashMap[String, String]())), + Row(Row(1, "haa")) :: Nil) + } + + test("roundtrip to_csv -> from_csv") { + val df = Seq(Tuple1(Tuple1(1)), Tuple1(null)).toDF("struct") + val schema = df.schema(0).dataType.asInstanceOf[StructType] + val options = Map.empty[String, String] + val readback = df.select(to_csv($"struct").as("csv")) + .select(from_csv($"csv", schema, options).as("struct")) + + checkAnswer(df, readback) + } + + test("roundtrip from_csv -> to_csv") { + val df = Seq(Some("1"), None).toDF("csv") + val schema = new StructType().add("a", IntegerType) + val options = Map.empty[String, String] + val readback = df.select(from_csv($"csv", schema, options).as("struct")) + .select(to_csv($"struct").as("csv")) + + checkAnswer(df, readback) + } + + test("infers schemas of a CSV string and pass to to from_csv") { + val in = Seq("""0.123456789,987654321,"San Francisco"""").toDS() + val options = Map.empty[String, String].asJava + val out = in.select(from_csv('value, schema_of_csv("0.1,1,a"), options) as "parsed") + val expected = StructType(Seq(StructField( + "parsed", + StructType(Seq( + StructField("_c0", DoubleType, true), + StructField("_c1", IntegerType, true), + StructField("_c2", StringType, true)))))) + + assert(out.schema == expected) + } + + test("Support to_csv in SQL") { + val df1 = Seq(Tuple1(Tuple1(1))).toDF("a") + checkAnswer(df1.selectExpr("to_csv(a)"), Row("1") :: Nil) + } + + test("parse timestamps with locale") { + Seq("en-US", "ko-KR", "zh-CN", "ru-RU").foreach { langTag => + val locale = Locale.forLanguageTag(langTag) + val ts = new SimpleDateFormat("dd/MM/yyyy HH:mm").parse("06/11/2018 18:00") + val timestampFormat = "dd MMM yyyy HH:mm" + val sdf = new SimpleDateFormat(timestampFormat, locale) + val input = Seq(s"""${sdf.format(ts)}""").toDS() + val options = Map("timestampFormat" -> timestampFormat, "locale" -> langTag) + val df = input.select(from_csv($"value", lit("time timestamp"), options.asJava)) + + checkAnswer(df, Row(Row(java.sql.Timestamp.valueOf("2018-11-06 18:00:00.0")))) + } + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala index d9ba6e2ce5120..ff64edcd07f4b 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala @@ -723,4 +723,18 @@ class DataFrameAggregateSuite extends QueryTest with SharedSQLContext { "grouping expressions: [current_date(None)], value: [key: int, value: string], " + "type: GroupBy]")) } + + test("SPARK-26021: Double and Float 0.0/-0.0 should be equal when grouping") { + val colName = "i" + val doubles = Seq(0.0d, -0.0d, 0.0d).toDF(colName).groupBy(colName).count().collect() + val floats = Seq(0.0f, -0.0f, 0.0f).toDF(colName).groupBy(colName).count().collect() + + assert(doubles.length == 1) + assert(floats.length == 1) + // using compare since 0.0 == -0.0 is true + assert(java.lang.Double.compare(doubles(0).getDouble(0), 0.0d) == 0) + assert(java.lang.Float.compare(floats(0).getFloat(0), 0.0f) == 0) + assert(doubles(0).getLong(1) == 3) + assert(floats(0).getLong(1) == 3) + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSetOperationsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSetOperationsSuite.scala new file mode 100644 index 0000000000000..30452af1fad64 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSetOperationsSuite.scala @@ -0,0 +1,509 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql + +import java.sql.{Date, Timestamp} + +import org.apache.spark.sql.catalyst.plans.logical.Union +import org.apache.spark.sql.functions._ +import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.test.{ExamplePoint, ExamplePointUDT, SharedSQLContext} +import org.apache.spark.sql.test.SQLTestData.NullStrings +import org.apache.spark.sql.types._ + +class DataFrameSetOperationsSuite extends QueryTest with SharedSQLContext { + import testImplicits._ + + test("except") { + checkAnswer( + lowerCaseData.except(upperCaseData), + Row(1, "a") :: + Row(2, "b") :: + Row(3, "c") :: + Row(4, "d") :: Nil) + checkAnswer(lowerCaseData.except(lowerCaseData), Nil) + checkAnswer(upperCaseData.except(upperCaseData), Nil) + + // check null equality + checkAnswer( + nullInts.except(nullInts.filter("0 = 1")), + nullInts) + checkAnswer( + nullInts.except(nullInts), + Nil) + + // check if values are de-duplicated + checkAnswer( + allNulls.except(allNulls.filter("0 = 1")), + Row(null) :: Nil) + checkAnswer( + allNulls.except(allNulls), + Nil) + + // check if values are de-duplicated + val df = Seq(("id1", 1), ("id1", 1), ("id", 1), ("id1", 2)).toDF("id", "value") + checkAnswer( + df.except(df.filter("0 = 1")), + Row("id1", 1) :: + Row("id", 1) :: + Row("id1", 2) :: Nil) + + // check if the empty set on the left side works + checkAnswer( + allNulls.filter("0 = 1").except(allNulls), + Nil) + } + + test("SPARK-23274: except between two projects without references used in filter") { + val df = Seq((1, 2, 4), (1, 3, 5), (2, 2, 3), (2, 4, 5)).toDF("a", "b", "c") + val df1 = df.filter($"a" === 1) + val df2 = df.filter($"a" === 2) + checkAnswer(df1.select("b").except(df2.select("b")), Row(3) :: Nil) + checkAnswer(df1.select("b").except(df2.select("c")), Row(2) :: Nil) + } + + test("except distinct - SQL compliance") { + val df_left = Seq(1, 2, 2, 3, 3, 4).toDF("id") + val df_right = Seq(1, 3).toDF("id") + + checkAnswer( + df_left.except(df_right), + Row(2) :: Row(4) :: Nil + ) + } + + test("except - nullability") { + val nonNullableInts = Seq(Tuple1(11), Tuple1(3)).toDF() + assert(nonNullableInts.schema.forall(!_.nullable)) + + val df1 = nonNullableInts.except(nullInts) + checkAnswer(df1, Row(11) :: Nil) + assert(df1.schema.forall(!_.nullable)) + + val df2 = nullInts.except(nonNullableInts) + checkAnswer(df2, Row(1) :: Row(2) :: Row(null) :: Nil) + assert(df2.schema.forall(_.nullable)) + + val df3 = nullInts.except(nullInts) + checkAnswer(df3, Nil) + assert(df3.schema.forall(_.nullable)) + + val df4 = nonNullableInts.except(nonNullableInts) + checkAnswer(df4, Nil) + assert(df4.schema.forall(!_.nullable)) + } + + test("except all") { + checkAnswer( + lowerCaseData.exceptAll(upperCaseData), + Row(1, "a") :: + Row(2, "b") :: + Row(3, "c") :: + Row(4, "d") :: Nil) + checkAnswer(lowerCaseData.exceptAll(lowerCaseData), Nil) + checkAnswer(upperCaseData.exceptAll(upperCaseData), Nil) + + // check null equality + checkAnswer( + nullInts.exceptAll(nullInts.filter("0 = 1")), + nullInts) + checkAnswer( + nullInts.exceptAll(nullInts), + Nil) + + // check that duplicate values are preserved + checkAnswer( + allNulls.exceptAll(allNulls.filter("0 = 1")), + Row(null) :: Row(null) :: Row(null) :: Row(null) :: Nil) + checkAnswer( + allNulls.exceptAll(allNulls.limit(2)), + Row(null) :: Row(null) :: Nil) + + // check that duplicates are retained. + val df = spark.sparkContext.parallelize( + NullStrings(1, "id1") :: + NullStrings(1, "id1") :: + NullStrings(2, "id1") :: + NullStrings(3, null) :: Nil).toDF("id", "value") + + checkAnswer( + df.exceptAll(df.filter("0 = 1")), + Row(1, "id1") :: + Row(1, "id1") :: + Row(2, "id1") :: + Row(3, null) :: Nil) + + // check if the empty set on the left side works + checkAnswer( + allNulls.filter("0 = 1").exceptAll(allNulls), + Nil) + + } + + test("exceptAll - nullability") { + val nonNullableInts = Seq(Tuple1(11), Tuple1(3)).toDF() + assert(nonNullableInts.schema.forall(!_.nullable)) + + val df1 = nonNullableInts.exceptAll(nullInts) + checkAnswer(df1, Row(11) :: Nil) + assert(df1.schema.forall(!_.nullable)) + + val df2 = nullInts.exceptAll(nonNullableInts) + checkAnswer(df2, Row(1) :: Row(2) :: Row(null) :: Nil) + assert(df2.schema.forall(_.nullable)) + + val df3 = nullInts.exceptAll(nullInts) + checkAnswer(df3, Nil) + assert(df3.schema.forall(_.nullable)) + + val df4 = nonNullableInts.exceptAll(nonNullableInts) + checkAnswer(df4, Nil) + assert(df4.schema.forall(!_.nullable)) + } + + test("intersect") { + checkAnswer( + lowerCaseData.intersect(lowerCaseData), + Row(1, "a") :: + Row(2, "b") :: + Row(3, "c") :: + Row(4, "d") :: Nil) + checkAnswer(lowerCaseData.intersect(upperCaseData), Nil) + + // check null equality + checkAnswer( + nullInts.intersect(nullInts), + Row(1) :: + Row(2) :: + Row(3) :: + Row(null) :: Nil) + + // check if values are de-duplicated + checkAnswer( + allNulls.intersect(allNulls), + Row(null) :: Nil) + + // check if values are de-duplicated + val df = Seq(("id1", 1), ("id1", 1), ("id", 1), ("id1", 2)).toDF("id", "value") + checkAnswer( + df.intersect(df), + Row("id1", 1) :: + Row("id", 1) :: + Row("id1", 2) :: Nil) + } + + test("intersect - nullability") { + val nonNullableInts = Seq(Tuple1(1), Tuple1(3)).toDF() + assert(nonNullableInts.schema.forall(!_.nullable)) + + val df1 = nonNullableInts.intersect(nullInts) + checkAnswer(df1, Row(1) :: Row(3) :: Nil) + assert(df1.schema.forall(!_.nullable)) + + val df2 = nullInts.intersect(nonNullableInts) + checkAnswer(df2, Row(1) :: Row(3) :: Nil) + assert(df2.schema.forall(!_.nullable)) + + val df3 = nullInts.intersect(nullInts) + checkAnswer(df3, Row(1) :: Row(2) :: Row(3) :: Row(null) :: Nil) + assert(df3.schema.forall(_.nullable)) + + val df4 = nonNullableInts.intersect(nonNullableInts) + checkAnswer(df4, Row(1) :: Row(3) :: Nil) + assert(df4.schema.forall(!_.nullable)) + } + + test("intersectAll") { + checkAnswer( + lowerCaseDataWithDuplicates.intersectAll(lowerCaseDataWithDuplicates), + Row(1, "a") :: + Row(2, "b") :: + Row(2, "b") :: + Row(3, "c") :: + Row(3, "c") :: + Row(3, "c") :: + Row(4, "d") :: Nil) + checkAnswer(lowerCaseData.intersectAll(upperCaseData), Nil) + + // check null equality + checkAnswer( + nullInts.intersectAll(nullInts), + Row(1) :: + Row(2) :: + Row(3) :: + Row(null) :: Nil) + + // Duplicate nulls are preserved. + checkAnswer( + allNulls.intersectAll(allNulls), + Row(null) :: Row(null) :: Row(null) :: Row(null) :: Nil) + + val df_left = Seq(1, 2, 2, 3, 3, 4).toDF("id") + val df_right = Seq(1, 2, 2, 3).toDF("id") + + checkAnswer( + df_left.intersectAll(df_right), + Row(1) :: Row(2) :: Row(2) :: Row(3) :: Nil) + } + + test("intersectAll - nullability") { + val nonNullableInts = Seq(Tuple1(1), Tuple1(3)).toDF() + assert(nonNullableInts.schema.forall(!_.nullable)) + + val df1 = nonNullableInts.intersectAll(nullInts) + checkAnswer(df1, Row(1) :: Row(3) :: Nil) + assert(df1.schema.forall(!_.nullable)) + + val df2 = nullInts.intersectAll(nonNullableInts) + checkAnswer(df2, Row(1) :: Row(3) :: Nil) + assert(df2.schema.forall(!_.nullable)) + + val df3 = nullInts.intersectAll(nullInts) + checkAnswer(df3, Row(1) :: Row(2) :: Row(3) :: Row(null) :: Nil) + assert(df3.schema.forall(_.nullable)) + + val df4 = nonNullableInts.intersectAll(nonNullableInts) + checkAnswer(df4, Row(1) :: Row(3) :: Nil) + assert(df4.schema.forall(!_.nullable)) + } + + test("SPARK-10539: Project should not be pushed down through Intersect or Except") { + val df1 = (1 to 100).map(Tuple1.apply).toDF("i") + val df2 = (1 to 30).map(Tuple1.apply).toDF("i") + val intersect = df1.intersect(df2) + val except = df1.except(df2) + assert(intersect.count() === 30) + assert(except.count() === 70) + } + + test("SPARK-10740: handle nondeterministic expressions correctly for set operations") { + val df1 = (1 to 20).map(Tuple1.apply).toDF("i") + val df2 = (1 to 10).map(Tuple1.apply).toDF("i") + + // When generating expected results at here, we need to follow the implementation of + // Rand expression. + def expected(df: DataFrame): Seq[Row] = { + df.rdd.collectPartitions().zipWithIndex.flatMap { + case (data, index) => + val rng = new org.apache.spark.util.random.XORShiftRandom(7 + index) + data.filter(_.getInt(0) < rng.nextDouble() * 10) + } + } + + val union = df1.union(df2) + checkAnswer( + union.filter('i < rand(7) * 10), + expected(union) + ) + checkAnswer( + union.select(rand(7)), + union.rdd.collectPartitions().zipWithIndex.flatMap { + case (data, index) => + val rng = new org.apache.spark.util.random.XORShiftRandom(7 + index) + data.map(_ => rng.nextDouble()).map(i => Row(i)) + } + ) + + val intersect = df1.intersect(df2) + checkAnswer( + intersect.filter('i < rand(7) * 10), + expected(intersect) + ) + + val except = df1.except(df2) + checkAnswer( + except.filter('i < rand(7) * 10), + expected(except) + ) + } + + test("SPARK-17123: Performing set operations that combine non-scala native types") { + val dates = Seq( + (new Date(0), BigDecimal.valueOf(1), new Timestamp(2)), + (new Date(3), BigDecimal.valueOf(4), new Timestamp(5)) + ).toDF("date", "timestamp", "decimal") + + val widenTypedRows = Seq( + (new Timestamp(2), 10.5D, "string") + ).toDF("date", "timestamp", "decimal") + + dates.union(widenTypedRows).collect() + dates.except(widenTypedRows).collect() + dates.intersect(widenTypedRows).collect() + } + + test("SPARK-19893: cannot run set operations with map type") { + val df = spark.range(1).select(map(lit("key"), $"id").as("m")) + val e = intercept[AnalysisException](df.intersect(df)) + assert(e.message.contains( + "Cannot have map type columns in DataFrame which calls set operations")) + val e2 = intercept[AnalysisException](df.except(df)) + assert(e2.message.contains( + "Cannot have map type columns in DataFrame which calls set operations")) + val e3 = intercept[AnalysisException](df.distinct()) + assert(e3.message.contains( + "Cannot have map type columns in DataFrame which calls set operations")) + withTempView("v") { + df.createOrReplaceTempView("v") + val e4 = intercept[AnalysisException](sql("SELECT DISTINCT m FROM v")) + assert(e4.message.contains( + "Cannot have map type columns in DataFrame which calls set operations")) + } + } + + test("union all") { + val unionDF = testData.union(testData).union(testData) + .union(testData).union(testData) + + // Before optimizer, Union should be combined. + assert(unionDF.queryExecution.analyzed.collect { + case j: Union if j.children.size == 5 => j }.size === 1) + + checkAnswer( + unionDF.agg(avg('key), max('key), min('key), sum('key)), + Row(50.5, 100, 1, 25250) :: Nil + ) + + // unionAll is an alias of union + val unionAllDF = testData.unionAll(testData).unionAll(testData) + .unionAll(testData).unionAll(testData) + + checkAnswer(unionDF, unionAllDF) + } + + test("union should union DataFrames with UDTs (SPARK-13410)") { + val rowRDD1 = sparkContext.parallelize(Seq(Row(1, new ExamplePoint(1.0, 2.0)))) + val schema1 = StructType(Array(StructField("label", IntegerType, false), + StructField("point", new ExamplePointUDT(), false))) + val rowRDD2 = sparkContext.parallelize(Seq(Row(2, new ExamplePoint(3.0, 4.0)))) + val schema2 = StructType(Array(StructField("label", IntegerType, false), + StructField("point", new ExamplePointUDT(), false))) + val df1 = spark.createDataFrame(rowRDD1, schema1) + val df2 = spark.createDataFrame(rowRDD2, schema2) + + checkAnswer( + df1.union(df2).orderBy("label"), + Seq(Row(1, new ExamplePoint(1.0, 2.0)), Row(2, new ExamplePoint(3.0, 4.0))) + ) + } + + test("union by name") { + var df1 = Seq((1, 2, 3)).toDF("a", "b", "c") + var df2 = Seq((3, 1, 2)).toDF("c", "a", "b") + val df3 = Seq((2, 3, 1)).toDF("b", "c", "a") + val unionDf = df1.unionByName(df2.unionByName(df3)) + checkAnswer(unionDf, + Row(1, 2, 3) :: Row(1, 2, 3) :: Row(1, 2, 3) :: Nil + ) + + // Check if adjacent unions are combined into a single one + assert(unionDf.queryExecution.optimizedPlan.collect { case u: Union => true }.size == 1) + + // Check failure cases + df1 = Seq((1, 2)).toDF("a", "c") + df2 = Seq((3, 4, 5)).toDF("a", "b", "c") + var errMsg = intercept[AnalysisException] { + df1.unionByName(df2) + }.getMessage + assert(errMsg.contains( + "Union can only be performed on tables with the same number of columns, " + + "but the first table has 2 columns and the second table has 3 columns")) + + df1 = Seq((1, 2, 3)).toDF("a", "b", "c") + df2 = Seq((4, 5, 6)).toDF("a", "c", "d") + errMsg = intercept[AnalysisException] { + df1.unionByName(df2) + }.getMessage + assert(errMsg.contains("""Cannot resolve column name "b" among (a, c, d)""")) + } + + test("union by name - type coercion") { + var df1 = Seq((1, "a")).toDF("c0", "c1") + var df2 = Seq((3, 1L)).toDF("c1", "c0") + checkAnswer(df1.unionByName(df2), Row(1L, "a") :: Row(1L, "3") :: Nil) + + df1 = Seq((1, 1.0)).toDF("c0", "c1") + df2 = Seq((8L, 3.0)).toDF("c1", "c0") + checkAnswer(df1.unionByName(df2), Row(1.0, 1.0) :: Row(3.0, 8.0) :: Nil) + + df1 = Seq((2.0f, 7.4)).toDF("c0", "c1") + df2 = Seq(("a", 4.0)).toDF("c1", "c0") + checkAnswer(df1.unionByName(df2), Row(2.0, "7.4") :: Row(4.0, "a") :: Nil) + + df1 = Seq((1, "a", 3.0)).toDF("c0", "c1", "c2") + df2 = Seq((1.2, 2, "bc")).toDF("c2", "c0", "c1") + val df3 = Seq(("def", 1.2, 3)).toDF("c1", "c2", "c0") + checkAnswer(df1.unionByName(df2.unionByName(df3)), + Row(1, "a", 3.0) :: Row(2, "bc", 1.2) :: Row(3, "def", 1.2) :: Nil + ) + } + + test("union by name - check case sensitivity") { + def checkCaseSensitiveTest(): Unit = { + val df1 = Seq((1, 2, 3)).toDF("ab", "cd", "ef") + val df2 = Seq((4, 5, 6)).toDF("cd", "ef", "AB") + checkAnswer(df1.unionByName(df2), Row(1, 2, 3) :: Row(6, 4, 5) :: Nil) + } + withSQLConf(SQLConf.CASE_SENSITIVE.key -> "true") { + val errMsg2 = intercept[AnalysisException] { + checkCaseSensitiveTest() + }.getMessage + assert(errMsg2.contains("""Cannot resolve column name "ab" among (cd, ef, AB)""")) + } + withSQLConf(SQLConf.CASE_SENSITIVE.key -> "false") { + checkCaseSensitiveTest() + } + } + + test("union by name - check name duplication") { + Seq((true, ("a", "a")), (false, ("aA", "Aa"))).foreach { case (caseSensitive, (c0, c1)) => + withSQLConf(SQLConf.CASE_SENSITIVE.key -> caseSensitive.toString) { + var df1 = Seq((1, 1)).toDF(c0, c1) + var df2 = Seq((1, 1)).toDF("c0", "c1") + var errMsg = intercept[AnalysisException] { + df1.unionByName(df2) + }.getMessage + assert(errMsg.contains("Found duplicate column(s) in the left attributes:")) + df1 = Seq((1, 1)).toDF("c0", "c1") + df2 = Seq((1, 1)).toDF(c0, c1) + errMsg = intercept[AnalysisException] { + df1.unionByName(df2) + }.getMessage + assert(errMsg.contains("Found duplicate column(s) in the right attributes:")) + } + } + } + + test("SPARK-25368 Incorrect predicate pushdown returns wrong result") { + def check(newCol: Column, filter: Column, result: Seq[Row]): Unit = { + val df1 = spark.createDataFrame(Seq( + (1, 1) + )).toDF("a", "b").withColumn("c", newCol) + + val df2 = df1.union(df1).withColumn("d", spark_partition_id).filter(filter) + checkAnswer(df2, result) + } + + check(lit(null).cast("int"), $"c".isNull, Seq(Row(1, 1, null, 0), Row(1, 1, null, 1))) + check(lit(null).cast("int"), $"c".isNotNull, Seq()) + check(lit(2).cast("int"), $"c".isNull, Seq()) + check(lit(2).cast("int"), $"c".isNotNull, Seq(Row(1, 1, 2, 0), Row(1, 1, 2, 1))) + check(lit(2).cast("int"), $"c" === 2, Seq(Row(1, 1, 2, 0), Row(1, 1, 2, 1))) + check(lit(2).cast("int"), $"c" =!= 2, Seq()) + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala index 4447e688475d1..1bfe3818742db 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala @@ -528,259 +528,6 @@ class DataFrameSuite extends QueryTest with SharedSQLContext { ) } - test("except") { - checkAnswer( - lowerCaseData.except(upperCaseData), - Row(1, "a") :: - Row(2, "b") :: - Row(3, "c") :: - Row(4, "d") :: Nil) - checkAnswer(lowerCaseData.except(lowerCaseData), Nil) - checkAnswer(upperCaseData.except(upperCaseData), Nil) - - // check null equality - checkAnswer( - nullInts.except(nullInts.filter("0 = 1")), - nullInts) - checkAnswer( - nullInts.except(nullInts), - Nil) - - // check if values are de-duplicated - checkAnswer( - allNulls.except(allNulls.filter("0 = 1")), - Row(null) :: Nil) - checkAnswer( - allNulls.except(allNulls), - Nil) - - // check if values are de-duplicated - val df = Seq(("id1", 1), ("id1", 1), ("id", 1), ("id1", 2)).toDF("id", "value") - checkAnswer( - df.except(df.filter("0 = 1")), - Row("id1", 1) :: - Row("id", 1) :: - Row("id1", 2) :: Nil) - - // check if the empty set on the left side works - checkAnswer( - allNulls.filter("0 = 1").except(allNulls), - Nil) - } - - test("SPARK-23274: except between two projects without references used in filter") { - val df = Seq((1, 2, 4), (1, 3, 5), (2, 2, 3), (2, 4, 5)).toDF("a", "b", "c") - val df1 = df.filter($"a" === 1) - val df2 = df.filter($"a" === 2) - checkAnswer(df1.select("b").except(df2.select("b")), Row(3) :: Nil) - checkAnswer(df1.select("b").except(df2.select("c")), Row(2) :: Nil) - } - - test("except distinct - SQL compliance") { - val df_left = Seq(1, 2, 2, 3, 3, 4).toDF("id") - val df_right = Seq(1, 3).toDF("id") - - checkAnswer( - df_left.except(df_right), - Row(2) :: Row(4) :: Nil - ) - } - - test("except - nullability") { - val nonNullableInts = Seq(Tuple1(11), Tuple1(3)).toDF() - assert(nonNullableInts.schema.forall(!_.nullable)) - - val df1 = nonNullableInts.except(nullInts) - checkAnswer(df1, Row(11) :: Nil) - assert(df1.schema.forall(!_.nullable)) - - val df2 = nullInts.except(nonNullableInts) - checkAnswer(df2, Row(1) :: Row(2) :: Row(null) :: Nil) - assert(df2.schema.forall(_.nullable)) - - val df3 = nullInts.except(nullInts) - checkAnswer(df3, Nil) - assert(df3.schema.forall(_.nullable)) - - val df4 = nonNullableInts.except(nonNullableInts) - checkAnswer(df4, Nil) - assert(df4.schema.forall(!_.nullable)) - } - - test("except all") { - checkAnswer( - lowerCaseData.exceptAll(upperCaseData), - Row(1, "a") :: - Row(2, "b") :: - Row(3, "c") :: - Row(4, "d") :: Nil) - checkAnswer(lowerCaseData.exceptAll(lowerCaseData), Nil) - checkAnswer(upperCaseData.exceptAll(upperCaseData), Nil) - - // check null equality - checkAnswer( - nullInts.exceptAll(nullInts.filter("0 = 1")), - nullInts) - checkAnswer( - nullInts.exceptAll(nullInts), - Nil) - - // check that duplicate values are preserved - checkAnswer( - allNulls.exceptAll(allNulls.filter("0 = 1")), - Row(null) :: Row(null) :: Row(null) :: Row(null) :: Nil) - checkAnswer( - allNulls.exceptAll(allNulls.limit(2)), - Row(null) :: Row(null) :: Nil) - - // check that duplicates are retained. - val df = spark.sparkContext.parallelize( - NullStrings(1, "id1") :: - NullStrings(1, "id1") :: - NullStrings(2, "id1") :: - NullStrings(3, null) :: Nil).toDF("id", "value") - - checkAnswer( - df.exceptAll(df.filter("0 = 1")), - Row(1, "id1") :: - Row(1, "id1") :: - Row(2, "id1") :: - Row(3, null) :: Nil) - - // check if the empty set on the left side works - checkAnswer( - allNulls.filter("0 = 1").exceptAll(allNulls), - Nil) - - } - - test("exceptAll - nullability") { - val nonNullableInts = Seq(Tuple1(11), Tuple1(3)).toDF() - assert(nonNullableInts.schema.forall(!_.nullable)) - - val df1 = nonNullableInts.exceptAll(nullInts) - checkAnswer(df1, Row(11) :: Nil) - assert(df1.schema.forall(!_.nullable)) - - val df2 = nullInts.exceptAll(nonNullableInts) - checkAnswer(df2, Row(1) :: Row(2) :: Row(null) :: Nil) - assert(df2.schema.forall(_.nullable)) - - val df3 = nullInts.exceptAll(nullInts) - checkAnswer(df3, Nil) - assert(df3.schema.forall(_.nullable)) - - val df4 = nonNullableInts.exceptAll(nonNullableInts) - checkAnswer(df4, Nil) - assert(df4.schema.forall(!_.nullable)) - } - - test("intersect") { - checkAnswer( - lowerCaseData.intersect(lowerCaseData), - Row(1, "a") :: - Row(2, "b") :: - Row(3, "c") :: - Row(4, "d") :: Nil) - checkAnswer(lowerCaseData.intersect(upperCaseData), Nil) - - // check null equality - checkAnswer( - nullInts.intersect(nullInts), - Row(1) :: - Row(2) :: - Row(3) :: - Row(null) :: Nil) - - // check if values are de-duplicated - checkAnswer( - allNulls.intersect(allNulls), - Row(null) :: Nil) - - // check if values are de-duplicated - val df = Seq(("id1", 1), ("id1", 1), ("id", 1), ("id1", 2)).toDF("id", "value") - checkAnswer( - df.intersect(df), - Row("id1", 1) :: - Row("id", 1) :: - Row("id1", 2) :: Nil) - } - - test("intersect - nullability") { - val nonNullableInts = Seq(Tuple1(1), Tuple1(3)).toDF() - assert(nonNullableInts.schema.forall(!_.nullable)) - - val df1 = nonNullableInts.intersect(nullInts) - checkAnswer(df1, Row(1) :: Row(3) :: Nil) - assert(df1.schema.forall(!_.nullable)) - - val df2 = nullInts.intersect(nonNullableInts) - checkAnswer(df2, Row(1) :: Row(3) :: Nil) - assert(df2.schema.forall(!_.nullable)) - - val df3 = nullInts.intersect(nullInts) - checkAnswer(df3, Row(1) :: Row(2) :: Row(3) :: Row(null) :: Nil) - assert(df3.schema.forall(_.nullable)) - - val df4 = nonNullableInts.intersect(nonNullableInts) - checkAnswer(df4, Row(1) :: Row(3) :: Nil) - assert(df4.schema.forall(!_.nullable)) - } - - test("intersectAll") { - checkAnswer( - lowerCaseDataWithDuplicates.intersectAll(lowerCaseDataWithDuplicates), - Row(1, "a") :: - Row(2, "b") :: - Row(2, "b") :: - Row(3, "c") :: - Row(3, "c") :: - Row(3, "c") :: - Row(4, "d") :: Nil) - checkAnswer(lowerCaseData.intersectAll(upperCaseData), Nil) - - // check null equality - checkAnswer( - nullInts.intersectAll(nullInts), - Row(1) :: - Row(2) :: - Row(3) :: - Row(null) :: Nil) - - // Duplicate nulls are preserved. - checkAnswer( - allNulls.intersectAll(allNulls), - Row(null) :: Row(null) :: Row(null) :: Row(null) :: Nil) - - val df_left = Seq(1, 2, 2, 3, 3, 4).toDF("id") - val df_right = Seq(1, 2, 2, 3).toDF("id") - - checkAnswer( - df_left.intersectAll(df_right), - Row(1) :: Row(2) :: Row(2) :: Row(3) :: Nil) - } - - test("intersectAll - nullability") { - val nonNullableInts = Seq(Tuple1(1), Tuple1(3)).toDF() - assert(nonNullableInts.schema.forall(!_.nullable)) - - val df1 = nonNullableInts.intersectAll(nullInts) - checkAnswer(df1, Row(1) :: Row(3) :: Nil) - assert(df1.schema.forall(!_.nullable)) - - val df2 = nullInts.intersectAll(nonNullableInts) - checkAnswer(df2, Row(1) :: Row(3) :: Nil) - assert(df2.schema.forall(!_.nullable)) - - val df3 = nullInts.intersectAll(nullInts) - checkAnswer(df3, Row(1) :: Row(2) :: Row(3) :: Row(null) :: Nil) - assert(df3.schema.forall(_.nullable)) - - val df4 = nonNullableInts.intersectAll(nonNullableInts) - checkAnswer(df4, Row(1) :: Row(3) :: Nil) - assert(df4.schema.forall(!_.nullable)) - } - test("udf") { val foo = udf((a: Int, b: String) => a.toString + b) @@ -1782,56 +1529,6 @@ class DataFrameSuite extends QueryTest with SharedSQLContext { } } - test("SPARK-10539: Project should not be pushed down through Intersect or Except") { - val df1 = (1 to 100).map(Tuple1.apply).toDF("i") - val df2 = (1 to 30).map(Tuple1.apply).toDF("i") - val intersect = df1.intersect(df2) - val except = df1.except(df2) - assert(intersect.count() === 30) - assert(except.count() === 70) - } - - test("SPARK-10740: handle nondeterministic expressions correctly for set operations") { - val df1 = (1 to 20).map(Tuple1.apply).toDF("i") - val df2 = (1 to 10).map(Tuple1.apply).toDF("i") - - // When generating expected results at here, we need to follow the implementation of - // Rand expression. - def expected(df: DataFrame): Seq[Row] = { - df.rdd.collectPartitions().zipWithIndex.flatMap { - case (data, index) => - val rng = new org.apache.spark.util.random.XORShiftRandom(7 + index) - data.filter(_.getInt(0) < rng.nextDouble() * 10) - } - } - - val union = df1.union(df2) - checkAnswer( - union.filter('i < rand(7) * 10), - expected(union) - ) - checkAnswer( - union.select(rand(7)), - union.rdd.collectPartitions().zipWithIndex.flatMap { - case (data, index) => - val rng = new org.apache.spark.util.random.XORShiftRandom(7 + index) - data.map(_ => rng.nextDouble()).map(i => Row(i)) - } - ) - - val intersect = df1.intersect(df2) - checkAnswer( - intersect.filter('i < rand(7) * 10), - expected(intersect) - ) - - val except = df1.except(df2) - checkAnswer( - except.filter('i < rand(7) * 10), - expected(except) - ) - } - test("SPARK-10743: keep the name of expression if possible when do cast") { val df = (1 to 10).map(Tuple1.apply).toDF("i").as("src") assert(df.select($"src.i".cast(StringType)).columns.head === "i") @@ -1992,7 +1689,7 @@ class DataFrameSuite extends QueryTest with SharedSQLContext { test("SPARK-11725: correctly handle null inputs for ScalaUDF") { val df = sparkContext.parallelize(Seq( - new java.lang.Integer(22) -> "John", + java.lang.Integer.valueOf(22) -> "John", null.asInstanceOf[java.lang.Integer] -> "Lucy")).toDF("age", "name") // passing null into the UDF that could handle it @@ -2225,9 +1922,9 @@ class DataFrameSuite extends QueryTest with SharedSQLContext { test("SPARK-17957: no change on nullability in FilterExec output") { val df = sparkContext.parallelize(Seq( - null.asInstanceOf[java.lang.Integer] -> new java.lang.Integer(3), - new java.lang.Integer(1) -> null.asInstanceOf[java.lang.Integer], - new java.lang.Integer(2) -> new java.lang.Integer(4))).toDF() + null.asInstanceOf[java.lang.Integer] -> java.lang.Integer.valueOf(3), + java.lang.Integer.valueOf(1) -> null.asInstanceOf[java.lang.Integer], + java.lang.Integer.valueOf(2) -> java.lang.Integer.valueOf(4))).toDF() verifyNullabilityInFilterExec(df, expr = "Rand()", expectedNonNullableColumns = Seq.empty[String]) @@ -2242,9 +1939,9 @@ class DataFrameSuite extends QueryTest with SharedSQLContext { test("SPARK-17957: set nullability to false in FilterExec output") { val df = sparkContext.parallelize(Seq( - null.asInstanceOf[java.lang.Integer] -> new java.lang.Integer(3), - new java.lang.Integer(1) -> null.asInstanceOf[java.lang.Integer], - new java.lang.Integer(2) -> new java.lang.Integer(4))).toDF() + null.asInstanceOf[java.lang.Integer] -> java.lang.Integer.valueOf(3), + java.lang.Integer.valueOf(1) -> null.asInstanceOf[java.lang.Integer], + java.lang.Integer.valueOf(2) -> java.lang.Integer.valueOf(4))).toDF() verifyNullabilityInFilterExec(df, expr = "_1 + _2 * 3", expectedNonNullableColumns = Seq("_1", "_2")) @@ -2280,21 +1977,6 @@ class DataFrameSuite extends QueryTest with SharedSQLContext { } } - test("SPARK-17123: Performing set operations that combine non-scala native types") { - val dates = Seq( - (new Date(0), BigDecimal.valueOf(1), new Timestamp(2)), - (new Date(3), BigDecimal.valueOf(4), new Timestamp(5)) - ).toDF("date", "timestamp", "decimal") - - val widenTypedRows = Seq( - (new Timestamp(2), 10.5D, "string") - ).toDF("date", "timestamp", "decimal") - - dates.union(widenTypedRows).collect() - dates.except(widenTypedRows).collect() - dates.intersect(widenTypedRows).collect() - } - test("SPARK-18070 binary operator should not consider nullability when comparing input types") { val rows = Seq(Row(Seq(1), Seq(1))) val schema = new StructType() @@ -2314,25 +1996,6 @@ class DataFrameSuite extends QueryTest with SharedSQLContext { checkAnswer(df, Row(BigDecimal(0)) :: Nil) } - test("SPARK-19893: cannot run set operations with map type") { - val df = spark.range(1).select(map(lit("key"), $"id").as("m")) - val e = intercept[AnalysisException](df.intersect(df)) - assert(e.message.contains( - "Cannot have map type columns in DataFrame which calls set operations")) - val e2 = intercept[AnalysisException](df.except(df)) - assert(e2.message.contains( - "Cannot have map type columns in DataFrame which calls set operations")) - val e3 = intercept[AnalysisException](df.distinct()) - assert(e3.message.contains( - "Cannot have map type columns in DataFrame which calls set operations")) - withTempView("v") { - df.createOrReplaceTempView("v") - val e4 = intercept[AnalysisException](sql("SELECT DISTINCT m FROM v")) - assert(e4.message.contains( - "Cannot have map type columns in DataFrame which calls set operations")) - } - } - test("SPARK-20359: catalyst outer join optimization should not throw npe") { val df1 = Seq("a", "b", "c").toDF("x") .withColumn("y", udf{ (x: String) => x.substring(0, 1) + "!" }.apply($"x")) @@ -2517,24 +2180,6 @@ class DataFrameSuite extends QueryTest with SharedSQLContext { } } - test("SPARK-25368 Incorrect predicate pushdown returns wrong result") { - def check(newCol: Column, filter: Column, result: Seq[Row]): Unit = { - val df1 = spark.createDataFrame(Seq( - (1, 1) - )).toDF("a", "b").withColumn("c", newCol) - - val df2 = df1.union(df1).withColumn("d", spark_partition_id).filter(filter) - checkAnswer(df2, result) - } - - check(lit(null).cast("int"), $"c".isNull, Seq(Row(1, 1, null, 0), Row(1, 1, null, 1))) - check(lit(null).cast("int"), $"c".isNotNull, Seq()) - check(lit(2).cast("int"), $"c".isNull, Seq()) - check(lit(2).cast("int"), $"c".isNotNull, Seq(Row(1, 1, 2, 0), Row(1, 1, 2, 1))) - check(lit(2).cast("int"), $"c" === 2, Seq(Row(1, 1, 2, 0), Row(1, 1, 2, 1))) - check(lit(2).cast("int"), $"c" =!= 2, Seq()) - } - test("SPARK-25402 Null handling in BooleanSimplification") { val schema = StructType.fromDDL("a boolean, b int") val rows = Seq(Row(null, 1)) @@ -2560,4 +2205,29 @@ class DataFrameSuite extends QueryTest with SharedSQLContext { checkAnswer(swappedDf.filter($"key"($"map") > "a"), Row(2, Map(2 -> "b"))) } + + test("SPARK-26057: attribute deduplication on already analyzed plans") { + withTempView("a", "b", "v") { + val df1 = Seq(("1-1", 6)).toDF("id", "n") + df1.createOrReplaceTempView("a") + val df3 = Seq("1-1").toDF("id") + df3.createOrReplaceTempView("b") + spark.sql( + """ + |SELECT a.id, n as m + |FROM a + |WHERE EXISTS( + | SELECT 1 + | FROM b + | WHERE b.id = a.id) + """.stripMargin).createOrReplaceTempView("v") + val res = spark.sql( + """ + |SELECT a.id, n, m + | FROM a + | LEFT OUTER JOIN v ON v.id = a.id + """.stripMargin) + checkAnswer(res, Row("1-1", 6, 6)) + } + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetAggregatorSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetAggregatorSuite.scala index 538ea3c66c40e..97c3f358c0e76 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetAggregatorSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetAggregatorSuite.scala @@ -22,7 +22,7 @@ import org.apache.spark.sql.expressions.Aggregator import org.apache.spark.sql.expressions.scalalang.typed import org.apache.spark.sql.functions._ import org.apache.spark.sql.test.SharedSQLContext -import org.apache.spark.sql.types.StringType +import org.apache.spark.sql.types.{BooleanType, IntegerType, StringType, StructType} object ComplexResultAgg extends Aggregator[(String, Int), (Long, Long), (Long, Long)] { @@ -149,6 +149,7 @@ object VeryComplexResultAgg extends Aggregator[Row, String, ComplexAggData] { case class OptionBooleanData(name: String, isGood: Option[Boolean]) +case class OptionBooleanIntData(name: String, isGood: Option[(Boolean, Int)]) case class OptionBooleanAggregator(colName: String) extends Aggregator[Row, Option[Boolean], Option[Boolean]] { @@ -183,6 +184,43 @@ case class OptionBooleanAggregator(colName: String) def OptionalBoolEncoder: Encoder[Option[Boolean]] = ExpressionEncoder() } +case class OptionBooleanIntAggregator(colName: String) + extends Aggregator[Row, Option[(Boolean, Int)], Option[(Boolean, Int)]] { + + override def zero: Option[(Boolean, Int)] = None + + override def reduce(buffer: Option[(Boolean, Int)], row: Row): Option[(Boolean, Int)] = { + val index = row.fieldIndex(colName) + val value = if (row.isNullAt(index)) { + Option.empty[(Boolean, Int)] + } else { + val nestedRow = row.getStruct(index) + Some((nestedRow.getBoolean(0), nestedRow.getInt(1))) + } + merge(buffer, value) + } + + override def merge( + b1: Option[(Boolean, Int)], + b2: Option[(Boolean, Int)]): Option[(Boolean, Int)] = { + if ((b1.isDefined && b1.get._1) || (b2.isDefined && b2.get._1)) { + val newInt = b1.map(_._2).getOrElse(0) + b2.map(_._2).getOrElse(0) + Some((true, newInt)) + } else if (b1.isDefined) { + b1 + } else { + b2 + } + } + + override def finish(reduction: Option[(Boolean, Int)]): Option[(Boolean, Int)] = reduction + + override def bufferEncoder: Encoder[Option[(Boolean, Int)]] = OptionalBoolIntEncoder + override def outputEncoder: Encoder[Option[(Boolean, Int)]] = OptionalBoolIntEncoder + + def OptionalBoolIntEncoder: Encoder[Option[(Boolean, Int)]] = ExpressionEncoder() +} + class DatasetAggregatorSuite extends QueryTest with SharedSQLContext { import testImplicits._ @@ -393,4 +431,28 @@ class DatasetAggregatorSuite extends QueryTest with SharedSQLContext { assert(grouped.schema == df.schema) checkDataset(grouped.as[OptionBooleanData], OptionBooleanData("bob", Some(true))) } + + test("SPARK-24762: Aggregator should be able to use Option of Product encoder") { + val df = Seq( + OptionBooleanIntData("bob", Some((true, 1))), + OptionBooleanIntData("bob", Some((false, 2))), + OptionBooleanIntData("bob", None)).toDF() + + val group = df + .groupBy("name") + .agg(OptionBooleanIntAggregator("isGood").toColumn.alias("isGood")) + + val expectedSchema = new StructType() + .add("name", StringType, nullable = true) + .add("isGood", + new StructType() + .add("_1", BooleanType, nullable = false) + .add("_2", IntegerType, nullable = false), + nullable = true) + + assert(df.schema == expectedSchema) + assert(group.schema == expectedSchema) + checkAnswer(group, Row("bob", Row(true, 3)) :: Nil) + checkDataset(group.as[OptionBooleanIntData], OptionBooleanIntData("bob", Some((true, 3)))) + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetPrimitiveSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetPrimitiveSuite.scala index 96a6792f52f3e..0ded5d8ce1e28 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetPrimitiveSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetPrimitiveSuite.scala @@ -393,4 +393,33 @@ class DatasetPrimitiveSuite extends QueryTest with SharedSQLContext { val ds = spark.createDataset(data) checkDataset(ds, data: _*) } + + test("special floating point values") { + import org.scalatest.exceptions.TestFailedException + + // Spark treats -0.0 as 0.0 + intercept[TestFailedException] { + checkDataset(Seq(-0.0d).toDS(), -0.0d) + } + intercept[TestFailedException] { + checkDataset(Seq(-0.0f).toDS(), -0.0f) + } + intercept[TestFailedException] { + checkDataset(Seq(Tuple1(-0.0)).toDS(), Tuple1(-0.0)) + } + + val floats = Seq[Float](-0.0f, 0.0f, Float.NaN).toDS() + checkDataset(floats, 0.0f, 0.0f, Float.NaN) + + val doubles = Seq[Double](-0.0d, 0.0d, Double.NaN).toDS() + checkDataset(doubles, 0.0, 0.0, Double.NaN) + + checkDataset(Seq(Tuple1(Float.NaN)).toDS(), Tuple1(Float.NaN)) + checkDataset(Seq(Tuple1(-0.0f)).toDS(), Tuple1(0.0f)) + checkDataset(Seq(Tuple1(Double.NaN)).toDS(), Tuple1(Double.NaN)) + checkDataset(Seq(Tuple1(-0.0)).toDS(), Tuple1(0.0)) + + val complex = Map(Array(Seq(Tuple1(Double.NaN))) -> Map(Tuple2(Float.NaN, null))) + checkDataset(Seq(complex).toDS(), complex) + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala index 8d1e8ba8b1e1d..43849b9e61938 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala @@ -21,6 +21,7 @@ import java.io.{Externalizable, ObjectInput, ObjectOutput} import java.sql.{Date, Timestamp} import org.apache.spark.SparkException +import org.apache.spark.sql.catalyst.ScroogeLikeExample import org.apache.spark.sql.catalyst.encoders.{OuterScopes, RowEncoder} import org.apache.spark.sql.catalyst.plans.{LeftAnti, LeftSemi} import org.apache.spark.sql.catalyst.util.sideBySide @@ -697,15 +698,15 @@ class DatasetSuite extends QueryTest with SharedSQLContext { test("SPARK-11894: Incorrect results are returned when using null") { val nullInt = null.asInstanceOf[java.lang.Integer] - val ds1 = Seq((nullInt, "1"), (new java.lang.Integer(22), "2")).toDS() - val ds2 = Seq((nullInt, "1"), (new java.lang.Integer(22), "2")).toDS() + val ds1 = Seq((nullInt, "1"), (java.lang.Integer.valueOf(22), "2")).toDS() + val ds2 = Seq((nullInt, "1"), (java.lang.Integer.valueOf(22), "2")).toDS() checkDataset( ds1.joinWith(ds2, lit(true), "cross"), ((nullInt, "1"), (nullInt, "1")), - ((nullInt, "1"), (new java.lang.Integer(22), "2")), - ((new java.lang.Integer(22), "2"), (nullInt, "1")), - ((new java.lang.Integer(22), "2"), (new java.lang.Integer(22), "2"))) + ((nullInt, "1"), (java.lang.Integer.valueOf(22), "2")), + ((java.lang.Integer.valueOf(22), "2"), (nullInt, "1")), + ((java.lang.Integer.valueOf(22), "2"), (java.lang.Integer.valueOf(22), "2"))) } test("change encoder with compatible schema") { @@ -881,7 +882,7 @@ class DatasetSuite extends QueryTest with SharedSQLContext { assert(ds.rdd.map(r => r.id).count === 2) assert(ds2.rdd.map(r => r.id).count === 2) - val ds3 = ds.map(g => new java.lang.Long(g.id)) + val ds3 = ds.map(g => java.lang.Long.valueOf(g.id)) assert(ds3.rdd.map(r => r).count === 2) } @@ -1311,15 +1312,6 @@ class DatasetSuite extends QueryTest with SharedSQLContext { checkDataset(dsString, arrayString) } - test("SPARK-18251: the type of Dataset can't be Option of Product type") { - checkDataset(Seq(Some(1), None).toDS(), Some(1), None) - - val e = intercept[UnsupportedOperationException] { - Seq(Some(1 -> "a"), None).toDS() - } - assert(e.getMessage.contains("Cannot create encoder for Option of Product type")) - } - test ("SPARK-17460: the sizeInBytes in Statistics shouldn't overflow to a negative number") { // Since the sizeInBytes in Statistics could exceed the limit of an Int, we should use BigInt // instead of Int for avoiding possible overflow. @@ -1499,7 +1491,7 @@ class DatasetSuite extends QueryTest with SharedSQLContext { assert(e.getCause.isInstanceOf[NullPointerException]) withTempPath { path => - Seq(new Integer(1), null).toDF("i").write.parquet(path.getCanonicalPath) + Seq(Integer.valueOf(1), null).toDF("i").write.parquet(path.getCanonicalPath) // If the primitive values are from files, we need to do runtime null check. val ds = spark.read.parquet(path.getCanonicalPath).as[Int] intercept[NullPointerException](ds.collect()) @@ -1553,9 +1545,108 @@ class DatasetSuite extends QueryTest with SharedSQLContext { val df = Seq("Amsterdam", "San Francisco", "X").toDF("city") checkAnswer(df.where('city === 'X'), Seq(Row("X"))) checkAnswer( - df.where($"city".contains(new java.lang.Character('A'))), + df.where($"city".contains(java.lang.Character.valueOf('A'))), Seq(Row("Amsterdam"))) } + + test("SPARK-24762: Enable top-level Option of Product encoders") { + val data = Seq(Some((1, "a")), Some((2, "b")), None) + val ds = data.toDS() + + checkDataset( + ds, + data: _*) + + val schema = new StructType().add( + "value", + new StructType() + .add("_1", IntegerType, nullable = false) + .add("_2", StringType, nullable = true), + nullable = true) + + assert(ds.schema == schema) + + val nestedOptData = Seq(Some((Some((1, "a")), 2.0)), Some((Some((2, "b")), 3.0))) + val nestedDs = nestedOptData.toDS() + + checkDataset( + nestedDs, + nestedOptData: _*) + + val nestedSchema = StructType(Seq( + StructField("value", StructType(Seq( + StructField("_1", StructType(Seq( + StructField("_1", IntegerType, nullable = false), + StructField("_2", StringType, nullable = true)))), + StructField("_2", DoubleType, nullable = false) + )), nullable = true) + )) + assert(nestedDs.schema == nestedSchema) + } + + test("SPARK-24762: Resolving Option[Product] field") { + val ds = Seq((1, ("a", 1.0)), (2, ("b", 2.0)), (3, null)).toDS() + .as[(Int, Option[(String, Double)])] + checkDataset(ds, + (1, Some(("a", 1.0))), (2, Some(("b", 2.0))), (3, None)) + } + + test("SPARK-24762: select Option[Product] field") { + val ds = Seq(("a", 1), ("b", 2), ("c", 3)).toDS() + val ds1 = ds.select(expr("struct(_2, _2 + 1)").as[Option[(Int, Int)]]) + checkDataset(ds1, + Some((1, 2)), Some((2, 3)), Some((3, 4))) + + val ds2 = ds.select(expr("if(_2 > 2, struct(_2, _2 + 1), null)").as[Option[(Int, Int)]]) + checkDataset(ds2, + None, None, Some((3, 4))) + } + + test("SPARK-24762: joinWith on Option[Product]") { + val ds1 = Seq(Some((1, 2)), Some((2, 3)), None).toDS().as("a") + val ds2 = Seq(Some((1, 2)), Some((2, 3)), None).toDS().as("b") + val joined = ds1.joinWith(ds2, $"a.value._1" === $"b.value._2", "inner") + checkDataset(joined, (Some((2, 3)), Some((1, 2)))) + } + + test("SPARK-24762: typed agg on Option[Product] type") { + val ds = Seq(Some((1, 2)), Some((2, 3)), Some((1, 3))).toDS() + assert(ds.groupByKey(_.get._1).count().collect() === Seq((1, 2), (2, 1))) + + assert(ds.groupByKey(x => x).count().collect() === + Seq((Some((1, 2)), 1), (Some((2, 3)), 1), (Some((1, 3)), 1))) + } + + test("SPARK-25942: typed aggregation on primitive type") { + val ds = Seq(1, 2, 3).toDS() + + val agg = ds.groupByKey(_ >= 2) + .agg(sum("value").as[Long], sum($"value" + 1).as[Long]) + checkDatasetUnorderly(agg, (false, 1L, 2L), (true, 5L, 7L)) + } + + test("SPARK-25942: typed aggregation on product type") { + val ds = Seq((1, 2), (2, 3), (3, 4)).toDS() + val agg = ds.groupByKey(x => x).agg(sum("_1").as[Long], sum($"_2" + 1).as[Long]) + checkDatasetUnorderly(agg, ((1, 2), 1L, 3L), ((2, 3), 2L, 4L), ((3, 4), 3L, 5L)) + } + + test("SPARK-26085: fix key attribute name for atomic type for typed aggregation") { + val ds = Seq(1, 2, 3).toDS() + assert(ds.groupByKey(x => x).count().schema.head.name == "key") + + // Enable legacy flag to follow previous Spark behavior + withSQLConf(SQLConf.NAME_NON_STRUCT_GROUPING_KEY_AS_VALUE.key -> "true") { + assert(ds.groupByKey(x => x).count().schema.head.name == "value") + } + } + + test("SPARK-8288: class with only a companion object constructor") { + val data = Seq(ScroogeLikeExample(1), ScroogeLikeExample(2)) + val ds = data.toDS + checkDataset(ds, data: _*) + checkAnswer(ds.select("x"), Seq(Row(1), Row(2))) + } } case class TestDataUnion(x: Int, y: Int, z: Int) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/FileBasedDataSourceSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/FileBasedDataSourceSuite.scala index 79f3ef28a630f..65973648c3373 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/FileBasedDataSourceSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/FileBasedDataSourceSuite.scala @@ -514,9 +514,9 @@ object TestingUDT { override def sqlType: DataType = CalendarIntervalType override def serialize(obj: IntervalData): Any = - throw new NotImplementedError("Not implemented") + throw new UnsupportedOperationException("Not implemented") override def deserialize(datum: Any): IntervalData = - throw new NotImplementedError("Not implemented") + throw new UnsupportedOperationException("Not implemented") override def userClass: Class[IntervalData] = classOf[IntervalData] } @@ -526,9 +526,10 @@ object TestingUDT { private[sql] class NullUDT extends UserDefinedType[NullData] { override def sqlType: DataType = NullType - override def serialize(obj: NullData): Any = throw new NotImplementedError("Not implemented") + override def serialize(obj: NullData): Any = + throw new UnsupportedOperationException("Not implemented") override def deserialize(datum: Any): NullData = - throw new NotImplementedError("Not implemented") + throw new UnsupportedOperationException("Not implemented") override def userClass: Class[NullData] = classOf[NullData] } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/JsonFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/JsonFunctionsSuite.scala index 2b09782faeeaa..24e7564259c83 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/JsonFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/JsonFunctionsSuite.scala @@ -17,6 +17,9 @@ package org.apache.spark.sql +import java.text.SimpleDateFormat +import java.util.Locale + import collection.JavaConverters._ import org.apache.spark.SparkException @@ -578,4 +581,31 @@ class JsonFunctionsSuite extends QueryTest with SharedSQLContext { "Acceptable modes are PERMISSIVE and FAILFAST.")) } } + + test("corrupt record column in the middle") { + val schema = new StructType() + .add("a", IntegerType) + .add("_unparsed", StringType) + .add("b", IntegerType) + val badRec = """{"a" 1, "b": 11}""" + val df = Seq(badRec, """{"a": 2, "b": 12}""").toDS() + + checkAnswer( + df.select(from_json($"value", schema, Map("columnNameOfCorruptRecord" -> "_unparsed"))), + Row(Row(null, badRec, null)) :: Row(Row(2, null, 12)) :: Nil) + } + + test("parse timestamps with locale") { + Seq("en-US", "ko-KR", "zh-CN", "ru-RU").foreach { langTag => + val locale = Locale.forLanguageTag(langTag) + val ts = new SimpleDateFormat("dd/MM/yyyy HH:mm").parse("06/11/2018 18:00") + val timestampFormat = "dd MMM yyyy HH:mm" + val sdf = new SimpleDateFormat(timestampFormat, locale) + val input = Seq(s"""{"time": "${sdf.format(ts)}"}""").toDS() + val options = Map("timestampFormat" -> timestampFormat, "locale" -> langTag) + val df = input.select(from_json($"value", "time timestamp", options)) + + checkAnswer(df, Row(Row(java.sql.Timestamp.valueOf("2018-11-06 18:00:00.0")))) + } + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala index baca9c1cfb9a0..a547676c5ed5c 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala @@ -132,6 +132,13 @@ abstract class QueryTest extends PlanTest { a.length == b.length && a.zip(b).forall { case (l, r) => compare(l, r)} case (a: Iterable[_], b: Iterable[_]) => a.size == b.size && a.zip(b).forall { case (l, r) => compare(l, r)} + case (a: Product, b: Product) => + compare(a.productIterator.toSeq, b.productIterator.toSeq) + // 0.0 == -0.0, turn float/double to binary before comparison, to distinguish 0.0 and -0.0. + case (a: Double, b: Double) => + java.lang.Double.doubleToRawLongBits(a) == java.lang.Double.doubleToRawLongBits(b) + case (a: Float, b: Float) => + java.lang.Float.floatToRawIntBits(a) == java.lang.Float.floatToRawIntBits(b) case (a, b) => a == b } @@ -289,7 +296,7 @@ object QueryTest { def prepareRow(row: Row): Row = { Row.fromSeq(row.toSeq.map { case null => null - case d: java.math.BigDecimal => BigDecimal(d) + case bd: java.math.BigDecimal => BigDecimal(bd) // Equality of WrappedArray differs for AnyVal and AnyRef in Scala 2.12.2+ case seq: Seq[_] => seq.map { case b: java.lang.Byte => b.byteValue @@ -303,6 +310,9 @@ object QueryTest { // Convert array to Seq for easy equality check. case b: Array[_] => b.toSeq case r: Row => prepareRow(r) + // spark treats -0.0 as 0.0 + case d: Double if d == -0.0d => 0.0d + case f: Float if f == -0.0f => 0.0f case o => o }) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/ReplaceNullWithFalseEndToEndSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/ReplaceNullWithFalseInPredicateEndToEndSuite.scala similarity index 63% rename from sql/core/src/test/scala/org/apache/spark/sql/ReplaceNullWithFalseEndToEndSuite.scala rename to sql/core/src/test/scala/org/apache/spark/sql/ReplaceNullWithFalseInPredicateEndToEndSuite.scala index fc6ecc4e032f6..0f84b0c961a10 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/ReplaceNullWithFalseEndToEndSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/ReplaceNullWithFalseInPredicateEndToEndSuite.scala @@ -17,12 +17,13 @@ package org.apache.spark.sql -import org.apache.spark.sql.catalyst.expressions.{CaseWhen, If} +import org.apache.spark.sql.catalyst.expressions.{CaseWhen, If, Literal} import org.apache.spark.sql.execution.LocalTableScanExec import org.apache.spark.sql.functions.{lit, when} import org.apache.spark.sql.test.SharedSQLContext +import org.apache.spark.sql.types.BooleanType -class ReplaceNullWithFalseEndToEndSuite extends QueryTest with SharedSQLContext { +class ReplaceNullWithFalseInPredicateEndToEndSuite extends QueryTest with SharedSQLContext { import testImplicits._ test("SPARK-25860: Replace Literal(null, _) with FalseLiteral whenever possible") { @@ -68,4 +69,44 @@ class ReplaceNullWithFalseEndToEndSuite extends QueryTest with SharedSQLContext case p => fail(s"$p is not LocalTableScanExec") } } + + test("SPARK-26107: Replace Literal(null, _) with FalseLiteral in higher-order functions") { + def assertNoLiteralNullInPlan(df: DataFrame): Unit = { + df.queryExecution.executedPlan.foreach { p => + assert(p.expressions.forall(_.find { + case Literal(null, BooleanType) => true + case _ => false + }.isEmpty)) + } + } + + withTable("t1", "t2") { + // to test ArrayFilter and ArrayExists + spark.sql("select array(null, 1, null, 3) as a") + .write.saveAsTable("t1") + // to test MapFilter + spark.sql(""" + select map_from_entries(arrays_zip(a, transform(a, e -> if(mod(e, 2) = 0, null, e)))) as m + from (select array(0, 1, 2, 3) as a) + """).write.saveAsTable("t2") + + val df1 = spark.table("t1") + val df2 = spark.table("t2") + + // ArrayExists + val q1 = df1.selectExpr("EXISTS(a, e -> IF(e is null, null, true))") + checkAnswer(q1, Row(true) :: Nil) + assertNoLiteralNullInPlan(q1) + + // ArrayFilter + val q2 = df1.selectExpr("FILTER(a, e -> IF(e is null, null, true))") + checkAnswer(q2, Row(Seq[Any](1, 3)) :: Nil) + assertNoLiteralNullInPlan(q2) + + // MapFilter + val q3 = df2.selectExpr("MAP_FILTER(m, (k, v) -> IF(v is null, null, true))") + checkAnswer(q3, Row(Map[Any, Any](1 -> 1, 3 -> 3))) + assertNoLiteralNullInPlan(q3) + } + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala index c45f7012a80a9..07b222bfae3a5 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala @@ -240,16 +240,6 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext { Seq(Row("1"), Row("2"))) } - test("SPARK-11226 Skip empty line in json file") { - spark.read - .json(Seq("{\"a\": \"1\"}}", "{\"a\": \"2\"}}", "{\"a\": \"3\"}}", "").toDS()) - .createOrReplaceTempView("d") - - checkAnswer( - sql("select count(1) from d"), - Seq(Row(3))) - } - test("SPARK-8828 sum should return null if all input values are null") { checkAnswer( sql("select sum(a), avg(a) from allNulls"), @@ -2864,6 +2854,59 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext { checkAnswer(sql("select 26393499451 / (1e6 * 1000)"), Row(BigDecimal("26.3934994510000"))) } } + + test("SPARK-25988: self join with aliases on partitioned tables #1") { + withTempView("tmpView1", "tmpView2") { + withTable("tab1", "tab2") { + sql( + """ + |CREATE TABLE `tab1` (`col1` INT, `TDATE` DATE) + |USING CSV + |PARTITIONED BY (TDATE) + """.stripMargin) + spark.table("tab1").where("TDATE >= '2017-08-15'").createOrReplaceTempView("tmpView1") + sql("CREATE TABLE `tab2` (`TDATE` DATE) USING parquet") + sql( + """ + |CREATE OR REPLACE TEMPORARY VIEW tmpView2 AS + |SELECT N.tdate, col1 AS aliasCol1 + |FROM tmpView1 N + |JOIN tab2 Z + |ON N.tdate = Z.tdate + """.stripMargin) + withSQLConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "0") { + sql("SELECT * FROM tmpView2 x JOIN tmpView2 y ON x.tdate = y.tdate").collect() + } + } + } + } + + test("SPARK-25988: self join with aliases on partitioned tables #2") { + withTempView("tmp") { + withTable("tab1", "tab2") { + sql( + """ + |CREATE TABLE `tab1` (`EX` STRING, `TDATE` DATE) + |USING parquet + |PARTITIONED BY (tdate) + """.stripMargin) + sql("CREATE TABLE `tab2` (`TDATE` DATE) USING parquet") + sql( + """ + |CREATE OR REPLACE TEMPORARY VIEW TMP as + |SELECT N.tdate, EX AS new_ex + |FROM tab1 N + |JOIN tab2 Z + |ON N.tdate = Z.tdate + """.stripMargin) + sql( + """ + |SELECT * FROM TMP x JOIN TMP y + |ON x.tdate = y.tdate + """.stripMargin).queryExecution.executedPlan + } + } + } } case class Foo(bar: Option[String]) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQueryTestSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQueryTestSuite.scala index 826408c7161e9..6ca3ac596e5f4 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQueryTestSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQueryTestSuite.scala @@ -272,6 +272,7 @@ class SQLQueryTestSuite extends QueryTest with SharedSQLContext { .replaceAll("Created By.*", s"Created By $notIncludedMsg") .replaceAll("Created Time.*", s"Created Time $notIncludedMsg") .replaceAll("Last Access.*", s"Last Access $notIncludedMsg") + .replaceAll("Partition Statistics\t\\d+", s"Partition Statistics\t$notIncludedMsg") .replaceAll("\\*\\(\\d+\\) ", "*")) // remove the WholeStageCodegen codegenStageIds // If the output is not pre-sorted, sort it. diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SubquerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SubquerySuite.scala index cbffed994bb4f..5088821ad7361 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SubquerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SubquerySuite.scala @@ -1268,4 +1268,16 @@ class SubquerySuite extends QueryTest with SharedSQLContext { assert(getNumSortsInQuery(query5) == 1) } } + + test("SPARK-25482: Forbid pushdown to datasources of filters containing subqueries") { + withTempView("t1", "t2") { + sql("create temporary view t1(a int) using parquet") + sql("create temporary view t2(b int) using parquet") + val plan = sql("select * from t2 where b > (select max(a) from t1)") + val subqueries = plan.queryExecution.executedPlan.collect { + case p => p.subqueries + }.flatten + assert(subqueries.length == 1) + } + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala index 3b301a4f8144a..20dcefa7e3cad 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala @@ -413,7 +413,7 @@ class UDFSuite extends QueryTest with SharedSQLContext { test("SPARK-25044 Verify null input handling for primitive types - with udf.register") { withTable("t") { - Seq((null, new Integer(1), "x"), ("M", null, "y"), ("N", new Integer(3), null)) + Seq((null, Integer.valueOf(1), "x"), ("M", null, "y"), ("N", Integer.valueOf(3), null)) .toDF("a", "b", "c").write.format("json").saveAsTable("t") spark.udf.register("f", (a: String, b: Int, c: Any) => a + b + c) val df = spark.sql("SELECT f(a, b, c) FROM t") diff --git a/sql/core/src/test/scala/org/apache/spark/sql/UserDefinedTypeSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/UserDefinedTypeSuite.scala index cc8b600efa46a..cf956316057eb 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/UserDefinedTypeSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/UserDefinedTypeSuite.scala @@ -17,8 +17,6 @@ package org.apache.spark.sql -import scala.beans.{BeanInfo, BeanProperty} - import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow} import org.apache.spark.sql.catalyst.expressions.{Cast, ExpressionEvalHelper, GenericInternalRow, Literal} @@ -28,10 +26,10 @@ import org.apache.spark.sql.functions._ import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.sql.types._ -@BeanInfo -private[sql] case class MyLabeledPoint( - @BeanProperty label: Double, - @BeanProperty features: UDT.MyDenseVector) +private[sql] case class MyLabeledPoint(label: Double, features: UDT.MyDenseVector) { + def getLabel: Double = label + def getFeatures: UDT.MyDenseVector = features +} // Wrapped in an object to check Scala compatibility. See SPARK-13929 object UDT { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala index 409474552b1cf..c97041a8f341c 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala @@ -787,6 +787,6 @@ private case class DummySparkPlan( override val requiredChildDistribution: Seq[Distribution] = Nil, override val requiredChildOrdering: Seq[Seq[SortOrder]] = Nil ) extends SparkPlan { - override protected def doExecute(): RDD[InternalRow] = throw new NotImplementedError + override protected def doExecute(): RDD[InternalRow] = throw new UnsupportedOperationException override def output: Seq[Attribute] = Seq.empty } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/QueryExecutionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/QueryExecutionSuite.scala index 964440346deb0..0c47a2040f171 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/QueryExecutionSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/QueryExecutionSuite.scala @@ -16,11 +16,96 @@ */ package org.apache.spark.sql.execution +import scala.io.Source + import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, OneRowRelation} +import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.test.SharedSQLContext +case class QueryExecutionTestRecord( + c0: Int, c1: Int, c2: Int, c3: Int, c4: Int, + c5: Int, c6: Int, c7: Int, c8: Int, c9: Int, + c10: Int, c11: Int, c12: Int, c13: Int, c14: Int, + c15: Int, c16: Int, c17: Int, c18: Int, c19: Int, + c20: Int, c21: Int, c22: Int, c23: Int, c24: Int, + c25: Int, c26: Int) + class QueryExecutionSuite extends SharedSQLContext { + import testImplicits._ + + def checkDumpedPlans(path: String, expected: Int): Unit = { + assert(Source.fromFile(path).getLines.toList + .takeWhile(_ != "== Whole Stage Codegen ==") == List( + "== Parsed Logical Plan ==", + s"Range (0, $expected, step=1, splits=Some(2))", + "", + "== Analyzed Logical Plan ==", + "id: bigint", + s"Range (0, $expected, step=1, splits=Some(2))", + "", + "== Optimized Logical Plan ==", + s"Range (0, $expected, step=1, splits=Some(2))", + "", + "== Physical Plan ==", + s"*(1) Range (0, $expected, step=1, splits=2)", + "")) + } + test("dumping query execution info to a file") { + withTempDir { dir => + val path = dir.getCanonicalPath + "/plans.txt" + val df = spark.range(0, 10) + df.queryExecution.debug.toFile(path) + + checkDumpedPlans(path, expected = 10) + } + } + + test("dumping query execution info to an existing file") { + withTempDir { dir => + val path = dir.getCanonicalPath + "/plans.txt" + val df = spark.range(0, 10) + df.queryExecution.debug.toFile(path) + + val df2 = spark.range(0, 1) + df2.queryExecution.debug.toFile(path) + checkDumpedPlans(path, expected = 1) + } + } + + test("dumping query execution info to non-existing folder") { + withTempDir { dir => + val path = dir.getCanonicalPath + "/newfolder/plans.txt" + val df = spark.range(0, 100) + df.queryExecution.debug.toFile(path) + checkDumpedPlans(path, expected = 100) + } + } + + test("dumping query execution info by invalid path") { + val path = "1234567890://plans.txt" + val exception = intercept[IllegalArgumentException] { + spark.range(0, 100).queryExecution.debug.toFile(path) + } + + assert(exception.getMessage.contains("Illegal character in scheme name")) + } + + test("limit number of fields by sql config") { + def relationPlans: String = { + val ds = spark.createDataset(Seq(QueryExecutionTestRecord( + 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, + 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26))) + ds.queryExecution.toString + } + withSQLConf(SQLConf.MAX_TO_STRING_FIELDS.key -> "26") { + assert(relationPlans.contains("more fields")) + } + withSQLConf(SQLConf.MAX_TO_STRING_FIELDS.key -> "27") { + assert(!relationPlans.contains("more fields")) + } + } + test("toString() exception/error handling") { spark.experimental.extraStrategies = Seq( new SparkStrategy { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/QueryPlanningTrackerEndToEndSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/QueryPlanningTrackerEndToEndSuite.scala new file mode 100644 index 0000000000000..0af4c85400e9e --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/QueryPlanningTrackerEndToEndSuite.scala @@ -0,0 +1,52 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution + +import org.apache.spark.sql.test.SharedSQLContext + +class QueryPlanningTrackerEndToEndSuite extends SharedSQLContext { + + test("programmatic API") { + val df = spark.range(1000).selectExpr("count(*)") + df.collect() + val tracker = df.queryExecution.tracker + + assert(tracker.phases.size == 3) + assert(tracker.phases("analysis") > 0) + assert(tracker.phases("optimization") > 0) + assert(tracker.phases("planning") > 0) + + assert(tracker.rules.nonEmpty) + } + + test("sql") { + val df = spark.sql("select * from range(1)") + df.collect() + + val tracker = df.queryExecution.tracker + + assert(tracker.phases.size == 4) + assert(tracker.phases("parsing") > 0) + assert(tracker.phases("analysis") > 0) + assert(tracker.phases("optimization") > 0) + assert(tracker.phases("planning") > 0) + + assert(tracker.rules.nonEmpty) + } + +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeRowSerializerSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeRowSerializerSuite.scala index d305ce3e698ae..96b3aa5ee75b8 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeRowSerializerSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeRowSerializerSuite.scala @@ -26,6 +26,7 @@ import org.apache.spark.rdd.RDD import org.apache.spark.sql.{LocalSparkSession, Row, SparkSession} import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow} import org.apache.spark.sql.catalyst.expressions.{UnsafeProjection, UnsafeRow} +import org.apache.spark.sql.execution.metric.SQLMetrics import org.apache.spark.sql.types._ import org.apache.spark.storage.ShuffleBlockId import org.apache.spark.util.collection.ExternalSorter @@ -137,7 +138,9 @@ class UnsafeRowSerializerSuite extends SparkFunSuite with LocalSparkSession { rowsRDD, new PartitionIdPassthrough(2), new UnsafeRowSerializer(2)) - val shuffled = new ShuffledRowRDD(dependency) + val shuffled = new ShuffledRowRDD( + dependency, + SQLMetrics.getShuffleReadMetrics(spark.sparkContext)) shuffled.count() } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/BenchmarkWithCodegen.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/BenchmarkWithCodegen.scala deleted file mode 100644 index 51331500479a3..0000000000000 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/BenchmarkWithCodegen.scala +++ /dev/null @@ -1,54 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.execution.benchmark - -import org.apache.spark.SparkFunSuite -import org.apache.spark.benchmark.Benchmark -import org.apache.spark.sql.SparkSession - -/** - * Common base trait for micro benchmarks that are supposed to run standalone (i.e. not together - * with other test suites). - */ -private[benchmark] trait BenchmarkWithCodegen extends SparkFunSuite { - - lazy val sparkSession = SparkSession.builder - .master("local[1]") - .appName("microbenchmark") - .config("spark.sql.shuffle.partitions", 1) - .config("spark.sql.autoBroadcastJoinThreshold", 1) - .getOrCreate() - - /** Runs function `f` with whole stage codegen on and off. */ - def runBenchmark(name: String, cardinality: Long)(f: => Unit): Unit = { - val benchmark = new Benchmark(name, cardinality) - - benchmark.addCase(s"$name wholestage off", numIters = 2) { iter => - sparkSession.conf.set("spark.sql.codegen.wholeStage", value = false) - f - } - - benchmark.addCase(s"$name wholestage on", numIters = 5) { iter => - sparkSession.conf.set("spark.sql.codegen.wholeStage", value = true) - f - } - - benchmark.run() - } - -} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/DataSourceReadBenchmark.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/DataSourceReadBenchmark.scala index a1f51f8e54805..ecd9ead0ae39a 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/DataSourceReadBenchmark.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/DataSourceReadBenchmark.scala @@ -447,7 +447,9 @@ object DataSourceReadBenchmark extends BenchmarkBase with SQLHelper { } def stringWithNullsScanBenchmark(values: Int, fractionOfNulls: Double): Unit = { - val benchmark = new Benchmark("String with Nulls Scan", values, output = output) + val percentageOfNulls = fractionOfNulls * 100 + val benchmark = + new Benchmark(s"String with Nulls Scan ($percentageOfNulls%)", values, output = output) withTempPath { dir => withTempTable("t1", "csvTable", "jsonTable", "parquetTable", "orcTable") { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/WideTableBenchmark.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/WideTableBenchmark.scala index ffefef1d4fce3..52426d81bd1a7 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/WideTableBenchmark.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/WideTableBenchmark.scala @@ -42,7 +42,7 @@ object WideTableBenchmark extends SqlBasedBenchmark { Seq("10", "100", "1024", "2048", "4096", "8192", "65536").foreach { n => benchmark.addCase(s"split threshold $n", numIters = 5) { iter => withSQLConf(SQLConf.CODEGEN_METHOD_SPLIT_THRESHOLD.key -> n) { - df.selectExpr(columns: _*).foreach(identity(_)) + df.selectExpr(columns: _*).foreach(_ => ()) } } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/ColumnStatsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/ColumnStatsSuite.scala index d4e7e362c6c8c..3121b7e99c99d 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/ColumnStatsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/ColumnStatsSuite.scala @@ -39,7 +39,7 @@ class ColumnStatsSuite extends SparkFunSuite { val columnStatsName = columnStatsClass.getSimpleName test(s"$columnStatsName: empty") { - val columnStats = columnStatsClass.newInstance() + val columnStats = columnStatsClass.getConstructor().newInstance() columnStats.collectedStatistics.zip(initialStatistics).foreach { case (actual, expected) => assert(actual === expected) } @@ -48,7 +48,7 @@ class ColumnStatsSuite extends SparkFunSuite { test(s"$columnStatsName: non-empty") { import org.apache.spark.sql.execution.columnar.ColumnarTestUtils._ - val columnStats = columnStatsClass.newInstance() + val columnStats = columnStatsClass.getConstructor().newInstance() val rows = Seq.fill(10)(makeRandomRow(columnType)) ++ Seq.fill(10)(makeNullRow(1)) rows.foreach(columnStats.gatherStats(_, 0)) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/InMemoryColumnarQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/InMemoryColumnarQuerySuite.scala index ab38d5294fa3d..d35268580ba38 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/InMemoryColumnarQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/InMemoryColumnarQuerySuite.scala @@ -506,7 +506,7 @@ class InMemoryColumnarQuerySuite extends QueryTest with SharedSQLContext { case plan: InMemoryRelation => plan }.head // InMemoryRelation's stats is file size before the underlying RDD is materialized - assert(inMemoryRelation.computeStats().sizeInBytes === 848) + assert(inMemoryRelation.computeStats().sizeInBytes === 916) // InMemoryRelation's stats is updated after materializing RDD dfFromFile.collect() @@ -519,7 +519,7 @@ class InMemoryColumnarQuerySuite extends QueryTest with SharedSQLContext { // Even CBO enabled, InMemoryRelation's stats keeps as the file size before table's stats // is calculated - assert(inMemoryRelation2.computeStats().sizeInBytes === 848) + assert(inMemoryRelation2.computeStats().sizeInBytes === 916) // InMemoryRelation's stats should be updated after calculating stats of the table // clear cache to simulate a fresh environment diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/compression/IntegralDeltaSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/compression/IntegralDeltaSuite.scala index 0d9f1fb0c02c9..fb3388452e4e5 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/compression/IntegralDeltaSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/compression/IntegralDeltaSuite.scala @@ -46,6 +46,7 @@ class IntegralDeltaSuite extends SparkFunSuite { (input.tail, input.init).zipped.map { case (x: Int, y: Int) => (x - y).toLong case (x: Long, y: Long) => x - y + case other => fail(s"Unexpected input $other") } } @@ -116,7 +117,7 @@ class IntegralDeltaSuite extends SparkFunSuite { val row = new GenericInternalRow(1) val nullRow = new GenericInternalRow(1) nullRow.setNullAt(0) - input.map { value => + input.foreach { value => if (value == nullValue) { builder.appendFrom(nullRow, 0) } else { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/FileSourceStrategySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/FileSourceStrategySuite.scala index 0fa59c8d482cd..5ee66c0cccb33 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/FileSourceStrategySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/FileSourceStrategySuite.scala @@ -629,7 +629,7 @@ class TestFileFormat extends TextBasedFileFormat { job: Job, options: Map[String, String], dataSchema: StructType): OutputWriterFactory = { - throw new NotImplementedError("JUST FOR TESTING") + throw new UnsupportedOperationException("JUST FOR TESTING") } override def buildReader( diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/HadoopFsRelationSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/HadoopFsRelationSuite.scala index f62c544428c0c..1229d31f4cc07 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/HadoopFsRelationSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/HadoopFsRelationSuite.scala @@ -45,7 +45,7 @@ class HadoopFsRelationSuite extends QueryTest with SharedSQLContext { import testImplicits._ Seq(1.0, 0.5).foreach { compressionFactor => withSQLConf("spark.sql.sources.fileCompressionFactor" -> compressionFactor.toString, - "spark.sql.autoBroadcastJoinThreshold" -> "424") { + "spark.sql.autoBroadcastJoinThreshold" -> "458") { withTempPath { workDir => // the file size is 740 bytes val workDirPath = workDir.getAbsolutePath diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala index d43efc8776b00..c275d63d32cc8 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala @@ -18,7 +18,7 @@ package org.apache.spark.sql.execution.datasources.csv import java.io.File -import java.nio.charset.{Charset, UnsupportedCharsetException} +import java.nio.charset.{Charset, StandardCharsets, UnsupportedCharsetException} import java.nio.file.Files import java.sql.{Date, Timestamp} import java.text.SimpleDateFormat @@ -33,7 +33,7 @@ import org.apache.hadoop.io.compress.GzipCodec import org.apache.log4j.{AppenderSkeleton, LogManager} import org.apache.log4j.spi.LoggingEvent -import org.apache.spark.SparkException +import org.apache.spark.{SparkException, TestUtils} import org.apache.spark.sql.{AnalysisException, DataFrame, QueryTest, Row} import org.apache.spark.sql.catalyst.util.DateTimeUtils import org.apache.spark.sql.internal.SQLConf @@ -1848,4 +1848,142 @@ class CSVSuite extends QueryTest with SharedSQLContext with SQLTestUtils with Te val schema = new StructType().add("a", StringType).add("b", IntegerType) checkAnswer(spark.read.schema(schema).option("delimiter", delimiter).csv(input), Row("abc", 1)) } + + test("using spark.sql.columnNameOfCorruptRecord") { + withSQLConf(SQLConf.COLUMN_NAME_OF_CORRUPT_RECORD.key -> "_unparsed") { + val csv = "\"" + val df = spark.read + .schema("a int, _unparsed string") + .csv(Seq(csv).toDS()) + + checkAnswer(df, Row(null, csv)) + } + } + + test("encoding in multiLine mode") { + val df = spark.range(3).toDF() + Seq("UTF-8", "ISO-8859-1", "CP1251", "US-ASCII", "UTF-16BE", "UTF-32LE").foreach { encoding => + Seq(true, false).foreach { header => + withTempPath { path => + df.write + .option("encoding", encoding) + .option("header", header) + .csv(path.getCanonicalPath) + val readback = spark.read + .option("multiLine", true) + .option("encoding", encoding) + .option("inferSchema", true) + .option("header", header) + .csv(path.getCanonicalPath) + checkAnswer(readback, df) + } + } + } + } + + test("""Support line separator - default value \r, \r\n and \n""") { + val data = "\"a\",1\r\"c\",2\r\n\"d\",3\n" + + withTempPath { path => + Files.write(path.toPath, data.getBytes(StandardCharsets.UTF_8)) + val df = spark.read.option("inferSchema", true).csv(path.getAbsolutePath) + val expectedSchema = + StructType(StructField("_c0", StringType) :: StructField("_c1", IntegerType) :: Nil) + checkAnswer(df, Seq(("a", 1), ("c", 2), ("d", 3)).toDF()) + assert(df.schema === expectedSchema) + } + } + + def testLineSeparator(lineSep: String, encoding: String, inferSchema: Boolean, id: Int): Unit = { + test(s"Support line separator in ${encoding} #${id}") { + // Read + val data = + s""""a",1$lineSep + |c,2$lineSep" + |d",3""".stripMargin + val dataWithTrailingLineSep = s"$data$lineSep" + + Seq(data, dataWithTrailingLineSep).foreach { lines => + withTempPath { path => + Files.write(path.toPath, lines.getBytes(encoding)) + val schema = StructType(StructField("_c0", StringType) + :: StructField("_c1", LongType) :: Nil) + + val expected = Seq(("a", 1), ("\nc", 2), ("\nd", 3)) + .toDF("_c0", "_c1") + Seq(false, true).foreach { multiLine => + val reader = spark + .read + .option("lineSep", lineSep) + .option("multiLine", multiLine) + .option("encoding", encoding) + val df = if (inferSchema) { + reader.option("inferSchema", true).csv(path.getAbsolutePath) + } else { + reader.schema(schema).csv(path.getAbsolutePath) + } + checkAnswer(df, expected) + } + } + } + + // Write + withTempPath { path => + Seq("a", "b", "c").toDF("value").coalesce(1) + .write + .option("lineSep", lineSep) + .option("encoding", encoding) + .csv(path.getAbsolutePath) + val partFile = TestUtils.recursiveList(path).filter(f => f.getName.startsWith("part-")).head + val readBack = new String(Files.readAllBytes(partFile.toPath), encoding) + assert( + readBack === s"a${lineSep}b${lineSep}c${lineSep}") + } + + // Roundtrip + withTempPath { path => + val df = Seq("a", "b", "c").toDF() + df.write + .option("lineSep", lineSep) + .option("encoding", encoding) + .csv(path.getAbsolutePath) + val readBack = spark + .read + .option("lineSep", lineSep) + .option("encoding", encoding) + .csv(path.getAbsolutePath) + checkAnswer(df, readBack) + } + } + } + + // scalastyle:off nonascii + List( + (0, "|", "UTF-8", false), + (1, "^", "UTF-16BE", true), + (2, ":", "ISO-8859-1", true), + (3, "!", "UTF-32LE", false), + (4, 0x1E.toChar.toString, "UTF-8", true), + (5, "아", "UTF-32BE", false), + (6, "у", "CP1251", true), + (8, "\r", "UTF-16LE", true), + (9, "\u000d", "UTF-32BE", false), + (10, "=", "US-ASCII", false), + (11, "$", "utf-32le", true) + ).foreach { case (testNum, sep, encoding, inferSchema) => + testLineSeparator(sep, encoding, inferSchema, testNum) + } + // scalastyle:on nonascii + + test("lineSep restrictions") { + val errMsg1 = intercept[IllegalArgumentException] { + spark.read.option("lineSep", "").csv(testFile(carsFile)).collect + }.getMessage + assert(errMsg1.contains("'lineSep' cannot be an empty string")) + + val errMsg2 = intercept[IllegalArgumentException] { + spark.read.option("lineSep", "123").csv(testFile(carsFile)).collect + }.getMessage + assert(errMsg2.contains("'lineSep' can contain only 1 character")) + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala index 06032ded42a53..9ea9189cdf7f4 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala @@ -1115,6 +1115,7 @@ class JsonSuite extends QueryTest with SharedSQLContext with TestJsonData { Row(null, null, null), Row(null, null, null), Row(null, null, null), + Row(null, null, null), Row("str_a_4", "str_b_4", "str_c_4"), Row(null, null, null)) ) @@ -1136,6 +1137,7 @@ class JsonSuite extends QueryTest with SharedSQLContext with TestJsonData { checkAnswer( jsonDF.select($"a", $"b", $"c", $"_unparsed"), Row(null, null, null, "{") :: + Row(null, null, null, "") :: Row(null, null, null, """{"a":1, b:2}""") :: Row(null, null, null, """{"a":{, b:3}""") :: Row("str_a_4", "str_b_4", "str_c_4", null) :: @@ -1150,6 +1152,7 @@ class JsonSuite extends QueryTest with SharedSQLContext with TestJsonData { checkAnswer( jsonDF.filter($"_unparsed".isNotNull).select($"_unparsed"), Row("{") :: + Row("") :: Row("""{"a":1, b:2}""") :: Row("""{"a":{, b:3}""") :: Row("]") :: Nil @@ -1171,6 +1174,7 @@ class JsonSuite extends QueryTest with SharedSQLContext with TestJsonData { checkAnswer( jsonDF.selectExpr("a", "b", "c", "_malformed"), Row(null, null, null, "{") :: + Row(null, null, null, "") :: Row(null, null, null, """{"a":1, b:2}""") :: Row(null, null, null, """{"a":{, b:3}""") :: Row("str_a_4", "str_b_4", "str_c_4", null) :: @@ -1813,6 +1817,7 @@ class JsonSuite extends QueryTest with SharedSQLContext with TestJsonData { val path = dir.getCanonicalPath primitiveFieldAndType .toDF("value") + .repartition(1) .write .option("compression", "GzIp") .text(path) @@ -1838,6 +1843,7 @@ class JsonSuite extends QueryTest with SharedSQLContext with TestJsonData { val path = dir.getCanonicalPath primitiveFieldAndType .toDF("value") + .repartition(1) .write .text(path) @@ -1892,7 +1898,7 @@ class JsonSuite extends QueryTest with SharedSQLContext with TestJsonData { .text(path) val jsonDF = spark.read.option("multiLine", true).option("mode", "PERMISSIVE").json(path) - assert(jsonDF.count() === corruptRecordCount) + assert(jsonDF.count() === corruptRecordCount + 1) // null row for empty file assert(jsonDF.schema === new StructType() .add("_corrupt_record", StringType) .add("dummy", StringType)) @@ -1905,7 +1911,7 @@ class JsonSuite extends QueryTest with SharedSQLContext with TestJsonData { F.count($"dummy").as("valid"), F.count($"_corrupt_record").as("corrupt"), F.count("*").as("count")) - checkAnswer(counts, Row(1, 4, 6)) + checkAnswer(counts, Row(1, 5, 7)) // null row for empty file } } @@ -2513,7 +2519,7 @@ class JsonSuite extends QueryTest with SharedSQLContext with TestJsonData { } checkCount(2) - countForMalformedJSON(0, Seq("")) + countForMalformedJSON(1, Seq("")) } test("SPARK-25040: empty strings should be disallowed") { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/orc/OrcSourceSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/orc/OrcSourceSuite.scala index dc81c0585bf18..48910103e702a 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/orc/OrcSourceSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/orc/OrcSourceSuite.scala @@ -18,6 +18,7 @@ package org.apache.spark.sql.execution.datasources.orc import java.io.File +import java.nio.charset.StandardCharsets.UTF_8 import java.sql.Timestamp import java.util.Locale @@ -30,7 +31,8 @@ import org.apache.orc.OrcProto.Stream.Kind import org.apache.orc.impl.RecordReaderImpl import org.scalatest.BeforeAndAfterAll -import org.apache.spark.sql.Row +import org.apache.spark.SPARK_VERSION_SHORT +import org.apache.spark.sql.{Row, SPARK_VERSION_METADATA_KEY} import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.util.Utils @@ -314,6 +316,22 @@ abstract class OrcSuite extends OrcTest with BeforeAndAfterAll { checkAnswer(spark.read.orc(path.getCanonicalPath), Row(ts)) } } + + test("Write Spark version into ORC file metadata") { + withTempPath { path => + spark.range(1).repartition(1).write.orc(path.getCanonicalPath) + + val partFiles = path.listFiles() + .filter(f => f.isFile && !f.getName.startsWith(".") && !f.getName.startsWith("_")) + assert(partFiles.length === 1) + + val orcFilePath = new Path(partFiles.head.getAbsolutePath) + val readerOptions = OrcFile.readerOptions(new Configuration()) + val reader = OrcFile.createReader(orcFilePath, readerOptions) + val version = UTF_8.decode(reader.getMetadataValue(SPARK_VERSION_METADATA_KEY)).toString + assert(version === SPARK_VERSION_SHORT) + } + } } class OrcSourceSuite extends OrcSuite with SharedSQLContext { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetIOSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetIOSuite.scala index 002c42f23bd64..6b05b9c0f7207 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetIOSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetIOSuite.scala @@ -27,6 +27,7 @@ import scala.reflect.runtime.universe.TypeTag import org.apache.hadoop.conf.Configuration import org.apache.hadoop.fs.{FileSystem, Path} import org.apache.hadoop.mapreduce.{JobContext, TaskAttemptContext} +import org.apache.parquet.HadoopReadOptions import org.apache.parquet.column.{Encoding, ParquetProperties} import org.apache.parquet.example.data.{Group, GroupWriter} import org.apache.parquet.example.data.simple.SimpleGroup @@ -34,10 +35,11 @@ import org.apache.parquet.hadoop._ import org.apache.parquet.hadoop.api.WriteSupport import org.apache.parquet.hadoop.api.WriteSupport.WriteContext import org.apache.parquet.hadoop.metadata.CompressionCodecName +import org.apache.parquet.hadoop.util.HadoopInputFile import org.apache.parquet.io.api.RecordConsumer import org.apache.parquet.schema.{MessageType, MessageTypeParser} -import org.apache.spark.SparkException +import org.apache.spark.{SPARK_VERSION_SHORT, SparkException} import org.apache.spark.sql._ import org.apache.spark.sql.catalyst.{InternalRow, ScalaReflection} import org.apache.spark.sql.catalyst.expressions.{GenericInternalRow, UnsafeRow} @@ -799,6 +801,23 @@ class ParquetIOSuite extends QueryTest with ParquetTest with SharedSQLContext { checkAnswer(spark.read.parquet(file.getAbsolutePath), Seq(Row(Row(1, null, "foo")))) } } + + test("Write Spark version into Parquet metadata") { + withTempPath { dir => + val path = dir.getAbsolutePath + spark.range(1).repartition(1).write.parquet(path) + val file = SpecificParquetRecordReaderBase.listDirectory(dir).get(0) + + val conf = new Configuration() + val hadoopInputFile = HadoopInputFile.fromPath(new Path(file), conf) + val parquetReadOptions = HadoopReadOptions.builder(conf).build() + val m = ParquetFileReader.open(hadoopInputFile, parquetReadOptions) + val metaData = m.getFileMetaData.getKeyValueMetaData + m.close() + + assert(metaData.get(SPARK_VERSION_METADATA_KEY) === SPARK_VERSION_SHORT) + } + } } class JobCommitFailureParquetOutputCommitter(outputPath: Path, context: TaskAttemptContext) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala index b955c157a620e..0f1d08b6af5d5 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala @@ -94,8 +94,13 @@ class SQLMetricsSuite extends SparkFunSuite with SQLMetricsTestUtils with Shared "avg hash probe (min, med, max)" -> "\n(1, 1, 1)"), Map("number of output rows" -> 1L, "avg hash probe (min, med, max)" -> "\n(1, 1, 1)")) + val shuffleExpected1 = Map( + "records read" -> 2L, + "local blocks fetched" -> 2L, + "remote blocks fetched" -> 0L) testSparkPlanMetrics(df, 1, Map( 2L -> (("HashAggregate", expected1(0))), + 1L -> (("Exchange", shuffleExpected1)), 0L -> (("HashAggregate", expected1(1)))) ) @@ -106,8 +111,13 @@ class SQLMetricsSuite extends SparkFunSuite with SQLMetricsTestUtils with Shared "avg hash probe (min, med, max)" -> "\n(1, 1, 1)"), Map("number of output rows" -> 3L, "avg hash probe (min, med, max)" -> "\n(1, 1, 1)")) + val shuffleExpected2 = Map( + "records read" -> 4L, + "local blocks fetched" -> 4L, + "remote blocks fetched" -> 0L) testSparkPlanMetrics(df2, 1, Map( 2L -> (("HashAggregate", expected2(0))), + 1L -> (("Exchange", shuffleExpected2)), 0L -> (("HashAggregate", expected2(1)))) ) } @@ -191,7 +201,11 @@ class SQLMetricsSuite extends SparkFunSuite with SQLMetricsTestUtils with Shared testSparkPlanMetrics(df, 1, Map( 0L -> (("SortMergeJoin", Map( // It's 4 because we only read 3 rows in the first partition and 1 row in the second one - "number of output rows" -> 4L)))) + "number of output rows" -> 4L))), + 2L -> (("Exchange", Map( + "records read" -> 4L, + "local blocks fetched" -> 2L, + "remote blocks fetched" -> 0L)))) ) } } @@ -208,7 +222,7 @@ class SQLMetricsSuite extends SparkFunSuite with SQLMetricsTestUtils with Shared "SELECT * FROM testData2 left JOIN testDataForJoin ON testData2.a = testDataForJoin.a") testSparkPlanMetrics(df, 1, Map( 0L -> (("SortMergeJoin", Map( - // It's 4 because we only read 3 rows in the first partition and 1 row in the second one + // It's 8 because we read 6 rows in the left and 2 row in the right one "number of output rows" -> 8L)))) ) @@ -216,7 +230,7 @@ class SQLMetricsSuite extends SparkFunSuite with SQLMetricsTestUtils with Shared "SELECT * FROM testDataForJoin right JOIN testData2 ON testData2.a = testDataForJoin.a") testSparkPlanMetrics(df2, 1, Map( 0L -> (("SortMergeJoin", Map( - // It's 4 because we only read 3 rows in the first partition and 1 row in the second one + // It's 8 because we read 6 rows in the left and 2 row in the right one "number of output rows" -> 8L)))) ) } @@ -287,7 +301,6 @@ class SQLMetricsSuite extends SparkFunSuite with SQLMetricsTestUtils with Shared // Assume the execution plan is // ... -> ShuffledHashJoin(nodeId = 1) -> Project(nodeId = 0) val df = df1.join(df2, "key") - val metrics = getSparkPlanMetrics(df, 1, Set(1L)) testSparkPlanMetrics(df, 1, Map( 1L -> (("ShuffledHashJoin", Map( "number of output rows" -> 2L, diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/ProcessingTimeExecutorSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/ProcessingTimeExecutorSuite.scala index 80c76915e4c23..2d338ab92211e 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/ProcessingTimeExecutorSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/ProcessingTimeExecutorSuite.scala @@ -19,9 +19,6 @@ package org.apache.spark.sql.execution.streaming import java.util.concurrent.ConcurrentHashMap -import scala.collection.mutable - -import org.eclipse.jetty.util.ConcurrentHashSet import org.scalatest.concurrent.{Eventually, Signaler, ThreadSignaler, TimeLimits} import org.scalatest.concurrent.PatienceConfiguration.Timeout import org.scalatest.time.SpanSugar._ @@ -48,7 +45,7 @@ class ProcessingTimeExecutorSuite extends SparkFunSuite with TimeLimits { } test("trigger timing") { - val triggerTimes = new ConcurrentHashSet[Int] + val triggerTimes = ConcurrentHashMap.newKeySet[Int]() val clock = new StreamManualClock() @volatile var continueExecuting = true @volatile var clockIncrementInTrigger = 0L diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/sources/RateStreamProviderSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/sources/RateStreamProviderSuite.scala index dd74af873c2e5..be3efed714030 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/sources/RateStreamProviderSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/sources/RateStreamProviderSuite.scala @@ -53,7 +53,8 @@ class RateSourceSuite extends StreamTest { test("microbatch in registry") { withTempDir { temp => - DataSource.lookupDataSource("rate", spark.sqlContext.conf).newInstance() match { + DataSource.lookupDataSource("rate", spark.sqlContext.conf). + getConstructor().newInstance() match { case ds: MicroBatchReadSupportProvider => val readSupport = ds.createMicroBatchReadSupport( temp.getCanonicalPath, DataSourceOptions.empty()) @@ -66,7 +67,7 @@ class RateSourceSuite extends StreamTest { test("compatible with old path in registry") { DataSource.lookupDataSource("org.apache.spark.sql.execution.streaming.RateSourceProvider", - spark.sqlContext.conf).newInstance() match { + spark.sqlContext.conf).getConstructor().newInstance() match { case ds: MicroBatchReadSupportProvider => assert(ds.isInstanceOf[RateStreamProvider]) case _ => @@ -320,7 +321,8 @@ class RateSourceSuite extends StreamTest { } test("continuous in registry") { - DataSource.lookupDataSource("rate", spark.sqlContext.conf).newInstance() match { + DataSource.lookupDataSource("rate", spark.sqlContext.conf). + getConstructor().newInstance() match { case ds: ContinuousReadSupportProvider => val readSupport = ds.createContinuousReadSupport( "", DataSourceOptions.empty()) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/sources/TextSocketStreamSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/sources/TextSocketStreamSuite.scala index 409156e5ebc70..7db31f1f8f699 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/sources/TextSocketStreamSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/sources/TextSocketStreamSuite.scala @@ -84,7 +84,7 @@ class TextSocketStreamSuite extends StreamTest with SharedSQLContext with Before test("backward compatibility with old path") { DataSource.lookupDataSource("org.apache.spark.sql.execution.streaming.TextSocketSourceProvider", - spark.sqlContext.conf).newInstance() match { + spark.sqlContext.conf).getConstructor().newInstance() match { case ds: MicroBatchReadSupportProvider => assert(ds.isInstanceOf[TextSocketSourceProvider]) case _ => @@ -382,10 +382,9 @@ class TextSocketStreamSuite extends StreamTest with SharedSQLContext with Before tasks.foreach { case t: TextSocketContinuousInputPartition => val r = readerFactory.createReader(t).asInstanceOf[TextSocketContinuousPartitionReader] - for (i <- 0 until numRecords / 2) { + for (_ <- 0 until numRecords / 2) { r.next() - assert(r.get().get(0, TextSocketReader.SCHEMA_TIMESTAMP) - .isInstanceOf[(String, Timestamp)]) + assert(r.get().get(0, TextSocketReader.SCHEMA_TIMESTAMP).isInstanceOf[(_, _)]) } case _ => throw new IllegalStateException("Unexpected task type") } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnarBatchSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnarBatchSuite.scala index f57f07b498261..e8062dbb91e35 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnarBatchSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnarBatchSuite.scala @@ -1123,7 +1123,7 @@ class ColumnarBatchSuite extends SparkFunSuite { compareStruct(childFields, r1.getStruct(ordinal, fields.length), r2.getStruct(ordinal), seed) case _ => - throw new NotImplementedError("Not implemented " + field.dataType) + throw new UnsupportedOperationException("Not implemented " + field.dataType) } } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/DataSourceV2UtilsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/DataSourceV2UtilsSuite.scala index 4911e3225552d..f903c17923d0f 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/DataSourceV2UtilsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/DataSourceV2UtilsSuite.scala @@ -33,7 +33,7 @@ class DataSourceV2UtilsSuite extends SparkFunSuite { conf.setConfString(s"spark.sql.$keyPrefix.config.name", "false") conf.setConfString("spark.datasource.another.config.name", "123") conf.setConfString(s"spark.datasource.$keyPrefix.", "123") - val cs = classOf[DataSourceV2WithSessionConfig].newInstance() + val cs = classOf[DataSourceV2WithSessionConfig].getConstructor().newInstance() val confs = DataSourceV2Utils.extractSessionConfigs(cs.asInstanceOf[DataSourceV2], conf) assert(confs.size == 2) assert(confs.keySet.filter(_.startsWith("spark.datasource")).size == 0) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQueryListenersConfSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQueryListenersConfSuite.scala index 1aaf8a9aa2d55..ddbc175e7ea48 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQueryListenersConfSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQueryListenersConfSuite.scala @@ -30,7 +30,6 @@ class StreamingQueryListenersConfSuite extends StreamTest with BeforeAndAfter { import testImplicits._ - override protected def sparkConf: SparkConf = super.sparkConf.set("spark.sql.streaming.streamingQueryListeners", "org.apache.spark.sql.streaming.TestListener") @@ -41,6 +40,8 @@ class StreamingQueryListenersConfSuite extends StreamTest with BeforeAndAfter { StopStream ) + spark.sparkContext.listenerBus.waitUntilEmpty(5000) + assert(TestListener.queryStartedEvent != null) assert(TestListener.queryTerminatedEvent != null) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/sources/StreamingDataSourceV2Suite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/sources/StreamingDataSourceV2Suite.scala index 3a0e780a73915..31fce46c2daba 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/sources/StreamingDataSourceV2Suite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/sources/StreamingDataSourceV2Suite.scala @@ -261,7 +261,7 @@ class StreamingDataSourceV2Suite extends StreamTest { ).foreach { case (source, trigger) => test(s"SPARK-25460: session options are respected in structured streaming sources - $source") { // `keyPrefix` and `shortName` are the same in this test case - val readSource = source.newInstance().shortName() + val readSource = source.getConstructor().newInstance().shortName() val writeSource = "fake-write-microbatch-continuous" val readOptionName = "optionA" @@ -299,8 +299,10 @@ class StreamingDataSourceV2Suite extends StreamTest { for ((read, write, trigger) <- cases) { testQuietly(s"stream with read format $read, write format $write, trigger $trigger") { - val readSource = DataSource.lookupDataSource(read, spark.sqlContext.conf).newInstance() - val writeSource = DataSource.lookupDataSource(write, spark.sqlContext.conf).newInstance() + val readSource = DataSource.lookupDataSource(read, spark.sqlContext.conf). + getConstructor().newInstance() + val writeSource = DataSource.lookupDataSource(write, spark.sqlContext.conf). + getConstructor().newInstance() (readSource, writeSource, trigger) match { // Valid microbatch queries. case (_: MicroBatchReadSupportProvider, _: StreamingWriteSupportProvider, t) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/util/DataFrameCallbackSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/util/DataFrameCallbackSuite.scala index e8710aeb40bd4..ddc5dbb148cb5 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/util/DataFrameCallbackSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/util/DataFrameCallbackSuite.scala @@ -150,7 +150,7 @@ class DataFrameCallbackSuite extends QueryTest with SharedSQLContext { def getPeakExecutionMemory(stageId: Int): Long = { val peakMemoryAccumulator = sparkListener.getCompletedStageInfos(stageId).accumulables - .filter(_._2.name == InternalAccumulator.PEAK_EXECUTION_MEMORY) + .filter(_._2.name == Some(InternalAccumulator.PEAK_EXECUTION_MEMORY)) assert(peakMemoryAccumulator.size == 1) peakMemoryAccumulator.head._2.value.get.asInstanceOf[Long] diff --git a/sql/hive-thriftserver/pom.xml b/sql/hive-thriftserver/pom.xml index 55e051c3ed1be..4a4629fae2706 100644 --- a/sql/hive-thriftserver/pom.xml +++ b/sql/hive-thriftserver/pom.xml @@ -21,12 +21,12 @@ 4.0.0 org.apache.spark - spark-parent_2.11 + spark-parent_2.12 3.0.0-SNAPSHOT ../../pom.xml - spark-hive-thriftserver_2.11 + spark-hive-thriftserver_2.12 jar Spark Project Hive Thrift Server http://spark.apache.org/ diff --git a/sql/hive-thriftserver/src/gen/java/org/apache/hive/service/cli/thrift/TColumn.java b/sql/hive-thriftserver/src/gen/java/org/apache/hive/service/cli/thrift/TColumn.java index bfe50c7810f73..fc2171dc99e4c 100644 --- a/sql/hive-thriftserver/src/gen/java/org/apache/hive/service/cli/thrift/TColumn.java +++ b/sql/hive-thriftserver/src/gen/java/org/apache/hive/service/cli/thrift/TColumn.java @@ -148,7 +148,7 @@ public TColumn() { super(); } - public TColumn(_Fields setField, Object value) { + public TColumn(TColumn._Fields setField, Object value) { super(setField, value); } diff --git a/sql/hive-thriftserver/src/gen/java/org/apache/hive/service/cli/thrift/TColumnValue.java b/sql/hive-thriftserver/src/gen/java/org/apache/hive/service/cli/thrift/TColumnValue.java index 44da2cdd089d6..8504c6d608d42 100644 --- a/sql/hive-thriftserver/src/gen/java/org/apache/hive/service/cli/thrift/TColumnValue.java +++ b/sql/hive-thriftserver/src/gen/java/org/apache/hive/service/cli/thrift/TColumnValue.java @@ -142,7 +142,7 @@ public TColumnValue() { super(); } - public TColumnValue(_Fields setField, Object value) { + public TColumnValue(TColumnValue._Fields setField, Object value) { super(setField, value); } diff --git a/sql/hive-thriftserver/src/gen/java/org/apache/hive/service/cli/thrift/TGetInfoValue.java b/sql/hive-thriftserver/src/gen/java/org/apache/hive/service/cli/thrift/TGetInfoValue.java index 4fe59b1c51462..fe2a211c46309 100644 --- a/sql/hive-thriftserver/src/gen/java/org/apache/hive/service/cli/thrift/TGetInfoValue.java +++ b/sql/hive-thriftserver/src/gen/java/org/apache/hive/service/cli/thrift/TGetInfoValue.java @@ -136,7 +136,7 @@ public TGetInfoValue() { super(); } - public TGetInfoValue(_Fields setField, Object value) { + public TGetInfoValue(TGetInfoValue._Fields setField, Object value) { super(setField, value); } diff --git a/sql/hive-thriftserver/src/gen/java/org/apache/hive/service/cli/thrift/TTypeEntry.java b/sql/hive-thriftserver/src/gen/java/org/apache/hive/service/cli/thrift/TTypeEntry.java index af7c0b4f15d95..d0d70c1279572 100644 --- a/sql/hive-thriftserver/src/gen/java/org/apache/hive/service/cli/thrift/TTypeEntry.java +++ b/sql/hive-thriftserver/src/gen/java/org/apache/hive/service/cli/thrift/TTypeEntry.java @@ -136,7 +136,7 @@ public TTypeEntry() { super(); } - public TTypeEntry(_Fields setField, Object value) { + public TTypeEntry(TTypeEntry._Fields setField, Object value) { super(setField, value); } diff --git a/sql/hive-thriftserver/src/gen/java/org/apache/hive/service/cli/thrift/TTypeQualifierValue.java b/sql/hive-thriftserver/src/gen/java/org/apache/hive/service/cli/thrift/TTypeQualifierValue.java index 8c40687a0aab7..a3e3829372276 100644 --- a/sql/hive-thriftserver/src/gen/java/org/apache/hive/service/cli/thrift/TTypeQualifierValue.java +++ b/sql/hive-thriftserver/src/gen/java/org/apache/hive/service/cli/thrift/TTypeQualifierValue.java @@ -112,7 +112,7 @@ public TTypeQualifierValue() { super(); } - public TTypeQualifierValue(_Fields setField, Object value) { + public TTypeQualifierValue(TTypeQualifierValue._Fields setField, Object value) { super(setField, value); } diff --git a/sql/hive-thriftserver/src/main/java/org/apache/hive/service/AbstractService.java b/sql/hive-thriftserver/src/main/java/org/apache/hive/service/AbstractService.java index 9dd0efc03968d..7e557aeccf5b0 100644 --- a/sql/hive-thriftserver/src/main/java/org/apache/hive/service/AbstractService.java +++ b/sql/hive-thriftserver/src/main/java/org/apache/hive/service/AbstractService.java @@ -36,7 +36,7 @@ public abstract class AbstractService implements Service { /** * Service state: initially {@link STATE#NOTINITED}. */ - private STATE state = STATE.NOTINITED; + private Service.STATE state = STATE.NOTINITED; /** * Service name. @@ -70,7 +70,7 @@ public AbstractService(String name) { } @Override - public synchronized STATE getServiceState() { + public synchronized Service.STATE getServiceState() { return state; } @@ -159,7 +159,7 @@ public long getStartTime() { * if the service state is different from * the desired state */ - private void ensureCurrentState(STATE currentState) { + private void ensureCurrentState(Service.STATE currentState) { ServiceOperations.ensureCurrentState(state, currentState); } @@ -173,7 +173,7 @@ private void ensureCurrentState(STATE currentState) { * @param newState * new service state */ - private void changeState(STATE newState) { + private void changeState(Service.STATE newState) { state = newState; // notify listeners for (ServiceStateChangeListener l : listeners) { diff --git a/sql/hive-thriftserver/src/main/java/org/apache/hive/service/FilterService.java b/sql/hive-thriftserver/src/main/java/org/apache/hive/service/FilterService.java index 5a508745414a7..15551da4785f6 100644 --- a/sql/hive-thriftserver/src/main/java/org/apache/hive/service/FilterService.java +++ b/sql/hive-thriftserver/src/main/java/org/apache/hive/service/FilterService.java @@ -71,7 +71,7 @@ public HiveConf getHiveConf() { } @Override - public STATE getServiceState() { + public Service.STATE getServiceState() { return service.getServiceState(); } diff --git a/sql/hive-thriftserver/src/main/java/org/apache/hive/service/cli/Column.java b/sql/hive-thriftserver/src/main/java/org/apache/hive/service/cli/Column.java index adb269aa235ea..26d0f718f383a 100644 --- a/sql/hive-thriftserver/src/main/java/org/apache/hive/service/cli/Column.java +++ b/sql/hive-thriftserver/src/main/java/org/apache/hive/service/cli/Column.java @@ -349,7 +349,7 @@ public void addValue(Type type, Object field) { break; case FLOAT_TYPE: nulls.set(size, field == null); - doubleVars()[size] = field == null ? 0 : new Double(field.toString()); + doubleVars()[size] = field == null ? 0 : Double.valueOf(field.toString()); break; case DOUBLE_TYPE: nulls.set(size, field == null); diff --git a/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/HiveCliSessionStateSuite.scala b/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/HiveCliSessionStateSuite.scala index 5f9ea4d26790b..035b71a37a692 100644 --- a/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/HiveCliSessionStateSuite.scala +++ b/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/HiveCliSessionStateSuite.scala @@ -28,7 +28,7 @@ import org.apache.spark.sql.hive.HiveUtils class HiveCliSessionStateSuite extends SparkFunSuite { def withSessionClear(f: () => Unit): Unit = { - try f finally SessionState.detachSession() + try f() finally SessionState.detachSession() } test("CliSessionState will be reused") { diff --git a/sql/hive/benchmarks/OrcReadBenchmark-results.txt b/sql/hive/benchmarks/OrcReadBenchmark-results.txt index c77f966723d71..80c2f5e93405a 100644 --- a/sql/hive/benchmarks/OrcReadBenchmark-results.txt +++ b/sql/hive/benchmarks/OrcReadBenchmark-results.txt @@ -2,172 +2,172 @@ SQL Single Numeric Column Scan ================================================================================================ -OpenJDK 64-Bit Server VM 1.8.0_181-b13 on Linux 3.10.0-862.3.2.el7.x86_64 +OpenJDK 64-Bit Server VM 1.8.0_191-b12 on Linux 3.10.0-862.3.2.el7.x86_64 Intel(R) Xeon(R) CPU E5-2670 v2 @ 2.50GHz SQL Single TINYINT Column Scan: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------ -Native ORC MR 1630 / 1639 9.7 103.6 1.0X -Native ORC Vectorized 253 / 288 62.2 16.1 6.4X -Native ORC Vectorized with copy 227 / 244 69.2 14.5 7.2X -Hive built-in ORC 1980 / 1991 7.9 125.9 0.8X +Native ORC MR 1725 / 1759 9.1 109.7 1.0X +Native ORC Vectorized 272 / 316 57.8 17.3 6.3X +Native ORC Vectorized with copy 239 / 254 65.7 15.2 7.2X +Hive built-in ORC 1970 / 1987 8.0 125.3 0.9X -OpenJDK 64-Bit Server VM 1.8.0_181-b13 on Linux 3.10.0-862.3.2.el7.x86_64 +OpenJDK 64-Bit Server VM 1.8.0_191-b12 on Linux 3.10.0-862.3.2.el7.x86_64 Intel(R) Xeon(R) CPU E5-2670 v2 @ 2.50GHz SQL Single SMALLINT Column Scan: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------ -Native ORC MR 1587 / 1589 9.9 100.9 1.0X -Native ORC Vectorized 227 / 242 69.2 14.5 7.0X -Native ORC Vectorized with copy 228 / 238 69.0 14.5 7.0X -Hive built-in ORC 2323 / 2332 6.8 147.7 0.7X +Native ORC MR 1633 / 1672 9.6 103.8 1.0X +Native ORC Vectorized 238 / 255 66.0 15.1 6.9X +Native ORC Vectorized with copy 235 / 253 66.8 15.0 6.9X +Hive built-in ORC 2293 / 2305 6.9 145.8 0.7X -OpenJDK 64-Bit Server VM 1.8.0_181-b13 on Linux 3.10.0-862.3.2.el7.x86_64 +OpenJDK 64-Bit Server VM 1.8.0_191-b12 on Linux 3.10.0-862.3.2.el7.x86_64 Intel(R) Xeon(R) CPU E5-2670 v2 @ 2.50GHz SQL Single INT Column Scan: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------ -Native ORC MR 1726 / 1771 9.1 109.7 1.0X -Native ORC Vectorized 309 / 333 50.9 19.7 5.6X -Native ORC Vectorized with copy 313 / 321 50.2 19.9 5.5X -Hive built-in ORC 2668 / 2672 5.9 169.6 0.6X +Native ORC MR 1677 / 1699 9.4 106.6 1.0X +Native ORC Vectorized 325 / 342 48.3 20.7 5.2X +Native ORC Vectorized with copy 328 / 341 47.9 20.9 5.1X +Hive built-in ORC 2561 / 2569 6.1 162.8 0.7X -OpenJDK 64-Bit Server VM 1.8.0_181-b13 on Linux 3.10.0-862.3.2.el7.x86_64 +OpenJDK 64-Bit Server VM 1.8.0_191-b12 on Linux 3.10.0-862.3.2.el7.x86_64 Intel(R) Xeon(R) CPU E5-2670 v2 @ 2.50GHz SQL Single BIGINT Column Scan: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------ -Native ORC MR 1722 / 1747 9.1 109.5 1.0X -Native ORC Vectorized 395 / 403 39.8 25.1 4.4X -Native ORC Vectorized with copy 399 / 405 39.4 25.4 4.3X -Hive built-in ORC 2767 / 2777 5.7 175.9 0.6X +Native ORC MR 1791 / 1795 8.8 113.9 1.0X +Native ORC Vectorized 400 / 408 39.3 25.4 4.5X +Native ORC Vectorized with copy 410 / 417 38.4 26.1 4.4X +Hive built-in ORC 2713 / 2720 5.8 172.5 0.7X -OpenJDK 64-Bit Server VM 1.8.0_181-b13 on Linux 3.10.0-862.3.2.el7.x86_64 +OpenJDK 64-Bit Server VM 1.8.0_191-b12 on Linux 3.10.0-862.3.2.el7.x86_64 Intel(R) Xeon(R) CPU E5-2670 v2 @ 2.50GHz SQL Single FLOAT Column Scan: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------ -Native ORC MR 1797 / 1824 8.8 114.2 1.0X -Native ORC Vectorized 434 / 441 36.2 27.6 4.1X -Native ORC Vectorized with copy 437 / 447 36.0 27.8 4.1X -Hive built-in ORC 2701 / 2710 5.8 171.7 0.7X +Native ORC MR 1791 / 1805 8.8 113.8 1.0X +Native ORC Vectorized 433 / 438 36.3 27.5 4.1X +Native ORC Vectorized with copy 441 / 447 35.7 28.0 4.1X +Hive built-in ORC 2690 / 2803 5.8 171.0 0.7X -OpenJDK 64-Bit Server VM 1.8.0_181-b13 on Linux 3.10.0-862.3.2.el7.x86_64 +OpenJDK 64-Bit Server VM 1.8.0_191-b12 on Linux 3.10.0-862.3.2.el7.x86_64 Intel(R) Xeon(R) CPU E5-2670 v2 @ 2.50GHz SQL Single DOUBLE Column Scan: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------ -Native ORC MR 1931 / 2028 8.1 122.8 1.0X -Native ORC Vectorized 542 / 557 29.0 34.5 3.6X -Native ORC Vectorized with copy 550 / 564 28.6 35.0 3.5X -Hive built-in ORC 2816 / 3206 5.6 179.1 0.7X +Native ORC MR 1911 / 1930 8.2 121.5 1.0X +Native ORC Vectorized 543 / 552 29.0 34.5 3.5X +Native ORC Vectorized with copy 547 / 555 28.8 34.8 3.5X +Hive built-in ORC 2967 / 3065 5.3 188.6 0.6X ================================================================================================ Int and String Scan ================================================================================================ -OpenJDK 64-Bit Server VM 1.8.0_181-b13 on Linux 3.10.0-862.3.2.el7.x86_64 +OpenJDK 64-Bit Server VM 1.8.0_191-b12 on Linux 3.10.0-862.3.2.el7.x86_64 Intel(R) Xeon(R) CPU E5-2670 v2 @ 2.50GHz Int and String Scan: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------ -Native ORC MR 4012 / 4068 2.6 382.6 1.0X -Native ORC Vectorized 2337 / 2339 4.5 222.9 1.7X -Native ORC Vectorized with copy 2520 / 2540 4.2 240.3 1.6X -Hive built-in ORC 5503 / 5575 1.9 524.8 0.7X +Native ORC MR 4160 / 4188 2.5 396.7 1.0X +Native ORC Vectorized 2405 / 2406 4.4 229.4 1.7X +Native ORC Vectorized with copy 2588 / 2592 4.1 246.8 1.6X +Hive built-in ORC 5514 / 5562 1.9 525.9 0.8X ================================================================================================ Partitioned Table Scan ================================================================================================ -OpenJDK 64-Bit Server VM 1.8.0_181-b13 on Linux 3.10.0-862.3.2.el7.x86_64 +OpenJDK 64-Bit Server VM 1.8.0_191-b12 on Linux 3.10.0-862.3.2.el7.x86_64 Intel(R) Xeon(R) CPU E5-2670 v2 @ 2.50GHz Partitioned Table: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------ -Data column - Native ORC MR 2020 / 2025 7.8 128.4 1.0X -Data column - Native ORC Vectorized 398 / 409 39.5 25.3 5.1X -Data column - Native ORC Vectorized with copy 406 / 411 38.8 25.8 5.0X -Data column - Hive built-in ORC 2967 / 2969 5.3 188.6 0.7X -Partition column - Native ORC MR 1494 / 1505 10.5 95.0 1.4X -Partition column - Native ORC Vectorized 73 / 82 216.3 4.6 27.8X -Partition column - Native ORC Vectorized with copy 71 / 80 221.4 4.5 28.4X -Partition column - Hive built-in ORC 1932 / 1937 8.1 122.8 1.0X -Both columns - Native ORC MR 2057 / 2071 7.6 130.8 1.0X -Both columns - Native ORC Vectorized 445 / 448 35.4 28.3 4.5X -Both column - Native ORC Vectorized with copy 534 / 539 29.4 34.0 3.8X -Both columns - Hive built-in ORC 2994 / 2994 5.3 190.3 0.7X +Data column - Native ORC MR 1863 / 1867 8.4 118.4 1.0X +Data column - Native ORC Vectorized 411 / 418 38.2 26.2 4.5X +Data column - Native ORC Vectorized with copy 417 / 422 37.8 26.5 4.5X +Data column - Hive built-in ORC 3297 / 3308 4.8 209.6 0.6X +Partition column - Native ORC MR 1505 / 1506 10.4 95.7 1.2X +Partition column - Native ORC Vectorized 80 / 93 195.6 5.1 23.2X +Partition column - Native ORC Vectorized with copy 78 / 86 201.4 5.0 23.9X +Partition column - Hive built-in ORC 1960 / 1979 8.0 124.6 1.0X +Both columns - Native ORC MR 2076 / 2090 7.6 132.0 0.9X +Both columns - Native ORC Vectorized 450 / 463 34.9 28.6 4.1X +Both column - Native ORC Vectorized with copy 532 / 538 29.6 33.8 3.5X +Both columns - Hive built-in ORC 3528 / 3548 4.5 224.3 0.5X ================================================================================================ Repeated String Scan ================================================================================================ -OpenJDK 64-Bit Server VM 1.8.0_181-b13 on Linux 3.10.0-862.3.2.el7.x86_64 +OpenJDK 64-Bit Server VM 1.8.0_191-b12 on Linux 3.10.0-862.3.2.el7.x86_64 Intel(R) Xeon(R) CPU E5-2670 v2 @ 2.50GHz Repeated String: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------ -Native ORC MR 1771 / 1785 5.9 168.9 1.0X -Native ORC Vectorized 372 / 375 28.2 35.5 4.8X -Native ORC Vectorized with copy 543 / 576 19.3 51.8 3.3X -Hive built-in ORC 2671 / 2671 3.9 254.7 0.7X +Native ORC MR 1727 / 1733 6.1 164.7 1.0X +Native ORC Vectorized 375 / 379 28.0 35.7 4.6X +Native ORC Vectorized with copy 552 / 556 19.0 52.6 3.1X +Hive built-in ORC 2665 / 2666 3.9 254.2 0.6X ================================================================================================ String with Nulls Scan ================================================================================================ -OpenJDK 64-Bit Server VM 1.8.0_181-b13 on Linux 3.10.0-862.3.2.el7.x86_64 +OpenJDK 64-Bit Server VM 1.8.0_191-b12 on Linux 3.10.0-862.3.2.el7.x86_64 Intel(R) Xeon(R) CPU E5-2670 v2 @ 2.50GHz String with Nulls Scan (0.0%): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------ -Native ORC MR 3276 / 3302 3.2 312.5 1.0X -Native ORC Vectorized 1057 / 1080 9.9 100.8 3.1X -Native ORC Vectorized with copy 1420 / 1431 7.4 135.4 2.3X -Hive built-in ORC 5377 / 5407 2.0 512.8 0.6X +Native ORC MR 3324 / 3325 3.2 317.0 1.0X +Native ORC Vectorized 1085 / 1106 9.7 103.4 3.1X +Native ORC Vectorized with copy 1463 / 1471 7.2 139.5 2.3X +Hive built-in ORC 5272 / 5299 2.0 502.8 0.6X -OpenJDK 64-Bit Server VM 1.8.0_181-b13 on Linux 3.10.0-862.3.2.el7.x86_64 +OpenJDK 64-Bit Server VM 1.8.0_191-b12 on Linux 3.10.0-862.3.2.el7.x86_64 Intel(R) Xeon(R) CPU E5-2670 v2 @ 2.50GHz -String with Nulls Scan (0.5%): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative +String with Nulls Scan (50.0%): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------ -Native ORC MR 3147 / 3147 3.3 300.1 1.0X -Native ORC Vectorized 1305 / 1319 8.0 124.4 2.4X -Native ORC Vectorized with copy 1685 / 1686 6.2 160.7 1.9X -Hive built-in ORC 4077 / 4085 2.6 388.8 0.8X +Native ORC MR 3045 / 3046 3.4 290.4 1.0X +Native ORC Vectorized 1248 / 1260 8.4 119.0 2.4X +Native ORC Vectorized with copy 1609 / 1624 6.5 153.5 1.9X +Hive built-in ORC 3989 / 3999 2.6 380.4 0.8X -OpenJDK 64-Bit Server VM 1.8.0_181-b13 on Linux 3.10.0-862.3.2.el7.x86_64 +OpenJDK 64-Bit Server VM 1.8.0_191-b12 on Linux 3.10.0-862.3.2.el7.x86_64 Intel(R) Xeon(R) CPU E5-2670 v2 @ 2.50GHz -String with Nulls Scan (0.95%): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative +String with Nulls Scan (95.0%): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------ -Native ORC MR 1739 / 1744 6.0 165.8 1.0X -Native ORC Vectorized 500 / 501 21.0 47.7 3.5X -Native ORC Vectorized with copy 618 / 631 17.0 58.9 2.8X -Hive built-in ORC 2411 / 2427 4.3 229.9 0.7X +Native ORC MR 1692 / 1694 6.2 161.3 1.0X +Native ORC Vectorized 471 / 493 22.3 44.9 3.6X +Native ORC Vectorized with copy 588 / 590 17.8 56.1 2.9X +Hive built-in ORC 2398 / 2411 4.4 228.7 0.7X ================================================================================================ Single Column Scan From Wide Columns ================================================================================================ -OpenJDK 64-Bit Server VM 1.8.0_181-b13 on Linux 3.10.0-862.3.2.el7.x86_64 +OpenJDK 64-Bit Server VM 1.8.0_191-b12 on Linux 3.10.0-862.3.2.el7.x86_64 Intel(R) Xeon(R) CPU E5-2670 v2 @ 2.50GHz Single Column Scan from 100 columns: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------ -Native ORC MR 1348 / 1366 0.8 1285.3 1.0X -Native ORC Vectorized 119 / 134 8.8 113.5 11.3X -Native ORC Vectorized with copy 119 / 148 8.8 113.9 11.3X -Hive built-in ORC 487 / 507 2.2 464.8 2.8X +Native ORC MR 1371 / 1379 0.8 1307.5 1.0X +Native ORC Vectorized 121 / 135 8.6 115.8 11.3X +Native ORC Vectorized with copy 122 / 138 8.6 116.2 11.3X +Hive built-in ORC 521 / 561 2.0 497.1 2.6X -OpenJDK 64-Bit Server VM 1.8.0_181-b13 on Linux 3.10.0-862.3.2.el7.x86_64 +OpenJDK 64-Bit Server VM 1.8.0_191-b12 on Linux 3.10.0-862.3.2.el7.x86_64 Intel(R) Xeon(R) CPU E5-2670 v2 @ 2.50GHz Single Column Scan from 200 columns: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------ -Native ORC MR 2667 / 2837 0.4 2543.6 1.0X -Native ORC Vectorized 203 / 222 5.2 193.4 13.2X -Native ORC Vectorized with copy 217 / 255 4.8 207.0 12.3X -Hive built-in ORC 737 / 741 1.4 702.4 3.6X +Native ORC MR 2711 / 2767 0.4 2585.5 1.0X +Native ORC Vectorized 210 / 232 5.0 200.5 12.9X +Native ORC Vectorized with copy 208 / 219 5.0 198.4 13.0X +Hive built-in ORC 764 / 775 1.4 728.3 3.5X -OpenJDK 64-Bit Server VM 1.8.0_181-b13 on Linux 3.10.0-862.3.2.el7.x86_64 +OpenJDK 64-Bit Server VM 1.8.0_191-b12 on Linux 3.10.0-862.3.2.el7.x86_64 Intel(R) Xeon(R) CPU E5-2670 v2 @ 2.50GHz Single Column Scan from 300 columns: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------ -Native ORC MR 3954 / 3956 0.3 3770.4 1.0X -Native ORC Vectorized 348 / 360 3.0 331.7 11.4X -Native ORC Vectorized with copy 349 / 359 3.0 333.2 11.3X -Hive built-in ORC 1057 / 1067 1.0 1008.0 3.7X +Native ORC MR 3979 / 3988 0.3 3794.4 1.0X +Native ORC Vectorized 357 / 366 2.9 340.2 11.2X +Native ORC Vectorized with copy 361 / 371 2.9 344.5 11.0X +Hive built-in ORC 1091 / 1095 1.0 1040.5 3.6X diff --git a/sql/hive/pom.xml b/sql/hive/pom.xml index 114a108ba2752..9642e5e3d035a 100644 --- a/sql/hive/pom.xml +++ b/sql/hive/pom.xml @@ -21,12 +21,12 @@ 4.0.0 org.apache.spark - spark-parent_2.11 + spark-parent_2.12 3.0.0-SNAPSHOT ../../pom.xml - spark-hive_2.11 + spark-hive_2.12 jar Spark Project Hive http://spark.apache.org/ diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala index 758b374c5f94d..66e4329861234 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala @@ -124,7 +124,7 @@ private[hive] class HiveMetastoreCatalog(sparkSession: SparkSession) extends Log val lazyPruningEnabled = sparkSession.sqlContext.conf.manageFilesourcePartitions val tablePath = new Path(relation.tableMeta.location) - val fileFormat = fileFormatClass.newInstance() + val fileFormat = fileFormatClass.getConstructor().newInstance() val result = if (relation.isPartitioned) { val partitionSchema = relation.tableMeta.partitionSchema diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionStateBuilder.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionStateBuilder.scala index 2882672f327c4..4f3914740ec20 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionStateBuilder.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionStateBuilder.scala @@ -17,7 +17,7 @@ package org.apache.spark.sql.hive -import org.apache.spark.annotation.{Experimental, InterfaceStability} +import org.apache.spark.annotation.{Experimental, Unstable} import org.apache.spark.sql._ import org.apache.spark.sql.catalyst.analysis.Analyzer import org.apache.spark.sql.catalyst.catalog.ExternalCatalogWithListener @@ -32,7 +32,7 @@ import org.apache.spark.sql.internal.{BaseSessionStateBuilder, SessionResourceLo * Builder that produces a Hive-aware `SessionState`. */ @Experimental -@InterfaceStability.Unstable +@Unstable class HiveSessionStateBuilder(session: SparkSession, parentState: Option[SessionState] = None) extends BaseSessionStateBuilder(session, parentState) { diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveShim.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveShim.scala index 11afe1af32809..c9fc3d4a02c4b 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveShim.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveShim.scala @@ -217,7 +217,7 @@ private[hive] object HiveShim { instance.asInstanceOf[UDFType] } else { val func = Utils.getContextOrSparkClassLoader - .loadClass(functionClassName).newInstance.asInstanceOf[UDFType] + .loadClass(functionClassName).getConstructor().newInstance().asInstanceOf[UDFType] if (!func.isInstanceOf[UDF]) { // We cache the function if it's no the Simple UDF, // as we always have to create new instance for Simple UDF diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveUtils.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveUtils.scala index 74f21532b22df..66067704195dd 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveUtils.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveUtils.scala @@ -62,7 +62,7 @@ private[spark] object HiveUtils extends Logging { val HIVE_METASTORE_VERSION = buildConf("spark.sql.hive.metastore.version") .doc("Version of the Hive metastore. Available options are " + - s"0.12.0 through 2.3.3.") + s"0.12.0 through 2.3.4.") .stringConf .createWithDefault(builtinHiveVersion) diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/TableReader.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/TableReader.scala index 9443fbb4330a5..536bc4a3f4ec4 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/TableReader.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/TableReader.scala @@ -132,7 +132,7 @@ class HadoopTableReader( val deserializedHadoopRDD = hadoopRDD.mapPartitions { iter => val hconf = broadcastedHadoopConf.value.value - val deserializer = deserializerClass.newInstance() + val deserializer = deserializerClass.getConstructor().newInstance() deserializer.initialize(hconf, localTableDesc.getProperties) HadoopTableReader.fillObject(iter, deserializer, attrsWithIndex, mutableRow, deserializer) } @@ -245,7 +245,7 @@ class HadoopTableReader( val localTableDesc = tableDesc createHadoopRdd(localTableDesc, inputPathStr, ifc).mapPartitions { iter => val hconf = broadcastedHiveConf.value.value - val deserializer = localDeserializer.newInstance() + val deserializer = localDeserializer.getConstructor().newInstance() // SPARK-13709: For SerDes like AvroSerDe, some essential information (e.g. Avro schema // information) may be defined in table properties. Here we should merge table properties // and partition properties before initializing the deserializer. Note that partition @@ -257,7 +257,7 @@ class HadoopTableReader( } deserializer.initialize(hconf, props) // get the table deserializer - val tableSerDe = localTableDesc.getDeserializerClass.newInstance() + val tableSerDe = localTableDesc.getDeserializerClass.getConstructor().newInstance() tableSerDe.initialize(hconf, localTableDesc.getProperties) // fill the non partition key attributes diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveShim.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveShim.scala index bc9d4cd7f4181..4d484904d2c27 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveShim.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveShim.scala @@ -987,7 +987,7 @@ private[client] class Shim_v1_2 extends Shim_v1_1 { part: JList[String], deleteData: Boolean, purge: Boolean): Unit = { - val dropOptions = dropOptionsClass.newInstance().asInstanceOf[Object] + val dropOptions = dropOptionsClass.getConstructor().newInstance().asInstanceOf[Object] dropOptionsDeleteData.setBoolean(dropOptions, deleteData) dropOptionsPurge.setBoolean(dropOptions, purge) dropPartitionMethod.invoke(hive, dbName, tableName, part, dropOptions) diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/IsolatedClientLoader.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/IsolatedClientLoader.scala index 51f944374d76b..6fcd5f9f096d2 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/IsolatedClientLoader.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/IsolatedClientLoader.scala @@ -65,7 +65,7 @@ private[hive] object IsolatedClientLoader extends Logging { case e: RuntimeException if e.getMessage.contains("hadoop") => // If the error message contains hadoop, it is probably because the hadoop // version cannot be resolved. - val fallbackVersion = "2.7.3" + val fallbackVersion = "2.7.4" logWarning(s"Failed to resolve Hadoop artifacts for the version $hadoopVersion. We " + s"will change the hadoop version from $hadoopVersion to $fallbackVersion and try " + "again. Hadoop classes will not be shared between Spark and Hive metastore client. " + @@ -99,7 +99,7 @@ private[hive] object IsolatedClientLoader extends Logging { case "2.0" | "2.0.0" | "2.0.1" => hive.v2_0 case "2.1" | "2.1.0" | "2.1.1" => hive.v2_1 case "2.2" | "2.2.0" => hive.v2_2 - case "2.3" | "2.3.0" | "2.3.1" | "2.3.2" | "2.3.3" => hive.v2_3 + case "2.3" | "2.3.0" | "2.3.1" | "2.3.2" | "2.3.3" | "2.3.4" => hive.v2_3 case version => throw new UnsupportedOperationException(s"Unsupported Hive Metastore version ($version). " + s"Please set ${HiveUtils.HIVE_METASTORE_VERSION.key} with a valid version.") @@ -187,7 +187,6 @@ private[hive] class IsolatedClientLoader( name.startsWith("org.slf4j") || name.startsWith("org.apache.log4j") || // log4j1.x name.startsWith("org.apache.logging.log4j") || // log4j2 - name.startsWith("org.apache.derby.") || name.startsWith("org.apache.spark.") || (sharesHadoopClasses && isHadoopClass) || name.startsWith("scala.") || diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/package.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/package.scala index 25e9886fa6576..e4cf7299d2af6 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/package.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/package.scala @@ -75,7 +75,7 @@ package object client { exclusions = Seq("org.apache.curator:*", "org.pentaho:pentaho-aggdesigner-algorithm")) - case object v2_3 extends HiveVersion("2.3.3", + case object v2_3 extends HiveVersion("2.3.4", exclusions = Seq("org.apache.curator:*", "org.pentaho:pentaho-aggdesigner-algorithm")) diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/CreateHiveTableAsSelectCommand.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/CreateHiveTableAsSelectCommand.scala index aa573b54a2b62..630bea5161f19 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/CreateHiveTableAsSelectCommand.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/CreateHiveTableAsSelectCommand.scala @@ -96,7 +96,7 @@ case class CreateHiveTableAsSelectCommand( } override def argString: String = { - s"[Database:${tableDesc.database}}, " + + s"[Database:${tableDesc.database}, " + s"TableName: ${tableDesc.identifier.table}, " + s"InsertIntoHiveTable]" } diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/HiveFileFormat.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/HiveFileFormat.scala index 4a7cd6901923b..d8d2a80e0e8b7 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/HiveFileFormat.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/HiveFileFormat.scala @@ -115,7 +115,8 @@ class HiveOutputWriter( private def tableDesc = fileSinkConf.getTableInfo private val serializer = { - val serializer = tableDesc.getDeserializerClass.newInstance().asInstanceOf[Serializer] + val serializer = tableDesc.getDeserializerClass.getConstructor(). + newInstance().asInstanceOf[Serializer] serializer.initialize(jobConf, tableDesc.getProperties) serializer } diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/HiveTableScanExec.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/HiveTableScanExec.scala index 92c6632ad7863..fa940fe73bd13 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/HiveTableScanExec.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/HiveTableScanExec.scala @@ -120,7 +120,7 @@ case class HiveTableScanExec( HiveShim.appendReadColumns(hiveConf, neededColumnIDs, output.map(_.name)) - val deserializer = tableDesc.getDeserializerClass.newInstance + val deserializer = tableDesc.getDeserializerClass.getConstructor().newInstance() deserializer.initialize(hiveConf, tableDesc.getProperties) // Specifies types and object inspectors of columns to be scanned. diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/ScriptTransformationExec.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/ScriptTransformationExec.scala index 3328400b214fb..7b35a5f920ae9 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/ScriptTransformationExec.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/ScriptTransformationExec.scala @@ -123,7 +123,7 @@ case class ScriptTransformationExec( var scriptOutputWritable: Writable = null val reusedWritableObject: Writable = if (null != outputSerde) { - outputSerde.getSerializedClass().newInstance + outputSerde.getSerializedClass().getConstructor().newInstance() } else { null } @@ -404,7 +404,8 @@ case class HiveScriptIOSchema ( columnTypes: Seq[DataType], serdeProps: Seq[(String, String)]): AbstractSerDe = { - val serde = Utils.classForName(serdeClassName).newInstance.asInstanceOf[AbstractSerDe] + val serde = Utils.classForName(serdeClassName).getConstructor(). + newInstance().asInstanceOf[AbstractSerDe] val columnTypesNames = columnTypes.map(_.toTypeInfo.getTypeName()).mkString(",") @@ -424,7 +425,8 @@ case class HiveScriptIOSchema ( inputStream: InputStream, conf: Configuration): Option[RecordReader] = { recordReaderClass.map { klass => - val instance = Utils.classForName(klass).newInstance().asInstanceOf[RecordReader] + val instance = Utils.classForName(klass).getConstructor(). + newInstance().asInstanceOf[RecordReader] val props = new Properties() // Can not use props.putAll(outputSerdeProps.toMap.asJava) in scala-2.12 // See https://github.com/scala/bug/issues/10418 @@ -436,7 +438,8 @@ case class HiveScriptIOSchema ( def recordWriter(outputStream: OutputStream, conf: Configuration): Option[RecordWriter] = { recordWriterClass.map { klass => - val instance = Utils.classForName(klass).newInstance().asInstanceOf[RecordWriter] + val instance = Utils.classForName(klass).getConstructor(). + newInstance().asInstanceOf[RecordWriter] instance.initialize(outputStream, conf) instance } diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcFileFormat.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcFileFormat.scala index 89e6ea8604974..4e641e34c18d9 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcFileFormat.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcFileFormat.scala @@ -18,9 +18,11 @@ package org.apache.spark.sql.hive.orc import java.net.URI +import java.nio.charset.StandardCharsets.UTF_8 import java.util.Properties import scala.collection.JavaConverters._ +import scala.util.control.NonFatal import org.apache.hadoop.conf.Configuration import org.apache.hadoop.fs.{FileStatus, Path} @@ -31,10 +33,12 @@ import org.apache.hadoop.hive.serde2.typeinfo.{StructTypeInfo, TypeInfoUtils} import org.apache.hadoop.io.{NullWritable, Writable} import org.apache.hadoop.mapred.{JobConf, OutputFormat => MapRedOutputFormat, RecordWriter, Reporter} import org.apache.hadoop.mapreduce._ -import org.apache.hadoop.mapreduce.lib.input.{FileInputFormat, FileSplit} +import org.apache.hadoop.mapreduce.lib.input.FileInputFormat import org.apache.orc.OrcConf.COMPRESS -import org.apache.spark.TaskContext +import org.apache.spark.{SPARK_VERSION_SHORT, TaskContext} +import org.apache.spark.internal.Logging +import org.apache.spark.sql.SPARK_VERSION_METADATA_KEY import org.apache.spark.sql.SparkSession import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ @@ -274,12 +278,14 @@ private[orc] class OrcOutputWriter( override def close(): Unit = { if (recordWriterInstantiated) { + // Hive 1.2.1 ORC initializes its private `writer` field at the first write. + OrcFileFormat.addSparkVersionMetadata(recordWriter) recordWriter.close(Reporter.NULL) } } } -private[orc] object OrcFileFormat extends HiveInspectors { +private[orc] object OrcFileFormat extends HiveInspectors with Logging { // This constant duplicates `OrcInputFormat.SARG_PUSHDOWN`, which is unfortunately not public. private[orc] val SARG_PUSHDOWN = "sarg.pushdown" @@ -339,4 +345,18 @@ private[orc] object OrcFileFormat extends HiveInspectors { val (sortedIDs, sortedNames) = ids.zip(requestedSchema.fieldNames).sorted.unzip HiveShim.appendReadColumns(conf, sortedIDs, sortedNames) } + + /** + * Add a metadata specifying Spark version. + */ + def addSparkVersionMetadata(recordWriter: RecordWriter[NullWritable, Writable]): Unit = { + try { + val writerField = recordWriter.getClass.getDeclaredField("writer") + writerField.setAccessible(true) + val writer = writerField.get(recordWriter).asInstanceOf[Writer] + writer.addUserMetadata(SPARK_VERSION_METADATA_KEY, UTF_8.encode(SPARK_VERSION_SHORT)) + } catch { + case NonFatal(e) => log.warn(e.toString, e) + } + } } diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/test/TestHive.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/test/TestHive.scala index 634b3db19ec27..3508affda241a 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/test/TestHive.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/test/TestHive.scala @@ -33,9 +33,9 @@ import org.apache.hadoop.hive.serde2.`lazy`.LazySimpleSerDe import org.apache.spark.{SparkConf, SparkContext} import org.apache.spark.internal.Logging -import org.apache.spark.sql.{SparkSession, SQLContext} +import org.apache.spark.sql.{DataFrame, Dataset, SparkSession, SQLContext} import org.apache.spark.sql.catalyst.analysis.UnresolvedRelation -import org.apache.spark.sql.catalyst.catalog.{ExternalCatalog, ExternalCatalogWithListener} +import org.apache.spark.sql.catalyst.catalog.ExternalCatalogWithListener import org.apache.spark.sql.catalyst.optimizer.ConvertToLocalRelation import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, OneRowRelation} import org.apache.spark.sql.execution.{QueryExecution, SQLExecution} @@ -219,6 +219,16 @@ private[hive] class TestHiveSparkSession( sharedState.externalCatalog.unwrapped.asInstanceOf[HiveExternalCatalog].client.newSession() } + /** + * This is a temporary hack to override SparkSession.sql so we can still use the version of + * Dataset.ofRows that creates a TestHiveQueryExecution (rather than a normal QueryExecution + * which wouldn't load all the test tables). + */ + override def sql(sqlText: String): DataFrame = { + val plan = sessionState.sqlParser.parsePlan(sqlText) + Dataset.ofRows(self, plan) + } + override def newSession(): TestHiveSparkSession = { new TestHiveSparkSession(sc, Some(sharedState), None, loadTestTables) } @@ -586,7 +596,7 @@ private[hive] class TestHiveQueryExecution( logDebug(s"Query references test tables: ${referencedTestTables.mkString(", ")}") referencedTestTables.foreach(sparkSession.loadTestTable) // Proceed with analysis. - sparkSession.sessionState.analyzer.executeAndCheck(logical) + sparkSession.sessionState.analyzer.executeAndCheck(logical, tracker) } } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveExternalCatalogSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveExternalCatalogSuite.scala index 1de258f060943..0a522b6a11c80 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveExternalCatalogSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveExternalCatalogSuite.scala @@ -113,10 +113,4 @@ class HiveExternalCatalogSuite extends ExternalCatalogSuite { catalog.createDatabase(newDb("dbWithNullDesc").copy(description = null), ignoreIfExists = false) assert(catalog.getDatabase("dbWithNullDesc").description == "") } - - test("SPARK-23831: Add org.apache.derby to IsolatedClientLoader") { - val client1 = HiveUtils.newClientForMetadata(new SparkConf, new Configuration) - val client2 = HiveUtils.newClientForMetadata(new SparkConf, new Configuration) - assert(!client1.equals(client2)) - } } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveExternalCatalogVersionsSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveExternalCatalogVersionsSuite.scala index fd4985d131885..f1e842334416c 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveExternalCatalogVersionsSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveExternalCatalogVersionsSuite.scala @@ -206,7 +206,7 @@ class HiveExternalCatalogVersionsSuite extends SparkSubmitTestUtils { object PROCESS_TABLES extends QueryTest with SQLTestUtils { // Tests the latest version of every release line. - val testingVersions = Seq("2.1.3", "2.2.2", "2.3.2") + val testingVersions = Seq("2.2.2", "2.3.2", "2.4.0") protected var spark: SparkSession = _ diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveParquetMetastoreSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveParquetMetastoreSuite.scala index 8aef3621d57c5..da3f0bc205c05 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveParquetMetastoreSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveParquetMetastoreSuite.scala @@ -152,7 +152,7 @@ class HiveParquetMetastoreSuite extends ParquetPartitioningTest { } (1 to 10).map(i => (i, s"str$i")).toDF("a", "b").createOrReplaceTempView("jt") - (1 to 10).map(i => Tuple1(Seq(new Integer(i), null))).toDF("a") + (1 to 10).map(i => Tuple1(Seq(Integer.valueOf(i), null))).toDF("a") .createOrReplaceTempView("jt_array") assert(spark.sqlContext.getConf(HiveUtils.CONVERT_METASTORE_PARQUET.key) == "true") diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcReadBenchmark.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcReadBenchmark.scala index ec13288f759a6..eb3cde8472dac 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcReadBenchmark.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcReadBenchmark.scala @@ -32,9 +32,11 @@ import org.apache.spark.sql.types._ * Benchmark to measure ORC read performance. * {{{ * To run this benchmark: - * 1. without sbt: bin/spark-submit --class - * 2. build/sbt "sql/test:runMain " - * 3. generate result: SPARK_GENERATE_BENCHMARK_FILES=1 build/sbt "sql/test:runMain " + * 1. without sbt: bin/spark-submit --class + * --jars ,,,, + * + * 2. build/sbt "hive/test:runMain " + * 3. generate result: SPARK_GENERATE_BENCHMARK_FILES=1 build/sbt "hive/test:runMain " * Results will be written to "benchmarks/OrcReadBenchmark-results.txt". * }}} * @@ -266,8 +268,9 @@ object OrcReadBenchmark extends BenchmarkBase with SQLHelper { s"SELECT IF(RAND(1) < $fractionOfNulls, NULL, CAST(id as STRING)) AS c1, " + s"IF(RAND(2) < $fractionOfNulls, NULL, CAST(id as STRING)) AS c2 FROM t1")) + val percentageOfNulls = fractionOfNulls * 100 val benchmark = - new Benchmark(s"String with Nulls Scan ($fractionOfNulls%)", values, output = output) + new Benchmark(s"String with Nulls Scan ($percentageOfNulls%)", values, output = output) benchmark.addCase("Native ORC MR") { _ => withSQLConf(SQLConf.ORC_VECTORIZED_READER_ENABLED.key -> "false") { diff --git a/streaming/pom.xml b/streaming/pom.xml index f9a5029a8e818..1d1ea469f7d18 100644 --- a/streaming/pom.xml +++ b/streaming/pom.xml @@ -20,12 +20,12 @@ 4.0.0 org.apache.spark - spark-parent_2.11 + spark-parent_2.12 3.0.0-SNAPSHOT ../pom.xml - spark-streaming_2.11 + spark-streaming_2.12 streaming diff --git a/streaming/src/main/scala/org/apache/spark/streaming/rdd/WriteAheadLogBackedBlockRDD.scala b/streaming/src/main/scala/org/apache/spark/streaming/rdd/WriteAheadLogBackedBlockRDD.scala index 844760ab61d2e..f677c492d561f 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/rdd/WriteAheadLogBackedBlockRDD.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/rdd/WriteAheadLogBackedBlockRDD.scala @@ -136,7 +136,7 @@ class WriteAheadLogBackedBlockRDD[T: ClassTag]( // this dummy directory should not already exist otherwise the WAL will try to recover // past events from the directory and throw errors. val nonExistentDirectory = new File( - System.getProperty("java.io.tmpdir"), UUID.randomUUID().toString).getAbsolutePath + System.getProperty("java.io.tmpdir"), UUID.randomUUID().toString).toURI.toString writeAheadLog = WriteAheadLogUtils.createLogForReceiver( SparkEnv.get.conf, nonExistentDirectory, hadoopConf) dataRead = writeAheadLog.read(partition.walRecordHandle) diff --git a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobGenerator.scala b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobGenerator.scala index 8d83dc8a8fc04..6f0b46b6a4cb3 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobGenerator.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobGenerator.scala @@ -49,11 +49,11 @@ class JobGenerator(jobScheduler: JobScheduler) extends Logging { val clockClass = ssc.sc.conf.get( "spark.streaming.clock", "org.apache.spark.util.SystemClock") try { - Utils.classForName(clockClass).newInstance().asInstanceOf[Clock] + Utils.classForName(clockClass).getConstructor().newInstance().asInstanceOf[Clock] } catch { case e: ClassNotFoundException if clockClass.startsWith("org.apache.spark.streaming") => val newClockClass = clockClass.replace("org.apache.spark.streaming", "org.apache.spark") - Utils.classForName(newClockClass).newInstance().asInstanceOf[Clock] + Utils.classForName(newClockClass).getConstructor().newInstance().asInstanceOf[Clock] } } diff --git a/streaming/src/main/scala/org/apache/spark/streaming/util/StateMap.scala b/streaming/src/main/scala/org/apache/spark/streaming/util/StateMap.scala index 89524cd84ff32..618c036377aee 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/util/StateMap.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/util/StateMap.scala @@ -70,7 +70,7 @@ private[streaming] object StateMap { /** Implementation of StateMap interface representing an empty map */ private[streaming] class EmptyStateMap[K, S] extends StateMap[K, S] { override def put(key: K, session: S, updateTime: Long): Unit = { - throw new NotImplementedError("put() should not be called on an EmptyStateMap") + throw new UnsupportedOperationException("put() should not be called on an EmptyStateMap") } override def get(key: K): Option[S] = None override def getByTime(threshUpdatedTime: Long): Iterator[(K, S, Long)] = Iterator.empty diff --git a/streaming/src/test/scala/org/apache/spark/streaming/CheckpointSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/CheckpointSuite.scala index 19b621f11759d..2332ee2ab9de1 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/CheckpointSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/CheckpointSuite.scala @@ -808,7 +808,8 @@ class CheckpointSuite extends TestSuiteBase with DStreamCheckpointTester // visible to mutableURLClassLoader val loader = new MutableURLClassLoader( Array(jar), appClassLoader) - assert(loader.loadClass("testClz").newInstance().toString == "testStringValue") + assert(loader.loadClass("testClz").getConstructor().newInstance().toString === + "testStringValue") // create and serialize Array[testClz] // scalastyle:off classforname diff --git a/streaming/src/test/scala/org/apache/spark/streaming/InputStreamsSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/InputStreamsSuite.scala index 1cf21e8a28033..7376741f64a12 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/InputStreamsSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/InputStreamsSuite.scala @@ -31,6 +31,7 @@ import org.apache.commons.io.IOUtils import org.apache.hadoop.fs.{FileSystem, Path} import org.apache.hadoop.io.{LongWritable, Text} import org.apache.hadoop.mapreduce.lib.input.TextInputFormat +import org.scalatest.Assertions import org.scalatest.BeforeAndAfter import org.scalatest.concurrent.Eventually._ @@ -532,7 +533,7 @@ class InputStreamsSuite extends TestSuiteBase with BeforeAndAfter { /** This is a server to test the network input stream */ -class TestServer(portToBind: Int = 0) extends Logging { +class TestServer(portToBind: Int = 0) extends Logging with Assertions { val queue = new ArrayBlockingQueue[String](100) @@ -592,7 +593,7 @@ class TestServer(portToBind: Int = 0) extends Logging { servingThread.start() if (!waitForStart(10000)) { stop() - throw new AssertionError("Timeout: TestServer cannot start in 10 seconds") + fail("Timeout: TestServer cannot start in 10 seconds") } } diff --git a/streaming/src/test/scala/org/apache/spark/streaming/StateMapSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/StateMapSuite.scala index 484f3733e8423..e444132d3a626 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/StateMapSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/StateMapSuite.scala @@ -35,7 +35,7 @@ class StateMapSuite extends SparkFunSuite { test("EmptyStateMap") { val map = new EmptyStateMap[Int, Int] - intercept[scala.NotImplementedError] { + intercept[UnsupportedOperationException] { map.put(1, 1, 1) } assert(map.get(1) === None) diff --git a/tools/pom.xml b/tools/pom.xml index 247f5a6df4b08..6286fad403c83 100644 --- a/tools/pom.xml +++ b/tools/pom.xml @@ -19,12 +19,12 @@ 4.0.0 org.apache.spark - spark-parent_2.11 + spark-parent_2.12 3.0.0-SNAPSHOT ../pom.xml - spark-tools_2.11 + spark-tools_2.12 tools